diff --git a/src/cmd/src/cli.rs b/src/cmd/src/cli.rs index 022588057d..a514fa0d3c 100644 --- a/src/cmd/src/cli.rs +++ b/src/cmd/src/cli.rs @@ -58,7 +58,7 @@ impl App for Instance { false } - async fn stop(&self) -> Result<()> { + async fn stop(&mut self) -> Result<()> { Ok(()) } } diff --git a/src/cmd/src/datanode.rs b/src/cmd/src/datanode.rs index 2a422064d7..6603fbc2bf 100644 --- a/src/cmd/src/datanode.rs +++ b/src/cmd/src/datanode.rs @@ -77,7 +77,7 @@ impl App for Instance { self.datanode.start().await.context(StartDatanodeSnafu) } - async fn stop(&self) -> Result<()> { + async fn stop(&mut self) -> Result<()> { self.datanode .shutdown() .await diff --git a/src/cmd/src/datanode/builder.rs b/src/cmd/src/datanode/builder.rs index 40180563a2..64acc29a3d 100644 --- a/src/cmd/src/datanode/builder.rs +++ b/src/cmd/src/datanode/builder.rs @@ -129,7 +129,6 @@ impl InstanceBuilder { .with_default_grpc_server(&datanode.region_server()) .enable_http_service() .build() - .await .context(StartDatanodeSnafu)?; datanode.setup_services(services); diff --git a/src/cmd/src/flownode.rs b/src/cmd/src/flownode.rs index 3fc8249349..e990c871f0 100644 --- a/src/cmd/src/flownode.rs +++ b/src/cmd/src/flownode.rs @@ -85,7 +85,7 @@ impl App for Instance { self.flownode.start().await.context(StartFlownodeSnafu) } - async fn stop(&self) -> Result<()> { + async fn stop(&mut self) -> Result<()> { self.flownode .shutdown() .await @@ -331,7 +331,6 @@ impl StartCommand { .with_grpc_server(flownode.flownode_server().clone()) .enable_http_service() .build() - .await .context(StartFlownodeSnafu)?; flownode.setup_services(services); let flownode = flownode; diff --git a/src/cmd/src/frontend.rs b/src/cmd/src/frontend.rs index 408d5831ad..a28f4cd8f9 100644 --- a/src/cmd/src/frontend.rs +++ b/src/cmd/src/frontend.rs @@ -89,7 +89,7 @@ impl App for Instance { .context(error::StartFrontendSnafu) } - async fn stop(&self) -> Result<()> { + async fn stop(&mut self) -> Result<()> { self.frontend .shutdown() .await @@ -382,7 +382,6 @@ impl StartCommand { let servers = Services::new(opts, instance.clone(), plugins) .build() - .await .context(error::StartFrontendSnafu)?; let frontend = Frontend { diff --git a/src/cmd/src/lib.rs b/src/cmd/src/lib.rs index acd27f46d7..9ebb4629bf 100644 --- a/src/cmd/src/lib.rs +++ b/src/cmd/src/lib.rs @@ -74,7 +74,7 @@ pub trait App: Send { true } - async fn stop(&self) -> Result<()>; + async fn stop(&mut self) -> Result<()>; async fn run(&mut self) -> Result<()> { info!("Starting app: {}", self.name()); diff --git a/src/cmd/src/metasrv.rs b/src/cmd/src/metasrv.rs index fcd8ca8fa9..87e632c660 100644 --- a/src/cmd/src/metasrv.rs +++ b/src/cmd/src/metasrv.rs @@ -69,7 +69,7 @@ impl App for Instance { self.instance.start().await.context(StartMetaServerSnafu) } - async fn stop(&self) -> Result<()> { + async fn stop(&mut self) -> Result<()> { self.instance .shutdown() .await diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index d81abedb9a..dd9e7926c4 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -255,8 +255,8 @@ pub struct Instance { impl Instance { /// Find the socket addr of a server by its `name`. - pub async fn server_addr(&self, name: &str) -> Option { - self.frontend.server_handlers().addr(name).await + pub fn server_addr(&self, name: &str) -> Option { + self.frontend.server_handlers().addr(name) } } @@ -293,7 +293,7 @@ impl App for Instance { Ok(()) } - async fn stop(&self) -> Result<()> { + async fn stop(&mut self) -> Result<()> { self.frontend .shutdown() .await @@ -630,7 +630,6 @@ impl StartCommand { let servers = Services::new(opts, fe_instance.clone(), plugins) .build() - .await .context(error::StartFrontendSnafu)?; let frontend = Frontend { diff --git a/src/datanode/src/datanode.rs b/src/datanode/src/datanode.rs index d11de6c48e..032f3ffb2d 100644 --- a/src/datanode/src/datanode.rs +++ b/src/datanode/src/datanode.rs @@ -129,7 +129,7 @@ impl Datanode { self.services = services; } - pub async fn shutdown(&self) -> Result<()> { + pub async fn shutdown(&mut self) -> Result<()> { self.services .shutdown_all() .await diff --git a/src/datanode/src/service.rs b/src/datanode/src/service.rs index a5c2cbac66..570602cc63 100644 --- a/src/datanode/src/service.rs +++ b/src/datanode/src/service.rs @@ -62,7 +62,7 @@ impl<'a> DatanodeServiceBuilder<'a> { } } - pub async fn build(mut self) -> Result { + pub fn build(mut self) -> Result { let handlers = ServerHandlers::default(); if let Some(grpc_server) = self.grpc_server.take() { @@ -70,7 +70,7 @@ impl<'a> DatanodeServiceBuilder<'a> { addr: &self.opts.grpc.bind_addr, })?; let handler: ServerHandler = (Box::new(grpc_server), addr); - handlers.insert(handler).await; + handlers.insert(handler); } if self.enable_http_service { @@ -82,7 +82,7 @@ impl<'a> DatanodeServiceBuilder<'a> { addr: &self.opts.http.addr, })?; let handler: ServerHandler = (Box::new(http_server), addr); - handlers.insert(handler).await; + handlers.insert(handler); } Ok(handlers) diff --git a/src/flow/src/server.rs b/src/flow/src/server.rs index e82a20dd81..358712496d 100644 --- a/src/flow/src/server.rs +++ b/src/flow/src/server.rs @@ -231,10 +231,10 @@ impl servers::server::Server for FlownodeServer { Ok(()) } - async fn start(&self, addr: SocketAddr) -> Result { + async fn start(&mut self, addr: SocketAddr) -> Result<(), servers::error::Error> { let mut rx_server = self.inner.server_shutdown_tx.lock().await.subscribe(); - let (incoming, addr) = { + let incoming = { let listener = TcpListener::bind(addr) .await .context(TcpBindSnafu { addr })?; @@ -243,7 +243,7 @@ impl servers::server::Server for FlownodeServer { TcpIncoming::from_listener(listener, true, None).context(TcpIncomingSnafu)?; info!("flow server is bound to {}", addr); - (incoming, addr) + incoming }; let builder = tonic::transport::Server::builder().add_service(self.create_flow_service()); @@ -255,7 +255,7 @@ impl servers::server::Server for FlownodeServer { .context(StartGrpcSnafu); }); - Ok(addr) + Ok(()) } fn name(&self) -> &str { @@ -282,7 +282,7 @@ impl FlownodeInstance { Ok(()) } - pub async fn shutdown(&self) -> Result<(), crate::Error> { + pub async fn shutdown(&mut self) -> Result<(), Error> { self.services .shutdown_all() .await @@ -391,7 +391,7 @@ impl FlownodeBuilder { let instance = FlownodeInstance { flownode_server: server, - services: ServerHandlers::new(), + services: ServerHandlers::default(), heartbeat_task, }; Ok(instance) @@ -572,14 +572,14 @@ impl<'a> FlownodeServiceBuilder<'a> { } } - pub async fn build(mut self) -> Result { + pub fn build(mut self) -> Result { let handlers = ServerHandlers::default(); if let Some(grpc_server) = self.grpc_server.take() { let addr: SocketAddr = self.opts.grpc.bind_addr.parse().context(ParseAddrSnafu { addr: &self.opts.grpc.bind_addr, })?; let handler: ServerHandler = (Box::new(grpc_server), addr); - handlers.insert(handler).await; + handlers.insert(handler); } if self.enable_http_service { @@ -590,7 +590,7 @@ impl<'a> FlownodeServiceBuilder<'a> { addr: &self.opts.http.addr, })?; let handler: ServerHandler = (Box::new(http_server), addr); - handlers.insert(handler).await; + handlers.insert(handler); } Ok(handlers) } diff --git a/src/frontend/src/frontend.rs b/src/frontend/src/frontend.rs index ba795730c4..5574da2169 100644 --- a/src/frontend/src/frontend.rs +++ b/src/frontend/src/frontend.rs @@ -106,7 +106,7 @@ pub struct Frontend { } impl Frontend { - pub async fn start(&self) -> Result<()> { + pub async fn start(&mut self) -> Result<()> { if let Some(t) = &self.heartbeat_task { t.start().await?; } @@ -128,7 +128,7 @@ impl Frontend { .context(error::StartServerSnafu) } - pub async fn shutdown(&self) -> Result<()> { + pub async fn shutdown(&mut self) -> Result<()> { self.servers .shutdown_all() .await diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index b3a34368fa..c014d28e7e 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -179,7 +179,7 @@ where Ok(http_server) } - pub async fn build(mut self) -> Result { + pub fn build(mut self) -> Result { let opts = self.opts.clone(); let instance = self.instance.clone(); @@ -194,7 +194,7 @@ where // Always init GRPC server let grpc_addr = parse_addr(&opts.grpc.bind_addr)?; let grpc_server = self.build_grpc_server(&opts)?; - handlers.insert((Box::new(grpc_server), grpc_addr)).await; + handlers.insert((Box::new(grpc_server), grpc_addr)); } { @@ -202,7 +202,7 @@ where let http_options = &opts.http; let http_addr = parse_addr(&http_options.addr)?; let http_server = self.build_http_server(&opts, toml)?; - handlers.insert((Box::new(http_server), http_addr)).await; + handlers.insert((Box::new(http_server), http_addr)); } if opts.mysql.enable { @@ -230,7 +230,7 @@ where opts.reject_no_database.unwrap_or(false), )), ); - handlers.insert((mysql_server, mysql_addr)).await; + handlers.insert((mysql_server, mysql_addr)); } if opts.postgres.enable { @@ -253,7 +253,7 @@ where user_provider.clone(), )) as Box; - handlers.insert((pg_server, pg_addr)).await; + handlers.insert((pg_server, pg_addr)); } Ok(handlers) diff --git a/src/meta-srv/src/bootstrap.rs b/src/meta-srv/src/bootstrap.rs index 40e41bb815..c669e2d546 100644 --- a/src/meta-srv/src/bootstrap.rs +++ b/src/meta-srv/src/bootstrap.rs @@ -79,7 +79,7 @@ use crate::{error, Result}; pub struct MetasrvInstance { metasrv: Arc, - httpsrv: Arc, + http_server: HttpServer, opts: MetasrvOptions, @@ -96,12 +96,11 @@ impl MetasrvInstance { plugins: Plugins, metasrv: Metasrv, ) -> Result { - let httpsrv = Arc::new( - HttpServerBuilder::new(opts.http.clone()) - .with_metrics_handler(MetricsHandler) - .with_greptime_config_options(opts.to_toml().context(error::TomlFormatSnafu)?) - .build(), - ); + let http_server = HttpServerBuilder::new(opts.http.clone()) + .with_metrics_handler(MetricsHandler) + .with_greptime_config_options(opts.to_toml().context(error::TomlFormatSnafu)?) + .build(); + let metasrv = Arc::new(metasrv); // put metasrv into plugins for later use plugins.insert::>(metasrv.clone()); @@ -109,7 +108,7 @@ impl MetasrvInstance { .context(error::InitExportMetricsTaskSnafu)?; Ok(MetasrvInstance { metasrv, - httpsrv, + http_server, opts, signal_sender: None, plugins, @@ -138,10 +137,9 @@ impl MetasrvInstance { addr: &self.opts.http.addr, })?; let http_srv = async { - self.httpsrv + self.http_server .start(addr) .await - .map(|_| ()) .context(error::StartHttpSnafu) }; future::try_join(metasrv, http_srv).await?; @@ -156,11 +154,11 @@ impl MetasrvInstance { .context(error::SendShutdownSignalSnafu)?; } self.metasrv.shutdown().await?; - self.httpsrv + self.http_server .shutdown() .await .context(error::ShutdownServerSnafu { - server: self.httpsrv.name(), + server: self.http_server.name(), })?; Ok(()) } diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index dd591d7805..bad5dd9ae7 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -151,6 +151,7 @@ pub struct GrpcServer { >, >, >, + bind_addr: Option, } /// Grpc Server configuration @@ -236,7 +237,7 @@ impl Server for GrpcServer { Ok(()) } - async fn start(&self, addr: SocketAddr) -> Result { + async fn start(&mut self, addr: SocketAddr) -> Result<()> { let routes = { let mut routes = self.routes.lock().await; let Some(routes) = routes.take() else { @@ -298,10 +299,16 @@ impl Server for GrpcServer { .context(StartGrpcSnafu); serve_state_tx.send(result) }); - Ok(addr) + + self.bind_addr = Some(addr); + Ok(()) } fn name(&self) -> &str { GRPC_SERVER } + + fn bind_addr(&self) -> Option { + self.bind_addr + } } diff --git a/src/servers/src/grpc/builder.rs b/src/servers/src/grpc/builder.rs index 65d439fada..cee125b0c6 100644 --- a/src/servers/src/grpc/builder.rs +++ b/src/servers/src/grpc/builder.rs @@ -181,6 +181,7 @@ impl GrpcServerBuilder { serve_state: Mutex::new(None), tls_config: self.tls_config, otel_arrow_service: Mutex::new(self.otel_arrow_service), + bind_addr: None, } } } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index bff38d980f..b2c854485d 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -130,6 +130,7 @@ pub struct HttpServer { // server configs options: HttpOptions, + bind_addr: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -687,6 +688,7 @@ impl HttpServerBuilder { shutdown_tx: Mutex::new(None), plugins: self.plugins, router: StdMutex::new(self.router), + bind_addr: None, } } } @@ -1099,7 +1101,7 @@ impl Server for HttpServer { Ok(()) } - async fn start(&self, listening: SocketAddr) -> Result { + async fn start(&mut self, listening: SocketAddr) -> Result<()> { let (tx, rx) = oneshot::channel(); let serve = { let mut shutdown_tx = self.shutdown_tx.lock().await; @@ -1155,12 +1157,18 @@ impl Server for HttpServer { error!(e; "Failed to shutdown http server"); } }); - Ok(listening) + + self.bind_addr = Some(listening); + Ok(()) } fn name(&self) -> &str { HTTP_SERVER } + + fn bind_addr(&self) -> Option { + self.bind_addr + } } #[cfg(test)] diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index 12b9c689a1..f274ccce32 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -111,6 +111,7 @@ pub struct MysqlServer { base_server: BaseTcpServer, spawn_ref: Arc, spawn_config: Arc, + bind_addr: Option, } impl MysqlServer { @@ -123,6 +124,7 @@ impl MysqlServer { base_server: BaseTcpServer::create_server("MySQL", io_runtime), spawn_ref, spawn_config, + bind_addr: None, }) } @@ -221,7 +223,7 @@ impl Server for MysqlServer { self.base_server.shutdown().await } - async fn start(&self, listening: SocketAddr) -> Result { + async fn start(&mut self, listening: SocketAddr) -> Result<()> { let (stream, addr) = self .base_server .bind(listening, self.spawn_config.keep_alive_secs) @@ -230,10 +232,16 @@ impl Server for MysqlServer { let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream)); self.base_server.start_with(join_handle).await?; - Ok(addr) + + self.bind_addr = Some(addr); + Ok(()) } fn name(&self) -> &str { MYSQL_SERVER } + + fn bind_addr(&self) -> Option { + self.bind_addr + } } diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 06f7b0ef06..3064871bb7 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -36,6 +36,7 @@ pub struct PostgresServer { make_handler: Arc, tls_server_config: Arc, keep_alive_secs: u64, + bind_addr: Option, } impl PostgresServer { @@ -61,6 +62,7 @@ impl PostgresServer { make_handler, tls_server_config, keep_alive_secs, + bind_addr: None, } } @@ -118,7 +120,7 @@ impl Server for PostgresServer { self.base_server.shutdown().await } - async fn start(&self, listening: SocketAddr) -> Result { + async fn start(&mut self, listening: SocketAddr) -> Result<()> { let (stream, addr) = self .base_server .bind(listening, self.keep_alive_secs) @@ -128,10 +130,16 @@ impl Server for PostgresServer { let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream)); self.base_server.start_with(join_handle).await?; - Ok(addr) + + self.bind_addr = Some(addr); + Ok(()) } fn name(&self) -> &str { POSTGRES_SERVER } + + fn bind_addr(&self) -> Option { + self.bind_addr + } } diff --git a/src/servers/src/server.rs b/src/servers/src/server.rs index 35a5c61859..0d655e5348 100644 --- a/src/servers/src/server.rs +++ b/src/servers/src/server.rs @@ -21,7 +21,7 @@ use common_runtime::Runtime; use common_telemetry::{error, info}; use futures::future::{try_join_all, AbortHandle, AbortRegistration, Abortable}; use snafu::{ensure, ResultExt}; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio_stream::wrappers::TcpListenerStream; @@ -32,20 +32,31 @@ pub(crate) type AbortableStream = Abortable; pub type ServerHandler = (Box, SocketAddr); /// [ServerHandlers] is used to manage the lifecycle of all the services like http or grpc in the GreptimeDB server. -#[derive(Clone, Default)] -pub struct ServerHandlers { - handlers: Arc>>, +#[derive(Clone)] +pub enum ServerHandlers { + Init(Arc>>), + Started(Arc>>), +} + +impl Default for ServerHandlers { + fn default() -> Self { + Self::Init(Arc::new(std::sync::Mutex::new(HashMap::new()))) + } } impl ServerHandlers { - pub fn new() -> Self { - Self { - handlers: Arc::new(RwLock::new(HashMap::new())), - } - } - - pub async fn insert(&self, handler: ServerHandler) { - let mut handlers = self.handlers.write().await; + /// Inserts a [ServerHandler] **before** the [ServerHandlers] is started. + pub fn insert(&self, handler: ServerHandler) { + // Inserts more to ServerHandlers while it is not in the initialization state + // is considered a bug. + assert!( + matches!(self, ServerHandlers::Init(_)), + "unexpected: insert when `ServerHandlers` is not during initialization" + ); + let ServerHandlers::Init(handlers) = self else { + unreachable!("guarded by the assertion above"); + }; + let mut handlers = handlers.lock().unwrap(); handlers.insert(handler.0.name().to_string(), handler); } @@ -55,33 +66,59 @@ impl ServerHandlers { /// the server to get the real bound port number. This way we avoid doing careful assignment of /// the port number to the service in the test. /// - /// Note that the address is guaranteed to be correct only after the `start_all` method is - /// successfully invoked. Otherwise you may find the address to be what you configured before. - pub async fn addr(&self, name: &str) -> Option { - let handlers = self.handlers.read().await; - handlers.get(name).map(|x| x.1) + /// Note that the address is only retrievable after the [ServerHandlers] is started (the + /// `start_all` method is called successfully). Otherwise you may find the address still be + /// `None` even if you are certain the server was inserted before. + pub fn addr(&self, name: &str) -> Option { + let ServerHandlers::Started(handlers) = self else { + return None; + }; + handlers.get(name).and_then(|x| x.bind_addr()) } /// Starts all the managed services. It will block until all the services are started. /// And it will set the actual bound address to the service. - pub async fn start_all(&self) -> Result<()> { - let mut handlers = self.handlers.write().await; + pub async fn start_all(&mut self) -> Result<()> { + let ServerHandlers::Init(handlers) = self else { + // If already started, do nothing. + return Ok(()); + }; + + let mut handlers = { + let mut handlers = handlers.lock().unwrap(); + std::mem::take(&mut *handlers) + }; + try_join_all(handlers.values_mut().map(|(server, addr)| async move { - let bind_addr = server.start(*addr).await?; - *addr = bind_addr; - info!("Service {} is started at {}", server.name(), bind_addr); + server.start(*addr).await?; + + let bind_addr = server.bind_addr(); + info!( + "Server {} is started and bind to {:?}", + server.name(), + bind_addr, + ); Ok::<(), error::Error>(()) })) .await?; + + let handlers = handlers + .into_iter() + .map(|(k, v)| (k, v.0)) + .collect::>(); + *self = ServerHandlers::Started(Arc::new(handlers)); Ok(()) } /// Shutdown all the managed services. It will block until all the services are shutdown. - pub async fn shutdown_all(&self) -> Result<()> { - // Even though the `shutdown` method in server does not require mut self, we still acquire - // write lock to pair with `start_all` method. - let handlers = self.handlers.write().await; - try_join_all(handlers.values().map(|(server, _)| async move { + pub async fn shutdown_all(&mut self) -> Result<()> { + let ServerHandlers::Started(handlers) = self else { + // If not started, do nothing. + return Ok(()); + }; + + let handlers = std::mem::take(handlers); + try_join_all(handlers.values().map(|server| async move { server.shutdown().await?; info!("Service {} is shutdown!", server.name()); Ok::<(), error::Error>(()) @@ -99,9 +136,15 @@ pub trait Server: Send + Sync { /// Starts the server and binds on `listening`. /// /// Caller should ensure `start()` is only invoked once. - async fn start(&self, listening: SocketAddr) -> Result; + async fn start(&mut self, listening: SocketAddr) -> Result<()>; fn name(&self) -> &str; + + /// Finds the actual bind address of this server. + /// If not found (returns `None`), maybe it's not started yet, or just don't have it. + fn bind_addr(&self) -> Option { + None + } } struct AcceptTask { diff --git a/src/servers/tests/grpc/mod.rs b/src/servers/tests/grpc/mod.rs deleted file mode 100644 index 30bd168bc1..0000000000 --- a/src/servers/tests/grpc/mod.rs +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2023 Greptime Team -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::net::SocketAddr; -use std::sync::Arc; - -use api::v1::auth_header::AuthScheme; -use api::v1::Basic; -use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; -use async_trait::async_trait; -use auth::tests::MockUserProvider; -use auth::UserProviderRef; -use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; -use common_runtime::runtime::BuilderBuild; -use common_runtime::{Builder as RuntimeBuilder, Runtime}; -use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu}; -use servers::grpc::flight::FlightCraftWrapper; -use servers::grpc::greptime_handler::GreptimeRequestHandler; -use servers::query_handler::grpc::ServerGrpcQueryHandlerRef; -use servers::server::Server; -use snafu::ResultExt; -use table::test_util::MemTable; -use table::TableRef; -use tokio::net::TcpListener; -use tokio_stream::wrappers::TcpListenerStream; -use tonic::codec::CompressionEncoding; - -use crate::{create_testing_grpc_query_handler, LOCALHOST_WITH_0}; - -struct MockGrpcServer { - query_handler: ServerGrpcQueryHandlerRef, - user_provider: Option, - runtime: Runtime, -} - -impl MockGrpcServer { - fn new( - query_handler: ServerGrpcQueryHandlerRef, - user_provider: Option, - runtime: Runtime, - ) -> Self { - Self { - query_handler, - user_provider, - runtime, - } - } - - fn create_service(&self) -> FlightServiceServer { - let service: FlightCraftWrapper<_> = GreptimeRequestHandler::new( - self.query_handler.clone(), - self.user_provider.clone(), - Some(self.runtime.clone()), - ) - .into(); - FlightServiceServer::new(service) - .accept_compressed(CompressionEncoding::Gzip) - .accept_compressed(CompressionEncoding::Zstd) - .send_compressed(CompressionEncoding::Gzip) - .send_compressed(CompressionEncoding::Zstd) - } -} - -#[async_trait] -impl Server for MockGrpcServer { - async fn shutdown(&self) -> Result<()> { - Ok(()) - } - - async fn start(&self, addr: SocketAddr) -> Result { - let (listener, addr) = { - let listener = TcpListener::bind(addr) - .await - .context(TcpBindSnafu { addr })?; - let addr = listener.local_addr().context(TcpBindSnafu { addr })?; - (listener, addr) - }; - - let service = self.create_service(); - // Would block to serve requests. - let _handle = tokio::spawn(async move { - tonic::transport::Server::builder() - .add_service(service) - .serve_with_incoming(TcpListenerStream::new(listener)) - .await - .context(StartGrpcSnafu) - .unwrap() - }); - - Ok(addr) - } - - fn name(&self) -> &str { - "MockGrpcServer" - } -} - -fn create_grpc_server(table: TableRef) -> Result> { - let query_handler = create_testing_grpc_query_handler(table); - let io_runtime = RuntimeBuilder::default() - .worker_threads(4) - .thread_name("grpc-io-handlers") - .build() - .unwrap(); - - let provider = MockUserProvider::default(); - - Ok(Arc::new(MockGrpcServer::new( - query_handler, - Some(Arc::new(provider)), - io_runtime, - ))) -} - -#[tokio::test] -async fn test_grpc_server_startup() { - let server = create_grpc_server(MemTable::default_numbers_table()).unwrap(); - let re = server.start(LOCALHOST_WITH_0.parse().unwrap()).await; - let _ = re.unwrap(); -} - -#[tokio::test] -async fn test_grpc_query() { - let server = create_grpc_server(MemTable::default_numbers_table()).unwrap(); - let re = server - .start(LOCALHOST_WITH_0.parse().unwrap()) - .await - .unwrap(); - let grpc_client = Client::with_urls(vec![re.to_string()]); - let mut db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); - - let re = db.sql("select * from numbers").await; - assert!(re.is_err()); - - let greptime = "greptime".to_string(); - db.set_auth(AuthScheme::Basic(Basic { - username: greptime.clone(), - password: greptime.clone(), - })); - let re = db.sql("select * from numbers").await; - let _ = re.unwrap(); -} diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 752145b7cb..1c5049e3ba 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -29,7 +29,7 @@ use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; use query::query_engine::DescribeResult; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error::{Error, NotSupportedSnafu, Result}; -use servers::query_handler::grpc::{GrpcQueryHandler, ServerGrpcQueryHandlerRef}; +use servers::query_handler::grpc::GrpcQueryHandler; use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler}; use session::context::QueryContextRef; use snafu::ensure; @@ -38,14 +38,11 @@ use table::metadata::TableId; use table::table_name::TableName; use table::TableRef; -mod grpc; mod http; mod interceptor; mod mysql; mod postgres; -const LOCALHOST_WITH_0: &str = "127.0.0.1:0"; - pub struct DummyInstance { query_engine: QueryEngineRef, } @@ -194,7 +191,3 @@ fn create_testing_instance(table: TableRef) -> DummyInstance { fn create_testing_sql_query_handler(table: TableRef) -> ServerSqlQueryHandlerRef { Arc::new(create_testing_instance(table)) as _ } - -fn create_testing_grpc_query_handler(table: TableRef) -> ServerGrpcQueryHandlerRef { - Arc::new(create_testing_instance(table)) as _ -} diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index 542f73ce69..56ddb574cc 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -80,10 +80,9 @@ fn create_mysql_server(table: TableRef, opts: MysqlOpts<'_>) -> Result Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default())?; + let mut mysql_server = create_mysql_server(table, Default::default())?; let listening = "127.0.0.1:0".parse::().unwrap(); - let result = mysql_server.start(listening).await; - let _ = result.unwrap(); + mysql_server.start(listening).await.unwrap(); let result = mysql_server.start(listening).await; assert!(result @@ -97,7 +96,7 @@ async fn test_start_mysql_server() -> Result<()> { async fn test_reject_no_database() -> Result<()> { common_telemetry::init_default_ut_logging(); let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server( + let mut mysql_server = create_mysql_server( table, MysqlOpts { reject_no_database: true, @@ -105,7 +104,8 @@ async fn test_reject_no_database() -> Result<()> { }, )?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = mysql_server.start(listening).await.unwrap(); + mysql_server.start(listening).await.unwrap(); + let server_addr = mysql_server.bind_addr().unwrap(); let server_port = server_addr.port(); let fail = create_connection(server_port, None, false).await; @@ -122,7 +122,7 @@ async fn test_reject_no_database() -> Result<()> { async fn test_schema_validation() -> Result<()> { async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box, u16)> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server( + let mut mysql_server = create_mysql_server( table, MysqlOpts { auth_info: Some(auth_info), @@ -130,7 +130,8 @@ async fn test_schema_validation() -> Result<()> { }, )?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = mysql_server.start(listening).await.unwrap(); + mysql_server.start(listening).await.unwrap(); + let server_addr = mysql_server.bind_addr().unwrap(); Ok((mysql_server, server_addr.port())) } @@ -168,7 +169,7 @@ async fn test_shutdown_mysql_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default())?; + let mut mysql_server = create_mysql_server(table, Default::default())?; let result = mysql_server.shutdown().await; assert!(result .unwrap_err() @@ -176,7 +177,8 @@ async fn test_shutdown_mysql_server() -> Result<()> { .contains("MySQL server is not started.")); let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = mysql_server.start(listening).await.unwrap(); + mysql_server.start(listening).await.unwrap(); + let server_addr = mysql_server.bind_addr().unwrap(); let server_port = server_addr.port(); let mut join_handles = vec![]; @@ -272,7 +274,7 @@ async fn test_server_required_secure_client_plain() -> Result<()> { let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::table("all_datatypes", recordbatch); - let mysql_server = create_mysql_server( + let mut mysql_server = create_mysql_server( table, MysqlOpts { tls: server_tls, @@ -281,7 +283,8 @@ async fn test_server_required_secure_client_plain() -> Result<()> { )?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = mysql_server.start(listening).await.unwrap(); + mysql_server.start(listening).await.unwrap(); + let server_addr = mysql_server.bind_addr().unwrap(); let r = create_connection(server_addr.port(), None, client_tls).await; assert!(r.is_err()); @@ -310,7 +313,7 @@ async fn test_server_required_secure_client_plain_with_pkcs8_priv_key() -> Resul let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::table("all_datatypes", recordbatch); - let mysql_server = create_mysql_server( + let mut mysql_server = create_mysql_server( table, MysqlOpts { tls: server_tls, @@ -319,7 +322,8 @@ async fn test_server_required_secure_client_plain_with_pkcs8_priv_key() -> Resul )?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = mysql_server.start(listening).await.unwrap(); + mysql_server.start(listening).await.unwrap(); + let server_addr = mysql_server.bind_addr().unwrap(); let r = create_connection_default_db_name(server_addr.port(), client_tls).await; assert!(r.is_err()); @@ -342,7 +346,7 @@ async fn test_db_name() -> Result<()> { let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::table("all_datatypes", recordbatch); - let mysql_server = create_mysql_server( + let mut mysql_server = create_mysql_server( table, MysqlOpts { tls: server_tls, @@ -351,7 +355,8 @@ async fn test_db_name() -> Result<()> { )?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = mysql_server.start(listening).await.unwrap(); + mysql_server.start(listening).await.unwrap(); + let server_addr = mysql_server.bind_addr().unwrap(); // None actually uses default database name let r = create_connection_default_db_name(server_addr.port(), client_tls).await; @@ -374,7 +379,7 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) -> let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::table("all_datatypes", recordbatch); - let mysql_server = create_mysql_server( + let mut mysql_server = create_mysql_server( table, MysqlOpts { tls: server_tls, @@ -383,7 +388,8 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) -> )?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = mysql_server.start(listening).await.unwrap(); + mysql_server.start(listening).await.unwrap(); + let server_addr = mysql_server.bind_addr().unwrap(); let mut connection = create_connection_default_db_name(server_addr.port(), client_tls) .await @@ -415,9 +421,10 @@ async fn test_query_concurrently() -> Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table, Default::default())?; + let mut mysql_server = create_mysql_server(table, Default::default())?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = mysql_server.start(listening).await.unwrap(); + mysql_server.start(listening).await.unwrap(); + let server_addr = mysql_server.bind_addr().unwrap(); let server_port = server_addr.port(); let threads = 4; @@ -470,7 +477,7 @@ async fn test_query_prepared() -> Result<()> { let recordbatch = RecordBatch::new(schema, columns.clone()).unwrap(); let table = MemTable::table("all_datatypes", recordbatch); - let mysql_server = create_mysql_server( + let mut mysql_server = create_mysql_server( table, MysqlOpts { ..Default::default() @@ -478,7 +485,8 @@ async fn test_query_prepared() -> Result<()> { )?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = mysql_server.start(listening).await.unwrap(); + mysql_server.start(listening).await.unwrap(); + let server_addr = mysql_server.bind_addr().unwrap(); let mut connection = create_connection_default_db_name(server_addr.port(), false) .await diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index c321fc07ff..e23a48210d 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -78,10 +78,9 @@ fn create_postgres_server( pub async fn test_start_postgres_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let pg_server = create_postgres_server(table, false, Default::default(), None)?; + let mut pg_server = create_postgres_server(table, false, Default::default(), None)?; let listening = "127.0.0.1:0".parse::().unwrap(); - let result = pg_server.start(listening).await; - let _ = result.unwrap(); + pg_server.start(listening).await.unwrap(); let result = pg_server.start(listening).await; assert!(result @@ -102,10 +101,11 @@ async fn test_shutdown_pg_server_range() -> Result<()> { async fn test_schema_validating() -> Result<()> { async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box, u16)> { let table = MemTable::default_numbers_table(); - let postgres_server = + let mut postgres_server = create_postgres_server(table, true, Default::default(), Some(auth_info))?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = postgres_server.start(listening).await.unwrap(); + postgres_server.start(listening).await.unwrap(); + let server_addr = postgres_server.bind_addr().unwrap(); let server_port = server_addr.port(); Ok((postgres_server, server_port)) } @@ -140,7 +140,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> { common_telemetry::init_default_ut_logging(); let table = MemTable::default_numbers_table(); - let postgres_server = create_postgres_server(table, with_pwd, Default::default(), None)?; + let mut postgres_server = create_postgres_server(table, with_pwd, Default::default(), None)?; let result = postgres_server.shutdown().await; assert!(result .unwrap_err() @@ -148,7 +148,8 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> { .contains("Postgres server is not started.")); let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = postgres_server.start(listening).await.unwrap(); + postgres_server.start(listening).await.unwrap(); + let server_addr = postgres_server.bind_addr().unwrap(); let server_port = server_addr.port(); let mut join_handles = vec![]; @@ -360,9 +361,10 @@ async fn start_test_server(server_tls: TlsOption) -> Result { let _ = install_ring_crypto_provider(); let table = MemTable::default_numbers_table(); - let pg_server = create_postgres_server(table, false, server_tls, None)?; + let mut pg_server = create_postgres_server(table, false, server_tls, None)?; let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = pg_server.start(listening).await.unwrap(); + pg_server.start(listening).await.unwrap(); + let server_addr = pg_server.bind_addr().unwrap(); Ok(server_addr.port()) } diff --git a/tests-integration/src/cluster.rs b/tests-integration/src/cluster.rs index adb047c823..4693dc3e84 100644 --- a/tests-integration/src/cluster.rs +++ b/tests-integration/src/cluster.rs @@ -213,7 +213,7 @@ impl GreptimeDbClusterBuilder { self.wait_datanodes_alive(metasrv.metasrv.meta_peer_client(), datanodes) .await; - let frontend = self.build_frontend(metasrv.clone(), datanode_clients).await; + let mut frontend = self.build_frontend(metasrv.clone(), datanode_clients).await; test_util::prepare_another_catalog_and_schema(&frontend.instance).await; @@ -225,7 +225,7 @@ impl GreptimeDbClusterBuilder { datanode_instances, kv_backend: self.kv_backend.clone(), metasrv: metasrv.metasrv, - frontend, + frontend: Arc::new(frontend), } } @@ -347,7 +347,7 @@ impl GreptimeDbClusterBuilder { &self, metasrv: MockInfo, datanode_clients: Arc, - ) -> Arc { + ) -> Frontend { let mut meta_client = MetaClientBuilder::frontend_default_options() .channel_manager(metasrv.channel_manager) .enable_access_cluster_info() @@ -413,11 +413,10 @@ impl GreptimeDbClusterBuilder { Frontend { instance, - servers: ServerHandlers::new(), + servers: ServerHandlers::default(), heartbeat_task: Some(heartbeat_task), export_metrics_task: None, } - .into() } } diff --git a/tests-integration/src/grpc/flight.rs b/tests-integration/src/grpc/flight.rs index 5e079bb037..6c50a90e11 100644 --- a/tests-integration/src/grpc/flight.rs +++ b/tests-integration/src/grpc/flight.rs @@ -16,7 +16,6 @@ mod test { use std::net::SocketAddr; use std::sync::Arc; - use std::time::Duration; use api::v1::auth_header::AuthScheme; use api::v1::{Basic, ColumnDataType, ColumnDef, CreateTableExpr, SemanticType}; @@ -48,8 +47,9 @@ mod test { async fn test_standalone_flight_do_put() { common_telemetry::init_default_ut_logging(); - let (addr, db, _server) = + let (db, server) = setup_grpc_server(StorageType::File, "test_standalone_flight_do_put").await; + let addr = server.bind_addr().unwrap().to_string(); let client = Client::with_urls(vec![addr]); let client = Database::new_with_dbname("greptime-public", client); @@ -95,17 +95,14 @@ mod test { .ok(), Some(runtime.clone()), ); - let grpc_server = GrpcServerBuilder::new(GrpcServerConfig::default(), runtime) + let mut grpc_server = GrpcServerBuilder::new(GrpcServerConfig::default(), runtime) .flight_handler(Arc::new(greptime_request_handler)) .build(); - let addr = grpc_server + grpc_server .start("127.0.0.1:0".parse::().unwrap()) .await - .unwrap() - .to_string(); - - // wait for GRPC server to start - tokio::time::sleep(Duration::from_secs(1)).await; + .unwrap(); + let addr = grpc_server.bind_addr().unwrap().to_string(); let client = Client::with_urls(vec![addr]); let mut client = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, client); diff --git a/tests-integration/src/standalone.rs b/tests-integration/src/standalone.rs index e8224688ba..b14ebafb3f 100644 --- a/tests-integration/src/standalone.rs +++ b/tests-integration/src/standalone.rs @@ -274,18 +274,17 @@ impl GreptimeDbStandaloneBuilder { test_util::prepare_another_catalog_and_schema(&instance).await; - let frontend = Frontend { + let mut frontend = Frontend { instance, - servers: ServerHandlers::new(), + servers: ServerHandlers::default(), heartbeat_task: None, export_metrics_task: None, }; - let frontend = Arc::new(frontend); frontend.start().await.unwrap(); GreptimeDbStandalone { - frontend, + frontend: Arc::new(frontend), opts, guard, kv_backend, diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 0e1d5cc261..6b8e56b89b 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -16,7 +16,6 @@ use std::env; use std::fmt::Display; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; use auth::UserProviderRef; use axum::Router; @@ -549,7 +548,7 @@ pub async fn setup_test_prom_app_with_frontend( pub async fn setup_grpc_server( store_type: StorageType, name: &str, -) -> (String, GreptimeDbStandalone, Arc) { +) -> (GreptimeDbStandalone, Arc) { setup_grpc_server_with(store_type, name, None, None).await } @@ -557,7 +556,7 @@ pub async fn setup_grpc_server_with_user_provider( store_type: StorageType, name: &str, user_provider: Option, -) -> (String, GreptimeDbStandalone, Arc) { +) -> (GreptimeDbStandalone, Arc) { setup_grpc_server_with(store_type, name, user_provider, None).await } @@ -566,7 +565,7 @@ pub async fn setup_grpc_server_with( name: &str, user_provider: Option, grpc_config: Option, -) -> (String, GreptimeDbStandalone, Arc) { +) -> (GreptimeDbStandalone, Arc) { let instance = setup_standalone_instance(name, store_type).await; let runtime: Runtime = RuntimeBuilder::default() @@ -593,25 +592,18 @@ pub async fn setup_grpc_server_with( .with_tls_config(grpc_config.tls) .unwrap(); - let fe_grpc_server = Arc::new(grpc_builder.build()); + let mut grpc_server = grpc_builder.build(); let fe_grpc_addr = "127.0.0.1:0".parse::().unwrap(); - let fe_grpc_addr = fe_grpc_server - .start(fe_grpc_addr) - .await - .unwrap() - .to_string(); + grpc_server.start(fe_grpc_addr).await.unwrap(); - // wait for GRPC server to start - tokio::time::sleep(Duration::from_secs(1)).await; - - (fe_grpc_addr, instance, fe_grpc_server) + (instance, Arc::new(grpc_server)) } pub async fn setup_mysql_server( store_type: StorageType, name: &str, -) -> (String, TestGuard, Arc>) { +) -> (TestGuard, Arc>) { setup_mysql_server_with_user_provider(store_type, name, None).await } @@ -619,7 +611,7 @@ pub async fn setup_mysql_server_with_user_provider( store_type: StorageType, name: &str, user_provider: Option, -) -> (String, TestGuard, Arc>) { +) -> (TestGuard, Arc>) { let instance = setup_standalone_instance(name, store_type).await; let runtime = RuntimeBuilder::default() @@ -635,7 +627,7 @@ pub async fn setup_mysql_server_with_user_provider( addr: fe_mysql_addr.clone(), ..Default::default() }; - let fe_mysql_server = Arc::new(MysqlServer::create_server( + let mut mysql_server = MysqlServer::create_server( runtime, Arc::new(MysqlSpawnRef::new( ServerSqlQueryHandlerAdapter::arc(fe_instance_ref), @@ -650,24 +642,20 @@ pub async fn setup_mysql_server_with_user_provider( 0, opts.reject_no_database.unwrap_or(false), )), - )); + ); - let fe_mysql_addr_clone = fe_mysql_addr.clone(); - let fe_mysql_server_clone = fe_mysql_server.clone(); - let _handle = tokio::spawn(async move { - let addr = fe_mysql_addr_clone.parse::().unwrap(); - fe_mysql_server_clone.start(addr).await.unwrap() - }); + mysql_server + .start(fe_mysql_addr.parse::().unwrap()) + .await + .unwrap(); - tokio::time::sleep(Duration::from_secs(1)).await; - - (fe_mysql_addr, instance.guard, fe_mysql_server) + (instance.guard, Arc::new(mysql_server)) } pub async fn setup_pg_server( store_type: StorageType, name: &str, -) -> (String, TestGuard, Arc>) { +) -> (TestGuard, Arc>) { setup_pg_server_with_user_provider(store_type, name, None).await } @@ -675,7 +663,7 @@ pub async fn setup_pg_server_with_user_provider( store_type: StorageType, name: &str, user_provider: Option, -) -> (String, TestGuard, Arc>) { +) -> (TestGuard, Arc>) { let instance = setup_standalone_instance(name, store_type).await; let runtime = RuntimeBuilder::default() @@ -696,25 +684,21 @@ pub async fn setup_pg_server_with_user_provider( .expect("Failed to load certificates and keys"), ); - let fe_pg_server = Arc::new(Box::new(PostgresServer::new( + let mut pg_server = Box::new(PostgresServer::new( ServerSqlQueryHandlerAdapter::arc(fe_instance_ref), opts.tls.should_force_tls(), tls_server_config, 0, runtime, user_provider, - )) as Box); + )); - let fe_pg_addr_clone = fe_pg_addr.clone(); - let fe_pg_server_clone = fe_pg_server.clone(); - let _handle = tokio::spawn(async move { - let addr = fe_pg_addr_clone.parse::().unwrap(); - fe_pg_server_clone.start(addr).await.unwrap() - }); + pg_server + .start(fe_pg_addr.parse::().unwrap()) + .await + .unwrap(); - tokio::time::sleep(Duration::from_secs(1)).await; - - (fe_pg_addr, instance.guard, fe_pg_server) + (instance.guard, Arc::new(pg_server)) } pub(crate) async fn prepare_another_catalog_and_schema(instance: &Instance) { diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index 0a7fffa82d..d37759ece7 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -90,7 +90,8 @@ macro_rules! grpc_tests { } pub async fn test_invalid_dbname(store_type: StorageType) { - let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_invalid_dbname").await; + let (_db, fe_grpc_server) = setup_grpc_server(store_type, "test_invalid_dbname").await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname("tom", grpc_client); @@ -117,7 +118,8 @@ pub async fn test_invalid_dbname(store_type: StorageType) { } pub async fn test_dbname(store_type: StorageType) { - let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_dbname").await; + let (_db, fe_grpc_server) = setup_grpc_server(store_type, "test_dbname").await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -134,8 +136,9 @@ pub async fn test_grpc_message_size_ok(store_type: StorageType) { max_send_message_size: 1024, ..Default::default() }; - let (addr, _db, fe_grpc_server) = + let (_db, fe_grpc_server) = setup_grpc_server_with(store_type, "test_grpc_message_size_ok", None, Some(config)).await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -153,8 +156,9 @@ pub async fn test_grpc_zstd_compression(store_type: StorageType) { max_send_message_size: 1024, ..Default::default() }; - let (addr, _db, fe_grpc_server) = + let (_db, fe_grpc_server) = setup_grpc_server_with(store_type, "test_grpc_zstd_compression", None, Some(config)).await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -171,13 +175,14 @@ pub async fn test_grpc_message_size_limit_send(store_type: StorageType) { max_send_message_size: 50, ..Default::default() }; - let (addr, _db, fe_grpc_server) = setup_grpc_server_with( + let (_db, fe_grpc_server) = setup_grpc_server_with( store_type, "test_grpc_message_size_limit_send", None, Some(config), ) .await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -195,13 +200,14 @@ pub async fn test_grpc_message_size_limit_recv(store_type: StorageType) { max_send_message_size: 1024, ..Default::default() }; - let (addr, _db, fe_grpc_server) = setup_grpc_server_with( + let (_db, fe_grpc_server) = setup_grpc_server_with( store_type, "test_grpc_message_size_limit_recv", None, Some(config), ) .await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let db = Database::new_with_dbname( @@ -222,9 +228,10 @@ pub async fn test_grpc_auth(store_type: StorageType) { &"static_user_provider:cmd:greptime_user=greptime_pwd".to_string(), ) .unwrap(); - let (addr, _db, fe_grpc_server) = + let (_db, fe_grpc_server) = setup_grpc_server_with_user_provider(store_type, "auto_create_table", Some(user_provider)) .await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let mut db = Database::new_with_dbname( @@ -270,7 +277,8 @@ pub async fn test_grpc_auth(store_type: StorageType) { } pub async fn test_auto_create_table(store_type: StorageType) { - let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_auto_create_table").await; + let (_db, fe_grpc_server) = setup_grpc_server(store_type, "test_auto_create_table").await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); @@ -279,8 +287,9 @@ pub async fn test_auto_create_table(store_type: StorageType) { } pub async fn test_auto_create_table_with_hints(store_type: StorageType) { - let (addr, _db, fe_grpc_server) = - setup_grpc_server(store_type, "auto_create_table_with_hints").await; + let (_db, fe_grpc_server) = + setup_grpc_server(store_type, "test_auto_create_table_with_hints").await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); @@ -346,7 +355,8 @@ fn expect_data() -> (Column, Column, Column, Column) { pub async fn test_insert_and_select(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_insert_and_select").await; + let (_db, fe_grpc_server) = setup_grpc_server(store_type, "test_insert_and_select").await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); @@ -587,7 +597,8 @@ fn testing_create_expr() -> CreateTableExpr { } pub async fn test_health_check(store_type: StorageType) { - let (addr, _db, fe_grpc_server) = setup_grpc_server(store_type, "test_health_check").await; + let (_db, fe_grpc_server) = setup_grpc_server(store_type, "test_health_check").await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); grpc_client.health_check().await.unwrap(); @@ -599,8 +610,9 @@ pub async fn test_prom_gateway_query(store_type: StorageType) { common_telemetry::init_default_ut_logging(); // prepare connection - let (addr, _db, fe_grpc_server) = - setup_grpc_server(store_type, "test_prom_gateway_query").await; + let (_db, fe_grpc_server) = setup_grpc_server(store_type, "test_prom_gateway_query").await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); + let grpc_client = Client::with_urls(vec![addr]); let db = Database::new( DEFAULT_CATALOG_NAME, @@ -775,8 +787,9 @@ pub async fn test_grpc_timezone(store_type: StorageType) { max_send_message_size: 1024, ..Default::default() }; - let (addr, _db, fe_grpc_server) = + let (_db, fe_grpc_server) = setup_grpc_server_with(store_type, "auto_create_table", None, Some(config)).await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let grpc_client = Client::with_urls(vec![addr]); let mut db = Database::new_with_dbname( @@ -849,8 +862,9 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { max_send_message_size: 1024, tls, }; - let (addr, _db, fe_grpc_server) = + let (_db, fe_grpc_server) = setup_grpc_server_with(store_type, "tls_create_table", None, Some(config)).await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); let mut client_tls = ClientTlsOption { server_ca_cert_path: ca_path, diff --git a/tests-integration/tests/sql.rs b/tests-integration/tests/sql.rs index 06ab0226e2..9c66f33ead 100644 --- a/tests-integration/tests/sql.rs +++ b/tests-integration/tests/sql.rs @@ -84,8 +84,9 @@ pub async fn test_mysql_auth(store_type: StorageType) { ) .unwrap(); - let (addr, mut guard, fe_mysql_server) = + let (mut guard, fe_mysql_server) = setup_mysql_server_with_user_provider(store_type, "sql_crud", Some(user_provider)).await; + let addr = fe_mysql_server.bind_addr().unwrap().to_string(); // 1. no auth let conn_re = MySqlPoolOptions::new() @@ -138,7 +139,8 @@ pub async fn test_mysql_auth(store_type: StorageType) { pub async fn test_mysql_stmts(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (addr, mut guard, fe_mysql_server) = setup_mysql_server(store_type, "sql_crud").await; + let (mut guard, fe_mysql_server) = setup_mysql_server(store_type, "test_mysql_stmts").await; + let addr = fe_mysql_server.bind_addr().unwrap().to_string(); let mut conn = MySqlConnection::connect(&format!("mysql://{addr}/public")) .await @@ -157,7 +159,8 @@ pub async fn test_mysql_stmts(store_type: StorageType) { pub async fn test_mysql_crud(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (addr, mut guard, fe_mysql_server) = setup_mysql_server(store_type, "sql_crud").await; + let (mut guard, fe_mysql_server) = setup_mysql_server(store_type, "test_mysql_crud").await; + let addr = fe_mysql_server.bind_addr().unwrap().to_string(); let pool = MySqlPoolOptions::new() .max_connections(2) @@ -322,7 +325,9 @@ pub async fn test_mysql_crud(store_type: StorageType) { pub async fn test_mysql_timezone(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (addr, mut guard, fe_mysql_server) = setup_mysql_server(store_type, "mysql_timezone").await; + let (mut guard, fe_mysql_server) = setup_mysql_server(store_type, "test_mysql_timezone").await; + let addr = fe_mysql_server.bind_addr().unwrap().to_string(); + let mut conn = MySqlConnection::connect(&format!("mysql://{addr}/public")) .await .unwrap(); @@ -378,8 +383,9 @@ pub async fn test_postgres_auth(store_type: StorageType) { ) .unwrap(); - let (addr, mut guard, fe_pg_server) = + let (mut guard, fe_pg_server) = setup_pg_server_with_user_provider(store_type, "sql_crud", Some(user_provider)).await; + let addr = fe_pg_server.bind_addr().unwrap().to_string(); // 1. no auth let conn_re = PgPoolOptions::new() @@ -432,7 +438,8 @@ pub async fn test_postgres_auth(store_type: StorageType) { } pub async fn test_postgres_crud(store_type: StorageType) { - let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_crud").await; + let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_crud").await; + let addr = fe_pg_server.bind_addr().unwrap().to_string(); let pool = PgPoolOptions::new() .max_connections(2) @@ -539,7 +546,8 @@ pub async fn test_postgres_crud(store_type: StorageType) { guard.remove_all().await; } pub async fn test_postgres_bytea(store_type: StorageType) { - let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_bytea_output").await; + let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_bytea").await; + let addr = fe_pg_server.bind_addr().unwrap().to_string(); let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) .await @@ -608,7 +616,8 @@ pub async fn test_postgres_bytea(store_type: StorageType) { } pub async fn test_postgres_datestyle(store_type: StorageType) { - let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "various datestyle").await; + let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_datestyle").await; + let addr = fe_pg_server.bind_addr().unwrap().to_string(); let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) .await @@ -835,7 +844,8 @@ pub async fn test_postgres_datestyle(store_type: StorageType) { } pub async fn test_postgres_timezone(store_type: StorageType) { - let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await; + let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_timezone").await; + let addr = fe_pg_server.bind_addr().unwrap().to_string(); let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) .await @@ -896,7 +906,9 @@ pub async fn test_postgres_timezone(store_type: StorageType) { } pub async fn test_postgres_parameter_inference(store_type: StorageType) { - let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await; + let (mut guard, fe_pg_server) = + setup_pg_server(store_type, "test_postgres_parameter_inference").await; + let addr = fe_pg_server.bind_addr().unwrap().to_string(); let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) .await @@ -974,7 +986,10 @@ pub async fn test_mysql_async_timestamp(store_type: StorageType) { } common_telemetry::init_default_ut_logging(); - let (addr, mut guard, fe_mysql_server) = setup_mysql_server(store_type, "sql_timestamp").await; + let (mut guard, fe_mysql_server) = + setup_mysql_server(store_type, "test_mysql_async_timestamp").await; + let addr = fe_mysql_server.bind_addr().unwrap().to_string(); + let url = format!("mysql://{addr}/public"); let opts = mysql_async::Opts::from_url(&url).unwrap(); let mut conn = mysql_async::Conn::new(opts) @@ -1095,8 +1110,9 @@ pub async fn test_mysql_async_timestamp(store_type: StorageType) { } pub async fn test_mysql_prepare_stmt_insert_timestamp(store_type: StorageType) { - let (addr, mut guard, server) = + let (mut guard, server) = setup_mysql_server(store_type, "test_mysql_prepare_stmt_insert_timestamp").await; + let addr = server.bind_addr().unwrap().to_string(); let pool = MySqlPoolOptions::new() .max_connections(2) @@ -1170,7 +1186,8 @@ pub async fn test_mysql_prepare_stmt_insert_timestamp(store_type: StorageType) { } pub async fn test_postgres_array_types(store_type: StorageType) { - let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await; + let (mut guard, fe_pg_server) = setup_pg_server(store_type, "test_postgres_array_types").await; + let addr = fe_pg_server.bind_addr().unwrap().to_string(); let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) .await @@ -1201,7 +1218,9 @@ pub async fn test_postgres_array_types(store_type: StorageType) { } pub async fn test_declare_fetch_close_cursor(store_type: StorageType) { - let (addr, mut guard, fe_pg_server) = setup_pg_server(store_type, "sql_inference").await; + let (mut guard, fe_pg_server) = + setup_pg_server(store_type, "test_declare_fetch_close_cursor").await; + let addr = fe_pg_server.bind_addr().unwrap().to_string(); let (client, connection) = tokio_postgres::connect(&format!("postgres://{addr}/public"), NoTls) .await