From ebbf1e43b5b2e2e09baf31b57bee19b5dbd83eb4 Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Fri, 3 Feb 2023 15:05:27 +0800 Subject: [PATCH] feat: Query using sql inside python script (#884) * feat: add weakref to QueryEngine in copr * feat: sql query in python * fix: make_class for Query Engine * fix: use `Handle::try_current` instead * fix: cache `Runtime` * fix: lock file conflict * fix: dedicated thread for blocking&fix test * test: remove unnecessary print --- src/script/src/python/coprocessor.rs | 138 ++++++++++++++++++++- src/script/src/python/coprocessor/parse.rs | 8 +- src/script/src/python/engine.rs | 43 ++++++- src/script/src/python/test.rs | 6 +- 4 files changed, 180 insertions(+), 15 deletions(-) diff --git a/src/script/src/python/coprocessor.rs b/src/script/src/python/coprocessor.rs index f05f53dcb1..c10e5dc513 100644 --- a/src/script/src/python/coprocessor.rs +++ b/src/script/src/python/coprocessor.rs @@ -18,15 +18,17 @@ pub mod parse; use std::cell::RefCell; use std::collections::HashSet; use std::result::Result as StdResult; -use std::sync::Arc; +use std::sync::{Arc, Weak}; -use common_recordbatch::RecordBatch; +use common_recordbatch::{RecordBatch, RecordBatches}; use common_telemetry::info; use datatypes::arrow::array::Array; use datatypes::arrow::compute; use datatypes::data_type::{ConcreteDataType, DataType}; use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; use datatypes::vectors::{Helper, VectorRef}; +use query::parser::QueryLanguageParser; +use query::QueryEngine; use rustpython_compiler_core::CodeObject; use rustpython_vm as vm; use rustpython_vm::class::PyClassImpl; @@ -34,9 +36,10 @@ use rustpython_vm::AsObject; #[cfg(test)] use serde::Deserialize; use snafu::{OptionExt, ResultExt}; -use vm::builtins::{PyBaseExceptionRef, PyTuple}; +use vm::builtins::{PyBaseExceptionRef, PyList, PyListRef, PyTuple}; +use vm::convert::ToPyObject; use vm::scope::Scope; -use vm::{Interpreter, PyObjectRef, VirtualMachine}; +use vm::{pyclass, Interpreter, PyObjectRef, PyPayload, PyResult, VirtualMachine}; use crate::python::builtins::greptime_builtin; use crate::python::coprocessor::parse::DecoratorArgs; @@ -77,6 +80,31 @@ pub struct Coprocessor { // but CodeObject doesn't. #[cfg_attr(test, serde(skip))] pub code_obj: Option, + #[cfg_attr(test, serde(skip))] + pub query_engine: Option, +} + +#[derive(Clone)] +pub struct QueryEngineWeakRef(pub Weak); + +impl From> for QueryEngineWeakRef { + fn from(value: Weak) -> Self { + Self(value) + } +} + +impl From<&Arc> for QueryEngineWeakRef { + fn from(value: &Arc) -> Self { + Self(Arc::downgrade(value)) + } +} + +impl std::fmt::Debug for QueryEngineWeakRef { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("QueryEngineWeakRef") + .field(&self.0.upgrade().map(|f| f.name().to_owned())) + .finish() + } } impl PartialEq for Coprocessor { @@ -318,10 +346,99 @@ pub fn exec_coprocessor(script: &str, rb: &RecordBatch) -> Result { // 1. parse the script and check if it's only a function with `@coprocessor` decorator, and get `args` and `returns`, // 2. also check for exist of `args` in `rb`, if not found, return error // TODO(discord9): cache the result of parse_copr - let copr = parse::parse_and_compile_copr(script)?; + let copr = parse::parse_and_compile_copr(script, None)?; exec_parsed(&copr, rb) } +#[pyclass(module = false, name = "query_engine")] +#[derive(Debug, PyPayload)] +pub struct PyQueryEngine { + inner: QueryEngineWeakRef, +} + +#[pyclass] +impl PyQueryEngine { + // TODO(discord9): find a better way to call sql query api, now we don't if we are in async contex or not + /// return sql query results in List[List[PyVector]], or List[usize] for AffectedRows number if no recordbatches is returned + #[pymethod] + fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult { + enum Either { + Rb(RecordBatches), + AffectedRows(usize), + } + let query = self.inner.0.upgrade(); + let thread_handle = std::thread::spawn(move || -> std::result::Result<_, String> { + if let Some(engine) = query { + let stmt = QueryLanguageParser::parse_sql(s.as_str()).map_err(|e| e.to_string())?; + let plan = engine + .statement_to_plan(stmt, Default::default()) + .map_err(|e| e.to_string())?; + // To prevent the error of nested creating Runtime, if is nested, use the parent runtime instead + + let rt = tokio::runtime::Runtime::new().map_err(|e| e.to_string())?; + let handle = rt.handle().clone(); + let res = handle.block_on(async { + let res = engine + .clone() + .execute(&plan) + .await + .map_err(|e| e.to_string()); + match res { + Ok(common_query::Output::AffectedRows(cnt)) => { + Ok(Either::AffectedRows(cnt)) + } + Ok(common_query::Output::RecordBatches(rbs)) => Ok(Either::Rb(rbs)), + Ok(common_query::Output::Stream(s)) => Ok(Either::Rb( + common_recordbatch::util::collect_batches(s).await.unwrap(), + )), + Err(e) => Err(e), + } + })?; + Ok(res) + } else { + Err("Query Engine is already dropped".to_string()) + } + }); + thread_handle + .join() + .map_err(|e| { + vm.new_system_error(format!("Dedicated thread for sql query panic: {e:?}")) + })? + .map_err(|e| vm.new_system_error(e)) + .map(|rbs| match rbs { + Either::Rb(rbs) => { + let mut top_vec = Vec::with_capacity(rbs.iter().count()); + for rb in rbs.iter() { + let mut vec_of_vec = Vec::with_capacity(rb.columns().len()); + for v in rb.columns() { + let v = PyVector::from(v.to_owned()); + vec_of_vec.push(v.to_pyobject(vm)); + } + let vec_of_vec = PyList::new_ref(vec_of_vec, vm.as_ref()).to_pyobject(vm); + top_vec.push(vec_of_vec); + } + let top_vec = PyList::new_ref(top_vec, vm.as_ref()); + top_vec + } + Either::AffectedRows(cnt) => { + PyList::new_ref(vec![vm.ctx.new_int(cnt).into()], vm.as_ref()) + } + }) + } +} + +fn set_query_engine_in_scope( + scope: &Scope, + vm: &VirtualMachine, + query_engine: PyQueryEngine, +) -> Result<()> { + scope + .locals + .as_object() + .set_item("query", query_engine.to_pyobject(vm), vm) + .map_err(|e| format_py_error(e, vm)) +} + pub(crate) fn exec_with_cached_vm( copr: &Coprocessor, rb: &RecordBatch, @@ -333,6 +450,15 @@ pub(crate) fn exec_with_cached_vm( // set arguments with given name and values let scope = vm.new_scope_with_builtins(); set_items_in_scope(&scope, vm, &copr.deco_args.arg_names, args)?; + if let Some(engine) = &copr.query_engine { + let query_engine = PyQueryEngine { + inner: engine.clone(), + }; + + // put a object named with query of class PyQueryEngine in scope + PyQueryEngine::make_class(&vm.ctx); + set_query_engine_in_scope(&scope, vm, query_engine)?; + } // It's safe to unwrap code_object, it's already compiled before. let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap()); @@ -444,7 +570,7 @@ def test(a, b, c): return add(a, b) / g.sqrt(c) "#; - let copr = parse_and_compile_copr(script).unwrap(); + let copr = parse_and_compile_copr(script, None).unwrap(); assert_eq!(copr.name, "test"); let deco_args = copr.deco_args.clone(); assert_eq!( diff --git a/src/script/src/python/coprocessor/parse.rs b/src/script/src/python/coprocessor/parse.rs index 570c3368f8..18bcdca8c8 100644 --- a/src/script/src/python/coprocessor/parse.rs +++ b/src/script/src/python/coprocessor/parse.rs @@ -13,8 +13,10 @@ // limitations under the License. use std::collections::HashSet; +use std::sync::Arc; use datatypes::prelude::ConcreteDataType; +use query::QueryEngineRef; use rustpython_parser::ast::{Arguments, Location}; use rustpython_parser::{ast, parser}; #[cfg(test)] @@ -423,7 +425,10 @@ fn get_return_annotations(rets: &ast::Expr<()>) -> Result Result { +pub fn parse_and_compile_copr( + script: &str, + query_engine: Option, +) -> Result { let python_ast = parser::parse_program(script, "").context(PyParseSnafu)?; let mut coprocessor = None; @@ -500,6 +505,7 @@ pub fn parse_and_compile_copr(script: &str) -> Result { arg_types, return_types, script: script.to_owned(), + query_engine: query_engine.as_ref().map(|e| Arc::downgrade(e).into()), }); } } else if matches!( diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index fe1060ab9e..7a68a63e6b 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -265,7 +265,10 @@ impl ScriptEngine for PyEngine { } async fn compile(&self, script: &str, _ctx: CompileContext) -> Result { - let copr = Arc::new(parse::parse_and_compile_copr(script)?); + let copr = Arc::new(parse::parse_and_compile_copr( + script, + Some(self.query_engine.clone()), + )?); Ok(PyScript { copr, @@ -287,8 +290,7 @@ mod tests { use super::*; - #[tokio::test] - async fn test_compile_execute() { + fn sample_script_engine() -> PyEngine { let catalog_list = catalog::local::new_memory_catalog_list().unwrap(); let default_schema = Arc::new(MemorySchemaProvider::new()); @@ -306,7 +308,38 @@ mod tests { let factory = QueryEngineFactory::new(catalog_list); let query_engine = factory.query_engine(); - let script_engine = PyEngine::new(query_engine.clone()); + PyEngine::new(query_engine.clone()) + } + + #[tokio::test] + async fn test_sql_in_py() { + let script_engine = sample_script_engine(); + + let script = r#" +import greptime as gt + +@copr(args=["number"], returns = ["number"], sql = "select * from numbers") +def test(number)->vector[u32]: + return query.sql("select * from numbers")[0][0][1] +"#; + let script = script_engine + .compile(script, CompileContext::default()) + .await + .unwrap(); + let _output = script.execute(EvalContext::default()).await.unwrap(); + let res = common_recordbatch::util::collect_batches(match _output { + Output::Stream(s) => s, + _ => todo!(), + }) + .await + .unwrap(); + let rb = res.iter().next().expect("One and only one recordbatch"); + assert_eq!(rb.column(0).len(), 100); + } + + #[tokio::test] + async fn test_compile_execute() { + let script_engine = sample_script_engine(); // To avoid divide by zero, the script divides `add(a, b)` by `g.sqrt(c + 1)` instead of `g.sqrt(c)` let script = r#" @@ -351,7 +384,7 @@ import greptime as gt @copr(args=["number"], returns = ["r"], sql="select number from numbers limit 100") def test(a): - return gt.vector([x for x in a if x % 2 == 0]) + return gt.vector([x for x in a if x % 2 == 0]) "#; let script = script_engine .compile(script, CompileContext::default()) diff --git a/src/script/src/python/test.rs b/src/script/src/python/test.rs index 99fb5b82fe..e0778c32ad 100644 --- a/src/script/src/python/test.rs +++ b/src/script/src/python/test.rs @@ -103,13 +103,13 @@ fn run_ron_testcases() { info!(".ron test {}", testcase.name); match testcase.predicate { Predicate::ParseIsOk { result } => { - let copr = parse_and_compile_copr(&testcase.code); + let copr = parse_and_compile_copr(&testcase.code, None); let mut copr = copr.unwrap(); copr.script = "".into(); assert_eq!(copr, *result); } Predicate::ParseIsErr { reason } => { - let copr = parse_and_compile_copr(&testcase.code); + let copr = parse_and_compile_copr(&testcase.code, None); assert!(copr.is_err(), "Expect to be err, actual {copr:#?}"); let res = &copr.unwrap_err(); @@ -183,7 +183,7 @@ def a(cpu, mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[ return cpu + mem, cpu - mem, cpu * mem, cpu / mem "#; let pyast = parser::parse(python_source, parser::Mode::Interactive, "").unwrap(); - let copr = parse_and_compile_copr(python_source); + let copr = parse_and_compile_copr(python_source, None); dbg!(copr); }