From 8ba0741c819a88f69091674d3b53d93410c302ea Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Mon, 27 Mar 2023 15:23:52 +0800 Subject: [PATCH] fix: set locals to main.dict too (#1242) --- .../ffi_types/pair_tests/sample_testcases.rs | 50 +++++++++++++++++++ src/script/src/python/pyo3/copr_impl.rs | 3 +- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs b/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs index 1e6ed16bc0..54b12a076d 100644 --- a/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs +++ b/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs @@ -288,6 +288,56 @@ def answer() -> vector[i64]: .to_string(), expect: Some(ronish!("value": vector!(Int64Vector, [43]))), }, + CoprTestCase { + script: r#" +import math + +def normalize0(x): + if x is None or math.isnan(x): + return 0 + elif x > 100: + return 100 + elif x < 0: + return 0 + else: + return x + +@coprocessor(args=["number"], sql="select number from numbers limit 10", returns=["value"], backend="rspy") +def normalize(v) -> vector[i64]: + return [normalize0(x) for x in v] + +"# + .to_string(), + expect: Some(ronish!( + "value": vector!(Int64Vector, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,]) + )), + }, + #[cfg(feature = "pyo3_backend")] + CoprTestCase { + script: r#" +import math +from greptime import vector + +def normalize0(x): + if x is None or math.isnan(x): + return 0 + elif x > 100: + return 100 + elif x < 0: + return 0 + else: + return x + +@coprocessor(args=["number"], sql="select number from numbers limit 10", returns=["value"], backend="pyo3") +def normalize(v) -> vector[i64]: + return vector([normalize0(x) for x in v]) + +"# + .to_string(), + expect: Some(ronish!( + "value": vector!(Int64Vector, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,]) + )), + }, ] } diff --git a/src/script/src/python/pyo3/copr_impl.rs b/src/script/src/python/pyo3/copr_impl.rs index e2be48c6ff..25c8aa4f4c 100644 --- a/src/script/src/python/pyo3/copr_impl.rs +++ b/src/script/src/python/pyo3/copr_impl.rs @@ -88,7 +88,6 @@ coprocessor = copr "; let gen_call = format!("\n_return_from_coprocessor = {}(*_args_for_coprocessor, **_kwargs_for_coprocessor)", copr.name); let script = format!("{}{}{}", dummy_decorator, copr.script, gen_call); - let args = args .clone() .into_iter() @@ -106,7 +105,7 @@ coprocessor = copr let py_main = PyModule::import(py, "__main__")?; let globals = py_main.dict(); - let locals = PyDict::new(py); + let locals = py_main.dict(); if let Some(engine) = &copr.query_engine { let query_engine = PyQueryEngine::from_weakref(engine.clone());