diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 44587783be..7221cfc6e2 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -57,11 +57,10 @@ pub mod tls; /// Cached SQL and logical plan for database interfaces #[derive(Clone, Debug)] pub struct SqlPlan { - query: String, - // Store the parsed statement to determine if it is a query and whether to track it. - statement: Option, - plan: Option, - schema: Option, + pub(crate) query: String, + pub(crate) statement: Option, + pub(crate) plan: Option, + pub(crate) schema: Option, } /// Install the ring crypto provider for rustls process-wide. see: diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 94363b06eb..3275ebfdb8 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -300,8 +300,8 @@ impl DefaultQueryParser { /// A container type of parse result types #[derive(Clone, Debug)] pub struct PgSqlPlan { - plan: SqlPlan, - copy_to_stdout_format: Option, + pub(crate) plan: SqlPlan, + pub(crate) copy_to_stdout_format: Option, } #[async_trait] diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index 203e477c6f..33fc41164b 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -363,6 +363,13 @@ fn to_decimal_scalar_value(data: Option, ctype: &Decimal128Type) -> Sca } } +fn numeric_out_of_range_error(value: impl std::fmt::Display) -> PgWireError { + invalid_parameter_error( + "numeric_value_out_of_range", + Some(format!("value {} is out of range for target type", value)), + ) +} + pub(super) fn parameters_to_scalar_values( plan: &LogicalPlan, portal: &Portal, @@ -440,14 +447,29 @@ pub(super) fn parameters_to_scalar_values( let data = portal.parameter::(idx, &client_type)?; if let Some(server_type) = &server_type { match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int8(_) => ScalarValue::Int8( + data.map(|n| n.to_i8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), ConcreteDataType::Int16(_) => ScalarValue::Int16(data), ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8( + data.map(|n| n.to_u8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16( + data.map(|n| n.to_u16().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32( + data.map(|n| n.to_u32().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64( + data.map(|n| n.to_u64().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), ConcreteDataType::Timestamp(unit) => { to_timestamp_scalar_value(data, unit, server_type)? } @@ -466,14 +488,32 @@ pub(super) fn parameters_to_scalar_values( let data = portal.parameter::(idx, &client_type)?; if let Some(server_type) = &server_type { match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), - ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), + ConcreteDataType::Int8(_) => ScalarValue::Int8( + data.map(|n| n.to_i8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Int16(_) => ScalarValue::Int16( + data.map(|n| n.to_i16().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), ConcreteDataType::Int32(_) => ScalarValue::Int32(data), ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8( + data.map(|n| n.to_u8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16( + data.map(|n| n.to_u16().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32( + data.map(|n| n.to_u32().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64( + data.map(|n| n.to_u64().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), ConcreteDataType::Timestamp(unit) => { to_timestamp_scalar_value(data, unit, server_type)? } @@ -492,14 +532,35 @@ pub(super) fn parameters_to_scalar_values( let data = portal.parameter::(idx, &client_type)?; if let Some(server_type) = &server_type { match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), - ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), - ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), + ConcreteDataType::Int8(_) => ScalarValue::Int8( + data.map(|n| n.to_i8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Int16(_) => ScalarValue::Int16( + data.map(|n| n.to_i16().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Int32(_) => ScalarValue::Int32( + data.map(|n| n.to_i32().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), ConcreteDataType::Int64(_) => ScalarValue::Int64(data), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8( + data.map(|n| n.to_u8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16( + data.map(|n| n.to_u16().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32( + data.map(|n| n.to_u32().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64( + data.map(|n| n.to_u64().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), ConcreteDataType::Timestamp(unit) => { to_timestamp_scalar_value(data, unit, server_type)? } @@ -536,14 +597,38 @@ pub(super) fn parameters_to_scalar_values( let data = portal.parameter::(idx, &client_type)?; if let Some(server_type) = &server_type { match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), - ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), - ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), - ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Int8(_) => ScalarValue::Int8( + data.map(|n| n.to_i8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Int16(_) => ScalarValue::Int16( + data.map(|n| n.to_i16().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Int32(_) => ScalarValue::Int32( + data.map(|n| n.to_i32().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Int64(_) => ScalarValue::Int64( + data.map(|n| n.to_i64().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8( + data.map(|n| n.to_u8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16( + data.map(|n| n.to_u16().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32( + data.map(|n| n.to_u32().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64( + data.map(|n| n.to_u64().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), ConcreteDataType::Float32(_) => ScalarValue::Float32(data), ConcreteDataType::Float64(_) => { ScalarValue::Float64(data.map(|n| n as f64)) @@ -563,17 +648,42 @@ pub(super) fn parameters_to_scalar_values( let data = portal.parameter::(idx, &client_type)?; if let Some(server_type) = &server_type { match server_type { - ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), - ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), - ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), - ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), - ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), - ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), - ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), - ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), - ConcreteDataType::Float32(_) => { - ScalarValue::Float32(data.map(|n| n as f32)) - } + ConcreteDataType::Int8(_) => ScalarValue::Int8( + data.map(|n| n.to_i8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Int16(_) => ScalarValue::Int16( + data.map(|n| n.to_i16().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Int32(_) => ScalarValue::Int32( + data.map(|n| n.to_i32().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Int64(_) => ScalarValue::Int64( + data.map(|n| n.to_i64().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8( + data.map(|n| n.to_u8().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16( + data.map(|n| n.to_u16().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32( + data.map(|n| n.to_u32().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64( + data.map(|n| n.to_u64().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), + ConcreteDataType::Float32(_) => ScalarValue::Float32( + data.map(|n| n.to_f32().ok_or_else(|| numeric_out_of_range_error(n))) + .transpose()?, + ), ConcreteDataType::Float64(_) => ScalarValue::Float64(data), _ => { return Err(invalid_parameter_error( @@ -1139,6 +1249,9 @@ mod test { Float64Builder, Int64Builder, ListBuilder, StringBuilder, TimestampSecondBuilder, }; use arrow_schema::{Field, IntervalUnit}; + use bytes::Bytes; + use datafusion_expr::expr::Placeholder; + use datafusion_expr::{Expr, LogicalPlanBuilder}; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{ BinaryVector, BooleanVector, DateVector, Float32Vector, Float64Vector, Int8Vector, @@ -1148,10 +1261,15 @@ mod test { }; use futures::{StreamExt as FuturesStreamExt, stream}; use pgwire::api::Type; + use pgwire::api::portal::{Format, Portal}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo}; + use pgwire::api::stmt::StoredStatement; + use pgwire::messages::extendedquery::Bind; use session::context::QueryContextBuilder; use super::*; + use crate::SqlPlan; + use crate::postgres::handler::PgSqlPlan; #[test] fn test_schema_convert() { @@ -1573,4 +1691,343 @@ mod test { let scalar = to_decimal_scalar_value(None, &dt); assert_eq!(scalar, ScalarValue::Decimal128(None, 18, 4)); } + + fn s(v: &str) -> Option { + Some(v.to_string()) + } + + fn typed_param(id: &str, dt: DataType) -> Expr { + Expr::Placeholder(Placeholder::new_with_field( + id.to_string(), + Some(Arc::new(arrow_schema::Field::new(id, dt, true))), + )) + } + + fn build_plan_with_params(params: Vec<(&str, DataType)>) -> LogicalPlan { + let exprs: Vec = params + .into_iter() + .map(|(id, dt)| typed_param(id, dt)) + .collect(); + LogicalPlanBuilder::empty(true) + .project(exprs) + .unwrap() + .build() + .unwrap() + } + + fn make_portal( + client_param_types: Vec>, + param_data: Vec>, + ) -> Portal { + let bind = Bind::new( + None, + None, + vec![], + param_data + .into_iter() + .map(|opt| opt.map(Bytes::from)) + .collect(), + vec![], + ); + let statement = Arc::new(StoredStatement::new( + String::new(), + PgSqlPlan { + plan: SqlPlan { + query: String::new(), + statement: None, + plan: None, + schema: None, + }, + copy_to_stdout_format: None, + }, + client_param_types, + )); + Portal::try_new(&bind, statement).unwrap() + } + + #[test] + fn test_int2_coerce_in_range() { + let plan = build_plan_with_params(vec![ + ("$1", DataType::Int8), + ("$2", DataType::Int16), + ("$3", DataType::Int32), + ("$4", DataType::Int64), + ("$5", DataType::UInt8), + ("$6", DataType::UInt16), + ("$7", DataType::UInt32), + ("$8", DataType::UInt64), + ]); + let portal = make_portal( + vec![ + Some(Type::INT2), + Some(Type::INT2), + Some(Type::INT2), + Some(Type::INT2), + Some(Type::INT2), + Some(Type::INT2), + Some(Type::INT2), + Some(Type::INT2), + ], + vec![ + s("100"), + s("100"), + s("100"), + s("100"), + s("100"), + s("100"), + s("100"), + s("100"), + ], + ); + + let values = parameters_to_scalar_values(&plan, &portal).unwrap(); + assert_eq!(values[0], ScalarValue::Int8(Some(100))); + assert_eq!(values[1], ScalarValue::Int16(Some(100))); + assert_eq!(values[2], ScalarValue::Int32(Some(100))); + assert_eq!(values[3], ScalarValue::Int64(Some(100))); + assert_eq!(values[4], ScalarValue::UInt8(Some(100))); + assert_eq!(values[5], ScalarValue::UInt16(Some(100))); + assert_eq!(values[6], ScalarValue::UInt32(Some(100))); + assert_eq!(values[7], ScalarValue::UInt64(Some(100))); + } + + #[test] + fn test_int2_coerce_out_of_range() { + let plan = build_plan_with_params(vec![("$1", DataType::Int8)]); + let portal = make_portal(vec![Some(Type::INT2)], vec![s("200")]); + let result = parameters_to_scalar_values(&plan, &portal); + assert!(result.is_err()); + } + + #[test] + fn test_int2_coerce_negative_to_unsigned_out_of_range() { + let plan = build_plan_with_params(vec![("$1", DataType::UInt64)]); + let portal = make_portal(vec![Some(Type::INT2)], vec![s("-1")]); + let result = parameters_to_scalar_values(&plan, &portal); + assert!(result.is_err()); + } + + #[test] + fn test_int4_coerce_in_range() { + let plan = build_plan_with_params(vec![ + ("$1", DataType::Int8), + ("$2", DataType::Int16), + ("$3", DataType::Int32), + ("$4", DataType::Int64), + ("$5", DataType::UInt8), + ("$6", DataType::UInt16), + ("$7", DataType::UInt32), + ("$8", DataType::UInt64), + ]); + let portal = make_portal( + vec![ + Some(Type::INT4), + Some(Type::INT4), + Some(Type::INT4), + Some(Type::INT4), + Some(Type::INT4), + Some(Type::INT4), + Some(Type::INT4), + Some(Type::INT4), + ], + vec![ + s("100"), + s("1000"), + s("100000"), + s("100000"), + s("200"), + s("1000"), + s("100000"), + s("100000"), + ], + ); + + let values = parameters_to_scalar_values(&plan, &portal).unwrap(); + assert_eq!(values[0], ScalarValue::Int8(Some(100))); + assert_eq!(values[1], ScalarValue::Int16(Some(1000))); + assert_eq!(values[2], ScalarValue::Int32(Some(100000))); + assert_eq!(values[3], ScalarValue::Int64(Some(100000))); + assert_eq!(values[4], ScalarValue::UInt8(Some(200))); + assert_eq!(values[5], ScalarValue::UInt16(Some(1000))); + assert_eq!(values[6], ScalarValue::UInt32(Some(100000))); + assert_eq!(values[7], ScalarValue::UInt64(Some(100000))); + } + + #[test] + fn test_int4_coerce_out_of_range() { + let plan = build_plan_with_params(vec![("$1", DataType::Int8)]); + let portal = make_portal(vec![Some(Type::INT4)], vec![s("200")]); + let result = parameters_to_scalar_values(&plan, &portal); + assert!(result.is_err()); + } + + #[test] + fn test_int4_coerce_i32_max_to_i16_out_of_range() { + let plan = build_plan_with_params(vec![("$1", DataType::Int16)]); + let portal = make_portal(vec![Some(Type::INT4)], vec![Some(i32::MAX.to_string())]); + let result = parameters_to_scalar_values(&plan, &portal); + assert!(result.is_err()); + } + + #[test] + fn test_int8_coerce_in_range() { + let plan = build_plan_with_params(vec![ + ("$1", DataType::Int8), + ("$2", DataType::Int16), + ("$3", DataType::Int32), + ("$4", DataType::Int64), + ("$5", DataType::UInt8), + ("$6", DataType::UInt16), + ("$7", DataType::UInt32), + ("$8", DataType::UInt64), + ]); + let portal = make_portal( + vec![ + Some(Type::INT8), + Some(Type::INT8), + Some(Type::INT8), + Some(Type::INT8), + Some(Type::INT8), + Some(Type::INT8), + Some(Type::INT8), + Some(Type::INT8), + ], + vec![ + s("100"), + s("1000"), + s("100000"), + s("100000"), + s("200"), + s("1000"), + s("3000000000"), + s("3000000000"), + ], + ); + + let values = parameters_to_scalar_values(&plan, &portal).unwrap(); + assert_eq!(values[0], ScalarValue::Int8(Some(100))); + assert_eq!(values[1], ScalarValue::Int16(Some(1000))); + assert_eq!(values[2], ScalarValue::Int32(Some(100000))); + assert_eq!(values[3], ScalarValue::Int64(Some(100000))); + assert_eq!(values[4], ScalarValue::UInt8(Some(200))); + assert_eq!(values[5], ScalarValue::UInt16(Some(1000))); + assert_eq!(values[6], ScalarValue::UInt32(Some(3000000000))); + assert_eq!(values[7], ScalarValue::UInt64(Some(3000000000))); + } + + #[test] + fn test_int8_coerce_out_of_range() { + let plan = build_plan_with_params(vec![("$1", DataType::Int32)]); + let portal = make_portal( + vec![Some(Type::INT8)], + vec![Some((i32::MAX as i64 + 1).to_string())], + ); + let result = parameters_to_scalar_values(&plan, &portal); + assert!(result.is_err()); + } + + #[test] + fn test_int8_coerce_negative_to_unsigned_out_of_range() { + let plan = build_plan_with_params(vec![("$1", DataType::UInt64)]); + let portal = make_portal(vec![Some(Type::INT8)], vec![s("-1")]); + let result = parameters_to_scalar_values(&plan, &portal); + assert!(result.is_err()); + } + + #[test] + fn test_float4_coerce_in_range() { + let plan = + build_plan_with_params(vec![("$1", DataType::Float32), ("$2", DataType::Float64)]); + let portal = make_portal( + vec![Some(Type::FLOAT4), Some(Type::FLOAT4)], + vec![s("1.5"), s("2.5")], + ); + + let values = parameters_to_scalar_values(&plan, &portal).unwrap(); + assert_eq!(values[0], ScalarValue::Float32(Some(1.5))); + assert_eq!(values[1], ScalarValue::Float64(Some(2.5))); + } + + #[test] + fn test_float4_coerce_to_int_in_range() { + let plan = build_plan_with_params(vec![ + ("$1", DataType::Int8), + ("$2", DataType::Int32), + ("$3", DataType::UInt64), + ]); + let portal = make_portal( + vec![Some(Type::FLOAT4), Some(Type::FLOAT4), Some(Type::FLOAT4)], + vec![s("100"), s("1000"), s("200")], + ); + + let values = parameters_to_scalar_values(&plan, &portal).unwrap(); + assert_eq!(values[0], ScalarValue::Int8(Some(100))); + assert_eq!(values[1], ScalarValue::Int32(Some(1000))); + assert_eq!(values[2], ScalarValue::UInt64(Some(200))); + } + + #[test] + fn test_float4_coerce_to_int_out_of_range() { + let plan = build_plan_with_params(vec![("$1", DataType::Int8)]); + let portal = make_portal(vec![Some(Type::FLOAT4)], vec![s("200")]); + let result = parameters_to_scalar_values(&plan, &portal); + assert!(result.is_err()); + } + + #[test] + fn test_float8_coerce_in_range() { + let plan = + build_plan_with_params(vec![("$1", DataType::Float32), ("$2", DataType::Float64)]); + let portal = make_portal( + vec![Some(Type::FLOAT8), Some(Type::FLOAT8)], + vec![s("1.5"), s("2.5")], + ); + + let values = parameters_to_scalar_values(&plan, &portal).unwrap(); + assert_eq!(values[0], ScalarValue::Float32(Some(1.5))); + assert_eq!(values[1], ScalarValue::Float64(Some(2.5))); + } + + #[test] + fn test_float8_coerce_to_int_in_range() { + let plan = build_plan_with_params(vec![ + ("$1", DataType::Int8), + ("$2", DataType::Int64), + ("$3", DataType::UInt64), + ]); + let portal = make_portal( + vec![Some(Type::FLOAT8), Some(Type::FLOAT8), Some(Type::FLOAT8)], + vec![s("100"), s("1000000"), s("200")], + ); + + let values = parameters_to_scalar_values(&plan, &portal).unwrap(); + assert_eq!(values[0], ScalarValue::Int8(Some(100))); + assert_eq!(values[1], ScalarValue::Int64(Some(1000000))); + assert_eq!(values[2], ScalarValue::UInt64(Some(200))); + } + + #[test] + fn test_float8_coerce_to_int_out_of_range() { + let plan = build_plan_with_params(vec![("$1", DataType::Int8)]); + let portal = make_portal(vec![Some(Type::FLOAT8)], vec![s("200")]); + let result = parameters_to_scalar_values(&plan, &portal); + assert!(result.is_err()); + } + + #[test] + fn test_float8_coerce_negative_to_unsigned_out_of_range() { + let plan = build_plan_with_params(vec![("$1", DataType::UInt64)]); + let portal = make_portal(vec![Some(Type::FLOAT8)], vec![s("-1")]); + let result = parameters_to_scalar_values(&plan, &portal); + assert!(result.is_err()); + } + + #[test] + fn test_null_parameter() { + let plan = build_plan_with_params(vec![("$1", DataType::Int8)]); + let portal = make_portal(vec![Some(Type::INT2)], vec![None]); + + let values = parameters_to_scalar_values(&plan, &portal).unwrap(); + assert_eq!(values[0], ScalarValue::Int8(None)); + } }