fix: postgres extended query paramater parsing and type check (#7276)

* fix: postgres extended query paramater parsing and type check

* test: update sqlness output

* feat: implement FromSqlText for pg_interval

* chore: toml format
This commit is contained in:
Ning Sun
2025-11-24 10:40:35 +08:00
committed by GitHub
parent c9a7b1fd68
commit 2f447e6f91
10 changed files with 150 additions and 54 deletions

19
Cargo.lock generated
View File

@@ -9183,10 +9183,21 @@ dependencies = [
]
[[package]]
name = "pgwire"
version = "0.34.2"
name = "pg_interval"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f56a81b4fcc69016028f657a68f9b8e8a2a4b7d07684ca3298f2d3e7ff199ce"
checksum = "fe46640b465e284b048ef065cbed8ef17a622878d310c724578396b4cfd00df2"
dependencies = [
"bytes",
"chrono",
"postgres-types",
]
[[package]]
name = "pgwire"
version = "0.36.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5cc59678d0c10c73a552d465ce9156995189d1c678f2784dc817fe8623487f5"
dependencies = [
"async-trait",
"base64 0.22.1",
@@ -9202,6 +9213,7 @@ dependencies = [
"ring",
"rust_decimal",
"rustls-pki-types",
"ryu",
"serde",
"serde_json",
"stringprep",
@@ -11585,6 +11597,7 @@ dependencies = [
"otel-arrow-rust",
"parking_lot 0.12.4",
"permutation",
"pg_interval",
"pgwire",
"pin-project",
"pipeline",

View File

@@ -86,7 +86,8 @@ opentelemetry-proto.workspace = true
operator.workspace = true
otel-arrow-rust.workspace = true
parking_lot.workspace = true
pgwire = { version = "0.34", default-features = false, features = [
pg_interval = "0.4"
pgwire = { version = "0.36", default-features = false, features = [
"server-api-ring",
"pg-ext-types",
] }

View File

@@ -201,7 +201,7 @@ impl QueryParser for DefaultQueryParser {
&self,
_client: &C,
sql: &str,
_types: &[Type],
_types: &[Option<Type>],
) -> PgWireResult<Self::Statement> {
crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
let query_ctx = self.session.new_query_context();
@@ -341,7 +341,9 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
C: ClientInfo + Unpin + Send + Sync,
{
let sql_plan = &stmt.statement;
let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan {
// client provided parameter types, can be empty if client doesn't try to parse statement
let provided_param_types = &stmt.parameter_types;
let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
let param_types = plan
.get_parameter_types()
.context(DataFusionSnafu)
@@ -352,14 +354,36 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
let types = param_types_to_pg_types(&param_types).map_err(convert_err)?;
(types, sql_plan, &Format::UnifiedBinary)
Some(types)
} else {
let param_types = stmt.parameter_types.clone();
(param_types, sql_plan, &Format::UnifiedBinary)
None
};
let param_count = if provided_param_types.is_empty() {
server_inferenced_types
.as_ref()
.map(|types| types.len())
.unwrap_or(0)
} else {
provided_param_types.len()
};
let param_types = (0..param_count)
.map(|i| {
let client_type = provided_param_types.get(i);
// use server type when client provided type is None (oid: 0 or other invalid values)
match client_type {
Some(Some(client_type)) => client_type.clone(),
_ => server_inferenced_types
.as_ref()
.and_then(|types| types.get(i).cloned())
.unwrap_or(Type::UNKNOWN),
}
})
.collect::<Vec<_>>();
if let Some(schema) = &sql_plan.schema {
schema_to_pg(schema, format)
schema_to_pg(schema, &Format::UnifiedBinary)
.map(|fields| DescribeStatementResponse::new(param_types, fields))
.map_err(convert_err)
} else {

View File

@@ -749,7 +749,13 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWireResult<String> {
// the index is managed from portal's parameters count so it's safe to
// unwrap here.
let param_type = portal.statement.parameter_types.get(idx).unwrap();
let param_type = portal
.statement
.parameter_types
.get(idx)
.unwrap()
.as_ref()
.unwrap_or(&Type::UNKNOWN);
match param_type {
&Type::VARCHAR | &Type::TEXT => Ok(format!(
"'{}'",
@@ -828,7 +834,7 @@ pub(super) fn parameters_to_scalar_values(
let mut results = Vec::with_capacity(param_count);
let client_param_types = &portal.statement.parameter_types;
let param_types = plan
let server_param_types = plan
.get_parameter_types()
.context(DataFusionSnafu)
.map_err(convert_err)?
@@ -837,18 +843,12 @@ pub(super) fn parameters_to_scalar_values(
.collect::<HashMap<_, _>>();
for idx in 0..param_count {
let server_type = param_types
let server_type = server_param_types
.get(&format!("${}", idx + 1))
.and_then(|t| t.as_ref());
let client_type = if let Some(client_given_type) = client_param_types.get(idx) {
match (client_given_type, server_type) {
(&Type::UNKNOWN, Some(server_type)) => {
// If client type is unknown, use the server type.
type_gt_to_pg(server_type).map_err(convert_err)?
}
_ => client_given_type.clone(),
}
let client_type = if let Some(Some(client_given_type)) = client_param_types.get(idx) {
client_given_type.clone()
} else if let Some(server_provided_type) = &server_type {
type_gt_to_pg(server_provided_type).map_err(convert_err)?
} else {

View File

@@ -14,6 +14,7 @@
use bytes::BufMut;
use pgwire::types::ToSqlText;
use pgwire::types::format::FormatOptions;
use postgres_types::{IsNull, ToSql, Type};
#[derive(Debug)]
@@ -23,11 +24,12 @@ impl ToSqlText for HexOutputBytea<'_> {
&self,
ty: &Type,
out: &mut bytes::BytesMut,
format_options: &FormatOptions,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
let _ = self.0.to_sql_text(ty, out);
let _ = self.0.to_sql_text(ty, out, format_options);
Ok(IsNull::No)
}
}
@@ -66,6 +68,7 @@ impl ToSqlText for EscapeOutputBytea<'_> {
&self,
_ty: &Type,
out: &mut bytes::BytesMut,
_format_options: &FormatOptions,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
@@ -120,7 +123,9 @@ mod tests {
let expected = b"abcklm*\\251T";
let mut out = bytes::BytesMut::new();
let is_null = input.to_sql_text(&Type::BYTEA, &mut out).unwrap();
let is_null = input
.to_sql_text(&Type::BYTEA, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);
@@ -138,7 +143,9 @@ mod tests {
let expected = b"\\x68656c6c6f2c20776f726c6421";
let mut out = bytes::BytesMut::new();
let is_null = input.to_sql_text(&Type::BYTEA, &mut out).unwrap();
let is_null = input
.to_sql_text(&Type::BYTEA, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);

View File

@@ -15,6 +15,7 @@
use bytes::BufMut;
use chrono::{NaiveDate, NaiveDateTime};
use pgwire::types::ToSqlText;
use pgwire::types::format::FormatOptions;
use postgres_types::{IsNull, ToSql, Type};
use session::session_config::{PGDateOrder, PGDateTimeStyle};
@@ -58,6 +59,7 @@ impl ToSqlText for StylingDate {
&self,
ty: &Type,
out: &mut bytes::BytesMut,
format_options: &FormatOptions,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
@@ -71,7 +73,7 @@ impl ToSqlText for StylingDate {
out.put_slice(fmt.as_bytes());
}
_ => {
self.0.to_sql_text(ty, out)?;
self.0.to_sql_text(ty, out, format_options)?;
}
}
Ok(IsNull::No)
@@ -83,6 +85,7 @@ impl ToSqlText for StylingDateTime {
&self,
ty: &Type,
out: &mut bytes::BytesMut,
format_options: &FormatOptions,
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
@@ -103,7 +106,7 @@ impl ToSqlText for StylingDateTime {
out.put_slice(fmt.as_bytes());
}
_ => {
self.0.to_sql_text(ty, out)?;
self.0.to_sql_text(ty, out, format_options)?;
}
}
Ok(IsNull::No)
@@ -151,7 +154,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::ISO, PGDateOrder::MDY);
let expected = "1997-12-17";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -160,7 +165,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::ISO, PGDateOrder::YMD);
let expected = "1997-12-17";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -169,7 +176,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::ISO, PGDateOrder::DMY);
let expected = "1997-12-17";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -178,7 +187,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::German, PGDateOrder::MDY);
let expected = "17.12.1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -187,7 +198,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::German, PGDateOrder::YMD);
let expected = "17.12.1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -196,7 +209,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::German, PGDateOrder::DMY);
let expected = "17.12.1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -205,7 +220,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::Postgres, PGDateOrder::MDY);
let expected = "12-17-1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -214,7 +231,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::Postgres, PGDateOrder::YMD);
let expected = "12-17-1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -223,7 +242,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::Postgres, PGDateOrder::DMY);
let expected = "17-12-1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -232,7 +253,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::SQL, PGDateOrder::MDY);
let expected = "12/17/1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -241,7 +264,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::SQL, PGDateOrder::YMD);
let expected = "12/17/1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -250,7 +275,9 @@ mod tests {
let styling_date = StylingDate(naive_date, PGDateTimeStyle::SQL, PGDateOrder::DMY);
let expected = "17/12/1997";
let mut out = bytes::BytesMut::new();
let is_null = styling_date.to_sql_text(&Type::DATE, &mut out).unwrap();
let is_null = styling_date
.to_sql_text(&Type::DATE, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
}
@@ -267,7 +294,7 @@ mod tests {
let expected = "2021-09-01 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -278,7 +305,7 @@ mod tests {
let expected = "2021-09-01 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -289,7 +316,7 @@ mod tests {
let expected = "2021-09-01 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -301,7 +328,7 @@ mod tests {
let expected = "01.09.2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -313,7 +340,7 @@ mod tests {
let expected = "01.09.2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -325,7 +352,7 @@ mod tests {
let expected = "01.09.2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -337,7 +364,7 @@ mod tests {
let expected = "Wed Sep 01 12:34:56.789012 2021";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -349,7 +376,7 @@ mod tests {
let expected = "Wed Sep 01 12:34:56.789012 2021";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -361,7 +388,7 @@ mod tests {
let expected = "Wed 01 Sep 12:34:56.789012 2021";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -372,7 +399,7 @@ mod tests {
let expected = "09/01/2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -383,7 +410,7 @@ mod tests {
let expected = "09/01/2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());
@@ -394,7 +421,7 @@ mod tests {
let expected = "01/09/2021 12:34:56.789012";
let mut out = bytes::BytesMut::new();
let is_null = styling_datetime
.to_sql_text(&Type::TIMESTAMP, &mut out)
.to_sql_text(&Type::TIMESTAMP, &mut out, &FormatOptions::default())
.unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(out, expected.as_bytes());

View File

@@ -18,7 +18,8 @@ use bytes::{Buf, BufMut};
use common_time::interval::IntervalFormat;
use common_time::timestamp::TimeUnit;
use common_time::{Duration, IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth};
use pgwire::types::ToSqlText;
use pgwire::types::format::FormatOptions;
use pgwire::types::{FromSqlText, ToSqlText};
use postgres_types::{FromSql, IsNull, ToSql, Type, to_sql_checked};
use crate::error;
@@ -201,6 +202,7 @@ impl ToSqlText for PgInterval {
&self,
ty: &Type,
out: &mut bytes::BytesMut,
_format_options: &FormatOptions,
) -> std::result::Result<postgres_types::IsNull, Box<dyn snafu::Error + Sync + Send>>
where
Self: Sized,
@@ -215,6 +217,28 @@ impl ToSqlText for PgInterval {
}
}
impl<'a> FromSqlText<'a> for PgInterval {
fn from_sql_text(
_ty: &Type,
input: &[u8],
_format_options: &FormatOptions,
) -> std::result::Result<Self, Box<dyn snafu::Error + Sync + Send>>
where
Self: Sized,
{
// only support parsing interval from postgres format
if let Ok(interval) = pg_interval::Interval::from_postgres(str::from_utf8(input)?) {
Ok(PgInterval {
months: interval.months,
days: interval.days,
microseconds: interval.microseconds,
})
} else {
Err("invalid interval format".into())
}
}
}
#[cfg(test)]
mod tests {
use common_time::Duration;

View File

@@ -72,7 +72,7 @@ select host, cpu, memory, jsons, ts from demo where host != 'host3';
+-------+------+--------+------------------------+----------------------------+
| host | cpu | memory | jsons | ts |
+-------+------+--------+------------------------+----------------------------+
| host1 | 66.6 | 1024 | {"foo":"bar"} | 2022-06-15 07:02:37.000000 |
| host1 | 66.6 | 1024.0 | {"foo":"bar"} | 2022-06-15 07:02:37.000000 |
| host2 | 88.8 | 333.3 | {"a":null,"foo":"bar"} | 2022-06-15 07:02:38.000000 |
+-------+------+--------+------------------------+----------------------------+

View File

@@ -72,7 +72,7 @@ select host, cpu, memory, jsons, ts from demo where host != 'host3';
+-------+------+--------+------------------------+----------------------------+
| host | cpu | memory | jsons | ts |
+-------+------+--------+------------------------+----------------------------+
| host1 | 66.6 | 1024 | {"foo":"bar"} | 2022-06-15 07:02:37.000000 |
| host1 | 66.6 | 1024.0 | {"foo":"bar"} | 2022-06-15 07:02:37.000000 |
| host2 | 88.8 | 333.3 | {"a":null,"foo":"bar"} | 2022-06-15 07:02:38.000000 |
+-------+------+--------+------------------------+----------------------------+

View File

@@ -110,7 +110,7 @@ select host, cpu, memory, jsons, ts from demo where host != 'host3';
+-------+------+--------+------------------------+----------------------------+
| host | cpu | memory | jsons | ts |
+-------+------+--------+------------------------+----------------------------+
| host1 | 66.6 | 1024 | {"foo":"bar"} | 2022-06-15 07:02:37.000000 |
| host1 | 66.6 | 1024.0 | {"foo":"bar"} | 2022-06-15 07:02:37.000000 |
| host2 | 88.8 | 333.3 | {"a":null,"foo":"bar"} | 2022-06-15 07:02:38.000000 |
+-------+------+--------+------------------------+----------------------------+