mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-26 01:40:36 +00:00
feat: add udtf (table function) registration (#6922)
* feat: work-in-progress udtf support Signed-off-by: Ning Sun <sunning@greptime.com> * feat: add table function support Signed-off-by: Ning Sun <sunning@greptime.com> * test: resolve test error Signed-off-by: Ning Sun <sunning@greptime.com> --------- Signed-off-by: Ning Sun <sunning@greptime.com>
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, LazyLock, RwLock};
|
||||
|
||||
use datafusion::catalog::TableFunction;
|
||||
use datafusion_expr::AggregateUDF;
|
||||
|
||||
use crate::admin::AdminFunction;
|
||||
@@ -42,6 +43,7 @@ use crate::system::SystemFunction;
|
||||
pub struct FunctionRegistry {
|
||||
functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
|
||||
aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
|
||||
table_functions: RwLock<HashMap<String, Arc<TableFunction>>>,
|
||||
}
|
||||
|
||||
impl FunctionRegistry {
|
||||
@@ -87,6 +89,15 @@ impl FunctionRegistry {
|
||||
.insert(func.name().to_string(), func);
|
||||
}
|
||||
|
||||
/// Register a table function
|
||||
pub fn register_table_function(&self, func: TableFunction) {
|
||||
let _ = self
|
||||
.table_functions
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(func.name().to_string(), Arc::new(func));
|
||||
}
|
||||
|
||||
pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
|
||||
self.functions.read().unwrap().get(name).cloned()
|
||||
}
|
||||
@@ -106,6 +117,15 @@ impl FunctionRegistry {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn table_functions(&self) -> Vec<Arc<TableFunction>> {
|
||||
self.table_functions
|
||||
.read()
|
||||
.unwrap()
|
||||
.values()
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns true if an aggregate function with the given name exists in the registry.
|
||||
pub fn is_aggr_func_exist(&self, name: &str) -> bool {
|
||||
self.aggregate_functions.read().unwrap().contains_key(name)
|
||||
|
||||
@@ -23,6 +23,7 @@ use common_function::function_factory::ScalarFunctionFactory;
|
||||
use common_query::Output;
|
||||
use common_runtime::Runtime;
|
||||
use common_runtime::runtime::{BuilderBuild, RuntimeTrait};
|
||||
use datafusion::catalog::TableFunction;
|
||||
use datafusion_expr::{AggregateUDF, LogicalPlan};
|
||||
use query::dataframe::DataFrame;
|
||||
use query::planner::LogicalPlanner;
|
||||
@@ -80,6 +81,8 @@ impl QueryEngine for MockQueryEngine {
|
||||
|
||||
fn register_scalar_function(&self, _func: ScalarFunctionFactory) {}
|
||||
|
||||
fn register_table_function(&self, _func: Arc<TableFunction>) {}
|
||||
|
||||
fn read_table(&self, _table: TableRef) -> query::error::Result<DataFrame> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ use common_query::{Output, OutputData, OutputMeta};
|
||||
use common_recordbatch::adapter::RecordBatchStreamAdapter;
|
||||
use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
|
||||
use common_telemetry::tracing;
|
||||
use datafusion::catalog::TableFunction;
|
||||
use datafusion::physical_plan::ExecutionPlan;
|
||||
use datafusion::physical_plan::analyze::AnalyzeExec;
|
||||
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
|
||||
@@ -508,6 +509,10 @@ impl QueryEngine for DatafusionQueryEngine {
|
||||
self.state.register_scalar_function(func);
|
||||
}
|
||||
|
||||
fn register_table_function(&self, func: Arc<TableFunction>) {
|
||||
self.state.register_table_function(func);
|
||||
}
|
||||
|
||||
fn read_table(&self, table: TableRef) -> Result<DataFrame> {
|
||||
Ok(DataFrame::DataFusion(
|
||||
self.state
|
||||
|
||||
@@ -224,15 +224,22 @@ impl ContextProvider for DfContextProviderAdapter {
|
||||
name: &str,
|
||||
args: Vec<datafusion_expr::Expr>,
|
||||
) -> DfResult<Arc<dyn TableSource>> {
|
||||
let tbl_func = self
|
||||
.session_state
|
||||
.table_functions()
|
||||
.get(name)
|
||||
.cloned()
|
||||
.ok_or_else(|| DataFusionError::Plan(format!("table function '{name}' not found")))?;
|
||||
let provider = tbl_func.create_table_provider(&args)?;
|
||||
if let Some(tbl_func) = self.engine_state.table_function(name) {
|
||||
let provider = tbl_func.create_table_provider(&args)?;
|
||||
Ok(provider_as_source(provider))
|
||||
} else {
|
||||
let tbl_func = self
|
||||
.session_state
|
||||
.table_functions()
|
||||
.get(name)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
DataFusionError::Plan(format!("table function '{name}' not found"))
|
||||
})?;
|
||||
let provider = tbl_func.create_table_provider(&args)?;
|
||||
|
||||
Ok(provider_as_source(provider))
|
||||
Ok(provider_as_source(provider))
|
||||
}
|
||||
}
|
||||
|
||||
fn create_cte_work_table(
|
||||
|
||||
@@ -28,6 +28,7 @@ use common_function::handlers::{
|
||||
FlowServiceHandlerRef, ProcedureServiceHandlerRef, TableMutationHandlerRef,
|
||||
};
|
||||
use common_query::Output;
|
||||
use datafusion::catalog::TableFunction;
|
||||
use datafusion_expr::{AggregateUDF, LogicalPlan};
|
||||
use datatypes::schema::Schema;
|
||||
pub use default_serializer::{DefaultPlanDecoder, DefaultSerializer};
|
||||
@@ -85,6 +86,9 @@ pub trait QueryEngine: Send + Sync {
|
||||
/// Will override if the function with same name is already registered.
|
||||
fn register_scalar_function(&self, func: ScalarFunctionFactory);
|
||||
|
||||
/// Register table function
|
||||
fn register_table_function(&self, func: Arc<TableFunction>);
|
||||
|
||||
/// Create a DataFrame from a table.
|
||||
fn read_table(&self, table: TableRef) -> Result<DataFrame>;
|
||||
|
||||
@@ -164,6 +168,10 @@ fn register_functions(query_engine: &Arc<DatafusionQueryEngine>) {
|
||||
for accumulator in FUNCTION_REGISTRY.aggregate_functions() {
|
||||
query_engine.register_aggregate_function(accumulator);
|
||||
}
|
||||
|
||||
for table_function in FUNCTION_REGISTRY.table_functions() {
|
||||
query_engine.register_table_function(table_function);
|
||||
}
|
||||
}
|
||||
|
||||
pub type QueryEngineRef = Arc<dyn QueryEngine>;
|
||||
|
||||
@@ -25,6 +25,7 @@ use common_function::handlers::{
|
||||
};
|
||||
use common_function::state::FunctionState;
|
||||
use common_telemetry::warn;
|
||||
use datafusion::catalog::TableFunction;
|
||||
use datafusion::dataframe::DataFrame;
|
||||
use datafusion::error::Result as DfResult;
|
||||
use datafusion::execution::SessionStateBuilder;
|
||||
@@ -72,6 +73,7 @@ pub struct QueryEngineState {
|
||||
function_state: Arc<FunctionState>,
|
||||
scalar_functions: Arc<RwLock<HashMap<String, ScalarFunctionFactory>>>,
|
||||
aggr_functions: Arc<RwLock<HashMap<String, AggregateUDF>>>,
|
||||
table_functions: Arc<RwLock<HashMap<String, Arc<TableFunction>>>>,
|
||||
extension_rules: Vec<Arc<dyn ExtensionAnalyzerRule + Send + Sync>>,
|
||||
plugins: Plugins,
|
||||
}
|
||||
@@ -196,6 +198,7 @@ impl QueryEngineState {
|
||||
flow_service_handler,
|
||||
}),
|
||||
aggr_functions: Arc::new(RwLock::new(HashMap::new())),
|
||||
table_functions: Arc::new(RwLock::new(HashMap::new())),
|
||||
extension_rules,
|
||||
plugins,
|
||||
scalar_functions: Arc::new(RwLock::new(HashMap::new())),
|
||||
@@ -265,6 +268,25 @@ impl QueryEngineState {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Retrieve table function by name
|
||||
pub fn table_function(&self, function_name: &str) -> Option<Arc<TableFunction>> {
|
||||
self.table_functions
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(function_name)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Retrieve table function names.
|
||||
pub fn table_function_names(&self) -> Vec<String> {
|
||||
self.table_functions
|
||||
.read()
|
||||
.unwrap()
|
||||
.keys()
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Register an scalar function.
|
||||
/// Will override if the function with same name is already registered.
|
||||
pub fn register_scalar_function(&self, func: ScalarFunctionFactory) {
|
||||
@@ -301,6 +323,19 @@ impl QueryEngineState {
|
||||
);
|
||||
}
|
||||
|
||||
pub fn register_table_function(&self, func: Arc<TableFunction>) {
|
||||
let name = func.name();
|
||||
let x = self
|
||||
.table_functions
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(name.to_string(), func.clone());
|
||||
|
||||
if x.is_some() {
|
||||
warn!("Already registered table function '{name}");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn catalog_manager(&self) -> &CatalogManagerRef {
|
||||
&self.catalog_manager
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user