mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-26 18:00:41 +00:00
feat: add mysql reject_no_database (#896)
* chore: update opensrv-mysql to main * refactor: change mysql server struct * feat: add option to reject no database mysql connection request * chore: remove unused condition * chore: rebase develop * chore: make reject_no_database optional
This commit is contained in:
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -4448,8 +4448,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "opensrv-mysql"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac5d68ae914b1317d874ce049e52d386b1209d8835d4e6e094f2e90bfb49eccc"
|
||||
source = "git+https://github.com/datafuselabs/opensrv?rev=b44c9d1360da297b305abf33aecfa94888e1554c#b44c9d1360da297b305abf33aecfa94888e1554c"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
|
||||
@@ -323,6 +323,10 @@ mod tests {
|
||||
fe_opts.mysql_options.as_ref().unwrap().addr
|
||||
);
|
||||
assert_eq!(2, fe_opts.mysql_options.as_ref().unwrap().runtime_size);
|
||||
assert_eq!(
|
||||
None,
|
||||
fe_opts.mysql_options.as_ref().unwrap().reject_no_database
|
||||
);
|
||||
assert!(fe_opts.influxdb_options.as_ref().unwrap().enable);
|
||||
}
|
||||
|
||||
|
||||
@@ -18,15 +18,18 @@ use std::sync::Arc;
|
||||
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use common_telemetry::tracing::log::info;
|
||||
use servers::error::Error::InternalIo;
|
||||
use servers::grpc::GrpcServer;
|
||||
use servers::mysql::server::MysqlServer;
|
||||
use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef};
|
||||
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor;
|
||||
use servers::query_handler::sql::ServerSqlQueryHandlerAdaptor;
|
||||
use servers::server::Server;
|
||||
use servers::tls::TlsOption;
|
||||
use servers::Mode;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::datanode::DatanodeOptions;
|
||||
use crate::error::Error::StartServer;
|
||||
use crate::error::{ParseAddrSnafu, Result, RuntimeResourceSnafu, StartServerSnafu};
|
||||
use crate::instance::InstanceRef;
|
||||
|
||||
@@ -61,11 +64,24 @@ impl Services {
|
||||
.build()
|
||||
.context(RuntimeResourceSnafu)?,
|
||||
);
|
||||
let tls = TlsOption::default();
|
||||
// default tls config returns None
|
||||
// but try to think a better way to do this
|
||||
Some(MysqlServer::create_server(
|
||||
ServerSqlQueryHandlerAdaptor::arc(instance.clone()),
|
||||
mysql_io_runtime,
|
||||
Default::default(),
|
||||
None,
|
||||
Arc::new(MysqlSpawnRef::new(
|
||||
ServerSqlQueryHandlerAdaptor::arc(instance.clone()),
|
||||
None,
|
||||
)),
|
||||
Arc::new(MysqlSpawnConfig::new(
|
||||
tls.should_force_tls(),
|
||||
tls.setup()
|
||||
.map_err(|e| StartServer {
|
||||
source: InternalIo { source: e },
|
||||
})?
|
||||
.map(Arc::new),
|
||||
false,
|
||||
)),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
@@ -21,6 +21,7 @@ pub struct MysqlOptions {
|
||||
pub runtime_size: usize,
|
||||
#[serde(default = "Default::default")]
|
||||
pub tls: TlsOption,
|
||||
pub reject_no_database: Option<bool>,
|
||||
}
|
||||
|
||||
impl Default for MysqlOptions {
|
||||
@@ -29,6 +30,7 @@ impl Default for MysqlOptions {
|
||||
addr: "127.0.0.1:4002".to_string(),
|
||||
runtime_size: 2,
|
||||
tls: TlsOption::default(),
|
||||
reject_no_database: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,9 +18,10 @@ use std::sync::Arc;
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use common_telemetry::info;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::error::Error::InternalIo;
|
||||
use servers::grpc::GrpcServer;
|
||||
use servers::http::HttpServer;
|
||||
use servers::mysql::server::MysqlServer;
|
||||
use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef};
|
||||
use servers::opentsdb::OpentsdbServer;
|
||||
use servers::postgres::PostgresServer;
|
||||
use servers::query_handler::grpc::ServerGrpcQueryHandlerAdaptor;
|
||||
@@ -29,6 +30,7 @@ use servers::server::Server;
|
||||
use snafu::ResultExt;
|
||||
use tokio::try_join;
|
||||
|
||||
use crate::error::Error::StartServer;
|
||||
use crate::error::{self, Result};
|
||||
use crate::frontend::FrontendOptions;
|
||||
use crate::influxdb::InfluxdbOptions;
|
||||
@@ -81,12 +83,22 @@ impl Services {
|
||||
.build()
|
||||
.context(error::RuntimeResourceSnafu)?,
|
||||
);
|
||||
|
||||
let mysql_server = MysqlServer::create_server(
|
||||
ServerSqlQueryHandlerAdaptor::arc(instance.clone()),
|
||||
mysql_io_runtime,
|
||||
opts.tls.clone(),
|
||||
user_provider.clone(),
|
||||
Arc::new(MysqlSpawnRef::new(
|
||||
ServerSqlQueryHandlerAdaptor::arc(instance.clone()),
|
||||
user_provider.clone(),
|
||||
)),
|
||||
Arc::new(MysqlSpawnConfig::new(
|
||||
opts.tls.should_force_tls(),
|
||||
opts.tls
|
||||
.setup()
|
||||
.map_err(|e| StartServer {
|
||||
source: InternalIo { source: e },
|
||||
})?
|
||||
.map(Arc::new),
|
||||
opts.reject_no_database.unwrap_or(false),
|
||||
)),
|
||||
);
|
||||
|
||||
Some((mysql_server, mysql_addr))
|
||||
|
||||
@@ -36,7 +36,7 @@ metrics = "0.20"
|
||||
num_cpus = "1.13"
|
||||
once_cell = "1.16"
|
||||
openmetrics-parser = "0.4"
|
||||
opensrv-mysql = "0.3"
|
||||
opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", rev = "b44c9d1360da297b305abf33aecfa94888e1554c" }
|
||||
pgwire = "0.6.3"
|
||||
pin-project = "1.0"
|
||||
prost.workspace = true
|
||||
|
||||
@@ -187,6 +187,7 @@ pub mod test {
|
||||
use std::fs::File;
|
||||
use std::io::{LineWriter, Write};
|
||||
|
||||
use session::context::UserInfo;
|
||||
use tempdir::TempDir;
|
||||
|
||||
use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider};
|
||||
@@ -216,7 +217,7 @@ pub mod test {
|
||||
assert_eq!(sha1_2, sha1_2_answer);
|
||||
}
|
||||
|
||||
async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) {
|
||||
async fn test_authenticate(provider: &dyn UserProvider, username: &str, password: &str) {
|
||||
let re = provider
|
||||
.authenticate(
|
||||
Identity::UserId(username, None),
|
||||
@@ -226,11 +227,20 @@ pub mod test {
|
||||
assert!(re.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_authorize() {
|
||||
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
|
||||
let re = provider
|
||||
.authorize("catalog", "schema", &UserInfo::new("root"))
|
||||
.await;
|
||||
assert!(re.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_inline_provider() {
|
||||
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
|
||||
test_auth(&provider, "root", "123456").await;
|
||||
test_auth(&provider, "admin", "654321").await;
|
||||
test_authenticate(&provider, "root", "123456").await;
|
||||
test_authenticate(&provider, "admin", "654321").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -254,7 +264,7 @@ admin=654321",
|
||||
|
||||
let param = format!("file:{file_path}");
|
||||
let provider = StaticUserProvider::try_from(param.as_str()).unwrap();
|
||||
test_auth(&provider, "root", "123456").await;
|
||||
test_auth(&provider, "admin", "654321").await;
|
||||
test_authenticate(&provider, "root", "123456").await;
|
||||
test_authenticate(&provider, "admin", "654321").await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,8 +44,8 @@ pub struct MysqlInstanceShim {
|
||||
impl MysqlInstanceShim {
|
||||
pub fn create(
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
client_addr: SocketAddr,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
client_addr: SocketAddr,
|
||||
) -> MysqlInstanceShim {
|
||||
// init a random salt
|
||||
let mut bs = vec![0u8; 20];
|
||||
|
||||
@@ -33,39 +33,89 @@ use crate::error::{Error, Result};
|
||||
use crate::mysql::handler::MysqlInstanceShim;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
use crate::server::{AbortableStream, BaseTcpServer, Server};
|
||||
use crate::tls::TlsOption;
|
||||
|
||||
// Default size of ResultSet write buffer: 100KB
|
||||
const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
|
||||
|
||||
struct MysqlRuntimeOption {
|
||||
/// [`MysqlSpawnRef`] stores arc refs
|
||||
/// that should be passed to new [`MysqlInstanceShim`]s.
|
||||
pub struct MysqlSpawnRef {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
tls_conf: Option<Arc<ServerConfig>>,
|
||||
force_tls: bool,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
}
|
||||
|
||||
type MysqlRuntimeOptionRef = Arc<MysqlRuntimeOption>;
|
||||
impl MysqlSpawnRef {
|
||||
pub fn new(
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
) -> MysqlSpawnRef {
|
||||
MysqlSpawnRef {
|
||||
query_handler,
|
||||
user_provider,
|
||||
}
|
||||
}
|
||||
|
||||
fn query_handler(&self) -> ServerSqlQueryHandlerRef {
|
||||
self.query_handler.clone()
|
||||
}
|
||||
fn user_provider(&self) -> Option<UserProviderRef> {
|
||||
self.user_provider.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// [`MysqlSpawnConfig`] stores config values
|
||||
/// which are used to initialize [`MysqlInstanceShim`]s.
|
||||
pub struct MysqlSpawnConfig {
|
||||
// tls config
|
||||
force_tls: bool,
|
||||
tls: Option<Arc<ServerConfig>>,
|
||||
// other shim config
|
||||
reject_no_database: bool,
|
||||
}
|
||||
|
||||
impl MysqlSpawnConfig {
|
||||
pub fn new(
|
||||
force_tls: bool,
|
||||
tls: Option<Arc<ServerConfig>>,
|
||||
reject_no_database: bool,
|
||||
) -> MysqlSpawnConfig {
|
||||
MysqlSpawnConfig {
|
||||
force_tls,
|
||||
tls,
|
||||
reject_no_database,
|
||||
}
|
||||
}
|
||||
|
||||
fn tls(&self) -> Option<Arc<ServerConfig>> {
|
||||
self.tls.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&MysqlSpawnConfig> for IntermediaryOptions {
|
||||
fn from(value: &MysqlSpawnConfig) -> Self {
|
||||
IntermediaryOptions {
|
||||
reject_connection_on_dbname_absence: value.reject_no_database,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MysqlServer {
|
||||
base_server: BaseTcpServer,
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
tls: TlsOption,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
spawn_ref: Arc<MysqlSpawnRef>,
|
||||
spawn_config: Arc<MysqlSpawnConfig>,
|
||||
}
|
||||
|
||||
impl MysqlServer {
|
||||
pub fn create_server(
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
io_runtime: Arc<Runtime>,
|
||||
tls: TlsOption,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
spawn_ref: Arc<MysqlSpawnRef>,
|
||||
spawn_config: Arc<MysqlSpawnConfig>,
|
||||
) -> Box<dyn Server> {
|
||||
Box::new(MysqlServer {
|
||||
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
|
||||
query_handler,
|
||||
tls,
|
||||
user_provider,
|
||||
spawn_ref,
|
||||
spawn_config,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -73,32 +123,21 @@ impl MysqlServer {
|
||||
&self,
|
||||
io_runtime: Arc<Runtime>,
|
||||
stream: AbortableStream,
|
||||
tls_conf: Option<Arc<ServerConfig>>,
|
||||
) -> impl Future<Output = ()> {
|
||||
let query_handler = self.query_handler.clone();
|
||||
let user_provider = self.user_provider.clone();
|
||||
|
||||
let force_tls = self.tls.should_force_tls();
|
||||
let spawn_ref = self.spawn_ref.clone();
|
||||
let spawn_config = self.spawn_config.clone();
|
||||
|
||||
stream.for_each(move |tcp_stream| {
|
||||
let io_runtime = io_runtime.clone();
|
||||
let query_handler = query_handler.clone();
|
||||
let user_provider = user_provider.clone();
|
||||
let tls_conf = tls_conf.clone();
|
||||
|
||||
let mysql_runtime_option = Arc::new(MysqlRuntimeOption {
|
||||
query_handler,
|
||||
tls_conf,
|
||||
force_tls,
|
||||
user_provider,
|
||||
});
|
||||
let spawn_ref = spawn_ref.clone();
|
||||
let spawn_config = spawn_config.clone();
|
||||
|
||||
async move {
|
||||
match tcp_stream {
|
||||
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
|
||||
Ok(io_stream) => {
|
||||
if let Err(error) =
|
||||
Self::handle(io_stream, io_runtime, mysql_runtime_option).await
|
||||
Self::handle(io_stream, io_runtime, spawn_ref, spawn_config).await
|
||||
{
|
||||
error!(error; "Unexpected error when handling TcpStream");
|
||||
};
|
||||
@@ -111,12 +150,13 @@ impl MysqlServer {
|
||||
async fn handle(
|
||||
stream: TcpStream,
|
||||
io_runtime: Arc<Runtime>,
|
||||
runtime_opts: MysqlRuntimeOptionRef,
|
||||
spawn_ref: Arc<MysqlSpawnRef>,
|
||||
spawn_config: Arc<MysqlSpawnConfig>,
|
||||
) -> Result<()> {
|
||||
info!("MySQL connection coming from: {}", stream.peer_addr()?);
|
||||
io_runtime .spawn(async move {
|
||||
io_runtime.spawn(async move {
|
||||
// TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client.
|
||||
if let Err(e) = Self::do_handle(stream, runtime_opts).await {
|
||||
if let Err(e) = Self::do_handle(stream, spawn_ref, spawn_config).await {
|
||||
// TODO(LFC): Write this error to client as well, in MySQL text protocol.
|
||||
// Looks like we have to expose opensrv-mysql's `PacketWriter`?
|
||||
error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.")
|
||||
@@ -126,31 +166,32 @@ impl MysqlServer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_handle(stream: TcpStream, runtime_opts: MysqlRuntimeOptionRef) -> Result<()> {
|
||||
async fn do_handle(
|
||||
stream: TcpStream,
|
||||
spawn_ref: Arc<MysqlSpawnRef>,
|
||||
spawn_config: Arc<MysqlSpawnConfig>,
|
||||
) -> Result<()> {
|
||||
let mut shim = MysqlInstanceShim::create(
|
||||
runtime_opts.query_handler.clone(),
|
||||
spawn_ref.query_handler(),
|
||||
spawn_ref.user_provider(),
|
||||
stream.peer_addr()?,
|
||||
runtime_opts.user_provider.clone(),
|
||||
);
|
||||
let (mut r, w) = stream.into_split();
|
||||
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
|
||||
let ops = IntermediaryOptions::default();
|
||||
|
||||
let (client_tls, init_params) = AsyncMysqlIntermediary::init_before_ssl(
|
||||
&mut shim,
|
||||
&mut r,
|
||||
&mut w,
|
||||
&runtime_opts.tls_conf,
|
||||
)
|
||||
.await?;
|
||||
let ops = spawn_config.as_ref().into();
|
||||
|
||||
if runtime_opts.force_tls && !client_tls {
|
||||
let (client_tls, init_params) =
|
||||
AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &spawn_config.tls())
|
||||
.await?;
|
||||
|
||||
if spawn_config.force_tls && !client_tls {
|
||||
return Err(Error::TlsRequired {
|
||||
server: "mysql".to_owned(),
|
||||
});
|
||||
}
|
||||
|
||||
match runtime_opts.tls_conf.clone() {
|
||||
match spawn_config.tls() {
|
||||
Some(tls_conf) if client_tls => {
|
||||
secure_run_with_options(shim, w, ops, tls_conf, init_params).await
|
||||
}
|
||||
@@ -167,12 +208,9 @@ impl Server for MysqlServer {
|
||||
|
||||
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
let (stream, addr) = self.base_server.bind(listening).await?;
|
||||
|
||||
let io_runtime = self.base_server.io_runtime();
|
||||
|
||||
let tls_conf = self.tls.setup()?.map(Arc::new);
|
||||
|
||||
let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_conf));
|
||||
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
|
||||
self.base_server.start_with(join_handle).await?;
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use pgwire::messages::response::ErrorResponse;
|
||||
use pgwire::messages::startup::Authentication;
|
||||
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
|
||||
use session::context::{UserInfo, DEFAULT_USERNAME};
|
||||
use session::context::UserInfo;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
@@ -202,21 +202,6 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
))
|
||||
.await?;
|
||||
} else {
|
||||
// no user is provided, use default user
|
||||
// and still do authorization
|
||||
let mut login_info = LoginInfo::from_client_info(client);
|
||||
login_info.user = Some(DEFAULT_USERNAME.to_string());
|
||||
|
||||
let authorize_result = self.verifier.authorize(&login_info).await;
|
||||
if !matches!(authorize_result, Ok(true)) {
|
||||
return send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
"28P01",
|
||||
"password authorization failed".to_owned(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
auth::finish_authentication(client, &self.param_provider).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ use mysql_async::SslOpts;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use servers::error::Result;
|
||||
use servers::mysql::server::MysqlServer;
|
||||
use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef};
|
||||
use servers::server::Server;
|
||||
use servers::tls::TlsOption;
|
||||
use table::test_util::MemTable;
|
||||
@@ -34,11 +34,14 @@ use crate::auth::{DatabaseAuthInfo, MockUserProvider};
|
||||
use crate::create_testing_sql_query_handler;
|
||||
use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData};
|
||||
|
||||
fn create_mysql_server(
|
||||
table: MemTable,
|
||||
#[derive(Default)]
|
||||
struct MysqlOpts<'a> {
|
||||
tls: TlsOption,
|
||||
auth_info: Option<DatabaseAuthInfo>,
|
||||
) -> Result<Box<dyn Server>> {
|
||||
auth_info: Option<DatabaseAuthInfo<'a>>,
|
||||
reject_no_database: bool,
|
||||
}
|
||||
|
||||
fn create_mysql_server(table: MemTable, opts: MysqlOpts<'_>) -> Result<Box<dyn Server>> {
|
||||
let query_handler = create_testing_sql_query_handler(table);
|
||||
let io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
@@ -49,15 +52,18 @@ fn create_mysql_server(
|
||||
);
|
||||
|
||||
let mut provider = MockUserProvider::default();
|
||||
if let Some(auth_info) = auth_info {
|
||||
if let Some(auth_info) = opts.auth_info {
|
||||
provider.set_authorization_info(auth_info);
|
||||
}
|
||||
|
||||
Ok(MysqlServer::create_server(
|
||||
query_handler,
|
||||
io_runtime,
|
||||
tls,
|
||||
Some(Arc::new(provider)),
|
||||
Arc::new(MysqlSpawnRef::new(query_handler, Some(Arc::new(provider)))),
|
||||
Arc::new(MysqlSpawnConfig::new(
|
||||
opts.tls.should_force_tls(),
|
||||
opts.tls.setup()?.map(Arc::new),
|
||||
opts.reject_no_database,
|
||||
)),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -65,7 +71,7 @@ fn create_mysql_server(
|
||||
async fn test_start_mysql_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mysql_server = create_mysql_server(table, Default::default(), None)?;
|
||||
let mysql_server = create_mysql_server(table, Default::default())?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = mysql_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -78,11 +84,42 @@ async fn test_start_mysql_server() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reject_no_database() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let table = MemTable::default_numbers_table();
|
||||
let mysql_server = create_mysql_server(
|
||||
table,
|
||||
MysqlOpts {
|
||||
reject_no_database: true,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
|
||||
let fail = create_connection(server_port, None, false).await;
|
||||
assert!(fail.is_err());
|
||||
let pass = create_connection(server_port, Some("public"), false).await;
|
||||
assert!(pass.is_ok());
|
||||
let result = mysql_server.shutdown().await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema_validation() -> Result<()> {
|
||||
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
let mysql_server = create_mysql_server(table, Default::default(), Some(auth_info))?;
|
||||
let mysql_server = create_mysql_server(
|
||||
table,
|
||||
MysqlOpts {
|
||||
auth_info: Some(auth_info),
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
Ok((mysql_server, server_addr.port()))
|
||||
@@ -96,9 +133,7 @@ async fn test_schema_validation() -> Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
//TODO(shuiyisong): mysql conn without dbname rejection is not implemented yet, add test later.
|
||||
|
||||
let pass = create_connection(server_port, Some("public"), false).await;
|
||||
let pass = create_connection_default_db_name(server_port, false).await;
|
||||
assert!(pass.is_ok());
|
||||
let result = mysql_server.shutdown().await;
|
||||
assert!(result.is_ok());
|
||||
@@ -111,7 +146,7 @@ async fn test_schema_validation() -> Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
let fail = create_connection(server_port, Some("public"), false).await;
|
||||
let fail = create_connection_default_db_name(server_port, false).await;
|
||||
assert!(fail.is_err());
|
||||
let result = mysql_server.shutdown().await;
|
||||
assert!(result.is_ok());
|
||||
@@ -125,7 +160,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mysql_server = create_mysql_server(table, Default::default(), None)?;
|
||||
let mysql_server = create_mysql_server(table, Default::default())?;
|
||||
let result = mysql_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -140,7 +175,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
for _ in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port, None, false).await {
|
||||
match create_connection_default_db_name(server_port, false).await {
|
||||
Ok(mut connection) => {
|
||||
let result: u32 = connection
|
||||
.query_first("SELECT uint32s FROM numbers LIMIT 1")
|
||||
@@ -230,7 +265,13 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table, server_tls, None)?;
|
||||
let mysql_server = create_mysql_server(
|
||||
table,
|
||||
MysqlOpts {
|
||||
tls: server_tls,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
@@ -261,12 +302,18 @@ async fn test_server_required_secure_client_plain_with_pkcs8_priv_key() -> Resul
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table, server_tls, None)?;
|
||||
let mysql_server = create_mysql_server(
|
||||
table,
|
||||
MysqlOpts {
|
||||
tls: server_tls,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let r = create_connection(server_addr.port(), None, client_tls).await;
|
||||
let r = create_connection_default_db_name(server_addr.port(), client_tls).await;
|
||||
assert!(r.is_err());
|
||||
Ok(())
|
||||
}
|
||||
@@ -287,15 +334,19 @@ async fn test_db_name() -> Result<()> {
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table, server_tls, None)?;
|
||||
let mysql_server = create_mysql_server(
|
||||
table,
|
||||
MysqlOpts {
|
||||
tls: server_tls,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let r = create_connection(server_addr.port(), None, client_tls).await;
|
||||
assert!(r.is_ok());
|
||||
|
||||
let r = create_connection(server_addr.port(), Some(DEFAULT_SCHEMA_NAME), client_tls).await;
|
||||
// None actually uses default database name
|
||||
let r = create_connection_default_db_name(server_addr.port(), client_tls).await;
|
||||
assert!(r.is_ok());
|
||||
|
||||
let r = create_connection(server_addr.port(), Some("tomcat"), client_tls).await;
|
||||
@@ -315,12 +366,18 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) ->
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table, server_tls, None)?;
|
||||
let mysql_server = create_mysql_server(
|
||||
table,
|
||||
MysqlOpts {
|
||||
tls: server_tls,
|
||||
..Default::default()
|
||||
},
|
||||
)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let mut connection = create_connection(server_addr.port(), None, client_tls)
|
||||
let mut connection = create_connection_default_db_name(server_addr.port(), client_tls)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -350,7 +407,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mysql_server = create_mysql_server(table, Default::default(), None)?;
|
||||
let mysql_server = create_mysql_server(table, Default::default())?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
@@ -362,7 +419,9 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut connection = create_connection(server_port, None, false).await.unwrap();
|
||||
let mut connection = create_connection_default_db_name(server_port, false)
|
||||
.await
|
||||
.unwrap();
|
||||
for _ in 0..expect_executed_queries_per_worker {
|
||||
let expected: u32 = rand.gen_range(0..100);
|
||||
let result: u32 = connection
|
||||
@@ -376,7 +435,9 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
connection = create_connection(server_port, None, false).await.unwrap();
|
||||
connection = create_connection_default_db_name(server_port, false)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
expect_executed_queries_per_worker
|
||||
@@ -390,6 +451,13 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection_default_db_name(
|
||||
port: u16,
|
||||
ssl: bool,
|
||||
) -> mysql_async::Result<mysql_async::Conn> {
|
||||
create_connection(port, Some(DEFAULT_SCHEMA_NAME), ssl).await
|
||||
}
|
||||
|
||||
async fn create_connection(
|
||||
port: u16,
|
||||
db_name: Option<&str>,
|
||||
@@ -400,10 +468,13 @@ async fn create_connection(
|
||||
.tcp_port(port)
|
||||
.prefer_socket(false)
|
||||
.wait_timeout(Some(1000))
|
||||
.db_name(db_name.or(Some(DEFAULT_SCHEMA_NAME)))
|
||||
.user(Some("greptime".to_string()))
|
||||
.pass(Some("greptime".to_string()));
|
||||
|
||||
if let Some(db_name) = db_name {
|
||||
opts = opts.db_name(Some(db_name.to_string()));
|
||||
}
|
||||
|
||||
if ssl {
|
||||
let ssl_opts = SslOpts::default()
|
||||
.with_danger_skip_domain_validation(true)
|
||||
|
||||
Reference in New Issue
Block a user