diff --git a/Cargo.lock b/Cargo.lock index a9493edf04..69f82d5986 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -640,6 +640,23 @@ dependencies = [ "arrow-select 57.0.0", ] +[[package]] +name = "arrow-pg" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87bc2eb53228ffb0cffff4a8a99d5311641b6d8ce63ec48b860dab70ec01ae1f" +dependencies = [ + "arrow 57.0.0", + "arrow-schema 57.0.0", + "bytes", + "chrono", + "futures", + "pg_interval_2", + "pgwire", + "postgres-types", + "rust_decimal", +] + [[package]] name = "arrow-row" version = "56.2.0" @@ -1553,9 +1570,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" dependencies = [ "serde", ] @@ -9629,10 +9646,10 @@ dependencies = [ ] [[package]] -name = "pg_interval" -version = "0.4.2" +name = "pg_interval_2" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe46640b465e284b048ef065cbed8ef17a622878d310c724578396b4cfd00df2" +checksum = "a055f44628dcf9c4e68f931535dabd3544a239655fdde25a3b0e95d4b36e9260" dependencies = [ "bytes", "chrono", @@ -9641,9 +9658,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.37.0" +version = "0.37.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02d86d57e732d40382ceb9bfea80901d839bae8571aa11c06af9177aed9dfb6c" +checksum = "6fcd410bc6990bd8d20b3fe3cd879a3c3ec250bdb1cb12537b528818823b02c9" dependencies = [ "async-trait", "base64 0.22.1", @@ -9654,6 +9671,7 @@ dependencies = [ "hex", "lazy-regex", "md5", + "pg_interval_2", "postgres-types", "rand 0.9.1", "ring", @@ -11326,9 +11344,9 @@ dependencies = [ [[package]] name = "rkyv" -version = "0.7.45" +version = "0.7.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9008cd6385b9e161d8229e1f6549dd23c3d022f132a2ea37ac3a10ac4935779b" +checksum = "2297bf9c81a3f0dc96bc9521370b88f054168c29826a75e89c55ff196e7ed6a1" dependencies = [ "bitvec", "bytecheck", @@ -11344,9 +11362,9 @@ dependencies = [ [[package]] name = "rkyv_derive" -version = "0.7.45" +version = "0.7.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "503d1d27590a2b0a3a4ca4c94755aa2875657196ecbf401a42eff41d7de532c0" +checksum = "84d7b42d4b8d06048d3ac8db0eb31bcb942cbeb709f0b5f2b2ebde398d3038f5" dependencies = [ "proc-macro2", "quote", @@ -11604,9 +11622,9 @@ dependencies = [ [[package]] name = "rust_decimal" -version = "1.38.0" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8975fc98059f365204d635119cf9c5a60ae67b841ed49b5422a9a7e56cdfac0" +checksum = "61f703d19852dbf87cbc513643fa81428361eb6940f1ac14fd58155d295a3eb0" dependencies = [ "arrayvec", "borsh", @@ -12182,6 +12200,7 @@ dependencies = [ "arrow 57.0.0", "arrow-flight", "arrow-ipc 57.0.0", + "arrow-pg", "arrow-schema 57.0.0", "async-trait", "auth", @@ -12253,7 +12272,7 @@ dependencies = [ "otel-arrow-rust", "parking_lot 0.12.4", "permutation", - "pg_interval", + "pg_interval_2", "pgwire", "pin-project", "pipeline", @@ -12807,6 +12826,7 @@ dependencies = [ "memchr", "once_cell", "percent-encoding", + "rust_decimal", "rustls", "serde", "serde_json", @@ -12890,6 +12910,7 @@ dependencies = [ "percent-encoding", "rand 0.8.5", "rsa", + "rust_decimal", "serde", "sha1", "sha2", @@ -12928,6 +12949,7 @@ dependencies = [ "memchr", "once_cell", "rand 0.8.5", + "rust_decimal", "serde", "serde_json", "sha2", diff --git a/Cargo.toml b/Cargo.toml index d55fedd621..5dc943d2ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -225,7 +225,13 @@ similar-asserts = "1.6.0" smallvec = { version = "1", features = ["serde"] } snafu = "0.8" sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor", "serde"] } -sqlx = { version = "0.8", default-features = false, features = ["any", "macros", "json", "runtime-tokio-rustls"] } +sqlx = { version = "0.8", default-features = false, features = [ + "any", + "macros", + "json", + "runtime-tokio-rustls", + "rust_decimal", +] } strum = { version = "0.27", features = ["derive"] } sysinfo = "0.33" tempfile = "3" diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index 54efd80369..890808d0f0 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -79,7 +79,8 @@ use table::table_name::TableName; use table::table_reference::TableReference; use self::set::{ - set_bytea_output, set_datestyle, set_search_path, set_timezone, validate_client_encoding, + set_bytea_output, set_datestyle, set_intervalstyle, set_search_path, set_timezone, + validate_client_encoding, }; use crate::error::{ self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu, @@ -483,6 +484,7 @@ impl StatementExecutor { // Not harmful since it only relates to how date is viewed in client app's output. // The tracked issue is https://github.com/GreptimeTeam/greptimedb/issues/3442. "DATESTYLE" => set_datestyle(set_var.value, query_ctx)?, + "INTERVALSTYLE" => set_intervalstyle(set_var.value, query_ctx)?, // Allow query to fallback when failed to push down. "ALLOW_QUERY_FALLBACK" => set_allow_query_fallback(set_var.value, query_ctx)?, diff --git a/src/operator/src/statement/set.rs b/src/operator/src/statement/set.rs index c4bf3758a0..38a16e1b05 100644 --- a/src/operator/src/statement/set.rs +++ b/src/operator/src/statement/set.rs @@ -21,7 +21,7 @@ use regex::Regex; use session::ReadPreference; use session::context::Channel::Postgres; use session::context::QueryContextRef; -use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle}; +use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle}; use snafu::{OptionExt, ResultExt, ensure}; use sql::ast::{Expr, Ident, Value}; use sql::statements::set_variables::SetVariables; @@ -279,6 +279,25 @@ pub fn set_allow_query_fallback(exprs: Vec, ctx: QueryContextRef) -> Resul } } +pub fn set_intervalstyle(exprs: Vec, ctx: QueryContextRef) -> Result<()> { + let Some((var_value, [])) = exprs.split_first() else { + return NotSupportedSnafu { + feat: "Set variable value must have one and only one value for intervalstyle", + } + .fail(); + }; + let Expr::Value(value) = var_value else { + return NotSupportedSnafu { + feat: "Set variable value must be a value", + } + .fail(); + }; + ctx.configuration_parameter().set_pg_intervalstyle_format( + PGIntervalStyle::try_from(&value.value).context(InvalidConfigValueSnafu)?, + ); + Ok(()) +} + pub fn set_datestyle(exprs: Vec, ctx: QueryContextRef) -> Result<()> { // ORDER, // STYLE, diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index b36f9e4df5..fb927d2901 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -742,6 +742,12 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result< let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style(); format!("{}, {}", style, order) } + "INTERVALSTYLE" => { + let style = *query_ctx + .configuration_parameter() + .pg_intervalstyle_format(); + style.to_string() + } "MAX_EXECUTION_TIME" => { if query_ctx.channel() == Channel::Mysql { query_ctx.query_timeout_as_millis().to_string() diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index cbef4bdad7..428a96e15b 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -23,6 +23,7 @@ api.workspace = true arrow.workspace = true arrow-flight.workspace = true arrow-ipc.workspace = true +arrow-pg = "0.11" arrow-schema.workspace = true async-trait.workspace = true auth.workspace = true @@ -87,8 +88,8 @@ opentelemetry-proto.workspace = true operator.workspace = true otel-arrow-rust.workspace = true parking_lot.workspace = true -pg_interval = "0.4" -pgwire = { version = "0.37", default-features = false, features = [ +pg_interval = { version = "0.5.2", package = "pg_interval_2" } +pgwire = { version = "0.37.3", default-features = false, features = [ "server-api-ring", "pg-ext-types", ] } diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index cdcb91bd13..92fb89af03 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -26,7 +26,6 @@ use common_error::ext::{BoxedError, ErrorExt}; use common_error::status_code::StatusCode; use common_macro::stack_trace_debug; use common_telemetry::{error, warn}; -use common_time::Duration; use datafusion::error::DataFusionError; use datatypes::prelude::ConcreteDataType; use headers::ContentType; @@ -640,9 +639,6 @@ pub enum Error { location: Location, }, - #[snafu(display("Overflow while casting `{:?}` to Interval", val))] - DurationOverflow { val: Duration }, - #[snafu(display("Failed to handle otel-arrow request, error message: {}", err_msg))] HandleOtelArrowRequest { err_msg: String, @@ -792,8 +788,6 @@ impl ErrorExt for Error { ConvertSqlValue { source, .. } => source.status_code(), - DurationOverflow { .. } => StatusCode::InvalidArguments, - HandleOtelArrowRequest { .. } => StatusCode::Internal, Cancelled { .. } => StatusCode::Cancelled, diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 7d428e5c45..5fb7281472 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -158,7 +158,10 @@ fn recordbatches_to_query_response( where S: Stream> + Send + Unpin + 'static, { - let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?); + let format_options = format_options_from_query_ctx(&query_ctx); + let pg_schema = Arc::new( + schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?, + ); let pg_schema_ref = pg_schema.clone(); let data_row_stream = recordbatches_stream .map(move |result| match result { @@ -405,7 +408,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { .collect::>(); if let Some(schema) = &sql_plan.schema { - schema_to_pg(schema, &Format::UnifiedBinary) + schema_to_pg(schema, &Format::UnifiedBinary, None) .map(|fields| DescribeStatementResponse::new(param_types, fields)) .map_err(convert_err) } else { @@ -438,7 +441,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { Some(Statement::Query(_)) => { // if the query has a schema, it is managed by datafusion, use the schema if let Some(schema) = &sql_plan.schema { - schema_to_pg(schema, format) + schema_to_pg(schema, format, None) .map(DescribePortalResponse::new) .map_err(convert_err) } else { diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index 866e49738d..b11735015f 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -12,66 +12,62 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod bytea; -mod datetime; mod error; -mod interval; use std::collections::HashMap; use std::sync::Arc; -use arrow::array::{Array, ArrayRef, AsArray}; -use arrow::datatypes::{ - Date32Type, Date64Type, Decimal128Type, Float32Type, Float64Type, Int8Type, Int16Type, - Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type, -}; -use arrow_schema::{DataType, IntervalUnit, TimeUnit}; -use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; -use common_decimal::Decimal128; +use arrow::array::{Array, AsArray}; +use arrow_pg::encoder::encode_value; +use arrow_pg::list_encoder::encode_list; +use arrow_schema::{DataType, TimeUnit}; +use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime}; use common_recordbatch::RecordBatch; -use common_time::time::Time; -use common_time::{Date, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth, Timestamp}; +use common_time::{IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth}; use datafusion_common::ScalarValue; use datafusion_expr::LogicalPlan; use datatypes::arrow::datatypes::DataType as ArrowDataType; use datatypes::json::JsonStructureSettings; use datatypes::prelude::{ConcreteDataType, Value}; -use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; +use datatypes::schema::{Schema, SchemaRef}; use datatypes::types::{IntervalType, TimestampType, jsonb_to_string}; use datatypes::value::StructValue; +use pg_interval::Interval as PgInterval; use pgwire::api::Type; use pgwire::api::portal::{Format, Portal}; use pgwire::api::results::{DataRowEncoder, FieldInfo}; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; +use pgwire::types::format::FormatOptions as PgFormatOptions; use session::context::QueryContextRef; -use session::session_config::PGByteaOutputValue; use snafu::ResultExt; -use self::bytea::{EscapeOutputBytea, HexOutputBytea}; -use self::datetime::{StylingDate, StylingDateTime}; pub use self::error::{PgErrorCode, PgErrorSeverity}; -use self::interval::PgInterval; use crate::SqlPlan; -use crate::error::{self as server_error, DataFusionSnafu, Error, Result}; +use crate::error::{self as server_error, DataFusionSnafu, Result}; use crate::postgres::utils::convert_err; -pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result> { +pub(super) fn schema_to_pg( + origin: &Schema, + field_formats: &Format, + format_options: Option>, +) -> Result> { origin .column_schemas() .iter() .enumerate() .map(|(idx, col)| { - Ok(FieldInfo::new( + let mut field_info = FieldInfo::new( col.name.clone(), None, None, type_gt_to_pg(&col.data_type)?, field_formats.format_for(idx), - )) + ); + if let Some(format_options) = &format_options { + field_info = field_info.with_format_options(format_options.clone()); + } + Ok(field_info) }) .collect::>>() } @@ -98,291 +94,6 @@ fn encode_struct( builder.encode_field(&json_value) } -fn encode_array( - query_ctx: &QueryContextRef, - array: ArrayRef, - builder: &mut DataRowEncoder, -) -> PgWireResult<()> { - macro_rules! encode_primitive_array { - ($array: ident, $data_type: ty, $lower_type: ty, $upper_type: ty) => {{ - let array = $array.iter().collect::>>(); - if array - .iter() - .all(|x| x.is_none_or(|i| i <= <$lower_type>::MAX as $data_type)) - { - builder.encode_field( - &array - .into_iter() - .map(|x| x.map(|i| i as $lower_type)) - .collect::>>(), - ) - } else { - builder.encode_field( - &array - .into_iter() - .map(|x| x.map(|i| i as $upper_type)) - .collect::>>(), - ) - } - }}; - } - - match array.data_type() { - DataType::Boolean => { - let array = array.as_boolean(); - let array = array.iter().collect::>(); - builder.encode_field(&array) - } - DataType::Int8 => { - let array = array.as_primitive::(); - let array = array.iter().collect::>(); - builder.encode_field(&array) - } - DataType::Int16 => { - let array = array.as_primitive::(); - let array = array.iter().collect::>(); - builder.encode_field(&array) - } - DataType::Int32 => { - let array = array.as_primitive::(); - let array = array.iter().collect::>(); - builder.encode_field(&array) - } - DataType::Int64 => { - let array = array.as_primitive::(); - let array = array.iter().collect::>(); - builder.encode_field(&array) - } - DataType::UInt8 => { - let array = array.as_primitive::(); - encode_primitive_array!(array, u8, i8, i16) - } - DataType::UInt16 => { - let array = array.as_primitive::(); - encode_primitive_array!(array, u16, i16, i32) - } - DataType::UInt32 => { - let array = array.as_primitive::(); - encode_primitive_array!(array, u32, i32, i64) - } - DataType::UInt64 => { - let array = array.as_primitive::(); - let array = array.iter().collect::>(); - if array.iter().all(|x| x.is_none_or(|i| i <= i64::MAX as u64)) { - builder.encode_field( - &array - .into_iter() - .map(|x| x.map(|i| i as i64)) - .collect::>>(), - ) - } else { - builder.encode_field( - &array - .into_iter() - .map(|x| x.map(|i| i.to_string())) - .collect::>(), - ) - } - } - DataType::Float32 => { - let array = array.as_primitive::(); - let array = array.iter().collect::>(); - builder.encode_field(&array) - } - DataType::Float64 => { - let array = array.as_primitive::(); - let array = array.iter().collect::>(); - builder.encode_field(&array) - } - DataType::Binary => { - let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output(); - - let array = array.as_binary::(); - match *bytea_output { - PGByteaOutputValue::ESCAPE => { - let array = array - .iter() - .map(|v| v.map(EscapeOutputBytea)) - .collect::>(); - builder.encode_field(&array) - } - PGByteaOutputValue::HEX => { - let array = array - .iter() - .map(|v| v.map(HexOutputBytea)) - .collect::>(); - builder.encode_field(&array) - } - } - } - DataType::Utf8 => { - let array = array.as_string::(); - let array = array.into_iter().collect::>(); - builder.encode_field(&array) - } - DataType::LargeUtf8 => { - let array = array.as_string::(); - let array = array.into_iter().collect::>(); - builder.encode_field(&array) - } - DataType::Utf8View => { - let array = array.as_string_view(); - let array = array.into_iter().collect::>(); - builder.encode_field(&array) - } - DataType::Date32 | DataType::Date64 => { - let iter: Box>> = - if matches!(array.data_type(), DataType::Date32) { - let array = array.as_primitive::(); - Box::new(array.into_iter()) - } else { - let array = array.as_primitive::(); - // `Date64` values are milliseconds representation of `Date32` values, according - // to its specification. So we convert them to `Date32` values to process the - // `Date64` array unified with `Date32` array. - Box::new( - array - .into_iter() - .map(|x| x.map(|i| (i / 86_400_000) as i32)), - ) - }; - let array = iter - .into_iter() - .map(|v| match v { - None => Ok(None), - Some(v) => { - if let Some(date) = Date::new(v).to_chrono_date() { - let (style, order) = - *query_ctx.configuration_parameter().pg_datetime_style(); - Ok(Some(StylingDate(date, style, order))) - } else { - Err(convert_err(Error::Internal { - err_msg: format!("Failed to convert date to postgres type {v:?}",), - })) - } - } - }) - .collect::>>>()?; - builder.encode_field(&array) - } - DataType::Timestamp(time_unit, _) => { - let array = match time_unit { - TimeUnit::Second => { - let array = array.as_primitive::(); - array.into_iter().collect::>() - } - TimeUnit::Millisecond => { - let array = array.as_primitive::(); - array.into_iter().collect::>() - } - TimeUnit::Microsecond => { - let array = array.as_primitive::(); - array.into_iter().collect::>() - } - TimeUnit::Nanosecond => { - let array = array.as_primitive::(); - array.into_iter().collect::>() - } - }; - let time_unit = time_unit.into(); - let array = array - .into_iter() - .map(|v| match v { - None => Ok(None), - Some(v) => { - let v = Timestamp::new(v, time_unit); - if let Some(datetime) = - v.to_chrono_datetime_with_timezone(Some(&query_ctx.timezone())) - { - let (style, order) = - *query_ctx.configuration_parameter().pg_datetime_style(); - Ok(Some(StylingDateTime(datetime, style, order))) - } else { - Err(convert_err(Error::Internal { - err_msg: format!("Failed to convert date to postgres type {v:?}",), - })) - } - } - }) - .collect::>>>()?; - builder.encode_field(&array) - } - DataType::Time32(time_unit) | DataType::Time64(time_unit) => { - let iter: Box>> = match time_unit { - TimeUnit::Second => { - let array = array.as_primitive::(); - Box::new( - array - .into_iter() - .map(|v| v.map(|i| Time::new_second(i as i64))), - ) - } - TimeUnit::Millisecond => { - let array = array.as_primitive::(); - Box::new( - array - .into_iter() - .map(|v| v.map(|i| Time::new_millisecond(i as i64))), - ) - } - TimeUnit::Microsecond => { - let array = array.as_primitive::(); - Box::new(array.into_iter().map(|v| v.map(Time::new_microsecond))) - } - TimeUnit::Nanosecond => { - let array = array.as_primitive::(); - Box::new(array.into_iter().map(|v| v.map(Time::new_nanosecond))) - } - }; - let array = iter - .into_iter() - .map(|v| v.and_then(|v| v.to_chrono_time())) - .collect::>>(); - builder.encode_field(&array) - } - DataType::Interval(interval_unit) => { - let array = match interval_unit { - IntervalUnit::YearMonth => { - let array = array.as_primitive::(); - array - .into_iter() - .map(|v| v.map(|i| PgInterval::from(IntervalYearMonth::from(i)))) - .collect::>() - } - IntervalUnit::DayTime => { - let array = array.as_primitive::(); - array - .into_iter() - .map(|v| v.map(|i| PgInterval::from(IntervalDayTime::from(i)))) - .collect::>() - } - IntervalUnit::MonthDayNano => { - let array = array.as_primitive::(); - array - .into_iter() - .map(|v| v.map(|i| PgInterval::from(IntervalMonthDayNano::from(i)))) - .collect::>() - } - }; - builder.encode_field(&array) - } - DataType::Decimal128(precision, scale) => { - let array = array.as_primitive::(); - let array = array - .into_iter() - .map(|v| v.map(|i| Decimal128::new(i, *precision, *scale).to_string())) - .collect::>(); - builder.encode_field(&array) - } - _ => Err(convert_err(Error::Internal { - err_msg: format!( - "cannot write array type {:?} in postgres protocol: unimplemented", - array.data_type() - ), - })), - } -} - pub(crate) struct RecordBatchRowIterator { query_ctx: QueryContextRef, pg_schema: Arc>, @@ -426,175 +137,42 @@ impl RecordBatchRowIterator { } fn encode_row(&mut self, i: usize, encoder: &mut DataRowEncoder) -> PgWireResult<()> { + let arrow_schema = self.record_batch.schema(); for (j, column) in self.record_batch.columns().iter().enumerate() { if column.is_null(i) { encoder.encode_field(&None::<&i8>)?; continue; } + let pg_field = &self.pg_schema[j]; match column.data_type() { - DataType::Null => { - encoder.encode_field(&None::<&i8>)?; - } - DataType::Boolean => { - let array = column.as_boolean(); - encoder.encode_field(&array.value(i))?; - } - DataType::UInt8 => { - let array = column.as_primitive::(); - let value = array.value(i); - if value <= i8::MAX as u8 { - encoder.encode_field(&(value as i8))?; - } else { - encoder.encode_field(&(value as i16))?; - } - } - DataType::UInt16 => { - let array = column.as_primitive::(); - let value = array.value(i); - if value <= i16::MAX as u16 { - encoder.encode_field(&(value as i16))?; - } else { - encoder.encode_field(&(value as i32))?; - } - } - DataType::UInt32 => { - let array = column.as_primitive::(); - let value = array.value(i); - if value <= i32::MAX as u32 { - encoder.encode_field(&(value as i32))?; - } else { - encoder.encode_field(&(value as i64))?; - } - } - DataType::UInt64 => { - let array = column.as_primitive::(); - let value = array.value(i); - if value <= i64::MAX as u64 { - encoder.encode_field(&(value as i64))?; - } else { - encoder.encode_field(&value.to_string())?; - } - } - DataType::Int8 => { - let array = column.as_primitive::(); - encoder.encode_field(&array.value(i))?; - } - DataType::Int16 => { - let array = column.as_primitive::(); - encoder.encode_field(&array.value(i))?; - } - DataType::Int32 => { - let array = column.as_primitive::(); - encoder.encode_field(&array.value(i))?; - } - DataType::Int64 => { - let array = column.as_primitive::(); - encoder.encode_field(&array.value(i))?; - } - DataType::Float32 => { - let array = column.as_primitive::(); - encoder.encode_field(&array.value(i))?; - } - DataType::Float64 => { - let array = column.as_primitive::(); - encoder.encode_field(&array.value(i))?; - } - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { - let value = datatypes::arrow_array::string_array_value(column, i); - encoder.encode_field(&value)?; - } + // these types are greptimedb specific or custom DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { - let v = datatypes::arrow_array::binary_array_value(column, i); - encode_bytes( - &self.schema.column_schemas()[j], - v, - encoder, - &self.query_ctx, - )?; - } - DataType::Date32 | DataType::Date64 => { - let v = if matches!(column.data_type(), DataType::Date32) { - let array = column.as_primitive::(); - array.value(i) + // jsonb + if let ConcreteDataType::Json(_) = &self.schema.column_schemas()[j].data_type { + let v = datatypes::arrow_array::binary_array_value(column, i); + let s = jsonb_to_string(v).map_err(convert_err)?; + encoder.encode_field(&s)?; } else { - let array = column.as_primitive::(); - // `Date64` values are milliseconds representation of `Date32` values, - // according to its specification. So we convert the `Date64` value here to - // the `Date32` value to process them unified. - (array.value(i) / 86_400_000) as i32 - }; - let v = Date::new(v); - let date = v.to_chrono_date().map(|v| { - let (style, order) = - *self.query_ctx.configuration_parameter().pg_datetime_style(); - StylingDate(v, style, order) - }); - encoder.encode_field(&date)?; - } - DataType::Timestamp(_, _) => { - let v = datatypes::arrow_array::timestamp_array_value(column, i); - let datetime = v - .to_chrono_datetime_with_timezone(Some(&self.query_ctx.timezone())) - .map(|v| { - let (style, order) = - *self.query_ctx.configuration_parameter().pg_datetime_style(); - StylingDateTime(v, style, order) - }); - encoder.encode_field(&datetime)?; - } - DataType::Interval(interval_unit) => match interval_unit { - IntervalUnit::YearMonth => { - let array = column.as_primitive::(); - let v: IntervalYearMonth = array.value(i).into(); - encoder.encode_field(&PgInterval::from(v))?; - } - IntervalUnit::DayTime => { - let array = column.as_primitive::(); - let v: IntervalDayTime = array.value(i).into(); - encoder.encode_field(&PgInterval::from(v))?; - } - IntervalUnit::MonthDayNano => { - let array = column.as_primitive::(); - let v: IntervalMonthDayNano = array.value(i).into(); - encoder.encode_field(&PgInterval::from(v))?; - } - }, - DataType::Duration(_) => { - let d = datatypes::arrow_array::duration_array_value(column, i); - match PgInterval::try_from(d) { - Ok(i) => encoder.encode_field(&i)?, - Err(e) => { - return Err(convert_err(Error::Internal { - err_msg: e.to_string(), - })); - } + // bytea + let arrow_field = arrow_schema.field(j); + encode_value(encoder, column, i, arrow_field, pg_field)?; } } + DataType::List(_) => { let array = column.as_list::(); let items = array.value(i); - encode_array(&self.query_ctx, items, encoder)?; + + encode_list(encoder, items, pg_field)?; } DataType::Struct(_) => { encode_struct(&self.query_ctx, Default::default(), encoder)?; } - DataType::Time32(_) | DataType::Time64(_) => { - let v = datatypes::arrow_array::time_array_value(column, i); - encoder.encode_field(&v.to_chrono_time())?; - } - DataType::Decimal128(precision, scale) => { - let array = column.as_primitive::(); - let v = Decimal128::new(array.value(i), *precision, *scale); - encoder.encode_field(&v.to_string())?; - } _ => { - return Err(convert_err(Error::Internal { - err_msg: format!( - "cannot convert datatype {} to postgres", - column.data_type() - ), - })); + // Encode value using arrow-pg + let arrow_field = arrow_schema.field(j); + encode_value(encoder, column, i, arrow_field, pg_field)?; } } } @@ -602,32 +180,15 @@ impl RecordBatchRowIterator { } } -fn encode_bytes( - schema: &ColumnSchema, - v: &[u8], - encoder: &mut DataRowEncoder, - query_ctx: &QueryContextRef, -) -> PgWireResult<()> { - if let ConcreteDataType::Json(_) = &schema.data_type { - let s = jsonb_to_string(v).map_err(convert_err)?; - encoder.encode_field(&s) - } else { - let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output(); - match *bytea_output { - PGByteaOutputValue::ESCAPE => encoder.encode_field(&EscapeOutputBytea(v)), - PGByteaOutputValue::HEX => encoder.encode_field(&HexOutputBytea(v)), - } - } -} - pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result { match origin { &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN), &ConcreteDataType::Boolean(_) => Ok(Type::BOOL), - &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR), - &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2), - &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4), - &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8), + &ConcreteDataType::Int8(_) => Ok(Type::CHAR), + &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt8(_) => Ok(Type::INT2), + &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT4), + &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT8), + &ConcreteDataType::UInt64(_) => Ok(Type::NUMERIC), &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4), &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8), &ConcreteDataType::Binary(_) | &ConcreteDataType::Vector(_) => Ok(Type::BYTEA), @@ -641,10 +202,11 @@ pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result { ConcreteDataType::List(list) => match list.item_type() { &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN), &ConcreteDataType::Boolean(_) => Ok(Type::BOOL_ARRAY), - &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR_ARRAY), - &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2_ARRAY), - &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4_ARRAY), - &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8_ARRAY), + &ConcreteDataType::Int8(_) => Ok(Type::CHAR_ARRAY), + &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt8(_) => Ok(Type::INT2_ARRAY), + &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT4_ARRAY), + &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT8_ARRAY), + &ConcreteDataType::UInt64(_) => Ok(Type::NUMERIC_ARRAY), &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4_ARRAY), &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8_ARRAY), &ConcreteDataType::Binary(_) => Ok(Type::BYTEA_ARRAY), @@ -762,7 +324,7 @@ pub(super) fn parameter_to_string(portal: &Portal, idx: usize) -> PgWir .unwrap_or_else(|| "".to_owned())), &Type::INTERVAL => Ok(portal .parameter::(idx, param_type)? - .map(|v| v.to_string()) + .map(|v| v.to_sql()) .unwrap_or_else(|| "".to_owned())), _ => Err(invalid_parameter_error( "unsupported_parameter_type", @@ -1133,9 +695,14 @@ pub(super) fn parameters_to_scalar_values( ) } ConcreteDataType::Interval(IntervalType::MonthDayNano(_)) => { - ScalarValue::IntervalMonthDayNano( - data.map(|i| IntervalMonthDayNano::from(i).into()), - ) + ScalarValue::IntervalMonthDayNano(data.map(|i| { + IntervalMonthDayNano::new( + i.months, + i.days, + i.microseconds * 1_000i64, + ) + .into() + })) } _ => { return Err(invalid_parameter_error( @@ -1145,9 +712,10 @@ pub(super) fn parameters_to_scalar_values( } } } else { - ScalarValue::IntervalMonthDayNano( - data.map(|i| IntervalMonthDayNano::from(i).into()), - ) + ScalarValue::IntervalMonthDayNano(data.map(|i| { + IntervalMonthDayNano::new(i.months, i.days, i.microseconds * 1_000i64) + .into() + })) } } &Type::BYTEA => { @@ -1454,6 +1022,19 @@ pub(super) fn param_types_to_pg_types( Ok(types) } +pub fn format_options_from_query_ctx(query_ctx: &QueryContextRef) -> Arc { + let config = query_ctx.configuration_parameter(); + let (date_style, date_order) = *config.pg_datetime_style(); + + let mut format_options = PgFormatOptions::default(); + format_options.date_style = format!("{}, {}", date_style, date_order); + format_options.interval_style = config.pg_intervalstyle_format().to_string(); + format_options.bytea_output = config.postgres_bytea_output().to_string(); + format_options.time_zone = query_ctx.timezone().to_string(); + + Arc::new(format_options) +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -1461,7 +1042,7 @@ mod test { use arrow::array::{ Float64Builder, Int64Builder, ListBuilder, StringBuilder, TimestampSecondBuilder, }; - use arrow_schema::Field; + use arrow_schema::{Field, IntervalUnit}; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{ BinaryVector, BooleanVector, DateVector, Float32Vector, Float64Vector, Int8Vector, @@ -1512,10 +1093,16 @@ mod test { FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text), FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text), FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text), - FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text), - FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text), - FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text), - FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text), + FieldInfo::new("uint8s".into(), None, None, Type::INT2, FieldFormat::Text), + FieldInfo::new("uint16s".into(), None, None, Type::INT4, FieldFormat::Text), + FieldInfo::new("uint32s".into(), None, None, Type::INT8, FieldFormat::Text), + FieldInfo::new( + "uint64s".into(), + None, + None, + Type::NUMERIC, + FieldFormat::Text, + ), FieldInfo::new( "float32s".into(), None, @@ -1562,7 +1149,7 @@ mod test { ), ]; let schema = Schema::new(column_schemas); - let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap(); + let fs = schema_to_pg(&schema, &Format::UnifiedText, None).unwrap(); assert_eq!(fs, pg_field_info); } @@ -1571,10 +1158,16 @@ mod test { let schema = vec![ FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text), FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text), - FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text), - FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text), - FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text), - FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text), + FieldInfo::new("uint8s".into(), None, None, Type::INT2, FieldFormat::Text), + FieldInfo::new("uint16s".into(), None, None, Type::INT4, FieldFormat::Text), + FieldInfo::new("uint32s".into(), None, None, Type::INT8, FieldFormat::Text), + FieldInfo::new( + "uint64s".into(), + None, + None, + Type::NUMERIC, + FieldFormat::Text, + ), FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text), FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text), FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text), diff --git a/src/servers/src/postgres/types/bytea.rs b/src/servers/src/postgres/types/bytea.rs deleted file mode 100644 index 7b8b42f754..0000000000 --- a/src/servers/src/postgres/types/bytea.rs +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use bytes::BufMut; -use pgwire::types::ToSqlText; -use pgwire::types::format::FormatOptions; -use postgres_types::{IsNull, ToSql, Type}; - -#[derive(Debug)] -pub struct HexOutputBytea<'a>(pub &'a [u8]); -impl ToSqlText for HexOutputBytea<'_> { - fn to_sql_text( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - format_options: &FormatOptions, - ) -> std::result::Result> - where - Self: Sized, - { - let _ = self.0.to_sql_text(ty, out, format_options); - Ok(IsNull::No) - } -} - -impl ToSql for HexOutputBytea<'_> { - fn to_sql( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - ) -> std::result::Result> - where - Self: Sized, - { - self.0.to_sql(ty, out) - } - - fn accepts(ty: &Type) -> bool - where - Self: Sized, - { - <&[u8] as ToSql>::accepts(ty) - } - - fn to_sql_checked( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - ) -> std::result::Result> { - self.0.to_sql_checked(ty, out) - } -} -#[derive(Debug)] -pub struct EscapeOutputBytea<'a>(pub &'a [u8]); -impl ToSqlText for EscapeOutputBytea<'_> { - fn to_sql_text( - &self, - _ty: &Type, - out: &mut bytes::BytesMut, - _format_options: &FormatOptions, - ) -> std::result::Result> - where - Self: Sized, - { - self.0.iter().for_each(|b| match b { - 0..=31 | 127..=255 => { - out.put_slice(b"\\"); - out.put_slice(format!("{:03o}", b).as_bytes()); - } - 92 => out.put_slice(b"\\\\"), - 32..=126 => out.put_u8(*b), - }); - Ok(IsNull::No) - } -} -impl ToSql for EscapeOutputBytea<'_> { - fn to_sql( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - ) -> std::result::Result> - where - Self: Sized, - { - self.0.to_sql(ty, out) - } - - fn accepts(ty: &Type) -> bool - where - Self: Sized, - { - <&[u8] as ToSql>::accepts(ty) - } - - fn to_sql_checked( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - ) -> std::result::Result> { - self.0.to_sql_checked(ty, out) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_escape_output_bytea() { - let input: &[u8] = &[97, 98, 99, 107, 108, 109, 42, 169, 84]; - let input = EscapeOutputBytea(input); - - let expected = b"abcklm*\\251T"; - let mut out = bytes::BytesMut::new(); - let is_null = input - .to_sql_text(&Type::BYTEA, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(&out[..], expected); - - let expected = &[97, 98, 99, 107, 108, 109, 42, 169, 84]; - let mut out = bytes::BytesMut::new(); - let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(&out[..], expected); - } - - #[test] - fn test_hex_output_bytea() { - let input = b"hello, world!"; - let input = HexOutputBytea(input); - - let expected = b"\\x68656c6c6f2c20776f726c6421"; - let mut out = bytes::BytesMut::new(); - let is_null = input - .to_sql_text(&Type::BYTEA, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(&out[..], expected); - - let expected = b"hello, world!"; - let mut out = bytes::BytesMut::new(); - let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(&out[..], expected); - } -} diff --git a/src/servers/src/postgres/types/datetime.rs b/src/servers/src/postgres/types/datetime.rs deleted file mode 100644 index 5fdd87decf..0000000000 --- a/src/servers/src/postgres/types/datetime.rs +++ /dev/null @@ -1,430 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use bytes::BufMut; -use chrono::{NaiveDate, NaiveDateTime}; -use pgwire::types::ToSqlText; -use pgwire::types::format::FormatOptions; -use postgres_types::{IsNull, ToSql, Type}; -use session::session_config::{PGDateOrder, PGDateTimeStyle}; - -#[derive(Debug)] -pub struct StylingDate(pub NaiveDate, pub PGDateTimeStyle, pub PGDateOrder); - -#[derive(Debug)] -pub struct StylingDateTime(pub NaiveDateTime, pub PGDateTimeStyle, pub PGDateOrder); - -fn date_format_string(style: PGDateTimeStyle, order: PGDateOrder) -> &'static str { - match style { - PGDateTimeStyle::ISO => "%Y-%m-%d", - PGDateTimeStyle::German => "%d.%m.%Y", - PGDateTimeStyle::Postgres => match order { - PGDateOrder::MDY | PGDateOrder::YMD => "%m-%d-%Y", - PGDateOrder::DMY => "%d-%m-%Y", - }, - PGDateTimeStyle::SQL => match order { - PGDateOrder::MDY | PGDateOrder::YMD => "%m/%d/%Y", - PGDateOrder::DMY => "%d/%m/%Y", - }, - } -} - -fn datetime_format_string(style: PGDateTimeStyle, order: PGDateOrder) -> &'static str { - match style { - PGDateTimeStyle::ISO => "%Y-%m-%d %H:%M:%S%.6f", - PGDateTimeStyle::German => "%d.%m.%Y %H:%M:%S%.6f", - PGDateTimeStyle::Postgres => match order { - PGDateOrder::MDY | PGDateOrder::YMD => "%a %b %d %H:%M:%S%.6f %Y", - PGDateOrder::DMY => "%a %d %b %H:%M:%S%.6f %Y", - }, - PGDateTimeStyle::SQL => match order { - PGDateOrder::MDY | PGDateOrder::YMD => "%m/%d/%Y %H:%M:%S%.6f", - PGDateOrder::DMY => "%d/%m/%Y %H:%M:%S%.6f", - }, - } -} -impl ToSqlText for StylingDate { - fn to_sql_text( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - format_options: &FormatOptions, - ) -> std::result::Result> - where - Self: Sized, - { - match *ty { - Type::DATE => { - let fmt = self - .0 - .format(date_format_string(self.1, self.2)) - .to_string(); - out.put_slice(fmt.as_bytes()); - } - _ => { - self.0.to_sql_text(ty, out, format_options)?; - } - } - Ok(IsNull::No) - } -} - -impl ToSqlText for StylingDateTime { - fn to_sql_text( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - format_options: &FormatOptions, - ) -> Result> - where - Self: Sized, - { - match *ty { - Type::TIMESTAMP => { - let fmt = self - .0 - .format(datetime_format_string(self.1, self.2)) - .to_string(); - out.put_slice(fmt.as_bytes()); - } - Type::DATE => { - let fmt = self - .0 - .format(date_format_string(self.1, self.2)) - .to_string(); - out.put_slice(fmt.as_bytes()); - } - _ => { - self.0.to_sql_text(ty, out, format_options)?; - } - } - Ok(IsNull::No) - } -} - -macro_rules! delegate_to_sql { - ($delegator:ident, $delegatee:ident) => { - impl ToSql for $delegator { - fn to_sql( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - ) -> std::result::Result> { - self.0.to_sql(ty, out) - } - - fn accepts(ty: &Type) -> bool { - <$delegatee as ToSql>::accepts(ty) - } - - fn to_sql_checked( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - ) -> Result> { - self.0.to_sql_checked(ty, out) - } - } - }; -} - -delegate_to_sql!(StylingDate, NaiveDate); -delegate_to_sql!(StylingDateTime, NaiveDateTime); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_styling_date() { - let naive_date = NaiveDate::from_ymd_opt(1997, 12, 17).unwrap(); - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::ISO, PGDateOrder::MDY); - let expected = "1997-12-17"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::ISO, PGDateOrder::YMD); - let expected = "1997-12-17"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::ISO, PGDateOrder::DMY); - let expected = "1997-12-17"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::German, PGDateOrder::MDY); - let expected = "17.12.1997"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::German, PGDateOrder::YMD); - let expected = "17.12.1997"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::German, PGDateOrder::DMY); - let expected = "17.12.1997"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::Postgres, PGDateOrder::MDY); - let expected = "12-17-1997"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::Postgres, PGDateOrder::YMD); - let expected = "12-17-1997"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::Postgres, PGDateOrder::DMY); - let expected = "17-12-1997"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::SQL, PGDateOrder::MDY); - let expected = "12/17/1997"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::SQL, PGDateOrder::YMD); - let expected = "12/17/1997"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_date = StylingDate(naive_date, PGDateTimeStyle::SQL, PGDateOrder::DMY); - let expected = "17/12/1997"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_date - .to_sql_text(&Type::DATE, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - } - - #[test] - fn test_styling_datetime() { - let input = - NaiveDateTime::parse_from_str("2021-09-01 12:34:56.789012", "%Y-%m-%d %H:%M:%S%.f") - .unwrap(); - - { - let styling_datetime = StylingDateTime(input, PGDateTimeStyle::ISO, PGDateOrder::MDY); - let expected = "2021-09-01 12:34:56.789012"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = StylingDateTime(input, PGDateTimeStyle::ISO, PGDateOrder::YMD); - let expected = "2021-09-01 12:34:56.789012"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = StylingDateTime(input, PGDateTimeStyle::ISO, PGDateOrder::DMY); - let expected = "2021-09-01 12:34:56.789012"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = - StylingDateTime(input, PGDateTimeStyle::German, PGDateOrder::MDY); - let expected = "01.09.2021 12:34:56.789012"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = - StylingDateTime(input, PGDateTimeStyle::German, PGDateOrder::YMD); - let expected = "01.09.2021 12:34:56.789012"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = - StylingDateTime(input, PGDateTimeStyle::German, PGDateOrder::DMY); - let expected = "01.09.2021 12:34:56.789012"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = - StylingDateTime(input, PGDateTimeStyle::Postgres, PGDateOrder::MDY); - let expected = "Wed Sep 01 12:34:56.789012 2021"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = - StylingDateTime(input, PGDateTimeStyle::Postgres, PGDateOrder::YMD); - let expected = "Wed Sep 01 12:34:56.789012 2021"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = - StylingDateTime(input, PGDateTimeStyle::Postgres, PGDateOrder::DMY); - let expected = "Wed 01 Sep 12:34:56.789012 2021"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = StylingDateTime(input, PGDateTimeStyle::SQL, PGDateOrder::MDY); - let expected = "09/01/2021 12:34:56.789012"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = StylingDateTime(input, PGDateTimeStyle::SQL, PGDateOrder::YMD); - let expected = "09/01/2021 12:34:56.789012"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - - { - let styling_datetime = StylingDateTime(input, PGDateTimeStyle::SQL, PGDateOrder::DMY); - let expected = "01/09/2021 12:34:56.789012"; - let mut out = bytes::BytesMut::new(); - let is_null = styling_datetime - .to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default()) - .unwrap(); - assert!(matches!(is_null, IsNull::No)); - assert_eq!(out, expected.as_bytes()); - } - } -} diff --git a/src/servers/src/postgres/types/interval.rs b/src/servers/src/postgres/types/interval.rs deleted file mode 100644 index 2734d449b0..0000000000 --- a/src/servers/src/postgres/types/interval.rs +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::fmt::Display; - -use bytes::{Buf, BufMut}; -use common_time::interval::IntervalFormat; -use common_time::timestamp::TimeUnit; -use common_time::{Duration, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth}; -use pgwire::types::format::FormatOptions; -use pgwire::types::{FromSqlText, ToSqlText}; -use postgres_types::{FromSql, IsNull, ToSql, Type, to_sql_checked}; - -use crate::error; - -/// On average one month has 30.44 day, which is a common approximation. -const SECONDS_PER_MONTH: i64 = 24 * 6 * 6 * 3044; -const SECONDS_PER_DAY: i64 = 24 * 60 * 60; -const MILLISECONDS_PER_MONTH: i64 = SECONDS_PER_MONTH * 1000; -const MILLISECONDS_PER_DAY: i64 = SECONDS_PER_DAY * 1000; - -#[derive(Debug, Clone, Copy, Default)] -pub struct PgInterval { - pub(crate) months: i32, - pub(crate) days: i32, - pub(crate) microseconds: i64, -} - -impl From for PgInterval { - fn from(interval: IntervalYearMonth) -> Self { - Self { - months: interval.months, - days: 0, - microseconds: 0, - } - } -} - -impl From for PgInterval { - fn from(interval: IntervalDayTime) -> Self { - Self { - months: 0, - days: interval.days, - microseconds: interval.milliseconds as i64 * 1000, - } - } -} - -impl From for PgInterval { - fn from(interval: IntervalMonthDayNano) -> Self { - Self { - months: interval.months, - days: interval.days, - microseconds: interval.nanoseconds / 1000, - } - } -} - -impl TryFrom for PgInterval { - type Error = error::Error; - - fn try_from(duration: Duration) -> error::Result { - let value = duration.value(); - let unit = duration.unit(); - - // Convert the duration to microseconds - match unit { - TimeUnit::Second => { - let months = i32::try_from(value / SECONDS_PER_MONTH) - .map_err(|_| error::DurationOverflowSnafu { val: duration }.build())?; - let days = - i32::try_from((value - (months as i64) * SECONDS_PER_MONTH) / SECONDS_PER_DAY) - .map_err(|_| error::DurationOverflowSnafu { val: duration }.build())?; - let microseconds = - (value - (months as i64) * SECONDS_PER_MONTH - (days as i64) * SECONDS_PER_DAY) - .checked_mul(1_000_000) - .ok_or(error::DurationOverflowSnafu { val: duration }.build())?; - - Ok(Self { - months, - days, - microseconds, - }) - } - TimeUnit::Millisecond => { - let months = i32::try_from(value / MILLISECONDS_PER_MONTH) - .map_err(|_| error::DurationOverflowSnafu { val: duration }.build())?; - let days = i32::try_from( - (value - (months as i64) * MILLISECONDS_PER_MONTH) / MILLISECONDS_PER_DAY, - ) - .map_err(|_| error::DurationOverflowSnafu { val: duration }.build())?; - let microseconds = ((value - (months as i64) * MILLISECONDS_PER_MONTH) - - (days as i64) * MILLISECONDS_PER_DAY) - * 1_000; - Ok(Self { - months, - days, - microseconds, - }) - } - TimeUnit::Microsecond => Ok(Self { - months: 0, - days: 0, - microseconds: value, - }), - TimeUnit::Nanosecond => Ok(Self { - months: 0, - days: 0, - microseconds: value / 1000, - }), - } - } -} - -impl From for IntervalMonthDayNano { - fn from(interval: PgInterval) -> Self { - IntervalMonthDayNano::new( - interval.months, - interval.days, - // Maybe overflow, but most scenarios ok. - interval.microseconds.checked_mul(1000).unwrap_or_else(|| { - if interval.microseconds.is_negative() { - i64::MIN - } else { - i64::MAX - } - }), - ) - } -} - -impl Display for PgInterval { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - IntervalFormat::from(IntervalMonthDayNano::from(*self)).to_postgres_string() - ) - } -} - -impl ToSql for PgInterval { - to_sql_checked!(); - - fn to_sql( - &self, - _: &Type, - out: &mut bytes::BytesMut, - ) -> std::result::Result> - where - Self: Sized, - { - // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L989-L991 - out.put_i64(self.microseconds); - out.put_i32(self.days); - out.put_i32(self.months); - Ok(postgres_types::IsNull::No) - } - - fn accepts(ty: &Type) -> bool - where - Self: Sized, - { - matches!(ty, &Type::INTERVAL) - } -} - -impl<'a> FromSql<'a> for PgInterval { - fn from_sql( - _: &Type, - mut raw: &'a [u8], - ) -> std::result::Result> { - // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L1007-L1010 - let microseconds = raw.get_i64(); - let days = raw.get_i32(); - let months = raw.get_i32(); - Ok(PgInterval { - months, - days, - microseconds, - }) - } - - fn accepts(ty: &Type) -> bool { - matches!(ty, &Type::INTERVAL) - } -} - -impl ToSqlText for PgInterval { - fn to_sql_text( - &self, - ty: &Type, - out: &mut bytes::BytesMut, - _format_options: &FormatOptions, - ) -> std::result::Result> - where - Self: Sized, - { - let fmt = match ty { - &Type::INTERVAL => self.to_string(), - _ => return Err("unsupported type".into()), - }; - - out.put_slice(fmt.as_bytes()); - Ok(IsNull::No) - } -} - -impl<'a> FromSqlText<'a> for PgInterval { - fn from_sql_text( - _ty: &Type, - input: &[u8], - _format_options: &FormatOptions, - ) -> std::result::Result> - where - Self: Sized, - { - // only support parsing interval from postgres format - if let Ok(interval) = pg_interval::Interval::from_postgres(str::from_utf8(input)?) { - Ok(PgInterval { - months: interval.months, - days: interval.days, - microseconds: interval.microseconds, - }) - } else { - Err("invalid interval format".into()) - } - } -} - -#[cfg(test)] -mod tests { - use common_time::Duration; - use common_time::timestamp::TimeUnit; - - use super::*; - - #[test] - fn test_duration_to_pg_interval() { - // Test with seconds - let duration = Duration::new(86400, TimeUnit::Second); // 1 day - let interval = PgInterval::try_from(duration).unwrap(); - assert_eq!(interval.months, 0); - assert_eq!(interval.days, 1); - assert_eq!(interval.microseconds, 0); - - // Test with milliseconds - let duration = Duration::new(86400000, TimeUnit::Millisecond); // 1 day - let interval = PgInterval::try_from(duration).unwrap(); - assert_eq!(interval.months, 0); - assert_eq!(interval.days, 1); - assert_eq!(interval.microseconds, 0); - - // Test with microseconds - let duration = Duration::new(86400000000, TimeUnit::Microsecond); // 1 day - let interval = PgInterval::try_from(duration).unwrap(); - assert_eq!(interval.months, 0); - assert_eq!(interval.days, 0); - assert_eq!(interval.microseconds, 86400000000); - - // Test with nanoseconds - let duration = Duration::new(86400000000000, TimeUnit::Nanosecond); // 1 day - let interval = PgInterval::try_from(duration).unwrap(); - assert_eq!(interval.months, 0); - assert_eq!(interval.days, 0); - assert_eq!(interval.microseconds, 86400000000); - - // Test with partial day - let duration = Duration::new(43200, TimeUnit::Second); // 12 hours - let interval = PgInterval::try_from(duration).unwrap(); - assert_eq!(interval.months, 0); - assert_eq!(interval.days, 0); - assert_eq!(interval.microseconds, 43_200_000_000); // 12 hours in microseconds - - // Test with negative duration - let duration = Duration::new(-86400, TimeUnit::Second); // -1 day - let interval = PgInterval::try_from(duration).unwrap(); - assert_eq!(interval.months, 0); - assert_eq!(interval.days, -1); - assert_eq!(interval.microseconds, 0); - - // Test with multiple days - let duration = Duration::new(259200, TimeUnit::Second); // 3 days - let interval = PgInterval::try_from(duration).unwrap(); - assert_eq!(interval.months, 0); - assert_eq!(interval.days, 3); - assert_eq!(interval.microseconds, 0); - - // Test with small duration (less than a day) - let duration = Duration::new(3600, TimeUnit::Second); // 1 hour - let interval = PgInterval::try_from(duration).unwrap(); - assert_eq!(interval.months, 0); - assert_eq!(interval.days, 0); - assert_eq!(interval.microseconds, 3600000000); // 1 hour in microseconds - - // Test with very small duration - let duration = Duration::new(1, TimeUnit::Microsecond); // 1 microsecond - let interval = PgInterval::try_from(duration).unwrap(); - assert_eq!(interval.months, 0); - assert_eq!(interval.days, 0); - assert_eq!(interval.microseconds, 1); - - let duration = Duration::new(i64::MAX, TimeUnit::Second); - assert!(PgInterval::try_from(duration).is_err()); - - let duration = Duration::new(i64::MAX, TimeUnit::Millisecond); - assert!(PgInterval::try_from(duration).is_err()); - } -} diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index 3d73d84e27..311e630e0e 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -355,8 +355,8 @@ async fn test_extended_query() -> Result<()> { let rows = client.query(&stmt, &[&1i32]).await.unwrap(); assert_eq!(rows.len(), 1); assert_eq!(rows[0].len(), 2); - assert_eq!(rows[0].get::(0usize), 1); - assert_eq!(rows[0].get::<&str, i32>("uint32s"), 1); + assert_eq!(rows[0].get::(0usize), 1); + assert_eq!(rows[0].get::<&str, i64>("uint32s"), 1); assert_eq!(rows[0].get::(1usize), 2); assert_eq!(rows[0].get::<&str, i64>("numbers.uint32s + Int64(1)"), 2); diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 7f83bfd509..2b9483aca8 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -33,7 +33,7 @@ use derive_builder::Builder; use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; use crate::protocol_ctx::ProtocolCtx; -use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle}; +use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle}; use crate::{MutableInner, ReadPreference}; pub type QueryContextRef = Arc; @@ -614,6 +614,7 @@ impl AsRef for Channel { pub struct ConfigurationVariables { postgres_bytea_output: ArcSwap, pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>, + pg_intervalstyle_format: ArcSwap, allow_query_fallback: ArcSwap, } @@ -622,6 +623,7 @@ impl Clone for ConfigurationVariables { Self { postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()), pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()), + pg_intervalstyle_format: ArcSwap::new(self.pg_intervalstyle_format.load().clone()), allow_query_fallback: ArcSwap::new(self.allow_query_fallback.load().clone()), } } @@ -648,6 +650,14 @@ impl ConfigurationVariables { self.pg_datestyle_format.swap(Arc::new((style, order))); } + pub fn pg_intervalstyle_format(&self) -> Arc { + self.pg_intervalstyle_format.load().clone() + } + + pub fn set_pg_intervalstyle_format(&self, value: PGIntervalStyle) { + self.pg_intervalstyle_format.swap(Arc::new(value)); + } + pub fn allow_query_fallback(&self) -> bool { **self.allow_query_fallback.load() } diff --git a/src/session/src/session_config.rs b/src/session/src/session_config.rs index 8b93ce2a2c..cc13d47f44 100644 --- a/src/session/src/session_config.rs +++ b/src/session/src/session_config.rs @@ -66,6 +66,15 @@ impl TryFrom for PGByteaOutputValue { } } +impl Display for PGByteaOutputValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PGByteaOutputValue::HEX => write!(f, "hex"), + PGByteaOutputValue::ESCAPE => write!(f, "escape"), + } + } +} + // Refers to: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-DATESTYLE #[derive(Default, PartialEq, Eq, Clone, Copy, Debug)] pub enum PGDateOrder { @@ -176,3 +185,60 @@ impl TryFrom<&Value> for PGDateTimeStyle { } } } + +#[derive(Default, PartialEq, Eq, Clone, Copy, Debug)] +pub enum PGIntervalStyle { + ISO, + SQL, + #[default] + Postgres, + PostgresVerbose, +} + +impl Display for PGIntervalStyle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PGIntervalStyle::ISO => write!(f, "iso_8601"), + PGIntervalStyle::SQL => write!(f, "sql_standard"), + PGIntervalStyle::Postgres => write!(f, "postgres"), + PGIntervalStyle::PostgresVerbose => write!(f, "postgres_verbose"), + } + } +} + +impl TryFrom<&str> for PGIntervalStyle { + type Error = Error; + + fn try_from(s: &str) -> Result { + match s.to_uppercase().as_str() { + "ISO" | "ISO_8601" => Ok(PGIntervalStyle::ISO), + "SQL" | "SQL_STANDARD" => Ok(PGIntervalStyle::SQL), + "POSTGRES" => Ok(PGIntervalStyle::Postgres), + "POSTGRES_VERBOSE" | "POSTGRES, VERBOSE" => Ok(PGIntervalStyle::PostgresVerbose), + _ => InvalidConfigValueSnafu { + name: "IntervalStyle", + value: s, + hint: format!("Unrecognized key word: {}", s), + } + .fail(), + } + } +} + +impl TryFrom<&Value> for PGIntervalStyle { + type Error = Error; + + fn try_from(value: &Value) -> Result { + match value { + Value::DoubleQuotedString(s) | Value::SingleQuotedString(s) => { + Self::try_from(s.as_str()) + } + _ => InvalidConfigValueSnafu { + name: "IntervalStyle", + value: value.to_string(), + hint: format!("Unrecognized key word: {}", value), + } + .fail(), + } + } +} diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index b3d981b1b0..41bbf2ce4d 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -24,6 +24,7 @@ use common_frontend::slow_query_event::{ }; use sqlx::mysql::{MySqlConnection, MySqlDatabaseError, MySqlPoolOptions}; use sqlx::postgres::{PgDatabaseError, PgPoolOptions}; +use sqlx::types::Decimal; use sqlx::{Connection, Executor, Row}; use tests_integration::test_util::{ StorageType, setup_mysql_server, setup_mysql_server_with_user_provider, setup_pg_server, @@ -77,6 +78,7 @@ macro_rules! sql_tests { test_postgres_bytea, test_postgres_slow_query, test_postgres_datestyle, + test_postgres_intervalstyle, test_postgres_parameter_inference, test_postgres_array_types, test_mysql_prepare_stmt_insert_timestamp, @@ -834,12 +836,12 @@ pub async fn test_postgres_slow_query(store_type: StorageType) { let rows = sqlx::query(&query).fetch_all(&pool).await.unwrap(); assert_eq!(rows.len(), 1); let row = &rows[0]; - let cost: i64 = row.get(0); - let threshold: i64 = row.get(1); + let cost: Decimal = row.get(0); + let threshold: Decimal = row.get(1); let query: String = row.get(2); let is_promql: bool = row.get(3); - assert!(cost > 0 && threshold > 0 && cost > threshold); + assert!(cost > 0.into() && threshold > 0.into() && cost > threshold); assert_eq!(query, slow_query); assert!(!is_promql); @@ -1075,6 +1077,90 @@ pub async fn test_postgres_datestyle(store_type: StorageType) { guard.remove_all().await; } +pub async fn test_postgres_intervalstyle(store_type: StorageType) { + let (mut guard, fe_pg_server) = + setup_pg_server(store_type, "test_postgres_intervalstyle").await; + let addr = fe_pg_server.bind_addr().unwrap().to_string(); + + let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) + .await + .unwrap(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + connection.await.unwrap(); + tx.send(()).unwrap(); + }); + + let validate_intervalstyle = |client: Client, intervalstyle: &str, is_valid: bool| { + let intervalstyle = intervalstyle.to_string(); + async move { + assert_eq!( + client + .simple_query(format!("SET INTERVALSTYLE='{}'", intervalstyle).as_str()) + .await + .is_ok(), + is_valid, + "testing intervalstyle {intervalstyle}" + ); + client + } + }; + + let get_row = |mess: Vec| -> String { + match &mess[1] { + SimpleQueryMessage::Row(row) => row.get(0).unwrap().to_string(), + _ => unreachable!(), + } + }; + + let client = validate_intervalstyle(client, "iso_8601", true).await; + let client = validate_intervalstyle(client, "sql_standard", true).await; + let client = validate_intervalstyle(client, "postgres", true).await; + let client = validate_intervalstyle(client, "postgres_verbose", true).await; + let client = validate_intervalstyle(client, "invalid_style", false).await; + + let expected_formats: HashMap<&str, &str> = HashMap::from([ + ("iso_8601", "P1DT2H3M"), + ("sql_standard", "1 2:03:00"), + ("postgres", "1 day 02:03:00"), + ("postgres_verbose", "@ 1 day 2 hours 3 mins"), + ]); + + for (style, expected_format) in expected_formats { + let _ = client + .simple_query(&format!("SET INTERVALSTYLE='{}'", style)) + .await + .expect("SET INTERVALSTYLE ERROR"); + + let interval = get_row( + client + .simple_query("SHOW VARIABLES intervalstyle") + .await + .unwrap(), + ); + assert_eq!(interval, style); + + let result = get_row( + client + .simple_query("SELECT INTERVAL '1 day 2 hours 3 minutes'") + .await + .unwrap(), + ); + assert_eq!( + result, expected_format, + "intervalstyle {}: expected '{}', got '{}'", + style, expected_format, result + ); + } + + drop(client); + rx.await.unwrap(); + + let _ = fe_pg_server.shutdown().await; + guard.remove_all().await; +} + pub async fn test_postgres_timezone(store_type: StorageType) { let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_timezone").await; let addr = fe_pg_server.bind_addr().unwrap().to_string();