fix: set locals to main.dict too (#1242)

This commit is contained in:
discord9
2023-03-27 15:23:52 +08:00
committed by GitHub
parent 0eeb5b460c
commit 8ba0741c81
2 changed files with 51 additions and 2 deletions

View File

@@ -288,6 +288,56 @@ def answer() -> vector[i64]:
.to_string(),
expect: Some(ronish!("value": vector!(Int64Vector, [43]))),
},
CoprTestCase {
script: r#"
import math
def normalize0(x):
if x is None or math.isnan(x):
return 0
elif x > 100:
return 100
elif x < 0:
return 0
else:
return x
@coprocessor(args=["number"], sql="select number from numbers limit 10", returns=["value"], backend="rspy")
def normalize(v) -> vector[i64]:
return [normalize0(x) for x in v]
"#
.to_string(),
expect: Some(ronish!(
"value": vector!(Int64Vector, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,])
)),
},
#[cfg(feature = "pyo3_backend")]
CoprTestCase {
script: r#"
import math
from greptime import vector
def normalize0(x):
if x is None or math.isnan(x):
return 0
elif x > 100:
return 100
elif x < 0:
return 0
else:
return x
@coprocessor(args=["number"], sql="select number from numbers limit 10", returns=["value"], backend="pyo3")
def normalize(v) -> vector[i64]:
return vector([normalize0(x) for x in v])
"#
.to_string(),
expect: Some(ronish!(
"value": vector!(Int64Vector, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,])
)),
},
]
}

View File

@@ -88,7 +88,6 @@ coprocessor = copr
";
let gen_call = format!("\n_return_from_coprocessor = {}(*_args_for_coprocessor, **_kwargs_for_coprocessor)", copr.name);
let script = format!("{}{}{}", dummy_decorator, copr.script, gen_call);
let args = args
.clone()
.into_iter()
@@ -106,7 +105,7 @@ coprocessor = copr
let py_main = PyModule::import(py, "__main__")?;
let globals = py_main.dict();
let locals = PyDict::new(py);
let locals = py_main.dict();
if let Some(engine) = &copr.query_engine {
let query_engine = PyQueryEngine::from_weakref(engine.clone());