From f293126315d2a19fb1b3ccf25406673a93fd0de7 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Tue, 11 Jul 2023 11:07:18 +0800 Subject: [PATCH] feat: add logical plan based prepare statement for postgresql (#1813) * feat: add logical plan based prepare statement for postgresql * refactor: correct more types * Update src/servers/src/postgres/types.rs Co-authored-by: LFC * fix: address review issues * test: add datetime in integration tests --------- Co-authored-by: LFC --- Cargo.lock | 10 +- src/query/src/parser.rs | 2 +- src/servers/Cargo.toml | 2 +- src/servers/src/lib.rs | 10 + src/servers/src/mysql/handler.rs | 23 +- src/servers/src/postgres.rs | 27 +- src/servers/src/postgres/handler.rs | 494 +++---------------- src/servers/src/postgres/server.rs | 27 +- src/servers/src/postgres/types.rs | 712 ++++++++++++++++++++++++++++ src/sql/src/statements/query.rs | 31 +- tests-integration/tests/sql.rs | 33 +- 11 files changed, 865 insertions(+), 506 deletions(-) create mode 100644 src/servers/src/postgres/types.rs diff --git a/Cargo.lock b/Cargo.lock index 59daaa6ea8..f769e00b50 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6408,9 +6408,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.14.1" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd92c65406efd0d621cdece478a41a89e472a559e44a6f2b218df4c14e66a888" +checksum = "e2de42ee35f9694def25c37c15f564555411d9904b48e33680618ee7359080dc" dependencies = [ "async-trait", "base64 0.21.2", @@ -11252,16 +11252,16 @@ dependencies = [ [[package]] name = "x509-certificate" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf14059fbc1dce14de1d08535c411ba0b18749c2550a12550300da90b7ba350b" +checksum = "2133ce6c08c050a5b368730a67c53a603ffd4a4a6c577c5218675a19f7782c05" dependencies = [ "bcder", "bytes", "chrono", "der 0.7.6", "hex", - "pem 1.1.1", + "pem 2.0.1", "ring", "signature", "spki 0.7.2", diff --git a/src/query/src/parser.rs b/src/query/src/parser.rs index df962efc72..b2bc9b6580 100644 --- a/src/query/src/parser.rs +++ b/src/query/src/parser.rs @@ -276,7 +276,7 @@ mod test { having: None, \ named_window: [], \ qualify: None \ - }), order_by: [], limit: None, offset: None, fetch: None, locks: [] }, param_types: [] }))"); + }), order_by: [], limit: None, offset: None, fetch: None, locks: [] } }))"); assert_eq!(format!("{stmt:?}"), expected); } diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index e4d1f8337b..e076940aa4 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -55,7 +55,7 @@ once_cell = "1.16" openmetrics-parser = "0.4" opensrv-mysql = "0.4" parking_lot = "0.12" -pgwire = "0.14.1" +pgwire = "0.15" pin-project = "1.0" postgres-types = { version = "0.2", features = ["with-chrono-0_4"] } promql-parser = "0.1.1" diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 9f7872b485..3546a70e55 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -16,6 +16,8 @@ #![feature(try_blocks)] use common_catalog::consts::DEFAULT_CATALOG_NAME; +use datatypes::schema::Schema; +use query::plan::LogicalPlan; use serde::{Deserialize, Serialize}; pub mod auth; @@ -72,6 +74,14 @@ pub fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &s } } +/// Cached SQL and logical plan for database interfaces +#[derive(Clone)] +pub struct SqlPlan { + query: String, + plan: Option, + schema: Option, +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index ddd53d64fb..7105ee7684 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -48,13 +48,7 @@ use crate::mysql::helper::{ use crate::mysql::writer; use crate::mysql::writer::create_mysql_column; use crate::query_handler::sql::ServerSqlQueryHandlerRef; - -/// Cached SQL and logical plan -#[derive(Clone)] -struct SqlPlan { - query: String, - plan: Option, -} +use crate::SqlPlan; // An intermediate shim for executing MySQL queries. pub struct MysqlInstanceShim { @@ -214,10 +208,16 @@ impl AsyncMysqlShim for MysqlInstanceShi // in the form of "$i", it can't process "?" right now. let statement = transform_placeholders(statement); - let plan = self - .do_describe(statement.clone()) - .await? - .map(|DescribeResult { logical_plan, .. }| logical_plan); + let describe_result = self.do_describe(statement.clone()).await?; + let (plan, schema) = if let Some(DescribeResult { + logical_plan, + schema, + }) = describe_result + { + (Some(logical_plan), Some(schema)) + } else { + (None, None) + }; let params = if let Some(plan) = &plan { prepared_params( @@ -234,6 +234,7 @@ impl AsyncMysqlShim for MysqlInstanceShi let stmt_id = self.save_plan(SqlPlan { query: query.to_string(), plan, + schema, }); w.reply(stmt_id, ¶ms, &[]).await?; diff --git a/src/servers/src/postgres.rs b/src/servers/src/postgres.rs index 6db0d1a93c..af09b4e7c7 100644 --- a/src/servers/src/postgres.rs +++ b/src/servers/src/postgres.rs @@ -15,6 +15,7 @@ mod auth_handler; mod handler; mod server; +mod types; pub(crate) const METADATA_USER: &str = "user"; pub(crate) const METADATA_DATABASE: &str = "database"; @@ -24,21 +25,22 @@ pub(crate) const METADATA_CATALOG: &str = "catalog"; pub(crate) const METADATA_SCHEMA: &str = "schema"; use std::collections::HashMap; +use std::net::SocketAddr; use std::sync::Arc; use derive_builder::Builder; use pgwire::api::auth::ServerParameterProvider; use pgwire::api::store::MemPortalStore; -use pgwire::api::{ClientInfo, MakeHandler}; +use pgwire::api::ClientInfo; pub use server::PostgresServer; use session::context::Channel; use session::Session; -use sql::statements::statement::Statement; use self::auth_handler::PgLoginVerifier; -use self::handler::POCQueryParser; +use self::handler::DefaultQueryParser; use crate::auth::UserProviderRef; use crate::query_handler::sql::ServerSqlQueryHandlerRef; +use crate::SqlPlan; pub(crate) struct GreptimeDBStartupParameters { version: &'static str, @@ -73,9 +75,9 @@ pub struct PostgresServerHandler { force_tls: bool, param_provider: Arc, - session: Session, - portal_store: Arc>, - query_parser: Arc, + session: Arc, + portal_store: Arc>, + query_parser: Arc, } #[derive(Builder)] @@ -84,24 +86,21 @@ pub(crate) struct MakePostgresServerHandler { user_provider: Option, #[builder(default = "Arc::new(GreptimeDBStartupParameters::new())")] param_provider: Arc, - #[builder(default = "Arc::new(POCQueryParser::default())")] - query_parser: Arc, force_tls: bool, } -impl MakeHandler for MakePostgresServerHandler { - type Handler = PostgresServerHandler; - - fn make(&self) -> Self::Handler { +impl MakePostgresServerHandler { + fn make(&self, addr: Option) -> PostgresServerHandler { + let session = Arc::new(Session::new(addr, Channel::Postgres)); PostgresServerHandler { query_handler: self.query_handler.clone(), login_verifier: PgLoginVerifier::new(self.user_provider.clone()), force_tls: self.force_tls, param_provider: self.param_provider.clone(), - session: Session::new(None, Channel::Postgres), + session: session.clone(), portal_store: Arc::new(MemPortalStore::new()), - query_parser: self.query_parser.clone(), + query_parser: Arc::new(DefaultQueryParser::new(self.query_handler.clone(), session)), } } } diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 035c2557eb..5bdaa63bd7 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::ops::Deref; use std::sync::Arc; use async_trait::async_trait; @@ -20,26 +19,26 @@ use common_query::Output; use common_recordbatch::error::Result as RecordBatchResult; use common_recordbatch::RecordBatch; use common_telemetry::timer; -use datatypes::prelude::{ConcreteDataType, Value}; -use datatypes::schema::{Schema, SchemaRef}; +use datatypes::schema::SchemaRef; use futures::{future, stream, Stream, StreamExt}; use metrics::increment_counter; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler, StatementOrPortal}; -use pgwire::api::results::{ - DataRowEncoder, DescribeResponse, FieldInfo, QueryResponse, Response, Tag, -}; +use pgwire::api::results::{DataRowEncoder, DescribeResponse, QueryResponse, Response, Tag}; use pgwire::api::stmt::QueryParser; use pgwire::api::store::MemPortalStore; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use query::query_engine::DescribeResult; +use session::Session; use sql::dialect::PostgreSqlDialect; use sql::parser::ParserContext; -use sql::statements::statement::Statement; +use super::types::*; use super::PostgresServerHandler; -use crate::error::{self, Error, Result}; +use crate::error::Result; +use crate::query_handler::sql::ServerSqlQueryHandlerRef; +use crate::SqlPlan; #[async_trait] impl SimpleQueryHandler for PostgresServerHandler { @@ -141,125 +140,25 @@ where ))) } -fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result> { - origin - .column_schemas() - .iter() - .enumerate() - .map(|(idx, col)| { - Ok(FieldInfo::new( - col.name.clone(), - None, - None, - type_gt_to_pg(&col.data_type)?, - field_formats.format_for(idx), - )) - }) - .collect::>>() +pub struct DefaultQueryParser { + query_handler: ServerSqlQueryHandlerRef, + session: Arc, } -fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> { - match value { - Value::Null => builder.encode_field(&None::<&i8>), - Value::Boolean(v) => builder.encode_field(v), - Value::UInt8(v) => builder.encode_field(&(*v as i8)), - Value::UInt16(v) => builder.encode_field(&(*v as i16)), - Value::UInt32(v) => builder.encode_field(v), - Value::UInt64(v) => builder.encode_field(&(*v as i64)), - Value::Int8(v) => builder.encode_field(v), - Value::Int16(v) => builder.encode_field(v), - Value::Int32(v) => builder.encode_field(v), - Value::Int64(v) => builder.encode_field(v), - Value::Float32(v) => builder.encode_field(&v.0), - Value::Float64(v) => builder.encode_field(&v.0), - Value::String(v) => builder.encode_field(&v.as_utf8()), - Value::Binary(v) => builder.encode_field(&v.deref()), - Value::Date(v) => { - if let Some(date) = v.to_chrono_date() { - builder.encode_field(&date) - } else { - Err(PgWireError::ApiError(Box::new(Error::Internal { - err_msg: format!("Failed to convert date to postgres type {v:?}",), - }))) - } +impl DefaultQueryParser { + pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc) -> Self { + DefaultQueryParser { + query_handler, + session, } - Value::DateTime(v) => { - if let Some(datetime) = v.to_chrono_datetime() { - builder.encode_field(&datetime) - } else { - Err(PgWireError::ApiError(Box::new(Error::Internal { - err_msg: format!("Failed to convert date to postgres type {v:?}",), - }))) - } - } - Value::Timestamp(v) => { - if let Some(datetime) = v.to_chrono_datetime() { - builder.encode_field(&datetime) - } else { - Err(PgWireError::ApiError(Box::new(Error::Internal { - err_msg: format!("Failed to convert date to postgres type {v:?}",), - }))) - } - } - Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal { - err_msg: format!( - "cannot write value {:?} in postgres protocol: unimplemented", - &value - ), - }))), } } -fn type_gt_to_pg(origin: &ConcreteDataType) -> Result { - match origin { - &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN), - &ConcreteDataType::Boolean(_) => Ok(Type::BOOL), - &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR), - &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2), - &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4), - &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8), - &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4), - &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8), - &ConcreteDataType::Binary(_) => Ok(Type::BYTEA), - &ConcreteDataType::String(_) => Ok(Type::VARCHAR), - &ConcreteDataType::Date(_) => Ok(Type::DATE), - &ConcreteDataType::DateTime(_) => Ok(Type::TIMESTAMP), - &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP), - &ConcreteDataType::List(_) | &ConcreteDataType::Dictionary(_) => error::InternalSnafu { - err_msg: format!("not implemented for column datatype {origin:?}"), - } - .fail(), - } -} +#[async_trait] +impl QueryParser for DefaultQueryParser { + type Statement = SqlPlan; -fn type_pg_to_gt(origin: &Type) -> Result { - // Note that we only support a small amount of pg data types - match origin { - &Type::BOOL => Ok(ConcreteDataType::boolean_datatype()), - &Type::CHAR => Ok(ConcreteDataType::int8_datatype()), - &Type::INT2 => Ok(ConcreteDataType::int16_datatype()), - &Type::INT4 => Ok(ConcreteDataType::int32_datatype()), - &Type::INT8 => Ok(ConcreteDataType::int64_datatype()), - &Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()), - &Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype( - common_time::timestamp::TimeUnit::Millisecond, - )), - &Type::DATE => Ok(ConcreteDataType::date_datatype()), - &Type::TIME => Ok(ConcreteDataType::datetime_datatype()), - _ => error::InternalSnafu { - err_msg: format!("unimplemented datatype {origin:?}"), - } - .fail(), - } -} - -#[derive(Default)] -pub struct POCQueryParser; - -impl QueryParser for POCQueryParser { - type Statement = (Statement, String); - - fn parse_sql(&self, sql: &str, types: &[Type]) -> PgWireResult { + async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult { increment_counter!(crate::metrics::METRIC_POSTGRES_PREPARED_COUNT); let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {}) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; @@ -270,70 +169,36 @@ impl QueryParser for POCQueryParser { "invalid_prepared_statement_definition".to_owned(), )))) } else { - let mut stmt = stmts.remove(0); - if let Statement::Query(qs) = &mut stmt { - for t in types { - let gt_type = - type_pg_to_gt(t).map_err(|e| PgWireError::ApiError(Box::new(e)))?; - qs.param_types_mut().push(gt_type); - } - } + let stmt = stmts.remove(0); + let describe_result = self + .query_handler + .do_describe(stmt, self.session.context()) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - Ok((stmt, sql.to_owned())) + let (plan, schema) = if let Some(DescribeResult { + logical_plan, + schema, + }) = describe_result + { + (Some(logical_plan), Some(schema)) + } else { + (None, None) + }; + + Ok(SqlPlan { + query: sql.to_owned(), + plan, + schema, + }) } } } -fn parameter_to_string(portal: &Portal<(Statement, String)>, idx: usize) -> PgWireResult { - // the index is managed from portal's parameters count so it's safe to - // unwrap here. - let param_type = portal.statement().parameter_types().get(idx).unwrap(); - match param_type { - &Type::VARCHAR | &Type::TEXT => Ok(format!( - "'{}'", - portal.parameter::(idx)?.as_deref().unwrap_or("") - )), - &Type::BOOL => Ok(portal - .parameter::(idx)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::INT4 => Ok(portal - .parameter::(idx)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::INT8 => Ok(portal - .parameter::(idx)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::FLOAT4 => Ok(portal - .parameter::(idx)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - &Type::FLOAT8 => Ok(portal - .parameter::(idx)? - .map(|v| v.to_string()) - .unwrap_or_else(|| "".to_owned())), - _ => Err(PgWireError::UserError(Box::new(ErrorInfo::new( - "ERROR".to_owned(), - "22023".to_owned(), - "unsupported_parameter_value".to_owned(), - )))), - } -} - -// TODO(sunng87): this is a proof-of-concept implementation of postgres extended -// query. We will choose better `Statement` for caching, a good statement type -// is easy to: -// -// - getting schema from -// - setting parameters in -// -// Datafusion's LogicalPlan is a good candidate for SELECT. But we need to -// confirm it's support for other SQL command like INSERT, UPDATE. #[async_trait] impl ExtendedQueryHandler for PostgresServerHandler { - type Statement = (Statement, String); - type QueryParser = POCQueryParser; + type Statement = SqlPlan; + type QueryParser = DefaultQueryParser; type PortalStore = MemPortalStore; fn portal_store(&self) -> Arc { @@ -366,20 +231,29 @@ impl ExtendedQueryHandler for PostgresServerHandler { ) ] ); - let (_, sql) = portal.statement().statement(); + let sql_plan = portal.statement().statement(); - // manually replace variables in prepared statement - // FIXME(sunng87) - let mut sql = sql.clone(); - for i in 0..portal.parameter_len() { - sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?); - } + let output = if let Some(plan) = &sql_plan.plan { + let plan = plan + .replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref()) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + self.query_handler + .do_exec_plan(plan, self.session.context()) + .await + } else { + // manually replace variables in prepared statement when no + // logical_plan is generated. This happens when logical plan is not + // supported for certain statements. + let mut sql = sql_plan.query.clone(); + for i in 0..portal.parameter_len() { + sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?); + } - let output = self - .query_handler - .do_query(&sql, self.session.context()) - .await - .remove(0); + self.query_handler + .do_query(&sql, self.session.context()) + .await + .remove(0) + }; output_to_query_response(output, portal.result_column_format()) } @@ -392,9 +266,11 @@ impl ExtendedQueryHandler for PostgresServerHandler { where C: ClientInfo + Unpin + Send + Sync, { - let (param_types, stmt, format) = match target { + let (param_types, sql_plan, format) = match target { StatementOrPortal::Statement(stmt) => { let param_types = Some(stmt.parameter_types().clone()); + // TODO(sunng87): return server inferenced param_types if client + // not specified (param_types, stmt.statement(), &Format::UnifiedBinary) } StatementOrPortal::Portal(portal) => ( @@ -403,16 +279,9 @@ impl ExtendedQueryHandler for PostgresServerHandler { portal.result_column_format(), ), }; - // get Statement part of the tuple - let (stmt, _) = stmt; - if let Some(DescribeResult { schema, .. }) = self - .query_handler - .do_describe(stmt.clone(), self.session.context()) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))? - { - schema_to_pg(&schema, format) + if let Some(schema) = &sql_plan.schema { + schema_to_pg(schema, format) .map(|fields| DescribeResponse::new(param_types, fields)) .map_err(|e| PgWireError::ApiError(Box::new(e))) } else { @@ -420,230 +289,3 @@ impl ExtendedQueryHandler for PostgresServerHandler { } } } - -#[cfg(test)] -mod test { - use datatypes::schema::{ColumnSchema, Schema}; - use datatypes::value::ListValue; - use pgwire::api::results::{FieldFormat, FieldInfo}; - use pgwire::api::Type; - - use super::*; - - #[test] - fn test_schema_convert() { - let column_schemas = vec![ - ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true), - ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true), - ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true), - ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true), - ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true), - ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true), - ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true), - ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true), - ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true), - ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true), - ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true), - ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true), - ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true), - ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), - ColumnSchema::new( - "timestamps", - ConcreteDataType::timestamp_millisecond_datatype(), - true, - ), - ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true), - ]; - let pg_field_info = vec![ - FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text), - FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text), - FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text), - FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text), - FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text), - FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text), - FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text), - FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text), - FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text), - FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text), - FieldInfo::new( - "float32s".into(), - None, - None, - Type::FLOAT4, - FieldFormat::Text, - ), - FieldInfo::new( - "float64s".into(), - None, - None, - Type::FLOAT8, - FieldFormat::Text, - ), - FieldInfo::new( - "binaries".into(), - None, - None, - Type::BYTEA, - FieldFormat::Text, - ), - FieldInfo::new( - "strings".into(), - None, - None, - Type::VARCHAR, - FieldFormat::Text, - ), - FieldInfo::new( - "timestamps".into(), - None, - None, - Type::TIMESTAMP, - FieldFormat::Text, - ), - FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text), - ]; - let schema = Schema::new(column_schemas); - let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap(); - assert_eq!(fs, pg_field_info); - } - - #[test] - fn test_encode_text_format_data() { - let schema = vec![ - FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text), - FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text), - FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text), - FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text), - FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text), - FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text), - FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text), - FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text), - FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text), - FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text), - FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text), - FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text), - FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text), - FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text), - FieldInfo::new( - "float32s".into(), - None, - None, - Type::FLOAT4, - FieldFormat::Text, - ), - FieldInfo::new( - "float32s".into(), - None, - None, - Type::FLOAT4, - FieldFormat::Text, - ), - FieldInfo::new( - "float32s".into(), - None, - None, - Type::FLOAT4, - FieldFormat::Text, - ), - FieldInfo::new( - "float64s".into(), - None, - None, - Type::FLOAT8, - FieldFormat::Text, - ), - FieldInfo::new( - "float64s".into(), - None, - None, - Type::FLOAT8, - FieldFormat::Text, - ), - FieldInfo::new( - "float64s".into(), - None, - None, - Type::FLOAT8, - FieldFormat::Text, - ), - FieldInfo::new( - "strings".into(), - None, - None, - Type::VARCHAR, - FieldFormat::Text, - ), - FieldInfo::new( - "binaries".into(), - None, - None, - Type::BYTEA, - FieldFormat::Text, - ), - FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text), - FieldInfo::new( - "datetimes".into(), - None, - None, - Type::TIMESTAMP, - FieldFormat::Text, - ), - FieldInfo::new( - "timestamps".into(), - None, - None, - Type::TIMESTAMP, - FieldFormat::Text, - ), - ]; - - let values = vec![ - Value::Null, - Value::Boolean(true), - Value::UInt8(u8::MAX), - Value::UInt16(u16::MAX), - Value::UInt32(u32::MAX), - Value::UInt64(u64::MAX), - Value::Int8(i8::MAX), - Value::Int8(i8::MIN), - Value::Int16(i16::MAX), - Value::Int16(i16::MIN), - Value::Int32(i32::MAX), - Value::Int32(i32::MIN), - Value::Int64(i64::MAX), - Value::Int64(i64::MIN), - Value::Float32(f32::MAX.into()), - Value::Float32(f32::MIN.into()), - Value::Float32(0f32.into()), - Value::Float64(f64::MAX.into()), - Value::Float64(f64::MIN.into()), - Value::Float64(0f64.into()), - Value::String("greptime".into()), - Value::Binary("greptime".as_bytes().into()), - Value::Date(1001i32.into()), - Value::DateTime(1000001i64.into()), - Value::Timestamp(1000001i64.into()), - ]; - let mut builder = DataRowEncoder::new(Arc::new(schema)); - for i in values.iter() { - encode_value(i, &mut builder).unwrap(); - } - - let err = encode_value( - &Value::List(ListValue::new( - Some(Box::default()), - ConcreteDataType::int16_datatype(), - )), - &mut builder, - ) - .unwrap_err(); - match err { - PgWireError::ApiError(e) => { - assert!(format!("{e}").contains("Internal error:")); - } - _ => { - unreachable!() - } - } - } -} diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 8b6c36c95d..cc83ece444 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -22,7 +22,6 @@ use common_telemetry::logging::error; use common_telemetry::{debug, warn}; use futures::StreamExt; use metrics::{decrement_gauge, increment_gauge}; -use pgwire::api::MakeHandler; use pgwire::tokio::process_socket; use tokio; use tokio_rustls::TlsAcceptor; @@ -69,32 +68,36 @@ impl PostgresServer { accepting_stream: AbortableStream, tls_acceptor: Option>, ) -> impl Future { - let handler = self.make_handler.clone(); + let handler_maker = self.make_handler.clone(); accepting_stream.for_each(move |tcp_stream| { let io_runtime = io_runtime.clone(); let tls_acceptor = tls_acceptor.clone(); - let mut handler = handler.make(); + let handler_maker = handler_maker.clone(); + async move { match tcp_stream { Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt. Ok(io_stream) => { - match io_stream.peer_addr() { + let addr = match io_stream.peer_addr() { Ok(addr) => { - handler.session.mut_conn_info().client_addr = Some(addr); - debug!("PostgreSQL client coming from {}", addr) + debug!("PostgreSQL client coming from {}", addr); + Some(addr) } - Err(e) => warn!("Failed to get PostgreSQL client addr, err: {}", e), - } + Err(e) => { + warn!("Failed to get PostgreSQL client addr, err: {}", e); + None + } + }; let _handle = io_runtime.spawn(async move { increment_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0); - let handler = Arc::new(handler); + let pg_handler = Arc::new(handler_maker.make(addr)); let r = process_socket( io_stream, tls_acceptor.clone(), - handler.clone(), - handler.clone(), - handler, + pg_handler.clone(), + pg_handler.clone(), + pg_handler, ) .await; decrement_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0); diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs new file mode 100644 index 0000000000..1e355b7275 --- /dev/null +++ b/src/servers/src/postgres/types.rs @@ -0,0 +1,712 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Deref; + +use chrono::{NaiveDate, NaiveDateTime}; +use datafusion_common::ScalarValue; +use datatypes::prelude::{ConcreteDataType, Value}; +use datatypes::schema::Schema; +use datatypes::types::TimestampType; +use pgwire::api::portal::{Format, Portal}; +use pgwire::api::results::{DataRowEncoder, FieldInfo}; +use pgwire::api::Type; +use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use query::plan::LogicalPlan; + +use crate::error::{self, Error, Result}; +use crate::SqlPlan; + +pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result> { + origin + .column_schemas() + .iter() + .enumerate() + .map(|(idx, col)| { + Ok(FieldInfo::new( + col.name.clone(), + None, + None, + type_gt_to_pg(&col.data_type)?, + field_formats.format_for(idx), + )) + }) + .collect::>>() +} + +pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> { + match value { + Value::Null => builder.encode_field(&None::<&i8>), + Value::Boolean(v) => builder.encode_field(v), + Value::UInt8(v) => builder.encode_field(&(*v as i8)), + Value::UInt16(v) => builder.encode_field(&(*v as i16)), + Value::UInt32(v) => builder.encode_field(v), + Value::UInt64(v) => builder.encode_field(&(*v as i64)), + Value::Int8(v) => builder.encode_field(v), + Value::Int16(v) => builder.encode_field(v), + Value::Int32(v) => builder.encode_field(v), + Value::Int64(v) => builder.encode_field(v), + Value::Float32(v) => builder.encode_field(&v.0), + Value::Float64(v) => builder.encode_field(&v.0), + Value::String(v) => builder.encode_field(&v.as_utf8()), + Value::Binary(v) => builder.encode_field(&v.deref()), + Value::Date(v) => { + if let Some(date) = v.to_chrono_date() { + builder.encode_field(&date) + } else { + Err(PgWireError::ApiError(Box::new(Error::Internal { + err_msg: format!("Failed to convert date to postgres type {v:?}",), + }))) + } + } + Value::DateTime(v) => { + if let Some(datetime) = v.to_chrono_datetime() { + builder.encode_field(&datetime) + } else { + Err(PgWireError::ApiError(Box::new(Error::Internal { + err_msg: format!("Failed to convert date to postgres type {v:?}",), + }))) + } + } + Value::Timestamp(v) => { + if let Some(datetime) = v.to_chrono_datetime() { + builder.encode_field(&datetime) + } else { + Err(PgWireError::ApiError(Box::new(Error::Internal { + err_msg: format!("Failed to convert date to postgres type {v:?}",), + }))) + } + } + Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal { + err_msg: format!( + "cannot write value {:?} in postgres protocol: unimplemented", + &value + ), + }))), + } +} + +pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result { + match origin { + &ConcreteDataType::Null(_) => Ok(Type::UNKNOWN), + &ConcreteDataType::Boolean(_) => Ok(Type::BOOL), + &ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR), + &ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2), + &ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4), + &ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8), + &ConcreteDataType::Float32(_) => Ok(Type::FLOAT4), + &ConcreteDataType::Float64(_) => Ok(Type::FLOAT8), + &ConcreteDataType::Binary(_) => Ok(Type::BYTEA), + &ConcreteDataType::String(_) => Ok(Type::VARCHAR), + &ConcreteDataType::Date(_) => Ok(Type::DATE), + &ConcreteDataType::DateTime(_) => Ok(Type::TIMESTAMP), + &ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP), + &ConcreteDataType::List(_) | &ConcreteDataType::Dictionary(_) => error::InternalSnafu { + err_msg: format!("not implemented for column datatype {origin:?}"), + } + .fail(), + } +} + +#[allow(dead_code)] +pub(super) fn type_pg_to_gt(origin: &Type) -> Result { + // Note that we only support a small amount of pg data types + match origin { + &Type::BOOL => Ok(ConcreteDataType::boolean_datatype()), + &Type::CHAR => Ok(ConcreteDataType::int8_datatype()), + &Type::INT2 => Ok(ConcreteDataType::int16_datatype()), + &Type::INT4 => Ok(ConcreteDataType::int32_datatype()), + &Type::INT8 => Ok(ConcreteDataType::int64_datatype()), + &Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()), + &Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype( + common_time::timestamp::TimeUnit::Millisecond, + )), + &Type::DATE => Ok(ConcreteDataType::date_datatype()), + &Type::TIME => Ok(ConcreteDataType::datetime_datatype()), + _ => error::InternalSnafu { + err_msg: format!("unimplemented datatype {origin:?}"), + } + .fail(), + } +} + +pub(super) fn parameter_to_string(portal: &Portal, idx: usize) -> PgWireResult { + // the index is managed from portal's parameters count so it's safe to + // unwrap here. + let param_type = portal.statement().parameter_types().get(idx).unwrap(); + match param_type { + &Type::VARCHAR | &Type::TEXT => Ok(format!( + "'{}'", + portal.parameter::(idx)?.as_deref().unwrap_or("") + )), + &Type::BOOL => Ok(portal + .parameter::(idx)? + .map(|v| v.to_string()) + .unwrap_or_else(|| "".to_owned())), + &Type::INT4 => Ok(portal + .parameter::(idx)? + .map(|v| v.to_string()) + .unwrap_or_else(|| "".to_owned())), + &Type::INT8 => Ok(portal + .parameter::(idx)? + .map(|v| v.to_string()) + .unwrap_or_else(|| "".to_owned())), + &Type::FLOAT4 => Ok(portal + .parameter::(idx)? + .map(|v| v.to_string()) + .unwrap_or_else(|| "".to_owned())), + &Type::FLOAT8 => Ok(portal + .parameter::(idx)? + .map(|v| v.to_string()) + .unwrap_or_else(|| "".to_owned())), + &Type::DATE => Ok(portal + .parameter::(idx)? + .map(|v| v.format("%Y-%m-%d").to_string()) + .unwrap_or_else(|| "".to_owned())), + &Type::TIMESTAMP => Ok(portal + .parameter::(idx)? + .map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string()) + .unwrap_or_else(|| "".to_owned())), + _ => Err(invalid_parameter_error( + "unsupported_parameter_type", + Some(¶m_type.to_string()), + )), + } +} + +pub(super) fn invalid_parameter_error(msg: &str, detail: Option<&str>) -> PgWireError { + let mut error_info = ErrorInfo::new("ERROR".to_owned(), "22023".to_owned(), msg.to_owned()); + error_info.set_detail(detail.map(|s| s.to_owned())); + PgWireError::UserError(Box::new(error_info)) +} + +fn to_timestamp_scalar_value( + data: Option, + unit: &TimestampType, + ctype: &ConcreteDataType, +) -> PgWireResult +where + T: Into, +{ + if let Some(n) = data { + Value::Timestamp(unit.create_timestamp(n.into())) + .try_to_scalar_value(ctype) + .map_err(|e| PgWireError::ApiError(Box::new(e))) + } else { + Ok(ScalarValue::Null) + } +} + +pub(super) fn parameters_to_scalar_values( + plan: &LogicalPlan, + portal: &Portal, +) -> PgWireResult> { + let param_count = portal.parameter_len(); + let mut results = Vec::with_capacity(param_count); + + let client_param_types = portal.statement().parameter_types(); + let param_types = plan + .get_param_types() + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + // ensure parameter count consistent for: client parameter types, server + // parameter types and parameter count + if param_types.len() != param_count { + return Err(invalid_parameter_error( + "invalid_parameter_count", + Some(&format!( + "Expected: {}, found: {}", + param_types.len(), + param_count + )), + )); + } + if client_param_types.len() != param_count { + return Err(invalid_parameter_error( + "invalid_parameter_count", + Some(&format!( + "Expected: {}, found: {}", + client_param_types.len(), + param_count + )), + )); + } + + for (idx, client_type) in client_param_types.iter().enumerate() { + let Some(Some(server_type)) = param_types.get(&format!("${}", idx + 1)) else { continue }; + let value = match client_type { + &Type::VARCHAR | &Type::TEXT => { + let data = portal.parameter::(idx)?; + match server_type { + ConcreteDataType::String(_) => ScalarValue::Utf8(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )) + } + } + } + &Type::BOOL => { + let data = portal.parameter::(idx)?; + match server_type { + ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )) + } + } + } + &Type::INT2 => { + let data = portal.parameter::(idx)?; + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Timestamp(unit) => { + to_timestamp_scalar_value(data, unit, server_type)? + } + ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )) + } + } + } + &Type::INT4 => { + let data = portal.parameter::(idx)?; + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Timestamp(unit) => { + to_timestamp_scalar_value(data, unit, server_type)? + } + ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )) + } + } + } + &Type::INT8 => { + let data = portal.parameter::(idx)?; + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Timestamp(unit) => { + to_timestamp_scalar_value(data, unit, server_type)? + } + ConcreteDataType::DateTime(_) => ScalarValue::Date64(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )) + } + } + } + &Type::FLOAT4 => { + let data = portal.parameter::(idx)?; + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Float32(_) => ScalarValue::Float32(data), + ConcreteDataType::Float64(_) => ScalarValue::Float64(data.map(|n| n as f64)), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )) + } + } + } + &Type::FLOAT8 => { + let data = portal.parameter::(idx)?; + match server_type { + ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)), + ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)), + ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)), + ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)), + ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)), + ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)), + ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)), + ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)), + ConcreteDataType::Float32(_) => ScalarValue::Float32(data.map(|n| n as f32)), + ConcreteDataType::Float64(_) => ScalarValue::Float64(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )) + } + } + } + &Type::TIMESTAMP => { + let data = portal.parameter::(idx)?; + match server_type { + ConcreteDataType::Timestamp(unit) => match *unit { + TimestampType::Second(_) => { + ScalarValue::TimestampSecond(data.map(|ts| ts.timestamp()), None) + } + TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond( + data.map(|ts| ts.timestamp_millis()), + None, + ), + TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond( + data.map(|ts| ts.timestamp_micros()), + None, + ), + TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond( + data.map(|ts| ts.timestamp_micros()), + None, + ), + }, + ConcreteDataType::DateTime(_) => { + ScalarValue::Date64(data.map(|d| d.timestamp_millis())) + } + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )) + } + } + } + &Type::DATE => { + let data = portal.parameter::(idx)?; + match server_type { + ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| { + (d - NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()).num_days() as i32 + })), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )); + } + } + } + &Type::BYTEA => { + let data = portal.parameter::>(idx)?; + match server_type { + ConcreteDataType::String(_) => { + ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string())) + } + ConcreteDataType::Binary(_) => ScalarValue::Binary(data), + _ => { + return Err(invalid_parameter_error( + "invalid_parameter_type", + Some(&format!( + "Expected: {}, found: {}", + server_type, client_type + )), + )); + } + } + } + _ => Err(invalid_parameter_error( + "unsupported_parameter_value", + Some(&format!("Found type: {}", client_type)), + ))?, + }; + results.push(value); + } + + Ok(results) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use datatypes::schema::{ColumnSchema, Schema}; + use datatypes::value::ListValue; + use pgwire::api::results::{FieldFormat, FieldInfo}; + use pgwire::api::Type; + + use super::*; + + #[test] + fn test_schema_convert() { + let column_schemas = vec![ + ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true), + ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true), + ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true), + ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true), + ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true), + ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true), + ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true), + ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true), + ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true), + ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true), + ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true), + ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true), + ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true), + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), + ColumnSchema::new( + "timestamps", + ConcreteDataType::timestamp_millisecond_datatype(), + true, + ), + ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true), + ]; + let pg_field_info = vec![ + FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text), + FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text), + FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text), + FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text), + FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text), + FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text), + FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text), + FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text), + FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text), + FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text), + FieldInfo::new( + "float32s".into(), + None, + None, + Type::FLOAT4, + FieldFormat::Text, + ), + FieldInfo::new( + "float64s".into(), + None, + None, + Type::FLOAT8, + FieldFormat::Text, + ), + FieldInfo::new( + "binaries".into(), + None, + None, + Type::BYTEA, + FieldFormat::Text, + ), + FieldInfo::new( + "strings".into(), + None, + None, + Type::VARCHAR, + FieldFormat::Text, + ), + FieldInfo::new( + "timestamps".into(), + None, + None, + Type::TIMESTAMP, + FieldFormat::Text, + ), + FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text), + ]; + let schema = Schema::new(column_schemas); + let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap(); + assert_eq!(fs, pg_field_info); + } + + #[test] + fn test_encode_text_format_data() { + let schema = vec![ + FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text), + FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text), + FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text), + FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text), + FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text), + FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text), + FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text), + FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text), + FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text), + FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text), + FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text), + FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text), + FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text), + FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text), + FieldInfo::new( + "float32s".into(), + None, + None, + Type::FLOAT4, + FieldFormat::Text, + ), + FieldInfo::new( + "float32s".into(), + None, + None, + Type::FLOAT4, + FieldFormat::Text, + ), + FieldInfo::new( + "float32s".into(), + None, + None, + Type::FLOAT4, + FieldFormat::Text, + ), + FieldInfo::new( + "float64s".into(), + None, + None, + Type::FLOAT8, + FieldFormat::Text, + ), + FieldInfo::new( + "float64s".into(), + None, + None, + Type::FLOAT8, + FieldFormat::Text, + ), + FieldInfo::new( + "float64s".into(), + None, + None, + Type::FLOAT8, + FieldFormat::Text, + ), + FieldInfo::new( + "strings".into(), + None, + None, + Type::VARCHAR, + FieldFormat::Text, + ), + FieldInfo::new( + "binaries".into(), + None, + None, + Type::BYTEA, + FieldFormat::Text, + ), + FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text), + FieldInfo::new( + "datetimes".into(), + None, + None, + Type::TIMESTAMP, + FieldFormat::Text, + ), + FieldInfo::new( + "timestamps".into(), + None, + None, + Type::TIMESTAMP, + FieldFormat::Text, + ), + ]; + + let values = vec![ + Value::Null, + Value::Boolean(true), + Value::UInt8(u8::MAX), + Value::UInt16(u16::MAX), + Value::UInt32(u32::MAX), + Value::UInt64(u64::MAX), + Value::Int8(i8::MAX), + Value::Int8(i8::MIN), + Value::Int16(i16::MAX), + Value::Int16(i16::MIN), + Value::Int32(i32::MAX), + Value::Int32(i32::MIN), + Value::Int64(i64::MAX), + Value::Int64(i64::MIN), + Value::Float32(f32::MAX.into()), + Value::Float32(f32::MIN.into()), + Value::Float32(0f32.into()), + Value::Float64(f64::MAX.into()), + Value::Float64(f64::MIN.into()), + Value::Float64(0f64.into()), + Value::String("greptime".into()), + Value::Binary("greptime".as_bytes().into()), + Value::Date(1001i32.into()), + Value::DateTime(1000001i64.into()), + Value::Timestamp(1000001i64.into()), + ]; + let mut builder = DataRowEncoder::new(Arc::new(schema)); + for i in values.iter() { + encode_value(i, &mut builder).unwrap(); + } + + let err = encode_value( + &Value::List(ListValue::new( + Some(Box::default()), + ConcreteDataType::int16_datatype(), + )), + &mut builder, + ) + .unwrap_err(); + match err { + PgWireError::ApiError(e) => { + assert!(format!("{e}").contains("Internal error:")); + } + _ => { + unreachable!() + } + } + } +} diff --git a/src/sql/src/statements/query.rs b/src/sql/src/statements/query.rs index ef0e17665a..c2b720ce15 100644 --- a/src/sql/src/statements/query.rs +++ b/src/sql/src/statements/query.rs @@ -14,7 +14,6 @@ use std::fmt; -use datatypes::prelude::ConcreteDataType; use sqlparser::ast::Query as SpQuery; use crate::error::Error; @@ -23,7 +22,6 @@ use crate::error::Error; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Query { pub inner: SpQuery, - pub param_types: Vec, } /// Automatically converts from sqlparser Query instance to SqlQuery. @@ -31,10 +29,7 @@ impl TryFrom for Query { type Error = Error; fn try_from(q: SpQuery) -> Result { - Ok(Query { - inner: q, - param_types: vec![], - }) + Ok(Query { inner: q }) } } @@ -46,27 +41,9 @@ impl TryFrom for SpQuery { } } -impl Query { - pub fn param_types(&self) -> &Vec { - &self.param_types - } - - pub fn param_types_mut(&mut self) -> &mut Vec { - &mut self.param_types - } -} - impl fmt::Display for Query { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} ", self.inner)?; - write!(f, "[")?; - for i in 0..self.param_types.len() { - write!(f, "{}", self.param_types[i])?; - if i != self.param_types.len() - 1 { - write!(f, ",")?; - } - } - write!(f, "]")?; + write!(f, "{}", self.inner)?; Ok(()) } } @@ -95,7 +72,7 @@ mod test { create_query("select * from abc where x = 1 and y = 7") .unwrap() .to_string(), - "SELECT * FROM abc WHERE x = 1 AND y = 7 []" + "SELECT * FROM abc WHERE x = 1 AND y = 7" ); assert_eq!( create_query( @@ -103,7 +80,7 @@ mod test { ) .unwrap() .to_string(), - "SELECT * FROM abc LEFT JOIN bcd WHERE abc.a = 1 AND bcd.d = 7 AND abc.id = bcd.id []" + "SELECT * FROM abc LEFT JOIN bcd WHERE abc.a = 1 AND bcd.d = 7 AND abc.id = bcd.id" ); } } diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index 78b1526363..1d1d2f824e 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -145,22 +145,26 @@ pub async fn test_postgres_crud(store_type: StorageType) { .await .unwrap(); - assert!( - sqlx::query("create table demo(i bigint, ts timestamp time index)") - .execute(&pool) - .await - .is_ok() - ); + sqlx::query("create table demo(i bigint, ts timestamp time index, d date, dt datetime)") + .execute(&pool) + .await + .unwrap(); + for i in 0..10 { - assert!(sqlx::query("insert into demo values($1, $2)") + let d = NaiveDate::from_yo_opt(2015, 100).unwrap(); + let dt = d.and_hms_opt(0, 0, 0).unwrap().timestamp_millis(); + + sqlx::query("insert into demo values($1, $2, $3, $4)") .bind(i) .bind(i) + .bind(d) + .bind(dt) .execute(&pool) .await - .is_ok()); + .unwrap(); } - let rows = sqlx::query("select i from demo") + let rows = sqlx::query("select i,d,dt from demo") .fetch_all(&pool) .await .unwrap(); @@ -168,7 +172,18 @@ pub async fn test_postgres_crud(store_type: StorageType) { for (i, row) in rows.iter().enumerate() { let ret: i64 = row.get(0); + let d: NaiveDate = row.get(1); + let dt: NaiveDateTime = row.get(2); + assert_eq!(ret, i as i64); + + let expected_d = NaiveDate::from_yo_opt(2015, 100).unwrap(); + assert_eq!(expected_d, d); + + let expected_dt = NaiveDate::from_yo_opt(2015, 100) + .and_then(|d| d.and_hms_opt(0, 0, 0)) + .unwrap(); + assert_eq!(expected_dt, dt); } let rows = sqlx::query("select i from demo where i=$1")