diff --git a/Cargo.lock b/Cargo.lock index c89802d1f1..e7cd25c366 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9198,9 +9198,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.36.0" +version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5cc59678d0c10c73a552d465ce9156995189d1c678f2784dc817fe8623487f5" +checksum = "d331bb0eef5bc83a221c0a85b1f205bccf094d4f72a26ae1d68a1b1c535123b7" dependencies = [ "async-trait", "base64 0.22.1", diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 825bbe8e95..b5e6371785 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -87,7 +87,7 @@ operator.workspace = true otel-arrow-rust.workspace = true parking_lot.workspace = true pg_interval = "0.4" -pgwire = { version = "0.36", default-features = false, features = [ +pgwire = { version = "0.36.1", default-features = false, features = [ "server-api-ring", "pg-ext-types", ] } diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 3013aee4e5..40635cd036 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -56,7 +56,7 @@ pub mod server; pub mod tls; /// Cached SQL and logical plan for database interfaces -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct SqlPlan { query: String, // Store the parsed statement to determine if it is a query and whether to track it. diff --git a/src/servers/src/postgres/fixtures.rs b/src/servers/src/postgres/fixtures.rs index 3b56d99241..dcd7842c95 100644 --- a/src/servers/src/postgres/fixtures.rs +++ b/src/servers/src/postgres/fixtures.rs @@ -22,7 +22,7 @@ use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse use pgwire::error::PgWireResult; use pgwire::messages::data::DataRow; use regex::Regex; -use session::context::QueryContextRef; +use session::context::{QueryContext, QueryContextRef}; fn build_string_data_rows( schema: Arc>, @@ -60,11 +60,7 @@ static ABORT_TRANSACTION_PATTERN: Lazy = /// Test if given query statement matches the patterns pub(crate) fn matches(query: &str) -> bool { - START_TRANSACTION_PATTERN.is_match(query) - || COMMIT_TRANSACTION_PATTERN.is_match(query) - || ABORT_TRANSACTION_PATTERN.is_match(query) - || SHOW_PATTERN.captures(query).is_some() - || SET_TRANSACTION_PATTERN.is_match(query) + process(query, QueryContext::arc()).is_some() } fn set_transaction_warning(query_ctx: QueryContextRef) { diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 50535f2162..ab0ea32b5e 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -28,7 +28,7 @@ use futures::{Sink, SinkExt, Stream, StreamExt, future, stream}; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ - DescribePortalResponse, DescribeStatementResponse, QueryResponse, Response, Tag, + DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag, }; use pgwire::api::stmt::{QueryParser, StoredStatement}; use pgwire::api::{ClientInfo, ErrorHandler, Type}; @@ -40,6 +40,7 @@ use session::context::QueryContextRef; use snafu::ResultExt; use sql::dialect::PostgreSqlDialect; use sql::parser::{ParseOptions, ParserContext}; +use sql::statements::statement::Statement; use crate::SqlPlan; use crate::error::{DataFusionSnafu, Result}; @@ -412,21 +413,57 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { let sql_plan = &portal.statement.statement; let format = &portal.result_column_format; - if let Some(schema) = &sql_plan.schema { - schema_to_pg(schema, format) - .map(DescribePortalResponse::new) - .map_err(convert_err) - } else { - if let Some(mut resp) = - fixtures::process(&sql_plan.query, self.session.new_query_context()) - && let Response::Query(query_response) = resp.remove(0) - { - return Ok(DescribePortalResponse::new( - (*query_response.row_schema()).clone(), - )); + match sql_plan.statement.as_ref() { + Some(Statement::Query(_)) => { + // if the query has a schema, it is managed by datafusion, use the schema + if let Some(schema) = &sql_plan.schema { + schema_to_pg(schema, format) + .map(DescribePortalResponse::new) + .map_err(convert_err) + } else { + // fallback to NoData + Ok(DescribePortalResponse::new(vec![])) + } + } + // We can cover only part of show statements + // these show create statements will return 2 columns + Some(Statement::ShowCreateDatabase(_)) + | Some(Statement::ShowCreateTable(_)) + | Some(Statement::ShowCreateFlow(_)) + | Some(Statement::ShowCreateView(_)) => Ok(DescribePortalResponse::new(vec![ + FieldInfo::new( + "name".to_string(), + None, + None, + Type::TEXT, + format.format_for(0), + ), + FieldInfo::new( + "create_statement".to_string(), + None, + None, + Type::TEXT, + format.format_for(1), + ), + ])), + // single column show statements + Some(Statement::ShowTables(_)) + | Some(Statement::ShowFlows(_)) + | Some(Statement::ShowViews(_)) => { + Ok(DescribePortalResponse::new(vec![FieldInfo::new( + "name".to_string(), + None, + None, + Type::TEXT, + format.format_for(0), + )])) + } + // we will not support other show statements for extended query protocol at least for now. + // because the return columns is not predictable at this stage + _ => { + // fallback to NoData + Ok(DescribePortalResponse::new(vec![])) } - - Ok(DescribePortalResponse::new(vec![])) } } } diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index fce7646f21..323b4fb558 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -29,7 +29,7 @@ use arrow::datatypes::{ TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type, }; use arrow_schema::{DataType, IntervalUnit, TimeUnit}; -use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; +use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; use common_decimal::Decimal128; use common_recordbatch::RecordBatch; use common_time::time::Time; @@ -717,7 +717,7 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result { &Type::INT4 => Ok(ConcreteDataType::int32_datatype()), &Type::INT8 => Ok(ConcreteDataType::int64_datatype()), &Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()), - &Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype( + &Type::TIMESTAMP | &Type::TIMESTAMPTZ => Ok(ConcreteDataType::timestamp_datatype( common_time::timestamp::TimeUnit::Millisecond, )), &Type::DATE => Ok(ConcreteDataType::date_datatype()), @@ -1050,7 +1050,7 @@ pub(super) fn parameters_to_scalar_values( None, ), TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond( - data.map(|ts| ts.and_utc().timestamp_micros()), + data.and_then(|ts| ts.and_utc().timestamp_nanos_opt()), None, ), }, @@ -1068,6 +1068,38 @@ pub(super) fn parameters_to_scalar_values( ) } } + &Type::TIMESTAMPTZ => { + let data = portal.parameter::>(idx, &client_type)?; + if let Some(server_type) = &server_type { + match server_type { + ConcreteDataType::Timestamp(unit) => match *unit { + TimestampType::Second(_) => { + ScalarValue::TimestampSecond(data.map(|ts| ts.timestamp()), None) + } + TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond( + data.map(|ts| ts.timestamp_millis()), + None, + ), + TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond( + data.map(|ts| ts.timestamp_micros()), + None, + ), + TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond( + data.and_then(|ts| ts.timestamp_nanos_opt()), + None, + ), + }, + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(format!("Expected: {}, found: {}", server_type, client_type)), + )); + } + } + } else { + ScalarValue::TimestampMillisecond(data.map(|ts| ts.timestamp_millis()), None) + } + } &Type::DATE => { let data = portal.parameter::(idx, &client_type)?; if let Some(server_type) = &server_type {