From aafc26c788cc151710f6416ed4e6d0b121c349b5 Mon Sep 17 00:00:00 2001 From: shuiyisong <113876041+shuiyisong@users.noreply.github.com> Date: Sun, 29 Jan 2023 12:09:47 +0800 Subject: [PATCH] feat: add mysql `reject_no_database` (#896) * chore: update opensrv-mysql to main * refactor: change mysql server struct * feat: add option to reject no database mysql connection request * chore: remove unused condition * chore: rebase develop * chore: make reject_no_database optional --- Cargo.lock | 3 +- src/cmd/src/standalone.rs | 4 + src/datanode/src/server.rs | 24 +++- src/frontend/src/mysql.rs | 2 + src/frontend/src/server.rs | 22 ++- src/servers/Cargo.toml | 2 +- src/servers/src/auth/user_provider.rs | 20 ++- src/servers/src/mysql/handler.rs | 2 +- src/servers/src/mysql/server.rs | 138 ++++++++++++------- src/servers/src/postgres/auth_handler.rs | 17 +-- src/servers/tests/mysql/mysql_server_test.rs | 133 +++++++++++++----- 11 files changed, 252 insertions(+), 115 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 69c6fad489..18a8164c01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4448,8 +4448,7 @@ dependencies = [ [[package]] name = "opensrv-mysql" version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac5d68ae914b1317d874ce049e52d386b1209d8835d4e6e094f2e90bfb49eccc" +source = "git+https://github.com/datafuselabs/opensrv?rev=b44c9d1360da297b305abf33aecfa94888e1554c#b44c9d1360da297b305abf33aecfa94888e1554c" dependencies = [ "async-trait", "byteorder", diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index 6dfd64f5d0..26d30cbbf2 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -323,6 +323,10 @@ mod tests { fe_opts.mysql_options.as_ref().unwrap().addr ); assert_eq!(2, fe_opts.mysql_options.as_ref().unwrap().runtime_size); + assert_eq!( + None, + fe_opts.mysql_options.as_ref().unwrap().reject_no_database + ); assert!(fe_opts.influxdb_options.as_ref().unwrap().enable); } diff --git a/src/datanode/src/server.rs b/src/datanode/src/server.rs index 49da3c52e1..3827138fb3 100644 --- a/src/datanode/src/server.rs +++ b/src/datanode/src/server.rs @@ -18,15 +18,18 @@ use std::sync::Arc; use common_runtime::Builder as RuntimeBuilder; use common_telemetry::tracing::log::info; +use servers::error::Error::InternalIo; use servers::grpc::GrpcServer; -use servers::mysql::server::MysqlServer; +use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef}; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor; use servers::query_handler::sql::ServerSqlQueryHandlerAdaptor; use servers::server::Server; +use servers::tls::TlsOption; use servers::Mode; use snafu::ResultExt; use crate::datanode::DatanodeOptions; +use crate::error::Error::StartServer; use crate::error::{ParseAddrSnafu, Result, RuntimeResourceSnafu, StartServerSnafu}; use crate::instance::InstanceRef; @@ -61,11 +64,24 @@ impl Services { .build() .context(RuntimeResourceSnafu)?, ); + let tls = TlsOption::default(); + // default tls config returns None + // but try to think a better way to do this Some(MysqlServer::create_server( - ServerSqlQueryHandlerAdaptor::arc(instance.clone()), mysql_io_runtime, - Default::default(), - None, + Arc::new(MysqlSpawnRef::new( + ServerSqlQueryHandlerAdaptor::arc(instance.clone()), + None, + )), + Arc::new(MysqlSpawnConfig::new( + tls.should_force_tls(), + tls.setup() + .map_err(|e| StartServer { + source: InternalIo { source: e }, + })? + .map(Arc::new), + false, + )), )) } }; diff --git a/src/frontend/src/mysql.rs b/src/frontend/src/mysql.rs index 2d0ebcfd6a..5dd00ee9ba 100644 --- a/src/frontend/src/mysql.rs +++ b/src/frontend/src/mysql.rs @@ -21,6 +21,7 @@ pub struct MysqlOptions { pub runtime_size: usize, #[serde(default = "Default::default")] pub tls: TlsOption, + pub reject_no_database: Option, } impl Default for MysqlOptions { @@ -29,6 +30,7 @@ impl Default for MysqlOptions { addr: "127.0.0.1:4002".to_string(), runtime_size: 2, tls: TlsOption::default(), + reject_no_database: None, } } } diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index d69b6c6fde..e438d0657b 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -18,9 +18,10 @@ use std::sync::Arc; use common_runtime::Builder as RuntimeBuilder; use common_telemetry::info; use servers::auth::UserProviderRef; +use servers::error::Error::InternalIo; use servers::grpc::GrpcServer; use servers::http::HttpServer; -use servers::mysql::server::MysqlServer; +use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef}; use servers::opentsdb::OpentsdbServer; use servers::postgres::PostgresServer; use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor; @@ -29,6 +30,7 @@ use servers::server::Server; use snafu::ResultExt; use tokio::try_join; +use crate::error::Error::StartServer; use crate::error::{self, Result}; use crate::frontend::FrontendOptions; use crate::influxdb::InfluxdbOptions; @@ -81,12 +83,22 @@ impl Services { .build() .context(error::RuntimeResourceSnafu)?, ); - let mysql_server = MysqlServer::create_server( - ServerSqlQueryHandlerAdaptor::arc(instance.clone()), mysql_io_runtime, - opts.tls.clone(), - user_provider.clone(), + Arc::new(MysqlSpawnRef::new( + ServerSqlQueryHandlerAdaptor::arc(instance.clone()), + user_provider.clone(), + )), + Arc::new(MysqlSpawnConfig::new( + opts.tls.should_force_tls(), + opts.tls + .setup() + .map_err(|e| StartServer { + source: InternalIo { source: e }, + })? + .map(Arc::new), + opts.reject_no_database.unwrap_or(false), + )), ); Some((mysql_server, mysql_addr)) diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 52f914ccc0..e2b3808e6e 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -36,7 +36,7 @@ metrics = "0.20" num_cpus = "1.13" once_cell = "1.16" openmetrics-parser = "0.4" -opensrv-mysql = "0.3" +opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", rev = "b44c9d1360da297b305abf33aecfa94888e1554c" } pgwire = "0.6.3" pin-project = "1.0" prost.workspace = true diff --git a/src/servers/src/auth/user_provider.rs b/src/servers/src/auth/user_provider.rs index 8e9d570aa8..f425726b49 100644 --- a/src/servers/src/auth/user_provider.rs +++ b/src/servers/src/auth/user_provider.rs @@ -187,6 +187,7 @@ pub mod test { use std::fs::File; use std::io::{LineWriter, Write}; + use session::context::UserInfo; use tempdir::TempDir; use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider}; @@ -216,7 +217,7 @@ pub mod test { assert_eq!(sha1_2, sha1_2_answer); } - async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) { + async fn test_authenticate(provider: &dyn UserProvider, username: &str, password: &str) { let re = provider .authenticate( Identity::UserId(username, None), @@ -226,11 +227,20 @@ pub mod test { assert!(re.is_ok()); } + #[tokio::test] + async fn test_authorize() { + let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap(); + let re = provider + .authorize("catalog", "schema", &UserInfo::new("root")) + .await; + assert!(re.is_ok()); + } + #[tokio::test] async fn test_inline_provider() { let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap(); - test_auth(&provider, "root", "123456").await; - test_auth(&provider, "admin", "654321").await; + test_authenticate(&provider, "root", "123456").await; + test_authenticate(&provider, "admin", "654321").await; } #[tokio::test] @@ -254,7 +264,7 @@ admin=654321", let param = format!("file:{file_path}"); let provider = StaticUserProvider::try_from(param.as_str()).unwrap(); - test_auth(&provider, "root", "123456").await; - test_auth(&provider, "admin", "654321").await; + test_authenticate(&provider, "root", "123456").await; + test_authenticate(&provider, "admin", "654321").await; } } diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index ad13053a3b..284b801dbb 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -44,8 +44,8 @@ pub struct MysqlInstanceShim { impl MysqlInstanceShim { pub fn create( query_handler: ServerSqlQueryHandlerRef, - client_addr: SocketAddr, user_provider: Option, + client_addr: SocketAddr, ) -> MysqlInstanceShim { // init a random salt let mut bs = vec![0u8; 20]; diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index 830d3f858e..4653e66023 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -33,39 +33,89 @@ use crate::error::{Error, Result}; use crate::mysql::handler::MysqlInstanceShim; use crate::query_handler::sql::ServerSqlQueryHandlerRef; use crate::server::{AbortableStream, BaseTcpServer, Server}; -use crate::tls::TlsOption; // Default size of ResultSet write buffer: 100KB const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024; -struct MysqlRuntimeOption { +/// [`MysqlSpawnRef`] stores arc refs +/// that should be passed to new [`MysqlInstanceShim`]s. +pub struct MysqlSpawnRef { query_handler: ServerSqlQueryHandlerRef, - tls_conf: Option>, - force_tls: bool, user_provider: Option, } -type MysqlRuntimeOptionRef = Arc; +impl MysqlSpawnRef { + pub fn new( + query_handler: ServerSqlQueryHandlerRef, + user_provider: Option, + ) -> MysqlSpawnRef { + MysqlSpawnRef { + query_handler, + user_provider, + } + } + + fn query_handler(&self) -> ServerSqlQueryHandlerRef { + self.query_handler.clone() + } + fn user_provider(&self) -> Option { + self.user_provider.clone() + } +} + +/// [`MysqlSpawnConfig`] stores config values +/// which are used to initialize [`MysqlInstanceShim`]s. +pub struct MysqlSpawnConfig { + // tls config + force_tls: bool, + tls: Option>, + // other shim config + reject_no_database: bool, +} + +impl MysqlSpawnConfig { + pub fn new( + force_tls: bool, + tls: Option>, + reject_no_database: bool, + ) -> MysqlSpawnConfig { + MysqlSpawnConfig { + force_tls, + tls, + reject_no_database, + } + } + + fn tls(&self) -> Option> { + self.tls.clone() + } +} + +impl From<&MysqlSpawnConfig> for IntermediaryOptions { + fn from(value: &MysqlSpawnConfig) -> Self { + IntermediaryOptions { + reject_connection_on_dbname_absence: value.reject_no_database, + ..Default::default() + } + } +} pub struct MysqlServer { base_server: BaseTcpServer, - query_handler: ServerSqlQueryHandlerRef, - tls: TlsOption, - user_provider: Option, + spawn_ref: Arc, + spawn_config: Arc, } impl MysqlServer { pub fn create_server( - query_handler: ServerSqlQueryHandlerRef, io_runtime: Arc, - tls: TlsOption, - user_provider: Option, + spawn_ref: Arc, + spawn_config: Arc, ) -> Box { Box::new(MysqlServer { base_server: BaseTcpServer::create_server("MySQL", io_runtime), - query_handler, - tls, - user_provider, + spawn_ref, + spawn_config, }) } @@ -73,32 +123,21 @@ impl MysqlServer { &self, io_runtime: Arc, stream: AbortableStream, - tls_conf: Option>, ) -> impl Future { - let query_handler = self.query_handler.clone(); - let user_provider = self.user_provider.clone(); - - let force_tls = self.tls.should_force_tls(); + let spawn_ref = self.spawn_ref.clone(); + let spawn_config = self.spawn_config.clone(); stream.for_each(move |tcp_stream| { let io_runtime = io_runtime.clone(); - let query_handler = query_handler.clone(); - let user_provider = user_provider.clone(); - let tls_conf = tls_conf.clone(); - - let mysql_runtime_option = Arc::new(MysqlRuntimeOption { - query_handler, - tls_conf, - force_tls, - user_provider, - }); + let spawn_ref = spawn_ref.clone(); + let spawn_config = spawn_config.clone(); async move { match tcp_stream { Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt. Ok(io_stream) => { if let Err(error) = - Self::handle(io_stream, io_runtime, mysql_runtime_option).await + Self::handle(io_stream, io_runtime, spawn_ref, spawn_config).await { error!(error; "Unexpected error when handling TcpStream"); }; @@ -111,12 +150,13 @@ impl MysqlServer { async fn handle( stream: TcpStream, io_runtime: Arc, - runtime_opts: MysqlRuntimeOptionRef, + spawn_ref: Arc, + spawn_config: Arc, ) -> Result<()> { info!("MySQL connection coming from: {}", stream.peer_addr()?); - io_runtime .spawn(async move { + io_runtime.spawn(async move { // TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client. - if let Err(e) = Self::do_handle(stream, runtime_opts).await { + if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config).await { // TODO(LFC): Write this error to client as well, in MySQL text protocol. // Looks like we have to expose opensrv-mysql's `PacketWriter`? error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.") @@ -126,31 +166,32 @@ impl MysqlServer { Ok(()) } - async fn do_handle(stream: TcpStream, runtime_opts: MysqlRuntimeOptionRef) -> Result<()> { + async fn do_handle( + stream: TcpStream, + spawn_ref: Arc, + spawn_config: Arc, + ) -> Result<()> { let mut shim = MysqlInstanceShim::create( - runtime_opts.query_handler.clone(), + spawn_ref.query_handler(), + spawn_ref.user_provider(), stream.peer_addr()?, - runtime_opts.user_provider.clone(), ); let (mut r, w) = stream.into_split(); let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w); - let ops = IntermediaryOptions::default(); - let (client_tls, init_params) = AsyncMysqlIntermediary::init_before_ssl( - &mut shim, - &mut r, - &mut w, - &runtime_opts.tls_conf, - ) - .await?; + let ops = spawn_config.as_ref().into(); - if runtime_opts.force_tls && !client_tls { + let (client_tls, init_params) = + AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls()) + .await?; + + if spawn_config.force_tls && !client_tls { return Err(Error::TlsRequired { server: "mysql".to_owned(), }); } - match runtime_opts.tls_conf.clone() { + match spawn_config.tls() { Some(tls_conf) if client_tls => { secure_run_with_options(shim, w, ops, tls_conf, init_params).await } @@ -167,12 +208,9 @@ impl Server for MysqlServer { async fn start(&self, listening: SocketAddr) -> Result { let (stream, addr) = self.base_server.bind(listening).await?; - let io_runtime = self.base_server.io_runtime(); - let tls_conf = self.tls.setup()?.map(Arc::new); - - let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_conf)); + let join_handle = tokio::spawn(self.accept(io_runtime, stream)); self.base_server.start_with(join_handle).await?; Ok(addr) } diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 15988589be..075251fd76 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -23,7 +23,7 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::response::ErrorResponse; use pgwire::messages::startup::Authentication; use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; -use session::context::{UserInfo, DEFAULT_USERNAME}; +use session::context::UserInfo; use snafu::ResultExt; use crate::auth::{Identity, Password, UserProviderRef}; @@ -202,21 +202,6 @@ impl StartupHandler for PgAuthStartupHandler { )) .await?; } else { - // no user is provided, use default user - // and still do authorization - let mut login_info = LoginInfo::from_client_info(client); - login_info.user = Some(DEFAULT_USERNAME.to_string()); - - let authorize_result = self.verifier.authorize(&login_info).await; - if !matches!(authorize_result, Ok(true)) { - return send_error( - client, - "FATAL", - "28P01", - "password authorization failed".to_owned(), - ) - .await; - } auth::finish_authentication(client, &self.param_provider).await; } } diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 97323c366a..7d5a57da31 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -25,7 +25,7 @@ use mysql_async::SslOpts; use rand::rngs::StdRng; use rand::Rng; use servers::error::Result; -use servers::mysql::server::MysqlServer; +use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef}; use servers::server::Server; use servers::tls::TlsOption; use table::test_util::MemTable; @@ -34,11 +34,14 @@ use crate::auth::{DatabaseAuthInfo, MockUserProvider}; use crate::create_testing_sql_query_handler; use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData}; -fn create_mysql_server( - table: MemTable, +#[derive(Default)] +struct MysqlOpts<'a> { tls: TlsOption, - auth_info: Option, -) -> Result> { + auth_info: Option>, + reject_no_database: bool, +} + +fn create_mysql_server(table: MemTable, opts: MysqlOpts<'_>) -> Result> { let query_handler = create_testing_sql_query_handler(table); let io_runtime = Arc::new( RuntimeBuilder::default() @@ -49,15 +52,18 @@ fn create_mysql_server( ); let mut provider = MockUserProvider::default(); - if let Some(auth_info) = auth_info { + if let Some(auth_info) = opts.auth_info { provider.set_authorization_info(auth_info); } Ok(MysqlServer::create_server( - query_handler, io_runtime, - tls, - Some(Arc::new(provider)), + Arc::new(MysqlSpawnRef::new(query_handler, Some(Arc::new(provider)))), + Arc::new(MysqlSpawnConfig::new( + opts.tls.should_force_tls(), + opts.tls.setup()?.map(Arc::new), + opts.reject_no_database, + )), )) } @@ -65,7 +71,7 @@ fn create_mysql_server( async fn test_start_mysql_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default(), None)?; + let mysql_server = create_mysql_server(table, Default::default())?; let listening = "127.0.0.1:0".parse::().unwrap(); let result = mysql_server.start(listening).await; assert!(result.is_ok()); @@ -78,11 +84,42 @@ async fn test_start_mysql_server() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_reject_no_database() -> Result<()> { + common_telemetry::init_default_ut_logging(); + let table = MemTable::default_numbers_table(); + let mysql_server = create_mysql_server( + table, + MysqlOpts { + reject_no_database: true, + ..Default::default() + }, + )?; + let listening = "127.0.0.1:0".parse::().unwrap(); + let server_addr = mysql_server.start(listening).await.unwrap(); + let server_port = server_addr.port(); + + let fail = create_connection(server_port, None, false).await; + assert!(fail.is_err()); + let pass = create_connection(server_port, Some("public"), false).await; + assert!(pass.is_ok()); + let result = mysql_server.shutdown().await; + assert!(result.is_ok()); + + Ok(()) +} + #[tokio::test] async fn test_schema_validation() -> Result<()> { async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box, u16)> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default(), Some(auth_info))?; + let mysql_server = create_mysql_server( + table, + MysqlOpts { + auth_info: Some(auth_info), + ..Default::default() + }, + )?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); Ok((mysql_server, server_addr.port())) @@ -96,9 +133,7 @@ async fn test_schema_validation() -> Result<()> { }) .await?; - //TODO(shuiyisong): mysql conn without dbname rejection is not implemented yet, add test later. - - let pass = create_connection(server_port, Some("public"), false).await; + let pass = create_connection_default_db_name(server_port, false).await; assert!(pass.is_ok()); let result = mysql_server.shutdown().await; assert!(result.is_ok()); @@ -111,7 +146,7 @@ async fn test_schema_validation() -> Result<()> { }) .await?; - let fail = create_connection(server_port, Some("public"), false).await; + let fail = create_connection_default_db_name(server_port, false).await; assert!(fail.is_err()); let result = mysql_server.shutdown().await; assert!(result.is_ok()); @@ -125,7 +160,7 @@ async fn test_shutdown_mysql_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default(), None)?; + let mysql_server = create_mysql_server(table, Default::default())?; let result = mysql_server.shutdown().await; assert!(result .unwrap_err() @@ -140,7 +175,7 @@ async fn test_shutdown_mysql_server() -> Result<()> { for _ in 0..2 { join_handles.push(tokio::spawn(async move { for _ in 0..1000 { - match create_connection(server_port, None, false).await { + match create_connection_default_db_name(server_port, false).await { Ok(mut connection) => { let result: u32 = connection .query_first("SELECT uint32s FROM numbers LIMIT 1") @@ -230,7 +265,13 @@ async fn test_server_required_secure_client_plain() -> Result<()> { let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::new("all_datatypes", recordbatch); - let mysql_server = create_mysql_server(table, server_tls, None)?; + let mysql_server = create_mysql_server( + table, + MysqlOpts { + tls: server_tls, + ..Default::default() + }, + )?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); @@ -261,12 +302,18 @@ async fn test_server_required_secure_client_plain_with_pkcs8_priv_key() -> Resul let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::new("all_datatypes", recordbatch); - let mysql_server = create_mysql_server(table, server_tls, None)?; + let mysql_server = create_mysql_server( + table, + MysqlOpts { + tls: server_tls, + ..Default::default() + }, + )?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); - let r = create_connection(server_addr.port(), None, client_tls).await; + let r = create_connection_default_db_name(server_addr.port(), client_tls).await; assert!(r.is_err()); Ok(()) } @@ -287,15 +334,19 @@ async fn test_db_name() -> Result<()> { let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::new("all_datatypes", recordbatch); - let mysql_server = create_mysql_server(table, server_tls, None)?; + let mysql_server = create_mysql_server( + table, + MysqlOpts { + tls: server_tls, + ..Default::default() + }, + )?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); - let r = create_connection(server_addr.port(), None, client_tls).await; - assert!(r.is_ok()); - - let r = create_connection(server_addr.port(), Some(DEFAULT_SCHEMA_NAME), client_tls).await; + // None actually uses default database name + let r = create_connection_default_db_name(server_addr.port(), client_tls).await; assert!(r.is_ok()); let r = create_connection(server_addr.port(), Some("tomcat"), client_tls).await; @@ -315,12 +366,18 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) -> let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::new("all_datatypes", recordbatch); - let mysql_server = create_mysql_server(table, server_tls, None)?; + let mysql_server = create_mysql_server( + table, + MysqlOpts { + tls: server_tls, + ..Default::default() + }, + )?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); - let mut connection = create_connection(server_addr.port(), None, client_tls) + let mut connection = create_connection_default_db_name(server_addr.port(), client_tls) .await .unwrap(); @@ -350,7 +407,7 @@ async fn test_query_concurrently() -> Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default(), None)?; + let mysql_server = create_mysql_server(table, Default::default())?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); let server_port = server_addr.port(); @@ -362,7 +419,9 @@ async fn test_query_concurrently() -> Result<()> { join_handles.push(tokio::spawn(async move { let mut rand: StdRng = rand::SeedableRng::from_entropy(); - let mut connection = create_connection(server_port, None, false).await.unwrap(); + let mut connection = create_connection_default_db_name(server_port, false) + .await + .unwrap(); for _ in 0..expect_executed_queries_per_worker { let expected: u32 = rand.gen_range(0..100); let result: u32 = connection @@ -376,7 +435,9 @@ async fn test_query_concurrently() -> Result<()> { let should_recreate_conn = expected == 1; if should_recreate_conn { - connection = create_connection(server_port, None, false).await.unwrap(); + connection = create_connection_default_db_name(server_port, false) + .await + .unwrap(); } } expect_executed_queries_per_worker @@ -390,6 +451,13 @@ async fn test_query_concurrently() -> Result<()> { Ok(()) } +async fn create_connection_default_db_name( + port: u16, + ssl: bool, +) -> mysql_async::Result { + create_connection(port, Some(DEFAULT_SCHEMA_NAME), ssl).await +} + async fn create_connection( port: u16, db_name: Option<&str>, @@ -400,10 +468,13 @@ async fn create_connection( .tcp_port(port) .prefer_socket(false) .wait_timeout(Some(1000)) - .db_name(db_name.or(Some(DEFAULT_SCHEMA_NAME))) .user(Some("greptime".to_string())) .pass(Some("greptime".to_string())); + if let Some(db_name) = db_name { + opts = opts.db_name(Some(db_name.to_string())); + } + if ssl { let ssl_opts = SslOpts::default() .with_danger_skip_domain_validation(true)