mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-17 05:20:37 +00:00
feat: Initial support of postgresql wire protocol (#229)
* feat: initial commit of postgres protocol adapter * initial commit of postgres server * feat: use common_io runtime and correct testcase * fix previous tests * feat: adopt pgwire api changes and add support for text encoded data * feat: initial integration with datanode * test: add feature flag to test * fix: resolve lint warnings * feat: add postgres feature flags for datanode * feat: add support for newly introduced timestamp type * feat: adopt latest datanode changes * fix: address clippy warning for flattern scenario * fix: make clippy great again * fix: address issues found in review * chore: sort dependencies by name * feat: adopt new Output api * fix: return error on unsupported data types * refactor: extract common code dealing with record batches * fix: resolve clippy warnings * test: adds some unit tests postgres handler * test: correct test for cargo update * fix: update query module name * test: add assertion for error content
This commit is contained in:
1104
Cargo.lock
generated
1104
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -5,6 +5,10 @@ wal_dir = '/tmp/greptimedb/wal'
|
||||
mysql_addr = '0.0.0.0:3306'
|
||||
mysql_runtime_size = 4
|
||||
|
||||
# applied when postgres feature enbaled
|
||||
postgres_addr = '0.0.0.0:5432'
|
||||
postgres_runtime_size = 4
|
||||
|
||||
[storage]
|
||||
type = 'File'
|
||||
data_dir = '/tmp/greptimedb/data/'
|
||||
|
||||
@@ -21,3 +21,7 @@ toml = "0.5"
|
||||
[dev-dependencies]
|
||||
serde = "1.0"
|
||||
tempdir = "0.3"
|
||||
|
||||
[features]
|
||||
default = ["postgres"]
|
||||
postgres = ["datanode/postgres"]
|
||||
|
||||
@@ -39,6 +39,9 @@ struct StartCommand {
|
||||
rpc_addr: Option<String>,
|
||||
#[clap(long)]
|
||||
mysql_addr: Option<String>,
|
||||
#[cfg(feature = "postgres")]
|
||||
#[clap(long)]
|
||||
postgres_addr: Option<String>,
|
||||
#[clap(short, long)]
|
||||
config_file: Option<String>,
|
||||
}
|
||||
@@ -78,6 +81,10 @@ impl TryFrom<StartCommand> for DatanodeOptions {
|
||||
if let Some(addr) = cmd.mysql_addr {
|
||||
opts.mysql_addr = addr;
|
||||
}
|
||||
#[cfg(feature = "postgres")]
|
||||
if let Some(addr) = cmd.postgres_addr {
|
||||
opts.postgres_addr = addr;
|
||||
}
|
||||
|
||||
Ok(opts)
|
||||
}
|
||||
@@ -95,6 +102,8 @@ mod tests {
|
||||
http_addr: None,
|
||||
rpc_addr: None,
|
||||
mysql_addr: None,
|
||||
#[cfg(feature = "postgres")]
|
||||
postgres_addr: None,
|
||||
config_file: Some(format!(
|
||||
"{}/../../config/datanode.example.toml",
|
||||
std::env::current_dir().unwrap().as_path().to_str().unwrap()
|
||||
@@ -106,6 +115,13 @@ mod tests {
|
||||
assert_eq!("/tmp/greptimedb/wal".to_string(), options.wal_dir);
|
||||
assert_eq!("0.0.0.0:3306".to_string(), options.mysql_addr);
|
||||
assert_eq!(4, options.mysql_runtime_size);
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
{
|
||||
assert_eq!("0.0.0.0:5432".to_string(), options.postgres_addr);
|
||||
assert_eq!(4, options.postgres_runtime_size);
|
||||
}
|
||||
|
||||
match options.storage {
|
||||
ObjectStoreConfig::File { data_dir } => {
|
||||
assert_eq!("/tmp/greptimedb/data/".to_string(), data_dir)
|
||||
|
||||
@@ -71,6 +71,7 @@ impl RecordBatches {
|
||||
self.schema.clone()
|
||||
}
|
||||
|
||||
// TODO: a new name that to avoid misunderstanding it as an allocation operation
|
||||
pub fn to_vec(self) -> Vec<RecordBatch> {
|
||||
self.batches
|
||||
}
|
||||
|
||||
@@ -4,10 +4,11 @@ version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["python"]
|
||||
default = ["python", "postgres"]
|
||||
python = [
|
||||
"dep:script"
|
||||
]
|
||||
postgres = ["servers/postgres"]
|
||||
|
||||
[dependencies]
|
||||
api = { path = "../api" }
|
||||
|
||||
@@ -26,6 +26,10 @@ pub struct DatanodeOptions {
|
||||
pub rpc_addr: String,
|
||||
pub mysql_addr: String,
|
||||
pub mysql_runtime_size: u32,
|
||||
#[cfg(feature = "postgres")]
|
||||
pub postgres_addr: String,
|
||||
#[cfg(feature = "postgres")]
|
||||
pub postgres_runtime_size: u32,
|
||||
pub wal_dir: String,
|
||||
pub storage: ObjectStoreConfig,
|
||||
}
|
||||
@@ -37,6 +41,10 @@ impl Default for DatanodeOptions {
|
||||
rpc_addr: "0.0.0.0:3001".to_string(),
|
||||
mysql_addr: "0.0.0.0:3306".to_string(),
|
||||
mysql_runtime_size: 2,
|
||||
#[cfg(feature = "postgres")]
|
||||
postgres_addr: "0.0.0.0:5432".to_string(),
|
||||
#[cfg(feature = "postgres")]
|
||||
postgres_runtime_size: 2,
|
||||
wal_dir: "/tmp/greptimedb/wal".to_string(),
|
||||
storage: ObjectStoreConfig::default(),
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use common_runtime::Builder as RuntimeBuilder;
|
||||
use servers::grpc::GrpcServer;
|
||||
use servers::http::HttpServer;
|
||||
use servers::mysql::server::MysqlServer;
|
||||
use servers::postgres::PostgresServer;
|
||||
use servers::server::Server;
|
||||
use snafu::ResultExt;
|
||||
use tokio::try_join;
|
||||
@@ -20,6 +21,8 @@ pub struct Services {
|
||||
http_server: HttpServer,
|
||||
grpc_server: GrpcServer,
|
||||
mysql_server: Box<dyn Server>,
|
||||
#[cfg(feature = "postgres")]
|
||||
postgres_server: Box<dyn Server>,
|
||||
}
|
||||
|
||||
impl Services {
|
||||
@@ -31,34 +34,49 @@ impl Services {
|
||||
.build()
|
||||
.context(error::RuntimeResourceSnafu)?,
|
||||
);
|
||||
#[cfg(feature = "postgres")]
|
||||
let postgres_io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
.worker_threads(opts.postgres_runtime_size as usize)
|
||||
.thread_name("postgres-io-handlers")
|
||||
.build()
|
||||
.context(error::RuntimeResourceSnafu)?,
|
||||
);
|
||||
Ok(Self {
|
||||
http_server: HttpServer::new(instance.clone()),
|
||||
grpc_server: GrpcServer::new(instance.clone(), instance.clone()),
|
||||
mysql_server: MysqlServer::create_server(instance, mysql_io_runtime),
|
||||
mysql_server: MysqlServer::create_server(instance.clone(), mysql_io_runtime),
|
||||
#[cfg(feature = "postgres")]
|
||||
postgres_server: Box::new(PostgresServer::new(instance, postgres_io_runtime)),
|
||||
})
|
||||
}
|
||||
|
||||
// TODO(LFC): make servers started on demand (not starting mysql if no needed, for example)
|
||||
pub async fn start(&mut self, opts: &DatanodeOptions) -> Result<()> {
|
||||
let http_addr = &opts.http_addr;
|
||||
let http_addr: SocketAddr = http_addr
|
||||
.parse()
|
||||
.context(error::ParseAddrSnafu { addr: http_addr })?;
|
||||
let http_addr: SocketAddr = opts.http_addr.parse().context(error::ParseAddrSnafu {
|
||||
addr: &opts.http_addr,
|
||||
})?;
|
||||
|
||||
let grpc_addr = &opts.rpc_addr;
|
||||
let grpc_addr: SocketAddr = grpc_addr
|
||||
.parse()
|
||||
.context(error::ParseAddrSnafu { addr: grpc_addr })?;
|
||||
let grpc_addr: SocketAddr = opts.rpc_addr.parse().context(error::ParseAddrSnafu {
|
||||
addr: &opts.rpc_addr,
|
||||
})?;
|
||||
|
||||
let mysql_addr = &opts.mysql_addr;
|
||||
let mysql_addr: SocketAddr = mysql_addr
|
||||
.parse()
|
||||
.context(error::ParseAddrSnafu { addr: mysql_addr })?;
|
||||
let mysql_addr: SocketAddr = opts.mysql_addr.parse().context(error::ParseAddrSnafu {
|
||||
addr: &opts.mysql_addr,
|
||||
})?;
|
||||
|
||||
#[cfg(feature = "postgres")]
|
||||
let postgres_addr: SocketAddr =
|
||||
opts.postgres_addr.parse().context(error::ParseAddrSnafu {
|
||||
addr: &opts.postgres_addr,
|
||||
})?;
|
||||
|
||||
try_join!(
|
||||
self.http_server.start(http_addr),
|
||||
self.grpc_server.start(grpc_addr),
|
||||
self.mysql_server.start(mysql_addr),
|
||||
#[cfg(feature = "postgres")]
|
||||
self.postgres_server.start(postgres_addr),
|
||||
)
|
||||
.context(error::StartServerSnafu)?;
|
||||
Ok(())
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#[cfg(any(test, feature = "test"))]
|
||||
use crate::data_type::ConcreteDataType;
|
||||
|
||||
/// Unique identifier for logical data type.
|
||||
|
||||
@@ -268,6 +268,7 @@ impl MutableVector for ListVectorBuilder {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use arrow::array::{MutableListArray, MutablePrimitiveArray, TryExtend};
|
||||
use serde_json::json;
|
||||
|
||||
use super::*;
|
||||
use crate::types::ListType;
|
||||
@@ -426,8 +427,8 @@ mod tests {
|
||||
|
||||
let list_vector = ListVector::from(arrow_array);
|
||||
assert_eq!(
|
||||
"Ok([Array([Number(1), Number(2), Number(3)]), Null, Array([Number(4), Null, Number(6)])])",
|
||||
format!("{:?}", list_vector.serialize_to_json())
|
||||
vec![json!([1, 2, 3]), json!(null), json!([4, null, 6]),],
|
||||
list_vector.serialize_to_json().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -16,10 +16,12 @@ common-telemetry = { path = "../common/telemetry" }
|
||||
common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
futures = "0.3"
|
||||
hex = { version = "0.4", optional = true }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
metrics = "0.20"
|
||||
num_cpus = "1.13"
|
||||
opensrv-mysql = "0.1"
|
||||
pgwire = { version = "0.3", optional = true }
|
||||
query = { path = "../query" }
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
@@ -30,6 +32,10 @@ tonic = "0.8"
|
||||
tower = { version = "0.4", features = ["full"] }
|
||||
tower-http = { version = "0.3", features = ["full"] }
|
||||
|
||||
[features]
|
||||
default = ["postgres"]
|
||||
postgres = ["hex", "pgwire"]
|
||||
|
||||
[dev-dependencies]
|
||||
catalog = { path = "../catalog" }
|
||||
common-base = { path = "../common/base" }
|
||||
@@ -37,3 +43,4 @@ mysql_async = "0.30"
|
||||
rand = "0.8"
|
||||
script = { path = "../script", features = ["python"] }
|
||||
test-util = { path = "../../test-util" }
|
||||
tokio-postgres = "0.7"
|
||||
|
||||
@@ -2,5 +2,7 @@ pub mod error;
|
||||
pub mod grpc;
|
||||
pub mod http;
|
||||
pub mod mysql;
|
||||
#[cfg(feature = "postgres")]
|
||||
pub mod postgres;
|
||||
pub mod query_handler;
|
||||
pub mod server;
|
||||
|
||||
295
src/servers/src/postgres/handler.rs
Normal file
295
src/servers/src/postgres/handler.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
use std::ops::Deref;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use common_time::timestamp::TimeUnit;
|
||||
use datatypes::prelude::{ConcreteDataType, Value};
|
||||
use datatypes::schema::SchemaRef;
|
||||
use pgwire::api::portal::Portal;
|
||||
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
|
||||
use pgwire::api::results::{FieldInfo, Response, Tag, TextQueryResponseBuilder};
|
||||
use pgwire::api::{ClientInfo, Type};
|
||||
use pgwire::error::{PgWireError, PgWireResult};
|
||||
|
||||
use crate::error::{self, Error, Result};
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
|
||||
pub struct PostgresServerHandler {
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
}
|
||||
|
||||
impl PostgresServerHandler {
|
||||
pub fn new(query_handler: SqlQueryHandlerRef) -> Self {
|
||||
PostgresServerHandler { query_handler }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SimpleQueryHandler for PostgresServerHandler {
|
||||
async fn do_query<C>(&self, _client: &C, query: &str) -> PgWireResult<Response>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
let output = self
|
||||
.query_handler
|
||||
.do_query(query)
|
||||
.await
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
|
||||
match output {
|
||||
Output::AffectedRows(rows) => Ok(Response::Execution(Tag::new_for_execution(
|
||||
"OK",
|
||||
Some(rows),
|
||||
))),
|
||||
Output::Stream(record_stream) => {
|
||||
let schema = record_stream.schema();
|
||||
let recordbatches = util::collect(record_stream)
|
||||
.await
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
recordbatches_to_query_response(recordbatches.iter(), schema)
|
||||
}
|
||||
Output::RecordBatches(recordbatches) => {
|
||||
let schema = recordbatches.schema();
|
||||
recordbatches_to_query_response(recordbatches.to_vec().iter(), schema)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn recordbatches_to_query_response<'a, I>(
|
||||
recordbatches: I,
|
||||
schema: SchemaRef,
|
||||
) -> PgWireResult<Response>
|
||||
where
|
||||
I: Iterator<Item = &'a RecordBatch>,
|
||||
{
|
||||
let pg_schema = schema_to_pg(schema).map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
let mut builder = TextQueryResponseBuilder::new(pg_schema);
|
||||
|
||||
for recordbatch in recordbatches {
|
||||
for row in recordbatch.rows() {
|
||||
let row = row.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
for value in row.into_iter() {
|
||||
encode_value(&value, &mut builder)?;
|
||||
}
|
||||
builder.finish_row();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Response::Query(builder.build()))
|
||||
}
|
||||
|
||||
fn schema_to_pg(origin: SchemaRef) -> Result<Vec<FieldInfo>> {
|
||||
origin
|
||||
.column_schemas()
|
||||
.iter()
|
||||
.map(|col| {
|
||||
Ok(FieldInfo::new(
|
||||
col.name.clone(),
|
||||
None,
|
||||
None,
|
||||
type_translate(&col.data_type)?,
|
||||
))
|
||||
})
|
||||
.collect::<Result<Vec<FieldInfo>>>()
|
||||
}
|
||||
|
||||
fn encode_value(value: &Value, builder: &mut TextQueryResponseBuilder) -> PgWireResult<()> {
|
||||
match value {
|
||||
Value::Null => builder.append_field(None::<i8>),
|
||||
Value::Boolean(v) => builder.append_field(Some(v)),
|
||||
Value::UInt8(v) => builder.append_field(Some(v)),
|
||||
Value::UInt16(v) => builder.append_field(Some(v)),
|
||||
Value::UInt32(v) => builder.append_field(Some(v)),
|
||||
Value::UInt64(v) => builder.append_field(Some(v)),
|
||||
Value::Int8(v) => builder.append_field(Some(v)),
|
||||
Value::Int16(v) => builder.append_field(Some(v)),
|
||||
Value::Int32(v) => builder.append_field(Some(v)),
|
||||
Value::Int64(v) => builder.append_field(Some(v)),
|
||||
Value::Float32(v) => builder.append_field(Some(v.0)),
|
||||
Value::Float64(v) => builder.append_field(Some(v.0)),
|
||||
Value::String(v) => builder.append_field(Some(v.as_utf8())),
|
||||
Value::Binary(v) => builder.append_field(Some(hex::encode(v.deref()))),
|
||||
Value::Date(v) => builder.append_field(Some(v.val())),
|
||||
Value::DateTime(v) => builder.append_field(Some(v.val())),
|
||||
Value::Timestamp(v) => builder.append_field(Some(v.convert_to(TimeUnit::Millisecond))),
|
||||
Value::List(_) => Err(PgWireError::ApiError(Box::new(Error::Internal {
|
||||
err_msg: format!(
|
||||
"cannot write value {:?} in postgres protocol: unimplemented",
|
||||
&value
|
||||
),
|
||||
}))),
|
||||
}
|
||||
}
|
||||
|
||||
fn type_translate(origin: &ConcreteDataType) -> Result<Type> {
|
||||
match origin {
|
||||
&ConcreteDataType::Null(_) => Ok(Type::UNKNOWN),
|
||||
&ConcreteDataType::Boolean(_) => Ok(Type::BOOL),
|
||||
&ConcreteDataType::Int8(_) | &ConcreteDataType::UInt8(_) => Ok(Type::CHAR),
|
||||
&ConcreteDataType::Int16(_) | &ConcreteDataType::UInt16(_) => Ok(Type::INT2),
|
||||
&ConcreteDataType::Int32(_) | &ConcreteDataType::UInt32(_) => Ok(Type::INT4),
|
||||
&ConcreteDataType::Int64(_) | &ConcreteDataType::UInt64(_) => Ok(Type::INT8),
|
||||
&ConcreteDataType::Float32(_) => Ok(Type::FLOAT4),
|
||||
&ConcreteDataType::Float64(_) => Ok(Type::FLOAT8),
|
||||
&ConcreteDataType::Binary(_) => Ok(Type::BYTEA),
|
||||
&ConcreteDataType::String(_) => Ok(Type::VARCHAR),
|
||||
&ConcreteDataType::Date(_) => Ok(Type::DATE),
|
||||
&ConcreteDataType::DateTime(_) => Ok(Type::TIMESTAMP),
|
||||
&ConcreteDataType::Timestamp(_) => Ok(Type::TIMESTAMP),
|
||||
&ConcreteDataType::List(_) => error::InternalSnafu {
|
||||
err_msg: format!("not implemented for column datatype {:?}", origin),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
async fn do_query<C>(&self, _client: &mut C, _portal: &Portal) -> PgWireResult<Response>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::value::ListValue;
|
||||
use pgwire::api::results::FieldInfo;
|
||||
use pgwire::api::Type;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_schema_convert() {
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true),
|
||||
ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true),
|
||||
ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true),
|
||||
ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true),
|
||||
ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true),
|
||||
ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true),
|
||||
ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true),
|
||||
ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true),
|
||||
ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true),
|
||||
ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true),
|
||||
ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true),
|
||||
ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true),
|
||||
ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true),
|
||||
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
|
||||
ColumnSchema::new(
|
||||
"timestamps",
|
||||
ConcreteDataType::timestamp_millis_datatype(),
|
||||
true,
|
||||
),
|
||||
ColumnSchema::new("dates", ConcreteDataType::date_datatype(), true),
|
||||
];
|
||||
let pg_field_info = vec![
|
||||
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN),
|
||||
FieldInfo::new("bools".into(), None, None, Type::BOOL),
|
||||
FieldInfo::new("int8s".into(), None, None, Type::CHAR),
|
||||
FieldInfo::new("int16s".into(), None, None, Type::INT2),
|
||||
FieldInfo::new("int32s".into(), None, None, Type::INT4),
|
||||
FieldInfo::new("int64s".into(), None, None, Type::INT8),
|
||||
FieldInfo::new("uint8s".into(), None, None, Type::CHAR),
|
||||
FieldInfo::new("uint16s".into(), None, None, Type::INT2),
|
||||
FieldInfo::new("uint32s".into(), None, None, Type::INT4),
|
||||
FieldInfo::new("uint64s".into(), None, None, Type::INT8),
|
||||
FieldInfo::new("float32s".into(), None, None, Type::FLOAT4),
|
||||
FieldInfo::new("float64s".into(), None, None, Type::FLOAT8),
|
||||
FieldInfo::new("binaries".into(), None, None, Type::BYTEA),
|
||||
FieldInfo::new("strings".into(), None, None, Type::VARCHAR),
|
||||
FieldInfo::new("timestamps".into(), None, None, Type::TIMESTAMP),
|
||||
FieldInfo::new("dates".into(), None, None, Type::DATE),
|
||||
];
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
let fs = schema_to_pg(schema).unwrap();
|
||||
assert_eq!(fs, pg_field_info);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_text_format_data() {
|
||||
let schema = vec![
|
||||
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN),
|
||||
FieldInfo::new("bools".into(), None, None, Type::BOOL),
|
||||
FieldInfo::new("uint8s".into(), None, None, Type::CHAR),
|
||||
FieldInfo::new("uint16s".into(), None, None, Type::INT2),
|
||||
FieldInfo::new("uint32s".into(), None, None, Type::INT4),
|
||||
FieldInfo::new("uint64s".into(), None, None, Type::INT8),
|
||||
FieldInfo::new("int8s".into(), None, None, Type::CHAR),
|
||||
FieldInfo::new("int8s".into(), None, None, Type::CHAR),
|
||||
FieldInfo::new("int16s".into(), None, None, Type::INT2),
|
||||
FieldInfo::new("int16s".into(), None, None, Type::INT2),
|
||||
FieldInfo::new("int32s".into(), None, None, Type::INT4),
|
||||
FieldInfo::new("int32s".into(), None, None, Type::INT4),
|
||||
FieldInfo::new("int64s".into(), None, None, Type::INT8),
|
||||
FieldInfo::new("int64s".into(), None, None, Type::INT8),
|
||||
FieldInfo::new("float32s".into(), None, None, Type::FLOAT4),
|
||||
FieldInfo::new("float32s".into(), None, None, Type::FLOAT4),
|
||||
FieldInfo::new("float32s".into(), None, None, Type::FLOAT4),
|
||||
FieldInfo::new("float64s".into(), None, None, Type::FLOAT8),
|
||||
FieldInfo::new("float64s".into(), None, None, Type::FLOAT8),
|
||||
FieldInfo::new("float64s".into(), None, None, Type::FLOAT8),
|
||||
FieldInfo::new("strings".into(), None, None, Type::VARCHAR),
|
||||
FieldInfo::new("binaries".into(), None, None, Type::BYTEA),
|
||||
FieldInfo::new("dates".into(), None, None, Type::DATE),
|
||||
FieldInfo::new("datetimes".into(), None, None, Type::TIMESTAMP),
|
||||
FieldInfo::new("timestamps".into(), None, None, Type::TIMESTAMP),
|
||||
];
|
||||
|
||||
let values = vec![
|
||||
Value::Null,
|
||||
Value::Boolean(true),
|
||||
Value::UInt8(u8::MAX),
|
||||
Value::UInt16(u16::MAX),
|
||||
Value::UInt32(u32::MAX),
|
||||
Value::UInt64(u64::MAX),
|
||||
Value::Int8(i8::MAX),
|
||||
Value::Int8(i8::MIN),
|
||||
Value::Int16(i16::MAX),
|
||||
Value::Int16(i16::MIN),
|
||||
Value::Int32(i32::MAX),
|
||||
Value::Int32(i32::MIN),
|
||||
Value::Int64(i64::MAX),
|
||||
Value::Int64(i64::MIN),
|
||||
Value::Float32(f32::MAX.into()),
|
||||
Value::Float32(f32::MIN.into()),
|
||||
Value::Float32(0f32.into()),
|
||||
Value::Float64(f64::MAX.into()),
|
||||
Value::Float64(f64::MIN.into()),
|
||||
Value::Float64(0f64.into()),
|
||||
Value::String("greptime".into()),
|
||||
Value::Binary("greptime".as_bytes().into()),
|
||||
Value::Date(1001i32.into()),
|
||||
Value::DateTime(1000001i64.into()),
|
||||
Value::Timestamp(1000001i64.into()),
|
||||
];
|
||||
let mut builder = TextQueryResponseBuilder::new(schema);
|
||||
for i in values {
|
||||
assert!(encode_value(&i, &mut builder).is_ok());
|
||||
}
|
||||
|
||||
let err = encode_value(
|
||||
&Value::List(ListValue::new(
|
||||
Some(Box::new(vec![])),
|
||||
ConcreteDataType::int8_datatype(),
|
||||
)),
|
||||
&mut builder,
|
||||
)
|
||||
.unwrap_err();
|
||||
match err {
|
||||
PgWireError::ApiError(e) => {
|
||||
assert!(format!("{}", e).contains("Internal error:"));
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
4
src/servers/src/postgres/mod.rs
Normal file
4
src/servers/src/postgres/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
mod handler;
|
||||
mod server;
|
||||
|
||||
pub use server::PostgresServer;
|
||||
136
src/servers/src/postgres/server.rs
Normal file
136
src/servers/src/postgres/server.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_runtime::Runtime;
|
||||
use common_telemetry::logging::{error, info};
|
||||
use futures::future::AbortHandle;
|
||||
use futures::future::AbortRegistration;
|
||||
use futures::future::Abortable;
|
||||
use futures::StreamExt;
|
||||
use pgwire::api::auth::noop::NoopStartupHandler;
|
||||
use pgwire::tokio::process_socket;
|
||||
use snafu::prelude::*;
|
||||
use tokio;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_stream::wrappers::TcpListenerStream;
|
||||
|
||||
use crate::error::{self, Result};
|
||||
use crate::postgres::handler::PostgresServerHandler;
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
use crate::server::Server;
|
||||
|
||||
pub struct PostgresServer {
|
||||
// See MySQL module for usage of these types
|
||||
abort_handle: AbortHandle,
|
||||
abort_registration: Option<AbortRegistration>,
|
||||
|
||||
// A handle holding the TCP accepting task.
|
||||
join_handle: Option<JoinHandle<()>>,
|
||||
|
||||
auth_handler: Arc<NoopStartupHandler>,
|
||||
query_handler: Arc<PostgresServerHandler>,
|
||||
io_runtime: Arc<Runtime>,
|
||||
}
|
||||
|
||||
impl PostgresServer {
|
||||
/// Creates a new Postgres server with provided query_handler and async runtime
|
||||
pub fn new(query_handler: SqlQueryHandlerRef, io_runtime: Arc<Runtime>) -> PostgresServer {
|
||||
let (abort_handle, registration) = AbortHandle::new_pair();
|
||||
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
|
||||
let startup_handler = Arc::new(NoopStartupHandler);
|
||||
PostgresServer {
|
||||
abort_handle,
|
||||
abort_registration: Some(registration),
|
||||
join_handle: None,
|
||||
|
||||
auth_handler: startup_handler,
|
||||
query_handler: postgres_handler,
|
||||
|
||||
io_runtime,
|
||||
}
|
||||
}
|
||||
|
||||
async fn bind(addr: SocketAddr) -> Result<(TcpListenerStream, SocketAddr)> {
|
||||
let listener = tokio::net::TcpListener::bind(addr)
|
||||
.await
|
||||
.context(error::TokioIoSnafu {
|
||||
err_msg: format!("Failed to bind addr {}", addr),
|
||||
})?;
|
||||
// get actually bond addr in case input addr use port 0
|
||||
let addr = listener.local_addr()?;
|
||||
info!("Postgres server is bound to {}", addr);
|
||||
Ok((TcpListenerStream::new(listener), addr))
|
||||
}
|
||||
|
||||
fn accept(&self, accepting_stream: Abortable<TcpListenerStream>) -> impl Future<Output = ()> {
|
||||
let io_runtime = self.io_runtime.clone();
|
||||
let auth_handler = self.auth_handler.clone();
|
||||
let query_handler = self.query_handler.clone();
|
||||
|
||||
accepting_stream.for_each(move |tcp_stream| {
|
||||
let io_runtime = io_runtime.clone();
|
||||
let auth_handler = auth_handler.clone();
|
||||
let query_handler = query_handler.clone();
|
||||
|
||||
async move {
|
||||
match tcp_stream {
|
||||
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
|
||||
Ok(io_stream) => {
|
||||
io_runtime.spawn(async move {
|
||||
process_socket(
|
||||
io_stream,
|
||||
auth_handler.clone(),
|
||||
query_handler.clone(),
|
||||
query_handler.clone(),
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Server for PostgresServer {
|
||||
async fn shutdown(&mut self) -> Result<()> {
|
||||
match self.join_handle.take() {
|
||||
Some(join_handle) => {
|
||||
self.abort_handle.abort();
|
||||
|
||||
if let Err(error) = join_handle.await {
|
||||
// Couldn't use `error!(e; xxx)` as JoinError doesn't implement ErrorExt.
|
||||
error!(
|
||||
"Unexpected error during shutdown Postgres server, error: {}",
|
||||
error
|
||||
);
|
||||
} else {
|
||||
info!("Postgres server is shutdown.")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
None => error::InternalSnafu {
|
||||
err_msg: "Postgres server is not started.",
|
||||
}
|
||||
.fail()?,
|
||||
}
|
||||
}
|
||||
|
||||
async fn start(&mut self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
match self.abort_registration.take() {
|
||||
Some(registration) => {
|
||||
let (stream, listener) = Self::bind(listening).await?;
|
||||
let stream = Abortable::new(stream, registration);
|
||||
self.join_handle = Some(tokio::spawn(self.accept(stream)));
|
||||
Ok(listener)
|
||||
}
|
||||
None => error::InternalSnafu {
|
||||
err_msg: "Postgres server has been started.",
|
||||
}
|
||||
.fail()?,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,8 @@ use script::{
|
||||
engine::{CompileContext, EvalContext, Script, ScriptEngine},
|
||||
python::{PyEngine, PyScript},
|
||||
};
|
||||
#[cfg(feature = "postgres")]
|
||||
mod postgres;
|
||||
|
||||
struct DummyInstance {
|
||||
query_engine: QueryEngineRef,
|
||||
|
||||
169
src/servers/tests/postgres/mod.rs
Normal file
169
src/servers/tests/postgres/mod.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use servers::error::Result;
|
||||
use servers::postgres::PostgresServer;
|
||||
use servers::server::Server;
|
||||
use test_util::MemTable;
|
||||
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
|
||||
|
||||
use crate::create_testing_sql_query_handler;
|
||||
|
||||
fn create_postgres_server(table: MemTable) -> Result<Box<dyn Server>> {
|
||||
let query_handler = create_testing_sql_query_handler(table);
|
||||
let io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
.worker_threads(4)
|
||||
.thread_name("postgres-io-handlers")
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
Ok(Box::new(PostgresServer::new(query_handler, io_runtime)))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn test_start_postgres_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mut pg_server = create_postgres_server(table)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = pg_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let result = pg_server.start(listening).await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Postgres server has been started."));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shutdown_pg_server() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mut postgres_server = create_postgres_server(table)?;
|
||||
let result = postgres_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Postgres server is not started."));
|
||||
|
||||
let listening = "127.0.0.1:5432".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = postgres_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
|
||||
let mut join_handles = vec![];
|
||||
for _ in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port).await {
|
||||
Ok(connection) => {
|
||||
let rows = connection
|
||||
.simple_query("SELECT uint32s FROM numbers LIMIT 1")
|
||||
.await
|
||||
.unwrap();
|
||||
let result_text = unwrap_results(&rows)[0];
|
||||
let result: i32 = result_text.parse().unwrap();
|
||||
assert_eq!(result, 0);
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}))
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
let result = postgres_server.shutdown().await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
for handle in join_handles.iter_mut() {
|
||||
let result = handle.await.unwrap();
|
||||
assert!(result.is_err());
|
||||
let error = result.unwrap_err().to_string();
|
||||
assert!(error.contains("Connection refused") || error.contains("Connection reset by peer"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_query_pg_concurrently() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mut pg_server = create_postgres_server(table)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
|
||||
let threads = 4;
|
||||
let expect_executed_queries_per_worker = 300;
|
||||
let mut join_handles = vec![];
|
||||
for _i in 0..threads {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut client = create_connection(server_port).await.unwrap();
|
||||
|
||||
for _k in 0..expect_executed_queries_per_worker {
|
||||
let expected: u32 = rand.gen_range(0..100);
|
||||
let result: u32 = unwrap_results(
|
||||
client
|
||||
.simple_query(&format!(
|
||||
"SELECT uint32s FROM numbers WHERE uint32s = {}",
|
||||
expected
|
||||
))
|
||||
.await
|
||||
.unwrap()
|
||||
.as_ref(),
|
||||
)[0]
|
||||
.parse()
|
||||
.unwrap();
|
||||
assert_eq!(result, expected);
|
||||
|
||||
// 1/100 chance to reconnect
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
client = create_connection(server_port).await.unwrap();
|
||||
}
|
||||
}
|
||||
expect_executed_queries_per_worker
|
||||
}))
|
||||
}
|
||||
let mut total_pending_queries = threads * expect_executed_queries_per_worker;
|
||||
for handle in join_handles.iter_mut() {
|
||||
total_pending_queries -= handle.await.unwrap();
|
||||
}
|
||||
assert_eq!(0, total_pending_queries);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection(port: u16) -> std::result::Result<Client, PgError> {
|
||||
let url = format!("host=127.0.0.1 port={} connect_timeout=2", port);
|
||||
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
|
||||
tokio::spawn(conn);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
|
||||
match resp {
|
||||
&SimpleQueryMessage::Row(ref r) => r.get(col_index),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn unwrap_results(resp: &[SimpleQueryMessage]) -> Vec<&str> {
|
||||
resp.iter().filter_map(|m| resolve_result(m, 0)).collect()
|
||||
}
|
||||
Reference in New Issue
Block a user