diff --git a/src/datanode/src/server/grpc/ddl.rs b/src/datanode/src/server/grpc/ddl.rs index bb57008f92..579c0d96a6 100644 --- a/src/datanode/src/server/grpc/ddl.rs +++ b/src/datanode/src/server/grpc/ddl.rs @@ -181,7 +181,7 @@ mod tests { #[tokio::test] async fn test_create_expr_to_request() { - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(); + let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("create_expr_to_request"); let instance = Instance::new(&opts).await.unwrap(); instance.start().await.unwrap(); diff --git a/src/datanode/src/tests/grpc_test.rs b/src/datanode/src/tests/grpc_test.rs index 0ef48218ff..b0455b8ce7 100644 --- a/src/datanode/src/tests/grpc_test.rs +++ b/src/datanode/src/tests/grpc_test.rs @@ -15,37 +15,42 @@ use servers::grpc::GrpcServer; use servers::server::Server; use crate::instance::Instance; -use crate::tests::test_util; +use crate::tests::test_util::{self, TestGuard}; -async fn setup_grpc_server(port: usize) -> String { +async fn setup_grpc_server(name: &str, port: usize) -> (String, TestGuard, Arc) { common_telemetry::init_default_ut_logging(); - let (mut opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(); + let (mut opts, guard) = test_util::create_tmp_dir_and_datanode_opts(name); let addr = format!("127.0.0.1:{}", port); opts.rpc_addr = addr.clone(); let instance = Arc::new(Instance::new(&opts).await.unwrap()); instance.start().await.unwrap(); let addr_cloned = addr.clone(); + let grpc_server = Arc::new(GrpcServer::new(instance.clone(), instance)); + + let grpc_server_clone = grpc_server.clone(); tokio::spawn(async move { - let mut grpc_server = GrpcServer::new(instance.clone(), instance); let addr = addr_cloned.parse::().unwrap(); - grpc_server.start(addr).await.unwrap() + grpc_server_clone.start(addr).await.unwrap() }); // wait for GRPC server to start tokio::time::sleep(Duration::from_secs(1)).await; - addr + + (addr, guard, grpc_server) } #[tokio::test] async fn test_auto_create_table() { - let addr = setup_grpc_server(3991).await; + let (addr, _guard, grpc_server) = setup_grpc_server("auto_create_table", 3991).await; let grpc_client = Client::connect(format!("http://{}", addr)).await.unwrap(); let db = Database::new("greptime", grpc_client); insert_and_assert(&db).await; + + grpc_server.shutdown().await.unwrap(); } fn expect_data() -> (Column, Column, Column, Column) { @@ -104,7 +109,7 @@ fn expect_data() -> (Column, Column, Column, Column) { #[tokio::test] async fn test_insert_and_select() { - let addr = setup_grpc_server(3990).await; + let (addr, _guard, grpc_server) = setup_grpc_server("insert_and_select", 3990).await; let grpc_client = Client::connect(format!("http://{}", addr)).await.unwrap(); @@ -143,6 +148,8 @@ async fn test_insert_and_select() { // insert insert_and_assert(&db).await; + + grpc_server.shutdown().await.unwrap(); } async fn insert_and_assert(db: &Database) { diff --git a/src/datanode/src/tests/http_test.rs b/src/datanode/src/tests/http_test.rs index b9194178d8..25a83d3a28 100644 --- a/src/datanode/src/tests/http_test.rs +++ b/src/datanode/src/tests/http_test.rs @@ -13,8 +13,8 @@ use test_util::TestGuard; use crate::instance::Instance; use crate::tests::test_util; -async fn make_test_app() -> (Router, TestGuard) { - let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts(); +async fn make_test_app(name: &str) -> (Router, TestGuard) { + let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts(name); let instance = Arc::new(Instance::new(&opts).await.unwrap()); instance.start().await.unwrap(); test_util::create_test_table(&instance, ConcreteDataType::timestamp_millis_datatype()) @@ -27,7 +27,7 @@ async fn make_test_app() -> (Router, TestGuard) { #[tokio::test] async fn test_sql_api() { common_telemetry::init_default_ut_logging(); - let (app, _guard) = make_test_app().await; + let (app, _guard) = make_test_app("sql_api").await; let client = TestClient::new(app); let res = client.get("/v1/sql").send().await; assert_eq!(res.status(), StatusCode::OK); @@ -88,7 +88,7 @@ async fn test_sql_api() { async fn test_metrics_api() { common_telemetry::init_default_ut_logging(); common_telemetry::init_default_metrics_recorder(); - let (app, _guard) = make_test_app().await; + let (app, _guard) = make_test_app("metrics_api").await; let client = TestClient::new(app); // Send a sql @@ -108,7 +108,7 @@ async fn test_metrics_api() { #[tokio::test] async fn test_scripts_api() { common_telemetry::init_default_ut_logging(); - let (app, _guard) = make_test_app().await; + let (app, _guard) = make_test_app("scripts_api").await; let client = TestClient::new(app); let res = client .post("/v1/scripts") @@ -140,10 +140,10 @@ def test(n): } async fn start_test_app(addr: &str) -> (SocketAddr, TestGuard) { - let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts(); + let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts("py_side_scripts_api"); let instance = Arc::new(Instance::new(&opts).await.unwrap()); instance.start().await.unwrap(); - let mut http_server = HttpServer::new(instance); + let http_server = HttpServer::new(instance); ( http_server.start(addr.parse().unwrap()).await.unwrap(), guard, diff --git a/src/datanode/src/tests/instance_test.rs b/src/datanode/src/tests/instance_test.rs index 8c23867fd5..a65e867780 100644 --- a/src/datanode/src/tests/instance_test.rs +++ b/src/datanode/src/tests/instance_test.rs @@ -13,7 +13,7 @@ use crate::tests::test_util; async fn test_execute_insert() { common_telemetry::init_default_ut_logging(); - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(); + let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_insert"); let instance = Instance::new(&opts).await.unwrap(); instance.start().await.unwrap(); @@ -37,7 +37,7 @@ async fn test_execute_insert() { async fn test_execute_insert_query_with_i64_timestamp() { common_telemetry::init_default_ut_logging(); - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(); + let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("insert_query_i64_timestamp"); let instance = Instance::new(&opts).await.unwrap(); instance.start().await.unwrap(); @@ -74,7 +74,7 @@ async fn test_execute_insert_query_with_i64_timestamp() { #[tokio::test] async fn test_execute_query() { - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(); + let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_query"); let instance = Instance::new(&opts).await.unwrap(); instance.start().await.unwrap(); @@ -100,7 +100,8 @@ async fn test_execute_query() { #[tokio::test] async fn test_execute_show_databases_tables() { - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(); + let (opts, _guard) = + test_util::create_tmp_dir_and_datanode_opts("execute_show_databases_tables"); let instance = Instance::new(&opts).await.unwrap(); instance.start().await.unwrap(); @@ -191,7 +192,7 @@ async fn test_execute_show_databases_tables() { pub async fn test_execute_create() { common_telemetry::init_default_ut_logging(); - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(); + let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_create"); let instance = Instance::new(&opts).await.unwrap(); instance.start().await.unwrap(); @@ -215,7 +216,8 @@ pub async fn test_execute_create() { pub async fn test_create_table_illegal_timestamp_type() { common_telemetry::init_default_ut_logging(); - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(); + let (opts, _guard) = + test_util::create_tmp_dir_and_datanode_opts("create_table_illegal_timestamp_type"); let instance = Instance::new(&opts).await.unwrap(); instance.start().await.unwrap(); diff --git a/src/datanode/src/tests/test_util.rs b/src/datanode/src/tests/test_util.rs index 2866443769..7e284c8f66 100644 --- a/src/datanode/src/tests/test_util.rs +++ b/src/datanode/src/tests/test_util.rs @@ -24,9 +24,9 @@ pub struct TestGuard { _data_tmp_dir: TempDir, } -pub fn create_tmp_dir_and_datanode_opts() -> (DatanodeOptions, TestGuard) { - let wal_tmp_dir = TempDir::new("/tmp/greptimedb_test_wal").unwrap(); - let data_tmp_dir = TempDir::new("/tmp/greptimedb_test_data").unwrap(); +pub fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) { + let wal_tmp_dir = TempDir::new(&format!("gt_wal_{}", name)).unwrap(); + let data_tmp_dir = TempDir::new(&format!("gt_data_{}", name)).unwrap(); let opts = DatanodeOptions { wal_dir: wal_tmp_dir.path().to_str().unwrap().to_string(), storage: ObjectStoreConfig::File { diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index b34c7de5cb..2ee90c2dd0 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -131,7 +131,7 @@ fn parse_addr(addr: &str) -> Result { async fn start_server( server_and_addr: Option<(Box, SocketAddr)>, ) -> servers::error::Result> { - if let Some((mut server, addr)) = server_and_addr { + if let Some((server, addr)) = server_and_addr { server.start(addr).await.map(Some) } else { Ok(None) diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index af7f71f686..9060280093 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -42,6 +42,12 @@ pub enum Error { #[snafu(display("Failed to start gRPC server, source: {}", source))] StartGrpc { source: tonic::transport::Error }, + #[snafu(display("{} server is already started", server))] + AlreadyStarted { + server: String, + backtrace: Backtrace, + }, + #[snafu(display("Failed to bind address {}, source: {}", addr, source))] TcpBind { addr: SocketAddr, @@ -161,6 +167,7 @@ impl ErrorExt for Error { | CollectRecordbatch { .. } | StartHttp { .. } | StartGrpc { .. } + | AlreadyStarted { .. } | InvalidPromRemoteReadQueryResult { .. } | TcpBind { .. } => StatusCode::Internal, diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index 6410cbce0e..e000c580f5 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -5,12 +5,16 @@ use std::net::SocketAddr; use api::v1::{greptime_server, BatchRequest, BatchResponse}; use async_trait::async_trait; use common_telemetry::logging::info; +use futures::FutureExt; +use snafu::ensure; use snafu::ResultExt; use tokio::net::TcpListener; +use tokio::sync::oneshot::{self, Sender}; +use tokio::sync::Mutex; use tokio_stream::wrappers::TcpListenerStream; use tonic::{Request, Response, Status}; -use crate::error::{Result, StartGrpcSnafu, TcpBindSnafu}; +use crate::error::{AlreadyStartedSnafu, Result, StartGrpcSnafu, TcpBindSnafu}; use crate::grpc::handler::BatchHandler; use crate::query_handler::{GrpcAdminHandlerRef, GrpcQueryHandlerRef}; use crate::server::Server; @@ -18,6 +22,7 @@ use crate::server::Server; pub struct GrpcServer { query_handler: GrpcQueryHandlerRef, admin_handler: GrpcAdminHandlerRef, + shutdown_tx: Mutex>>, } impl GrpcServer { @@ -25,6 +30,7 @@ impl GrpcServer { Self { query_handler, admin_handler, + shutdown_tx: Mutex::new(None), } } @@ -54,23 +60,45 @@ impl greptime_server::Greptime for GrpcService { #[async_trait] impl Server for GrpcServer { - async fn shutdown(&mut self) -> Result<()> { - // TODO(LFC): shutdown grpc server - unimplemented!() + async fn shutdown(&self) -> Result<()> { + let mut shutdown_tx = self.shutdown_tx.lock().await; + if let Some(tx) = shutdown_tx.take() { + if tx.send(()).is_err() { + info!("Receiver dropped, the grpc server has already existed"); + } + } + info!("Shutdown grpc server"); + + Ok(()) } - async fn start(&mut self, addr: SocketAddr) -> Result { - let listener = TcpListener::bind(addr) - .await - .context(TcpBindSnafu { addr })?; - let addr = listener.local_addr().context(TcpBindSnafu { addr })?; - info!("GRPC server is bound to {}", addr); + async fn start(&self, addr: SocketAddr) -> Result { + let (tx, rx) = oneshot::channel(); + let (listener, addr) = { + let mut shutdown_tx = self.shutdown_tx.lock().await; + ensure!( + shutdown_tx.is_none(), + AlreadyStartedSnafu { server: "gRPC" } + ); + let listener = TcpListener::bind(addr) + .await + .context(TcpBindSnafu { addr })?; + let addr = listener.local_addr().context(TcpBindSnafu { addr })?; + info!("GRPC server is bound to {}", addr); + + *shutdown_tx = Some(tx); + + (listener, addr) + }; + + // Would block to serve requests. tonic::transport::Server::builder() .add_service(self.create_service()) - .serve_with_incoming(TcpListenerStream::new(listener)) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), rx.map(drop)) .await .context(StartGrpcSnafu)?; + Ok(addr) } } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 525b6f22c4..697ecb1519 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -20,15 +20,19 @@ use common_query::Output; use common_recordbatch::{util, RecordBatch}; use common_telemetry::logging::info; use datatypes::data_type::DataType; +use futures::FutureExt; use schemars::JsonSchema; use serde::Serialize; use serde_json::Value; +use snafu::ensure; use snafu::ResultExt; +use tokio::sync::oneshot::{self, Sender}; +use tokio::sync::Mutex; use tower::{timeout::TimeoutLayer, ServiceBuilder}; use tower_http::trace::TraceLayer; use self::influxdb::influxdb_write; -use crate::error::{Result, StartHttpSnafu}; +use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu}; use crate::query_handler::SqlQueryHandlerRef; use crate::query_handler::{ InfluxdbLineProtocolHandlerRef, OpentsdbProtocolHandlerRef, PrometheusProtocolHandlerRef, @@ -42,6 +46,7 @@ pub struct HttpServer { influxdb_handler: Option, opentsdb_handler: Option, prom_handler: Option, + shutdown_tx: Mutex>>, } #[derive(Debug, Serialize, JsonSchema)] @@ -195,14 +200,6 @@ impl JsonResponse { } } -async fn shutdown_signal() { - // Wait for the CTRL+C signal - // It has an issue on chrome: https://github.com/sigp/lighthouse/issues/478 - tokio::signal::ctrl_c() - .await - .expect("failed to install CTRL+C signal handler"); -} - async fn serve_api(Extension(api): Extension>) -> impl IntoApiResponse { Json(api) } @@ -218,6 +215,7 @@ impl HttpServer { opentsdb_handler: None, influxdb_handler: None, prom_handler: None, + shutdown_tx: Mutex::new(None), } } @@ -323,18 +321,40 @@ impl HttpServer { #[async_trait] impl Server for HttpServer { - async fn shutdown(&mut self) -> Result<()> { - // TODO(LFC): shutdown http server, and remove `shutdown_signal` above - unimplemented!() + async fn shutdown(&self) -> Result<()> { + let mut shutdown_tx = self.shutdown_tx.lock().await; + if let Some(tx) = shutdown_tx.take() { + if tx.send(()).is_err() { + info!("Receiver dropped, the HTTP server has already existed"); + } + } + info!("Shutdown HTTP server"); + + Ok(()) } - async fn start(&mut self, listening: SocketAddr) -> Result { - let app = self.make_app(); - let server = axum::Server::bind(&listening).serve(app.into_make_service()); + async fn start(&self, listening: SocketAddr) -> Result { + let (tx, rx) = oneshot::channel(); + let server = { + let mut shutdown_tx = self.shutdown_tx.lock().await; + ensure!( + shutdown_tx.is_none(), + AlreadyStartedSnafu { server: "HTTP" } + ); + + let app = self.make_app(); + let server = axum::Server::bind(&listening).serve(app.into_make_service()); + + *shutdown_tx = Some(tx); + + server + }; let listening = server.local_addr(); info!("HTTP server is bound to {}", listening); - let graceful = server.with_graceful_shutdown(shutdown_signal()); + + let graceful = server.with_graceful_shutdown(rx.map(drop)); graceful.await.context(StartHttpSnafu)?; + Ok(listening) } } diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index ff1b28145d..5cfa2c279a 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -68,16 +68,16 @@ impl MysqlServer { #[async_trait] impl Server for MysqlServer { - async fn shutdown(&mut self) -> Result<()> { + async fn shutdown(&self) -> Result<()> { self.base_server.shutdown().await } - async fn start(&mut self, listening: SocketAddr) -> Result { + async fn start(&self, listening: SocketAddr) -> Result { let (stream, addr) = self.base_server.bind(listening).await?; let io_runtime = self.base_server.io_runtime(); let join_handle = tokio::spawn(self.accept(io_runtime, stream)); - self.base_server.start_with(join_handle)?; + self.base_server.start_with(join_handle).await?; Ok(addr) } } diff --git a/src/servers/src/opentsdb.rs b/src/servers/src/opentsdb.rs index aaf04b0bd2..9129a2c669 100644 --- a/src/servers/src/opentsdb.rs +++ b/src/servers/src/opentsdb.rs @@ -85,18 +85,22 @@ impl OpentsdbServer { #[async_trait] impl Server for OpentsdbServer { - async fn shutdown(&mut self) -> Result<()> { + async fn shutdown(&self) -> Result<()> { + if let Some(tx) = &self.notify_shutdown { + // Err of broadcast sender does not mean that future calls to send will fail, so + // its return value is ignored here. + let _ = tx.send(()); + } self.base_server.shutdown().await?; - drop(self.notify_shutdown.take()); Ok(()) } - async fn start(&mut self, listening: SocketAddr) -> Result { + async fn start(&self, listening: SocketAddr) -> Result { let (stream, addr) = self.base_server.bind(listening).await?; let io_runtime = self.base_server.io_runtime(); let join_handle = tokio::spawn(self.accept(io_runtime, stream)); - self.base_server.start_with(join_handle)?; + self.base_server.start_with(join_handle).await?; Ok(addr) } } diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 43aec90e4b..655e37c4b3 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -68,16 +68,16 @@ impl PostgresServer { #[async_trait] impl Server for PostgresServer { - async fn shutdown(&mut self) -> Result<()> { + async fn shutdown(&self) -> Result<()> { self.base_server.shutdown().await } - async fn start(&mut self, listening: SocketAddr) -> Result { + async fn start(&self, listening: SocketAddr) -> Result { let (stream, addr) = self.base_server.bind(listening).await?; let io_runtime = self.base_server.io_runtime(); let join_handle = tokio::spawn(self.accept(io_runtime, stream)); - self.base_server.start_with(join_handle)?; + self.base_server.start_with(join_handle).await?; Ok(addr) } } diff --git a/src/servers/src/server.rs b/src/servers/src/server.rs index 1075cc6d2f..a8f7bfefcb 100644 --- a/src/servers/src/server.rs +++ b/src/servers/src/server.rs @@ -7,6 +7,7 @@ use common_telemetry::logging::{error, info}; use futures::future::AbortRegistration; use futures::future::{AbortHandle, Abortable}; use snafu::ResultExt; +use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio_stream::wrappers::TcpListenerStream; @@ -16,13 +17,16 @@ pub(crate) type AbortableStream = Abortable; #[async_trait] pub trait Server: Send { - async fn shutdown(&mut self) -> Result<()>; - async fn start(&mut self, listening: SocketAddr) -> Result; + /// Shutdown the server gracefully. + async fn shutdown(&self) -> Result<()>; + + /// Starts the server and binds on `listening`. + /// + /// Caller should ensure `start()` is only invoked once. + async fn start(&self, listening: SocketAddr) -> Result; } -pub(crate) struct BaseTcpServer { - name: String, - +struct AccpetTask { // `abort_handle` and `abort_registration` are used in pairs in shutting down the server. // They work like sender and receiver for aborting stream. When the server is shutting down, // calling `abort_handle.abort()` will "notify" `abort_registration` to stop emitting new @@ -32,23 +36,10 @@ pub(crate) struct BaseTcpServer { // A handle holding the TCP accepting task. join_handle: Option>, - - io_runtime: Arc, } -impl BaseTcpServer { - pub(crate) fn create_server(name: impl Into, io_runtime: Arc) -> Self { - let (abort_handle, registration) = AbortHandle::new_pair(); - Self { - name: name.into(), - abort_handle, - abort_registration: Some(registration), - join_handle: None, - io_runtime, - } - } - - pub(crate) async fn shutdown(&mut self) -> Result<()> { +impl AccpetTask { + async fn shutdown(&mut self, name: &str) -> Result<()> { match self.join_handle.take() { Some(join_handle) => { self.abort_handle.abort(); @@ -57,23 +48,24 @@ impl BaseTcpServer { // Couldn't use `error!(e; xxx)` because JoinError doesn't implement ErrorExt. error!( "Unexpected error during shutdown {} server, error: {}", - &self.name, error + name, error ); } else { - info!("{} server is shutdown.", &self.name); + info!("{} server is shutdown.", name); } Ok(()) } None => error::InternalSnafu { - err_msg: format!("{} server is not started.", &self.name), + err_msg: format!("{} server is not started.", name), } .fail()?, } } - pub(crate) async fn bind( + async fn bind( &mut self, addr: SocketAddr, + name: &str, ) -> Result<(Abortable, SocketAddr)> { match self.abort_registration.take() { Some(registration) => { @@ -81,33 +73,73 @@ impl BaseTcpServer { tokio::net::TcpListener::bind(addr) .await .context(error::TokioIoSnafu { - err_msg: format!("Failed to bind addr {}", addr), + err_msg: format!("{} failed to bind addr {}", name, addr), })?; // get actually bond addr in case input addr use port 0 let addr = listener.local_addr()?; - info!("{} server started at {}", &self.name, addr); + info!("{} server started at {}", name, addr); let stream = TcpListenerStream::new(listener); let stream = Abortable::new(stream, registration); Ok((stream, addr)) } None => error::InternalSnafu { - err_msg: format!("{} server has been started.", &self.name), + err_msg: format!("{} server has been started.", name), } .fail()?, } } - pub(crate) fn start_with(&mut self, join_handle: JoinHandle<()>) -> Result<()> { + fn start_with(&mut self, join_handle: JoinHandle<()>, name: &str) -> Result<()> { if self.join_handle.is_some() { return error::InternalSnafu { - err_msg: format!("{} server has been started.", &self.name), + err_msg: format!("{} server has been started.", name), } .fail(); } let _ = self.join_handle.insert(join_handle); + Ok(()) } +} + +pub(crate) struct BaseTcpServer { + name: String, + accept_task: Mutex, + io_runtime: Arc, +} + +impl BaseTcpServer { + pub(crate) fn create_server(name: impl Into, io_runtime: Arc) -> Self { + let (abort_handle, registration) = AbortHandle::new_pair(); + Self { + name: name.into(), + accept_task: Mutex::new(AccpetTask { + abort_handle, + abort_registration: Some(registration), + join_handle: None, + }), + io_runtime, + } + } + + pub(crate) async fn shutdown(&self) -> Result<()> { + let mut task = self.accept_task.lock().await; + task.shutdown(&self.name).await + } + + pub(crate) async fn bind( + &self, + addr: SocketAddr, + ) -> Result<(Abortable, SocketAddr)> { + let mut task = self.accept_task.lock().await; + task.bind(addr, &self.name).await + } + + pub(crate) async fn start_with(&self, join_handle: JoinHandle<()>) -> Result<()> { + let mut task = self.accept_task.lock().await; + task.start_with(join_handle, &self.name) + } pub(crate) fn io_runtime(&self) -> Arc { self.io_runtime.clone() diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 91e5f29367..6cdea2509f 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -32,7 +32,7 @@ fn create_mysql_server(table: MemTable) -> Result> { async fn test_start_mysql_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let mut mysql_server = create_mysql_server(table)?; + let mysql_server = create_mysql_server(table)?; let listening = "127.0.0.1:0".parse::().unwrap(); let result = mysql_server.start(listening).await; assert!(result.is_ok()); @@ -51,7 +51,7 @@ async fn test_shutdown_mysql_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let mut mysql_server = create_mysql_server(table)?; + let mysql_server = create_mysql_server(table)?; let result = mysql_server.shutdown().await; assert!(result .unwrap_err() @@ -110,7 +110,7 @@ async fn test_query_all_datatypes() -> Result<()> { let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::new("all_datatypes", recordbatch); - let mut mysql_server = create_mysql_server(table)?; + let mysql_server = create_mysql_server(table)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); @@ -141,7 +141,7 @@ async fn test_query_concurrently() -> Result<()> { let table = MemTable::default_numbers_table(); - let mut mysql_server = create_mysql_server(table)?; + let mysql_server = create_mysql_server(table)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); let server_port = server_addr.port(); diff --git a/src/servers/tests/opentsdb.rs b/src/servers/tests/opentsdb.rs index a02f795d49..7a4fffb8e4 100644 --- a/src/servers/tests/opentsdb.rs +++ b/src/servers/tests/opentsdb.rs @@ -52,7 +52,7 @@ fn create_opentsdb_server(tx: mpsc::Sender) -> Result> { #[tokio::test] async fn test_start_opentsdb_server() -> Result<()> { let (tx, _) = mpsc::channel(100); - let mut server = create_opentsdb_server(tx)?; + let server = create_opentsdb_server(tx)?; let listening = "127.0.0.1:0".parse::().unwrap(); let result = server.start(listening).await; assert!(result.is_ok()); @@ -68,7 +68,7 @@ async fn test_start_opentsdb_server() -> Result<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_shutdown_opentsdb_server_concurrently() -> Result<()> { let (tx, _) = mpsc::channel(100); - let mut server = create_opentsdb_server(tx)?; + let server = create_opentsdb_server(tx)?; let result = server.shutdown().await; assert!(result .unwrap_err() @@ -133,7 +133,7 @@ async fn test_shutdown_opentsdb_server_concurrently() -> Result<()> { #[tokio::test] async fn test_opentsdb_connection_shutdown() -> Result<()> { let (tx, _) = mpsc::channel(100); - let mut server = create_opentsdb_server(tx)?; + let server = create_opentsdb_server(tx)?; let result = server.shutdown().await; assert!(result .unwrap_err() @@ -178,7 +178,7 @@ async fn test_opentsdb_connection_shutdown() -> Result<()> { #[tokio::test] async fn test_opentsdb_connect_after_shutdown() -> Result<()> { let (tx, _) = mpsc::channel(100); - let mut server = create_opentsdb_server(tx)?; + let server = create_opentsdb_server(tx)?; let result = server.shutdown().await; assert!(result .unwrap_err() @@ -198,7 +198,7 @@ async fn test_opentsdb_connect_after_shutdown() -> Result<()> { #[tokio::test] async fn test_query() -> Result<()> { let (tx, mut rx) = mpsc::channel(10); - let mut server = create_opentsdb_server(tx)?; + let server = create_opentsdb_server(tx)?; let listening = "127.0.0.1:0".parse::().unwrap(); let addr = server.start(listening).await?; @@ -225,7 +225,7 @@ async fn test_query_concurrently() -> Result<()> { let expect_executed_queries_per_worker = 1000; let (tx, mut rx) = mpsc::channel(threads * expect_executed_queries_per_worker); - let mut server = create_opentsdb_server(tx)?; + let server = create_opentsdb_server(tx)?; let listening = "127.0.0.1:0".parse::().unwrap(); let addr = server.start(listening).await?; diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index 2935479286..3e69863702 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -29,7 +29,7 @@ fn create_postgres_server(table: MemTable) -> Result> { pub async fn test_start_postgres_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let mut pg_server = create_postgres_server(table)?; + let pg_server = create_postgres_server(table)?; let listening = "127.0.0.1:0".parse::().unwrap(); let result = pg_server.start(listening).await; assert!(result.is_ok()); @@ -48,7 +48,7 @@ async fn test_shutdown_pg_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let mut postgres_server = create_postgres_server(table)?; + let postgres_server = create_postgres_server(table)?; let result = postgres_server.shutdown().await; assert!(result .unwrap_err() @@ -107,7 +107,7 @@ async fn test_query_pg_concurrently() -> Result<()> { let table = MemTable::default_numbers_table(); - let mut pg_server = create_postgres_server(table)?; + let pg_server = create_postgres_server(table)?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = pg_server.start(listening).await.unwrap(); let server_port = server_addr.port();