mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-22 22:20:02 +00:00
feat: add logical plan based prepare statement for postgresql (#1813)
* feat: add logical plan based prepare statement for postgresql * refactor: correct more types * Update src/servers/src/postgres/types.rs Co-authored-by: LFC <bayinamine@gmail.com> * fix: address review issues * test: add datetime in integration tests --------- Co-authored-by: LFC <bayinamine@gmail.com>
This commit is contained in:
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -6408,9 +6408,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pgwire"
|
||||
version = "0.14.1"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd92c65406efd0d621cdece478a41a89e472a559e44a6f2b218df4c14e66a888"
|
||||
checksum = "e2de42ee35f9694def25c37c15f564555411d9904b48e33680618ee7359080dc"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.21.2",
|
||||
@@ -11252,16 +11252,16 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "x509-certificate"
|
||||
version = "0.19.0"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bf14059fbc1dce14de1d08535c411ba0b18749c2550a12550300da90b7ba350b"
|
||||
checksum = "2133ce6c08c050a5b368730a67c53a603ffd4a4a6c577c5218675a19f7782c05"
|
||||
dependencies = [
|
||||
"bcder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"der 0.7.6",
|
||||
"hex",
|
||||
"pem 1.1.1",
|
||||
"pem 2.0.1",
|
||||
"ring",
|
||||
"signature",
|
||||
"spki 0.7.2",
|
||||
|
||||
@@ -276,7 +276,7 @@ mod test {
|
||||
having: None, \
|
||||
named_window: [], \
|
||||
qualify: None \
|
||||
}), order_by: [], limit: None, offset: None, fetch: None, locks: [] }, param_types: [] }))");
|
||||
}), order_by: [], limit: None, offset: None, fetch: None, locks: [] } }))");
|
||||
|
||||
assert_eq!(format!("{stmt:?}"), expected);
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ once_cell = "1.16"
|
||||
openmetrics-parser = "0.4"
|
||||
opensrv-mysql = "0.4"
|
||||
parking_lot = "0.12"
|
||||
pgwire = "0.14.1"
|
||||
pgwire = "0.15"
|
||||
pin-project = "1.0"
|
||||
postgres-types = { version = "0.2", features = ["with-chrono-0_4"] }
|
||||
promql-parser = "0.1.1"
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
#![feature(try_blocks)]
|
||||
|
||||
use common_catalog::consts::DEFAULT_CATALOG_NAME;
|
||||
use datatypes::schema::Schema;
|
||||
use query::plan::LogicalPlan;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub mod auth;
|
||||
@@ -72,6 +74,14 @@ pub fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &s
|
||||
}
|
||||
}
|
||||
|
||||
/// Cached SQL and logical plan for database interfaces
|
||||
#[derive(Clone)]
|
||||
pub struct SqlPlan {
|
||||
query: String,
|
||||
plan: Option<LogicalPlan>,
|
||||
schema: Option<Schema>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -48,13 +48,7 @@ use crate::mysql::helper::{
|
||||
use crate::mysql::writer;
|
||||
use crate::mysql::writer::create_mysql_column;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
|
||||
/// Cached SQL and logical plan
|
||||
#[derive(Clone)]
|
||||
struct SqlPlan {
|
||||
query: String,
|
||||
plan: Option<LogicalPlan>,
|
||||
}
|
||||
use crate::SqlPlan;
|
||||
|
||||
// An intermediate shim for executing MySQL queries.
|
||||
pub struct MysqlInstanceShim {
|
||||
@@ -214,10 +208,16 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
// in the form of "$i", it can't process "?" right now.
|
||||
let statement = transform_placeholders(statement);
|
||||
|
||||
let plan = self
|
||||
.do_describe(statement.clone())
|
||||
.await?
|
||||
.map(|DescribeResult { logical_plan, .. }| logical_plan);
|
||||
let describe_result = self.do_describe(statement.clone()).await?;
|
||||
let (plan, schema) = if let Some(DescribeResult {
|
||||
logical_plan,
|
||||
schema,
|
||||
}) = describe_result
|
||||
{
|
||||
(Some(logical_plan), Some(schema))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let params = if let Some(plan) = &plan {
|
||||
prepared_params(
|
||||
@@ -234,6 +234,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
let stmt_id = self.save_plan(SqlPlan {
|
||||
query: query.to_string(),
|
||||
plan,
|
||||
schema,
|
||||
});
|
||||
|
||||
w.reply(stmt_id, ¶ms, &[]).await?;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
mod auth_handler;
|
||||
mod handler;
|
||||
mod server;
|
||||
mod types;
|
||||
|
||||
pub(crate) const METADATA_USER: &str = "user";
|
||||
pub(crate) const METADATA_DATABASE: &str = "database";
|
||||
@@ -24,21 +25,22 @@ pub(crate) const METADATA_CATALOG: &str = "catalog";
|
||||
pub(crate) const METADATA_SCHEMA: &str = "schema";
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use derive_builder::Builder;
|
||||
use pgwire::api::auth::ServerParameterProvider;
|
||||
use pgwire::api::store::MemPortalStore;
|
||||
use pgwire::api::{ClientInfo, MakeHandler};
|
||||
use pgwire::api::ClientInfo;
|
||||
pub use server::PostgresServer;
|
||||
use session::context::Channel;
|
||||
use session::Session;
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use self::auth_handler::PgLoginVerifier;
|
||||
use self::handler::POCQueryParser;
|
||||
use self::handler::DefaultQueryParser;
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
use crate::SqlPlan;
|
||||
|
||||
pub(crate) struct GreptimeDBStartupParameters {
|
||||
version: &'static str,
|
||||
@@ -73,9 +75,9 @@ pub struct PostgresServerHandler {
|
||||
force_tls: bool,
|
||||
param_provider: Arc<GreptimeDBStartupParameters>,
|
||||
|
||||
session: Session,
|
||||
portal_store: Arc<MemPortalStore<(Statement, String)>>,
|
||||
query_parser: Arc<POCQueryParser>,
|
||||
session: Arc<Session>,
|
||||
portal_store: Arc<MemPortalStore<SqlPlan>>,
|
||||
query_parser: Arc<DefaultQueryParser>,
|
||||
}
|
||||
|
||||
#[derive(Builder)]
|
||||
@@ -84,24 +86,21 @@ pub(crate) struct MakePostgresServerHandler {
|
||||
user_provider: Option<UserProviderRef>,
|
||||
#[builder(default = "Arc::new(GreptimeDBStartupParameters::new())")]
|
||||
param_provider: Arc<GreptimeDBStartupParameters>,
|
||||
#[builder(default = "Arc::new(POCQueryParser::default())")]
|
||||
query_parser: Arc<POCQueryParser>,
|
||||
force_tls: bool,
|
||||
}
|
||||
|
||||
impl MakeHandler for MakePostgresServerHandler {
|
||||
type Handler = PostgresServerHandler;
|
||||
|
||||
fn make(&self) -> Self::Handler {
|
||||
impl MakePostgresServerHandler {
|
||||
fn make(&self, addr: Option<SocketAddr>) -> PostgresServerHandler {
|
||||
let session = Arc::new(Session::new(addr, Channel::Postgres));
|
||||
PostgresServerHandler {
|
||||
query_handler: self.query_handler.clone(),
|
||||
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
|
||||
force_tls: self.force_tls,
|
||||
param_provider: self.param_provider.clone(),
|
||||
|
||||
session: Session::new(None, Channel::Postgres),
|
||||
session: session.clone(),
|
||||
portal_store: Arc::new(MemPortalStore::new()),
|
||||
query_parser: self.query_parser.clone(),
|
||||
query_parser: Arc::new(DefaultQueryParser::new(self.query_handler.clone(), session)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
@@ -20,26 +19,26 @@ use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordBatchResult;
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_telemetry::timer;
|
||||
use datatypes::prelude::{ConcreteDataType, Value};
|
||||
use datatypes::schema::{Schema, SchemaRef};
|
||||
use datatypes::schema::SchemaRef;
|
||||
use futures::{future, stream, Stream, StreamExt};
|
||||
use metrics::increment_counter;
|
||||
use pgwire::api::portal::{Format, Portal};
|
||||
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler, StatementOrPortal};
|
||||
use pgwire::api::results::{
|
||||
DataRowEncoder, DescribeResponse, FieldInfo, QueryResponse, Response, Tag,
|
||||
};
|
||||
use pgwire::api::results::{DataRowEncoder, DescribeResponse, QueryResponse, Response, Tag};
|
||||
use pgwire::api::stmt::QueryParser;
|
||||
use pgwire::api::store::MemPortalStore;
|
||||
use pgwire::api::{ClientInfo, Type};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use query::query_engine::DescribeResult;
|
||||
use session::Session;
|
||||
use sql::dialect::PostgreSqlDialect;
|
||||
use sql::parser::ParserContext;
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use super::types::*;
|
||||
use super::PostgresServerHandler;
|
||||
use crate::error::{self, Error, Result};
|
||||
use crate::error::Result;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
use crate::SqlPlan;
|
||||
|
||||
#[async_trait]
|
||||
impl SimpleQueryHandler for PostgresServerHandler {
|
||||
@@ -141,125 +140,25 @@ where
|
||||
)))
|
||||
}
|
||||
|
||||
fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
|
||||
origin
|
||||
.column_schemas()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, col)| {
|
||||
Ok(FieldInfo::new(
|
||||
col.name.clone(),
|
||||
None,
|
||||
None,
|
||||
type_gt_to_pg(&col.data_type)?,
|
||||
field_formats.format_for(idx),
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<FieldInfo>>>()
|
||||
pub struct DefaultQueryParser {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
session: Arc<Session>,
|
||||
}
|
||||
|
||||
fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
|
||||
match value {
|
||||
Value::Null => builder.encode_field(&None::<&i8>),
|
||||
Value::Boolean(v) => builder.encode_field(v),
|
||||
Value::UInt8(v) => builder.encode_field(&(*v as i8)),
|
||||
Value::UInt16(v) => builder.encode_field(&(*v as i16)),
|
||||
Value::UInt32(v) => builder.encode_field(v),
|
||||
Value::UInt64(v) => builder.encode_field(&(*v as i64)),
|
||||
Value::Int8(v) => builder.encode_field(v),
|
||||
Value::Int16(v) => builder.encode_field(v),
|
||||
Value::Int32(v) => builder.encode_field(v),
|
||||
Value::Int64(v) => builder.encode_field(v),
|
||||
Value::Float32(v) => builder.encode_field(&v.0),
|
||||
Value::Float64(v) => builder.encode_field(&v.0),
|
||||
Value::String(v) => builder.encode_field(&v.as_utf8()),
|
||||
Value::Binary(v) => builder.encode_field(&v.deref()),
|
||||
Value::Date(v) => {
|
||||
if let Some(date) = v.to_chrono_date() {
|
||||
builder.encode_field(&date)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
})))
|
||||
}
|
||||
impl DefaultQueryParser {
|
||||
pub fn new(query_handler: ServerSqlQueryHandlerRef, session: Arc<Session>) -> Self {
|
||||
DefaultQueryParser {
|
||||
query_handler,
|
||||
session,
|
||||
}
|
||||
Value::DateTime(v) => {
|
||||
if let Some(datetime) = v.to_chrono_datetime() {
|
||||
builder.encode_field(&datetime)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
})))
|
||||
}
|
||||
}
|
||||
Value::Timestamp(v) => {
|
||||
if let Some(datetime) = v.to_chrono_datetime() {
|
||||
builder.encode_field(&datetime)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
})))
|
||||
}
|
||||
}
|
||||
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!(
|
||||
"cannot write value {:?} in postgres protocol: unimplemented",
|
||||
&value
|
||||
),
|
||||
}))),
|
||||
}
|
||||
}
|
||||
|
||||
fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
|
||||
match origin {
|
||||
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
|
||||
&ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
|
||||
&ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
|
||||
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
|
||||
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
|
||||
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
|
||||
&ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
|
||||
&ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
|
||||
&ConcreteDataType::Binary(_) => Ok(Type::BYTEA),
|
||||
&ConcreteDataType::String(_) => Ok(Type::VARCHAR),
|
||||
&ConcreteDataType::Date(_) => Ok(Type::DATE),
|
||||
&ConcreteDataType::DateTime(_) => Ok(Type::TIMESTAMP),
|
||||
&ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
|
||||
&ConcreteDataType::List(_) | &ConcreteDataType::Dictionary(_) => error::InternalSnafu {
|
||||
err_msg: format!("not implemented for column datatype {origin:?}"),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
#[async_trait]
|
||||
impl QueryParser for DefaultQueryParser {
|
||||
type Statement = SqlPlan;
|
||||
|
||||
fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
|
||||
// Note that we only support a small amount of pg data types
|
||||
match origin {
|
||||
&Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
|
||||
&Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
|
||||
&Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
|
||||
&Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
|
||||
&Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
|
||||
&Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
|
||||
&Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype(
|
||||
common_time::timestamp::TimeUnit::Millisecond,
|
||||
)),
|
||||
&Type::DATE => Ok(ConcreteDataType::date_datatype()),
|
||||
&Type::TIME => Ok(ConcreteDataType::datetime_datatype()),
|
||||
_ => error::InternalSnafu {
|
||||
err_msg: format!("unimplemented datatype {origin:?}"),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct POCQueryParser;
|
||||
|
||||
impl QueryParser for POCQueryParser {
|
||||
type Statement = (Statement, String);
|
||||
|
||||
fn parse_sql(&self, sql: &str, types: &[Type]) -> PgWireResult<Self::Statement> {
|
||||
async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
|
||||
increment_counter!(crate::metrics::METRIC_POSTGRES_PREPARED_COUNT);
|
||||
let mut stmts = ParserContext::create_with_dialect(sql, &PostgreSqlDialect {})
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
@@ -270,70 +169,36 @@ impl QueryParser for POCQueryParser {
|
||||
"invalid_prepared_statement_definition".to_owned(),
|
||||
))))
|
||||
} else {
|
||||
let mut stmt = stmts.remove(0);
|
||||
if let Statement::Query(qs) = &mut stmt {
|
||||
for t in types {
|
||||
let gt_type =
|
||||
type_pg_to_gt(t).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
qs.param_types_mut().push(gt_type);
|
||||
}
|
||||
}
|
||||
let stmt = stmts.remove(0);
|
||||
let describe_result = self
|
||||
.query_handler
|
||||
.do_describe(stmt, self.session.context())
|
||||
.await
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
|
||||
Ok((stmt, sql.to_owned()))
|
||||
let (plan, schema) = if let Some(DescribeResult {
|
||||
logical_plan,
|
||||
schema,
|
||||
}) = describe_result
|
||||
{
|
||||
(Some(logical_plan), Some(schema))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
Ok(SqlPlan {
|
||||
query: sql.to_owned(),
|
||||
plan,
|
||||
schema,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parameter_to_string(portal: &Portal<(Statement, String)>, idx: usize) -> PgWireResult<String> {
|
||||
// the index is managed from portal's parameters count so it's safe to
|
||||
// unwrap here.
|
||||
let param_type = portal.statement().parameter_types().get(idx).unwrap();
|
||||
match param_type {
|
||||
&Type::VARCHAR | &Type::TEXT => Ok(format!(
|
||||
"'{}'",
|
||||
portal.parameter::<String>(idx)?.as_deref().unwrap_or("")
|
||||
)),
|
||||
&Type::BOOL => Ok(portal
|
||||
.parameter::<bool>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::INT4 => Ok(portal
|
||||
.parameter::<i32>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::INT8 => Ok(portal
|
||||
.parameter::<i64>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::FLOAT4 => Ok(portal
|
||||
.parameter::<f32>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::FLOAT8 => Ok(portal
|
||||
.parameter::<f64>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
_ => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
|
||||
"ERROR".to_owned(),
|
||||
"22023".to_owned(),
|
||||
"unsupported_parameter_value".to_owned(),
|
||||
)))),
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(sunng87): this is a proof-of-concept implementation of postgres extended
|
||||
// query. We will choose better `Statement` for caching, a good statement type
|
||||
// is easy to:
|
||||
//
|
||||
// - getting schema from
|
||||
// - setting parameters in
|
||||
//
|
||||
// Datafusion's LogicalPlan is a good candidate for SELECT. But we need to
|
||||
// confirm it's support for other SQL command like INSERT, UPDATE.
|
||||
#[async_trait]
|
||||
impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
type Statement = (Statement, String);
|
||||
type QueryParser = POCQueryParser;
|
||||
type Statement = SqlPlan;
|
||||
type QueryParser = DefaultQueryParser;
|
||||
type PortalStore = MemPortalStore<Self::Statement>;
|
||||
|
||||
fn portal_store(&self) -> Arc<Self::PortalStore> {
|
||||
@@ -366,20 +231,29 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
)
|
||||
]
|
||||
);
|
||||
let (_, sql) = portal.statement().statement();
|
||||
let sql_plan = portal.statement().statement();
|
||||
|
||||
// manually replace variables in prepared statement
|
||||
// FIXME(sunng87)
|
||||
let mut sql = sql.clone();
|
||||
for i in 0..portal.parameter_len() {
|
||||
sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
|
||||
}
|
||||
let output = if let Some(plan) = &sql_plan.plan {
|
||||
let plan = plan
|
||||
.replace_params_with_values(parameters_to_scalar_values(plan, portal)?.as_ref())
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
self.query_handler
|
||||
.do_exec_plan(plan, self.session.context())
|
||||
.await
|
||||
} else {
|
||||
// manually replace variables in prepared statement when no
|
||||
// logical_plan is generated. This happens when logical plan is not
|
||||
// supported for certain statements.
|
||||
let mut sql = sql_plan.query.clone();
|
||||
for i in 0..portal.parameter_len() {
|
||||
sql = sql.replace(&format!("${}", i + 1), ¶meter_to_string(portal, i)?);
|
||||
}
|
||||
|
||||
let output = self
|
||||
.query_handler
|
||||
.do_query(&sql, self.session.context())
|
||||
.await
|
||||
.remove(0);
|
||||
self.query_handler
|
||||
.do_query(&sql, self.session.context())
|
||||
.await
|
||||
.remove(0)
|
||||
};
|
||||
|
||||
output_to_query_response(output, portal.result_column_format())
|
||||
}
|
||||
@@ -392,9 +266,11 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
let (param_types, stmt, format) = match target {
|
||||
let (param_types, sql_plan, format) = match target {
|
||||
StatementOrPortal::Statement(stmt) => {
|
||||
let param_types = Some(stmt.parameter_types().clone());
|
||||
// TODO(sunng87): return server inferenced param_types if client
|
||||
// not specified
|
||||
(param_types, stmt.statement(), &Format::UnifiedBinary)
|
||||
}
|
||||
StatementOrPortal::Portal(portal) => (
|
||||
@@ -403,16 +279,9 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
portal.result_column_format(),
|
||||
),
|
||||
};
|
||||
// get Statement part of the tuple
|
||||
let (stmt, _) = stmt;
|
||||
|
||||
if let Some(DescribeResult { schema, .. }) = self
|
||||
.query_handler
|
||||
.do_describe(stmt.clone(), self.session.context())
|
||||
.await
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
|
||||
{
|
||||
schema_to_pg(&schema, format)
|
||||
if let Some(schema) = &sql_plan.schema {
|
||||
schema_to_pg(schema, format)
|
||||
.map(|fields| DescribeResponse::new(param_types, fields))
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))
|
||||
} else {
|
||||
@@ -420,230 +289,3 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::value::ListValue;
|
||||
use pgwire::api::results::{FieldFormat, FieldInfo};
|
||||
use pgwire::api::Type;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_schema_convert() {
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
|
||||
ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
|
||||
ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
|
||||
ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
|
||||
ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
|
||||
ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
|
||||
ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
|
||||
ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
|
||||
ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
|
||||
ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
|
||||
ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
|
||||
ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
|
||||
ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
|
||||
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
|
||||
ColumnSchema::new(
|
||||
"timestamps",
|
||||
ConcreteDataType::timestamp_millisecond_datatype(),
|
||||
true,
|
||||
),
|
||||
ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
|
||||
];
|
||||
let pg_field_info = vec![
|
||||
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
|
||||
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
|
||||
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new(
|
||||
"float32s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT4,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float64s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT8,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"binaries".into(),
|
||||
None,
|
||||
None,
|
||||
Type::BYTEA,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"strings".into(),
|
||||
None,
|
||||
None,
|
||||
Type::VARCHAR,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"timestamps".into(),
|
||||
None,
|
||||
None,
|
||||
Type::TIMESTAMP,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
|
||||
];
|
||||
let schema = Schema::new(column_schemas);
|
||||
let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
|
||||
assert_eq!(fs, pg_field_info);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_text_format_data() {
|
||||
let schema = vec![
|
||||
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
|
||||
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
|
||||
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new(
|
||||
"float32s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT4,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float32s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT4,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float32s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT4,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float64s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT8,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float64s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT8,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float64s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT8,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"strings".into(),
|
||||
None,
|
||||
None,
|
||||
Type::VARCHAR,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"binaries".into(),
|
||||
None,
|
||||
None,
|
||||
Type::BYTEA,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
|
||||
FieldInfo::new(
|
||||
"datetimes".into(),
|
||||
None,
|
||||
None,
|
||||
Type::TIMESTAMP,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"timestamps".into(),
|
||||
None,
|
||||
None,
|
||||
Type::TIMESTAMP,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
];
|
||||
|
||||
let values = vec![
|
||||
Value::Null,
|
||||
Value::Boolean(true),
|
||||
Value::UInt8(u8::MAX),
|
||||
Value::UInt16(u16::MAX),
|
||||
Value::UInt32(u32::MAX),
|
||||
Value::UInt64(u64::MAX),
|
||||
Value::Int8(i8::MAX),
|
||||
Value::Int8(i8::MIN),
|
||||
Value::Int16(i16::MAX),
|
||||
Value::Int16(i16::MIN),
|
||||
Value::Int32(i32::MAX),
|
||||
Value::Int32(i32::MIN),
|
||||
Value::Int64(i64::MAX),
|
||||
Value::Int64(i64::MIN),
|
||||
Value::Float32(f32::MAX.into()),
|
||||
Value::Float32(f32::MIN.into()),
|
||||
Value::Float32(0f32.into()),
|
||||
Value::Float64(f64::MAX.into()),
|
||||
Value::Float64(f64::MIN.into()),
|
||||
Value::Float64(0f64.into()),
|
||||
Value::String("greptime".into()),
|
||||
Value::Binary("greptime".as_bytes().into()),
|
||||
Value::Date(1001i32.into()),
|
||||
Value::DateTime(1000001i64.into()),
|
||||
Value::Timestamp(1000001i64.into()),
|
||||
];
|
||||
let mut builder = DataRowEncoder::new(Arc::new(schema));
|
||||
for i in values.iter() {
|
||||
encode_value(i, &mut builder).unwrap();
|
||||
}
|
||||
|
||||
let err = encode_value(
|
||||
&Value::List(ListValue::new(
|
||||
Some(Box::default()),
|
||||
ConcreteDataType::int16_datatype(),
|
||||
)),
|
||||
&mut builder,
|
||||
)
|
||||
.unwrap_err();
|
||||
match err {
|
||||
PgWireError::ApiError(e) => {
|
||||
assert!(format!("{e}").contains("Internal error:"));
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ use common_telemetry::logging::error;
|
||||
use common_telemetry::{debug, warn};
|
||||
use futures::StreamExt;
|
||||
use metrics::{decrement_gauge, increment_gauge};
|
||||
use pgwire::api::MakeHandler;
|
||||
use pgwire::tokio::process_socket;
|
||||
use tokio;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
@@ -69,32 +68,36 @@ impl PostgresServer {
|
||||
accepting_stream: AbortableStream,
|
||||
tls_acceptor: Option<Arc<TlsAcceptor>>,
|
||||
) -> impl Future<Output = ()> {
|
||||
let handler = self.make_handler.clone();
|
||||
let handler_maker = self.make_handler.clone();
|
||||
accepting_stream.for_each(move |tcp_stream| {
|
||||
let io_runtime = io_runtime.clone();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let mut handler = handler.make();
|
||||
let handler_maker = handler_maker.clone();
|
||||
|
||||
async move {
|
||||
match tcp_stream {
|
||||
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
|
||||
Ok(io_stream) => {
|
||||
match io_stream.peer_addr() {
|
||||
let addr = match io_stream.peer_addr() {
|
||||
Ok(addr) => {
|
||||
handler.session.mut_conn_info().client_addr = Some(addr);
|
||||
debug!("PostgreSQL client coming from {}", addr)
|
||||
debug!("PostgreSQL client coming from {}", addr);
|
||||
Some(addr)
|
||||
}
|
||||
Err(e) => warn!("Failed to get PostgreSQL client addr, err: {}", e),
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to get PostgreSQL client addr, err: {}", e);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let _handle = io_runtime.spawn(async move {
|
||||
increment_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0);
|
||||
let handler = Arc::new(handler);
|
||||
let pg_handler = Arc::new(handler_maker.make(addr));
|
||||
let r = process_socket(
|
||||
io_stream,
|
||||
tls_acceptor.clone(),
|
||||
handler.clone(),
|
||||
handler.clone(),
|
||||
handler,
|
||||
pg_handler.clone(),
|
||||
pg_handler.clone(),
|
||||
pg_handler,
|
||||
)
|
||||
.await;
|
||||
decrement_gauge!(crate::metrics::METRIC_POSTGRES_CONNECTIONS, 1.0);
|
||||
|
||||
712
src/servers/src/postgres/types.rs
Normal file
712
src/servers/src/postgres/types.rs
Normal file
@@ -0,0 +1,712 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::Deref;
|
||||
|
||||
use chrono::{NaiveDate, NaiveDateTime};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datatypes::prelude::{ConcreteDataType, Value};
|
||||
use datatypes::schema::Schema;
|
||||
use datatypes::types::TimestampType;
|
||||
use pgwire::api::portal::{Format, Portal};
|
||||
use pgwire::api::results::{DataRowEncoder, FieldInfo};
|
||||
use pgwire::api::Type;
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use query::plan::LogicalPlan;
|
||||
|
||||
use crate::error::{self, Error, Result};
|
||||
use crate::SqlPlan;
|
||||
|
||||
pub(super) fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
|
||||
origin
|
||||
.column_schemas()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, col)| {
|
||||
Ok(FieldInfo::new(
|
||||
col.name.clone(),
|
||||
None,
|
||||
None,
|
||||
type_gt_to_pg(&col.data_type)?,
|
||||
field_formats.format_for(idx),
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<FieldInfo>>>()
|
||||
}
|
||||
|
||||
pub(super) fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
|
||||
match value {
|
||||
Value::Null => builder.encode_field(&None::<&i8>),
|
||||
Value::Boolean(v) => builder.encode_field(v),
|
||||
Value::UInt8(v) => builder.encode_field(&(*v as i8)),
|
||||
Value::UInt16(v) => builder.encode_field(&(*v as i16)),
|
||||
Value::UInt32(v) => builder.encode_field(v),
|
||||
Value::UInt64(v) => builder.encode_field(&(*v as i64)),
|
||||
Value::Int8(v) => builder.encode_field(v),
|
||||
Value::Int16(v) => builder.encode_field(v),
|
||||
Value::Int32(v) => builder.encode_field(v),
|
||||
Value::Int64(v) => builder.encode_field(v),
|
||||
Value::Float32(v) => builder.encode_field(&v.0),
|
||||
Value::Float64(v) => builder.encode_field(&v.0),
|
||||
Value::String(v) => builder.encode_field(&v.as_utf8()),
|
||||
Value::Binary(v) => builder.encode_field(&v.deref()),
|
||||
Value::Date(v) => {
|
||||
if let Some(date) = v.to_chrono_date() {
|
||||
builder.encode_field(&date)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
})))
|
||||
}
|
||||
}
|
||||
Value::DateTime(v) => {
|
||||
if let Some(datetime) = v.to_chrono_datetime() {
|
||||
builder.encode_field(&datetime)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
})))
|
||||
}
|
||||
}
|
||||
Value::Timestamp(v) => {
|
||||
if let Some(datetime) = v.to_chrono_datetime() {
|
||||
builder.encode_field(&datetime)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
})))
|
||||
}
|
||||
}
|
||||
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!(
|
||||
"cannot write value {:?} in postgres protocol: unimplemented",
|
||||
&value
|
||||
),
|
||||
}))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
|
||||
match origin {
|
||||
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
|
||||
&ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
|
||||
&ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
|
||||
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
|
||||
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
|
||||
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
|
||||
&ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
|
||||
&ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
|
||||
&ConcreteDataType::Binary(_) => Ok(Type::BYTEA),
|
||||
&ConcreteDataType::String(_) => Ok(Type::VARCHAR),
|
||||
&ConcreteDataType::Date(_) => Ok(Type::DATE),
|
||||
&ConcreteDataType::DateTime(_) => Ok(Type::TIMESTAMP),
|
||||
&ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
|
||||
&ConcreteDataType::List(_) | &ConcreteDataType::Dictionary(_) => error::InternalSnafu {
|
||||
err_msg: format!("not implemented for column datatype {origin:?}"),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
|
||||
// Note that we only support a small amount of pg data types
|
||||
match origin {
|
||||
&Type::BOOL => Ok(ConcreteDataType::boolean_datatype()),
|
||||
&Type::CHAR => Ok(ConcreteDataType::int8_datatype()),
|
||||
&Type::INT2 => Ok(ConcreteDataType::int16_datatype()),
|
||||
&Type::INT4 => Ok(ConcreteDataType::int32_datatype()),
|
||||
&Type::INT8 => Ok(ConcreteDataType::int64_datatype()),
|
||||
&Type::VARCHAR | &Type::TEXT => Ok(ConcreteDataType::string_datatype()),
|
||||
&Type::TIMESTAMP => Ok(ConcreteDataType::timestamp_datatype(
|
||||
common_time::timestamp::TimeUnit::Millisecond,
|
||||
)),
|
||||
&Type::DATE => Ok(ConcreteDataType::date_datatype()),
|
||||
&Type::TIME => Ok(ConcreteDataType::datetime_datatype()),
|
||||
_ => error::InternalSnafu {
|
||||
err_msg: format!("unimplemented datatype {origin:?}"),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWireResult<String> {
|
||||
// the index is managed from portal's parameters count so it's safe to
|
||||
// unwrap here.
|
||||
let param_type = portal.statement().parameter_types().get(idx).unwrap();
|
||||
match param_type {
|
||||
&Type::VARCHAR | &Type::TEXT => Ok(format!(
|
||||
"'{}'",
|
||||
portal.parameter::<String>(idx)?.as_deref().unwrap_or("")
|
||||
)),
|
||||
&Type::BOOL => Ok(portal
|
||||
.parameter::<bool>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::INT4 => Ok(portal
|
||||
.parameter::<i32>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::INT8 => Ok(portal
|
||||
.parameter::<i64>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::FLOAT4 => Ok(portal
|
||||
.parameter::<f32>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::FLOAT8 => Ok(portal
|
||||
.parameter::<f64>(idx)?
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::DATE => Ok(portal
|
||||
.parameter::<NaiveDate>(idx)?
|
||||
.map(|v| v.format("%Y-%m-%d").to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
&Type::TIMESTAMP => Ok(portal
|
||||
.parameter::<NaiveDateTime>(idx)?
|
||||
.map(|v| v.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
|
||||
.unwrap_or_else(|| "".to_owned())),
|
||||
_ => Err(invalid_parameter_error(
|
||||
"unsupported_parameter_type",
|
||||
Some(¶m_type.to_string()),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn invalid_parameter_error(msg: &str, detail: Option<&str>) -> PgWireError {
|
||||
let mut error_info = ErrorInfo::new("ERROR".to_owned(), "22023".to_owned(), msg.to_owned());
|
||||
error_info.set_detail(detail.map(|s| s.to_owned()));
|
||||
PgWireError::UserError(Box::new(error_info))
|
||||
}
|
||||
|
||||
fn to_timestamp_scalar_value<T>(
|
||||
data: Option<T>,
|
||||
unit: &TimestampType,
|
||||
ctype: &ConcreteDataType,
|
||||
) -> PgWireResult<ScalarValue>
|
||||
where
|
||||
T: Into<i64>,
|
||||
{
|
||||
if let Some(n) = data {
|
||||
Value::Timestamp(unit.create_timestamp(n.into()))
|
||||
.try_to_scalar_value(ctype)
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))
|
||||
} else {
|
||||
Ok(ScalarValue::Null)
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn parameters_to_scalar_values(
|
||||
plan: &LogicalPlan,
|
||||
portal: &Portal<SqlPlan>,
|
||||
) -> PgWireResult<Vec<ScalarValue>> {
|
||||
let param_count = portal.parameter_len();
|
||||
let mut results = Vec::with_capacity(param_count);
|
||||
|
||||
let client_param_types = portal.statement().parameter_types();
|
||||
let param_types = plan
|
||||
.get_param_types()
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
|
||||
// ensure parameter count consistent for: client parameter types, server
|
||||
// parameter types and parameter count
|
||||
if param_types.len() != param_count {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_count",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
param_types.len(),
|
||||
param_count
|
||||
)),
|
||||
));
|
||||
}
|
||||
if client_param_types.len() != param_count {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_count",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
client_param_types.len(),
|
||||
param_count
|
||||
)),
|
||||
));
|
||||
}
|
||||
|
||||
for (idx, client_type) in client_param_types.iter().enumerate() {
|
||||
let Some(Some(server_type)) = param_types.get(&format!("${}", idx + 1)) else { continue };
|
||||
let value = match client_type {
|
||||
&Type::VARCHAR | &Type::TEXT => {
|
||||
let data = portal.parameter::<String>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::String(_) => ScalarValue::Utf8(data),
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::BOOL => {
|
||||
let data = portal.parameter::<bool>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Boolean(_) => ScalarValue::Boolean(data),
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::INT2 => {
|
||||
let data = portal.parameter::<i16>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data),
|
||||
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
|
||||
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
|
||||
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
|
||||
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
|
||||
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
|
||||
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
|
||||
ConcreteDataType::Timestamp(unit) => {
|
||||
to_timestamp_scalar_value(data, unit, server_type)?
|
||||
}
|
||||
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)),
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::INT4 => {
|
||||
let data = portal.parameter::<i32>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
|
||||
ConcreteDataType::Int32(_) => ScalarValue::Int32(data),
|
||||
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
|
||||
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
|
||||
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
|
||||
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
|
||||
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
|
||||
ConcreteDataType::Timestamp(unit) => {
|
||||
to_timestamp_scalar_value(data, unit, server_type)?
|
||||
}
|
||||
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data.map(|d| d as i64)),
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::INT8 => {
|
||||
let data = portal.parameter::<i64>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
|
||||
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
|
||||
ConcreteDataType::Int64(_) => ScalarValue::Int64(data),
|
||||
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
|
||||
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
|
||||
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
|
||||
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
|
||||
ConcreteDataType::Timestamp(unit) => {
|
||||
to_timestamp_scalar_value(data, unit, server_type)?
|
||||
}
|
||||
ConcreteDataType::DateTime(_) => ScalarValue::Date64(data),
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::FLOAT4 => {
|
||||
let data = portal.parameter::<f32>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
|
||||
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
|
||||
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
|
||||
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
|
||||
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
|
||||
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
|
||||
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
|
||||
ConcreteDataType::Float32(_) => ScalarValue::Float32(data),
|
||||
ConcreteDataType::Float64(_) => ScalarValue::Float64(data.map(|n| n as f64)),
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::FLOAT8 => {
|
||||
let data = portal.parameter::<f64>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Int8(_) => ScalarValue::Int8(data.map(|n| n as i8)),
|
||||
ConcreteDataType::Int16(_) => ScalarValue::Int16(data.map(|n| n as i16)),
|
||||
ConcreteDataType::Int32(_) => ScalarValue::Int32(data.map(|n| n as i32)),
|
||||
ConcreteDataType::Int64(_) => ScalarValue::Int64(data.map(|n| n as i64)),
|
||||
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(data.map(|n| n as u8)),
|
||||
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(data.map(|n| n as u16)),
|
||||
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(data.map(|n| n as u32)),
|
||||
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(data.map(|n| n as u64)),
|
||||
ConcreteDataType::Float32(_) => ScalarValue::Float32(data.map(|n| n as f32)),
|
||||
ConcreteDataType::Float64(_) => ScalarValue::Float64(data),
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::TIMESTAMP => {
|
||||
let data = portal.parameter::<NaiveDateTime>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Timestamp(unit) => match *unit {
|
||||
TimestampType::Second(_) => {
|
||||
ScalarValue::TimestampSecond(data.map(|ts| ts.timestamp()), None)
|
||||
}
|
||||
TimestampType::Millisecond(_) => ScalarValue::TimestampMillisecond(
|
||||
data.map(|ts| ts.timestamp_millis()),
|
||||
None,
|
||||
),
|
||||
TimestampType::Microsecond(_) => ScalarValue::TimestampMicrosecond(
|
||||
data.map(|ts| ts.timestamp_micros()),
|
||||
None,
|
||||
),
|
||||
TimestampType::Nanosecond(_) => ScalarValue::TimestampNanosecond(
|
||||
data.map(|ts| ts.timestamp_micros()),
|
||||
None,
|
||||
),
|
||||
},
|
||||
ConcreteDataType::DateTime(_) => {
|
||||
ScalarValue::Date64(data.map(|d| d.timestamp_millis()))
|
||||
}
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::DATE => {
|
||||
let data = portal.parameter::<NaiveDate>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::Date(_) => ScalarValue::Date32(data.map(|d| {
|
||||
(d - NaiveDate::from_ymd_opt(1970, 1, 1).unwrap()).num_days() as i32
|
||||
})),
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
&Type::BYTEA => {
|
||||
let data = portal.parameter::<Vec<u8>>(idx)?;
|
||||
match server_type {
|
||||
ConcreteDataType::String(_) => {
|
||||
ScalarValue::Utf8(data.map(|d| String::from_utf8_lossy(&d).to_string()))
|
||||
}
|
||||
ConcreteDataType::Binary(_) => ScalarValue::Binary(data),
|
||||
_ => {
|
||||
return Err(invalid_parameter_error(
|
||||
"invalid_parameter_type",
|
||||
Some(&format!(
|
||||
"Expected: {}, found: {}",
|
||||
server_type, client_type
|
||||
)),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => Err(invalid_parameter_error(
|
||||
"unsupported_parameter_value",
|
||||
Some(&format!("Found type: {}", client_type)),
|
||||
))?,
|
||||
};
|
||||
results.push(value);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::value::ListValue;
|
||||
use pgwire::api::results::{FieldFormat, FieldInfo};
|
||||
use pgwire::api::Type;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_schema_convert() {
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
|
||||
ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
|
||||
ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
|
||||
ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
|
||||
ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
|
||||
ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
|
||||
ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
|
||||
ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
|
||||
ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
|
||||
ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
|
||||
ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
|
||||
ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
|
||||
ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
|
||||
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
|
||||
ColumnSchema::new(
|
||||
"timestamps",
|
||||
ConcreteDataType::timestamp_millisecond_datatype(),
|
||||
true,
|
||||
),
|
||||
ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
|
||||
];
|
||||
let pg_field_info = vec![
|
||||
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
|
||||
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
|
||||
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new(
|
||||
"float32s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT4,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float64s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT8,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"binaries".into(),
|
||||
None,
|
||||
None,
|
||||
Type::BYTEA,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"strings".into(),
|
||||
None,
|
||||
None,
|
||||
Type::VARCHAR,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"timestamps".into(),
|
||||
None,
|
||||
None,
|
||||
Type::TIMESTAMP,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
|
||||
];
|
||||
let schema = Schema::new(column_schemas);
|
||||
let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
|
||||
assert_eq!(fs, pg_field_info);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_text_format_data() {
|
||||
let schema = vec![
|
||||
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
|
||||
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
|
||||
FieldInfo::new("uint8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("uint16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("uint32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("uint64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("int8s".into(), None, None, Type::CHAR, FieldFormat::Text),
|
||||
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("int16s".into(), None, None, Type::INT2, FieldFormat::Text),
|
||||
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("int32s".into(), None, None, Type::INT4, FieldFormat::Text),
|
||||
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new("int64s".into(), None, None, Type::INT8, FieldFormat::Text),
|
||||
FieldInfo::new(
|
||||
"float32s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT4,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float32s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT4,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float32s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT4,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float64s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT8,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float64s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT8,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"float64s".into(),
|
||||
None,
|
||||
None,
|
||||
Type::FLOAT8,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"strings".into(),
|
||||
None,
|
||||
None,
|
||||
Type::VARCHAR,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"binaries".into(),
|
||||
None,
|
||||
None,
|
||||
Type::BYTEA,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
|
||||
FieldInfo::new(
|
||||
"datetimes".into(),
|
||||
None,
|
||||
None,
|
||||
Type::TIMESTAMP,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
FieldInfo::new(
|
||||
"timestamps".into(),
|
||||
None,
|
||||
None,
|
||||
Type::TIMESTAMP,
|
||||
FieldFormat::Text,
|
||||
),
|
||||
];
|
||||
|
||||
let values = vec![
|
||||
Value::Null,
|
||||
Value::Boolean(true),
|
||||
Value::UInt8(u8::MAX),
|
||||
Value::UInt16(u16::MAX),
|
||||
Value::UInt32(u32::MAX),
|
||||
Value::UInt64(u64::MAX),
|
||||
Value::Int8(i8::MAX),
|
||||
Value::Int8(i8::MIN),
|
||||
Value::Int16(i16::MAX),
|
||||
Value::Int16(i16::MIN),
|
||||
Value::Int32(i32::MAX),
|
||||
Value::Int32(i32::MIN),
|
||||
Value::Int64(i64::MAX),
|
||||
Value::Int64(i64::MIN),
|
||||
Value::Float32(f32::MAX.into()),
|
||||
Value::Float32(f32::MIN.into()),
|
||||
Value::Float32(0f32.into()),
|
||||
Value::Float64(f64::MAX.into()),
|
||||
Value::Float64(f64::MIN.into()),
|
||||
Value::Float64(0f64.into()),
|
||||
Value::String("greptime".into()),
|
||||
Value::Binary("greptime".as_bytes().into()),
|
||||
Value::Date(1001i32.into()),
|
||||
Value::DateTime(1000001i64.into()),
|
||||
Value::Timestamp(1000001i64.into()),
|
||||
];
|
||||
let mut builder = DataRowEncoder::new(Arc::new(schema));
|
||||
for i in values.iter() {
|
||||
encode_value(i, &mut builder).unwrap();
|
||||
}
|
||||
|
||||
let err = encode_value(
|
||||
&Value::List(ListValue::new(
|
||||
Some(Box::default()),
|
||||
ConcreteDataType::int16_datatype(),
|
||||
)),
|
||||
&mut builder,
|
||||
)
|
||||
.unwrap_err();
|
||||
match err {
|
||||
PgWireError::ApiError(e) => {
|
||||
assert!(format!("{e}").contains("Internal error:"));
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use sqlparser::ast::Query as SpQuery;
|
||||
|
||||
use crate::error::Error;
|
||||
@@ -23,7 +22,6 @@ use crate::error::Error;
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Query {
|
||||
pub inner: SpQuery,
|
||||
pub param_types: Vec<ConcreteDataType>,
|
||||
}
|
||||
|
||||
/// Automatically converts from sqlparser Query instance to SqlQuery.
|
||||
@@ -31,10 +29,7 @@ impl TryFrom<SpQuery> for Query {
|
||||
type Error = Error;
|
||||
|
||||
fn try_from(q: SpQuery) -> Result<Self, Self::Error> {
|
||||
Ok(Query {
|
||||
inner: q,
|
||||
param_types: vec![],
|
||||
})
|
||||
Ok(Query { inner: q })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,27 +41,9 @@ impl TryFrom<Query> for SpQuery {
|
||||
}
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub fn param_types(&self) -> &Vec<ConcreteDataType> {
|
||||
&self.param_types
|
||||
}
|
||||
|
||||
pub fn param_types_mut(&mut self) -> &mut Vec<ConcreteDataType> {
|
||||
&mut self.param_types
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Query {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{} ", self.inner)?;
|
||||
write!(f, "[")?;
|
||||
for i in 0..self.param_types.len() {
|
||||
write!(f, "{}", self.param_types[i])?;
|
||||
if i != self.param_types.len() - 1 {
|
||||
write!(f, ",")?;
|
||||
}
|
||||
}
|
||||
write!(f, "]")?;
|
||||
write!(f, "{}", self.inner)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -95,7 +72,7 @@ mod test {
|
||||
create_query("select * from abc where x = 1 and y = 7")
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
"SELECT * FROM abc WHERE x = 1 AND y = 7 []"
|
||||
"SELECT * FROM abc WHERE x = 1 AND y = 7"
|
||||
);
|
||||
assert_eq!(
|
||||
create_query(
|
||||
@@ -103,7 +80,7 @@ mod test {
|
||||
)
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
"SELECT * FROM abc LEFT JOIN bcd WHERE abc.a = 1 AND bcd.d = 7 AND abc.id = bcd.id []"
|
||||
"SELECT * FROM abc LEFT JOIN bcd WHERE abc.a = 1 AND bcd.d = 7 AND abc.id = bcd.id"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,22 +145,26 @@ pub async fn test_postgres_crud(store_type: StorageType) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
sqlx::query("create table demo(i bigint, ts timestamp time index)")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.is_ok()
|
||||
);
|
||||
sqlx::query("create table demo(i bigint, ts timestamp time index, d date, dt datetime)")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for i in 0..10 {
|
||||
assert!(sqlx::query("insert into demo values($1, $2)")
|
||||
let d = NaiveDate::from_yo_opt(2015, 100).unwrap();
|
||||
let dt = d.and_hms_opt(0, 0, 0).unwrap().timestamp_millis();
|
||||
|
||||
sqlx::query("insert into demo values($1, $2, $3, $4)")
|
||||
.bind(i)
|
||||
.bind(i)
|
||||
.bind(d)
|
||||
.bind(dt)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.is_ok());
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let rows = sqlx::query("select i from demo")
|
||||
let rows = sqlx::query("select i,d,dt from demo")
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -168,7 +172,18 @@ pub async fn test_postgres_crud(store_type: StorageType) {
|
||||
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
let ret: i64 = row.get(0);
|
||||
let d: NaiveDate = row.get(1);
|
||||
let dt: NaiveDateTime = row.get(2);
|
||||
|
||||
assert_eq!(ret, i as i64);
|
||||
|
||||
let expected_d = NaiveDate::from_yo_opt(2015, 100).unwrap();
|
||||
assert_eq!(expected_d, d);
|
||||
|
||||
let expected_dt = NaiveDate::from_yo_opt(2015, 100)
|
||||
.and_then(|d| d.and_hms_opt(0, 0, 0))
|
||||
.unwrap();
|
||||
assert_eq!(expected_dt, dt);
|
||||
}
|
||||
|
||||
let rows = sqlx::query("select i from demo where i=$1")
|
||||
|
||||
Reference in New Issue
Block a user