diff --git a/src/datanode/src/instance/script.rs b/src/datanode/src/instance/script.rs index 8dd878546b..fc7757a365 100644 --- a/src/datanode/src/instance/script.rs +++ b/src/datanode/src/instance/script.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use async_trait::async_trait; use common_query::Output; use common_telemetry::timer; @@ -34,8 +36,15 @@ impl ScriptHandler for Instance { .await } - async fn execute_script(&self, schema: &str, name: &str) -> servers::error::Result { + async fn execute_script( + &self, + schema: &str, + name: &str, + params: HashMap, + ) -> servers::error::Result { let _timer = timer!(metric::METRIC_RUN_SCRIPT_ELAPSED); - self.script_executor.execute_script(schema, name).await + self.script_executor + .execute_script(schema, name, params) + .await } } diff --git a/src/datanode/src/script.rs b/src/datanode/src/script.rs index b7cc622c95..d63efccc6e 100644 --- a/src/datanode/src/script.rs +++ b/src/datanode/src/script.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use catalog::CatalogManagerRef; use common_query::Output; use query::QueryEngineRef; @@ -34,13 +36,19 @@ mod dummy { pub async fn insert_script( &self, + _schema: &str, _name: &str, _script: &str, ) -> servers::error::Result<()> { servers::error::NotSupportedSnafu { feat: "script" }.fail() } - pub async fn execute_script(&self, _script: &str) -> servers::error::Result { + pub async fn execute_script( + &self, + _schema: &str, + _name: &str, + _params: HashMap, + ) -> servers::error::Result { servers::error::NotSupportedSnafu { feat: "script" }.fail() } } @@ -94,9 +102,10 @@ mod python { &self, schema: &str, name: &str, + params: HashMap, ) -> servers::error::Result { self.script_manager - .execute(schema, name) + .execute(schema, name, params) .await .map_err(|e| { error!(e; "Instance failed to execute script"); diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 11becc4a7e..127597bc7d 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -19,6 +19,7 @@ mod opentsdb; mod prometheus; mod standalone; +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -512,9 +513,14 @@ impl ScriptHandler for Instance { } } - async fn execute_script(&self, schema: &str, script: &str) -> server_error::Result { + async fn execute_script( + &self, + schema: &str, + script: &str, + params: HashMap, + ) -> server_error::Result { if let Some(handler) = &self.script_handler { - handler.execute_script(schema, script).await + handler.execute_script(schema, script, params).await } else { server_error::NotSupportedSnafu { feat: "Script execution in Frontend", diff --git a/src/script/src/engine.rs b/src/script/src/engine.rs index 004ce351bb..78ddfcf283 100644 --- a/src/script/src/engine.rs +++ b/src/script/src/engine.rs @@ -15,6 +15,7 @@ //! Script engine use std::any::Any; +use std::collections::HashMap; use async_trait::async_trait; use common_error::ext::ErrorExt; @@ -30,7 +31,11 @@ pub trait Script { fn as_any(&self) -> &dyn Any; /// Execute the script and returns the output. - async fn execute(&self, ctx: EvalContext) -> std::result::Result; + async fn execute( + &self, + params: HashMap, + ctx: EvalContext, + ) -> std::result::Result; } #[async_trait] diff --git a/src/script/src/manager.rs b/src/script/src/manager.rs index 637e1194a7..ba79a361c8 100644 --- a/src/script/src/manager.rs +++ b/src/script/src/manager.rs @@ -76,7 +76,12 @@ impl ScriptManager { Ok(compiled_script) } - pub async fn execute(&self, schema: &str, name: &str) -> Result { + pub async fn execute( + &self, + schema: &str, + name: &str, + params: HashMap, + ) -> Result { let script = { let s = self.compiled.read().unwrap().get(name).cloned(); @@ -90,7 +95,7 @@ impl ScriptManager { let script = script.context(ScriptNotFoundSnafu { name })?; script - .execute(EvalContext::default()) + .execute(params, EvalContext::default()) .await .context(ExecutePythonSnafu { name }) } diff --git a/src/script/src/python/builtins.rs b/src/script/src/python/builtins.rs index 71815f1750..e37bd6b60c 100644 --- a/src/script/src/python/builtins.rs +++ b/src/script/src/python/builtins.rs @@ -97,7 +97,7 @@ pub fn try_into_columnar_value(obj: PyObjectRef, vm: &VirtualMachine) -> PyResul .borrow_vec() .iter() .map(|obj| -> PyResult { - let col = try_into_columnar_value(obj.to_owned(), vm)?; + let col = try_into_columnar_value(obj.clone(), vm)?; match col { DFColValue::Array(arr) => Err(vm.new_type_error(format!( "Expect only scalar value in a list, found a vector of type {:?} nested in list", arr.data_type() @@ -243,7 +243,7 @@ macro_rules! bind_aggr_fn { $( Arc::new(expressions::Column::new(stringify!($EXPR_ARGS), 0)) as _, )* - stringify!($AGGR_FUNC), $DATA_TYPE.to_owned()), + stringify!($AGGR_FUNC), $DATA_TYPE.clone()), $ARGS, $VM) }; } @@ -597,7 +597,7 @@ pub(crate) mod greptime_builtin { Arc::new(percent) as _, ], "ApproxPercentileCont", - (values.to_arrow_array().data_type()).to_owned(), + (values.to_arrow_array().data_type()).clone(), ) .map_err(|err| from_df_err(err, vm))?, &[values.to_arrow_array()], @@ -839,7 +839,7 @@ pub(crate) mod greptime_builtin { return Ok(ret.into()); } let cur = cur.slice(0, cur.len() - 1); // except the last one that is - let fill = gen_none_array(cur.data_type().to_owned(), 1, vm)?; + let fill = gen_none_array(cur.data_type().clone(), 1, vm)?; let ret = compute::concat(&[&*fill, &*cur]).map_err(|err| { vm.new_runtime_error(format!("Can't concat array[0] with array[0:-1]!{err:#?}")) })?; @@ -864,7 +864,7 @@ pub(crate) mod greptime_builtin { return Ok(ret.into()); } let cur = cur.slice(1, cur.len() - 1); // except the last one that is - let fill = gen_none_array(cur.data_type().to_owned(), 1, vm)?; + let fill = gen_none_array(cur.data_type().clone(), 1, vm)?; let ret = compute::concat(&[&*cur, &*fill]).map_err(|err| { vm.new_runtime_error(format!("Can't concat array[0] with array[0:-1]!{err:#?}")) })?; @@ -1048,7 +1048,7 @@ pub(crate) mod greptime_builtin { match (ch.is_ascii_digit(), &state) { (true, State::Separator(_)) => { let res = &input[prev..idx]; - let res = State::Separator(res.to_owned()); + let res = State::Separator(res.to_string()); parsed.push(res); prev = idx; state = State::Num(Default::default()); @@ -1073,7 +1073,7 @@ pub(crate) mod greptime_builtin { } State::Separator(_) => { let res = &input[prev..]; - State::Separator(res.to_owned()) + State::Separator(res.to_string()) } }; parsed.push(last); diff --git a/src/script/src/python/builtins/test.rs b/src/script/src/python/builtins/test.rs index 6738f9e722..46cafb5ee3 100644 --- a/src/script/src/python/builtins/test.rs +++ b/src/script/src/python/builtins/test.rs @@ -240,10 +240,7 @@ impl PyValue { .as_any() .downcast_ref::() .ok_or(format!("Can't cast {vec_f64:#?} to Float64Array!"))?; - let ret = vec_f64 - .into_iter() - .map(|v| v.map(|inner| inner.to_owned())) - .collect::>(); + let ret = vec_f64.into_iter().collect::>(); if ret.iter().all(|x| x.is_some()) { Ok(Self::FloatVec( ret.into_iter().map(|i| i.unwrap()).collect(), @@ -266,7 +263,6 @@ impl PyValue { v.ok_or(format!( "No null element expected, found one in {idx} position" )) - .map(|v| v.to_owned()) }) .collect::>()?; Ok(Self::IntVec(ret)) @@ -275,13 +271,13 @@ impl PyValue { } } else if is_instance::(obj, vm) { let res = obj - .to_owned() + .clone() .try_into_value::(vm) .map_err(|err| format_py_error(err, vm).to_string())?; Ok(Self::Int(res)) } else if is_instance::(obj, vm) { let res = obj - .to_owned() + .clone() .try_into_value::(vm) .map_err(|err| format_py_error(err, vm).to_string())?; Ok(Self::Float(res)) @@ -338,7 +334,7 @@ fn run_builtin_fn_testcases() { .compile( &case.script, rustpython_compiler_core::Mode::BlockExpr, - "".to_owned(), + "".to_string(), ) .map_err(|err| vm.new_syntax_error(&err)) .unwrap(); @@ -389,7 +385,7 @@ fn set_item_into_scope( scope .locals .as_object() - .set_item(&name.to_owned(), vm.new_pyobj(value), vm) + .set_item(&name.to_string(), vm.new_pyobj(value), vm) .map_err(|err| { format!( "Error in setting var {name} in scope: \n{}", @@ -408,7 +404,7 @@ fn set_lst_of_vecs_in_scope( scope .locals .as_object() - .set_item(name.to_owned(), vm.new_pyobj(vector), vm) + .set_item(&name.to_string(), vm.new_pyobj(vector), vm) .map_err(|err| { format!( "Error in setting var {name} in scope: \n{}", @@ -447,7 +443,7 @@ fn test_vm() { from udf_builtins import * sin(values)"#, rustpython_compiler_core::Mode::BlockExpr, - "".to_owned(), + "".to_string(), ) .map_err(|err| vm.new_syntax_error(&err)) .unwrap(); diff --git a/src/script/src/python/coprocessor.rs b/src/script/src/python/coprocessor.rs index 9e8bf8e50b..9f793887e4 100644 --- a/src/script/src/python/coprocessor.rs +++ b/src/script/src/python/coprocessor.rs @@ -16,7 +16,7 @@ pub mod compile; pub mod parse; use std::cell::RefCell; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::result::Result as StdResult; use std::sync::{Arc, Weak}; @@ -36,7 +36,7 @@ use rustpython_vm::AsObject; #[cfg(test)] use serde::Deserialize; use snafu::{OptionExt, ResultExt}; -use vm::builtins::{PyBaseExceptionRef, PyList, PyListRef, PyTuple}; +use vm::builtins::{PyBaseExceptionRef, PyDict, PyList, PyListRef, PyStr, PyTuple}; use vm::convert::ToPyObject; use vm::scope::Scope; use vm::{pyclass, Interpreter, PyObjectRef, PyPayload, PyResult, VirtualMachine}; @@ -73,6 +73,8 @@ pub struct Coprocessor { pub arg_types: Vec>, /// get from python function returns' annotation, first is type, second is is_nullable pub return_types: Vec>, + /// kwargs in coprocessor function's signature + pub kwarg: Option, /// store its corresponding script, also skip serde when in `cfg(test)` to reduce work in compare #[cfg_attr(test, serde(skip))] pub script: String, @@ -103,7 +105,7 @@ impl From<&Arc> for QueryEngineWeakRef { impl std::fmt::Debug for QueryEngineWeakRef { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("QueryEngineWeakRef") - .field(&self.0.upgrade().map(|f| f.name().to_owned())) + .field(&self.0.upgrade().map(|f| f.name().to_string())) .finish() } } @@ -147,7 +149,7 @@ impl Coprocessor { let AnnotationInfo { datatype: ty, is_nullable, - } = anno[idx].to_owned().unwrap_or_else(|| { + } = anno[idx].clone().unwrap_or_else(|| { // default to be not nullable and use DataType inferred by PyVector itself AnnotationInfo { datatype: Some(real_ty.clone()), @@ -248,20 +250,23 @@ fn check_args_anno_real_type( rb: &RecordBatch, ) -> Result<()> { for (idx, arg) in args.iter().enumerate() { - let anno_ty = copr.arg_types[idx].to_owned(); - let real_ty = arg.to_arrow_array().data_type().to_owned(); + let anno_ty = copr.arg_types[idx].clone(); + let real_ty = arg.to_arrow_array().data_type().clone(); let real_ty = ConcreteDataType::from_arrow_type(&real_ty); let is_nullable: bool = rb.schema.column_schemas()[idx].is_nullable(); ensure!( anno_ty - .to_owned() + .clone() .map(|v| v.datatype.is_none() // like a vector[_] - || v.datatype == Some(real_ty.to_owned()) && v.is_nullable == is_nullable) + || v.datatype == Some(real_ty.clone()) && v.is_nullable == is_nullable) .unwrap_or(true), OtherSnafu { reason: format!( "column {}'s Type annotation is {:?}, but actual type is {:?}", - copr.deco_args.arg_names[idx], anno_ty, real_ty + // It's safe to unwrap here, we already ensure the args and types number is the same when parsing + copr.deco_args.arg_names.as_ref().unwrap()[idx], + anno_ty, + real_ty ) } ) @@ -343,12 +348,12 @@ fn set_items_in_scope( /// You can return constant in python code like `return 1, 1.0, True` /// which create a constant array(with same value)(currently support int, float and bool) as column on return #[cfg(test)] -pub fn exec_coprocessor(script: &str, rb: &RecordBatch) -> Result { +pub fn exec_coprocessor(script: &str, rb: &Option) -> Result { // 1. parse the script and check if it's only a function with `@coprocessor` decorator, and get `args` and `returns`, // 2. also check for exist of `args` in `rb`, if not found, return error // TODO(discord9): cache the result of parse_copr let copr = parse::parse_and_compile_copr(script, None)?; - exec_parsed(&copr, rb) + exec_parsed(&copr, rb, &HashMap::new()) } #[pyclass(module = false, name = "query_engine")] @@ -412,7 +417,7 @@ impl PyQueryEngine { for rb in rbs.iter() { let mut vec_of_vec = Vec::with_capacity(rb.columns().len()); for v in rb.columns() { - let v = PyVector::from(v.to_owned()); + let v = PyVector::from(v.clone()); vec_of_vec.push(v.to_pyobject(vm)); } let vec_of_vec = PyList::new_ref(vec_of_vec, vm.as_ref()).to_pyobject(vm); @@ -440,18 +445,25 @@ fn set_query_engine_in_scope( .map_err(|e| format_py_error(e, vm)) } -pub(crate) fn exec_with_cached_vm( +fn exec_with_cached_vm( copr: &Coprocessor, - rb: &RecordBatch, + rb: &Option, args: Vec, + params: &HashMap, vm: &Arc, ) -> Result { vm.enter(|vm| -> Result { PyVector::make_class(&vm.ctx); // set arguments with given name and values let scope = vm.new_scope_with_builtins(); - set_items_in_scope(&scope, vm, &copr.deco_args.arg_names, args)?; - set_dataframe_in_scope(&scope, vm, "dataframe", rb)?; + if let Some(rb) = rb { + set_dataframe_in_scope(&scope, vm, "dataframe", rb)?; + } + + if let Some(arg_names) = &copr.deco_args.arg_names { + assert_eq!(arg_names.len(), args.len()); + set_items_in_scope(&scope, vm, arg_names, args)?; + } if let Some(engine) = &copr.query_engine { let query_engine = PyQueryEngine { @@ -463,6 +475,19 @@ pub(crate) fn exec_with_cached_vm( set_query_engine_in_scope(&scope, vm, query_engine)?; } + if let Some(kwarg) = &copr.kwarg { + let dict = PyDict::new_ref(&vm.ctx); + for (k, v) in params { + dict.set_item(k, PyStr::from(v.clone()).into_pyobject(vm), vm) + .map_err(|e| format_py_error(e, vm))?; + } + scope + .locals + .as_object() + .set_item(kwarg, vm.new_pyobj(dict), vm) + .map_err(|e| format_py_error(e, vm))?; + } + // It's safe to unwrap code_object, it's already compiled before. let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap()); let ret = vm @@ -470,7 +495,7 @@ pub(crate) fn exec_with_cached_vm( .map_err(|e| format_py_error(e, vm))?; // 5. get returns as either a PyVector or a PyTuple, and naming schema them according to `returns` - let col_len = rb.num_rows(); + let col_len = rb.as_ref().map(|rb| rb.num_rows()).unwrap_or(1); let mut cols = try_into_columns(&ret, vm, col_len)?; ensure!( cols.len() == copr.deco_args.ret_names.len(), @@ -485,6 +510,7 @@ pub(crate) fn exec_with_cached_vm( // if cols and schema's data types is not match, try coerce it to given type(if annotated)(if error occur, return relevant error with question mark) copr.check_and_cast_type(&mut cols)?; + // 6. return a assembled DfRecordBatch let schema = copr.gen_schema(&cols)?; RecordBatch::new(schema, cols).context(NewRecordBatchSnafu) @@ -533,13 +559,23 @@ pub(crate) fn init_interpreter() -> Arc { } /// using a parsed `Coprocessor` struct as input to execute python code -pub(crate) fn exec_parsed(copr: &Coprocessor, rb: &RecordBatch) -> Result { +pub(crate) fn exec_parsed( + copr: &Coprocessor, + rb: &Option, + params: &HashMap, +) -> Result { // 3. get args from `rb`, and cast them into PyVector - let args: Vec = select_from_rb(rb, &copr.deco_args.arg_names)?; - check_args_anno_real_type(&args, copr, rb)?; + let args: Vec = if let Some(rb) = rb { + let args = select_from_rb(rb, copr.deco_args.arg_names.as_ref().unwrap_or(&vec![]))?; + check_args_anno_real_type(&args, copr, rb)?; + args + } else { + vec![] + }; + let interpreter = init_interpreter(); // 4. then set args in scope and compile then run `CodeObject` which already append a new `Call` node - exec_with_cached_vm(copr, rb, args, &interpreter) + exec_with_cached_vm(copr, rb, args, params, &interpreter) } /// execute script just like [`exec_coprocessor`] do, @@ -551,7 +587,7 @@ pub(crate) fn exec_parsed(copr: &Coprocessor, rb: &RecordBatch) -> Result, ln_offset: usize, filename: &str, ) -> StdResult { @@ -572,7 +608,7 @@ def add(a, b): return a + b @copr(args=["a", "b", "c"], returns = ["r"], sql="select number as a,number as b,number as c from numbers limit 100") -def test(a, b, c): +def test(a, b, c, **params): import greptime as g return add(a, b) / g.sqrt(c) "#; @@ -585,9 +621,10 @@ def test(a, b, c): "select number as a,number as b,number as c from numbers limit 100" ); assert_eq!(deco_args.ret_names, vec!["r"]); - assert_eq!(deco_args.arg_names, vec!["a", "b", "c"]); + assert_eq!(deco_args.arg_names.unwrap(), vec!["a", "b", "c"]); assert_eq!(copr.arg_types, vec![None, None, None]); assert_eq!(copr.return_types, vec![None]); + assert_eq!(copr.kwarg, Some("params".to_string())); assert_eq!(copr.script, script); assert!(copr.code_obj.is_some()); } diff --git a/src/script/src/python/coprocessor/compile.rs b/src/script/src/python/coprocessor/compile.rs index 9f5f0cd82e..2f73f8f6c7 100644 --- a/src/script/src/python/coprocessor/compile.rs +++ b/src/script/src/python/coprocessor/compile.rs @@ -16,7 +16,7 @@ use rustpython_codegen::compile::compile_top; use rustpython_compiler::{CompileOpts, Mode}; use rustpython_compiler_core::CodeObject; -use rustpython_parser::ast::{Located, Location}; +use rustpython_parser::ast::{ArgData, Located, Location}; use rustpython_parser::{ast, parser}; use snafu::ResultExt; @@ -31,23 +31,40 @@ fn create_located(node: T, loc: Location) -> Located { /// generate a call to the coprocessor function /// with arguments given in decorator's `args` list /// also set in location in source code to `loc` -fn gen_call(name: &str, deco_args: &DecoratorArgs, loc: &Location) -> ast::Stmt<()> { - let mut loc = loc.to_owned(); +fn gen_call( + name: &str, + deco_args: &DecoratorArgs, + kwarg: &Option, + loc: &Location, +) -> ast::Stmt<()> { + let mut loc = *loc; // adding a line to avoid confusing if any error occurs when calling the function // then the pretty print will point to the last line in code // instead of point to any of existing code written by user. loc.newline(); - let args: Vec> = deco_args - .arg_names - .iter() - .map(|v| { - let node = ast::ExprKind::Name { - id: v.to_owned(), - ctx: ast::ExprContext::Load, - }; - create_located(node, loc) - }) - .collect(); + let mut args: Vec> = if let Some(arg_names) = &deco_args.arg_names { + arg_names + .iter() + .map(|v| { + let node = ast::ExprKind::Name { + id: v.clone(), + ctx: ast::ExprContext::Load, + }; + create_located(node, loc) + }) + .collect() + } else { + vec![] + }; + + if let Some(kwarg) = kwarg { + let node = ast::ExprKind::Name { + id: kwarg.clone(), + ctx: ast::ExprContext::Load, + }; + args.push(create_located(node, loc)); + } + let func = ast::ExprKind::Call { func: Box::new(create_located( ast::ExprKind::Name { @@ -71,7 +88,12 @@ fn gen_call(name: &str, deco_args: &DecoratorArgs, loc: &Location) -> ast::Stmt< /// So we should avoid running too much Python Bytecode, hence in this function we delete `@` decorator(instead of actually write a decorator in python) /// And add a function call in the end and also /// strip type annotation -pub fn compile_script(name: &str, deco_args: &DecoratorArgs, script: &str) -> Result { +pub fn compile_script( + name: &str, + deco_args: &DecoratorArgs, + kwarg: &Option, + script: &str, +) -> Result { // note that it's important to use `parser::Mode::Interactive` so the ast can be compile to return a result instead of return None in eval mode let mut top = parser::parse(script, parser::Mode::Interactive, "").context(PyParseSnafu)?; @@ -89,6 +111,20 @@ pub fn compile_script(name: &str, deco_args: &DecoratorArgs, script: &str) -> Re type_comment: __main__, } = &mut stmt.node { + // Rewrite kwargs in coprocessor, make it as a positional argument + if !decorator_list.is_empty() { + if let Some(kwarg) = kwarg { + args.kwarg = None; + let node = ArgData { + arg: kwarg.clone(), + annotation: None, + type_comment: Some("kwargs".to_string()), + }; + let kwarg = create_located(node, stmt.location); + args.args.push(kwarg); + } + } + *decorator_list = Vec::new(); // strip type annotation // def a(b: int, c:int) -> int @@ -115,14 +151,14 @@ pub fn compile_script(name: &str, deco_args: &DecoratorArgs, script: &str) -> Re } // Append statement which calling coprocessor function. // It's safe to unwrap loc, it is always exists. - stmts.push(gen_call(name, deco_args, &loc.unwrap())); + stmts.push(gen_call(name, deco_args, kwarg, &loc.unwrap())); } else { return fail_parse_error!(format!("Expect statement in script, found: {top:?}"), None); } // use `compile::Mode::BlockExpr` so it return the result of statement compile_top( &top, - "".to_owned(), + "".to_string(), Mode::BlockExpr, CompileOpts { optimize: 0 }, ) diff --git a/src/script/src/python/coprocessor/parse.rs b/src/script/src/python/coprocessor/parse.rs index 18bcdca8c8..a616680352 100644 --- a/src/script/src/python/coprocessor/parse.rs +++ b/src/script/src/python/coprocessor/parse.rs @@ -29,7 +29,7 @@ use crate::python::error::{ensure, CoprParseSnafu, PyParseSnafu, Result}; #[cfg_attr(test, derive(Deserialize))] #[derive(Default, Debug, Clone, PartialEq, Eq)] pub struct DecoratorArgs { - pub arg_names: Vec, + pub arg_names: Option>, pub ret_names: Vec, pub sql: Option, // maybe add a URL for connecting or what? @@ -58,7 +58,7 @@ fn py_str_to_string(s: &ast::Expr<()>) -> Result { kind: _, } = &s.node { - Ok(v.to_owned()) + Ok(v.clone()) } else { fail_parse_error!( format!( @@ -100,10 +100,7 @@ fn try_into_datatype(ty: &str, loc: &Location) -> Result Ok(None), // note the different between "_" and _ - _ => fail_parse_error!( - format!("Unknown datatype: {ty} at {loc:?}"), - Some(loc.to_owned()) - ), + _ => fail_parse_error!(format!("Unknown datatype: {ty} at {loc:?}"), Some(*loc)), } } @@ -263,7 +260,7 @@ fn parse_annotation(sub: &ast::Expr<()>) -> Result { fn parse_keywords(keywords: &Vec>) -> Result { // more keys maybe add to this list of `avail_key`(like `sql` for querying and maybe config for connecting to database?), for better extension using a `HashSet` in here let avail_key = HashSet::from(["args", "returns", "sql"]); - let opt_keys = HashSet::from(["sql"]); + let opt_keys = HashSet::from(["sql", "args"]); let mut visited_key = HashSet::new(); let len_min = avail_key.len() - opt_keys.len(); let len_max = avail_key.len(); @@ -298,7 +295,7 @@ fn parse_keywords(keywords: &Vec>) -> Result { visited_key.insert(s); } match s { - "args" => ret_args.arg_names = pylist_to_vec(&kw.node.value)?, + "args" => ret_args.arg_names = Some(pylist_to_vec(&kw.node.value)?), "returns" => ret_args.ret_names = pylist_to_vec(&kw.node.value)?, "sql" => ret_args.sql = Some(py_str_to_string(&kw.node.value)?), _ => unreachable!(), @@ -476,17 +473,19 @@ pub fn parse_and_compile_copr( // make sure both arguments&returns in function // and in decorator have same length - ensure!( - deco_args.arg_names.len() == arg_types.len(), - CoprParseSnafu { - reason: format!( - "args number in decorator({}) and function({}) doesn't match", - deco_args.arg_names.len(), - arg_types.len() - ), - loc: None - } - ); + if let Some(arg_names) = &deco_args.arg_names { + ensure!( + arg_names.len() == arg_types.len(), + CoprParseSnafu { + reason: format!( + "args number in decorator({}) and function({}) doesn't match", + arg_names.len(), + arg_types.len() + ), + loc: None + } + ); + } ensure!( deco_args.ret_names.len() == return_types.len(), CoprParseSnafu { @@ -498,13 +497,15 @@ pub fn parse_and_compile_copr( loc: None } ); + let kwarg = fn_args.kwarg.as_ref().map(|arg| arg.node.arg.clone()); coprocessor = Some(Coprocessor { - code_obj: Some(compile::compile_script(name, &deco_args, script)?), + code_obj: Some(compile::compile_script(name, &deco_args, &kwarg, script)?), name: name.to_string(), deco_args, arg_types, return_types, - script: script.to_owned(), + kwarg, + script: script.to_string(), query_engine: query_engine.as_ref().map(|e| Arc::downgrade(e).into()), }); } diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index fd89ee74be..25af2c14ba 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -14,6 +14,7 @@ //! Python script engine use std::any::Any; +use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -25,7 +26,9 @@ use common_query::error::{PyUdfSnafu, UdfTempRecordBatchSnafu}; use common_query::prelude::Signature; use common_query::Output; use common_recordbatch::error::{ExternalSnafu, Result as RecordBatchResult}; -use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream}; +use common_recordbatch::{ + RecordBatch, RecordBatchStream, RecordBatches, SendableRecordBatchStream, +}; use datafusion_expr::Volatility; use datatypes::schema::{ColumnSchema, SchemaRef}; use datatypes::vectors::VectorRef; @@ -53,7 +56,12 @@ impl std::fmt::Display for PyUDF { f, "{}({})->", &self.copr.name, - &self.copr.deco_args.arg_names.join(",") + self.copr + .deco_args + .arg_names + .as_ref() + .unwrap_or(&vec![]) + .join(",") ) } } @@ -73,11 +81,17 @@ impl PyUDF { /// Fake a schema, should only be used with dynamically eval a Python Udf fn fake_schema(&self, columns: &[VectorRef]) -> SchemaRef { - let arg_names = &self.copr.deco_args.arg_names; + let empty_args = vec![]; + let arg_names = self + .copr + .deco_args + .arg_names + .as_ref() + .unwrap_or(&empty_args); let col_sch: Vec<_> = columns .iter() .enumerate() - .map(|(i, col)| ColumnSchema::new(arg_names[i].to_owned(), col.data_type(), true)) + .map(|(i, col)| ColumnSchema::new(arg_names[i].clone(), col.data_type(), true)) .collect(); let schema = datatypes::schema::Schema::new(col_sch); Arc::new(schema) @@ -97,7 +111,7 @@ impl Function for PyUDF { match self.copr.return_types.get(0) { Some(Some(AnnotationInfo { datatype: Some(ty), .. - })) => Ok(ty.to_owned()), + })) => Ok(ty.clone()), _ => PyUdfSnafu { msg: "Can't found return type for python UDF {self}", } @@ -113,7 +127,7 @@ impl Function for PyUDF { match ty { Some(AnnotationInfo { datatype: Some(ty), .. - }) => arg_types.push(ty.to_owned()), + }) => arg_types.push(ty.clone()), _ => { know_all_types = false; break; @@ -135,9 +149,8 @@ impl Function for PyUDF { // FIXME(discord9): exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right? let schema = self.fake_schema(columns); let columns = columns.to_vec(); - // TODO(discord9): remove unwrap - let rb = RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?; - let res = exec_parsed(&self.copr, &rb).map_err(|err| { + let rb = Some(RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?); + let res = exec_parsed(&self.copr, &rb, &HashMap::new()).map_err(|err| { PyUdfSnafu { msg: format!("{err:#?}"), } @@ -153,7 +166,7 @@ impl Function for PyUDF { // TODO(discord9): more error handling let res0 = res.column(0); - Ok(res0.to_owned()) + Ok(res0.clone()) } } @@ -168,13 +181,14 @@ impl PyScript { pub fn register_udf(&self) { let udf = PyUDF::from_copr(self.copr.clone()); PyUDF::register_as_udf(udf.clone()); - PyUDF::register_to_query_engine(udf, self.query_engine.to_owned()); + PyUDF::register_to_query_engine(udf, self.query_engine.clone()); } } pub struct CoprStream { stream: SendableRecordBatchStream, copr: CoprocessorRef, + params: HashMap, } impl RecordBatchStream for CoprStream { @@ -190,7 +204,7 @@ impl Stream for CoprStream { match Pin::new(&mut self.stream).poll_next(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Some(Ok(recordbatch))) => { - let batch = exec_parsed(&self.copr, &recordbatch) + let batch = exec_parsed(&self.copr, &Some(recordbatch), &self.params) .map_err(BoxedError::new) .context(ExternalSnafu)?; @@ -218,7 +232,7 @@ impl Script for PyScript { self } - async fn execute(&self, _ctx: EvalContext) -> Result { + async fn execute(&self, params: HashMap, _ctx: EvalContext) -> Result { if let Some(sql) = &self.copr.deco_args.sql { let stmt = QueryLanguageParser::parse_sql(sql).unwrap(); ensure!( @@ -231,12 +245,17 @@ impl Script for PyScript { let res = self.query_engine.execute(&plan).await?; let copr = self.copr.clone(); match res { - Output::Stream(stream) => Ok(Output::Stream(Box::pin(CoprStream { copr, stream }))), + Output::Stream(stream) => Ok(Output::Stream(Box::pin(CoprStream { + params, + copr, + stream, + }))), _ => unreachable!(), } } else { - // TODO(boyan): try to retrieve sql from user request - error::MissingSqlSnafu {}.fail() + let batch = exec_parsed(&self.copr, &None, ¶ms)?; + let batches = RecordBatches::try_new(batch.schema.clone(), vec![batch]).unwrap(); + Ok(Output::RecordBatches(batches)) } } } @@ -284,6 +303,7 @@ mod tests { use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_recordbatch::util; use datatypes::prelude::ScalarVector; + use datatypes::value::Value; use datatypes::vectors::{Float64Vector, Int64Vector}; use query::QueryEngineFactory; use table::table::numbers::NumbersTable; @@ -326,7 +346,10 @@ def test(number)->vector[u32]: .compile(script, CompileContext::default()) .await .unwrap(); - let output = script.execute(EvalContext::default()).await.unwrap(); + let output = script + .execute(HashMap::default(), EvalContext::default()) + .await + .unwrap(); let res = common_recordbatch::util::collect_batches(match output { Output::Stream(s) => s, _ => unreachable!(), @@ -337,6 +360,36 @@ def test(number)->vector[u32]: assert_eq!(rb.column(0).len(), 100); } + #[tokio::test] + async fn test_user_params_in_py() { + let script_engine = sample_script_engine(); + + let script = r#" +@copr(returns = ["number"]) +def test(**params)->vector[i64]: + return int(params['a']) + int(params['b']) +"#; + let script = script_engine + .compile(script, CompileContext::default()) + .await + .unwrap(); + let mut params = HashMap::new(); + params.insert("a".to_string(), "30".to_string()); + params.insert("b".to_string(), "12".to_string()); + let _output = script + .execute(params, EvalContext::default()) + .await + .unwrap(); + let res = match _output { + Output::RecordBatches(s) => s, + _ => todo!(), + }; + let rb = res.iter().next().expect("One and only one recordbatch"); + assert_eq!(rb.column(0).len(), 1); + let result = rb.column(0).get(0); + assert!(matches!(result, Value::Int64(42))); + } + #[tokio::test] async fn test_data_frame_in_py() { let script_engine = sample_script_engine(); @@ -353,7 +406,10 @@ def test(number)->vector[u32]: .compile(script, CompileContext::default()) .await .unwrap(); - let _output = script.execute(EvalContext::default()).await.unwrap(); + let _output = script + .execute(HashMap::new(), EvalContext::default()) + .await + .unwrap(); let res = common_recordbatch::util::collect_batches(match _output { Output::Stream(s) => s, _ => todo!(), @@ -382,7 +438,10 @@ def test(a, b, c): .compile(script, CompileContext::default()) .await .unwrap(); - let output = script.execute(EvalContext::default()).await.unwrap(); + let output = script + .execute(HashMap::new(), EvalContext::default()) + .await + .unwrap(); match output { Output::Stream(stream) => { let numbers = util::collect(stream).await.unwrap(); @@ -417,7 +476,10 @@ def test(a): .compile(script, CompileContext::default()) .await .unwrap(); - let output = script.execute(EvalContext::default()).await.unwrap(); + let output = script + .execute(HashMap::new(), EvalContext::default()) + .await + .unwrap(); match output { Output::Stream(stream) => { let numbers = util::collect(stream).await.unwrap(); diff --git a/src/script/src/python/error.rs b/src/script/src/python/error.rs index 70e3b89ab3..5fa21a3230 100644 --- a/src/script/src/python/error.rs +++ b/src/script/src/python/error.rs @@ -216,7 +216,7 @@ pub fn visualize_loc( /// extract a reason for [`Error`] in string format, also return a location if possible pub fn get_error_reason_loc(err: &Error) -> (String, Option) { match err { - Error::CoprParse { reason, loc, .. } => (reason.clone(), loc.to_owned()), + Error::CoprParse { reason, loc, .. } => (reason.clone(), *loc), Error::Other { reason, .. } => (reason.clone(), None), Error::PyRuntime { msg, .. } => (msg.clone(), None), Error::PyParse { source, .. } => (source.error.to_string(), Some(source.location)), diff --git a/src/script/src/python/test.rs b/src/script/src/python/test.rs index e0778c32ad..a935f6ca41 100644 --- a/src/script/src/python/test.rs +++ b/src/script/src/python/test.rs @@ -126,7 +126,7 @@ fn run_ron_testcases() { } Predicate::ExecIsOk { fields, columns } => { let rb = create_sample_recordbatch(); - let res = coprocessor::exec_coprocessor(&testcase.code, &rb).unwrap(); + let res = coprocessor::exec_coprocessor(&testcase.code, &Some(rb)).unwrap(); fields .iter() .zip(res.schema.column_schemas()) @@ -152,7 +152,7 @@ fn run_ron_testcases() { reason: part_reason, } => { let rb = create_sample_recordbatch(); - let res = coprocessor::exec_coprocessor(&testcase.code, &rb); + let res = coprocessor::exec_coprocessor(&testcase.code, &Some(rb)); assert!(res.is_err(), "{res:#?}\nExpect Err(...), actual Ok(...)"); if let Err(res) = res { error!( @@ -254,7 +254,7 @@ def calc_rvs(open_time, close): ], ) .unwrap(); - let ret = coprocessor::exec_coprocessor(python_source, &rb); + let ret = coprocessor::exec_coprocessor(python_source, &Some(rb)); if let Err(Error::PyParse { backtrace: _, source, @@ -304,7 +304,7 @@ def a(cpu, mem): ], ) .unwrap(); - let ret = coprocessor::exec_coprocessor(python_source, &rb); + let ret = coprocessor::exec_coprocessor(python_source, &Some(rb)); if let Err(Error::PyParse { backtrace: _, source, diff --git a/src/script/src/python/testcases.ron b/src/script/src/python/testcases.ron index 3ebd2d5e4c..23f819a241 100644 --- a/src/script/src/python/testcases.ron +++ b/src/script/src/python/testcases.ron @@ -19,7 +19,7 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vecto result: ( name: "a", deco_args: ( - arg_names: ["cpu", "mem"], + arg_names: Some(["cpu", "mem"]), ret_names: ["perf", "what", "how", "why"], ), arg_types: [ @@ -49,7 +49,62 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vecto datatype: None, is_nullable: true )), - ] + ], + kwarg: None, + ) + ) + ), + ( + name: "correct_parse_params", + code: r#" +import greptime as gt +from greptime import pow +def add(a, b): + return a + b +def sub(a, b): + return a - b +@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"]) +def a(cpu: vector[f32], mem: vector[f64], **params) -> (vector[f64], vector[f64|None], vector[_], vector[_ | None]): + for key, value in params.items(): + print("%s == %s" % (key, value)) + return add(cpu, mem), sub(cpu, mem), cpu * mem, cpu / mem + "#, + predicate: ParseIsOk( + result: ( + name: "a", + deco_args: ( + arg_names: Some(["cpu", "mem"]), + ret_names: ["perf", "what", "how", "why"], + ), + arg_types: [ + Some(( + datatype: Some(Float32(())), + is_nullable: false + )), + Some(( + datatype: Some(Float64(())), + is_nullable: false + )), + ], + return_types: [ + Some(( + datatype: Some(Float64(())), + is_nullable: false + )), + Some(( + datatype: Some(Float64(())), + is_nullable: true + )), + Some(( + datatype: None, + is_nullable: false + )), + Some(( + datatype: None, + is_nullable: true + )), + ], + kwarg: Some("params"), ) ) ), @@ -231,7 +286,7 @@ def a(cpu: vector[f64], mem: vector[f64])->(vector[f64|None], vector[into(f64)], "#, predicate: ParseIsErr( reason: - " keyword argument, found " + "Expect `returns` keyword" ) ), ( diff --git a/src/script/src/python/vector.rs b/src/script/src/python/vector.rs index c776f140b3..24c7aa6976 100644 --- a/src/script/src/python/vector.rs +++ b/src/script/src/python/vector.rs @@ -490,6 +490,21 @@ impl PyVector { self.as_vector_ref().len() } + #[pymethod(name = "concat")] + fn concat(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult { + let left = self.to_arrow_array(); + let right = other.to_arrow_array(); + + let res = compute::concat(&[left.as_ref(), right.as_ref()]); + let res = res.map_err(|err| vm.new_runtime_error(format!("Arrow Error: {err:#?}")))?; + let ret = Helper::try_into_vector(res.clone()).map_err(|e| { + vm.new_type_error(format!( + "Can't cast result into vector, result: {res:?}, err: {e:?}", + )) + })?; + Ok(ret.into()) + } + /// take a boolean array and filters the Array, returning elements matching the filter (i.e. where the values are true). #[pymethod(name = "filter")] fn filter(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult { @@ -549,7 +564,7 @@ impl PyVector { // in the newest version of rustpython_vm, wrapped_at for isize is replace by wrap_index(i, len) let i = i .wrapped_at(self.len()) - .ok_or_else(|| vm.new_index_error("PyVector index out of range".to_owned()))?; + .ok_or_else(|| vm.new_index_error("PyVector index out of range".to_string()))?; Ok(val_to_pyobj(self.as_vector_ref().get(i), vm)) } @@ -912,7 +927,7 @@ impl AsSequence for PyVector { zelf.getitem_by_index(i, vm) }), ass_item: atomic_func!(|_seq, _i, _value, vm| { - Err(vm.new_type_error("PyVector object doesn't support item assigns".to_owned())) + Err(vm.new_type_error("PyVector object doesn't support item assigns".to_string())) }), ..PySequenceMethods::NOT_IMPLEMENTED }); @@ -1080,7 +1095,7 @@ pub mod tests { .compile( script, rustpython_compiler_core::Mode::BlockExpr, - "".to_owned(), + "".to_string(), ) .map_err(|err| vm.new_syntax_error(&err))?; let ret = vm.run_code_obj(code_obj, scope); diff --git a/src/servers/src/http/script.rs b/src/servers/src/http/script.rs index 881a4d5551..7cf59862ae 100644 --- a/src/servers/src/http/script.rs +++ b/src/servers/src/http/script.rs @@ -11,7 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +use std::collections::HashMap; use std::time::Instant; use axum::extract::{Json, Query, RawBody, State}; @@ -81,10 +81,12 @@ pub async fn scripts( } } -#[derive(Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Debug, Serialize, Deserialize, JsonSchema, Default)] pub struct ScriptQuery { pub db: Option, pub name: Option, + #[serde(flatten)] + pub params: HashMap, } /// Handler to execute script @@ -110,7 +112,7 @@ pub async fn run_script( // TODO(sunng87): query_context and db name resolution let output = script_handler - .execute_script(schema.unwrap(), name.unwrap()) + .execute_script(schema.unwrap(), name.unwrap(), params.params) .await; let resp = JsonResponse::from_output(vec![output]).await; diff --git a/src/servers/src/query_handler.rs b/src/servers/src/query_handler.rs index edaad6f0ad..d94cdb1c0a 100644 --- a/src/servers/src/query_handler.rs +++ b/src/servers/src/query_handler.rs @@ -25,6 +25,7 @@ pub mod grpc; pub mod sql; +use std::collections::HashMap; use std::sync::Arc; use api::prometheus::remote::{ReadRequest, WriteRequest}; @@ -45,7 +46,12 @@ pub type ScriptHandlerRef = Arc; #[async_trait] pub trait ScriptHandler { async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>; - async fn execute_script(&self, schema: &str, name: &str) -> Result; + async fn execute_script( + &self, + schema: &str, + name: &str, + params: HashMap, + ) -> Result; } #[async_trait] diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index 3206c33abd..6ad18a9ac8 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -23,7 +23,10 @@ use servers::http::{handler as http_handler, script as script_handler, ApiState, use session::context::UserInfo; use table::test_util::MemTable; -use crate::{create_testing_script_handler, create_testing_sql_query_handler}; +use crate::{ + create_testing_script_handler, create_testing_sql_query_handler, ScriptHandlerRef, + ServerSqlQueryHandlerRef, +}; #[tokio::test] async fn test_sql_not_provided() { @@ -68,6 +71,25 @@ async fn test_sql_output_rows() { match &json.output().expect("assertion failed")[0] { JsonOutput::Records(records) => { assert_eq!(1, records.num_rows()); + let json = serde_json::to_string_pretty(&records).unwrap(); + assert_eq!( + json, + r#"{ + "schema": { + "column_schemas": [ + { + "name": "SUM(numbers.uint32s)", + "data_type": "UInt64" + } + ] + }, + "rows": [ + [ + 4950 + ] + ] +}"# + ); } _ => unreachable!(), } @@ -95,6 +117,25 @@ async fn test_sql_form() { match &json.output().expect("assertion failed")[0] { JsonOutput::Records(records) => { assert_eq!(1, records.num_rows()); + let json = serde_json::to_string_pretty(&records).unwrap(); + assert_eq!( + json, + r#"{ + "schema": { + "column_schemas": [ + { + "name": "SUM(numbers.uint32s)", + "data_type": "UInt64" + } + ] + }, + "rows": [ + [ + 4950 + ] + ] +}"# + ); } _ => unreachable!(), } @@ -110,18 +151,11 @@ async fn test_metrics() { assert!(text.contains("test_metrics counter")); } -#[tokio::test] -async fn test_scripts() { - common_telemetry::init_default_ut_logging(); - - let script = r#" -@copr(sql='select uint32s as number from numbers', args=['number'], returns=['n']) -def test(n): - return n; -"# - .to_string(); - let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); - let script_handler = create_testing_script_handler(MemTable::default_numbers_table()); +async fn insert_script( + script: String, + script_handler: ScriptHandlerRef, + sql_handler: ServerSqlQueryHandlerRef, +) { let body = RawBody(Body::from(script.clone())); let invalid_query = create_invalid_script_query(); let Json(json) = script_handler::scripts( @@ -136,12 +170,13 @@ def test(n): assert!(!json.success(), "{json:?}"); assert_eq!(json.error().unwrap(), "Invalid argument: invalid schema"); - let body = RawBody(Body::from(script)); + let body = RawBody(Body::from(script.clone())); let exec = create_script_query(); + // Insert the script let Json(json) = script_handler::scripts( State(ApiState { - sql_handler, - script_handler: Some(script_handler), + sql_handler: sql_handler.clone(), + script_handler: Some(script_handler.clone()), }), exec, body, @@ -152,10 +187,144 @@ def test(n): assert!(json.output().is_none()); } +#[tokio::test] +async fn test_scripts() { + common_telemetry::init_default_ut_logging(); + + let script = r#" +@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n']) +def test(n) -> vector[i64]: + return n; +"# + .to_string(); + let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let script_handler = create_testing_script_handler(MemTable::default_numbers_table()); + + insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await; + // Run the script + let exec = create_script_query(); + let Json(json) = script_handler::run_script( + State(ApiState { + sql_handler, + script_handler: Some(script_handler), + }), + exec, + ) + .await; + assert!(json.success(), "{json:?}"); + assert!(json.error().is_none()); + + match &json.output().unwrap()[0] { + JsonOutput::Records(records) => { + let json = serde_json::to_string_pretty(&records).unwrap(); + assert_eq!(5, records.num_rows()); + assert_eq!( + json, + r#"{ + "schema": { + "column_schemas": [ + { + "name": "n", + "data_type": "Int64" + } + ] + }, + "rows": [ + [ + 0 + ], + [ + 1 + ], + [ + 2 + ], + [ + 3 + ], + [ + 4 + ] + ] +}"# + ); + } + _ => unreachable!(), + } +} + +#[tokio::test] +async fn test_scripts_with_params() { + common_telemetry::init_default_ut_logging(); + + let script = r#" +@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n']) +def test(n, **params) -> vector[i64]: + return n + int(params['a']) +"# + .to_string(); + let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let script_handler = create_testing_script_handler(MemTable::default_numbers_table()); + + insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await; + // Run the script + let mut exec = create_script_query(); + exec.0.params.insert("a".to_string(), "42".to_string()); + let Json(json) = script_handler::run_script( + State(ApiState { + sql_handler, + script_handler: Some(script_handler), + }), + exec, + ) + .await; + assert!(json.success(), "{json:?}"); + assert!(json.error().is_none()); + + match &json.output().unwrap()[0] { + JsonOutput::Records(records) => { + let json = serde_json::to_string_pretty(&records).unwrap(); + assert_eq!(5, records.num_rows()); + assert_eq!( + json, + r#"{ + "schema": { + "column_schemas": [ + { + "name": "n", + "data_type": "Int64" + } + ] + }, + "rows": [ + [ + 42 + ], + [ + 43 + ], + [ + 44 + ], + [ + 45 + ], + [ + 46 + ] + ] +}"# + ); + } + _ => unreachable!(), + } +} + fn create_script_query() -> Query { Query(script_handler::ScriptQuery { db: Some("test".to_string()), name: Some("test".to_string()), + ..Default::default() }) } @@ -163,6 +332,7 @@ fn create_invalid_script_query() -> Query { Query(script_handler::ScriptQuery { db: None, name: None, + ..Default::default() }) } diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index fc0818902a..055f95ff2f 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -120,12 +120,20 @@ impl ScriptHandler for DummyInstance { Ok(()) } - async fn execute_script(&self, schema: &str, name: &str) -> Result { + async fn execute_script( + &self, + schema: &str, + name: &str, + params: HashMap, + ) -> Result { let key = format!("{schema}_{name}"); let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone(); - Ok(py_script.execute(EvalContext::default()).await.unwrap()) + Ok(py_script + .execute(params, EvalContext::default()) + .await + .unwrap()) } }