feat: Support printing postgresql's bytea data type in its "hex" and "escape" format (#3567)

* feat: support set variable statement of session

* feat: support printing postgresql's bytea data type in its "hex" and "escape" format in ugly way

* refactor: add 'SessionConfigValue' type and unify the name

* doc: add license header

* refactor: confine coupling with 'sql::ast::Value' in SessionConfigValue

* refactor: move all bytea wrapper into bytea.rs

* fix: remove unused import in context.rs and postgres.rs

* refactor: rename 'set_configuration_parameter' to 'set_session_config'

rename 'set_configuration_parameter' in statement_.rs to 'set_session_config'

* refactor: use mod to organize options via macro

* refactor: re-model the session config value with static type

* test: add integration test

* refactor: move the encode bytea by format type logic into encoder

refactor: use Arc<DashMap> instead of DashMap in QueryContext

refactor: use Arc<DashMap> instead of DashMap in QueryContext

    Avoid expensive clone

refactor: use unreachable!() instead of unimplemented!()

refactor: move the encode bytea by format type logic into encoder

test: add binary format integration test case

* test: add ut for byte related type

* doc: remove TODO of bytea_output

* refactor: simplify the implementation with simple struct instead of complex typing

* fix: typo of 'Available'

* fix compile

Signed-off-by: tison <wander4096@gmail.com>

---------

Signed-off-by: tison <wander4096@gmail.com>
Co-authored-by: tison <wander4096@gmail.com>
This commit is contained in:
JohnsonLee
2024-03-27 09:54:41 +08:00
committed by GitHub
parent d83279567b
commit 83643eb195
15 changed files with 423 additions and 22 deletions

5
Cargo.lock generated
View File

@@ -9030,6 +9030,7 @@ dependencies = [
"common-time",
"common-version",
"criterion",
"dashmap",
"datafusion",
"datafusion-common",
"datatypes",
@@ -9103,8 +9104,12 @@ dependencies = [
"arc-swap",
"auth",
"common-catalog",
"common-error",
"common-macro",
"common-telemetry",
"common-time",
"derive_builder 0.12.0",
"snafu",
"sql",
]

View File

@@ -20,9 +20,9 @@ use common_error::status_code::StatusCode;
use common_macro::stack_trace_debug;
use datafusion::parquet;
use datatypes::arrow::error::ArrowError;
use datatypes::value::Value;
use servers::define_into_tonic_status;
use snafu::{Location, Snafu};
use sql::ast::Value;
#[derive(Snafu)]
#[snafu(visibility(pub))]
@@ -528,6 +528,12 @@ pub enum Error {
#[snafu(display("Invalid partition rule: {}", reason))]
InvalidPartitionRule { reason: String, location: Location },
#[snafu(display("Invalid configuration value."))]
InvalidConfigValue {
source: session::session_config::Error,
location: Location,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -536,6 +542,7 @@ impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::InvalidSql { .. }
| Error::InvalidConfigValue { .. }
| Error::InvalidInsertRequest { .. }
| Error::InvalidDeleteRequest { .. }
| Error::IllegalPrimaryKeysDef { .. }

View File

@@ -39,6 +39,7 @@ use query::parser::QueryStatement;
use query::plan::LogicalPlan;
use query::QueryEngineRef;
use session::context::QueryContextRef;
use session::session_config::PGByteaOutputValue;
use session::table_name::table_idents_to_full_name;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
@@ -52,8 +53,8 @@ use table::table_reference::TableReference;
use table::TableRef;
use crate::error::{
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu,
PlanStatementSnafu, Result, TableNotFoundSnafu,
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidConfigValueSnafu,
InvalidSqlSnafu, NotSupportedSnafu, PlanStatementSnafu, Result, TableNotFoundSnafu,
};
use crate::insert::InserterRef;
use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY};
@@ -219,8 +220,7 @@ impl StatementExecutor {
// so we just ignore it here instead of returning an error to break the connection.
// Since the "bytea_output" only determines the output format of binary values,
// it won't cause much trouble if we do so.
// TODO(#3438): Remove this temporary workaround after the feature is implemented.
"BYTEA_OUTPUT" => (),
"BYTEA_OUTPUT" => set_bytea_output(set_var.value, query_ctx)?,
// Same as "bytea_output", we just ignore it here.
// Not harmful since it only relates to how date is viewed in client app's output.
@@ -339,6 +339,25 @@ fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
}
}
fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let Some((var_value, [])) = exprs.split_first() else {
return (NotSupportedSnafu {
feat: "Set variable value must have one and only one value for bytea_output",
})
.fail();
};
let Expr::Value(value) = var_value else {
return (NotSupportedSnafu {
feat: "Set variable value must be a value",
})
.fail();
};
ctx.configuration_parameter().set_postgres_bytea_output(
PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?,
);
Ok(())
}
fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result<CopyTableRequest> {
let direction = match stmt {
CopyTable::To(_) => CopyDirection::Export,

View File

@@ -41,6 +41,7 @@ common-recordbatch.workspace = true
common-runtime.workspace = true
common-telemetry.workspace = true
common-time.workspace = true
dashmap.workspace = true
datafusion.workspace = true
datafusion-common.workspace = true
datatypes.workspace = true

View File

@@ -310,7 +310,7 @@ mod test {
#[test]
fn test_check() {
let session = Arc::new(Session::new(None, Channel::Mysql));
let session = Arc::new(Session::new(None, Channel::Mysql, Default::default()));
let query = "select 1";
let result = check(query, QueryContext::arc(), session.clone());
assert!(result.is_none());
@@ -320,7 +320,7 @@ mod test {
assert!(output.is_none());
fn test(query: &str, expected: &str) {
let session = Arc::new(Session::new(None, Channel::Mysql));
let session = Arc::new(Session::new(None, Channel::Mysql, Default::default()));
let output = check(query, QueryContext::arc(), session.clone());
match output.unwrap().data {
OutputData::RecordBatches(r) => {

View File

@@ -85,7 +85,11 @@ impl MysqlInstanceShim {
MysqlInstanceShim {
query_handler,
salt: scramble,
session: Arc::new(Session::new(Some(client_addr), Channel::Mysql)),
session: Arc::new(Session::new(
Some(client_addr),
Channel::Mysql,
Default::default(),
)),
user_provider,
prepared_stmts: Default::default(),
prepared_stmts_counter: AtomicU32::new(1),

View File

@@ -88,7 +88,7 @@ pub(crate) struct MakePostgresServerHandler {
impl MakePostgresServerHandler {
fn make(&self, addr: Option<SocketAddr>) -> PostgresServerHandler {
let session = Arc::new(Session::new(addr, Channel::Postgres));
let session = Arc::new(Session::new(addr, Channel::Postgres, Default::default()));
PostgresServerHandler {
query_handler: self.query_handler.clone(),
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),

View File

@@ -31,6 +31,7 @@ use pgwire::api::stmt::{QueryParser, StoredStatement};
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use query::query_engine::DescribeResult;
use session::context::QueryContextRef;
use session::Session;
use sql::dialect::PostgreSqlDialect;
use sql::parser::{ParseOptions, ParserContext};
@@ -63,7 +64,7 @@ impl SimpleQueryHandler for PostgresServerHandler {
let mut results = Vec::with_capacity(outputs.len());
for output in outputs {
let resp = output_to_query_response(output, &Format::UnifiedText)?;
let resp = output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
results.push(resp);
}
@@ -72,6 +73,7 @@ impl SimpleQueryHandler for PostgresServerHandler {
}
fn output_to_query_response<'a>(
query_ctx: QueryContextRef,
output: Result<Output>,
field_format: &Format,
) -> PgWireResult<Response<'a>> {
@@ -82,11 +84,16 @@ fn output_to_query_response<'a>(
}
OutputData::Stream(record_stream) => {
let schema = record_stream.schema();
recordbatches_to_query_response(record_stream, schema, field_format)
recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
}
OutputData::RecordBatches(recordbatches) => {
let schema = recordbatches.schema();
recordbatches_to_query_response(recordbatches.as_stream(), schema, field_format)
recordbatches_to_query_response(
query_ctx,
recordbatches.as_stream(),
schema,
field_format,
)
}
},
Err(e) => Ok(Response::Error(Box::new(ErrorInfo::new(
@@ -98,6 +105,7 @@ fn output_to_query_response<'a>(
}
fn recordbatches_to_query_response<'a, S>(
query_ctx: QueryContextRef,
recordbatches_stream: S,
schema: SchemaRef,
field_format: &Format,
@@ -125,7 +133,7 @@ where
row.and_then(|row| {
let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
for value in row.iter() {
encode_value(value, &mut encoder)?;
encode_value(&query_ctx, value, &mut encoder)?;
}
encoder.finish()
})
@@ -224,7 +232,9 @@ impl ExtendedQueryHandler for PostgresServerHandler {
let plan = plan
.replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
self.query_handler.do_exec_plan(plan, query_ctx).await
self.query_handler
.do_exec_plan(plan, query_ctx.clone())
.await
} else {
// manually replace variables in prepared statement when no
// logical_plan is generated. This happens when logical plan is not
@@ -234,10 +244,13 @@ impl ExtendedQueryHandler for PostgresServerHandler {
sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
}
self.query_handler.do_query(&sql, query_ctx).await.remove(0)
self.query_handler
.do_query(&sql, query_ctx.clone())
.await
.remove(0)
};
output_to_query_response(output, &portal.result_column_format)
output_to_query_response(query_ctx, output, &portal.result_column_format)
}
async fn do_describe_statement<C>(

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod bytea;
mod interval;
use std::collections::HashMap;
@@ -28,7 +29,10 @@ use pgwire::api::results::{DataRowEncoder, FieldInfo};
use pgwire::api::Type;
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use query::plan::LogicalPlan;
use session::context::QueryContextRef;
use session::session_config::PGByteaOutputValue;
use self::bytea::{EscapeOutputBytea, HexOutputBytea};
use self::interval::PgInterval;
use crate::error::{self, Error, Result};
use crate::SqlPlan;
@@ -50,7 +54,11 @@ pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Ve
.collect::<Result<Vec<FieldInfo>>>()
}
pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
pub(super) fn encode_value(
query_ctx: &QueryContextRef,
value: &Value,
builder: &mut DataRowEncoder,
) -> PgWireResult<()> {
match value {
Value::Null => builder.encode_field(&None::<&i8>),
Value::Boolean(v) => builder.encode_field(v),
@@ -65,7 +73,13 @@ pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWir
Value::Float32(v) => builder.encode_field(&v.0),
Value::Float64(v) => builder.encode_field(&v.0),
Value::String(v) => builder.encode_field(&v.as_utf8()),
Value::Binary(v) => builder.encode_field(&v.deref()),
Value::Binary(v) => {
let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
match *bytea_output {
PGByteaOutputValue::ESCAPE => builder.encode_field(&EscapeOutputBytea(v.deref())),
PGByteaOutputValue::HEX => builder.encode_field(&HexOutputBytea(v.deref())),
}
}
Value::Date(v) => {
if let Some(date) = v.to_chrono_date() {
builder.encode_field(&date)
@@ -563,6 +577,7 @@ mod test {
use datatypes::value::ListValue;
use pgwire::api::results::{FieldFormat, FieldInfo};
use pgwire::api::Type;
use session::context::QueryContextBuilder;
use super::*;
@@ -784,12 +799,16 @@ mod test {
Value::Timestamp(1000001i64.into()),
Value::Interval(1000001i128.into()),
];
let query_context = QueryContextBuilder::default()
.configuration_parameter(Default::default())
.build();
let mut builder = DataRowEncoder::new(Arc::new(schema));
for i in values.iter() {
encode_value(i, &mut builder).unwrap();
encode_value(&query_context, i, &mut builder).unwrap();
}
let err = encode_value(
&query_context,
&Value::List(ListValue::new(
Some(Box::default()),
ConcreteDataType::int16_datatype(),

View File

@@ -0,0 +1,152 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bytes::BufMut;
use pgwire::types::ToSqlText;
use postgres_types::{IsNull, ToSql, Type};
#[derive(Debug)]
pub struct HexOutputBytea<'a>(pub &'a [u8]);
impl ToSqlText for HexOutputBytea<'_> {
fn to_sql_text(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
out.put_slice(b"\\x");
let _ = self.0.to_sql_text(ty, out);
Ok(IsNull::No)
}
}
impl ToSql for HexOutputBytea<'_> {
fn to_sql(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
self.0.to_sql(ty, out)
}
fn accepts(ty: &Type) -> bool
where
Self: Sized,
{
<&[u8] as ToSql>::accepts(ty)
}
fn to_sql_checked(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
self.0.to_sql_checked(ty, out)
}
}
#[derive(Debug)]
pub struct EscapeOutputBytea<'a>(pub &'a [u8]);
impl ToSqlText for EscapeOutputBytea<'_> {
fn to_sql_text(
&self,
_ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
self.0.iter().for_each(|b| match b {
0..=31 | 127..=255 => {
out.put_slice(b"\\");
out.put_slice(format!("{:03o}", b).as_bytes());
}
92 => out.put_slice(b"\\\\"),
32..=126 => out.put_u8(*b),
});
Ok(IsNull::No)
}
}
impl ToSql for EscapeOutputBytea<'_> {
fn to_sql(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
self.0.to_sql(ty, out)
}
fn accepts(ty: &Type) -> bool
where
Self: Sized,
{
<&[u8] as ToSql>::accepts(ty)
}
fn to_sql_checked(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
self.0.to_sql_checked(ty, out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_output_bytea() {
let input: &[u8] = &[97, 98, 99, 107, 108, 109, 42, 169, 84];
let input = EscapeOutputBytea(input);
let expected = b"abcklm*\\251T";
let mut out = bytes::BytesMut::new();
let is_null = input.to_sql_text(&Type::BYTEA, &mut out).unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);
let expected = &[97, 98, 99, 107, 108, 109, 42, 169, 84];
let mut out = bytes::BytesMut::new();
let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);
}
#[test]
fn test_hex_output_bytea() {
let input = b"hello, world!";
let input = HexOutputBytea(input);
let expected = b"\\x68656c6c6f2c20776f726c6421";
let mut out = bytes::BytesMut::new();
let is_null = input.to_sql_text(&Type::BYTEA, &mut out).unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);
let expected = b"hello, world!";
let mut out = bytes::BytesMut::new();
let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap();
assert!(matches!(is_null, IsNull::No));
assert_eq!(&out[..], expected);
}
}

View File

@@ -15,6 +15,10 @@ api.workspace = true
arc-swap = "1.5"
auth.workspace = true
common-catalog.workspace = true
common-error.workspace = true
common-macro.workspace = true
common-telemetry.workspace = true
common-time.workspace = true
derive_builder.workspace = true
snafu.workspace = true
sql.workspace = true

View File

@@ -27,6 +27,7 @@ use common_time::Timezone;
use derive_builder::Builder;
use sql::dialect::{Dialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
use crate::session_config::PGByteaOutputValue;
use crate::SessionRef;
pub type QueryContextRef = Arc<QueryContext>;
@@ -44,6 +45,9 @@ pub struct QueryContext {
sql_dialect: Arc<dyn Dialect + Send + Sync>,
#[builder(default)]
extension: HashMap<String, String>,
// The configuration parameter are used to store the parameters that are set by the user
#[builder(default)]
configuration_parameter: Arc<ConfigurationVariables>,
}
impl QueryContextBuilder {
@@ -73,6 +77,7 @@ impl Clone for QueryContext {
timezone: self.timezone.load().clone().into(),
sql_dialect: self.sql_dialect.clone(),
extension: self.extension.clone(),
configuration_parameter: self.configuration_parameter.clone(),
}
}
}
@@ -88,6 +93,7 @@ impl From<&RegionRequestHeader> for QueryContext {
timezone: ArcSwap::new(Arc::new(get_timezone(None).clone())),
sql_dialect: Arc::new(GreptimeDbDialect {}),
extension: Default::default(),
configuration_parameter: Default::default(),
}
}
}
@@ -183,6 +189,10 @@ impl QueryContext {
'`'
}
}
pub fn configuration_parameter(&self) -> &ConfigurationVariables {
&self.configuration_parameter
}
}
impl QueryContextBuilder {
@@ -204,6 +214,7 @@ impl QueryContextBuilder {
.sql_dialect
.unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
extension: self.extension.unwrap_or_default(),
configuration_parameter: self.configuration_parameter.unwrap_or_default(),
})
}
@@ -268,6 +279,33 @@ impl Display for Channel {
}
}
#[derive(Default, Debug)]
pub struct ConfigurationVariables {
postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
}
impl Clone for ConfigurationVariables {
fn clone(&self) -> Self {
Self {
postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
}
}
}
impl ConfigurationVariables {
pub fn new() -> Self {
Self::default()
}
pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
let _ = self.postgres_bytea_output.swap(Arc::new(value));
}
pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
self.postgres_bytea_output.load().clone()
}
}
#[cfg(test)]
mod test {
use common_catalog::consts::DEFAULT_CATALOG_NAME;
@@ -278,7 +316,11 @@ mod test {
#[test]
fn test_session() {
let session = Session::new(Some("127.0.0.1:9000".parse().unwrap()), Channel::Mysql);
let session = Session::new(
Some("127.0.0.1:9000".parse().unwrap()),
Channel::Mysql,
Default::default(),
);
// test user_info
assert_eq!(session.user_info().username(), "greptime");

View File

@@ -13,6 +13,7 @@
// limitations under the License.
pub mod context;
pub mod session_config;
pub mod table_name;
use std::net::SocketAddr;
@@ -24,7 +25,7 @@ use common_catalog::build_db_string;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_time::timezone::get_timezone;
use common_time::Timezone;
use context::QueryContextBuilder;
use context::{ConfigurationVariables, QueryContextBuilder};
use crate::context::{Channel, ConnInfo, QueryContextRef};
@@ -36,18 +37,24 @@ pub struct Session {
user_info: ArcSwap<UserInfoRef>,
conn_info: ConnInfo,
timezone: ArcSwap<Timezone>,
configuration_variables: Arc<ConfigurationVariables>,
}
pub type SessionRef = Arc<Session>;
impl Session {
pub fn new(addr: Option<SocketAddr>, channel: Channel) -> Self {
pub fn new(
addr: Option<SocketAddr>,
channel: Channel,
configuration_variables: ConfigurationVariables,
) -> Self {
Session {
catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.into())),
schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.into())),
user_info: ArcSwap::new(Arc::new(auth::userinfo_by_name(None))),
conn_info: ConnInfo::new(addr, channel),
timezone: ArcSwap::new(Arc::new(get_timezone(None).clone())),
configuration_variables: Arc::new(configuration_variables),
}
}
@@ -60,6 +67,7 @@ impl Session {
.current_catalog(self.catalog.load().to_string())
.current_schema(self.schema.load().to_string())
.sql_dialect(self.conn_info.channel.dialect())
.configuration_parameter(self.configuration_variables.clone())
.timezone(self.timezone())
.build()
}

View File

@@ -0,0 +1,64 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_macro::stack_trace_debug;
use snafu::{Location, Snafu};
use sql::ast::Value;
#[derive(Snafu)]
#[snafu(visibility(pub))]
#[stack_trace_debug]
pub enum Error {
#[snafu(display("Invalid value for parameter \"{}\": {}\nHint: {}", name, value, hint,))]
InvalidConfigValue {
name: String,
value: String,
hint: String,
location: Location,
},
}
#[derive(Clone, Copy, Debug, Default)]
pub enum PGByteaOutputValue {
#[default]
HEX,
ESCAPE,
}
impl TryFrom<Value> for PGByteaOutputValue {
type Error = Error;
fn try_from(value: Value) -> Result<Self, Self::Error> {
match &value {
Value::DoubleQuotedString(s) | Value::SingleQuotedString(s) => {
match s.to_uppercase().as_str() {
"ESCAPE" => Ok(PGByteaOutputValue::ESCAPE),
"HEX" => Ok(PGByteaOutputValue::HEX),
_ => InvalidConfigValueSnafu {
name: "BYTEA_OUTPUT",
value: value.to_string(),
hint: "Available values: escape, hex",
}
.fail(),
}
}
_ => InvalidConfigValueSnafu {
name: "BYTEA_OUTPUT",
value: value.to_string(),
hint: "Available values: escape, hex",
}
.fail(),
}
}
}

View File

@@ -60,6 +60,7 @@ macro_rules! sql_tests {
test_postgres_auth,
test_postgres_crud,
test_postgres_timezone,
test_postgres_bytea,
test_postgres_parameter_inference,
test_mysql_prepare_stmt_insert_timestamp,
);
@@ -415,7 +416,69 @@ pub async fn test_postgres_crud(store_type: StorageType) {
let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}
pub async fn test_postgres_bytea(store_type: StorageType) {
let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_bytea_output").await;
let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
.await
.unwrap();
tokio::spawn(async move {
connection.await.unwrap();
});
let _ = client
.simple_query("CREATE TABLE test(b BLOB, ts TIMESTAMP TIME INDEX)")
.await
.unwrap();
let _ = client
.simple_query("INSERT INTO test VALUES(X'6162636b6c6d2aa954', 0)")
.await
.unwrap();
let get_row = |mess: Vec<SimpleQueryMessage>| -> String {
match &mess[0] {
SimpleQueryMessage::Row(row) => row.get(0).unwrap().to_string(),
_ => unreachable!(),
}
};
let r = client.simple_query("SELECT b FROM test").await.unwrap();
let b = get_row(r);
assert_eq!(b, "\\x6162636b6c6d2aa954");
let _ = client.simple_query("SET bytea_output='hex'").await.unwrap();
let r = client.simple_query("SELECT b FROM test").await.unwrap();
let b = get_row(r);
assert_eq!(b, "\\x6162636b6c6d2aa954");
let _ = client
.simple_query("SET bytea_output='escape'")
.await
.unwrap();
let r = client.simple_query("SELECT b FROM test").await.unwrap();
let b = get_row(r);
assert_eq!(b, "abcklm*\\251T");
let _e = client
.simple_query("SET bytea_output='invalid'")
.await
.unwrap_err();
// binary format shall not be affected by bytea_output
let pool = PgPoolOptions::new()
.max_connections(2)
.connect(&format!("postgres://{addr}/public"))
.await
.unwrap();
let row = sqlx::query("select b from test")
.fetch_one(&pool)
.await
.unwrap();
let val: Vec<u8> = row.get("b");
assert_eq!(val, [97, 98, 99, 107, 108, 109, 42, 169, 84]);
let _ = fe_pg_server.shutdown().await;
guard.remove_all().await;
}
pub async fn test_postgres_timezone(store_type: StorageType) {
let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await;