mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-28 02:40:38 +00:00
feat: sql query interceptor and plugin refactoring (#773)
* feat: let instance hold plugins * feat: add sql query interceptor definition * docs: add comments to key apis * feat: add implementation for pre-parsing and post-parsing * feat: add post_execute hook * test: add tests for interceptor * chore: add license header * fix: clippy error * Update src/cmd/src/frontend.rs Co-authored-by: LFC <bayinamine@gmail.com> * refactor: batching post_parsing calls * refactor: rename AnyMap2 to Plugins * feat: call pre_execute with logical plan empty at the moment Co-authored-by: LFC <bayinamine@gmail.com>
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use anymap::AnyMap;
|
||||
use meta_client::MetaClientOpts;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use servers::auth::UserProviderRef;
|
||||
@@ -31,6 +30,7 @@ use crate::opentsdb::OpentsdbOptions;
|
||||
use crate::postgres::PostgresOptions;
|
||||
use crate::prometheus::PrometheusOptions;
|
||||
use crate::server::Services;
|
||||
use crate::Plugins;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct FrontendOptions {
|
||||
@@ -67,11 +67,11 @@ where
|
||||
{
|
||||
opts: FrontendOptions,
|
||||
instance: Option<T>,
|
||||
plugins: AnyMap,
|
||||
plugins: Arc<Plugins>,
|
||||
}
|
||||
|
||||
impl<T: FrontendInstance> Frontend<T> {
|
||||
pub fn new(opts: FrontendOptions, instance: T, plugins: AnyMap) -> Self {
|
||||
pub fn new(opts: FrontendOptions, instance: T, plugins: Arc<Plugins>) -> Self {
|
||||
Self {
|
||||
opts,
|
||||
instance: Some(instance),
|
||||
@@ -90,6 +90,7 @@ impl<T: FrontendInstance> Frontend<T> {
|
||||
|
||||
let instance = Arc::new(instance);
|
||||
|
||||
// TODO(sunng87): merge this into instance
|
||||
let provider = self.plugins.get::<UserProviderRef>().cloned();
|
||||
|
||||
Services::start(&self.opts, instance, provider).await
|
||||
|
||||
@@ -43,6 +43,7 @@ use datanode::instance::InstanceRef as DnInstanceRef;
|
||||
use distributed::DistInstance;
|
||||
use meta_client::client::{MetaClient, MetaClientBuilder};
|
||||
use meta_client::MetaClientOpts;
|
||||
use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef};
|
||||
use servers::query_handler::{
|
||||
GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef,
|
||||
InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler,
|
||||
@@ -69,6 +70,7 @@ use crate::frontend::FrontendOptions;
|
||||
use crate::sql::insert_to_request;
|
||||
use crate::table::insert::insert_request_to_insert_batch;
|
||||
use crate::table::route::TableRoutes;
|
||||
use crate::Plugins;
|
||||
|
||||
#[async_trait]
|
||||
pub trait FrontendInstance:
|
||||
@@ -105,6 +107,10 @@ pub struct Instance {
|
||||
sql_handler: SqlQueryHandlerRef,
|
||||
grpc_query_handler: GrpcQueryHandlerRef,
|
||||
grpc_admin_handler: GrpcAdminHandlerRef,
|
||||
|
||||
/// plugins: this map holds extensions to customize query or auth
|
||||
/// behaviours.
|
||||
plugins: Arc<Plugins>,
|
||||
}
|
||||
|
||||
impl Instance {
|
||||
@@ -135,6 +141,7 @@ impl Instance {
|
||||
sql_handler: dist_instance_ref.clone(),
|
||||
grpc_query_handler: dist_instance_ref.clone(),
|
||||
grpc_admin_handler: dist_instance_ref,
|
||||
plugins: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -178,6 +185,7 @@ impl Instance {
|
||||
sql_handler: dn_instance.clone(),
|
||||
grpc_query_handler: dn_instance.clone(),
|
||||
grpc_admin_handler: dn_instance,
|
||||
plugins: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,6 +459,14 @@ impl Instance {
|
||||
|
||||
Ok(Output::RecordBatches(RecordBatches::empty()))
|
||||
}
|
||||
|
||||
pub fn set_plugins(&mut self, map: Arc<Plugins>) {
|
||||
self.plugins = map;
|
||||
}
|
||||
|
||||
pub fn plugins(&self) -> Arc<Plugins> {
|
||||
self.plugins.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -563,15 +579,33 @@ impl SqlQueryHandler for Instance {
|
||||
query: &str,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Vec<server_error::Result<Output>> {
|
||||
match parse_stmt(query)
|
||||
let query_interceptor = self.plugins.get::<SqlQueryInterceptorRef>();
|
||||
let query = match query_interceptor.pre_parsing(query, query_ctx.clone()) {
|
||||
Ok(q) => q,
|
||||
Err(e) => return vec![Err(e)],
|
||||
};
|
||||
|
||||
match parse_stmt(query.as_ref())
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query })
|
||||
.and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))
|
||||
{
|
||||
Ok(stmts) => {
|
||||
let mut results = Vec::with_capacity(stmts.len());
|
||||
for stmt in stmts {
|
||||
// TODO(sunng87): figure out at which stage we can call
|
||||
// this hook after ArrowFlight adoption. We need to provide
|
||||
// LogicalPlan as to this hook.
|
||||
if let Err(e) = query_interceptor.pre_execute(&stmt, None, query_ctx.clone()) {
|
||||
results.push(Err(e));
|
||||
break;
|
||||
}
|
||||
match self.query_statement(stmt, query_ctx.clone()).await {
|
||||
Ok(output) => results.push(Ok(output)),
|
||||
Ok(output) => {
|
||||
let output_result =
|
||||
query_interceptor.post_execute(output, query_ctx.clone());
|
||||
results.push(output_result);
|
||||
}
|
||||
Err(e) => {
|
||||
results.push(Err(e));
|
||||
break;
|
||||
@@ -591,7 +625,15 @@ impl SqlQueryHandler for Instance {
|
||||
stmt: Statement,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> server_error::Result<Output> {
|
||||
self.query_statement(stmt, query_ctx).await
|
||||
let query_interceptor = self.plugins.get::<SqlQueryInterceptorRef>();
|
||||
|
||||
// TODO(sunng87): figure out at which stage we can call
|
||||
// this hook after ArrowFlight adoption. We need to provide
|
||||
// LogicalPlan as to this hook.
|
||||
query_interceptor.pre_execute(&stmt, None, query_ctx.clone())?;
|
||||
self.query_statement(stmt, query_ctx.clone())
|
||||
.await
|
||||
.and_then(|output| query_interceptor.post_execute(output, query_ctx.clone()))
|
||||
}
|
||||
|
||||
fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result<bool> {
|
||||
@@ -673,6 +715,8 @@ impl GrpcAdminHandler for Instance {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::assert_matches::assert_matches;
|
||||
use std::borrow::Cow;
|
||||
use std::sync::atomic::AtomicU32;
|
||||
|
||||
use api::v1::codec::SelectResult;
|
||||
use api::v1::column::SemanticType;
|
||||
@@ -972,4 +1016,164 @@ mod tests {
|
||||
region_ids: vec![0],
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_sql_interceptor_plugin() {
|
||||
#[derive(Default)]
|
||||
struct AssertionHook {
|
||||
pub(crate) c: AtomicU32,
|
||||
}
|
||||
|
||||
impl SqlQueryInterceptor for AssertionHook {
|
||||
fn pre_parsing<'a>(
|
||||
&self,
|
||||
query: &'a str,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> server_error::Result<std::borrow::Cow<'a, str>> {
|
||||
self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
assert!(query.starts_with("CREATE TABLE demo"));
|
||||
Ok(Cow::Borrowed(query))
|
||||
}
|
||||
|
||||
fn post_parsing(
|
||||
&self,
|
||||
statements: Vec<Statement>,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> server_error::Result<Vec<Statement>> {
|
||||
self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
assert!(matches!(statements[0], Statement::CreateTable(_)));
|
||||
Ok(statements)
|
||||
}
|
||||
|
||||
fn pre_execute(
|
||||
&self,
|
||||
_statement: &Statement,
|
||||
_plan: Option<&query::plan::LogicalPlan>,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> server_error::Result<()> {
|
||||
self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn post_execute(
|
||||
&self,
|
||||
mut output: Output,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> server_error::Result<Output> {
|
||||
self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
match &mut output {
|
||||
Output::AffectedRows(rows) => {
|
||||
assert_eq!(*rows, 1);
|
||||
// update output result
|
||||
*rows = 10;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
let (mut instance, _guard) = tests::create_frontend_instance("test_hook").await;
|
||||
|
||||
let mut plugins = Plugins::new();
|
||||
let counter_hook = Arc::new(AssertionHook::default());
|
||||
plugins.insert::<SqlQueryInterceptorRef>(counter_hook.clone());
|
||||
Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins));
|
||||
|
||||
let sql = r#"CREATE TABLE demo(
|
||||
host STRING,
|
||||
ts TIMESTAMP,
|
||||
cpu DOUBLE NULL,
|
||||
memory DOUBLE NULL,
|
||||
disk_util DOUBLE DEFAULT 9.9,
|
||||
TIME INDEX (ts),
|
||||
PRIMARY KEY(host)
|
||||
) engine=mito with(regions=1);"#;
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
|
||||
.await
|
||||
.remove(0)
|
||||
.unwrap();
|
||||
|
||||
// assert that the hook is called 3 times
|
||||
assert_eq!(4, counter_hook.c.load(std::sync::atomic::Ordering::Relaxed));
|
||||
match output {
|
||||
Output::AffectedRows(rows) => assert_eq!(rows, 10),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_disable_db_operation_plugin() {
|
||||
#[derive(Default)]
|
||||
struct DisableDBOpHook;
|
||||
|
||||
impl SqlQueryInterceptor for DisableDBOpHook {
|
||||
fn post_parsing(
|
||||
&self,
|
||||
statements: Vec<Statement>,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> server_error::Result<Vec<Statement>> {
|
||||
for s in &statements {
|
||||
match s {
|
||||
Statement::CreateDatabase(_) | Statement::ShowDatabases(_) => {
|
||||
return Err(server_error::Error::NotSupported {
|
||||
feat: "Database operations".to_owned(),
|
||||
})
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(statements)
|
||||
}
|
||||
}
|
||||
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
let (mut instance, _guard) = tests::create_frontend_instance("test_db_hook").await;
|
||||
|
||||
let mut plugins = Plugins::new();
|
||||
let hook = Arc::new(DisableDBOpHook::default());
|
||||
plugins.insert::<SqlQueryInterceptorRef>(hook.clone());
|
||||
Arc::make_mut(&mut instance).set_plugins(Arc::new(plugins));
|
||||
|
||||
let sql = r#"CREATE TABLE demo(
|
||||
host STRING,
|
||||
ts TIMESTAMP,
|
||||
cpu DOUBLE NULL,
|
||||
memory DOUBLE NULL,
|
||||
disk_util DOUBLE DEFAULT 9.9,
|
||||
TIME INDEX (ts),
|
||||
PRIMARY KEY(host)
|
||||
) engine=mito with(regions=1);"#;
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
|
||||
.await
|
||||
.remove(0)
|
||||
.unwrap();
|
||||
|
||||
match output {
|
||||
Output::AffectedRows(rows) => assert_eq!(rows, 1),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let sql = r#"CREATE DATABASE tomcat"#;
|
||||
if let Err(e) = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
|
||||
.await
|
||||
.remove(0)
|
||||
{
|
||||
assert!(matches!(e, server_error::Error::NotSupported { .. }));
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
|
||||
let sql = r#"SELECT 1; SHOW DATABASES"#;
|
||||
if let Err(e) = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
|
||||
.await
|
||||
.remove(0)
|
||||
{
|
||||
assert!(matches!(e, server_error::Error::NotSupported { .. }));
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
#![feature(assert_matches)]
|
||||
|
||||
pub type Plugins = anymap::Map<dyn core::any::Any + Send + Sync>;
|
||||
|
||||
mod catalog;
|
||||
mod datanode;
|
||||
pub mod error;
|
||||
|
||||
Reference in New Issue
Block a user