refactor: remove some async in ServerHandlers (#6057)

* refactor: remove some async in ServerHandlers

* address PR comments
This commit is contained in:
LFC
2025-05-07 11:57:16 +08:00
committed by GitHub
parent 56f31d5933
commit 4b5ab75312
30 changed files with 291 additions and 360 deletions

View File

@@ -58,7 +58,7 @@ impl App for Instance {
false
}
async fn stop(&self) -> Result<()> {
async fn stop(&mut self) -> Result<()> {
Ok(())
}
}

View File

@@ -77,7 +77,7 @@ impl App for Instance {
self.datanode.start().await.context(StartDatanodeSnafu)
}
async fn stop(&self) -> Result<()> {
async fn stop(&mut self) -> Result<()> {
self.datanode
.shutdown()
.await

View File

@@ -129,7 +129,6 @@ impl InstanceBuilder {
.with_default_grpc_server(&datanode.region_server())
.enable_http_service()
.build()
.await
.context(StartDatanodeSnafu)?;
datanode.setup_services(services);

View File

@@ -85,7 +85,7 @@ impl App for Instance {
self.flownode.start().await.context(StartFlownodeSnafu)
}
async fn stop(&self) -> Result<()> {
async fn stop(&mut self) -> Result<()> {
self.flownode
.shutdown()
.await
@@ -331,7 +331,6 @@ impl StartCommand {
.with_grpc_server(flownode.flownode_server().clone())
.enable_http_service()
.build()
.await
.context(StartFlownodeSnafu)?;
flownode.setup_services(services);
let flownode = flownode;

View File

@@ -89,7 +89,7 @@ impl App for Instance {
.context(error::StartFrontendSnafu)
}
async fn stop(&self) -> Result<()> {
async fn stop(&mut self) -> Result<()> {
self.frontend
.shutdown()
.await
@@ -382,7 +382,6 @@ impl StartCommand {
let servers = Services::new(opts, instance.clone(), plugins)
.build()
.await
.context(error::StartFrontendSnafu)?;
let frontend = Frontend {

View File

@@ -74,7 +74,7 @@ pub trait App: Send {
true
}
async fn stop(&self) -> Result<()>;
async fn stop(&mut self) -> Result<()>;
async fn run(&mut self) -> Result<()> {
info!("Starting app: {}", self.name());

View File

@@ -69,7 +69,7 @@ impl App for Instance {
self.instance.start().await.context(StartMetaServerSnafu)
}
async fn stop(&self) -> Result<()> {
async fn stop(&mut self) -> Result<()> {
self.instance
.shutdown()
.await

View File

@@ -255,8 +255,8 @@ pub struct Instance {
impl Instance {
/// Find the socket addr of a server by its `name`.
pub async fn server_addr(&self, name: &str) -> Option<SocketAddr> {
self.frontend.server_handlers().addr(name).await
pub fn server_addr(&self, name: &str) -> Option<SocketAddr> {
self.frontend.server_handlers().addr(name)
}
}
@@ -293,7 +293,7 @@ impl App for Instance {
Ok(())
}
async fn stop(&self) -> Result<()> {
async fn stop(&mut self) -> Result<()> {
self.frontend
.shutdown()
.await
@@ -630,7 +630,6 @@ impl StartCommand {
let servers = Services::new(opts, fe_instance.clone(), plugins)
.build()
.await
.context(error::StartFrontendSnafu)?;
let frontend = Frontend {

View File

@@ -129,7 +129,7 @@ impl Datanode {
self.services = services;
}
pub async fn shutdown(&self) -> Result<()> {
pub async fn shutdown(&mut self) -> Result<()> {
self.services
.shutdown_all()
.await

View File

@@ -62,7 +62,7 @@ impl<'a> DatanodeServiceBuilder<'a> {
}
}
pub async fn build(mut self) -> Result<ServerHandlers> {
pub fn build(mut self) -> Result<ServerHandlers> {
let handlers = ServerHandlers::default();
if let Some(grpc_server) = self.grpc_server.take() {
@@ -70,7 +70,7 @@ impl<'a> DatanodeServiceBuilder<'a> {
addr: &self.opts.grpc.bind_addr,
})?;
let handler: ServerHandler = (Box::new(grpc_server), addr);
handlers.insert(handler).await;
handlers.insert(handler);
}
if self.enable_http_service {
@@ -82,7 +82,7 @@ impl<'a> DatanodeServiceBuilder<'a> {
addr: &self.opts.http.addr,
})?;
let handler: ServerHandler = (Box::new(http_server), addr);
handlers.insert(handler).await;
handlers.insert(handler);
}
Ok(handlers)

View File

@@ -231,10 +231,10 @@ impl servers::server::Server for FlownodeServer {
Ok(())
}
async fn start(&self, addr: SocketAddr) -> Result<SocketAddr, servers::error::Error> {
async fn start(&mut self, addr: SocketAddr) -> Result<(), servers::error::Error> {
let mut rx_server = self.inner.server_shutdown_tx.lock().await.subscribe();
let (incoming, addr) = {
let incoming = {
let listener = TcpListener::bind(addr)
.await
.context(TcpBindSnafu { addr })?;
@@ -243,7 +243,7 @@ impl servers::server::Server for FlownodeServer {
TcpIncoming::from_listener(listener, true, None).context(TcpIncomingSnafu)?;
info!("flow server is bound to {}", addr);
(incoming, addr)
incoming
};
let builder = tonic::transport::Server::builder().add_service(self.create_flow_service());
@@ -255,7 +255,7 @@ impl servers::server::Server for FlownodeServer {
.context(StartGrpcSnafu);
});
Ok(addr)
Ok(())
}
fn name(&self) -> &str {
@@ -282,7 +282,7 @@ impl FlownodeInstance {
Ok(())
}
pub async fn shutdown(&self) -> Result<(), crate::Error> {
pub async fn shutdown(&mut self) -> Result<(), Error> {
self.services
.shutdown_all()
.await
@@ -391,7 +391,7 @@ impl FlownodeBuilder {
let instance = FlownodeInstance {
flownode_server: server,
services: ServerHandlers::new(),
services: ServerHandlers::default(),
heartbeat_task,
};
Ok(instance)
@@ -572,14 +572,14 @@ impl<'a> FlownodeServiceBuilder<'a> {
}
}
pub async fn build(mut self) -> Result<ServerHandlers, Error> {
pub fn build(mut self) -> Result<ServerHandlers, Error> {
let handlers = ServerHandlers::default();
if let Some(grpc_server) = self.grpc_server.take() {
let addr: SocketAddr = self.opts.grpc.bind_addr.parse().context(ParseAddrSnafu {
addr: &self.opts.grpc.bind_addr,
})?;
let handler: ServerHandler = (Box::new(grpc_server), addr);
handlers.insert(handler).await;
handlers.insert(handler);
}
if self.enable_http_service {
@@ -590,7 +590,7 @@ impl<'a> FlownodeServiceBuilder<'a> {
addr: &self.opts.http.addr,
})?;
let handler: ServerHandler = (Box::new(http_server), addr);
handlers.insert(handler).await;
handlers.insert(handler);
}
Ok(handlers)
}

View File

@@ -106,7 +106,7 @@ pub struct Frontend {
}
impl Frontend {
pub async fn start(&self) -> Result<()> {
pub async fn start(&mut self) -> Result<()> {
if let Some(t) = &self.heartbeat_task {
t.start().await?;
}
@@ -128,7 +128,7 @@ impl Frontend {
.context(error::StartServerSnafu)
}
pub async fn shutdown(&self) -> Result<()> {
pub async fn shutdown(&mut self) -> Result<()> {
self.servers
.shutdown_all()
.await

View File

@@ -179,7 +179,7 @@ where
Ok(http_server)
}
pub async fn build(mut self) -> Result<ServerHandlers> {
pub fn build(mut self) -> Result<ServerHandlers> {
let opts = self.opts.clone();
let instance = self.instance.clone();
@@ -194,7 +194,7 @@ where
// Always init GRPC server
let grpc_addr = parse_addr(&opts.grpc.bind_addr)?;
let grpc_server = self.build_grpc_server(&opts)?;
handlers.insert((Box::new(grpc_server), grpc_addr)).await;
handlers.insert((Box::new(grpc_server), grpc_addr));
}
{
@@ -202,7 +202,7 @@ where
let http_options = &opts.http;
let http_addr = parse_addr(&http_options.addr)?;
let http_server = self.build_http_server(&opts, toml)?;
handlers.insert((Box::new(http_server), http_addr)).await;
handlers.insert((Box::new(http_server), http_addr));
}
if opts.mysql.enable {
@@ -230,7 +230,7 @@ where
opts.reject_no_database.unwrap_or(false),
)),
);
handlers.insert((mysql_server, mysql_addr)).await;
handlers.insert((mysql_server, mysql_addr));
}
if opts.postgres.enable {
@@ -253,7 +253,7 @@ where
user_provider.clone(),
)) as Box<dyn Server>;
handlers.insert((pg_server, pg_addr)).await;
handlers.insert((pg_server, pg_addr));
}
Ok(handlers)

View File

@@ -79,7 +79,7 @@ use crate::{error, Result};
pub struct MetasrvInstance {
metasrv: Arc<Metasrv>,
httpsrv: Arc<HttpServer>,
http_server: HttpServer,
opts: MetasrvOptions,
@@ -96,12 +96,11 @@ impl MetasrvInstance {
plugins: Plugins,
metasrv: Metasrv,
) -> Result<MetasrvInstance> {
let httpsrv = Arc::new(
HttpServerBuilder::new(opts.http.clone())
.with_metrics_handler(MetricsHandler)
.with_greptime_config_options(opts.to_toml().context(error::TomlFormatSnafu)?)
.build(),
);
let http_server = HttpServerBuilder::new(opts.http.clone())
.with_metrics_handler(MetricsHandler)
.with_greptime_config_options(opts.to_toml().context(error::TomlFormatSnafu)?)
.build();
let metasrv = Arc::new(metasrv);
// put metasrv into plugins for later use
plugins.insert::<Arc<Metasrv>>(metasrv.clone());
@@ -109,7 +108,7 @@ impl MetasrvInstance {
.context(error::InitExportMetricsTaskSnafu)?;
Ok(MetasrvInstance {
metasrv,
httpsrv,
http_server,
opts,
signal_sender: None,
plugins,
@@ -138,10 +137,9 @@ impl MetasrvInstance {
addr: &self.opts.http.addr,
})?;
let http_srv = async {
self.httpsrv
self.http_server
.start(addr)
.await
.map(|_| ())
.context(error::StartHttpSnafu)
};
future::try_join(metasrv, http_srv).await?;
@@ -156,11 +154,11 @@ impl MetasrvInstance {
.context(error::SendShutdownSignalSnafu)?;
}
self.metasrv.shutdown().await?;
self.httpsrv
self.http_server
.shutdown()
.await
.context(error::ShutdownServerSnafu {
server: self.httpsrv.name(),
server: self.http_server.name(),
})?;
Ok(())
}

View File

@@ -151,6 +151,7 @@ pub struct GrpcServer {
>,
>,
>,
bind_addr: Option<SocketAddr>,
}
/// Grpc Server configuration
@@ -236,7 +237,7 @@ impl Server for GrpcServer {
Ok(())
}
async fn start(&self, addr: SocketAddr) -> Result<SocketAddr> {
async fn start(&mut self, addr: SocketAddr) -> Result<()> {
let routes = {
let mut routes = self.routes.lock().await;
let Some(routes) = routes.take() else {
@@ -298,10 +299,16 @@ impl Server for GrpcServer {
.context(StartGrpcSnafu);
serve_state_tx.send(result)
});
Ok(addr)
self.bind_addr = Some(addr);
Ok(())
}
fn name(&self) -> &str {
GRPC_SERVER
}
fn bind_addr(&self) -> Option<SocketAddr> {
self.bind_addr
}
}

View File

@@ -181,6 +181,7 @@ impl GrpcServerBuilder {
serve_state: Mutex::new(None),
tls_config: self.tls_config,
otel_arrow_service: Mutex::new(self.otel_arrow_service),
bind_addr: None,
}
}
}

View File

@@ -130,6 +130,7 @@ pub struct HttpServer {
// server configs
options: HttpOptions,
bind_addr: Option<SocketAddr>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -687,6 +688,7 @@ impl HttpServerBuilder {
shutdown_tx: Mutex::new(None),
plugins: self.plugins,
router: StdMutex::new(self.router),
bind_addr: None,
}
}
}
@@ -1099,7 +1101,7 @@ impl Server for HttpServer {
Ok(())
}
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
async fn start(&mut self, listening: SocketAddr) -> Result<()> {
let (tx, rx) = oneshot::channel();
let serve = {
let mut shutdown_tx = self.shutdown_tx.lock().await;
@@ -1155,12 +1157,18 @@ impl Server for HttpServer {
error!(e; "Failed to shutdown http server");
}
});
Ok(listening)
self.bind_addr = Some(listening);
Ok(())
}
fn name(&self) -> &str {
HTTP_SERVER
}
fn bind_addr(&self) -> Option<SocketAddr> {
self.bind_addr
}
}
#[cfg(test)]

View File

@@ -111,6 +111,7 @@ pub struct MysqlServer {
base_server: BaseTcpServer,
spawn_ref: Arc<MysqlSpawnRef>,
spawn_config: Arc<MysqlSpawnConfig>,
bind_addr: Option<SocketAddr>,
}
impl MysqlServer {
@@ -123,6 +124,7 @@ impl MysqlServer {
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
spawn_ref,
spawn_config,
bind_addr: None,
})
}
@@ -221,7 +223,7 @@ impl Server for MysqlServer {
self.base_server.shutdown().await
}
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
async fn start(&mut self, listening: SocketAddr) -> Result<()> {
let (stream, addr) = self
.base_server
.bind(listening, self.spawn_config.keep_alive_secs)
@@ -230,10 +232,16 @@ impl Server for MysqlServer {
let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
self.base_server.start_with(join_handle).await?;
Ok(addr)
self.bind_addr = Some(addr);
Ok(())
}
fn name(&self) -> &str {
MYSQL_SERVER
}
fn bind_addr(&self) -> Option<SocketAddr> {
self.bind_addr
}
}

View File

@@ -36,6 +36,7 @@ pub struct PostgresServer {
make_handler: Arc<MakePostgresServerHandler>,
tls_server_config: Arc<ReloadableTlsServerConfig>,
keep_alive_secs: u64,
bind_addr: Option<SocketAddr>,
}
impl PostgresServer {
@@ -61,6 +62,7 @@ impl PostgresServer {
make_handler,
tls_server_config,
keep_alive_secs,
bind_addr: None,
}
}
@@ -118,7 +120,7 @@ impl Server for PostgresServer {
self.base_server.shutdown().await
}
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
async fn start(&mut self, listening: SocketAddr) -> Result<()> {
let (stream, addr) = self
.base_server
.bind(listening, self.keep_alive_secs)
@@ -128,10 +130,16 @@ impl Server for PostgresServer {
let join_handle = common_runtime::spawn_global(self.accept(io_runtime, stream));
self.base_server.start_with(join_handle).await?;
Ok(addr)
self.bind_addr = Some(addr);
Ok(())
}
fn name(&self) -> &str {
POSTGRES_SERVER
}
fn bind_addr(&self) -> Option<SocketAddr> {
self.bind_addr
}
}

View File

@@ -21,7 +21,7 @@ use common_runtime::Runtime;
use common_telemetry::{error, info};
use futures::future::{try_join_all, AbortHandle, AbortRegistration, Abortable};
use snafu::{ensure, ResultExt};
use tokio::sync::{Mutex, RwLock};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio_stream::wrappers::TcpListenerStream;
@@ -32,20 +32,31 @@ pub(crate) type AbortableStream = Abortable<TcpListenerStream>;
pub type ServerHandler = (Box<dyn Server>, SocketAddr);
/// [ServerHandlers] is used to manage the lifecycle of all the services like http or grpc in the GreptimeDB server.
#[derive(Clone, Default)]
pub struct ServerHandlers {
handlers: Arc<RwLock<HashMap<String, ServerHandler>>>,
#[derive(Clone)]
pub enum ServerHandlers {
Init(Arc<std::sync::Mutex<HashMap<String, ServerHandler>>>),
Started(Arc<HashMap<String, Box<dyn Server>>>),
}
impl Default for ServerHandlers {
fn default() -> Self {
Self::Init(Arc::new(std::sync::Mutex::new(HashMap::new())))
}
}
impl ServerHandlers {
pub fn new() -> Self {
Self {
handlers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn insert(&self, handler: ServerHandler) {
let mut handlers = self.handlers.write().await;
/// Inserts a [ServerHandler] **before** the [ServerHandlers] is started.
pub fn insert(&self, handler: ServerHandler) {
// Inserts more to ServerHandlers while it is not in the initialization state
// is considered a bug.
assert!(
matches!(self, ServerHandlers::Init(_)),
"unexpected: insert when `ServerHandlers` is not during initialization"
);
let ServerHandlers::Init(handlers) = self else {
unreachable!("guarded by the assertion above");
};
let mut handlers = handlers.lock().unwrap();
handlers.insert(handler.0.name().to_string(), handler);
}
@@ -55,33 +66,59 @@ impl ServerHandlers {
/// the server to get the real bound port number. This way we avoid doing careful assignment of
/// the port number to the service in the test.
///
/// Note that the address is guaranteed to be correct only after the `start_all` method is
/// successfully invoked. Otherwise you may find the address to be what you configured before.
pub async fn addr(&self, name: &str) -> Option<SocketAddr> {
let handlers = self.handlers.read().await;
handlers.get(name).map(|x| x.1)
/// Note that the address is only retrievable after the [ServerHandlers] is started (the
/// `start_all` method is called successfully). Otherwise you may find the address still be
/// `None` even if you are certain the server was inserted before.
pub fn addr(&self, name: &str) -> Option<SocketAddr> {
let ServerHandlers::Started(handlers) = self else {
return None;
};
handlers.get(name).and_then(|x| x.bind_addr())
}
/// Starts all the managed services. It will block until all the services are started.
/// And it will set the actual bound address to the service.
pub async fn start_all(&self) -> Result<()> {
let mut handlers = self.handlers.write().await;
pub async fn start_all(&mut self) -> Result<()> {
let ServerHandlers::Init(handlers) = self else {
// If already started, do nothing.
return Ok(());
};
let mut handlers = {
let mut handlers = handlers.lock().unwrap();
std::mem::take(&mut *handlers)
};
try_join_all(handlers.values_mut().map(|(server, addr)| async move {
let bind_addr = server.start(*addr).await?;
*addr = bind_addr;
info!("Service {} is started at {}", server.name(), bind_addr);
server.start(*addr).await?;
let bind_addr = server.bind_addr();
info!(
"Server {} is started and bind to {:?}",
server.name(),
bind_addr,
);
Ok::<(), error::Error>(())
}))
.await?;
let handlers = handlers
.into_iter()
.map(|(k, v)| (k, v.0))
.collect::<HashMap<_, _>>();
*self = ServerHandlers::Started(Arc::new(handlers));
Ok(())
}
/// Shutdown all the managed services. It will block until all the services are shutdown.
pub async fn shutdown_all(&self) -> Result<()> {
// Even though the `shutdown` method in server does not require mut self, we still acquire
// write lock to pair with `start_all` method.
let handlers = self.handlers.write().await;
try_join_all(handlers.values().map(|(server, _)| async move {
pub async fn shutdown_all(&mut self) -> Result<()> {
let ServerHandlers::Started(handlers) = self else {
// If not started, do nothing.
return Ok(());
};
let handlers = std::mem::take(handlers);
try_join_all(handlers.values().map(|server| async move {
server.shutdown().await?;
info!("Service {} is shutdown!", server.name());
Ok::<(), error::Error>(())
@@ -99,9 +136,15 @@ pub trait Server: Send + Sync {
/// Starts the server and binds on `listening`.
///
/// Caller should ensure `start()` is only invoked once.
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr>;
async fn start(&mut self, listening: SocketAddr) -> Result<()>;
fn name(&self) -> &str;
/// Finds the actual bind address of this server.
/// If not found (returns `None`), maybe it's not started yet, or just don't have it.
fn bind_addr(&self) -> Option<SocketAddr> {
None
}
}
struct AcceptTask {

View File

@@ -1,153 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::SocketAddr;
use std::sync::Arc;
use api::v1::auth_header::AuthScheme;
use api::v1::Basic;
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
use async_trait::async_trait;
use auth::tests::MockUserProvider;
use auth::UserProviderRef;
use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_runtime::runtime::BuilderBuild;
use common_runtime::{Builder as RuntimeBuilder, Runtime};
use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu};
use servers::grpc::flight::FlightCraftWrapper;
use servers::grpc::greptime_handler::GreptimeRequestHandler;
use servers::query_handler::grpc::ServerGrpcQueryHandlerRef;
use servers::server::Server;
use snafu::ResultExt;
use table::test_util::MemTable;
use table::TableRef;
use tokio::net::TcpListener;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::codec::CompressionEncoding;
use crate::{create_testing_grpc_query_handler, LOCALHOST_WITH_0};
struct MockGrpcServer {
query_handler: ServerGrpcQueryHandlerRef,
user_provider: Option<UserProviderRef>,
runtime: Runtime,
}
impl MockGrpcServer {
fn new(
query_handler: ServerGrpcQueryHandlerRef,
user_provider: Option<UserProviderRef>,
runtime: Runtime,
) -> Self {
Self {
query_handler,
user_provider,
runtime,
}
}
fn create_service(&self) -> FlightServiceServer<impl FlightService> {
let service: FlightCraftWrapper<_> = GreptimeRequestHandler::new(
self.query_handler.clone(),
self.user_provider.clone(),
Some(self.runtime.clone()),
)
.into();
FlightServiceServer::new(service)
.accept_compressed(CompressionEncoding::Gzip)
.accept_compressed(CompressionEncoding::Zstd)
.send_compressed(CompressionEncoding::Gzip)
.send_compressed(CompressionEncoding::Zstd)
}
}
#[async_trait]
impl Server for MockGrpcServer {
async fn shutdown(&self) -> Result<()> {
Ok(())
}
async fn start(&self, addr: SocketAddr) -> Result<SocketAddr> {
let (listener, addr) = {
let listener = TcpListener::bind(addr)
.await
.context(TcpBindSnafu { addr })?;
let addr = listener.local_addr().context(TcpBindSnafu { addr })?;
(listener, addr)
};
let service = self.create_service();
// Would block to serve requests.
let _handle = tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(service)
.serve_with_incoming(TcpListenerStream::new(listener))
.await
.context(StartGrpcSnafu)
.unwrap()
});
Ok(addr)
}
fn name(&self) -> &str {
"MockGrpcServer"
}
}
fn create_grpc_server(table: TableRef) -> Result<Arc<dyn Server>> {
let query_handler = create_testing_grpc_query_handler(table);
let io_runtime = RuntimeBuilder::default()
.worker_threads(4)
.thread_name("grpc-io-handlers")
.build()
.unwrap();
let provider = MockUserProvider::default();
Ok(Arc::new(MockGrpcServer::new(
query_handler,
Some(Arc::new(provider)),
io_runtime,
)))
}
#[tokio::test]
async fn test_grpc_server_startup() {
let server = create_grpc_server(MemTable::default_numbers_table()).unwrap();
let re = server.start(LOCALHOST_WITH_0.parse().unwrap()).await;
let _ = re.unwrap();
}
#[tokio::test]
async fn test_grpc_query() {
let server = create_grpc_server(MemTable::default_numbers_table()).unwrap();
let re = server
.start(LOCALHOST_WITH_0.parse().unwrap())
.await
.unwrap();
let grpc_client = Client::with_urls(vec![re.to_string()]);
let mut db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client);
let re = db.sql("select * from numbers").await;
assert!(re.is_err());
let greptime = "greptime".to_string();
db.set_auth(AuthScheme::Basic(Basic {
username: greptime.clone(),
password: greptime.clone(),
}));
let re = db.sql("select * from numbers").await;
let _ = re.unwrap();
}

View File

@@ -29,7 +29,7 @@ use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
use query::query_engine::DescribeResult;
use query::{QueryEngineFactory, QueryEngineRef};
use servers::error::{Error, NotSupportedSnafu, Result};
use servers::query_handler::grpc::{GrpcQueryHandler, ServerGrpcQueryHandlerRef};
use servers::query_handler::grpc::GrpcQueryHandler;
use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler};
use session::context::QueryContextRef;
use snafu::ensure;
@@ -38,14 +38,11 @@ use table::metadata::TableId;
use table::table_name::TableName;
use table::TableRef;
mod grpc;
mod http;
mod interceptor;
mod mysql;
mod postgres;
const LOCALHOST_WITH_0: &str = "127.0.0.1:0";
pub struct DummyInstance {
query_engine: QueryEngineRef,
}
@@ -194,7 +191,3 @@ fn create_testing_instance(table: TableRef) -> DummyInstance {
fn create_testing_sql_query_handler(table: TableRef) -> ServerSqlQueryHandlerRef {
Arc::new(create_testing_instance(table)) as _
}
fn create_testing_grpc_query_handler(table: TableRef) -> ServerGrpcQueryHandlerRef {
Arc::new(create_testing_instance(table)) as _
}

View File

@@ -80,10 +80,9 @@ fn create_mysql_server(table: TableRef, opts: MysqlOpts<'_>) -> Result<Box<dyn S
async fn test_start_mysql_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default())?;
let mut mysql_server = create_mysql_server(table, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = mysql_server.start(listening).await;
let _ = result.unwrap();
mysql_server.start(listening).await.unwrap();
let result = mysql_server.start(listening).await;
assert!(result
@@ -97,7 +96,7 @@ async fn test_start_mysql_server() -> Result<()> {
async fn test_reject_no_database() -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(
let mut mysql_server = create_mysql_server(
table,
MysqlOpts {
reject_no_database: true,
@@ -105,7 +104,8 @@ async fn test_reject_no_database() -> Result<()> {
},
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
mysql_server.start(listening).await.unwrap();
let server_addr = mysql_server.bind_addr().unwrap();
let server_port = server_addr.port();
let fail = create_connection(server_port, None, false).await;
@@ -122,7 +122,7 @@ async fn test_reject_no_database() -> Result<()> {
async fn test_schema_validation() -> Result<()> {
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(
let mut mysql_server = create_mysql_server(
table,
MysqlOpts {
auth_info: Some(auth_info),
@@ -130,7 +130,8 @@ async fn test_schema_validation() -> Result<()> {
},
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
mysql_server.start(listening).await.unwrap();
let server_addr = mysql_server.bind_addr().unwrap();
Ok((mysql_server, server_addr.port()))
}
@@ -168,7 +169,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default())?;
let mut mysql_server = create_mysql_server(table, Default::default())?;
let result = mysql_server.shutdown().await;
assert!(result
.unwrap_err()
@@ -176,7 +177,8 @@ async fn test_shutdown_mysql_server() -> Result<()> {
.contains("MySQL server is not started."));
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
mysql_server.start(listening).await.unwrap();
let server_addr = mysql_server.bind_addr().unwrap();
let server_port = server_addr.port();
let mut join_handles = vec![];
@@ -272,7 +274,7 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::table("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(
let mut mysql_server = create_mysql_server(
table,
MysqlOpts {
tls: server_tls,
@@ -281,7 +283,8 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
mysql_server.start(listening).await.unwrap();
let server_addr = mysql_server.bind_addr().unwrap();
let r = create_connection(server_addr.port(), None, client_tls).await;
assert!(r.is_err());
@@ -310,7 +313,7 @@ async fn test_server_required_secure_client_plain_with_pkcs8_priv_key() -> Resul
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::table("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(
let mut mysql_server = create_mysql_server(
table,
MysqlOpts {
tls: server_tls,
@@ -319,7 +322,8 @@ async fn test_server_required_secure_client_plain_with_pkcs8_priv_key() -> Resul
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
mysql_server.start(listening).await.unwrap();
let server_addr = mysql_server.bind_addr().unwrap();
let r = create_connection_default_db_name(server_addr.port(), client_tls).await;
assert!(r.is_err());
@@ -342,7 +346,7 @@ async fn test_db_name() -> Result<()> {
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::table("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(
let mut mysql_server = create_mysql_server(
table,
MysqlOpts {
tls: server_tls,
@@ -351,7 +355,8 @@ async fn test_db_name() -> Result<()> {
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
mysql_server.start(listening).await.unwrap();
let server_addr = mysql_server.bind_addr().unwrap();
// None actually uses default database name
let r = create_connection_default_db_name(server_addr.port(), client_tls).await;
@@ -374,7 +379,7 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) ->
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::table("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(
let mut mysql_server = create_mysql_server(
table,
MysqlOpts {
tls: server_tls,
@@ -383,7 +388,8 @@ async fn do_test_query_all_datatypes(server_tls: TlsOption, client_tls: bool) ->
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
mysql_server.start(listening).await.unwrap();
let server_addr = mysql_server.bind_addr().unwrap();
let mut connection = create_connection_default_db_name(server_addr.port(), client_tls)
.await
@@ -415,9 +421,10 @@ async fn test_query_concurrently() -> Result<()> {
let table = MemTable::default_numbers_table();
let mysql_server = create_mysql_server(table, Default::default())?;
let mut mysql_server = create_mysql_server(table, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
mysql_server.start(listening).await.unwrap();
let server_addr = mysql_server.bind_addr().unwrap();
let server_port = server_addr.port();
let threads = 4;
@@ -470,7 +477,7 @@ async fn test_query_prepared() -> Result<()> {
let recordbatch = RecordBatch::new(schema, columns.clone()).unwrap();
let table = MemTable::table("all_datatypes", recordbatch);
let mysql_server = create_mysql_server(
let mut mysql_server = create_mysql_server(
table,
MysqlOpts {
..Default::default()
@@ -478,7 +485,8 @@ async fn test_query_prepared() -> Result<()> {
)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
mysql_server.start(listening).await.unwrap();
let server_addr = mysql_server.bind_addr().unwrap();
let mut connection = create_connection_default_db_name(server_addr.port(), false)
.await

View File

@@ -78,10 +78,9 @@ fn create_postgres_server(
pub async fn test_start_postgres_server() -> Result<()> {
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, Default::default(), None)?;
let mut pg_server = create_postgres_server(table, false, Default::default(), None)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let result = pg_server.start(listening).await;
let _ = result.unwrap();
pg_server.start(listening).await.unwrap();
let result = pg_server.start(listening).await;
assert!(result
@@ -102,10 +101,11 @@ async fn test_shutdown_pg_server_range() -> Result<()> {
async fn test_schema_validating() -> Result<()> {
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
let table = MemTable::default_numbers_table();
let postgres_server =
let mut postgres_server =
create_postgres_server(table, true, Default::default(), Some(auth_info))?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = postgres_server.start(listening).await.unwrap();
postgres_server.start(listening).await.unwrap();
let server_addr = postgres_server.bind_addr().unwrap();
let server_port = server_addr.port();
Ok((postgres_server, server_port))
}
@@ -140,7 +140,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let postgres_server = create_postgres_server(table, with_pwd, Default::default(), None)?;
let mut postgres_server = create_postgres_server(table, with_pwd, Default::default(), None)?;
let result = postgres_server.shutdown().await;
assert!(result
.unwrap_err()
@@ -148,7 +148,8 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
.contains("Postgres server is not started."));
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = postgres_server.start(listening).await.unwrap();
postgres_server.start(listening).await.unwrap();
let server_addr = postgres_server.bind_addr().unwrap();
let server_port = server_addr.port();
let mut join_handles = vec![];
@@ -360,9 +361,10 @@ async fn start_test_server(server_tls: TlsOption) -> Result<u16> {
let _ = install_ring_crypto_provider();
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, server_tls, None)?;
let mut pg_server = create_postgres_server(table, false, server_tls, None)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
pg_server.start(listening).await.unwrap();
let server_addr = pg_server.bind_addr().unwrap();
Ok(server_addr.port())
}