feat: sql query interceptor and plugin refactoring (#773)

* feat: let instance hold plugins

* feat: add sql query interceptor definition

* docs: add comments to key apis

* feat: add implementation for pre-parsing and post-parsing

* feat: add post_execute hook

* test: add tests for interceptor

* chore: add license header

* fix: clippy error

* Update src/cmd/src/frontend.rs

Co-authored-by: LFC <bayinamine@gmail.com>

* refactor: batching post_parsing calls

* refactor: rename AnyMap2 to Plugins

* feat: call pre_execute with logical plan empty at the moment

Co-authored-by: LFC <bayinamine@gmail.com>
This commit is contained in:
Ning Sun
2022-12-23 15:22:12 +08:00
committed by GitHub
parent 1daba75e7b
commit 11bdb33d37
10 changed files with 378 additions and 21 deletions

View File

@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anymap::AnyMap;
use std::sync::Arc;
use clap::Parser;
use frontend::frontend::{Frontend, FrontendOptions};
use frontend::grpc::GrpcOptions;
@@ -21,6 +22,7 @@ use frontend::instance::Instance;
use frontend::mysql::MysqlOptions;
use frontend::opentsdb::OpentsdbOptions;
use frontend::postgres::PostgresOptions;
use frontend::Plugins;
use meta_client::MetaClientOpts;
use servers::auth::UserProviderRef;
use servers::http::HttpOptions;
@@ -86,21 +88,21 @@ pub struct StartCommand {
impl StartCommand {
async fn run(self) -> Result<()> {
let plugins = load_frontend_plugins(&self.user_provider)?;
let plugins = Arc::new(load_frontend_plugins(&self.user_provider)?);
let opts: FrontendOptions = self.try_into()?;
let mut frontend = Frontend::new(
opts.clone(),
Instance::try_new_distributed(&opts)
.await
.context(error::StartFrontendSnafu)?,
plugins,
);
let mut instance = Instance::try_new_distributed(&opts)
.await
.context(error::StartFrontendSnafu)?;
instance.set_plugins(plugins.clone());
let mut frontend = Frontend::new(opts, instance, plugins);
frontend.start().await.context(error::StartFrontendSnafu)
}
}
pub fn load_frontend_plugins(user_provider: &Option<String>) -> Result<AnyMap> {
let mut plugins = AnyMap::new();
pub fn load_frontend_plugins(user_provider: &Option<String>) -> Result<Plugins> {
let mut plugins = Plugins::new();
if let Some(provider) = user_provider {
let provider = auth::user_provider_from_option(provider).context(IllegalAuthConfigSnafu)?;

View File

@@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use anymap::AnyMap;
use std::sync::Arc;
use clap::Parser;
use common_telemetry::info;
use datanode::datanode::{Datanode, DatanodeOptions, ObjectStoreConfig};
@@ -25,6 +26,7 @@ use frontend::mysql::MysqlOptions;
use frontend::opentsdb::OpentsdbOptions;
use frontend::postgres::PostgresOptions;
use frontend::prometheus::PrometheusOptions;
use frontend::Plugins;
use serde::{Deserialize, Serialize};
use servers::http::HttpOptions;
use servers::tls::{TlsMode, TlsOption};
@@ -150,7 +152,7 @@ impl StartCommand {
async fn run(self) -> Result<()> {
let enable_memory_catalog = self.enable_memory_catalog;
let config_file = self.config_file.clone();
let plugins = load_frontend_plugins(&self.user_provider)?;
let plugins = Arc::new(load_frontend_plugins(&self.user_provider)?);
let fe_opts = FrontendOptions::try_from(self)?;
let dn_opts: DatanodeOptions = {
let mut opts: StandaloneOptions = if let Some(path) = config_file {
@@ -187,11 +189,12 @@ impl StartCommand {
/// Build frontend instance in standalone mode
async fn build_frontend(
fe_opts: FrontendOptions,
plugins: AnyMap,
plugins: Arc<Plugins>,
datanode_instance: InstanceRef,
) -> Result<Frontend<FeInstance>> {
let mut frontend_instance = FeInstance::new_standalone(datanode_instance.clone());
frontend_instance.set_script_handler(datanode_instance);
frontend_instance.set_plugins(plugins.clone());
Ok(Frontend::new(fe_opts, frontend_instance, plugins))
}

View File

@@ -14,7 +14,6 @@
use std::sync::Arc;
use anymap::AnyMap;
use meta_client::MetaClientOpts;
use serde::{Deserialize, Serialize};
use servers::auth::UserProviderRef;
@@ -31,6 +30,7 @@ use crate::opentsdb::OpentsdbOptions;
use crate::postgres::PostgresOptions;
use crate::prometheus::PrometheusOptions;
use crate::server::Services;
use crate::Plugins;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FrontendOptions {
@@ -67,11 +67,11 @@ where
{
opts: FrontendOptions,
instance: Option<T>,
plugins: AnyMap,
plugins: Arc<Plugins>,
}
impl<T: FrontendInstance> Frontend<T> {
pub fn new(opts: FrontendOptions, instance: T, plugins: AnyMap) -> Self {
pub fn new(opts: FrontendOptions, instance: T, plugins: Arc<Plugins>) -> Self {
Self {
opts,
instance: Some(instance),
@@ -90,6 +90,7 @@ impl<T: FrontendInstance> Frontend<T> {
let instance = Arc::new(instance);
// TODO(sunng87): merge this into instance
let provider = self.plugins.get::<UserProviderRef>().cloned();
Services::start(&self.opts, instance, provider).await

View File

@@ -43,6 +43,7 @@ use datanode::instance::InstanceRef as DnInstanceRef;
use distributed::DistInstance;
use meta_client::client::{MetaClient, MetaClientBuilder};
use meta_client::MetaClientOpts;
use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef};
use servers::query_handler::{
GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef,
InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler,
@@ -69,6 +70,7 @@ use crate::frontend::FrontendOptions;
use crate::sql::insert_to_request;
use crate::table::insert::insert_request_to_insert_batch;
use crate::table::route::TableRoutes;
use crate::Plugins;
#[async_trait]
pub trait FrontendInstance:
@@ -105,6 +107,10 @@ pub struct Instance {
sql_handler: SqlQueryHandlerRef,
grpc_query_handler: GrpcQueryHandlerRef,
grpc_admin_handler: GrpcAdminHandlerRef,
/// plugins: this map holds extensions to customize query or auth
/// behaviours.
plugins: Arc<Plugins>,
}
impl Instance {
@@ -135,6 +141,7 @@ impl Instance {
sql_handler: dist_instance_ref.clone(),
grpc_query_handler: dist_instance_ref.clone(),
grpc_admin_handler: dist_instance_ref,
plugins: Default::default(),
})
}
@@ -178,6 +185,7 @@ impl Instance {
sql_handler: dn_instance.clone(),
grpc_query_handler: dn_instance.clone(),
grpc_admin_handler: dn_instance,
plugins: Default::default(),
}
}
@@ -451,6 +459,14 @@ impl Instance {
Ok(Output::RecordBatches(RecordBatches::empty()))
}
pub fn set_plugins(&mut self, map: Arc<Plugins>) {
self.plugins = map;
}
pub fn plugins(&self) -> Arc<Plugins> {
self.plugins.clone()
}
}
#[async_trait]
@@ -563,15 +579,33 @@ impl SqlQueryHandler for Instance {
query: &str,
query_ctx: QueryContextRef,
) -> Vec<server_error::Result<Output>> {
match parse_stmt(query)
let query_interceptor = self.plugins.get::<SqlQueryInterceptorRef>();
let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) {
Ok(q) => q,
Err(e) => return vec![Err(e)],
};
match parse_stmt(query.as_ref())
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query })
.and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))
{
Ok(stmts) => {
let mut results = Vec::with_capacity(stmts.len());
for stmt in stmts {
// TODO(sunng87): figure out at which stage we can call
// this hook after ArrowFlight adoption. We need to provide
// LogicalPlan as to this hook.
if let Err(e) = query_interceptor.pre_execute(&stmt, None, query_ctx.clone()) {
results.push(Err(e));
break;
}
match self.query_statement(stmt, query_ctx.clone()).await {
Ok(output) => results.push(Ok(output)),
Ok(output) => {
let output_result =
query_interceptor.post_execute(output, query_ctx.clone());
results.push(output_result);
}
Err(e) => {
results.push(Err(e));
break;
@@ -591,7 +625,15 @@ impl SqlQueryHandler for Instance {
stmt: Statement,
query_ctx: QueryContextRef,
) -> server_error::Result<Output> {
self.query_statement(stmt, query_ctx).await
let query_interceptor = self.plugins.get::<SqlQueryInterceptorRef>();
// TODO(sunng87): figure out at which stage we can call
// this hook after ArrowFlight adoption. We need to provide
// LogicalPlan as to this hook.
query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?;
self.query_statement(stmt, query_ctx.clone())
.await
.and_then(|output| query_interceptor.post_execute(output, query_ctx.clone()))
}
fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result<bool> {
@@ -673,6 +715,8 @@ impl GrpcAdminHandler for Instance {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::borrow::Cow;
use std::sync::atomic::AtomicU32;
use api::v1::codec::SelectResult;
use api::v1::column::SemanticType;
@@ -972,4 +1016,164 @@ mod tests {
region_ids: vec![0],
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_sql_interceptor_plugin() {
#[derive(Default)]
struct AssertionHook {
pub(crate) c: AtomicU32,
}
impl SqlQueryInterceptor for AssertionHook {
fn pre_parsing<'a>(
&self,
query: &'a str,
_query_ctx: QueryContextRef,
) -> server_error::Result<std::borrow::Cow<'a, str>> {
self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
assert!(query.starts_with("CREATE TABLE demo"));
Ok(Cow::Borrowed(query))
}
fn post_parsing(
&self,
statements: Vec<Statement>,
_query_ctx: QueryContextRef,
) -> server_error::Result<Vec<Statement>> {
self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
assert!(matches!(statements[0], Statement::CreateTable(_)));
Ok(statements)
}
fn pre_execute(
&self,
_statement: &Statement,
_plan: Option<&query::plan::LogicalPlan>,
_query_ctx: QueryContextRef,
) -> server_error::Result<()> {
self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(())
}
fn post_execute(
&self,
mut output: Output,
_query_ctx: QueryContextRef,
) -> server_error::Result<Output> {
self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
match &mut output {
Output::AffectedRows(rows) => {
assert_eq!(*rows, 1);
// update output result
*rows = 10;
}
_ => unreachable!(),
}
Ok(output)
}
}
let query_ctx = Arc::new(QueryContext::new());
let (mut instance, _guard) = tests::create_frontend_instance("test_hook").await;
let mut plugins = Plugins::new();
let counter_hook = Arc::new(AssertionHook::default());
plugins.insert::<SqlQueryInterceptorRef>(counter_hook.clone());
Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins));
let sql = r#"CREATE TABLE demo(
host STRING,
ts TIMESTAMP,
cpu DOUBLE NULL,
memory DOUBLE NULL,
disk_util DOUBLE DEFAULT 9.9,
TIME INDEX (ts),
PRIMARY KEY(host)
) engine=mito with(regions=1);"#;
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
.await
.remove(0)
.unwrap();
// assert that the hook is called 3 times
assert_eq!(4, counter_hook.c.load(std::sync::atomic::Ordering::Relaxed));
match output {
Output::AffectedRows(rows) => assert_eq!(rows, 10),
_ => unreachable!(),
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_disable_db_operation_plugin() {
#[derive(Default)]
struct DisableDBOpHook;
impl SqlQueryInterceptor for DisableDBOpHook {
fn post_parsing(
&self,
statements: Vec<Statement>,
_query_ctx: QueryContextRef,
) -> server_error::Result<Vec<Statement>> {
for s in &statements {
match s {
Statement::CreateDatabase(_) | Statement::ShowDatabases(_) => {
return Err(server_error::Error::NotSupported {
feat: "Database operations".to_owned(),
})
}
_ => {}
}
}
Ok(statements)
}
}
let query_ctx = Arc::new(QueryContext::new());
let (mut instance, _guard) = tests::create_frontend_instance("test_db_hook").await;
let mut plugins = Plugins::new();
let hook = Arc::new(DisableDBOpHook::default());
plugins.insert::<SqlQueryInterceptorRef>(hook.clone());
Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins));
let sql = r#"CREATE TABLE demo(
host STRING,
ts TIMESTAMP,
cpu DOUBLE NULL,
memory DOUBLE NULL,
disk_util DOUBLE DEFAULT 9.9,
TIME INDEX (ts),
PRIMARY KEY(host)
) engine=mito with(regions=1);"#;
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
.await
.remove(0)
.unwrap();
match output {
Output::AffectedRows(rows) => assert_eq!(rows, 1),
_ => unreachable!(),
}
let sql = r#"CREATE DATABASE tomcat"#;
if let Err(e) = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
.await
.remove(0)
{
assert!(matches!(e, server_error::Error::NotSupported { .. }));
} else {
unreachable!();
}
let sql = r#"SELECT 1; SHOW DATABASES"#;
if let Err(e) = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
.await
.remove(0)
{
assert!(matches!(e, server_error::Error::NotSupported { .. }));
} else {
unreachable!();
}
}
}

View File

@@ -14,6 +14,8 @@
#![feature(assert_matches)]
pub type Plugins = anymap::Map<dyn core::any::Any + Send + Sync>;
mod catalog;
mod datanode;
pub mod error;

View File

@@ -37,6 +37,7 @@ openmetrics-parser = "0.4"
opensrv-mysql = "0.3"
pgwire = "0.6.3"
prost = "0.11"
query = { path = "../query" }
rand = "0.8"
regex = "1.6"
rustls = "0.20"
@@ -65,7 +66,6 @@ common-base = { path = "../common/base" }
mysql_async = { version = "0.31", default-features = false, features = [
"default-rustls",
] }
query = { path = "../query" }
rand = "0.8"
script = { path = "../script", features = ["python"] }
serde_json = "1.0"

View File

@@ -0,0 +1,105 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::sync::Arc;
use common_query::Output;
use query::plan::LogicalPlan;
use session::context::QueryContextRef;
use sql::statements::statement::Statement;
use crate::error::Result;
/// SqlQueryInterceptor can track life cycle of a sql query and customize or
/// abort its execution at given point.
pub trait SqlQueryInterceptor {
/// Called before a query string is parsed into sql statements.
/// The implementation is allowed to change the sql string if needed.
fn pre_parsing<'a>(&self, query: &'a str, _query_ctx: QueryContextRef) -> Result<Cow<'a, str>> {
Ok(Cow::Borrowed(query))
}
/// Called after sql is parsed into statements. This interceptor is called
/// on each statement and the implementation can alter the statement or
/// abort execution by raising an error.
fn post_parsing(
&self,
statements: Vec<Statement>,
_query_ctx: QueryContextRef,
) -> Result<Vec<Statement>> {
Ok(statements)
}
/// Called before sql is actually executed. This hook is not called at the moment.
fn pre_execute(
&self,
_statement: &Statement,
_plan: Option<&LogicalPlan>,
_query_ctx: QueryContextRef,
) -> Result<()> {
Ok(())
}
/// Called after execution finished. The implementation can modify the
/// output if needed.
fn post_execute(&self, output: Output, _query_ctx: QueryContextRef) -> Result<Output> {
Ok(output)
}
}
pub type SqlQueryInterceptorRef = Arc<dyn SqlQueryInterceptor + Send + Sync + 'static>;
impl SqlQueryInterceptor for Option<&SqlQueryInterceptorRef> {
fn pre_parsing<'a>(&self, query: &'a str, query_ctx: QueryContextRef) -> Result<Cow<'a, str>> {
if let Some(this) = self {
this.pre_parsing(query, query_ctx)
} else {
Ok(Cow::Borrowed(query))
}
}
fn post_parsing(
&self,
statements: Vec<Statement>,
query_ctx: QueryContextRef,
) -> Result<Vec<Statement>> {
if let Some(this) = self {
this.post_parsing(statements, query_ctx)
} else {
Ok(statements)
}
}
fn pre_execute(
&self,
statement: &Statement,
plan: Option<&LogicalPlan>,
query_ctx: QueryContextRef,
) -> Result<()> {
if let Some(this) = self {
this.pre_execute(statement, plan, query_ctx)
} else {
Ok(())
}
}
fn post_execute(&self, output: Output, query_ctx: QueryContextRef) -> Result<Output> {
if let Some(this) = self {
this.post_execute(output, query_ctx)
} else {
Ok(output)
}
}
}

View File

@@ -22,6 +22,7 @@ pub mod error;
pub mod grpc;
pub mod http;
pub mod influxdb;
pub mod interceptor;
pub mod line_writer;
pub mod mysql;
pub mod opentsdb;

View File

@@ -0,0 +1,38 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::sync::Arc;
use servers::error::Result;
use servers::interceptor::SqlQueryInterceptor;
use session::context::{QueryContext, QueryContextRef};
pub struct NoopInterceptor;
impl SqlQueryInterceptor for NoopInterceptor {
fn pre_parsing<'a>(&self, query: &'a str, _query_ctx: QueryContextRef) -> Result<Cow<'a, str>> {
let modified_query = format!("{query};");
Ok(Cow::Owned(modified_query))
}
}
#[test]
fn test_default_interceptor_behaviour() {
let di = NoopInterceptor;
let ctx = Arc::new(QueryContext::new());
let query = "SELECT 1";
assert_eq!("SELECT 1;", di.pre_parsing(query, ctx).unwrap());
}

View File

@@ -33,6 +33,7 @@ use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
use script::python::{PyEngine, PyScript};
use session::context::QueryContextRef;
mod interceptor;
mod opentsdb;
mod postgres;