diff --git a/Cargo.lock b/Cargo.lock index cb732d1ab0..2cd6d46c46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9030,6 +9030,7 @@ dependencies = [ "common-time", "common-version", "criterion", + "dashmap", "datafusion", "datafusion-common", "datatypes", @@ -9103,8 +9104,12 @@ dependencies = [ "arc-swap", "auth", "common-catalog", + "common-error", + "common-macro", + "common-telemetry", "common-time", "derive_builder 0.12.0", + "snafu", "sql", ] diff --git a/src/operator/src/error.rs b/src/operator/src/error.rs index 731fbe288b..43bd75865b 100644 --- a/src/operator/src/error.rs +++ b/src/operator/src/error.rs @@ -20,9 +20,9 @@ use common_error::status_code::StatusCode; use common_macro::stack_trace_debug; use datafusion::parquet; use datatypes::arrow::error::ArrowError; -use datatypes::value::Value; use servers::define_into_tonic_status; use snafu::{Location, Snafu}; +use sql::ast::Value; #[derive(Snafu)] #[snafu(visibility(pub))] @@ -528,6 +528,12 @@ pub enum Error { #[snafu(display("Invalid partition rule: {}", reason))] InvalidPartitionRule { reason: String, location: Location }, + + #[snafu(display("Invalid configuration value."))] + InvalidConfigValue { + source: session::session_config::Error, + location: Location, + }, } pub type Result = std::result::Result; @@ -536,6 +542,7 @@ impl ErrorExt for Error { fn status_code(&self) -> StatusCode { match self { Error::InvalidSql { .. } + | Error::InvalidConfigValue { .. } | Error::InvalidInsertRequest { .. } | Error::InvalidDeleteRequest { .. } | Error::IllegalPrimaryKeysDef { .. } diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index 1a30c596ce..2fb9267f79 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -39,6 +39,7 @@ use query::parser::QueryStatement; use query::plan::LogicalPlan; use query::QueryEngineRef; use session::context::QueryContextRef; +use session::session_config::PGByteaOutputValue; use session::table_name::table_idents_to_full_name; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument}; @@ -52,8 +53,8 @@ use table::table_reference::TableReference; use table::TableRef; use crate::error::{ - self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu, - PlanStatementSnafu, Result, TableNotFoundSnafu, + self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidConfigValueSnafu, + InvalidSqlSnafu, NotSupportedSnafu, PlanStatementSnafu, Result, TableNotFoundSnafu, }; use crate::insert::InserterRef; use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY}; @@ -219,8 +220,7 @@ impl StatementExecutor { // so we just ignore it here instead of returning an error to break the connection. // Since the "bytea_output" only determines the output format of binary values, // it won't cause much trouble if we do so. - // TODO(#3438): Remove this temporary workaround after the feature is implemented. - "BYTEA_OUTPUT" => (), + "BYTEA_OUTPUT" => set_bytea_output(set_var.value, query_ctx)?, // Same as "bytea_output", we just ignore it here. // Not harmful since it only relates to how date is viewed in client app's output. @@ -339,6 +339,25 @@ fn set_timezone(exprs: Vec, ctx: QueryContextRef) -> Result<()> { } } +fn set_bytea_output(exprs: Vec, ctx: QueryContextRef) -> Result<()> { + let Some((var_value, [])) = exprs.split_first() else { + return (NotSupportedSnafu { + feat: "Set variable value must have one and only one value for bytea_output", + }) + .fail(); + }; + let Expr::Value(value) = var_value else { + return (NotSupportedSnafu { + feat: "Set variable value must be a value", + }) + .fail(); + }; + ctx.configuration_parameter().set_postgres_bytea_output( + PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?, + ); + Ok(()) +} + fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result { let direction = match stmt { CopyTable::To(_) => CopyDirection::Export, diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 43dbc55703..5cfbdd5462 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -41,6 +41,7 @@ common-recordbatch.workspace = true common-runtime.workspace = true common-telemetry.workspace = true common-time.workspace = true +dashmap.workspace = true datafusion.workspace = true datafusion-common.workspace = true datatypes.workspace = true diff --git a/src/servers/src/mysql/federated.rs b/src/servers/src/mysql/federated.rs index a9c9d630b8..d9f16bd654 100644 --- a/src/servers/src/mysql/federated.rs +++ b/src/servers/src/mysql/federated.rs @@ -310,7 +310,7 @@ mod test { #[test] fn test_check() { - let session = Arc::new(Session::new(None, Channel::Mysql)); + let session = Arc::new(Session::new(None, Channel::Mysql, Default::default())); let query = "select 1"; let result = check(query, QueryContext::arc(), session.clone()); assert!(result.is_none()); @@ -320,7 +320,7 @@ mod test { assert!(output.is_none()); fn test(query: &str, expected: &str) { - let session = Arc::new(Session::new(None, Channel::Mysql)); + let session = Arc::new(Session::new(None, Channel::Mysql, Default::default())); let output = check(query, QueryContext::arc(), session.clone()); match output.unwrap().data { OutputData::RecordBatches(r) => { diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 8c1814580f..9fe088cb66 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -85,7 +85,11 @@ impl MysqlInstanceShim { MysqlInstanceShim { query_handler, salt: scramble, - session: Arc::new(Session::new(Some(client_addr), Channel::Mysql)), + session: Arc::new(Session::new( + Some(client_addr), + Channel::Mysql, + Default::default(), + )), user_provider, prepared_stmts: Default::default(), prepared_stmts_counter: AtomicU32::new(1), diff --git a/src/servers/src/postgres.rs b/src/servers/src/postgres.rs index 0836ea51bb..c6e10ad8db 100644 --- a/src/servers/src/postgres.rs +++ b/src/servers/src/postgres.rs @@ -88,7 +88,7 @@ pub(crate) struct MakePostgresServerHandler { impl MakePostgresServerHandler { fn make(&self, addr: Option) -> PostgresServerHandler { - let session = Arc::new(Session::new(addr, Channel::Postgres)); + let session = Arc::new(Session::new(addr, Channel::Postgres, Default::default())); PostgresServerHandler { query_handler: self.query_handler.clone(), login_verifier: PgLoginVerifier::new(self.user_provider.clone()), diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 711c2b30db..bb3db5bc9c 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -31,6 +31,7 @@ use pgwire::api::stmt::{QueryParser, StoredStatement}; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use query::query_engine::DescribeResult; +use session::context::QueryContextRef; use session::Session; use sql::dialect::PostgreSqlDialect; use sql::parser::{ParseOptions, ParserContext}; @@ -63,7 +64,7 @@ impl SimpleQueryHandler for PostgresServerHandler { let mut results = Vec::with_capacity(outputs.len()); for output in outputs { - let resp = output_to_query_response(output, &Format::UnifiedText)?; + let resp = output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?; results.push(resp); } @@ -72,6 +73,7 @@ impl SimpleQueryHandler for PostgresServerHandler { } fn output_to_query_response<'a>( + query_ctx: QueryContextRef, output: Result, field_format: &Format, ) -> PgWireResult> { @@ -82,11 +84,16 @@ fn output_to_query_response<'a>( } OutputData::Stream(record_stream) => { let schema = record_stream.schema(); - recordbatches_to_query_response(record_stream, schema, field_format) + recordbatches_to_query_response(query_ctx, record_stream, schema, field_format) } OutputData::RecordBatches(recordbatches) => { let schema = recordbatches.schema(); - recordbatches_to_query_response(recordbatches.as_stream(), schema, field_format) + recordbatches_to_query_response( + query_ctx, + recordbatches.as_stream(), + schema, + field_format, + ) } }, Err(e) => Ok(Response::Error(Box::new(ErrorInfo::new( @@ -98,6 +105,7 @@ fn output_to_query_response<'a>( } fn recordbatches_to_query_response<'a, S>( + query_ctx: QueryContextRef, recordbatches_stream: S, schema: SchemaRef, field_format: &Format, @@ -125,7 +133,7 @@ where row.and_then(|row| { let mut encoder = DataRowEncoder::new(pg_schema_ref.clone()); for value in row.iter() { - encode_value(value, &mut encoder)?; + encode_value(&query_ctx, value, &mut encoder)?; } encoder.finish() }) @@ -224,7 +232,9 @@ impl ExtendedQueryHandler for PostgresServerHandler { let plan = plan .replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref()) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - self.query_handler.do_exec_plan(plan, query_ctx).await + self.query_handler + .do_exec_plan(plan, query_ctx.clone()) + .await } else { // manually replace variables in prepared statement when no // logical_plan is generated. This happens when logical plan is not @@ -234,10 +244,13 @@ impl ExtendedQueryHandler for PostgresServerHandler { sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?); } - self.query_handler.do_query(&sql, query_ctx).await.remove(0) + self.query_handler + .do_query(&sql, query_ctx.clone()) + .await + .remove(0) }; - output_to_query_response(output, &portal.result_column_format) + output_to_query_response(query_ctx, output, &portal.result_column_format) } async fn do_describe_statement( diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index d863c4d4e0..01351f8541 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod bytea; mod interval; use std::collections::HashMap; @@ -28,7 +29,10 @@ use pgwire::api::results::{DataRowEncoder, FieldInfo}; use pgwire::api::Type; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use query::plan::LogicalPlan; +use session::context::QueryContextRef; +use session::session_config::PGByteaOutputValue; +use self::bytea::{EscapeOutputBytea, HexOutputBytea}; use self::interval::PgInterval; use crate::error::{self, Error, Result}; use crate::SqlPlan; @@ -50,7 +54,11 @@ pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result>>() } -pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> { +pub(super) fn encode_value( + query_ctx: &QueryContextRef, + value: &Value, + builder: &mut DataRowEncoder, +) -> PgWireResult<()> { match value { Value::Null => builder.encode_field(&None::<&i8>), Value::Boolean(v) => builder.encode_field(v), @@ -65,7 +73,13 @@ pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWir Value::Float32(v) => builder.encode_field(&v.0), Value::Float64(v) => builder.encode_field(&v.0), Value::String(v) => builder.encode_field(&v.as_utf8()), - Value::Binary(v) => builder.encode_field(&v.deref()), + Value::Binary(v) => { + let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output(); + match *bytea_output { + PGByteaOutputValue::ESCAPE => builder.encode_field(&EscapeOutputBytea(v.deref())), + PGByteaOutputValue::HEX => builder.encode_field(&HexOutputBytea(v.deref())), + } + } Value::Date(v) => { if let Some(date) = v.to_chrono_date() { builder.encode_field(&date) @@ -563,6 +577,7 @@ mod test { use datatypes::value::ListValue; use pgwire::api::results::{FieldFormat, FieldInfo}; use pgwire::api::Type; + use session::context::QueryContextBuilder; use super::*; @@ -784,12 +799,16 @@ mod test { Value::Timestamp(1000001i64.into()), Value::Interval(1000001i128.into()), ]; + let query_context = QueryContextBuilder::default() + .configuration_parameter(Default::default()) + .build(); let mut builder = DataRowEncoder::new(Arc::new(schema)); for i in values.iter() { - encode_value(i, &mut builder).unwrap(); + encode_value(&query_context, i, &mut builder).unwrap(); } let err = encode_value( + &query_context, &Value::List(ListValue::new( Some(Box::default()), ConcreteDataType::int16_datatype(), diff --git a/src/servers/src/postgres/types/bytea.rs b/src/servers/src/postgres/types/bytea.rs new file mode 100644 index 0000000000..975d670f9c --- /dev/null +++ b/src/servers/src/postgres/types/bytea.rs @@ -0,0 +1,152 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use bytes::BufMut; +use pgwire::types::ToSqlText; +use postgres_types::{IsNull, ToSql, Type}; + +#[derive(Debug)] +pub struct HexOutputBytea<'a>(pub &'a [u8]); +impl ToSqlText for HexOutputBytea<'_> { + fn to_sql_text( + &self, + ty: &Type, + out: &mut bytes::BytesMut, + ) -> std::result::Result> + where + Self: Sized, + { + out.put_slice(b"\\x"); + let _ = self.0.to_sql_text(ty, out); + Ok(IsNull::No) + } +} + +impl ToSql for HexOutputBytea<'_> { + fn to_sql( + &self, + ty: &Type, + out: &mut bytes::BytesMut, + ) -> std::result::Result> + where + Self: Sized, + { + self.0.to_sql(ty, out) + } + + fn accepts(ty: &Type) -> bool + where + Self: Sized, + { + <&[u8] as ToSql>::accepts(ty) + } + + fn to_sql_checked( + &self, + ty: &Type, + out: &mut bytes::BytesMut, + ) -> std::result::Result> { + self.0.to_sql_checked(ty, out) + } +} +#[derive(Debug)] +pub struct EscapeOutputBytea<'a>(pub &'a [u8]); +impl ToSqlText for EscapeOutputBytea<'_> { + fn to_sql_text( + &self, + _ty: &Type, + out: &mut bytes::BytesMut, + ) -> std::result::Result> + where + Self: Sized, + { + self.0.iter().for_each(|b| match b { + 0..=31 | 127..=255 => { + out.put_slice(b"\\"); + out.put_slice(format!("{:03o}", b).as_bytes()); + } + 92 => out.put_slice(b"\\\\"), + 32..=126 => out.put_u8(*b), + }); + Ok(IsNull::No) + } +} +impl ToSql for EscapeOutputBytea<'_> { + fn to_sql( + &self, + ty: &Type, + out: &mut bytes::BytesMut, + ) -> std::result::Result> + where + Self: Sized, + { + self.0.to_sql(ty, out) + } + + fn accepts(ty: &Type) -> bool + where + Self: Sized, + { + <&[u8] as ToSql>::accepts(ty) + } + + fn to_sql_checked( + &self, + ty: &Type, + out: &mut bytes::BytesMut, + ) -> std::result::Result> { + self.0.to_sql_checked(ty, out) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_escape_output_bytea() { + let input: &[u8] = &[97, 98, 99, 107, 108, 109, 42, 169, 84]; + let input = EscapeOutputBytea(input); + + let expected = b"abcklm*\\251T"; + let mut out = bytes::BytesMut::new(); + let is_null = input.to_sql_text(&Type::BYTEA, &mut out).unwrap(); + assert!(matches!(is_null, IsNull::No)); + assert_eq!(&out[..], expected); + + let expected = &[97, 98, 99, 107, 108, 109, 42, 169, 84]; + let mut out = bytes::BytesMut::new(); + let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap(); + assert!(matches!(is_null, IsNull::No)); + assert_eq!(&out[..], expected); + } + + #[test] + fn test_hex_output_bytea() { + let input = b"hello, world!"; + let input = HexOutputBytea(input); + + let expected = b"\\x68656c6c6f2c20776f726c6421"; + let mut out = bytes::BytesMut::new(); + let is_null = input.to_sql_text(&Type::BYTEA, &mut out).unwrap(); + assert!(matches!(is_null, IsNull::No)); + assert_eq!(&out[..], expected); + + let expected = b"hello, world!"; + let mut out = bytes::BytesMut::new(); + let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap(); + assert!(matches!(is_null, IsNull::No)); + assert_eq!(&out[..], expected); + } +} diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index 85697d4ca7..8e0baeaa0f 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -15,6 +15,10 @@ api.workspace = true arc-swap = "1.5" auth.workspace = true common-catalog.workspace = true +common-error.workspace = true +common-macro.workspace = true +common-telemetry.workspace = true common-time.workspace = true derive_builder.workspace = true +snafu.workspace = true sql.workspace = true diff --git a/src/session/src/context.rs b/src/session/src/context.rs index cc41af3744..d401b03316 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -27,6 +27,7 @@ use common_time::Timezone; use derive_builder::Builder; use sql::dialect::{Dialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; +use crate::session_config::PGByteaOutputValue; use crate::SessionRef; pub type QueryContextRef = Arc; @@ -44,6 +45,9 @@ pub struct QueryContext { sql_dialect: Arc, #[builder(default)] extension: HashMap, + // The configuration parameter are used to store the parameters that are set by the user + #[builder(default)] + configuration_parameter: Arc, } impl QueryContextBuilder { @@ -73,6 +77,7 @@ impl Clone for QueryContext { timezone: self.timezone.load().clone().into(), sql_dialect: self.sql_dialect.clone(), extension: self.extension.clone(), + configuration_parameter: self.configuration_parameter.clone(), } } } @@ -88,6 +93,7 @@ impl From<&RegionRequestHeader> for QueryContext { timezone: ArcSwap::new(Arc::new(get_timezone(None).clone())), sql_dialect: Arc::new(GreptimeDbDialect {}), extension: Default::default(), + configuration_parameter: Default::default(), } } } @@ -183,6 +189,10 @@ impl QueryContext { '`' } } + + pub fn configuration_parameter(&self) -> &ConfigurationVariables { + &self.configuration_parameter + } } impl QueryContextBuilder { @@ -204,6 +214,7 @@ impl QueryContextBuilder { .sql_dialect .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})), extension: self.extension.unwrap_or_default(), + configuration_parameter: self.configuration_parameter.unwrap_or_default(), }) } @@ -268,6 +279,33 @@ impl Display for Channel { } } +#[derive(Default, Debug)] +pub struct ConfigurationVariables { + postgres_bytea_output: ArcSwap, +} + +impl Clone for ConfigurationVariables { + fn clone(&self) -> Self { + Self { + postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()), + } + } +} + +impl ConfigurationVariables { + pub fn new() -> Self { + Self::default() + } + + pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) { + let _ = self.postgres_bytea_output.swap(Arc::new(value)); + } + + pub fn postgres_bytea_output(&self) -> Arc { + self.postgres_bytea_output.load().clone() + } +} + #[cfg(test)] mod test { use common_catalog::consts::DEFAULT_CATALOG_NAME; @@ -278,7 +316,11 @@ mod test { #[test] fn test_session() { - let session = Session::new(Some("127.0.0.1:9000".parse().unwrap()), Channel::Mysql); + let session = Session::new( + Some("127.0.0.1:9000".parse().unwrap()), + Channel::Mysql, + Default::default(), + ); // test user_info assert_eq!(session.user_info().username(), "greptime"); diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index b51511ce6d..e89a733553 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -13,6 +13,7 @@ // limitations under the License. pub mod context; +pub mod session_config; pub mod table_name; use std::net::SocketAddr; @@ -24,7 +25,7 @@ use common_catalog::build_db_string; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_time::timezone::get_timezone; use common_time::Timezone; -use context::QueryContextBuilder; +use context::{ConfigurationVariables, QueryContextBuilder}; use crate::context::{Channel, ConnInfo, QueryContextRef}; @@ -36,18 +37,24 @@ pub struct Session { user_info: ArcSwap, conn_info: ConnInfo, timezone: ArcSwap, + configuration_variables: Arc, } pub type SessionRef = Arc; impl Session { - pub fn new(addr: Option, channel: Channel) -> Self { + pub fn new( + addr: Option, + channel: Channel, + configuration_variables: ConfigurationVariables, + ) -> Self { Session { catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.into())), schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.into())), user_info: ArcSwap::new(Arc::new(auth::userinfo_by_name(None))), conn_info: ConnInfo::new(addr, channel), timezone: ArcSwap::new(Arc::new(get_timezone(None).clone())), + configuration_variables: Arc::new(configuration_variables), } } @@ -60,6 +67,7 @@ impl Session { .current_catalog(self.catalog.load().to_string()) .current_schema(self.schema.load().to_string()) .sql_dialect(self.conn_info.channel.dialect()) + .configuration_parameter(self.configuration_variables.clone()) .timezone(self.timezone()) .build() } diff --git a/src/session/src/session_config.rs b/src/session/src/session_config.rs new file mode 100644 index 0000000000..aad50e70c1 --- /dev/null +++ b/src/session/src/session_config.rs @@ -0,0 +1,64 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_macro::stack_trace_debug; +use snafu::{Location, Snafu}; +use sql::ast::Value; + +#[derive(Snafu)] +#[snafu(visibility(pub))] +#[stack_trace_debug] +pub enum Error { + #[snafu(display("Invalid value for parameter \"{}\": {}\nHint: {}", name, value, hint,))] + InvalidConfigValue { + name: String, + value: String, + hint: String, + location: Location, + }, +} + +#[derive(Clone, Copy, Debug, Default)] +pub enum PGByteaOutputValue { + #[default] + HEX, + ESCAPE, +} + +impl TryFrom for PGByteaOutputValue { + type Error = Error; + + fn try_from(value: Value) -> Result { + match &value { + Value::DoubleQuotedString(s) | Value::SingleQuotedString(s) => { + match s.to_uppercase().as_str() { + "ESCAPE" => Ok(PGByteaOutputValue::ESCAPE), + "HEX" => Ok(PGByteaOutputValue::HEX), + _ => InvalidConfigValueSnafu { + name: "BYTEA_OUTPUT", + value: value.to_string(), + hint: "Available values: escape, hex", + } + .fail(), + } + } + _ => InvalidConfigValueSnafu { + name: "BYTEA_OUTPUT", + value: value.to_string(), + hint: "Available values: escape, hex", + } + .fail(), + } + } +} diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index 7cff590c39..4628bc372f 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -60,6 +60,7 @@ macro_rules! sql_tests { test_postgres_auth, test_postgres_crud, test_postgres_timezone, + test_postgres_bytea, test_postgres_parameter_inference, test_mysql_prepare_stmt_insert_timestamp, ); @@ -415,7 +416,69 @@ pub async fn test_postgres_crud(store_type: StorageType) { let _ = fe_pg_server.shutdown().await; guard.remove_all().await; } +pub async fn test_postgres_bytea(store_type: StorageType) { + let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_bytea_output").await; + let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) + .await + .unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let _ = client + .simple_query("CREATE TABLE test(b BLOB, ts TIMESTAMP TIME INDEX)") + .await + .unwrap(); + let _ = client + .simple_query("INSERT INTO test VALUES(X'6162636b6c6d2aa954', 0)") + .await + .unwrap(); + let get_row = |mess: Vec| -> String { + match &mess[0] { + SimpleQueryMessage::Row(row) => row.get(0).unwrap().to_string(), + _ => unreachable!(), + } + }; + + let r = client.simple_query("SELECT b FROM test").await.unwrap(); + let b = get_row(r); + assert_eq!(b, "\\x6162636b6c6d2aa954"); + + let _ = client.simple_query("SET bytea_output='hex'").await.unwrap(); + let r = client.simple_query("SELECT b FROM test").await.unwrap(); + let b = get_row(r); + assert_eq!(b, "\\x6162636b6c6d2aa954"); + + let _ = client + .simple_query("SET bytea_output='escape'") + .await + .unwrap(); + let r = client.simple_query("SELECT b FROM test").await.unwrap(); + let b = get_row(r); + assert_eq!(b, "abcklm*\\251T"); + + let _e = client + .simple_query("SET bytea_output='invalid'") + .await + .unwrap_err(); + + // binary format shall not be affected by bytea_output + let pool = PgPoolOptions::new() + .max_connections(2) + .connect(&format!("postgres://{addr}/public")) + .await + .unwrap(); + + let row = sqlx::query("select b from test") + .fetch_one(&pool) + .await + .unwrap(); + let val: Vec = row.get("b"); + assert_eq!(val, [97, 98, 99, 107, 108, 109, 42, 169, 84]); + + let _ = fe_pg_server.shutdown().await; + guard.remove_all().await; +} pub async fn test_postgres_timezone(store_type: StorageType) { let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await;