mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-03 11:52:54 +00:00
chore: add configurator to http server (#1488)
* chore: add configurator params to start server fun * chore: update plugins type --------- Co-authored-by: paomian <qtang@greptime.com>
This commit is contained in:
@@ -216,7 +216,7 @@ impl StartCommand {
|
||||
}
|
||||
|
||||
pub fn load_frontend_plugins(user_provider: &Option<String>) -> Result<Plugins> {
|
||||
let mut plugins = Plugins::new();
|
||||
let plugins = Plugins::new();
|
||||
|
||||
if let Some(provider) = user_provider {
|
||||
let provider = auth::user_provider_from_option(provider).context(IllegalAuthConfigSnafu)?;
|
||||
|
||||
@@ -18,6 +18,60 @@ pub mod bytes;
|
||||
#[allow(clippy::all)]
|
||||
pub mod readable_size;
|
||||
|
||||
use core::any::Any;
|
||||
use std::sync::{Arc, Mutex, MutexGuard};
|
||||
|
||||
pub use bit_vec::BitVec;
|
||||
|
||||
pub type Plugins = anymap::Map<dyn core::any::Any + Send + Sync>;
|
||||
#[derive(Default, Clone)]
|
||||
pub struct Plugins {
|
||||
inner: Arc<Mutex<anymap::Map<dyn Any + Send + Sync>>>,
|
||||
}
|
||||
|
||||
impl Plugins {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: Arc::new(Mutex::new(anymap::Map::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn lock(&self) -> MutexGuard<anymap::Map<dyn Any + Send + Sync>> {
|
||||
self.inner.lock().unwrap()
|
||||
}
|
||||
|
||||
pub fn insert<T: 'static + Send + Sync>(&self, value: T) {
|
||||
self.lock().insert(value);
|
||||
}
|
||||
|
||||
pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
|
||||
let binding = self.lock();
|
||||
binding.get::<T>().cloned()
|
||||
}
|
||||
|
||||
pub fn map_mut<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> R
|
||||
where
|
||||
F: FnOnce(Option<&mut T>) -> R,
|
||||
{
|
||||
let mut binding = self.lock();
|
||||
let opt = binding.get_mut::<T>();
|
||||
mapper(opt)
|
||||
}
|
||||
|
||||
pub fn map<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> Option<R>
|
||||
where
|
||||
F: FnOnce(&T) -> R,
|
||||
{
|
||||
let binding = self.lock();
|
||||
binding.get::<T>().map(mapper)
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
let binding = self.lock();
|
||||
binding.len()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
let binding = self.lock();
|
||||
binding.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +97,7 @@ pub trait FrontendInstance:
|
||||
}
|
||||
|
||||
pub type FrontendInstanceRef = Arc<dyn FrontendInstance>;
|
||||
pub type StatementExecutorRef = Arc<StatementExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Instance {
|
||||
@@ -154,6 +155,7 @@ impl Instance {
|
||||
query_engine.clone(),
|
||||
dist_instance.clone(),
|
||||
));
|
||||
plugins.insert::<Option<StatementExecutorRef>>(Some(statement_executor.clone()));
|
||||
|
||||
Ok(Instance {
|
||||
catalog_manager,
|
||||
@@ -452,8 +454,8 @@ impl SqlQueryHandler for Instance {
|
||||
|
||||
async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
|
||||
let _timer = timer!(metrics::METRIC_HANDLE_SQL_ELAPSED);
|
||||
|
||||
let query_interceptor = self.plugins.get::<SqlQueryInterceptorRef<Error>>();
|
||||
let query_interceptor_opt = self.plugins.get::<SqlQueryInterceptorRef<Error>>();
|
||||
let query_interceptor = query_interceptor_opt.as_ref();
|
||||
let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) {
|
||||
Ok(q) => q,
|
||||
Err(e) => return vec![Err(e)],
|
||||
@@ -729,7 +731,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_exec_validation() {
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
let mut plugins = Plugins::new();
|
||||
let plugins = Plugins::new();
|
||||
plugins.insert(QueryOptions {
|
||||
disallow_cross_schema_query: true,
|
||||
});
|
||||
@@ -794,7 +796,7 @@ mod tests {
|
||||
|
||||
fn do_test(sql: &str, plugins: Arc<Plugins>, query_ctx: &QueryContextRef, is_ok: bool) {
|
||||
let stmt = &parse_stmt(sql).unwrap()[0];
|
||||
let re = check_permission(plugins.clone(), stmt, query_ctx);
|
||||
let re = check_permission(plugins, stmt, query_ctx);
|
||||
if is_ok {
|
||||
assert!(re.is_ok());
|
||||
} else {
|
||||
@@ -832,7 +834,7 @@ mod tests {
|
||||
|
||||
// test describe table
|
||||
let sql = "DESC TABLE {catalog}{schema}demo;";
|
||||
replace_test(sql, plugins.clone(), &query_ctx);
|
||||
replace_test(sql, plugins, &query_ctx);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
@@ -1098,7 +1100,7 @@ mod tests {
|
||||
let standalone = tests::create_standalone_instance("test_hook").await;
|
||||
let mut instance = standalone.instance;
|
||||
|
||||
let mut plugins = Plugins::new();
|
||||
let plugins = Plugins::new();
|
||||
let counter_hook = Arc::new(AssertionHook::default());
|
||||
plugins.insert::<SqlQueryInterceptorRef<Error>>(counter_hook.clone());
|
||||
Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins));
|
||||
@@ -1158,7 +1160,7 @@ mod tests {
|
||||
let standalone = tests::create_standalone_instance("test_db_hook").await;
|
||||
let mut instance = standalone.instance;
|
||||
|
||||
let mut plugins = Plugins::new();
|
||||
let plugins = Plugins::new();
|
||||
let hook = Arc::new(DisableDBOpHook::default());
|
||||
plugins.insert::<SqlQueryInterceptorRef<Error>>(hook.clone());
|
||||
Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins));
|
||||
|
||||
@@ -31,7 +31,7 @@ pub mod prom;
|
||||
pub mod prometheus;
|
||||
mod script;
|
||||
mod server;
|
||||
pub(crate) mod statement;
|
||||
pub mod statement;
|
||||
mod table;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -20,6 +20,7 @@ use common_base::Plugins;
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use common_telemetry::info;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::configurator::ConfiguratorRef;
|
||||
use servers::error::Error::InternalIo;
|
||||
use servers::grpc::GrpcServer;
|
||||
use servers::http::HttpServerBuilder;
|
||||
@@ -56,7 +57,7 @@ impl Services {
|
||||
T: FrontendInstance,
|
||||
{
|
||||
let mut result = Vec::<ServerHandler>::with_capacity(plugins.len());
|
||||
let user_provider = plugins.get::<UserProviderRef>().cloned();
|
||||
let user_provider = plugins.get::<UserProviderRef>();
|
||||
|
||||
if let Some(opts) = &opts.grpc_options {
|
||||
let grpc_addr = parse_addr(&opts.addr)?;
|
||||
@@ -179,6 +180,10 @@ impl Services {
|
||||
}
|
||||
http_server_builder.with_metrics_handler(MetricsHandler);
|
||||
http_server_builder.with_script_handler(instance.clone());
|
||||
|
||||
if let Some(configurator) = plugins.get::<Option<ConfiguratorRef>>() {
|
||||
http_server_builder.with_configurator(configurator);
|
||||
}
|
||||
let http_server = http_server_builder.build();
|
||||
result.push((Box::new(http_server), http_addr));
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ use crate::error::{
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct StatementExecutor {
|
||||
pub struct StatementExecutor {
|
||||
catalog_manager: CatalogManagerRef,
|
||||
query_engine: QueryEngineRef,
|
||||
sql_stmt_executor: SqlStatementExecutorRef,
|
||||
@@ -70,7 +70,7 @@ impl StatementExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
pub async fn execute_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
match stmt {
|
||||
Statement::Query(_) | Statement::Explain(_) | Statement::Delete(_) => {
|
||||
self.plan_exec(QueryStatement::Sql(stmt), query_ctx).await
|
||||
|
||||
@@ -196,7 +196,7 @@ impl StatementExecutor {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn copy_table_from(&self, req: CopyTableRequest) -> Result<Output> {
|
||||
pub async fn copy_table_from(&self, req: CopyTableRequest) -> Result<Output> {
|
||||
let table_ref = TableReference {
|
||||
catalog: &req.catalog_name,
|
||||
schema: &req.schema_name,
|
||||
|
||||
@@ -110,8 +110,7 @@ impl QueryEngineState {
|
||||
|
||||
pub(crate) fn disallow_cross_schema_query(&self) -> bool {
|
||||
self.plugins
|
||||
.get::<QueryOptions>()
|
||||
.map(|x| x.disallow_cross_schema_query)
|
||||
.map::<QueryOptions, _, _>(|x| x.disallow_cross_schema_query)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ async fn test_query_validate() -> Result<()> {
|
||||
let catalog_list = catalog_list()?;
|
||||
|
||||
// set plugins
|
||||
let mut plugins = Plugins::new();
|
||||
let plugins = Plugins::new();
|
||||
plugins.insert(QueryOptions {
|
||||
disallow_cross_schema_query: true,
|
||||
});
|
||||
|
||||
25
src/servers/src/configurator.rs
Normal file
25
src/servers/src/configurator.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
|
||||
pub trait Configurator: Send + Sync {
|
||||
fn config_http(&self, route: Router) -> Router {
|
||||
route
|
||||
}
|
||||
}
|
||||
|
||||
pub type ConfiguratorRef = Arc<dyn Configurator>;
|
||||
@@ -58,6 +58,7 @@ use tower_http::trace::TraceLayer;
|
||||
use self::authorize::HttpAuth;
|
||||
use self::influxdb::{influxdb_health, influxdb_ping, influxdb_write};
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::configurator::ConfiguratorRef;
|
||||
use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu};
|
||||
use crate::http::admin::flush;
|
||||
use crate::metrics_handler::MetricsHandler;
|
||||
@@ -112,6 +113,7 @@ pub struct HttpServer {
|
||||
shutdown_tx: Mutex<Option<Sender<()>>>,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
metrics_handler: Option<MetricsHandler>,
|
||||
configurator: Option<ConfiguratorRef>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
@@ -256,7 +258,7 @@ pub struct JsonResponse {
|
||||
}
|
||||
|
||||
impl JsonResponse {
|
||||
fn with_error(error: String, error_code: StatusCode) -> Self {
|
||||
pub fn with_error(error: String, error_code: StatusCode) -> Self {
|
||||
JsonResponse {
|
||||
error: Some(error),
|
||||
code: error_code as u32,
|
||||
@@ -280,7 +282,7 @@ impl JsonResponse {
|
||||
}
|
||||
|
||||
/// Create a json response from query result
|
||||
async fn from_output(outputs: Vec<Result<Output>>) -> Self {
|
||||
pub async fn from_output(outputs: Vec<Result<Output>>) -> Self {
|
||||
// TODO(sunng87): this api response structure cannot represent error
|
||||
// well. It hides successful execution results from error response
|
||||
let mut results = Vec::with_capacity(outputs.len());
|
||||
@@ -382,6 +384,7 @@ impl HttpServerBuilder {
|
||||
script_handler: None,
|
||||
metrics_handler: None,
|
||||
shutdown_tx: Mutex::new(None),
|
||||
configurator: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -425,6 +428,12 @@ impl HttpServerBuilder {
|
||||
self.inner.metrics_handler.get_or_insert(handler);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_configurator(&mut self, configurator: Option<ConfiguratorRef>) -> &mut Self {
|
||||
self.inner.configurator = configurator;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(&mut self) -> HttpServer {
|
||||
std::mem::take(self).inner
|
||||
}
|
||||
@@ -512,7 +521,10 @@ impl HttpServer {
|
||||
router = router.nest("/dashboard", dashboard::dashboard());
|
||||
}
|
||||
}
|
||||
router
|
||||
}
|
||||
|
||||
pub fn build(&self, router: Router) -> Router {
|
||||
router
|
||||
// middlewares
|
||||
.layer(
|
||||
@@ -605,7 +617,11 @@ impl Server for HttpServer {
|
||||
AlreadyStartedSnafu { server: "HTTP" }
|
||||
);
|
||||
|
||||
let app = self.make_app();
|
||||
let mut app = self.make_app();
|
||||
if let Some(configurator) = self.configurator.as_ref() {
|
||||
app = configurator.config_http(app);
|
||||
}
|
||||
let app = self.build(app);
|
||||
let server = axum::Server::bind(&listening).serve(app.into_make_service());
|
||||
|
||||
*shutdown_tx = Some(tx);
|
||||
@@ -719,7 +735,7 @@ mod test {
|
||||
.with_sql_handler(sql_instance)
|
||||
.with_grpc_handler(grpc_instance)
|
||||
.build();
|
||||
server.make_app().route(
|
||||
server.build(server.make_app()).route(
|
||||
"/test/timeout",
|
||||
get(forever.layer(
|
||||
ServiceBuilder::new()
|
||||
|
||||
@@ -19,6 +19,7 @@ use common_catalog::consts::DEFAULT_CATALOG_NAME;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub mod auth;
|
||||
pub mod configurator;
|
||||
pub mod error;
|
||||
pub mod grpc;
|
||||
pub mod http;
|
||||
@@ -62,7 +63,7 @@ pub enum Mode {
|
||||
/// schema name
|
||||
/// - if `[<catalog>-]` is provided, we split database name with `-` and use
|
||||
/// `<catalog>` and `<schema>`.
|
||||
pub(crate) fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &str) {
|
||||
pub fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &str) {
|
||||
let parts = db.splitn(2, '-').collect::<Vec<&str>>();
|
||||
if parts.len() == 2 {
|
||||
(parts[0], parts[1])
|
||||
|
||||
@@ -28,7 +28,7 @@ fn make_test_app() -> Router {
|
||||
MemTable::default_numbers_table(),
|
||||
))
|
||||
.build();
|
||||
server.make_app()
|
||||
server.build(server.make_app())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -109,7 +109,7 @@ fn make_test_app(tx: Arc<mpsc::Sender<(String, String)>>, db_name: Option<&str>)
|
||||
|
||||
server_builder.with_influxdb_handler(instance);
|
||||
let server = server_builder.build();
|
||||
server.make_app()
|
||||
server.build(server.make_app())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -97,7 +97,7 @@ fn make_test_app(tx: mpsc::Sender<String>) -> Router {
|
||||
.with_sql_handler(instance.clone())
|
||||
.with_opentsdb_handler(instance)
|
||||
.build();
|
||||
server.make_app()
|
||||
server.build(server.make_app())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -122,7 +122,7 @@ fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
|
||||
.with_sql_handler(instance.clone())
|
||||
.with_prom_handler(instance)
|
||||
.build();
|
||||
server.make_app()
|
||||
server.build(server.make_app())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -306,7 +306,7 @@ pub async fn setup_test_http_app(store_type: StorageType, name: &str) -> (Router
|
||||
.with_grpc_handler(ServerGrpcQueryHandlerAdaptor::arc(instance.clone()))
|
||||
.with_metrics_handler(MetricsHandler)
|
||||
.build();
|
||||
(http_server.make_app(), guard)
|
||||
(http_server.build(http_server.make_app()), guard)
|
||||
}
|
||||
|
||||
pub async fn setup_test_http_app_with_frontend(
|
||||
@@ -332,7 +332,7 @@ pub async fn setup_test_http_app_with_frontend(
|
||||
.with_grpc_handler(ServerGrpcQueryHandlerAdaptor::arc(frontend_ref.clone()))
|
||||
.with_script_handler(frontend_ref)
|
||||
.build();
|
||||
let app = http_server.make_app();
|
||||
let app = http_server.build(http_server.make_app());
|
||||
(app, guard)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user