feat: add logical plan based prepare statement for postgresql (#1813)

* feat: add logical plan based prepare statement for postgresql

* refactor: correct more types

* Update src/servers/src/postgres/types.rs

Co-authored-by: LFC <bayinamine@gmail.com>

* fix: address review issues

* test: add datetime in integration tests

---------

Co-authored-by: LFC <bayinamine@gmail.com>
This commit is contained in:
Ning Sun
2023-07-11 11:07:18 +08:00
committed by GitHub
parent c615fb2a93
commit f293126315
11 changed files with 865 additions and 506 deletions

10
Cargo.lock generated
View File

@@ -6408,9 +6408,9 @@ dependencies = [
[[package]]
name = "pgwire"
version = "0.14.1"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd92c65406efd0d621cdece478a41a89e472a559e44a6f2b218df4c14e66a888"
checksum = "e2de42ee35f9694def25c37c15f564555411d9904b48e33680618ee7359080dc"
dependencies = [
"async-trait",
"base64 0.21.2",
@@ -11252,16 +11252,16 @@ dependencies = [
[[package]]
name = "x509-certificate"
version = "0.19.0"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf14059fbc1dce14de1d08535c411ba0b18749c2550a12550300da90b7ba350b"
checksum = "2133ce6c08c050a5b368730a67c53a603ffd4a4a6c577c5218675a19f7782c05"
dependencies = [
"bcder",
"bytes",
"chrono",
"der 0.7.6",
"hex",
"pem 1.1.1",
"pem 2.0.1",
"ring",
"signature",
"spki 0.7.2",

View File

@@ -276,7 +276,7 @@ mod test {
having: None, \
named_window: [], \
qualify: None \
}), order_by: [], limit: None, offset: None, fetch: None, locks: [] }, param_types: [] }))");
}), order_by: [], limit: None, offset: None, fetch: None, locks: [] } }))");
assert_eq!(format!("{stmt:?}"), expected);
}

View File

@@ -55,7 +55,7 @@ once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = "0.4"
parking_lot = "0.12"
pgwire = "0.14.1"
pgwire = "0.15"
pin-project = "1.0"
postgres-types = { version = "0.2", features = ["with-chrono-0_4"] }
promql-parser = "0.1.1"

View File

@@ -16,6 +16,8 @@
#![feature(try_blocks)]
use common_catalog::consts::DEFAULT_CATALOG_NAME;
use datatypes::schema::Schema;
use query::plan::LogicalPlan;
use serde::{Deserialize, Serialize};
pub mod auth;
@@ -72,6 +74,14 @@ pub fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &s
}
}
/// Cached SQL and logical plan for database interfaces
#[derive(Clone)]
pub struct SqlPlan {
query: String,
plan: Option<LogicalPlan>,
schema: Option<Schema>,
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -48,13 +48,7 @@ use crate::mysql::helper::{
use crate::mysql::writer;
use crate::mysql::writer::create_mysql_column;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
/// Cached SQL and logical plan
#[derive(Clone)]
struct SqlPlan {
query: String,
plan: Option<LogicalPlan>,
}
use crate::SqlPlan;
// An intermediate shim for executing MySQL queries.
pub struct MysqlInstanceShim {
@@ -214,10 +208,16 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
// in the form of "$i", it can't process "?" right now.
let statement = transform_placeholders(statement);
let plan = self
.do_describe(statement.clone())
.await?
.map(|DescribeResult { logical_plan, .. }| logical_plan);
let describe_result = self.do_describe(statement.clone()).await?;
let (plan, schema) = if let Some(DescribeResult {
logical_plan,
schema,
}) = describe_result
{
(Some(logical_plan), Some(schema))
} else {
(None, None)
};
let params = if let Some(plan) = &plan {
prepared_params(
@@ -234,6 +234,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
let stmt_id = self.save_plan(SqlPlan {
query: query.to_string(),
plan,
schema,
});
w.reply(stmt_id, &params, &[]).await?;

View File

@@ -15,6 +15,7 @@
mod auth_handler;
mod handler;
mod server;
mod types;
pub(crate) const METADATA_USER: &str = "user";
pub(crate) const METADATA_DATABASE: &str = "database";
@@ -24,21 +25,22 @@ pub(crate) const METADATA_CATALOG: &str = "catalog";
pub(crate) const METADATA_SCHEMA: &str = "schema";
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use derive_builder::Builder;
use pgwire::api::auth::ServerParameterProvider;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, MakeHandler};
use pgwire::api::ClientInfo;
pub use server::PostgresServer;
use session::context::Channel;
use session::Session;
use sql::statements::statement::Statement;
use self::auth_handler::PgLoginVerifier;
use self::handler::POCQueryParser;
use self::handler::DefaultQueryParser;
use crate::auth::UserProviderRef;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::SqlPlan;
pub(crate) struct GreptimeDBStartupParameters {
version: &'static str,
@@ -73,9 +75,9 @@ pub struct PostgresServerHandler {
force_tls: bool,
param_provider: Arc<GreptimeDBStartupParameters>,
session: Session,
portal_store: Arc<MemPortalStore<(Statement, String)>>,
query_parser: Arc<POCQueryParser>,
session: Arc<Session>,
portal_store: Arc<MemPortalStore<SqlPlan>>,
query_parser: Arc<DefaultQueryParser>,
}
#[derive(Builder)]
@@ -84,24 +86,21 @@ pub(crate) struct MakePostgresServerHandler {
user_provider: Option<UserProviderRef>,
#[builder(default = "Arc::new(GreptimeDBStartupParameters::new())")]
param_provider: Arc<GreptimeDBStartupParameters>,
#[builder(default = "Arc::new(POCQueryParser::default())")]
query_parser: Arc<POCQueryParser>,
force_tls: bool,
}
impl MakeHandler for MakePostgresServerHandler {
type Handler = PostgresServerHandler;
fn make(&self) -> Self::Handler {
impl MakePostgresServerHandler {
fn make(&self, addr: Option<SocketAddr>) -> PostgresServerHandler {
let session = Arc::new(Session::new(addr, Channel::Postgres));
PostgresServerHandler {
query_handler: self.query_handler.clone(),
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
force_tls: self.force_tls,
param_provider: self.param_provider.clone(),
session: Session::new(None, Channel::Postgres),
session: session.clone(),
portal_store: Arc::new(MemPortalStore::new()),
query_parser: self.query_parser.clone(),
query_parser: Arc::new(DefaultQueryParser::new(self.query_handler.clone(), session)),
}
}
}

View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::Deref;
use std::sync::Arc;
use async_trait::async_trait;
@@ -20,26 +19,26 @@ use common_query::Output;
use common_recordbatch::error::Result as RecordBatchResult;
use common_recordbatch::RecordBatch;
use common_telemetry::timer;
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::{Schema, SchemaRef};
use datatypes::schema::SchemaRef;
use futures::{future, stream, Stream, StreamExt};
use metrics::increment_counter;
use pgwire::api::portal::{Format, Portal};
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler, StatementOrPortal};
use pgwire::api::results::{
DataRowEncoder, DescribeResponse, FieldInfo, QueryResponse, Response, Tag,
};
use pgwire::api::results::{DataRowEncoder, DescribeResponse, QueryResponse, Response, Tag};
use pgwire::api::stmt::QueryParser;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use query::query_engine::DescribeResult;
use session::Session;
use sql::dialect::PostgreSqlDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use super::types::*;
use super::PostgresServerHandler;
use crate::error::{self, Error, Result};
use crate::error::Result;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::SqlPlan;
#[async_trait]
impl SimpleQueryHandler for PostgresServerHandler {
@@ -141,125 +140,25 @@ where
)))
}
fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
origin
.column_schemas()
.iter()
.enumerate()
.map(|(idx, col)| {
Ok(FieldInfo::new(
col.name.clone(),
None,
None,
type_gt_to_pg(&col.data_type)?,
field_formats.format_for(idx),
))
})
.collect::<Result<Vec<FieldInfo>>>()
pub struct DefaultQueryParser {
query_handler: ServerSqlQueryHandlerRef,
session: Arc<Session>,
}
fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
match value {
Value::Null => builder.encode_field(&None::<&i8>),
Value::Boolean(v) => builder.encode_field(v),
Value::UInt8(v) => builder.encode_field(&(*v as i8)),
Value::UInt16(v) => builder.encode_field(&(*v as i16)),
Value::UInt32(v) => builder.encode_field(v),
Value::UInt64(v) => builder.encode_field(&(*v as i64)),
Value::Int8(v) => builder.encode_field(v),
Value::Int16(v) => builder.encode_field(v),
Value::Int32(v) => builder.encode_field(v),
Value::Int64(v) => builder.encode_field(v),
Value::Float32(v) => builder.encode_field(&v.0),
Value::Float64(v) => builder.encode_field(&v.0),
Value::String(v) => builder.encode_field(&v.as_utf8()),
Value::Binary(v) => builder.encode_field(&v.deref()),
Value::Date(v) => {
if let Some(date) = v.to_chrono_date() {
builder.encode_field(&date)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
})))
}
impl DefaultQueryParser {
pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
DefaultQueryParser {
query_handler,
session,
}
Value::DateTime(v) => {
if let Some(datetime) = v.to_chrono_datetime() {
builder.encode_field(&datetime)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
})))
}
}
Value::Timestamp(v) => {
if let Some(datetime) = v.to_chrono_datetime() {
builder.encode_field(&datetime)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
})))
}
}
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!(
"cannot write value {:?} in postgres protocol: unimplemented",
&value
),
}))),
}
}
fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
match origin {
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
&ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
&ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
&ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
&ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
&ConcreteDataType::Binary(_) => Ok(Type::BYTEA),
&ConcreteDataType::String(_) => Ok(Type::VARCHAR),
&ConcreteDataType::Date(_) => Ok(Type::DATE),
&ConcreteDataType::DateTime(_) => Ok(Type::TIMESTAMP),
&ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
&ConcreteDataType::List(_) | &ConcreteDataType::Dictionary(_) => error::InternalSnafu {
err_msg: format!("not implemented for column datatype {origin:?}"),
}
.fail(),
}
}
#[async_trait]
impl QueryParser for DefaultQueryParser {
type Statement = SqlPlan;
fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
// Note that we only support a small amount of pg data types
match origin {
&Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
&Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
&Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
&Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
&Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
&Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
&Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype(
common_time::timestamp::TimeUnit::Millisecond,
)),
&Type::DATE => Ok(ConcreteDataType::date_datatype()),
&Type::TIME => Ok(ConcreteDataType::datetime_datatype()),
_ => error::InternalSnafu {
err_msg: format!("unimplemented datatype {origin:?}"),
}
.fail(),
}
}
#[derive(Default)]
pub struct POCQueryParser;
impl QueryParser for POCQueryParser {
type Statement = (Statement, String);
fn parse_sql(&self, sql: &str, types: &[Type]) -> PgWireResult<Self::Statement> {
async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
increment_counter!(crate::metrics::METRIC_POSTGRES_PREPARED_COUNT);
let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {})
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
@@ -270,70 +169,36 @@ impl QueryParser for POCQueryParser {
"invalid_prepared_statement_definition".to_owned(),
))))
} else {
let mut stmt = stmts.remove(0);
if let Statement::Query(qs) = &mut stmt {
for t in types {
let gt_type =
type_pg_to_gt(t).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
qs.param_types_mut().push(gt_type);
}
}
let stmt = stmts.remove(0);
let describe_result = self
.query_handler
.do_describe(stmt, self.session.context())
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Ok((stmt, sql.to_owned()))
let (plan, schema) = if let Some(DescribeResult {
logical_plan,
schema,
}) = describe_result
{
(Some(logical_plan), Some(schema))
} else {
(None, None)
};
Ok(SqlPlan {
query: sql.to_owned(),
plan,
schema,
})
}
}
}
fn parameter_to_string(portal: &Portal<(Statement, String)>, idx: usize) -> PgWireResult<String> {
// the index is managed from portal's parameters count so it's safe to
// unwrap here.
let param_type = portal.statement().parameter_types().get(idx).unwrap();
match param_type {
&Type::VARCHAR | &Type::TEXT => Ok(format!(
"'{}'",
portal.parameter::<String>(idx)?.as_deref().unwrap_or("")
)),
&Type::BOOL => Ok(portal
.parameter::<bool>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INT4 => Ok(portal
.parameter::<i32>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INT8 => Ok(portal
.parameter::<i64>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::FLOAT4 => Ok(portal
.parameter::<f32>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::FLOAT8 => Ok(portal
.parameter::<f64>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
_ => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"22023".to_owned(),
"unsupported_parameter_value".to_owned(),
)))),
}
}
// TODO(sunng87): this is a proof-of-concept implementation of postgres extended
// query. We will choose better `Statement` for caching, a good statement type
// is easy to:
//
// - getting schema from
// - setting parameters in
//
// Datafusion's LogicalPlan is a good candidate for SELECT. But we need to
// confirm it's support for other SQL command like INSERT, UPDATE.
#[async_trait]
impl ExtendedQueryHandler for PostgresServerHandler {
type Statement = (Statement, String);
type QueryParser = POCQueryParser;
type Statement = SqlPlan;
type QueryParser = DefaultQueryParser;
type PortalStore = MemPortalStore<Self::Statement>;
fn portal_store(&self) -> Arc<Self::PortalStore> {
@@ -366,20 +231,29 @@ impl ExtendedQueryHandler for PostgresServerHandler {
)
]
);
let (_, sql) = portal.statement().statement();
let sql_plan = portal.statement().statement();
// manually replace variables in prepared statement
// FIXME(sunng87)
let mut sql = sql.clone();
for i in 0..portal.parameter_len() {
sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
}
let output = if let Some(plan) = &sql_plan.plan {
let plan = plan
.replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
self.query_handler
.do_exec_plan(plan, self.session.context())
.await
} else {
// manually replace variables in prepared statement when no
// logical_plan is generated. This happens when logical plan is not
// supported for certain statements.
let mut sql = sql_plan.query.clone();
for i in 0..portal.parameter_len() {
sql = sql.replace(&format!("${}", i + 1), &parameter_to_string(portal, i)?);
}
let output = self
.query_handler
.do_query(&sql, self.session.context())
.await
.remove(0);
self.query_handler
.do_query(&sql, self.session.context())
.await
.remove(0)
};
output_to_query_response(output, portal.result_column_format())
}
@@ -392,9 +266,11 @@ impl ExtendedQueryHandler for PostgresServerHandler {
where
C: ClientInfo + Unpin + Send + Sync,
{
let (param_types, stmt, format) = match target {
let (param_types, sql_plan, format) = match target {
StatementOrPortal::Statement(stmt) => {
let param_types = Some(stmt.parameter_types().clone());
// TODO(sunng87): return server inferenced param_types if client
// not specified
(param_types, stmt.statement(), &Format::UnifiedBinary)
}
StatementOrPortal::Portal(portal) => (
@@ -403,16 +279,9 @@ impl ExtendedQueryHandler for PostgresServerHandler {
portal.result_column_format(),
),
};
// get Statement part of the tuple
let (stmt, _) = stmt;
if let Some(DescribeResult { schema, .. }) = self
.query_handler
.do_describe(stmt.clone(), self.session.context())
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
schema_to_pg(&schema, format)
if let Some(schema) = &sql_plan.schema {
schema_to_pg(schema, format)
.map(|fields| DescribeResponse::new(param_types, fields))
.map_err(|e| PgWireError::ApiError(Box::new(e)))
} else {
@@ -420,230 +289,3 @@ impl ExtendedQueryHandler for PostgresServerHandler {
}
}
}
#[cfg(test)]
mod test {
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::value::ListValue;
use pgwire::api::results::{FieldFormat, FieldInfo};
use pgwire::api::Type;
use super::*;
#[test]
fn test_schema_convert() {
let column_schemas = vec![
ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
ColumnSchema::new(
"timestamps",
ConcreteDataType::timestamp_millisecond_datatype(),
true,
),
ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
];
let pg_field_info = vec![
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new(
"float32s".into(),
None,
None,
Type::FLOAT4,
FieldFormat::Text,
),
FieldInfo::new(
"float64s".into(),
None,
None,
Type::FLOAT8,
FieldFormat::Text,
),
FieldInfo::new(
"binaries".into(),
None,
None,
Type::BYTEA,
FieldFormat::Text,
),
FieldInfo::new(
"strings".into(),
None,
None,
Type::VARCHAR,
FieldFormat::Text,
),
FieldInfo::new(
"timestamps".into(),
None,
None,
Type::TIMESTAMP,
FieldFormat::Text,
),
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
];
let schema = Schema::new(column_schemas);
let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
assert_eq!(fs, pg_field_info);
}
#[test]
fn test_encode_text_format_data() {
let schema = vec![
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new(
"float32s".into(),
None,
None,
Type::FLOAT4,
FieldFormat::Text,
),
FieldInfo::new(
"float32s".into(),
None,
None,
Type::FLOAT4,
FieldFormat::Text,
),
FieldInfo::new(
"float32s".into(),
None,
None,
Type::FLOAT4,
FieldFormat::Text,
),
FieldInfo::new(
"float64s".into(),
None,
None,
Type::FLOAT8,
FieldFormat::Text,
),
FieldInfo::new(
"float64s".into(),
None,
None,
Type::FLOAT8,
FieldFormat::Text,
),
FieldInfo::new(
"float64s".into(),
None,
None,
Type::FLOAT8,
FieldFormat::Text,
),
FieldInfo::new(
"strings".into(),
None,
None,
Type::VARCHAR,
FieldFormat::Text,
),
FieldInfo::new(
"binaries".into(),
None,
None,
Type::BYTEA,
FieldFormat::Text,
),
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
FieldInfo::new(
"datetimes".into(),
None,
None,
Type::TIMESTAMP,
FieldFormat::Text,
),
FieldInfo::new(
"timestamps".into(),
None,
None,
Type::TIMESTAMP,
FieldFormat::Text,
),
];
let values = vec![
Value::Null,
Value::Boolean(true),
Value::UInt8(u8::MAX),
Value::UInt16(u16::MAX),
Value::UInt32(u32::MAX),
Value::UInt64(u64::MAX),
Value::Int8(i8::MAX),
Value::Int8(i8::MIN),
Value::Int16(i16::MAX),
Value::Int16(i16::MIN),
Value::Int32(i32::MAX),
Value::Int32(i32::MIN),
Value::Int64(i64::MAX),
Value::Int64(i64::MIN),
Value::Float32(f32::MAX.into()),
Value::Float32(f32::MIN.into()),
Value::Float32(0f32.into()),
Value::Float64(f64::MAX.into()),
Value::Float64(f64::MIN.into()),
Value::Float64(0f64.into()),
Value::String("greptime".into()),
Value::Binary("greptime".as_bytes().into()),
Value::Date(1001i32.into()),
Value::DateTime(1000001i64.into()),
Value::Timestamp(1000001i64.into()),
];
let mut builder = DataRowEncoder::new(Arc::new(schema));
for i in values.iter() {
encode_value(i, &mut builder).unwrap();
}
let err = encode_value(
&Value::List(ListValue::new(
Some(Box::default()),
ConcreteDataType::int16_datatype(),
)),
&mut builder,
)
.unwrap_err();
match err {
PgWireError::ApiError(e) => {
assert!(format!("{e}").contains("Internal error:"));
}
_ => {
unreachable!()
}
}
}
}

View File

@@ -22,7 +22,6 @@ use common_telemetry::logging::error;
use common_telemetry::{debug, warn};
use futures::StreamExt;
use metrics::{decrement_gauge, increment_gauge};
use pgwire::api::MakeHandler;
use pgwire::tokio::process_socket;
use tokio;
use tokio_rustls::TlsAcceptor;
@@ -69,32 +68,36 @@ impl PostgresServer {
accepting_stream: AbortableStream,
tls_acceptor: Option<Arc<TlsAcceptor>>,
) -> impl Future<Output = ()> {
let handler = self.make_handler.clone();
let handler_maker = self.make_handler.clone();
accepting_stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let tls_acceptor = tls_acceptor.clone();
let mut handler = handler.make();
let handler_maker = handler_maker.clone();
async move {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
match io_stream.peer_addr() {
let addr = match io_stream.peer_addr() {
Ok(addr) => {
handler.session.mut_conn_info().client_addr = Some(addr);
debug!("PostgreSQL client coming from {}", addr)
debug!("PostgreSQL client coming from {}", addr);
Some(addr)
}
Err(e) => warn!("Failed to get PostgreSQL client addr, err: {}", e),
}
Err(e) => {
warn!("Failed to get PostgreSQL client addr, err: {}", e);
None
}
};
let _handle = io_runtime.spawn(async move {
increment_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0);
let handler = Arc::new(handler);
let pg_handler = Arc::new(handler_maker.make(addr));
let r = process_socket(
io_stream,
tls_acceptor.clone(),
handler.clone(),
handler.clone(),
handler,
pg_handler.clone(),
pg_handler.clone(),
pg_handler,
)
.await;
decrement_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0);

View File

@@ -0,0 +1,712 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::Deref;
use chrono::{NaiveDate, NaiveDateTime};
use datafusion_common::ScalarValue;
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::Schema;
use datatypes::types::TimestampType;
use pgwire::api::portal::{Format, Portal};
use pgwire::api::results::{DataRowEncoder, FieldInfo};
use pgwire::api::Type;
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use query::plan::LogicalPlan;
use crate::error::{self, Error, Result};
use crate::SqlPlan;
pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
origin
.column_schemas()
.iter()
.enumerate()
.map(|(idx, col)| {
Ok(FieldInfo::new(
col.name.clone(),
None,
None,
type_gt_to_pg(&col.data_type)?,
field_formats.format_for(idx),
))
})
.collect::<Result<Vec<FieldInfo>>>()
}
pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
match value {
Value::Null => builder.encode_field(&None::<&i8>),
Value::Boolean(v) => builder.encode_field(v),
Value::UInt8(v) => builder.encode_field(&(*v as i8)),
Value::UInt16(v) => builder.encode_field(&(*v as i16)),
Value::UInt32(v) => builder.encode_field(v),
Value::UInt64(v) => builder.encode_field(&(*v as i64)),
Value::Int8(v) => builder.encode_field(v),
Value::Int16(v) => builder.encode_field(v),
Value::Int32(v) => builder.encode_field(v),
Value::Int64(v) => builder.encode_field(v),
Value::Float32(v) => builder.encode_field(&v.0),
Value::Float64(v) => builder.encode_field(&v.0),
Value::String(v) => builder.encode_field(&v.as_utf8()),
Value::Binary(v) => builder.encode_field(&v.deref()),
Value::Date(v) => {
if let Some(date) = v.to_chrono_date() {
builder.encode_field(&date)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
})))
}
}
Value::DateTime(v) => {
if let Some(datetime) = v.to_chrono_datetime() {
builder.encode_field(&datetime)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
})))
}
}
Value::Timestamp(v) => {
if let Some(datetime) = v.to_chrono_datetime() {
builder.encode_field(&datetime)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
})))
}
}
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!(
"cannot write value {:?} in postgres protocol: unimplemented",
&value
),
}))),
}
}
pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
match origin {
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
&ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
&ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
&ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
&ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
&ConcreteDataType::Binary(_) => Ok(Type::BYTEA),
&ConcreteDataType::String(_) => Ok(Type::VARCHAR),
&ConcreteDataType::Date(_) => Ok(Type::DATE),
&ConcreteDataType::DateTime(_) => Ok(Type::TIMESTAMP),
&ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
&ConcreteDataType::List(_) | &ConcreteDataType::Dictionary(_) => error::InternalSnafu {
err_msg: format!("not implemented for column datatype {origin:?}"),
}
.fail(),
}
}
#[allow(dead_code)]
pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
// Note that we only support a small amount of pg data types
match origin {
&Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
&Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
&Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
&Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
&Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
&Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
&Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype(
common_time::timestamp::TimeUnit::Millisecond,
)),
&Type::DATE => Ok(ConcreteDataType::date_datatype()),
&Type::TIME => Ok(ConcreteDataType::datetime_datatype()),
_ => error::InternalSnafu {
err_msg: format!("unimplemented datatype {origin:?}"),
}
.fail(),
}
}
pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWireResult<String> {
// the index is managed from portal's parameters count so it's safe to
// unwrap here.
let param_type = portal.statement().parameter_types().get(idx).unwrap();
match param_type {
&Type::VARCHAR | &Type::TEXT => Ok(format!(
"'{}'",
portal.parameter::<String>(idx)?.as_deref().unwrap_or("")
)),
&Type::BOOL => Ok(portal
.parameter::<bool>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INT4 => Ok(portal
.parameter::<i32>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::INT8 => Ok(portal
.parameter::<i64>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::FLOAT4 => Ok(portal
.parameter::<f32>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::FLOAT8 => Ok(portal
.parameter::<f64>(idx)?
.map(|v| v.to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::DATE => Ok(portal
.parameter::<NaiveDate>(idx)?
.map(|v| v.format("%Y-%m-%d").to_string())
.unwrap_or_else(|| "".to_owned())),
&Type::TIMESTAMP => Ok(portal
.parameter::<NaiveDateTime>(idx)?
.map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
.unwrap_or_else(|| "".to_owned())),
_ => Err(invalid_parameter_error(
"unsupported_parameter_type",
Some(&param_type.to_string()),
)),
}
}
pub(super) fn invalid_parameter_error(msg: &str, detail: Option<&str>) -> PgWireError {
let mut error_info = ErrorInfo::new("ERROR".to_owned(), "22023".to_owned(), msg.to_owned());
error_info.set_detail(detail.map(|s| s.to_owned()));
PgWireError::UserError(Box::new(error_info))
}
fn to_timestamp_scalar_value<T>(
data: Option<T>,
unit: &TimestampType,
ctype: &ConcreteDataType,
) -> PgWireResult<ScalarValue>
where
T: Into<i64>,
{
if let Some(n) = data {
Value::Timestamp(unit.create_timestamp(n.into()))
.try_to_scalar_value(ctype)
.map_err(|e| PgWireError::ApiError(Box::new(e)))
} else {
Ok(ScalarValue::Null)
}
}
pub(super) fn parameters_to_scalar_values(
plan: &LogicalPlan,
portal: &Portal<SqlPlan>,
) -> PgWireResult<Vec<ScalarValue>> {
let param_count = portal.parameter_len();
let mut results = Vec::with_capacity(param_count);
let client_param_types = portal.statement().parameter_types();
let param_types = plan
.get_param_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
// ensure parameter count consistent for: client parameter types, server
// parameter types and parameter count
if param_types.len() != param_count {
return Err(invalid_parameter_error(
"invalid_parameter_count",
Some(&format!(
"Expected: {}, found: {}",
param_types.len(),
param_count
)),
));
}
if client_param_types.len() != param_count {
return Err(invalid_parameter_error(
"invalid_parameter_count",
Some(&format!(
"Expected: {}, found: {}",
client_param_types.len(),
param_count
)),
));
}
for (idx, client_type) in client_param_types.iter().enumerate() {
let Some(Some(server_type)) = param_types.get(&format!("${}", idx + 1)) else { continue };
let value = match client_type {
&Type::VARCHAR | &Type::TEXT => {
let data = portal.parameter::<String>(idx)?;
match server_type {
ConcreteDataType::String(_) => ScalarValue::Utf8(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
}
}
}
&Type::BOOL => {
let data = portal.parameter::<bool>(idx)?;
match server_type {
ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
}
}
}
&Type::INT2 => {
let data = portal.parameter::<i16>(idx)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Timestamp(unit) => {
to_timestamp_scalar_value(data, unit, server_type)?
}
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
}
}
}
&Type::INT4 => {
let data = portal.parameter::<i32>(idx)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Timestamp(unit) => {
to_timestamp_scalar_value(data, unit, server_type)?
}
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
}
}
}
&Type::INT8 => {
let data = portal.parameter::<i64>(idx)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Timestamp(unit) => {
to_timestamp_scalar_value(data, unit, server_type)?
}
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
}
}
}
&Type::FLOAT4 => {
let data = portal.parameter::<f32>(idx)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Float32(_) => ScalarValue::Float32(data),
ConcreteDataType::Float64(_) => ScalarValue::Float64(data.map(|n| n as f64)),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
}
}
}
&Type::FLOAT8 => {
let data = portal.parameter::<f64>(idx)?;
match server_type {
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
ConcreteDataType::Float32(_) => ScalarValue::Float32(data.map(|n| n as f32)),
ConcreteDataType::Float64(_) => ScalarValue::Float64(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
}
}
}
&Type::TIMESTAMP => {
let data = portal.parameter::<NaiveDateTime>(idx)?;
match server_type {
ConcreteDataType::Timestamp(unit) => match *unit {
TimestampType::Second(_) => {
ScalarValue::TimestampSecond(data.map(|ts| ts.timestamp()), None)
}
TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
data.map(|ts| ts.timestamp_millis()),
None,
),
TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
data.map(|ts| ts.timestamp_micros()),
None,
),
TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
data.map(|ts| ts.timestamp_micros()),
None,
),
},
ConcreteDataType::DateTime(_) => {
ScalarValue::Date64(data.map(|d| d.timestamp_millis()))
}
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
))
}
}
}
&Type::DATE => {
let data = portal.parameter::<NaiveDate>(idx)?;
match server_type {
ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| {
(d - NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()).num_days() as i32
})),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
));
}
}
}
&Type::BYTEA => {
let data = portal.parameter::<Vec<u8>>(idx)?;
match server_type {
ConcreteDataType::String(_) => {
ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string()))
}
ConcreteDataType::Binary(_) => ScalarValue::Binary(data),
_ => {
return Err(invalid_parameter_error(
"invalid_parameter_type",
Some(&format!(
"Expected: {}, found: {}",
server_type, client_type
)),
));
}
}
}
_ => Err(invalid_parameter_error(
"unsupported_parameter_value",
Some(&format!("Found type: {}", client_type)),
))?,
};
results.push(value);
}
Ok(results)
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::value::ListValue;
use pgwire::api::results::{FieldFormat, FieldInfo};
use pgwire::api::Type;
use super::*;
#[test]
fn test_schema_convert() {
let column_schemas = vec![
ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
ColumnSchema::new(
"timestamps",
ConcreteDataType::timestamp_millisecond_datatype(),
true,
),
ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
];
let pg_field_info = vec![
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new(
"float32s".into(),
None,
None,
Type::FLOAT4,
FieldFormat::Text,
),
FieldInfo::new(
"float64s".into(),
None,
None,
Type::FLOAT8,
FieldFormat::Text,
),
FieldInfo::new(
"binaries".into(),
None,
None,
Type::BYTEA,
FieldFormat::Text,
),
FieldInfo::new(
"strings".into(),
None,
None,
Type::VARCHAR,
FieldFormat::Text,
),
FieldInfo::new(
"timestamps".into(),
None,
None,
Type::TIMESTAMP,
FieldFormat::Text,
),
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
];
let schema = Schema::new(column_schemas);
let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
assert_eq!(fs, pg_field_info);
}
#[test]
fn test_encode_text_format_data() {
let schema = vec![
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
FieldInfo::new(
"float32s".into(),
None,
None,
Type::FLOAT4,
FieldFormat::Text,
),
FieldInfo::new(
"float32s".into(),
None,
None,
Type::FLOAT4,
FieldFormat::Text,
),
FieldInfo::new(
"float32s".into(),
None,
None,
Type::FLOAT4,
FieldFormat::Text,
),
FieldInfo::new(
"float64s".into(),
None,
None,
Type::FLOAT8,
FieldFormat::Text,
),
FieldInfo::new(
"float64s".into(),
None,
None,
Type::FLOAT8,
FieldFormat::Text,
),
FieldInfo::new(
"float64s".into(),
None,
None,
Type::FLOAT8,
FieldFormat::Text,
),
FieldInfo::new(
"strings".into(),
None,
None,
Type::VARCHAR,
FieldFormat::Text,
),
FieldInfo::new(
"binaries".into(),
None,
None,
Type::BYTEA,
FieldFormat::Text,
),
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
FieldInfo::new(
"datetimes".into(),
None,
None,
Type::TIMESTAMP,
FieldFormat::Text,
),
FieldInfo::new(
"timestamps".into(),
None,
None,
Type::TIMESTAMP,
FieldFormat::Text,
),
];
let values = vec![
Value::Null,
Value::Boolean(true),
Value::UInt8(u8::MAX),
Value::UInt16(u16::MAX),
Value::UInt32(u32::MAX),
Value::UInt64(u64::MAX),
Value::Int8(i8::MAX),
Value::Int8(i8::MIN),
Value::Int16(i16::MAX),
Value::Int16(i16::MIN),
Value::Int32(i32::MAX),
Value::Int32(i32::MIN),
Value::Int64(i64::MAX),
Value::Int64(i64::MIN),
Value::Float32(f32::MAX.into()),
Value::Float32(f32::MIN.into()),
Value::Float32(0f32.into()),
Value::Float64(f64::MAX.into()),
Value::Float64(f64::MIN.into()),
Value::Float64(0f64.into()),
Value::String("greptime".into()),
Value::Binary("greptime".as_bytes().into()),
Value::Date(1001i32.into()),
Value::DateTime(1000001i64.into()),
Value::Timestamp(1000001i64.into()),
];
let mut builder = DataRowEncoder::new(Arc::new(schema));
for i in values.iter() {
encode_value(i, &mut builder).unwrap();
}
let err = encode_value(
&Value::List(ListValue::new(
Some(Box::default()),
ConcreteDataType::int16_datatype(),
)),
&mut builder,
)
.unwrap_err();
match err {
PgWireError::ApiError(e) => {
assert!(format!("{e}").contains("Internal error:"));
}
_ => {
unreachable!()
}
}
}
}

View File

@@ -14,7 +14,6 @@
use std::fmt;
use datatypes::prelude::ConcreteDataType;
use sqlparser::ast::Query as SpQuery;
use crate::error::Error;
@@ -23,7 +22,6 @@ use crate::error::Error;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Query {
pub inner: SpQuery,
pub param_types: Vec<ConcreteDataType>,
}
/// Automatically converts from sqlparser Query instance to SqlQuery.
@@ -31,10 +29,7 @@ impl TryFrom<SpQuery> for Query {
type Error = Error;
fn try_from(q: SpQuery) -> Result<Self, Self::Error> {
Ok(Query {
inner: q,
param_types: vec![],
})
Ok(Query { inner: q })
}
}
@@ -46,27 +41,9 @@ impl TryFrom<Query> for SpQuery {
}
}
impl Query {
pub fn param_types(&self) -> &Vec<ConcreteDataType> {
&self.param_types
}
pub fn param_types_mut(&mut self) -> &mut Vec<ConcreteDataType> {
&mut self.param_types
}
}
impl fmt::Display for Query {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} ", self.inner)?;
write!(f, "[")?;
for i in 0..self.param_types.len() {
write!(f, "{}", self.param_types[i])?;
if i != self.param_types.len() - 1 {
write!(f, ",")?;
}
}
write!(f, "]")?;
write!(f, "{}", self.inner)?;
Ok(())
}
}
@@ -95,7 +72,7 @@ mod test {
create_query("select * from abc where x = 1 and y = 7")
.unwrap()
.to_string(),
"SELECT * FROM abc WHERE x = 1 AND y = 7 []"
"SELECT * FROM abc WHERE x = 1 AND y = 7"
);
assert_eq!(
create_query(
@@ -103,7 +80,7 @@ mod test {
)
.unwrap()
.to_string(),
"SELECT * FROM abc LEFT JOIN bcd WHERE abc.a = 1 AND bcd.d = 7 AND abc.id = bcd.id []"
"SELECT * FROM abc LEFT JOIN bcd WHERE abc.a = 1 AND bcd.d = 7 AND abc.id = bcd.id"
);
}
}

View File

@@ -145,22 +145,26 @@ pub async fn test_postgres_crud(store_type: StorageType) {
.await
.unwrap();
assert!(
sqlx::query("create table demo(i bigint, ts timestamp time index)")
.execute(&pool)
.await
.is_ok()
);
sqlx::query("create table demo(i bigint, ts timestamp time index, d date, dt datetime)")
.execute(&pool)
.await
.unwrap();
for i in 0..10 {
assert!(sqlx::query("insert into demo values($1, $2)")
let d = NaiveDate::from_yo_opt(2015, 100).unwrap();
let dt = d.and_hms_opt(0, 0, 0).unwrap().timestamp_millis();
sqlx::query("insert into demo values($1, $2, $3, $4)")
.bind(i)
.bind(i)
.bind(d)
.bind(dt)
.execute(&pool)
.await
.is_ok());
.unwrap();
}
let rows = sqlx::query("select i from demo")
let rows = sqlx::query("select i,d,dt from demo")
.fetch_all(&pool)
.await
.unwrap();
@@ -168,7 +172,18 @@ pub async fn test_postgres_crud(store_type: StorageType) {
for (i, row) in rows.iter().enumerate() {
let ret: i64 = row.get(0);
let d: NaiveDate = row.get(1);
let dt: NaiveDateTime = row.get(2);
assert_eq!(ret, i as i64);
let expected_d = NaiveDate::from_yo_opt(2015, 100).unwrap();
assert_eq!(expected_d, d);
let expected_dt = NaiveDate::from_yo_opt(2015, 100)
.and_then(|d| d.and_hms_opt(0, 0, 0))
.unwrap();
assert_eq!(expected_dt, dt);
}
let rows = sqlx::query("select i from demo where i=$1")