From 11bdb33d37a35489d89bea39b36e08dbfb29dcd5 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Fri, 23 Dec 2022 15:22:12 +0800 Subject: [PATCH] feat: sql query interceptor and plugin refactoring (#773) * feat: let instance hold plugins * feat: add sql query interceptor definition * docs: add comments to key apis * feat: add implementation for pre-parsing and post-parsing * feat: add post_execute hook * test: add tests for interceptor * chore: add license header * fix: clippy error * Update src/cmd/src/frontend.rs Co-authored-by: LFC * refactor: batching post_parsing calls * refactor: rename AnyMap2 to Plugins * feat: call pre_execute with logical plan empty at the moment Co-authored-by: LFC --- src/cmd/src/frontend.rs | 24 ++-- src/cmd/src/standalone.rs | 9 +- src/frontend/src/frontend.rs | 7 +- src/frontend/src/instance.rs | 210 ++++++++++++++++++++++++++++++- src/frontend/src/lib.rs | 2 + src/servers/Cargo.toml | 2 +- src/servers/src/interceptor.rs | 105 ++++++++++++++++ src/servers/src/lib.rs | 1 + src/servers/tests/interceptor.rs | 38 ++++++ src/servers/tests/mod.rs | 1 + 10 files changed, 378 insertions(+), 21 deletions(-) create mode 100644 src/servers/src/interceptor.rs create mode 100644 src/servers/tests/interceptor.rs diff --git a/src/cmd/src/frontend.rs b/src/cmd/src/frontend.rs index 6bea05ce67..0563629ab2 100644 --- a/src/cmd/src/frontend.rs +++ b/src/cmd/src/frontend.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anymap::AnyMap; +use std::sync::Arc; + use clap::Parser; use frontend::frontend::{Frontend, FrontendOptions}; use frontend::grpc::GrpcOptions; @@ -21,6 +22,7 @@ use frontend::instance::Instance; use frontend::mysql::MysqlOptions; use frontend::opentsdb::OpentsdbOptions; use frontend::postgres::PostgresOptions; +use frontend::Plugins; use meta_client::MetaClientOpts; use servers::auth::UserProviderRef; use servers::http::HttpOptions; @@ -86,21 +88,21 @@ pub struct StartCommand { impl StartCommand { async fn run(self) -> Result<()> { - let plugins = load_frontend_plugins(&self.user_provider)?; + let plugins = Arc::new(load_frontend_plugins(&self.user_provider)?); let opts: FrontendOptions = self.try_into()?; - let mut frontend = Frontend::new( - opts.clone(), - Instance::try_new_distributed(&opts) - .await - .context(error::StartFrontendSnafu)?, - plugins, - ); + + let mut instance = Instance::try_new_distributed(&opts) + .await + .context(error::StartFrontendSnafu)?; + instance.set_plugins(plugins.clone()); + + let mut frontend = Frontend::new(opts, instance, plugins); frontend.start().await.context(error::StartFrontendSnafu) } } -pub fn load_frontend_plugins(user_provider: &Option) -> Result { - let mut plugins = AnyMap::new(); +pub fn load_frontend_plugins(user_provider: &Option) -> Result { + let mut plugins = Plugins::new(); if let Some(provider) = user_provider { let provider = auth::user_provider_from_option(provider).context(IllegalAuthConfigSnafu)?; diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index e14f6f6e0a..83809f0b43 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anymap::AnyMap; +use std::sync::Arc; + use clap::Parser; use common_telemetry::info; use datanode::datanode::{Datanode, DatanodeOptions, ObjectStoreConfig}; @@ -25,6 +26,7 @@ use frontend::mysql::MysqlOptions; use frontend::opentsdb::OpentsdbOptions; use frontend::postgres::PostgresOptions; use frontend::prometheus::PrometheusOptions; +use frontend::Plugins; use serde::{Deserialize, Serialize}; use servers::http::HttpOptions; use servers::tls::{TlsMode, TlsOption}; @@ -150,7 +152,7 @@ impl StartCommand { async fn run(self) -> Result<()> { let enable_memory_catalog = self.enable_memory_catalog; let config_file = self.config_file.clone(); - let plugins = load_frontend_plugins(&self.user_provider)?; + let plugins = Arc::new(load_frontend_plugins(&self.user_provider)?); let fe_opts = FrontendOptions::try_from(self)?; let dn_opts: DatanodeOptions = { let mut opts: StandaloneOptions = if let Some(path) = config_file { @@ -187,11 +189,12 @@ impl StartCommand { /// Build frontend instance in standalone mode async fn build_frontend( fe_opts: FrontendOptions, - plugins: AnyMap, + plugins: Arc, datanode_instance: InstanceRef, ) -> Result> { let mut frontend_instance = FeInstance::new_standalone(datanode_instance.clone()); frontend_instance.set_script_handler(datanode_instance); + frontend_instance.set_plugins(plugins.clone()); Ok(Frontend::new(fe_opts, frontend_instance, plugins)) } diff --git a/src/frontend/src/frontend.rs b/src/frontend/src/frontend.rs index c73d229e1b..2c943de82d 100644 --- a/src/frontend/src/frontend.rs +++ b/src/frontend/src/frontend.rs @@ -14,7 +14,6 @@ use std::sync::Arc; -use anymap::AnyMap; use meta_client::MetaClientOpts; use serde::{Deserialize, Serialize}; use servers::auth::UserProviderRef; @@ -31,6 +30,7 @@ use crate::opentsdb::OpentsdbOptions; use crate::postgres::PostgresOptions; use crate::prometheus::PrometheusOptions; use crate::server::Services; +use crate::Plugins; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct FrontendOptions { @@ -67,11 +67,11 @@ where { opts: FrontendOptions, instance: Option, - plugins: AnyMap, + plugins: Arc, } impl Frontend { - pub fn new(opts: FrontendOptions, instance: T, plugins: AnyMap) -> Self { + pub fn new(opts: FrontendOptions, instance: T, plugins: Arc) -> Self { Self { opts, instance: Some(instance), @@ -90,6 +90,7 @@ impl Frontend { let instance = Arc::new(instance); + // TODO(sunng87): merge this into instance let provider = self.plugins.get::().cloned(); Services::start(&self.opts, instance, provider).await diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index b14bb136eb..8550ce456f 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -43,6 +43,7 @@ use datanode::instance::InstanceRef as DnInstanceRef; use distributed::DistInstance; use meta_client::client::{MetaClient, MetaClientBuilder}; use meta_client::MetaClientOpts; +use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef}; use servers::query_handler::{ GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef, InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler, @@ -69,6 +70,7 @@ use crate::frontend::FrontendOptions; use crate::sql::insert_to_request; use crate::table::insert::insert_request_to_insert_batch; use crate::table::route::TableRoutes; +use crate::Plugins; #[async_trait] pub trait FrontendInstance: @@ -105,6 +107,10 @@ pub struct Instance { sql_handler: SqlQueryHandlerRef, grpc_query_handler: GrpcQueryHandlerRef, grpc_admin_handler: GrpcAdminHandlerRef, + + /// plugins: this map holds extensions to customize query or auth + /// behaviours. + plugins: Arc, } impl Instance { @@ -135,6 +141,7 @@ impl Instance { sql_handler: dist_instance_ref.clone(), grpc_query_handler: dist_instance_ref.clone(), grpc_admin_handler: dist_instance_ref, + plugins: Default::default(), }) } @@ -178,6 +185,7 @@ impl Instance { sql_handler: dn_instance.clone(), grpc_query_handler: dn_instance.clone(), grpc_admin_handler: dn_instance, + plugins: Default::default(), } } @@ -451,6 +459,14 @@ impl Instance { Ok(Output::RecordBatches(RecordBatches::empty())) } + + pub fn set_plugins(&mut self, map: Arc) { + self.plugins = map; + } + + pub fn plugins(&self) -> Arc { + self.plugins.clone() + } } #[async_trait] @@ -563,15 +579,33 @@ impl SqlQueryHandler for Instance { query: &str, query_ctx: QueryContextRef, ) -> Vec> { - match parse_stmt(query) + let query_interceptor = self.plugins.get::(); + let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) { + Ok(q) => q, + Err(e) => return vec![Err(e)], + }; + + match parse_stmt(query.as_ref()) .map_err(BoxedError::new) .context(server_error::ExecuteQuerySnafu { query }) + .and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone())) { Ok(stmts) => { let mut results = Vec::with_capacity(stmts.len()); for stmt in stmts { + // TODO(sunng87): figure out at which stage we can call + // this hook after ArrowFlight adoption. We need to provide + // LogicalPlan as to this hook. + if let Err(e) = query_interceptor.pre_execute(&stmt, None, query_ctx.clone()) { + results.push(Err(e)); + break; + } match self.query_statement(stmt, query_ctx.clone()).await { - Ok(output) => results.push(Ok(output)), + Ok(output) => { + let output_result = + query_interceptor.post_execute(output, query_ctx.clone()); + results.push(output_result); + } Err(e) => { results.push(Err(e)); break; @@ -591,7 +625,15 @@ impl SqlQueryHandler for Instance { stmt: Statement, query_ctx: QueryContextRef, ) -> server_error::Result { - self.query_statement(stmt, query_ctx).await + let query_interceptor = self.plugins.get::(); + + // TODO(sunng87): figure out at which stage we can call + // this hook after ArrowFlight adoption. We need to provide + // LogicalPlan as to this hook. + query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?; + self.query_statement(stmt, query_ctx.clone()) + .await + .and_then(|output| query_interceptor.post_execute(output, query_ctx.clone())) } fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result { @@ -673,6 +715,8 @@ impl GrpcAdminHandler for Instance { #[cfg(test)] mod tests { use std::assert_matches::assert_matches; + use std::borrow::Cow; + use std::sync::atomic::AtomicU32; use api::v1::codec::SelectResult; use api::v1::column::SemanticType; @@ -972,4 +1016,164 @@ mod tests { region_ids: vec![0], } } + + #[tokio::test(flavor = "multi_thread")] + async fn test_sql_interceptor_plugin() { + #[derive(Default)] + struct AssertionHook { + pub(crate) c: AtomicU32, + } + + impl SqlQueryInterceptor for AssertionHook { + fn pre_parsing<'a>( + &self, + query: &'a str, + _query_ctx: QueryContextRef, + ) -> server_error::Result> { + self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + assert!(query.starts_with("CREATE TABLE demo")); + Ok(Cow::Borrowed(query)) + } + + fn post_parsing( + &self, + statements: Vec, + _query_ctx: QueryContextRef, + ) -> server_error::Result> { + self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + assert!(matches!(statements[0], Statement::CreateTable(_))); + Ok(statements) + } + + fn pre_execute( + &self, + _statement: &Statement, + _plan: Option<&query::plan::LogicalPlan>, + _query_ctx: QueryContextRef, + ) -> server_error::Result<()> { + self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Ok(()) + } + + fn post_execute( + &self, + mut output: Output, + _query_ctx: QueryContextRef, + ) -> server_error::Result { + self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + match &mut output { + Output::AffectedRows(rows) => { + assert_eq!(*rows, 1); + // update output result + *rows = 10; + } + _ => unreachable!(), + } + Ok(output) + } + } + + let query_ctx = Arc::new(QueryContext::new()); + let (mut instance, _guard) = tests::create_frontend_instance("test_hook").await; + + let mut plugins = Plugins::new(); + let counter_hook = Arc::new(AssertionHook::default()); + plugins.insert::(counter_hook.clone()); + Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins)); + + let sql = r#"CREATE TABLE demo( + host STRING, + ts TIMESTAMP, + cpu DOUBLE NULL, + memory DOUBLE NULL, + disk_util DOUBLE DEFAULT 9.9, + TIME INDEX (ts), + PRIMARY KEY(host) + ) engine=mito with(regions=1);"#; + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .remove(0) + .unwrap(); + + // assert that the hook is called 3 times + assert_eq!(4, counter_hook.c.load(std::sync::atomic::Ordering::Relaxed)); + match output { + Output::AffectedRows(rows) => assert_eq!(rows, 10), + _ => unreachable!(), + } + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_disable_db_operation_plugin() { + #[derive(Default)] + struct DisableDBOpHook; + + impl SqlQueryInterceptor for DisableDBOpHook { + fn post_parsing( + &self, + statements: Vec, + _query_ctx: QueryContextRef, + ) -> server_error::Result> { + for s in &statements { + match s { + Statement::CreateDatabase(_) | Statement::ShowDatabases(_) => { + return Err(server_error::Error::NotSupported { + feat: "Database operations".to_owned(), + }) + } + _ => {} + } + } + + Ok(statements) + } + } + + let query_ctx = Arc::new(QueryContext::new()); + let (mut instance, _guard) = tests::create_frontend_instance("test_db_hook").await; + + let mut plugins = Plugins::new(); + let hook = Arc::new(DisableDBOpHook::default()); + plugins.insert::(hook.clone()); + Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins)); + + let sql = r#"CREATE TABLE demo( + host STRING, + ts TIMESTAMP, + cpu DOUBLE NULL, + memory DOUBLE NULL, + disk_util DOUBLE DEFAULT 9.9, + TIME INDEX (ts), + PRIMARY KEY(host) + ) engine=mito with(regions=1);"#; + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .remove(0) + .unwrap(); + + match output { + Output::AffectedRows(rows) => assert_eq!(rows, 1), + _ => unreachable!(), + } + + let sql = r#"CREATE DATABASE tomcat"#; + if let Err(e) = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .remove(0) + { + assert!(matches!(e, server_error::Error::NotSupported { .. })); + } else { + unreachable!(); + } + + let sql = r#"SELECT 1; SHOW DATABASES"#; + if let Err(e) = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .remove(0) + { + assert!(matches!(e, server_error::Error::NotSupported { .. })); + } else { + unreachable!(); + } + } } diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index 82807d0582..0c5bf33816 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -14,6 +14,8 @@ #![feature(assert_matches)] +pub type Plugins = anymap::Map; + mod catalog; mod datanode; pub mod error; diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 6cdd0a83cf..3d3978cd20 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -37,6 +37,7 @@ openmetrics-parser = "0.4" opensrv-mysql = "0.3" pgwire = "0.6.3" prost = "0.11" +query = { path = "../query" } rand = "0.8" regex = "1.6" rustls = "0.20" @@ -65,7 +66,6 @@ common-base = { path = "../common/base" } mysql_async = { version = "0.31", default-features = false, features = [ "default-rustls", ] } -query = { path = "../query" } rand = "0.8" script = { path = "../script", features = ["python"] } serde_json = "1.0" diff --git a/src/servers/src/interceptor.rs b/src/servers/src/interceptor.rs new file mode 100644 index 0000000000..3f105e7dd8 --- /dev/null +++ b/src/servers/src/interceptor.rs @@ -0,0 +1,105 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::sync::Arc; + +use common_query::Output; +use query::plan::LogicalPlan; +use session::context::QueryContextRef; +use sql::statements::statement::Statement; + +use crate::error::Result; + +/// SqlQueryInterceptor can track life cycle of a sql query and customize or +/// abort its execution at given point. +pub trait SqlQueryInterceptor { + /// Called before a query string is parsed into sql statements. + /// The implementation is allowed to change the sql string if needed. + fn pre_parsing<'a>(&self, query: &'a str, _query_ctx: QueryContextRef) -> Result> { + Ok(Cow::Borrowed(query)) + } + + /// Called after sql is parsed into statements. This interceptor is called + /// on each statement and the implementation can alter the statement or + /// abort execution by raising an error. + fn post_parsing( + &self, + statements: Vec, + _query_ctx: QueryContextRef, + ) -> Result> { + Ok(statements) + } + + /// Called before sql is actually executed. This hook is not called at the moment. + fn pre_execute( + &self, + _statement: &Statement, + _plan: Option<&LogicalPlan>, + _query_ctx: QueryContextRef, + ) -> Result<()> { + Ok(()) + } + + /// Called after execution finished. The implementation can modify the + /// output if needed. + fn post_execute(&self, output: Output, _query_ctx: QueryContextRef) -> Result { + Ok(output) + } +} + +pub type SqlQueryInterceptorRef = Arc; + +impl SqlQueryInterceptor for Option<&SqlQueryInterceptorRef> { + fn pre_parsing<'a>(&self, query: &'a str, query_ctx: QueryContextRef) -> Result> { + if let Some(this) = self { + this.pre_parsing(query, query_ctx) + } else { + Ok(Cow::Borrowed(query)) + } + } + + fn post_parsing( + &self, + statements: Vec, + query_ctx: QueryContextRef, + ) -> Result> { + if let Some(this) = self { + this.post_parsing(statements, query_ctx) + } else { + Ok(statements) + } + } + + fn pre_execute( + &self, + statement: &Statement, + plan: Option<&LogicalPlan>, + query_ctx: QueryContextRef, + ) -> Result<()> { + if let Some(this) = self { + this.pre_execute(statement, plan, query_ctx) + } else { + Ok(()) + } + } + + fn post_execute(&self, output: Output, query_ctx: QueryContextRef) -> Result { + if let Some(this) = self { + this.post_execute(output, query_ctx) + } else { + Ok(output) + } + } +} diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index e58e4363a7..e18caf7fa3 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -22,6 +22,7 @@ pub mod error; pub mod grpc; pub mod http; pub mod influxdb; +pub mod interceptor; pub mod line_writer; pub mod mysql; pub mod opentsdb; diff --git a/src/servers/tests/interceptor.rs b/src/servers/tests/interceptor.rs new file mode 100644 index 0000000000..c1acd7c808 --- /dev/null +++ b/src/servers/tests/interceptor.rs @@ -0,0 +1,38 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::sync::Arc; + +use servers::error::Result; +use servers::interceptor::SqlQueryInterceptor; +use session::context::{QueryContext, QueryContextRef}; + +pub struct NoopInterceptor; + +impl SqlQueryInterceptor for NoopInterceptor { + fn pre_parsing<'a>(&self, query: &'a str, _query_ctx: QueryContextRef) -> Result> { + let modified_query = format!("{query};"); + Ok(Cow::Owned(modified_query)) + } +} + +#[test] +fn test_default_interceptor_behaviour() { + let di = NoopInterceptor; + let ctx = Arc::new(QueryContext::new()); + + let query = "SELECT 1"; + assert_eq!("SELECT 1;", di.pre_parsing(query, ctx).unwrap()); +} diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 0105f7d3f1..32a7638578 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -33,6 +33,7 @@ use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; use session::context::QueryContextRef; +mod interceptor; mod opentsdb; mod postgres;