feat: update our cross schema check to cross catalog (#3123)

This commit is contained in:
Ning Sun
2024-01-09 17:38:48 +08:00
committed by GitHub
parent db98484796
commit 1fc168bf6a
7 changed files with 35 additions and 47 deletions

View File

@@ -15,7 +15,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use common_catalog::consts::INFORMATION_SCHEMA_NAME;
use common_catalog::format_full_table_name;
use datafusion::common::{ResolvedTableReference, TableReference};
use datafusion::datasource::provider_as_source;
@@ -30,7 +29,7 @@ use crate::CatalogManagerRef;
pub struct DfTableSourceProvider {
catalog_manager: CatalogManagerRef,
resolved_tables: HashMap<String, Arc<dyn TableSource>>,
disallow_cross_schema_query: bool,
disallow_cross_catalog_query: bool,
default_catalog: String,
default_schema: String,
}
@@ -38,12 +37,12 @@ pub struct DfTableSourceProvider {
impl DfTableSourceProvider {
pub fn new(
catalog_manager: CatalogManagerRef,
disallow_cross_schema_query: bool,
disallow_cross_catalog_query: bool,
query_ctx: &QueryContext,
) -> Self {
Self {
catalog_manager,
disallow_cross_schema_query,
disallow_cross_catalog_query,
resolved_tables: HashMap::new(),
default_catalog: query_ctx.current_catalog().to_owned(),
default_schema: query_ctx.current_schema().to_owned(),
@@ -54,29 +53,18 @@ impl DfTableSourceProvider {
&'a self,
table_ref: TableReference<'a>,
) -> Result<ResolvedTableReference<'a>> {
if self.disallow_cross_schema_query {
if self.disallow_cross_catalog_query {
match &table_ref {
TableReference::Bare { .. } => (),
TableReference::Partial { schema, .. } => {
ensure!(
schema.as_ref() == self.default_schema
|| schema.as_ref() == INFORMATION_SCHEMA_NAME,
QueryAccessDeniedSnafu {
catalog: &self.default_catalog,
schema: schema.as_ref(),
}
);
}
TableReference::Partial { .. } => {}
TableReference::Full {
catalog, schema, ..
} => {
ensure!(
catalog.as_ref() == self.default_catalog
&& (schema.as_ref() == self.default_schema
|| schema.as_ref() == INFORMATION_SCHEMA_NAME),
catalog.as_ref() == self.default_catalog,
QueryAccessDeniedSnafu {
catalog: catalog.as_ref(),
schema: schema.as_ref()
schema: schema.as_ref(),
}
);
}
@@ -136,21 +124,21 @@ mod tests {
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
let _ = result.unwrap();
assert!(result.is_ok());
let table_ref = TableReference::Partial {
schema: Cow::Borrowed("public"),
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
let _ = result.unwrap();
assert!(result.is_ok());
let table_ref = TableReference::Partial {
schema: Cow::Borrowed("wrong_schema"),
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
assert!(result.is_err());
assert!(result.is_ok());
let table_ref = TableReference::Full {
catalog: Cow::Borrowed("greptime"),
@@ -158,7 +146,7 @@ mod tests {
table: Cow::Borrowed("table_name"),
};
let result = table_provider.resolve_table_ref(table_ref);
let _ = result.unwrap();
assert!(result.is_ok());
let table_ref = TableReference::Full {
catalog: Cow::Borrowed("wrong_catalog"),
@@ -172,14 +160,15 @@ mod tests {
schema: Cow::Borrowed("information_schema"),
table: Cow::Borrowed("columns"),
};
let _ = table_provider.resolve_table_ref(table_ref).unwrap();
let result = table_provider.resolve_table_ref(table_ref);
assert!(result.is_ok());
let table_ref = TableReference::Full {
catalog: Cow::Borrowed("greptime"),
schema: Cow::Borrowed("information_schema"),
table: Cow::Borrowed("columns"),
};
let _ = table_provider.resolve_table_ref(table_ref).unwrap();
assert!(table_provider.resolve_table_ref(table_ref).is_ok());
let table_ref = TableReference::Full {
catalog: Cow::Borrowed("dummy"),
@@ -187,5 +176,12 @@ mod tests {
table: Cow::Borrowed("columns"),
};
assert!(table_provider.resolve_table_ref(table_ref).is_err());
let table_ref = TableReference::Full {
catalog: Cow::Borrowed("greptime"),
schema: Cow::Borrowed("greptime_private"),
table: Cow::Borrowed("columns"),
};
assert!(table_provider.resolve_table_ref(table_ref).is_ok());
}
}

View File

@@ -442,7 +442,7 @@ pub fn check_permission(
) -> Result<()> {
let need_validate = plugins
.get::<QueryOptions>()
.map(|opts| opts.disallow_cross_schema_query)
.map(|opts| opts.disallow_cross_catalog_query)
.unwrap_or_default();
if !need_validate {
@@ -520,7 +520,7 @@ mod tests {
let query_ctx = QueryContext::arc();
let plugins: Plugins = Plugins::new();
plugins.insert(QueryOptions {
disallow_cross_schema_query: true,
disallow_cross_catalog_query: true,
});
let sql = r#"
@@ -556,8 +556,6 @@ mod tests {
}
let wrong = vec![
("", "wrongschema."),
("greptime.", "wrongschema."),
("wrongcatalog.", "public."),
("wrongcatalog.", "wrongschema."),
];
@@ -607,10 +605,10 @@ mod tests {
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
check_permission(plugins.clone(), &stmt[0], &query_ctx).unwrap();
let sql = "SHOW TABLES FROM wrongschema";
let sql = "SHOW TABLES FROM private";
let stmt = parse_stmt(sql, &GreptimeDbDialect {}).unwrap();
let re = check_permission(plugins.clone(), &stmt[0], &query_ctx);
assert!(re.is_err());
assert!(re.is_ok());
// test describe table
let sql = "DESC TABLE {catalog}{schema}demo;";

View File

@@ -56,7 +56,7 @@ impl DfContextProviderAdapter {
let mut table_provider = DfTableSourceProvider::new(
engine_state.catalog_manager().clone(),
engine_state.disallow_cross_schema_query(),
engine_state.disallow_cross_catalog_query(),
query_ctx.as_ref(),
);

View File

@@ -58,7 +58,7 @@ impl DfLogicalPlanner {
let table_provider = DfTableSourceProvider::new(
self.engine_state.catalog_manager().clone(),
self.engine_state.disallow_cross_schema_query(),
self.engine_state.disallow_cross_catalog_query(),
query_ctx.as_ref(),
);
@@ -91,7 +91,7 @@ impl DfLogicalPlanner {
async fn plan_pql(&self, stmt: EvalStmt, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
let table_provider = DfTableSourceProvider::new(
self.engine_state.catalog_manager().clone(),
self.engine_state.disallow_cross_schema_query(),
self.engine_state.disallow_cross_catalog_query(),
query_ctx.as_ref(),
);
PromPlanner::stmt_to_plan(table_provider, stmt)

View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use common_catalog::consts::INFORMATION_SCHEMA_NAME;
use session::context::QueryContextRef;
use snafu::ensure;
@@ -20,7 +19,7 @@ use crate::error::{QueryAccessDeniedSnafu, Result};
#[derive(Default, Clone)]
pub struct QueryOptions {
pub disallow_cross_schema_query: bool,
pub disallow_cross_catalog_query: bool,
}
// TODO(shuiyisong): remove one method after #559 is done
@@ -29,13 +28,8 @@ pub fn validate_catalog_and_schema(
schema: &str,
query_ctx: &QueryContextRef,
) -> Result<()> {
// information_schema is an exception
if schema.eq_ignore_ascii_case(INFORMATION_SCHEMA_NAME) {
return Ok(());
}
ensure!(
catalog == query_ctx.current_catalog() && schema == query_ctx.current_schema(),
catalog == query_ctx.current_catalog(),
QueryAccessDeniedSnafu {
catalog: catalog.to_string(),
schema: schema.to_string(),
@@ -57,8 +51,8 @@ mod tests {
let context = QueryContext::with("greptime", "public");
validate_catalog_and_schema("greptime", "public", &context).unwrap();
let re = validate_catalog_and_schema("greptime", "wrong_schema", &context);
assert!(re.is_err());
let re = validate_catalog_and_schema("greptime", "private_schema", &context);
assert!(re.is_ok());
let re = validate_catalog_and_schema("wrong_catalog", "public", &context);
assert!(re.is_err());
let re = validate_catalog_and_schema("wrong_catalog", "wrong_schema", &context);

View File

@@ -163,9 +163,9 @@ impl QueryEngineState {
self.table_mutation_handler.as_ref()
}
pub(crate) fn disallow_cross_schema_query(&self) -> bool {
pub(crate) fn disallow_cross_catalog_query(&self) -> bool {
self.plugins
.map::<QueryOptions, _, _>(|x| x.disallow_cross_schema_query)
.map::<QueryOptions, _, _>(|x| x.disallow_cross_catalog_query)
.unwrap_or(false)
}

View File

@@ -125,7 +125,7 @@ async fn test_query_validate() -> Result<()> {
// set plugins
let plugins = Plugins::new();
plugins.insert(QueryOptions {
disallow_cross_schema_query: true,
disallow_cross_catalog_query: true,
});
let factory = QueryEngineFactory::new_with_plugins(catalog_list, None, None, false, plugins);