feat: improve support for postgres extended protocol (#4721)

* feat: improve support for postgres extended protocol

* fix: lint fix

* fix: test code

* fix: adopt upstream

* refactor: remove dup code

* refactor: avoid copy on error message
This commit is contained in:
Ning Sun
2024-09-19 13:30:56 +08:00
committed by GitHub
parent 52d627e37d
commit 8786624515
4 changed files with 313 additions and 247 deletions

View File

@@ -398,11 +398,19 @@ impl QueryEngine for DatafusionQueryEngine {
query_ctx: QueryContextRef,
) -> Result<DescribeResult> {
let ctx = self.engine_context(query_ctx);
let optimised_plan = self.optimize(&ctx, &plan)?;
Ok(DescribeResult {
schema: optimised_plan.schema()?,
logical_plan: optimised_plan,
})
if let Ok(optimised_plan) = self.optimize(&ctx, &plan) {
Ok(DescribeResult {
schema: optimised_plan.schema()?,
logical_plan: optimised_plan,
})
} else {
// Table's like those in information_schema cannot be optimized when
// it contains parameters. So we fallback to original plans.
Ok(DescribeResult {
schema: plan.schema()?,
logical_plan: plan,
})
}
}
async fn execute(&self, plan: LogicalPlan, query_ctx: QueryContextRef) -> Result<Output> {

View File

@@ -54,17 +54,19 @@ static SET_TRANSACTION_PATTERN: Lazy<Regex> =
static TRANSACTION_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(BEGIN|ROLLBACK|COMMIT);?").unwrap());
/// Test if given query statement matches the patterns
pub(crate) fn matches(query: &str) -> bool {
TRANSACTION_PATTERN.captures(query).is_some()
|| SHOW_PATTERN.captures(query).is_some()
|| SET_TRANSACTION_PATTERN.is_match(query)
}
/// Process unsupported SQL and return fixed result as a compatibility solution
pub(crate) fn process<'a>(
query: &str,
_query_ctx: QueryContextRef,
) -> Option<PgWireResult<Vec<Response<'a>>>> {
pub(crate) fn process<'a>(query: &str, _query_ctx: QueryContextRef) -> Option<Vec<Response<'a>>> {
// Transaction directives:
if let Some(tx) = TRANSACTION_PATTERN.captures(query) {
let tx_tag = &tx[1];
Some(Ok(vec![Response::Execution(Tag::new(
&tx_tag.to_uppercase(),
))]))
Some(vec![Response::Execution(Tag::new(&tx_tag.to_uppercase()))])
} else if let Some(show_var) = SHOW_PATTERN.captures(query) {
let show_var = show_var[1].to_lowercase();
if let Some(value) = VAR_VALUES.get(&show_var.as_ref()) {
@@ -81,12 +83,12 @@ pub(crate) fn process<'a>(
vec![vec![value.to_string()]],
));
Some(Ok(vec![Response::Query(QueryResponse::new(schema, data))]))
Some(vec![Response::Query(QueryResponse::new(schema, data))])
} else {
None
}
} else if SET_TRANSACTION_PATTERN.is_match(query) {
Some(Ok(vec![Response::Execution(Tag::new("SET"))]))
Some(vec![Response::Execution(Tag::new("SET"))])
} else {
None
}
@@ -101,7 +103,6 @@ mod test {
fn assert_tag(q: &str, t: &str, query_context: QueryContextRef) {
if let Response::Execution(tag) = process(q, query_context.clone())
.unwrap_or_else(|| panic!("fail to match {}", q))
.expect("unexpected error")
.remove(0)
{
assert_eq!(Tag::new(t), tag);
@@ -113,7 +114,6 @@ mod test {
fn get_data<'a>(q: &str, query_context: QueryContextRef) -> QueryResponse<'a> {
if let Response::Query(resp) = process(q, query_context.clone())
.unwrap_or_else(|| panic!("fail to match {}", q))
.expect("unexpected error")
.remove(0)
{
resp

View File

@@ -59,8 +59,13 @@ impl SimpleQueryHandler for PostgresServerHandler {
.with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()])
.start_timer();
if query.is_empty() {
// early return if query is empty
return Ok(vec![Response::EmptyQuery]);
}
if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
resps
Ok(resps)
} else {
let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
@@ -184,6 +189,16 @@ impl QueryParser for DefaultQueryParser {
async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
crate::metrics::METRIC_POSTGRES_PREPARED_COUNT.inc();
let query_ctx = self.session.new_query_context();
// do not parse if query is empty or matches rules
if sql.is_empty() || fixtures::matches(sql) {
return Ok(SqlPlan {
query: sql.to_owned(),
plan: None,
schema: None,
});
}
let mut stmts =
ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
@@ -193,6 +208,7 @@ impl QueryParser for DefaultQueryParser {
))))
} else {
let stmt = stmts.remove(0);
let describe_result = self
.query_handler
.do_describe(stmt, query_ctx)
@@ -244,6 +260,16 @@ impl ExtendedQueryHandler for PostgresServerHandler {
let sql_plan = &portal.statement.statement;
if sql_plan.query.is_empty() {
// early return if query is empty
return Ok(Response::EmptyQuery);
}
if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
// if the statement matches our predefined rules, return it early
return Ok(resps.remove(0));
}
let output = if let Some(plan) = &sql_plan.plan {
let plan = plan
.replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref())
@@ -297,6 +323,17 @@ impl ExtendedQueryHandler for PostgresServerHandler {
.map(|fields| DescribeStatementResponse::new(param_types, fields))
.map_err(|e| PgWireError::ApiError(Box::new(e)))
} else {
if let Some(mut resp) =
fixtures::process(&sql_plan.query, self.session.new_query_context())
{
if let Response::Query(query_response) = resp.remove(0) {
return Ok(DescribeStatementResponse::new(
param_types,
(*query_response.row_schema()).clone(),
));
}
}
Ok(DescribeStatementResponse::new(param_types, vec![]))
}
}
@@ -317,6 +354,16 @@ impl ExtendedQueryHandler for PostgresServerHandler {
.map(DescribePortalResponse::new)
.map_err(|e| PgWireError::ApiError(Box::new(e)))
} else {
if let Some(mut resp) =
fixtures::process(&sql_plan.query, self.session.new_query_context())
{
if let Response::Query(query_response) = resp.remove(0) {
return Ok(DescribePortalResponse::new(
(*query_response.row_schema()).clone(),
));
}
}
Ok(DescribePortalResponse::new(vec![]))
}
}

View File

@@ -239,14 +239,14 @@ pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWir
.unwrap_or_else(|| "".to_owned())),
_ => Err(invalid_parameter_error(
"unsupported_parameter_type",
Some(&param_type.to_string()),
Some(param_type.to_string()),
)),
}
}
pub(super) fn invalid_parameter_error(msg: &str, detail: Option<&str>) -> PgWireError {
pub(super) fn invalid_parameter_error(msg: &str, detail: Option<String>) -> PgWireError {
let mut error_info = PgErrorCode::Ec22023.to_err_info(msg.to_string());
error_info.detail = detail.map(|s| s.to_owned());
error_info.detail = detail;
PgWireError::UserError(Box::new(error_info))
}
@@ -279,303 +279,314 @@ pub(super) fn parameters_to_scalar_values(
.get_param_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
// ensure parameter count consistent for: client parameter types, server
// parameter types and parameter count
if param_types.len() != param_count {
return Err(invalid_parameter_error(
"invalid_parameter_count",
Some(&format!(
"Expected: {}, found: {}",
param_types.len(),
param_count
)),
));
}
for idx in 0..param_count {
let server_type =
if let Some(Some(server_infer_type)) = param_types.get(&format!("${}", idx + 1)) {
server_infer_type
} else {
// at the moment we require type information inferenced by
// server so here we return error if the type is unknown from
// server-side.
//
// It might be possible to parse the parameter just using client
// specified type, we will implement that if there is a case.
return Err(invalid_parameter_error("unknown_parameter_type", None));
};
let server_type = param_types
.get(&format!("${}", idx + 1))
.and_then(|t| t.as_ref());
let client_type = if let Some(client_given_type) = client_param_types.get(idx) {
client_given_type.clone()
} else if let Some(server_provided_type) = &server_type {
type_gt_to_pg(server_provided_type).map_err(|e| PgWireError::ApiError(Box::new(e)))?
} else {
type_gt_to_pg(server_type).map_err(|e| PgWireError::ApiError(Box::new(e)))?
return Err(invalid_parameter_error(
"unknown_parameter_type",
Some(format!(
"Cannot get parameter type information for parameter {}",
idx
)),
));
};
let value = match &client_type {
&Type::VARCHAR | &Type::TEXT => {
let data = portal.parameter::<String>(idx, &client_type)?;
match server_type {
ConcreteDataType::String(_) => ScalarValue::Utf8(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::String(_) => ScalarValue::Utf8(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
))
}
}
} else {
ScalarValue::Utf8(data)
}
}
&Type::BOOL => {
let data = portal.parameter::<bool>(idx, &client_type)?;
match server_type {
ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
))
}
}
} else {
ScalarValue::Boolean(data)
}
}
&Type::INT2 => {
let data = portal.parameter::<i16>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Timestamp(unit) => {
to_timestamp_scalar_value(data, unit, server_type)?
}
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Timestamp(unit) => {
to_timestamp_scalar_value(data, unit, server_type)?
}
ConcreteDataType::DateTime(_) => {
ScalarValue::Date64(data.map(|d| d as i64))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
))
}
}
} else {
ScalarValue::Int16(data)
}
}
&Type::INT4 => {
let data = portal.parameter::<i32>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Timestamp(unit) => {
to_timestamp_scalar_value(data, unit, server_type)?
}
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Timestamp(unit) => {
to_timestamp_scalar_value(data, unit, server_type)?
}
ConcreteDataType::DateTime(_) => {
ScalarValue::Date64(data.map(|d| d as i64))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
))
}
}
} else {
ScalarValue::Int32(data)
}
}
&Type::INT8 => {
let data = portal.parameter::<i64>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Timestamp(unit) => {
to_timestamp_scalar_value(data, unit, server_type)?
}
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Timestamp(unit) => {
to_timestamp_scalar_value(data, unit, server_type)?
}
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
))
}
}
} else {
ScalarValue::Int64(data)
}
}
&Type::FLOAT4 => {
let data = portal.parameter::<f32>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Float32(_) => ScalarValue::Float32(data),
ConcreteDataType::Float64(_) => ScalarValue::Float64(data.map(|n| n as f64)),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Float32(_) => ScalarValue::Float32(data),
ConcreteDataType::Float64(_) => {
ScalarValue::Float64(data.map(|n| n as f64))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
))
}
}
} else {
ScalarValue::Float32(data)
}
}
&Type::FLOAT8 => {
let data = portal.parameter::<f64>(idx, &client_type)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Float32(_) => ScalarValue::Float32(data.map(|n| n as f32)),
ConcreteDataType::Float64(_) => ScalarValue::Float64(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Float32(_) => {
ScalarValue::Float32(data.map(|n| n as f32))
}
ConcreteDataType::Float64(_) => ScalarValue::Float64(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
))
}
}
} else {
ScalarValue::Float64(data)
}
}
&Type::TIMESTAMP => {
let data = portal.parameter::<NaiveDateTime>(idx, &client_type)?;
match server_type {
ConcreteDataType::Timestamp(unit) => match *unit {
TimestampType::Second(_) => ScalarValue::TimestampSecond(
data.map(|ts| ts.and_utc().timestamp()),
None,
),
TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
data.map(|ts| ts.and_utc().timestamp_millis()),
None,
),
TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
data.map(|ts| ts.and_utc().timestamp_micros()),
None,
),
TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
data.map(|ts| ts.and_utc().timestamp_micros()),
None,
),
},
ConcreteDataType::DateTime(_) => {
ScalarValue::Date64(data.map(|d| d.and_utc().timestamp_millis()))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Timestamp(unit) => match *unit {
TimestampType::Second(_) => ScalarValue::TimestampSecond(
data.map(|ts| ts.and_utc().timestamp()),
None,
),
TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
data.map(|ts| ts.and_utc().timestamp_millis()),
None,
),
TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
data.map(|ts| ts.and_utc().timestamp_micros()),
None,
),
TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
data.map(|ts| ts.and_utc().timestamp_micros()),
None,
),
},
ConcreteDataType::DateTime(_) => {
ScalarValue::Date64(data.map(|d| d.and_utc().timestamp_millis()))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
))
}
}
} else {
ScalarValue::TimestampMillisecond(
data.map(|ts| ts.and_utc().timestamp_millis()),
None,
)
}
}
&Type::DATE => {
let data = portal.parameter::<NaiveDate>(idx, &client_type)?;
match server_type {
ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| {
(d - NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()).num_days() as i32
})),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
));
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| {
(d - NaiveDate::from(NaiveDateTime::UNIX_EPOCH)).num_days() as i32
})),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
));
}
}
} else {
ScalarValue::Date32(data.map(|d| {
(d - NaiveDate::from(NaiveDateTime::UNIX_EPOCH)).num_days() as i32
}))
}
}
&Type::INTERVAL => {
let data = portal.parameter::<PgInterval>(idx, &client_type)?;
match server_type {
ConcreteDataType::Interval(_) => {
ScalarValue::IntervalMonthDayNano(data.map(|i| Interval::from(i).to_i128()))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
));
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Interval(_) => ScalarValue::IntervalMonthDayNano(
data.map(|i| Interval::from(i).to_i128()),
),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
));
}
}
} else {
ScalarValue::IntervalMonthDayNano(data.map(|i| Interval::from(i).to_i128()))
}
}
&Type::BYTEA => {
let data = portal.parameter::<Vec<u8>>(idx, &client_type)?;
match server_type {
ConcreteDataType::String(_) => {
ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string()))
}
ConcreteDataType::Binary(_) => ScalarValue::Binary(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
));
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::String(_) => {
ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string()))
}
ConcreteDataType::Binary(_) => ScalarValue::Binary(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
));
}
}
} else {
ScalarValue::Binary(data)
}
}
&Type::JSONB => {
let data = portal.parameter::<serde_json::Value>(idx, &client_type)?;
match server_type {
ConcreteDataType::Binary(_) => {
ScalarValue::Binary(data.map(|d| jsonb::Value::from(d).to_vec()))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
));
if let Some(server_type) = &server_type {
match server_type {
ConcreteDataType::Binary(_) => {
ScalarValue::Binary(data.map(|d| jsonb::Value::from(d).to_vec()))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(format!("Expected: {}, found: {}", server_type, client_type)),
));
}
}
} else {
ScalarValue::Binary(data.map(|d| jsonb::Value::from(d).to_vec()))
}
}
_ => Err(invalid_parameter_error(
"unsupported_parameter_value",
Some(&format!("Found type: {}", client_type)),
Some(format!("Found type: {}", client_type)),
))?,
};