mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-07 05:42:57 +00:00
feat: Implements shutdown for GrpcServer and HttpServer (#372)
* fix: Fix TestGuard being dropped before grpc test starts * feat: Let start and shutdown takes immutable reference to self Also implement shutdown for GrpcServer * feat: Implement shutdown for HttpServer * style: Fix clippy * chore: Add name to AlreadyStarted error
This commit is contained in:
@@ -181,7 +181,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_expr_to_request() {
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("create_expr_to_request");
|
||||
let instance = Instance::new(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
|
||||
@@ -15,37 +15,42 @@ use servers::grpc::GrpcServer;
|
||||
use servers::server::Server;
|
||||
|
||||
use crate::instance::Instance;
|
||||
use crate::tests::test_util;
|
||||
use crate::tests::test_util::{self, TestGuard};
|
||||
|
||||
async fn setup_grpc_server(port: usize) -> String {
|
||||
async fn setup_grpc_server(name: &str, port: usize) -> (String, TestGuard, Arc<GrpcServer>) {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let (mut opts, _guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let (mut opts, guard) = test_util::create_tmp_dir_and_datanode_opts(name);
|
||||
let addr = format!("127.0.0.1:{}", port);
|
||||
opts.rpc_addr = addr.clone();
|
||||
let instance = Arc::new(Instance::new(&opts).await.unwrap());
|
||||
instance.start().await.unwrap();
|
||||
|
||||
let addr_cloned = addr.clone();
|
||||
let grpc_server = Arc::new(GrpcServer::new(instance.clone(), instance));
|
||||
|
||||
let grpc_server_clone = grpc_server.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut grpc_server = GrpcServer::new(instance.clone(), instance);
|
||||
let addr = addr_cloned.parse::<SocketAddr>().unwrap();
|
||||
grpc_server.start(addr).await.unwrap()
|
||||
grpc_server_clone.start(addr).await.unwrap()
|
||||
});
|
||||
|
||||
// wait for GRPC server to start
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
addr
|
||||
|
||||
(addr, guard, grpc_server)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auto_create_table() {
|
||||
let addr = setup_grpc_server(3991).await;
|
||||
let (addr, _guard, grpc_server) = setup_grpc_server("auto_create_table", 3991).await;
|
||||
|
||||
let grpc_client = Client::connect(format!("http://{}", addr)).await.unwrap();
|
||||
let db = Database::new("greptime", grpc_client);
|
||||
|
||||
insert_and_assert(&db).await;
|
||||
|
||||
grpc_server.shutdown().await.unwrap();
|
||||
}
|
||||
|
||||
fn expect_data() -> (Column, Column, Column, Column) {
|
||||
@@ -104,7 +109,7 @@ fn expect_data() -> (Column, Column, Column, Column) {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_and_select() {
|
||||
let addr = setup_grpc_server(3990).await;
|
||||
let (addr, _guard, grpc_server) = setup_grpc_server("insert_and_select", 3990).await;
|
||||
|
||||
let grpc_client = Client::connect(format!("http://{}", addr)).await.unwrap();
|
||||
|
||||
@@ -143,6 +148,8 @@ async fn test_insert_and_select() {
|
||||
|
||||
// insert
|
||||
insert_and_assert(&db).await;
|
||||
|
||||
grpc_server.shutdown().await.unwrap();
|
||||
}
|
||||
|
||||
async fn insert_and_assert(db: &Database) {
|
||||
|
||||
@@ -13,8 +13,8 @@ use test_util::TestGuard;
|
||||
use crate::instance::Instance;
|
||||
use crate::tests::test_util;
|
||||
|
||||
async fn make_test_app() -> (Router, TestGuard) {
|
||||
let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
async fn make_test_app(name: &str) -> (Router, TestGuard) {
|
||||
let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts(name);
|
||||
let instance = Arc::new(Instance::new(&opts).await.unwrap());
|
||||
instance.start().await.unwrap();
|
||||
test_util::create_test_table(&instance, ConcreteDataType::timestamp_millis_datatype())
|
||||
@@ -27,7 +27,7 @@ async fn make_test_app() -> (Router, TestGuard) {
|
||||
#[tokio::test]
|
||||
async fn test_sql_api() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let (app, _guard) = make_test_app().await;
|
||||
let (app, _guard) = make_test_app("sql_api").await;
|
||||
let client = TestClient::new(app);
|
||||
let res = client.get("/v1/sql").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
@@ -88,7 +88,7 @@ async fn test_sql_api() {
|
||||
async fn test_metrics_api() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
common_telemetry::init_default_metrics_recorder();
|
||||
let (app, _guard) = make_test_app().await;
|
||||
let (app, _guard) = make_test_app("metrics_api").await;
|
||||
let client = TestClient::new(app);
|
||||
|
||||
// Send a sql
|
||||
@@ -108,7 +108,7 @@ async fn test_metrics_api() {
|
||||
#[tokio::test]
|
||||
async fn test_scripts_api() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let (app, _guard) = make_test_app().await;
|
||||
let (app, _guard) = make_test_app("scripts_api").await;
|
||||
let client = TestClient::new(app);
|
||||
let res = client
|
||||
.post("/v1/scripts")
|
||||
@@ -140,10 +140,10 @@ def test(n):
|
||||
}
|
||||
|
||||
async fn start_test_app(addr: &str) -> (SocketAddr, TestGuard) {
|
||||
let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts("py_side_scripts_api");
|
||||
let instance = Arc::new(Instance::new(&opts).await.unwrap());
|
||||
instance.start().await.unwrap();
|
||||
let mut http_server = HttpServer::new(instance);
|
||||
let http_server = HttpServer::new(instance);
|
||||
(
|
||||
http_server.start(addr.parse().unwrap()).await.unwrap(),
|
||||
guard,
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::tests::test_util;
|
||||
async fn test_execute_insert() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_insert");
|
||||
let instance = Instance::new(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
@@ -37,7 +37,7 @@ async fn test_execute_insert() {
|
||||
async fn test_execute_insert_query_with_i64_timestamp() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("insert_query_i64_timestamp");
|
||||
let instance = Instance::new(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
@@ -74,7 +74,7 @@ async fn test_execute_insert_query_with_i64_timestamp() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_query() {
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_query");
|
||||
let instance = Instance::new(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
@@ -100,7 +100,8 @@ async fn test_execute_query() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_show_databases_tables() {
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let (opts, _guard) =
|
||||
test_util::create_tmp_dir_and_datanode_opts("execute_show_databases_tables");
|
||||
let instance = Instance::new(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
@@ -191,7 +192,7 @@ async fn test_execute_show_databases_tables() {
|
||||
pub async fn test_execute_create() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_create");
|
||||
let instance = Instance::new(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
@@ -215,7 +216,8 @@ pub async fn test_execute_create() {
|
||||
pub async fn test_create_table_illegal_timestamp_type() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let (opts, _guard) =
|
||||
test_util::create_tmp_dir_and_datanode_opts("create_table_illegal_timestamp_type");
|
||||
let instance = Instance::new(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
|
||||
@@ -24,9 +24,9 @@ pub struct TestGuard {
|
||||
_data_tmp_dir: TempDir,
|
||||
}
|
||||
|
||||
pub fn create_tmp_dir_and_datanode_opts() -> (DatanodeOptions, TestGuard) {
|
||||
let wal_tmp_dir = TempDir::new("/tmp/greptimedb_test_wal").unwrap();
|
||||
let data_tmp_dir = TempDir::new("/tmp/greptimedb_test_data").unwrap();
|
||||
pub fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) {
|
||||
let wal_tmp_dir = TempDir::new(&format!("gt_wal_{}", name)).unwrap();
|
||||
let data_tmp_dir = TempDir::new(&format!("gt_data_{}", name)).unwrap();
|
||||
let opts = DatanodeOptions {
|
||||
wal_dir: wal_tmp_dir.path().to_str().unwrap().to_string(),
|
||||
storage: ObjectStoreConfig::File {
|
||||
|
||||
@@ -131,7 +131,7 @@ fn parse_addr(addr: &str) -> Result<SocketAddr> {
|
||||
async fn start_server(
|
||||
server_and_addr: Option<(Box<dyn Server>, SocketAddr)>,
|
||||
) -> servers::error::Result<Option<SocketAddr>> {
|
||||
if let Some((mut server, addr)) = server_and_addr {
|
||||
if let Some((server, addr)) = server_and_addr {
|
||||
server.start(addr).await.map(Some)
|
||||
} else {
|
||||
Ok(None)
|
||||
|
||||
@@ -42,6 +42,12 @@ pub enum Error {
|
||||
#[snafu(display("Failed to start gRPC server, source: {}", source))]
|
||||
StartGrpc { source: tonic::transport::Error },
|
||||
|
||||
#[snafu(display("{} server is already started", server))]
|
||||
AlreadyStarted {
|
||||
server: String,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to bind address {}, source: {}", addr, source))]
|
||||
TcpBind {
|
||||
addr: SocketAddr,
|
||||
@@ -161,6 +167,7 @@ impl ErrorExt for Error {
|
||||
| CollectRecordbatch { .. }
|
||||
| StartHttp { .. }
|
||||
| StartGrpc { .. }
|
||||
| AlreadyStarted { .. }
|
||||
| InvalidPromRemoteReadQueryResult { .. }
|
||||
| TcpBind { .. } => StatusCode::Internal,
|
||||
|
||||
|
||||
@@ -5,12 +5,16 @@ use std::net::SocketAddr;
|
||||
use api::v1::{greptime_server, BatchRequest, BatchResponse};
|
||||
use async_trait::async_trait;
|
||||
use common_telemetry::logging::info;
|
||||
use futures::FutureExt;
|
||||
use snafu::ensure;
|
||||
use snafu::ResultExt;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::oneshot::{self, Sender};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_stream::wrappers::TcpListenerStream;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use crate::error::{Result, StartGrpcSnafu, TcpBindSnafu};
|
||||
use crate::error::{AlreadyStartedSnafu, Result, StartGrpcSnafu, TcpBindSnafu};
|
||||
use crate::grpc::handler::BatchHandler;
|
||||
use crate::query_handler::{GrpcAdminHandlerRef, GrpcQueryHandlerRef};
|
||||
use crate::server::Server;
|
||||
@@ -18,6 +22,7 @@ use crate::server::Server;
|
||||
pub struct GrpcServer {
|
||||
query_handler: GrpcQueryHandlerRef,
|
||||
admin_handler: GrpcAdminHandlerRef,
|
||||
shutdown_tx: Mutex<Option<Sender<()>>>,
|
||||
}
|
||||
|
||||
impl GrpcServer {
|
||||
@@ -25,6 +30,7 @@ impl GrpcServer {
|
||||
Self {
|
||||
query_handler,
|
||||
admin_handler,
|
||||
shutdown_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,23 +60,45 @@ impl greptime_server::Greptime for GrpcService {
|
||||
|
||||
#[async_trait]
|
||||
impl Server for GrpcServer {
|
||||
async fn shutdown(&mut self) -> Result<()> {
|
||||
// TODO(LFC): shutdown grpc server
|
||||
unimplemented!()
|
||||
async fn shutdown(&self) -> Result<()> {
|
||||
let mut shutdown_tx = self.shutdown_tx.lock().await;
|
||||
if let Some(tx) = shutdown_tx.take() {
|
||||
if tx.send(()).is_err() {
|
||||
info!("Receiver dropped, the grpc server has already existed");
|
||||
}
|
||||
}
|
||||
info!("Shutdown grpc server");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start(&mut self, addr: SocketAddr) -> Result<SocketAddr> {
|
||||
let listener = TcpListener::bind(addr)
|
||||
.await
|
||||
.context(TcpBindSnafu { addr })?;
|
||||
let addr = listener.local_addr().context(TcpBindSnafu { addr })?;
|
||||
info!("GRPC server is bound to {}", addr);
|
||||
async fn start(&self, addr: SocketAddr) -> Result<SocketAddr> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let (listener, addr) = {
|
||||
let mut shutdown_tx = self.shutdown_tx.lock().await;
|
||||
ensure!(
|
||||
shutdown_tx.is_none(),
|
||||
AlreadyStartedSnafu { server: "gRPC" }
|
||||
);
|
||||
|
||||
let listener = TcpListener::bind(addr)
|
||||
.await
|
||||
.context(TcpBindSnafu { addr })?;
|
||||
let addr = listener.local_addr().context(TcpBindSnafu { addr })?;
|
||||
info!("GRPC server is bound to {}", addr);
|
||||
|
||||
*shutdown_tx = Some(tx);
|
||||
|
||||
(listener, addr)
|
||||
};
|
||||
|
||||
// Would block to serve requests.
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(self.create_service())
|
||||
.serve_with_incoming(TcpListenerStream::new(listener))
|
||||
.serve_with_incoming_shutdown(TcpListenerStream::new(listener), rx.map(drop))
|
||||
.await
|
||||
.context(StartGrpcSnafu)?;
|
||||
|
||||
Ok(addr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,15 +20,19 @@ use common_query::Output;
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use common_telemetry::logging::info;
|
||||
use datatypes::data_type::DataType;
|
||||
use futures::FutureExt;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use snafu::ensure;
|
||||
use snafu::ResultExt;
|
||||
use tokio::sync::oneshot::{self, Sender};
|
||||
use tokio::sync::Mutex;
|
||||
use tower::{timeout::TimeoutLayer, ServiceBuilder};
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
use self::influxdb::influxdb_write;
|
||||
use crate::error::{Result, StartHttpSnafu};
|
||||
use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu};
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
use crate::query_handler::{
|
||||
InfluxdbLineProtocolHandlerRef, OpentsdbProtocolHandlerRef, PrometheusProtocolHandlerRef,
|
||||
@@ -42,6 +46,7 @@ pub struct HttpServer {
|
||||
influxdb_handler: Option<InfluxdbLineProtocolHandlerRef>,
|
||||
opentsdb_handler: Option<OpentsdbProtocolHandlerRef>,
|
||||
prom_handler: Option<PrometheusProtocolHandlerRef>,
|
||||
shutdown_tx: Mutex<Option<Sender<()>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, JsonSchema)]
|
||||
@@ -195,14 +200,6 @@ impl JsonResponse {
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
// Wait for the CTRL+C signal
|
||||
// It has an issue on chrome: https://github.com/sigp/lighthouse/issues/478
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install CTRL+C signal handler");
|
||||
}
|
||||
|
||||
async fn serve_api(Extension(api): Extension<Arc<OpenApi>>) -> impl IntoApiResponse {
|
||||
Json(api)
|
||||
}
|
||||
@@ -218,6 +215,7 @@ impl HttpServer {
|
||||
opentsdb_handler: None,
|
||||
influxdb_handler: None,
|
||||
prom_handler: None,
|
||||
shutdown_tx: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,18 +321,40 @@ impl HttpServer {
|
||||
|
||||
#[async_trait]
|
||||
impl Server for HttpServer {
|
||||
async fn shutdown(&mut self) -> Result<()> {
|
||||
// TODO(LFC): shutdown http server, and remove `shutdown_signal` above
|
||||
unimplemented!()
|
||||
async fn shutdown(&self) -> Result<()> {
|
||||
let mut shutdown_tx = self.shutdown_tx.lock().await;
|
||||
if let Some(tx) = shutdown_tx.take() {
|
||||
if tx.send(()).is_err() {
|
||||
info!("Receiver dropped, the HTTP server has already existed");
|
||||
}
|
||||
}
|
||||
info!("Shutdown HTTP server");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start(&mut self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
let app = self.make_app();
|
||||
let server = axum::Server::bind(&listening).serve(app.into_make_service());
|
||||
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let server = {
|
||||
let mut shutdown_tx = self.shutdown_tx.lock().await;
|
||||
ensure!(
|
||||
shutdown_tx.is_none(),
|
||||
AlreadyStartedSnafu { server: "HTTP" }
|
||||
);
|
||||
|
||||
let app = self.make_app();
|
||||
let server = axum::Server::bind(&listening).serve(app.into_make_service());
|
||||
|
||||
*shutdown_tx = Some(tx);
|
||||
|
||||
server
|
||||
};
|
||||
let listening = server.local_addr();
|
||||
info!("HTTP server is bound to {}", listening);
|
||||
let graceful = server.with_graceful_shutdown(shutdown_signal());
|
||||
|
||||
let graceful = server.with_graceful_shutdown(rx.map(drop));
|
||||
graceful.await.context(StartHttpSnafu)?;
|
||||
|
||||
Ok(listening)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,16 +68,16 @@ impl MysqlServer {
|
||||
|
||||
#[async_trait]
|
||||
impl Server for MysqlServer {
|
||||
async fn shutdown(&mut self) -> Result<()> {
|
||||
async fn shutdown(&self) -> Result<()> {
|
||||
self.base_server.shutdown().await
|
||||
}
|
||||
|
||||
async fn start(&mut self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
let (stream, addr) = self.base_server.bind(listening).await?;
|
||||
|
||||
let io_runtime = self.base_server.io_runtime();
|
||||
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
|
||||
self.base_server.start_with(join_handle)?;
|
||||
self.base_server.start_with(join_handle).await?;
|
||||
Ok(addr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,18 +85,22 @@ impl OpentsdbServer {
|
||||
|
||||
#[async_trait]
|
||||
impl Server for OpentsdbServer {
|
||||
async fn shutdown(&mut self) -> Result<()> {
|
||||
async fn shutdown(&self) -> Result<()> {
|
||||
if let Some(tx) = &self.notify_shutdown {
|
||||
// Err of broadcast sender does not mean that future calls to send will fail, so
|
||||
// its return value is ignored here.
|
||||
let _ = tx.send(());
|
||||
}
|
||||
self.base_server.shutdown().await?;
|
||||
drop(self.notify_shutdown.take());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start(&mut self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
let (stream, addr) = self.base_server.bind(listening).await?;
|
||||
|
||||
let io_runtime = self.base_server.io_runtime();
|
||||
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
|
||||
self.base_server.start_with(join_handle)?;
|
||||
self.base_server.start_with(join_handle).await?;
|
||||
Ok(addr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,16 +68,16 @@ impl PostgresServer {
|
||||
|
||||
#[async_trait]
|
||||
impl Server for PostgresServer {
|
||||
async fn shutdown(&mut self) -> Result<()> {
|
||||
async fn shutdown(&self) -> Result<()> {
|
||||
self.base_server.shutdown().await
|
||||
}
|
||||
|
||||
async fn start(&mut self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
let (stream, addr) = self.base_server.bind(listening).await?;
|
||||
|
||||
let io_runtime = self.base_server.io_runtime();
|
||||
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
|
||||
self.base_server.start_with(join_handle)?;
|
||||
self.base_server.start_with(join_handle).await?;
|
||||
Ok(addr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use common_telemetry::logging::{error, info};
|
||||
use futures::future::AbortRegistration;
|
||||
use futures::future::{AbortHandle, Abortable};
|
||||
use snafu::ResultExt;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_stream::wrappers::TcpListenerStream;
|
||||
|
||||
@@ -16,13 +17,16 @@ pub(crate) type AbortableStream = Abortable<TcpListenerStream>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Server: Send {
|
||||
async fn shutdown(&mut self) -> Result<()>;
|
||||
async fn start(&mut self, listening: SocketAddr) -> Result<SocketAddr>;
|
||||
/// Shutdown the server gracefully.
|
||||
async fn shutdown(&self) -> Result<()>;
|
||||
|
||||
/// Starts the server and binds on `listening`.
|
||||
///
|
||||
/// Caller should ensure `start()` is only invoked once.
|
||||
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr>;
|
||||
}
|
||||
|
||||
pub(crate) struct BaseTcpServer {
|
||||
name: String,
|
||||
|
||||
struct AccpetTask {
|
||||
// `abort_handle` and `abort_registration` are used in pairs in shutting down the server.
|
||||
// They work like sender and receiver for aborting stream. When the server is shutting down,
|
||||
// calling `abort_handle.abort()` will "notify" `abort_registration` to stop emitting new
|
||||
@@ -32,23 +36,10 @@ pub(crate) struct BaseTcpServer {
|
||||
|
||||
// A handle holding the TCP accepting task.
|
||||
join_handle: Option<JoinHandle<()>>,
|
||||
|
||||
io_runtime: Arc<Runtime>,
|
||||
}
|
||||
|
||||
impl BaseTcpServer {
|
||||
pub(crate) fn create_server(name: impl Into<String>, io_runtime: Arc<Runtime>) -> Self {
|
||||
let (abort_handle, registration) = AbortHandle::new_pair();
|
||||
Self {
|
||||
name: name.into(),
|
||||
abort_handle,
|
||||
abort_registration: Some(registration),
|
||||
join_handle: None,
|
||||
io_runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown(&mut self) -> Result<()> {
|
||||
impl AccpetTask {
|
||||
async fn shutdown(&mut self, name: &str) -> Result<()> {
|
||||
match self.join_handle.take() {
|
||||
Some(join_handle) => {
|
||||
self.abort_handle.abort();
|
||||
@@ -57,23 +48,24 @@ impl BaseTcpServer {
|
||||
// Couldn't use `error!(e; xxx)` because JoinError doesn't implement ErrorExt.
|
||||
error!(
|
||||
"Unexpected error during shutdown {} server, error: {}",
|
||||
&self.name, error
|
||||
name, error
|
||||
);
|
||||
} else {
|
||||
info!("{} server is shutdown.", &self.name);
|
||||
info!("{} server is shutdown.", name);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
None => error::InternalSnafu {
|
||||
err_msg: format!("{} server is not started.", &self.name),
|
||||
err_msg: format!("{} server is not started.", name),
|
||||
}
|
||||
.fail()?,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn bind(
|
||||
async fn bind(
|
||||
&mut self,
|
||||
addr: SocketAddr,
|
||||
name: &str,
|
||||
) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
|
||||
match self.abort_registration.take() {
|
||||
Some(registration) => {
|
||||
@@ -81,33 +73,73 @@ impl BaseTcpServer {
|
||||
tokio::net::TcpListener::bind(addr)
|
||||
.await
|
||||
.context(error::TokioIoSnafu {
|
||||
err_msg: format!("Failed to bind addr {}", addr),
|
||||
err_msg: format!("{} failed to bind addr {}", name, addr),
|
||||
})?;
|
||||
// get actually bond addr in case input addr use port 0
|
||||
let addr = listener.local_addr()?;
|
||||
info!("{} server started at {}", &self.name, addr);
|
||||
info!("{} server started at {}", name, addr);
|
||||
|
||||
let stream = TcpListenerStream::new(listener);
|
||||
let stream = Abortable::new(stream, registration);
|
||||
Ok((stream, addr))
|
||||
}
|
||||
None => error::InternalSnafu {
|
||||
err_msg: format!("{} server has been started.", &self.name),
|
||||
err_msg: format!("{} server has been started.", name),
|
||||
}
|
||||
.fail()?,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn start_with(&mut self, join_handle: JoinHandle<()>) -> Result<()> {
|
||||
fn start_with(&mut self, join_handle: JoinHandle<()>, name: &str) -> Result<()> {
|
||||
if self.join_handle.is_some() {
|
||||
return error::InternalSnafu {
|
||||
err_msg: format!("{} server has been started.", &self.name),
|
||||
err_msg: format!("{} server has been started.", name),
|
||||
}
|
||||
.fail();
|
||||
}
|
||||
let _ = self.join_handle.insert(join_handle);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct BaseTcpServer {
|
||||
name: String,
|
||||
accept_task: Mutex<AccpetTask>,
|
||||
io_runtime: Arc<Runtime>,
|
||||
}
|
||||
|
||||
impl BaseTcpServer {
|
||||
pub(crate) fn create_server(name: impl Into<String>, io_runtime: Arc<Runtime>) -> Self {
|
||||
let (abort_handle, registration) = AbortHandle::new_pair();
|
||||
Self {
|
||||
name: name.into(),
|
||||
accept_task: Mutex::new(AccpetTask {
|
||||
abort_handle,
|
||||
abort_registration: Some(registration),
|
||||
join_handle: None,
|
||||
}),
|
||||
io_runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn shutdown(&self) -> Result<()> {
|
||||
let mut task = self.accept_task.lock().await;
|
||||
task.shutdown(&self.name).await
|
||||
}
|
||||
|
||||
pub(crate) async fn bind(
|
||||
&self,
|
||||
addr: SocketAddr,
|
||||
) -> Result<(Abortable<TcpListenerStream>, SocketAddr)> {
|
||||
let mut task = self.accept_task.lock().await;
|
||||
task.bind(addr, &self.name).await
|
||||
}
|
||||
|
||||
pub(crate) async fn start_with(&self, join_handle: JoinHandle<()>) -> Result<()> {
|
||||
let mut task = self.accept_task.lock().await;
|
||||
task.start_with(join_handle, &self.name)
|
||||
}
|
||||
|
||||
pub(crate) fn io_runtime(&self) -> Arc<Runtime> {
|
||||
self.io_runtime.clone()
|
||||
|
||||
@@ -32,7 +32,7 @@ fn create_mysql_server(table: MemTable) -> Result<Box<dyn Server>> {
|
||||
async fn test_start_mysql_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mut mysql_server = create_mysql_server(table)?;
|
||||
let mysql_server = create_mysql_server(table)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = mysql_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -51,7 +51,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mut mysql_server = create_mysql_server(table)?;
|
||||
let mysql_server = create_mysql_server(table)?;
|
||||
let result = mysql_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -110,7 +110,7 @@ async fn test_query_all_datatypes() -> Result<()> {
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mut mysql_server = create_mysql_server(table)?;
|
||||
let mysql_server = create_mysql_server(table)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
@@ -141,7 +141,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mut mysql_server = create_mysql_server(table)?;
|
||||
let mysql_server = create_mysql_server(table)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
|
||||
@@ -52,7 +52,7 @@ fn create_opentsdb_server(tx: mpsc::Sender<i32>) -> Result<Box<dyn Server>> {
|
||||
#[tokio::test]
|
||||
async fn test_start_opentsdb_server() -> Result<()> {
|
||||
let (tx, _) = mpsc::channel(100);
|
||||
let mut server = create_opentsdb_server(tx)?;
|
||||
let server = create_opentsdb_server(tx)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -68,7 +68,7 @@ async fn test_start_opentsdb_server() -> Result<()> {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shutdown_opentsdb_server_concurrently() -> Result<()> {
|
||||
let (tx, _) = mpsc::channel(100);
|
||||
let mut server = create_opentsdb_server(tx)?;
|
||||
let server = create_opentsdb_server(tx)?;
|
||||
let result = server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -133,7 +133,7 @@ async fn test_shutdown_opentsdb_server_concurrently() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn test_opentsdb_connection_shutdown() -> Result<()> {
|
||||
let (tx, _) = mpsc::channel(100);
|
||||
let mut server = create_opentsdb_server(tx)?;
|
||||
let server = create_opentsdb_server(tx)?;
|
||||
let result = server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -178,7 +178,7 @@ async fn test_opentsdb_connection_shutdown() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn test_opentsdb_connect_after_shutdown() -> Result<()> {
|
||||
let (tx, _) = mpsc::channel(100);
|
||||
let mut server = create_opentsdb_server(tx)?;
|
||||
let server = create_opentsdb_server(tx)?;
|
||||
let result = server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -198,7 +198,7 @@ async fn test_opentsdb_connect_after_shutdown() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn test_query() -> Result<()> {
|
||||
let (tx, mut rx) = mpsc::channel(10);
|
||||
let mut server = create_opentsdb_server(tx)?;
|
||||
let server = create_opentsdb_server(tx)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let addr = server.start(listening).await?;
|
||||
|
||||
@@ -225,7 +225,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
let expect_executed_queries_per_worker = 1000;
|
||||
let (tx, mut rx) = mpsc::channel(threads * expect_executed_queries_per_worker);
|
||||
|
||||
let mut server = create_opentsdb_server(tx)?;
|
||||
let server = create_opentsdb_server(tx)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let addr = server.start(listening).await?;
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ fn create_postgres_server(table: MemTable) -> Result<Box<dyn Server>> {
|
||||
pub async fn test_start_postgres_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mut pg_server = create_postgres_server(table)?;
|
||||
let pg_server = create_postgres_server(table)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = pg_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -48,7 +48,7 @@ async fn test_shutdown_pg_server() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mut postgres_server = create_postgres_server(table)?;
|
||||
let postgres_server = create_postgres_server(table)?;
|
||||
let result = postgres_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -107,7 +107,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mut pg_server = create_postgres_server(table)?;
|
||||
let pg_server = create_postgres_server(table)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
|
||||
Reference in New Issue
Block a user