feat: generating context in http middleware & mysql auth method (#453)

This commit is contained in:
shuiyisong
2022-11-14 17:24:11 +08:00
committed by GitHub
parent 7e49493e34
commit dcd5e34dbd
14 changed files with 422 additions and 146 deletions

View File

@@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize};
pub struct PostgresOptions {
pub addr: String,
pub runtime_size: usize,
pub check_pwd: bool,
}
impl Default for PostgresOptions {
@@ -11,6 +12,7 @@ impl Default for PostgresOptions {
Self {
addr: "0.0.0.0:4003".to_string(),
runtime_size: 2,
check_pwd: false,
}
}
}

View File

@@ -73,8 +73,11 @@ impl Services {
.context(error::RuntimeResourceSnafu)?,
);
let pg_server =
Box::new(PostgresServer::new(instance.clone(), pg_io_runtime)) as Box<dyn Server>;
let pg_server = Box::new(PostgresServer::new(
instance.clone(),
opts.check_pwd,
pg_io_runtime,
)) as Box<dyn Server>;
Some((pg_server, pg_addr))
} else {

View File

@@ -32,6 +32,7 @@ opensrv-mysql = "0.1"
pgwire = { version = "0.4" }
prost = "0.11"
regex = "1.6"
rand = "0.8"
schemars = "0.8"
serde = "1.0"
serde_json = "1.0"

View File

@@ -2,13 +2,13 @@ use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use snafu::OptionExt;
use crate::context::AuthMethod::Token;
use crate::context::Channel::HTTP;
use crate::error::{BuildingContextSnafu, Result};
type CtxFnRef = Arc<dyn Fn(&Context) -> bool + Send + Sync>;
#[derive(Default, Serialize, Deserialize)]
#[derive(Serialize, Deserialize)]
pub struct Context {
pub exec_info: ExecInfo,
pub client_info: ClientInfo,
@@ -19,16 +19,70 @@ pub struct Context {
}
impl Context {
pub fn new() -> Self {
Context::default()
}
pub fn add_predicate(&mut self, predicate: CtxFnRef) {
self.predicates.push(predicate);
}
}
#[derive(Default, Serialize, Deserialize)]
#[derive(Default)]
pub struct CtxBuilder {
client_addr: Option<String>,
username: Option<String>,
from_channel: Option<Channel>,
auth_method: Option<AuthMethod>,
}
impl CtxBuilder {
pub fn new() -> CtxBuilder {
CtxBuilder::default()
}
pub fn client_addr(mut self, addr: Option<String>) -> CtxBuilder {
self.client_addr = addr;
self
}
pub fn set_channel(mut self, channel: Option<Channel>) -> CtxBuilder {
self.from_channel = channel;
self
}
pub fn set_auth_method(mut self, auth_method: Option<AuthMethod>) -> CtxBuilder {
self.auth_method = auth_method;
self
}
pub fn set_username(mut self, username: Option<String>) -> CtxBuilder {
self.username = username;
self
}
pub fn build(self) -> Result<Context> {
Ok(Context {
client_info: ClientInfo {
client_host: self.client_addr.context(BuildingContextSnafu {
err_msg: "unknown client addr while building ctx",
})?,
},
user_info: UserInfo {
username: self.username,
from_channel: self.from_channel.context(BuildingContextSnafu {
err_msg: "unknown channel while building ctx",
})?,
auth_method: self.auth_method.context(BuildingContextSnafu {
err_msg: "unknown auth method while building ctx",
})?,
},
exec_info: ExecInfo::default(),
quota: Quota::default(),
predicates: vec![],
})
}
}
#[derive(Serialize, Deserialize)]
pub struct ExecInfo {
pub catalog: Option<String>,
pub schema: Option<String>,
@@ -37,34 +91,29 @@ pub struct ExecInfo {
pub trace_id: Option<String>,
}
#[derive(Default, Serialize, Deserialize)]
pub struct ClientInfo {
pub client_host: Option<String>,
}
impl ClientInfo {
pub fn new(host: Option<String>) -> Self {
ClientInfo { client_host: host }
}
}
#[derive(Default, Serialize, Deserialize)]
pub struct UserInfo {
pub username: Option<String>,
pub from_channel: Option<Channel>,
pub auth_method: Option<AuthMethod>,
}
impl UserInfo {
pub fn with_http_token(token: String) -> Self {
UserInfo {
username: None,
from_channel: Some(HTTP),
auth_method: Some(Token(token)),
impl Default for ExecInfo {
fn default() -> Self {
ExecInfo {
catalog: Some("greptime".to_string()),
schema: Some("public".to_string()),
extra_opts: HashMap::new(),
trace_id: None,
}
}
}
#[derive(Default, Serialize, Deserialize)]
pub struct ClientInfo {
pub client_host: String,
}
#[derive(Serialize, Deserialize)]
pub struct UserInfo {
pub username: Option<String>,
pub from_channel: Channel,
pub auth_method: AuthMethod,
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Channel {
GRPC,
@@ -78,10 +127,17 @@ pub enum AuthMethod {
Password {
hash_method: AuthHashMethod,
hashed_value: Vec<u8>,
salt: Vec<u8>,
},
Token(String),
}
impl Default for AuthMethod {
fn default() -> Self {
AuthMethod::None
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum AuthHashMethod {
DoubleSha1,
@@ -97,16 +153,26 @@ pub struct Quota {
#[cfg(test)]
mod test {
use std::collections::HashMap;
use std::sync::Arc;
use crate::context::AuthMethod::Token;
use crate::context::Channel::HTTP;
use crate::context::{ClientInfo, Context, ExecInfo, Quota, UserInfo};
use crate::context::{Channel, Context, CtxBuilder, UserInfo};
#[test]
fn test_predicate() {
let mut ctx = Context::default();
let mut ctx = Context {
exec_info: Default::default(),
client_info: Default::default(),
user_info: UserInfo {
username: None,
from_channel: Channel::GRPC,
auth_method: Default::default(),
},
quota: Default::default(),
predicates: vec![],
};
ctx.add_predicate(Arc::new(|ctx: &Context| {
ctx.quota.total > ctx.quota.consumed
}));
@@ -123,43 +189,27 @@ mod test {
#[test]
fn test_build() {
let ctx = Context {
exec_info: ExecInfo {
catalog: Some(String::from("greptime")),
schema: Some(String::from("public")),
extra_opts: HashMap::new(),
trace_id: None,
},
client_info: ClientInfo::new(Some(String::from("127.0.0.1:4001"))),
user_info: UserInfo::with_http_token(String::from("HELLO")),
quota: Quota {
total: 10,
consumed: 5,
estimated: 2,
},
predicates: vec![],
};
let ctx = CtxBuilder::new()
.client_addr(Some("127.0.0.1:4001".to_string()))
.set_channel(Some(HTTP))
.set_auth_method(Some(Token("HELLO".to_string())))
.build()
.unwrap();
assert_eq!(ctx.exec_info.catalog.unwrap(), String::from("greptime"));
assert_eq!(ctx.exec_info.schema.unwrap(), String::from("public"));
assert_eq!(ctx.exec_info.extra_opts.capacity(), 0);
assert_eq!(ctx.exec_info.extra_opts.len(), 0);
assert_eq!(ctx.exec_info.trace_id, None);
assert_eq!(
ctx.client_info.client_host.unwrap(),
String::from("127.0.0.1:4001")
);
assert_eq!(ctx.client_info.client_host, String::from("127.0.0.1:4001"));
assert_eq!(ctx.user_info.username, None);
assert_eq!(ctx.user_info.from_channel.unwrap(), HTTP);
assert_eq!(
ctx.user_info.auth_method.unwrap(),
Token(String::from("HELLO"))
);
assert_eq!(ctx.user_info.from_channel, HTTP);
assert_eq!(ctx.user_info.auth_method, Token(String::from("HELLO")));
assert!(ctx.quota.total > 0);
assert!(ctx.quota.consumed > 0);
assert!(ctx.quota.estimated > 0);
assert_eq!(ctx.quota.total, 0);
assert_eq!(ctx.quota.consumed, 0);
assert_eq!(ctx.quota.estimated, 0);
assert_eq!(ctx.predicates.capacity(), 0);
}

View File

@@ -174,6 +174,12 @@ pub enum Error {
#[snafu(backtrace)]
source: BoxedError,
},
#[snafu(display("Failed to build context, msg: {}", err_msg))]
BuildingContext {
err_msg: String,
backtrace: Backtrace,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -192,7 +198,8 @@ impl ErrorExt for Error {
| AlreadyStarted { .. }
| InvalidPromRemoteReadQueryResult { .. }
| TcpBind { .. }
| GrpcReflectionService { .. } => StatusCode::Internal,
| GrpcReflectionService { .. }
| BuildingContext { .. } => StatusCode::Internal,
InsertScript { source, .. }
| ExecuteScript { source, .. }

View File

@@ -1,3 +1,4 @@
mod context;
pub mod handler;
pub mod influxdb;
pub mod opentsdb;
@@ -11,6 +12,7 @@ use aide::axum::routing as apirouting;
use aide::axum::{ApiRouter, IntoApiResponse};
use aide::openapi::{Info, OpenApi, Server as OpenAPIServer};
use async_trait::async_trait;
use axum::middleware::{self};
use axum::response::Html;
use axum::Extension;
use axum::{error_handling::HandleErrorLayer, response::Json, routing, BoxError, Router};
@@ -313,7 +315,9 @@ impl HttpServer {
.layer(HandleErrorLayer::new(handle_error))
.layer(TraceLayer::new_for_http())
// TODO(LFC): make timeout configurable
.layer(TimeoutLayer::new(Duration::from_secs(30))),
.layer(TimeoutLayer::new(Duration::from_secs(30)))
// custom layer
.layer(middleware::from_fn(context::build_ctx)),
)
}
}

View File

@@ -0,0 +1,48 @@
use axum::{
http,
http::{Request, StatusCode},
middleware::Next,
response::Response,
};
use common_telemetry::error;
use crate::context::{AuthMethod, Channel, CtxBuilder};
pub async fn build_ctx<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
let auth_option = req
.headers()
.get(http::header::AUTHORIZATION)
.map(|header| {
header
.to_str()
.map(|header_str| match header_str.split_once(' ') {
Some((name, content)) if name == "Bearer" || name == "TOKEN" => {
AuthMethod::Token(String::from(content))
}
_ => AuthMethod::None,
})
.unwrap_or(AuthMethod::None)
})
.or(Some(AuthMethod::None));
match CtxBuilder::new()
.client_addr(
req.headers()
.get(http::header::HOST)
.and_then(|h| h.to_str().ok())
.map(|h| h.to_string()),
)
.set_channel(Some(Channel::HTTP))
.set_auth_method(auth_option)
.build()
{
Ok(ctx) => {
req.extensions_mut().insert(ctx);
Ok(next.run(req).await)
}
Err(e) => {
error!(e; "fail to create context");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}

View File

@@ -1,12 +1,19 @@
use std::io;
use std::sync::Arc;
use async_trait::async_trait;
use common_telemetry::error;
use opensrv_mysql::AsyncMysqlShim;
use opensrv_mysql::ErrorKind;
use opensrv_mysql::ParamParser;
use opensrv_mysql::QueryResultWriter;
use opensrv_mysql::StatementMetaWriter;
use rand::RngCore;
use tokio::sync::RwLock;
use crate::context::AuthHashMethod::DoubleSha1;
use crate::context::Channel::MYSQL;
use crate::context::{AuthMethod, Context, CtxBuilder};
use crate::error::{self, Result};
use crate::mysql::writer::MysqlResultWriter;
use crate::query_handler::SqlQueryHandlerRef;
@@ -14,11 +21,32 @@ use crate::query_handler::SqlQueryHandlerRef;
// An intermediate shim for executing MySQL queries.
pub struct MysqlInstanceShim {
query_handler: SqlQueryHandlerRef,
salt: [u8; 20],
client_addr: String,
ctx: Arc<RwLock<Option<Context>>>,
}
impl MysqlInstanceShim {
pub fn create(query_handler: SqlQueryHandlerRef) -> MysqlInstanceShim {
MysqlInstanceShim { query_handler }
pub fn create(query_handler: SqlQueryHandlerRef, client_addr: String) -> MysqlInstanceShim {
// init a random salt
let mut bs = vec![0u8; 20];
let mut rng = rand::thread_rng();
rng.fill_bytes(bs.as_mut());
let mut scramble: [u8; 20] = [0; 20];
for i in 0..20 {
scramble[i] = bs[i] & 0x7fu8;
if scramble[i] == b'\0' || scramble[i] == b'$' {
scramble[i] += 1;
}
}
MysqlInstanceShim {
query_handler,
salt: scramble,
client_addr,
ctx: Arc::new(RwLock::new(None)),
}
}
}
@@ -26,6 +54,48 @@ impl MysqlInstanceShim {
impl<W: io::Write + Send + Sync> AsyncMysqlShim<W> for MysqlInstanceShim {
type Error = error::Error;
fn salt(&self) -> [u8; 20] {
self.salt
}
async fn authenticate(
&self,
_auth_plugin: &str,
username: &[u8],
salt: &[u8],
auth_data: &[u8],
) -> bool {
// if not specified then **root** will be used
let username = String::from_utf8_lossy(username);
let client_addr = self.client_addr.clone();
let auth_method = match auth_data.len() {
0 => AuthMethod::None,
_ => AuthMethod::Password {
hash_method: DoubleSha1,
hashed_value: auth_data.to_vec(),
salt: salt.to_vec(),
},
};
return match CtxBuilder::new()
.client_addr(Some(client_addr))
.set_channel(Some(MYSQL))
.set_username(Some(username.to_string()))
.set_auth_method(Some(auth_method))
.build()
{
Ok(ctx) => {
let mut a = self.ctx.write().await;
*a = Some(ctx);
true
}
Err(e) => {
error!(e; "create ctx failed when authing mysql conn");
false
}
};
}
async fn on_prepare<'a>(
&'a mut self,
_: &'a str,

View File

@@ -59,7 +59,7 @@ impl MysqlServer {
query_handler: SqlQueryHandlerRef,
) -> Result<()> {
info!("MySQL connection coming from: {}", stream.peer_addr()?);
let shim = MysqlInstanceShim::create(query_handler);
let shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string());
// TODO(LFC): Relate "handler" with MySQL session; also deal with panics there.
let _handler = io_runtime.spawn(AsyncMysqlIntermediary::run_on(shim, stream));
Ok(())

View File

@@ -0,0 +1,110 @@
use std::collections::HashMap;
use std::fmt::Debug;
use async_trait::async_trait;
use futures::{Sink, SinkExt};
use pgwire::api::auth::{ServerParameterProvider, StartupHandler};
use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
use pgwire::error::ErrorInfo;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
struct PgPwdVerifier;
impl PgPwdVerifier {
async fn verify_pwd(&self, _pwd: &str, _meta: HashMap<String, String>) -> PgWireResult<bool> {
Ok(true)
}
}
struct GreptimeDBStartupParameters {
version: &'static str,
}
impl GreptimeDBStartupParameters {
fn new() -> GreptimeDBStartupParameters {
GreptimeDBStartupParameters {
version: env!("CARGO_PKG_VERSION"),
}
}
}
impl ServerParameterProvider for GreptimeDBStartupParameters {
fn server_parameters<C>(&self, _client: &C) -> Option<HashMap<String, String>>
where
C: ClientInfo,
{
let mut params = HashMap::with_capacity(1);
params.insert("server_version".to_owned(), self.version.to_owned());
Some(params)
}
}
pub struct PgAuthStartupHandler {
verifier: PgPwdVerifier,
param_provider: GreptimeDBStartupParameters,
with_pwd: bool,
}
impl PgAuthStartupHandler {
pub fn new(with_pwd: bool) -> Self {
PgAuthStartupHandler {
verifier: PgPwdVerifier,
param_provider: GreptimeDBStartupParameters::new(),
with_pwd,
}
}
}
#[async_trait]
impl StartupHandler for PgAuthStartupHandler {
async fn on_startup<C>(
&self,
client: &mut C,
message: &PgWireFrontendMessage,
) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
match message {
PgWireFrontendMessage::Startup(ref startup) => {
auth::save_startup_parameters_to_metadata(client, startup);
if self.with_pwd {
client.set_state(PgWireConnectionState::AuthenticationInProgress);
client
.send(PgWireBackendMessage::Authentication(
Authentication::CleartextPassword,
))
.await?;
} else {
auth::finish_authentication(client, &self.param_provider).await;
}
}
PgWireFrontendMessage::Password(ref pwd) => {
let meta = client.metadata().clone();
if let Ok(true) = self.verifier.verify_pwd(pwd.password(), meta).await {
auth::finish_authentication(client, &self.param_provider).await
} else {
let error_info = ErrorInfo::new(
"FATAL".to_owned(),
"28P01".to_owned(),
"Password authentication failed".to_owned(),
);
let error = ErrorResponse::from(error_info);
client
.feed(PgWireBackendMessage::ErrorResponse(error))
.await?;
client.close().await?;
}
}
_ => {}
}
Ok(())
}
}

View File

@@ -1,3 +1,4 @@
mod auth_handler;
mod handler;
mod server;

View File

@@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
@@ -6,77 +5,31 @@ use std::sync::Arc;
use async_trait::async_trait;
use common_runtime::Runtime;
use common_telemetry::logging::error;
use futures::{Sink, StreamExt};
use pgwire::api::auth::{self, ServerParameterProvider, StartupHandler};
use pgwire::api::ClientInfo;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use futures::StreamExt;
use pgwire::tokio::process_socket;
use tokio;
use crate::error::Result;
use crate::postgres::auth_handler::PgAuthStartupHandler;
use crate::postgres::handler::PostgresServerHandler;
use crate::query_handler::SqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};
struct SimpleStartupHandler;
#[async_trait]
impl StartupHandler for SimpleStartupHandler {
async fn on_startup<C>(
&self,
client: &mut C,
message: &PgWireFrontendMessage,
) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: std::fmt::Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
if let PgWireFrontendMessage::Startup(ref startup) = message {
auth::save_startup_parameters_to_metadata(client, startup);
auth::finish_authentication(client, &GreptimeDBStartupParameters::new()).await;
}
Ok(())
}
}
struct GreptimeDBStartupParameters {
version: &'static str,
}
impl GreptimeDBStartupParameters {
fn new() -> GreptimeDBStartupParameters {
GreptimeDBStartupParameters {
version: env!("CARGO_PKG_VERSION"),
}
}
}
impl ServerParameterProvider for GreptimeDBStartupParameters {
fn server_parameters<C>(&self, _client: &C) -> Option<HashMap<String, String>>
where
C: ClientInfo,
{
let mut params = HashMap::with_capacity(1);
params.insert("server_version".to_owned(), self.version.to_owned());
Some(params)
}
}
pub struct PostgresServer {
base_server: BaseTcpServer,
auth_handler: Arc<SimpleStartupHandler>,
auth_handler: Arc<PgAuthStartupHandler>,
query_handler: Arc<PostgresServerHandler>,
}
impl PostgresServer {
/// Creates a new Postgres server with provided query_handler and async runtime
pub fn new(query_handler: SqlQueryHandlerRef, io_runtime: Arc<Runtime>) -> PostgresServer {
pub fn new(
query_handler: SqlQueryHandlerRef,
check_pwd: bool,
io_runtime: Arc<Runtime>,
) -> PostgresServer {
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
let startup_handler = Arc::new(SimpleStartupHandler);
let startup_handler = Arc::new(PgAuthStartupHandler::new(check_pwd));
PostgresServer {
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
auth_handler: startup_handler,

View File

@@ -63,10 +63,10 @@ async fn test_shutdown_mysql_server() -> Result<()> {
let server_port = server_addr.port();
let mut join_handles = vec![];
for _ in 0..2 {
for index in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port).await {
match create_connection(server_port, index == 1).await {
Ok(mut connection) => {
let result: u32 = connection
.query_first("SELECT uint32s FROM numbers LIMIT 1")
@@ -114,7 +114,7 @@ async fn test_query_all_datatypes() -> Result<()> {
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let mut connection = create_connection(server_addr.port()).await.unwrap();
let mut connection = create_connection(server_addr.port(), false).await.unwrap();
let mut result = connection
.query_iter("SELECT * FROM all_datatypes LIMIT 3")
.await
@@ -149,11 +149,13 @@ async fn test_query_concurrently() -> Result<()> {
let threads = 4;
let expect_executed_queries_per_worker = 1000;
let mut join_handles = vec![];
for _ in 0..threads {
for index in 0..threads {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut connection = create_connection(server_port).await.unwrap();
let mut connection = create_connection(server_port, index % 2 == 0)
.await
.unwrap();
for _ in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
let result: u32 = connection
@@ -168,7 +170,9 @@ async fn test_query_concurrently() -> Result<()> {
let should_recreate_conn = expected == 1;
if should_recreate_conn {
connection = create_connection(server_port).await.unwrap();
connection = create_connection(server_port, index % 2 == 0)
.await
.unwrap();
}
}
expect_executed_queries_per_worker
@@ -182,11 +186,16 @@ async fn test_query_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection(port: u16) -> mysql_async::Result<mysql_async::Conn> {
let opts = mysql_async::OptsBuilder::default()
async fn create_connection(port: u16, with_pwd: bool) -> mysql_async::Result<mysql_async::Conn> {
let mut opts = mysql_async::OptsBuilder::default()
.ip_or_hostname("127.0.0.1")
.tcp_port(port)
.prefer_socket(false)
.wait_timeout(Some(1000));
if with_pwd {
opts = opts.pass(Some("default_pwd".to_string()));
}
mysql_async::Conn::new(opts).await
}

View File

@@ -13,7 +13,7 @@ use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
use crate::create_testing_sql_query_handler;
fn create_postgres_server(table: MemTable) -> Result<Box<dyn Server>> {
fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Server>> {
let query_handler = create_testing_sql_query_handler(table);
let io_runtime = Arc::new(
RuntimeBuilder::default()
@@ -22,14 +22,18 @@ fn create_postgres_server(table: MemTable) -> Result<Box<dyn Server>> {
.build()
.unwrap(),
);
Ok(Box::new(PostgresServer::new(query_handler, io_runtime)))
Ok(Box::new(PostgresServer::new(
query_handler,
check_pwd,
io_runtime,
)))
}
#[tokio::test]
pub async fn test_start_postgres_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table)?;
let pg_server = create_postgres_server(table, false)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = pg_server.start(listening).await;
assert!(result.is_ok());
@@ -43,12 +47,19 @@ pub async fn test_start_postgres_server() -> Result<()> {
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shutdown_pg_server() -> Result<()> {
async fn test_shutdown_pg_server_range() -> Result<()> {
assert!(test_shutdown_pg_server(false).await.is_ok());
assert!(test_shutdown_pg_server(true).await.is_ok());
Ok(())
}
// #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let postgres_server = create_postgres_server(table)?;
let postgres_server = create_postgres_server(table, with_pwd)?;
let result = postgres_server.shutdown().await;
assert!(result
.unwrap_err()
@@ -63,7 +74,7 @@ async fn test_shutdown_pg_server() -> Result<()> {
for _ in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port).await {
match create_connection(server_port, with_pwd).await {
Ok(connection) => {
match connection
.simple_query("SELECT uint32s FROM numbers LIMIT 1")
@@ -107,7 +118,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table)?;
let pg_server = create_postgres_server(table, false)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
@@ -119,7 +130,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut client = create_connection(server_port).await.unwrap();
let mut client = create_connection(server_port, false).await.unwrap();
for _k in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
@@ -140,7 +151,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
// 1/100 chance to reconnect
let should_recreate_conn = expected == 1;
if should_recreate_conn {
client = create_connection(server_port).await.unwrap();
client = create_connection(server_port, false).await.unwrap();
}
}
expect_executed_queries_per_worker
@@ -154,8 +165,15 @@ async fn test_query_pg_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection(port: u16) -> std::result::Result<Client, PgError> {
let url = format!("host=127.0.0.1 port={} connect_timeout=2", port);
async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result<Client, PgError> {
let url = if with_pwd {
format!(
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
port
)
} else {
format!("host=127.0.0.1 port={} connect_timeout=2", port)
};
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
tokio::spawn(conn);
Ok(client)