feat: use Python Script as UDF in SQL (#839)

* feat: reg PyScript as UDF

* refactor: use `ConcreteDataType` instead

* fix: accept `str` data type

* fix: allow binary to capture SIGINT

* test: add test for py udf

* Update src/servers/tests/py_script/mod.rs

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* style: clippy problem

* style: add newline

* chore: PR advices

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
discord9
2023-01-13 14:35:03 +08:00
committed by GitHub
parent 58c37f588d
commit e428a84446
11 changed files with 276 additions and 61 deletions

22
Cargo.lock generated
View File

@@ -5932,7 +5932,7 @@ dependencies = [
[[package]]
name = "rustpython-ast"
version = "0.1.0"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"num-bigint",
"rustpython-common",
@@ -5942,7 +5942,7 @@ dependencies = [
[[package]]
name = "rustpython-codegen"
version = "0.1.2"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"ahash 0.7.6",
"bitflags",
@@ -5959,7 +5959,7 @@ dependencies = [
[[package]]
name = "rustpython-common"
version = "0.0.0"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"ascii",
"cfg-if 1.0.0",
@@ -5982,7 +5982,7 @@ dependencies = [
[[package]]
name = "rustpython-compiler"
version = "0.1.2"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"rustpython-codegen",
"rustpython-compiler-core",
@@ -5993,7 +5993,7 @@ dependencies = [
[[package]]
name = "rustpython-compiler-core"
version = "0.1.2"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"bincode 1.3.3",
"bitflags",
@@ -6010,7 +6010,7 @@ dependencies = [
[[package]]
name = "rustpython-derive"
version = "0.1.2"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"rustpython-compiler",
"rustpython-derive-impl",
@@ -6020,7 +6020,7 @@ dependencies = [
[[package]]
name = "rustpython-derive-impl"
version = "0.0.0"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"indexmap",
"itertools",
@@ -6046,7 +6046,7 @@ dependencies = [
[[package]]
name = "rustpython-parser"
version = "0.1.2"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"ahash 0.7.6",
"anyhow",
@@ -6071,7 +6071,7 @@ dependencies = [
[[package]]
name = "rustpython-pylib"
version = "0.1.0"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"glob",
"rustpython-compiler-core",
@@ -6081,7 +6081,7 @@ dependencies = [
[[package]]
name = "rustpython-stdlib"
version = "0.1.2"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"adler32",
"ahash 0.7.6",
@@ -6146,7 +6146,7 @@ dependencies = [
[[package]]
name = "rustpython-vm"
version = "0.1.2"
source = "git+https://github.com/discord9/RustPython?rev=f89b1537#f89b1537b9c789ff566717d1c2ad3d0777cbf5d6"
source = "git+https://github.com/discord9/RustPython?rev=2e126345#2e12634569d01674724490193eb9638f056e51ca"
dependencies = [
"adler32",
"ahash 0.7.6",

View File

@@ -16,6 +16,7 @@ use std::any::Any;
use arrow::error::ArrowError;
use common_error::prelude::*;
use common_recordbatch::error::Error as RecordbatchError;
use datafusion_common::DataFusionError;
use datatypes::arrow;
use datatypes::arrow::datatypes::DataType as ArrowDatatype;
@@ -26,6 +27,22 @@ use statrs::StatsError;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("Fail to execute Python UDF, source: {}", msg))]
PyUdf {
// TODO(discord9): find a way that prevent circle depend(query<-script<-query) and can use script's error type
msg: String,
backtrace: Backtrace,
},
#[snafu(display(
"Fail to create temporary recordbatch when eval Python UDF, source: {}",
source
))]
UdfTempRecordBatch {
#[snafu(backtrace)]
source: RecordbatchError,
},
#[snafu(display("Fail to execute function, source: {}", source))]
ExecuteFunction {
source: DataFusionError,
@@ -167,7 +184,9 @@ pub type Result<T> = std::result::Result<T, Error>;
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::ExecuteFunction { .. }
Error::UdfTempRecordBatch { .. }
| Error::PyUdf { .. }
| Error::ExecuteFunction { .. }
| Error::GenerateFunction { .. }
| Error::CreateAccumulator { .. }
| Error::DowncastVector { .. }

View File

@@ -45,16 +45,16 @@ once_cell = "1.17.0"
paste = { workspace = true, optional = true }
query = { path = "../query" }
# TODO(discord9): This is a forked and tweaked version of RustPython, please update it to newest original RustPython After Update toolchain to 1.65
rustpython-ast = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
rustpython-codegen = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
rustpython-compiler = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
rustpython-compiler-core = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
rustpython-parser = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
rustpython-pylib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537", features = [
rustpython-ast = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
rustpython-codegen = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
rustpython-compiler = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
rustpython-compiler-core = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
rustpython-parser = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
rustpython-pylib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345", features = [
"freeze-stdlib",
] }
rustpython-stdlib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537" }
rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = true, rev = "f89b1537", features = [
rustpython-stdlib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345" }
rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = true, rev = "2e126345", features = [
"default",
"codegen",
] }

View File

@@ -58,6 +58,10 @@ impl ScriptManager {
logging::info!("Compiled and cached script: {}", name);
script.as_ref().register_udf();
logging::info!("Script register as UDF: {}", name);
Ok(script)
}

View File

@@ -24,7 +24,6 @@ use common_recordbatch::RecordBatch;
use common_telemetry::info;
use datatypes::arrow::array::Array;
use datatypes::arrow::compute;
use datatypes::arrow::datatypes::DataType as ArrowDataType;
use datatypes::data_type::{ConcreteDataType, DataType};
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
use datatypes::vectors::{Helper, VectorRef};
@@ -55,7 +54,7 @@ thread_local!(static INTERPRETER: RefCell<Option<Arc<Interpreter>>> = RefCell::n
pub struct AnnotationInfo {
/// if None, use types inferred by PyVector
// TODO(yingwen): We should use our data type. i.e. ConcreteDataType.
pub datatype: Option<ArrowDataType>,
pub datatype: Option<ConcreteDataType>,
pub is_nullable: bool,
}
@@ -122,14 +121,12 @@ impl Coprocessor {
} = anno[idx].to_owned().unwrap_or_else(|| {
// default to be not nullable and use DataType inferred by PyVector itself
AnnotationInfo {
datatype: Some(real_ty.as_arrow_type()),
datatype: Some(real_ty.clone()),
is_nullable: false,
}
});
let column_type = match ty {
Some(arrow_type) => {
ConcreteDataType::try_from(&arrow_type).context(TypeCastSnafu)?
}
Some(anno_type) => anno_type,
// if type is like `_` or `_ | None`
None => real_ty,
};
@@ -165,9 +162,10 @@ impl Coprocessor {
{
let real_ty = col.data_type();
let anno_ty = datatype;
if real_ty.as_arrow_type() != *anno_ty {
if real_ty != *anno_ty {
let array = col.to_arrow_array();
let array = compute::cast(&array, anno_ty).context(ArrowSnafu)?;
let array =
compute::cast(&array, &anno_ty.as_arrow_type()).context(ArrowSnafu)?;
*col = Helper::try_into_vector(array).context(TypeCastSnafu)?;
}
}
@@ -223,6 +221,7 @@ fn check_args_anno_real_type(
for (idx, arg) in args.iter().enumerate() {
let anno_ty = copr.arg_types[idx].to_owned();
let real_ty = arg.to_arrow_array().data_type().to_owned();
let real_ty = ConcreteDataType::from_arrow_type(&real_ty);
let is_nullable: bool = rb.schema.column_schemas()[idx].is_nullable();
ensure!(
anno_ty
@@ -372,7 +371,11 @@ pub(crate) fn init_interpreter() -> Arc<Interpreter> {
let native_module_allow_list = HashSet::from([
"array", "cmath", "gc", "hashlib", "_json", "_random", "math",
]);
let interpreter = Arc::new(vm::Interpreter::with_init(Default::default(), |vm| {
// TODO(discord9): edge cases, can't use "..Default::default" because Settings is `#[non_exhaustive]`
// so more in here: https://internals.rust-lang.org/t/allow-constructing-non-exhaustive-structs-using-default-default/13868
let mut settings = vm::Settings::default();
settings.no_sig_int = true;
let interpreter = Arc::new(vm::Interpreter::with_init(settings, |vm| {
// not using full stdlib to prevent security issue, instead filter out a few simple util module
vm.add_native_modules(
rustpython_stdlib::get_module_inits()

View File

@@ -14,7 +14,7 @@
use std::collections::HashSet;
use datatypes::arrow::datatypes::DataType;
use datatypes::prelude::ConcreteDataType;
use rustpython_parser::ast::{Arguments, Location};
use rustpython_parser::{ast, parser};
#[cfg(test)]
@@ -81,20 +81,20 @@ fn pylist_to_vec(lst: &ast::Expr<()>) -> Result<Vec<String>> {
}
}
fn try_into_datatype(ty: &str, loc: &Location) -> Result<Option<DataType>> {
fn try_into_datatype(ty: &str, loc: &Location) -> Result<Option<ConcreteDataType>> {
match ty {
"bool" => Ok(Some(DataType::Boolean)),
"u8" => Ok(Some(DataType::UInt8)),
"u16" => Ok(Some(DataType::UInt16)),
"u32" => Ok(Some(DataType::UInt32)),
"u64" => Ok(Some(DataType::UInt64)),
"i8" => Ok(Some(DataType::Int8)),
"i16" => Ok(Some(DataType::Int16)),
"i32" => Ok(Some(DataType::Int32)),
"i64" => Ok(Some(DataType::Int64)),
"f16" => Ok(Some(DataType::Float16)),
"f32" => Ok(Some(DataType::Float32)),
"f64" => Ok(Some(DataType::Float64)),
"bool" => Ok(Some(ConcreteDataType::boolean_datatype())),
"u8" => Ok(Some(ConcreteDataType::uint8_datatype())),
"u16" => Ok(Some(ConcreteDataType::uint16_datatype())),
"u32" => Ok(Some(ConcreteDataType::uint32_datatype())),
"u64" => Ok(Some(ConcreteDataType::uint64_datatype())),
"i8" => Ok(Some(ConcreteDataType::int8_datatype())),
"i16" => Ok(Some(ConcreteDataType::int16_datatype())),
"i32" => Ok(Some(ConcreteDataType::int32_datatype())),
"i64" => Ok(Some(ConcreteDataType::int64_datatype())),
"f32" => Ok(Some(ConcreteDataType::float32_datatype())),
"f64" => Ok(Some(ConcreteDataType::float64_datatype())),
"str" => Ok(Some(ConcreteDataType::string_datatype())),
// for any datatype
"_" => Ok(None),
// note the different between "_" and _

View File

@@ -20,10 +20,15 @@ use std::task::{Context, Poll};
use async_trait::async_trait;
use common_error::prelude::BoxedError;
use common_function::scalars::{Function, FUNCTION_REGISTRY};
use common_query::error::{PyUdfSnafu, UdfTempRecordBatchSnafu};
use common_query::prelude::Signature;
use common_query::Output;
use common_recordbatch::error::{ExternalSnafu, Result as RecordBatchResult};
use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream};
use datatypes::schema::SchemaRef;
use datafusion_expr::Volatility;
use datatypes::schema::{ColumnSchema, SchemaRef};
use datatypes::vectors::VectorRef;
use futures::Stream;
use query::parser::{QueryLanguageParser, QueryStatement};
use query::QueryEngineRef;
@@ -32,16 +37,141 @@ use snafu::{ensure, ResultExt};
use sql::statements::statement::Statement;
use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
use crate::python::coprocessor::{exec_parsed, parse, CoprocessorRef};
use crate::python::coprocessor::{exec_parsed, parse, AnnotationInfo, CoprocessorRef};
use crate::python::error::{self, Result};
const PY_ENGINE: &str = "python";
#[derive(Debug)]
pub struct PyUDF {
copr: CoprocessorRef,
}
impl std::fmt::Display for PyUDF {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}({})->",
&self.copr.name,
&self.copr.deco_args.arg_names.join(",")
)
}
}
impl PyUDF {
fn from_copr(copr: CoprocessorRef) -> Arc<Self> {
Arc::new(Self { copr })
}
/// Register to `FUNCTION_REGISTRY`
fn register_as_udf(zelf: Arc<Self>) {
FUNCTION_REGISTRY.register(zelf)
}
fn register_to_query_engine(zelf: Arc<Self>, engine: QueryEngineRef) {
engine.register_function(zelf)
}
/// Fake a schema, should only be used with dynamically eval a Python Udf
fn fake_schema(&self, columns: &[VectorRef]) -> SchemaRef {
let arg_names = &self.copr.deco_args.arg_names;
let col_sch: Vec<_> = columns
.iter()
.enumerate()
.map(|(i, col)| ColumnSchema::new(arg_names[i].to_owned(), col.data_type(), true))
.collect();
let schema = datatypes::schema::Schema::new(col_sch);
Arc::new(schema)
}
}
impl Function for PyUDF {
fn name(&self) -> &str {
&self.copr.name
}
fn return_type(
&self,
_input_types: &[datatypes::prelude::ConcreteDataType],
) -> common_query::error::Result<datatypes::prelude::ConcreteDataType> {
// TODO(discord9): use correct return annotation if exist
match self.copr.return_types.get(0) {
Some(Some(AnnotationInfo {
datatype: Some(ty), ..
})) => Ok(ty.to_owned()),
_ => PyUdfSnafu {
msg: "Can't found return type for python UDF {self}",
}
.fail(),
}
}
fn signature(&self) -> common_query::prelude::Signature {
// try our best to get a type signature
let mut arg_types = Vec::with_capacity(self.copr.arg_types.len());
let mut know_all_types = true;
for ty in self.copr.arg_types.iter() {
match ty {
Some(AnnotationInfo {
datatype: Some(ty), ..
}) => arg_types.push(ty.to_owned()),
_ => {
know_all_types = false;
break;
}
}
}
if know_all_types {
Signature::variadic(arg_types, Volatility::Immutable)
} else {
Signature::any(self.copr.arg_types.len(), Volatility::Immutable)
}
}
fn eval(
&self,
_func_ctx: common_function::scalars::function::FunctionContext,
columns: &[datatypes::vectors::VectorRef],
) -> common_query::error::Result<datatypes::vectors::VectorRef> {
// FIXME(discord9): exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right?
let schema = self.fake_schema(columns);
let columns = columns.to_vec();
// TODO(discord9): remove unwrap
let rb = RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?;
let res = exec_parsed(&self.copr, &rb).map_err(|err| {
PyUdfSnafu {
msg: format!("{err:#?}"),
}
.build()
})?;
let len = res.columns().len();
if len == 0 {
return PyUdfSnafu {
msg: "Python UDF should return exactly one column, found zero column".to_string(),
}
.fail();
} // if more than one columns, just return first one
// TODO(discord9): more error handling
let res0 = res.column(0);
Ok(res0.to_owned())
}
}
pub struct PyScript {
query_engine: QueryEngineRef,
copr: CoprocessorRef,
}
impl PyScript {
/// Register Current Script as UDF, register name is same as script name
/// FIXME(discord9): possible inject attack?
pub fn register_udf(&self) {
let udf = PyUDF::from_copr(self.copr.clone());
PyUDF::register_as_udf(udf.clone());
PyUDF::register_to_query_engine(udf, self.query_engine.to_owned());
}
}
pub struct CoprStream {
stream: SendableRecordBatchStream,
copr: CoprocessorRef,

View File

@@ -132,7 +132,7 @@ fn run_ron_testcases() {
.zip(res.schema.column_schemas())
.for_each(|(anno, real)| {
assert!(
anno.datatype.as_ref().unwrap() == &real.data_type.as_arrow_type()
anno.datatype.as_ref().unwrap() == &real.data_type
&& anno.is_nullable == real.is_nullable(),
"Fields expected to be {anno:#?}, actual {real:#?}"
);

View File

@@ -24,21 +24,21 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vecto
),
arg_types: [
Some((
datatype: Some(Float32),
datatype: Some(Float32(())),
is_nullable: false
)),
Some((
datatype: Some(Float64),
datatype: Some(Float64(())),
is_nullable: false
)),
],
return_types: [
Some((
datatype: Some(Float64),
datatype: Some(Float64(())),
is_nullable: false
)),
Some((
datatype: Some(Float64),
datatype: Some(Float64(())),
is_nullable: true
)),
Some((
@@ -246,11 +246,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
predicate: ExecIsOk(
fields: [
(
datatype: Some(Float64),
datatype: Some(Float64(())),
is_nullable: true
),
(
datatype: Some(Float32),
datatype: Some(Float32(())),
is_nullable: false
),
],
@@ -278,11 +278,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
predicate: ExecIsOk(
fields: [
(
datatype: Some(Float64),
datatype: Some(Float64(())),
is_nullable: true
),
(
datatype: Some(Float32),
datatype: Some(Float32(())),
is_nullable: false
),
],
@@ -310,11 +310,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
predicate: ExecIsOk(
fields: [
(
datatype: Some(Float64),
datatype: Some(Float64(())),
is_nullable: true
),
(
datatype: Some(Float32),
datatype: Some(Float32(())),
is_nullable: false
),
],
@@ -342,11 +342,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
predicate: ExecIsOk(
fields: [
(
datatype: Some(Float64),
datatype: Some(Float64(())),
is_nullable: true
),
(
datatype: Some(Float32),
datatype: Some(Float32(())),
is_nullable: false
),
],
@@ -372,7 +372,7 @@ def a(cpu: vector[f32], mem: vector[f64]):
predicate: ExecIsOk(
fields: [
(
datatype: Some(Utf8),
datatype: Some(String(())),
is_nullable: false,
),
],
@@ -480,11 +480,11 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
predicate: ExecIsOk(
fields: [
(
datatype: Some(Float64),
datatype: Some(Float64(())),
is_nullable: true
),
(
datatype: Some(Float32),
datatype: Some(Float32(())),
is_nullable: false
),
],

View File

@@ -37,6 +37,8 @@ mod interceptor;
mod opentsdb;
mod postgres;
mod py_script;
struct DummyInstance {
query_engine: QueryEngineRef,
py_engine: Arc<PyEngine>,
@@ -88,6 +90,7 @@ impl ScriptHandler for DummyInstance {
.compile(script, CompileContext::default())
.await
.unwrap();
script.register_udf();
self.scripts
.write()
.unwrap()

View File

@@ -0,0 +1,56 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use servers::error::Result;
use servers::query_handler::{ScriptHandler, SqlQueryHandler};
use session::context::QueryContext;
use table::test_util::MemTable;
use crate::create_testing_instance;
#[tokio::test]
async fn test_insert_py_udf_and_query() -> Result<()> {
let query_ctx = Arc::new(QueryContext::new());
let table = MemTable::default_numbers_table();
let instance = create_testing_instance(table);
let src = r#"
@coprocessor(args=["uint32s"], returns = ["ret"])
def double_that(col)->vector[u32]:
return col*2
"#;
instance.insert_script("double_that", src).await?;
let res = instance
.do_query("select double_that(uint32s) from numbers", query_ctx)
.await
.remove(0)
.unwrap();
match res {
common_query::Output::AffectedRows(_) => (),
common_query::Output::RecordBatches(_) => {
unreachable!()
}
common_query::Output::Stream(s) => {
let batches = common_recordbatch::util::collect_batches(s).await.unwrap();
assert_eq!(batches.iter().count(), 1);
let first = batches.iter().next().unwrap();
let col = first.column(0);
let val = col.get(1);
assert_eq!(val, datatypes::value::Value::UInt32(2));
}
}
Ok(())
}