feat: Query using sql inside python script (#884)

* feat: add weakref to QueryEngine in copr

* feat: sql query in python

* fix: make_class for Query Engine

* fix: use `Handle::try_current` instead

* fix: cache `Runtime`

* fix: lock file conflict

* fix: dedicated thread for blocking&fix test

* test: remove unnecessary print
This commit is contained in:
discord9
2023-02-03 15:05:27 +08:00
committed by GitHub
parent 54fe81dad9
commit ebbf1e43b5
4 changed files with 180 additions and 15 deletions

View File

@@ -18,15 +18,17 @@ pub mod parse;
use std::cell::RefCell;
use std::collections::HashSet;
use std::result::Result as StdResult;
use std::sync::Arc;
use std::sync::{Arc, Weak};
use common_recordbatch::RecordBatch;
use common_recordbatch::{RecordBatch, RecordBatches};
use common_telemetry::info;
use datatypes::arrow::array::Array;
use datatypes::arrow::compute;
use datatypes::data_type::{ConcreteDataType, DataType};
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
use datatypes::vectors::{Helper, VectorRef};
use query::parser::QueryLanguageParser;
use query::QueryEngine;
use rustpython_compiler_core::CodeObject;
use rustpython_vm as vm;
use rustpython_vm::class::PyClassImpl;
@@ -34,9 +36,10 @@ use rustpython_vm::AsObject;
#[cfg(test)]
use serde::Deserialize;
use snafu::{OptionExt, ResultExt};
use vm::builtins::{PyBaseExceptionRef, PyTuple};
use vm::builtins::{PyBaseExceptionRef, PyList, PyListRef, PyTuple};
use vm::convert::ToPyObject;
use vm::scope::Scope;
use vm::{Interpreter, PyObjectRef, VirtualMachine};
use vm::{pyclass, Interpreter, PyObjectRef, PyPayload, PyResult, VirtualMachine};
use crate::python::builtins::greptime_builtin;
use crate::python::coprocessor::parse::DecoratorArgs;
@@ -77,6 +80,31 @@ pub struct Coprocessor {
// but CodeObject doesn't.
#[cfg_attr(test, serde(skip))]
pub code_obj: Option<CodeObject>,
#[cfg_attr(test, serde(skip))]
pub query_engine: Option<QueryEngineWeakRef>,
}
#[derive(Clone)]
pub struct QueryEngineWeakRef(pub Weak<dyn QueryEngine>);
impl From<Weak<dyn QueryEngine>> for QueryEngineWeakRef {
fn from(value: Weak<dyn QueryEngine>) -> Self {
Self(value)
}
}
impl From<&Arc<dyn QueryEngine>> for QueryEngineWeakRef {
fn from(value: &Arc<dyn QueryEngine>) -> Self {
Self(Arc::downgrade(value))
}
}
impl std::fmt::Debug for QueryEngineWeakRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("QueryEngineWeakRef")
.field(&self.0.upgrade().map(|f| f.name().to_owned()))
.finish()
}
}
impl PartialEq for Coprocessor {
@@ -318,10 +346,99 @@ pub fn exec_coprocessor(script: &str, rb: &RecordBatch) -> Result<RecordBatch> {
// 1. parse the script and check if it's only a function with `@coprocessor` decorator, and get `args` and `returns`,
// 2. also check for exist of `args` in `rb`, if not found, return error
// TODO(discord9): cache the result of parse_copr
let copr = parse::parse_and_compile_copr(script)?;
let copr = parse::parse_and_compile_copr(script, None)?;
exec_parsed(&copr, rb)
}
#[pyclass(module = false, name = "query_engine")]
#[derive(Debug, PyPayload)]
pub struct PyQueryEngine {
inner: QueryEngineWeakRef,
}
#[pyclass]
impl PyQueryEngine {
// TODO(discord9): find a better way to call sql query api, now we don't if we are in async contex or not
/// return sql query results in List[List[PyVector]], or List[usize] for AffectedRows number if no recordbatches is returned
#[pymethod]
fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult<PyListRef> {
enum Either {
Rb(RecordBatches),
AffectedRows(usize),
}
let query = self.inner.0.upgrade();
let thread_handle = std::thread::spawn(move || -> std::result::Result<_, String> {
if let Some(engine) = query {
let stmt = QueryLanguageParser::parse_sql(s.as_str()).map_err(|e| e.to_string())?;
let plan = engine
.statement_to_plan(stmt, Default::default())
.map_err(|e| e.to_string())?;
// To prevent the error of nested creating Runtime, if is nested, use the parent runtime instead
let rt = tokio::runtime::Runtime::new().map_err(|e| e.to_string())?;
let handle = rt.handle().clone();
let res = handle.block_on(async {
let res = engine
.clone()
.execute(&plan)
.await
.map_err(|e| e.to_string());
match res {
Ok(common_query::Output::AffectedRows(cnt)) => {
Ok(Either::AffectedRows(cnt))
}
Ok(common_query::Output::RecordBatches(rbs)) => Ok(Either::Rb(rbs)),
Ok(common_query::Output::Stream(s)) => Ok(Either::Rb(
common_recordbatch::util::collect_batches(s).await.unwrap(),
)),
Err(e) => Err(e),
}
})?;
Ok(res)
} else {
Err("Query Engine is already dropped".to_string())
}
});
thread_handle
.join()
.map_err(|e| {
vm.new_system_error(format!("Dedicated thread for sql query panic: {e:?}"))
})?
.map_err(|e| vm.new_system_error(e))
.map(|rbs| match rbs {
Either::Rb(rbs) => {
let mut top_vec = Vec::with_capacity(rbs.iter().count());
for rb in rbs.iter() {
let mut vec_of_vec = Vec::with_capacity(rb.columns().len());
for v in rb.columns() {
let v = PyVector::from(v.to_owned());
vec_of_vec.push(v.to_pyobject(vm));
}
let vec_of_vec = PyList::new_ref(vec_of_vec, vm.as_ref()).to_pyobject(vm);
top_vec.push(vec_of_vec);
}
let top_vec = PyList::new_ref(top_vec, vm.as_ref());
top_vec
}
Either::AffectedRows(cnt) => {
PyList::new_ref(vec![vm.ctx.new_int(cnt).into()], vm.as_ref())
}
})
}
}
fn set_query_engine_in_scope(
scope: &Scope,
vm: &VirtualMachine,
query_engine: PyQueryEngine,
) -> Result<()> {
scope
.locals
.as_object()
.set_item("query", query_engine.to_pyobject(vm), vm)
.map_err(|e| format_py_error(e, vm))
}
pub(crate) fn exec_with_cached_vm(
copr: &Coprocessor,
rb: &RecordBatch,
@@ -333,6 +450,15 @@ pub(crate) fn exec_with_cached_vm(
// set arguments with given name and values
let scope = vm.new_scope_with_builtins();
set_items_in_scope(&scope, vm, &copr.deco_args.arg_names, args)?;
if let Some(engine) = &copr.query_engine {
let query_engine = PyQueryEngine {
inner: engine.clone(),
};
// put a object named with query of class PyQueryEngine in scope
PyQueryEngine::make_class(&vm.ctx);
set_query_engine_in_scope(&scope, vm, query_engine)?;
}
// It's safe to unwrap code_object, it's already compiled before.
let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap());
@@ -444,7 +570,7 @@ def test(a, b, c):
return add(a, b) / g.sqrt(c)
"#;
let copr = parse_and_compile_copr(script).unwrap();
let copr = parse_and_compile_copr(script, None).unwrap();
assert_eq!(copr.name, "test");
let deco_args = copr.deco_args.clone();
assert_eq!(

View File

@@ -13,8 +13,10 @@
// limitations under the License.
use std::collections::HashSet;
use std::sync::Arc;
use datatypes::prelude::ConcreteDataType;
use query::QueryEngineRef;
use rustpython_parser::ast::{Arguments, Location};
use rustpython_parser::{ast, parser};
#[cfg(test)]
@@ -423,7 +425,10 @@ fn get_return_annotations(rets: &ast::Expr<()>) -> Result<Vec<Option<AnnotationI
}
/// parse script and return `Coprocessor` struct with info extract from ast
pub fn parse_and_compile_copr(script: &str) -> Result<Coprocessor> {
pub fn parse_and_compile_copr(
script: &str,
query_engine: Option<QueryEngineRef>,
) -> Result<Coprocessor> {
let python_ast = parser::parse_program(script, "<embedded>").context(PyParseSnafu)?;
let mut coprocessor = None;
@@ -500,6 +505,7 @@ pub fn parse_and_compile_copr(script: &str) -> Result<Coprocessor> {
arg_types,
return_types,
script: script.to_owned(),
query_engine: query_engine.as_ref().map(|e| Arc::downgrade(e).into()),
});
}
} else if matches!(

View File

@@ -265,7 +265,10 @@ impl ScriptEngine for PyEngine {
}
async fn compile(&self, script: &str, _ctx: CompileContext) -> Result<PyScript> {
let copr = Arc::new(parse::parse_and_compile_copr(script)?);
let copr = Arc::new(parse::parse_and_compile_copr(
script,
Some(self.query_engine.clone()),
)?);
Ok(PyScript {
copr,
@@ -287,8 +290,7 @@ mod tests {
use super::*;
#[tokio::test]
async fn test_compile_execute() {
fn sample_script_engine() -> PyEngine {
let catalog_list = catalog::local::new_memory_catalog_list().unwrap();
let default_schema = Arc::new(MemorySchemaProvider::new());
@@ -306,7 +308,38 @@ mod tests {
let factory = QueryEngineFactory::new(catalog_list);
let query_engine = factory.query_engine();
let script_engine = PyEngine::new(query_engine.clone());
PyEngine::new(query_engine.clone())
}
#[tokio::test]
async fn test_sql_in_py() {
let script_engine = sample_script_engine();
let script = r#"
import greptime as gt
@copr(args=["number"], returns = ["number"], sql = "select * from numbers")
def test(number)->vector[u32]:
return query.sql("select * from numbers")[0][0][1]
"#;
let script = script_engine
.compile(script, CompileContext::default())
.await
.unwrap();
let _output = script.execute(EvalContext::default()).await.unwrap();
let res = common_recordbatch::util::collect_batches(match _output {
Output::Stream(s) => s,
_ => todo!(),
})
.await
.unwrap();
let rb = res.iter().next().expect("One and only one recordbatch");
assert_eq!(rb.column(0).len(), 100);
}
#[tokio::test]
async fn test_compile_execute() {
let script_engine = sample_script_engine();
// To avoid divide by zero, the script divides `add(a, b)` by `g.sqrt(c + 1)` instead of `g.sqrt(c)`
let script = r#"
@@ -351,7 +384,7 @@ import greptime as gt
@copr(args=["number"], returns = ["r"], sql="select number from numbers limit 100")
def test(a):
return gt.vector([x for x in a if x % 2 == 0])
return gt.vector([x for x in a if x % 2 == 0])
"#;
let script = script_engine
.compile(script, CompileContext::default())

View File

@@ -103,13 +103,13 @@ fn run_ron_testcases() {
info!(".ron test {}", testcase.name);
match testcase.predicate {
Predicate::ParseIsOk { result } => {
let copr = parse_and_compile_copr(&testcase.code);
let copr = parse_and_compile_copr(&testcase.code, None);
let mut copr = copr.unwrap();
copr.script = "".into();
assert_eq!(copr, *result);
}
Predicate::ParseIsErr { reason } => {
let copr = parse_and_compile_copr(&testcase.code);
let copr = parse_and_compile_copr(&testcase.code, None);
assert!(copr.is_err(), "Expect to be err, actual {copr:#?}");
let res = &copr.unwrap_err();
@@ -183,7 +183,7 @@ def a(cpu, mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
"#;
let pyast = parser::parse(python_source, parser::Mode::Interactive, "<embedded>").unwrap();
let copr = parse_and_compile_copr(python_source);
let copr = parse_and_compile_copr(python_source, None);
dbg!(copr);
}