diff --git a/src/common/function/src/function_registry.rs b/src/common/function/src/function_registry.rs index 4a6b8f91e1..9d2a15c370 100644 --- a/src/common/function/src/function_registry.rs +++ b/src/common/function/src/function_registry.rs @@ -16,6 +16,7 @@ use std::collections::HashMap; use std::sync::{Arc, LazyLock, RwLock}; +use datafusion::catalog::TableFunction; use datafusion_expr::AggregateUDF; use crate::admin::AdminFunction; @@ -42,6 +43,7 @@ use crate::system::SystemFunction; pub struct FunctionRegistry { functions: RwLock>, aggregate_functions: RwLock>, + table_functions: RwLock>>, } impl FunctionRegistry { @@ -87,6 +89,15 @@ impl FunctionRegistry { .insert(func.name().to_string(), func); } + /// Register a table function + pub fn register_table_function(&self, func: TableFunction) { + let _ = self + .table_functions + .write() + .unwrap() + .insert(func.name().to_string(), Arc::new(func)); + } + pub fn get_function(&self, name: &str) -> Option { self.functions.read().unwrap().get(name).cloned() } @@ -106,6 +117,15 @@ impl FunctionRegistry { .collect() } + pub fn table_functions(&self) -> Vec> { + self.table_functions + .read() + .unwrap() + .values() + .cloned() + .collect() + } + /// Returns true if an aggregate function with the given name exists in the registry. pub fn is_aggr_func_exist(&self, name: &str) -> bool { self.aggregate_functions.read().unwrap().contains_key(name) diff --git a/src/datanode/src/tests.rs b/src/datanode/src/tests.rs index 1f6a882415..ae0af78904 100644 --- a/src/datanode/src/tests.rs +++ b/src/datanode/src/tests.rs @@ -23,6 +23,7 @@ use common_function::function_factory::ScalarFunctionFactory; use common_query::Output; use common_runtime::Runtime; use common_runtime::runtime::{BuilderBuild, RuntimeTrait}; +use datafusion::catalog::TableFunction; use datafusion_expr::{AggregateUDF, LogicalPlan}; use query::dataframe::DataFrame; use query::planner::LogicalPlanner; @@ -80,6 +81,8 @@ impl QueryEngine for MockQueryEngine { fn register_scalar_function(&self, _func: ScalarFunctionFactory) {} + fn register_table_function(&self, _func: Arc) {} + fn read_table(&self, _table: TableRef) -> query::error::Result { unimplemented!() } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 67a92f0be2..96e1f60763 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -30,6 +30,7 @@ use common_query::{Output, OutputData, OutputMeta}; use common_recordbatch::adapter::RecordBatchStreamAdapter; use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream}; use common_telemetry::tracing; +use datafusion::catalog::TableFunction; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -508,6 +509,10 @@ impl QueryEngine for DatafusionQueryEngine { self.state.register_scalar_function(func); } + fn register_table_function(&self, func: Arc) { + self.state.register_table_function(func); + } + fn read_table(&self, table: TableRef) -> Result { Ok(DataFrame::DataFusion( self.state diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 4e6917fad3..d9c74b9d5a 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -224,15 +224,22 @@ impl ContextProvider for DfContextProviderAdapter { name: &str, args: Vec, ) -> DfResult> { - let tbl_func = self - .session_state - .table_functions() - .get(name) - .cloned() - .ok_or_else(|| DataFusionError::Plan(format!("table function '{name}' not found")))?; - let provider = tbl_func.create_table_provider(&args)?; + if let Some(tbl_func) = self.engine_state.table_function(name) { + let provider = tbl_func.create_table_provider(&args)?; + Ok(provider_as_source(provider)) + } else { + let tbl_func = self + .session_state + .table_functions() + .get(name) + .cloned() + .ok_or_else(|| { + DataFusionError::Plan(format!("table function '{name}' not found")) + })?; + let provider = tbl_func.create_table_provider(&args)?; - Ok(provider_as_source(provider)) + Ok(provider_as_source(provider)) + } } fn create_cte_work_table( diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index db9ab7140a..34a4fee209 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -28,6 +28,7 @@ use common_function::handlers::{ FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef, }; use common_query::Output; +use datafusion::catalog::TableFunction; use datafusion_expr::{AggregateUDF, LogicalPlan}; use datatypes::schema::Schema; pub use default_serializer::{DefaultPlanDecoder, DefaultSerializer}; @@ -85,6 +86,9 @@ pub trait QueryEngine: Send + Sync { /// Will override if the function with same name is already registered. fn register_scalar_function(&self, func: ScalarFunctionFactory); + /// Register table function + fn register_table_function(&self, func: Arc); + /// Create a DataFrame from a table. fn read_table(&self, table: TableRef) -> Result; @@ -164,6 +168,10 @@ fn register_functions(query_engine: &Arc) { for accumulator in FUNCTION_REGISTRY.aggregate_functions() { query_engine.register_aggregate_function(accumulator); } + + for table_function in FUNCTION_REGISTRY.table_functions() { + query_engine.register_table_function(table_function); + } } pub type QueryEngineRef = Arc; diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index 51b7f68684..2608264261 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -25,6 +25,7 @@ use common_function::handlers::{ }; use common_function::state::FunctionState; use common_telemetry::warn; +use datafusion::catalog::TableFunction; use datafusion::dataframe::DataFrame; use datafusion::error::Result as DfResult; use datafusion::execution::SessionStateBuilder; @@ -72,6 +73,7 @@ pub struct QueryEngineState { function_state: Arc, scalar_functions: Arc>>, aggr_functions: Arc>>, + table_functions: Arc>>>, extension_rules: Vec>, plugins: Plugins, } @@ -196,6 +198,7 @@ impl QueryEngineState { flow_service_handler, }), aggr_functions: Arc::new(RwLock::new(HashMap::new())), + table_functions: Arc::new(RwLock::new(HashMap::new())), extension_rules, plugins, scalar_functions: Arc::new(RwLock::new(HashMap::new())), @@ -265,6 +268,25 @@ impl QueryEngineState { .collect() } + /// Retrieve table function by name + pub fn table_function(&self, function_name: &str) -> Option> { + self.table_functions + .read() + .unwrap() + .get(function_name) + .cloned() + } + + /// Retrieve table function names. + pub fn table_function_names(&self) -> Vec { + self.table_functions + .read() + .unwrap() + .keys() + .cloned() + .collect() + } + /// Register an scalar function. /// Will override if the function with same name is already registered. pub fn register_scalar_function(&self, func: ScalarFunctionFactory) { @@ -301,6 +323,19 @@ impl QueryEngineState { ); } + pub fn register_table_function(&self, func: Arc) { + let name = func.name(); + let x = self + .table_functions + .write() + .unwrap() + .insert(name.to_string(), func.clone()); + + if x.is_some() { + warn!("Already registered table function '{name}"); + } + } + pub fn catalog_manager(&self) -> &CatalogManagerRef { &self.catalog_manager }