feat: use arrow-pg for postgres data encoding (#7591)

* feat: use arrow-pg for encode_row

* refactor: remove bytea and datetime module

* feat: port more encodings to arrow-pg

* feat: implement intervalstyle

* chore: format

* chore: remove error that is no longer used

* chore: use released arrow-pg

* Apply suggestions from code review

Co-authored-by: LFC <990479+MichaelScofield@users.noreply.github.com>

---------

Co-authored-by: LFC <990479+MichaelScofield@users.noreply.github.com>
This commit is contained in:
Ning Sun
2026-01-28 10:34:02 +08:00
committed by GitHub
parent 3c915a382b
commit 124478f577
16 changed files with 345 additions and 1445 deletions

50
Cargo.lock generated
View File

@@ -640,6 +640,23 @@ dependencies = [
"arrow-select 57.0.0",
]
[[package]]
name = "arrow-pg"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87bc2eb53228ffb0cffff4a8a99d5311641b6d8ce63ec48b860dab70ec01ae1f"
dependencies = [
"arrow 57.0.0",
"arrow-schema 57.0.0",
"bytes",
"chrono",
"futures",
"pg_interval_2",
"pgwire",
"postgres-types",
"rust_decimal",
]
[[package]]
name = "arrow-row"
version = "56.2.0"
@@ -1553,9 +1570,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.10.1"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3"
dependencies = [
"serde",
]
@@ -9629,10 +9646,10 @@ dependencies = [
]
[[package]]
name = "pg_interval"
version = "0.4.2"
name = "pg_interval_2"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe46640b465e284b048ef065cbed8ef17a622878d310c724578396b4cfd00df2"
checksum = "a055f44628dcf9c4e68f931535dabd3544a239655fdde25a3b0e95d4b36e9260"
dependencies = [
"bytes",
"chrono",
@@ -9641,9 +9658,9 @@ dependencies = [
[[package]]
name = "pgwire"
version = "0.37.0"
version = "0.37.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02d86d57e732d40382ceb9bfea80901d839bae8571aa11c06af9177aed9dfb6c"
checksum = "6fcd410bc6990bd8d20b3fe3cd879a3c3ec250bdb1cb12537b528818823b02c9"
dependencies = [
"async-trait",
"base64 0.22.1",
@@ -9654,6 +9671,7 @@ dependencies = [
"hex",
"lazy-regex",
"md5",
"pg_interval_2",
"postgres-types",
"rand 0.9.1",
"ring",
@@ -11326,9 +11344,9 @@ dependencies = [
[[package]]
name = "rkyv"
version = "0.7.45"
version = "0.7.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9008cd6385b9e161d8229e1f6549dd23c3d022f132a2ea37ac3a10ac4935779b"
checksum = "2297bf9c81a3f0dc96bc9521370b88f054168c29826a75e89c55ff196e7ed6a1"
dependencies = [
"bitvec",
"bytecheck",
@@ -11344,9 +11362,9 @@ dependencies = [
[[package]]
name = "rkyv_derive"
version = "0.7.45"
version = "0.7.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "503d1d27590a2b0a3a4ca4c94755aa2875657196ecbf401a42eff41d7de532c0"
checksum = "84d7b42d4b8d06048d3ac8db0eb31bcb942cbeb709f0b5f2b2ebde398d3038f5"
dependencies = [
"proc-macro2",
"quote",
@@ -11604,9 +11622,9 @@ dependencies = [
[[package]]
name = "rust_decimal"
version = "1.38.0"
version = "1.40.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8975fc98059f365204d635119cf9c5a60ae67b841ed49b5422a9a7e56cdfac0"
checksum = "61f703d19852dbf87cbc513643fa81428361eb6940f1ac14fd58155d295a3eb0"
dependencies = [
"arrayvec",
"borsh",
@@ -12182,6 +12200,7 @@ dependencies = [
"arrow 57.0.0",
"arrow-flight",
"arrow-ipc 57.0.0",
"arrow-pg",
"arrow-schema 57.0.0",
"async-trait",
"auth",
@@ -12253,7 +12272,7 @@ dependencies = [
"otel-arrow-rust",
"parking_lot 0.12.4",
"permutation",
"pg_interval",
"pg_interval_2",
"pgwire",
"pin-project",
"pipeline",
@@ -12807,6 +12826,7 @@ dependencies = [
"memchr",
"once_cell",
"percent-encoding",
"rust_decimal",
"rustls",
"serde",
"serde_json",
@@ -12890,6 +12910,7 @@ dependencies = [
"percent-encoding",
"rand 0.8.5",
"rsa",
"rust_decimal",
"serde",
"sha1",
"sha2",
@@ -12928,6 +12949,7 @@ dependencies = [
"memchr",
"once_cell",
"rand 0.8.5",
"rust_decimal",
"serde",
"serde_json",
"sha2",

View File

@@ -225,7 +225,13 @@ similar-asserts = "1.6.0"
smallvec = { version = "1", features = ["serde"] }
snafu = "0.8"
sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor", "serde"] }
sqlx = { version = "0.8", default-features = false, features = ["any", "macros", "json", "runtime-tokio-rustls"] }
sqlx = { version = "0.8", default-features = false, features = [
"any",
"macros",
"json",
"runtime-tokio-rustls",
"rust_decimal",
] }
strum = { version = "0.27", features = ["derive"] }
sysinfo = "0.33"
tempfile = "3"

View File

@@ -79,7 +79,8 @@ use table::table_name::TableName;
use table::table_reference::TableReference;
use self::set::{
set_bytea_output, set_datestyle, set_search_path, set_timezone, validate_client_encoding,
set_bytea_output, set_datestyle, set_intervalstyle, set_search_path, set_timezone,
validate_client_encoding,
};
use crate::error::{
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu,
@@ -483,6 +484,7 @@ impl StatementExecutor {
// Not harmful since it only relates to how date is viewed in client app's output.
// The tracked issue is https://github.com/GreptimeTeam/greptimedb/issues/3442.
"DATESTYLE" => set_datestyle(set_var.value, query_ctx)?,
"INTERVALSTYLE" => set_intervalstyle(set_var.value, query_ctx)?,
// Allow query to fallback when failed to push down.
"ALLOW_QUERY_FALLBACK" => set_allow_query_fallback(set_var.value, query_ctx)?,

View File

@@ -21,7 +21,7 @@ use regex::Regex;
use session::ReadPreference;
use session::context::Channel::Postgres;
use session::context::QueryContextRef;
use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle};
use snafu::{OptionExt, ResultExt, ensure};
use sql::ast::{Expr, Ident, Value};
use sql::statements::set_variables::SetVariables;
@@ -279,6 +279,25 @@ pub fn set_allow_query_fallback(exprs: Vec<Expr>, ctx: QueryContextRef) -> Resul
}
}
pub fn set_intervalstyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let Some((var_value, [])) = exprs.split_first() else {
return NotSupportedSnafu {
feat: "Set variable value must have one and only one value for intervalstyle",
}
.fail();
};
let Expr::Value(value) = var_value else {
return NotSupportedSnafu {
feat: "Set variable value must be a value",
}
.fail();
};
ctx.configuration_parameter().set_pg_intervalstyle_format(
PGIntervalStyle::try_from(&value.value).context(InvalidConfigValueSnafu)?,
);
Ok(())
}
pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
// ORDER,
// STYLE,

View File

@@ -742,6 +742,12 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result<
let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style();
format!("{}, {}", style, order)
}
"INTERVALSTYLE" => {
let style = *query_ctx
.configuration_parameter()
.pg_intervalstyle_format();
style.to_string()
}
"MAX_EXECUTION_TIME" => {
if query_ctx.channel() == Channel::Mysql {
query_ctx.query_timeout_as_millis().to_string()

View File

@@ -23,6 +23,7 @@ api.workspace = true
arrow.workspace = true
arrow-flight.workspace = true
arrow-ipc.workspace = true
arrow-pg = "0.11"
arrow-schema.workspace = true
async-trait.workspace = true
auth.workspace = true
@@ -87,8 +88,8 @@ opentelemetry-proto.workspace = true
operator.workspace = true
otel-arrow-rust.workspace = true
parking_lot.workspace = true
pg_interval = "0.4"
pgwire = { version = "0.37", default-features = false, features = [
pg_interval = { version = "0.5.2", package = "pg_interval_2" }
pgwire = { version = "0.37.3", default-features = false, features = [
"server-api-ring",
"pg-ext-types",
] }

View File

@@ -26,7 +26,6 @@ use common_error::ext::{BoxedError, ErrorExt};
use common_error::status_code::StatusCode;
use common_macro::stack_trace_debug;
use common_telemetry::{error, warn};
use common_time::Duration;
use datafusion::error::DataFusionError;
use datatypes::prelude::ConcreteDataType;
use headers::ContentType;
@@ -640,9 +639,6 @@ pub enum Error {
location: Location,
},
#[snafu(display("Overflow while casting `{:?}` to Interval", val))]
DurationOverflow { val: Duration },
#[snafu(display("Failed to handle otel-arrow request, error message: {}", err_msg))]
HandleOtelArrowRequest {
err_msg: String,
@@ -792,8 +788,6 @@ impl ErrorExt for Error {
ConvertSqlValue { source, .. } => source.status_code(),
DurationOverflow { .. } => StatusCode::InvalidArguments,
HandleOtelArrowRequest { .. } => StatusCode::Internal,
Cancelled { .. } => StatusCode::Cancelled,

View File

@@ -158,7 +158,10 @@ fn recordbatches_to_query_response<S>(
where
S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
{
let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?);
let format_options = format_options_from_query_ctx(&query_ctx);
let pg_schema = Arc::new(
schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?,
);
let pg_schema_ref = pg_schema.clone();
let data_row_stream = recordbatches_stream
.map(move |result| match result {
@@ -405,7 +408,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
.collect::<Vec<_>>();
if let Some(schema) = &sql_plan.schema {
schema_to_pg(schema, &Format::UnifiedBinary)
schema_to_pg(schema, &Format::UnifiedBinary, None)
.map(|fields| DescribeStatementResponse::new(param_types, fields))
.map_err(convert_err)
} else {
@@ -438,7 +441,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
Some(Statement::Query(_)) => {
// if the query has a schema, it is managed by datafusion, use the schema
if let Some(schema) = &sql_plan.schema {
schema_to_pg(schema, format)
schema_to_pg(schema, format, None)
.map(DescribePortalResponse::new)
.map_err(convert_err)
} else {

View File

@@ -12,66 +12,62 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod bytea;
mod datetime;
mod error;
mod interval;
use std::collections::HashMap;
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, AsArray};
use arrow::datatypes::{
Date32Type, Date64Type, Decimal128Type, Float32Type, Float64Type, Int8Type, Int16Type,
Int32Type, Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
};
use arrow_schema::{DataType, IntervalUnit, TimeUnit};
use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime};
use common_decimal::Decimal128;
use arrow::array::{Array, AsArray};
use arrow_pg::encoder::encode_value;
use arrow_pg::list_encoder::encode_list;
use arrow_schema::{DataType, TimeUnit};
use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime};
use common_recordbatch::RecordBatch;
use common_time::time::Time;
use common_time::{Date, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth, Timestamp};
use common_time::{IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth};
use datafusion_common::ScalarValue;
use datafusion_expr::LogicalPlan;
use datatypes::arrow::datatypes::DataType as ArrowDataType;
use datatypes::json::JsonStructureSettings;
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
use datatypes::schema::{Schema, SchemaRef};
use datatypes::types::{IntervalType, TimestampType, jsonb_to_string};
use datatypes::value::StructValue;
use pg_interval::Interval as PgInterval;
use pgwire::api::Type;
use pgwire::api::portal::{Format, Portal};
use pgwire::api::results::{DataRowEncoder, FieldInfo};
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::data::DataRow;
use pgwire::types::format::FormatOptions as PgFormatOptions;
use session::context::QueryContextRef;
use session::session_config::PGByteaOutputValue;
use snafu::ResultExt;
use self::bytea::{EscapeOutputBytea, HexOutputBytea};
use self::datetime::{StylingDate, StylingDateTime};
pub use self::error::{PgErrorCode, PgErrorSeverity};
use self::interval::PgInterval;
use crate::SqlPlan;
use crate::error::{self as server_error, DataFusionSnafu, Error, Result};
use crate::error::{self as server_error, DataFusionSnafu, Result};
use crate::postgres::utils::convert_err;
pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
pub(super) fn schema_to_pg(
origin: &Schema,
field_formats: &Format,
format_options: Option<Arc<PgFormatOptions>>,
) -> Result<Vec<FieldInfo>> {
origin
.column_schemas()
.iter()
.enumerate()
.map(|(idx, col)| {
Ok(FieldInfo::new(
let mut field_info = FieldInfo::new(
col.name.clone(),
None,
None,
type_gt_to_pg(&col.data_type)?,
field_formats.format_for(idx),
))
);
if let Some(format_options) = &format_options {
field_info = field_info.with_format_options(format_options.clone());
}
Ok(field_info)
})
.collect::<Result<Vec<FieldInfo>>>()
}
@@ -98,291 +94,6 @@ fn encode_struct(
builder.encode_field(&json_value)
}
fn encode_array(
query_ctx: &QueryContextRef,
array: ArrayRef,
builder: &mut DataRowEncoder,
) -> PgWireResult<()> {
macro_rules! encode_primitive_array {
($array: ident, $data_type: ty, $lower_type: ty, $upper_type: ty) => {{
let array = $array.iter().collect::<Vec<Option<$data_type>>>();
if array
.iter()
.all(|x| x.is_none_or(|i| i <= <$lower_type>::MAX as $data_type))
{
builder.encode_field(
&array
.into_iter()
.map(|x| x.map(|i| i as $lower_type))
.collect::<Vec<Option<$lower_type>>>(),
)
} else {
builder.encode_field(
&array
.into_iter()
.map(|x| x.map(|i| i as $upper_type))
.collect::<Vec<Option<$upper_type>>>(),
)
}
}};
}
match array.data_type() {
DataType::Boolean => {
let array = array.as_boolean();
let array = array.iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::Int8 => {
let array = array.as_primitive::<Int8Type>();
let array = array.iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::Int16 => {
let array = array.as_primitive::<Int16Type>();
let array = array.iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::Int32 => {
let array = array.as_primitive::<Int32Type>();
let array = array.iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::Int64 => {
let array = array.as_primitive::<Int64Type>();
let array = array.iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::UInt8 => {
let array = array.as_primitive::<UInt8Type>();
encode_primitive_array!(array, u8, i8, i16)
}
DataType::UInt16 => {
let array = array.as_primitive::<UInt16Type>();
encode_primitive_array!(array, u16, i16, i32)
}
DataType::UInt32 => {
let array = array.as_primitive::<UInt32Type>();
encode_primitive_array!(array, u32, i32, i64)
}
DataType::UInt64 => {
let array = array.as_primitive::<UInt64Type>();
let array = array.iter().collect::<Vec<_>>();
if array.iter().all(|x| x.is_none_or(|i| i <= i64::MAX as u64)) {
builder.encode_field(
&array
.into_iter()
.map(|x| x.map(|i| i as i64))
.collect::<Vec<Option<i64>>>(),
)
} else {
builder.encode_field(
&array
.into_iter()
.map(|x| x.map(|i| i.to_string()))
.collect::<Vec<_>>(),
)
}
}
DataType::Float32 => {
let array = array.as_primitive::<Float32Type>();
let array = array.iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::Float64 => {
let array = array.as_primitive::<Float64Type>();
let array = array.iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::Binary => {
let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
let array = array.as_binary::<i32>();
match *bytea_output {
PGByteaOutputValue::ESCAPE => {
let array = array
.iter()
.map(|v| v.map(EscapeOutputBytea))
.collect::<Vec<_>>();
builder.encode_field(&array)
}
PGByteaOutputValue::HEX => {
let array = array
.iter()
.map(|v| v.map(HexOutputBytea))
.collect::<Vec<_>>();
builder.encode_field(&array)
}
}
}
DataType::Utf8 => {
let array = array.as_string::<i32>();
let array = array.into_iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::LargeUtf8 => {
let array = array.as_string::<i64>();
let array = array.into_iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::Utf8View => {
let array = array.as_string_view();
let array = array.into_iter().collect::<Vec<_>>();
builder.encode_field(&array)
}
DataType::Date32 | DataType::Date64 => {
let iter: Box<dyn Iterator<Item = Option<i32>>> =
if matches!(array.data_type(), DataType::Date32) {
let array = array.as_primitive::<Date32Type>();
Box::new(array.into_iter())
} else {
let array = array.as_primitive::<Date64Type>();
// `Date64` values are milliseconds representation of `Date32` values, according
// to its specification. So we convert them to `Date32` values to process the
// `Date64` array unified with `Date32` array.
Box::new(
array
.into_iter()
.map(|x| x.map(|i| (i / 86_400_000) as i32)),
)
};
let array = iter
.into_iter()
.map(|v| match v {
None => Ok(None),
Some(v) => {
if let Some(date) = Date::new(v).to_chrono_date() {
let (style, order) =
*query_ctx.configuration_parameter().pg_datetime_style();
Ok(Some(StylingDate(date, style, order)))
} else {
Err(convert_err(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
}))
}
}
})
.collect::<PgWireResult<Vec<Option<StylingDate>>>>()?;
builder.encode_field(&array)
}
DataType::Timestamp(time_unit, _) => {
let array = match time_unit {
TimeUnit::Second => {
let array = array.as_primitive::<TimestampSecondType>();
array.into_iter().collect::<Vec<_>>()
}
TimeUnit::Millisecond => {
let array = array.as_primitive::<TimestampMillisecondType>();
array.into_iter().collect::<Vec<_>>()
}
TimeUnit::Microsecond => {
let array = array.as_primitive::<TimestampMicrosecondType>();
array.into_iter().collect::<Vec<_>>()
}
TimeUnit::Nanosecond => {
let array = array.as_primitive::<TimestampNanosecondType>();
array.into_iter().collect::<Vec<_>>()
}
};
let time_unit = time_unit.into();
let array = array
.into_iter()
.map(|v| match v {
None => Ok(None),
Some(v) => {
let v = Timestamp::new(v, time_unit);
if let Some(datetime) =
v.to_chrono_datetime_with_timezone(Some(&query_ctx.timezone()))
{
let (style, order) =
*query_ctx.configuration_parameter().pg_datetime_style();
Ok(Some(StylingDateTime(datetime, style, order)))
} else {
Err(convert_err(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
}))
}
}
})
.collect::<PgWireResult<Vec<Option<StylingDateTime>>>>()?;
builder.encode_field(&array)
}
DataType::Time32(time_unit) | DataType::Time64(time_unit) => {
let iter: Box<dyn Iterator<Item = Option<Time>>> = match time_unit {
TimeUnit::Second => {
let array = array.as_primitive::<Time32SecondType>();
Box::new(
array
.into_iter()
.map(|v| v.map(|i| Time::new_second(i as i64))),
)
}
TimeUnit::Millisecond => {
let array = array.as_primitive::<Time32MillisecondType>();
Box::new(
array
.into_iter()
.map(|v| v.map(|i| Time::new_millisecond(i as i64))),
)
}
TimeUnit::Microsecond => {
let array = array.as_primitive::<Time64MicrosecondType>();
Box::new(array.into_iter().map(|v| v.map(Time::new_microsecond)))
}
TimeUnit::Nanosecond => {
let array = array.as_primitive::<Time64NanosecondType>();
Box::new(array.into_iter().map(|v| v.map(Time::new_nanosecond)))
}
};
let array = iter
.into_iter()
.map(|v| v.and_then(|v| v.to_chrono_time()))
.collect::<Vec<Option<NaiveTime>>>();
builder.encode_field(&array)
}
DataType::Interval(interval_unit) => {
let array = match interval_unit {
IntervalUnit::YearMonth => {
let array = array.as_primitive::<IntervalYearMonthType>();
array
.into_iter()
.map(|v| v.map(|i| PgInterval::from(IntervalYearMonth::from(i))))
.collect::<Vec<_>>()
}
IntervalUnit::DayTime => {
let array = array.as_primitive::<IntervalDayTimeType>();
array
.into_iter()
.map(|v| v.map(|i| PgInterval::from(IntervalDayTime::from(i))))
.collect::<Vec<_>>()
}
IntervalUnit::MonthDayNano => {
let array = array.as_primitive::<IntervalMonthDayNanoType>();
array
.into_iter()
.map(|v| v.map(|i| PgInterval::from(IntervalMonthDayNano::from(i))))
.collect::<Vec<_>>()
}
};
builder.encode_field(&array)
}
DataType::Decimal128(precision, scale) => {
let array = array.as_primitive::<Decimal128Type>();
let array = array
.into_iter()
.map(|v| v.map(|i| Decimal128::new(i, *precision, *scale).to_string()))
.collect::<Vec<_>>();
builder.encode_field(&array)
}
_ => Err(convert_err(Error::Internal {
err_msg: format!(
"cannot write array type {:?} in postgres protocol: unimplemented",
array.data_type()
),
})),
}
}
pub(crate) struct RecordBatchRowIterator {
query_ctx: QueryContextRef,
pg_schema: Arc<Vec<FieldInfo>>,
@@ -426,175 +137,42 @@ impl RecordBatchRowIterator {
}
fn encode_row(&mut self, i: usize, encoder: &mut DataRowEncoder) -> PgWireResult<()> {
let arrow_schema = self.record_batch.schema();
for (j, column) in self.record_batch.columns().iter().enumerate() {
if column.is_null(i) {
encoder.encode_field(&None::<&i8>)?;
continue;
}
let pg_field = &self.pg_schema[j];
match column.data_type() {
DataType::Null => {
encoder.encode_field(&None::<&i8>)?;
}
DataType::Boolean => {
let array = column.as_boolean();
encoder.encode_field(&array.value(i))?;
}
DataType::UInt8 => {
let array = column.as_primitive::<UInt8Type>();
let value = array.value(i);
if value <= i8::MAX as u8 {
encoder.encode_field(&(value as i8))?;
} else {
encoder.encode_field(&(value as i16))?;
}
}
DataType::UInt16 => {
let array = column.as_primitive::<UInt16Type>();
let value = array.value(i);
if value <= i16::MAX as u16 {
encoder.encode_field(&(value as i16))?;
} else {
encoder.encode_field(&(value as i32))?;
}
}
DataType::UInt32 => {
let array = column.as_primitive::<UInt32Type>();
let value = array.value(i);
if value <= i32::MAX as u32 {
encoder.encode_field(&(value as i32))?;
} else {
encoder.encode_field(&(value as i64))?;
}
}
DataType::UInt64 => {
let array = column.as_primitive::<UInt64Type>();
let value = array.value(i);
if value <= i64::MAX as u64 {
encoder.encode_field(&(value as i64))?;
} else {
encoder.encode_field(&value.to_string())?;
}
}
DataType::Int8 => {
let array = column.as_primitive::<Int8Type>();
encoder.encode_field(&array.value(i))?;
}
DataType::Int16 => {
let array = column.as_primitive::<Int16Type>();
encoder.encode_field(&array.value(i))?;
}
DataType::Int32 => {
let array = column.as_primitive::<Int32Type>();
encoder.encode_field(&array.value(i))?;
}
DataType::Int64 => {
let array = column.as_primitive::<Int64Type>();
encoder.encode_field(&array.value(i))?;
}
DataType::Float32 => {
let array = column.as_primitive::<Float32Type>();
encoder.encode_field(&array.value(i))?;
}
DataType::Float64 => {
let array = column.as_primitive::<Float64Type>();
encoder.encode_field(&array.value(i))?;
}
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => {
let value = datatypes::arrow_array::string_array_value(column, i);
encoder.encode_field(&value)?;
}
// these types are greptimedb specific or custom
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
let v = datatypes::arrow_array::binary_array_value(column, i);
encode_bytes(
&self.schema.column_schemas()[j],
v,
encoder,
&self.query_ctx,
)?;
}
DataType::Date32 | DataType::Date64 => {
let v = if matches!(column.data_type(), DataType::Date32) {
let array = column.as_primitive::<Date32Type>();
array.value(i)
// jsonb
if let ConcreteDataType::Json(_) = &self.schema.column_schemas()[j].data_type {
let v = datatypes::arrow_array::binary_array_value(column, i);
let s = jsonb_to_string(v).map_err(convert_err)?;
encoder.encode_field(&s)?;
} else {
let array = column.as_primitive::<Date64Type>();
// `Date64` values are milliseconds representation of `Date32` values,
// according to its specification. So we convert the `Date64` value here to
// the `Date32` value to process them unified.
(array.value(i) / 86_400_000) as i32
};
let v = Date::new(v);
let date = v.to_chrono_date().map(|v| {
let (style, order) =
*self.query_ctx.configuration_parameter().pg_datetime_style();
StylingDate(v, style, order)
});
encoder.encode_field(&date)?;
}
DataType::Timestamp(_, _) => {
let v = datatypes::arrow_array::timestamp_array_value(column, i);
let datetime = v
.to_chrono_datetime_with_timezone(Some(&self.query_ctx.timezone()))
.map(|v| {
let (style, order) =
*self.query_ctx.configuration_parameter().pg_datetime_style();
StylingDateTime(v, style, order)
});
encoder.encode_field(&datetime)?;
}
DataType::Interval(interval_unit) => match interval_unit {
IntervalUnit::YearMonth => {
let array = column.as_primitive::<IntervalYearMonthType>();
let v: IntervalYearMonth = array.value(i).into();
encoder.encode_field(&PgInterval::from(v))?;
}
IntervalUnit::DayTime => {
let array = column.as_primitive::<IntervalDayTimeType>();
let v: IntervalDayTime = array.value(i).into();
encoder.encode_field(&PgInterval::from(v))?;
}
IntervalUnit::MonthDayNano => {
let array = column.as_primitive::<IntervalMonthDayNanoType>();
let v: IntervalMonthDayNano = array.value(i).into();
encoder.encode_field(&PgInterval::from(v))?;
}
},
DataType::Duration(_) => {
let d = datatypes::arrow_array::duration_array_value(column, i);
match PgInterval::try_from(d) {
Ok(i) => encoder.encode_field(&i)?,
Err(e) => {
return Err(convert_err(Error::Internal {
err_msg: e.to_string(),
}));
}
// bytea
let arrow_field = arrow_schema.field(j);
encode_value(encoder, column, i, arrow_field, pg_field)?;
}
}
DataType::List(_) => {
let array = column.as_list::<i32>();
let items = array.value(i);
encode_array(&self.query_ctx, items, encoder)?;
encode_list(encoder, items, pg_field)?;
}
DataType::Struct(_) => {
encode_struct(&self.query_ctx, Default::default(), encoder)?;
}
DataType::Time32(_) | DataType::Time64(_) => {
let v = datatypes::arrow_array::time_array_value(column, i);
encoder.encode_field(&v.to_chrono_time())?;
}
DataType::Decimal128(precision, scale) => {
let array = column.as_primitive::<Decimal128Type>();
let v = Decimal128::new(array.value(i), *precision, *scale);
encoder.encode_field(&v.to_string())?;
}
_ => {
return Err(convert_err(Error::Internal {
err_msg: format!(
"cannot convert datatype {} to postgres",
column.data_type()
),
}));
// Encode value using arrow-pg
let arrow_field = arrow_schema.field(j);
encode_value(encoder, column, i, arrow_field, pg_field)?;
}
}
}
@@ -602,32 +180,15 @@ impl RecordBatchRowIterator {
}
}
fn encode_bytes(
schema: &ColumnSchema,
v: &[u8],
encoder: &mut DataRowEncoder,
query_ctx: &QueryContextRef,
) -> PgWireResult<()> {
if let ConcreteDataType::Json(_) = &schema.data_type {
let s = jsonb_to_string(v).map_err(convert_err)?;
encoder.encode_field(&s)
} else {
let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
match *bytea_output {
PGByteaOutputValue::ESCAPE => encoder.encode_field(&EscapeOutputBytea(v)),
PGByteaOutputValue::HEX => encoder.encode_field(&HexOutputBytea(v)),
}
}
}
pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
match origin {
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
&ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
&ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
&ConcreteDataType::Int8(_) => Ok(Type::CHAR),
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt8(_) => Ok(Type::INT2),
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT4),
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT8),
&ConcreteDataType::UInt64(_) => Ok(Type::NUMERIC),
&ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
&ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
&ConcreteDataType::Binary(_) | &ConcreteDataType::Vector(_) => Ok(Type::BYTEA),
@@ -641,10 +202,11 @@ pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
ConcreteDataType::List(list) => match list.item_type() {
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
&ConcreteDataType::Boolean(_) => Ok(Type::BOOL_ARRAY),
&ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR_ARRAY),
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2_ARRAY),
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4_ARRAY),
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8_ARRAY),
&ConcreteDataType::Int8(_) => Ok(Type::CHAR_ARRAY),
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt8(_) => Ok(Type::INT2_ARRAY),
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT4_ARRAY),
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT8_ARRAY),
&ConcreteDataType::UInt64(_) => Ok(Type::NUMERIC_ARRAY),
&ConcreteDataType::Float32(_) => Ok(Type::FLOAT4_ARRAY),
&ConcreteDataType::Float64(_) => Ok(Type::FLOAT8_ARRAY),
&ConcreteDataType::Binary(_) => Ok(Type::BYTEA_ARRAY),
@@ -762,7 +324,7 @@ pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWir
.unwrap_or_else(|| "".to_owned())),
&Type::INTERVAL => Ok(portal
.parameter::<PgInterval>(idx, param_type)?
.map(|v| v.to_string())
.map(|v| v.to_sql())
.unwrap_or_else(|| "".to_owned())),
_ => Err(invalid_parameter_error(
"unsupported_parameter_type",
@@ -1133,9 +695,14 @@ pub(super) fn parameters_to_scalar_values(
)
}
ConcreteDataType::Interval(IntervalType::MonthDayNano(_)) => {
ScalarValue::IntervalMonthDayNano(
data.map(|i| IntervalMonthDayNano::from(i).into()),
)
ScalarValue::IntervalMonthDayNano(data.map(|i| {
IntervalMonthDayNano::new(
i.months,
i.days,
i.microseconds * 1_000i64,
)
.into()
}))
}
_ => {
return Err(invalid_parameter_error(
@@ -1145,9 +712,10 @@ pub(super) fn parameters_to_scalar_values(
}
}
} else {
ScalarValue::IntervalMonthDayNano(
data.map(|i| IntervalMonthDayNano::from(i).into()),
)
ScalarValue::IntervalMonthDayNano(data.map(|i| {
IntervalMonthDayNano::new(i.months, i.days, i.microseconds * 1_000i64)
.into()
}))
}
}
&Type::BYTEA => {
@@ -1454,6 +1022,19 @@ pub(super) fn param_types_to_pg_types(
Ok(types)
}
pub fn format_options_from_query_ctx(query_ctx: &QueryContextRef) -> Arc<PgFormatOptions> {
let config = query_ctx.configuration_parameter();
let (date_style, date_order) = *config.pg_datetime_style();
let mut format_options = PgFormatOptions::default();
format_options.date_style = format!("{}, {}", date_style, date_order);
format_options.interval_style = config.pg_intervalstyle_format().to_string();
format_options.bytea_output = config.postgres_bytea_output().to_string();
format_options.time_zone = query_ctx.timezone().to_string();
Arc::new(format_options)
}
#[cfg(test)]
mod test {
use std::sync::Arc;
@@ -1461,7 +1042,7 @@ mod test {
use arrow::array::{
Float64Builder, Int64Builder, ListBuilder, StringBuilder, TimestampSecondBuilder,
};
use arrow_schema::Field;
use arrow_schema::{Field, IntervalUnit};
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{
BinaryVector, BooleanVector, DateVector, Float32Vector, Float64Vector, Int8Vector,
@@ -1512,10 +1093,16 @@ mod test {
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new("uint8s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("uint16s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("uint32s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new(
"uint64s".into(),
None,
None,
Type::NUMERIC,
FieldFormat::Text,
),
FieldInfo::new(
"float32s".into(),
None,
@@ -1562,7 +1149,7 @@ mod test {
),
];
let schema = Schema::new(column_schemas);
let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
let fs = schema_to_pg(&schema, &Format::UnifiedText, None).unwrap();
assert_eq!(fs, pg_field_info);
}
@@ -1571,10 +1158,16 @@ mod test {
let schema = vec![
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new("uint8s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("uint16s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("uint32s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new(
"uint64s".into(),
None,
None,
Type::NUMERIC,
FieldFormat::Text,
),
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),

View File

@@ -1,158 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bytes::BufMut;
use pgwire::types::ToSqlText;
use pgwire::types::format::FormatOptions;
use postgres_types::{IsNull, ToSql, Type};
#[derive(Debug)]
pub struct HexOutputBytea<'a>(pub &'a [u8]);
impl ToSqlText for HexOutputBytea<'_> {
fn to_sql_text(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
format_options: &FormatOptions,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
let _ = self.0.to_sql_text(ty, out, format_options);
Ok(IsNull::No)
}
}
impl ToSql for HexOutputBytea<'_> {
fn to_sql(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
self.0.to_sql(ty, out)
}
fn accepts(ty: &Type) -> bool
where
Self: Sized,
{
<&[u8] as ToSql>::accepts(ty)
}
fn to_sql_checked(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
self.0.to_sql_checked(ty, out)
}
}
#[derive(Debug)]
pub struct EscapeOutputBytea<'a>(pub &'a [u8]);
impl ToSqlText for EscapeOutputBytea<'_> {
fn to_sql_text(
&self,
_ty: &Type,
out: &mut bytes::BytesMut,
_format_options: &FormatOptions,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
self.0.iter().for_each(|b| match b {
0..=31 | 127..=255 => {
out.put_slice(b"\\");
out.put_slice(format!("{:03o}", b).as_bytes());
}
92 => out.put_slice(b"\\\\"),
32..=126 => out.put_u8(*b),
});
Ok(IsNull::No)
}
}
impl ToSql for EscapeOutputBytea<'_> {
fn to_sql(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
self.0.to_sql(ty, out)
}
fn accepts(ty: &Type) -> bool
where
Self: Sized,
{
<&[u8] as ToSql>::accepts(ty)
}
fn to_sql_checked(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
self.0.to_sql_checked(ty, out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_output_bytea() {
let input: &[u8] = &[97, 98, 99, 107, 108, 109, 42, 169, 84];
let input = EscapeOutputBytea(input);
let expected = b"abcklm*\\251T";
let mut out = bytes::BytesMut::new();
let is_null = input
.to_sql_text(&Type::BYTEA, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);
let expected = &[97, 98, 99, 107, 108, 109, 42, 169, 84];
let mut out = bytes::BytesMut::new();
let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);
}
#[test]
fn test_hex_output_bytea() {
let input = b"hello, world!";
let input = HexOutputBytea(input);
let expected = b"\\x68656c6c6f2c20776f726c6421";
let mut out = bytes::BytesMut::new();
let is_null = input
.to_sql_text(&Type::BYTEA, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);
let expected = b"hello, world!";
let mut out = bytes::BytesMut::new();
let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);
}
}

View File

@@ -1,430 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bytes::BufMut;
use chrono::{NaiveDate, NaiveDateTime};
use pgwire::types::ToSqlText;
use pgwire::types::format::FormatOptions;
use postgres_types::{IsNull, ToSql, Type};
use session::session_config::{PGDateOrder, PGDateTimeStyle};
#[derive(Debug)]
pub struct StylingDate(pub NaiveDate, pub PGDateTimeStyle, pub PGDateOrder);
#[derive(Debug)]
pub struct StylingDateTime(pub NaiveDateTime, pub PGDateTimeStyle, pub PGDateOrder);
fn date_format_string(style: PGDateTimeStyle, order: PGDateOrder) -> &'static str {
match style {
PGDateTimeStyle::ISO => "%Y-%m-%d",
PGDateTimeStyle::German => "%d.%m.%Y",
PGDateTimeStyle::Postgres => match order {
PGDateOrder::MDY | PGDateOrder::YMD => "%m-%d-%Y",
PGDateOrder::DMY => "%d-%m-%Y",
},
PGDateTimeStyle::SQL => match order {
PGDateOrder::MDY | PGDateOrder::YMD => "%m/%d/%Y",
PGDateOrder::DMY => "%d/%m/%Y",
},
}
}
fn datetime_format_string(style: PGDateTimeStyle, order: PGDateOrder) -> &'static str {
match style {
PGDateTimeStyle::ISO => "%Y-%m-%d %H:%M:%S%.6f",
PGDateTimeStyle::German => "%d.%m.%Y %H:%M:%S%.6f",
PGDateTimeStyle::Postgres => match order {
PGDateOrder::MDY | PGDateOrder::YMD => "%a %b %d %H:%M:%S%.6f %Y",
PGDateOrder::DMY => "%a %d %b %H:%M:%S%.6f %Y",
},
PGDateTimeStyle::SQL => match order {
PGDateOrder::MDY | PGDateOrder::YMD => "%m/%d/%Y %H:%M:%S%.6f",
PGDateOrder::DMY => "%d/%m/%Y %H:%M:%S%.6f",
},
}
}
impl ToSqlText for StylingDate {
fn to_sql_text(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
format_options: &FormatOptions,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
match *ty {
Type::DATE => {
let fmt = self
.0
.format(date_format_string(self.1, self.2))
.to_string();
out.put_slice(fmt.as_bytes());
}
_ => {
self.0.to_sql_text(ty, out, format_options)?;
}
}
Ok(IsNull::No)
}
}
impl ToSqlText for StylingDateTime {
fn to_sql_text(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
format_options: &FormatOptions,
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
match *ty {
Type::TIMESTAMP => {
let fmt = self
.0
.format(datetime_format_string(self.1, self.2))
.to_string();
out.put_slice(fmt.as_bytes());
}
Type::DATE => {
let fmt = self
.0
.format(date_format_string(self.1, self.2))
.to_string();
out.put_slice(fmt.as_bytes());
}
_ => {
self.0.to_sql_text(ty, out, format_options)?;
}
}
Ok(IsNull::No)
}
}
macro_rules! delegate_to_sql {
($delegator:ident, $delegatee:ident) => {
impl ToSql for $delegator {
fn to_sql(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
self.0.to_sql(ty, out)
}
fn accepts(ty: &Type) -> bool {
<$delegatee as ToSql>::accepts(ty)
}
fn to_sql_checked(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
self.0.to_sql_checked(ty, out)
}
}
};
}
delegate_to_sql!(StylingDate, NaiveDate);
delegate_to_sql!(StylingDateTime, NaiveDateTime);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_styling_date() {
let naive_date = NaiveDate::from_ymd_opt(1997, 12, 17).unwrap();
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::ISO, PGDateOrder::MDY);
let expected = "1997-12-17";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::ISO, PGDateOrder::YMD);
let expected = "1997-12-17";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::ISO, PGDateOrder::DMY);
let expected = "1997-12-17";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::German, PGDateOrder::MDY);
let expected = "17.12.1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::German, PGDateOrder::YMD);
let expected = "17.12.1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::German, PGDateOrder::DMY);
let expected = "17.12.1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::Postgres, PGDateOrder::MDY);
let expected = "12-17-1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::Postgres, PGDateOrder::YMD);
let expected = "12-17-1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::Postgres, PGDateOrder::DMY);
let expected = "17-12-1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::SQL, PGDateOrder::MDY);
let expected = "12/17/1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::SQL, PGDateOrder::YMD);
let expected = "12/17/1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_date = StylingDate(naive_date, PGDateTimeStyle::SQL, PGDateOrder::DMY);
let expected = "17/12/1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
}
#[test]
fn test_styling_datetime() {
let input =
NaiveDateTime::parse_from_str("2021-09-01 12:34:56.789012", "%Y-%m-%d %H:%M:%S%.f")
.unwrap();
{
let styling_datetime = StylingDateTime(input, PGDateTimeStyle::ISO, PGDateOrder::MDY);
let expected = "2021-09-01 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime = StylingDateTime(input, PGDateTimeStyle::ISO, PGDateOrder::YMD);
let expected = "2021-09-01 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime = StylingDateTime(input, PGDateTimeStyle::ISO, PGDateOrder::DMY);
let expected = "2021-09-01 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime =
StylingDateTime(input, PGDateTimeStyle::German, PGDateOrder::MDY);
let expected = "01.09.2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime =
StylingDateTime(input, PGDateTimeStyle::German, PGDateOrder::YMD);
let expected = "01.09.2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime =
StylingDateTime(input, PGDateTimeStyle::German, PGDateOrder::DMY);
let expected = "01.09.2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime =
StylingDateTime(input, PGDateTimeStyle::Postgres, PGDateOrder::MDY);
let expected = "Wed Sep 01 12:34:56.789012 2021";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime =
StylingDateTime(input, PGDateTimeStyle::Postgres, PGDateOrder::YMD);
let expected = "Wed Sep 01 12:34:56.789012 2021";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime =
StylingDateTime(input, PGDateTimeStyle::Postgres, PGDateOrder::DMY);
let expected = "Wed 01 Sep 12:34:56.789012 2021";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime = StylingDateTime(input, PGDateTimeStyle::SQL, PGDateOrder::MDY);
let expected = "09/01/2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime = StylingDateTime(input, PGDateTimeStyle::SQL, PGDateOrder::YMD);
let expected = "09/01/2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
{
let styling_datetime = StylingDateTime(input, PGDateTimeStyle::SQL, PGDateOrder::DMY);
let expected = "01/09/2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
}
}

View File

@@ -1,320 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt::Display;
use bytes::{Buf, BufMut};
use common_time::interval::IntervalFormat;
use common_time::timestamp::TimeUnit;
use common_time::{Duration, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth};
use pgwire::types::format::FormatOptions;
use pgwire::types::{FromSqlText, ToSqlText};
use postgres_types::{FromSql, IsNull, ToSql, Type, to_sql_checked};
use crate::error;
/// On average one month has 30.44 day, which is a common approximation.
const SECONDS_PER_MONTH: i64 = 24 * 6 * 6 * 3044;
const SECONDS_PER_DAY: i64 = 24 * 60 * 60;
const MILLISECONDS_PER_MONTH: i64 = SECONDS_PER_MONTH * 1000;
const MILLISECONDS_PER_DAY: i64 = SECONDS_PER_DAY * 1000;
#[derive(Debug, Clone, Copy, Default)]
pub struct PgInterval {
pub(crate) months: i32,
pub(crate) days: i32,
pub(crate) microseconds: i64,
}
impl From<IntervalYearMonth> for PgInterval {
fn from(interval: IntervalYearMonth) -> Self {
Self {
months: interval.months,
days: 0,
microseconds: 0,
}
}
}
impl From<IntervalDayTime> for PgInterval {
fn from(interval: IntervalDayTime) -> Self {
Self {
months: 0,
days: interval.days,
microseconds: interval.milliseconds as i64 * 1000,
}
}
}
impl From<IntervalMonthDayNano> for PgInterval {
fn from(interval: IntervalMonthDayNano) -> Self {
Self {
months: interval.months,
days: interval.days,
microseconds: interval.nanoseconds / 1000,
}
}
}
impl TryFrom<Duration> for PgInterval {
type Error = error::Error;
fn try_from(duration: Duration) -> error::Result<Self> {
let value = duration.value();
let unit = duration.unit();
// Convert the duration to microseconds
match unit {
TimeUnit::Second => {
let months = i32::try_from(value / SECONDS_PER_MONTH)
.map_err(|_| error::DurationOverflowSnafu { val: duration }.build())?;
let days =
i32::try_from((value - (months as i64) * SECONDS_PER_MONTH) / SECONDS_PER_DAY)
.map_err(|_| error::DurationOverflowSnafu { val: duration }.build())?;
let microseconds =
(value - (months as i64) * SECONDS_PER_MONTH - (days as i64) * SECONDS_PER_DAY)
.checked_mul(1_000_000)
.ok_or(error::DurationOverflowSnafu { val: duration }.build())?;
Ok(Self {
months,
days,
microseconds,
})
}
TimeUnit::Millisecond => {
let months = i32::try_from(value / MILLISECONDS_PER_MONTH)
.map_err(|_| error::DurationOverflowSnafu { val: duration }.build())?;
let days = i32::try_from(
(value - (months as i64) * MILLISECONDS_PER_MONTH) / MILLISECONDS_PER_DAY,
)
.map_err(|_| error::DurationOverflowSnafu { val: duration }.build())?;
let microseconds = ((value - (months as i64) * MILLISECONDS_PER_MONTH)
- (days as i64) * MILLISECONDS_PER_DAY)
* 1_000;
Ok(Self {
months,
days,
microseconds,
})
}
TimeUnit::Microsecond => Ok(Self {
months: 0,
days: 0,
microseconds: value,
}),
TimeUnit::Nanosecond => Ok(Self {
months: 0,
days: 0,
microseconds: value / 1000,
}),
}
}
}
impl From<PgInterval> for IntervalMonthDayNano {
fn from(interval: PgInterval) -> Self {
IntervalMonthDayNano::new(
interval.months,
interval.days,
// Maybe overflow, but most scenarios ok.
interval.microseconds.checked_mul(1000).unwrap_or_else(|| {
if interval.microseconds.is_negative() {
i64::MIN
} else {
i64::MAX
}
}),
)
}
}
impl Display for PgInterval {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
IntervalFormat::from(IntervalMonthDayNano::from(*self)).to_postgres_string()
)
}
}
impl ToSql for PgInterval {
to_sql_checked!();
fn to_sql(
&self,
_: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<postgres_types::IsNull, Box<dyn snafu::Error + Sync + Send>>
where
Self: Sized,
{
// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L989-L991
out.put_i64(self.microseconds);
out.put_i32(self.days);
out.put_i32(self.months);
Ok(postgres_types::IsNull::No)
}
fn accepts(ty: &Type) -> bool
where
Self: Sized,
{
matches!(ty, &Type::INTERVAL)
}
}
impl<'a> FromSql<'a> for PgInterval {
fn from_sql(
_: &Type,
mut raw: &'a [u8],
) -> std::result::Result<Self, Box<dyn snafu::Error + Sync + Send>> {
// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/timestamp.c#L1007-L1010
let microseconds = raw.get_i64();
let days = raw.get_i32();
let months = raw.get_i32();
Ok(PgInterval {
months,
days,
microseconds,
})
}
fn accepts(ty: &Type) -> bool {
matches!(ty, &Type::INTERVAL)
}
}
impl ToSqlText for PgInterval {
fn to_sql_text(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
_format_options: &FormatOptions,
) -> std::result::Result<postgres_types::IsNull, Box<dyn snafu::Error + Sync + Send>>
where
Self: Sized,
{
let fmt = match ty {
&Type::INTERVAL => self.to_string(),
_ => return Err("unsupported type".into()),
};
out.put_slice(fmt.as_bytes());
Ok(IsNull::No)
}
}
impl<'a> FromSqlText<'a> for PgInterval {
fn from_sql_text(
_ty: &Type,
input: &[u8],
_format_options: &FormatOptions,
) -> std::result::Result<Self, Box<dyn snafu::Error + Sync + Send>>
where
Self: Sized,
{
// only support parsing interval from postgres format
if let Ok(interval) = pg_interval::Interval::from_postgres(str::from_utf8(input)?) {
Ok(PgInterval {
months: interval.months,
days: interval.days,
microseconds: interval.microseconds,
})
} else {
Err("invalid interval format".into())
}
}
}
#[cfg(test)]
mod tests {
use common_time::Duration;
use common_time::timestamp::TimeUnit;
use super::*;
#[test]
fn test_duration_to_pg_interval() {
// Test with seconds
let duration = Duration::new(86400, TimeUnit::Second); // 1 day
let interval = PgInterval::try_from(duration).unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, 1);
assert_eq!(interval.microseconds, 0);
// Test with milliseconds
let duration = Duration::new(86400000, TimeUnit::Millisecond); // 1 day
let interval = PgInterval::try_from(duration).unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, 1);
assert_eq!(interval.microseconds, 0);
// Test with microseconds
let duration = Duration::new(86400000000, TimeUnit::Microsecond); // 1 day
let interval = PgInterval::try_from(duration).unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, 0);
assert_eq!(interval.microseconds, 86400000000);
// Test with nanoseconds
let duration = Duration::new(86400000000000, TimeUnit::Nanosecond); // 1 day
let interval = PgInterval::try_from(duration).unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, 0);
assert_eq!(interval.microseconds, 86400000000);
// Test with partial day
let duration = Duration::new(43200, TimeUnit::Second); // 12 hours
let interval = PgInterval::try_from(duration).unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, 0);
assert_eq!(interval.microseconds, 43_200_000_000); // 12 hours in microseconds
// Test with negative duration
let duration = Duration::new(-86400, TimeUnit::Second); // -1 day
let interval = PgInterval::try_from(duration).unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, -1);
assert_eq!(interval.microseconds, 0);
// Test with multiple days
let duration = Duration::new(259200, TimeUnit::Second); // 3 days
let interval = PgInterval::try_from(duration).unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, 3);
assert_eq!(interval.microseconds, 0);
// Test with small duration (less than a day)
let duration = Duration::new(3600, TimeUnit::Second); // 1 hour
let interval = PgInterval::try_from(duration).unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, 0);
assert_eq!(interval.microseconds, 3600000000); // 1 hour in microseconds
// Test with very small duration
let duration = Duration::new(1, TimeUnit::Microsecond); // 1 microsecond
let interval = PgInterval::try_from(duration).unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, 0);
assert_eq!(interval.microseconds, 1);
let duration = Duration::new(i64::MAX, TimeUnit::Second);
assert!(PgInterval::try_from(duration).is_err());
let duration = Duration::new(i64::MAX, TimeUnit::Millisecond);
assert!(PgInterval::try_from(duration).is_err());
}
}

View File

@@ -355,8 +355,8 @@ async fn test_extended_query() -> Result<()> {
let rows = client.query(&stmt, &[&1i32]).await.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].len(), 2);
assert_eq!(rows[0].get::<usize, i32>(0usize), 1);
assert_eq!(rows[0].get::<&str, i32>("uint32s"), 1);
assert_eq!(rows[0].get::<usize, i64>(0usize), 1);
assert_eq!(rows[0].get::<&str, i64>("uint32s"), 1);
assert_eq!(rows[0].get::<usize, i64>(1usize), 2);
assert_eq!(rows[0].get::<&str, i64>("numbers.uint32s + Int64(1)"), 2);

View File

@@ -33,7 +33,7 @@ use derive_builder::Builder;
use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
use crate::protocol_ctx::ProtocolCtx;
use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle};
use crate::{MutableInner, ReadPreference};
pub type QueryContextRef = Arc<QueryContext>;
@@ -614,6 +614,7 @@ impl AsRef<str> for Channel {
pub struct ConfigurationVariables {
postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
pg_datestyle_format: ArcSwap<(PGDateTimeStyle, PGDateOrder)>,
pg_intervalstyle_format: ArcSwap<PGIntervalStyle>,
allow_query_fallback: ArcSwap<bool>,
}
@@ -622,6 +623,7 @@ impl Clone for ConfigurationVariables {
Self {
postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
pg_datestyle_format: ArcSwap::new(self.pg_datestyle_format.load().clone()),
pg_intervalstyle_format: ArcSwap::new(self.pg_intervalstyle_format.load().clone()),
allow_query_fallback: ArcSwap::new(self.allow_query_fallback.load().clone()),
}
}
@@ -648,6 +650,14 @@ impl ConfigurationVariables {
self.pg_datestyle_format.swap(Arc::new((style, order)));
}
pub fn pg_intervalstyle_format(&self) -> Arc<PGIntervalStyle> {
self.pg_intervalstyle_format.load().clone()
}
pub fn set_pg_intervalstyle_format(&self, value: PGIntervalStyle) {
self.pg_intervalstyle_format.swap(Arc::new(value));
}
pub fn allow_query_fallback(&self) -> bool {
**self.allow_query_fallback.load()
}

View File

@@ -66,6 +66,15 @@ impl TryFrom<Value> for PGByteaOutputValue {
}
}
impl Display for PGByteaOutputValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PGByteaOutputValue::HEX => write!(f, "hex"),
PGByteaOutputValue::ESCAPE => write!(f, "escape"),
}
}
}
// Refers to: https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-DATESTYLE
#[derive(Default, PartialEq, Eq, Clone, Copy, Debug)]
pub enum PGDateOrder {
@@ -176,3 +185,60 @@ impl TryFrom<&Value> for PGDateTimeStyle {
}
}
}
#[derive(Default, PartialEq, Eq, Clone, Copy, Debug)]
pub enum PGIntervalStyle {
ISO,
SQL,
#[default]
Postgres,
PostgresVerbose,
}
impl Display for PGIntervalStyle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PGIntervalStyle::ISO => write!(f, "iso_8601"),
PGIntervalStyle::SQL => write!(f, "sql_standard"),
PGIntervalStyle::Postgres => write!(f, "postgres"),
PGIntervalStyle::PostgresVerbose => write!(f, "postgres_verbose"),
}
}
}
impl TryFrom<&str> for PGIntervalStyle {
type Error = Error;
fn try_from(s: &str) -> Result<Self, Self::Error> {
match s.to_uppercase().as_str() {
"ISO" | "ISO_8601" => Ok(PGIntervalStyle::ISO),
"SQL" | "SQL_STANDARD" => Ok(PGIntervalStyle::SQL),
"POSTGRES" => Ok(PGIntervalStyle::Postgres),
"POSTGRES_VERBOSE" | "POSTGRES, VERBOSE" => Ok(PGIntervalStyle::PostgresVerbose),
_ => InvalidConfigValueSnafu {
name: "IntervalStyle",
value: s,
hint: format!("Unrecognized key word: {}", s),
}
.fail(),
}
}
}
impl TryFrom<&Value> for PGIntervalStyle {
type Error = Error;
fn try_from(value: &Value) -> Result<Self, Self::Error> {
match value {
Value::DoubleQuotedString(s) | Value::SingleQuotedString(s) => {
Self::try_from(s.as_str())
}
_ => InvalidConfigValueSnafu {
name: "IntervalStyle",
value: value.to_string(),
hint: format!("Unrecognized key word: {}", value),
}
.fail(),
}
}
}

View File

@@ -24,6 +24,7 @@ use common_frontend::slow_query_event::{
};
use sqlx::mysql::{MySqlConnection, MySqlDatabaseError, MySqlPoolOptions};
use sqlx::postgres::{PgDatabaseError, PgPoolOptions};
use sqlx::types::Decimal;
use sqlx::{Connection, Executor, Row};
use tests_integration::test_util::{
StorageType, setup_mysql_server, setup_mysql_server_with_user_provider, setup_pg_server,
@@ -77,6 +78,7 @@ macro_rules! sql_tests {
test_postgres_bytea,
test_postgres_slow_query,
test_postgres_datestyle,
test_postgres_intervalstyle,
test_postgres_parameter_inference,
test_postgres_array_types,
test_mysql_prepare_stmt_insert_timestamp,
@@ -834,12 +836,12 @@ pub async fn test_postgres_slow_query(store_type: StorageType) {
let rows = sqlx::query(&query).fetch_all(&pool).await.unwrap();
assert_eq!(rows.len(), 1);
let row = &rows[0];
let cost: i64 = row.get(0);
let threshold: i64 = row.get(1);
let cost: Decimal = row.get(0);
let threshold: Decimal = row.get(1);
let query: String = row.get(2);
let is_promql: bool = row.get(3);
assert!(cost > 0 && threshold > 0 && cost > threshold);
assert!(cost > 0.into() && threshold > 0.into() && cost > threshold);
assert_eq!(query, slow_query);
assert!(!is_promql);
@@ -1075,6 +1077,90 @@ pub async fn test_postgres_datestyle(store_type: StorageType) {
guard.remove_all().await;
}
pub async fn test_postgres_intervalstyle(store_type: StorageType) {
let (mut guard, fe_pg_server) =
setup_pg_server(store_type, "test_postgres_intervalstyle").await;
let addr = fe_pg_server.bind_addr().unwrap().to_string();
let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
.await
.unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
connection.await.unwrap();
tx.send(()).unwrap();
});
let validate_intervalstyle = |client: Client, intervalstyle: &str, is_valid: bool| {
let intervalstyle = intervalstyle.to_string();
async move {
assert_eq!(
client
.simple_query(format!("SET INTERVALSTYLE='{}'", intervalstyle).as_str())
.await
.is_ok(),
is_valid,
"testing intervalstyle {intervalstyle}"
);
client
}
};
let get_row = |mess: Vec<SimpleQueryMessage>| -> String {
match &mess[1] {
SimpleQueryMessage::Row(row) => row.get(0).unwrap().to_string(),
_ => unreachable!(),
}
};
let client = validate_intervalstyle(client, "iso_8601", true).await;
let client = validate_intervalstyle(client, "sql_standard", true).await;
let client = validate_intervalstyle(client, "postgres", true).await;
let client = validate_intervalstyle(client, "postgres_verbose", true).await;
let client = validate_intervalstyle(client, "invalid_style", false).await;
let expected_formats: HashMap<&str, &str> = HashMap::from([
("iso_8601", "P1DT2H3M"),
("sql_standard", "1 2:03:00"),
("postgres", "1 day 02:03:00"),
("postgres_verbose", "@ 1 day 2 hours 3 mins"),
]);
for (style, expected_format) in expected_formats {
let _ = client
.simple_query(&format!("SET INTERVALSTYLE='{}'", style))
.await
.expect("SET INTERVALSTYLE ERROR");
let interval = get_row(
client
.simple_query("SHOW VARIABLES intervalstyle")
.await
.unwrap(),
);
assert_eq!(interval, style);
let result = get_row(
client
.simple_query("SELECT INTERVAL '1 day 2 hours 3 minutes'")
.await
.unwrap(),
);
assert_eq!(
result, expected_format,
"intervalstyle {}: expected '{}', got '{}'",
style, expected_format, result
);
}
drop(client);
rx.await.unwrap();
let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}
pub async fn test_postgres_timezone(store_type: StorageType) {
let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_timezone").await;
let addr = fe_pg_server.bind_addr().unwrap().to_string();