feat: add mysql reject_no_database (#896)

* chore: update opensrv-mysql to main

* refactor: change mysql server struct

* feat: add option to reject no database mysql connection request

* chore: remove unused condition

* chore: rebase develop

* chore: make reject_no_database optional
This commit is contained in:
shuiyisong
2023-01-29 12:09:47 +08:00
committed by GitHub
parent 64243e3a7d
commit aafc26c788
11 changed files with 252 additions and 115 deletions

3
Cargo.lock generated
View File

@@ -4448,8 +4448,7 @@ dependencies = [
[[package]]
name = "opensrv-mysql"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac5d68ae914b1317d874ce049e52d386b1209d8835d4e6e094f2e90bfb49eccc"
source = "git+https://github.com/datafuselabs/opensrv?rev=b44c9d1360da297b305abf33aecfa94888e1554c#b44c9d1360da297b305abf33aecfa94888e1554c"
dependencies = [
"async-trait",
"byteorder",

View File

@@ -323,6 +323,10 @@ mod tests {
fe_opts.mysql_options.as_ref().unwrap().addr
);
assert_eq!(2, fe_opts.mysql_options.as_ref().unwrap().runtime_size);
assert_eq!(
None,
fe_opts.mysql_options.as_ref().unwrap().reject_no_database
);
assert!(fe_opts.influxdb_options.as_ref().unwrap().enable);
}

View File

@@ -18,15 +18,18 @@ use std::sync::Arc;
use common_runtime::Builder as RuntimeBuilder;
use common_telemetry::tracing::log::info;
use servers::error::Error::InternalIo;
use servers::grpc::GrpcServer;
use servers::mysql::server::MysqlServer;
use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef};
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor;
use servers::query_handler::sql::ServerSqlQueryHandlerAdaptor;
use servers::server::Server;
use servers::tls::TlsOption;
use servers::Mode;
use snafu::ResultExt;
use crate::datanode::DatanodeOptions;
use crate::error::Error::StartServer;
use crate::error::{ParseAddrSnafu, Result, RuntimeResourceSnafu, StartServerSnafu};
use crate::instance::InstanceRef;
@@ -61,11 +64,24 @@ impl Services {
.build()
.context(RuntimeResourceSnafu)?,
);
let tls = TlsOption::default();
// default tls config returns None
// but try to think a better way to do this
Some(MysqlServer::create_server(
ServerSqlQueryHandlerAdaptor::arc(instance.clone()),
mysql_io_runtime,
Default::default(),
None,
Arc::new(MysqlSpawnRef::new(
ServerSqlQueryHandlerAdaptor::arc(instance.clone()),
None,
)),
Arc::new(MysqlSpawnConfig::new(
tls.should_force_tls(),
tls.setup()
.map_err(|e| StartServer {
source: InternalIo { source: e },
})?
.map(Arc::new),
false,
)),
))
}
};

View File

@@ -21,6 +21,7 @@ pub struct MysqlOptions {
pub runtime_size: usize,
#[serde(default = "Default::default")]
pub tls: TlsOption,
pub reject_no_database: Option<bool>,
}
impl Default for MysqlOptions {
@@ -29,6 +30,7 @@ impl Default for MysqlOptions {
addr: "127.0.0.1:4002".to_string(),
runtime_size: 2,
tls: TlsOption::default(),
reject_no_database: None,
}
}
}

View File

@@ -18,9 +18,10 @@ use std::sync::Arc;
use common_runtime::Builder as RuntimeBuilder;
use common_telemetry::info;
use servers::auth::UserProviderRef;
use servers::error::Error::InternalIo;
use servers::grpc::GrpcServer;
use servers::http::HttpServer;
use servers::mysql::server::MysqlServer;
use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef};
use servers::opentsdb::OpentsdbServer;
use servers::postgres::PostgresServer;
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor;
@@ -29,6 +30,7 @@ use servers::server::Server;
use snafu::ResultExt;
use tokio::try_join;
use crate::error::Error::StartServer;
use crate::error::{self, Result};
use crate::frontend::FrontendOptions;
use crate::influxdb::InfluxdbOptions;
@@ -81,12 +83,22 @@ impl Services {
.build()
.context(error::RuntimeResourceSnafu)?,
);
let mysql_server = MysqlServer::create_server(
ServerSqlQueryHandlerAdaptor::arc(instance.clone()),
mysql_io_runtime,
opts.tls.clone(),
user_provider.clone(),
Arc::new(MysqlSpawnRef::new(
ServerSqlQueryHandlerAdaptor::arc(instance.clone()),
user_provider.clone(),
)),
Arc::new(MysqlSpawnConfig::new(
opts.tls.should_force_tls(),
opts.tls
.setup()
.map_err(|e| StartServer {
source: InternalIo { source: e },
})?
.map(Arc::new),
opts.reject_no_database.unwrap_or(false),
)),
);
Some((mysql_server, mysql_addr))

View File

@@ -36,7 +36,7 @@ metrics = "0.20"
num_cpus = "1.13"
once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = "0.3"
opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", rev = "b44c9d1360da297b305abf33aecfa94888e1554c" }
pgwire = "0.6.3"
pin-project = "1.0"
prost.workspace = true

View File

@@ -187,6 +187,7 @@ pub mod test {
use std::fs::File;
use std::io::{LineWriter, Write};
use session::context::UserInfo;
use tempdir::TempDir;
use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider};
@@ -216,7 +217,7 @@ pub mod test {
assert_eq!(sha1_2, sha1_2_answer);
}
async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) {
async fn test_authenticate(provider: &dyn UserProvider, username: &str, password: &str) {
let re = provider
.authenticate(
Identity::UserId(username, None),
@@ -226,11 +227,20 @@ pub mod test {
assert!(re.is_ok());
}
#[tokio::test]
async fn test_authorize() {
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
let re = provider
.authorize("catalog", "schema", &UserInfo::new("root"))
.await;
assert!(re.is_ok());
}
#[tokio::test]
async fn test_inline_provider() {
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
test_auth(&provider, "root", "123456").await;
test_auth(&provider, "admin", "654321").await;
test_authenticate(&provider, "root", "123456").await;
test_authenticate(&provider, "admin", "654321").await;
}
#[tokio::test]
@@ -254,7 +264,7 @@ admin=654321",
let param = format!("file:{file_path}");
let provider = StaticUserProvider::try_from(param.as_str()).unwrap();
test_auth(&provider, "root", "123456").await;
test_auth(&provider, "admin", "654321").await;
test_authenticate(&provider, "root", "123456").await;
test_authenticate(&provider, "admin", "654321").await;
}
}

View File

@@ -44,8 +44,8 @@ pub struct MysqlInstanceShim {
impl MysqlInstanceShim {
pub fn create(
query_handler: ServerSqlQueryHandlerRef,
client_addr: SocketAddr,
user_provider: Option<UserProviderRef>,
client_addr: SocketAddr,
) -> MysqlInstanceShim {
// init a random salt
let mut bs = vec![0u8; 20];

View File

@@ -33,39 +33,89 @@ use crate::error::{Error, Result};
use crate::mysql::handler::MysqlInstanceShim;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};
use crate::tls::TlsOption;
// Default size of ResultSet write buffer: 100KB
const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
struct MysqlRuntimeOption {
/// [`MysqlSpawnRef`] stores arc refs
/// that should be passed to new [`MysqlInstanceShim`]s.
pub struct MysqlSpawnRef {
query_handler: ServerSqlQueryHandlerRef,
tls_conf: Option<Arc<ServerConfig>>,
force_tls: bool,
user_provider: Option<UserProviderRef>,
}
type MysqlRuntimeOptionRef = Arc<MysqlRuntimeOption>;
impl MysqlSpawnRef {
pub fn new(
query_handler: ServerSqlQueryHandlerRef,
user_provider: Option<UserProviderRef>,
) -> MysqlSpawnRef {
MysqlSpawnRef {
query_handler,
user_provider,
}
}
fn query_handler(&self) -> ServerSqlQueryHandlerRef {
self.query_handler.clone()
}
fn user_provider(&self) -> Option<UserProviderRef> {
self.user_provider.clone()
}
}
/// [`MysqlSpawnConfig`] stores config values
/// which are used to initialize [`MysqlInstanceShim`]s.
pub struct MysqlSpawnConfig {
// tls config
force_tls: bool,
tls: Option<Arc<ServerConfig>>,
// other shim config
reject_no_database: bool,
}
impl MysqlSpawnConfig {
pub fn new(
force_tls: bool,
tls: Option<Arc<ServerConfig>>,
reject_no_database: bool,
) -> MysqlSpawnConfig {
MysqlSpawnConfig {
force_tls,
tls,
reject_no_database,
}
}
fn tls(&self) -> Option<Arc<ServerConfig>> {
self.tls.clone()
}
}
impl From<&MysqlSpawnConfig> for IntermediaryOptions {
fn from(value: &MysqlSpawnConfig) -> Self {
IntermediaryOptions {
reject_connection_on_dbname_absence: value.reject_no_database,
..Default::default()
}
}
}
pub struct MysqlServer {
base_server: BaseTcpServer,
query_handler: ServerSqlQueryHandlerRef,
tls: TlsOption,
user_provider: Option<UserProviderRef>,
spawn_ref: Arc<MysqlSpawnRef>,
spawn_config: Arc<MysqlSpawnConfig>,
}
impl MysqlServer {
pub fn create_server(
query_handler: ServerSqlQueryHandlerRef,
io_runtime: Arc<Runtime>,
tls: TlsOption,
user_provider: Option<UserProviderRef>,
spawn_ref: Arc<MysqlSpawnRef>,
spawn_config: Arc<MysqlSpawnConfig>,
) -> Box<dyn Server> {
Box::new(MysqlServer {
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
query_handler,
tls,
user_provider,
spawn_ref,
spawn_config,
})
}
@@ -73,32 +123,21 @@ impl MysqlServer {
&self,
io_runtime: Arc<Runtime>,
stream: AbortableStream,
tls_conf: Option<Arc<ServerConfig>>,
) -> impl Future<Output = ()> {
let query_handler = self.query_handler.clone();
let user_provider = self.user_provider.clone();
let force_tls = self.tls.should_force_tls();
let spawn_ref = self.spawn_ref.clone();
let spawn_config = self.spawn_config.clone();
stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let query_handler = query_handler.clone();
let user_provider = user_provider.clone();
let tls_conf = tls_conf.clone();
let mysql_runtime_option = Arc::new(MysqlRuntimeOption {
query_handler,
tls_conf,
force_tls,
user_provider,
});
let spawn_ref = spawn_ref.clone();
let spawn_config = spawn_config.clone();
async move {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
if let Err(error) =
Self::handle(io_stream, io_runtime, mysql_runtime_option).await
Self::handle(io_stream, io_runtime, spawn_ref, spawn_config).await
{
error!(error; "Unexpected error when handling TcpStream");
};
@@ -111,12 +150,13 @@ impl MysqlServer {
async fn handle(
stream: TcpStream,
io_runtime: Arc<Runtime>,
runtime_opts: MysqlRuntimeOptionRef,
spawn_ref: Arc<MysqlSpawnRef>,
spawn_config: Arc<MysqlSpawnConfig>,
) -> Result<()> {
info!("MySQL connection coming from: {}", stream.peer_addr()?);
io_runtime .spawn(async move {
io_runtime.spawn(async move {
// TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client.
if let Err(e) = Self::do_handle(stream, runtime_opts).await {
if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config).await {
// TODO(LFC): Write this error to client as well, in MySQL text protocol.
// Looks like we have to expose opensrv-mysql's `PacketWriter`?
error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.")
@@ -126,31 +166,32 @@ impl MysqlServer {
Ok(())
}
async fn do_handle(stream: TcpStream, runtime_opts: MysqlRuntimeOptionRef) -> Result<()> {
async fn do_handle(
stream: TcpStream,
spawn_ref: Arc<MysqlSpawnRef>,
spawn_config: Arc<MysqlSpawnConfig>,
) -> Result<()> {
let mut shim = MysqlInstanceShim::create(
runtime_opts.query_handler.clone(),
spawn_ref.query_handler(),
spawn_ref.user_provider(),
stream.peer_addr()?,
runtime_opts.user_provider.clone(),
);
let (mut r, w) = stream.into_split();
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
let ops = IntermediaryOptions::default();
let (client_tls, init_params) = AsyncMysqlIntermediary::init_before_ssl(
&mut shim,
&mut r,
&mut w,
&runtime_opts.tls_conf,
)
.await?;
let ops = spawn_config.as_ref().into();
if runtime_opts.force_tls && !client_tls {
let (client_tls, init_params) =
AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
.await?;
if spawn_config.force_tls && !client_tls {
return Err(Error::TlsRequired {
server: "mysql".to_owned(),
});
}
match runtime_opts.tls_conf.clone() {
match spawn_config.tls() {
Some(tls_conf) if client_tls => {
secure_run_with_options(shim, w, ops, tls_conf, init_params).await
}
@@ -167,12 +208,9 @@ impl Server for MysqlServer {
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
let (stream, addr) = self.base_server.bind(listening).await?;
let io_runtime = self.base_server.io_runtime();
let tls_conf = self.tls.setup()?.map(Arc::new);
let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_conf));
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
self.base_server.start_with(join_handle).await?;
Ok(addr)
}

View File

@@ -23,7 +23,7 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use session::context::{UserInfo, DEFAULT_USERNAME};
use session::context::UserInfo;
use snafu::ResultExt;
use crate::auth::{Identity, Password, UserProviderRef};
@@ -202,21 +202,6 @@ impl StartupHandler for PgAuthStartupHandler {
))
.await?;
} else {
// no user is provided, use default user
// and still do authorization
let mut login_info = LoginInfo::from_client_info(client);
login_info.user = Some(DEFAULT_USERNAME.to_string());
let authorize_result = self.verifier.authorize(&login_info).await;
if !matches!(authorize_result, Ok(true)) {
return send_error(
client,
"FATAL",
"28P01",
"password authorization failed".to_owned(),
)
.await;
}
auth::finish_authentication(client, &self.param_provider).await;
}
}

View File

@@ -25,7 +25,7 @@ use mysql_async::SslOpts;
use rand::rngs::StdRng;
use rand::Rng;
use servers::error::Result;
use servers::mysql::server::MysqlServer;
use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef};
use servers::server::Server;
use servers::tls::TlsOption;
use table::test_util::MemTable;
@@ -34,11 +34,14 @@ use crate::auth::{DatabaseAuthInfo, MockUserProvider};
use crate::create_testing_sql_query_handler;
use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData};
fn create_mysql_server(
table: MemTable,
#[derive(Default)]
struct MysqlOpts<'a> {
tls: TlsOption,
auth_info: Option<DatabaseAuthInfo>,
) -> Result<Box<dyn Server>> {
auth_info: Option<DatabaseAuthInfo<'a>>,
reject_no_database: bool,
}
fn create_mysql_server(table: MemTable, opts: MysqlOpts<'_>) -> Result<Box<dyn Server>> {
let query_handler = create_testing_sql_query_handler(table);
let io_runtime = Arc::new(
RuntimeBuilder::default()
@@ -49,15 +52,18 @@ fn create_mysql_server(
);
let mut provider = MockUserProvider::default();
if let Some(auth_info) = auth_info {
if let Some(auth_info) = opts.auth_info {
provider.set_authorization_info(auth_info);
}
Ok(MysqlServer::create_server(
query_handler,
io_runtime,
tls,
Some(Arc::new(provider)),
Arc::new(MysqlSpawnRef::new(query_handler, Some(Arc::new(provider)))),
Arc::new(MysqlSpawnConfig::new(
opts.tls.should_force_tls(),
opts.tls.setup()?.map(Arc::new),
opts.reject_no_database,
)),
))
}
@@ -65,7 +71,7 @@ fn create_mysql_server(
async fn test_start_mysql_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default(), None)?;
let mysql_server = create_mysql_server(table, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = mysql_server.start(listening).await;
assert!(result.is_ok());
@@ -78,11 +84,42 @@ async fn test_start_mysql_server() -> Result<()> {
Ok(())
}
#[tokio::test]
async fn test_reject_no_database() -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(
table,
MysqlOpts {
reject_no_database: true,
..Default::default()
},
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let server_port = server_addr.port();
let fail = create_connection(server_port, None, false).await;
assert!(fail.is_err());
let pass = create_connection(server_port, Some("public"), false).await;
assert!(pass.is_ok());
let result = mysql_server.shutdown().await;
assert!(result.is_ok());
Ok(())
}
#[tokio::test]
async fn test_schema_validation() -> Result<()> {
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default(), Some(auth_info))?;
let mysql_server = create_mysql_server(
table,
MysqlOpts {
auth_info: Some(auth_info),
..Default::default()
},
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
Ok((mysql_server, server_addr.port()))
@@ -96,9 +133,7 @@ async fn test_schema_validation() -> Result<()> {
})
.await?;
//TODO(shuiyisong): mysql conn without dbname rejection is not implemented yet, add test later.
let pass = create_connection(server_port, Some("public"), false).await;
let pass = create_connection_default_db_name(server_port, false).await;
assert!(pass.is_ok());
let result = mysql_server.shutdown().await;
assert!(result.is_ok());
@@ -111,7 +146,7 @@ async fn test_schema_validation() -> Result<()> {
})
.await?;
let fail = create_connection(server_port, Some("public"), false).await;
let fail = create_connection_default_db_name(server_port, false).await;
assert!(fail.is_err());
let result = mysql_server.shutdown().await;
assert!(result.is_ok());
@@ -125,7 +160,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default(), None)?;
let mysql_server = create_mysql_server(table, Default::default())?;
let result = mysql_server.shutdown().await;
assert!(result
.unwrap_err()
@@ -140,7 +175,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
for _ in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port, None, false).await {
match create_connection_default_db_name(server_port, false).await {
Ok(mut connection) => {
let result: u32 = connection
.query_first("SELECT uint32s FROM numbers LIMIT 1")
@@ -230,7 +265,13 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table, server_tls, None)?;
let mysql_server = create_mysql_server(
table,
MysqlOpts {
tls: server_tls,
..Default::default()
},
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
@@ -261,12 +302,18 @@ async fn test_server_required_secure_client_plain_with_pkcs8_priv_key() -> Resul
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table, server_tls, None)?;
let mysql_server = create_mysql_server(
table,
MysqlOpts {
tls: server_tls,
..Default::default()
},
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let r = create_connection(server_addr.port(), None, client_tls).await;
let r = create_connection_default_db_name(server_addr.port(), client_tls).await;
assert!(r.is_err());
Ok(())
}
@@ -287,15 +334,19 @@ async fn test_db_name() -> Result<()> {
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table, server_tls, None)?;
let mysql_server = create_mysql_server(
table,
MysqlOpts {
tls: server_tls,
..Default::default()
},
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let r = create_connection(server_addr.port(), None, client_tls).await;
assert!(r.is_ok());
let r = create_connection(server_addr.port(), Some(DEFAULT_SCHEMA_NAME), client_tls).await;
// None actually uses default database name
let r = create_connection_default_db_name(server_addr.port(), client_tls).await;
assert!(r.is_ok());
let r = create_connection(server_addr.port(), Some("tomcat"), client_tls).await;
@@ -315,12 +366,18 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) ->
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table, server_tls, None)?;
let mysql_server = create_mysql_server(
table,
MysqlOpts {
tls: server_tls,
..Default::default()
},
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let mut connection = create_connection(server_addr.port(), None, client_tls)
let mut connection = create_connection_default_db_name(server_addr.port(), client_tls)
.await
.unwrap();
@@ -350,7 +407,7 @@ async fn test_query_concurrently() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default(), None)?;
let mysql_server = create_mysql_server(table, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let server_port = server_addr.port();
@@ -362,7 +419,9 @@ async fn test_query_concurrently() -> Result<()> {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut connection = create_connection(server_port, None, false).await.unwrap();
let mut connection = create_connection_default_db_name(server_port, false)
.await
.unwrap();
for _ in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
let result: u32 = connection
@@ -376,7 +435,9 @@ async fn test_query_concurrently() -> Result<()> {
let should_recreate_conn = expected == 1;
if should_recreate_conn {
connection = create_connection(server_port, None, false).await.unwrap();
connection = create_connection_default_db_name(server_port, false)
.await
.unwrap();
}
}
expect_executed_queries_per_worker
@@ -390,6 +451,13 @@ async fn test_query_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection_default_db_name(
port: u16,
ssl: bool,
) -> mysql_async::Result<mysql_async::Conn> {
create_connection(port, Some(DEFAULT_SCHEMA_NAME), ssl).await
}
async fn create_connection(
port: u16,
db_name: Option<&str>,
@@ -400,10 +468,13 @@ async fn create_connection(
.tcp_port(port)
.prefer_socket(false)
.wait_timeout(Some(1000))
.db_name(db_name.or(Some(DEFAULT_SCHEMA_NAME)))
.user(Some("greptime".to_string()))
.pass(Some("greptime".to_string()));
if let Some(db_name) = db_name {
opts = opts.db_name(Some(db_name.to_string()));
}
if ssl {
let ssl_opts = SslOpts::default()
.with_danger_skip_domain_validation(true)