diff --git a/Cargo.lock b/Cargo.lock index edb8ce04d4..872095752b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12080,6 +12080,7 @@ dependencies = [ "regex", "reqwest", "rust-embed", + "rust_decimal", "rustls", "rustls-pemfile", "rustls-pki-types", diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 2d68f17699..46a51f1280 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -107,6 +107,7 @@ rand.workspace = true regex.workspace = true reqwest.workspace = true rust-embed = { version = "6.6", optional = true, features = ["debug-embed"] } +rust_decimal = { workspace = true, features = ["db-postgres"] } rustls = { workspace = true, default-features = false, features = ["ring", "logging", "std", "tls12"] } rustls-pemfile = "2.0" rustls-pki-types = "1.0" diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 2b84b3aa30..94363b06eb 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -456,16 +456,13 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { .do_exec_plan(sql_plan.statement.clone(), plan, query_ctx.clone()) .await } else { - // manually replace variables in prepared statement when no - // logical_plan is generated. This happens when logical plan is not - // supported for certain statements. - let mut sql = sql_plan.query.clone(); - for i in 0..portal.parameter_len() { - sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?); - } - + // We won't replace params from statement manually any more. + // Newer version of datafusion can generate plan for SELECT/INSERT/UPDATE/DELETE. + // Only CREATE TABLE and others minor statements cannot generate sql plan, + // in this case, we assume these statements will not carry parameters + // and execute them directly. self.query_handler - .do_query(&sql, query_ctx.clone()) + .do_query(&sql_plan.query, query_ctx.clone()) .await .remove(0) }; diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index d4d15ef64a..203e477c6f 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -33,7 +33,7 @@ use datatypes::arrow::datatypes::DataType as ArrowDataType; use datatypes::json::JsonStructureSettings; use datatypes::prelude::{ConcreteDataType, Value}; use datatypes::schema::{Schema, SchemaRef}; -use datatypes::types::{IntervalType, TimestampType, jsonb_to_string}; +use datatypes::types::{Decimal128Type, IntervalType, TimestampType, jsonb_to_string}; use datatypes::value::StructValue; use futures::Stream; use pg_interval::Interval as PgInterval; @@ -43,6 +43,8 @@ use pgwire::api::results::FieldInfo; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::types::format::FormatOptions as PgFormatOptions; use query::planner::DfLogicalPlanner; +use rust_decimal::Decimal; +use rust_decimal::prelude::ToPrimitive; use session::context::QueryContextRef; use snafu::ResultExt; @@ -293,11 +295,11 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result { // Note that we only support a small amount of pg data types match origin { &Type::BOOL => Ok(ConcreteDataType::boolean_datatype()), - &Type::CHAR => Ok(ConcreteDataType::int8_datatype()), &Type::INT2 => Ok(ConcreteDataType::int16_datatype()), &Type::INT4 => Ok(ConcreteDataType::int32_datatype()), &Type::INT8 => Ok(ConcreteDataType::int64_datatype()), - &Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()), + &Type::NUMERIC => Ok(ConcreteDataType::uint64_datatype()), + &Type::VARCHAR | &Type::CHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()), &Type::TIMESTAMP | &Type::TIMESTAMPTZ => Ok(ConcreteDataType::timestamp_datatype( common_time::timestamp::TimeUnit::Millisecond, )), @@ -305,9 +307,6 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result { &Type::TIME => Ok(ConcreteDataType::timestamp_datatype( common_time::timestamp::TimeUnit::Microsecond, )), - &Type::CHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new( - ConcreteDataType::int8_datatype(), - ))), &Type::INT2_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new( ConcreteDataType::int16_datatype(), ))), @@ -317,9 +316,12 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result { &Type::INT8_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new( ConcreteDataType::int64_datatype(), ))), - &Type::VARCHAR_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new( - ConcreteDataType::string_datatype(), + &Type::NUMERIC_ARRAY => Ok(ConcreteDataType::list_datatype(Arc::new( + ConcreteDataType::uint64_datatype(), ))), + &Type::VARCHAR_ARRAY | &Type::CHAR_ARRAY | &Type::TEXT_ARRAY => Ok( + ConcreteDataType::list_datatype(Arc::new(ConcreteDataType::string_datatype())), + ), _ => server_error::InternalSnafu { err_msg: format!("unimplemented datatype {origin:?}"), } @@ -327,63 +329,6 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result { } } -pub(super) fn parameter_to_string(portal: &Portal, idx: usize) -> PgWireResult { - // the index is managed from portal's parameters count so it's safe to - // unwrap here. - let param_type = portal - .statement - .parameter_types - .get(idx) - .unwrap() - .as_ref() - .unwrap_or(&Type::UNKNOWN); - match param_type { - &Type::VARCHAR | &Type::TEXT => Ok(format!( - "'{}'", - portal - .parameter::(idx, param_type)? - .as_deref() - .unwrap_or("") - )), - &Type::BOOL => Ok(portal - .parameter::(idx, param_type)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::INT4 => Ok(portal - .parameter::(idx, param_type)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::INT8 => Ok(portal - .parameter::(idx, param_type)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::FLOAT4 => Ok(portal - .parameter::(idx, param_type)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::FLOAT8 => Ok(portal - .parameter::(idx, param_type)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::DATE => Ok(portal - .parameter::(idx, param_type)? - .map(|v| v.format("%Y-%m-%d").to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::TIMESTAMP => Ok(portal - .parameter::(idx, param_type)? - .map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::INTERVAL => Ok(portal - .parameter::(idx, param_type)? - .map(|v| v.to_sql()) - .unwrap_or_else(|| "".to_owned())), - _ => Err(invalid_parameter_error( - "unsupported_parameter_type", - Some(param_type.to_string()), - )), - } -} - pub(super) fn invalid_parameter_error(msg: &str, detail: Option) -> PgWireError { let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string()); error_info.detail = detail; @@ -407,6 +352,17 @@ where } } +fn to_decimal_scalar_value(data: Option, ctype: &Decimal128Type) -> ScalarValue { + if let Some(data) = data { + let mut value = data; + value.rescale(ctype.scale() as u32); + + ScalarValue::Decimal128(Some(value.mantissa()), ctype.precision(), ctype.scale()) + } else { + ScalarValue::Decimal128(None, ctype.precision(), ctype.scale()) + } +} + pub(super) fn parameters_to_scalar_values( plan: &LogicalPlan, portal: &Portal, @@ -442,7 +398,7 @@ pub(super) fn parameters_to_scalar_values( }; let value = match &client_type { - &Type::VARCHAR | &Type::TEXT => { + &Type::VARCHAR | &Type::TEXT | &Type::CHAR => { let data = portal.parameter::(idx, &client_type)?; if let Some(server_type) = &server_type { match server_type { @@ -558,6 +514,24 @@ pub(super) fn parameters_to_scalar_values( ScalarValue::Int64(data) } } + &Type::NUMERIC => { + let data = portal.parameter::(idx, &client_type)?; + match &server_type { + Some(ConcreteDataType::Decimal128(dt)) => to_decimal_scalar_value(data, dt), + Some(st @ ConcreteDataType::Timestamp(unit)) => { + to_timestamp_scalar_value(data.and_then(|n| n.to_i64()), unit, st)? + } + Some(ConcreteDataType::UInt64(_)) | None => { + ScalarValue::UInt64(data.and_then(|n| n.to_u64())) + } + Some(st) => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", st, client_type)), + )); + } + } + } &Type::FLOAT4 => { let data = portal.parameter::(idx, &client_type)?; if let Some(server_type) = &server_type { @@ -837,7 +811,67 @@ pub(super) fn parameters_to_scalar_values( ScalarValue::Null } } - &Type::VARCHAR_ARRAY => { + &Type::NUMERIC_ARRAY => { + let data = portal.parameter::>>(idx, &client_type)?; + if let Some(data) = data { + let build_u64_list = |data: Vec>| { + let values = data + .into_iter() + .map(|n| ScalarValue::UInt64(n.and_then(|n| n.to_u64()))) + .collect::>(); + ScalarValue::List(ScalarValue::new_list( + &values, + &ArrowDataType::UInt64, + true, + )) + }; + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::List(list_type) => match list_type.item_type() { + ConcreteDataType::UInt64(_) => build_u64_list(data), + ConcreteDataType::Decimal128(dt) => { + let values = data + .into_iter() + .map(|n| to_decimal_scalar_value(n, dt)) + .collect::>(); + ScalarValue::List(ScalarValue::new_list( + &values, + &ArrowDataType::Decimal128(dt.precision(), dt.scale()), + true, + )) + } + _ => { + // the server type is not a list of decimal or uint64 + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!( + "Expected: {}, found: {}", + list_type.item_type(), + client_type + )), + )); + } + }, + _ => { + // the server type is not a list + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )); + } + } + } else { + // server type not provided + build_u64_list(data) + } + } else { + ScalarValue::Null + } + } + &Type::VARCHAR_ARRAY | &Type::TEXT_ARRAY | &Type::CHAR_ARRAY => { let data = portal.parameter::>>(idx, &client_type)?; if let Some(data) = data { let values = data.into_iter().map(|i| i.into()).collect::>(); @@ -1098,6 +1132,7 @@ pub fn format_options_from_query_ctx(query_ctx: &QueryContextRef) -> Arc