From 3d195ff858826d4c61adf020cd41dfb5fee18607 Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Fri, 4 Nov 2022 15:49:41 +0800 Subject: [PATCH] feat: bind Greptime's own UDF&UDAF into Python Coprocessor Module (#335) * feat: port own UDF&UDAF into py copr(untest yet) * refactor: move UDF&UDAF to greptime_builtins * feat: support List in val2py_obj * test: some testcases for newly added UDFs * test: complete test for all added gpdb's own UDF * refactor: add underscore for long func name * feat: better error message * fix: typo --- .../function/src/scalars/function_registry.rs | 4 + src/script/src/python/builtins/mod.rs | 171 ++++++++++++++-- src/script/src/python/builtins/testcases.ron | 193 ++++++++++++++++++ src/script/src/python/vector.rs | 11 +- 4 files changed, 362 insertions(+), 17 deletions(-) diff --git a/src/common/function/src/scalars/function_registry.rs b/src/common/function/src/scalars/function_registry.rs index 0de2935592..fb2a99ef02 100644 --- a/src/common/function/src/scalars/function_registry.rs +++ b/src/common/function/src/scalars/function_registry.rs @@ -31,6 +31,10 @@ impl FunctionRegistry { .insert(func.name(), func); } + pub fn get_aggr_function(&self, name: &str) -> Option { + self.aggregate_functions.read().unwrap().get(name).cloned() + } + pub fn get_function(&self, name: &str) -> Option { self.functions.read().unwrap().get(name).cloned() } diff --git a/src/script/src/python/builtins/mod.rs b/src/script/src/python/builtins/mod.rs index 75e05c3bdd..e077d0031a 100644 --- a/src/script/src/python/builtins/mod.rs +++ b/src/script/src/python/builtins/mod.rs @@ -270,31 +270,45 @@ pub(crate) mod greptime_builtin { // P.S.: not extract to file because not-inlined proc macro attribute is *unstable* use std::sync::Arc; - use common_function::scalars::math::PowFunction; - use common_function::scalars::{function::FunctionContext, Function}; - use datafusion::arrow::compute::comparison::{gt_eq_scalar, lt_eq_scalar}; - use datafusion::arrow::datatypes::DataType; - use datafusion::arrow::error::ArrowError; - use datafusion::arrow::scalar::{PrimitiveScalar, Scalar}; - use datafusion::physical_plan::expressions; + use common_function::scalars::{ + function::FunctionContext, math::PowFunction, Function, FunctionRef, FUNCTION_REGISTRY, + }; + use datafusion::{ + arrow::{ + compute::comparison::{gt_eq_scalar, lt_eq_scalar}, + datatypes::DataType, + error::ArrowError, + scalar::{PrimitiveScalar, Scalar}, + }, + physical_plan::expressions, + }; use datafusion_expr::ColumnarValue as DFColValue; use datafusion_physical_expr::math_expressions; - use datatypes::arrow; - use datatypes::arrow::array::{ArrayRef, NullArray}; - use datatypes::arrow::compute; use datatypes::vectors::{ConstantVector, Float64Vector, Helper, Int64Vector}; + use datatypes::{ + arrow::{ + self, + array::{ArrayRef, NullArray}, + compute, + }, + vectors::VectorRef, + }; use paste::paste; - use rustpython_vm::builtins::{PyFloat, PyFunction, PyInt, PyStr}; - use rustpython_vm::function::{FuncArgs, KwArgs, OptionalArg}; - use rustpython_vm::{AsObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine}; + use rustpython_vm::{ + builtins::{PyFloat, PyFunction, PyInt, PyStr}, + function::{FuncArgs, KwArgs, OptionalArg}, + AsObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + }; use crate::python::builtins::{ all_to_f64, eval_aggr_fn, from_df_err, try_into_columnar_value, try_into_py_obj, type_cast_error, }; - use crate::python::utils::PyVectorRef; - use crate::python::utils::{is_instance, py_vec_obj_to_array}; - use crate::python::PyVector; + use crate::python::{ + utils::{is_instance, py_vec_obj_to_array, PyVectorRef}, + vector::val_to_pyobj, + PyVector, + }; #[pyfunction] fn vector(args: OptionalArg, vm: &VirtualMachine) -> PyResult { @@ -303,10 +317,135 @@ pub(crate) mod greptime_builtin { // the main binding code, due to proc macro things, can't directly use a simpler macro // because pyfunction is not a attr? + // ------ + // GrepTime DB's own UDF&UDAF + // ------ + + fn eval_func(name: &str, v: &[PyVectorRef], vm: &VirtualMachine) -> PyResult { + let v: Vec = v.iter().map(|v| v.as_vector_ref()).collect(); + let func: Option = FUNCTION_REGISTRY.get_function(name); + let res = match func { + Some(f) => f.eval(Default::default(), &v), + None => return Err(vm.new_type_error(format!("Can't find function {}", name))), + }; + match res { + Ok(v) => Ok(v.into()), + Err(err) => { + Err(vm.new_runtime_error(format!("Fail to evaluate the function,: {}", err))) + } + } + } + + fn eval_aggr_func( + name: &str, + args: &[PyVectorRef], + vm: &VirtualMachine, + ) -> PyResult { + let v: Vec = args.iter().map(|v| v.as_vector_ref()).collect(); + let func = FUNCTION_REGISTRY.get_aggr_function(name); + let f = match func { + Some(f) => f.create().creator(), + None => return Err(vm.new_type_error(format!("Can't find function {}", name))), + }; + let types: Vec<_> = v.iter().map(|v| v.data_type()).collect(); + let acc = f(&types); + let mut acc = match acc { + Ok(acc) => acc, + Err(err) => { + return Err(vm.new_runtime_error(format!("Failed to create accumulator: {}", err))) + } + }; + match acc.update_batch(&v) { + Ok(_) => (), + Err(err) => { + return Err(vm.new_runtime_error(format!("Failed to update batch: {}", err))) + } + }; + let res = match acc.evaluate() { + Ok(r) => r, + Err(err) => { + return Err(vm.new_runtime_error(format!("Failed to evaluate accumulator: {}", err))) + } + }; + let res = val_to_pyobj(res, vm); + Ok(res) + } + + /// GrepTime's own impl of pow function + #[pyfunction] + fn pow_gp(v0: PyVectorRef, v1: PyVectorRef, vm: &VirtualMachine) -> PyResult { + eval_func("pow", &[v0, v1], vm) + } + + #[pyfunction] + fn clip( + v0: PyVectorRef, + v1: PyVectorRef, + v2: PyVectorRef, + vm: &VirtualMachine, + ) -> PyResult { + eval_func("clip", &[v0, v1, v2], vm) + } + + #[pyfunction] + fn median(v: PyVectorRef, vm: &VirtualMachine) -> PyResult { + eval_aggr_func("median", &[v], vm) + } + + #[pyfunction] + fn diff(v: PyVectorRef, vm: &VirtualMachine) -> PyResult { + eval_aggr_func("diff", &[v], vm) + } + + #[pyfunction] + fn mean(v: PyVectorRef, vm: &VirtualMachine) -> PyResult { + eval_aggr_func("mean", &[v], vm) + } + + #[pyfunction] + fn polyval(v0: PyVectorRef, v1: PyVectorRef, vm: &VirtualMachine) -> PyResult { + eval_aggr_func("polyval", &[v0, v1], vm) + } + + #[pyfunction] + fn argmax(v0: PyVectorRef, vm: &VirtualMachine) -> PyResult { + eval_aggr_func("argmax", &[v0], vm) + } + + #[pyfunction] + fn argmin(v0: PyVectorRef, vm: &VirtualMachine) -> PyResult { + eval_aggr_func("argmin", &[v0], vm) + } + + #[pyfunction] + fn percentile(v0: PyVectorRef, v1: PyVectorRef, vm: &VirtualMachine) -> PyResult { + eval_aggr_func("percentile", &[v0, v1], vm) + } + + #[pyfunction] + fn scipy_stats_norm_cdf( + v0: PyVectorRef, + v1: PyVectorRef, + vm: &VirtualMachine, + ) -> PyResult { + eval_aggr_func("scipystatsnormcdf", &[v0, v1], vm) + } + + #[pyfunction] + fn scipy_stats_norm_pdf( + v0: PyVectorRef, + v1: PyVectorRef, + vm: &VirtualMachine, + ) -> PyResult { + eval_aggr_func("scipystatsnormpdf", &[v0, v1], vm) + } // The math function return a general PyObjectRef // so it can return both PyVector or a scalar PyInt/Float/Bool + // ------ + // DataFusion's UDF&UDAF + // ------ /// simple math function, the backing implement is datafusion's `sqrt` math function #[pyfunction] fn sqrt(val: PyObjectRef, vm: &VirtualMachine) -> PyResult { diff --git a/src/script/src/python/builtins/testcases.ron b/src/script/src/python/builtins/testcases.ron index 1d073effd1..994c50b811 100644 --- a/src/script/src/python/builtins/testcases.ron +++ b/src/script/src/python/builtins/testcases.ron @@ -924,5 +924,198 @@ sum(prev(values))"#, ty: Float64, value: Float(3.0) )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([1.0, 2.0, 3.0]) + ), + "pows": Var( + ty: Float64, + value: FloatVec([1.0, 2.0, 3.0]) + ), + }, + script: r#" +from greptime import * +pow_gp(values, pows)"#, + expect: Ok(( + ty: Float64, + value: FloatVec([1.0, 4.0, 27.0]) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0, 0.5]) + ), + "lower": Var( + ty: Float64, + value: FloatVec([0.0, 0.0, 0.0]) + ), + "upper": Var( + ty: Float64, + value: FloatVec([1.0, 1.0, 1.0]) + ), + }, + script: r#" +from greptime import * +clip(values, lower, upper)"#, + expect: Ok(( + ty: Float64, + value: FloatVec([0.0, 1.0, 0.5]) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0, 2.0, 0.5]) + ) + }, + script: r#" +from greptime import * +median(values)"#, + expect: Ok(( + ty: Float64, + value: Float(1.25) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0, 2.0, 0.5]) + ) + }, + script: r#" +from greptime import * +diff(values)"#, + expect: Ok(( + ty: Float64, + value: FloatVec([3.0, 0.0, -1.5]) + )) + ), + TestCase( + input: { + "values": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0, 2.0, 0.0]) + ) + }, + script: r#" +from greptime import * +mean(values)"#, + expect: Ok(( + ty: Float64, + value: Float(0.75) + )) + ), + TestCase( + input: { + "p": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0]) + ), + "x": Var( + ty: Int64, + value: IntVec([1, 1]) + ) + }, + script: r#" +from greptime import * +polyval(p, x)"#, + expect: Ok(( + ty: Float64, + value: Float(1.0) + )) + ), + TestCase( + input: { + "p": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0, 3.0]) + ) + }, + script: r#" +from greptime import * +argmax(p)"#, + expect: Ok(( + ty: Int64, + value: Int(2) + )) + ), + TestCase( + input: { + "p": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0, 3.0]) + ) + }, + script: r#" +from greptime import * +argmin(p)"#, + expect: Ok(( + ty: Int64, + value: Int(0) + )) + ), + TestCase( + input: { + "x": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0, 3.0]) + ), + "p": Var( + ty: Float64, + value: FloatVec([0.5, 0.5, 0.5]) + ) + }, + script: r#" +from greptime import * +percentile(x, p)"#, + expect: Ok(( + ty: Float64, + value: Float(-0.97) + )) + ), + TestCase( + input: { + "x": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0, 3.0]) + ), + "p": Var( + ty: Float64, + value: FloatVec([0.5, 0.5, 0.5]) + ) + }, + script: r#" +from greptime import * +scipy_stats_norm_cdf(x, p)"#, + expect: Ok(( + ty: Float64, + value: Float(0.3444602779022303) + )) + ), + TestCase( + input: { + "x": Var( + ty: Float64, + value: FloatVec([-1.0, 2.0, 3.0]) + ), + "p": Var( + ty: Float64, + value: FloatVec([0.5, 0.5, 0.5]) + ) + }, + script: r#" +from greptime import * +scipy_stats_norm_pdf(x, p)"#, + expect: Ok(( + ty: Float64, + value: Float(0.1768885735289059) + )) ) ] diff --git a/src/script/src/python/vector.rs b/src/script/src/python/vector.rs index 6ad17572c6..c3b583002b 100644 --- a/src/script/src/python/vector.rs +++ b/src/script/src/python/vector.rs @@ -939,7 +939,16 @@ pub fn val_to_pyobj(val: value::Value, vm: &VirtualMachine) -> PyObjectRef { value::Value::DateTime(v) => vm.ctx.new_int(v.val()).into(), // FIXME(dennis): lose the timestamp unit here Value::Timestamp(v) => vm.ctx.new_int(v.value()).into(), - value::Value::List(_) => unreachable!(), + value::Value::List(list) => { + let list = list.items().as_ref(); + match list { + Some(list) => { + let list: Vec<_> = list.iter().map(|v| val_to_pyobj(v.clone(), vm)).collect(); + vm.ctx.new_list(list).into() + } + None => vm.ctx.new_list(Vec::new()).into(), + } + } } }