mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-08 14:22:58 +00:00
feat: use Python Script as UDF in SQL (#839)
* feat: reg PyScript as UDF * refactor: use `ConcreteDataType` instead * fix: accept `str` data type * fix: allow binary to capture SIGINT * test: add test for py udf * Update src/servers/tests/py_script/mod.rs Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * style: clippy problem * style: add newline * chore: PR advices Co-authored-by: Ruihang Xia <waynestxia@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
22
Cargo.lock
generated
22
Cargo.lock
generated
@@ -5932,7 +5932,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-ast"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"rustpython-common",
|
||||
@@ -5942,7 +5942,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-codegen"
|
||||
version = "0.1.2"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"ahash 0.7.6",
|
||||
"bitflags",
|
||||
@@ -5959,7 +5959,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-common"
|
||||
version = "0.0.0"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"ascii",
|
||||
"cfg-if 1.0.0",
|
||||
@@ -5982,7 +5982,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-compiler"
|
||||
version = "0.1.2"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"rustpython-codegen",
|
||||
"rustpython-compiler-core",
|
||||
@@ -5993,7 +5993,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-compiler-core"
|
||||
version = "0.1.2"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"bincode 1.3.3",
|
||||
"bitflags",
|
||||
@@ -6010,7 +6010,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-derive"
|
||||
version = "0.1.2"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"rustpython-compiler",
|
||||
"rustpython-derive-impl",
|
||||
@@ -6020,7 +6020,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-derive-impl"
|
||||
version = "0.0.0"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"itertools",
|
||||
@@ -6046,7 +6046,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-parser"
|
||||
version = "0.1.2"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"ahash 0.7.6",
|
||||
"anyhow",
|
||||
@@ -6071,7 +6071,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-pylib"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"rustpython-compiler-core",
|
||||
@@ -6081,7 +6081,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-stdlib"
|
||||
version = "0.1.2"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"adler32",
|
||||
"ahash 0.7.6",
|
||||
@@ -6146,7 +6146,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "rustpython-vm"
|
||||
version = "0.1.2"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
|
||||
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
|
||||
dependencies = [
|
||||
"adler32",
|
||||
"ahash 0.7.6",
|
||||
|
||||
@@ -16,6 +16,7 @@ use std::any::Any;
|
||||
|
||||
use arrow::error::ArrowError;
|
||||
use common_error::prelude::*;
|
||||
use common_recordbatch::error::Error as RecordbatchError;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datatypes::arrow;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDatatype;
|
||||
@@ -26,6 +27,22 @@ use statrs::StatsError;
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
pub enum Error {
|
||||
#[snafu(display("Fail to execute Python UDF, source: {}", msg))]
|
||||
PyUdf {
|
||||
// TODO(discord9): find a way that prevent circle depend(query<-script<-query) and can use script's error type
|
||||
msg: String,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display(
|
||||
"Fail to create temporary recordbatch when eval Python UDF, source: {}",
|
||||
source
|
||||
))]
|
||||
UdfTempRecordBatch {
|
||||
#[snafu(backtrace)]
|
||||
source: RecordbatchError,
|
||||
},
|
||||
|
||||
#[snafu(display("Fail to execute function, source: {}", source))]
|
||||
ExecuteFunction {
|
||||
source: DataFusionError,
|
||||
@@ -167,7 +184,9 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::ExecuteFunction { .. }
|
||||
Error::UdfTempRecordBatch { .. }
|
||||
| Error::PyUdf { .. }
|
||||
| Error::ExecuteFunction { .. }
|
||||
| Error::GenerateFunction { .. }
|
||||
| Error::CreateAccumulator { .. }
|
||||
| Error::DowncastVector { .. }
|
||||
|
||||
@@ -45,16 +45,16 @@ once_cell = "1.17.0"
|
||||
paste = { workspace = true, optional = true }
|
||||
query = { path = "../query" }
|
||||
# TODO(discord9): This is a forked and tweaked version of RustPython, please update it to newest original RustPython After Update toolchain to 1.65
|
||||
rustpython-ast = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
|
||||
rustpython-codegen = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
|
||||
rustpython-compiler = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
|
||||
rustpython-compiler-core = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
|
||||
rustpython-parser = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
|
||||
rustpython-pylib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537", features = [
|
||||
rustpython-ast = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
|
||||
rustpython-codegen = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
|
||||
rustpython-compiler = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
|
||||
rustpython-compiler-core = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
|
||||
rustpython-parser = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
|
||||
rustpython-pylib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345", features = [
|
||||
"freeze-stdlib",
|
||||
] }
|
||||
rustpython-stdlib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
|
||||
rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537", features = [
|
||||
rustpython-stdlib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
|
||||
rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345", features = [
|
||||
"default",
|
||||
"codegen",
|
||||
] }
|
||||
|
||||
@@ -58,6 +58,10 @@ impl ScriptManager {
|
||||
|
||||
logging::info!("Compiled and cached script: {}", name);
|
||||
|
||||
script.as_ref().register_udf();
|
||||
|
||||
logging::info!("Script register as UDF: {}", name);
|
||||
|
||||
Ok(script)
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ use common_recordbatch::RecordBatch;
|
||||
use common_telemetry::info;
|
||||
use datatypes::arrow::array::Array;
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::data_type::{ConcreteDataType, DataType};
|
||||
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
@@ -55,7 +54,7 @@ thread_local!(static INTERPRETER: RefCell<Option<Arc<Interpreter>>> = RefCell::n
|
||||
pub struct AnnotationInfo {
|
||||
/// if None, use types inferred by PyVector
|
||||
// TODO(yingwen): We should use our data type. i.e. ConcreteDataType.
|
||||
pub datatype: Option<ArrowDataType>,
|
||||
pub datatype: Option<ConcreteDataType>,
|
||||
pub is_nullable: bool,
|
||||
}
|
||||
|
||||
@@ -122,14 +121,12 @@ impl Coprocessor {
|
||||
} = anno[idx].to_owned().unwrap_or_else(|| {
|
||||
// default to be not nullable and use DataType inferred by PyVector itself
|
||||
AnnotationInfo {
|
||||
datatype: Some(real_ty.as_arrow_type()),
|
||||
datatype: Some(real_ty.clone()),
|
||||
is_nullable: false,
|
||||
}
|
||||
});
|
||||
let column_type = match ty {
|
||||
Some(arrow_type) => {
|
||||
ConcreteDataType::try_from(&arrow_type).context(TypeCastSnafu)?
|
||||
}
|
||||
Some(anno_type) => anno_type,
|
||||
// if type is like `_` or `_ | None`
|
||||
None => real_ty,
|
||||
};
|
||||
@@ -165,9 +162,10 @@ impl Coprocessor {
|
||||
{
|
||||
let real_ty = col.data_type();
|
||||
let anno_ty = datatype;
|
||||
if real_ty.as_arrow_type() != *anno_ty {
|
||||
if real_ty != *anno_ty {
|
||||
let array = col.to_arrow_array();
|
||||
let array = compute::cast(&array, anno_ty).context(ArrowSnafu)?;
|
||||
let array =
|
||||
compute::cast(&array, &anno_ty.as_arrow_type()).context(ArrowSnafu)?;
|
||||
*col = Helper::try_into_vector(array).context(TypeCastSnafu)?;
|
||||
}
|
||||
}
|
||||
@@ -223,6 +221,7 @@ fn check_args_anno_real_type(
|
||||
for (idx, arg) in args.iter().enumerate() {
|
||||
let anno_ty = copr.arg_types[idx].to_owned();
|
||||
let real_ty = arg.to_arrow_array().data_type().to_owned();
|
||||
let real_ty = ConcreteDataType::from_arrow_type(&real_ty);
|
||||
let is_nullable: bool = rb.schema.column_schemas()[idx].is_nullable();
|
||||
ensure!(
|
||||
anno_ty
|
||||
@@ -372,7 +371,11 @@ pub(crate) fn init_interpreter() -> Arc<Interpreter> {
|
||||
let native_module_allow_list = HashSet::from([
|
||||
"array", "cmath", "gc", "hashlib", "_json", "_random", "math",
|
||||
]);
|
||||
let interpreter = Arc::new(vm::Interpreter::with_init(Default::default(), |vm| {
|
||||
// TODO(discord9): edge cases, can't use "..Default::default" because Settings is `#[non_exhaustive]`
|
||||
// so more in here: https://internals.rust-lang.org/t/allow-constructing-non-exhaustive-structs-using-default-default/13868
|
||||
let mut settings = vm::Settings::default();
|
||||
settings.no_sig_int = true;
|
||||
let interpreter = Arc::new(vm::Interpreter::with_init(settings, |vm| {
|
||||
// not using full stdlib to prevent security issue, instead filter out a few simple util module
|
||||
vm.add_native_modules(
|
||||
rustpython_stdlib::get_module_inits()
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use rustpython_parser::ast::{Arguments, Location};
|
||||
use rustpython_parser::{ast, parser};
|
||||
#[cfg(test)]
|
||||
@@ -81,20 +81,20 @@ fn pylist_to_vec(lst: &ast::Expr<()>) -> Result<Vec<String>> {
|
||||
}
|
||||
}
|
||||
|
||||
fn try_into_datatype(ty: &str, loc: &Location) -> Result<Option<DataType>> {
|
||||
fn try_into_datatype(ty: &str, loc: &Location) -> Result<Option<ConcreteDataType>> {
|
||||
match ty {
|
||||
"bool" => Ok(Some(DataType::Boolean)),
|
||||
"u8" => Ok(Some(DataType::UInt8)),
|
||||
"u16" => Ok(Some(DataType::UInt16)),
|
||||
"u32" => Ok(Some(DataType::UInt32)),
|
||||
"u64" => Ok(Some(DataType::UInt64)),
|
||||
"i8" => Ok(Some(DataType::Int8)),
|
||||
"i16" => Ok(Some(DataType::Int16)),
|
||||
"i32" => Ok(Some(DataType::Int32)),
|
||||
"i64" => Ok(Some(DataType::Int64)),
|
||||
"f16" => Ok(Some(DataType::Float16)),
|
||||
"f32" => Ok(Some(DataType::Float32)),
|
||||
"f64" => Ok(Some(DataType::Float64)),
|
||||
"bool" => Ok(Some(ConcreteDataType::boolean_datatype())),
|
||||
"u8" => Ok(Some(ConcreteDataType::uint8_datatype())),
|
||||
"u16" => Ok(Some(ConcreteDataType::uint16_datatype())),
|
||||
"u32" => Ok(Some(ConcreteDataType::uint32_datatype())),
|
||||
"u64" => Ok(Some(ConcreteDataType::uint64_datatype())),
|
||||
"i8" => Ok(Some(ConcreteDataType::int8_datatype())),
|
||||
"i16" => Ok(Some(ConcreteDataType::int16_datatype())),
|
||||
"i32" => Ok(Some(ConcreteDataType::int32_datatype())),
|
||||
"i64" => Ok(Some(ConcreteDataType::int64_datatype())),
|
||||
"f32" => Ok(Some(ConcreteDataType::float32_datatype())),
|
||||
"f64" => Ok(Some(ConcreteDataType::float64_datatype())),
|
||||
"str" => Ok(Some(ConcreteDataType::string_datatype())),
|
||||
// for any datatype
|
||||
"_" => Ok(None),
|
||||
// note the different between "_" and _
|
||||
|
||||
@@ -20,10 +20,15 @@ use std::task::{Context, Poll};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_function::scalars::{Function, FUNCTION_REGISTRY};
|
||||
use common_query::error::{PyUdfSnafu, UdfTempRecordBatchSnafu};
|
||||
use common_query::prelude::Signature;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::{ExternalSnafu, Result as RecordBatchResult};
|
||||
use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream};
|
||||
use datatypes::schema::SchemaRef;
|
||||
use datafusion_expr::Volatility;
|
||||
use datatypes::schema::{ColumnSchema, SchemaRef};
|
||||
use datatypes::vectors::VectorRef;
|
||||
use futures::Stream;
|
||||
use query::parser::{QueryLanguageParser, QueryStatement};
|
||||
use query::QueryEngineRef;
|
||||
@@ -32,16 +37,141 @@ use snafu::{ensure, ResultExt};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use crate::python::coprocessor::{exec_parsed, parse, CoprocessorRef};
|
||||
use crate::python::coprocessor::{exec_parsed, parse, AnnotationInfo, CoprocessorRef};
|
||||
use crate::python::error::{self, Result};
|
||||
|
||||
const PY_ENGINE: &str = "python";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PyUDF {
|
||||
copr: CoprocessorRef,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PyUDF {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}({})->",
|
||||
&self.copr.name,
|
||||
&self.copr.deco_args.arg_names.join(",")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PyUDF {
|
||||
fn from_copr(copr: CoprocessorRef) -> Arc<Self> {
|
||||
Arc::new(Self { copr })
|
||||
}
|
||||
|
||||
/// Register to `FUNCTION_REGISTRY`
|
||||
fn register_as_udf(zelf: Arc<Self>) {
|
||||
FUNCTION_REGISTRY.register(zelf)
|
||||
}
|
||||
fn register_to_query_engine(zelf: Arc<Self>, engine: QueryEngineRef) {
|
||||
engine.register_function(zelf)
|
||||
}
|
||||
|
||||
/// Fake a schema, should only be used with dynamically eval a Python Udf
|
||||
fn fake_schema(&self, columns: &[VectorRef]) -> SchemaRef {
|
||||
let arg_names = &self.copr.deco_args.arg_names;
|
||||
let col_sch: Vec<_> = columns
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, col)| ColumnSchema::new(arg_names[i].to_owned(), col.data_type(), true))
|
||||
.collect();
|
||||
let schema = datatypes::schema::Schema::new(col_sch);
|
||||
Arc::new(schema)
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for PyUDF {
|
||||
fn name(&self) -> &str {
|
||||
&self.copr.name
|
||||
}
|
||||
|
||||
fn return_type(
|
||||
&self,
|
||||
_input_types: &[datatypes::prelude::ConcreteDataType],
|
||||
) -> common_query::error::Result<datatypes::prelude::ConcreteDataType> {
|
||||
// TODO(discord9): use correct return annotation if exist
|
||||
match self.copr.return_types.get(0) {
|
||||
Some(Some(AnnotationInfo {
|
||||
datatype: Some(ty), ..
|
||||
})) => Ok(ty.to_owned()),
|
||||
_ => PyUdfSnafu {
|
||||
msg: "Can't found return type for python UDF {self}",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
fn signature(&self) -> common_query::prelude::Signature {
|
||||
// try our best to get a type signature
|
||||
let mut arg_types = Vec::with_capacity(self.copr.arg_types.len());
|
||||
let mut know_all_types = true;
|
||||
for ty in self.copr.arg_types.iter() {
|
||||
match ty {
|
||||
Some(AnnotationInfo {
|
||||
datatype: Some(ty), ..
|
||||
}) => arg_types.push(ty.to_owned()),
|
||||
_ => {
|
||||
know_all_types = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if know_all_types {
|
||||
Signature::variadic(arg_types, Volatility::Immutable)
|
||||
} else {
|
||||
Signature::any(self.copr.arg_types.len(), Volatility::Immutable)
|
||||
}
|
||||
}
|
||||
|
||||
fn eval(
|
||||
&self,
|
||||
_func_ctx: common_function::scalars::function::FunctionContext,
|
||||
columns: &[datatypes::vectors::VectorRef],
|
||||
) -> common_query::error::Result<datatypes::vectors::VectorRef> {
|
||||
// FIXME(discord9): exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right?
|
||||
let schema = self.fake_schema(columns);
|
||||
let columns = columns.to_vec();
|
||||
// TODO(discord9): remove unwrap
|
||||
let rb = RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?;
|
||||
let res = exec_parsed(&self.copr, &rb).map_err(|err| {
|
||||
PyUdfSnafu {
|
||||
msg: format!("{err:#?}"),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let len = res.columns().len();
|
||||
if len == 0 {
|
||||
return PyUdfSnafu {
|
||||
msg: "Python UDF should return exactly one column, found zero column".to_string(),
|
||||
}
|
||||
.fail();
|
||||
} // if more than one columns, just return first one
|
||||
|
||||
// TODO(discord9): more error handling
|
||||
let res0 = res.column(0);
|
||||
Ok(res0.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PyScript {
|
||||
query_engine: QueryEngineRef,
|
||||
copr: CoprocessorRef,
|
||||
}
|
||||
|
||||
impl PyScript {
|
||||
/// Register Current Script as UDF, register name is same as script name
|
||||
/// FIXME(discord9): possible inject attack?
|
||||
pub fn register_udf(&self) {
|
||||
let udf = PyUDF::from_copr(self.copr.clone());
|
||||
PyUDF::register_as_udf(udf.clone());
|
||||
PyUDF::register_to_query_engine(udf, self.query_engine.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CoprStream {
|
||||
stream: SendableRecordBatchStream,
|
||||
copr: CoprocessorRef,
|
||||
|
||||
@@ -132,7 +132,7 @@ fn run_ron_testcases() {
|
||||
.zip(res.schema.column_schemas())
|
||||
.for_each(|(anno, real)| {
|
||||
assert!(
|
||||
anno.datatype.as_ref().unwrap() == &real.data_type.as_arrow_type()
|
||||
anno.datatype.as_ref().unwrap() == &real.data_type
|
||||
&& anno.is_nullable == real.is_nullable(),
|
||||
"Fields expected to be {anno:#?}, actual {real:#?}"
|
||||
);
|
||||
|
||||
@@ -24,21 +24,21 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vecto
|
||||
),
|
||||
arg_types: [
|
||||
Some((
|
||||
datatype: Some(Float32),
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64),
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
],
|
||||
return_types: [
|
||||
Some((
|
||||
datatype: Some(Float64),
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64),
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
)),
|
||||
Some((
|
||||
@@ -246,11 +246,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64),
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32),
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
@@ -278,11 +278,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64),
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32),
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
@@ -310,11 +310,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64),
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32),
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
@@ -342,11 +342,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64),
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32),
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
@@ -372,7 +372,7 @@ def a(cpu: vector[f32], mem: vector[f64]):
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Utf8),
|
||||
datatype: Some(String(())),
|
||||
is_nullable: false,
|
||||
),
|
||||
],
|
||||
@@ -480,11 +480,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64),
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32),
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
|
||||
@@ -37,6 +37,8 @@ mod interceptor;
|
||||
mod opentsdb;
|
||||
mod postgres;
|
||||
|
||||
mod py_script;
|
||||
|
||||
struct DummyInstance {
|
||||
query_engine: QueryEngineRef,
|
||||
py_engine: Arc<PyEngine>,
|
||||
@@ -88,6 +90,7 @@ impl ScriptHandler for DummyInstance {
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
script.register_udf();
|
||||
self.scripts
|
||||
.write()
|
||||
.unwrap()
|
||||
|
||||
56
src/servers/tests/py_script/mod.rs
Normal file
56
src/servers/tests/py_script/mod.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use servers::error::Result;
|
||||
use servers::query_handler::{ScriptHandler, SqlQueryHandler};
|
||||
use session::context::QueryContext;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::create_testing_instance;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_py_udf_and_query() -> Result<()> {
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let instance = create_testing_instance(table);
|
||||
let src = r#"
|
||||
@coprocessor(args=["uint32s"], returns = ["ret"])
|
||||
def double_that(col)->vector[u32]:
|
||||
return col*2
|
||||
"#;
|
||||
instance.insert_script("double_that", src).await?;
|
||||
let res = instance
|
||||
.do_query("select double_that(uint32s) from numbers", query_ctx)
|
||||
.await
|
||||
.remove(0)
|
||||
.unwrap();
|
||||
match res {
|
||||
common_query::Output::AffectedRows(_) => (),
|
||||
common_query::Output::RecordBatches(_) => {
|
||||
unreachable!()
|
||||
}
|
||||
common_query::Output::Stream(s) => {
|
||||
let batches = common_recordbatch::util::collect_batches(s).await.unwrap();
|
||||
assert_eq!(batches.iter().count(), 1);
|
||||
let first = batches.iter().next().unwrap();
|
||||
let col = first.column(0);
|
||||
let val = col.get(1);
|
||||
assert_eq!(val, datatypes::value::Value::UInt32(2));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user