feat: Initial support of postgresql wire protocol (#229)

* feat: initial commit of postgres protocol adapter

* initial commit of postgres server

* feat: use common_io runtime and correct testcase

* fix previous tests

* feat: adopt pgwire api changes and add support for text encoded data

* feat: initial integration with datanode

* test: add feature flag to test

* fix: resolve lint warnings

* feat: add postgres feature flags for datanode

* feat: add support for newly introduced timestamp type

* feat: adopt latest datanode changes

* fix: address clippy warning for flattern scenario

* fix: make clippy great again

* fix: address issues found in review

* chore: sort dependencies by name

* feat: adopt new Output api

* fix: return error on unsupported data types

* refactor: extract common code dealing with record batches

* fix: resolve clippy warnings

* test: adds some unit tests postgres handler

* test: correct test for cargo update

* fix: update query module name

* test: add assertion for error content
This commit is contained in:
Ning Sun
2022-09-15 21:39:05 +08:00
committed by GitHub
parent fb6153f7e0
commit e67b0eb259
17 changed files with 1277 additions and 528 deletions

1104
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,10 @@ wal_dir = '/tmp/greptimedb/wal'
mysql_addr = '0.0.0.0:3306'
mysql_runtime_size = 4
# applied when postgres feature enbaled
postgres_addr = '0.0.0.0:5432'
postgres_runtime_size = 4
[storage]
type = 'File'
data_dir = '/tmp/greptimedb/data/'

View File

@@ -21,3 +21,7 @@ toml = "0.5"
[dev-dependencies]
serde = "1.0"
tempdir = "0.3"
[features]
default = ["postgres"]
postgres = ["datanode/postgres"]

View File

@@ -39,6 +39,9 @@ struct StartCommand {
rpc_addr: Option<String>,
#[clap(long)]
mysql_addr: Option<String>,
#[cfg(feature = "postgres")]
#[clap(long)]
postgres_addr: Option<String>,
#[clap(short, long)]
config_file: Option<String>,
}
@@ -78,6 +81,10 @@ impl TryFrom<StartCommand> for DatanodeOptions {
if let Some(addr) = cmd.mysql_addr {
opts.mysql_addr = addr;
}
#[cfg(feature = "postgres")]
if let Some(addr) = cmd.postgres_addr {
opts.postgres_addr = addr;
}
Ok(opts)
}
@@ -95,6 +102,8 @@ mod tests {
http_addr: None,
rpc_addr: None,
mysql_addr: None,
#[cfg(feature = "postgres")]
postgres_addr: None,
config_file: Some(format!(
"{}/../../config/datanode.example.toml",
std::env::current_dir().unwrap().as_path().to_str().unwrap()
@@ -106,6 +115,13 @@ mod tests {
assert_eq!("/tmp/greptimedb/wal".to_string(), options.wal_dir);
assert_eq!("0.0.0.0:3306".to_string(), options.mysql_addr);
assert_eq!(4, options.mysql_runtime_size);
#[cfg(feature = "postgres")]
{
assert_eq!("0.0.0.0:5432".to_string(), options.postgres_addr);
assert_eq!(4, options.postgres_runtime_size);
}
match options.storage {
ObjectStoreConfig::File { data_dir } => {
assert_eq!("/tmp/greptimedb/data/".to_string(), data_dir)

View File

@@ -71,6 +71,7 @@ impl RecordBatches {
self.schema.clone()
}
// TODO: a new name that to avoid misunderstanding it as an allocation operation
pub fn to_vec(self) -> Vec<RecordBatch> {
self.batches
}

View File

@@ -4,10 +4,11 @@ version = "0.1.0"
edition = "2021"
[features]
default = ["python"]
default = ["python", "postgres"]
python = [
"dep:script"
]
postgres = ["servers/postgres"]
[dependencies]
api = { path = "../api" }

View File

@@ -26,6 +26,10 @@ pub struct DatanodeOptions {
pub rpc_addr: String,
pub mysql_addr: String,
pub mysql_runtime_size: u32,
#[cfg(feature = "postgres")]
pub postgres_addr: String,
#[cfg(feature = "postgres")]
pub postgres_runtime_size: u32,
pub wal_dir: String,
pub storage: ObjectStoreConfig,
}
@@ -37,6 +41,10 @@ impl Default for DatanodeOptions {
rpc_addr: "0.0.0.0:3001".to_string(),
mysql_addr: "0.0.0.0:3306".to_string(),
mysql_runtime_size: 2,
#[cfg(feature = "postgres")]
postgres_addr: "0.0.0.0:5432".to_string(),
#[cfg(feature = "postgres")]
postgres_runtime_size: 2,
wal_dir: "/tmp/greptimedb/wal".to_string(),
storage: ObjectStoreConfig::default(),
}

View File

@@ -7,6 +7,7 @@ use common_runtime::Builder as RuntimeBuilder;
use servers::grpc::GrpcServer;
use servers::http::HttpServer;
use servers::mysql::server::MysqlServer;
use servers::postgres::PostgresServer;
use servers::server::Server;
use snafu::ResultExt;
use tokio::try_join;
@@ -20,6 +21,8 @@ pub struct Services {
http_server: HttpServer,
grpc_server: GrpcServer,
mysql_server: Box<dyn Server>,
#[cfg(feature = "postgres")]
postgres_server: Box<dyn Server>,
}
impl Services {
@@ -31,34 +34,49 @@ impl Services {
.build()
.context(error::RuntimeResourceSnafu)?,
);
#[cfg(feature = "postgres")]
let postgres_io_runtime = Arc::new(
RuntimeBuilder::default()
.worker_threads(opts.postgres_runtime_size as usize)
.thread_name("postgres-io-handlers")
.build()
.context(error::RuntimeResourceSnafu)?,
);
Ok(Self {
http_server: HttpServer::new(instance.clone()),
grpc_server: GrpcServer::new(instance.clone(), instance.clone()),
mysql_server: MysqlServer::create_server(instance, mysql_io_runtime),
mysql_server: MysqlServer::create_server(instance.clone(), mysql_io_runtime),
#[cfg(feature = "postgres")]
postgres_server: Box::new(PostgresServer::new(instance, postgres_io_runtime)),
})
}
// TODO(LFC): make servers started on demand (not starting mysql if no needed, for example)
pub async fn start(&mut self, opts: &DatanodeOptions) -> Result<()> {
let http_addr = &opts.http_addr;
let http_addr: SocketAddr = http_addr
.parse()
.context(error::ParseAddrSnafu { addr: http_addr })?;
let http_addr: SocketAddr = opts.http_addr.parse().context(error::ParseAddrSnafu {
addr: &opts.http_addr,
})?;
let grpc_addr = &opts.rpc_addr;
let grpc_addr: SocketAddr = grpc_addr
.parse()
.context(error::ParseAddrSnafu { addr: grpc_addr })?;
let grpc_addr: SocketAddr = opts.rpc_addr.parse().context(error::ParseAddrSnafu {
addr: &opts.rpc_addr,
})?;
let mysql_addr = &opts.mysql_addr;
let mysql_addr: SocketAddr = mysql_addr
.parse()
.context(error::ParseAddrSnafu { addr: mysql_addr })?;
let mysql_addr: SocketAddr = opts.mysql_addr.parse().context(error::ParseAddrSnafu {
addr: &opts.mysql_addr,
})?;
#[cfg(feature = "postgres")]
let postgres_addr: SocketAddr =
opts.postgres_addr.parse().context(error::ParseAddrSnafu {
addr: &opts.postgres_addr,
})?;
try_join!(
self.http_server.start(http_addr),
self.grpc_server.start(grpc_addr),
self.mysql_server.start(mysql_addr),
#[cfg(feature = "postgres")]
self.postgres_server.start(postgres_addr),
)
.context(error::StartServerSnafu)?;
Ok(())

View File

@@ -1,3 +1,4 @@
#[cfg(any(test, feature = "test"))]
use crate::data_type::ConcreteDataType;
/// Unique identifier for logical data type.

View File

@@ -268,6 +268,7 @@ impl MutableVector for ListVectorBuilder {
#[cfg(test)]
mod tests {
use arrow::array::{MutableListArray, MutablePrimitiveArray, TryExtend};
use serde_json::json;
use super::*;
use crate::types::ListType;
@@ -426,8 +427,8 @@ mod tests {
let list_vector = ListVector::from(arrow_array);
assert_eq!(
"Ok([Array([Number(1), Number(2), Number(3)]), Null, Array([Number(4), Null, Number(6)])])",
format!("{:?}", list_vector.serialize_to_json())
vec![json!([1, 2, 3]), json!(null), json!([4, null, 6]),],
list_vector.serialize_to_json().unwrap()
);
}

View File

@@ -16,10 +16,12 @@ common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }
datatypes = { path = "../datatypes" }
futures = "0.3"
hex = { version = "0.4", optional = true }
hyper = { version = "0.14", features = ["full"] }
metrics = "0.20"
num_cpus = "1.13"
opensrv-mysql = "0.1"
pgwire = { version = "0.3", optional = true }
query = { path = "../query" }
serde = "1.0"
serde_json = "1.0"
@@ -30,6 +32,10 @@ tonic = "0.8"
tower = { version = "0.4", features = ["full"] }
tower-http = { version = "0.3", features = ["full"] }
[features]
default = ["postgres"]
postgres = ["hex", "pgwire"]
[dev-dependencies]
catalog = { path = "../catalog" }
common-base = { path = "../common/base" }
@@ -37,3 +43,4 @@ mysql_async = "0.30"
rand = "0.8"
script = { path = "../script", features = ["python"] }
test-util = { path = "../../test-util" }
tokio-postgres = "0.7"

View File

@@ -2,5 +2,7 @@ pub mod error;
pub mod grpc;
pub mod http;
pub mod mysql;
#[cfg(feature = "postgres")]
pub mod postgres;
pub mod query_handler;
pub mod server;

View File

@@ -0,0 +1,295 @@
use std::ops::Deref;
use async_trait::async_trait;
use common_query::Output;
use common_recordbatch::{util, RecordBatch};
use common_time::timestamp::TimeUnit;
use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::SchemaRef;
use pgwire::api::portal::Portal;
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{FieldInfo, Response, Tag, TextQueryResponseBuilder};
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{PgWireError, PgWireResult};
use crate::error::{self, Error, Result};
use crate::query_handler::SqlQueryHandlerRef;
pub struct PostgresServerHandler {
query_handler: SqlQueryHandlerRef,
}
impl PostgresServerHandler {
pub fn new(query_handler: SqlQueryHandlerRef) -> Self {
PostgresServerHandler { query_handler }
}
}
#[async_trait]
impl SimpleQueryHandler for PostgresServerHandler {
async fn do_query<C>(&self, _client: &C, query: &str) -> PgWireResult<Response>
where
C: ClientInfo + Unpin + Send + Sync,
{
let output = self
.query_handler
.do_query(query)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
match output {
Output::AffectedRows(rows) => Ok(Response::Execution(Tag::new_for_execution(
"OK",
Some(rows),
))),
Output::Stream(record_stream) => {
let schema = record_stream.schema();
let recordbatches = util::collect(record_stream)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
recordbatches_to_query_response(recordbatches.iter(), schema)
}
Output::RecordBatches(recordbatches) => {
let schema = recordbatches.schema();
recordbatches_to_query_response(recordbatches.to_vec().iter(), schema)
}
}
}
}
fn recordbatches_to_query_response<'a, I>(
recordbatches: I,
schema: SchemaRef,
) -> PgWireResult<Response>
where
I: Iterator<Item = &'a RecordBatch>,
{
let pg_schema = schema_to_pg(schema).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let mut builder = TextQueryResponseBuilder::new(pg_schema);
for recordbatch in recordbatches {
for row in recordbatch.rows() {
let row = row.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
for value in row.into_iter() {
encode_value(&value, &mut builder)?;
}
builder.finish_row();
}
}
Ok(Response::Query(builder.build()))
}
fn schema_to_pg(origin: SchemaRef) -> Result<Vec<FieldInfo>> {
origin
.column_schemas()
.iter()
.map(|col| {
Ok(FieldInfo::new(
col.name.clone(),
None,
None,
type_translate(&col.data_type)?,
))
})
.collect::<Result<Vec<FieldInfo>>>()
}
fn encode_value(value: &Value, builder: &mut TextQueryResponseBuilder) -> PgWireResult<()> {
match value {
Value::Null => builder.append_field(None::<i8>),
Value::Boolean(v) => builder.append_field(Some(v)),
Value::UInt8(v) => builder.append_field(Some(v)),
Value::UInt16(v) => builder.append_field(Some(v)),
Value::UInt32(v) => builder.append_field(Some(v)),
Value::UInt64(v) => builder.append_field(Some(v)),
Value::Int8(v) => builder.append_field(Some(v)),
Value::Int16(v) => builder.append_field(Some(v)),
Value::Int32(v) => builder.append_field(Some(v)),
Value::Int64(v) => builder.append_field(Some(v)),
Value::Float32(v) => builder.append_field(Some(v.0)),
Value::Float64(v) => builder.append_field(Some(v.0)),
Value::String(v) => builder.append_field(Some(v.as_utf8())),
Value::Binary(v) => builder.append_field(Some(hex::encode(v.deref()))),
Value::Date(v) => builder.append_field(Some(v.val())),
Value::DateTime(v) => builder.append_field(Some(v.val())),
Value::Timestamp(v) => builder.append_field(Some(v.convert_to(TimeUnit::Millisecond))),
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
err_msg: format!(
"cannot write value {:?} in postgres protocol: unimplemented",
&value
),
}))),
}
}
fn type_translate(origin: &ConcreteDataType) -> Result<Type> {
match origin {
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
&ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
&ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
&ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
&ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
&ConcreteDataType::Binary(_) => Ok(Type::BYTEA),
&ConcreteDataType::String(_) => Ok(Type::VARCHAR),
&ConcreteDataType::Date(_) => Ok(Type::DATE),
&ConcreteDataType::DateTime(_) => Ok(Type::TIMESTAMP),
&ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
&ConcreteDataType::List(_) => error::InternalSnafu {
err_msg: format!("not implemented for column datatype {:?}", origin),
}
.fail(),
}
}
#[async_trait]
impl ExtendedQueryHandler for PostgresServerHandler {
async fn do_query<C>(&self, _client: &mut C, _portal: &Portal) -> PgWireResult<Response>
where
C: ClientInfo + Unpin + Send + Sync,
{
unimplemented!()
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::value::ListValue;
use pgwire::api::results::FieldInfo;
use pgwire::api::Type;
use super::*;
#[test]
fn test_schema_convert() {
let column_schemas = vec![
ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
ColumnSchema::new(
"timestamps",
ConcreteDataType::timestamp_millis_datatype(),
true,
),
ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
];
let pg_field_info = vec![
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN),
FieldInfo::new("bools".into(), None, None, Type::BOOL),
FieldInfo::new("int8s".into(), None, None, Type::CHAR),
FieldInfo::new("int16s".into(), None, None, Type::INT2),
FieldInfo::new("int32s".into(), None, None, Type::INT4),
FieldInfo::new("int64s".into(), None, None, Type::INT8),
FieldInfo::new("uint8s".into(), None, None, Type::CHAR),
FieldInfo::new("uint16s".into(), None, None, Type::INT2),
FieldInfo::new("uint32s".into(), None, None, Type::INT4),
FieldInfo::new("uint64s".into(), None, None, Type::INT8),
FieldInfo::new("float32s".into(), None, None, Type::FLOAT4),
FieldInfo::new("float64s".into(), None, None, Type::FLOAT8),
FieldInfo::new("binaries".into(), None, None, Type::BYTEA),
FieldInfo::new("strings".into(), None, None, Type::VARCHAR),
FieldInfo::new("timestamps".into(), None, None, Type::TIMESTAMP),
FieldInfo::new("dates".into(), None, None, Type::DATE),
];
let schema = Arc::new(Schema::new(column_schemas));
let fs = schema_to_pg(schema).unwrap();
assert_eq!(fs, pg_field_info);
}
#[test]
fn test_encode_text_format_data() {
let schema = vec![
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN),
FieldInfo::new("bools".into(), None, None, Type::BOOL),
FieldInfo::new("uint8s".into(), None, None, Type::CHAR),
FieldInfo::new("uint16s".into(), None, None, Type::INT2),
FieldInfo::new("uint32s".into(), None, None, Type::INT4),
FieldInfo::new("uint64s".into(), None, None, Type::INT8),
FieldInfo::new("int8s".into(), None, None, Type::CHAR),
FieldInfo::new("int8s".into(), None, None, Type::CHAR),
FieldInfo::new("int16s".into(), None, None, Type::INT2),
FieldInfo::new("int16s".into(), None, None, Type::INT2),
FieldInfo::new("int32s".into(), None, None, Type::INT4),
FieldInfo::new("int32s".into(), None, None, Type::INT4),
FieldInfo::new("int64s".into(), None, None, Type::INT8),
FieldInfo::new("int64s".into(), None, None, Type::INT8),
FieldInfo::new("float32s".into(), None, None, Type::FLOAT4),
FieldInfo::new("float32s".into(), None, None, Type::FLOAT4),
FieldInfo::new("float32s".into(), None, None, Type::FLOAT4),
FieldInfo::new("float64s".into(), None, None, Type::FLOAT8),
FieldInfo::new("float64s".into(), None, None, Type::FLOAT8),
FieldInfo::new("float64s".into(), None, None, Type::FLOAT8),
FieldInfo::new("strings".into(), None, None, Type::VARCHAR),
FieldInfo::new("binaries".into(), None, None, Type::BYTEA),
FieldInfo::new("dates".into(), None, None, Type::DATE),
FieldInfo::new("datetimes".into(), None, None, Type::TIMESTAMP),
FieldInfo::new("timestamps".into(), None, None, Type::TIMESTAMP),
];
let values = vec![
Value::Null,
Value::Boolean(true),
Value::UInt8(u8::MAX),
Value::UInt16(u16::MAX),
Value::UInt32(u32::MAX),
Value::UInt64(u64::MAX),
Value::Int8(i8::MAX),
Value::Int8(i8::MIN),
Value::Int16(i16::MAX),
Value::Int16(i16::MIN),
Value::Int32(i32::MAX),
Value::Int32(i32::MIN),
Value::Int64(i64::MAX),
Value::Int64(i64::MIN),
Value::Float32(f32::MAX.into()),
Value::Float32(f32::MIN.into()),
Value::Float32(0f32.into()),
Value::Float64(f64::MAX.into()),
Value::Float64(f64::MIN.into()),
Value::Float64(0f64.into()),
Value::String("greptime".into()),
Value::Binary("greptime".as_bytes().into()),
Value::Date(1001i32.into()),
Value::DateTime(1000001i64.into()),
Value::Timestamp(1000001i64.into()),
];
let mut builder = TextQueryResponseBuilder::new(schema);
for i in values {
assert!(encode_value(&i, &mut builder).is_ok());
}
let err = encode_value(
&Value::List(ListValue::new(
Some(Box::new(vec![])),
ConcreteDataType::int8_datatype(),
)),
&mut builder,
)
.unwrap_err();
match err {
PgWireError::ApiError(e) => {
assert!(format!("{}", e).contains("Internal error:"));
}
_ => {
unreachable!()
}
}
}
}

View File

@@ -0,0 +1,4 @@
mod handler;
mod server;
pub use server::PostgresServer;

View File

@@ -0,0 +1,136 @@
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use async_trait::async_trait;
use common_runtime::Runtime;
use common_telemetry::logging::{error, info};
use futures::future::AbortHandle;
use futures::future::AbortRegistration;
use futures::future::Abortable;
use futures::StreamExt;
use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::tokio::process_socket;
use snafu::prelude::*;
use tokio;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::TcpListenerStream;
use crate::error::{self, Result};
use crate::postgres::handler::PostgresServerHandler;
use crate::query_handler::SqlQueryHandlerRef;
use crate::server::Server;
pub struct PostgresServer {
// See MySQL module for usage of these types
abort_handle: AbortHandle,
abort_registration: Option<AbortRegistration>,
// A handle holding the TCP accepting task.
join_handle: Option<JoinHandle<()>>,
auth_handler: Arc<NoopStartupHandler>,
query_handler: Arc<PostgresServerHandler>,
io_runtime: Arc<Runtime>,
}
impl PostgresServer {
/// Creates a new Postgres server with provided query_handler and async runtime
pub fn new(query_handler: SqlQueryHandlerRef, io_runtime: Arc<Runtime>) -> PostgresServer {
let (abort_handle, registration) = AbortHandle::new_pair();
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
let startup_handler = Arc::new(NoopStartupHandler);
PostgresServer {
abort_handle,
abort_registration: Some(registration),
join_handle: None,
auth_handler: startup_handler,
query_handler: postgres_handler,
io_runtime,
}
}
async fn bind(addr: SocketAddr) -> Result<(TcpListenerStream, SocketAddr)> {
let listener = tokio::net::TcpListener::bind(addr)
.await
.context(error::TokioIoSnafu {
err_msg: format!("Failed to bind addr {}", addr),
})?;
// get actually bond addr in case input addr use port 0
let addr = listener.local_addr()?;
info!("Postgres server is bound to {}", addr);
Ok((TcpListenerStream::new(listener), addr))
}
fn accept(&self, accepting_stream: Abortable<TcpListenerStream>) -> impl Future<Output = ()> {
let io_runtime = self.io_runtime.clone();
let auth_handler = self.auth_handler.clone();
let query_handler = self.query_handler.clone();
accepting_stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let auth_handler = auth_handler.clone();
let query_handler = query_handler.clone();
async move {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
io_runtime.spawn(async move {
process_socket(
io_stream,
auth_handler.clone(),
query_handler.clone(),
query_handler.clone(),
)
.await;
});
}
};
}
})
}
}
#[async_trait]
impl Server for PostgresServer {
async fn shutdown(&mut self) -> Result<()> {
match self.join_handle.take() {
Some(join_handle) => {
self.abort_handle.abort();
if let Err(error) = join_handle.await {
// Couldn't use `error!(e; xxx)` as JoinError doesn't implement ErrorExt.
error!(
"Unexpected error during shutdown Postgres server, error: {}",
error
);
} else {
info!("Postgres server is shutdown.")
}
Ok(())
}
None => error::InternalSnafu {
err_msg: "Postgres server is not started.",
}
.fail()?,
}
}
async fn start(&mut self, listening: SocketAddr) -> Result<SocketAddr> {
match self.abort_registration.take() {
Some(registration) => {
let (stream, listener) = Self::bind(listening).await?;
let stream = Abortable::new(stream, registration);
self.join_handle = Some(tokio::spawn(self.accept(stream)));
Ok(listener)
}
None => error::InternalSnafu {
err_msg: "Postgres server has been started.",
}
.fail()?,
}
}
}

View File

@@ -18,6 +18,8 @@ use script::{
engine::{CompileContext, EvalContext, Script, ScriptEngine},
python::{PyEngine, PyScript},
};
#[cfg(feature = "postgres")]
mod postgres;
struct DummyInstance {
query_engine: QueryEngineRef,

View File

@@ -0,0 +1,169 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use common_runtime::Builder as RuntimeBuilder;
use rand::rngs::StdRng;
use rand::Rng;
use servers::error::Result;
use servers::postgres::PostgresServer;
use servers::server::Server;
use test_util::MemTable;
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
use crate::create_testing_sql_query_handler;
fn create_postgres_server(table: MemTable) -> Result<Box<dyn Server>> {
let query_handler = create_testing_sql_query_handler(table);
let io_runtime = Arc::new(
RuntimeBuilder::default()
.worker_threads(4)
.thread_name("postgres-io-handlers")
.build()
.unwrap(),
);
Ok(Box::new(PostgresServer::new(query_handler, io_runtime)))
}
#[tokio::test]
pub async fn test_start_postgres_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let mut pg_server = create_postgres_server(table)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = pg_server.start(listening).await;
assert!(result.is_ok());
let result = pg_server.start(listening).await;
assert!(result
.unwrap_err()
.to_string()
.contains("Postgres server has been started."));
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shutdown_pg_server() -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let mut postgres_server = create_postgres_server(table)?;
let result = postgres_server.shutdown().await;
assert!(result
.unwrap_err()
.to_string()
.contains("Postgres server is not started."));
let listening = "127.0.0.1:5432".parse::<SocketAddr>().unwrap();
let server_addr = postgres_server.start(listening).await.unwrap();
let server_port = server_addr.port();
let mut join_handles = vec![];
for _ in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port).await {
Ok(connection) => {
let rows = connection
.simple_query("SELECT uint32s FROM numbers LIMIT 1")
.await
.unwrap();
let result_text = unwrap_results(&rows)[0];
let result: i32 = result_text.parse().unwrap();
assert_eq!(result, 0);
tokio::time::sleep(Duration::from_millis(10)).await;
}
Err(e) => {
return Err(e);
}
}
}
Ok(())
}))
}
tokio::time::sleep(Duration::from_millis(100)).await;
let result = postgres_server.shutdown().await;
assert!(result.is_ok());
for handle in join_handles.iter_mut() {
let result = handle.await.unwrap();
assert!(result.is_err());
let error = result.unwrap_err().to_string();
assert!(error.contains("Connection refused") || error.contains("Connection reset by peer"));
}
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_query_pg_concurrently() -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let mut pg_server = create_postgres_server(table)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
let threads = 4;
let expect_executed_queries_per_worker = 300;
let mut join_handles = vec![];
for _i in 0..threads {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut client = create_connection(server_port).await.unwrap();
for _k in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
let result: u32 = unwrap_results(
client
.simple_query(&format!(
"SELECT uint32s FROM numbers WHERE uint32s = {}",
expected
))
.await
.unwrap()
.as_ref(),
)[0]
.parse()
.unwrap();
assert_eq!(result, expected);
// 1/100 chance to reconnect
let should_recreate_conn = expected == 1;
if should_recreate_conn {
client = create_connection(server_port).await.unwrap();
}
}
expect_executed_queries_per_worker
}))
}
let mut total_pending_queries = threads * expect_executed_queries_per_worker;
for handle in join_handles.iter_mut() {
total_pending_queries -= handle.await.unwrap();
}
assert_eq!(0, total_pending_queries);
Ok(())
}
async fn create_connection(port: u16) -> std::result::Result<Client, PgError> {
let url = format!("host=127.0.0.1 port={} connect_timeout=2", port);
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
tokio::spawn(conn);
Ok(client)
}
fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
match resp {
&SimpleQueryMessage::Row(ref r) => r.get(col_index),
_ => None,
}
}
fn unwrap_results(resp: &[SimpleQueryMessage]) -> Vec<&str> {
resp.iter().filter_map(|m| resolve_result(m, 0)).collect()
}