From 9989a8c192a324a2b520b9be6aefdf3560d90e2e Mon Sep 17 00:00:00 2001 From: shuiyisong <113876041+shuiyisong@users.noreply.github.com> Date: Thu, 9 Feb 2023 17:23:28 +0800 Subject: [PATCH] fix: check full table name during logical plan creation (#948) --- Cargo.lock | 11 +- src/cmd/Cargo.toml | 1 + src/cmd/src/frontend.rs | 5 +- src/cmd/src/standalone.rs | 2 +- src/common/base/Cargo.toml | 1 + src/common/base/src/lib.rs | 2 + src/frontend/Cargo.toml | 2 +- src/frontend/src/frontend.rs | 2 +- src/frontend/src/instance.rs | 219 ++++++++++++++++++++++- src/frontend/src/instance/distributed.rs | 11 +- src/frontend/src/lib.rs | 2 - src/frontend/src/server.rs | 2 +- src/frontend/src/tests.rs | 1 + src/promql/src/planner.rs | 2 +- src/query/Cargo.toml | 1 + src/query/src/datafusion.rs | 5 +- src/query/src/error.rs | 4 + src/query/src/query_engine.rs | 34 ++-- src/query/src/query_engine/options.rs | 132 ++++++++++++++ src/query/src/query_engine/state.rs | 13 +- src/query/tests/query_engine_test.rs | 41 ++++- 21 files changed, 451 insertions(+), 42 deletions(-) create mode 100644 src/query/src/query_engine/options.rs diff --git a/Cargo.lock b/Cargo.lock index c792e77f64..d5752bfb61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1356,6 +1356,7 @@ dependencies = [ "anymap", "build-data", "clap 3.2.23", + "common-base", "common-error", "common-telemetry", "datanode", @@ -1396,6 +1397,7 @@ dependencies = [ name = "common-base" version = "0.1.0" dependencies = [ + "anymap", "bitvec", "bytes", "common-error", @@ -2663,7 +2665,6 @@ dependencies = [ name = "frontend" version = "0.1.0" dependencies = [ - "anymap", "api", "async-stream", "async-trait", @@ -2702,6 +2703,7 @@ dependencies = [ "snafu", "sql", "store-api", + "strfmt", "substrait 0.1.0", "table", "tempdir", @@ -5454,6 +5456,7 @@ dependencies = [ "arc-swap", "async-trait", "catalog", + "common-base", "common-catalog", "common-error", "common-function", @@ -7116,6 +7119,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "strfmt" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73de159e5e71c4c7579ea3041d3e765f46790555790c24489195554210f1fb4" + [[package]] name = "string_cache" version = "0.8.4" diff --git a/src/cmd/Cargo.toml b/src/cmd/Cargo.toml index d8ba0850c0..b55960817f 100644 --- a/src/cmd/Cargo.toml +++ b/src/cmd/Cargo.toml @@ -12,6 +12,7 @@ path = "src/bin/greptime.rs" [dependencies] anymap = "1.0.0-beta.2" clap = { version = "3.1", features = ["derive"] } +common-base = { path = "../common/base" } common-error = { path = "../common/error" } common-telemetry = { path = "../common/telemetry", features = [ "deadlock_detection", diff --git a/src/cmd/src/frontend.rs b/src/cmd/src/frontend.rs index 24e115988c..cf6587ca3d 100644 --- a/src/cmd/src/frontend.rs +++ b/src/cmd/src/frontend.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use clap::Parser; +use common_base::Plugins; use frontend::frontend::{Frontend, FrontendOptions}; use frontend::grpc::GrpcOptions; use frontend::influxdb::InfluxdbOptions; @@ -22,7 +23,6 @@ use frontend::instance::Instance; use frontend::mysql::MysqlOptions; use frontend::opentsdb::OpentsdbOptions; use frontend::postgres::PostgresOptions; -use frontend::Plugins; use meta_client::MetaClientOpts; use servers::auth::UserProviderRef; use servers::http::HttpOptions; @@ -91,10 +91,9 @@ impl StartCommand { let plugins = Arc::new(load_frontend_plugins(&self.user_provider)?); let opts: FrontendOptions = self.try_into()?; - let mut instance = Instance::try_new_distributed(&opts) + let instance = Instance::try_new_distributed(&opts, plugins.clone()) .await .context(error::StartFrontendSnafu)?; - instance.set_plugins(plugins.clone()); let mut frontend = Frontend::new(opts, instance, plugins); frontend.start().await.context(error::StartFrontendSnafu) diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index cb4efc072d..f3f1ea6ae5 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use clap::Parser; +use common_base::Plugins; use common_telemetry::info; use datanode::datanode::{Datanode, DatanodeOptions, ObjectStoreConfig, WalConfig}; use datanode::instance::InstanceRef; @@ -27,7 +28,6 @@ use frontend::opentsdb::OpentsdbOptions; use frontend::postgres::PostgresOptions; use frontend::prometheus::PrometheusOptions; use frontend::promql::PromqlOptions; -use frontend::Plugins; use serde::{Deserialize, Serialize}; use servers::http::HttpOptions; use servers::tls::{TlsMode, TlsOption}; diff --git a/src/common/base/Cargo.toml b/src/common/base/Cargo.toml index a4e785ae0b..9af84ecd55 100644 --- a/src/common/base/Cargo.toml +++ b/src/common/base/Cargo.toml @@ -5,6 +5,7 @@ edition.workspace = true license.workspace = true [dependencies] +anymap = "1.0.0-beta.2" bitvec = "1.0" bytes = { version = "1.1", features = ["serde"] } common-error = { path = "../error" } diff --git a/src/common/base/src/lib.rs b/src/common/base/src/lib.rs index e782b96967..cf3a19a771 100644 --- a/src/common/base/src/lib.rs +++ b/src/common/base/src/lib.rs @@ -19,3 +19,5 @@ pub mod bytes; pub mod readable_size; pub use bit_vec::BitVec; + +pub type Plugins = anymap::Map; diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 2738022252..716c51e539 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -5,7 +5,6 @@ edition.workspace = true license.workspace = true [dependencies] -anymap = "1.0.0-beta.2" api = { path = "../api" } async-stream.workspace = true async-trait = "0.1" @@ -52,5 +51,6 @@ tonic.workspace = true datanode = { path = "../datanode" } futures = "0.3" meta-srv = { path = "../meta-srv", features = ["mock"] } +strfmt = "0.2" tempdir = "0.3" tower = "0.4" diff --git a/src/frontend/src/frontend.rs b/src/frontend/src/frontend.rs index 899c5929bc..8cf492873b 100644 --- a/src/frontend/src/frontend.rs +++ b/src/frontend/src/frontend.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use common_base::Plugins; use meta_client::MetaClientOpts; use serde::{Deserialize, Serialize}; use servers::http::HttpOptions; @@ -30,7 +31,6 @@ use crate::postgres::PostgresOptions; use crate::prometheus::PrometheusOptions; use crate::promql::PromqlOptions; use crate::server::Services; -use crate::Plugins; #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(default)] diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index b92c2ddfdc..501b87ff77 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -29,10 +29,14 @@ use api::v1::{AddColumns, AlterExpr, Column, DdlRequest, InsertRequest}; use async_trait::async_trait; use catalog::remote::MetaKvBackend; use catalog::CatalogManagerRef; +use common_base::Plugins; +use common_error::ext::BoxedError; use common_grpc::channel_manager::{ChannelConfig, ChannelManager}; use common_query::Output; use common_recordbatch::RecordBatches; use common_telemetry::logging::{debug, info}; +use datafusion::sql::sqlparser::ast::ObjectName; +use datafusion_common::TableReference; use datanode::instance::InstanceRef as DnInstanceRef; use datatypes::schema::Schema; use distributed::DistInstance; @@ -40,6 +44,7 @@ use meta_client::client::{MetaClient, MetaClientBuilder}; use meta_client::MetaClientOpts; use partition::manager::PartitionRuleManager; use partition::route::TableRoutes; +use query::query_engine::options::QueryOptions; use servers::error as server_error; use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef}; use servers::promql::{PromqlHandler, PromqlHandlerRef}; @@ -58,12 +63,12 @@ use sql::statements::statement::Statement; use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ - self, Error, ExecutePromqlSnafu, MissingMetasrvOptsSnafu, NotSupportedSnafu, Result, + self, Error, ExecutePromqlSnafu, MissingMetasrvOptsSnafu, NotSupportedSnafu, ParseSqlSnafu, + Result, SqlExecInterceptedSnafu, }; use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory}; use crate::frontend::FrontendOptions; use crate::instance::standalone::{StandaloneGrpcQueryHandler, StandaloneSqlQueryHandler}; -use crate::Plugins; #[async_trait] pub trait FrontendInstance: @@ -101,7 +106,10 @@ pub struct Instance { } impl Instance { - pub async fn try_new_distributed(opts: &FrontendOptions) -> Result { + pub async fn try_new_distributed( + opts: &FrontendOptions, + plugins: Arc, + ) -> Result { let meta_client = Self::create_meta_client(opts).await?; let meta_backend = Arc::new(MetaKvBackend { @@ -117,8 +125,12 @@ impl Instance { datanode_clients.clone(), )); - let dist_instance = - DistInstance::new(meta_client, catalog_manager.clone(), datanode_clients); + let dist_instance = DistInstance::new( + meta_client, + catalog_manager.clone(), + datanode_clients, + plugins.clone(), + ); let dist_instance = Arc::new(dist_instance); Ok(Instance { @@ -128,7 +140,7 @@ impl Instance { sql_handler: dist_instance.clone(), grpc_query_handler: dist_instance, promql_handler: None, - plugins: Default::default(), + plugins, }) } @@ -365,11 +377,12 @@ impl FrontendInstance for Instance { } fn parse_stmt(sql: &str) -> Result> { - ParserContext::create_with_dialect(sql, &GenericDialect {}).context(error::ParseSqlSnafu) + ParserContext::create_with_dialect(sql, &GenericDialect {}).context(ParseSqlSnafu) } impl Instance { async fn query_statement(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { + check_permission(self.plugins.clone(), &stmt, &query_ctx)?; match stmt { Statement::CreateDatabase(_) | Statement::ShowDatabases(_) @@ -521,6 +534,87 @@ impl PromqlHandler for Instance { } } +pub fn check_permission( + plugins: Arc, + stmt: &Statement, + query_ctx: &QueryContextRef, +) -> Result<()> { + let need_validate = plugins + .get::() + .map(|opts| opts.disallow_cross_schema_query) + .unwrap_or_default(); + + if !need_validate { + return Ok(()); + } + + match stmt { + // query and explain will be checked in QueryEngineState + Statement::Query(_) | Statement::Explain(_) => {} + // database ops won't be checked + Statement::CreateDatabase(_) | Statement::ShowDatabases(_) | Statement::Use(_) => {} + // show create table and alter are not supported yet + Statement::ShowCreateTable(_) | Statement::Alter(_) => {} + + Statement::Insert(insert) => { + let (catalog, schema, _) = insert.full_table_name().context(ParseSqlSnafu)?; + validate_param(&catalog, &schema, query_ctx)?; + } + Statement::CreateTable(stmt) => { + let tab_ref = obj_name_to_tab_ref(&stmt.name)?; + validate_tab_ref(tab_ref, query_ctx)?; + } + Statement::DropTable(drop_stmt) => { + let tab_ref = obj_name_to_tab_ref(drop_stmt.table_name())?; + validate_tab_ref(tab_ref, query_ctx)?; + } + Statement::ShowTables(stmt) => { + if let Some(database) = &stmt.database { + validate_param(&query_ctx.current_catalog(), database, query_ctx)?; + } + } + Statement::DescribeTable(stmt) => { + let tab_ref = obj_name_to_tab_ref(stmt.name())?; + validate_tab_ref(tab_ref, query_ctx)?; + } + } + Ok(()) +} + +fn obj_name_to_tab_ref(obj: &ObjectName) -> Result { + match &obj.0[..] { + [table] => Ok(TableReference::Bare { + table: &table.value, + }), + [schema, table] => Ok(TableReference::Partial { + schema: &schema.value, + table: &table.value, + }), + [catalog, schema, table] => Ok(TableReference::Full { + catalog: &catalog.value, + schema: &schema.value, + table: &table.value, + }), + _ => error::InvalidSqlSnafu { + err_msg: format!( + "expect table name to be .., .
or
, actual: {obj}", + ), + }.fail(), + } +} + +fn validate_tab_ref(tab_ref: TableReference, query_ctx: &QueryContextRef) -> Result<()> { + query::query_engine::options::validate_table_references(tab_ref, query_ctx) + .map_err(BoxedError::new) + .context(SqlExecInterceptedSnafu) +} + +fn validate_param(catalog: &str, schema: &str, query_ctx: &QueryContextRef) -> Result<()> { + query::query_engine::options::validate_catalog_and_schema(catalog, schema, query_ctx) + .map_err(BoxedError::new) + .context(SqlExecInterceptedSnafu) +} + #[cfg(test)] mod tests { use std::borrow::Cow; @@ -528,13 +622,124 @@ mod tests { use std::sync::atomic::AtomicU32; use catalog::helper::{TableGlobalKey, TableGlobalValue}; + use query::query_engine::options::QueryOptions; use session::context::QueryContext; + use strfmt::Format; use super::*; use crate::table::DistTable; use crate::tests; use crate::tests::MockDistributedInstance; + #[test] + fn test_exec_validation() { + let query_ctx = Arc::new(QueryContext::new()); + let mut plugins = Plugins::new(); + plugins.insert(QueryOptions { + disallow_cross_schema_query: true, + }); + let plugins = Arc::new(plugins); + + let sql = r#" + SELECT * FROM demo; + EXPLAIN SELECT * FROM demo; + CREATE DATABASE test_database; + SHOW DATABASES; + "#; + let stmts = parse_stmt(sql).unwrap(); + assert_eq!(stmts.len(), 4); + for stmt in stmts { + let re = check_permission(plugins.clone(), &stmt, &query_ctx); + assert!(re.is_ok()); + } + + let sql = r#" + SHOW CREATE TABLE demo; + ALTER TABLE demo ADD COLUMN new_col INT; + "#; + let stmts = parse_stmt(sql).unwrap(); + assert_eq!(stmts.len(), 2); + for stmt in stmts { + let re = check_permission(plugins.clone(), &stmt, &query_ctx); + assert!(re.is_ok()); + } + + let sql = "USE randomschema"; + let stmts = parse_stmt(sql).unwrap(); + let re = check_permission(plugins.clone(), &stmts[0], &query_ctx); + assert!(re.is_ok()); + + fn replace_test(template_sql: &str, plugins: Arc, query_ctx: &QueryContextRef) { + // test right + let right = vec![("", ""), ("", "public."), ("greptime.", "public.")]; + for (catalog, schema) in right { + let sql = do_fmt(template_sql, catalog, schema); + do_test(&sql, plugins.clone(), query_ctx, true); + } + + let wrong = vec![ + ("", "wrongschema."), + ("greptime.", "wrongschema."), + ("wrongcatalog.", "public."), + ("wrongcatalog.", "wrongschema."), + ]; + for (catalog, schema) in wrong { + let sql = do_fmt(template_sql, catalog, schema); + do_test(&sql, plugins.clone(), query_ctx, false); + } + } + + fn do_fmt(template: &str, catalog: &str, schema: &str) -> String { + let mut vars = HashMap::new(); + vars.insert("catalog".to_string(), catalog); + vars.insert("schema".to_string(), schema); + + template.format(&vars).unwrap() + } + + fn do_test(sql: &str, plugins: Arc, query_ctx: &QueryContextRef, is_ok: bool) { + let stmt = &parse_stmt(sql).unwrap()[0]; + let re = check_permission(plugins.clone(), stmt, query_ctx); + if is_ok { + assert!(re.is_ok()); + } else { + assert!(re.is_err()); + } + } + + // test insert + let sql = "INSERT INTO {catalog}{schema}monitor(host) VALUES ('host1');"; + replace_test(sql, plugins.clone(), &query_ctx); + + // test create table + let sql = r#"CREATE TABLE {catalog}{schema}demo( + host STRING, + ts TIMESTAMP, + TIME INDEX (ts), + PRIMARY KEY(host) + ) engine=mito with(regions=1);"#; + replace_test(sql, plugins.clone(), &query_ctx); + + // test drop table + let sql = "DROP TABLE {catalog}{schema}demo;"; + replace_test(sql, plugins.clone(), &query_ctx); + + // test show tables + let sql = "SHOW TABLES FROM public"; + let stmt = parse_stmt(sql).unwrap(); + let re = check_permission(plugins.clone(), &stmt[0], &query_ctx); + assert!(re.is_ok()); + + let sql = "SHOW TABLES FROM wrongschema"; + let stmt = parse_stmt(sql).unwrap(); + let re = check_permission(plugins.clone(), &stmt[0], &query_ctx); + assert!(re.is_err()); + + // test describe table + let sql = "DESC TABLE {catalog}{schema}demo;"; + replace_test(sql, plugins.clone(), &query_ctx); + } + #[tokio::test(flavor = "multi_thread")] async fn test_standalone_exec_sql() { let standalone = tests::create_standalone_instance("test_standalone_exec_sql").await; diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index 5398762006..d504b05d6e 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -26,6 +26,7 @@ use catalog::helper::{SchemaKey, SchemaValue}; use catalog::{CatalogList, CatalogManager, DeregisterTableRequest, RegisterTableRequest}; use chrono::DateTime; use client::Database; +use common_base::Plugins; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::prelude::BoxedError; use common_query::Output; @@ -79,8 +80,11 @@ impl DistInstance { meta_client: Arc, catalog_manager: Arc, datanode_clients: Arc, + plugins: Arc, ) -> Self { - let query_engine = QueryEngineFactory::new(catalog_manager.clone()).query_engine(); + let query_engine = + QueryEngineFactory::new_with_plugins(catalog_manager.clone(), plugins.clone()) + .query_engine(); Self { meta_client, catalog_manager, @@ -422,7 +426,7 @@ impl DistInstance { self.meta_client .create_route(request) .await - .context(error::RequestMetaSnafu) + .context(RequestMetaSnafu) } // TODO(LFC): Refactor insertion implementation for DistTable, @@ -621,8 +625,7 @@ fn find_partition_entries( let v = match v { SqlValue::Number(n, _) if n == "MAXVALUE" => PartitionBound::MaxValue, _ => PartitionBound::Value( - sql_value_to_value(column_name, data_type, v) - .context(error::ParseSqlSnafu)?, + sql_value_to_value(column_name, data_type, v).context(ParseSqlSnafu)?, ), }; values.push(v); diff --git a/src/frontend/src/lib.rs b/src/frontend/src/lib.rs index 36fa94ec1f..ea9fd25d70 100644 --- a/src/frontend/src/lib.rs +++ b/src/frontend/src/lib.rs @@ -14,8 +14,6 @@ #![feature(assert_matches)] -pub type Plugins = anymap::Map; - mod catalog; mod datanode; pub mod error; diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index daff47ed26..2b07fc6fd0 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -15,6 +15,7 @@ use std::net::SocketAddr; use std::sync::Arc; +use common_base::Plugins; use common_runtime::Builder as RuntimeBuilder; use common_telemetry::info; use servers::auth::UserProviderRef; @@ -37,7 +38,6 @@ use crate::frontend::FrontendOptions; use crate::influxdb::InfluxdbOptions; use crate::instance::FrontendInstance; use crate::prometheus::PrometheusOptions; -use crate::Plugins; pub(crate) struct Services; diff --git a/src/frontend/src/tests.rs b/src/frontend/src/tests.rs index e2d5e116c9..d4a54b7886 100644 --- a/src/frontend/src/tests.rs +++ b/src/frontend/src/tests.rs @@ -257,6 +257,7 @@ pub(crate) async fn create_distributed_instance(test_name: &str) -> MockDistribu meta_client.clone(), catalog_manager, datanode_clients.clone(), + Default::default(), ); let dist_instance = Arc::new(dist_instance); let frontend = Instance::new_distributed(dist_instance.clone()); diff --git a/src/promql/src/planner.rs b/src/promql/src/planner.rs index 1b749e9142..72c76402e9 100644 --- a/src/promql/src/planner.rs +++ b/src/promql/src/planner.rs @@ -797,7 +797,7 @@ mod test { .await .unwrap(); - let query_engine_state = QueryEngineState::new(catalog_list); + let query_engine_state = QueryEngineState::new(catalog_list, Default::default()); let query_context = QueryContext::new(); DfContextProviderAdapter::new(query_engine_state, query_context.into()) } diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index 2ba78232d4..031622c6d1 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true arc-swap = "1.0" async-trait = "0.1" catalog = { path = "../catalog" } +common-base = { path = "../common/base" } common-catalog = { path = "../common/catalog" } common-error = { path = "../common/error" } common-function = { path = "../common/function" } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 813892becf..9730c4d62d 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use async_trait::async_trait; use catalog::CatalogListRef; +use common_base::Plugins; use common_error::prelude::BoxedError; use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_function::scalars::udf::create_udf; @@ -60,9 +61,9 @@ pub(crate) struct DatafusionQueryEngine { } impl DatafusionQueryEngine { - pub fn new(catalog_list: CatalogListRef) -> Self { + pub fn new(catalog_list: CatalogListRef, plugins: Arc) -> Self { Self { - state: QueryEngineState::new(catalog_list.clone()), + state: QueryEngineState::new(catalog_list.clone(), plugins), } } diff --git a/src/query/src/error.rs b/src/query/src/error.rs index 26b88107de..e9d5058daf 100644 --- a/src/query/src/error.rs +++ b/src/query/src/error.rs @@ -66,6 +66,9 @@ pub enum Error { #[snafu(display("Failure during query parsing, query: {}, source: {}", query, source))] QueryParse { query: String, source: BoxedError }, + #[snafu(display("Illegal access to catalog: {} and schema: {}", catalog, schema))] + QueryAccessDenied { catalog: String, schema: String }, + #[snafu(display("The SQL string has multiple statements, query: {}", query))] MultipleStatements { query: String, backtrace: Backtrace }, @@ -83,6 +86,7 @@ impl ErrorExt for Error { | CatalogNotFound { .. } | SchemaNotFound { .. } | TableNotFound { .. } => StatusCode::InvalidArguments, + QueryAccessDenied { .. } => StatusCode::AccessDenied, Catalog { source } => source.status_code(), VectorComputation { source } => source.status_code(), CreateRecordBatch { source } => source.status_code(), diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 6a7e748aab..3d4adfda83 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -13,12 +13,14 @@ // limitations under the License. mod context; +pub mod options; mod state; use std::sync::Arc; use async_trait::async_trait; use catalog::CatalogListRef; +use common_base::Plugins; use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY}; use common_query::physical_plan::PhysicalPlan; @@ -63,23 +65,29 @@ pub struct QueryEngineFactory { impl QueryEngineFactory { pub fn new(catalog_list: CatalogListRef) -> Self { - let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list)); - - for func in FUNCTION_REGISTRY.functions() { - query_engine.register_function(func); - } - - for accumulator in FUNCTION_REGISTRY.aggregate_functions() { - query_engine.register_aggregate_function(accumulator); - } - + let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list, Default::default())); + register_functions(&query_engine); Self { query_engine } } + + pub fn new_with_plugins(catalog_list: CatalogListRef, plugins: Arc) -> Self { + let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list, plugins)); + register_functions(&query_engine); + Self { query_engine } + } + + pub fn query_engine(&self) -> QueryEngineRef { + self.query_engine.clone() + } } -impl QueryEngineFactory { - pub fn query_engine(&self) -> QueryEngineRef { - self.query_engine.clone() +fn register_functions(query_engine: &Arc) { + for func in FUNCTION_REGISTRY.functions() { + query_engine.register_function(func); + } + + for accumulator in FUNCTION_REGISTRY.aggregate_functions() { + query_engine.register_aggregate_function(accumulator); } } diff --git a/src/query/src/query_engine/options.rs b/src/query/src/query_engine/options.rs new file mode 100644 index 0000000000..ec5bf2e8d2 --- /dev/null +++ b/src/query/src/query_engine/options.rs @@ -0,0 +1,132 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use datafusion_common::TableReference; +use session::context::QueryContextRef; +use snafu::ensure; + +use crate::error::{QueryAccessDeniedSnafu, Result}; + +#[derive(Default)] +pub struct QueryOptions { + pub disallow_cross_schema_query: bool, +} + +pub fn validate_catalog_and_schema( + catalog: &str, + schema: &str, + query_ctx: &QueryContextRef, +) -> Result<()> { + ensure!( + catalog == query_ctx.current_catalog() && schema == query_ctx.current_schema(), + QueryAccessDeniedSnafu { + catalog: catalog.to_string(), + schema: schema.to_string(), + } + ); + + Ok(()) +} + +pub fn validate_table_references(name: TableReference, query_ctx: &QueryContextRef) -> Result<()> { + match name { + TableReference::Bare { .. } => Ok(()), + TableReference::Partial { schema, .. } => { + ensure!( + schema == query_ctx.current_schema(), + QueryAccessDeniedSnafu { + catalog: query_ctx.current_catalog(), + schema: schema.to_string(), + } + ); + Ok(()) + } + TableReference::Full { + catalog, schema, .. + } => { + ensure!( + catalog == query_ctx.current_catalog() && schema == query_ctx.current_schema(), + QueryAccessDeniedSnafu { + catalog: catalog.to_string(), + schema: schema.to_string(), + } + ); + Ok(()) + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use session::context::QueryContext; + + use super::*; + + #[test] + fn test_validate_table_ref() { + let context = Arc::new(QueryContext::with("greptime", "public")); + + let table_ref = TableReference::Bare { + table: "table_name", + }; + let re = validate_table_references(table_ref, &context); + assert!(re.is_ok()); + + let table_ref = TableReference::Partial { + schema: "public", + table: "table_name", + }; + let re = validate_table_references(table_ref, &context); + assert!(re.is_ok()); + + let table_ref = TableReference::Partial { + schema: "wrong_schema", + table: "table_name", + }; + let re = validate_table_references(table_ref, &context); + assert!(re.is_err()); + + let table_ref = TableReference::Full { + catalog: "greptime", + schema: "public", + table: "table_name", + }; + let re = validate_table_references(table_ref, &context); + assert!(re.is_ok()); + + let table_ref = TableReference::Full { + catalog: "wrong_catalog", + schema: "public", + table: "table_name", + }; + let re = validate_table_references(table_ref, &context); + assert!(re.is_err()); + } + + #[test] + fn test_validate_catalog_and_schema() { + let context = Arc::new(QueryContext::with("greptime", "public")); + + let re = validate_catalog_and_schema("greptime", "public", &context); + assert!(re.is_ok()); + let re = validate_catalog_and_schema("greptime", "wrong_schema", &context); + assert!(re.is_err()); + let re = validate_catalog_and_schema("wrong_catalog", "public", &context); + assert!(re.is_err()); + let re = validate_catalog_and_schema("wrong_catalog", "wrong_schema", &context); + assert!(re.is_err()); + } +} diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 0e0921a038..9f6246e07e 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -18,6 +18,7 @@ use std::sync::{Arc, RwLock}; use async_trait::async_trait; use catalog::CatalogListRef; +use common_base::Plugins; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_query::physical_plan::{SessionContext, TaskContext}; @@ -39,6 +40,7 @@ use session::context::QueryContextRef; use crate::datafusion::DfCatalogListAdapter; use crate::optimizer::TypeConversionRule; +use crate::query_engine::options::{validate_table_references, QueryOptions}; /// Query engine global state // TODO(yingwen): This QueryEngineState still relies on datafusion, maybe we can define a trait for it, @@ -49,6 +51,7 @@ pub struct QueryEngineState { df_context: SessionContext, catalog_list: CatalogListRef, aggregate_functions: Arc>>, + plugins: Arc, } impl fmt::Debug for QueryEngineState { @@ -59,7 +62,7 @@ impl fmt::Debug for QueryEngineState { } impl QueryEngineState { - pub fn new(catalog_list: CatalogListRef) -> Self { + pub fn new(catalog_list: CatalogListRef, plugins: Arc) -> Self { let runtime_env = Arc::new(RuntimeEnv::default()); let session_config = SessionConfig::new() .with_default_catalog_and_schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME); @@ -78,6 +81,7 @@ impl QueryEngineState { df_context, catalog_list, aggregate_functions: Arc::new(RwLock::new(HashMap::new())), + plugins, } } @@ -120,6 +124,13 @@ impl QueryEngineState { name: TableReference, ) -> DfResult> { let state = self.df_context.state(); + + if let Some(opts) = self.plugins.get::() { + if opts.disallow_cross_schema_query { + validate_table_references(name, &query_ctx)?; + } + } + if let TableReference::Bare { table } = name { let name = TableReference::Partial { schema: &query_ctx.current_schema(), diff --git a/src/query/tests/query_engine_test.rs b/src/query/tests/query_engine_test.rs index d479805602..9c642427f8 100644 --- a/src/query/tests/query_engine_test.rs +++ b/src/query/tests/query_engine_test.rs @@ -21,8 +21,9 @@ mod function; use std::sync::Arc; -use catalog::local::{MemoryCatalogProvider, MemorySchemaProvider}; +use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; use catalog::{CatalogList, CatalogProvider, SchemaProvider}; +use common_base::Plugins; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::prelude::BoxedError; use common_query::prelude::{create_udf, make_scalar_function, Volatility}; @@ -36,6 +37,7 @@ use datatypes::vectors::UInt32Vector; use query::error::{QueryExecutionSnafu, Result}; use query::parser::QueryLanguageParser; use query::plan::LogicalPlan; +use query::query_engine::options::QueryOptions; use query::query_engine::QueryEngineFactory; use session::context::QueryContext; use snafu::ResultExt; @@ -107,9 +109,7 @@ async fn test_datafusion_query_engine() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_udf() -> Result<()> { - common_telemetry::init_default_ut_logging(); +fn catalog_list() -> Result> { let catalog_list = catalog::local::new_memory_catalog_list() .map_err(BoxedError::new) .context(QueryExecutionSnafu)?; @@ -125,6 +125,39 @@ async fn test_udf() -> Result<()> { catalog_list .register_catalog(DEFAULT_CATALOG_NAME.to_string(), default_catalog) .unwrap(); + Ok(catalog_list) +} + +#[test] +fn test_query_validate() -> Result<()> { + common_telemetry::init_default_ut_logging(); + let catalog_list = catalog_list()?; + + // set plugins + let mut plugins = Plugins::new(); + plugins.insert(QueryOptions { + disallow_cross_schema_query: true, + }); + let plugins = Arc::new(plugins); + + let factory = QueryEngineFactory::new_with_plugins(catalog_list, plugins); + let engine = factory.query_engine(); + + let stmt = QueryLanguageParser::parse_sql("select number from public.numbers").unwrap(); + let re = engine.statement_to_plan(stmt, Arc::new(QueryContext::new())); + assert!(re.is_ok()); + + let stmt = QueryLanguageParser::parse_sql("select number from wrongschema.numbers").unwrap(); + let re = engine.statement_to_plan(stmt, Arc::new(QueryContext::new())); + assert!(re.is_err()); + + Ok(()) +} + +#[tokio::test] +async fn test_udf() -> Result<()> { + common_telemetry::init_default_ut_logging(); + let catalog_list = catalog_list()?; let factory = QueryEngineFactory::new(catalog_list); let engine = factory.query_engine();