feat: time_zone variable for mysql connections (#1607)

* feat: add timezone info to query context

* feat: parse mysql compatible time zone string

* feat: add method to timestamp for rendering timezone aware string

* feat: use timezone from session for time string rendering

* refactor: use querycontectref

* feat: implement session/timezone variable read/write

* style: resolve toml format

* test: update tests

* Apply suggestions from code review

Co-authored-by: dennis zhuang <killme2008@gmail.com>

* Update src/session/src/context.rs

Co-authored-by: dennis zhuang <killme2008@gmail.com>

* refactor: address review issues

---------

Co-authored-by: dennis zhuang <killme2008@gmail.com>
This commit is contained in:
Ning Sun
2023-05-22 18:30:23 +08:00
committed by GitHub
parent 32ad358323
commit 067c5ee7ce
11 changed files with 380 additions and 46 deletions

2
Cargo.lock generated
View File

@@ -1874,6 +1874,7 @@ name = "common-time"
version = "0.2.0"
dependencies = [
"chrono",
"chrono-tz 0.8.2",
"common-error",
"rand",
"serde",
@@ -7995,6 +7996,7 @@ dependencies = [
"arc-swap",
"common-catalog",
"common-telemetry",
"common-time",
]
[[package]]

View File

@@ -6,6 +6,7 @@ license.workspace = true
[dependencies]
chrono.workspace = true
chrono-tz = "0.8"
common-error = { path = "../error" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

View File

@@ -13,7 +13,7 @@
// limitations under the License.
use std::any::Any;
use std::num::TryFromIntError;
use std::num::{ParseIntError, TryFromIntError};
use chrono::ParseError;
use common_error::ext::ErrorExt;
@@ -40,14 +40,33 @@ pub enum Error {
#[snafu(display("Timestamp arithmetic overflow, msg: {}", msg))]
ArithmeticOverflow { msg: String, location: Location },
#[snafu(display("Invalid time zone offset: {hours}:{minutes}"))]
InvalidTimeZoneOffset {
hours: i32,
minutes: u32,
location: Location,
},
#[snafu(display("Invalid offset string {raw}: {source}"))]
ParseOffsetStr {
raw: String,
source: ParseIntError,
location: Location,
},
#[snafu(display("Invalid time zone string {raw}"))]
ParseTimeZoneName { raw: String, location: Location },
}
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::ParseDateStr { .. } | Error::ParseTimestamp { .. } => {
StatusCode::InvalidArguments
}
Error::ParseDateStr { .. }
| Error::ParseTimestamp { .. }
| Error::InvalidTimeZoneOffset { .. }
| Error::ParseOffsetStr { .. }
| Error::ParseTimeZoneName { .. } => StatusCode::InvalidArguments,
Error::TimestampOverflow { .. } => StatusCode::Internal,
Error::InvalidDateStr { .. } | Error::ArithmeticOverflow { .. } => {
StatusCode::InvalidArguments
@@ -64,7 +83,10 @@ impl ErrorExt for Error {
Error::ParseTimestamp { location, .. }
| Error::TimestampOverflow { location, .. }
| Error::ArithmeticOverflow { location, .. } => Some(*location),
Error::ParseDateStr { .. } => None,
Error::ParseDateStr { .. }
| Error::InvalidTimeZoneOffset { .. }
| Error::ParseOffsetStr { .. }
| Error::ParseTimeZoneName { .. } => None,
Error::InvalidDateStr { location, .. } => Some(*location),
}
}

View File

@@ -18,6 +18,7 @@ pub mod error;
pub mod range;
pub mod timestamp;
pub mod timestamp_millis;
pub mod timezone;
pub mod util;
pub use date::Date;
@@ -25,3 +26,4 @@ pub use datetime::DateTime;
pub use range::RangeMillis;
pub use timestamp::Timestamp;
pub use timestamp_millis::TimestampMillis;
pub use timezone::TimeZone;

View File

@@ -20,12 +20,13 @@ use std::str::FromStr;
use std::time::Duration;
use chrono::offset::Local;
use chrono::{DateTime, LocalResult, NaiveDateTime, TimeZone, Utc};
use chrono::{DateTime, LocalResult, NaiveDateTime, TimeZone as ChronoTimeZone, Utc};
use serde::{Deserialize, Serialize};
use snafu::{OptionExt, ResultExt};
use crate::error;
use crate::error::{ArithmeticOverflowSnafu, Error, ParseTimestampSnafu, TimestampOverflowSnafu};
use crate::timezone::TimeZone;
use crate::util::div_ceil;
#[derive(Debug, Clone, Default, Copy, Serialize, Deserialize)]
@@ -171,17 +172,33 @@ impl Timestamp {
/// Format timestamp to ISO8601 string. If the timestamp exceeds what chrono timestamp can
/// represent, this function simply print the timestamp unit and value in plain string.
pub fn to_iso8601_string(&self) -> String {
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f%z")
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f%z", None)
}
pub fn to_local_string(&self) -> String {
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f")
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f", None)
}
fn as_formatted_string(self, pattern: &str) -> String {
/// Format timestamp for given timezone.
/// When timezone is None, using local time by default.
pub fn to_timezone_aware_string(&self, tz: Option<TimeZone>) -> String {
self.as_formatted_string("%Y-%m-%d %H:%M:%S%.f", tz)
}
fn as_formatted_string(self, pattern: &str, timezone: Option<TimeZone>) -> String {
if let Some(v) = self.to_chrono_datetime() {
let local = Local {};
format!("{}", local.from_utc_datetime(&v).format(pattern))
match timezone {
Some(TimeZone::Offset(offset)) => {
format!("{}", offset.from_utc_datetime(&v).format(pattern))
}
Some(TimeZone::Named(tz)) => {
format!("{}", tz.from_utc_datetime(&v).format(pattern))
}
None => {
let local = Local {};
format!("{}", local.from_utc_datetime(&v).format(pattern))
}
}
} else {
format!("[Timestamp{}: {}]", self.unit, self.value)
}
@@ -934,4 +951,54 @@ mod tests {
Timestamp::new_millisecond(58).sub(Timestamp::new_millisecond(100))
);
}
#[test]
fn test_to_timezone_aware_string() {
std::env::set_var("TZ", "Asia/Shanghai");
assert_eq!(
"1970-01-01 08:00:00.001",
Timestamp::new(1, TimeUnit::Millisecond).to_timezone_aware_string(None)
);
assert_eq!(
"1970-01-01 08:00:00.001",
Timestamp::new(1, TimeUnit::Millisecond)
.to_timezone_aware_string(TimeZone::from_tz_string("SYSTEM").unwrap())
);
assert_eq!(
"1970-01-01 08:00:00.001",
Timestamp::new(1, TimeUnit::Millisecond)
.to_timezone_aware_string(TimeZone::from_tz_string("+08:00").unwrap())
);
assert_eq!(
"1970-01-01 07:00:00.001",
Timestamp::new(1, TimeUnit::Millisecond)
.to_timezone_aware_string(TimeZone::from_tz_string("+07:00").unwrap())
);
assert_eq!(
"1969-12-31 23:00:00.001",
Timestamp::new(1, TimeUnit::Millisecond)
.to_timezone_aware_string(TimeZone::from_tz_string("-01:00").unwrap())
);
assert_eq!(
"1970-01-01 08:00:00.001",
Timestamp::new(1, TimeUnit::Millisecond)
.to_timezone_aware_string(TimeZone::from_tz_string("Asia/Shanghai").unwrap())
);
assert_eq!(
"1970-01-01 00:00:00.001",
Timestamp::new(1, TimeUnit::Millisecond)
.to_timezone_aware_string(TimeZone::from_tz_string("UTC").unwrap())
);
assert_eq!(
"1970-01-01 01:00:00.001",
Timestamp::new(1, TimeUnit::Millisecond)
.to_timezone_aware_string(TimeZone::from_tz_string("Europe/Berlin").unwrap())
);
assert_eq!(
"1970-01-01 03:00:00.001",
Timestamp::new(1, TimeUnit::Millisecond)
.to_timezone_aware_string(TimeZone::from_tz_string("Europe/Moscow").unwrap())
);
}
}

View File

@@ -0,0 +1,158 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Display;
use std::str::FromStr;
use chrono::{FixedOffset, Local};
use chrono_tz::Tz;
use snafu::{OptionExt, ResultExt};
use crate::error::{
InvalidTimeZoneOffsetSnafu, ParseOffsetStrSnafu, ParseTimeZoneNameSnafu, Result,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TimeZone {
Offset(FixedOffset),
Named(Tz),
}
impl TimeZone {
/// Compute timezone from given offset hours and minutes
/// Return `None` if given offset exceeds scope
pub fn hours_mins_opt(offset_hours: i32, offset_mins: u32) -> Result<Self> {
let offset_secs = if offset_hours > 0 {
offset_hours * 3600 + offset_mins as i32 * 60
} else {
offset_hours * 3600 - offset_mins as i32 * 60
};
FixedOffset::east_opt(offset_secs)
.map(Self::Offset)
.context(InvalidTimeZoneOffsetSnafu {
hours: offset_hours,
minutes: offset_mins,
})
}
/// Parse timezone offset string and return None if given offset exceeds
/// scope.
///
/// String examples are available as described in
/// https://dev.mysql.com/doc/refman/8.0/en/time-zone-support.html
///
/// - `SYSTEM`
/// - Offset to UTC: `+08:00` , `-11:30`
/// - Named zones: `Asia/Shanghai`, `Europe/Berlin`
pub fn from_tz_string(tz_string: &str) -> Result<Option<Self>> {
// Use system timezone
if tz_string.eq_ignore_ascii_case("SYSTEM") {
Ok(None)
} else if let Some((hrs, mins)) = tz_string.split_once(':') {
let hrs = hrs
.parse::<i32>()
.context(ParseOffsetStrSnafu { raw: tz_string })?;
let mins = mins
.parse::<u32>()
.context(ParseOffsetStrSnafu { raw: tz_string })?;
Self::hours_mins_opt(hrs, mins).map(Some)
} else if let Ok(tz) = Tz::from_str(tz_string) {
Ok(Some(Self::Named(tz)))
} else {
ParseTimeZoneNameSnafu { raw: tz_string }.fail()
}
}
}
impl Display for TimeZone {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Named(tz) => write!(f, "{}", tz.name()),
Self::Offset(offset) => write!(f, "{}", offset),
}
}
}
#[inline]
pub fn system_time_zone_name() -> String {
Local::now().offset().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_tz_string() {
assert_eq!(None, TimeZone::from_tz_string("SYSTEM").unwrap());
let utc_plus_8 = Some(TimeZone::Offset(FixedOffset::east_opt(3600 * 8).unwrap()));
assert_eq!(utc_plus_8, TimeZone::from_tz_string("+8:00").unwrap());
assert_eq!(utc_plus_8, TimeZone::from_tz_string("+08:00").unwrap());
assert_eq!(utc_plus_8, TimeZone::from_tz_string("08:00").unwrap());
let utc_minus_8 = Some(TimeZone::Offset(FixedOffset::west_opt(3600 * 8).unwrap()));
assert_eq!(utc_minus_8, TimeZone::from_tz_string("-08:00").unwrap());
assert_eq!(utc_minus_8, TimeZone::from_tz_string("-8:00").unwrap());
let utc_minus_8_5 = Some(TimeZone::Offset(
FixedOffset::west_opt(3600 * 8 + 60 * 30).unwrap(),
));
assert_eq!(utc_minus_8_5, TimeZone::from_tz_string("-8:30").unwrap());
let utc_plus_max = Some(TimeZone::Offset(FixedOffset::east_opt(3600 * 14).unwrap()));
assert_eq!(utc_plus_max, TimeZone::from_tz_string("14:00").unwrap());
let utc_minus_max = Some(TimeZone::Offset(
FixedOffset::west_opt(3600 * 13 + 60 * 59).unwrap(),
));
assert_eq!(utc_minus_max, TimeZone::from_tz_string("-13:59").unwrap());
assert_eq!(
Some(TimeZone::Named(Tz::Asia__Shanghai)),
TimeZone::from_tz_string("Asia/Shanghai").unwrap()
);
assert_eq!(
Some(TimeZone::Named(Tz::UTC)),
TimeZone::from_tz_string("UTC").unwrap()
);
assert!(TimeZone::from_tz_string("WORLD_PEACE").is_err());
assert!(TimeZone::from_tz_string("A0:01").is_err());
assert!(TimeZone::from_tz_string("20:0A").is_err());
assert!(TimeZone::from_tz_string(":::::").is_err());
assert!(TimeZone::from_tz_string("Asia/London").is_err());
assert!(TimeZone::from_tz_string("Unknown").is_err());
}
#[test]
fn test_timezone_to_string() {
assert_eq!("UTC", TimeZone::Named(Tz::UTC).to_string());
assert_eq!(
"+01:00",
TimeZone::from_tz_string("01:00")
.unwrap()
.unwrap()
.to_string()
);
assert_eq!(
"Asia/Shanghai",
TimeZone::from_tz_string("Asia/Shanghai")
.unwrap()
.unwrap()
.to_string()
);
}
}

View File

@@ -20,6 +20,8 @@ use std::sync::Arc;
use common_query::Output;
use common_recordbatch::RecordBatches;
use common_time::timezone::system_time_zone_name;
use common_time::TimeZone;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::StringVector;
@@ -54,6 +56,10 @@ static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
static SHOW_SQL_MODE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'sql_mode'(.*))").unwrap());
// Time zone settings
static SET_TIME_ZONE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)^SET TIME_ZONE\s*=\s*'(\S+)'").unwrap());
static OTHER_NOT_SUPPORTED_STMT: Lazy<RegexSet> = Lazy::new(|| {
RegexSet::new([
// Txn.
@@ -124,8 +130,6 @@ static VAR_VALUES: Lazy<HashMap<&str, &str>> = Lazy::new(|| {
("transaction_isolation", "REPEATABLE-READ"),
("session.transaction_isolation", "REPEATABLE-READ"),
("session.transaction_read_only", "0"),
("time_zone", "UTC"),
("system_time_zone", "UTC"),
("max_allowed_packet", "134217728"),
("interactive_timeout", "31536000"),
("wait_timeout", "31536000"),
@@ -168,7 +172,7 @@ fn show_variables(name: &str, value: &str) -> RecordBatches {
.unwrap()
}
fn select_variable(query: &str) -> Option<Output> {
fn select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
let mut fields = vec![];
let mut values = vec![];
@@ -191,12 +195,24 @@ fn select_variable(query: &str) -> Option<Output> {
.unwrap_or("")
})
.collect();
// get value of variables from known sources or fallback to defaults
let value = match var_as[0] {
"time_zone" => query_context
.time_zone()
.map(|tz| tz.to_string())
.unwrap_or_else(|| "".to_owned()),
"system_time_zone" => system_time_zone_name(),
_ => VAR_VALUES
.get(var_as[0])
.map(|v| v.to_string())
.unwrap_or_else(|| "0".to_owned()),
};
values.push(Arc::new(StringVector::from(vec![value])) as _);
match var_as.len() {
1 => {
// @@aa
let value = VAR_VALUES.get(var_as[0]).unwrap_or(&"0");
values.push(Arc::new(StringVector::from(vec![*value])) as _);
// field is '@@aa'
fields.push(ColumnSchema::new(
&format!("@@{}", var_as[0]),
@@ -207,9 +223,6 @@ fn select_variable(query: &str) -> Option<Output> {
2 => {
// @@bb as cc:
// var is 'bb'.
let value = VAR_VALUES.get(var_as[0]).unwrap_or(&"0");
values.push(Arc::new(StringVector::from(vec![*value])) as _);
// field is 'cc'.
fields.push(ColumnSchema::new(
var_as[1],
@@ -227,12 +240,12 @@ fn select_variable(query: &str) -> Option<Output> {
Some(Output::RecordBatches(batches))
}
fn check_select_variable(query: &str) -> Option<Output> {
fn check_select_variable(query: &str, query_context: QueryContextRef) -> Option<Output> {
if vec![&SELECT_VAR_PATTERN, &MYSQL_CONN_JAVA_PATTERN]
.iter()
.any(|r| r.is_match(query))
{
select_variable(query)
select_variable(query, query_context)
} else {
None
}
@@ -251,6 +264,20 @@ fn check_show_variables(query: &str) -> Option<Output> {
recordbatches.map(Output::RecordBatches)
}
// TODO(sunng87): extract this to use sqlparser for more variables
fn check_set_variables(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
if let Some(captures) = SET_TIME_ZONE_PATTERN.captures(query) {
// get the capture
let tz = captures.get(1).unwrap();
if let Ok(timezone) = TimeZone::from_tz_string(tz.as_str()) {
query_ctx.set_time_zone(timezone);
return Some(Output::AffectedRows(0));
}
}
None
}
// Check for SET or others query, this is the final check of the federated query.
fn check_others(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
@@ -283,19 +310,12 @@ pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
}
// First to check the query is like "select @@variables".
let output = check_select_variable(query);
if output.is_some() {
return output;
}
// Then to check "show variables like ...".
let output = check_show_variables(query);
if output.is_some() {
return output;
}
// Last check.
check_others(query, query_ctx)
check_select_variable(query, query_ctx.clone())
// Then to check "show variables like ...".
.or_else(|| check_show_variables(query))
.or_else(|| check_set_variables(query, query_ctx.clone()))
// Last check
.or_else(|| check_others(query, query_ctx))
}
#[cfg(test)]
@@ -352,13 +372,15 @@ mod test {
+-----------------+------------------------+";
test(query, expected);
// set sysstem timezone
std::env::set_var("TZ", "Asia/Shanghai");
// complex variables
let query = "/* mysql-connector-java-8.0.17 (Revision: 16a712ddb3f826a1933ab42b0039f7fb9eebc6ec) */SELECT @@session.auto_increment_increment AS auto_increment_increment, @@character_set_client AS character_set_client, @@character_set_connection AS character_set_connection, @@character_set_results AS character_set_results, @@character_set_server AS character_set_server, @@collation_server AS collation_server, @@collation_connection AS collation_connection, @@init_connect AS init_connect, @@interactive_timeout AS interactive_timeout, @@license AS license, @@lower_case_table_names AS lower_case_table_names, @@max_allowed_packet AS max_allowed_packet, @@net_write_timeout AS net_write_timeout, @@performance_schema AS performance_schema, @@sql_mode AS sql_mode, @@system_time_zone AS system_time_zone, @@time_zone AS time_zone, @@transaction_isolation AS transaction_isolation, @@wait_timeout AS wait_timeout;";
let expected = "\
+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+
| auto_increment_increment | character_set_client | character_set_connection | character_set_results | character_set_server | collation_server | collation_connection | init_connect | interactive_timeout | license | lower_case_table_names | max_allowed_packet | net_write_timeout | performance_schema | sql_mode | system_time_zone | time_zone | transaction_isolation | wait_timeout; |
+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | UTC | UTC | REPEATABLE-READ | 31536000 |
| 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 31536000 | 0 | 0 | 134217728 | 31536000 | 0 | 0 | +08:00 | | REPEATABLE-READ | 31536000 |
+--------------------------+----------------------+--------------------------+-----------------------+----------------------+------------------+----------------------+--------------+---------------------+---------+------------------------+--------------------+-------------------+--------------------+----------+------------------+-----------+-----------------------+---------------+";
test(query, expected);
@@ -395,4 +417,31 @@ mod test {
+----------------------------------+";
test(query, expected);
}
#[test]
fn test_set_time_zone() {
let query_context = Arc::new(QueryContext::new());
let output = check("set time_zone = 'UTC'", query_context.clone());
match output.unwrap() {
Output::AffectedRows(rows) => {
assert_eq!(rows, 0)
}
_ => unreachable!(),
}
assert_eq!("UTC", query_context.time_zone().unwrap().to_string());
let output = check("select @@time_zone", query_context);
match output.unwrap() {
Output::RecordBatches(r) => {
let expected = "\
+-------------+
| @@time_zone |
+-------------+
| UTC |
+-------------+";
assert_eq!(r.pretty_print().unwrap(), expected);
}
_ => unreachable!(),
}
}
}

View File

@@ -231,7 +231,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
log::debug!("execute replaced query: {}", query);
let outputs = self.do_query(&query).await;
writer::write_output(w, &query, outputs).await?;
writer::write_output(w, &query, self.session.context(), outputs).await?;
Ok(())
}
@@ -263,7 +263,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
]
);
let outputs = self.do_query(query).await;
writer::write_output(writer, query, outputs).await?;
writer::write_output(writer, query, self.session.context(), outputs).await?;
Ok(())
}

View File

@@ -22,6 +22,7 @@ use datatypes::schema::{ColumnSchema, SchemaRef};
use opensrv_mysql::{
Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter,
};
use session::context::QueryContextRef;
use snafu::prelude::*;
use tokio::io::AsyncWrite;
@@ -31,9 +32,10 @@ use crate::error::{self, Error, Result};
pub async fn write_output<'a, W: AsyncWrite + Send + Sync + Unpin>(
w: QueryResultWriter<'a, W>,
query: &str,
query_context: QueryContextRef,
outputs: Vec<Result<Output>>,
) -> Result<()> {
let mut writer = Some(MysqlResultWriter::new(w));
let mut writer = Some(MysqlResultWriter::new(w, query_context.clone()));
for output in outputs {
let result_writer = writer.take().context(error::InternalSnafu {
err_msg: "Sending multiple result set is unsupported",
@@ -54,11 +56,18 @@ struct QueryResult {
pub struct MysqlResultWriter<'a, W: AsyncWrite + Unpin> {
writer: QueryResultWriter<'a, W>,
query_context: QueryContextRef,
}
impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
pub fn new(writer: QueryResultWriter<'a, W>) -> MysqlResultWriter<'a, W> {
MysqlResultWriter::<'a, W> { writer }
pub fn new(
writer: QueryResultWriter<'a, W>,
query_context: QueryContextRef,
) -> MysqlResultWriter<'a, W> {
MysqlResultWriter::<'a, W> {
writer,
query_context,
}
}
/// Try to write one result set. If there are more than one result set, return `Some`.
@@ -80,18 +89,23 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
recordbatches,
schema,
};
Self::write_query_result(query, query_result, self.writer).await?;
Self::write_query_result(query, query_result, self.writer, self.query_context)
.await?;
}
Output::RecordBatches(recordbatches) => {
let query_result = QueryResult {
schema: recordbatches.schema(),
recordbatches: recordbatches.take(),
};
Self::write_query_result(query, query_result, self.writer).await?;
Self::write_query_result(query, query_result, self.writer, self.query_context)
.await?;
}
Output::AffectedRows(rows) => {
let next_writer = Self::write_affected_rows(self.writer, rows).await?;
return Ok(Some(MysqlResultWriter::new(next_writer)));
return Ok(Some(MysqlResultWriter::new(
next_writer,
self.query_context,
)));
}
},
Err(error) => Self::write_query_error(query, error, self.writer).await?,
@@ -122,6 +136,7 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
query: &str,
query_result: QueryResult,
writer: QueryResultWriter<'a, W>,
query_context: QueryContextRef,
) -> Result<()> {
match create_mysql_column_def(&query_result.schema) {
Ok(column_def) => {
@@ -129,7 +144,8 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
// to return a new QueryResultWriter.
let mut row_writer = writer.start(&column_def).await?;
for recordbatch in &query_result.recordbatches {
Self::write_recordbatch(&mut row_writer, recordbatch).await?;
Self::write_recordbatch(&mut row_writer, recordbatch, query_context.clone())
.await?;
}
row_writer.finish().await?;
Ok(())
@@ -141,6 +157,7 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
async fn write_recordbatch(
row_writer: &mut RowWriter<'_, W>,
recordbatch: &RecordBatch,
query_context: QueryContextRef,
) -> Result<()> {
for row in recordbatch.rows() {
for value in row.into_iter() {
@@ -161,7 +178,8 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> {
Value::Binary(v) => row_writer.write_col(v.deref())?,
Value::Date(v) => row_writer.write_col(v.val())?,
Value::DateTime(v) => row_writer.write_col(v.val())?,
Value::Timestamp(v) => row_writer.write_col(v.to_local_string())?,
Value::Timestamp(v) => row_writer
.write_col(v.to_timezone_aware_string(query_context.time_zone()))?,
Value::List(_) => {
return Err(Error::Internal {
err_msg: format!(

View File

@@ -8,3 +8,4 @@ license.workspace = true
arc-swap = "1.5"
common-catalog = { path = "../common/catalog" }
common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }

View File

@@ -20,6 +20,7 @@ use arc_swap::ArcSwap;
use common_catalog::build_db_string;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_telemetry::debug;
use common_time::TimeZone;
pub type QueryContextRef = Arc<QueryContext>;
pub type ConnInfoRef = Arc<ConnInfo>;
@@ -28,6 +29,7 @@ pub type ConnInfoRef = Arc<ConnInfo>;
pub struct QueryContext {
current_catalog: ArcSwap<String>,
current_schema: ArcSwap<String>,
time_zone: ArcSwap<Option<TimeZone>>,
}
impl Default for QueryContext {
@@ -56,6 +58,7 @@ impl QueryContext {
Self {
current_catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.to_string())),
current_schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.to_string())),
time_zone: ArcSwap::new(Arc::new(None)),
}
}
@@ -63,6 +66,7 @@ impl QueryContext {
Self {
current_catalog: ArcSwap::new(Arc::new(catalog.to_string())),
current_schema: ArcSwap::new(Arc::new(schema.to_string())),
time_zone: ArcSwap::new(Arc::new(None)),
}
}
@@ -99,6 +103,16 @@ impl QueryContext {
let schema = self.current_schema();
build_db_string(&catalog, &schema)
}
#[inline]
pub fn time_zone(&self) -> Option<TimeZone> {
self.time_zone.load().as_ref().clone()
}
#[inline]
pub fn set_time_zone(&self, tz: Option<TimeZone>) {
self.time_zone.swap(Arc::new(tz));
}
}
pub const DEFAULT_USERNAME: &str = "greptime";