chore: add configurator to http server (#1488)

* chore: add configurator params to start server fun

* chore: update plugins type

---------

Co-authored-by: paomian <qtang@greptime.com>
This commit is contained in:
localhost
2023-05-08 18:55:03 +08:00
committed by GitHub
parent 610651fa8f
commit 34c7f78861
17 changed files with 130 additions and 28 deletions

View File

@@ -216,7 +216,7 @@ impl StartCommand {
}
pub fn load_frontend_plugins(user_provider: &Option<String>) -> Result<Plugins> {
let mut plugins = Plugins::new();
let plugins = Plugins::new();
if let Some(provider) = user_provider {
let provider = auth::user_provider_from_option(provider).context(IllegalAuthConfigSnafu)?;

View File

@@ -18,6 +18,60 @@ pub mod bytes;
#[allow(clippy::all)]
pub mod readable_size;
use core::any::Any;
use std::sync::{Arc, Mutex, MutexGuard};
pub use bit_vec::BitVec;
pub type Plugins = anymap::Map<dyn core::any::Any + Send + Sync>;
#[derive(Default, Clone)]
pub struct Plugins {
inner: Arc<Mutex<anymap::Map<dyn Any + Send + Sync>>>,
}
impl Plugins {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(anymap::Map::new())),
}
}
fn lock(&self) -> MutexGuard<anymap::Map<dyn Any + Send + Sync>> {
self.inner.lock().unwrap()
}
pub fn insert<T: 'static + Send + Sync>(&self, value: T) {
self.lock().insert(value);
}
pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
let binding = self.lock();
binding.get::<T>().cloned()
}
pub fn map_mut<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> R
where
F: FnOnce(Option<&mut T>) -> R,
{
let mut binding = self.lock();
let opt = binding.get_mut::<T>();
mapper(opt)
}
pub fn map<T: 'static + Send + Sync, F, R>(&self, mapper: F) -> Option<R>
where
F: FnOnce(&T) -> R,
{
let binding = self.lock();
binding.get::<T>().map(mapper)
}
pub fn len(&self) -> usize {
let binding = self.lock();
binding.len()
}
pub fn is_empty(&self) -> bool {
let binding = self.lock();
binding.is_empty()
}
}

View File

@@ -97,6 +97,7 @@ pub trait FrontendInstance:
}
pub type FrontendInstanceRef = Arc<dyn FrontendInstance>;
pub type StatementExecutorRef = Arc<StatementExecutor>;
#[derive(Clone)]
pub struct Instance {
@@ -154,6 +155,7 @@ impl Instance {
query_engine.clone(),
dist_instance.clone(),
));
plugins.insert::<Option<StatementExecutorRef>>(Some(statement_executor.clone()));
Ok(Instance {
catalog_manager,
@@ -452,8 +454,8 @@ impl SqlQueryHandler for Instance {
async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec<Result<Output>> {
let _timer = timer!(metrics::METRIC_HANDLE_SQL_ELAPSED);
let query_interceptor = self.plugins.get::<SqlQueryInterceptorRef<Error>>();
let query_interceptor_opt = self.plugins.get::<SqlQueryInterceptorRef<Error>>();
let query_interceptor = query_interceptor_opt.as_ref();
let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) {
Ok(q) => q,
Err(e) => return vec![Err(e)],
@@ -729,7 +731,7 @@ mod tests {
#[test]
fn test_exec_validation() {
let query_ctx = Arc::new(QueryContext::new());
let mut plugins = Plugins::new();
let plugins = Plugins::new();
plugins.insert(QueryOptions {
disallow_cross_schema_query: true,
});
@@ -794,7 +796,7 @@ mod tests {
fn do_test(sql: &str, plugins: Arc<Plugins>, query_ctx: &QueryContextRef, is_ok: bool) {
let stmt = &parse_stmt(sql).unwrap()[0];
let re = check_permission(plugins.clone(), stmt, query_ctx);
let re = check_permission(plugins, stmt, query_ctx);
if is_ok {
assert!(re.is_ok());
} else {
@@ -832,7 +834,7 @@ mod tests {
// test describe table
let sql = "DESC TABLE {catalog}{schema}demo;";
replace_test(sql, plugins.clone(), &query_ctx);
replace_test(sql, plugins, &query_ctx);
}
#[tokio::test(flavor = "multi_thread")]
@@ -1098,7 +1100,7 @@ mod tests {
let standalone = tests::create_standalone_instance("test_hook").await;
let mut instance = standalone.instance;
let mut plugins = Plugins::new();
let plugins = Plugins::new();
let counter_hook = Arc::new(AssertionHook::default());
plugins.insert::<SqlQueryInterceptorRef<Error>>(counter_hook.clone());
Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins));
@@ -1158,7 +1160,7 @@ mod tests {
let standalone = tests::create_standalone_instance("test_db_hook").await;
let mut instance = standalone.instance;
let mut plugins = Plugins::new();
let plugins = Plugins::new();
let hook = Arc::new(DisableDBOpHook::default());
plugins.insert::<SqlQueryInterceptorRef<Error>>(hook.clone());
Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins));

View File

@@ -31,7 +31,7 @@ pub mod prom;
pub mod prometheus;
mod script;
mod server;
pub(crate) mod statement;
pub mod statement;
mod table;
#[cfg(test)]
mod tests;

View File

@@ -20,6 +20,7 @@ use common_base::Plugins;
use common_runtime::Builder as RuntimeBuilder;
use common_telemetry::info;
use servers::auth::UserProviderRef;
use servers::configurator::ConfiguratorRef;
use servers::error::Error::InternalIo;
use servers::grpc::GrpcServer;
use servers::http::HttpServerBuilder;
@@ -56,7 +57,7 @@ impl Services {
T: FrontendInstance,
{
let mut result = Vec::<ServerHandler>::with_capacity(plugins.len());
let user_provider = plugins.get::<UserProviderRef>().cloned();
let user_provider = plugins.get::<UserProviderRef>();
if let Some(opts) = &opts.grpc_options {
let grpc_addr = parse_addr(&opts.addr)?;
@@ -179,6 +180,10 @@ impl Services {
}
http_server_builder.with_metrics_handler(MetricsHandler);
http_server_builder.with_script_handler(instance.clone());
if let Some(configurator) = plugins.get::<Option<ConfiguratorRef>>() {
http_server_builder.with_configurator(configurator);
}
let http_server = http_server_builder.build();
result.push((Box::new(http_server), http_addr));
}

View File

@@ -40,7 +40,7 @@ use crate::error::{
};
#[derive(Clone)]
pub(crate) struct StatementExecutor {
pub struct StatementExecutor {
catalog_manager: CatalogManagerRef,
query_engine: QueryEngineRef,
sql_stmt_executor: SqlStatementExecutorRef,
@@ -70,7 +70,7 @@ impl StatementExecutor {
}
}
async fn execute_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result<Output> {
pub async fn execute_sql(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result<Output> {
match stmt {
Statement::Query(_) | Statement::Explain(_) | Statement::Delete(_) => {
self.plan_exec(QueryStatement::Sql(stmt), query_ctx).await

View File

@@ -196,7 +196,7 @@ impl StatementExecutor {
}
}
pub(crate) async fn copy_table_from(&self, req: CopyTableRequest) -> Result<Output> {
pub async fn copy_table_from(&self, req: CopyTableRequest) -> Result<Output> {
let table_ref = TableReference {
catalog: &req.catalog_name,
schema: &req.schema_name,

View File

@@ -110,8 +110,7 @@ impl QueryEngineState {
pub(crate) fn disallow_cross_schema_query(&self) -> bool {
self.plugins
.get::<QueryOptions>()
.map(|x| x.disallow_cross_schema_query)
.map::<QueryOptions, _, _>(|x| x.disallow_cross_schema_query)
.unwrap_or(false)
}

View File

@@ -127,7 +127,7 @@ async fn test_query_validate() -> Result<()> {
let catalog_list = catalog_list()?;
// set plugins
let mut plugins = Plugins::new();
let plugins = Plugins::new();
plugins.insert(QueryOptions {
disallow_cross_schema_query: true,
});

View File

@@ -0,0 +1,25 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::Router;
pub trait Configurator: Send + Sync {
fn config_http(&self, route: Router) -> Router {
route
}
}
pub type ConfiguratorRef = Arc<dyn Configurator>;

View File

@@ -58,6 +58,7 @@ use tower_http::trace::TraceLayer;
use self::authorize::HttpAuth;
use self::influxdb::{influxdb_health, influxdb_ping, influxdb_write};
use crate::auth::UserProviderRef;
use crate::configurator::ConfiguratorRef;
use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu};
use crate::http::admin::flush;
use crate::metrics_handler::MetricsHandler;
@@ -112,6 +113,7 @@ pub struct HttpServer {
shutdown_tx: Mutex<Option<Sender<()>>>,
user_provider: Option<UserProviderRef>,
metrics_handler: Option<MetricsHandler>,
configurator: Option<ConfiguratorRef>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
@@ -256,7 +258,7 @@ pub struct JsonResponse {
}
impl JsonResponse {
fn with_error(error: String, error_code: StatusCode) -> Self {
pub fn with_error(error: String, error_code: StatusCode) -> Self {
JsonResponse {
error: Some(error),
code: error_code as u32,
@@ -280,7 +282,7 @@ impl JsonResponse {
}
/// Create a json response from query result
async fn from_output(outputs: Vec<Result<Output>>) -> Self {
pub async fn from_output(outputs: Vec<Result<Output>>) -> Self {
// TODO(sunng87): this api response structure cannot represent error
// well. It hides successful execution results from error response
let mut results = Vec::with_capacity(outputs.len());
@@ -382,6 +384,7 @@ impl HttpServerBuilder {
script_handler: None,
metrics_handler: None,
shutdown_tx: Mutex::new(None),
configurator: None,
},
}
}
@@ -425,6 +428,12 @@ impl HttpServerBuilder {
self.inner.metrics_handler.get_or_insert(handler);
self
}
pub fn with_configurator(&mut self, configurator: Option<ConfiguratorRef>) -> &mut Self {
self.inner.configurator = configurator;
self
}
pub fn build(&mut self) -> HttpServer {
std::mem::take(self).inner
}
@@ -512,7 +521,10 @@ impl HttpServer {
router = router.nest("/dashboard", dashboard::dashboard());
}
}
router
}
pub fn build(&self, router: Router) -> Router {
router
// middlewares
.layer(
@@ -605,7 +617,11 @@ impl Server for HttpServer {
AlreadyStartedSnafu { server: "HTTP" }
);
let app = self.make_app();
let mut app = self.make_app();
if let Some(configurator) = self.configurator.as_ref() {
app = configurator.config_http(app);
}
let app = self.build(app);
let server = axum::Server::bind(&listening).serve(app.into_make_service());
*shutdown_tx = Some(tx);
@@ -719,7 +735,7 @@ mod test {
.with_sql_handler(sql_instance)
.with_grpc_handler(grpc_instance)
.build();
server.make_app().route(
server.build(server.make_app()).route(
"/test/timeout",
get(forever.layer(
ServiceBuilder::new()

View File

@@ -19,6 +19,7 @@ use common_catalog::consts::DEFAULT_CATALOG_NAME;
use serde::{Deserialize, Serialize};
pub mod auth;
pub mod configurator;
pub mod error;
pub mod grpc;
pub mod http;
@@ -62,7 +63,7 @@ pub enum Mode {
/// schema name
/// - if `[<catalog>-]` is provided, we split database name with `-` and use
/// `<catalog>` and `<schema>`.
pub(crate) fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &str) {
pub fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &str) {
let parts = db.splitn(2, '-').collect::<Vec<&str>>();
if parts.len() == 2 {
(parts[0], parts[1])

View File

@@ -28,7 +28,7 @@ fn make_test_app() -> Router {
MemTable::default_numbers_table(),
))
.build();
server.make_app()
server.build(server.make_app())
}
#[tokio::test]

View File

@@ -109,7 +109,7 @@ fn make_test_app(tx: Arc<mpsc::Sender<(String, String)>>, db_name: Option<&str>)
server_builder.with_influxdb_handler(instance);
let server = server_builder.build();
server.make_app()
server.build(server.make_app())
}
#[tokio::test]

View File

@@ -97,7 +97,7 @@ fn make_test_app(tx: mpsc::Sender<String>) -> Router {
.with_sql_handler(instance.clone())
.with_opentsdb_handler(instance)
.build();
server.make_app()
server.build(server.make_app())
}
#[tokio::test]

View File

@@ -122,7 +122,7 @@ fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
.with_sql_handler(instance.clone())
.with_prom_handler(instance)
.build();
server.make_app()
server.build(server.make_app())
}
#[tokio::test]

View File

@@ -306,7 +306,7 @@ pub async fn setup_test_http_app(store_type: StorageType, name: &str) -> (Router
.with_grpc_handler(ServerGrpcQueryHandlerAdaptor::arc(instance.clone()))
.with_metrics_handler(MetricsHandler)
.build();
(http_server.make_app(), guard)
(http_server.build(http_server.make_app()), guard)
}
pub async fn setup_test_http_app_with_frontend(
@@ -332,7 +332,7 @@ pub async fn setup_test_http_app_with_frontend(
.with_grpc_handler(ServerGrpcQueryHandlerAdaptor::arc(frontend_ref.clone()))
.with_script_handler(frontend_ref)
.build();
let app = http_server.make_app();
let app = http_server.build(http_server.make_app());
(app, guard)
}