mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-22 22:20:02 +00:00
feat: Support printing postgresql's bytea data type in its "hex" and "escape" format (#3567)
* feat: support set variable statement of session
* feat: support printing postgresql's bytea data type in its "hex" and "escape" format in ugly way
* refactor: add 'SessionConfigValue' type and unify the name
* doc: add license header
* refactor: confine coupling with 'sql::ast::Value' in SessionConfigValue
* refactor: move all bytea wrapper into bytea.rs
* fix: remove unused import in context.rs and postgres.rs
* refactor: rename 'set_configuration_parameter' to 'set_session_config'
rename 'set_configuration_parameter' in statement_.rs to 'set_session_config'
* refactor: use mod to organize options via macro
* refactor: re-model the session config value with static type
* test: add integration test
* refactor: move the encode bytea by format type logic into encoder
refactor: use Arc<DashMap> instead of DashMap in QueryContext
refactor: use Arc<DashMap> instead of DashMap in QueryContext
Avoid expensive clone
refactor: use unreachable!() instead of unimplemented!()
refactor: move the encode bytea by format type logic into encoder
test: add binary format integration test case
* test: add ut for byte related type
* doc: remove TODO of bytea_output
* refactor: simplify the implementation with simple struct instead of complex typing
* fix: typo of 'Available'
* fix compile
Signed-off-by: tison <wander4096@gmail.com>
---------
Signed-off-by: tison <wander4096@gmail.com>
Co-authored-by: tison <wander4096@gmail.com>
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -9030,6 +9030,7 @@ dependencies = [
|
||||
"common-time",
|
||||
"common-version",
|
||||
"criterion",
|
||||
"dashmap",
|
||||
"datafusion",
|
||||
"datafusion-common",
|
||||
"datatypes",
|
||||
@@ -9103,8 +9104,12 @@ dependencies = [
|
||||
"arc-swap",
|
||||
"auth",
|
||||
"common-catalog",
|
||||
"common-error",
|
||||
"common-macro",
|
||||
"common-telemetry",
|
||||
"common-time",
|
||||
"derive_builder 0.12.0",
|
||||
"snafu",
|
||||
"sql",
|
||||
]
|
||||
|
||||
|
||||
@@ -20,9 +20,9 @@ use common_error::status_code::StatusCode;
|
||||
use common_macro::stack_trace_debug;
|
||||
use datafusion::parquet;
|
||||
use datatypes::arrow::error::ArrowError;
|
||||
use datatypes::value::Value;
|
||||
use servers::define_into_tonic_status;
|
||||
use snafu::{Location, Snafu};
|
||||
use sql::ast::Value;
|
||||
|
||||
#[derive(Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
@@ -528,6 +528,12 @@ pub enum Error {
|
||||
|
||||
#[snafu(display("Invalid partition rule: {}", reason))]
|
||||
InvalidPartitionRule { reason: String, location: Location },
|
||||
|
||||
#[snafu(display("Invalid configuration value."))]
|
||||
InvalidConfigValue {
|
||||
source: session::session_config::Error,
|
||||
location: Location,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -536,6 +542,7 @@ impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::InvalidSql { .. }
|
||||
| Error::InvalidConfigValue { .. }
|
||||
| Error::InvalidInsertRequest { .. }
|
||||
| Error::InvalidDeleteRequest { .. }
|
||||
| Error::IllegalPrimaryKeysDef { .. }
|
||||
|
||||
@@ -39,6 +39,7 @@ use query::parser::QueryStatement;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::QueryEngineRef;
|
||||
use session::context::QueryContextRef;
|
||||
use session::session_config::PGByteaOutputValue;
|
||||
use session::table_name::table_idents_to_full_name;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
|
||||
@@ -52,8 +53,8 @@ use table::table_reference::TableReference;
|
||||
use table::TableRef;
|
||||
|
||||
use crate::error::{
|
||||
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu,
|
||||
PlanStatementSnafu, Result, TableNotFoundSnafu,
|
||||
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidConfigValueSnafu,
|
||||
InvalidSqlSnafu, NotSupportedSnafu, PlanStatementSnafu, Result, TableNotFoundSnafu,
|
||||
};
|
||||
use crate::insert::InserterRef;
|
||||
use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY};
|
||||
@@ -219,8 +220,7 @@ impl StatementExecutor {
|
||||
// so we just ignore it here instead of returning an error to break the connection.
|
||||
// Since the "bytea_output" only determines the output format of binary values,
|
||||
// it won't cause much trouble if we do so.
|
||||
// TODO(#3438): Remove this temporary workaround after the feature is implemented.
|
||||
"BYTEA_OUTPUT" => (),
|
||||
"BYTEA_OUTPUT" => set_bytea_output(set_var.value, query_ctx)?,
|
||||
|
||||
// Same as "bytea_output", we just ignore it here.
|
||||
// Not harmful since it only relates to how date is viewed in client app's output.
|
||||
@@ -339,6 +339,25 @@ fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
fn set_bytea_output(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
|
||||
let Some((var_value, [])) = exprs.split_first() else {
|
||||
return (NotSupportedSnafu {
|
||||
feat: "Set variable value must have one and only one value for bytea_output",
|
||||
})
|
||||
.fail();
|
||||
};
|
||||
let Expr::Value(value) = var_value else {
|
||||
return (NotSupportedSnafu {
|
||||
feat: "Set variable value must be a value",
|
||||
})
|
||||
.fail();
|
||||
};
|
||||
ctx.configuration_parameter().set_postgres_bytea_output(
|
||||
PGByteaOutputValue::try_from(value.clone()).context(InvalidConfigValueSnafu)?,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result<CopyTableRequest> {
|
||||
let direction = match stmt {
|
||||
CopyTable::To(_) => CopyDirection::Export,
|
||||
|
||||
@@ -41,6 +41,7 @@ common-recordbatch.workspace = true
|
||||
common-runtime.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
common-time.workspace = true
|
||||
dashmap.workspace = true
|
||||
datafusion.workspace = true
|
||||
datafusion-common.workspace = true
|
||||
datatypes.workspace = true
|
||||
|
||||
@@ -310,7 +310,7 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_check() {
|
||||
let session = Arc::new(Session::new(None, Channel::Mysql));
|
||||
let session = Arc::new(Session::new(None, Channel::Mysql, Default::default()));
|
||||
let query = "select 1";
|
||||
let result = check(query, QueryContext::arc(), session.clone());
|
||||
assert!(result.is_none());
|
||||
@@ -320,7 +320,7 @@ mod test {
|
||||
assert!(output.is_none());
|
||||
|
||||
fn test(query: &str, expected: &str) {
|
||||
let session = Arc::new(Session::new(None, Channel::Mysql));
|
||||
let session = Arc::new(Session::new(None, Channel::Mysql, Default::default()));
|
||||
let output = check(query, QueryContext::arc(), session.clone());
|
||||
match output.unwrap().data {
|
||||
OutputData::RecordBatches(r) => {
|
||||
|
||||
@@ -85,7 +85,11 @@ impl MysqlInstanceShim {
|
||||
MysqlInstanceShim {
|
||||
query_handler,
|
||||
salt: scramble,
|
||||
session: Arc::new(Session::new(Some(client_addr), Channel::Mysql)),
|
||||
session: Arc::new(Session::new(
|
||||
Some(client_addr),
|
||||
Channel::Mysql,
|
||||
Default::default(),
|
||||
)),
|
||||
user_provider,
|
||||
prepared_stmts: Default::default(),
|
||||
prepared_stmts_counter: AtomicU32::new(1),
|
||||
|
||||
@@ -88,7 +88,7 @@ pub(crate) struct MakePostgresServerHandler {
|
||||
|
||||
impl MakePostgresServerHandler {
|
||||
fn make(&self, addr: Option<SocketAddr>) -> PostgresServerHandler {
|
||||
let session = Arc::new(Session::new(addr, Channel::Postgres));
|
||||
let session = Arc::new(Session::new(addr, Channel::Postgres, Default::default()));
|
||||
PostgresServerHandler {
|
||||
query_handler: self.query_handler.clone(),
|
||||
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
|
||||
|
||||
@@ -31,6 +31,7 @@ use pgwire::api::stmt::{QueryParser, StoredStatement};
|
||||
use pgwire::api::{ClientInfo, Type};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use query::query_engine::DescribeResult;
|
||||
use session::context::QueryContextRef;
|
||||
use session::Session;
|
||||
use sql::dialect::PostgreSqlDialect;
|
||||
use sql::parser::{ParseOptions, ParserContext};
|
||||
@@ -63,7 +64,7 @@ impl SimpleQueryHandler for PostgresServerHandler {
|
||||
let mut results = Vec::with_capacity(outputs.len());
|
||||
|
||||
for output in outputs {
|
||||
let resp = output_to_query_response(output, &Format::UnifiedText)?;
|
||||
let resp = output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
|
||||
results.push(resp);
|
||||
}
|
||||
|
||||
@@ -72,6 +73,7 @@ impl SimpleQueryHandler for PostgresServerHandler {
|
||||
}
|
||||
|
||||
fn output_to_query_response<'a>(
|
||||
query_ctx: QueryContextRef,
|
||||
output: Result<Output>,
|
||||
field_format: &Format,
|
||||
) -> PgWireResult<Response<'a>> {
|
||||
@@ -82,11 +84,16 @@ fn output_to_query_response<'a>(
|
||||
}
|
||||
OutputData::Stream(record_stream) => {
|
||||
let schema = record_stream.schema();
|
||||
recordbatches_to_query_response(record_stream, schema, field_format)
|
||||
recordbatches_to_query_response(query_ctx, record_stream, schema, field_format)
|
||||
}
|
||||
OutputData::RecordBatches(recordbatches) => {
|
||||
let schema = recordbatches.schema();
|
||||
recordbatches_to_query_response(recordbatches.as_stream(), schema, field_format)
|
||||
recordbatches_to_query_response(
|
||||
query_ctx,
|
||||
recordbatches.as_stream(),
|
||||
schema,
|
||||
field_format,
|
||||
)
|
||||
}
|
||||
},
|
||||
Err(e) => Ok(Response::Error(Box::new(ErrorInfo::new(
|
||||
@@ -98,6 +105,7 @@ fn output_to_query_response<'a>(
|
||||
}
|
||||
|
||||
fn recordbatches_to_query_response<'a, S>(
|
||||
query_ctx: QueryContextRef,
|
||||
recordbatches_stream: S,
|
||||
schema: SchemaRef,
|
||||
field_format: &Format,
|
||||
@@ -125,7 +133,7 @@ where
|
||||
row.and_then(|row| {
|
||||
let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
|
||||
for value in row.iter() {
|
||||
encode_value(value, &mut encoder)?;
|
||||
encode_value(&query_ctx, value, &mut encoder)?;
|
||||
}
|
||||
encoder.finish()
|
||||
})
|
||||
@@ -224,7 +232,9 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
let plan = plan
|
||||
.replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref())
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
self.query_handler.do_exec_plan(plan, query_ctx).await
|
||||
self.query_handler
|
||||
.do_exec_plan(plan, query_ctx.clone())
|
||||
.await
|
||||
} else {
|
||||
// manually replace variables in prepared statement when no
|
||||
// logical_plan is generated. This happens when logical plan is not
|
||||
@@ -234,10 +244,13 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
|
||||
}
|
||||
|
||||
self.query_handler.do_query(&sql, query_ctx).await.remove(0)
|
||||
self.query_handler
|
||||
.do_query(&sql, query_ctx.clone())
|
||||
.await
|
||||
.remove(0)
|
||||
};
|
||||
|
||||
output_to_query_response(output, &portal.result_column_format)
|
||||
output_to_query_response(query_ctx, output, &portal.result_column_format)
|
||||
}
|
||||
|
||||
async fn do_describe_statement<C>(
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod bytea;
|
||||
mod interval;
|
||||
|
||||
use std::collections::HashMap;
|
||||
@@ -28,7 +29,10 @@ use pgwire::api::results::{DataRowEncoder, FieldInfo};
|
||||
use pgwire::api::Type;
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use query::plan::LogicalPlan;
|
||||
use session::context::QueryContextRef;
|
||||
use session::session_config::PGByteaOutputValue;
|
||||
|
||||
use self::bytea::{EscapeOutputBytea, HexOutputBytea};
|
||||
use self::interval::PgInterval;
|
||||
use crate::error::{self, Error, Result};
|
||||
use crate::SqlPlan;
|
||||
@@ -50,7 +54,11 @@ pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Ve
|
||||
.collect::<Result<Vec<FieldInfo>>>()
|
||||
}
|
||||
|
||||
pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
|
||||
pub(super) fn encode_value(
|
||||
query_ctx: &QueryContextRef,
|
||||
value: &Value,
|
||||
builder: &mut DataRowEncoder,
|
||||
) -> PgWireResult<()> {
|
||||
match value {
|
||||
Value::Null => builder.encode_field(&None::<&i8>),
|
||||
Value::Boolean(v) => builder.encode_field(v),
|
||||
@@ -65,7 +73,13 @@ pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWir
|
||||
Value::Float32(v) => builder.encode_field(&v.0),
|
||||
Value::Float64(v) => builder.encode_field(&v.0),
|
||||
Value::String(v) => builder.encode_field(&v.as_utf8()),
|
||||
Value::Binary(v) => builder.encode_field(&v.deref()),
|
||||
Value::Binary(v) => {
|
||||
let bytea_output = query_ctx.configuration_parameter().postgres_bytea_output();
|
||||
match *bytea_output {
|
||||
PGByteaOutputValue::ESCAPE => builder.encode_field(&EscapeOutputBytea(v.deref())),
|
||||
PGByteaOutputValue::HEX => builder.encode_field(&HexOutputBytea(v.deref())),
|
||||
}
|
||||
}
|
||||
Value::Date(v) => {
|
||||
if let Some(date) = v.to_chrono_date() {
|
||||
builder.encode_field(&date)
|
||||
@@ -563,6 +577,7 @@ mod test {
|
||||
use datatypes::value::ListValue;
|
||||
use pgwire::api::results::{FieldFormat, FieldInfo};
|
||||
use pgwire::api::Type;
|
||||
use session::context::QueryContextBuilder;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -784,12 +799,16 @@ mod test {
|
||||
Value::Timestamp(1000001i64.into()),
|
||||
Value::Interval(1000001i128.into()),
|
||||
];
|
||||
let query_context = QueryContextBuilder::default()
|
||||
.configuration_parameter(Default::default())
|
||||
.build();
|
||||
let mut builder = DataRowEncoder::new(Arc::new(schema));
|
||||
for i in values.iter() {
|
||||
encode_value(i, &mut builder).unwrap();
|
||||
encode_value(&query_context, i, &mut builder).unwrap();
|
||||
}
|
||||
|
||||
let err = encode_value(
|
||||
&query_context,
|
||||
&Value::List(ListValue::new(
|
||||
Some(Box::default()),
|
||||
ConcreteDataType::int16_datatype(),
|
||||
|
||||
152
src/servers/src/postgres/types/bytea.rs
Normal file
152
src/servers/src/postgres/types/bytea.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use bytes::BufMut;
|
||||
use pgwire::types::ToSqlText;
|
||||
use postgres_types::{IsNull, ToSql, Type};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct HexOutputBytea<'a>(pub &'a [u8]);
|
||||
impl ToSqlText for HexOutputBytea<'_> {
|
||||
fn to_sql_text(
|
||||
&self,
|
||||
ty: &Type,
|
||||
out: &mut bytes::BytesMut,
|
||||
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
out.put_slice(b"\\x");
|
||||
let _ = self.0.to_sql_text(ty, out);
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToSql for HexOutputBytea<'_> {
|
||||
fn to_sql(
|
||||
&self,
|
||||
ty: &Type,
|
||||
out: &mut bytes::BytesMut,
|
||||
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.0.to_sql(ty, out)
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
<&[u8] as ToSql>::accepts(ty)
|
||||
}
|
||||
|
||||
fn to_sql_checked(
|
||||
&self,
|
||||
ty: &Type,
|
||||
out: &mut bytes::BytesMut,
|
||||
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
|
||||
self.0.to_sql_checked(ty, out)
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct EscapeOutputBytea<'a>(pub &'a [u8]);
|
||||
impl ToSqlText for EscapeOutputBytea<'_> {
|
||||
fn to_sql_text(
|
||||
&self,
|
||||
_ty: &Type,
|
||||
out: &mut bytes::BytesMut,
|
||||
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.0.iter().for_each(|b| match b {
|
||||
0..=31 | 127..=255 => {
|
||||
out.put_slice(b"\\");
|
||||
out.put_slice(format!("{:03o}", b).as_bytes());
|
||||
}
|
||||
92 => out.put_slice(b"\\\\"),
|
||||
32..=126 => out.put_u8(*b),
|
||||
});
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
}
|
||||
impl ToSql for EscapeOutputBytea<'_> {
|
||||
fn to_sql(
|
||||
&self,
|
||||
ty: &Type,
|
||||
out: &mut bytes::BytesMut,
|
||||
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.0.to_sql(ty, out)
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
<&[u8] as ToSql>::accepts(ty)
|
||||
}
|
||||
|
||||
fn to_sql_checked(
|
||||
&self,
|
||||
ty: &Type,
|
||||
out: &mut bytes::BytesMut,
|
||||
) -> std::result::Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
|
||||
self.0.to_sql_checked(ty, out)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_escape_output_bytea() {
|
||||
let input: &[u8] = &[97, 98, 99, 107, 108, 109, 42, 169, 84];
|
||||
let input = EscapeOutputBytea(input);
|
||||
|
||||
let expected = b"abcklm*\\251T";
|
||||
let mut out = bytes::BytesMut::new();
|
||||
let is_null = input.to_sql_text(&Type::BYTEA, &mut out).unwrap();
|
||||
assert!(matches!(is_null, IsNull::No));
|
||||
assert_eq!(&out[..], expected);
|
||||
|
||||
let expected = &[97, 98, 99, 107, 108, 109, 42, 169, 84];
|
||||
let mut out = bytes::BytesMut::new();
|
||||
let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap();
|
||||
assert!(matches!(is_null, IsNull::No));
|
||||
assert_eq!(&out[..], expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hex_output_bytea() {
|
||||
let input = b"hello, world!";
|
||||
let input = HexOutputBytea(input);
|
||||
|
||||
let expected = b"\\x68656c6c6f2c20776f726c6421";
|
||||
let mut out = bytes::BytesMut::new();
|
||||
let is_null = input.to_sql_text(&Type::BYTEA, &mut out).unwrap();
|
||||
assert!(matches!(is_null, IsNull::No));
|
||||
assert_eq!(&out[..], expected);
|
||||
|
||||
let expected = b"hello, world!";
|
||||
let mut out = bytes::BytesMut::new();
|
||||
let is_null = input.to_sql(&Type::BYTEA, &mut out).unwrap();
|
||||
assert!(matches!(is_null, IsNull::No));
|
||||
assert_eq!(&out[..], expected);
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,10 @@ api.workspace = true
|
||||
arc-swap = "1.5"
|
||||
auth.workspace = true
|
||||
common-catalog.workspace = true
|
||||
common-error.workspace = true
|
||||
common-macro.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
common-time.workspace = true
|
||||
derive_builder.workspace = true
|
||||
snafu.workspace = true
|
||||
sql.workspace = true
|
||||
|
||||
@@ -27,6 +27,7 @@ use common_time::Timezone;
|
||||
use derive_builder::Builder;
|
||||
use sql::dialect::{Dialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
|
||||
|
||||
use crate::session_config::PGByteaOutputValue;
|
||||
use crate::SessionRef;
|
||||
|
||||
pub type QueryContextRef = Arc<QueryContext>;
|
||||
@@ -44,6 +45,9 @@ pub struct QueryContext {
|
||||
sql_dialect: Arc<dyn Dialect + Send + Sync>,
|
||||
#[builder(default)]
|
||||
extension: HashMap<String, String>,
|
||||
// The configuration parameter are used to store the parameters that are set by the user
|
||||
#[builder(default)]
|
||||
configuration_parameter: Arc<ConfigurationVariables>,
|
||||
}
|
||||
|
||||
impl QueryContextBuilder {
|
||||
@@ -73,6 +77,7 @@ impl Clone for QueryContext {
|
||||
timezone: self.timezone.load().clone().into(),
|
||||
sql_dialect: self.sql_dialect.clone(),
|
||||
extension: self.extension.clone(),
|
||||
configuration_parameter: self.configuration_parameter.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -88,6 +93,7 @@ impl From<&RegionRequestHeader> for QueryContext {
|
||||
timezone: ArcSwap::new(Arc::new(get_timezone(None).clone())),
|
||||
sql_dialect: Arc::new(GreptimeDbDialect {}),
|
||||
extension: Default::default(),
|
||||
configuration_parameter: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -183,6 +189,10 @@ impl QueryContext {
|
||||
'`'
|
||||
}
|
||||
}
|
||||
|
||||
pub fn configuration_parameter(&self) -> &ConfigurationVariables {
|
||||
&self.configuration_parameter
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryContextBuilder {
|
||||
@@ -204,6 +214,7 @@ impl QueryContextBuilder {
|
||||
.sql_dialect
|
||||
.unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
|
||||
extension: self.extension.unwrap_or_default(),
|
||||
configuration_parameter: self.configuration_parameter.unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -268,6 +279,33 @@ impl Display for Channel {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ConfigurationVariables {
|
||||
postgres_bytea_output: ArcSwap<PGByteaOutputValue>,
|
||||
}
|
||||
|
||||
impl Clone for ConfigurationVariables {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
postgres_bytea_output: ArcSwap::new(self.postgres_bytea_output.load().clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigurationVariables {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn set_postgres_bytea_output(&self, value: PGByteaOutputValue) {
|
||||
let _ = self.postgres_bytea_output.swap(Arc::new(value));
|
||||
}
|
||||
|
||||
pub fn postgres_bytea_output(&self) -> Arc<PGByteaOutputValue> {
|
||||
self.postgres_bytea_output.load().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use common_catalog::consts::DEFAULT_CATALOG_NAME;
|
||||
@@ -278,7 +316,11 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_session() {
|
||||
let session = Session::new(Some("127.0.0.1:9000".parse().unwrap()), Channel::Mysql);
|
||||
let session = Session::new(
|
||||
Some("127.0.0.1:9000".parse().unwrap()),
|
||||
Channel::Mysql,
|
||||
Default::default(),
|
||||
);
|
||||
// test user_info
|
||||
assert_eq!(session.user_info().username(), "greptime");
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
pub mod context;
|
||||
pub mod session_config;
|
||||
pub mod table_name;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
@@ -24,7 +25,7 @@ use common_catalog::build_db_string;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_time::timezone::get_timezone;
|
||||
use common_time::Timezone;
|
||||
use context::QueryContextBuilder;
|
||||
use context::{ConfigurationVariables, QueryContextBuilder};
|
||||
|
||||
use crate::context::{Channel, ConnInfo, QueryContextRef};
|
||||
|
||||
@@ -36,18 +37,24 @@ pub struct Session {
|
||||
user_info: ArcSwap<UserInfoRef>,
|
||||
conn_info: ConnInfo,
|
||||
timezone: ArcSwap<Timezone>,
|
||||
configuration_variables: Arc<ConfigurationVariables>,
|
||||
}
|
||||
|
||||
pub type SessionRef = Arc<Session>;
|
||||
|
||||
impl Session {
|
||||
pub fn new(addr: Option<SocketAddr>, channel: Channel) -> Self {
|
||||
pub fn new(
|
||||
addr: Option<SocketAddr>,
|
||||
channel: Channel,
|
||||
configuration_variables: ConfigurationVariables,
|
||||
) -> Self {
|
||||
Session {
|
||||
catalog: ArcSwap::new(Arc::new(DEFAULT_CATALOG_NAME.into())),
|
||||
schema: ArcSwap::new(Arc::new(DEFAULT_SCHEMA_NAME.into())),
|
||||
user_info: ArcSwap::new(Arc::new(auth::userinfo_by_name(None))),
|
||||
conn_info: ConnInfo::new(addr, channel),
|
||||
timezone: ArcSwap::new(Arc::new(get_timezone(None).clone())),
|
||||
configuration_variables: Arc::new(configuration_variables),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +67,7 @@ impl Session {
|
||||
.current_catalog(self.catalog.load().to_string())
|
||||
.current_schema(self.schema.load().to_string())
|
||||
.sql_dialect(self.conn_info.channel.dialect())
|
||||
.configuration_parameter(self.configuration_variables.clone())
|
||||
.timezone(self.timezone())
|
||||
.build()
|
||||
}
|
||||
|
||||
64
src/session/src/session_config.rs
Normal file
64
src/session/src/session_config.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_macro::stack_trace_debug;
|
||||
use snafu::{Location, Snafu};
|
||||
use sql::ast::Value;
|
||||
|
||||
#[derive(Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
#[stack_trace_debug]
|
||||
pub enum Error {
|
||||
#[snafu(display("Invalid value for parameter \"{}\": {}\nHint: {}", name, value, hint,))]
|
||||
InvalidConfigValue {
|
||||
name: String,
|
||||
value: String,
|
||||
hint: String,
|
||||
location: Location,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub enum PGByteaOutputValue {
|
||||
#[default]
|
||||
HEX,
|
||||
ESCAPE,
|
||||
}
|
||||
|
||||
impl TryFrom<Value> for PGByteaOutputValue {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(value: Value) -> Result<Self, Self::Error> {
|
||||
match &value {
|
||||
Value::DoubleQuotedString(s) | Value::SingleQuotedString(s) => {
|
||||
match s.to_uppercase().as_str() {
|
||||
"ESCAPE" => Ok(PGByteaOutputValue::ESCAPE),
|
||||
"HEX" => Ok(PGByteaOutputValue::HEX),
|
||||
_ => InvalidConfigValueSnafu {
|
||||
name: "BYTEA_OUTPUT",
|
||||
value: value.to_string(),
|
||||
hint: "Available values: escape, hex",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
_ => InvalidConfigValueSnafu {
|
||||
name: "BYTEA_OUTPUT",
|
||||
value: value.to_string(),
|
||||
hint: "Available values: escape, hex",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -60,6 +60,7 @@ macro_rules! sql_tests {
|
||||
test_postgres_auth,
|
||||
test_postgres_crud,
|
||||
test_postgres_timezone,
|
||||
test_postgres_bytea,
|
||||
test_postgres_parameter_inference,
|
||||
test_mysql_prepare_stmt_insert_timestamp,
|
||||
);
|
||||
@@ -415,7 +416,69 @@ pub async fn test_postgres_crud(store_type: StorageType) {
|
||||
let _ = fe_pg_server.shutdown().await;
|
||||
guard.remove_all().await;
|
||||
}
|
||||
pub async fn test_postgres_bytea(store_type: StorageType) {
|
||||
let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_bytea_output").await;
|
||||
|
||||
let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls)
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::spawn(async move {
|
||||
connection.await.unwrap();
|
||||
});
|
||||
let _ = client
|
||||
.simple_query("CREATE TABLE test(b BLOB, ts TIMESTAMP TIME INDEX)")
|
||||
.await
|
||||
.unwrap();
|
||||
let _ = client
|
||||
.simple_query("INSERT INTO test VALUES(X'6162636b6c6d2aa954', 0)")
|
||||
.await
|
||||
.unwrap();
|
||||
let get_row = |mess: Vec<SimpleQueryMessage>| -> String {
|
||||
match &mess[0] {
|
||||
SimpleQueryMessage::Row(row) => row.get(0).unwrap().to_string(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
let r = client.simple_query("SELECT b FROM test").await.unwrap();
|
||||
let b = get_row(r);
|
||||
assert_eq!(b, "\\x6162636b6c6d2aa954");
|
||||
|
||||
let _ = client.simple_query("SET bytea_output='hex'").await.unwrap();
|
||||
let r = client.simple_query("SELECT b FROM test").await.unwrap();
|
||||
let b = get_row(r);
|
||||
assert_eq!(b, "\\x6162636b6c6d2aa954");
|
||||
|
||||
let _ = client
|
||||
.simple_query("SET bytea_output='escape'")
|
||||
.await
|
||||
.unwrap();
|
||||
let r = client.simple_query("SELECT b FROM test").await.unwrap();
|
||||
let b = get_row(r);
|
||||
assert_eq!(b, "abcklm*\\251T");
|
||||
|
||||
let _e = client
|
||||
.simple_query("SET bytea_output='invalid'")
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
// binary format shall not be affected by bytea_output
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(2)
|
||||
.connect(&format!("postgres://{addr}/public"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let row = sqlx::query("select b from test")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
let val: Vec<u8> = row.get("b");
|
||||
assert_eq!(val, [97, 98, 99, 107, 108, 109, 42, 169, 84]);
|
||||
|
||||
let _ = fe_pg_server.shutdown().await;
|
||||
guard.remove_all().await;
|
||||
}
|
||||
pub async fn test_postgres_timezone(store_type: StorageType) {
|
||||
let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user