mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-05 21:02:58 +00:00
feat: returning warning instead of error on unsupported SET statement (#4761)
* feat: add capability to send warning to pgclient * fix: refactor query context to carry query scope data * feat: return a warning for unsupported postgres statement
This commit is contained in:
@@ -46,7 +46,7 @@ use datafusion_expr::LogicalPlan;
|
||||
use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef};
|
||||
use query::parser::QueryStatement;
|
||||
use query::QueryEngineRef;
|
||||
use session::context::QueryContextRef;
|
||||
use session::context::{Channel, QueryContextRef};
|
||||
use session::table_name::table_idents_to_full_name;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
|
||||
@@ -338,10 +338,18 @@ impl StatementExecutor {
|
||||
|
||||
"CLIENT_ENCODING" => validate_client_encoding(set_var)?,
|
||||
_ => {
|
||||
return NotSupportedSnafu {
|
||||
feat: format!("Unsupported set variable {}", var_name),
|
||||
// for postgres, we give unknown SET statements a warning with
|
||||
// success, this is prevent the SET call becoming a blocker
|
||||
// of connection establishment
|
||||
//
|
||||
if query_ctx.channel() == Channel::Postgres {
|
||||
query_ctx.set_warning(format!("Unsupported set variable {}", var_name));
|
||||
} else {
|
||||
return NotSupportedSnafu {
|
||||
feat: format!("Unsupported set variable {}", var_name),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
Ok(Output::new_with_affected_rows(0))
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -23,7 +24,7 @@ use common_telemetry::{debug, error, tracing};
|
||||
use datafusion_common::ParamValues;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::SchemaRef;
|
||||
use futures::{future, stream, Stream, StreamExt};
|
||||
use futures::{future, stream, Sink, SinkExt, Stream, StreamExt};
|
||||
use pgwire::api::portal::{Format, Portal};
|
||||
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
|
||||
use pgwire::api::results::{
|
||||
@@ -32,6 +33,7 @@ use pgwire::api::results::{
|
||||
use pgwire::api::stmt::{QueryParser, StoredStatement};
|
||||
use pgwire::api::{ClientInfo, Type};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use pgwire::messages::PgWireBackendMessage;
|
||||
use query::query_engine::DescribeResult;
|
||||
use session::context::QueryContextRef;
|
||||
use session::Session;
|
||||
@@ -49,11 +51,13 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
|
||||
#[tracing::instrument(skip_all, fields(protocol = "postgres"))]
|
||||
async fn do_query<'a, C>(
|
||||
&self,
|
||||
_client: &mut C,
|
||||
client: &mut C,
|
||||
query: &'a str,
|
||||
) -> PgWireResult<Vec<Response<'a>>>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
|
||||
C::Error: Debug,
|
||||
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
||||
{
|
||||
let query_ctx = self.session.new_query_context();
|
||||
let db = query_ctx.get_db_string();
|
||||
@@ -67,6 +71,7 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
|
||||
}
|
||||
|
||||
if let Some(resps) = fixtures::process(query, query_ctx.clone()) {
|
||||
send_warning_opt(client, query_ctx).await?;
|
||||
Ok(resps)
|
||||
} else {
|
||||
let outputs = self.query_handler.do_query(query, query_ctx.clone()).await;
|
||||
@@ -79,11 +84,34 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
|
||||
results.push(resp);
|
||||
}
|
||||
|
||||
send_warning_opt(client, query_ctx).await?;
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_warning_opt<C>(client: &mut C, query_context: QueryContextRef) -> PgWireResult<()>
|
||||
where
|
||||
C: Sink<PgWireBackendMessage> + Unpin + Send + Sync,
|
||||
C::Error: Debug,
|
||||
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
||||
{
|
||||
if let Some(warning) = query_context.warning() {
|
||||
client
|
||||
.feed(PgWireBackendMessage::NoticeResponse(
|
||||
ErrorInfo::new(
|
||||
PgErrorSeverity::Warning.to_string(),
|
||||
PgErrorCode::Ec01000.code(),
|
||||
warning.to_string(),
|
||||
)
|
||||
.into(),
|
||||
))
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn output_to_query_response<'a>(
|
||||
query_ctx: QueryContextRef,
|
||||
output: Result<Output>,
|
||||
@@ -247,12 +275,14 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
|
||||
|
||||
async fn do_query<'a, C>(
|
||||
&self,
|
||||
_client: &mut C,
|
||||
client: &mut C,
|
||||
portal: &'a Portal<Self::Statement>,
|
||||
_max_rows: usize,
|
||||
) -> PgWireResult<Response<'a>>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
|
||||
C::Error: Debug,
|
||||
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
||||
{
|
||||
let query_ctx = self.session.new_query_context();
|
||||
let db = query_ctx.get_db_string();
|
||||
@@ -268,6 +298,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
|
||||
}
|
||||
|
||||
if let Some(mut resps) = fixtures::process(&sql_plan.query, query_ctx.clone()) {
|
||||
send_warning_opt(client, query_ctx).await?;
|
||||
// if the statement matches our predefined rules, return it early
|
||||
return Ok(resps.remove(0));
|
||||
}
|
||||
@@ -297,6 +328,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
|
||||
.remove(0)
|
||||
};
|
||||
|
||||
send_warning_opt(client, query_ctx.clone()).await?;
|
||||
output_to_query_response(query_ctx, output, &portal.result_column_format)
|
||||
}
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ use session::session_config::PGByteaOutputValue;
|
||||
|
||||
use self::bytea::{EscapeOutputBytea, HexOutputBytea};
|
||||
use self::datetime::{StylingDate, StylingDateTime};
|
||||
pub use self::error::PgErrorCode;
|
||||
pub use self::error::{PgErrorCode, PgErrorSeverity};
|
||||
use self::interval::PgInterval;
|
||||
use crate::error::{self as server_error, Error, Result};
|
||||
use crate::SqlPlan;
|
||||
|
||||
@@ -19,7 +19,7 @@ use strum::{AsRefStr, Display, EnumIter, EnumMessage};
|
||||
|
||||
#[derive(Display, Debug, PartialEq)]
|
||||
#[allow(dead_code)]
|
||||
enum ErrorSeverity {
|
||||
pub enum PgErrorSeverity {
|
||||
#[strum(serialize = "INFO")]
|
||||
Info,
|
||||
#[strum(serialize = "DEBUG")]
|
||||
@@ -335,23 +335,23 @@ pub enum PgErrorCode {
|
||||
}
|
||||
|
||||
impl PgErrorCode {
|
||||
fn severity(&self) -> ErrorSeverity {
|
||||
fn severity(&self) -> PgErrorSeverity {
|
||||
match self {
|
||||
PgErrorCode::Ec00000 => ErrorSeverity::Info,
|
||||
PgErrorCode::Ec01000 => ErrorSeverity::Warning,
|
||||
PgErrorCode::Ec00000 => PgErrorSeverity::Info,
|
||||
PgErrorCode::Ec01000 => PgErrorSeverity::Warning,
|
||||
|
||||
PgErrorCode::EcXX000 | PgErrorCode::Ec42P14 | PgErrorCode::Ec22023 => {
|
||||
ErrorSeverity::Error
|
||||
PgErrorSeverity::Error
|
||||
}
|
||||
PgErrorCode::Ec28000 | PgErrorCode::Ec28P01 | PgErrorCode::Ec3D000 => {
|
||||
ErrorSeverity::Fatal
|
||||
PgErrorSeverity::Fatal
|
||||
}
|
||||
|
||||
_ => ErrorSeverity::Error,
|
||||
_ => PgErrorSeverity::Error,
|
||||
}
|
||||
}
|
||||
|
||||
fn code(&self) -> String {
|
||||
pub(crate) fn code(&self) -> String {
|
||||
self.as_ref()[2..].to_string()
|
||||
}
|
||||
|
||||
@@ -428,19 +428,19 @@ mod tests {
|
||||
use common_error::status_code::StatusCode;
|
||||
use strum::{EnumMessage, IntoEnumIterator};
|
||||
|
||||
use super::{ErrorInfo, ErrorSeverity, PgErrorCode};
|
||||
use super::{ErrorInfo, PgErrorCode, PgErrorSeverity};
|
||||
|
||||
#[test]
|
||||
fn test_error_severity() {
|
||||
// test for ErrorSeverity enum
|
||||
assert_eq!("INFO", ErrorSeverity::Info.to_string());
|
||||
assert_eq!("DEBUG", ErrorSeverity::Debug.to_string());
|
||||
assert_eq!("NOTICE", ErrorSeverity::Notice.to_string());
|
||||
assert_eq!("WARNING", ErrorSeverity::Warning.to_string());
|
||||
assert_eq!("INFO", PgErrorSeverity::Info.to_string());
|
||||
assert_eq!("DEBUG", PgErrorSeverity::Debug.to_string());
|
||||
assert_eq!("NOTICE", PgErrorSeverity::Notice.to_string());
|
||||
assert_eq!("WARNING", PgErrorSeverity::Warning.to_string());
|
||||
|
||||
assert_eq!("ERROR", ErrorSeverity::Error.to_string());
|
||||
assert_eq!("FATAL", ErrorSeverity::Fatal.to_string());
|
||||
assert_eq!("PANIC", ErrorSeverity::Panic.to_string());
|
||||
assert_eq!("ERROR", PgErrorSeverity::Error.to_string());
|
||||
assert_eq!("FATAL", PgErrorSeverity::Fatal.to_string());
|
||||
assert_eq!("PANIC", PgErrorSeverity::Panic.to_string());
|
||||
|
||||
// test for severity method
|
||||
for code in PgErrorCode::iter() {
|
||||
@@ -448,13 +448,13 @@ mod tests {
|
||||
assert_eq!("Ec", &name[0..2]);
|
||||
|
||||
if name.starts_with("Ec00") {
|
||||
assert_eq!(ErrorSeverity::Info, code.severity());
|
||||
assert_eq!(PgErrorSeverity::Info, code.severity());
|
||||
} else if name.starts_with("Ec01") {
|
||||
assert_eq!(ErrorSeverity::Warning, code.severity());
|
||||
assert_eq!(PgErrorSeverity::Warning, code.severity());
|
||||
} else if name.starts_with("Ec28") || name.starts_with("Ec3D") {
|
||||
assert_eq!(ErrorSeverity::Fatal, code.severity());
|
||||
assert_eq!(PgErrorSeverity::Fatal, code.severity());
|
||||
} else {
|
||||
assert_eq!(ErrorSeverity::Error, code.severity());
|
||||
assert_eq!(PgErrorSeverity::Error, code.severity());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +40,9 @@ pub struct QueryContext {
|
||||
current_catalog: String,
|
||||
// we use Arc<RwLock>> for modifiable fields
|
||||
#[builder(default)]
|
||||
mutable_inner: Arc<RwLock<MutableInner>>,
|
||||
mutable_session_data: Arc<RwLock<MutableInner>>,
|
||||
#[builder(default)]
|
||||
mutable_query_context_data: Arc<RwLock<QueryContextMutableFields>>,
|
||||
sql_dialect: Arc<dyn Dialect + Send + Sync>,
|
||||
#[builder(default)]
|
||||
extensions: HashMap<String, String>,
|
||||
@@ -52,6 +54,12 @@ pub struct QueryContext {
|
||||
channel: Channel,
|
||||
}
|
||||
|
||||
/// This fields hold data that is only valid to current query context
|
||||
#[derive(Debug, Builder, Clone, Default)]
|
||||
pub struct QueryContextMutableFields {
|
||||
warning: Option<String>,
|
||||
}
|
||||
|
||||
impl Display for QueryContext {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
@@ -65,21 +73,26 @@ impl Display for QueryContext {
|
||||
|
||||
impl QueryContextBuilder {
|
||||
pub fn current_schema(mut self, schema: String) -> Self {
|
||||
if self.mutable_inner.is_none() {
|
||||
self.mutable_inner = Some(Arc::new(RwLock::new(MutableInner::default())));
|
||||
if self.mutable_session_data.is_none() {
|
||||
self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
|
||||
}
|
||||
|
||||
// safe for unwrap because previous none check
|
||||
self.mutable_inner.as_mut().unwrap().write().unwrap().schema = schema;
|
||||
self.mutable_session_data
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.write()
|
||||
.unwrap()
|
||||
.schema = schema;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn timezone(mut self, timezone: Timezone) -> Self {
|
||||
if self.mutable_inner.is_none() {
|
||||
self.mutable_inner = Some(Arc::new(RwLock::new(MutableInner::default())));
|
||||
if self.mutable_session_data.is_none() {
|
||||
self.mutable_session_data = Some(Arc::new(RwLock::new(MutableInner::default())));
|
||||
}
|
||||
|
||||
self.mutable_inner
|
||||
self.mutable_session_data
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.write()
|
||||
@@ -120,7 +133,7 @@ impl From<QueryContext> for api::v1::QueryContext {
|
||||
fn from(
|
||||
QueryContext {
|
||||
current_catalog,
|
||||
mutable_inner,
|
||||
mutable_session_data: mutable_inner,
|
||||
extensions,
|
||||
channel,
|
||||
..
|
||||
@@ -182,11 +195,11 @@ impl QueryContext {
|
||||
}
|
||||
|
||||
pub fn current_schema(&self) -> String {
|
||||
self.mutable_inner.read().unwrap().schema.clone()
|
||||
self.mutable_session_data.read().unwrap().schema.clone()
|
||||
}
|
||||
|
||||
pub fn set_current_schema(&self, new_schema: &str) {
|
||||
self.mutable_inner.write().unwrap().schema = new_schema.to_string();
|
||||
self.mutable_session_data.write().unwrap().schema = new_schema.to_string();
|
||||
}
|
||||
|
||||
pub fn current_catalog(&self) -> &str {
|
||||
@@ -208,19 +221,19 @@ impl QueryContext {
|
||||
}
|
||||
|
||||
pub fn timezone(&self) -> Timezone {
|
||||
self.mutable_inner.read().unwrap().timezone.clone()
|
||||
self.mutable_session_data.read().unwrap().timezone.clone()
|
||||
}
|
||||
|
||||
pub fn set_timezone(&self, timezone: Timezone) {
|
||||
self.mutable_inner.write().unwrap().timezone = timezone;
|
||||
self.mutable_session_data.write().unwrap().timezone = timezone;
|
||||
}
|
||||
|
||||
pub fn current_user(&self) -> UserInfoRef {
|
||||
self.mutable_inner.read().unwrap().user_info.clone()
|
||||
self.mutable_session_data.read().unwrap().user_info.clone()
|
||||
}
|
||||
|
||||
pub fn set_current_user(&self, user: UserInfoRef) {
|
||||
self.mutable_inner.write().unwrap().user_info = user;
|
||||
self.mutable_session_data.write().unwrap().user_info = user;
|
||||
}
|
||||
|
||||
pub fn set_extension<S1: Into<String>, S2: Into<String>>(&mut self, key: S1, value: S2) {
|
||||
@@ -257,6 +270,18 @@ impl QueryContext {
|
||||
pub fn set_channel(&mut self, channel: Channel) {
|
||||
self.channel = channel;
|
||||
}
|
||||
|
||||
pub fn warning(&self) -> Option<String> {
|
||||
self.mutable_query_context_data
|
||||
.read()
|
||||
.unwrap()
|
||||
.warning
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn set_warning(&self, msg: String) {
|
||||
self.mutable_query_context_data.write().unwrap().warning = Some(msg);
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryContextBuilder {
|
||||
@@ -266,7 +291,8 @@ impl QueryContextBuilder {
|
||||
current_catalog: self
|
||||
.current_catalog
|
||||
.unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()),
|
||||
mutable_inner: self.mutable_inner.unwrap_or_default(),
|
||||
mutable_session_data: self.mutable_session_data.unwrap_or_default(),
|
||||
mutable_query_context_data: self.mutable_query_context_data.unwrap_or_default(),
|
||||
sql_dialect: self
|
||||
.sql_dialect
|
||||
.unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
|
||||
|
||||
@@ -76,7 +76,7 @@ impl Session {
|
||||
// catalog is not allowed for update in query context so we use
|
||||
// string here
|
||||
.current_catalog(self.catalog.read().unwrap().clone())
|
||||
.mutable_inner(self.mutable_inner.clone())
|
||||
.mutable_session_data(self.mutable_inner.clone())
|
||||
.sql_dialect(self.conn_info.channel.dialect())
|
||||
.configuration_parameter(self.configuration_variables.clone())
|
||||
.channel(self.conn_info.channel)
|
||||
|
||||
Reference in New Issue
Block a user