feat: adopt pgwire 0.12 and simplify encoding apis (#1250)

* feat: adopt pgwire 0.12 and simplify encoding apis

* refactor: remove duplicated format match clause
This commit is contained in:
Ning Sun
2023-03-27 18:16:43 +08:00
committed by GitHub
parent 8ba0741c81
commit 7eb4d81929
3 changed files with 329 additions and 343 deletions

459
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -46,7 +46,7 @@ once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = { git = "https://github.com/sunng87/opensrv", branch = "fix/buffer-overread" }
parking_lot = "0.12"
pgwire = "0.10"
pgwire = "0.12"
pin-project = "1.0"
postgres-types = { version = "0.2", features = ["with-chrono-0_4"] }
promql-parser = "0.1.0"

View File

@@ -14,21 +14,21 @@
use std::ops::Deref;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use chrono::LocalResult;
use common_query::Output;
use common_recordbatch::error::Result as RecordBatchResult;
use common_recordbatch::RecordBatch;
use common_time::timestamp::TimeUnit;
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::{Schema, SchemaRef};
use futures::{future, stream, Stream, StreamExt};
use pgwire::api::portal::Portal;
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{query_response, DataRowEncoder, FieldFormat, FieldInfo, Response, Tag};
use pgwire::api::stmt::{QueryParser, StoredStatement};
use pgwire::api::portal::{Format, Portal};
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler, StatementOrPortal};
use pgwire::api::results::{
DataRowEncoder, DescribeResponse, FieldInfo, QueryResponse, Response, Tag,
};
use pgwire::api::stmt::QueryParser;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
@@ -41,7 +41,7 @@ use crate::error::{self, Error, Result};
#[async_trait]
impl SimpleQueryHandler for PostgresServerHandler {
async fn do_query<C>(&self, _client: &C, query: &str) -> PgWireResult<Vec<Response>>
async fn do_query<'a, C>(&self, _client: &C, query: &'a str) -> PgWireResult<Vec<Response<'a>>>
where
C: ClientInfo + Unpin + Send + Sync,
{
@@ -53,7 +53,7 @@ impl SimpleQueryHandler for PostgresServerHandler {
let mut results = Vec::with_capacity(outputs.len());
for output in outputs {
let resp = output_to_query_response(output, FieldFormat::Text)?;
let resp = output_to_query_response(output, &Format::UnifiedText)?;
results.push(resp);
}
@@ -61,10 +61,10 @@ impl SimpleQueryHandler for PostgresServerHandler {
}
}
fn output_to_query_response(
fn output_to_query_response<'a>(
output: Result<Output>,
field_format: FieldFormat,
) -> PgWireResult<Response> {
field_format: &Format,
) -> PgWireResult<Response<'a>> {
match output {
Ok(Output::AffectedRows(rows)) => Ok(Response::Execution(Tag::new_for_execution(
"OK",
@@ -86,11 +86,11 @@ fn output_to_query_response(
}
}
fn recordbatches_to_query_response<S>(
fn recordbatches_to_query_response<'a, S>(
recordbatches_stream: S,
schema: SchemaRef,
field_format: FieldFormat,
) -> PgWireResult<Response>
field_format: &Format,
) -> PgWireResult<Response<'a>>
where
S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
{
@@ -98,8 +98,6 @@ where
schema_to_pg(schema.as_ref(), field_format)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?,
);
let ncols = pg_schema.len();
let pg_schema_ref = pg_schema.clone();
let data_row_stream = recordbatches_stream
.map(|record_batch_result| match record_batch_result {
@@ -113,68 +111,57 @@ where
})
.flatten() // flatten into stream<result<row>>
.map(move |row| {
row.and_then(|row| match field_format {
FieldFormat::Text => {
let mut encoder = DataRowEncoder::new(ncols);
for value in row.into_iter() {
encode_text_value(&value, &mut encoder)?;
}
encoder.finish()
}
FieldFormat::Binary => {
let mut encoder = DataRowEncoder::new(ncols);
for (idx, value) in row.into_iter().enumerate() {
encode_binary_value(&value, pg_schema_ref[idx].datatype(), &mut encoder)?;
}
encoder.finish()
row.and_then(|row| {
let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
for value in row.iter() {
encode_value(value, &mut encoder)?;
}
encoder.finish()
})
});
match field_format {
FieldFormat::Text => Ok(Response::Query(query_response(
Some(pg_schema.deref().clone()),
data_row_stream,
))),
FieldFormat::Binary => Ok(Response::Query(query_response(None, data_row_stream))),
}
Ok(Response::Query(QueryResponse::new(
pg_schema,
data_row_stream,
)))
}
fn schema_to_pg(origin: &Schema, field_format: FieldFormat) -> Result<Vec<FieldInfo>> {
fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
origin
.column_schemas()
.iter()
.map(|col| {
.enumerate()
.map(|(idx, col)| {
Ok(FieldInfo::new(
col.name.clone(),
None,
None,
type_gt_to_pg(&col.data_type)?,
field_format,
field_formats.format_for(idx),
))
})
.collect::<Result<Vec<FieldInfo>>>()
}
fn encode_text_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
match value {
Value::Null => builder.encode_text_format_field(None::<&i8>),
Value::Boolean(v) => builder.encode_text_format_field(Some(v)),
Value::UInt8(v) => builder.encode_text_format_field(Some(v)),
Value::UInt16(v) => builder.encode_text_format_field(Some(v)),
Value::UInt32(v) => builder.encode_text_format_field(Some(v)),
Value::UInt64(v) => builder.encode_text_format_field(Some(v)),
Value::Int8(v) => builder.encode_text_format_field(Some(v)),
Value::Int16(v) => builder.encode_text_format_field(Some(v)),
Value::Int32(v) => builder.encode_text_format_field(Some(v)),
Value::Int64(v) => builder.encode_text_format_field(Some(v)),
Value::Float32(v) => builder.encode_text_format_field(Some(&v.0)),
Value::Float64(v) => builder.encode_text_format_field(Some(&v.0)),
Value::String(v) => builder.encode_text_format_field(Some(&v.as_utf8())),
Value::Binary(v) => builder.encode_text_format_field(Some(&hex::encode(v.deref()))),
Value::Null => builder.encode_field(&None::<&i8>),
Value::Boolean(v) => builder.encode_field(v),
Value::UInt8(v) => builder.encode_field(&(*v as i8)),
Value::UInt16(v) => builder.encode_field(&(*v as i16)),
Value::UInt32(v) => builder.encode_field(v),
Value::UInt64(v) => builder.encode_field(&(*v as i64)),
Value::Int8(v) => builder.encode_field(v),
Value::Int16(v) => builder.encode_field(v),
Value::Int32(v) => builder.encode_field(v),
Value::Int64(v) => builder.encode_field(v),
Value::Float32(v) => builder.encode_field(&v.0),
Value::Float64(v) => builder.encode_field(&v.0),
Value::String(v) => builder.encode_field(&v.as_utf8()),
Value::Binary(v) => builder.encode_field(&v.deref()),
Value::Date(v) => {
if let Some(date) = v.to_chrono_date() {
builder.encode_text_format_field(Some(&date.format("%Y-%m-%d").to_string()))
builder.encode_field(&date)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
@@ -183,9 +170,7 @@ fn encode_text_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResul
}
Value::DateTime(v) => {
if let Some(datetime) = v.to_chrono_datetime() {
builder.encode_text_format_field(Some(
&datetime.format("%Y-%m-%d %H:%M:%S%.6f").to_string(),
))
builder.encode_field(&datetime)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
@@ -194,9 +179,7 @@ fn encode_text_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResul
}
Value::Timestamp(v) => {
if let LocalResult::Single(datetime) = v.to_chrono_datetime() {
builder.encode_text_format_field(Some(
&datetime.format("%Y-%m-%d %H:%M:%S%.6f").to_string(),
))
builder.encode_field(&datetime)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
@@ -212,64 +195,6 @@ fn encode_text_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResul
}
}
fn encode_binary_value(
value: &Value,
datatype: &Type,
builder: &mut DataRowEncoder,
) -> PgWireResult<()> {
match value {
Value::Null => builder.encode_binary_format_field(&None::<&i8>, datatype),
Value::Boolean(v) => builder.encode_binary_format_field(v, datatype),
Value::UInt8(v) => builder.encode_binary_format_field(&(*v as i8), datatype),
Value::UInt16(v) => builder.encode_binary_format_field(&(*v as i16), datatype),
Value::UInt32(v) => builder.encode_binary_format_field(&(*v as i32), datatype),
Value::UInt64(v) => builder.encode_binary_format_field(&(*v as i64), datatype),
Value::Int8(v) => builder.encode_binary_format_field(v, datatype),
Value::Int16(v) => builder.encode_binary_format_field(v, datatype),
Value::Int32(v) => builder.encode_binary_format_field(v, datatype),
Value::Int64(v) => builder.encode_binary_format_field(v, datatype),
Value::Float32(v) => builder.encode_binary_format_field(&v.0, datatype),
Value::Float64(v) => builder.encode_binary_format_field(&v.0, datatype),
Value::String(v) => builder.encode_binary_format_field(&v.as_utf8(), datatype),
Value::Binary(v) => builder.encode_binary_format_field(&v.deref(), datatype),
Value::Date(v) => {
if let Some(date) = v.to_chrono_date() {
builder.encode_binary_format_field(&date, datatype)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert date to postgres type {v:?}",),
})))
}
}
Value::DateTime(v) => {
if let Some(datetime) = v.to_chrono_datetime() {
builder.encode_binary_format_field(&datetime, datatype)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert datetime to postgres type {v:?}",),
})))
}
}
Value::Timestamp(v) => {
// convert timestamp to SystemTime
if let Some(ts) = v.convert_to(TimeUnit::Microsecond) {
let sys_time = std::time::UNIX_EPOCH + Duration::from_micros(ts.value() as u64);
builder.encode_binary_format_field(&sys_time, datatype)
} else {
Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!("Failed to convert timestamp to postgres type {v:?}",),
})))
}
}
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!(
"cannot write value {:?} in postgres protocol: unimplemented",
&value
),
}))),
}
}
fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
match origin {
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
@@ -403,12 +328,12 @@ impl ExtendedQueryHandler for PostgresServerHandler {
self.query_parser.clone()
}
async fn do_query<C>(
async fn do_query<'a, C>(
&self,
_client: &mut C,
portal: &Portal<Self::Statement>,
portal: &'a Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response>
) -> PgWireResult<Response<'a>>
where
C: ClientInfo + Unpin + Send + Sync,
{
@@ -426,28 +351,42 @@ impl ExtendedQueryHandler for PostgresServerHandler {
.await
.remove(0);
output_to_query_response(output, FieldFormat::Binary)
output_to_query_response(output, portal.result_column_format())
}
async fn do_describe<C>(
&self,
_client: &mut C,
statement: &StoredStatement<Self::Statement>,
) -> PgWireResult<Vec<FieldInfo>>
target: StatementOrPortal<'_, Self::Statement>,
) -> PgWireResult<DescribeResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
let (stmt, _) = statement.statement();
let (param_types, stmt, format) = match target {
StatementOrPortal::Statement(stmt) => {
let param_types = Some(stmt.parameter_types().clone());
(param_types, stmt.statement(), &Format::UnifiedBinary)
}
StatementOrPortal::Portal(portal) => (
None,
portal.statement().statement(),
portal.result_column_format(),
),
};
// get Statement part of the tuple
let (stmt, _) = stmt;
if let Some(schema) = self
.query_handler
.do_describe(stmt.clone(), self.query_ctx.clone())
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
schema_to_pg(&schema, FieldFormat::Binary)
schema_to_pg(&schema, format)
.map(|fields| DescribeResponse::new(param_types, fields))
.map_err(|e| PgWireError::ApiError(Box::new(e)))
} else {
Ok(vec![])
Ok(DescribeResponse::new(param_types, vec![]))
}
}
}
@@ -456,7 +395,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
mod test {
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::value::ListValue;
use pgwire::api::results::FieldInfo;
use pgwire::api::results::{FieldFormat, FieldInfo};
use pgwire::api::Type;
use super::*;
@@ -534,7 +473,7 @@ mod test {
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
];
let schema = Schema::new(column_schemas);
let fs = schema_to_pg(&schema, FieldFormat::Text).unwrap();
let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
assert_eq!(fs, pg_field_info);
}
@@ -655,15 +594,15 @@ mod test {
Value::DateTime(1000001i64.into()),
Value::Timestamp(1000001i64.into()),
];
let mut builder = DataRowEncoder::new(schema.len());
for i in values {
assert!(encode_text_value(&i, &mut builder).is_ok());
let mut builder = DataRowEncoder::new(Arc::new(schema));
for i in values.iter() {
assert!(encode_value(i, &mut builder).is_ok());
}
let err = encode_text_value(
let err = encode_value(
&Value::List(ListValue::new(
Some(Box::default()),
ConcreteDataType::int8_datatype(),
ConcreteDataType::int16_datatype(),
)),
&mut builder,
)