feat: stream write for postgresql query results (#472)

This commit is contained in:
Ning Sun
2022-11-14 21:50:11 +08:00
committed by GitHub
parent c673debc89
commit 74c236a308
4 changed files with 67 additions and 44 deletions

5
Cargo.lock generated
View File

@@ -3869,9 +3869,9 @@ dependencies = [
[[package]]
name = "pgwire"
version = "0.4.0"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e30e99a0b8acf60a6815aa8178e9ffb08178ef3ca1366673bb0d6c7ababe4c2"
checksum = "5dacbf864d6cb6a0e676c9a1162ab7b315b5c8e6c87fa9b6e0ba9ba0a569adb1"
dependencies = [
"async-trait",
"bytes",
@@ -3884,6 +3884,7 @@ dependencies = [
"thiserror",
"time 0.3.14",
"tokio",
"tokio-rustls",
"tokio-util",
]

View File

@@ -29,7 +29,7 @@ num_cpus = "1.13"
once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = "0.1"
pgwire = { version = "0.4" }
pgwire = "0.5"
prost = "0.11"
regex = "1.6"
rand = "0.8"

View File

@@ -2,12 +2,14 @@ use std::ops::Deref;
use async_trait::async_trait;
use common_query::Output;
use common_recordbatch::{util, RecordBatch};
use common_recordbatch::error::Result as RecordBatchResult;
use common_recordbatch::RecordBatch;
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::SchemaRef;
use futures::{future, stream, Stream, StreamExt};
use pgwire::api::portal::Portal;
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{FieldInfo, Response, Tag, TextQueryResponseBuilder};
use pgwire::api::results::{text_query_response, FieldInfo, Response, Tag, TextDataRowEncoder};
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{PgWireError, PgWireResult};
@@ -43,40 +45,57 @@ impl SimpleQueryHandler for PostgresServerHandler {
))]),
Output::Stream(record_stream) => {
let schema = record_stream.schema();
let recordbatches = util::collect(record_stream)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
recordbatches_to_query_response(recordbatches.iter(), schema)
recordbatches_to_query_response(record_stream, schema)
}
Output::RecordBatches(recordbatches) => {
let schema = recordbatches.schema();
recordbatches_to_query_response(recordbatches.take().iter(), schema)
recordbatches_to_query_response(
stream::iter(recordbatches.take().into_iter().map(Ok)),
schema,
)
}
}
}
}
fn recordbatches_to_query_response<'a, I>(
recordbatches: I,
fn recordbatches_to_query_response<S>(
recordbatches_stream: S,
schema: SchemaRef,
) -> PgWireResult<Vec<Response>>
where
I: Iterator<Item = &'a RecordBatch>,
S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
{
let pg_schema = schema_to_pg(schema).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let mut builder = TextQueryResponseBuilder::new(pg_schema);
let ncols = pg_schema.len();
for recordbatch in recordbatches {
for row in recordbatch.rows() {
let row = row.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
for value in row.into_iter() {
encode_value(&value, &mut builder)?;
}
builder.finish_row();
}
}
let data_row_stream = recordbatches_stream
.map(|record_batch_result| match record_batch_result {
Ok(rb) => stream::iter(
// collect rows from a single recordbatch into vector to avoid
// borrowing it
rb.rows()
.map(|row| row.map_err(|e| PgWireError::ApiError(Box::new(e))))
.collect::<Vec<_>>()
.into_iter(),
)
.boxed(),
Err(e) => stream::once(future::err(PgWireError::ApiError(Box::new(e)))).boxed(),
})
.flatten() // flatten into stream<result<row>>
.map(move |row| {
row.and_then(|row| {
let mut encoder = TextDataRowEncoder::new(ncols);
for value in row.into_iter() {
encode_value(&value, &mut encoder)?;
}
encoder.finish()
})
});
Ok(vec![Response::Query(builder.build())])
Ok(vec![Response::Query(text_query_response(
pg_schema,
data_row_stream,
))])
}
fn schema_to_pg(origin: SchemaRef) -> Result<Vec<FieldInfo>> {
@@ -94,9 +113,9 @@ fn schema_to_pg(origin: SchemaRef) -> Result<Vec<FieldInfo>> {
.collect::<Result<Vec<FieldInfo>>>()
}
fn encode_value(value: &Value, builder: &mut TextQueryResponseBuilder) -> PgWireResult<()> {
fn encode_value(value: &Value, builder: &mut TextDataRowEncoder) -> PgWireResult<()> {
match value {
Value::Null => builder.append_field(None::<i8>),
Value::Null => builder.append_field(None::<&i8>),
Value::Boolean(v) => builder.append_field(Some(v)),
Value::UInt8(v) => builder.append_field(Some(v)),
Value::UInt16(v) => builder.append_field(Some(v)),
@@ -106,13 +125,13 @@ fn encode_value(value: &Value, builder: &mut TextQueryResponseBuilder) -> PgWire
Value::Int16(v) => builder.append_field(Some(v)),
Value::Int32(v) => builder.append_field(Some(v)),
Value::Int64(v) => builder.append_field(Some(v)),
Value::Float32(v) => builder.append_field(Some(v.0)),
Value::Float64(v) => builder.append_field(Some(v.0)),
Value::String(v) => builder.append_field(Some(v.as_utf8())),
Value::Binary(v) => builder.append_field(Some(hex::encode(v.deref()))),
Value::Date(v) => builder.append_field(Some(v.to_string())),
Value::DateTime(v) => builder.append_field(Some(v.to_string())),
Value::Timestamp(v) => builder.append_field(Some(v.to_iso8601_string())),
Value::Float32(v) => builder.append_field(Some(&v.0)),
Value::Float64(v) => builder.append_field(Some(&v.0)),
Value::String(v) => builder.append_field(Some(&v.as_utf8())),
Value::Binary(v) => builder.append_field(Some(&hex::encode(v.deref()))),
Value::Date(v) => builder.append_field(Some(&v.to_string())),
Value::DateTime(v) => builder.append_field(Some(&v.to_string())),
Value::Timestamp(v) => builder.append_field(Some(&v.to_iso8601_string())),
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!(
"cannot write value {:?} in postgres protocol: unimplemented",
@@ -146,7 +165,12 @@ fn type_translate(origin: &ConcreteDataType) -> Result<Type> {
#[async_trait]
impl ExtendedQueryHandler for PostgresServerHandler {
async fn do_query<C>(&self, _client: &mut C, _portal: &Portal) -> PgWireResult<Response>
async fn do_query<C>(
&self,
_client: &mut C,
_portal: &Portal,
_max_rows: usize,
) -> PgWireResult<Response>
where
C: ClientInfo + Unpin + Send + Sync,
{
@@ -269,7 +293,7 @@ mod test {
Value::DateTime(1000001i64.into()),
Value::Timestamp(1000001i64.into()),
];
let mut builder = TextQueryResponseBuilder::new(schema);
let mut builder = TextDataRowEncoder::new(schema.len());
for i in values {
assert!(encode_value(&i, &mut builder).is_ok());
}

View File

@@ -54,15 +54,13 @@ impl PostgresServer {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
io_runtime.spawn(async move {
process_socket(
io_stream,
auth_handler.clone(),
query_handler.clone(),
query_handler.clone(),
)
.await;
});
io_runtime.spawn(process_socket(
io_stream,
None,
auth_handler.clone(),
query_handler.clone(),
query_handler.clone(),
));
}
};
}