mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-10 07:12:54 +00:00
feat: supports passing user params into coprocessor (#962)
* feat: make args in coprocessor optional * feat: supports kwargs for coprocessor as params passed by the users * feat: supports params for /run-script * fix: we should rewrite the coprocessor by removing kwargs * fix: remove println * fix: compile error after rebasing * fix: improve http_handler_test * test: http scripts api with user params * refactor: tweak all to_owned
This commit is contained in:
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_query::Output;
|
||||
use common_telemetry::timer;
|
||||
@@ -34,8 +36,15 @@ impl ScriptHandler for Instance {
|
||||
.await
|
||||
}
|
||||
|
||||
async fn execute_script(&self, schema: &str, name: &str) -> servers::error::Result<Output> {
|
||||
async fn execute_script(
|
||||
&self,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> servers::error::Result<Output> {
|
||||
let _timer = timer!(metric::METRIC_RUN_SCRIPT_ELAPSED);
|
||||
self.script_executor.execute_script(schema, name).await
|
||||
self.script_executor
|
||||
.execute_script(schema, name, params)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use catalog::CatalogManagerRef;
|
||||
use common_query::Output;
|
||||
use query::QueryEngineRef;
|
||||
@@ -34,13 +36,19 @@ mod dummy {
|
||||
|
||||
pub async fn insert_script(
|
||||
&self,
|
||||
_schema: &str,
|
||||
_name: &str,
|
||||
_script: &str,
|
||||
) -> servers::error::Result<()> {
|
||||
servers::error::NotSupportedSnafu { feat: "script" }.fail()
|
||||
}
|
||||
|
||||
pub async fn execute_script(&self, _script: &str) -> servers::error::Result<Output> {
|
||||
pub async fn execute_script(
|
||||
&self,
|
||||
_schema: &str,
|
||||
_name: &str,
|
||||
_params: HashMap<String, String>,
|
||||
) -> servers::error::Result<Output> {
|
||||
servers::error::NotSupportedSnafu { feat: "script" }.fail()
|
||||
}
|
||||
}
|
||||
@@ -94,9 +102,10 @@ mod python {
|
||||
&self,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> servers::error::Result<Output> {
|
||||
self.script_manager
|
||||
.execute(schema, name)
|
||||
.execute(schema, name, params)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(e; "Instance failed to execute script");
|
||||
|
||||
@@ -19,6 +19,7 @@ mod opentsdb;
|
||||
mod prometheus;
|
||||
mod standalone;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -512,9 +513,14 @@ impl ScriptHandler for Instance {
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_script(&self, schema: &str, script: &str) -> server_error::Result<Output> {
|
||||
async fn execute_script(
|
||||
&self,
|
||||
schema: &str,
|
||||
script: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> server_error::Result<Output> {
|
||||
if let Some(handler) = &self.script_handler {
|
||||
handler.execute_script(schema, script).await
|
||||
handler.execute_script(schema, script, params).await
|
||||
} else {
|
||||
server_error::NotSupportedSnafu {
|
||||
feat: "Script execution in Frontend",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
//! Script engine
|
||||
|
||||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_error::ext::ErrorExt;
|
||||
@@ -30,7 +31,11 @@ pub trait Script {
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Execute the script and returns the output.
|
||||
async fn execute(&self, ctx: EvalContext) -> std::result::Result<Output, Self::Error>;
|
||||
async fn execute(
|
||||
&self,
|
||||
params: HashMap<String, String>,
|
||||
ctx: EvalContext,
|
||||
) -> std::result::Result<Output, Self::Error>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@@ -76,7 +76,12 @@ impl ScriptManager {
|
||||
Ok(compiled_script)
|
||||
}
|
||||
|
||||
pub async fn execute(&self, schema: &str, name: &str) -> Result<Output> {
|
||||
pub async fn execute(
|
||||
&self,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> Result<Output> {
|
||||
let script = {
|
||||
let s = self.compiled.read().unwrap().get(name).cloned();
|
||||
|
||||
@@ -90,7 +95,7 @@ impl ScriptManager {
|
||||
let script = script.context(ScriptNotFoundSnafu { name })?;
|
||||
|
||||
script
|
||||
.execute(EvalContext::default())
|
||||
.execute(params, EvalContext::default())
|
||||
.await
|
||||
.context(ExecutePythonSnafu { name })
|
||||
}
|
||||
|
||||
@@ -97,7 +97,7 @@ pub fn try_into_columnar_value(obj: PyObjectRef, vm: &VirtualMachine) -> PyResul
|
||||
.borrow_vec()
|
||||
.iter()
|
||||
.map(|obj| -> PyResult<ScalarValue> {
|
||||
let col = try_into_columnar_value(obj.to_owned(), vm)?;
|
||||
let col = try_into_columnar_value(obj.clone(), vm)?;
|
||||
match col {
|
||||
DFColValue::Array(arr) => Err(vm.new_type_error(format!(
|
||||
"Expect only scalar value in a list, found a vector of type {:?} nested in list", arr.data_type()
|
||||
@@ -243,7 +243,7 @@ macro_rules! bind_aggr_fn {
|
||||
$(
|
||||
Arc::new(expressions::Column::new(stringify!($EXPR_ARGS), 0)) as _,
|
||||
)*
|
||||
stringify!($AGGR_FUNC), $DATA_TYPE.to_owned()),
|
||||
stringify!($AGGR_FUNC), $DATA_TYPE.clone()),
|
||||
$ARGS, $VM)
|
||||
};
|
||||
}
|
||||
@@ -597,7 +597,7 @@ pub(crate) mod greptime_builtin {
|
||||
Arc::new(percent) as _,
|
||||
],
|
||||
"ApproxPercentileCont",
|
||||
(values.to_arrow_array().data_type()).to_owned(),
|
||||
(values.to_arrow_array().data_type()).clone(),
|
||||
)
|
||||
.map_err(|err| from_df_err(err, vm))?,
|
||||
&[values.to_arrow_array()],
|
||||
@@ -839,7 +839,7 @@ pub(crate) mod greptime_builtin {
|
||||
return Ok(ret.into());
|
||||
}
|
||||
let cur = cur.slice(0, cur.len() - 1); // except the last one that is
|
||||
let fill = gen_none_array(cur.data_type().to_owned(), 1, vm)?;
|
||||
let fill = gen_none_array(cur.data_type().clone(), 1, vm)?;
|
||||
let ret = compute::concat(&[&*fill, &*cur]).map_err(|err| {
|
||||
vm.new_runtime_error(format!("Can't concat array[0] with array[0:-1]!{err:#?}"))
|
||||
})?;
|
||||
@@ -864,7 +864,7 @@ pub(crate) mod greptime_builtin {
|
||||
return Ok(ret.into());
|
||||
}
|
||||
let cur = cur.slice(1, cur.len() - 1); // except the last one that is
|
||||
let fill = gen_none_array(cur.data_type().to_owned(), 1, vm)?;
|
||||
let fill = gen_none_array(cur.data_type().clone(), 1, vm)?;
|
||||
let ret = compute::concat(&[&*cur, &*fill]).map_err(|err| {
|
||||
vm.new_runtime_error(format!("Can't concat array[0] with array[0:-1]!{err:#?}"))
|
||||
})?;
|
||||
@@ -1048,7 +1048,7 @@ pub(crate) mod greptime_builtin {
|
||||
match (ch.is_ascii_digit(), &state) {
|
||||
(true, State::Separator(_)) => {
|
||||
let res = &input[prev..idx];
|
||||
let res = State::Separator(res.to_owned());
|
||||
let res = State::Separator(res.to_string());
|
||||
parsed.push(res);
|
||||
prev = idx;
|
||||
state = State::Num(Default::default());
|
||||
@@ -1073,7 +1073,7 @@ pub(crate) mod greptime_builtin {
|
||||
}
|
||||
State::Separator(_) => {
|
||||
let res = &input[prev..];
|
||||
State::Separator(res.to_owned())
|
||||
State::Separator(res.to_string())
|
||||
}
|
||||
};
|
||||
parsed.push(last);
|
||||
|
||||
@@ -240,10 +240,7 @@ impl PyValue {
|
||||
.as_any()
|
||||
.downcast_ref::<Float64Array>()
|
||||
.ok_or(format!("Can't cast {vec_f64:#?} to Float64Array!"))?;
|
||||
let ret = vec_f64
|
||||
.into_iter()
|
||||
.map(|v| v.map(|inner| inner.to_owned()))
|
||||
.collect::<Vec<_>>();
|
||||
let ret = vec_f64.into_iter().collect::<Vec<_>>();
|
||||
if ret.iter().all(|x| x.is_some()) {
|
||||
Ok(Self::FloatVec(
|
||||
ret.into_iter().map(|i| i.unwrap()).collect(),
|
||||
@@ -266,7 +263,6 @@ impl PyValue {
|
||||
v.ok_or(format!(
|
||||
"No null element expected, found one in {idx} position"
|
||||
))
|
||||
.map(|v| v.to_owned())
|
||||
})
|
||||
.collect::<Result<_, String>>()?;
|
||||
Ok(Self::IntVec(ret))
|
||||
@@ -275,13 +271,13 @@ impl PyValue {
|
||||
}
|
||||
} else if is_instance::<PyInt>(obj, vm) {
|
||||
let res = obj
|
||||
.to_owned()
|
||||
.clone()
|
||||
.try_into_value::<i64>(vm)
|
||||
.map_err(|err| format_py_error(err, vm).to_string())?;
|
||||
Ok(Self::Int(res))
|
||||
} else if is_instance::<PyFloat>(obj, vm) {
|
||||
let res = obj
|
||||
.to_owned()
|
||||
.clone()
|
||||
.try_into_value::<f64>(vm)
|
||||
.map_err(|err| format_py_error(err, vm).to_string())?;
|
||||
Ok(Self::Float(res))
|
||||
@@ -338,7 +334,7 @@ fn run_builtin_fn_testcases() {
|
||||
.compile(
|
||||
&case.script,
|
||||
rustpython_compiler_core::Mode::BlockExpr,
|
||||
"<embedded>".to_owned(),
|
||||
"<embedded>".to_string(),
|
||||
)
|
||||
.map_err(|err| vm.new_syntax_error(&err))
|
||||
.unwrap();
|
||||
@@ -389,7 +385,7 @@ fn set_item_into_scope(
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(&name.to_owned(), vm.new_pyobj(value), vm)
|
||||
.set_item(&name.to_string(), vm.new_pyobj(value), vm)
|
||||
.map_err(|err| {
|
||||
format!(
|
||||
"Error in setting var {name} in scope: \n{}",
|
||||
@@ -408,7 +404,7 @@ fn set_lst_of_vecs_in_scope(
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(name.to_owned(), vm.new_pyobj(vector), vm)
|
||||
.set_item(&name.to_string(), vm.new_pyobj(vector), vm)
|
||||
.map_err(|err| {
|
||||
format!(
|
||||
"Error in setting var {name} in scope: \n{}",
|
||||
@@ -447,7 +443,7 @@ fn test_vm() {
|
||||
from udf_builtins import *
|
||||
sin(values)"#,
|
||||
rustpython_compiler_core::Mode::BlockExpr,
|
||||
"<embedded>".to_owned(),
|
||||
"<embedded>".to_string(),
|
||||
)
|
||||
.map_err(|err| vm.new_syntax_error(&err))
|
||||
.unwrap();
|
||||
|
||||
@@ -16,7 +16,7 @@ pub mod compile;
|
||||
pub mod parse;
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
@@ -36,7 +36,7 @@ use rustpython_vm::AsObject;
|
||||
#[cfg(test)]
|
||||
use serde::Deserialize;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use vm::builtins::{PyBaseExceptionRef, PyList, PyListRef, PyTuple};
|
||||
use vm::builtins::{PyBaseExceptionRef, PyDict, PyList, PyListRef, PyStr, PyTuple};
|
||||
use vm::convert::ToPyObject;
|
||||
use vm::scope::Scope;
|
||||
use vm::{pyclass, Interpreter, PyObjectRef, PyPayload, PyResult, VirtualMachine};
|
||||
@@ -73,6 +73,8 @@ pub struct Coprocessor {
|
||||
pub arg_types: Vec<Option<AnnotationInfo>>,
|
||||
/// get from python function returns' annotation, first is type, second is is_nullable
|
||||
pub return_types: Vec<Option<AnnotationInfo>>,
|
||||
/// kwargs in coprocessor function's signature
|
||||
pub kwarg: Option<String>,
|
||||
/// store its corresponding script, also skip serde when in `cfg(test)` to reduce work in compare
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub script: String,
|
||||
@@ -103,7 +105,7 @@ impl From<&Arc<dyn QueryEngine>> for QueryEngineWeakRef {
|
||||
impl std::fmt::Debug for QueryEngineWeakRef {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_tuple("QueryEngineWeakRef")
|
||||
.field(&self.0.upgrade().map(|f| f.name().to_owned()))
|
||||
.field(&self.0.upgrade().map(|f| f.name().to_string()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -147,7 +149,7 @@ impl Coprocessor {
|
||||
let AnnotationInfo {
|
||||
datatype: ty,
|
||||
is_nullable,
|
||||
} = anno[idx].to_owned().unwrap_or_else(|| {
|
||||
} = anno[idx].clone().unwrap_or_else(|| {
|
||||
// default to be not nullable and use DataType inferred by PyVector itself
|
||||
AnnotationInfo {
|
||||
datatype: Some(real_ty.clone()),
|
||||
@@ -248,20 +250,23 @@ fn check_args_anno_real_type(
|
||||
rb: &RecordBatch,
|
||||
) -> Result<()> {
|
||||
for (idx, arg) in args.iter().enumerate() {
|
||||
let anno_ty = copr.arg_types[idx].to_owned();
|
||||
let real_ty = arg.to_arrow_array().data_type().to_owned();
|
||||
let anno_ty = copr.arg_types[idx].clone();
|
||||
let real_ty = arg.to_arrow_array().data_type().clone();
|
||||
let real_ty = ConcreteDataType::from_arrow_type(&real_ty);
|
||||
let is_nullable: bool = rb.schema.column_schemas()[idx].is_nullable();
|
||||
ensure!(
|
||||
anno_ty
|
||||
.to_owned()
|
||||
.clone()
|
||||
.map(|v| v.datatype.is_none() // like a vector[_]
|
||||
|| v.datatype == Some(real_ty.to_owned()) && v.is_nullable == is_nullable)
|
||||
|| v.datatype == Some(real_ty.clone()) && v.is_nullable == is_nullable)
|
||||
.unwrap_or(true),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"column {}'s Type annotation is {:?}, but actual type is {:?}",
|
||||
copr.deco_args.arg_names[idx], anno_ty, real_ty
|
||||
// It's safe to unwrap here, we already ensure the args and types number is the same when parsing
|
||||
copr.deco_args.arg_names.as_ref().unwrap()[idx],
|
||||
anno_ty,
|
||||
real_ty
|
||||
)
|
||||
}
|
||||
)
|
||||
@@ -343,12 +348,12 @@ fn set_items_in_scope(
|
||||
/// You can return constant in python code like `return 1, 1.0, True`
|
||||
/// which create a constant array(with same value)(currently support int, float and bool) as column on return
|
||||
#[cfg(test)]
|
||||
pub fn exec_coprocessor(script: &str, rb: &RecordBatch) -> Result<RecordBatch> {
|
||||
pub fn exec_coprocessor(script: &str, rb: &Option<RecordBatch>) -> Result<RecordBatch> {
|
||||
// 1. parse the script and check if it's only a function with `@coprocessor` decorator, and get `args` and `returns`,
|
||||
// 2. also check for exist of `args` in `rb`, if not found, return error
|
||||
// TODO(discord9): cache the result of parse_copr
|
||||
let copr = parse::parse_and_compile_copr(script, None)?;
|
||||
exec_parsed(&copr, rb)
|
||||
exec_parsed(&copr, rb, &HashMap::new())
|
||||
}
|
||||
|
||||
#[pyclass(module = false, name = "query_engine")]
|
||||
@@ -412,7 +417,7 @@ impl PyQueryEngine {
|
||||
for rb in rbs.iter() {
|
||||
let mut vec_of_vec = Vec::with_capacity(rb.columns().len());
|
||||
for v in rb.columns() {
|
||||
let v = PyVector::from(v.to_owned());
|
||||
let v = PyVector::from(v.clone());
|
||||
vec_of_vec.push(v.to_pyobject(vm));
|
||||
}
|
||||
let vec_of_vec = PyList::new_ref(vec_of_vec, vm.as_ref()).to_pyobject(vm);
|
||||
@@ -440,18 +445,25 @@ fn set_query_engine_in_scope(
|
||||
.map_err(|e| format_py_error(e, vm))
|
||||
}
|
||||
|
||||
pub(crate) fn exec_with_cached_vm(
|
||||
fn exec_with_cached_vm(
|
||||
copr: &Coprocessor,
|
||||
rb: &RecordBatch,
|
||||
rb: &Option<RecordBatch>,
|
||||
args: Vec<PyVector>,
|
||||
params: &HashMap<String, String>,
|
||||
vm: &Arc<Interpreter>,
|
||||
) -> Result<RecordBatch> {
|
||||
vm.enter(|vm| -> Result<RecordBatch> {
|
||||
PyVector::make_class(&vm.ctx);
|
||||
// set arguments with given name and values
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
set_items_in_scope(&scope, vm, &copr.deco_args.arg_names, args)?;
|
||||
set_dataframe_in_scope(&scope, vm, "dataframe", rb)?;
|
||||
if let Some(rb) = rb {
|
||||
set_dataframe_in_scope(&scope, vm, "dataframe", rb)?;
|
||||
}
|
||||
|
||||
if let Some(arg_names) = &copr.deco_args.arg_names {
|
||||
assert_eq!(arg_names.len(), args.len());
|
||||
set_items_in_scope(&scope, vm, arg_names, args)?;
|
||||
}
|
||||
|
||||
if let Some(engine) = &copr.query_engine {
|
||||
let query_engine = PyQueryEngine {
|
||||
@@ -463,6 +475,19 @@ pub(crate) fn exec_with_cached_vm(
|
||||
set_query_engine_in_scope(&scope, vm, query_engine)?;
|
||||
}
|
||||
|
||||
if let Some(kwarg) = &copr.kwarg {
|
||||
let dict = PyDict::new_ref(&vm.ctx);
|
||||
for (k, v) in params {
|
||||
dict.set_item(k, PyStr::from(v.clone()).into_pyobject(vm), vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
}
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(kwarg, vm.new_pyobj(dict), vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
}
|
||||
|
||||
// It's safe to unwrap code_object, it's already compiled before.
|
||||
let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap());
|
||||
let ret = vm
|
||||
@@ -470,7 +495,7 @@ pub(crate) fn exec_with_cached_vm(
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
// 5. get returns as either a PyVector or a PyTuple, and naming schema them according to `returns`
|
||||
let col_len = rb.num_rows();
|
||||
let col_len = rb.as_ref().map(|rb| rb.num_rows()).unwrap_or(1);
|
||||
let mut cols = try_into_columns(&ret, vm, col_len)?;
|
||||
ensure!(
|
||||
cols.len() == copr.deco_args.ret_names.len(),
|
||||
@@ -485,6 +510,7 @@ pub(crate) fn exec_with_cached_vm(
|
||||
|
||||
// if cols and schema's data types is not match, try coerce it to given type(if annotated)(if error occur, return relevant error with question mark)
|
||||
copr.check_and_cast_type(&mut cols)?;
|
||||
|
||||
// 6. return a assembled DfRecordBatch
|
||||
let schema = copr.gen_schema(&cols)?;
|
||||
RecordBatch::new(schema, cols).context(NewRecordBatchSnafu)
|
||||
@@ -533,13 +559,23 @@ pub(crate) fn init_interpreter() -> Arc<Interpreter> {
|
||||
}
|
||||
|
||||
/// using a parsed `Coprocessor` struct as input to execute python code
|
||||
pub(crate) fn exec_parsed(copr: &Coprocessor, rb: &RecordBatch) -> Result<RecordBatch> {
|
||||
pub(crate) fn exec_parsed(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
params: &HashMap<String, String>,
|
||||
) -> Result<RecordBatch> {
|
||||
// 3. get args from `rb`, and cast them into PyVector
|
||||
let args: Vec<PyVector> = select_from_rb(rb, &copr.deco_args.arg_names)?;
|
||||
check_args_anno_real_type(&args, copr, rb)?;
|
||||
let args: Vec<PyVector> = if let Some(rb) = rb {
|
||||
let args = select_from_rb(rb, copr.deco_args.arg_names.as_ref().unwrap_or(&vec![]))?;
|
||||
check_args_anno_real_type(&args, copr, rb)?;
|
||||
args
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let interpreter = init_interpreter();
|
||||
// 4. then set args in scope and compile then run `CodeObject` which already append a new `Call` node
|
||||
exec_with_cached_vm(copr, rb, args, &interpreter)
|
||||
exec_with_cached_vm(copr, rb, args, params, &interpreter)
|
||||
}
|
||||
|
||||
/// execute script just like [`exec_coprocessor`] do,
|
||||
@@ -551,7 +587,7 @@ pub(crate) fn exec_parsed(copr: &Coprocessor, rb: &RecordBatch) -> Result<Record
|
||||
#[allow(dead_code)]
|
||||
pub fn exec_copr_print(
|
||||
script: &str,
|
||||
rb: &RecordBatch,
|
||||
rb: &Option<RecordBatch>,
|
||||
ln_offset: usize,
|
||||
filename: &str,
|
||||
) -> StdResult<RecordBatch, String> {
|
||||
@@ -572,7 +608,7 @@ def add(a, b):
|
||||
return a + b
|
||||
|
||||
@copr(args=["a", "b", "c"], returns = ["r"], sql="select number as a,number as b,number as c from numbers limit 100")
|
||||
def test(a, b, c):
|
||||
def test(a, b, c, **params):
|
||||
import greptime as g
|
||||
return add(a, b) / g.sqrt(c)
|
||||
"#;
|
||||
@@ -585,9 +621,10 @@ def test(a, b, c):
|
||||
"select number as a,number as b,number as c from numbers limit 100"
|
||||
);
|
||||
assert_eq!(deco_args.ret_names, vec!["r"]);
|
||||
assert_eq!(deco_args.arg_names, vec!["a", "b", "c"]);
|
||||
assert_eq!(deco_args.arg_names.unwrap(), vec!["a", "b", "c"]);
|
||||
assert_eq!(copr.arg_types, vec![None, None, None]);
|
||||
assert_eq!(copr.return_types, vec![None]);
|
||||
assert_eq!(copr.kwarg, Some("params".to_string()));
|
||||
assert_eq!(copr.script, script);
|
||||
assert!(copr.code_obj.is_some());
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
use rustpython_codegen::compile::compile_top;
|
||||
use rustpython_compiler::{CompileOpts, Mode};
|
||||
use rustpython_compiler_core::CodeObject;
|
||||
use rustpython_parser::ast::{Located, Location};
|
||||
use rustpython_parser::ast::{ArgData, Located, Location};
|
||||
use rustpython_parser::{ast, parser};
|
||||
use snafu::ResultExt;
|
||||
|
||||
@@ -31,23 +31,40 @@ fn create_located<T>(node: T, loc: Location) -> Located<T> {
|
||||
/// generate a call to the coprocessor function
|
||||
/// with arguments given in decorator's `args` list
|
||||
/// also set in location in source code to `loc`
|
||||
fn gen_call(name: &str, deco_args: &DecoratorArgs, loc: &Location) -> ast::Stmt<()> {
|
||||
let mut loc = loc.to_owned();
|
||||
fn gen_call(
|
||||
name: &str,
|
||||
deco_args: &DecoratorArgs,
|
||||
kwarg: &Option<String>,
|
||||
loc: &Location,
|
||||
) -> ast::Stmt<()> {
|
||||
let mut loc = *loc;
|
||||
// adding a line to avoid confusing if any error occurs when calling the function
|
||||
// then the pretty print will point to the last line in code
|
||||
// instead of point to any of existing code written by user.
|
||||
loc.newline();
|
||||
let args: Vec<Located<ast::ExprKind>> = deco_args
|
||||
.arg_names
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let node = ast::ExprKind::Name {
|
||||
id: v.to_owned(),
|
||||
ctx: ast::ExprContext::Load,
|
||||
};
|
||||
create_located(node, loc)
|
||||
})
|
||||
.collect();
|
||||
let mut args: Vec<Located<ast::ExprKind>> = if let Some(arg_names) = &deco_args.arg_names {
|
||||
arg_names
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let node = ast::ExprKind::Name {
|
||||
id: v.clone(),
|
||||
ctx: ast::ExprContext::Load,
|
||||
};
|
||||
create_located(node, loc)
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
if let Some(kwarg) = kwarg {
|
||||
let node = ast::ExprKind::Name {
|
||||
id: kwarg.clone(),
|
||||
ctx: ast::ExprContext::Load,
|
||||
};
|
||||
args.push(create_located(node, loc));
|
||||
}
|
||||
|
||||
let func = ast::ExprKind::Call {
|
||||
func: Box::new(create_located(
|
||||
ast::ExprKind::Name {
|
||||
@@ -71,7 +88,12 @@ fn gen_call(name: &str, deco_args: &DecoratorArgs, loc: &Location) -> ast::Stmt<
|
||||
/// So we should avoid running too much Python Bytecode, hence in this function we delete `@` decorator(instead of actually write a decorator in python)
|
||||
/// And add a function call in the end and also
|
||||
/// strip type annotation
|
||||
pub fn compile_script(name: &str, deco_args: &DecoratorArgs, script: &str) -> Result<CodeObject> {
|
||||
pub fn compile_script(
|
||||
name: &str,
|
||||
deco_args: &DecoratorArgs,
|
||||
kwarg: &Option<String>,
|
||||
script: &str,
|
||||
) -> Result<CodeObject> {
|
||||
// note that it's important to use `parser::Mode::Interactive` so the ast can be compile to return a result instead of return None in eval mode
|
||||
let mut top =
|
||||
parser::parse(script, parser::Mode::Interactive, "<embedded>").context(PyParseSnafu)?;
|
||||
@@ -89,6 +111,20 @@ pub fn compile_script(name: &str, deco_args: &DecoratorArgs, script: &str) -> Re
|
||||
type_comment: __main__,
|
||||
} = &mut stmt.node
|
||||
{
|
||||
// Rewrite kwargs in coprocessor, make it as a positional argument
|
||||
if !decorator_list.is_empty() {
|
||||
if let Some(kwarg) = kwarg {
|
||||
args.kwarg = None;
|
||||
let node = ArgData {
|
||||
arg: kwarg.clone(),
|
||||
annotation: None,
|
||||
type_comment: Some("kwargs".to_string()),
|
||||
};
|
||||
let kwarg = create_located(node, stmt.location);
|
||||
args.args.push(kwarg);
|
||||
}
|
||||
}
|
||||
|
||||
*decorator_list = Vec::new();
|
||||
// strip type annotation
|
||||
// def a(b: int, c:int) -> int
|
||||
@@ -115,14 +151,14 @@ pub fn compile_script(name: &str, deco_args: &DecoratorArgs, script: &str) -> Re
|
||||
}
|
||||
// Append statement which calling coprocessor function.
|
||||
// It's safe to unwrap loc, it is always exists.
|
||||
stmts.push(gen_call(name, deco_args, &loc.unwrap()));
|
||||
stmts.push(gen_call(name, deco_args, kwarg, &loc.unwrap()));
|
||||
} else {
|
||||
return fail_parse_error!(format!("Expect statement in script, found: {top:?}"), None);
|
||||
}
|
||||
// use `compile::Mode::BlockExpr` so it return the result of statement
|
||||
compile_top(
|
||||
&top,
|
||||
"<embedded>".to_owned(),
|
||||
"<embedded>".to_string(),
|
||||
Mode::BlockExpr,
|
||||
CompileOpts { optimize: 0 },
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ use crate::python::error::{ensure, CoprParseSnafu, PyParseSnafu, Result};
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DecoratorArgs {
|
||||
pub arg_names: Vec<String>,
|
||||
pub arg_names: Option<Vec<String>>,
|
||||
pub ret_names: Vec<String>,
|
||||
pub sql: Option<String>,
|
||||
// maybe add a URL for connecting or what?
|
||||
@@ -58,7 +58,7 @@ fn py_str_to_string(s: &ast::Expr<()>) -> Result<String> {
|
||||
kind: _,
|
||||
} = &s.node
|
||||
{
|
||||
Ok(v.to_owned())
|
||||
Ok(v.clone())
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!(
|
||||
@@ -100,10 +100,7 @@ fn try_into_datatype(ty: &str, loc: &Location) -> Result<Option<ConcreteDataType
|
||||
// for any datatype
|
||||
"_" => Ok(None),
|
||||
// note the different between "_" and _
|
||||
_ => fail_parse_error!(
|
||||
format!("Unknown datatype: {ty} at {loc:?}"),
|
||||
Some(loc.to_owned())
|
||||
),
|
||||
_ => fail_parse_error!(format!("Unknown datatype: {ty} at {loc:?}"), Some(*loc)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,7 +260,7 @@ fn parse_annotation(sub: &ast::Expr<()>) -> Result<AnnotationInfo> {
|
||||
fn parse_keywords(keywords: &Vec<ast::Keyword<()>>) -> Result<DecoratorArgs> {
|
||||
// more keys maybe add to this list of `avail_key`(like `sql` for querying and maybe config for connecting to database?), for better extension using a `HashSet` in here
|
||||
let avail_key = HashSet::from(["args", "returns", "sql"]);
|
||||
let opt_keys = HashSet::from(["sql"]);
|
||||
let opt_keys = HashSet::from(["sql", "args"]);
|
||||
let mut visited_key = HashSet::new();
|
||||
let len_min = avail_key.len() - opt_keys.len();
|
||||
let len_max = avail_key.len();
|
||||
@@ -298,7 +295,7 @@ fn parse_keywords(keywords: &Vec<ast::Keyword<()>>) -> Result<DecoratorArgs> {
|
||||
visited_key.insert(s);
|
||||
}
|
||||
match s {
|
||||
"args" => ret_args.arg_names = pylist_to_vec(&kw.node.value)?,
|
||||
"args" => ret_args.arg_names = Some(pylist_to_vec(&kw.node.value)?),
|
||||
"returns" => ret_args.ret_names = pylist_to_vec(&kw.node.value)?,
|
||||
"sql" => ret_args.sql = Some(py_str_to_string(&kw.node.value)?),
|
||||
_ => unreachable!(),
|
||||
@@ -476,17 +473,19 @@ pub fn parse_and_compile_copr(
|
||||
|
||||
// make sure both arguments&returns in function
|
||||
// and in decorator have same length
|
||||
ensure!(
|
||||
deco_args.arg_names.len() == arg_types.len(),
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"args number in decorator({}) and function({}) doesn't match",
|
||||
deco_args.arg_names.len(),
|
||||
arg_types.len()
|
||||
),
|
||||
loc: None
|
||||
}
|
||||
);
|
||||
if let Some(arg_names) = &deco_args.arg_names {
|
||||
ensure!(
|
||||
arg_names.len() == arg_types.len(),
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"args number in decorator({}) and function({}) doesn't match",
|
||||
arg_names.len(),
|
||||
arg_types.len()
|
||||
),
|
||||
loc: None
|
||||
}
|
||||
);
|
||||
}
|
||||
ensure!(
|
||||
deco_args.ret_names.len() == return_types.len(),
|
||||
CoprParseSnafu {
|
||||
@@ -498,13 +497,15 @@ pub fn parse_and_compile_copr(
|
||||
loc: None
|
||||
}
|
||||
);
|
||||
let kwarg = fn_args.kwarg.as_ref().map(|arg| arg.node.arg.clone());
|
||||
coprocessor = Some(Coprocessor {
|
||||
code_obj: Some(compile::compile_script(name, &deco_args, script)?),
|
||||
code_obj: Some(compile::compile_script(name, &deco_args, &kwarg, script)?),
|
||||
name: name.to_string(),
|
||||
deco_args,
|
||||
arg_types,
|
||||
return_types,
|
||||
script: script.to_owned(),
|
||||
kwarg,
|
||||
script: script.to_string(),
|
||||
query_engine: query_engine.as_ref().map(|e| Arc::downgrade(e).into()),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
//! Python script engine
|
||||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
@@ -25,7 +26,9 @@ use common_query::error::{PyUdfSnafu, UdfTempRecordBatchSnafu};
|
||||
use common_query::prelude::Signature;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::error::{ExternalSnafu, Result as RecordBatchResult};
|
||||
use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream};
|
||||
use common_recordbatch::{
|
||||
RecordBatch, RecordBatchStream, RecordBatches, SendableRecordBatchStream,
|
||||
};
|
||||
use datafusion_expr::Volatility;
|
||||
use datatypes::schema::{ColumnSchema, SchemaRef};
|
||||
use datatypes::vectors::VectorRef;
|
||||
@@ -53,7 +56,12 @@ impl std::fmt::Display for PyUDF {
|
||||
f,
|
||||
"{}({})->",
|
||||
&self.copr.name,
|
||||
&self.copr.deco_args.arg_names.join(",")
|
||||
self.copr
|
||||
.deco_args
|
||||
.arg_names
|
||||
.as_ref()
|
||||
.unwrap_or(&vec![])
|
||||
.join(",")
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -73,11 +81,17 @@ impl PyUDF {
|
||||
|
||||
/// Fake a schema, should only be used with dynamically eval a Python Udf
|
||||
fn fake_schema(&self, columns: &[VectorRef]) -> SchemaRef {
|
||||
let arg_names = &self.copr.deco_args.arg_names;
|
||||
let empty_args = vec![];
|
||||
let arg_names = self
|
||||
.copr
|
||||
.deco_args
|
||||
.arg_names
|
||||
.as_ref()
|
||||
.unwrap_or(&empty_args);
|
||||
let col_sch: Vec<_> = columns
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, col)| ColumnSchema::new(arg_names[i].to_owned(), col.data_type(), true))
|
||||
.map(|(i, col)| ColumnSchema::new(arg_names[i].clone(), col.data_type(), true))
|
||||
.collect();
|
||||
let schema = datatypes::schema::Schema::new(col_sch);
|
||||
Arc::new(schema)
|
||||
@@ -97,7 +111,7 @@ impl Function for PyUDF {
|
||||
match self.copr.return_types.get(0) {
|
||||
Some(Some(AnnotationInfo {
|
||||
datatype: Some(ty), ..
|
||||
})) => Ok(ty.to_owned()),
|
||||
})) => Ok(ty.clone()),
|
||||
_ => PyUdfSnafu {
|
||||
msg: "Can't found return type for python UDF {self}",
|
||||
}
|
||||
@@ -113,7 +127,7 @@ impl Function for PyUDF {
|
||||
match ty {
|
||||
Some(AnnotationInfo {
|
||||
datatype: Some(ty), ..
|
||||
}) => arg_types.push(ty.to_owned()),
|
||||
}) => arg_types.push(ty.clone()),
|
||||
_ => {
|
||||
know_all_types = false;
|
||||
break;
|
||||
@@ -135,9 +149,8 @@ impl Function for PyUDF {
|
||||
// FIXME(discord9): exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right?
|
||||
let schema = self.fake_schema(columns);
|
||||
let columns = columns.to_vec();
|
||||
// TODO(discord9): remove unwrap
|
||||
let rb = RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?;
|
||||
let res = exec_parsed(&self.copr, &rb).map_err(|err| {
|
||||
let rb = Some(RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?);
|
||||
let res = exec_parsed(&self.copr, &rb, &HashMap::new()).map_err(|err| {
|
||||
PyUdfSnafu {
|
||||
msg: format!("{err:#?}"),
|
||||
}
|
||||
@@ -153,7 +166,7 @@ impl Function for PyUDF {
|
||||
|
||||
// TODO(discord9): more error handling
|
||||
let res0 = res.column(0);
|
||||
Ok(res0.to_owned())
|
||||
Ok(res0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,13 +181,14 @@ impl PyScript {
|
||||
pub fn register_udf(&self) {
|
||||
let udf = PyUDF::from_copr(self.copr.clone());
|
||||
PyUDF::register_as_udf(udf.clone());
|
||||
PyUDF::register_to_query_engine(udf, self.query_engine.to_owned());
|
||||
PyUDF::register_to_query_engine(udf, self.query_engine.clone());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CoprStream {
|
||||
stream: SendableRecordBatchStream,
|
||||
copr: CoprocessorRef,
|
||||
params: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl RecordBatchStream for CoprStream {
|
||||
@@ -190,7 +204,7 @@ impl Stream for CoprStream {
|
||||
match Pin::new(&mut self.stream).poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Some(Ok(recordbatch))) => {
|
||||
let batch = exec_parsed(&self.copr, &recordbatch)
|
||||
let batch = exec_parsed(&self.copr, &Some(recordbatch), &self.params)
|
||||
.map_err(BoxedError::new)
|
||||
.context(ExternalSnafu)?;
|
||||
|
||||
@@ -218,7 +232,7 @@ impl Script for PyScript {
|
||||
self
|
||||
}
|
||||
|
||||
async fn execute(&self, _ctx: EvalContext) -> Result<Output> {
|
||||
async fn execute(&self, params: HashMap<String, String>, _ctx: EvalContext) -> Result<Output> {
|
||||
if let Some(sql) = &self.copr.deco_args.sql {
|
||||
let stmt = QueryLanguageParser::parse_sql(sql).unwrap();
|
||||
ensure!(
|
||||
@@ -231,12 +245,17 @@ impl Script for PyScript {
|
||||
let res = self.query_engine.execute(&plan).await?;
|
||||
let copr = self.copr.clone();
|
||||
match res {
|
||||
Output::Stream(stream) => Ok(Output::Stream(Box::pin(CoprStream { copr, stream }))),
|
||||
Output::Stream(stream) => Ok(Output::Stream(Box::pin(CoprStream {
|
||||
params,
|
||||
copr,
|
||||
stream,
|
||||
}))),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
// TODO(boyan): try to retrieve sql from user request
|
||||
error::MissingSqlSnafu {}.fail()
|
||||
let batch = exec_parsed(&self.copr, &None, ¶ms)?;
|
||||
let batches = RecordBatches::try_new(batch.schema.clone(), vec![batch]).unwrap();
|
||||
Ok(Output::RecordBatches(batches))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -284,6 +303,7 @@ mod tests {
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_recordbatch::util;
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{Float64Vector, Int64Vector};
|
||||
use query::QueryEngineFactory;
|
||||
use table::table::numbers::NumbersTable;
|
||||
@@ -326,7 +346,10 @@ def test(number)->vector[u32]:
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script.execute(EvalContext::default()).await.unwrap();
|
||||
let output = script
|
||||
.execute(HashMap::default(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let res = common_recordbatch::util::collect_batches(match output {
|
||||
Output::Stream(s) => s,
|
||||
_ => unreachable!(),
|
||||
@@ -337,6 +360,36 @@ def test(number)->vector[u32]:
|
||||
assert_eq!(rb.column(0).len(), 100);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_params_in_py() {
|
||||
let script_engine = sample_script_engine();
|
||||
|
||||
let script = r#"
|
||||
@copr(returns = ["number"])
|
||||
def test(**params)->vector[i64]:
|
||||
return int(params['a']) + int(params['b'])
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let mut params = HashMap::new();
|
||||
params.insert("a".to_string(), "30".to_string());
|
||||
params.insert("b".to_string(), "12".to_string());
|
||||
let _output = script
|
||||
.execute(params, EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let res = match _output {
|
||||
Output::RecordBatches(s) => s,
|
||||
_ => todo!(),
|
||||
};
|
||||
let rb = res.iter().next().expect("One and only one recordbatch");
|
||||
assert_eq!(rb.column(0).len(), 1);
|
||||
let result = rb.column(0).get(0);
|
||||
assert!(matches!(result, Value::Int64(42)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_data_frame_in_py() {
|
||||
let script_engine = sample_script_engine();
|
||||
@@ -353,7 +406,10 @@ def test(number)->vector[u32]:
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let _output = script.execute(EvalContext::default()).await.unwrap();
|
||||
let _output = script
|
||||
.execute(HashMap::new(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let res = common_recordbatch::util::collect_batches(match _output {
|
||||
Output::Stream(s) => s,
|
||||
_ => todo!(),
|
||||
@@ -382,7 +438,10 @@ def test(a, b, c):
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script.execute(EvalContext::default()).await.unwrap();
|
||||
let output = script
|
||||
.execute(HashMap::new(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
match output {
|
||||
Output::Stream(stream) => {
|
||||
let numbers = util::collect(stream).await.unwrap();
|
||||
@@ -417,7 +476,10 @@ def test(a):
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script.execute(EvalContext::default()).await.unwrap();
|
||||
let output = script
|
||||
.execute(HashMap::new(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
match output {
|
||||
Output::Stream(stream) => {
|
||||
let numbers = util::collect(stream).await.unwrap();
|
||||
|
||||
@@ -216,7 +216,7 @@ pub fn visualize_loc(
|
||||
/// extract a reason for [`Error`] in string format, also return a location if possible
|
||||
pub fn get_error_reason_loc(err: &Error) -> (String, Option<Location>) {
|
||||
match err {
|
||||
Error::CoprParse { reason, loc, .. } => (reason.clone(), loc.to_owned()),
|
||||
Error::CoprParse { reason, loc, .. } => (reason.clone(), *loc),
|
||||
Error::Other { reason, .. } => (reason.clone(), None),
|
||||
Error::PyRuntime { msg, .. } => (msg.clone(), None),
|
||||
Error::PyParse { source, .. } => (source.error.to_string(), Some(source.location)),
|
||||
|
||||
@@ -126,7 +126,7 @@ fn run_ron_testcases() {
|
||||
}
|
||||
Predicate::ExecIsOk { fields, columns } => {
|
||||
let rb = create_sample_recordbatch();
|
||||
let res = coprocessor::exec_coprocessor(&testcase.code, &rb).unwrap();
|
||||
let res = coprocessor::exec_coprocessor(&testcase.code, &Some(rb)).unwrap();
|
||||
fields
|
||||
.iter()
|
||||
.zip(res.schema.column_schemas())
|
||||
@@ -152,7 +152,7 @@ fn run_ron_testcases() {
|
||||
reason: part_reason,
|
||||
} => {
|
||||
let rb = create_sample_recordbatch();
|
||||
let res = coprocessor::exec_coprocessor(&testcase.code, &rb);
|
||||
let res = coprocessor::exec_coprocessor(&testcase.code, &Some(rb));
|
||||
assert!(res.is_err(), "{res:#?}\nExpect Err(...), actual Ok(...)");
|
||||
if let Err(res) = res {
|
||||
error!(
|
||||
@@ -254,7 +254,7 @@ def calc_rvs(open_time, close):
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let ret = coprocessor::exec_coprocessor(python_source, &rb);
|
||||
let ret = coprocessor::exec_coprocessor(python_source, &Some(rb));
|
||||
if let Err(Error::PyParse {
|
||||
backtrace: _,
|
||||
source,
|
||||
@@ -304,7 +304,7 @@ def a(cpu, mem):
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let ret = coprocessor::exec_coprocessor(python_source, &rb);
|
||||
let ret = coprocessor::exec_coprocessor(python_source, &Some(rb));
|
||||
if let Err(Error::PyParse {
|
||||
backtrace: _,
|
||||
source,
|
||||
|
||||
@@ -19,7 +19,7 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vecto
|
||||
result: (
|
||||
name: "a",
|
||||
deco_args: (
|
||||
arg_names: ["cpu", "mem"],
|
||||
arg_names: Some(["cpu", "mem"]),
|
||||
ret_names: ["perf", "what", "how", "why"],
|
||||
),
|
||||
arg_types: [
|
||||
@@ -49,7 +49,62 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vecto
|
||||
datatype: None,
|
||||
is_nullable: true
|
||||
)),
|
||||
]
|
||||
],
|
||||
kwarg: None,
|
||||
)
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "correct_parse_params",
|
||||
code: r#"
|
||||
import greptime as gt
|
||||
from greptime import pow
|
||||
def add(a, b):
|
||||
return a + b
|
||||
def sub(a, b):
|
||||
return a - b
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32], mem: vector[f64], **params) -> (vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
for key, value in params.items():
|
||||
print("%s == %s" % (key, value))
|
||||
return add(cpu, mem), sub(cpu, mem), cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsOk(
|
||||
result: (
|
||||
name: "a",
|
||||
deco_args: (
|
||||
arg_names: Some(["cpu", "mem"]),
|
||||
ret_names: ["perf", "what", "how", "why"],
|
||||
),
|
||||
arg_types: [
|
||||
Some((
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
],
|
||||
return_types: [
|
||||
Some((
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
)),
|
||||
Some((
|
||||
datatype: None,
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: None,
|
||||
is_nullable: true
|
||||
)),
|
||||
],
|
||||
kwarg: Some("params"),
|
||||
)
|
||||
)
|
||||
),
|
||||
@@ -231,7 +286,7 @@ def a(cpu: vector[f64], mem: vector[f64])->(vector[f64|None], vector[into(f64)],
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
" keyword argument, found "
|
||||
"Expect `returns` keyword"
|
||||
)
|
||||
),
|
||||
(
|
||||
|
||||
@@ -490,6 +490,21 @@ impl PyVector {
|
||||
self.as_vector_ref().len()
|
||||
}
|
||||
|
||||
#[pymethod(name = "concat")]
|
||||
fn concat(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let left = self.to_arrow_array();
|
||||
let right = other.to_arrow_array();
|
||||
|
||||
let res = compute::concat(&[left.as_ref(), right.as_ref()]);
|
||||
let res = res.map_err(|err| vm.new_runtime_error(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast result into vector, result: {res:?}, err: {e:?}",
|
||||
))
|
||||
})?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
|
||||
/// take a boolean array and filters the Array, returning elements matching the filter (i.e. where the values are true).
|
||||
#[pymethod(name = "filter")]
|
||||
fn filter(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
@@ -549,7 +564,7 @@ impl PyVector {
|
||||
// in the newest version of rustpython_vm, wrapped_at for isize is replace by wrap_index(i, len)
|
||||
let i = i
|
||||
.wrapped_at(self.len())
|
||||
.ok_or_else(|| vm.new_index_error("PyVector index out of range".to_owned()))?;
|
||||
.ok_or_else(|| vm.new_index_error("PyVector index out of range".to_string()))?;
|
||||
Ok(val_to_pyobj(self.as_vector_ref().get(i), vm))
|
||||
}
|
||||
|
||||
@@ -912,7 +927,7 @@ impl AsSequence for PyVector {
|
||||
zelf.getitem_by_index(i, vm)
|
||||
}),
|
||||
ass_item: atomic_func!(|_seq, _i, _value, vm| {
|
||||
Err(vm.new_type_error("PyVector object doesn't support item assigns".to_owned()))
|
||||
Err(vm.new_type_error("PyVector object doesn't support item assigns".to_string()))
|
||||
}),
|
||||
..PySequenceMethods::NOT_IMPLEMENTED
|
||||
});
|
||||
@@ -1080,7 +1095,7 @@ pub mod tests {
|
||||
.compile(
|
||||
script,
|
||||
rustpython_compiler_core::Mode::BlockExpr,
|
||||
"<embedded>".to_owned(),
|
||||
"<embedded>".to_string(),
|
||||
)
|
||||
.map_err(|err| vm.new_syntax_error(&err))?;
|
||||
let ret = vm.run_code_obj(code_obj, scope);
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
|
||||
use axum::extract::{Json, Query, RawBody, State};
|
||||
@@ -81,10 +81,12 @@ pub async fn scripts(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, Default)]
|
||||
pub struct ScriptQuery {
|
||||
pub db: Option<String>,
|
||||
pub name: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub params: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Handler to execute script
|
||||
@@ -110,7 +112,7 @@ pub async fn run_script(
|
||||
// TODO(sunng87): query_context and db name resolution
|
||||
|
||||
let output = script_handler
|
||||
.execute_script(schema.unwrap(), name.unwrap())
|
||||
.execute_script(schema.unwrap(), name.unwrap(), params.params)
|
||||
.await;
|
||||
let resp = JsonResponse::from_output(vec![output]).await;
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
pub mod grpc;
|
||||
pub mod sql;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::prometheus::remote::{ReadRequest, WriteRequest};
|
||||
@@ -45,7 +46,12 @@ pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
|
||||
#[async_trait]
|
||||
pub trait ScriptHandler {
|
||||
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>;
|
||||
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output>;
|
||||
async fn execute_script(
|
||||
&self,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> Result<Output>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@@ -23,7 +23,10 @@ use servers::http::{handler as http_handler, script as script_handler, ApiState,
|
||||
use session::context::UserInfo;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::{create_testing_script_handler, create_testing_sql_query_handler};
|
||||
use crate::{
|
||||
create_testing_script_handler, create_testing_sql_query_handler, ScriptHandlerRef,
|
||||
ServerSqlQueryHandlerRef,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sql_not_provided() {
|
||||
@@ -68,6 +71,25 @@ async fn test_sql_output_rows() {
|
||||
match &json.output().expect("assertion failed")[0] {
|
||||
JsonOutput::Records(records) => {
|
||||
assert_eq!(1, records.num_rows());
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "SUM(numbers.uint32s)",
|
||||
"data_type": "UInt64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
4950
|
||||
]
|
||||
]
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
@@ -95,6 +117,25 @@ async fn test_sql_form() {
|
||||
match &json.output().expect("assertion failed")[0] {
|
||||
JsonOutput::Records(records) => {
|
||||
assert_eq!(1, records.num_rows());
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "SUM(numbers.uint32s)",
|
||||
"data_type": "UInt64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
4950
|
||||
]
|
||||
]
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
@@ -110,18 +151,11 @@ async fn test_metrics() {
|
||||
assert!(text.contains("test_metrics counter"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let script = r#"
|
||||
@copr(sql='select uint32s as number from numbers', args=['number'], returns=['n'])
|
||||
def test(n):
|
||||
return n;
|
||||
"#
|
||||
.to_string();
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
|
||||
async fn insert_script(
|
||||
script: String,
|
||||
script_handler: ScriptHandlerRef,
|
||||
sql_handler: ServerSqlQueryHandlerRef,
|
||||
) {
|
||||
let body = RawBody(Body::from(script.clone()));
|
||||
let invalid_query = create_invalid_script_query();
|
||||
let Json(json) = script_handler::scripts(
|
||||
@@ -136,12 +170,13 @@ def test(n):
|
||||
assert!(!json.success(), "{json:?}");
|
||||
assert_eq!(json.error().unwrap(), "Invalid argument: invalid schema");
|
||||
|
||||
let body = RawBody(Body::from(script));
|
||||
let body = RawBody(Body::from(script.clone()));
|
||||
let exec = create_script_query();
|
||||
// Insert the script
|
||||
let Json(json) = script_handler::scripts(
|
||||
State(ApiState {
|
||||
sql_handler,
|
||||
script_handler: Some(script_handler),
|
||||
sql_handler: sql_handler.clone(),
|
||||
script_handler: Some(script_handler.clone()),
|
||||
}),
|
||||
exec,
|
||||
body,
|
||||
@@ -152,10 +187,144 @@ def test(n):
|
||||
assert!(json.output().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let script = r#"
|
||||
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
|
||||
def test(n) -> vector[i64]:
|
||||
return n;
|
||||
"#
|
||||
.to_string();
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
|
||||
|
||||
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
|
||||
// Run the script
|
||||
let exec = create_script_query();
|
||||
let Json(json) = script_handler::run_script(
|
||||
State(ApiState {
|
||||
sql_handler,
|
||||
script_handler: Some(script_handler),
|
||||
}),
|
||||
exec,
|
||||
)
|
||||
.await;
|
||||
assert!(json.success(), "{json:?}");
|
||||
assert!(json.error().is_none());
|
||||
|
||||
match &json.output().unwrap()[0] {
|
||||
JsonOutput::Records(records) => {
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(5, records.num_rows());
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "n",
|
||||
"data_type": "Int64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
0
|
||||
],
|
||||
[
|
||||
1
|
||||
],
|
||||
[
|
||||
2
|
||||
],
|
||||
[
|
||||
3
|
||||
],
|
||||
[
|
||||
4
|
||||
]
|
||||
]
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts_with_params() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let script = r#"
|
||||
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
|
||||
def test(n, **params) -> vector[i64]:
|
||||
return n + int(params['a'])
|
||||
"#
|
||||
.to_string();
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
|
||||
|
||||
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
|
||||
// Run the script
|
||||
let mut exec = create_script_query();
|
||||
exec.0.params.insert("a".to_string(), "42".to_string());
|
||||
let Json(json) = script_handler::run_script(
|
||||
State(ApiState {
|
||||
sql_handler,
|
||||
script_handler: Some(script_handler),
|
||||
}),
|
||||
exec,
|
||||
)
|
||||
.await;
|
||||
assert!(json.success(), "{json:?}");
|
||||
assert!(json.error().is_none());
|
||||
|
||||
match &json.output().unwrap()[0] {
|
||||
JsonOutput::Records(records) => {
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(5, records.num_rows());
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "n",
|
||||
"data_type": "Int64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
42
|
||||
],
|
||||
[
|
||||
43
|
||||
],
|
||||
[
|
||||
44
|
||||
],
|
||||
[
|
||||
45
|
||||
],
|
||||
[
|
||||
46
|
||||
]
|
||||
]
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_script_query() -> Query<script_handler::ScriptQuery> {
|
||||
Query(script_handler::ScriptQuery {
|
||||
db: Some("test".to_string()),
|
||||
name: Some("test".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -163,6 +332,7 @@ fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
|
||||
Query(script_handler::ScriptQuery {
|
||||
db: None,
|
||||
name: None,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -120,12 +120,20 @@ impl ScriptHandler for DummyInstance {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute_script(&self, schema: &str, name: &str) -> Result<Output> {
|
||||
async fn execute_script(
|
||||
&self,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> Result<Output> {
|
||||
let key = format!("{schema}_{name}");
|
||||
|
||||
let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone();
|
||||
|
||||
Ok(py_script.execute(EvalContext::default()).await.unwrap())
|
||||
Ok(py_script
|
||||
.execute(params, EvalContext::default())
|
||||
.await
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user