diff --git a/src/servers/src/postgres.rs b/src/servers/src/postgres.rs index 047fcab98c..112d6b39dd 100644 --- a/src/servers/src/postgres.rs +++ b/src/servers/src/postgres.rs @@ -17,6 +17,7 @@ mod fixtures; mod handler; mod server; mod types; +mod utils; pub(crate) const METADATA_USER: &str = "user"; pub(crate) const METADATA_DATABASE: &str = "database"; diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 9505c11956..134ee54b19 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -32,6 +32,7 @@ use snafu::IntoError; use crate::error::{AuthSnafu, Result}; use crate::metrics::METRIC_AUTH_FAILURE; use crate::postgres::types::PgErrorCode; +use crate::postgres::utils::convert_err; use crate::postgres::PostgresServerHandlerInner; use crate::query_handler::sql::ServerSqlQueryHandlerRef; @@ -247,7 +248,7 @@ where if query_handler .is_valid_schema(&catalog, &schema) .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))? + .map_err(convert_err)? { Ok(DbResolution::Resolved(catalog, schema)) } else { diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 97c48a8ac9..fd0eb223a4 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -16,11 +16,10 @@ use std::fmt::Debug; use std::sync::Arc; use async_trait::async_trait; -use common_error::ext::ErrorExt; use common_query::{Output, OutputData}; use common_recordbatch::error::Result as RecordBatchResult; use common_recordbatch::RecordBatch; -use common_telemetry::{debug, error, tracing}; +use common_telemetry::{debug, tracing}; use datafusion_common::ParamValues; use datatypes::prelude::ConcreteDataType; use datatypes::schema::SchemaRef; @@ -37,11 +36,13 @@ use pgwire::messages::PgWireBackendMessage; use query::query_engine::DescribeResult; use session::context::QueryContextRef; use session::Session; +use snafu::ResultExt; use sql::dialect::PostgreSqlDialect; use sql::parser::{ParseOptions, ParserContext}; -use crate::error::Result; +use crate::error::{DataFusionSnafu, Result}; use crate::postgres::types::*; +use crate::postgres::utils::convert_err; use crate::postgres::{fixtures, PostgresServerHandlerInner}; use crate::query_handler::sql::ServerSqlQueryHandlerRef; use crate::SqlPlan; @@ -135,24 +136,7 @@ pub(crate) fn output_to_query_response<'a>( ) } }, - Err(e) => { - let status_code = e.status_code(); - - if status_code.should_log_error() { - let root_error = e.root_cause().unwrap_or(&e); - error!(e; "Failed to handle postgres query, code: {}, db: {}, error: {}", status_code, query_ctx.get_db_string(), root_error.to_string()); - } else { - debug!( - "Failed to handle postgres query, code: {}, db: {}, error: {:?}", - status_code, - query_ctx.get_db_string(), - e - ); - }; - Ok(Response::Error(Box::new( - PgErrorCode::from(status_code).to_err_info(e.output_msg()), - ))) - } + Err(e) => Err(convert_err(e)), } } @@ -165,10 +149,7 @@ fn recordbatches_to_query_response<'a, S>( where S: Stream> + Send + Unpin + 'static, { - let pg_schema = Arc::new( - schema_to_pg(schema.as_ref(), field_format) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?, - ); + let pg_schema = Arc::new(schema_to_pg(schema.as_ref(), field_format).map_err(convert_err)?); let pg_schema_ref = pg_schema.clone(); let data_row_stream = recordbatches_stream .map(|record_batch_result| match record_batch_result { @@ -178,7 +159,7 @@ where rb.rows().map(Ok).collect::>(), ) .boxed(), - Err(e) => stream::once(future::err(PgWireError::ApiError(Box::new(e)))).boxed(), + Err(e) => stream::once(future::err(convert_err(e))).boxed(), }) .flatten() // flatten into stream> .map(move |row| { @@ -238,7 +219,7 @@ impl QueryParser for DefaultQueryParser { let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}, ParseOptions::default()) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + .map_err(convert_err)?; if stmts.len() != 1 { Err(PgWireError::UserError(Box::new(ErrorInfo::from( PgErrorCode::Ec42P14, @@ -250,7 +231,7 @@ impl QueryParser for DefaultQueryParser { .query_handler .do_describe(stmt, query_ctx) .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + .map_err(convert_err)?; let (plan, schema) = if let Some(DescribeResult { logical_plan, @@ -316,7 +297,8 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { .replace_params_with_values(&ParamValues::List(parameters_to_scalar_values( plan, portal, )?)) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + .context(DataFusionSnafu) + .map_err(convert_err)?; self.query_handler .do_exec_plan(plan, query_ctx.clone()) .await @@ -351,13 +333,13 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { let (param_types, sql_plan, format) = if let Some(plan) = &sql_plan.plan { let param_types = plan .get_parameter_types() - .map_err(|e| PgWireError::ApiError(Box::new(e)))? + .context(DataFusionSnafu) + .map_err(convert_err)? .into_iter() .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) .collect(); - let types = param_types_to_pg_types(¶m_types) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let types = param_types_to_pg_types(¶m_types).map_err(convert_err)?; (types, sql_plan, &Format::UnifiedBinary) } else { @@ -368,7 +350,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { if let Some(schema) = &sql_plan.schema { schema_to_pg(schema, format) .map(|fields| DescribeStatementResponse::new(param_types, fields)) - .map_err(|e| PgWireError::ApiError(Box::new(e))) + .map_err(convert_err) } else { if let Some(mut resp) = fixtures::process(&sql_plan.query, self.session.new_query_context()) @@ -399,7 +381,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { if let Some(schema) = &sql_plan.schema { schema_to_pg(schema, format) .map(DescribePortalResponse::new) - .map_err(|e| PgWireError::ApiError(Box::new(e))) + .map_err(convert_err) } else { if let Some(mut resp) = fixtures::process(&sql_plan.query, self.session.new_query_context()) diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index 68bdb9d2e9..f3c2781c29 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -35,12 +35,14 @@ use pgwire::api::Type; use pgwire::error::{PgWireError, PgWireResult}; use session::context::QueryContextRef; use session::session_config::PGByteaOutputValue; +use snafu::ResultExt; use self::bytea::{EscapeOutputBytea, HexOutputBytea}; use self::datetime::{StylingDate, StylingDateTime}; pub use self::error::{PgErrorCode, PgErrorSeverity}; use self::interval::PgInterval; -use crate::error::{self as server_error, Error, Result}; +use crate::error::{self as server_error, DataFusionSnafu, Error, Result}; +use crate::postgres::utils::convert_err; use crate::SqlPlan; pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result> { @@ -73,9 +75,9 @@ fn encode_array( .map(|v| match v { Value::Null => Ok(None), Value::Boolean(v) => Ok(Some(*v)), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected bool",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -88,11 +90,11 @@ fn encode_array( Value::Null => Ok(None), Value::Int8(v) => Ok(Some(*v)), Value::UInt8(v) => Ok(Some(*v as i8)), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!( "Invalid list item type, find {v:?}, expected int8 or uint8", ), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -105,11 +107,11 @@ fn encode_array( Value::Null => Ok(None), Value::Int16(v) => Ok(Some(*v)), Value::UInt16(v) => Ok(Some(*v as i16)), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!( "Invalid list item type, find {v:?}, expected int16 or uint16", ), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -122,11 +124,11 @@ fn encode_array( Value::Null => Ok(None), Value::Int32(v) => Ok(Some(*v)), Value::UInt32(v) => Ok(Some(*v as i32)), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!( "Invalid list item type, find {v:?}, expected int32 or uint32", ), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -139,11 +141,11 @@ fn encode_array( Value::Null => Ok(None), Value::Int64(v) => Ok(Some(*v)), Value::UInt64(v) => Ok(Some(*v as i64)), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!( "Invalid list item type, find {v:?}, expected int64 or uint64", ), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -155,9 +157,9 @@ fn encode_array( .map(|v| match v { Value::Null => Ok(None), Value::Float32(v) => Ok(Some(v.0)), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected float32",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -169,9 +171,9 @@ fn encode_array( .map(|v| match v { Value::Null => Ok(None), Value::Float64(v) => Ok(Some(v.0)), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected float64",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -188,11 +190,11 @@ fn encode_array( Value::Null => Ok(None), Value::Binary(v) => Ok(Some(EscapeOutputBytea(v.deref()))), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!( "Invalid list item type, find {v:?}, expected binary", ), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -205,11 +207,11 @@ fn encode_array( Value::Null => Ok(None), Value::Binary(v) => Ok(Some(HexOutputBytea(v.deref()))), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!( "Invalid list item type, find {v:?}, expected binary", ), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -223,9 +225,9 @@ fn encode_array( .map(|v| match v { Value::Null => Ok(None), Value::String(v) => Ok(Some(v.as_utf8())), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected string",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -242,14 +244,14 @@ fn encode_array( *query_ctx.configuration_parameter().pg_datetime_style(); Ok(Some(StylingDate(date, style, order))) } else { - Err(PgWireError::ApiError(Box::new(Error::Internal { + Err(convert_err(Error::Internal { err_msg: format!("Failed to convert date to postgres type {v:?}",), - }))) + })) } } - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected date",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -268,14 +270,14 @@ fn encode_array( *query_ctx.configuration_parameter().pg_datetime_style(); Ok(Some(StylingDateTime(datetime, style, order))) } else { - Err(PgWireError::ApiError(Box::new(Error::Internal { + Err(convert_err(Error::Internal { err_msg: format!("Failed to convert date to postgres type {v:?}",), - }))) + })) } } - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected timestamp",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -287,9 +289,9 @@ fn encode_array( .map(|v| match v { Value::Null => Ok(None), Value::Time(v) => Ok(v.to_chrono_time()), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected time",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -303,9 +305,9 @@ fn encode_array( Value::IntervalYearMonth(v) => Ok(Some(PgInterval::from(*v))), Value::IntervalDayTime(v) => Ok(Some(PgInterval::from(*v))), Value::IntervalMonthDayNano(v) => Ok(Some(PgInterval::from(*v))), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected interval",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -317,9 +319,9 @@ fn encode_array( .map(|v| match v { Value::Null => Ok(None), Value::Decimal128(v) => Ok(Some(v.to_string())), - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected decimal",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) @@ -331,23 +333,22 @@ fn encode_array( .map(|v| match v { Value::Null => Ok(None), Value::Binary(v) => { - let s = json_type_value_to_string(v, &j.format) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let s = json_type_value_to_string(v, &j.format).map_err(convert_err)?; Ok(Some(s)) } - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!("Invalid list item type, find {v:?}, expected json",), - }))), + })), }) .collect::>>>()?; builder.encode_field(&array) } - _ => Err(PgWireError::ApiError(Box::new(Error::Internal { + _ => Err(convert_err(Error::Internal { err_msg: format!( "cannot write array type {:?} in postgres protocol: unimplemented", value_list.datatype() ), - }))), + })), } } @@ -373,8 +374,7 @@ pub(super) fn encode_value( Value::String(v) => builder.encode_field(&v.as_utf8()), Value::Binary(v) => match datatype { ConcreteDataType::Json(j) => { - let s = json_type_value_to_string(v, &j.format) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let s = json_type_value_to_string(v, &j.format).map_err(convert_err)?; builder.encode_field(&s) } _ => { @@ -392,9 +392,9 @@ pub(super) fn encode_value( let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style(); builder.encode_field(&StylingDate(date, style, order)) } else { - Err(PgWireError::ApiError(Box::new(Error::Internal { + Err(convert_err(Error::Internal { err_msg: format!("Failed to convert date to postgres type {v:?}",), - }))) + })) } } Value::Timestamp(v) => { @@ -403,18 +403,18 @@ pub(super) fn encode_value( let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style(); builder.encode_field(&StylingDateTime(datetime, style, order)) } else { - Err(PgWireError::ApiError(Box::new(Error::Internal { + Err(convert_err(Error::Internal { err_msg: format!("Failed to convert date to postgres type {v:?}",), - }))) + })) } } Value::Time(v) => { if let Some(time) = v.to_chrono_time() { builder.encode_field(&time) } else { - Err(PgWireError::ApiError(Box::new(Error::Internal { + Err(convert_err(Error::Internal { err_msg: format!("Failed to convert time to postgres type {v:?}",), - }))) + })) } } Value::IntervalYearMonth(v) => builder.encode_field(&PgInterval::from(*v)), @@ -423,9 +423,9 @@ pub(super) fn encode_value( Value::Decimal128(v) => builder.encode_field(&v.to_string()), Value::Duration(d) => match PgInterval::try_from(*d) { Ok(i) => builder.encode_field(&i), - Err(e) => Err(PgWireError::ApiError(Box::new(Error::Internal { + Err(e) => Err(convert_err(Error::Internal { err_msg: e.to_string(), - }))), + })), }, Value::List(values) => encode_array(query_ctx, values, builder), } @@ -591,7 +591,7 @@ where if let Some(n) = data { Value::Timestamp(unit.create_timestamp(n.into())) .try_to_scalar_value(ctype) - .map_err(|e| PgWireError::ApiError(Box::new(e))) + .map_err(convert_err) } else { Ok(ScalarValue::Null) } @@ -607,7 +607,8 @@ pub(super) fn parameters_to_scalar_values( let client_param_types = &portal.statement.parameter_types; let param_types = plan .get_parameter_types() - .map_err(|e| PgWireError::ApiError(Box::new(e)))? + .context(DataFusionSnafu) + .map_err(convert_err)? .into_iter() .map(|(k, v)| (k, v.map(|v| ConcreteDataType::from_arrow_type(&v)))) .collect::>(); @@ -620,7 +621,7 @@ pub(super) fn parameters_to_scalar_values( let client_type = if let Some(client_given_type) = client_param_types.get(idx) { client_given_type.clone() } else if let Some(server_provided_type) = &server_type { - type_gt_to_pg(server_provided_type).map_err(|e| PgWireError::ApiError(Box::new(e)))? + type_gt_to_pg(server_provided_type).map_err(convert_err)? } else { return Err(invalid_parameter_error( "unknown_parameter_type", diff --git a/src/servers/src/postgres/utils.rs b/src/servers/src/postgres/utils.rs new file mode 100644 index 0000000000..2c84bd8ef5 --- /dev/null +++ b/src/servers/src/postgres/utils.rs @@ -0,0 +1,36 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_error::ext::ErrorExt; +use common_telemetry::{debug, error}; +use pgwire::error::PgWireError; + +use crate::postgres::types::PgErrorCode; + +pub fn convert_err(e: impl ErrorExt) -> PgWireError { + let status_code = e.status_code(); + if status_code.should_log_error() { + let root_error = e.root_cause().unwrap_or(&e); + error!(e; "Failed to handle postgres query, code: {}, error: {}", status_code, root_error.to_string()); + } else { + debug!( + "Failed to handle postgres query, code: {}, error: {:?}", + status_code, e + ); + } + + PgWireError::UserError(Box::new( + PgErrorCode::from(status_code).to_err_info(e.output_msg()), + )) +}