mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-26 01:40:36 +00:00
fix: check full table name during logical plan creation (#948)
This commit is contained in:
@@ -8,6 +8,7 @@ license.workspace = true
|
||||
arc-swap = "1.0"
|
||||
async-trait = "0.1"
|
||||
catalog = { path = "../catalog" }
|
||||
common-base = { path = "../common/base" }
|
||||
common-catalog = { path = "../common/catalog" }
|
||||
common-error = { path = "../common/error" }
|
||||
common-function = { path = "../common/function" }
|
||||
|
||||
@@ -22,6 +22,7 @@ use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use catalog::CatalogListRef;
|
||||
use common_base::Plugins;
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
|
||||
use common_function::scalars::udf::create_udf;
|
||||
@@ -60,9 +61,9 @@ pub(crate) struct DatafusionQueryEngine {
|
||||
}
|
||||
|
||||
impl DatafusionQueryEngine {
|
||||
pub fn new(catalog_list: CatalogListRef) -> Self {
|
||||
pub fn new(catalog_list: CatalogListRef, plugins: Arc<Plugins>) -> Self {
|
||||
Self {
|
||||
state: QueryEngineState::new(catalog_list.clone()),
|
||||
state: QueryEngineState::new(catalog_list.clone(), plugins),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -66,6 +66,9 @@ pub enum Error {
|
||||
#[snafu(display("Failure during query parsing, query: {}, source: {}", query, source))]
|
||||
QueryParse { query: String, source: BoxedError },
|
||||
|
||||
#[snafu(display("Illegal access to catalog: {} and schema: {}", catalog, schema))]
|
||||
QueryAccessDenied { catalog: String, schema: String },
|
||||
|
||||
#[snafu(display("The SQL string has multiple statements, query: {}", query))]
|
||||
MultipleStatements { query: String, backtrace: Backtrace },
|
||||
|
||||
@@ -83,6 +86,7 @@ impl ErrorExt for Error {
|
||||
| CatalogNotFound { .. }
|
||||
| SchemaNotFound { .. }
|
||||
| TableNotFound { .. } => StatusCode::InvalidArguments,
|
||||
QueryAccessDenied { .. } => StatusCode::AccessDenied,
|
||||
Catalog { source } => source.status_code(),
|
||||
VectorComputation { source } => source.status_code(),
|
||||
CreateRecordBatch { source } => source.status_code(),
|
||||
|
||||
@@ -13,12 +13,14 @@
|
||||
// limitations under the License.
|
||||
|
||||
mod context;
|
||||
pub mod options;
|
||||
mod state;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use catalog::CatalogListRef;
|
||||
use common_base::Plugins;
|
||||
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
|
||||
use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY};
|
||||
use common_query::physical_plan::PhysicalPlan;
|
||||
@@ -63,23 +65,29 @@ pub struct QueryEngineFactory {
|
||||
|
||||
impl QueryEngineFactory {
|
||||
pub fn new(catalog_list: CatalogListRef) -> Self {
|
||||
let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list));
|
||||
|
||||
for func in FUNCTION_REGISTRY.functions() {
|
||||
query_engine.register_function(func);
|
||||
}
|
||||
|
||||
for accumulator in FUNCTION_REGISTRY.aggregate_functions() {
|
||||
query_engine.register_aggregate_function(accumulator);
|
||||
}
|
||||
|
||||
let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list, Default::default()));
|
||||
register_functions(&query_engine);
|
||||
Self { query_engine }
|
||||
}
|
||||
|
||||
pub fn new_with_plugins(catalog_list: CatalogListRef, plugins: Arc<Plugins>) -> Self {
|
||||
let query_engine = Arc::new(DatafusionQueryEngine::new(catalog_list, plugins));
|
||||
register_functions(&query_engine);
|
||||
Self { query_engine }
|
||||
}
|
||||
|
||||
pub fn query_engine(&self) -> QueryEngineRef {
|
||||
self.query_engine.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryEngineFactory {
|
||||
pub fn query_engine(&self) -> QueryEngineRef {
|
||||
self.query_engine.clone()
|
||||
fn register_functions(query_engine: &Arc<DatafusionQueryEngine>) {
|
||||
for func in FUNCTION_REGISTRY.functions() {
|
||||
query_engine.register_function(func);
|
||||
}
|
||||
|
||||
for accumulator in FUNCTION_REGISTRY.aggregate_functions() {
|
||||
query_engine.register_aggregate_function(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
132
src/query/src/query_engine/options.rs
Normal file
132
src/query/src/query_engine/options.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use datafusion_common::TableReference;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ensure;
|
||||
|
||||
use crate::error::{QueryAccessDeniedSnafu, Result};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct QueryOptions {
|
||||
pub disallow_cross_schema_query: bool,
|
||||
}
|
||||
|
||||
pub fn validate_catalog_and_schema(
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
query_ctx: &QueryContextRef,
|
||||
) -> Result<()> {
|
||||
ensure!(
|
||||
catalog == query_ctx.current_catalog() && schema == query_ctx.current_schema(),
|
||||
QueryAccessDeniedSnafu {
|
||||
catalog: catalog.to_string(),
|
||||
schema: schema.to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_table_references(name: TableReference, query_ctx: &QueryContextRef) -> Result<()> {
|
||||
match name {
|
||||
TableReference::Bare { .. } => Ok(()),
|
||||
TableReference::Partial { schema, .. } => {
|
||||
ensure!(
|
||||
schema == query_ctx.current_schema(),
|
||||
QueryAccessDeniedSnafu {
|
||||
catalog: query_ctx.current_catalog(),
|
||||
schema: schema.to_string(),
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
TableReference::Full {
|
||||
catalog, schema, ..
|
||||
} => {
|
||||
ensure!(
|
||||
catalog == query_ctx.current_catalog() && schema == query_ctx.current_schema(),
|
||||
QueryAccessDeniedSnafu {
|
||||
catalog: catalog.to_string(),
|
||||
schema: schema.to_string(),
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use session::context::QueryContext;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validate_table_ref() {
|
||||
let context = Arc::new(QueryContext::with("greptime", "public"));
|
||||
|
||||
let table_ref = TableReference::Bare {
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_ok());
|
||||
|
||||
let table_ref = TableReference::Partial {
|
||||
schema: "public",
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_ok());
|
||||
|
||||
let table_ref = TableReference::Partial {
|
||||
schema: "wrong_schema",
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_err());
|
||||
|
||||
let table_ref = TableReference::Full {
|
||||
catalog: "greptime",
|
||||
schema: "public",
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_ok());
|
||||
|
||||
let table_ref = TableReference::Full {
|
||||
catalog: "wrong_catalog",
|
||||
schema: "public",
|
||||
table: "table_name",
|
||||
};
|
||||
let re = validate_table_references(table_ref, &context);
|
||||
assert!(re.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_catalog_and_schema() {
|
||||
let context = Arc::new(QueryContext::with("greptime", "public"));
|
||||
|
||||
let re = validate_catalog_and_schema("greptime", "public", &context);
|
||||
assert!(re.is_ok());
|
||||
let re = validate_catalog_and_schema("greptime", "wrong_schema", &context);
|
||||
assert!(re.is_err());
|
||||
let re = validate_catalog_and_schema("wrong_catalog", "public", &context);
|
||||
assert!(re.is_err());
|
||||
let re = validate_catalog_and_schema("wrong_catalog", "wrong_schema", &context);
|
||||
assert!(re.is_err());
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ use std::sync::{Arc, RwLock};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use catalog::CatalogListRef;
|
||||
use common_base::Plugins;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_function::scalars::aggregate::AggregateFunctionMetaRef;
|
||||
use common_query::physical_plan::{SessionContext, TaskContext};
|
||||
@@ -39,6 +40,7 @@ use session::context::QueryContextRef;
|
||||
|
||||
use crate::datafusion::DfCatalogListAdapter;
|
||||
use crate::optimizer::TypeConversionRule;
|
||||
use crate::query_engine::options::{validate_table_references, QueryOptions};
|
||||
|
||||
/// Query engine global state
|
||||
// TODO(yingwen): This QueryEngineState still relies on datafusion, maybe we can define a trait for it,
|
||||
@@ -49,6 +51,7 @@ pub struct QueryEngineState {
|
||||
df_context: SessionContext,
|
||||
catalog_list: CatalogListRef,
|
||||
aggregate_functions: Arc<RwLock<HashMap<String, AggregateFunctionMetaRef>>>,
|
||||
plugins: Arc<Plugins>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for QueryEngineState {
|
||||
@@ -59,7 +62,7 @@ impl fmt::Debug for QueryEngineState {
|
||||
}
|
||||
|
||||
impl QueryEngineState {
|
||||
pub fn new(catalog_list: CatalogListRef) -> Self {
|
||||
pub fn new(catalog_list: CatalogListRef, plugins: Arc<Plugins>) -> Self {
|
||||
let runtime_env = Arc::new(RuntimeEnv::default());
|
||||
let session_config = SessionConfig::new()
|
||||
.with_default_catalog_and_schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME);
|
||||
@@ -78,6 +81,7 @@ impl QueryEngineState {
|
||||
df_context,
|
||||
catalog_list,
|
||||
aggregate_functions: Arc::new(RwLock::new(HashMap::new())),
|
||||
plugins,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,6 +124,13 @@ impl QueryEngineState {
|
||||
name: TableReference,
|
||||
) -> DfResult<Arc<dyn TableSource>> {
|
||||
let state = self.df_context.state();
|
||||
|
||||
if let Some(opts) = self.plugins.get::<QueryOptions>() {
|
||||
if opts.disallow_cross_schema_query {
|
||||
validate_table_references(name, &query_ctx)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let TableReference::Bare { table } = name {
|
||||
let name = TableReference::Partial {
|
||||
schema: &query_ctx.current_schema(),
|
||||
|
||||
@@ -21,8 +21,9 @@ mod function;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use catalog::local::{MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::{CatalogList, CatalogProvider, SchemaProvider};
|
||||
use common_base::Plugins;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_query::prelude::{create_udf, make_scalar_function, Volatility};
|
||||
@@ -36,6 +37,7 @@ use datatypes::vectors::UInt32Vector;
|
||||
use query::error::{QueryExecutionSnafu, Result};
|
||||
use query::parser::QueryLanguageParser;
|
||||
use query::plan::LogicalPlan;
|
||||
use query::query_engine::options::QueryOptions;
|
||||
use query::query_engine::QueryEngineFactory;
|
||||
use session::context::QueryContext;
|
||||
use snafu::ResultExt;
|
||||
@@ -107,9 +109,7 @@ async fn test_datafusion_query_engine() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_udf() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
fn catalog_list() -> Result<Arc<MemoryCatalogManager>> {
|
||||
let catalog_list = catalog::local::new_memory_catalog_list()
|
||||
.map_err(BoxedError::new)
|
||||
.context(QueryExecutionSnafu)?;
|
||||
@@ -125,6 +125,39 @@ async fn test_udf() -> Result<()> {
|
||||
catalog_list
|
||||
.register_catalog(DEFAULT_CATALOG_NAME.to_string(), default_catalog)
|
||||
.unwrap();
|
||||
Ok(catalog_list)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_validate() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let catalog_list = catalog_list()?;
|
||||
|
||||
// set plugins
|
||||
let mut plugins = Plugins::new();
|
||||
plugins.insert(QueryOptions {
|
||||
disallow_cross_schema_query: true,
|
||||
});
|
||||
let plugins = Arc::new(plugins);
|
||||
|
||||
let factory = QueryEngineFactory::new_with_plugins(catalog_list, plugins);
|
||||
let engine = factory.query_engine();
|
||||
|
||||
let stmt = QueryLanguageParser::parse_sql("select number from public.numbers").unwrap();
|
||||
let re = engine.statement_to_plan(stmt, Arc::new(QueryContext::new()));
|
||||
assert!(re.is_ok());
|
||||
|
||||
let stmt = QueryLanguageParser::parse_sql("select number from wrongschema.numbers").unwrap();
|
||||
let re = engine.statement_to_plan(stmt, Arc::new(QueryContext::new()));
|
||||
assert!(re.is_err());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_udf() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let catalog_list = catalog_list()?;
|
||||
|
||||
let factory = QueryEngineFactory::new(catalog_list);
|
||||
let engine = factory.query_engine();
|
||||
|
||||
Reference in New Issue
Block a user