mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-05 21:02:58 +00:00
feat: adopt pgwire 0.12 and simplify encoding apis (#1250)
* feat: adopt pgwire 0.12 and simplify encoding apis * refactor: remove duplicated format match clause
This commit is contained in:
459
Cargo.lock
generated
459
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -46,7 +46,7 @@ once_cell = "1.16"
|
||||
openmetrics-parser = "0.4"
|
||||
opensrv-mysql = { git = "https://github.com/sunng87/opensrv", branch = "fix/buffer-overread" }
|
||||
parking_lot = "0.12"
|
||||
pgwire = "0.10"
|
||||
pgwire = "0.12"
|
||||
pin-project = "1.0"
|
||||
postgres-types = { version = "0.2", features = ["with-chrono-0_4"] }
|
||||
promql-parser = "0.1.0"
|
||||
|
||||
@@ -14,21 +14,21 @@
|
||||
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::LocalResult;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::Result as RecordBatchResult;
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_time::timestamp::TimeUnit;
|
||||
use datatypes::prelude::{ConcreteDataType, Value};
|
||||
use datatypes::schema::{Schema, SchemaRef};
|
||||
use futures::{future, stream, Stream, StreamExt};
|
||||
use pgwire::api::portal::Portal;
|
||||
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
|
||||
use pgwire::api::results::{query_response, DataRowEncoder, FieldFormat, FieldInfo, Response, Tag};
|
||||
use pgwire::api::stmt::{QueryParser, StoredStatement};
|
||||
use pgwire::api::portal::{Format, Portal};
|
||||
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler, StatementOrPortal};
|
||||
use pgwire::api::results::{
|
||||
DataRowEncoder, DescribeResponse, FieldInfo, QueryResponse, Response, Tag,
|
||||
};
|
||||
use pgwire::api::stmt::QueryParser;
|
||||
use pgwire::api::store::MemPortalStore;
|
||||
use pgwire::api::{ClientInfo, Type};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
@@ -41,7 +41,7 @@ use crate::error::{self, Error, Result};
|
||||
|
||||
#[async_trait]
|
||||
impl SimpleQueryHandler for PostgresServerHandler {
|
||||
async fn do_query<C>(&self, _client: &C, query: &str) -> PgWireResult<Vec<Response>>
|
||||
async fn do_query<'a, C>(&self, _client: &C, query: &'a str) -> PgWireResult<Vec<Response<'a>>>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
@@ -53,7 +53,7 @@ impl SimpleQueryHandler for PostgresServerHandler {
|
||||
let mut results = Vec::with_capacity(outputs.len());
|
||||
|
||||
for output in outputs {
|
||||
let resp = output_to_query_response(output, FieldFormat::Text)?;
|
||||
let resp = output_to_query_response(output, &Format::UnifiedText)?;
|
||||
results.push(resp);
|
||||
}
|
||||
|
||||
@@ -61,10 +61,10 @@ impl SimpleQueryHandler for PostgresServerHandler {
|
||||
}
|
||||
}
|
||||
|
||||
fn output_to_query_response(
|
||||
fn output_to_query_response<'a>(
|
||||
output: Result<Output>,
|
||||
field_format: FieldFormat,
|
||||
) -> PgWireResult<Response> {
|
||||
field_format: &Format,
|
||||
) -> PgWireResult<Response<'a>> {
|
||||
match output {
|
||||
Ok(Output::AffectedRows(rows)) => Ok(Response::Execution(Tag::new_for_execution(
|
||||
"OK",
|
||||
@@ -86,11 +86,11 @@ fn output_to_query_response(
|
||||
}
|
||||
}
|
||||
|
||||
fn recordbatches_to_query_response<S>(
|
||||
fn recordbatches_to_query_response<'a, S>(
|
||||
recordbatches_stream: S,
|
||||
schema: SchemaRef,
|
||||
field_format: FieldFormat,
|
||||
) -> PgWireResult<Response>
|
||||
field_format: &Format,
|
||||
) -> PgWireResult<Response<'a>>
|
||||
where
|
||||
S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
|
||||
{
|
||||
@@ -98,8 +98,6 @@ where
|
||||
schema_to_pg(schema.as_ref(), field_format)
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?,
|
||||
);
|
||||
let ncols = pg_schema.len();
|
||||
|
||||
let pg_schema_ref = pg_schema.clone();
|
||||
let data_row_stream = recordbatches_stream
|
||||
.map(|record_batch_result| match record_batch_result {
|
||||
@@ -113,68 +111,57 @@ where
|
||||
})
|
||||
.flatten() // flatten into stream<result<row>>
|
||||
.map(move |row| {
|
||||
row.and_then(|row| match field_format {
|
||||
FieldFormat::Text => {
|
||||
let mut encoder = DataRowEncoder::new(ncols);
|
||||
for value in row.into_iter() {
|
||||
encode_text_value(&value, &mut encoder)?;
|
||||
}
|
||||
encoder.finish()
|
||||
}
|
||||
FieldFormat::Binary => {
|
||||
let mut encoder = DataRowEncoder::new(ncols);
|
||||
for (idx, value) in row.into_iter().enumerate() {
|
||||
encode_binary_value(&value, pg_schema_ref[idx].datatype(), &mut encoder)?;
|
||||
}
|
||||
encoder.finish()
|
||||
row.and_then(|row| {
|
||||
let mut encoder = DataRowEncoder::new(pg_schema_ref.clone());
|
||||
for value in row.iter() {
|
||||
encode_value(value, &mut encoder)?;
|
||||
}
|
||||
encoder.finish()
|
||||
})
|
||||
});
|
||||
|
||||
match field_format {
|
||||
FieldFormat::Text => Ok(Response::Query(query_response(
|
||||
Some(pg_schema.deref().clone()),
|
||||
data_row_stream,
|
||||
))),
|
||||
FieldFormat::Binary => Ok(Response::Query(query_response(None, data_row_stream))),
|
||||
}
|
||||
Ok(Response::Query(QueryResponse::new(
|
||||
pg_schema,
|
||||
data_row_stream,
|
||||
)))
|
||||
}
|
||||
|
||||
fn schema_to_pg(origin: &Schema, field_format: FieldFormat) -> Result<Vec<FieldInfo>> {
|
||||
fn schema_to_pg(origin: &Schema, field_formats: &Format) -> Result<Vec<FieldInfo>> {
|
||||
origin
|
||||
.column_schemas()
|
||||
.iter()
|
||||
.map(|col| {
|
||||
.enumerate()
|
||||
.map(|(idx, col)| {
|
||||
Ok(FieldInfo::new(
|
||||
col.name.clone(),
|
||||
None,
|
||||
None,
|
||||
type_gt_to_pg(&col.data_type)?,
|
||||
field_format,
|
||||
field_formats.format_for(idx),
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<FieldInfo>>>()
|
||||
}
|
||||
|
||||
fn encode_text_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
|
||||
fn encode_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResult<()> {
|
||||
match value {
|
||||
Value::Null => builder.encode_text_format_field(None::<&i8>),
|
||||
Value::Boolean(v) => builder.encode_text_format_field(Some(v)),
|
||||
Value::UInt8(v) => builder.encode_text_format_field(Some(v)),
|
||||
Value::UInt16(v) => builder.encode_text_format_field(Some(v)),
|
||||
Value::UInt32(v) => builder.encode_text_format_field(Some(v)),
|
||||
Value::UInt64(v) => builder.encode_text_format_field(Some(v)),
|
||||
Value::Int8(v) => builder.encode_text_format_field(Some(v)),
|
||||
Value::Int16(v) => builder.encode_text_format_field(Some(v)),
|
||||
Value::Int32(v) => builder.encode_text_format_field(Some(v)),
|
||||
Value::Int64(v) => builder.encode_text_format_field(Some(v)),
|
||||
Value::Float32(v) => builder.encode_text_format_field(Some(&v.0)),
|
||||
Value::Float64(v) => builder.encode_text_format_field(Some(&v.0)),
|
||||
Value::String(v) => builder.encode_text_format_field(Some(&v.as_utf8())),
|
||||
Value::Binary(v) => builder.encode_text_format_field(Some(&hex::encode(v.deref()))),
|
||||
Value::Null => builder.encode_field(&None::<&i8>),
|
||||
Value::Boolean(v) => builder.encode_field(v),
|
||||
Value::UInt8(v) => builder.encode_field(&(*v as i8)),
|
||||
Value::UInt16(v) => builder.encode_field(&(*v as i16)),
|
||||
Value::UInt32(v) => builder.encode_field(v),
|
||||
Value::UInt64(v) => builder.encode_field(&(*v as i64)),
|
||||
Value::Int8(v) => builder.encode_field(v),
|
||||
Value::Int16(v) => builder.encode_field(v),
|
||||
Value::Int32(v) => builder.encode_field(v),
|
||||
Value::Int64(v) => builder.encode_field(v),
|
||||
Value::Float32(v) => builder.encode_field(&v.0),
|
||||
Value::Float64(v) => builder.encode_field(&v.0),
|
||||
Value::String(v) => builder.encode_field(&v.as_utf8()),
|
||||
Value::Binary(v) => builder.encode_field(&v.deref()),
|
||||
Value::Date(v) => {
|
||||
if let Some(date) = v.to_chrono_date() {
|
||||
builder.encode_text_format_field(Some(&date.format("%Y-%m-%d").to_string()))
|
||||
builder.encode_field(&date)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
@@ -183,9 +170,7 @@ fn encode_text_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResul
|
||||
}
|
||||
Value::DateTime(v) => {
|
||||
if let Some(datetime) = v.to_chrono_datetime() {
|
||||
builder.encode_text_format_field(Some(
|
||||
&datetime.format("%Y-%m-%d %H:%M:%S%.6f").to_string(),
|
||||
))
|
||||
builder.encode_field(&datetime)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
@@ -194,9 +179,7 @@ fn encode_text_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResul
|
||||
}
|
||||
Value::Timestamp(v) => {
|
||||
if let LocalResult::Single(datetime) = v.to_chrono_datetime() {
|
||||
builder.encode_text_format_field(Some(
|
||||
&datetime.format("%Y-%m-%d %H:%M:%S%.6f").to_string(),
|
||||
))
|
||||
builder.encode_field(&datetime)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
@@ -212,64 +195,6 @@ fn encode_text_value(value: &Value, builder: &mut DataRowEncoder) -> PgWireResul
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_binary_value(
|
||||
value: &Value,
|
||||
datatype: &Type,
|
||||
builder: &mut DataRowEncoder,
|
||||
) -> PgWireResult<()> {
|
||||
match value {
|
||||
Value::Null => builder.encode_binary_format_field(&None::<&i8>, datatype),
|
||||
Value::Boolean(v) => builder.encode_binary_format_field(v, datatype),
|
||||
Value::UInt8(v) => builder.encode_binary_format_field(&(*v as i8), datatype),
|
||||
Value::UInt16(v) => builder.encode_binary_format_field(&(*v as i16), datatype),
|
||||
Value::UInt32(v) => builder.encode_binary_format_field(&(*v as i32), datatype),
|
||||
Value::UInt64(v) => builder.encode_binary_format_field(&(*v as i64), datatype),
|
||||
Value::Int8(v) => builder.encode_binary_format_field(v, datatype),
|
||||
Value::Int16(v) => builder.encode_binary_format_field(v, datatype),
|
||||
Value::Int32(v) => builder.encode_binary_format_field(v, datatype),
|
||||
Value::Int64(v) => builder.encode_binary_format_field(v, datatype),
|
||||
Value::Float32(v) => builder.encode_binary_format_field(&v.0, datatype),
|
||||
Value::Float64(v) => builder.encode_binary_format_field(&v.0, datatype),
|
||||
Value::String(v) => builder.encode_binary_format_field(&v.as_utf8(), datatype),
|
||||
Value::Binary(v) => builder.encode_binary_format_field(&v.deref(), datatype),
|
||||
Value::Date(v) => {
|
||||
if let Some(date) = v.to_chrono_date() {
|
||||
builder.encode_binary_format_field(&date, datatype)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert date to postgres type {v:?}",),
|
||||
})))
|
||||
}
|
||||
}
|
||||
Value::DateTime(v) => {
|
||||
if let Some(datetime) = v.to_chrono_datetime() {
|
||||
builder.encode_binary_format_field(&datetime, datatype)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert datetime to postgres type {v:?}",),
|
||||
})))
|
||||
}
|
||||
}
|
||||
Value::Timestamp(v) => {
|
||||
// convert timestamp to SystemTime
|
||||
if let Some(ts) = v.convert_to(TimeUnit::Microsecond) {
|
||||
let sys_time = std::time::UNIX_EPOCH + Duration::from_micros(ts.value() as u64);
|
||||
builder.encode_binary_format_field(&sys_time, datatype)
|
||||
} else {
|
||||
Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!("Failed to convert timestamp to postgres type {v:?}",),
|
||||
})))
|
||||
}
|
||||
}
|
||||
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!(
|
||||
"cannot write value {:?} in postgres protocol: unimplemented",
|
||||
&value
|
||||
),
|
||||
}))),
|
||||
}
|
||||
}
|
||||
|
||||
fn type_gt_to_pg(origin: &ConcreteDataType) -> Result<Type> {
|
||||
match origin {
|
||||
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
|
||||
@@ -403,12 +328,12 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
self.query_parser.clone()
|
||||
}
|
||||
|
||||
async fn do_query<C>(
|
||||
async fn do_query<'a, C>(
|
||||
&self,
|
||||
_client: &mut C,
|
||||
portal: &Portal<Self::Statement>,
|
||||
portal: &'a Portal<Self::Statement>,
|
||||
_max_rows: usize,
|
||||
) -> PgWireResult<Response>
|
||||
) -> PgWireResult<Response<'a>>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
@@ -426,28 +351,42 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
.await
|
||||
.remove(0);
|
||||
|
||||
output_to_query_response(output, FieldFormat::Binary)
|
||||
output_to_query_response(output, portal.result_column_format())
|
||||
}
|
||||
|
||||
async fn do_describe<C>(
|
||||
&self,
|
||||
_client: &mut C,
|
||||
statement: &StoredStatement<Self::Statement>,
|
||||
) -> PgWireResult<Vec<FieldInfo>>
|
||||
target: StatementOrPortal<'_, Self::Statement>,
|
||||
) -> PgWireResult<DescribeResponse>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
let (stmt, _) = statement.statement();
|
||||
let (param_types, stmt, format) = match target {
|
||||
StatementOrPortal::Statement(stmt) => {
|
||||
let param_types = Some(stmt.parameter_types().clone());
|
||||
(param_types, stmt.statement(), &Format::UnifiedBinary)
|
||||
}
|
||||
StatementOrPortal::Portal(portal) => (
|
||||
None,
|
||||
portal.statement().statement(),
|
||||
portal.result_column_format(),
|
||||
),
|
||||
};
|
||||
// get Statement part of the tuple
|
||||
let (stmt, _) = stmt;
|
||||
|
||||
if let Some(schema) = self
|
||||
.query_handler
|
||||
.do_describe(stmt.clone(), self.query_ctx.clone())
|
||||
.await
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
|
||||
{
|
||||
schema_to_pg(&schema, FieldFormat::Binary)
|
||||
schema_to_pg(&schema, format)
|
||||
.map(|fields| DescribeResponse::new(param_types, fields))
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))
|
||||
} else {
|
||||
Ok(vec![])
|
||||
Ok(DescribeResponse::new(param_types, vec![]))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -456,7 +395,7 @@ impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
mod test {
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::value::ListValue;
|
||||
use pgwire::api::results::FieldInfo;
|
||||
use pgwire::api::results::{FieldFormat, FieldInfo};
|
||||
use pgwire::api::Type;
|
||||
|
||||
use super::*;
|
||||
@@ -534,7 +473,7 @@ mod test {
|
||||
FieldInfo::new("dates".into(), None, None, Type::DATE, FieldFormat::Text),
|
||||
];
|
||||
let schema = Schema::new(column_schemas);
|
||||
let fs = schema_to_pg(&schema, FieldFormat::Text).unwrap();
|
||||
let fs = schema_to_pg(&schema, &Format::UnifiedText).unwrap();
|
||||
assert_eq!(fs, pg_field_info);
|
||||
}
|
||||
|
||||
@@ -655,15 +594,15 @@ mod test {
|
||||
Value::DateTime(1000001i64.into()),
|
||||
Value::Timestamp(1000001i64.into()),
|
||||
];
|
||||
let mut builder = DataRowEncoder::new(schema.len());
|
||||
for i in values {
|
||||
assert!(encode_text_value(&i, &mut builder).is_ok());
|
||||
let mut builder = DataRowEncoder::new(Arc::new(schema));
|
||||
for i in values.iter() {
|
||||
assert!(encode_value(i, &mut builder).is_ok());
|
||||
}
|
||||
|
||||
let err = encode_text_value(
|
||||
let err = encode_value(
|
||||
&Value::List(ListValue::new(
|
||||
Some(Box::default()),
|
||||
ConcreteDataType::int8_datatype(),
|
||||
ConcreteDataType::int16_datatype(),
|
||||
)),
|
||||
&mut builder,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user