feat: supports passing user params into coprocessor (#962)

* feat: make args in coprocessor optional

* feat: supports kwargs for coprocessor as params passed by the users

* feat: supports params for /run-script

* fix: we should rewrite the coprocessor by removing kwargs

* fix: remove println

* fix: compile error after rebasing

* fix: improve http_handler_test

* test: http scripts api with user params

* refactor: tweak all to_owned
This commit is contained in:
dennis zhuang
2023-02-16 16:11:26 +08:00
committed by GitHub
parent ddbc97befb
commit 5ec1a7027b
19 changed files with 564 additions and 142 deletions

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use async_trait::async_trait;
use common_query::Output;
use common_telemetry::timer;
@@ -34,8 +36,15 @@ impl ScriptHandler for Instance {
.await
}
async fn execute_script(&self, schema: &str, name: &str) -> servers::error::Result<Output> {
async fn execute_script(
&self,
schema: &str,
name: &str,
params: HashMap<String, String>,
) -> servers::error::Result<Output> {
let _timer = timer!(metric::METRIC_RUN_SCRIPT_ELAPSED);
self.script_executor.execute_script(schema, name).await
self.script_executor
.execute_script(schema, name, params)
.await
}
}

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use catalog::CatalogManagerRef;
use common_query::Output;
use query::QueryEngineRef;
@@ -34,13 +36,19 @@ mod dummy {
pub async fn insert_script(
&self,
_schema: &str,
_name: &str,
_script: &str,
) -> servers::error::Result<()> {
servers::error::NotSupportedSnafu { feat: "script" }.fail()
}
pub async fn execute_script(&self, _script: &str) -> servers::error::Result<Output> {
pub async fn execute_script(
&self,
_schema: &str,
_name: &str,
_params: HashMap<String, String>,
) -> servers::error::Result<Output> {
servers::error::NotSupportedSnafu { feat: "script" }.fail()
}
}
@@ -94,9 +102,10 @@ mod python {
&self,
schema: &str,
name: &str,
params: HashMap<String, String>,
) -> servers::error::Result<Output> {
self.script_manager
.execute(schema, name)
.execute(schema, name, params)
.await
.map_err(|e| {
error!(e; "Instance failed to execute script");

View File

@@ -19,6 +19,7 @@ mod opentsdb;
mod prometheus;
mod standalone;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
@@ -512,9 +513,14 @@ impl ScriptHandler for Instance {
}
}
async fn execute_script(&self, schema: &str, script: &str) -> server_error::Result<Output> {
async fn execute_script(
&self,
schema: &str,
script: &str,
params: HashMap<String, String>,
) -> server_error::Result<Output> {
if let Some(handler) = &self.script_handler {
handler.execute_script(schema, script).await
handler.execute_script(schema, script, params).await
} else {
server_error::NotSupportedSnafu {
feat: "Script execution in Frontend",

View File

@@ -15,6 +15,7 @@
//! Script engine
use std::any::Any;
use std::collections::HashMap;
use async_trait::async_trait;
use common_error::ext::ErrorExt;
@@ -30,7 +31,11 @@ pub trait Script {
fn as_any(&self) -> &dyn Any;
/// Execute the script and returns the output.
async fn execute(&self, ctx: EvalContext) -> std::result::Result<Output, Self::Error>;
async fn execute(
&self,
params: HashMap<String, String>,
ctx: EvalContext,
) -> std::result::Result<Output, Self::Error>;
}
#[async_trait]

View File

@@ -76,7 +76,12 @@ impl ScriptManager {
Ok(compiled_script)
}
pub async fn execute(&self, schema: &str, name: &str) -> Result<Output> {
pub async fn execute(
&self,
schema: &str,
name: &str,
params: HashMap<String, String>,
) -> Result<Output> {
let script = {
let s = self.compiled.read().unwrap().get(name).cloned();
@@ -90,7 +95,7 @@ impl ScriptManager {
let script = script.context(ScriptNotFoundSnafu { name })?;
script
.execute(EvalContext::default())
.execute(params, EvalContext::default())
.await
.context(ExecutePythonSnafu { name })
}

View File

@@ -97,7 +97,7 @@ pub fn try_into_columnar_value(obj: PyObjectRef, vm: &VirtualMachine) -> PyResul
.borrow_vec()
.iter()
.map(|obj| -> PyResult<ScalarValue> {
let col = try_into_columnar_value(obj.to_owned(), vm)?;
let col = try_into_columnar_value(obj.clone(), vm)?;
match col {
DFColValue::Array(arr) => Err(vm.new_type_error(format!(
"Expect only scalar value in a list, found a vector of type {:?} nested in list", arr.data_type()
@@ -243,7 +243,7 @@ macro_rules! bind_aggr_fn {
$(
Arc::new(expressions::Column::new(stringify!($EXPR_ARGS), 0)) as _,
)*
stringify!($AGGR_FUNC), $DATA_TYPE.to_owned()),
stringify!($AGGR_FUNC), $DATA_TYPE.clone()),
$ARGS, $VM)
};
}
@@ -597,7 +597,7 @@ pub(crate) mod greptime_builtin {
Arc::new(percent) as _,
],
"ApproxPercentileCont",
(values.to_arrow_array().data_type()).to_owned(),
(values.to_arrow_array().data_type()).clone(),
)
.map_err(|err| from_df_err(err, vm))?,
&[values.to_arrow_array()],
@@ -839,7 +839,7 @@ pub(crate) mod greptime_builtin {
return Ok(ret.into());
}
let cur = cur.slice(0, cur.len() - 1); // except the last one that is
let fill = gen_none_array(cur.data_type().to_owned(), 1, vm)?;
let fill = gen_none_array(cur.data_type().clone(), 1, vm)?;
let ret = compute::concat(&[&*fill, &*cur]).map_err(|err| {
vm.new_runtime_error(format!("Can't concat array[0] with array[0:-1]!{err:#?}"))
})?;
@@ -864,7 +864,7 @@ pub(crate) mod greptime_builtin {
return Ok(ret.into());
}
let cur = cur.slice(1, cur.len() - 1); // except the last one that is
let fill = gen_none_array(cur.data_type().to_owned(), 1, vm)?;
let fill = gen_none_array(cur.data_type().clone(), 1, vm)?;
let ret = compute::concat(&[&*cur, &*fill]).map_err(|err| {
vm.new_runtime_error(format!("Can't concat array[0] with array[0:-1]!{err:#?}"))
})?;
@@ -1048,7 +1048,7 @@ pub(crate) mod greptime_builtin {
match (ch.is_ascii_digit(), &state) {
(true, State::Separator(_)) => {
let res = &input[prev..idx];
let res = State::Separator(res.to_owned());
let res = State::Separator(res.to_string());
parsed.push(res);
prev = idx;
state = State::Num(Default::default());
@@ -1073,7 +1073,7 @@ pub(crate) mod greptime_builtin {
}
State::Separator(_) => {
let res = &input[prev..];
State::Separator(res.to_owned())
State::Separator(res.to_string())
}
};
parsed.push(last);

View File

@@ -240,10 +240,7 @@ impl PyValue {
.as_any()
.downcast_ref::<Float64Array>()
.ok_or(format!("Can't cast {vec_f64:#?} to Float64Array!"))?;
let ret = vec_f64
.into_iter()
.map(|v| v.map(|inner| inner.to_owned()))
.collect::<Vec<_>>();
let ret = vec_f64.into_iter().collect::<Vec<_>>();
if ret.iter().all(|x| x.is_some()) {
Ok(Self::FloatVec(
ret.into_iter().map(|i| i.unwrap()).collect(),
@@ -266,7 +263,6 @@ impl PyValue {
v.ok_or(format!(
"No null element expected, found one in {idx} position"
))
.map(|v| v.to_owned())
})
.collect::<Result<_, String>>()?;
Ok(Self::IntVec(ret))
@@ -275,13 +271,13 @@ impl PyValue {
}
} else if is_instance::<PyInt>(obj, vm) {
let res = obj
.to_owned()
.clone()
.try_into_value::<i64>(vm)
.map_err(|err| format_py_error(err, vm).to_string())?;
Ok(Self::Int(res))
} else if is_instance::<PyFloat>(obj, vm) {
let res = obj
.to_owned()
.clone()
.try_into_value::<f64>(vm)
.map_err(|err| format_py_error(err, vm).to_string())?;
Ok(Self::Float(res))
@@ -338,7 +334,7 @@ fn run_builtin_fn_testcases() {
.compile(
&case.script,
rustpython_compiler_core::Mode::BlockExpr,
"<embedded>".to_owned(),
"<embedded>".to_string(),
)
.map_err(|err| vm.new_syntax_error(&err))
.unwrap();
@@ -389,7 +385,7 @@ fn set_item_into_scope(
scope
.locals
.as_object()
.set_item(&name.to_owned(), vm.new_pyobj(value), vm)
.set_item(&name.to_string(), vm.new_pyobj(value), vm)
.map_err(|err| {
format!(
"Error in setting var {name} in scope: \n{}",
@@ -408,7 +404,7 @@ fn set_lst_of_vecs_in_scope(
scope
.locals
.as_object()
.set_item(name.to_owned(), vm.new_pyobj(vector), vm)
.set_item(&name.to_string(), vm.new_pyobj(vector), vm)
.map_err(|err| {
format!(
"Error in setting var {name} in scope: \n{}",
@@ -447,7 +443,7 @@ fn test_vm() {
from udf_builtins import *
sin(values)"#,
rustpython_compiler_core::Mode::BlockExpr,
"<embedded>".to_owned(),
"<embedded>".to_string(),
)
.map_err(|err| vm.new_syntax_error(&err))
.unwrap();

View File

@@ -16,7 +16,7 @@ pub mod compile;
pub mod parse;
use std::cell::RefCell;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::result::Result as StdResult;
use std::sync::{Arc, Weak};
@@ -36,7 +36,7 @@ use rustpython_vm::AsObject;
#[cfg(test)]
use serde::Deserialize;
use snafu::{OptionExt, ResultExt};
use vm::builtins::{PyBaseExceptionRef, PyList, PyListRef, PyTuple};
use vm::builtins::{PyBaseExceptionRef, PyDict, PyList, PyListRef, PyStr, PyTuple};
use vm::convert::ToPyObject;
use vm::scope::Scope;
use vm::{pyclass, Interpreter, PyObjectRef, PyPayload, PyResult, VirtualMachine};
@@ -73,6 +73,8 @@ pub struct Coprocessor {
pub arg_types: Vec<Option<AnnotationInfo>>,
/// get from python function returns' annotation, first is type, second is is_nullable
pub return_types: Vec<Option<AnnotationInfo>>,
/// kwargs in coprocessor function's signature
pub kwarg: Option<String>,
/// store its corresponding script, also skip serde when in `cfg(test)` to reduce work in compare
#[cfg_attr(test, serde(skip))]
pub script: String,
@@ -103,7 +105,7 @@ impl From<&Arc<dyn QueryEngine>> for QueryEngineWeakRef {
impl std::fmt::Debug for QueryEngineWeakRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("QueryEngineWeakRef")
.field(&self.0.upgrade().map(|f| f.name().to_owned()))
.field(&self.0.upgrade().map(|f| f.name().to_string()))
.finish()
}
}
@@ -147,7 +149,7 @@ impl Coprocessor {
let AnnotationInfo {
datatype: ty,
is_nullable,
} = anno[idx].to_owned().unwrap_or_else(|| {
} = anno[idx].clone().unwrap_or_else(|| {
// default to be not nullable and use DataType inferred by PyVector itself
AnnotationInfo {
datatype: Some(real_ty.clone()),
@@ -248,20 +250,23 @@ fn check_args_anno_real_type(
rb: &RecordBatch,
) -> Result<()> {
for (idx, arg) in args.iter().enumerate() {
let anno_ty = copr.arg_types[idx].to_owned();
let real_ty = arg.to_arrow_array().data_type().to_owned();
let anno_ty = copr.arg_types[idx].clone();
let real_ty = arg.to_arrow_array().data_type().clone();
let real_ty = ConcreteDataType::from_arrow_type(&real_ty);
let is_nullable: bool = rb.schema.column_schemas()[idx].is_nullable();
ensure!(
anno_ty
.to_owned()
.clone()
.map(|v| v.datatype.is_none() // like a vector[_]
|| v.datatype == Some(real_ty.to_owned()) && v.is_nullable == is_nullable)
|| v.datatype == Some(real_ty.clone()) && v.is_nullable == is_nullable)
.unwrap_or(true),
OtherSnafu {
reason: format!(
"column {}'s Type annotation is {:?}, but actual type is {:?}",
copr.deco_args.arg_names[idx], anno_ty, real_ty
// It's safe to unwrap here, we already ensure the args and types number is the same when parsing
copr.deco_args.arg_names.as_ref().unwrap()[idx],
anno_ty,
real_ty
)
}
)
@@ -343,12 +348,12 @@ fn set_items_in_scope(
/// You can return constant in python code like `return 1, 1.0, True`
/// which create a constant array(with same value)(currently support int, float and bool) as column on return
#[cfg(test)]
pub fn exec_coprocessor(script: &str, rb: &RecordBatch) -> Result<RecordBatch> {
pub fn exec_coprocessor(script: &str, rb: &Option<RecordBatch>) -> Result<RecordBatch> {
// 1. parse the script and check if it's only a function with `@coprocessor` decorator, and get `args` and `returns`,
// 2. also check for exist of `args` in `rb`, if not found, return error
// TODO(discord9): cache the result of parse_copr
let copr = parse::parse_and_compile_copr(script, None)?;
exec_parsed(&copr, rb)
exec_parsed(&copr, rb, &HashMap::new())
}
#[pyclass(module = false, name = "query_engine")]
@@ -412,7 +417,7 @@ impl PyQueryEngine {
for rb in rbs.iter() {
let mut vec_of_vec = Vec::with_capacity(rb.columns().len());
for v in rb.columns() {
let v = PyVector::from(v.to_owned());
let v = PyVector::from(v.clone());
vec_of_vec.push(v.to_pyobject(vm));
}
let vec_of_vec = PyList::new_ref(vec_of_vec, vm.as_ref()).to_pyobject(vm);
@@ -440,18 +445,25 @@ fn set_query_engine_in_scope(
.map_err(|e| format_py_error(e, vm))
}
pub(crate) fn exec_with_cached_vm(
fn exec_with_cached_vm(
copr: &Coprocessor,
rb: &RecordBatch,
rb: &Option<RecordBatch>,
args: Vec<PyVector>,
params: &HashMap<String, String>,
vm: &Arc<Interpreter>,
) -> Result<RecordBatch> {
vm.enter(|vm| -> Result<RecordBatch> {
PyVector::make_class(&vm.ctx);
// set arguments with given name and values
let scope = vm.new_scope_with_builtins();
set_items_in_scope(&scope, vm, &copr.deco_args.arg_names, args)?;
set_dataframe_in_scope(&scope, vm, "dataframe", rb)?;
if let Some(rb) = rb {
set_dataframe_in_scope(&scope, vm, "dataframe", rb)?;
}
if let Some(arg_names) = &copr.deco_args.arg_names {
assert_eq!(arg_names.len(), args.len());
set_items_in_scope(&scope, vm, arg_names, args)?;
}
if let Some(engine) = &copr.query_engine {
let query_engine = PyQueryEngine {
@@ -463,6 +475,19 @@ pub(crate) fn exec_with_cached_vm(
set_query_engine_in_scope(&scope, vm, query_engine)?;
}
if let Some(kwarg) = &copr.kwarg {
let dict = PyDict::new_ref(&vm.ctx);
for (k, v) in params {
dict.set_item(k, PyStr::from(v.clone()).into_pyobject(vm), vm)
.map_err(|e| format_py_error(e, vm))?;
}
scope
.locals
.as_object()
.set_item(kwarg, vm.new_pyobj(dict), vm)
.map_err(|e| format_py_error(e, vm))?;
}
// It's safe to unwrap code_object, it's already compiled before.
let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap());
let ret = vm
@@ -470,7 +495,7 @@ pub(crate) fn exec_with_cached_vm(
.map_err(|e| format_py_error(e, vm))?;
// 5. get returns as either a PyVector or a PyTuple, and naming schema them according to `returns`
let col_len = rb.num_rows();
let col_len = rb.as_ref().map(|rb| rb.num_rows()).unwrap_or(1);
let mut cols = try_into_columns(&ret, vm, col_len)?;
ensure!(
cols.len() == copr.deco_args.ret_names.len(),
@@ -485,6 +510,7 @@ pub(crate) fn exec_with_cached_vm(
// if cols and schema's data types is not match, try coerce it to given type(if annotated)(if error occur, return relevant error with question mark)
copr.check_and_cast_type(&mut cols)?;
// 6. return a assembled DfRecordBatch
let schema = copr.gen_schema(&cols)?;
RecordBatch::new(schema, cols).context(NewRecordBatchSnafu)
@@ -533,13 +559,23 @@ pub(crate) fn init_interpreter() -> Arc<Interpreter> {
}
/// using a parsed `Coprocessor` struct as input to execute python code
pub(crate) fn exec_parsed(copr: &Coprocessor, rb: &RecordBatch) -> Result<RecordBatch> {
pub(crate) fn exec_parsed(
copr: &Coprocessor,
rb: &Option<RecordBatch>,
params: &HashMap<String, String>,
) -> Result<RecordBatch> {
// 3. get args from `rb`, and cast them into PyVector
let args: Vec<PyVector> = select_from_rb(rb, &copr.deco_args.arg_names)?;
check_args_anno_real_type(&args, copr, rb)?;
let args: Vec<PyVector> = if let Some(rb) = rb {
let args = select_from_rb(rb, copr.deco_args.arg_names.as_ref().unwrap_or(&vec![]))?;
check_args_anno_real_type(&args, copr, rb)?;
args
} else {
vec![]
};
let interpreter = init_interpreter();
// 4. then set args in scope and compile then run `CodeObject` which already append a new `Call` node
exec_with_cached_vm(copr, rb, args, &interpreter)
exec_with_cached_vm(copr, rb, args, params, &interpreter)
}
/// execute script just like [`exec_coprocessor`] do,
@@ -551,7 +587,7 @@ pub(crate) fn exec_parsed(copr: &Coprocessor, rb: &RecordBatch) -> Result<Record
#[allow(dead_code)]
pub fn exec_copr_print(
script: &str,
rb: &RecordBatch,
rb: &Option<RecordBatch>,
ln_offset: usize,
filename: &str,
) -> StdResult<RecordBatch, String> {
@@ -572,7 +608,7 @@ def add(a, b):
return a + b
@copr(args=["a", "b", "c"], returns = ["r"], sql="select number as a,number as b,number as c from numbers limit 100")
def test(a, b, c):
def test(a, b, c, **params):
import greptime as g
return add(a, b) / g.sqrt(c)
"#;
@@ -585,9 +621,10 @@ def test(a, b, c):
"select number as a,number as b,number as c from numbers limit 100"
);
assert_eq!(deco_args.ret_names, vec!["r"]);
assert_eq!(deco_args.arg_names, vec!["a", "b", "c"]);
assert_eq!(deco_args.arg_names.unwrap(), vec!["a", "b", "c"]);
assert_eq!(copr.arg_types, vec![None, None, None]);
assert_eq!(copr.return_types, vec![None]);
assert_eq!(copr.kwarg, Some("params".to_string()));
assert_eq!(copr.script, script);
assert!(copr.code_obj.is_some());
}

View File

@@ -16,7 +16,7 @@
use rustpython_codegen::compile::compile_top;
use rustpython_compiler::{CompileOpts, Mode};
use rustpython_compiler_core::CodeObject;
use rustpython_parser::ast::{Located, Location};
use rustpython_parser::ast::{ArgData, Located, Location};
use rustpython_parser::{ast, parser};
use snafu::ResultExt;
@@ -31,23 +31,40 @@ fn create_located<T>(node: T, loc: Location) -> Located<T> {
/// generate a call to the coprocessor function
/// with arguments given in decorator's `args` list
/// also set in location in source code to `loc`
fn gen_call(name: &str, deco_args: &DecoratorArgs, loc: &Location) -> ast::Stmt<()> {
let mut loc = loc.to_owned();
fn gen_call(
name: &str,
deco_args: &DecoratorArgs,
kwarg: &Option<String>,
loc: &Location,
) -> ast::Stmt<()> {
let mut loc = *loc;
// adding a line to avoid confusing if any error occurs when calling the function
// then the pretty print will point to the last line in code
// instead of point to any of existing code written by user.
loc.newline();
let args: Vec<Located<ast::ExprKind>> = deco_args
.arg_names
.iter()
.map(|v| {
let node = ast::ExprKind::Name {
id: v.to_owned(),
ctx: ast::ExprContext::Load,
};
create_located(node, loc)
})
.collect();
let mut args: Vec<Located<ast::ExprKind>> = if let Some(arg_names) = &deco_args.arg_names {
arg_names
.iter()
.map(|v| {
let node = ast::ExprKind::Name {
id: v.clone(),
ctx: ast::ExprContext::Load,
};
create_located(node, loc)
})
.collect()
} else {
vec![]
};
if let Some(kwarg) = kwarg {
let node = ast::ExprKind::Name {
id: kwarg.clone(),
ctx: ast::ExprContext::Load,
};
args.push(create_located(node, loc));
}
let func = ast::ExprKind::Call {
func: Box::new(create_located(
ast::ExprKind::Name {
@@ -71,7 +88,12 @@ fn gen_call(name: &str, deco_args: &DecoratorArgs, loc: &Location) -> ast::Stmt<
/// So we should avoid running too much Python Bytecode, hence in this function we delete `@` decorator(instead of actually write a decorator in python)
/// And add a function call in the end and also
/// strip type annotation
pub fn compile_script(name: &str, deco_args: &DecoratorArgs, script: &str) -> Result<CodeObject> {
pub fn compile_script(
name: &str,
deco_args: &DecoratorArgs,
kwarg: &Option<String>,
script: &str,
) -> Result<CodeObject> {
// note that it's important to use `parser::Mode::Interactive` so the ast can be compile to return a result instead of return None in eval mode
let mut top =
parser::parse(script, parser::Mode::Interactive, "<embedded>").context(PyParseSnafu)?;
@@ -89,6 +111,20 @@ pub fn compile_script(name: &str, deco_args: &DecoratorArgs, script: &str) -> Re
type_comment: __main__,
} = &mut stmt.node
{
// Rewrite kwargs in coprocessor, make it as a positional argument
if !decorator_list.is_empty() {
if let Some(kwarg) = kwarg {
args.kwarg = None;
let node = ArgData {
arg: kwarg.clone(),
annotation: None,
type_comment: Some("kwargs".to_string()),
};
let kwarg = create_located(node, stmt.location);
args.args.push(kwarg);
}
}
*decorator_list = Vec::new();
// strip type annotation
// def a(b: int, c:int) -> int
@@ -115,14 +151,14 @@ pub fn compile_script(name: &str, deco_args: &DecoratorArgs, script: &str) -> Re
}
// Append statement which calling coprocessor function.
// It's safe to unwrap loc, it is always exists.
stmts.push(gen_call(name, deco_args, &loc.unwrap()));
stmts.push(gen_call(name, deco_args, kwarg, &loc.unwrap()));
} else {
return fail_parse_error!(format!("Expect statement in script, found: {top:?}"), None);
}
// use `compile::Mode::BlockExpr` so it return the result of statement
compile_top(
&top,
"<embedded>".to_owned(),
"<embedded>".to_string(),
Mode::BlockExpr,
CompileOpts { optimize: 0 },
)

View File

@@ -29,7 +29,7 @@ use crate::python::error::{ensure, CoprParseSnafu, PyParseSnafu, Result};
#[cfg_attr(test, derive(Deserialize))]
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct DecoratorArgs {
pub arg_names: Vec<String>,
pub arg_names: Option<Vec<String>>,
pub ret_names: Vec<String>,
pub sql: Option<String>,
// maybe add a URL for connecting or what?
@@ -58,7 +58,7 @@ fn py_str_to_string(s: &ast::Expr<()>) -> Result<String> {
kind: _,
} = &s.node
{
Ok(v.to_owned())
Ok(v.clone())
} else {
fail_parse_error!(
format!(
@@ -100,10 +100,7 @@ fn try_into_datatype(ty: &str, loc: &Location) -> Result<Option<ConcreteDataType
// for any datatype
"_" => Ok(None),
// note the different between "_" and _
_ => fail_parse_error!(
format!("Unknown datatype: {ty} at {loc:?}"),
Some(loc.to_owned())
),
_ => fail_parse_error!(format!("Unknown datatype: {ty} at {loc:?}"), Some(*loc)),
}
}
@@ -263,7 +260,7 @@ fn parse_annotation(sub: &ast::Expr<()>) -> Result<AnnotationInfo> {
fn parse_keywords(keywords: &Vec<ast::Keyword<()>>) -> Result<DecoratorArgs> {
// more keys maybe add to this list of `avail_key`(like `sql` for querying and maybe config for connecting to database?), for better extension using a `HashSet` in here
let avail_key = HashSet::from(["args", "returns", "sql"]);
let opt_keys = HashSet::from(["sql"]);
let opt_keys = HashSet::from(["sql", "args"]);
let mut visited_key = HashSet::new();
let len_min = avail_key.len() - opt_keys.len();
let len_max = avail_key.len();
@@ -298,7 +295,7 @@ fn parse_keywords(keywords: &Vec<ast::Keyword<()>>) -> Result<DecoratorArgs> {
visited_key.insert(s);
}
match s {
"args" => ret_args.arg_names = pylist_to_vec(&kw.node.value)?,
"args" => ret_args.arg_names = Some(pylist_to_vec(&kw.node.value)?),
"returns" => ret_args.ret_names = pylist_to_vec(&kw.node.value)?,
"sql" => ret_args.sql = Some(py_str_to_string(&kw.node.value)?),
_ => unreachable!(),
@@ -476,17 +473,19 @@ pub fn parse_and_compile_copr(
// make sure both arguments&returns in function
// and in decorator have same length
ensure!(
deco_args.arg_names.len() == arg_types.len(),
CoprParseSnafu {
reason: format!(
"args number in decorator({}) and function({}) doesn't match",
deco_args.arg_names.len(),
arg_types.len()
),
loc: None
}
);
if let Some(arg_names) = &deco_args.arg_names {
ensure!(
arg_names.len() == arg_types.len(),
CoprParseSnafu {
reason: format!(
"args number in decorator({}) and function({}) doesn't match",
arg_names.len(),
arg_types.len()
),
loc: None
}
);
}
ensure!(
deco_args.ret_names.len() == return_types.len(),
CoprParseSnafu {
@@ -498,13 +497,15 @@ pub fn parse_and_compile_copr(
loc: None
}
);
let kwarg = fn_args.kwarg.as_ref().map(|arg| arg.node.arg.clone());
coprocessor = Some(Coprocessor {
code_obj: Some(compile::compile_script(name, &deco_args, script)?),
code_obj: Some(compile::compile_script(name, &deco_args, &kwarg, script)?),
name: name.to_string(),
deco_args,
arg_types,
return_types,
script: script.to_owned(),
kwarg,
script: script.to_string(),
query_engine: query_engine.as_ref().map(|e| Arc::downgrade(e).into()),
});
}

View File

@@ -14,6 +14,7 @@
//! Python script engine
use std::any::Any;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
@@ -25,7 +26,9 @@ use common_query::error::{PyUdfSnafu, UdfTempRecordBatchSnafu};
use common_query::prelude::Signature;
use common_query::Output;
use common_recordbatch::error::{ExternalSnafu, Result as RecordBatchResult};
use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream};
use common_recordbatch::{
RecordBatch, RecordBatchStream, RecordBatches, SendableRecordBatchStream,
};
use datafusion_expr::Volatility;
use datatypes::schema::{ColumnSchema, SchemaRef};
use datatypes::vectors::VectorRef;
@@ -53,7 +56,12 @@ impl std::fmt::Display for PyUDF {
f,
"{}({})->",
&self.copr.name,
&self.copr.deco_args.arg_names.join(",")
self.copr
.deco_args
.arg_names
.as_ref()
.unwrap_or(&vec![])
.join(",")
)
}
}
@@ -73,11 +81,17 @@ impl PyUDF {
/// Fake a schema, should only be used with dynamically eval a Python Udf
fn fake_schema(&self, columns: &[VectorRef]) -> SchemaRef {
let arg_names = &self.copr.deco_args.arg_names;
let empty_args = vec![];
let arg_names = self
.copr
.deco_args
.arg_names
.as_ref()
.unwrap_or(&empty_args);
let col_sch: Vec<_> = columns
.iter()
.enumerate()
.map(|(i, col)| ColumnSchema::new(arg_names[i].to_owned(), col.data_type(), true))
.map(|(i, col)| ColumnSchema::new(arg_names[i].clone(), col.data_type(), true))
.collect();
let schema = datatypes::schema::Schema::new(col_sch);
Arc::new(schema)
@@ -97,7 +111,7 @@ impl Function for PyUDF {
match self.copr.return_types.get(0) {
Some(Some(AnnotationInfo {
datatype: Some(ty), ..
})) => Ok(ty.to_owned()),
})) => Ok(ty.clone()),
_ => PyUdfSnafu {
msg: "Can't found return type for python UDF {self}",
}
@@ -113,7 +127,7 @@ impl Function for PyUDF {
match ty {
Some(AnnotationInfo {
datatype: Some(ty), ..
}) => arg_types.push(ty.to_owned()),
}) => arg_types.push(ty.clone()),
_ => {
know_all_types = false;
break;
@@ -135,9 +149,8 @@ impl Function for PyUDF {
// FIXME(discord9): exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right?
let schema = self.fake_schema(columns);
let columns = columns.to_vec();
// TODO(discord9): remove unwrap
let rb = RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?;
let res = exec_parsed(&self.copr, &rb).map_err(|err| {
let rb = Some(RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?);
let res = exec_parsed(&self.copr, &rb, &HashMap::new()).map_err(|err| {
PyUdfSnafu {
msg: format!("{err:#?}"),
}
@@ -153,7 +166,7 @@ impl Function for PyUDF {
// TODO(discord9): more error handling
let res0 = res.column(0);
Ok(res0.to_owned())
Ok(res0.clone())
}
}
@@ -168,13 +181,14 @@ impl PyScript {
pub fn register_udf(&self) {
let udf = PyUDF::from_copr(self.copr.clone());
PyUDF::register_as_udf(udf.clone());
PyUDF::register_to_query_engine(udf, self.query_engine.to_owned());
PyUDF::register_to_query_engine(udf, self.query_engine.clone());
}
}
pub struct CoprStream {
stream: SendableRecordBatchStream,
copr: CoprocessorRef,
params: HashMap<String, String>,
}
impl RecordBatchStream for CoprStream {
@@ -190,7 +204,7 @@ impl Stream for CoprStream {
match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(Ok(recordbatch))) => {
let batch = exec_parsed(&self.copr, &recordbatch)
let batch = exec_parsed(&self.copr, &Some(recordbatch), &self.params)
.map_err(BoxedError::new)
.context(ExternalSnafu)?;
@@ -218,7 +232,7 @@ impl Script for PyScript {
self
}
async fn execute(&self, _ctx: EvalContext) -> Result<Output> {
async fn execute(&self, params: HashMap<String, String>, _ctx: EvalContext) -> Result<Output> {
if let Some(sql) = &self.copr.deco_args.sql {
let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
ensure!(
@@ -231,12 +245,17 @@ impl Script for PyScript {
let res = self.query_engine.execute(&plan).await?;
let copr = self.copr.clone();
match res {
Output::Stream(stream) => Ok(Output::Stream(Box::pin(CoprStream { copr, stream }))),
Output::Stream(stream) => Ok(Output::Stream(Box::pin(CoprStream {
params,
copr,
stream,
}))),
_ => unreachable!(),
}
} else {
// TODO(boyan): try to retrieve sql from user request
error::MissingSqlSnafu {}.fail()
let batch = exec_parsed(&self.copr, &None, &params)?;
let batches = RecordBatches::try_new(batch.schema.clone(), vec![batch]).unwrap();
Ok(Output::RecordBatches(batches))
}
}
}
@@ -284,6 +303,7 @@ mod tests {
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_recordbatch::util;
use datatypes::prelude::ScalarVector;
use datatypes::value::Value;
use datatypes::vectors::{Float64Vector, Int64Vector};
use query::QueryEngineFactory;
use table::table::numbers::NumbersTable;
@@ -326,7 +346,10 @@ def test(number)->vector[u32]:
.compile(script, CompileContext::default())
.await
.unwrap();
let output = script.execute(EvalContext::default()).await.unwrap();
let output = script
.execute(HashMap::default(), EvalContext::default())
.await
.unwrap();
let res = common_recordbatch::util::collect_batches(match output {
Output::Stream(s) => s,
_ => unreachable!(),
@@ -337,6 +360,36 @@ def test(number)->vector[u32]:
assert_eq!(rb.column(0).len(), 100);
}
#[tokio::test]
async fn test_user_params_in_py() {
let script_engine = sample_script_engine();
let script = r#"
@copr(returns = ["number"])
def test(**params)->vector[i64]:
return int(params['a']) + int(params['b'])
"#;
let script = script_engine
.compile(script, CompileContext::default())
.await
.unwrap();
let mut params = HashMap::new();
params.insert("a".to_string(), "30".to_string());
params.insert("b".to_string(), "12".to_string());
let _output = script
.execute(params, EvalContext::default())
.await
.unwrap();
let res = match _output {
Output::RecordBatches(s) => s,
_ => todo!(),
};
let rb = res.iter().next().expect("One and only one recordbatch");
assert_eq!(rb.column(0).len(), 1);
let result = rb.column(0).get(0);
assert!(matches!(result, Value::Int64(42)));
}
#[tokio::test]
async fn test_data_frame_in_py() {
let script_engine = sample_script_engine();
@@ -353,7 +406,10 @@ def test(number)->vector[u32]:
.compile(script, CompileContext::default())
.await
.unwrap();
let _output = script.execute(EvalContext::default()).await.unwrap();
let _output = script
.execute(HashMap::new(), EvalContext::default())
.await
.unwrap();
let res = common_recordbatch::util::collect_batches(match _output {
Output::Stream(s) => s,
_ => todo!(),
@@ -382,7 +438,10 @@ def test(a, b, c):
.compile(script, CompileContext::default())
.await
.unwrap();
let output = script.execute(EvalContext::default()).await.unwrap();
let output = script
.execute(HashMap::new(), EvalContext::default())
.await
.unwrap();
match output {
Output::Stream(stream) => {
let numbers = util::collect(stream).await.unwrap();
@@ -417,7 +476,10 @@ def test(a):
.compile(script, CompileContext::default())
.await
.unwrap();
let output = script.execute(EvalContext::default()).await.unwrap();
let output = script
.execute(HashMap::new(), EvalContext::default())
.await
.unwrap();
match output {
Output::Stream(stream) => {
let numbers = util::collect(stream).await.unwrap();

View File

@@ -216,7 +216,7 @@ pub fn visualize_loc(
/// extract a reason for [`Error`] in string format, also return a location if possible
pub fn get_error_reason_loc(err: &Error) -> (String, Option<Location>) {
match err {
Error::CoprParse { reason, loc, .. } => (reason.clone(), loc.to_owned()),
Error::CoprParse { reason, loc, .. } => (reason.clone(), *loc),
Error::Other { reason, .. } => (reason.clone(), None),
Error::PyRuntime { msg, .. } => (msg.clone(), None),
Error::PyParse { source, .. } => (source.error.to_string(), Some(source.location)),

View File

@@ -126,7 +126,7 @@ fn run_ron_testcases() {
}
Predicate::ExecIsOk { fields, columns } => {
let rb = create_sample_recordbatch();
let res = coprocessor::exec_coprocessor(&testcase.code, &rb).unwrap();
let res = coprocessor::exec_coprocessor(&testcase.code, &Some(rb)).unwrap();
fields
.iter()
.zip(res.schema.column_schemas())
@@ -152,7 +152,7 @@ fn run_ron_testcases() {
reason: part_reason,
} => {
let rb = create_sample_recordbatch();
let res = coprocessor::exec_coprocessor(&testcase.code, &rb);
let res = coprocessor::exec_coprocessor(&testcase.code, &Some(rb));
assert!(res.is_err(), "{res:#?}\nExpect Err(...), actual Ok(...)");
if let Err(res) = res {
error!(
@@ -254,7 +254,7 @@ def calc_rvs(open_time, close):
],
)
.unwrap();
let ret = coprocessor::exec_coprocessor(python_source, &rb);
let ret = coprocessor::exec_coprocessor(python_source, &Some(rb));
if let Err(Error::PyParse {
backtrace: _,
source,
@@ -304,7 +304,7 @@ def a(cpu, mem):
],
)
.unwrap();
let ret = coprocessor::exec_coprocessor(python_source, &rb);
let ret = coprocessor::exec_coprocessor(python_source, &Some(rb));
if let Err(Error::PyParse {
backtrace: _,
source,

View File

@@ -19,7 +19,7 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vecto
result: (
name: "a",
deco_args: (
arg_names: ["cpu", "mem"],
arg_names: Some(["cpu", "mem"]),
ret_names: ["perf", "what", "how", "why"],
),
arg_types: [
@@ -49,7 +49,62 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vecto
datatype: None,
is_nullable: true
)),
]
],
kwarg: None,
)
)
),
(
name: "correct_parse_params",
code: r#"
import greptime as gt
from greptime import pow
def add(a, b):
return a + b
def sub(a, b):
return a - b
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
def a(cpu: vector[f32], mem: vector[f64], **params) -> (vector[f64], vector[f64|None], vector[_], vector[_ | None]):
for key, value in params.items():
print("%s == %s" % (key, value))
return add(cpu, mem), sub(cpu, mem), cpu * mem, cpu / mem
"#,
predicate: ParseIsOk(
result: (
name: "a",
deco_args: (
arg_names: Some(["cpu", "mem"]),
ret_names: ["perf", "what", "how", "why"],
),
arg_types: [
Some((
datatype: Some(Float32(())),
is_nullable: false
)),
Some((
datatype: Some(Float64(())),
is_nullable: false
)),
],
return_types: [
Some((
datatype: Some(Float64(())),
is_nullable: false
)),
Some((
datatype: Some(Float64(())),
is_nullable: true
)),
Some((
datatype: None,
is_nullable: false
)),
Some((
datatype: None,
is_nullable: true
)),
],
kwarg: Some("params"),
)
)
),
@@ -231,7 +286,7 @@ def a(cpu: vector[f64], mem: vector[f64])->(vector[f64|None], vector[into(f64)],
"#,
predicate: ParseIsErr(
reason:
" keyword argument, found "
"Expect `returns` keyword"
)
),
(

View File

@@ -490,6 +490,21 @@ impl PyVector {
self.as_vector_ref().len()
}
#[pymethod(name = "concat")]
fn concat(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
let left = self.to_arrow_array();
let right = other.to_arrow_array();
let res = compute::concat(&[left.as_ref(), right.as_ref()]);
let res = res.map_err(|err| vm.new_runtime_error(format!("Arrow Error: {err:#?}")))?;
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
vm.new_type_error(format!(
"Can't cast result into vector, result: {res:?}, err: {e:?}",
))
})?;
Ok(ret.into())
}
/// take a boolean array and filters the Array, returning elements matching the filter (i.e. where the values are true).
#[pymethod(name = "filter")]
fn filter(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
@@ -549,7 +564,7 @@ impl PyVector {
// in the newest version of rustpython_vm, wrapped_at for isize is replace by wrap_index(i, len)
let i = i
.wrapped_at(self.len())
.ok_or_else(|| vm.new_index_error("PyVector index out of range".to_owned()))?;
.ok_or_else(|| vm.new_index_error("PyVector index out of range".to_string()))?;
Ok(val_to_pyobj(self.as_vector_ref().get(i), vm))
}
@@ -912,7 +927,7 @@ impl AsSequence for PyVector {
zelf.getitem_by_index(i, vm)
}),
ass_item: atomic_func!(|_seq, _i, _value, vm| {
Err(vm.new_type_error("PyVector object doesn't support item assigns".to_owned()))
Err(vm.new_type_error("PyVector object doesn't support item assigns".to_string()))
}),
..PySequenceMethods::NOT_IMPLEMENTED
});
@@ -1080,7 +1095,7 @@ pub mod tests {
.compile(
script,
rustpython_compiler_core::Mode::BlockExpr,
"<embedded>".to_owned(),
"<embedded>".to_string(),
)
.map_err(|err| vm.new_syntax_error(&err))?;
let ret = vm.run_code_obj(code_obj, scope);

View File

@@ -11,7 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::time::Instant;
use axum::extract::{Json, Query, RawBody, State};
@@ -81,10 +81,12 @@ pub async fn scripts(
}
}
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Serialize, Deserialize, JsonSchema, Default)]
pub struct ScriptQuery {
pub db: Option<String>,
pub name: Option<String>,
#[serde(flatten)]
pub params: HashMap<String, String>,
}
/// Handler to execute script
@@ -110,7 +112,7 @@ pub async fn run_script(
// TODO(sunng87): query_context and db name resolution
let output = script_handler
.execute_script(schema.unwrap(), name.unwrap())
.execute_script(schema.unwrap(), name.unwrap(), params.params)
.await;
let resp = JsonResponse::from_output(vec![output]).await;

View File

@@ -25,6 +25,7 @@
pub mod grpc;
pub mod sql;
use std::collections::HashMap;
use std::sync::Arc;
use api::prometheus::remote::{ReadRequest, WriteRequest};
@@ -45,7 +46,12 @@ pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
#[async_trait]
pub trait ScriptHandler {
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>;
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output>;
async fn execute_script(
&self,
schema: &str,
name: &str,
params: HashMap<String, String>,
) -> Result<Output>;
}
#[async_trait]

View File

@@ -23,7 +23,10 @@ use servers::http::{handler as http_handler, script as script_handler, ApiState,
use session::context::UserInfo;
use table::test_util::MemTable;
use crate::{create_testing_script_handler, create_testing_sql_query_handler};
use crate::{
create_testing_script_handler, create_testing_sql_query_handler, ScriptHandlerRef,
ServerSqlQueryHandlerRef,
};
#[tokio::test]
async fn test_sql_not_provided() {
@@ -68,6 +71,25 @@ async fn test_sql_output_rows() {
match &json.output().expect("assertion failed")[0] {
JsonOutput::Records(records) => {
assert_eq!(1, records.num_rows());
let json = serde_json::to_string_pretty(&records).unwrap();
assert_eq!(
json,
r#"{
"schema": {
"column_schemas": [
{
"name": "SUM(numbers.uint32s)",
"data_type": "UInt64"
}
]
},
"rows": [
[
4950
]
]
}"#
);
}
_ => unreachable!(),
}
@@ -95,6 +117,25 @@ async fn test_sql_form() {
match &json.output().expect("assertion failed")[0] {
JsonOutput::Records(records) => {
assert_eq!(1, records.num_rows());
let json = serde_json::to_string_pretty(&records).unwrap();
assert_eq!(
json,
r#"{
"schema": {
"column_schemas": [
{
"name": "SUM(numbers.uint32s)",
"data_type": "UInt64"
}
]
},
"rows": [
[
4950
]
]
}"#
);
}
_ => unreachable!(),
}
@@ -110,18 +151,11 @@ async fn test_metrics() {
assert!(text.contains("test_metrics counter"));
}
#[tokio::test]
async fn test_scripts() {
common_telemetry::init_default_ut_logging();
let script = r#"
@copr(sql='select uint32s as number from numbers', args=['number'], returns=['n'])
def test(n):
return n;
"#
.to_string();
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
async fn insert_script(
script: String,
script_handler: ScriptHandlerRef,
sql_handler: ServerSqlQueryHandlerRef,
) {
let body = RawBody(Body::from(script.clone()));
let invalid_query = create_invalid_script_query();
let Json(json) = script_handler::scripts(
@@ -136,12 +170,13 @@ def test(n):
assert!(!json.success(), "{json:?}");
assert_eq!(json.error().unwrap(), "Invalid argument: invalid schema");
let body = RawBody(Body::from(script));
let body = RawBody(Body::from(script.clone()));
let exec = create_script_query();
// Insert the script
let Json(json) = script_handler::scripts(
State(ApiState {
sql_handler,
script_handler: Some(script_handler),
sql_handler: sql_handler.clone(),
script_handler: Some(script_handler.clone()),
}),
exec,
body,
@@ -152,10 +187,144 @@ def test(n):
assert!(json.output().is_none());
}
#[tokio::test]
async fn test_scripts() {
common_telemetry::init_default_ut_logging();
let script = r#"
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
def test(n) -> vector[i64]:
return n;
"#
.to_string();
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
// Run the script
let exec = create_script_query();
let Json(json) = script_handler::run_script(
State(ApiState {
sql_handler,
script_handler: Some(script_handler),
}),
exec,
)
.await;
assert!(json.success(), "{json:?}");
assert!(json.error().is_none());
match &json.output().unwrap()[0] {
JsonOutput::Records(records) => {
let json = serde_json::to_string_pretty(&records).unwrap();
assert_eq!(5, records.num_rows());
assert_eq!(
json,
r#"{
"schema": {
"column_schemas": [
{
"name": "n",
"data_type": "Int64"
}
]
},
"rows": [
[
0
],
[
1
],
[
2
],
[
3
],
[
4
]
]
}"#
);
}
_ => unreachable!(),
}
}
#[tokio::test]
async fn test_scripts_with_params() {
common_telemetry::init_default_ut_logging();
let script = r#"
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
def test(n, **params) -> vector[i64]:
return n + int(params['a'])
"#
.to_string();
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
// Run the script
let mut exec = create_script_query();
exec.0.params.insert("a".to_string(), "42".to_string());
let Json(json) = script_handler::run_script(
State(ApiState {
sql_handler,
script_handler: Some(script_handler),
}),
exec,
)
.await;
assert!(json.success(), "{json:?}");
assert!(json.error().is_none());
match &json.output().unwrap()[0] {
JsonOutput::Records(records) => {
let json = serde_json::to_string_pretty(&records).unwrap();
assert_eq!(5, records.num_rows());
assert_eq!(
json,
r#"{
"schema": {
"column_schemas": [
{
"name": "n",
"data_type": "Int64"
}
]
},
"rows": [
[
42
],
[
43
],
[
44
],
[
45
],
[
46
]
]
}"#
);
}
_ => unreachable!(),
}
}
fn create_script_query() -> Query<script_handler::ScriptQuery> {
Query(script_handler::ScriptQuery {
db: Some("test".to_string()),
name: Some("test".to_string()),
..Default::default()
})
}
@@ -163,6 +332,7 @@ fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
Query(script_handler::ScriptQuery {
db: None,
name: None,
..Default::default()
})
}

View File

@@ -120,12 +120,20 @@ impl ScriptHandler for DummyInstance {
Ok(())
}
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output> {
async fn execute_script(
&self,
schema: &str,
name: &str,
params: HashMap<String, String>,
) -> Result<Output> {
let key = format!("{schema}_{name}");
let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone();
Ok(py_script.execute(EvalContext::default()).await.unwrap())
Ok(py_script
.execute(params, EvalContext::default())
.await
.unwrap())
}
}