fix: check full table name during logical plan creation (#948)

This commit is contained in:
shuiyisong
2023-02-09 17:23:28 +08:00
committed by GitHub
parent 19dd8b1246
commit 9989a8c192
21 changed files with 451 additions and 42 deletions

View File

@@ -5,7 +5,6 @@ edition.workspace = true
license.workspace = true
[dependencies]
anymap = "1.0.0-beta.2"
api = { path = "../api" }
async-stream.workspace = true
async-trait = "0.1"
@@ -52,5 +51,6 @@ tonic.workspace = true
datanode = { path = "../datanode" }
futures = "0.3"
meta-srv = { path = "../meta-srv", features = ["mock"] }
strfmt = "0.2"
tempdir = "0.3"
tower = "0.4"

View File

@@ -14,6 +14,7 @@
use std::sync::Arc;
use common_base::Plugins;
use meta_client::MetaClientOpts;
use serde::{Deserialize, Serialize};
use servers::http::HttpOptions;
@@ -30,7 +31,6 @@ use crate::postgres::PostgresOptions;
use crate::prometheus::PrometheusOptions;
use crate::promql::PromqlOptions;
use crate::server::Services;
use crate::Plugins;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]

View File

@@ -29,10 +29,14 @@ use api::v1::{AddColumns, AlterExpr, Column, DdlRequest, InsertRequest};
use async_trait::async_trait;
use catalog::remote::MetaKvBackend;
use catalog::CatalogManagerRef;
use common_base::Plugins;
use common_error::ext::BoxedError;
use common_grpc::channel_manager::{ChannelConfig, ChannelManager};
use common_query::Output;
use common_recordbatch::RecordBatches;
use common_telemetry::logging::{debug, info};
use datafusion::sql::sqlparser::ast::ObjectName;
use datafusion_common::TableReference;
use datanode::instance::InstanceRef as DnInstanceRef;
use datatypes::schema::Schema;
use distributed::DistInstance;
@@ -40,6 +44,7 @@ use meta_client::client::{MetaClient, MetaClientBuilder};
use meta_client::MetaClientOpts;
use partition::manager::PartitionRuleManager;
use partition::route::TableRoutes;
use query::query_engine::options::QueryOptions;
use servers::error as server_error;
use servers::interceptor::{SqlQueryInterceptor, SqlQueryInterceptorRef};
use servers::promql::{PromqlHandler, PromqlHandlerRef};
@@ -58,12 +63,12 @@ use sql::statements::statement::Statement;
use crate::catalog::FrontendCatalogManager;
use crate::datanode::DatanodeClients;
use crate::error::{
self, Error, ExecutePromqlSnafu, MissingMetasrvOptsSnafu, NotSupportedSnafu, Result,
self, Error, ExecutePromqlSnafu, MissingMetasrvOptsSnafu, NotSupportedSnafu, ParseSqlSnafu,
Result, SqlExecInterceptedSnafu,
};
use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory};
use crate::frontend::FrontendOptions;
use crate::instance::standalone::{StandaloneGrpcQueryHandler, StandaloneSqlQueryHandler};
use crate::Plugins;
#[async_trait]
pub trait FrontendInstance:
@@ -101,7 +106,10 @@ pub struct Instance {
}
impl Instance {
pub async fn try_new_distributed(opts: &FrontendOptions) -> Result<Self> {
pub async fn try_new_distributed(
opts: &FrontendOptions,
plugins: Arc<Plugins>,
) -> Result<Self> {
let meta_client = Self::create_meta_client(opts).await?;
let meta_backend = Arc::new(MetaKvBackend {
@@ -117,8 +125,12 @@ impl Instance {
datanode_clients.clone(),
));
let dist_instance =
DistInstance::new(meta_client, catalog_manager.clone(), datanode_clients);
let dist_instance = DistInstance::new(
meta_client,
catalog_manager.clone(),
datanode_clients,
plugins.clone(),
);
let dist_instance = Arc::new(dist_instance);
Ok(Instance {
@@ -128,7 +140,7 @@ impl Instance {
sql_handler: dist_instance.clone(),
grpc_query_handler: dist_instance,
promql_handler: None,
plugins: Default::default(),
plugins,
})
}
@@ -365,11 +377,12 @@ impl FrontendInstance for Instance {
}
fn parse_stmt(sql: &str) -> Result<Vec<Statement>> {
ParserContext::create_with_dialect(sql, &GenericDialect {}).context(error::ParseSqlSnafu)
ParserContext::create_with_dialect(sql, &GenericDialect {}).context(ParseSqlSnafu)
}
impl Instance {
async fn query_statement(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result<Output> {
check_permission(self.plugins.clone(), &stmt, &query_ctx)?;
match stmt {
Statement::CreateDatabase(_)
| Statement::ShowDatabases(_)
@@ -521,6 +534,87 @@ impl PromqlHandler for Instance {
}
}
pub fn check_permission(
plugins: Arc<Plugins>,
stmt: &Statement,
query_ctx: &QueryContextRef,
) -> Result<()> {
let need_validate = plugins
.get::<QueryOptions>()
.map(|opts| opts.disallow_cross_schema_query)
.unwrap_or_default();
if !need_validate {
return Ok(());
}
match stmt {
// query and explain will be checked in QueryEngineState
Statement::Query(_) | Statement::Explain(_) => {}
// database ops won't be checked
Statement::CreateDatabase(_) | Statement::ShowDatabases(_) | Statement::Use(_) => {}
// show create table and alter are not supported yet
Statement::ShowCreateTable(_) | Statement::Alter(_) => {}
Statement::Insert(insert) => {
let (catalog, schema, _) = insert.full_table_name().context(ParseSqlSnafu)?;
validate_param(&catalog, &schema, query_ctx)?;
}
Statement::CreateTable(stmt) => {
let tab_ref = obj_name_to_tab_ref(&stmt.name)?;
validate_tab_ref(tab_ref, query_ctx)?;
}
Statement::DropTable(drop_stmt) => {
let tab_ref = obj_name_to_tab_ref(drop_stmt.table_name())?;
validate_tab_ref(tab_ref, query_ctx)?;
}
Statement::ShowTables(stmt) => {
if let Some(database) = &stmt.database {
validate_param(&query_ctx.current_catalog(), database, query_ctx)?;
}
}
Statement::DescribeTable(stmt) => {
let tab_ref = obj_name_to_tab_ref(stmt.name())?;
validate_tab_ref(tab_ref, query_ctx)?;
}
}
Ok(())
}
fn obj_name_to_tab_ref(obj: &ObjectName) -> Result<TableReference> {
match &obj.0[..] {
[table] => Ok(TableReference::Bare {
table: &table.value,
}),
[schema, table] => Ok(TableReference::Partial {
schema: &schema.value,
table: &table.value,
}),
[catalog, schema, table] => Ok(TableReference::Full {
catalog: &catalog.value,
schema: &schema.value,
table: &table.value,
}),
_ => error::InvalidSqlSnafu {
err_msg: format!(
"expect table name to be <catalog>.<schema>.<table>, <schema>.<table> or <table>, actual: {obj}",
),
}.fail(),
}
}
fn validate_tab_ref(tab_ref: TableReference, query_ctx: &QueryContextRef) -> Result<()> {
query::query_engine::options::validate_table_references(tab_ref, query_ctx)
.map_err(BoxedError::new)
.context(SqlExecInterceptedSnafu)
}
fn validate_param(catalog: &str, schema: &str, query_ctx: &QueryContextRef) -> Result<()> {
query::query_engine::options::validate_catalog_and_schema(catalog, schema, query_ctx)
.map_err(BoxedError::new)
.context(SqlExecInterceptedSnafu)
}
#[cfg(test)]
mod tests {
use std::borrow::Cow;
@@ -528,13 +622,124 @@ mod tests {
use std::sync::atomic::AtomicU32;
use catalog::helper::{TableGlobalKey, TableGlobalValue};
use query::query_engine::options::QueryOptions;
use session::context::QueryContext;
use strfmt::Format;
use super::*;
use crate::table::DistTable;
use crate::tests;
use crate::tests::MockDistributedInstance;
#[test]
fn test_exec_validation() {
let query_ctx = Arc::new(QueryContext::new());
let mut plugins = Plugins::new();
plugins.insert(QueryOptions {
disallow_cross_schema_query: true,
});
let plugins = Arc::new(plugins);
let sql = r#"
SELECT * FROM demo;
EXPLAIN SELECT * FROM demo;
CREATE DATABASE test_database;
SHOW DATABASES;
"#;
let stmts = parse_stmt(sql).unwrap();
assert_eq!(stmts.len(), 4);
for stmt in stmts {
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
assert!(re.is_ok());
}
let sql = r#"
SHOW CREATE TABLE demo;
ALTER TABLE demo ADD COLUMN new_col INT;
"#;
let stmts = parse_stmt(sql).unwrap();
assert_eq!(stmts.len(), 2);
for stmt in stmts {
let re = check_permission(plugins.clone(), &stmt, &query_ctx);
assert!(re.is_ok());
}
let sql = "USE randomschema";
let stmts = parse_stmt(sql).unwrap();
let re = check_permission(plugins.clone(), &stmts[0], &query_ctx);
assert!(re.is_ok());
fn replace_test(template_sql: &str, plugins: Arc<Plugins>, query_ctx: &QueryContextRef) {
// test right
let right = vec![("", ""), ("", "public."), ("greptime.", "public.")];
for (catalog, schema) in right {
let sql = do_fmt(template_sql, catalog, schema);
do_test(&sql, plugins.clone(), query_ctx, true);
}
let wrong = vec![
("", "wrongschema."),
("greptime.", "wrongschema."),
("wrongcatalog.", "public."),
("wrongcatalog.", "wrongschema."),
];
for (catalog, schema) in wrong {
let sql = do_fmt(template_sql, catalog, schema);
do_test(&sql, plugins.clone(), query_ctx, false);
}
}
fn do_fmt(template: &str, catalog: &str, schema: &str) -> String {
let mut vars = HashMap::new();
vars.insert("catalog".to_string(), catalog);
vars.insert("schema".to_string(), schema);
template.format(&vars).unwrap()
}
fn do_test(sql: &str, plugins: Arc<Plugins>, query_ctx: &QueryContextRef, is_ok: bool) {
let stmt = &parse_stmt(sql).unwrap()[0];
let re = check_permission(plugins.clone(), stmt, query_ctx);
if is_ok {
assert!(re.is_ok());
} else {
assert!(re.is_err());
}
}
// test insert
let sql = "INSERT INTO {catalog}{schema}monitor(host) VALUES ('host1');";
replace_test(sql, plugins.clone(), &query_ctx);
// test create table
let sql = r#"CREATE TABLE {catalog}{schema}demo(
host STRING,
ts TIMESTAMP,
TIME INDEX (ts),
PRIMARY KEY(host)
) engine=mito with(regions=1);"#;
replace_test(sql, plugins.clone(), &query_ctx);
// test drop table
let sql = "DROP TABLE {catalog}{schema}demo;";
replace_test(sql, plugins.clone(), &query_ctx);
// test show tables
let sql = "SHOW TABLES FROM public";
let stmt = parse_stmt(sql).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_ok());
let sql = "SHOW TABLES FROM wrongschema";
let stmt = parse_stmt(sql).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_err());
// test describe table
let sql = "DESC TABLE {catalog}{schema}demo;";
replace_test(sql, plugins.clone(), &query_ctx);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_standalone_exec_sql() {
let standalone = tests::create_standalone_instance("test_standalone_exec_sql").await;

View File

@@ -26,6 +26,7 @@ use catalog::helper::{SchemaKey, SchemaValue};
use catalog::{CatalogList, CatalogManager, DeregisterTableRequest, RegisterTableRequest};
use chrono::DateTime;
use client::Database;
use common_base::Plugins;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_error::prelude::BoxedError;
use common_query::Output;
@@ -79,8 +80,11 @@ impl DistInstance {
meta_client: Arc<MetaClient>,
catalog_manager: Arc<FrontendCatalogManager>,
datanode_clients: Arc<DatanodeClients>,
plugins: Arc<Plugins>,
) -> Self {
let query_engine = QueryEngineFactory::new(catalog_manager.clone()).query_engine();
let query_engine =
QueryEngineFactory::new_with_plugins(catalog_manager.clone(), plugins.clone())
.query_engine();
Self {
meta_client,
catalog_manager,
@@ -422,7 +426,7 @@ impl DistInstance {
self.meta_client
.create_route(request)
.await
.context(error::RequestMetaSnafu)
.context(RequestMetaSnafu)
}
// TODO(LFC): Refactor insertion implementation for DistTable,
@@ -621,8 +625,7 @@ fn find_partition_entries(
let v = match v {
SqlValue::Number(n, _) if n == "MAXVALUE" => PartitionBound::MaxValue,
_ => PartitionBound::Value(
sql_value_to_value(column_name, data_type, v)
.context(error::ParseSqlSnafu)?,
sql_value_to_value(column_name, data_type, v).context(ParseSqlSnafu)?,
),
};
values.push(v);

View File

@@ -14,8 +14,6 @@
#![feature(assert_matches)]
pub type Plugins = anymap::Map<dyn core::any::Any + Send + Sync>;
mod catalog;
mod datanode;
pub mod error;

View File

@@ -15,6 +15,7 @@
use std::net::SocketAddr;
use std::sync::Arc;
use common_base::Plugins;
use common_runtime::Builder as RuntimeBuilder;
use common_telemetry::info;
use servers::auth::UserProviderRef;
@@ -37,7 +38,6 @@ use crate::frontend::FrontendOptions;
use crate::influxdb::InfluxdbOptions;
use crate::instance::FrontendInstance;
use crate::prometheus::PrometheusOptions;
use crate::Plugins;
pub(crate) struct Services;

View File

@@ -257,6 +257,7 @@ pub(crate) async fn create_distributed_instance(test_name: &str) -> MockDistribu
meta_client.clone(),
catalog_manager,
datanode_clients.clone(),
Default::default(),
);
let dist_instance = Arc::new(dist_instance);
let frontend = Instance::new_distributed(dist_instance.clone());