mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-19 22:40:40 +00:00
fix: using uint64 datatype for postgres prepared statement parameters (#7942)
* feat: add support for decimal parameter type, remove string replacement fallback * chore: format * fix: add support for using unsigned bigint in postgres * chore: format toml * refactor: cleanup duplicated code * fix: rescale decimal
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -12080,6 +12080,7 @@ dependencies = [
|
||||
"regex",
|
||||
"reqwest",
|
||||
"rust-embed",
|
||||
"rust_decimal",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"rustls-pki-types",
|
||||
|
||||
@@ -107,6 +107,7 @@ rand.workspace = true
|
||||
regex.workspace = true
|
||||
reqwest.workspace = true
|
||||
rust-embed = { version = "6.6", optional = true, features = ["debug-embed"] }
|
||||
rust_decimal = { workspace = true, features = ["db-postgres"] }
|
||||
rustls = { workspace = true, default-features = false, features = ["ring", "logging", "std", "tls12"] }
|
||||
rustls-pemfile = "2.0"
|
||||
rustls-pki-types = "1.0"
|
||||
|
||||
@@ -456,16 +456,13 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
|
||||
.do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone())
|
||||
.await
|
||||
} else {
|
||||
// manually replace variables in prepared statement when no
|
||||
// logical_plan is generated. This happens when logical plan is not
|
||||
// supported for certain statements.
|
||||
let mut sql = sql_plan.query.clone();
|
||||
for i in 0..portal.parameter_len() {
|
||||
sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
|
||||
}
|
||||
|
||||
// We won't replace params from statement manually any more.
|
||||
// Newer version of datafusion can generate plan for SELECT/INSERT/UPDATE/DELETE.
|
||||
// Only CREATE TABLE and others minor statements cannot generate sql plan,
|
||||
// in this case, we assume these statements will not carry parameters
|
||||
// and execute them directly.
|
||||
self.query_handler
|
||||
.do_query(&sql, query_ctx.clone())
|
||||
.do_query(&sql_plan.query, query_ctx.clone())
|
||||
.await
|
||||
.remove(0)
|
||||
};
|
||||
|
||||
@@ -33,7 +33,7 @@ use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::json::JsonStructureSettings;
|
||||
use datatypes::prelude::{ConcreteDataType, Value};
|
||||
use datatypes::schema::{Schema, SchemaRef};
|
||||
use datatypes::types::{IntervalType, TimestampType, jsonb_to_string};
|
||||
use datatypes::types::{Decimal128Type, IntervalType, TimestampType, jsonb_to_string};
|
||||
use datatypes::value::StructValue;
|
||||
use futures::Stream;
|
||||
use pg_interval::Interval as PgInterval;
|
||||
@@ -43,6 +43,8 @@ use pgwire::api::results::FieldInfo;
|
||||
use pgwire::error::{PgWireError, PgWireResult};
|
||||
use pgwire::types::format::FormatOptions as PgFormatOptions;
|
||||
use query::planner::DfLogicalPlanner;
|
||||
use rust_decimal::Decimal;
|
||||
use rust_decimal::prelude::ToPrimitive;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ResultExt;
|
||||
|
||||
@@ -293,11 +295,11 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
|
||||
// Note that we only support a small amount of pg data types
|
||||
match origin {
|
||||
&Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
|
||||
&Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
|
||||
&Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
|
||||
&Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
|
||||
&Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
|
||||
&Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
|
||||
&Type::NUMERIC => Ok(ConcreteDataType::uint64_datatype()),
|
||||
&Type::VARCHAR | &Type::CHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
|
||||
&Type::TIMESTAMP | &Type::TIMESTAMPTZ => Ok(ConcreteDataType::timestamp_datatype(
|
||||
common_time::timestamp::TimeUnit::Millisecond,
|
||||
)),
|
||||
@@ -305,9 +307,6 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
|
||||
&Type::TIME => Ok(ConcreteDataType::timestamp_datatype(
|
||||
common_time::timestamp::TimeUnit::Microsecond,
|
||||
)),
|
||||
&Type::CHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
|
||||
ConcreteDataType::int8_datatype(),
|
||||
))),
|
||||
&Type::INT2_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
|
||||
ConcreteDataType::int16_datatype(),
|
||||
))),
|
||||
@@ -317,9 +316,12 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
|
||||
&Type::INT8_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
|
||||
ConcreteDataType::int64_datatype(),
|
||||
))),
|
||||
&Type::VARCHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
|
||||
ConcreteDataType::string_datatype(),
|
||||
&Type::NUMERIC_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new(
|
||||
ConcreteDataType::uint64_datatype(),
|
||||
))),
|
||||
&Type::VARCHAR_ARRAY | &Type::CHAR_ARRAY | &Type::TEXT_ARRAY => Ok(
|
||||
ConcreteDataType::list_datatype(Arc::new(ConcreteDataType::string_datatype())),
|
||||
),
|
||||
_ => server_error::InternalSnafu {
|
||||
err_msg: format!("unimplemented datatype {origin:?}"),
|
||||
}
|
||||
@@ -327,63 +329,6 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn parameter_to_string(portal: &Portal<PgSqlPlan>, idx: usize) -> PgWireResult<String> {
|
||||
// the index is managed from portal's parameters count so it's safe to
|
||||
// unwrap here.
|
||||
let param_type = portal
|
||||
.statement
|
||||
.parameter_types
|
||||
.get(idx)
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.unwrap_or(&Type::UNKNOWN);
|
||||
match param_type {
|
||||
&Type::VARCHAR | &Type::TEXT => Ok(format!(
|
||||
"'{}'",
|
||||
portal
|
||||
.parameter::<String>(idx, param_type)?
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
)),
|
||||
&Type::BOOL => Ok(portal
|
||||
.parameter::<bool>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::INT4 => Ok(portal
|
||||
.parameter::<i32>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::INT8 => Ok(portal
|
||||
.parameter::<i64>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::FLOAT4 => Ok(portal
|
||||
.parameter::<f32>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::FLOAT8 => Ok(portal
|
||||
.parameter::<f64>(idx, param_type)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::DATE => Ok(portal
|
||||
.parameter::<NaiveDate>(idx, param_type)?
|
||||
.map(|v| v.format("%Y-%m-%d").to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::TIMESTAMP => Ok(portal
|
||||
.parameter::<NaiveDateTime>(idx, param_type)?
|
||||
.map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::INTERVAL => Ok(portal
|
||||
.parameter::<PgInterval>(idx, param_type)?
|
||||
.map(|v| v.to_sql())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
_ => Err(invalid_parameter_error(
|
||||
"unsupported_parameter_type",
|
||||
Some(param_type.to_string()),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn invalid_parameter_error(msg: &str, detail: Option<String>) -> PgWireError {
|
||||
let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string());
|
||||
error_info.detail = detail;
|
||||
@@ -407,6 +352,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn to_decimal_scalar_value(data: Option<Decimal>, ctype: &Decimal128Type) -> ScalarValue {
|
||||
if let Some(data) = data {
|
||||
let mut value = data;
|
||||
value.rescale(ctype.scale() as u32);
|
||||
|
||||
ScalarValue::Decimal128(Some(value.mantissa()), ctype.precision(), ctype.scale())
|
||||
} else {
|
||||
ScalarValue::Decimal128(None, ctype.precision(), ctype.scale())
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn parameters_to_scalar_values(
|
||||
plan: &LogicalPlan,
|
||||
portal: &Portal<PgSqlPlan>,
|
||||
@@ -442,7 +398,7 @@ pub(super) fn parameters_to_scalar_values(
|
||||
};
|
||||
|
||||
let value = match &client_type {
|
||||
&Type::VARCHAR | &Type::TEXT => {
|
||||
&Type::VARCHAR | &Type::TEXT | &Type::CHAR => {
|
||||
let data = portal.parameter::<String>(idx, &client_type)?;
|
||||
if let Some(server_type) = &server_type {
|
||||
match server_type {
|
||||
@@ -558,6 +514,24 @@ pub(super) fn parameters_to_scalar_values(
|
||||
ScalarValue::Int64(data)
|
||||
}
|
||||
}
|
||||
&Type::NUMERIC => {
|
||||
let data = portal.parameter::<Decimal>(idx, &client_type)?;
|
||||
match &server_type {
|
||||
Some(ConcreteDataType::Decimal128(dt)) => to_decimal_scalar_value(data, dt),
|
||||
Some(st @ ConcreteDataType::Timestamp(unit)) => {
|
||||
to_timestamp_scalar_value(data.and_then(|n| n.to_i64()), unit, st)?
|
||||
}
|
||||
Some(ConcreteDataType::UInt64(_)) | None => {
|
||||
ScalarValue::UInt64(data.and_then(|n| n.to_u64()))
|
||||
}
|
||||
Some(st) => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(format!("Expected: {}, found: {}", st, client_type)),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::FLOAT4 => {
|
||||
let data = portal.parameter::<f32>(idx, &client_type)?;
|
||||
if let Some(server_type) = &server_type {
|
||||
@@ -837,7 +811,67 @@ pub(super) fn parameters_to_scalar_values(
|
||||
ScalarValue::Null
|
||||
}
|
||||
}
|
||||
&Type::VARCHAR_ARRAY => {
|
||||
&Type::NUMERIC_ARRAY => {
|
||||
let data = portal.parameter::<Vec<Option<Decimal>>>(idx, &client_type)?;
|
||||
if let Some(data) = data {
|
||||
let build_u64_list = |data: Vec<Option<Decimal>>| {
|
||||
let values = data
|
||||
.into_iter()
|
||||
.map(|n| ScalarValue::UInt64(n.and_then(|n| n.to_u64())))
|
||||
.collect::<Vec<_>>();
|
||||
ScalarValue::List(ScalarValue::new_list(
|
||||
&values,
|
||||
&ArrowDataType::UInt64,
|
||||
true,
|
||||
))
|
||||
};
|
||||
if let Some(server_type) = &server_type {
|
||||
match server_type {
|
||||
ConcreteDataType::List(list_type) => match list_type.item_type() {
|
||||
ConcreteDataType::UInt64(_) => build_u64_list(data),
|
||||
ConcreteDataType::Decimal128(dt) => {
|
||||
let values = data
|
||||
.into_iter()
|
||||
.map(|n| to_decimal_scalar_value(n, dt))
|
||||
.collect::<Vec<_>>();
|
||||
ScalarValue::List(ScalarValue::new_list(
|
||||
&values,
|
||||
&ArrowDataType::Decimal128(dt.precision(), dt.scale()),
|
||||
true,
|
||||
))
|
||||
}
|
||||
_ => {
|
||||
// the server type is not a list of decimal or uint64
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(format!(
|
||||
"Expected: {}, found: {}",
|
||||
list_type.item_type(),
|
||||
client_type
|
||||
)),
|
||||
));
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
// the server type is not a list
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// server type not provided
|
||||
build_u64_list(data)
|
||||
}
|
||||
} else {
|
||||
ScalarValue::Null
|
||||
}
|
||||
}
|
||||
&Type::VARCHAR_ARRAY | &Type::TEXT_ARRAY | &Type::CHAR_ARRAY => {
|
||||
let data = portal.parameter::<Vec<Option<String>>>(idx, &client_type)?;
|
||||
if let Some(data) = data {
|
||||
let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
|
||||
@@ -1098,6 +1132,7 @@ pub fn format_options_from_query_ctx(query_ctx: &QueryContextRef) -> Arc<PgForma
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::{
|
||||
@@ -1516,4 +1551,26 @@ mod test {
|
||||
panic!("test_invalid_parameter failed");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_decimal_scalar_value() {
|
||||
let dt = Decimal128Type::new(18, 4);
|
||||
|
||||
let d = Decimal::from_str("12345.6789").unwrap();
|
||||
assert_eq!(d.mantissa(), 123456789i128);
|
||||
let scalar = to_decimal_scalar_value(Some(d), &dt);
|
||||
assert_eq!(scalar, ScalarValue::Decimal128(Some(123456789), 18, 4));
|
||||
|
||||
let d = Decimal::from_str("100.5").unwrap();
|
||||
assert_eq!(d.mantissa(), 1005);
|
||||
let scalar = to_decimal_scalar_value(Some(d), &dt);
|
||||
assert_eq!(scalar, ScalarValue::Decimal128(Some(1005000), 18, 4));
|
||||
|
||||
let d = Decimal::from_str("-9876.5432").unwrap();
|
||||
let scalar = to_decimal_scalar_value(Some(d), &dt);
|
||||
assert_eq!(scalar, ScalarValue::Decimal128(Some(-98765432), 18, 4));
|
||||
|
||||
let scalar = to_decimal_scalar_value(None, &dt);
|
||||
assert_eq!(scalar, ScalarValue::Decimal128(None, 18, 4));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,6 +82,7 @@ macro_rules! sql_tests {
|
||||
test_postgres_datestyle,
|
||||
test_postgres_intervalstyle,
|
||||
test_postgres_parameter_inference,
|
||||
test_postgres_uint64_parameter,
|
||||
test_postgres_array_types,
|
||||
test_mysql_prepare_stmt_insert_timestamp,
|
||||
test_declare_fetch_close_cursor,
|
||||
@@ -1300,6 +1301,57 @@ pub async fn test_postgres_parameter_inference(store_type: StorageType) {
|
||||
guard.remove_all().await;
|
||||
}
|
||||
|
||||
pub async fn test_postgres_uint64_parameter(store_type: StorageType) {
|
||||
let (mut guard, fe_pg_server) =
|
||||
setup_pg_server(store_type, "test_postgres_uint64_parameter").await;
|
||||
let addr = fe_pg_server.bind_addr().unwrap().to_string();
|
||||
|
||||
let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
tokio::spawn(async move {
|
||||
connection.await.unwrap();
|
||||
tx.send(()).unwrap();
|
||||
});
|
||||
|
||||
let _ = client
|
||||
.simple_query("create table demo_u64(v bigint unsigned, ts timestamp time index)")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let dt = NaiveDate::from_yo_opt(2015, 100)
|
||||
.unwrap()
|
||||
.and_hms_opt(0, 0, 0)
|
||||
.unwrap();
|
||||
let _ = client
|
||||
.execute(
|
||||
"INSERT INTO demo_u64 VALUES($1, $2)",
|
||||
&[&Decimal::from(123456u64), &dt],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let rows = client
|
||||
.query(
|
||||
"SELECT count(*) FROM demo_u64 WHERE v = $1",
|
||||
&[&Decimal::from(123456u64)],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(1, rows.len());
|
||||
let count: i64 = rows[0].get(0);
|
||||
assert_eq!(count, 1);
|
||||
|
||||
drop(client);
|
||||
rx.await.unwrap();
|
||||
|
||||
let _ = fe_pg_server.shutdown().await;
|
||||
guard.remove_all().await;
|
||||
}
|
||||
|
||||
pub async fn test_mysql_async_timestamp(store_type: StorageType) {
|
||||
use mysql_async::prelude::*;
|
||||
use time::PrimitiveDateTime;
|
||||
|
||||
Reference in New Issue
Block a user