mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-15 12:30:38 +00:00
fix: check full table name during logical plan creation (#948)
This commit is contained in:
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -1356,6 +1356,7 @@ dependencies = [
|
||||
"anymap",
|
||||
"build-data",
|
||||
"clap 3.2.23",
|
||||
"common-base",
|
||||
"common-error",
|
||||
"common-telemetry",
|
||||
"datanode",
|
||||
@@ -1396,6 +1397,7 @@ dependencies = [
|
||||
name = "common-base"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anymap",
|
||||
"bitvec",
|
||||
"bytes",
|
||||
"common-error",
|
||||
@@ -2663,7 +2665,6 @@ dependencies = [
|
||||
name = "frontend"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anymap",
|
||||
"api",
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
@@ -2702,6 +2703,7 @@ dependencies = [
|
||||
"snafu",
|
||||
"sql",
|
||||
"store-api",
|
||||
"strfmt",
|
||||
"substrait 0.1.0",
|
||||
"table",
|
||||
"tempdir",
|
||||
@@ -5454,6 +5456,7 @@ dependencies = [
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
"catalog",
|
||||
"common-base",
|
||||
"common-catalog",
|
||||
"common-error",
|
||||
"common-function",
|
||||
@@ -7116,6 +7119,12 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strfmt"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b73de159e5e71c4c7579ea3041d3e765f46790555790c24489195554210f1fb4"
|
||||
|
||||
[[package]]
|
||||
name = "string_cache"
|
||||
version = "0.8.4"
|
||||
|
||||
@@ -12,6 +12,7 @@ path = "src/bin/greptime.rs"
|
||||
[dependencies]
|
||||
anymap = "1.0.0-beta.2"
|
||||
clap = { version = "3.1", features = ["derive"] }
|
||||
common-base = { path = "../common/base" }
|
||||
common-error = { path = "../common/error" }
|
||||
common-telemetry = { path = "../common/telemetry", features = [
|
||||
"deadlock_detection",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::Parser;
|
||||
use common_base::Plugins;
|
||||
use frontend::frontend::{Frontend, FrontendOptions};
|
||||
use frontend::grpc::GrpcOptions;
|
||||
use frontend::influxdb::InfluxdbOptions;
|
||||
@@ -22,7 +23,6 @@ use frontend::instance::Instance;
|
||||
use frontend::mysql::MysqlOptions;
|
||||
use frontend::opentsdb::OpentsdbOptions;
|
||||
use frontend::postgres::PostgresOptions;
|
||||
use frontend::Plugins;
|
||||
use meta_client::MetaClientOpts;
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::http::HttpOptions;
|
||||
@@ -91,10 +91,9 @@ impl StartCommand {
|
||||
let plugins = Arc::new(load_frontend_plugins(&self.user_provider)?);
|
||||
let opts: FrontendOptions = self.try_into()?;
|
||||
|
||||
let mut instance = Instance::try_new_distributed(&opts)
|
||||
let instance = Instance::try_new_distributed(&opts, plugins.clone())
|
||||
.await
|
||||
.context(error::StartFrontendSnafu)?;
|
||||
instance.set_plugins(plugins.clone());
|
||||
|
||||
let mut frontend = Frontend::new(opts, instance, plugins);
|
||||
frontend.start().await.context(error::StartFrontendSnafu)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::Parser;
|
||||
use common_base::Plugins;
|
||||
use common_telemetry::info;
|
||||
use datanode::datanode::{Datanode, DatanodeOptions, ObjectStoreConfig, WalConfig};
|
||||
use datanode::instance::InstanceRef;
|
||||
@@ -27,7 +28,6 @@ use frontend::opentsdb::OpentsdbOptions;
|
||||
use frontend::postgres::PostgresOptions;
|
||||
use frontend::prometheus::PrometheusOptions;
|
||||
use frontend::promql::PromqlOptions;
|
||||
use frontend::Plugins;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use servers::http::HttpOptions;
|
||||
use servers::tls::{TlsMode, TlsOption};
|
||||
|
||||
@@ -5,6 +5,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anymap = "1.0.0-beta.2"
|
||||
bitvec = "1.0"
|
||||
bytes = { version = "1.1", features = ["serde"] }
|
||||
common-error = { path = "../error" }
|
||||
|
||||
@@ -19,3 +19,5 @@ pub mod bytes;
|
||||
pub mod readable_size;
|
||||
|
||||
pub use bit_vec::BitVec;
|
||||
|
||||
pub type Plugins = anymap::Map<dyn core::any::Any + Send + Sync>;
|
||||
|
||||
@@ -5,7 +5,6 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anymap = "1.0.0-beta.2"
|
||||
api = { path = "../api" }
|
||||
async-stream.workspace = true
|
||||
async-trait = "0.1"
|
||||
@@ -52,5 +51,6 @@ tonic.workspace = true
|
||||
datanode = { path = "../datanode" }
|
||||
futures = "0.3"
|
||||
meta-srv = { path = "../meta-srv", features = ["mock"] }
|
||||
strfmt = "0.2"
|
||||
tempdir = "0.3"
|
||||
tower = "0.4"
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_base::Plugins;
|
||||
use meta_client::MetaClientOpts;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use servers::http::HttpOptions;
|
||||
@@ -30,7 +31,6 @@ use crate::postgres::PostgresOptions;
|
||||
use crate::prometheus::PrometheusOptions;
|
||||
use crate::promql::PromqlOptions;
|
||||
use crate::server::Services;
|
||||
use crate::Plugins;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
|
||||
@@ -29,10 +29,14 @@ use api::v1::{AddColumns, AlterExpr, Column, DdlRequest, InsertRequest};
|
||||
use async_trait::async_trait;
|
||||
use catalog::remote::MetaKvBackend;
|
||||
use catalog::CatalogManagerRef;
|
||||
use common_base::Plugins;
|
||||
use common_error::ext::BoxedError;
|
||||
use common_grpc::channel_manager::{ChannelConfig, ChannelManager};
|
||||
use common_query::Output;
|
||||
use common_recordbatch::RecordBatches;
|
||||
use common_telemetry::logging::{debug, info};
|
||||
use datafusion::sql::sqlparser::ast::ObjectName;
|
||||
use datafusion_common::TableReference;
|
||||
use datanode::instance::InstanceRef as DnInstanceRef;
|
||||
use datatypes::schema::Schema;
|
||||
use distributed::DistInstance;
|
||||
@@ -40,6 +44,7 @@ use meta_client::client::{MetaClient, MetaClientBuilder};
|
||||
use meta_client::MetaClientOpts;
|
||||
use partition::manager::PartitionRuleManager;
|
||||
use partition::route::TableRoutes;
|
||||
use query::query_engine::options::QueryOptions;
|
||||
use servers::error as server_error;
|
||||
use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef};
|
||||
use servers::promql::{PromqlHandler, PromqlHandlerRef};
|
||||
@@ -58,12 +63,12 @@ use sql::statements::statement::Statement;
|
||||
use crate::catalog::FrontendCatalogManager;
|
||||
use crate::datanode::DatanodeClients;
|
||||
use crate::error::{
|
||||
self, Error, ExecutePromqlSnafu, MissingMetasrvOptsSnafu, NotSupportedSnafu, Result,
|
||||
self, Error, ExecutePromqlSnafu, MissingMetasrvOptsSnafu, NotSupportedSnafu, ParseSqlSnafu,
|
||||
Result, SqlExecInterceptedSnafu,
|
||||
};
|
||||
use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory};
|
||||
use crate::frontend::FrontendOptions;
|
||||
use crate::instance::standalone::{StandaloneGrpcQueryHandler, StandaloneSqlQueryHandler};
|
||||
use crate::Plugins;
|
||||
|
||||
#[async_trait]
|
||||
pub trait FrontendInstance:
|
||||
@@ -101,7 +106,10 @@ pub struct Instance {
|
||||
}
|
||||
|
||||
impl Instance {
|
||||
pub async fn try_new_distributed(opts: &FrontendOptions) -> Result<Self> {
|
||||
pub async fn try_new_distributed(
|
||||
opts: &FrontendOptions,
|
||||
plugins: Arc<Plugins>,
|
||||
) -> Result<Self> {
|
||||
let meta_client = Self::create_meta_client(opts).await?;
|
||||
|
||||
let meta_backend = Arc::new(MetaKvBackend {
|
||||
@@ -117,8 +125,12 @@ impl Instance {
|
||||
datanode_clients.clone(),
|
||||
));
|
||||
|
||||
let dist_instance =
|
||||
DistInstance::new(meta_client, catalog_manager.clone(), datanode_clients);
|
||||
let dist_instance = DistInstance::new(
|
||||
meta_client,
|
||||
catalog_manager.clone(),
|
||||
datanode_clients,
|
||||
plugins.clone(),
|
||||
);
|
||||
let dist_instance = Arc::new(dist_instance);
|
||||
|
||||
Ok(Instance {
|
||||
@@ -128,7 +140,7 @@ impl Instance {
|
||||
sql_handler: dist_instance.clone(),
|
||||
grpc_query_handler: dist_instance,
|
||||
promql_handler: None,
|
||||
plugins: Default::default(),
|
||||
plugins,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -365,11 +377,12 @@ impl FrontendInstance for Instance {
|
||||
}
|
||||
|
||||
fn parse_stmt(sql: &str) -> Result<Vec<Statement>> {
|
||||
ParserContext::create_with_dialect(sql, &GenericDialect {}).context(error::ParseSqlSnafu)
|
||||
ParserContext::create_with_dialect(sql, &GenericDialect {}).context(ParseSqlSnafu)
|
||||
}
|
||||
|
||||
impl Instance {
|
||||
async fn query_statement(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
check_permission(self.plugins.clone(), &stmt, &query_ctx)?;
|
||||
match stmt {
|
||||
Statement::CreateDatabase(_)
|
||||
| Statement::ShowDatabases(_)
|
||||
@@ -521,6 +534,87 @@ impl PromqlHandler for Instance {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_permission(
|
||||
plugins: Arc<Plugins>,
|
||||
stmt: &Statement,
|
||||
query_ctx: &QueryContextRef,
|
||||
) -> Result<()> {
|
||||
let need_validate = plugins
|
||||
.get::<QueryOptions>()
|
||||
.map(|opts| opts.disallow_cross_schema_query)
|
||||
.unwrap_or_default();
|
||||
|
||||
if !need_validate {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match stmt {
|
||||
// query and explain will be checked in QueryEngineState
|
||||
Statement::Query(_) | Statement::Explain(_) => {}
|
||||
// database ops won't be checked
|
||||
Statement::CreateDatabase(_) | Statement::ShowDatabases(_) | Statement::Use(_) => {}
|
||||
// show create table and alter are not supported yet
|
||||
Statement::ShowCreateTable(_) | Statement::Alter(_) => {}
|
||||
|
||||
Statement::Insert(insert) => {
|
||||
let (catalog, schema, _) = insert.full_table_name().context(ParseSqlSnafu)?;
|
||||
validate_param(&catalog, &schema, query_ctx)?;
|
||||
}
|
||||
Statement::CreateTable(stmt) => {
|
||||
let tab_ref = obj_name_to_tab_ref(&stmt.name)?;
|
||||
validate_tab_ref(tab_ref, query_ctx)?;
|
||||
}
|
||||
Statement::DropTable(drop_stmt) => {
|
||||
let tab_ref = obj_name_to_tab_ref(drop_stmt.table_name())?;
|
||||
validate_tab_ref(tab_ref, query_ctx)?;
|
||||
}
|
||||
Statement::ShowTables(stmt) => {
|
||||
if let Some(database) = &stmt.database {
|
||||
validate_param(&query_ctx.current_catalog(), database, query_ctx)?;
|
||||
}
|
||||
}
|
||||
Statement::DescribeTable(stmt) => {
|
||||
let tab_ref = obj_name_to_tab_ref(stmt.name())?;
|
||||
validate_tab_ref(tab_ref, query_ctx)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn obj_name_to_tab_ref(obj: &ObjectName) -> Result<TableReference> {
|
||||
match &obj.0[..] {
|
||||
[table] => Ok(TableReference::Bare {
|
||||
table: &table.value,
|
||||
}),
|
||||
[schema, table] => Ok(TableReference::Partial {
|
||||
schema: &schema.value,
|
||||
table: &table.value,
|
||||
}),
|
||||
[catalog, schema, table] => Ok(TableReference::Full {
|
||||
catalog: &catalog.value,
|
||||
schema: &schema.value,
|
||||
table: &table.value,
|
||||
}),
|
||||
_ => error::InvalidSqlSnafu {
|
||||
err_msg: format!(
|
||||
"expect table name to be <catalog>.<schema>.<table>, <schema>.<table> or <table>, actual: {obj}",
|
||||
),
|
||||
}.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_tab_ref(tab_ref: TableReference, query_ctx: &QueryContextRef) -> Result<()> {
|
||||
query::query_engine::options::validate_table_references(tab_ref, query_ctx)
|
||||
.map_err(BoxedError::new)
|
||||
.context(SqlExecInterceptedSnafu)
|
||||
}
|
||||
|
||||
fn validate_param(catalog: &str, schema: &str, query_ctx: &QueryContextRef) -> Result<()> {
|
||||
query::query_engine::options::validate_catalog_and_schema(catalog, schema, query_ctx)
|
||||
.map_err(BoxedError::new)
|
||||
.context(SqlExecInterceptedSnafu)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::borrow::Cow;
|
||||
@@ -528,13 +622,124 @@ mod tests {
|
||||
use std::sync::atomic::AtomicU32;
|
||||
|
||||
use catalog::helper::{TableGlobalKey, TableGlobalValue};
|
||||
use query::query_engine::options::QueryOptions;
|
||||
use session::context::QueryContext;
|
||||
use strfmt::Format;
|
||||
|
||||
use super::*;
|
||||
use crate::table::DistTable;
|
||||
use crate::tests;
|
||||
use crate::tests::MockDistributedInstance;
|
||||
|
||||
#[test]
|
||||
fn test_exec_validation() {
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
let mut plugins = Plugins::new();
|
||||
plugins.insert(QueryOptions {
|
||||
disallow_cross_schema_query: true,
|
||||
});
|
||||
let plugins = Arc::new(plugins);
|
||||
|
||||
let sql = r#"
|
||||
SELECT * FROM demo;
|
||||
EXPLAIN SELECT * FROM demo;
|
||||
CREATE DATABASE test_database;
|
||||
SHOW DATABASES;
|
||||
"#;
|
||||
let stmts = parse_stmt(sql).unwrap();
|
||||
assert_eq!(stmts.len(), 4);
|
||||
for stmt in stmts {
|
||||
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
|
||||
assert!(re.is_ok());
|
||||
}
|
||||
|
||||
let sql = r#"
|
||||
SHOW CREATE TABLE demo;
|
||||
ALTER TABLE demo ADD COLUMN new_col INT;
|
||||
"#;
|
||||
let stmts = parse_stmt(sql).unwrap();
|
||||
assert_eq!(stmts.len(), 2);
|
||||
for stmt in stmts {
|
||||
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
|
||||
assert!(re.is_ok());
|
||||
}
|
||||
|
||||
let sql = "USE randomschema";
|
||||
let stmts = parse_stmt(sql).unwrap();
|
||||
let re = check_permission(plugins.clone(), &stmts[0], &query_ctx);
|
||||
assert!(re.is_ok());
|
||||
|
||||
fn replace_test(template_sql: &str, plugins: Arc<Plugins>, query_ctx: &QueryContextRef) {
|
||||
// test right
|
||||
let right = vec![("", ""), ("", "public."), ("greptime.", "public.")];
|
||||
for (catalog, schema) in right {
|
||||
let sql = do_fmt(template_sql, catalog, schema);
|
||||
do_test(&sql, plugins.clone(), query_ctx, true);
|
||||
}
|
||||
|
||||
let wrong = vec![
|
||||
("", "wrongschema."),
|
||||
("greptime.", "wrongschema."),
|
||||
("wrongcatalog.", "public."),
|
||||
("wrongcatalog.", "wrongschema."),
|
||||
];
|
||||
for (catalog, schema) in wrong {
|
||||
let sql = do_fmt(template_sql, catalog, schema);
|
||||
do_test(&sql, plugins.clone(), query_ctx, false);
|
||||
}
|
||||
}
|
||||
|
||||
fn do_fmt(template: &str, catalog: &str, schema: &str) -> String {
|
||||
let mut vars = HashMap::new();
|
||||
vars.insert("catalog".to_string(), catalog);
|
||||
vars.insert("schema".to_string(), schema);
|
||||
|
||||
template.format(&vars).unwrap()
|
||||
}
|
||||
|
||||
fn do_test(sql: &str, plugins: Arc<Plugins>, query_ctx: &QueryContextRef, is_ok: bool) {
|
||||
let stmt = &parse_stmt(sql).unwrap()[0];
|
||||
let re = check_permission(plugins.clone(), stmt, query_ctx);
|
||||
if is_ok {
|
||||
assert!(re.is_ok());
|
||||
} else {
|
||||
assert!(re.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
// test insert
|
||||
let sql = "INSERT INTO {catalog}{schema}monitor(host) VALUES ('host1');";
|
||||
replace_test(sql, plugins.clone(), &query_ctx);
|
||||
|
||||
// test create table
|
||||
let sql = r#"CREATE TABLE {catalog}{schema}demo(
|
||||
host STRING,
|
||||
ts TIMESTAMP,
|
||||
TIME INDEX (ts),
|
||||
PRIMARY KEY(host)
|
||||
) engine=mito with(regions=1);"#;
|
||||
replace_test(sql, plugins.clone(), &query_ctx);
|
||||
|
||||
// test drop table
|
||||
let sql = "DROP TABLE {catalog}{schema}demo;";
|
||||
replace_test(sql, plugins.clone(), &query_ctx);
|
||||
|
||||
// test show tables
|
||||
let sql = "SHOW TABLES FROM public";
|
||||
let stmt = parse_stmt(sql).unwrap();
|
||||
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
|
||||
assert!(re.is_ok());
|
||||
|
||||
let sql = "SHOW TABLES FROM wrongschema";
|
||||
let stmt = parse_stmt(sql).unwrap();
|
||||
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
|
||||
assert!(re.is_err());
|
||||
|
||||
// test describe table
|
||||
let sql = "DESC TABLE {catalog}{schema}demo;";
|
||||
replace_test(sql, plugins.clone(), &query_ctx);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_standalone_exec_sql() {
|
||||
let standalone = tests::create_standalone_instance("test_standalone_exec_sql").await;
|
||||
|
||||
@@ -26,6 +26,7 @@ use catalog::helper::{SchemaKey, SchemaValue};
|
||||
use catalog::{CatalogList, CatalogManager, DeregisterTableRequest, RegisterTableRequest};
|
||||
use chrono::DateTime;
|
||||
use client::Database;
|
||||
use common_base::Plugins;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_query::Output;
|
||||
@@ -79,8 +80,11 @@ impl DistInstance {
|
||||
meta_client: Arc<MetaClient>,
|
||||
catalog_manager: Arc<FrontendCatalogManager>,
|
||||
datanode_clients: Arc<DatanodeClients>,
|
||||
plugins: Arc<Plugins>,
|
||||
) -> Self {
|
||||
let query_engine = QueryEngineFactory::new(catalog_manager.clone()).query_engine();
|
||||
let query_engine =
|
||||
QueryEngineFactory::new_with_plugins(catalog_manager.clone(), plugins.clone())
|
||||
.query_engine();
|
||||
Self {
|
||||
meta_client,
|
||||
catalog_manager,
|
||||
@@ -422,7 +426,7 @@ impl DistInstance {
|
||||
self.meta_client
|
||||
.create_route(request)
|
||||
.await
|
||||
.context(error::RequestMetaSnafu)
|
||||
.context(RequestMetaSnafu)
|
||||
}
|
||||
|
||||
// TODO(LFC): Refactor insertion implementation for DistTable,
|
||||
@@ -621,8 +625,7 @@ fn find_partition_entries(
|
||||
let v = match v {
|
||||
SqlValue::Number(n, _) if n == "MAXVALUE" => PartitionBound::MaxValue,
|
||||
_ => PartitionBound::Value(
|
||||
sql_value_to_value(column_name, data_type, v)
|
||||
.context(error::ParseSqlSnafu)?,
|
||||
sql_value_to_value(column_name, data_type, v).context(ParseSqlSnafu)?,
|
||||
),
|
||||
};
|
||||
values.push(v);
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
|
||||
#![feature(assert_matches)]
|
||||
|
||||
pub type Plugins = anymap::Map<dyn core::any::Any + Send + Sync>;
|
||||
|
||||
mod catalog;
|
||||
mod datanode;
|
||||
pub mod error;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_base::Plugins;
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use common_telemetry::info;
|
||||
use servers::auth::UserProviderRef;
|
||||
@@ -37,7 +38,6 @@ use crate::frontend::FrontendOptions;
|
||||
use crate::influxdb::InfluxdbOptions;
|
||||
use crate::instance::FrontendInstance;
|
||||
use crate::prometheus::PrometheusOptions;
|
||||
use crate::Plugins;
|
||||
|
||||
pub(crate) struct Services;
|
||||
|
||||
|
||||
@@ -257,6 +257,7 @@ pub(crate) async fn create_distributed_instance(test_name: &str) -> MockDistribu
|
||||
meta_client.clone(),
|
||||
catalog_manager,
|
||||
datanode_clients.clone(),
|
||||
Default::default(),
|
||||
);
|
||||
let dist_instance = Arc::new(dist_instance);
|
||||
let frontend = Instance::new_distributed(dist_instance.clone());
|
||||
|
||||
@@ -797,7 +797,7 @@ mod test {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let query_engine_state = QueryEngineState::new(catalog_list);
|
||||
let query_engine_state = QueryEngineState::new(catalog_list, Default::default());
|
||||
let query_context = QueryContext::new();
|
||||
DfContextProviderAdapter::new(query_engine_state, query_context.into())
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ license.workspace = true
|
||||
arc-swap = "1.0"
|
||||
async-trait = "0.1"
|
||||
catalog = { path = "../catalog" }
|
||||
common-base = { path = "../common/base" }
|
||||
common-catalog = { path = "../common/catalog" }
|
||||
common-error = { path = "../common/error" }
|
||||
common-function = { path = "../common/function" }
|
||||
|
||||
@@ -22,6 +22,7 @@ use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use catalog::CatalogListRef;
|
||||
use common_base::Plugins;
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
|
||||
use common_function::scalars::udf::create_udf;
|
||||
@@ -60,9 +61,9 @@ pub(crate) struct DatafusionQueryEngine {
|
||||
}
|
||||
|
||||
impl DatafusionQueryEngine {
|
||||
pub fn new(catalog_list: CatalogListRef) -> Self {
|
||||
pub fn new(catalog_list: CatalogListRef, plugins: Arc<Plugins>) -> Self {
|
||||
Self {
|
||||
state: QueryEngineState::new(catalog_list.clone()),
|
||||
state: QueryEngineState::new(catalog_list.clone(), plugins),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -66,6 +66,9 @@ pub enum Error {
|
||||
#[snafu(display("Failure during query parsing, query: {}, source: {}", query, source))]
|
||||
QueryParse { query: String, source: BoxedError },
|
||||
|
||||
#[snafu(display("Illegal access to catalog: {} and schema: {}", catalog, schema))]
|
||||
QueryAccessDenied { catalog: String, schema: String },
|
||||
|
||||
#[snafu(display("The SQL string has multiple statements, query: {}", query))]
|
||||
MultipleStatements { query: String, backtrace: Backtrace },
|
||||
|
||||
@@ -83,6 +86,7 @@ impl ErrorExt for Error {
|
||||
| CatalogNotFound { .. }
|
||||
| SchemaNotFound { .. }
|
||||
| TableNotFound { .. } => StatusCode::InvalidArguments,
|
||||
QueryAccessDenied { .. } => StatusCode::AccessDenied,
|
||||
Catalog { source } => source.status_code(),
|
||||
VectorComputation { source } => source.status_code(),
|
||||
CreateRecordBatch { source } => source.status_code(),
|
||||
|
||||
@@ -13,12 +13,14 @@
|
||||
// limitations under the License.
|
||||
|
||||
mod context;
|
||||
pub mod options;
|
||||
mod state;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use catalog::CatalogListRef;
|
||||
use common_base::Plugins;
|
||||
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
|
||||
use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY};
|
||||
use common_query::physical_plan::PhysicalPlan;
|
||||
@@ -63,23 +65,29 @@ pub struct QueryEngineFactory {
|
||||
|
||||
impl QueryEngineFactory {
|
||||
pub fn new(catalog_list: CatalogListRef) -> Self {
|
||||
let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list));
|
||||
|
||||
for func in FUNCTION_REGISTRY.functions() {
|
||||
query_engine.register_function(func);
|
||||
}
|
||||
|
||||
for accumulator in FUNCTION_REGISTRY.aggregate_functions() {
|
||||
query_engine.register_aggregate_function(accumulator);
|
||||
}
|
||||
|
||||
let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list, Default::default()));
|
||||
register_functions(&query_engine);
|
||||
Self { query_engine }
|
||||
}
|
||||
|
||||
pub fn new_with_plugins(catalog_list: CatalogListRef, plugins: Arc<Plugins>) -> Self {
|
||||
let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list, plugins));
|
||||
register_functions(&query_engine);
|
||||
Self { query_engine }
|
||||
}
|
||||
|
||||
pub fn query_engine(&self) -> QueryEngineRef {
|
||||
self.query_engine.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryEngineFactory {
|
||||
pub fn query_engine(&self) -> QueryEngineRef {
|
||||
self.query_engine.clone()
|
||||
fn register_functions(query_engine: &Arc<DatafusionQueryEngine>) {
|
||||
for func in FUNCTION_REGISTRY.functions() {
|
||||
query_engine.register_function(func);
|
||||
}
|
||||
|
||||
for accumulator in FUNCTION_REGISTRY.aggregate_functions() {
|
||||
query_engine.register_aggregate_function(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
132
src/query/src/query_engine/options.rs
Normal file
132
src/query/src/query_engine/options.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use datafusion_common::TableReference;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::error::{QueryAccessDeniedSnafu, Result};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct QueryOptions {
|
||||
pub disallow_cross_schema_query: bool,
|
||||
}
|
||||
|
||||
pub fn validate_catalog_and_schema(
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
query_ctx: &QueryContextRef,
|
||||
) -> Result<()> {
|
||||
ensure!(
|
||||
catalog == query_ctx.current_catalog() && schema == query_ctx.current_schema(),
|
||||
QueryAccessDeniedSnafu {
|
||||
catalog: catalog.to_string(),
|
||||
schema: schema.to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_table_references(name: TableReference, query_ctx: &QueryContextRef) -> Result<()> {
|
||||
match name {
|
||||
TableReference::Bare { .. } => Ok(()),
|
||||
TableReference::Partial { schema, .. } => {
|
||||
ensure!(
|
||||
schema == query_ctx.current_schema(),
|
||||
QueryAccessDeniedSnafu {
|
||||
catalog: query_ctx.current_catalog(),
|
||||
schema: schema.to_string(),
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
TableReference::Full {
|
||||
catalog, schema, ..
|
||||
} => {
|
||||
ensure!(
|
||||
catalog == query_ctx.current_catalog() && schema == query_ctx.current_schema(),
|
||||
QueryAccessDeniedSnafu {
|
||||
catalog: catalog.to_string(),
|
||||
schema: schema.to_string(),
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use session::context::QueryContext;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_table_ref() {
|
||||
let context = Arc::new(QueryContext::with("greptime", "public"));
|
||||
|
||||
let table_ref = TableReference::Bare {
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_ok());
|
||||
|
||||
let table_ref = TableReference::Partial {
|
||||
schema: "public",
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_ok());
|
||||
|
||||
let table_ref = TableReference::Partial {
|
||||
schema: "wrong_schema",
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_err());
|
||||
|
||||
let table_ref = TableReference::Full {
|
||||
catalog: "greptime",
|
||||
schema: "public",
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_ok());
|
||||
|
||||
let table_ref = TableReference::Full {
|
||||
catalog: "wrong_catalog",
|
||||
schema: "public",
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_catalog_and_schema() {
|
||||
let context = Arc::new(QueryContext::with("greptime", "public"));
|
||||
|
||||
let re = validate_catalog_and_schema("greptime", "public", &context);
|
||||
assert!(re.is_ok());
|
||||
let re = validate_catalog_and_schema("greptime", "wrong_schema", &context);
|
||||
assert!(re.is_err());
|
||||
let re = validate_catalog_and_schema("wrong_catalog", "public", &context);
|
||||
assert!(re.is_err());
|
||||
let re = validate_catalog_and_schema("wrong_catalog", "wrong_schema", &context);
|
||||
assert!(re.is_err());
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ use std::sync::{Arc, RwLock};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use catalog::CatalogListRef;
|
||||
use common_base::Plugins;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
|
||||
use common_query::physical_plan::{SessionContext, TaskContext};
|
||||
@@ -39,6 +40,7 @@ use session::context::QueryContextRef;
|
||||
|
||||
use crate::datafusion::DfCatalogListAdapter;
|
||||
use crate::optimizer::TypeConversionRule;
|
||||
use crate::query_engine::options::{validate_table_references, QueryOptions};
|
||||
|
||||
/// Query engine global state
|
||||
// TODO(yingwen): This QueryEngineState still relies on datafusion, maybe we can define a trait for it,
|
||||
@@ -49,6 +51,7 @@ pub struct QueryEngineState {
|
||||
df_context: SessionContext,
|
||||
catalog_list: CatalogListRef,
|
||||
aggregate_functions: Arc<RwLock<HashMap<String, AggregateFunctionMetaRef>>>,
|
||||
plugins: Arc<Plugins>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for QueryEngineState {
|
||||
@@ -59,7 +62,7 @@ impl fmt::Debug for QueryEngineState {
|
||||
}
|
||||
|
||||
impl QueryEngineState {
|
||||
pub fn new(catalog_list: CatalogListRef) -> Self {
|
||||
pub fn new(catalog_list: CatalogListRef, plugins: Arc<Plugins>) -> Self {
|
||||
let runtime_env = Arc::new(RuntimeEnv::default());
|
||||
let session_config = SessionConfig::new()
|
||||
.with_default_catalog_and_schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME);
|
||||
@@ -78,6 +81,7 @@ impl QueryEngineState {
|
||||
df_context,
|
||||
catalog_list,
|
||||
aggregate_functions: Arc::new(RwLock::new(HashMap::new())),
|
||||
plugins,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,6 +124,13 @@ impl QueryEngineState {
|
||||
name: TableReference,
|
||||
) -> DfResult<Arc<dyn TableSource>> {
|
||||
let state = self.df_context.state();
|
||||
|
||||
if let Some(opts) = self.plugins.get::<QueryOptions>() {
|
||||
if opts.disallow_cross_schema_query {
|
||||
validate_table_references(name, &query_ctx)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let TableReference::Bare { table } = name {
|
||||
let name = TableReference::Partial {
|
||||
schema: &query_ctx.current_schema(),
|
||||
|
||||
@@ -21,8 +21,9 @@ mod function;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use catalog::local::{MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::{CatalogList, CatalogProvider, SchemaProvider};
|
||||
use common_base::Plugins;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_query::prelude::{create_udf, make_scalar_function, Volatility};
|
||||
@@ -36,6 +37,7 @@ use datatypes::vectors::UInt32Vector;
|
||||
use query::error::{QueryExecutionSnafu, Result};
|
||||
use query::parser::QueryLanguageParser;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::options::QueryOptions;
|
||||
use query::query_engine::QueryEngineFactory;
|
||||
use session::context::QueryContext;
|
||||
use snafu::ResultExt;
|
||||
@@ -107,9 +109,7 @@ async fn test_datafusion_query_engine() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_udf() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
fn catalog_list() -> Result<Arc<MemoryCatalogManager>> {
|
||||
let catalog_list = catalog::local::new_memory_catalog_list()
|
||||
.map_err(BoxedError::new)
|
||||
.context(QueryExecutionSnafu)?;
|
||||
@@ -125,6 +125,39 @@ async fn test_udf() -> Result<()> {
|
||||
catalog_list
|
||||
.register_catalog(DEFAULT_CATALOG_NAME.to_string(), default_catalog)
|
||||
.unwrap();
|
||||
Ok(catalog_list)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_validate() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let catalog_list = catalog_list()?;
|
||||
|
||||
// set plugins
|
||||
let mut plugins = Plugins::new();
|
||||
plugins.insert(QueryOptions {
|
||||
disallow_cross_schema_query: true,
|
||||
});
|
||||
let plugins = Arc::new(plugins);
|
||||
|
||||
let factory = QueryEngineFactory::new_with_plugins(catalog_list, plugins);
|
||||
let engine = factory.query_engine();
|
||||
|
||||
let stmt = QueryLanguageParser::parse_sql("select number from public.numbers").unwrap();
|
||||
let re = engine.statement_to_plan(stmt, Arc::new(QueryContext::new()));
|
||||
assert!(re.is_ok());
|
||||
|
||||
let stmt = QueryLanguageParser::parse_sql("select number from wrongschema.numbers").unwrap();
|
||||
let re = engine.statement_to_plan(stmt, Arc::new(QueryContext::new()));
|
||||
assert!(re.is_err());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_udf() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let catalog_list = catalog_list()?;
|
||||
|
||||
let factory = QueryEngineFactory::new(catalog_list);
|
||||
let engine = factory.query_engine();
|
||||
|
||||
Reference in New Issue
Block a user