fix: using uint64 datatype for postgres prepared statement parameters (#7942)

* feat: add support for decimal parameter type, remove string replacement fallback

* chore: format

* fix: add support for using unsigned bigint in postgres

* chore: format toml

* refactor: cleanup duplicated code

* fix: rescale decimal
This commit is contained in:
Ning Sun
2026-04-10 15:56:33 +08:00
committed by GitHub
parent fd94f55193
commit 59021ce83b
5 changed files with 184 additions and 76 deletions

1
Cargo.lock generated
View File

@@ -12080,6 +12080,7 @@ dependencies = [
"regex",
"reqwest",
"rust-embed",
"rust_decimal",
"rustls",
"rustls-pemfile",
"rustls-pki-types",

View File

@@ -107,6 +107,7 @@ rand.workspace = true
regex.workspace = true
reqwest.workspace = true
rust-embed = { version = "6.6", optional = true, features = ["debug-embed"] }
rust_decimal = { workspace = true, features = ["db-postgres"] }
rustls = { workspace = true, default-features = false, features = ["ring", "logging", "std", "tls12"] }
rustls-pemfile = "2.0"
rustls-pki-types = "1.0"

View File

@@ -456,16 +456,13 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
.do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
.await
} else {
// manually replace variables in prepared statement when no
// logical_plan is generated. This happens when logical plan is not
// supported for certain statements.
let mut sql = sql_plan.query.clone();
for i in 0..portal.parameter_len() {
sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
}
// We won't replace params from statement manually any more.
// Newer version of datafusion can generate plan for SELECT/INSERT/UPDATE/DELETE.
// Only CREATE TABLE and others minor statements cannot generate sql plan,
// in this case, we assume these statements will not carry parameters
// and execute them directly.
self.query_handler
.do_query(&sql, query_ctx.clone())
.do_query(&sql_plan.query, query_ctx.clone())
.await
.remove(0)
};

View File

@@ -33,7 +33,7 @@ use datatypes::arrow::datatypes::DataType as ArrowDataType;
use datatypes::json::JsonStructureSettings;
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::{Schema, SchemaRef};
use datatypes::types::{IntervalType, TimestampType, jsonb_to_string};
use datatypes::types::{Decimal128Type, IntervalType, TimestampType, jsonb_to_string};
use datatypes::value::StructValue;
use futures::Stream;
use pg_interval::Interval as PgInterval;
@@ -43,6 +43,8 @@ use pgwire::api::results::FieldInfo;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::types::format::FormatOptions as PgFormatOptions;
use query::planner::DfLogicalPlanner;
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use session::context::QueryContextRef;
use snafu::ResultExt;
@@ -293,11 +295,11 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
// Note that we only support a small amount of pg data types
match origin {
&Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
&Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
&Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
&Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
&Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
&Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
&Type::NUMERIC => Ok(ConcreteDataType::uint64_datatype()),
&Type::VARCHAR | &Type::CHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
&Type::TIMESTAMP | &Type::TIMESTAMPTZ => Ok(ConcreteDataType::timestamp_datatype(
common_time::timestamp::TimeUnit::Millisecond,
)),
@@ -305,9 +307,6 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
&Type::TIME => Ok(ConcreteDataType::timestamp_datatype(
common_time::timestamp::TimeUnit::Microsecond,
)),
&Type::CHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
ConcreteDataType::int8_datatype(),
))),
&Type::INT2_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
ConcreteDataType::int16_datatype(),
))),
@@ -317,9 +316,12 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
&Type::INT8_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
ConcreteDataType::int64_datatype(),
))),
&Type::VARCHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
ConcreteDataType::string_datatype(),
&Type::NUMERIC_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
ConcreteDataType::uint64_datatype(),
))),
&Type::VARCHAR_ARRAY | &Type::CHAR_ARRAY | &Type::TEXT_ARRAY => Ok(
ConcreteDataType::list_datatype(Arc::new(ConcreteDataType::string_datatype())),
),
_ => server_error::InternalSnafu {
err_msg: format!("unimplemented datatype {origin:?}"),
}
@@ -327,63 +329,6 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
}
}
pub(super) fn parameter_to_string(portal: &Portal<PgSqlPlan>, idx: usize) -> PgWireResult<String> {
// the index is managed from portal's parameters count so it's safe to
// unwrap here.
let param_type = portal
.statement
.parameter_types
.get(idx)
.unwrap()
.as_ref()
.unwrap_or(&Type::UNKNOWN);
match param_type {
&Type::VARCHAR | &Type::TEXT => Ok(format!(
"'{}'",
portal
.parameter::<String>(idx, param_type)?
.as_deref()
.unwrap_or("")
)),
&Type::BOOL => Ok(portal
.parameter::<bool>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INT4 => Ok(portal
.parameter::<i32>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INT8 => Ok(portal
.parameter::<i64>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::FLOAT4 => Ok(portal
.parameter::<f32>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::FLOAT8 => Ok(portal
.parameter::<f64>(idx, param_type)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::DATE => Ok(portal
.parameter::<NaiveDate>(idx, param_type)?
.map(|v| v.format("%Y-%m-%d").to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::TIMESTAMP => Ok(portal
.parameter::<NaiveDateTime>(idx, param_type)?
.map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INTERVAL => Ok(portal
.parameter::<PgInterval>(idx, param_type)?
.map(|v| v.to_sql())
.unwrap_or_else(|| "".to_owned())),
_ => Err(invalid_parameter_error(
"unsupported_parameter_type",
Some(param_type.to_string()),
)),
}
}
pub(super) fn invalid_parameter_error(msg: &str, detail: Option<String>) -> PgWireError {
let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string());
error_info.detail = detail;
@@ -407,6 +352,17 @@ where
}
}
fn to_decimal_scalar_value(data: Option<Decimal>, ctype: &Decimal128Type) -> ScalarValue {
if let Some(data) = data {
let mut value = data;
value.rescale(ctype.scale() as u32);
ScalarValue::Decimal128(Some(value.mantissa()), ctype.precision(), ctype.scale())
} else {
ScalarValue::Decimal128(None, ctype.precision(), ctype.scale())
}
}
pub(super) fn parameters_to_scalar_values(
plan: &LogicalPlan,
portal: &Portal<PgSqlPlan>,
@@ -442,7 +398,7 @@ pub(super) fn parameters_to_scalar_values(
};
let value = match &client_type {
&Type::VARCHAR | &Type::TEXT => {
&Type::VARCHAR | &Type::TEXT | &Type::CHAR => {
let data = portal.parameter::<String>(idx, &client_type)?;
if let Some(server_type) = &server_type {
match server_type {
@@ -558,6 +514,24 @@ pub(super) fn parameters_to_scalar_values(
ScalarValue::Int64(data)
}
}
&Type::NUMERIC => {
let data = portal.parameter::<Decimal>(idx, &client_type)?;
match &server_type {
Some(ConcreteDataType::Decimal128(dt)) => to_decimal_scalar_value(data, dt),
Some(st @ ConcreteDataType::Timestamp(unit)) => {
to_timestamp_scalar_value(data.and_then(|n| n.to_i64()), unit, st)?
}
Some(ConcreteDataType::UInt64(_)) | None => {
ScalarValue::UInt64(data.and_then(|n| n.to_u64()))
}
Some(st) => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", st, client_type)),
));
}
}
}
&Type::FLOAT4 => {
let data = portal.parameter::<f32>(idx, &client_type)?;
if let Some(server_type) = &server_type {
@@ -837,7 +811,67 @@ pub(super) fn parameters_to_scalar_values(
ScalarValue::Null
}
}
&Type::VARCHAR_ARRAY => {
&Type::NUMERIC_ARRAY => {
let data = portal.parameter::<Vec<Option<Decimal>>>(idx, &client_type)?;
if let Some(data) = data {
let build_u64_list = |data: Vec<Option<Decimal>>| {
let values = data
.into_iter()
.map(|n| ScalarValue::UInt64(n.and_then(|n| n.to_u64())))
.collect::<Vec<_>>();
ScalarValue::List(ScalarValue::new_list(
&values,
&ArrowDataType::UInt64,
true,
))
};
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::List(list_type) => match list_type.item_type() {
ConcreteDataType::UInt64(_) => build_u64_list(data),
ConcreteDataType::Decimal128(dt) => {
let values = data
.into_iter()
.map(|n| to_decimal_scalar_value(n, dt))
.collect::<Vec<_>>();
ScalarValue::List(ScalarValue::new_list(
&values,
&ArrowDataType::Decimal128(dt.precision(), dt.scale()),
true,
))
}
_ => {
// the server type is not a list of decimal or uint64
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!(
"Expected: {}, found: {}",
list_type.item_type(),
client_type
)),
));
}
},
_ => {
// the server type is not a list
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!(
"Expected: {}, found: {}",
server_type, client_type
)),
));
}
}
} else {
// server type not provided
build_u64_list(data)
}
} else {
ScalarValue::Null
}
}
&Type::VARCHAR_ARRAY | &Type::TEXT_ARRAY | &Type::CHAR_ARRAY => {
let data = portal.parameter::<Vec<Option<String>>>(idx, &client_type)?;
if let Some(data) = data {
let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
@@ -1098,6 +1132,7 @@ pub fn format_options_from_query_ctx(query_ctx: &QueryContextRef) -> Arc<PgForma
#[cfg(test)]
mod test {
use std::str::FromStr;
use std::sync::Arc;
use arrow::array::{
@@ -1516,4 +1551,26 @@ mod test {
panic!("test_invalid_parameter failed");
}
}
#[test]
fn test_to_decimal_scalar_value() {
let dt = Decimal128Type::new(18, 4);
let d = Decimal::from_str("12345.6789").unwrap();
assert_eq!(d.mantissa(), 123456789i128);
let scalar = to_decimal_scalar_value(Some(d), &dt);
assert_eq!(scalar, ScalarValue::Decimal128(Some(123456789), 18, 4));
let d = Decimal::from_str("100.5").unwrap();
assert_eq!(d.mantissa(), 1005);
let scalar = to_decimal_scalar_value(Some(d), &dt);
assert_eq!(scalar, ScalarValue::Decimal128(Some(1005000), 18, 4));
let d = Decimal::from_str("-9876.5432").unwrap();
let scalar = to_decimal_scalar_value(Some(d), &dt);
assert_eq!(scalar, ScalarValue::Decimal128(Some(-98765432), 18, 4));
let scalar = to_decimal_scalar_value(None, &dt);
assert_eq!(scalar, ScalarValue::Decimal128(None, 18, 4));
}
}

View File

@@ -82,6 +82,7 @@ macro_rules! sql_tests {
test_postgres_datestyle,
test_postgres_intervalstyle,
test_postgres_parameter_inference,
test_postgres_uint64_parameter,
test_postgres_array_types,
test_mysql_prepare_stmt_insert_timestamp,
test_declare_fetch_close_cursor,
@@ -1300,6 +1301,57 @@ pub async fn test_postgres_parameter_inference(store_type: StorageType) {
guard.remove_all().await;
}
pub async fn test_postgres_uint64_parameter(store_type: StorageType) {
let (mut guard, fe_pg_server) =
setup_pg_server(store_type, "test_postgres_uint64_parameter").await;
let addr = fe_pg_server.bind_addr().unwrap().to_string();
let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
.await
.unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
connection.await.unwrap();
tx.send(()).unwrap();
});
let _ = client
.simple_query("create table demo_u64(v bigint unsigned, ts timestamp time index)")
.await
.unwrap();
let dt = NaiveDate::from_yo_opt(2015, 100)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap();
let _ = client
.execute(
"INSERT INTO demo_u64 VALUES($1, $2)",
&[&Decimal::from(123456u64), &dt],
)
.await
.unwrap();
let rows = client
.query(
"SELECT count(*) FROM demo_u64 WHERE v = $1",
&[&Decimal::from(123456u64)],
)
.await
.unwrap();
assert_eq!(1, rows.len());
let count: i64 = rows[0].get(0);
assert_eq!(count, 1);
drop(client);
rx.await.unwrap();
let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}
pub async fn test_mysql_async_timestamp(store_type: StorageType) {
use mysql_async::prelude::*;
use time::PrimitiveDateTime;