From 041cd422a1b284b3a267613eab79031da8ff571a Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Tue, 3 Jan 2023 19:15:47 +0800 Subject: [PATCH] refactor: do not call use upon mysql connection (#818) --- src/servers/src/error.rs | 5 ++ src/servers/src/mysql/handler.rs | 22 ++++++--- src/servers/tests/mysql/mysql_server_test.rs | 50 +++++++++++++++++--- 3 files changed, 64 insertions(+), 13 deletions(-) diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index fa790739a0..8e110e6b0e 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -249,6 +249,9 @@ pub enum Error { #[snafu(backtrace)] source: common_grpc::error::Error, }, + + #[snafu(display("Cannot find requested database: {}-{}", catalog, schema))] + DatabaseNotFound { catalog: String, schema: String }, } pub type Result = std::result::Result; @@ -306,6 +309,8 @@ impl ErrorExt for Error { | InvalidAuthorizationHeader { .. } | InvalidBase64Value { .. } | InvalidUtf8Value { .. } => StatusCode::InvalidAuthHeader, + + DatabaseNotFound { .. } => StatusCode::DatabaseNotFound, } } diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index c237ddc5d6..0e9b3a1b2b 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use std::time::Instant; use async_trait::async_trait; +use common_catalog::consts::DEFAULT_CATALOG_NAME; use common_query::Output; use common_telemetry::{error, trace}; use opensrv_mysql::{ @@ -183,14 +184,21 @@ impl AsyncMysqlShim for MysqlInstanceShi } async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> { - let query = format!("USE {}", database.trim()); - let output = self.do_query(&query).await.remove(0); - if let Err(e) = output { - w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes()) - .await + // TODO(sunng87): set catalog + if self + .query_handler + .is_valid_schema(DEFAULT_CATALOG_NAME, database)? + { + let context = self.session.context(); + // TODO(sunng87): set catalog + context.set_current_schema(database); + w.ok().await.map_err(|e| e.into()) } else { - w.ok().await + error::DatabaseNotFoundSnafu { + catalog: DEFAULT_CATALOG_NAME, + schema: database, + } + .fail() } - .map_err(|e| e.into()) } } diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 27dfa27f25..280d0a1dc2 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -16,6 +16,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; +use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_recordbatch::RecordBatch; use common_runtime::Builder as RuntimeBuilder; use datatypes::schema::Schema; @@ -91,7 +92,7 @@ async fn test_shutdown_mysql_server() -> Result<()> { for _ in 0..2 { join_handles.push(tokio::spawn(async move { for _ in 0..1000 { - match create_connection(server_port, false).await { + match create_connection(server_port, None, false).await { Ok(mut connection) => { let result: u32 = connection .query_first("SELECT uint32s FROM numbers LIMIT 1") @@ -197,7 +198,39 @@ async fn test_server_required_secure_client_plain() -> Result<()> { let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); - let r = create_connection(server_addr.port(), client_tls).await; + let r = create_connection(server_addr.port(), None, client_tls).await; + assert!(r.is_err()); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_db_name() -> Result<()> { + let server_tls = TlsOption::default(); + let client_tls = false; + + #[allow(unused)] + let TestingData { + column_schemas, + mysql_columns_def, + columns, + mysql_text_output_rows, + } = all_datatype_testing_data(); + let schema = Arc::new(Schema::new(column_schemas.clone())); + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let table = MemTable::new("all_datatypes", recordbatch); + + let mysql_server = create_mysql_server(table, server_tls)?; + + let listening = "127.0.0.1:0".parse::().unwrap(); + let server_addr = mysql_server.start(listening).await.unwrap(); + + let r = create_connection(server_addr.port(), None, client_tls).await; + assert!(r.is_ok()); + + let r = create_connection(server_addr.port(), Some(DEFAULT_SCHEMA_NAME), client_tls).await; + assert!(r.is_ok()); + + let r = create_connection(server_addr.port(), Some("tomcat"), client_tls).await; assert!(r.is_err()); Ok(()) } @@ -219,7 +252,7 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) -> let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); - let mut connection = create_connection(server_addr.port(), client_tls) + let mut connection = create_connection(server_addr.port(), None, client_tls) .await .unwrap(); @@ -261,7 +294,7 @@ async fn test_query_concurrently() -> Result<()> { join_handles.push(tokio::spawn(async move { let mut rand: StdRng = rand::SeedableRng::from_entropy(); - let mut connection = create_connection(server_port, false).await.unwrap(); + let mut connection = create_connection(server_port, None, false).await.unwrap(); for _ in 0..expect_executed_queries_per_worker { let expected: u32 = rand.gen_range(0..100); let result: u32 = connection @@ -275,7 +308,7 @@ async fn test_query_concurrently() -> Result<()> { let should_recreate_conn = expected == 1; if should_recreate_conn { - connection = create_connection(server_port, false).await.unwrap(); + connection = create_connection(server_port, None, false).await.unwrap(); } } expect_executed_queries_per_worker @@ -289,12 +322,17 @@ async fn test_query_concurrently() -> Result<()> { Ok(()) } -async fn create_connection(port: u16, ssl: bool) -> mysql_async::Result { +async fn create_connection( + port: u16, + db_name: Option<&str>, + ssl: bool, +) -> mysql_async::Result { let mut opts = mysql_async::OptsBuilder::default() .ip_or_hostname("127.0.0.1") .tcp_port(port) .prefer_socket(false) .wait_timeout(Some(1000)) + .db_name(db_name) .user(Some("greptime".to_string())) .pass(Some("greptime".to_string()));