refactor: do not call use upon mysql connection (#818)

This commit is contained in:
Ning Sun
2023-01-03 19:15:47 +08:00
committed by GitHub
parent f907a93b97
commit 041cd422a1
3 changed files with 64 additions and 13 deletions

View File

@@ -249,6 +249,9 @@ pub enum Error {
#[snafu(backtrace)]
source: common_grpc::error::Error,
},
#[snafu(display("Cannot find requested database: {}-{}", catalog, schema))]
DatabaseNotFound { catalog: String, schema: String },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -306,6 +309,8 @@ impl ErrorExt for Error {
| InvalidAuthorizationHeader { .. }
| InvalidBase64Value { .. }
| InvalidUtf8Value { .. } => StatusCode::InvalidAuthHeader,
DatabaseNotFound { .. } => StatusCode::DatabaseNotFound,
}
}

View File

@@ -17,6 +17,7 @@ use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use common_catalog::consts::DEFAULT_CATALOG_NAME;
use common_query::Output;
use common_telemetry::{error, trace};
use opensrv_mysql::{
@@ -183,14 +184,21 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
}
async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
let query = format!("USE {}", database.trim());
let output = self.do_query(&query).await.remove(0);
if let Err(e) = output {
w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes())
.await
// TODO(sunng87): set catalog
if self
.query_handler
.is_valid_schema(DEFAULT_CATALOG_NAME, database)?
{
let context = self.session.context();
// TODO(sunng87): set catalog
context.set_current_schema(database);
w.ok().await.map_err(|e| e.into())
} else {
w.ok().await
error::DatabaseNotFoundSnafu {
catalog: DEFAULT_CATALOG_NAME,
schema: database,
}
.fail()
}
.map_err(|e| e.into())
}
}

View File

@@ -16,6 +16,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_recordbatch::RecordBatch;
use common_runtime::Builder as RuntimeBuilder;
use datatypes::schema::Schema;
@@ -91,7 +92,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
for _ in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port, false).await {
match create_connection(server_port, None, false).await {
Ok(mut connection) => {
let result: u32 = connection
.query_first("SELECT uint32s FROM numbers LIMIT 1")
@@ -197,7 +198,39 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let r = create_connection(server_addr.port(), client_tls).await;
let r = create_connection(server_addr.port(), None, client_tls).await;
assert!(r.is_err());
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_db_name() -> Result<()> {
let server_tls = TlsOption::default();
let client_tls = false;
#[allow(unused)]
let TestingData {
column_schemas,
mysql_columns_def,
columns,
mysql_text_output_rows,
} = all_datatype_testing_data();
let schema = Arc::new(Schema::new(column_schemas.clone()));
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::new("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(table, server_tls)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let r = create_connection(server_addr.port(), None, client_tls).await;
assert!(r.is_ok());
let r = create_connection(server_addr.port(), Some(DEFAULT_SCHEMA_NAME), client_tls).await;
assert!(r.is_ok());
let r = create_connection(server_addr.port(), Some("tomcat"), client_tls).await;
assert!(r.is_err());
Ok(())
}
@@ -219,7 +252,7 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) ->
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let mut connection = create_connection(server_addr.port(), client_tls)
let mut connection = create_connection(server_addr.port(), None, client_tls)
.await
.unwrap();
@@ -261,7 +294,7 @@ async fn test_query_concurrently() -> Result<()> {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut connection = create_connection(server_port, false).await.unwrap();
let mut connection = create_connection(server_port, None, false).await.unwrap();
for _ in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
let result: u32 = connection
@@ -275,7 +308,7 @@ async fn test_query_concurrently() -> Result<()> {
let should_recreate_conn = expected == 1;
if should_recreate_conn {
connection = create_connection(server_port, false).await.unwrap();
connection = create_connection(server_port, None, false).await.unwrap();
}
}
expect_executed_queries_per_worker
@@ -289,12 +322,17 @@ async fn test_query_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection(port: u16, ssl: bool) -> mysql_async::Result<mysql_async::Conn> {
async fn create_connection(
port: u16,
db_name: Option<&str>,
ssl: bool,
) -> mysql_async::Result<mysql_async::Conn> {
let mut opts = mysql_async::OptsBuilder::default()
.ip_or_hostname("127.0.0.1")
.tcp_port(port)
.prefer_socket(false)
.wait_timeout(Some(1000))
.db_name(db_name)
.user(Some("greptime".to_string()))
.pass(Some("greptime".to_string()));