refactor: refactor admin functions with async udf (#6770)

* refactor: use async udf for admin functions

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: sqlness test

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* chore: code style

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: clippy

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* chore: remove unused error

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* chore: style

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: style

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: code style

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* chore: apply suggestions

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: logical_metric_table test

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

---------

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>
This commit is contained in:
dennis zhuang
2025-08-20 11:35:38 +08:00
committed by GitHub
parent 7402320abc
commit d6bc117408
27 changed files with 1076 additions and 487 deletions

View File

@@ -21,8 +21,6 @@ mod reconcile_database;
mod reconcile_table;
mod remove_region_follower;
use std::sync::Arc;
use add_region_follower::AddRegionFollowerFunction;
use flush_compact_region::{CompactRegionFunction, FlushRegionFunction};
use flush_compact_table::{CompactTableFunction, FlushTableFunction};
@@ -35,22 +33,22 @@ use remove_region_follower::RemoveRegionFollowerFunction;
use crate::flush_flow::FlushFlowFunction;
use crate::function_registry::FunctionRegistry;
/// Table functions
/// Administration functions
pub(crate) struct AdminFunction;
impl AdminFunction {
/// Register all table functions to [`FunctionRegistry`].
/// Register all admin functions to [`FunctionRegistry`].
pub fn register(registry: &FunctionRegistry) {
registry.register_async(Arc::new(MigrateRegionFunction));
registry.register_async(Arc::new(AddRegionFollowerFunction));
registry.register_async(Arc::new(RemoveRegionFollowerFunction));
registry.register_async(Arc::new(FlushRegionFunction));
registry.register_async(Arc::new(CompactRegionFunction));
registry.register_async(Arc::new(FlushTableFunction));
registry.register_async(Arc::new(CompactTableFunction));
registry.register_async(Arc::new(FlushFlowFunction));
registry.register_async(Arc::new(ReconcileCatalogFunction));
registry.register_async(Arc::new(ReconcileDatabaseFunction));
registry.register_async(Arc::new(ReconcileTableFunction));
registry.register(MigrateRegionFunction::factory());
registry.register(AddRegionFollowerFunction::factory());
registry.register(RemoveRegionFollowerFunction::factory());
registry.register(FlushRegionFunction::factory());
registry.register(CompactRegionFunction::factory());
registry.register(FlushTableFunction::factory());
registry.register(CompactTableFunction::factory());
registry.register(FlushFlowFunction::factory());
registry.register(ReconcileCatalogFunction::factory());
registry.register(ReconcileDatabaseFunction::factory());
registry.register(ReconcileTableFunction::factory());
}
}

View File

@@ -18,7 +18,8 @@ use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, TypeSignature, Volatility};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
@@ -82,7 +83,13 @@ fn signature() -> Signature {
Signature::one_of(
vec![
// add_region_follower(region_id, peer)
TypeSignature::Uniform(2, ConcreteDataType::numerics()),
TypeSignature::Uniform(
2,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
],
Volatility::Immutable,
)
@@ -92,38 +99,57 @@ fn signature() -> Signature {
mod tests {
use std::sync::Arc;
use common_query::prelude::TypeSignature;
use datatypes::vectors::{UInt64Vector, VectorRef};
use arrow::array::UInt64Array;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use super::*;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[test]
fn test_add_region_follower_misc() {
let f = AddRegionFollowerFunction;
let factory: ScalarFunctionFactory = AddRegionFollowerFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
assert_eq!("add_region_follower", f.name());
assert_eq!(
ConcreteDataType::uint64_datatype(),
f.return_type(&[]).unwrap()
);
assert_eq!(DataType::UInt64, f.return_type(&[]).unwrap());
assert!(matches!(f.signature(),
Signature {
type_signature: TypeSignature::OneOf(sigs),
volatility: Volatility::Immutable
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
volatility: datafusion_expr::Volatility::Immutable
} if sigs.len() == 1));
}
#[tokio::test]
async fn test_add_region_follower() {
let f = AddRegionFollowerFunction;
let args = vec![1, 1];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let factory: ScalarFunctionFactory = AddRegionFollowerFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(UInt64Vector::from_slice([0u64]));
assert_eq!(result, expect);
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![2]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 0u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(0)));
}
}
}
}

View File

@@ -16,7 +16,8 @@ use common_macro::admin_fn;
use common_query::error::{
InvalidFuncArgsSnafu, MissingTableMutationHandlerSnafu, Result, UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, Volatility};
use datafusion_expr::{Signature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::*;
use session::context::QueryContextRef;
use snafu::ensure;
@@ -66,71 +67,99 @@ define_region_function!(FlushRegionFunction, flush_region, flush_region);
define_region_function!(CompactRegionFunction, compact_region, compact_region);
fn signature() -> Signature {
Signature::uniform(1, ConcreteDataType::numerics(), Volatility::Immutable)
Signature::uniform(
1,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
Volatility::Immutable,
)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use common_query::prelude::TypeSignature;
use datatypes::vectors::UInt64Vector;
use arrow::array::UInt64Array;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use super::*;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
macro_rules! define_region_function_test {
($name: ident, $func: ident) => {
paste::paste! {
#[test]
fn [<test_ $name _misc>]() {
let f = $func;
let factory: ScalarFunctionFactory = $func::factory().into();
let f = factory.provide(FunctionContext::mock());
assert_eq!(stringify!($name), f.name());
assert_eq!(
ConcreteDataType::uint64_datatype(),
DataType::UInt64,
f.return_type(&[]).unwrap()
);
assert!(matches!(f.signature(),
Signature {
type_signature: TypeSignature::Uniform(1, valid_types),
volatility: Volatility::Immutable
} if valid_types == ConcreteDataType::numerics()));
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::Uniform(1, valid_types),
volatility: datafusion_expr::Volatility::Immutable
} if valid_types == &ConcreteDataType::numerics().into_iter().map(|dt| { use datatypes::data_type::DataType; dt.as_arrow_type() }).collect::<Vec<_>>()));
}
#[tokio::test]
async fn [<test_ $name _missing_table_mutation>]() {
let f = $func;
let factory: ScalarFunctionFactory = $func::factory().into();
let provider = factory.provide(FunctionContext::default());
let f = provider.as_async().unwrap();
let args = vec![99];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![99]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap_err();
assert_eq!(
"Missing TableMutationHandler, not expected",
"Execution error: Handler error: Missing TableMutationHandler, not expected",
result.to_string()
);
}
#[tokio::test]
async fn [<test_ $name>]() {
let f = $func;
let factory: ScalarFunctionFactory = $func::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![99]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
let args = vec![99];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(UInt64Vector::from_slice([42]));
assert_eq!(expect, result);
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 42u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(42)));
}
}
}
}
};

View File

@@ -15,14 +15,15 @@
use std::str::FromStr;
use api::v1::region::{compact_request, StrictWindow};
use arrow::datatypes::DataType as ArrowDataType;
use common_error::ext::BoxedError;
use common_macro::admin_fn;
use common_query::error::{
InvalidFuncArgsSnafu, MissingTableMutationHandlerSnafu, Result, TableMutationSnafu,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, Volatility};
use common_telemetry::info;
use datafusion_expr::{Signature, Volatility};
use datatypes::prelude::*;
use session::context::QueryContextRef;
use session::table_name::table_name_to_full_name;
@@ -105,18 +106,11 @@ pub(crate) async fn compact_table(
}
fn flush_signature() -> Signature {
Signature::uniform(
1,
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
}
fn compact_signature() -> Signature {
Signature::variadic(
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
Signature::variadic(vec![ArrowDataType::Utf8], Volatility::Immutable)
}
/// Parses `compact_table` UDF parameters. This function accepts following combinations:
@@ -204,66 +198,87 @@ mod tests {
use std::sync::Arc;
use api::v1::region::compact_request::Options;
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::prelude::TypeSignature;
use datatypes::vectors::{StringVector, UInt64Vector};
use datafusion_expr::ColumnarValue;
use session::context::QueryContext;
use super::*;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
macro_rules! define_table_function_test {
($name: ident, $func: ident) => {
paste::paste!{
#[test]
fn [<test_ $name _misc>]() {
let f = $func;
let factory: ScalarFunctionFactory = $func::factory().into();
let f = factory.provide(FunctionContext::mock());
assert_eq!(stringify!($name), f.name());
assert_eq!(
ConcreteDataType::uint64_datatype(),
DataType::UInt64,
f.return_type(&[]).unwrap()
);
assert!(matches!(f.signature(),
Signature {
type_signature: TypeSignature::Uniform(1, valid_types),
volatility: Volatility::Immutable
} if valid_types == vec![ConcreteDataType::string_datatype()]));
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::Uniform(1, valid_types),
volatility: datafusion_expr::Volatility::Immutable
} if valid_types == &vec![ArrowDataType::Utf8]));
}
#[tokio::test]
async fn [<test_ $name _missing_table_mutation>]() {
let f = $func;
let factory: ScalarFunctionFactory = $func::factory().into();
let provider = factory.provide(FunctionContext::default());
let f = provider.as_async().unwrap();
let args = vec!["test"];
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from(vec![arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap_err();
assert_eq!(
"Missing TableMutationHandler, not expected",
"Execution error: Handler error: Missing TableMutationHandler, not expected",
result.to_string()
);
}
#[tokio::test]
async fn [<test_ $name>]() {
let f = $func;
let factory: ScalarFunctionFactory = $func::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
let args = vec!["test"];
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from(vec![arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(UInt64Vector::from_slice([42]));
assert_eq!(expect, result);
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<arrow::array::UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 42u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(42)));
}
}
}
}
}

View File

@@ -17,7 +17,8 @@ use std::time::Duration;
use common_macro::admin_fn;
use common_meta::rpc::procedure::MigrateRegionRequest;
use common_query::error::{InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result};
use common_query::prelude::{Signature, TypeSignature, Volatility};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
@@ -103,9 +104,21 @@ fn signature() -> Signature {
Signature::one_of(
vec![
// migrate_region(region_id, from_peer, to_peer)
TypeSignature::Uniform(3, ConcreteDataType::numerics()),
TypeSignature::Uniform(
3,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
// migrate_region(region_id, from_peer, to_peer, timeout(secs))
TypeSignature::Uniform(4, ConcreteDataType::numerics()),
TypeSignature::Uniform(
4,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
],
Volatility::Immutable,
)
@@ -115,59 +128,89 @@ fn signature() -> Signature {
mod tests {
use std::sync::Arc;
use common_query::prelude::TypeSignature;
use datatypes::vectors::{StringVector, UInt64Vector, VectorRef};
use arrow::array::{StringArray, UInt64Array};
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use super::*;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[test]
fn test_migrate_region_misc() {
let f = MigrateRegionFunction;
let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
assert_eq!("migrate_region", f.name());
assert_eq!(
ConcreteDataType::string_datatype(),
f.return_type(&[]).unwrap()
);
assert_eq!(DataType::Utf8, f.return_type(&[]).unwrap());
assert!(matches!(f.signature(),
Signature {
type_signature: TypeSignature::OneOf(sigs),
volatility: Volatility::Immutable
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
volatility: datafusion_expr::Volatility::Immutable
} if sigs.len() == 2));
}
#[tokio::test]
async fn test_missing_procedure_service() {
let f = MigrateRegionFunction;
let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
let provider = factory.provide(FunctionContext::default());
let f = provider.as_async().unwrap();
let args = vec![1, 1, 1];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
Arc::new(Field::new("arg_2", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap_err();
assert_eq!(
"Missing ProcedureServiceHandler, not expected",
"Execution error: Handler error: Missing ProcedureServiceHandler, not expected",
result.to_string()
);
}
#[tokio::test]
async fn test_migrate_region() {
let f = MigrateRegionFunction;
let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let args = vec![1, 1, 1];
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
Arc::new(Field::new("arg_2", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
}
}

View File

@@ -14,13 +14,15 @@
use api::v1::meta::reconcile_request::Target;
use api::v1::meta::{ReconcileCatalog, ReconcileRequest};
use arrow::datatypes::DataType as ArrowDataType;
use common_macro::admin_fn;
use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, TypeSignature, Volatility};
use common_telemetry::info;
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::*;
use session::context::QueryContextRef;
@@ -104,15 +106,15 @@ fn signature() -> Signature {
let mut signs = Vec::with_capacity(2 + nums.len());
signs.extend([
// reconcile_catalog()
TypeSignature::NullAry,
TypeSignature::Nullary,
// reconcile_catalog(resolve_strategy)
TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
TypeSignature::Exact(vec![ArrowDataType::Utf8]),
]);
for sign in nums {
// reconcile_catalog(resolve_strategy, parallelism)
signs.push(TypeSignature::Exact(vec![
ConcreteDataType::string_datatype(),
sign,
ArrowDataType::Utf8,
sign.as_arrow_type(),
]));
}
Signature::one_of(signs, Volatility::Immutable)
@@ -120,60 +122,149 @@ fn signature() -> Signature {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::{StringVector, UInt64Vector, VectorRef};
use arrow::array::{StringArray, UInt64Array};
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use crate::admin::reconcile_catalog::ReconcileCatalogFunction;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[tokio::test]
async fn test_reconcile_catalog() {
common_telemetry::init_default_ut_logging();
// reconcile_catalog()
let f = ReconcileCatalogFunction;
let args = vec![];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![],
arg_fields: vec![],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
// reconcile_catalog(resolve_strategy)
let f = ReconcileCatalogFunction;
let args = vec![Arc::new(StringVector::from(vec!["UseMetasrv"])) as _];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"UseMetasrv",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
// reconcile_catalog(resolve_strategy, parallelism)
let f = ReconcileCatalogFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt64Vector::from_slice([10])) as _,
];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![10]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
// unsupported input data type
let f = ReconcileCatalogFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(StringVector::from(vec!["test"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::UnsupportedInputDataType { .. });
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
// invalid function args
let f = ReconcileCatalogFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt64Vector::from_slice([10])) as _,
Arc::new(StringVector::from(vec!["10"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::InvalidFuncArgs { .. });
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![10]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["10"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
Arc::new(Field::new("arg_2", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
}
}

View File

@@ -14,13 +14,15 @@
use api::v1::meta::reconcile_request::Target;
use api::v1::meta::{ReconcileDatabase, ReconcileRequest};
use arrow::datatypes::DataType as ArrowDataType;
use common_macro::admin_fn;
use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, TypeSignature, Volatility};
use common_telemetry::info;
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::*;
use session::context::QueryContextRef;
@@ -113,19 +115,16 @@ fn signature() -> Signature {
let mut signs = Vec::with_capacity(2 + nums.len());
signs.extend([
// reconcile_database(datanode_name)
TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
TypeSignature::Exact(vec![ArrowDataType::Utf8]),
// reconcile_database(database_name, resolve_strategy)
TypeSignature::Exact(vec![
ConcreteDataType::string_datatype(),
ConcreteDataType::string_datatype(),
]),
TypeSignature::Exact(vec![ArrowDataType::Utf8, ArrowDataType::Utf8]),
]);
for sign in nums {
// reconcile_database(database_name, resolve_strategy, parallelism)
signs.push(TypeSignature::Exact(vec![
ConcreteDataType::string_datatype(),
ConcreteDataType::string_datatype(),
sign,
ArrowDataType::Utf8,
ArrowDataType::Utf8,
sign.as_arrow_type(),
]));
}
Signature::one_of(signs, Volatility::Immutable)
@@ -133,66 +132,160 @@ fn signature() -> Signature {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::{StringVector, UInt32Vector, VectorRef};
use arrow::array::{StringArray, UInt32Array};
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use crate::admin::reconcile_database::ReconcileDatabaseFunction;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[tokio::test]
async fn test_reconcile_catalog() {
common_telemetry::init_default_ut_logging();
// reconcile_database(database_name)
let f = ReconcileDatabaseFunction;
let args = vec![Arc::new(StringVector::from(vec!["test"])) as _];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"test",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
// reconcile_database(database_name, resolve_strategy)
let f = ReconcileDatabaseFunction;
let args = vec![
Arc::new(StringVector::from(vec!["test"])) as _,
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
// reconcile_database(database_name, resolve_strategy, parallelism)
let f = ReconcileDatabaseFunction;
let args = vec![
Arc::new(StringVector::from(vec!["test"])) as _,
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt32Vector::from_slice([10])) as _,
];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt32Array::from(vec![10]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
Arc::new(Field::new("arg_2", DataType::UInt32, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
// invalid function args
let f = ReconcileDatabaseFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt32Vector::from_slice([10])) as _,
Arc::new(StringVector::from(vec!["v1"])) as _,
Arc::new(StringVector::from(vec!["v2"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::InvalidFuncArgs { .. });
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt32Array::from(vec![10]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["v1"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["v2"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::UInt32, false)),
Arc::new(Field::new("arg_2", DataType::Utf8, false)),
Arc::new(Field::new("arg_3", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
// unsupported input data type
let f = ReconcileDatabaseFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt32Vector::from_slice([10])) as _,
Arc::new(StringVector::from(vec!["v1"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::UnsupportedInputDataType { .. });
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt32Array::from(vec![10]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["v1"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::UInt32, false)),
Arc::new(Field::new("arg_2", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
}
}

View File

@@ -14,14 +14,15 @@
use api::v1::meta::reconcile_request::Target;
use api::v1::meta::{ReconcileRequest, ReconcileTable, ResolveStrategy};
use arrow::datatypes::DataType as ArrowDataType;
use common_catalog::format_full_table_name;
use common_error::ext::BoxedError;
use common_macro::admin_fn;
use common_query::error::{
MissingProcedureServiceHandlerSnafu, Result, TableMutationSnafu, UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, TypeSignature, Volatility};
use common_telemetry::info;
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::prelude::*;
use session::context::QueryContextRef;
use session::table_name::table_name_to_full_name;
@@ -93,12 +94,9 @@ fn signature() -> Signature {
Signature::one_of(
vec![
// reconcile_table(table_name)
TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
TypeSignature::Exact(vec![ArrowDataType::Utf8]),
// reconcile_table(table_name, resolve_strategy)
TypeSignature::Exact(vec![
ConcreteDataType::string_datatype(),
ConcreteDataType::string_datatype(),
]),
TypeSignature::Exact(vec![ArrowDataType::Utf8, ArrowDataType::Utf8]),
],
Volatility::Immutable,
)
@@ -106,44 +104,101 @@ fn signature() -> Signature {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::{StringVector, VectorRef};
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use crate::admin::reconcile_table::ReconcileTableFunction;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[tokio::test]
async fn test_reconcile_table() {
common_telemetry::init_default_ut_logging();
// reconcile_table(table_name)
let f = ReconcileTableFunction;
let args = vec![Arc::new(StringVector::from(vec!["test"])) as _];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
let factory: ScalarFunctionFactory = ReconcileTableFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"test",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
// reconcile_table(table_name, resolve_strategy)
let f = ReconcileTableFunction;
let args = vec![
Arc::new(StringVector::from(vec!["test"])) as _,
Arc::new(StringVector::from(vec!["UseMetasrv"])) as _,
];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
let factory: ScalarFunctionFactory = ReconcileTableFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseMetasrv"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
// unsupported input data type
let f = ReconcileTableFunction;
let args = vec![
Arc::new(StringVector::from(vec!["test"])) as _,
Arc::new(StringVector::from(vec!["UseMetasrv"])) as _,
Arc::new(StringVector::from(vec!["10"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::UnsupportedInputDataType { .. });
let factory: ScalarFunctionFactory = ReconcileTableFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseMetasrv"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["10"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
Arc::new(Field::new("arg_2", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
}
}

View File

@@ -18,7 +18,8 @@ use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, TypeSignature, Volatility};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
@@ -82,7 +83,13 @@ fn signature() -> Signature {
Signature::one_of(
vec![
// remove_region_follower(region_id, peer_id)
TypeSignature::Uniform(2, ConcreteDataType::numerics()),
TypeSignature::Uniform(
2,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
],
Volatility::Immutable,
)
@@ -92,38 +99,57 @@ fn signature() -> Signature {
mod tests {
use std::sync::Arc;
use common_query::prelude::TypeSignature;
use datatypes::vectors::{UInt64Vector, VectorRef};
use arrow::array::UInt64Array;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use super::*;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[test]
fn test_remove_region_follower_misc() {
let f = RemoveRegionFollowerFunction;
let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
assert_eq!("remove_region_follower", f.name());
assert_eq!(
ConcreteDataType::uint64_datatype(),
f.return_type(&[]).unwrap()
);
assert_eq!(DataType::UInt64, f.return_type(&[]).unwrap());
assert!(matches!(f.signature(),
Signature {
type_signature: TypeSignature::OneOf(sigs),
volatility: Volatility::Immutable
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
volatility: datafusion_expr::Volatility::Immutable
} if sigs.len() == 1));
}
#[tokio::test]
async fn test_remove_region_follower() {
let f = RemoveRegionFollowerFunction;
let args = vec![1, 1];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(UInt64Vector::from_slice([0u64]));
assert_eq!(result, expect);
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 0u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(0)));
}
}
}
}

View File

@@ -12,29 +12,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::datatypes::DataType as ArrowDataType;
use common_error::ext::BoxedError;
use common_macro::admin_fn;
use common_query::error::{
ExecuteSnafu, InvalidFuncArgsSnafu, MissingFlowServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::Signature;
use datafusion::logical_expr::Volatility;
use datafusion_expr::{Signature, Volatility};
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
use snafu::{ensure, ResultExt};
use sql::ast::ObjectNamePartExt;
use sql::parser::ParserContext;
use store_api::storage::ConcreteDataType;
use crate::handlers::FlowServiceHandlerRef;
fn flush_signature() -> Signature {
Signature::uniform(
1,
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
}
#[admin_fn(
@@ -106,44 +101,55 @@ fn parse_flush_flow(
mod test {
use std::sync::Arc;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::StringVector;
use session::context::QueryContext;
use super::*;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[test]
fn test_flush_flow_metadata() {
let f = FlushFlowFunction;
let factory: ScalarFunctionFactory = FlushFlowFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
assert_eq!("flush_flow", f.name());
assert_eq!(
ConcreteDataType::uint64_datatype(),
f.return_type(&[]).unwrap()
);
assert_eq!(
f.signature(),
Signature::uniform(
1,
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
assert_eq!(ArrowDataType::UInt64, f.return_type(&[]).unwrap());
let expected_signature = datafusion_expr::Signature::uniform(
1,
vec![ArrowDataType::Utf8],
datafusion_expr::Volatility::Immutable,
);
assert_eq!(*f.signature(), expected_signature);
}
#[tokio::test]
async fn test_missing_flow_service() {
let f = FlushFlowFunction;
let factory: ScalarFunctionFactory = FlushFlowFunction::factory().into();
let binding = factory.provide(FunctionContext::default());
let f = binding.as_async().unwrap();
let args = vec!["flow_name"];
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from_slice(&[arg])) as _)
.collect::<Vec<_>>();
let flow_name_array = Arc::new(arrow::array::StringArray::from(vec!["flow_name"]));
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
let columnar_args = vec![datafusion_expr::ColumnarValue::Array(flow_name_array as _)];
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: columnar_args,
arg_fields: vec![Arc::new(arrow::datatypes::Field::new(
"arg_0",
ArrowDataType::Utf8,
false,
))],
return_field: Arc::new(arrow::datatypes::Field::new(
"result",
ArrowDataType::UInt64,
true,
)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap_err();
assert_eq!(
"Missing FlowServiceHandler, not expected",
"Execution error: Handler error: Missing FlowServiceHandler, not expected",
result.to_string()
);
}

View File

@@ -41,6 +41,12 @@ impl FunctionContext {
}
}
impl std::fmt::Display for FunctionContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "FunctionContext {{ query_ctx: {} }}", self.query_ctx)
}
}
impl Default for FunctionContext {
fn default() -> Self {
Self {
@@ -67,22 +73,3 @@ pub trait Function: fmt::Display + Sync + Send {
}
pub type FunctionRef = Arc<dyn Function>;
/// Async Scalar function trait
#[async_trait::async_trait]
pub trait AsyncFunction: fmt::Display + Sync + Send {
/// Returns the name of the function, should be unique.
fn name(&self) -> &str;
/// The returned data type of function execution.
fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType>;
/// The signature of function.
fn signature(&self) -> Signature;
/// Evaluate the function, e.g. run/execute the function.
/// TODO(dennis): simplify the signature and refactor all the admin functions.
async fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef>;
}
pub type AsyncFunctionRef = Arc<dyn AsyncFunction>;

View File

@@ -22,8 +22,8 @@ use crate::scalars::udf::create_udf;
/// A factory for creating `ScalarUDF` that require a function context.
#[derive(Clone)]
pub struct ScalarFunctionFactory {
name: String,
factory: Arc<dyn Fn(FunctionContext) -> ScalarUDF + Send + Sync>,
pub(crate) name: String,
pub(crate) factory: Arc<dyn Fn(FunctionContext) -> ScalarUDF + Send + Sync>,
}
impl ScalarFunctionFactory {

View File

@@ -24,7 +24,7 @@ use crate::aggrs::aggr_wrapper::StateMergeHelper;
use crate::aggrs::approximate::ApproximateFunction;
use crate::aggrs::count_hash::CountHash;
use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
use crate::function::{AsyncFunctionRef, Function, FunctionRef};
use crate::function::{Function, FunctionRef};
use crate::function_factory::ScalarFunctionFactory;
use crate::scalars::date::DateFunction;
use crate::scalars::expression::ExpressionFunction;
@@ -42,11 +42,18 @@ use crate::system::SystemFunction;
#[derive(Default)]
pub struct FunctionRegistry {
functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
async_functions: RwLock<HashMap<String, AsyncFunctionRef>>,
aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
}
impl FunctionRegistry {
/// Register a function in the registry by converting it into a `ScalarFunctionFactory`.
///
/// # Arguments
///
/// * `func` - An object that can be converted into a `ScalarFunctionFactory`.
///
/// The function is inserted into the internal function map, keyed by its name.
/// If a function with the same name already exists, it will be replaced.
pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
let func = func.into();
let _ = self
@@ -56,18 +63,12 @@ impl FunctionRegistry {
.insert(func.name().to_string(), func);
}
/// Register a scalar function in the registry.
pub fn register_scalar(&self, func: impl Function + 'static) {
self.register(Arc::new(func) as FunctionRef);
}
pub fn register_async(&self, func: AsyncFunctionRef) {
let _ = self
.async_functions
.write()
.unwrap()
.insert(func.name().to_string(), func);
}
/// Register an aggregate function in the registry.
pub fn register_aggr(&self, func: AggregateUDF) {
let _ = self
.aggregate_functions
@@ -76,28 +77,16 @@ impl FunctionRegistry {
.insert(func.name().to_string(), func);
}
pub fn get_async_function(&self, name: &str) -> Option<AsyncFunctionRef> {
self.async_functions.read().unwrap().get(name).cloned()
}
pub fn async_functions(&self) -> Vec<AsyncFunctionRef> {
self.async_functions
.read()
.unwrap()
.values()
.cloned()
.collect()
}
#[cfg(test)]
pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
self.functions.read().unwrap().get(name).cloned()
}
/// Returns a list of all scalar functions registered in the registry.
pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
self.functions.read().unwrap().values().cloned().collect()
}
/// Returns a list of all aggregate functions registered in the registry.
pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
self.aggregate_functions
.read()
@@ -107,6 +96,7 @@ impl FunctionRegistry {
.collect()
}
/// Returns true if an aggregate function with the given name exists in the registry.
pub fn is_aggr_func_exist(&self, name: &str) -> bool {
self.aggregate_functions.read().unwrap().contains_key(name)
}

View File

@@ -19,8 +19,6 @@ mod procedure_state;
mod timezone;
mod version;
use std::sync::Arc;
use build::BuildFunction;
use database::{
ConnectionIdFunction, CurrentSchemaFunction, DatabaseFunction, PgBackendPidFunction,
@@ -46,7 +44,7 @@ impl SystemFunction {
registry.register_scalar(PgBackendPidFunction);
registry.register_scalar(ConnectionIdFunction);
registry.register_scalar(TimezoneFunction);
registry.register_async(Arc::new(ProcedureStateFunction));
registry.register(ProcedureStateFunction::factory());
PGCatalogFunction::register(registry);
}
}

View File

@@ -13,13 +13,14 @@
// limitations under the License.
use api::v1::meta::ProcedureStatus;
use arrow::datatypes::DataType as ArrowDataType;
use common_macro::admin_fn;
use common_meta::rpc::procedure::ProcedureStateResponse;
use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, Volatility};
use datafusion_expr::{Signature, Volatility};
use datatypes::prelude::*;
use serde::Serialize;
use session::context::QueryContextRef;
@@ -81,73 +82,86 @@ pub(crate) async fn procedure_state(
}
fn signature() -> Signature {
Signature::uniform(
1,
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use common_query::prelude::TypeSignature;
use datatypes::vectors::StringVector;
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use super::*;
use crate::function::{AsyncFunction, FunctionContext};
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[test]
fn test_procedure_state_misc() {
let f = ProcedureStateFunction;
let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
assert_eq!("procedure_state", f.name());
assert_eq!(
ConcreteDataType::string_datatype(),
f.return_type(&[]).unwrap()
);
assert_eq!(DataType::Utf8, f.return_type(&[]).unwrap());
assert!(matches!(f.signature(),
Signature {
type_signature: TypeSignature::Uniform(1, valid_types),
volatility: Volatility::Immutable
} if valid_types == vec![ConcreteDataType::string_datatype()]
));
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::Uniform(1, valid_types),
volatility: datafusion_expr::Volatility::Immutable
} if valid_types == &vec![ArrowDataType::Utf8]));
}
#[tokio::test]
async fn test_missing_procedure_service() {
let f = ProcedureStateFunction;
let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
let binding = factory.provide(FunctionContext::default());
let f = binding.as_async().unwrap();
let args = vec!["pid"];
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from_slice(&[arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
assert_eq!(
"Missing ProcedureServiceHandler, not expected",
result.to_string()
);
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"pid",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_procedure_state() {
let f = ProcedureStateFunction;
let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let args = vec!["pid"];
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"pid",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from_slice(&[arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec![
"{\"status\":\"Done\",\"error\":\"OK\"}",
]));
assert_eq!(expect, result);
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(
result_array.value(0),
"{\"status\":\"Done\",\"error\":\"OK\"}"
);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some(
"{\"status\":\"Done\",\"error\":\"OK\"}".to_string()
))
);
}
}
}
}