feat: add udtf (table function) registration (#6922)

* feat: work-in-progress udtf support

Signed-off-by: Ning Sun <sunning@greptime.com>

* feat: add table function support

Signed-off-by: Ning Sun <sunning@greptime.com>

* test: resolve test error

Signed-off-by: Ning Sun <sunning@greptime.com>

---------

Signed-off-by: Ning Sun <sunning@greptime.com>
This commit is contained in:
Ning Sun
2025-09-09 09:26:38 +08:00
committed by GitHub
parent 16febbd4c2
commit 948a6578fa
6 changed files with 86 additions and 8 deletions

View File

@@ -16,6 +16,7 @@
use std::collections::HashMap;
use std::sync::{Arc, LazyLock, RwLock};
use datafusion::catalog::TableFunction;
use datafusion_expr::AggregateUDF;
use crate::admin::AdminFunction;
@@ -42,6 +43,7 @@ use crate::system::SystemFunction;
pub struct FunctionRegistry {
functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
}
impl FunctionRegistry {
@@ -87,6 +89,15 @@ impl FunctionRegistry {
.insert(func.name().to_string(), func);
}
/// Register a table function
pub fn register_table_function(&self, func: TableFunction) {
let _ = self
.table_functions
.write()
.unwrap()
.insert(func.name().to_string(), Arc::new(func));
}
pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
self.functions.read().unwrap().get(name).cloned()
}
@@ -106,6 +117,15 @@ impl FunctionRegistry {
.collect()
}
pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
self.table_functions
.read()
.unwrap()
.values()
.cloned()
.collect()
}
/// Returns true if an aggregate function with the given name exists in the registry.
pub fn is_aggr_func_exist(&self, name: &str) -> bool {
self.aggregate_functions.read().unwrap().contains_key(name)

View File

@@ -23,6 +23,7 @@ use common_function::function_factory::ScalarFunctionFactory;
use common_query::Output;
use common_runtime::Runtime;
use common_runtime::runtime::{BuilderBuild, RuntimeTrait};
use datafusion::catalog::TableFunction;
use datafusion_expr::{AggregateUDF, LogicalPlan};
use query::dataframe::DataFrame;
use query::planner::LogicalPlanner;
@@ -80,6 +81,8 @@ impl QueryEngine for MockQueryEngine {
fn register_scalar_function(&self, _func: ScalarFunctionFactory) {}
fn register_table_function(&self, _func: Arc<TableFunction>) {}
fn read_table(&self, _table: TableRef) -> query::error::Result<DataFrame> {
unimplemented!()
}

View File

@@ -30,6 +30,7 @@ use common_query::{Output, OutputData, OutputMeta};
use common_recordbatch::adapter::RecordBatchStreamAdapter;
use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
use common_telemetry::tracing;
use datafusion::catalog::TableFunction;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::analyze::AnalyzeExec;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
@@ -508,6 +509,10 @@ impl QueryEngine for DatafusionQueryEngine {
self.state.register_scalar_function(func);
}
fn register_table_function(&self, func: Arc<TableFunction>) {
self.state.register_table_function(func);
}
fn read_table(&self, table: TableRef) -> Result<DataFrame> {
Ok(DataFrame::DataFusion(
self.state

View File

@@ -224,15 +224,22 @@ impl ContextProvider for DfContextProviderAdapter {
name: &str,
args: Vec<datafusion_expr::Expr>,
) -> DfResult<Arc<dyn TableSource>> {
let tbl_func = self
.session_state
.table_functions()
.get(name)
.cloned()
.ok_or_else(|| DataFusionError::Plan(format!("table function '{name}' not found")))?;
let provider = tbl_func.create_table_provider(&args)?;
if let Some(tbl_func) = self.engine_state.table_function(name) {
let provider = tbl_func.create_table_provider(&args)?;
Ok(provider_as_source(provider))
} else {
let tbl_func = self
.session_state
.table_functions()
.get(name)
.cloned()
.ok_or_else(|| {
DataFusionError::Plan(format!("table function '{name}' not found"))
})?;
let provider = tbl_func.create_table_provider(&args)?;
Ok(provider_as_source(provider))
Ok(provider_as_source(provider))
}
}
fn create_cte_work_table(

View File

@@ -28,6 +28,7 @@ use common_function::handlers::{
FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef,
};
use common_query::Output;
use datafusion::catalog::TableFunction;
use datafusion_expr::{AggregateUDF, LogicalPlan};
use datatypes::schema::Schema;
pub use default_serializer::{DefaultPlanDecoder, DefaultSerializer};
@@ -85,6 +86,9 @@ pub trait QueryEngine: Send + Sync {
/// Will override if the function with same name is already registered.
fn register_scalar_function(&self, func: ScalarFunctionFactory);
/// Register table function
fn register_table_function(&self, func: Arc<TableFunction>);
/// Create a DataFrame from a table.
fn read_table(&self, table: TableRef) -> Result<DataFrame>;
@@ -164,6 +168,10 @@ fn register_functions(query_engine: &Arc<DatafusionQueryEngine>) {
for accumulator in FUNCTION_REGISTRY.aggregate_functions() {
query_engine.register_aggregate_function(accumulator);
}
for table_function in FUNCTION_REGISTRY.table_functions() {
query_engine.register_table_function(table_function);
}
}
pub type QueryEngineRef = Arc<dyn QueryEngine>;

View File

@@ -25,6 +25,7 @@ use common_function::handlers::{
};
use common_function::state::FunctionState;
use common_telemetry::warn;
use datafusion::catalog::TableFunction;
use datafusion::dataframe::DataFrame;
use datafusion::error::Result as DfResult;
use datafusion::execution::SessionStateBuilder;
@@ -72,6 +73,7 @@ pub struct QueryEngineState {
function_state: Arc<FunctionState>,
scalar_functions: Arc<RwLock<HashMap<String, ScalarFunctionFactory>>>,
aggr_functions: Arc<RwLock<HashMap<String, AggregateUDF>>>,
table_functions: Arc<RwLock<HashMap<String, Arc<TableFunction>>>>,
extension_rules: Vec<Arc<dyn ExtensionAnalyzerRule + Send + Sync>>,
plugins: Plugins,
}
@@ -196,6 +198,7 @@ impl QueryEngineState {
flow_service_handler,
}),
aggr_functions: Arc::new(RwLock::new(HashMap::new())),
table_functions: Arc::new(RwLock::new(HashMap::new())),
extension_rules,
plugins,
scalar_functions: Arc::new(RwLock::new(HashMap::new())),
@@ -265,6 +268,25 @@ impl QueryEngineState {
.collect()
}
/// Retrieve table function by name
pub fn table_function(&self, function_name: &str) -> Option<Arc<TableFunction>> {
self.table_functions
.read()
.unwrap()
.get(function_name)
.cloned()
}
/// Retrieve table function names.
pub fn table_function_names(&self) -> Vec<String> {
self.table_functions
.read()
.unwrap()
.keys()
.cloned()
.collect()
}
/// Register an scalar function.
/// Will override if the function with same name is already registered.
pub fn register_scalar_function(&self, func: ScalarFunctionFactory) {
@@ -301,6 +323,19 @@ impl QueryEngineState {
);
}
pub fn register_table_function(&self, func: Arc<TableFunction>) {
let name = func.name();
let x = self
.table_functions
.write()
.unwrap()
.insert(name.to_string(), func.clone());
if x.is_some() {
warn!("Already registered table function '{name}");
}
}
pub fn catalog_manager(&self) -> &CatalogManagerRef {
&self.catalog_manager
}