diff --git a/src/frontend/src/postgres.rs b/src/frontend/src/postgres.rs index 1c949318f8..930764cc40 100644 --- a/src/frontend/src/postgres.rs +++ b/src/frontend/src/postgres.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; pub struct PostgresOptions { pub addr: String, pub runtime_size: usize, + pub check_pwd: bool, } impl Default for PostgresOptions { @@ -11,6 +12,7 @@ impl Default for PostgresOptions { Self { addr: "0.0.0.0:4003".to_string(), runtime_size: 2, + check_pwd: false, } } } diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 0efd828b8d..2b95278143 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -73,8 +73,11 @@ impl Services { .context(error::RuntimeResourceSnafu)?, ); - let pg_server = - Box::new(PostgresServer::new(instance.clone(), pg_io_runtime)) as Box; + let pg_server = Box::new(PostgresServer::new( + instance.clone(), + opts.check_pwd, + pg_io_runtime, + )) as Box; Some((pg_server, pg_addr)) } else { diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index e62b3abb09..022c3a647d 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -32,6 +32,7 @@ opensrv-mysql = "0.1" pgwire = { version = "0.4" } prost = "0.11" regex = "1.6" +rand = "0.8" schemars = "0.8" serde = "1.0" serde_json = "1.0" diff --git a/src/servers/src/context.rs b/src/servers/src/context.rs index 8d694d5097..5599e1a8e4 100644 --- a/src/servers/src/context.rs +++ b/src/servers/src/context.rs @@ -2,13 +2,13 @@ use std::collections::HashMap; use std::sync::Arc; use serde::{Deserialize, Serialize}; +use snafu::OptionExt; -use crate::context::AuthMethod::Token; -use crate::context::Channel::HTTP; +use crate::error::{BuildingContextSnafu, Result}; type CtxFnRef = Arc bool + Send + Sync>; -#[derive(Default, Serialize, Deserialize)] +#[derive(Serialize, Deserialize)] pub struct Context { pub exec_info: ExecInfo, pub client_info: ClientInfo, @@ -19,16 +19,70 @@ pub struct Context { } impl Context { - pub fn new() -> Self { - Context::default() - } - pub fn add_predicate(&mut self, predicate: CtxFnRef) { self.predicates.push(predicate); } } -#[derive(Default, Serialize, Deserialize)] +#[derive(Default)] +pub struct CtxBuilder { + client_addr: Option, + + username: Option, + from_channel: Option, + auth_method: Option, +} + +impl CtxBuilder { + pub fn new() -> CtxBuilder { + CtxBuilder::default() + } + + pub fn client_addr(mut self, addr: Option) -> CtxBuilder { + self.client_addr = addr; + self + } + + pub fn set_channel(mut self, channel: Option) -> CtxBuilder { + self.from_channel = channel; + self + } + + pub fn set_auth_method(mut self, auth_method: Option) -> CtxBuilder { + self.auth_method = auth_method; + self + } + + pub fn set_username(mut self, username: Option) -> CtxBuilder { + self.username = username; + self + } + + pub fn build(self) -> Result { + Ok(Context { + client_info: ClientInfo { + client_host: self.client_addr.context(BuildingContextSnafu { + err_msg: "unknown client addr while building ctx", + })?, + }, + user_info: UserInfo { + username: self.username, + from_channel: self.from_channel.context(BuildingContextSnafu { + err_msg: "unknown channel while building ctx", + })?, + auth_method: self.auth_method.context(BuildingContextSnafu { + err_msg: "unknown auth method while building ctx", + })?, + }, + + exec_info: ExecInfo::default(), + quota: Quota::default(), + predicates: vec![], + }) + } +} + +#[derive(Serialize, Deserialize)] pub struct ExecInfo { pub catalog: Option, pub schema: Option, @@ -37,34 +91,29 @@ pub struct ExecInfo { pub trace_id: Option, } -#[derive(Default, Serialize, Deserialize)] -pub struct ClientInfo { - pub client_host: Option, -} - -impl ClientInfo { - pub fn new(host: Option) -> Self { - ClientInfo { client_host: host } - } -} - -#[derive(Default, Serialize, Deserialize)] -pub struct UserInfo { - pub username: Option, - pub from_channel: Option, - pub auth_method: Option, -} - -impl UserInfo { - pub fn with_http_token(token: String) -> Self { - UserInfo { - username: None, - from_channel: Some(HTTP), - auth_method: Some(Token(token)), +impl Default for ExecInfo { + fn default() -> Self { + ExecInfo { + catalog: Some("greptime".to_string()), + schema: Some("public".to_string()), + extra_opts: HashMap::new(), + trace_id: None, } } } +#[derive(Default, Serialize, Deserialize)] +pub struct ClientInfo { + pub client_host: String, +} + +#[derive(Serialize, Deserialize)] +pub struct UserInfo { + pub username: Option, + pub from_channel: Channel, + pub auth_method: AuthMethod, +} + #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum Channel { GRPC, @@ -78,10 +127,17 @@ pub enum AuthMethod { Password { hash_method: AuthHashMethod, hashed_value: Vec, + salt: Vec, }, Token(String), } +impl Default for AuthMethod { + fn default() -> Self { + AuthMethod::None + } +} + #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum AuthHashMethod { DoubleSha1, @@ -97,16 +153,26 @@ pub struct Quota { #[cfg(test)] mod test { - use std::collections::HashMap; + use std::sync::Arc; use crate::context::AuthMethod::Token; use crate::context::Channel::HTTP; - use crate::context::{ClientInfo, Context, ExecInfo, Quota, UserInfo}; + use crate::context::{Channel, Context, CtxBuilder, UserInfo}; #[test] fn test_predicate() { - let mut ctx = Context::default(); + let mut ctx = Context { + exec_info: Default::default(), + client_info: Default::default(), + user_info: UserInfo { + username: None, + from_channel: Channel::GRPC, + auth_method: Default::default(), + }, + quota: Default::default(), + predicates: vec![], + }; ctx.add_predicate(Arc::new(|ctx: &Context| { ctx.quota.total > ctx.quota.consumed })); @@ -123,43 +189,27 @@ mod test { #[test] fn test_build() { - let ctx = Context { - exec_info: ExecInfo { - catalog: Some(String::from("greptime")), - schema: Some(String::from("public")), - extra_opts: HashMap::new(), - trace_id: None, - }, - client_info: ClientInfo::new(Some(String::from("127.0.0.1:4001"))), - user_info: UserInfo::with_http_token(String::from("HELLO")), - quota: Quota { - total: 10, - consumed: 5, - estimated: 2, - }, - predicates: vec![], - }; + let ctx = CtxBuilder::new() + .client_addr(Some("127.0.0.1:4001".to_string())) + .set_channel(Some(HTTP)) + .set_auth_method(Some(Token("HELLO".to_string()))) + .build() + .unwrap(); assert_eq!(ctx.exec_info.catalog.unwrap(), String::from("greptime")); assert_eq!(ctx.exec_info.schema.unwrap(), String::from("public")); - assert_eq!(ctx.exec_info.extra_opts.capacity(), 0); + assert_eq!(ctx.exec_info.extra_opts.len(), 0); assert_eq!(ctx.exec_info.trace_id, None); - assert_eq!( - ctx.client_info.client_host.unwrap(), - String::from("127.0.0.1:4001") - ); + assert_eq!(ctx.client_info.client_host, String::from("127.0.0.1:4001")); assert_eq!(ctx.user_info.username, None); - assert_eq!(ctx.user_info.from_channel.unwrap(), HTTP); - assert_eq!( - ctx.user_info.auth_method.unwrap(), - Token(String::from("HELLO")) - ); + assert_eq!(ctx.user_info.from_channel, HTTP); + assert_eq!(ctx.user_info.auth_method, Token(String::from("HELLO"))); - assert!(ctx.quota.total > 0); - assert!(ctx.quota.consumed > 0); - assert!(ctx.quota.estimated > 0); + assert_eq!(ctx.quota.total, 0); + assert_eq!(ctx.quota.consumed, 0); + assert_eq!(ctx.quota.estimated, 0); assert_eq!(ctx.predicates.capacity(), 0); } diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 2702232617..9c144f0d09 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -174,6 +174,12 @@ pub enum Error { #[snafu(backtrace)] source: BoxedError, }, + + #[snafu(display("Failed to build context, msg: {}", err_msg))] + BuildingContext { + err_msg: String, + backtrace: Backtrace, + }, } pub type Result = std::result::Result; @@ -192,7 +198,8 @@ impl ErrorExt for Error { | AlreadyStarted { .. } | InvalidPromRemoteReadQueryResult { .. } | TcpBind { .. } - | GrpcReflectionService { .. } => StatusCode::Internal, + | GrpcReflectionService { .. } + | BuildingContext { .. } => StatusCode::Internal, InsertScript { source, .. } | ExecuteScript { source, .. } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 5c0d6261e2..a2491dec92 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -1,3 +1,4 @@ +mod context; pub mod handler; pub mod influxdb; pub mod opentsdb; @@ -11,6 +12,7 @@ use aide::axum::routing as apirouting; use aide::axum::{ApiRouter, IntoApiResponse}; use aide::openapi::{Info, OpenApi, Server as OpenAPIServer}; use async_trait::async_trait; +use axum::middleware::{self}; use axum::response::Html; use axum::Extension; use axum::{error_handling::HandleErrorLayer, response::Json, routing, BoxError, Router}; @@ -313,7 +315,9 @@ impl HttpServer { .layer(HandleErrorLayer::new(handle_error)) .layer(TraceLayer::new_for_http()) // TODO(LFC): make timeout configurable - .layer(TimeoutLayer::new(Duration::from_secs(30))), + .layer(TimeoutLayer::new(Duration::from_secs(30))) + // custom layer + .layer(middleware::from_fn(context::build_ctx)), ) } } diff --git a/src/servers/src/http/context.rs b/src/servers/src/http/context.rs new file mode 100644 index 0000000000..dd1a84315d --- /dev/null +++ b/src/servers/src/http/context.rs @@ -0,0 +1,48 @@ +use axum::{ + http, + http::{Request, StatusCode}, + middleware::Next, + response::Response, +}; +use common_telemetry::error; + +use crate::context::{AuthMethod, Channel, CtxBuilder}; + +pub async fn build_ctx(mut req: Request, next: Next) -> Result { + let auth_option = req + .headers() + .get(http::header::AUTHORIZATION) + .map(|header| { + header + .to_str() + .map(|header_str| match header_str.split_once(' ') { + Some((name, content)) if name == "Bearer" || name == "TOKEN" => { + AuthMethod::Token(String::from(content)) + } + _ => AuthMethod::None, + }) + .unwrap_or(AuthMethod::None) + }) + .or(Some(AuthMethod::None)); + + match CtxBuilder::new() + .client_addr( + req.headers() + .get(http::header::HOST) + .and_then(|h| h.to_str().ok()) + .map(|h| h.to_string()), + ) + .set_channel(Some(Channel::HTTP)) + .set_auth_method(auth_option) + .build() + { + Ok(ctx) => { + req.extensions_mut().insert(ctx); + Ok(next.run(req).await) + } + Err(e) => { + error!(e; "fail to create context"); + Err(StatusCode::INTERNAL_SERVER_ERROR) + } + } +} diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 74083de14e..620fd799ba 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -1,12 +1,19 @@ use std::io; +use std::sync::Arc; use async_trait::async_trait; +use common_telemetry::error; use opensrv_mysql::AsyncMysqlShim; use opensrv_mysql::ErrorKind; use opensrv_mysql::ParamParser; use opensrv_mysql::QueryResultWriter; use opensrv_mysql::StatementMetaWriter; +use rand::RngCore; +use tokio::sync::RwLock; +use crate::context::AuthHashMethod::DoubleSha1; +use crate::context::Channel::MYSQL; +use crate::context::{AuthMethod, Context, CtxBuilder}; use crate::error::{self, Result}; use crate::mysql::writer::MysqlResultWriter; use crate::query_handler::SqlQueryHandlerRef; @@ -14,11 +21,32 @@ use crate::query_handler::SqlQueryHandlerRef; // An intermediate shim for executing MySQL queries. pub struct MysqlInstanceShim { query_handler: SqlQueryHandlerRef, + salt: [u8; 20], + client_addr: String, + ctx: Arc>>, } impl MysqlInstanceShim { - pub fn create(query_handler: SqlQueryHandlerRef) -> MysqlInstanceShim { - MysqlInstanceShim { query_handler } + pub fn create(query_handler: SqlQueryHandlerRef, client_addr: String) -> MysqlInstanceShim { + // init a random salt + let mut bs = vec![0u8; 20]; + let mut rng = rand::thread_rng(); + rng.fill_bytes(bs.as_mut()); + + let mut scramble: [u8; 20] = [0; 20]; + for i in 0..20 { + scramble[i] = bs[i] & 0x7fu8; + if scramble[i] == b'\0' || scramble[i] == b'$' { + scramble[i] += 1; + } + } + + MysqlInstanceShim { + query_handler, + salt: scramble, + client_addr, + ctx: Arc::new(RwLock::new(None)), + } } } @@ -26,6 +54,48 @@ impl MysqlInstanceShim { impl AsyncMysqlShim for MysqlInstanceShim { type Error = error::Error; + fn salt(&self) -> [u8; 20] { + self.salt + } + + async fn authenticate( + &self, + _auth_plugin: &str, + username: &[u8], + salt: &[u8], + auth_data: &[u8], + ) -> bool { + // if not specified then **root** will be used + let username = String::from_utf8_lossy(username); + let client_addr = self.client_addr.clone(); + let auth_method = match auth_data.len() { + 0 => AuthMethod::None, + _ => AuthMethod::Password { + hash_method: DoubleSha1, + hashed_value: auth_data.to_vec(), + salt: salt.to_vec(), + }, + }; + + return match CtxBuilder::new() + .client_addr(Some(client_addr)) + .set_channel(Some(MYSQL)) + .set_username(Some(username.to_string())) + .set_auth_method(Some(auth_method)) + .build() + { + Ok(ctx) => { + let mut a = self.ctx.write().await; + *a = Some(ctx); + true + } + Err(e) => { + error!(e; "create ctx failed when authing mysql conn"); + false + } + }; + } + async fn on_prepare<'a>( &'a mut self, _: &'a str, diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index 5cfa2c279a..03159f3181 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -59,7 +59,7 @@ impl MysqlServer { query_handler: SqlQueryHandlerRef, ) -> Result<()> { info!("MySQL connection coming from: {}", stream.peer_addr()?); - let shim = MysqlInstanceShim::create(query_handler); + let shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string()); // TODO(LFC): Relate "handler" with MySQL session; also deal with panics there. let _handler = io_runtime.spawn(AsyncMysqlIntermediary::run_on(shim, stream)); Ok(()) diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs new file mode 100644 index 0000000000..46cf8f6ac8 --- /dev/null +++ b/src/servers/src/postgres/auth_handler.rs @@ -0,0 +1,110 @@ +use std::collections::HashMap; +use std::fmt::Debug; + +use async_trait::async_trait; +use futures::{Sink, SinkExt}; +use pgwire::api::auth::{ServerParameterProvider, StartupHandler}; +use pgwire::api::{auth, ClientInfo, PgWireConnectionState}; +use pgwire::error::ErrorInfo; +use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::messages::response::ErrorResponse; +use pgwire::messages::startup::Authentication; +use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; + +struct PgPwdVerifier; + +impl PgPwdVerifier { + async fn verify_pwd(&self, _pwd: &str, _meta: HashMap) -> PgWireResult { + Ok(true) + } +} + +struct GreptimeDBStartupParameters { + version: &'static str, +} + +impl GreptimeDBStartupParameters { + fn new() -> GreptimeDBStartupParameters { + GreptimeDBStartupParameters { + version: env!("CARGO_PKG_VERSION"), + } + } +} + +impl ServerParameterProvider for GreptimeDBStartupParameters { + fn server_parameters(&self, _client: &C) -> Option> + where + C: ClientInfo, + { + let mut params = HashMap::with_capacity(1); + params.insert("server_version".to_owned(), self.version.to_owned()); + + Some(params) + } +} + +pub struct PgAuthStartupHandler { + verifier: PgPwdVerifier, + param_provider: GreptimeDBStartupParameters, + with_pwd: bool, +} + +impl PgAuthStartupHandler { + pub fn new(with_pwd: bool) -> Self { + PgAuthStartupHandler { + verifier: PgPwdVerifier, + param_provider: GreptimeDBStartupParameters::new(), + with_pwd, + } + } +} + +#[async_trait] +impl StartupHandler for PgAuthStartupHandler { + async fn on_startup( + &self, + client: &mut C, + message: &PgWireFrontendMessage, + ) -> PgWireResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + C::Error: Debug, + PgWireError: From<>::Error>, + { + match message { + PgWireFrontendMessage::Startup(ref startup) => { + auth::save_startup_parameters_to_metadata(client, startup); + if self.with_pwd { + client.set_state(PgWireConnectionState::AuthenticationInProgress); + client + .send(PgWireBackendMessage::Authentication( + Authentication::CleartextPassword, + )) + .await?; + } else { + auth::finish_authentication(client, &self.param_provider).await; + } + } + PgWireFrontendMessage::Password(ref pwd) => { + let meta = client.metadata().clone(); + if let Ok(true) = self.verifier.verify_pwd(pwd.password(), meta).await { + auth::finish_authentication(client, &self.param_provider).await + } else { + let error_info = ErrorInfo::new( + "FATAL".to_owned(), + "28P01".to_owned(), + "Password authentication failed".to_owned(), + ); + let error = ErrorResponse::from(error_info); + + client + .feed(PgWireBackendMessage::ErrorResponse(error)) + .await?; + client.close().await?; + } + } + _ => {} + } + Ok(()) + } +} diff --git a/src/servers/src/postgres/mod.rs b/src/servers/src/postgres/mod.rs index b7e04fe869..fe38358555 100644 --- a/src/servers/src/postgres/mod.rs +++ b/src/servers/src/postgres/mod.rs @@ -1,3 +1,4 @@ +mod auth_handler; mod handler; mod server; diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 3df4b5af1b..4394e03851 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::future::Future; use std::net::SocketAddr; use std::sync::Arc; @@ -6,77 +5,31 @@ use std::sync::Arc; use async_trait::async_trait; use common_runtime::Runtime; use common_telemetry::logging::error; -use futures::{Sink, StreamExt}; -use pgwire::api::auth::{self, ServerParameterProvider, StartupHandler}; -use pgwire::api::ClientInfo; -use pgwire::error::{PgWireError, PgWireResult}; -use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; +use futures::StreamExt; use pgwire::tokio::process_socket; use tokio; use crate::error::Result; +use crate::postgres::auth_handler::PgAuthStartupHandler; use crate::postgres::handler::PostgresServerHandler; use crate::query_handler::SqlQueryHandlerRef; use crate::server::{AbortableStream, BaseTcpServer, Server}; -struct SimpleStartupHandler; - -#[async_trait] -impl StartupHandler for SimpleStartupHandler { - async fn on_startup( - &self, - client: &mut C, - message: &PgWireFrontendMessage, - ) -> PgWireResult<()> - where - C: ClientInfo + Sink + Unpin + Send, - C::Error: std::fmt::Debug, - PgWireError: From<>::Error>, - { - if let PgWireFrontendMessage::Startup(ref startup) = message { - auth::save_startup_parameters_to_metadata(client, startup); - auth::finish_authentication(client, &GreptimeDBStartupParameters::new()).await; - } - - Ok(()) - } -} - -struct GreptimeDBStartupParameters { - version: &'static str, -} - -impl GreptimeDBStartupParameters { - fn new() -> GreptimeDBStartupParameters { - GreptimeDBStartupParameters { - version: env!("CARGO_PKG_VERSION"), - } - } -} - -impl ServerParameterProvider for GreptimeDBStartupParameters { - fn server_parameters(&self, _client: &C) -> Option> - where - C: ClientInfo, - { - let mut params = HashMap::with_capacity(1); - params.insert("server_version".to_owned(), self.version.to_owned()); - - Some(params) - } -} - pub struct PostgresServer { base_server: BaseTcpServer, - auth_handler: Arc, + auth_handler: Arc, query_handler: Arc, } impl PostgresServer { /// Creates a new Postgres server with provided query_handler and async runtime - pub fn new(query_handler: SqlQueryHandlerRef, io_runtime: Arc) -> PostgresServer { + pub fn new( + query_handler: SqlQueryHandlerRef, + check_pwd: bool, + io_runtime: Arc, + ) -> PostgresServer { let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler)); - let startup_handler = Arc::new(SimpleStartupHandler); + let startup_handler = Arc::new(PgAuthStartupHandler::new(check_pwd)); PostgresServer { base_server: BaseTcpServer::create_server("Postgres", io_runtime), auth_handler: startup_handler, diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 6cdea2509f..1c3f365d8a 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -63,10 +63,10 @@ async fn test_shutdown_mysql_server() -> Result<()> { let server_port = server_addr.port(); let mut join_handles = vec![]; - for _ in 0..2 { + for index in 0..2 { join_handles.push(tokio::spawn(async move { for _ in 0..1000 { - match create_connection(server_port).await { + match create_connection(server_port, index == 1).await { Ok(mut connection) => { let result: u32 = connection .query_first("SELECT uint32s FROM numbers LIMIT 1") @@ -114,7 +114,7 @@ async fn test_query_all_datatypes() -> Result<()> { let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); - let mut connection = create_connection(server_addr.port()).await.unwrap(); + let mut connection = create_connection(server_addr.port(), false).await.unwrap(); let mut result = connection .query_iter("SELECT * FROM all_datatypes LIMIT 3") .await @@ -149,11 +149,13 @@ async fn test_query_concurrently() -> Result<()> { let threads = 4; let expect_executed_queries_per_worker = 1000; let mut join_handles = vec![]; - for _ in 0..threads { + for index in 0..threads { join_handles.push(tokio::spawn(async move { let mut rand: StdRng = rand::SeedableRng::from_entropy(); - let mut connection = create_connection(server_port).await.unwrap(); + let mut connection = create_connection(server_port, index % 2 == 0) + .await + .unwrap(); for _ in 0..expect_executed_queries_per_worker { let expected: u32 = rand.gen_range(0..100); let result: u32 = connection @@ -168,7 +170,9 @@ async fn test_query_concurrently() -> Result<()> { let should_recreate_conn = expected == 1; if should_recreate_conn { - connection = create_connection(server_port).await.unwrap(); + connection = create_connection(server_port, index % 2 == 0) + .await + .unwrap(); } } expect_executed_queries_per_worker @@ -182,11 +186,16 @@ async fn test_query_concurrently() -> Result<()> { Ok(()) } -async fn create_connection(port: u16) -> mysql_async::Result { - let opts = mysql_async::OptsBuilder::default() +async fn create_connection(port: u16, with_pwd: bool) -> mysql_async::Result { + let mut opts = mysql_async::OptsBuilder::default() .ip_or_hostname("127.0.0.1") .tcp_port(port) .prefer_socket(false) .wait_timeout(Some(1000)); + + if with_pwd { + opts = opts.pass(Some("default_pwd".to_string())); + } + mysql_async::Conn::new(opts).await } diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index 3e69863702..8789b5e4e4 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -13,7 +13,7 @@ use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage}; use crate::create_testing_sql_query_handler; -fn create_postgres_server(table: MemTable) -> Result> { +fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result> { let query_handler = create_testing_sql_query_handler(table); let io_runtime = Arc::new( RuntimeBuilder::default() @@ -22,14 +22,18 @@ fn create_postgres_server(table: MemTable) -> Result> { .build() .unwrap(), ); - Ok(Box::new(PostgresServer::new(query_handler, io_runtime))) + Ok(Box::new(PostgresServer::new( + query_handler, + check_pwd, + io_runtime, + ))) } #[tokio::test] pub async fn test_start_postgres_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let pg_server = create_postgres_server(table)?; + let pg_server = create_postgres_server(table, false)?; let listening = "127.0.0.1:0".parse::().unwrap(); let result = pg_server.start(listening).await; assert!(result.is_ok()); @@ -43,12 +47,19 @@ pub async fn test_start_postgres_server() -> Result<()> { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_shutdown_pg_server() -> Result<()> { +async fn test_shutdown_pg_server_range() -> Result<()> { + assert!(test_shutdown_pg_server(false).await.is_ok()); + assert!(test_shutdown_pg_server(true).await.is_ok()); + Ok(()) +} + +// #[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> { common_telemetry::init_default_ut_logging(); let table = MemTable::default_numbers_table(); - let postgres_server = create_postgres_server(table)?; + let postgres_server = create_postgres_server(table, with_pwd)?; let result = postgres_server.shutdown().await; assert!(result .unwrap_err() @@ -63,7 +74,7 @@ async fn test_shutdown_pg_server() -> Result<()> { for _ in 0..2 { join_handles.push(tokio::spawn(async move { for _ in 0..1000 { - match create_connection(server_port).await { + match create_connection(server_port, with_pwd).await { Ok(connection) => { match connection .simple_query("SELECT uint32s FROM numbers LIMIT 1") @@ -107,7 +118,7 @@ async fn test_query_pg_concurrently() -> Result<()> { let table = MemTable::default_numbers_table(); - let pg_server = create_postgres_server(table)?; + let pg_server = create_postgres_server(table, false)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = pg_server.start(listening).await.unwrap(); let server_port = server_addr.port(); @@ -119,7 +130,7 @@ async fn test_query_pg_concurrently() -> Result<()> { join_handles.push(tokio::spawn(async move { let mut rand: StdRng = rand::SeedableRng::from_entropy(); - let mut client = create_connection(server_port).await.unwrap(); + let mut client = create_connection(server_port, false).await.unwrap(); for _k in 0..expect_executed_queries_per_worker { let expected: u32 = rand.gen_range(0..100); @@ -140,7 +151,7 @@ async fn test_query_pg_concurrently() -> Result<()> { // 1/100 chance to reconnect let should_recreate_conn = expected == 1; if should_recreate_conn { - client = create_connection(server_port).await.unwrap(); + client = create_connection(server_port, false).await.unwrap(); } } expect_executed_queries_per_worker @@ -154,8 +165,15 @@ async fn test_query_pg_concurrently() -> Result<()> { Ok(()) } -async fn create_connection(port: u16) -> std::result::Result { - let url = format!("host=127.0.0.1 port={} connect_timeout=2", port); +async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result { + let url = if with_pwd { + format!( + "host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2", + port + ) + } else { + format!("host=127.0.0.1 port={} connect_timeout=2", port) + }; let (client, conn) = tokio_postgres::connect(&url, NoTls).await?; tokio::spawn(conn); Ok(client)