From e428a8444687e0ff5905b09ec4027d45997c8bc2 Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Fri, 13 Jan 2023 14:35:03 +0800 Subject: [PATCH] feat: use Python Script as UDF in SQL (#839) * feat: reg PyScript as UDF * refactor: use `ConcreteDataType` instead * fix: accept `str` data type * fix: allow binary to capture SIGINT * test: add test for py udf * Update src/servers/tests/py_script/mod.rs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * style: clippy problem * style: add newline * chore: PR advices Co-authored-by: Ruihang Xia Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Cargo.lock | 22 ++-- src/common/query/src/error.rs | 21 +++- src/script/Cargo.toml | 16 +-- src/script/src/manager.rs | 4 + src/script/src/python/coprocessor.rs | 21 ++-- src/script/src/python/coprocessor/parse.rs | 28 ++--- src/script/src/python/engine.rs | 134 ++++++++++++++++++++- src/script/src/python/test.rs | 2 +- src/script/src/python/testcases.ron | 30 ++--- src/servers/tests/mod.rs | 3 + src/servers/tests/py_script/mod.rs | 56 +++++++++ 11 files changed, 276 insertions(+), 61 deletions(-) create mode 100644 src/servers/tests/py_script/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 7559787902..03e5f33374 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5932,7 +5932,7 @@ dependencies = [ [[package]] name = "rustpython-ast" version = "0.1.0" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "num-bigint", "rustpython-common", @@ -5942,7 +5942,7 @@ dependencies = [ [[package]] name = "rustpython-codegen" version = "0.1.2" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "ahash 0.7.6", "bitflags", @@ -5959,7 +5959,7 @@ dependencies = [ [[package]] name = "rustpython-common" version = "0.0.0" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "ascii", "cfg-if 1.0.0", @@ -5982,7 +5982,7 @@ dependencies = [ [[package]] name = "rustpython-compiler" version = "0.1.2" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "rustpython-codegen", "rustpython-compiler-core", @@ -5993,7 +5993,7 @@ dependencies = [ [[package]] name = "rustpython-compiler-core" version = "0.1.2" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "bincode 1.3.3", "bitflags", @@ -6010,7 +6010,7 @@ dependencies = [ [[package]] name = "rustpython-derive" version = "0.1.2" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "rustpython-compiler", "rustpython-derive-impl", @@ -6020,7 +6020,7 @@ dependencies = [ [[package]] name = "rustpython-derive-impl" version = "0.0.0" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "indexmap", "itertools", @@ -6046,7 +6046,7 @@ dependencies = [ [[package]] name = "rustpython-parser" version = "0.1.2" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "ahash 0.7.6", "anyhow", @@ -6071,7 +6071,7 @@ dependencies = [ [[package]] name = "rustpython-pylib" version = "0.1.0" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "glob", "rustpython-compiler-core", @@ -6081,7 +6081,7 @@ dependencies = [ [[package]] name = "rustpython-stdlib" version = "0.1.2" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "adler32", "ahash 0.7.6", @@ -6146,7 +6146,7 @@ dependencies = [ [[package]] name = "rustpython-vm" version = "0.1.2" -source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6" +source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca" dependencies = [ "adler32", "ahash 0.7.6", diff --git a/src/common/query/src/error.rs b/src/common/query/src/error.rs index 553ccd6919..22c2da80de 100644 --- a/src/common/query/src/error.rs +++ b/src/common/query/src/error.rs @@ -16,6 +16,7 @@ use std::any::Any; use arrow::error::ArrowError; use common_error::prelude::*; +use common_recordbatch::error::Error as RecordbatchError; use datafusion_common::DataFusionError; use datatypes::arrow; use datatypes::arrow::datatypes::DataType as ArrowDatatype; @@ -26,6 +27,22 @@ use statrs::StatsError; #[derive(Debug, Snafu)] #[snafu(visibility(pub))] pub enum Error { + #[snafu(display("Fail to execute Python UDF, source: {}", msg))] + PyUdf { + // TODO(discord9): find a way that prevent circle depend(query<-script<-query) and can use script's error type + msg: String, + backtrace: Backtrace, + }, + + #[snafu(display( + "Fail to create temporary recordbatch when eval Python UDF, source: {}", + source + ))] + UdfTempRecordBatch { + #[snafu(backtrace)] + source: RecordbatchError, + }, + #[snafu(display("Fail to execute function, source: {}", source))] ExecuteFunction { source: DataFusionError, @@ -167,7 +184,9 @@ pub type Result = std::result::Result; impl ErrorExt for Error { fn status_code(&self) -> StatusCode { match self { - Error::ExecuteFunction { .. } + Error::UdfTempRecordBatch { .. } + | Error::PyUdf { .. } + | Error::ExecuteFunction { .. } | Error::GenerateFunction { .. } | Error::CreateAccumulator { .. } | Error::DowncastVector { .. } diff --git a/src/script/Cargo.toml b/src/script/Cargo.toml index 433c3b3c5b..c9cf6a5b46 100644 --- a/src/script/Cargo.toml +++ b/src/script/Cargo.toml @@ -45,16 +45,16 @@ once_cell = "1.17.0" paste = { workspace = true, optional = true } query = { path = "../query" } # TODO(discord9): This is a forked and tweaked version of RustPython, please update it to newest original RustPython After Update toolchain to 1.65 -rustpython-ast = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" } -rustpython-codegen = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" } -rustpython-compiler = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" } -rustpython-compiler-core = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" } -rustpython-parser = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" } -rustpython-pylib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537", features = [ +rustpython-ast = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" } +rustpython-codegen = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" } +rustpython-compiler = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" } +rustpython-compiler-core = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" } +rustpython-parser = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" } +rustpython-pylib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345", features = [ "freeze-stdlib", ] } -rustpython-stdlib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" } -rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537", features = [ +rustpython-stdlib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" } +rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345", features = [ "default", "codegen", ] } diff --git a/src/script/src/manager.rs b/src/script/src/manager.rs index 01d4d56c53..59c6ea62cb 100644 --- a/src/script/src/manager.rs +++ b/src/script/src/manager.rs @@ -58,6 +58,10 @@ impl ScriptManager { logging::info!("Compiled and cached script: {}", name); + script.as_ref().register_udf(); + + logging::info!("Script register as UDF: {}", name); + Ok(script) } diff --git a/src/script/src/python/coprocessor.rs b/src/script/src/python/coprocessor.rs index 165691912f..f05f53dcb1 100644 --- a/src/script/src/python/coprocessor.rs +++ b/src/script/src/python/coprocessor.rs @@ -24,7 +24,6 @@ use common_recordbatch::RecordBatch; use common_telemetry::info; use datatypes::arrow::array::Array; use datatypes::arrow::compute; -use datatypes::arrow::datatypes::DataType as ArrowDataType; use datatypes::data_type::{ConcreteDataType, DataType}; use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; use datatypes::vectors::{Helper, VectorRef}; @@ -55,7 +54,7 @@ thread_local!(static INTERPRETER: RefCell>> = RefCell::n pub struct AnnotationInfo { /// if None, use types inferred by PyVector // TODO(yingwen): We should use our data type. i.e. ConcreteDataType. - pub datatype: Option, + pub datatype: Option, pub is_nullable: bool, } @@ -122,14 +121,12 @@ impl Coprocessor { } = anno[idx].to_owned().unwrap_or_else(|| { // default to be not nullable and use DataType inferred by PyVector itself AnnotationInfo { - datatype: Some(real_ty.as_arrow_type()), + datatype: Some(real_ty.clone()), is_nullable: false, } }); let column_type = match ty { - Some(arrow_type) => { - ConcreteDataType::try_from(&arrow_type).context(TypeCastSnafu)? - } + Some(anno_type) => anno_type, // if type is like `_` or `_ | None` None => real_ty, }; @@ -165,9 +162,10 @@ impl Coprocessor { { let real_ty = col.data_type(); let anno_ty = datatype; - if real_ty.as_arrow_type() != *anno_ty { + if real_ty != *anno_ty { let array = col.to_arrow_array(); - let array = compute::cast(&array, anno_ty).context(ArrowSnafu)?; + let array = + compute::cast(&array, &anno_ty.as_arrow_type()).context(ArrowSnafu)?; *col = Helper::try_into_vector(array).context(TypeCastSnafu)?; } } @@ -223,6 +221,7 @@ fn check_args_anno_real_type( for (idx, arg) in args.iter().enumerate() { let anno_ty = copr.arg_types[idx].to_owned(); let real_ty = arg.to_arrow_array().data_type().to_owned(); + let real_ty = ConcreteDataType::from_arrow_type(&real_ty); let is_nullable: bool = rb.schema.column_schemas()[idx].is_nullable(); ensure!( anno_ty @@ -372,7 +371,11 @@ pub(crate) fn init_interpreter() -> Arc { let native_module_allow_list = HashSet::from([ "array", "cmath", "gc", "hashlib", "_json", "_random", "math", ]); - let interpreter = Arc::new(vm::Interpreter::with_init(Default::default(), |vm| { + // TODO(discord9): edge cases, can't use "..Default::default" because Settings is `#[non_exhaustive]` + // so more in here: https://internals.rust-lang.org/t/allow-constructing-non-exhaustive-structs-using-default-default/13868 + let mut settings = vm::Settings::default(); + settings.no_sig_int = true; + let interpreter = Arc::new(vm::Interpreter::with_init(settings, |vm| { // not using full stdlib to prevent security issue, instead filter out a few simple util module vm.add_native_modules( rustpython_stdlib::get_module_inits() diff --git a/src/script/src/python/coprocessor/parse.rs b/src/script/src/python/coprocessor/parse.rs index 57b348552a..570c3368f8 100644 --- a/src/script/src/python/coprocessor/parse.rs +++ b/src/script/src/python/coprocessor/parse.rs @@ -14,7 +14,7 @@ use std::collections::HashSet; -use datatypes::arrow::datatypes::DataType; +use datatypes::prelude::ConcreteDataType; use rustpython_parser::ast::{Arguments, Location}; use rustpython_parser::{ast, parser}; #[cfg(test)] @@ -81,20 +81,20 @@ fn pylist_to_vec(lst: &ast::Expr<()>) -> Result> { } } -fn try_into_datatype(ty: &str, loc: &Location) -> Result> { +fn try_into_datatype(ty: &str, loc: &Location) -> Result> { match ty { - "bool" => Ok(Some(DataType::Boolean)), - "u8" => Ok(Some(DataType::UInt8)), - "u16" => Ok(Some(DataType::UInt16)), - "u32" => Ok(Some(DataType::UInt32)), - "u64" => Ok(Some(DataType::UInt64)), - "i8" => Ok(Some(DataType::Int8)), - "i16" => Ok(Some(DataType::Int16)), - "i32" => Ok(Some(DataType::Int32)), - "i64" => Ok(Some(DataType::Int64)), - "f16" => Ok(Some(DataType::Float16)), - "f32" => Ok(Some(DataType::Float32)), - "f64" => Ok(Some(DataType::Float64)), + "bool" => Ok(Some(ConcreteDataType::boolean_datatype())), + "u8" => Ok(Some(ConcreteDataType::uint8_datatype())), + "u16" => Ok(Some(ConcreteDataType::uint16_datatype())), + "u32" => Ok(Some(ConcreteDataType::uint32_datatype())), + "u64" => Ok(Some(ConcreteDataType::uint64_datatype())), + "i8" => Ok(Some(ConcreteDataType::int8_datatype())), + "i16" => Ok(Some(ConcreteDataType::int16_datatype())), + "i32" => Ok(Some(ConcreteDataType::int32_datatype())), + "i64" => Ok(Some(ConcreteDataType::int64_datatype())), + "f32" => Ok(Some(ConcreteDataType::float32_datatype())), + "f64" => Ok(Some(ConcreteDataType::float64_datatype())), + "str" => Ok(Some(ConcreteDataType::string_datatype())), // for any datatype "_" => Ok(None), // note the different between "_" and _ diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index 8b1bf46f94..fe1060ab9e 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -20,10 +20,15 @@ use std::task::{Context, Poll}; use async_trait::async_trait; use common_error::prelude::BoxedError; +use common_function::scalars::{Function, FUNCTION_REGISTRY}; +use common_query::error::{PyUdfSnafu, UdfTempRecordBatchSnafu}; +use common_query::prelude::Signature; use common_query::Output; use common_recordbatch::error::{ExternalSnafu, Result as RecordBatchResult}; use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream}; -use datatypes::schema::SchemaRef; +use datafusion_expr::Volatility; +use datatypes::schema::{ColumnSchema, SchemaRef}; +use datatypes::vectors::VectorRef; use futures::Stream; use query::parser::{QueryLanguageParser, QueryStatement}; use query::QueryEngineRef; @@ -32,16 +37,141 @@ use snafu::{ensure, ResultExt}; use sql::statements::statement::Statement; use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine}; -use crate::python::coprocessor::{exec_parsed, parse, CoprocessorRef}; +use crate::python::coprocessor::{exec_parsed, parse, AnnotationInfo, CoprocessorRef}; use crate::python::error::{self, Result}; const PY_ENGINE: &str = "python"; +#[derive(Debug)] +pub struct PyUDF { + copr: CoprocessorRef, +} + +impl std::fmt::Display for PyUDF { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}({})->", + &self.copr.name, + &self.copr.deco_args.arg_names.join(",") + ) + } +} + +impl PyUDF { + fn from_copr(copr: CoprocessorRef) -> Arc { + Arc::new(Self { copr }) + } + + /// Register to `FUNCTION_REGISTRY` + fn register_as_udf(zelf: Arc) { + FUNCTION_REGISTRY.register(zelf) + } + fn register_to_query_engine(zelf: Arc, engine: QueryEngineRef) { + engine.register_function(zelf) + } + + /// Fake a schema, should only be used with dynamically eval a Python Udf + fn fake_schema(&self, columns: &[VectorRef]) -> SchemaRef { + let arg_names = &self.copr.deco_args.arg_names; + let col_sch: Vec<_> = columns + .iter() + .enumerate() + .map(|(i, col)| ColumnSchema::new(arg_names[i].to_owned(), col.data_type(), true)) + .collect(); + let schema = datatypes::schema::Schema::new(col_sch); + Arc::new(schema) + } +} + +impl Function for PyUDF { + fn name(&self) -> &str { + &self.copr.name + } + + fn return_type( + &self, + _input_types: &[datatypes::prelude::ConcreteDataType], + ) -> common_query::error::Result { + // TODO(discord9): use correct return annotation if exist + match self.copr.return_types.get(0) { + Some(Some(AnnotationInfo { + datatype: Some(ty), .. + })) => Ok(ty.to_owned()), + _ => PyUdfSnafu { + msg: "Can't found return type for python UDF {self}", + } + .fail(), + } + } + + fn signature(&self) -> common_query::prelude::Signature { + // try our best to get a type signature + let mut arg_types = Vec::with_capacity(self.copr.arg_types.len()); + let mut know_all_types = true; + for ty in self.copr.arg_types.iter() { + match ty { + Some(AnnotationInfo { + datatype: Some(ty), .. + }) => arg_types.push(ty.to_owned()), + _ => { + know_all_types = false; + break; + } + } + } + if know_all_types { + Signature::variadic(arg_types, Volatility::Immutable) + } else { + Signature::any(self.copr.arg_types.len(), Volatility::Immutable) + } + } + + fn eval( + &self, + _func_ctx: common_function::scalars::function::FunctionContext, + columns: &[datatypes::vectors::VectorRef], + ) -> common_query::error::Result { + // FIXME(discord9): exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right? + let schema = self.fake_schema(columns); + let columns = columns.to_vec(); + // TODO(discord9): remove unwrap + let rb = RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?; + let res = exec_parsed(&self.copr, &rb).map_err(|err| { + PyUdfSnafu { + msg: format!("{err:#?}"), + } + .build() + })?; + let len = res.columns().len(); + if len == 0 { + return PyUdfSnafu { + msg: "Python UDF should return exactly one column, found zero column".to_string(), + } + .fail(); + } // if more than one columns, just return first one + + // TODO(discord9): more error handling + let res0 = res.column(0); + Ok(res0.to_owned()) + } +} + pub struct PyScript { query_engine: QueryEngineRef, copr: CoprocessorRef, } +impl PyScript { + /// Register Current Script as UDF, register name is same as script name + /// FIXME(discord9): possible inject attack? + pub fn register_udf(&self) { + let udf = PyUDF::from_copr(self.copr.clone()); + PyUDF::register_as_udf(udf.clone()); + PyUDF::register_to_query_engine(udf, self.query_engine.to_owned()); + } +} + pub struct CoprStream { stream: SendableRecordBatchStream, copr: CoprocessorRef, diff --git a/src/script/src/python/test.rs b/src/script/src/python/test.rs index 59502fe8c3..99fb5b82fe 100644 --- a/src/script/src/python/test.rs +++ b/src/script/src/python/test.rs @@ -132,7 +132,7 @@ fn run_ron_testcases() { .zip(res.schema.column_schemas()) .for_each(|(anno, real)| { assert!( - anno.datatype.as_ref().unwrap() == &real.data_type.as_arrow_type() + anno.datatype.as_ref().unwrap() == &real.data_type && anno.is_nullable == real.is_nullable(), "Fields expected to be {anno:#?}, actual {real:#?}" ); diff --git a/src/script/src/python/testcases.ron b/src/script/src/python/testcases.ron index 8e2415429a..ab70ec0662 100644 --- a/src/script/src/python/testcases.ron +++ b/src/script/src/python/testcases.ron @@ -24,21 +24,21 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vecto ), arg_types: [ Some(( - datatype: Some(Float32), + datatype: Some(Float32(())), is_nullable: false )), Some(( - datatype: Some(Float64), + datatype: Some(Float64(())), is_nullable: false )), ], return_types: [ Some(( - datatype: Some(Float64), + datatype: Some(Float64(())), is_nullable: false )), Some(( - datatype: Some(Float64), + datatype: Some(Float64(())), is_nullable: true )), Some(( @@ -246,11 +246,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], predicate: ExecIsOk( fields: [ ( - datatype: Some(Float64), + datatype: Some(Float64(())), is_nullable: true ), ( - datatype: Some(Float32), + datatype: Some(Float32(())), is_nullable: false ), ], @@ -278,11 +278,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], predicate: ExecIsOk( fields: [ ( - datatype: Some(Float64), + datatype: Some(Float64(())), is_nullable: true ), ( - datatype: Some(Float32), + datatype: Some(Float32(())), is_nullable: false ), ], @@ -310,11 +310,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], predicate: ExecIsOk( fields: [ ( - datatype: Some(Float64), + datatype: Some(Float64(())), is_nullable: true ), ( - datatype: Some(Float32), + datatype: Some(Float32(())), is_nullable: false ), ], @@ -342,11 +342,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], predicate: ExecIsOk( fields: [ ( - datatype: Some(Float64), + datatype: Some(Float64(())), is_nullable: true ), ( - datatype: Some(Float32), + datatype: Some(Float32(())), is_nullable: false ), ], @@ -372,7 +372,7 @@ def a(cpu: vector[f32], mem: vector[f64]): predicate: ExecIsOk( fields: [ ( - datatype: Some(Utf8), + datatype: Some(String(())), is_nullable: false, ), ], @@ -480,11 +480,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], predicate: ExecIsOk( fields: [ ( - datatype: Some(Float64), + datatype: Some(Float64(())), is_nullable: true ), ( - datatype: Some(Float32), + datatype: Some(Float32(())), is_nullable: false ), ], diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 498940cbcb..d4b265b1ea 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -37,6 +37,8 @@ mod interceptor; mod opentsdb; mod postgres; +mod py_script; + struct DummyInstance { query_engine: QueryEngineRef, py_engine: Arc, @@ -88,6 +90,7 @@ impl ScriptHandler for DummyInstance { .compile(script, CompileContext::default()) .await .unwrap(); + script.register_udf(); self.scripts .write() .unwrap() diff --git a/src/servers/tests/py_script/mod.rs b/src/servers/tests/py_script/mod.rs new file mode 100644 index 0000000000..eba376189c --- /dev/null +++ b/src/servers/tests/py_script/mod.rs @@ -0,0 +1,56 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use servers::error::Result; +use servers::query_handler::{ScriptHandler, SqlQueryHandler}; +use session::context::QueryContext; +use table::test_util::MemTable; + +use crate::create_testing_instance; + +#[tokio::test] +async fn test_insert_py_udf_and_query() -> Result<()> { + let query_ctx = Arc::new(QueryContext::new()); + let table = MemTable::default_numbers_table(); + + let instance = create_testing_instance(table); + let src = r#" +@coprocessor(args=["uint32s"], returns = ["ret"]) +def double_that(col)->vector[u32]: + return col*2 + "#; + instance.insert_script("double_that", src).await?; + let res = instance + .do_query("select double_that(uint32s) from numbers", query_ctx) + .await + .remove(0) + .unwrap(); + match res { + common_query::Output::AffectedRows(_) => (), + common_query::Output::RecordBatches(_) => { + unreachable!() + } + common_query::Output::Stream(s) => { + let batches = common_recordbatch::util::collect_batches(s).await.unwrap(); + assert_eq!(batches.iter().count(), 1); + let first = batches.iter().next().unwrap(); + let col = first.column(0); + let val = col.get(1); + assert_eq!(val, datatypes::value::Value::UInt32(2)); + } + } + Ok(()) +}