From 34c7f78861614123338649d2cb57a3a045c7c00e Mon Sep 17 00:00:00 2001 From: localhost Date: Mon, 8 May 2023 18:55:03 +0800 Subject: [PATCH] chore: add configurator to http server (#1488) * chore: add configurator params to start server fun * chore: update plugins type --------- Co-authored-by: paomian --- src/cmd/src/frontend.rs | 2 +- src/common/base/src/lib.rs | 56 ++++++++++++++++++- src/frontend/src/instance.rs | 16 +++--- src/frontend/src/lib.rs | 2 +- src/frontend/src/server.rs | 7 ++- src/frontend/src/statement.rs | 4 +- src/frontend/src/statement/copy_table_from.rs | 2 +- src/query/src/query_engine/state.rs | 3 +- src/query/src/tests/query_engine_test.rs | 2 +- src/servers/src/configurator.rs | 25 +++++++++ src/servers/src/http.rs | 24 ++++++-- src/servers/src/lib.rs | 3 +- src/servers/tests/http/http_test.rs | 2 +- src/servers/tests/http/influxdb_test.rs | 2 +- src/servers/tests/http/opentsdb_test.rs | 2 +- src/servers/tests/http/prometheus_test.rs | 2 +- tests-integration/src/test_util.rs | 4 +- 17 files changed, 130 insertions(+), 28 deletions(-) create mode 100644 src/servers/src/configurator.rs diff --git a/src/cmd/src/frontend.rs b/src/cmd/src/frontend.rs index 9bd2aae64b..ef58f0deca 100644 --- a/src/cmd/src/frontend.rs +++ b/src/cmd/src/frontend.rs @@ -216,7 +216,7 @@ impl StartCommand { } pub fn load_frontend_plugins(user_provider: &Option) -> Result { - let mut plugins = Plugins::new(); + let plugins = Plugins::new(); if let Some(provider) = user_provider { let provider = auth::user_provider_from_option(provider).context(IllegalAuthConfigSnafu)?; diff --git a/src/common/base/src/lib.rs b/src/common/base/src/lib.rs index cf3a19a771..1b57c3a2f6 100644 --- a/src/common/base/src/lib.rs +++ b/src/common/base/src/lib.rs @@ -18,6 +18,60 @@ pub mod bytes; #[allow(clippy::all)] pub mod readable_size; +use core::any::Any; +use std::sync::{Arc, Mutex, MutexGuard}; + pub use bit_vec::BitVec; -pub type Plugins = anymap::Map; +#[derive(Default, Clone)] +pub struct Plugins { + inner: Arc>>, +} + +impl Plugins { + pub fn new() -> Self { + Self { + inner: Arc::new(Mutex::new(anymap::Map::new())), + } + } + + fn lock(&self) -> MutexGuard> { + self.inner.lock().unwrap() + } + + pub fn insert(&self, value: T) { + self.lock().insert(value); + } + + pub fn get(&self) -> Option { + let binding = self.lock(); + binding.get::().cloned() + } + + pub fn map_mut(&self, mapper: F) -> R + where + F: FnOnce(Option<&mut T>) -> R, + { + let mut binding = self.lock(); + let opt = binding.get_mut::(); + mapper(opt) + } + + pub fn map(&self, mapper: F) -> Option + where + F: FnOnce(&T) -> R, + { + let binding = self.lock(); + binding.get::().map(mapper) + } + + pub fn len(&self) -> usize { + let binding = self.lock(); + binding.len() + } + + pub fn is_empty(&self) -> bool { + let binding = self.lock(); + binding.is_empty() + } +} diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 2f7cd8092e..68fc207e05 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -97,6 +97,7 @@ pub trait FrontendInstance: } pub type FrontendInstanceRef = Arc; +pub type StatementExecutorRef = Arc; #[derive(Clone)] pub struct Instance { @@ -154,6 +155,7 @@ impl Instance { query_engine.clone(), dist_instance.clone(), )); + plugins.insert::>(Some(statement_executor.clone())); Ok(Instance { catalog_manager, @@ -452,8 +454,8 @@ impl SqlQueryHandler for Instance { async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { let _timer = timer!(metrics::METRIC_HANDLE_SQL_ELAPSED); - - let query_interceptor = self.plugins.get::>(); + let query_interceptor_opt = self.plugins.get::>(); + let query_interceptor = query_interceptor_opt.as_ref(); let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) { Ok(q) => q, Err(e) => return vec![Err(e)], @@ -729,7 +731,7 @@ mod tests { #[test] fn test_exec_validation() { let query_ctx = Arc::new(QueryContext::new()); - let mut plugins = Plugins::new(); + let plugins = Plugins::new(); plugins.insert(QueryOptions { disallow_cross_schema_query: true, }); @@ -794,7 +796,7 @@ mod tests { fn do_test(sql: &str, plugins: Arc, query_ctx: &QueryContextRef, is_ok: bool) { let stmt = &parse_stmt(sql).unwrap()[0]; - let re = check_permission(plugins.clone(), stmt, query_ctx); + let re = check_permission(plugins, stmt, query_ctx); if is_ok { assert!(re.is_ok()); } else { @@ -832,7 +834,7 @@ mod tests { // test describe table let sql = "DESC TABLE {catalog}{schema}demo;"; - replace_test(sql, plugins.clone(), &query_ctx); + replace_test(sql, plugins, &query_ctx); } #[tokio::test(flavor = "multi_thread")] @@ -1098,7 +1100,7 @@ mod tests { let standalone = tests::create_standalone_instance("test_hook").await; let mut instance = standalone.instance; - let mut plugins = Plugins::new(); + let plugins = Plugins::new(); let counter_hook = Arc::new(AssertionHook::default()); plugins.insert::>(counter_hook.clone()); Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins)); @@ -1158,7 +1160,7 @@ mod tests { let standalone = tests::create_standalone_instance("test_db_hook").await; let mut instance = standalone.instance; - let mut plugins = Plugins::new(); + let plugins = Plugins::new(); let hook = Arc::new(DisableDBOpHook::default()); plugins.insert::>(hook.clone()); Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins)); diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index 371b6eadd1..cbcdd5bbe4 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -31,7 +31,7 @@ pub mod prom; pub mod prometheus; mod script; mod server; -pub(crate) mod statement; +pub mod statement; mod table; #[cfg(test)] mod tests; diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 37de9fd883..5f23431410 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -20,6 +20,7 @@ use common_base::Plugins; use common_runtime::Builder as RuntimeBuilder; use common_telemetry::info; use servers::auth::UserProviderRef; +use servers::configurator::ConfiguratorRef; use servers::error::Error::InternalIo; use servers::grpc::GrpcServer; use servers::http::HttpServerBuilder; @@ -56,7 +57,7 @@ impl Services { T: FrontendInstance, { let mut result = Vec::::with_capacity(plugins.len()); - let user_provider = plugins.get::().cloned(); + let user_provider = plugins.get::(); if let Some(opts) = &opts.grpc_options { let grpc_addr = parse_addr(&opts.addr)?; @@ -179,6 +180,10 @@ impl Services { } http_server_builder.with_metrics_handler(MetricsHandler); http_server_builder.with_script_handler(instance.clone()); + + if let Some(configurator) = plugins.get::>() { + http_server_builder.with_configurator(configurator); + } let http_server = http_server_builder.build(); result.push((Box::new(http_server), http_addr)); } diff --git a/src/frontend/src/statement.rs b/src/frontend/src/statement.rs index 28f7f177b6..4e1fbeebe4 100644 --- a/src/frontend/src/statement.rs +++ b/src/frontend/src/statement.rs @@ -40,7 +40,7 @@ use crate::error::{ }; #[derive(Clone)] -pub(crate) struct StatementExecutor { +pub struct StatementExecutor { catalog_manager: CatalogManagerRef, query_engine: QueryEngineRef, sql_stmt_executor: SqlStatementExecutorRef, @@ -70,7 +70,7 @@ impl StatementExecutor { } } - async fn execute_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { + pub async fn execute_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { match stmt { Statement::Query(_) | Statement::Explain(_) | Statement::Delete(_) => { self.plan_exec(QueryStatement::Sql(stmt), query_ctx).await diff --git a/src/frontend/src/statement/copy_table_from.rs b/src/frontend/src/statement/copy_table_from.rs index 1ef4ffb758..7c4e442719 100644 --- a/src/frontend/src/statement/copy_table_from.rs +++ b/src/frontend/src/statement/copy_table_from.rs @@ -196,7 +196,7 @@ impl StatementExecutor { } } - pub(crate) async fn copy_table_from(&self, req: CopyTableRequest) -> Result { + pub async fn copy_table_from(&self, req: CopyTableRequest) -> Result { let table_ref = TableReference { catalog: &req.catalog_name, schema: &req.schema_name, diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 617620a924..4348543c9d 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -110,8 +110,7 @@ impl QueryEngineState { pub(crate) fn disallow_cross_schema_query(&self) -> bool { self.plugins - .get::() - .map(|x| x.disallow_cross_schema_query) + .map::(|x| x.disallow_cross_schema_query) .unwrap_or(false) } diff --git a/src/query/src/tests/query_engine_test.rs b/src/query/src/tests/query_engine_test.rs index 61753a314f..08d81e53c6 100644 --- a/src/query/src/tests/query_engine_test.rs +++ b/src/query/src/tests/query_engine_test.rs @@ -127,7 +127,7 @@ async fn test_query_validate() -> Result<()> { let catalog_list = catalog_list()?; // set plugins - let mut plugins = Plugins::new(); + let plugins = Plugins::new(); plugins.insert(QueryOptions { disallow_cross_schema_query: true, }); diff --git a/src/servers/src/configurator.rs b/src/servers/src/configurator.rs new file mode 100644 index 0000000000..a86aa43e60 --- /dev/null +++ b/src/servers/src/configurator.rs @@ -0,0 +1,25 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use axum::Router; + +pub trait Configurator: Send + Sync { + fn config_http(&self, route: Router) -> Router { + route + } +} + +pub type ConfiguratorRef = Arc; diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 31b2b75b1e..73906d6dc0 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -58,6 +58,7 @@ use tower_http::trace::TraceLayer; use self::authorize::HttpAuth; use self::influxdb::{influxdb_health, influxdb_ping, influxdb_write}; use crate::auth::UserProviderRef; +use crate::configurator::ConfiguratorRef; use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu}; use crate::http::admin::flush; use crate::metrics_handler::MetricsHandler; @@ -112,6 +113,7 @@ pub struct HttpServer { shutdown_tx: Mutex>>, user_provider: Option, metrics_handler: Option, + configurator: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -256,7 +258,7 @@ pub struct JsonResponse { } impl JsonResponse { - fn with_error(error: String, error_code: StatusCode) -> Self { + pub fn with_error(error: String, error_code: StatusCode) -> Self { JsonResponse { error: Some(error), code: error_code as u32, @@ -280,7 +282,7 @@ impl JsonResponse { } /// Create a json response from query result - async fn from_output(outputs: Vec>) -> Self { + pub async fn from_output(outputs: Vec>) -> Self { // TODO(sunng87): this api response structure cannot represent error // well. It hides successful execution results from error response let mut results = Vec::with_capacity(outputs.len()); @@ -382,6 +384,7 @@ impl HttpServerBuilder { script_handler: None, metrics_handler: None, shutdown_tx: Mutex::new(None), + configurator: None, }, } } @@ -425,6 +428,12 @@ impl HttpServerBuilder { self.inner.metrics_handler.get_or_insert(handler); self } + + pub fn with_configurator(&mut self, configurator: Option) -> &mut Self { + self.inner.configurator = configurator; + self + } + pub fn build(&mut self) -> HttpServer { std::mem::take(self).inner } @@ -512,7 +521,10 @@ impl HttpServer { router = router.nest("/dashboard", dashboard::dashboard()); } } + router + } + pub fn build(&self, router: Router) -> Router { router // middlewares .layer( @@ -605,7 +617,11 @@ impl Server for HttpServer { AlreadyStartedSnafu { server: "HTTP" } ); - let app = self.make_app(); + let mut app = self.make_app(); + if let Some(configurator) = self.configurator.as_ref() { + app = configurator.config_http(app); + } + let app = self.build(app); let server = axum::Server::bind(&listening).serve(app.into_make_service()); *shutdown_tx = Some(tx); @@ -719,7 +735,7 @@ mod test { .with_sql_handler(sql_instance) .with_grpc_handler(grpc_instance) .build(); - server.make_app().route( + server.build(server.make_app()).route( "/test/timeout", get(forever.layer( ServiceBuilder::new() diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index cef577be69..9f7872b485 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -19,6 +19,7 @@ use common_catalog::consts::DEFAULT_CATALOG_NAME; use serde::{Deserialize, Serialize}; pub mod auth; +pub mod configurator; pub mod error; pub mod grpc; pub mod http; @@ -62,7 +63,7 @@ pub enum Mode { /// schema name /// - if `[-]` is provided, we split database name with `-` and use /// `` and ``. -pub(crate) fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &str) { +pub fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &str) { let parts = db.splitn(2, '-').collect::>(); if parts.len() == 2 { (parts[0], parts[1]) diff --git a/src/servers/tests/http/http_test.rs b/src/servers/tests/http/http_test.rs index b6b3a9fa13..9c974fde6d 100644 --- a/src/servers/tests/http/http_test.rs +++ b/src/servers/tests/http/http_test.rs @@ -28,7 +28,7 @@ fn make_test_app() -> Router { MemTable::default_numbers_table(), )) .build(); - server.make_app() + server.build(server.make_app()) } #[tokio::test] diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index ab5c86afc7..c3fc3b50ee 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -109,7 +109,7 @@ fn make_test_app(tx: Arc>, db_name: Option<&str>) server_builder.with_influxdb_handler(instance); let server = server_builder.build(); - server.make_app() + server.build(server.make_app()) } #[tokio::test] diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 399d0d7833..64033e2b1b 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -97,7 +97,7 @@ fn make_test_app(tx: mpsc::Sender) -> Router { .with_sql_handler(instance.clone()) .with_opentsdb_handler(instance) .build(); - server.make_app() + server.build(server.make_app()) } #[tokio::test] diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index a1046d59e9..395f7fe836 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -122,7 +122,7 @@ fn make_test_app(tx: mpsc::Sender<(String, Vec)>) -> Router { .with_sql_handler(instance.clone()) .with_prom_handler(instance) .build(); - server.make_app() + server.build(server.make_app()) } #[tokio::test] diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 0b3c521dcb..a63c2fb7aa 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -306,7 +306,7 @@ pub async fn setup_test_http_app(store_type: StorageType, name: &str) -> (Router .with_grpc_handler(ServerGrpcQueryHandlerAdaptor::arc(instance.clone())) .with_metrics_handler(MetricsHandler) .build(); - (http_server.make_app(), guard) + (http_server.build(http_server.make_app()), guard) } pub async fn setup_test_http_app_with_frontend( @@ -332,7 +332,7 @@ pub async fn setup_test_http_app_with_frontend( .with_grpc_handler(ServerGrpcQueryHandlerAdaptor::arc(frontend_ref.clone())) .with_script_handler(frontend_ref) .build(); - let app = http_server.make_app(); + let app = http_server.build(http_server.make_app()); (app, guard) }