mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-07 05:42:57 +00:00
refactor: do not call use upon mysql connection (#818)
This commit is contained in:
@@ -249,6 +249,9 @@ pub enum Error {
|
||||
#[snafu(backtrace)]
|
||||
source: common_grpc::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Cannot find requested database: {}-{}", catalog, schema))]
|
||||
DatabaseNotFound { catalog: String, schema: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -306,6 +309,8 @@ impl ErrorExt for Error {
|
||||
| InvalidAuthorizationHeader { .. }
|
||||
| InvalidBase64Value { .. }
|
||||
| InvalidUtf8Value { .. } => StatusCode::InvalidAuthHeader,
|
||||
|
||||
DatabaseNotFound { .. } => StatusCode::DatabaseNotFound,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_catalog::consts::DEFAULT_CATALOG_NAME;
|
||||
use common_query::Output;
|
||||
use common_telemetry::{error, trace};
|
||||
use opensrv_mysql::{
|
||||
@@ -183,14 +184,21 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
}
|
||||
|
||||
async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
|
||||
let query = format!("USE {}", database.trim());
|
||||
let output = self.do_query(&query).await.remove(0);
|
||||
if let Err(e) = output {
|
||||
w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes())
|
||||
.await
|
||||
// TODO(sunng87): set catalog
|
||||
if self
|
||||
.query_handler
|
||||
.is_valid_schema(DEFAULT_CATALOG_NAME, database)?
|
||||
{
|
||||
let context = self.session.context();
|
||||
// TODO(sunng87): set catalog
|
||||
context.set_current_schema(database);
|
||||
w.ok().await.map_err(|e| e.into())
|
||||
} else {
|
||||
w.ok().await
|
||||
error::DatabaseNotFoundSnafu {
|
||||
catalog: DEFAULT_CATALOG_NAME,
|
||||
schema: database,
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use datatypes::schema::Schema;
|
||||
@@ -91,7 +92,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
for _ in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port, false).await {
|
||||
match create_connection(server_port, None, false).await {
|
||||
Ok(mut connection) => {
|
||||
let result: u32 = connection
|
||||
.query_first("SELECT uint32s FROM numbers LIMIT 1")
|
||||
@@ -197,7 +198,39 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let r = create_connection(server_addr.port(), client_tls).await;
|
||||
let r = create_connection(server_addr.port(), None, client_tls).await;
|
||||
assert!(r.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_db_name() -> Result<()> {
|
||||
let server_tls = TlsOption::default();
|
||||
let client_tls = false;
|
||||
|
||||
#[allow(unused)]
|
||||
let TestingData {
|
||||
column_schemas,
|
||||
mysql_columns_def,
|
||||
columns,
|
||||
mysql_text_output_rows,
|
||||
} = all_datatype_testing_data();
|
||||
let schema = Arc::new(Schema::new(column_schemas.clone()));
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table, server_tls)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let r = create_connection(server_addr.port(), None, client_tls).await;
|
||||
assert!(r.is_ok());
|
||||
|
||||
let r = create_connection(server_addr.port(), Some(DEFAULT_SCHEMA_NAME), client_tls).await;
|
||||
assert!(r.is_ok());
|
||||
|
||||
let r = create_connection(server_addr.port(), Some("tomcat"), client_tls).await;
|
||||
assert!(r.is_err());
|
||||
Ok(())
|
||||
}
|
||||
@@ -219,7 +252,7 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) ->
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let mut connection = create_connection(server_addr.port(), client_tls)
|
||||
let mut connection = create_connection(server_addr.port(), None, client_tls)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -261,7 +294,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut connection = create_connection(server_port, false).await.unwrap();
|
||||
let mut connection = create_connection(server_port, None, false).await.unwrap();
|
||||
for _ in 0..expect_executed_queries_per_worker {
|
||||
let expected: u32 = rand.gen_range(0..100);
|
||||
let result: u32 = connection
|
||||
@@ -275,7 +308,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
connection = create_connection(server_port, false).await.unwrap();
|
||||
connection = create_connection(server_port, None, false).await.unwrap();
|
||||
}
|
||||
}
|
||||
expect_executed_queries_per_worker
|
||||
@@ -289,12 +322,17 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection(port: u16, ssl: bool) -> mysql_async::Result<mysql_async::Conn> {
|
||||
async fn create_connection(
|
||||
port: u16,
|
||||
db_name: Option<&str>,
|
||||
ssl: bool,
|
||||
) -> mysql_async::Result<mysql_async::Conn> {
|
||||
let mut opts = mysql_async::OptsBuilder::default()
|
||||
.ip_or_hostname("127.0.0.1")
|
||||
.tcp_port(port)
|
||||
.prefer_socket(false)
|
||||
.wait_timeout(Some(1000))
|
||||
.db_name(db_name)
|
||||
.user(Some("greptime".to_string()))
|
||||
.pass(Some("greptime".to_string()));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user