mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-07 05:42:57 +00:00
feat: Query using sql inside python script (#884)
* feat: add weakref to QueryEngine in copr * feat: sql query in python * fix: make_class for Query Engine * fix: use `Handle::try_current` instead * fix: cache `Runtime` * fix: lock file conflict * fix: dedicated thread for blocking&fix test * test: remove unnecessary print
This commit is contained in:
@@ -18,15 +18,17 @@ pub mod parse;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashSet;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_recordbatch::{RecordBatch, RecordBatches};
|
||||
use common_telemetry::info;
|
||||
use datatypes::arrow::array::Array;
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::data_type::{ConcreteDataType, DataType};
|
||||
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use query::parser::QueryLanguageParser;
|
||||
use query::QueryEngine;
|
||||
use rustpython_compiler_core::CodeObject;
|
||||
use rustpython_vm as vm;
|
||||
use rustpython_vm::class::PyClassImpl;
|
||||
@@ -34,9 +36,10 @@ use rustpython_vm::AsObject;
|
||||
#[cfg(test)]
|
||||
use serde::Deserialize;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use vm::builtins::{PyBaseExceptionRef, PyTuple};
|
||||
use vm::builtins::{PyBaseExceptionRef, PyList, PyListRef, PyTuple};
|
||||
use vm::convert::ToPyObject;
|
||||
use vm::scope::Scope;
|
||||
use vm::{Interpreter, PyObjectRef, VirtualMachine};
|
||||
use vm::{pyclass, Interpreter, PyObjectRef, PyPayload, PyResult, VirtualMachine};
|
||||
|
||||
use crate::python::builtins::greptime_builtin;
|
||||
use crate::python::coprocessor::parse::DecoratorArgs;
|
||||
@@ -77,6 +80,31 @@ pub struct Coprocessor {
|
||||
// but CodeObject doesn't.
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub code_obj: Option<CodeObject>,
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub query_engine: Option<QueryEngineWeakRef>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct QueryEngineWeakRef(pub Weak<dyn QueryEngine>);
|
||||
|
||||
impl From<Weak<dyn QueryEngine>> for QueryEngineWeakRef {
|
||||
fn from(value: Weak<dyn QueryEngine>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Arc<dyn QueryEngine>> for QueryEngineWeakRef {
|
||||
fn from(value: &Arc<dyn QueryEngine>) -> Self {
|
||||
Self(Arc::downgrade(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for QueryEngineWeakRef {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_tuple("QueryEngineWeakRef")
|
||||
.field(&self.0.upgrade().map(|f| f.name().to_owned()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Coprocessor {
|
||||
@@ -318,10 +346,99 @@ pub fn exec_coprocessor(script: &str, rb: &RecordBatch) -> Result<RecordBatch> {
|
||||
// 1. parse the script and check if it's only a function with `@coprocessor` decorator, and get `args` and `returns`,
|
||||
// 2. also check for exist of `args` in `rb`, if not found, return error
|
||||
// TODO(discord9): cache the result of parse_copr
|
||||
let copr = parse::parse_and_compile_copr(script)?;
|
||||
let copr = parse::parse_and_compile_copr(script, None)?;
|
||||
exec_parsed(&copr, rb)
|
||||
}
|
||||
|
||||
#[pyclass(module = false, name = "query_engine")]
|
||||
#[derive(Debug, PyPayload)]
|
||||
pub struct PyQueryEngine {
|
||||
inner: QueryEngineWeakRef,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
impl PyQueryEngine {
|
||||
// TODO(discord9): find a better way to call sql query api, now we don't if we are in async contex or not
|
||||
/// return sql query results in List[List[PyVector]], or List[usize] for AffectedRows number if no recordbatches is returned
|
||||
#[pymethod]
|
||||
fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult<PyListRef> {
|
||||
enum Either {
|
||||
Rb(RecordBatches),
|
||||
AffectedRows(usize),
|
||||
}
|
||||
let query = self.inner.0.upgrade();
|
||||
let thread_handle = std::thread::spawn(move || -> std::result::Result<_, String> {
|
||||
if let Some(engine) = query {
|
||||
let stmt = QueryLanguageParser::parse_sql(s.as_str()).map_err(|e| e.to_string())?;
|
||||
let plan = engine
|
||||
.statement_to_plan(stmt, Default::default())
|
||||
.map_err(|e| e.to_string())?;
|
||||
// To prevent the error of nested creating Runtime, if is nested, use the parent runtime instead
|
||||
|
||||
let rt = tokio::runtime::Runtime::new().map_err(|e| e.to_string())?;
|
||||
let handle = rt.handle().clone();
|
||||
let res = handle.block_on(async {
|
||||
let res = engine
|
||||
.clone()
|
||||
.execute(&plan)
|
||||
.await
|
||||
.map_err(|e| e.to_string());
|
||||
match res {
|
||||
Ok(common_query::Output::AffectedRows(cnt)) => {
|
||||
Ok(Either::AffectedRows(cnt))
|
||||
}
|
||||
Ok(common_query::Output::RecordBatches(rbs)) => Ok(Either::Rb(rbs)),
|
||||
Ok(common_query::Output::Stream(s)) => Ok(Either::Rb(
|
||||
common_recordbatch::util::collect_batches(s).await.unwrap(),
|
||||
)),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
})?;
|
||||
Ok(res)
|
||||
} else {
|
||||
Err("Query Engine is already dropped".to_string())
|
||||
}
|
||||
});
|
||||
thread_handle
|
||||
.join()
|
||||
.map_err(|e| {
|
||||
vm.new_system_error(format!("Dedicated thread for sql query panic: {e:?}"))
|
||||
})?
|
||||
.map_err(|e| vm.new_system_error(e))
|
||||
.map(|rbs| match rbs {
|
||||
Either::Rb(rbs) => {
|
||||
let mut top_vec = Vec::with_capacity(rbs.iter().count());
|
||||
for rb in rbs.iter() {
|
||||
let mut vec_of_vec = Vec::with_capacity(rb.columns().len());
|
||||
for v in rb.columns() {
|
||||
let v = PyVector::from(v.to_owned());
|
||||
vec_of_vec.push(v.to_pyobject(vm));
|
||||
}
|
||||
let vec_of_vec = PyList::new_ref(vec_of_vec, vm.as_ref()).to_pyobject(vm);
|
||||
top_vec.push(vec_of_vec);
|
||||
}
|
||||
let top_vec = PyList::new_ref(top_vec, vm.as_ref());
|
||||
top_vec
|
||||
}
|
||||
Either::AffectedRows(cnt) => {
|
||||
PyList::new_ref(vec![vm.ctx.new_int(cnt).into()], vm.as_ref())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn set_query_engine_in_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
query_engine: PyQueryEngine,
|
||||
) -> Result<()> {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item("query", query_engine.to_pyobject(vm), vm)
|
||||
.map_err(|e| format_py_error(e, vm))
|
||||
}
|
||||
|
||||
pub(crate) fn exec_with_cached_vm(
|
||||
copr: &Coprocessor,
|
||||
rb: &RecordBatch,
|
||||
@@ -333,6 +450,15 @@ pub(crate) fn exec_with_cached_vm(
|
||||
// set arguments with given name and values
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
set_items_in_scope(&scope, vm, &copr.deco_args.arg_names, args)?;
|
||||
if let Some(engine) = &copr.query_engine {
|
||||
let query_engine = PyQueryEngine {
|
||||
inner: engine.clone(),
|
||||
};
|
||||
|
||||
// put a object named with query of class PyQueryEngine in scope
|
||||
PyQueryEngine::make_class(&vm.ctx);
|
||||
set_query_engine_in_scope(&scope, vm, query_engine)?;
|
||||
}
|
||||
|
||||
// It's safe to unwrap code_object, it's already compiled before.
|
||||
let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap());
|
||||
@@ -444,7 +570,7 @@ def test(a, b, c):
|
||||
return add(a, b) / g.sqrt(c)
|
||||
"#;
|
||||
|
||||
let copr = parse_and_compile_copr(script).unwrap();
|
||||
let copr = parse_and_compile_copr(script, None).unwrap();
|
||||
assert_eq!(copr.name, "test");
|
||||
let deco_args = copr.deco_args.clone();
|
||||
assert_eq!(
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use query::QueryEngineRef;
|
||||
use rustpython_parser::ast::{Arguments, Location};
|
||||
use rustpython_parser::{ast, parser};
|
||||
#[cfg(test)]
|
||||
@@ -423,7 +425,10 @@ fn get_return_annotations(rets: &ast::Expr<()>) -> Result<Vec<Option<AnnotationI
|
||||
}
|
||||
|
||||
/// parse script and return `Coprocessor` struct with info extract from ast
|
||||
pub fn parse_and_compile_copr(script: &str) -> Result<Coprocessor> {
|
||||
pub fn parse_and_compile_copr(
|
||||
script: &str,
|
||||
query_engine: Option<QueryEngineRef>,
|
||||
) -> Result<Coprocessor> {
|
||||
let python_ast = parser::parse_program(script, "<embedded>").context(PyParseSnafu)?;
|
||||
|
||||
let mut coprocessor = None;
|
||||
@@ -500,6 +505,7 @@ pub fn parse_and_compile_copr(script: &str) -> Result<Coprocessor> {
|
||||
arg_types,
|
||||
return_types,
|
||||
script: script.to_owned(),
|
||||
query_engine: query_engine.as_ref().map(|e| Arc::downgrade(e).into()),
|
||||
});
|
||||
}
|
||||
} else if matches!(
|
||||
|
||||
@@ -265,7 +265,10 @@ impl ScriptEngine for PyEngine {
|
||||
}
|
||||
|
||||
async fn compile(&self, script: &str, _ctx: CompileContext) -> Result<PyScript> {
|
||||
let copr = Arc::new(parse::parse_and_compile_copr(script)?);
|
||||
let copr = Arc::new(parse::parse_and_compile_copr(
|
||||
script,
|
||||
Some(self.query_engine.clone()),
|
||||
)?);
|
||||
|
||||
Ok(PyScript {
|
||||
copr,
|
||||
@@ -287,8 +290,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compile_execute() {
|
||||
fn sample_script_engine() -> PyEngine {
|
||||
let catalog_list = catalog::local::new_memory_catalog_list().unwrap();
|
||||
|
||||
let default_schema = Arc::new(MemorySchemaProvider::new());
|
||||
@@ -306,7 +308,38 @@ mod tests {
|
||||
let factory = QueryEngineFactory::new(catalog_list);
|
||||
let query_engine = factory.query_engine();
|
||||
|
||||
let script_engine = PyEngine::new(query_engine.clone());
|
||||
PyEngine::new(query_engine.clone())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sql_in_py() {
|
||||
let script_engine = sample_script_engine();
|
||||
|
||||
let script = r#"
|
||||
import greptime as gt
|
||||
|
||||
@copr(args=["number"], returns = ["number"], sql = "select * from numbers")
|
||||
def test(number)->vector[u32]:
|
||||
return query.sql("select * from numbers")[0][0][1]
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let _output = script.execute(EvalContext::default()).await.unwrap();
|
||||
let res = common_recordbatch::util::collect_batches(match _output {
|
||||
Output::Stream(s) => s,
|
||||
_ => todo!(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let rb = res.iter().next().expect("One and only one recordbatch");
|
||||
assert_eq!(rb.column(0).len(), 100);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compile_execute() {
|
||||
let script_engine = sample_script_engine();
|
||||
|
||||
// To avoid divide by zero, the script divides `add(a, b)` by `g.sqrt(c + 1)` instead of `g.sqrt(c)`
|
||||
let script = r#"
|
||||
@@ -351,7 +384,7 @@ import greptime as gt
|
||||
|
||||
@copr(args=["number"], returns = ["r"], sql="select number from numbers limit 100")
|
||||
def test(a):
|
||||
return gt.vector([x for x in a if x % 2 == 0])
|
||||
return gt.vector([x for x in a if x % 2 == 0])
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
|
||||
@@ -103,13 +103,13 @@ fn run_ron_testcases() {
|
||||
info!(".ron test {}", testcase.name);
|
||||
match testcase.predicate {
|
||||
Predicate::ParseIsOk { result } => {
|
||||
let copr = parse_and_compile_copr(&testcase.code);
|
||||
let copr = parse_and_compile_copr(&testcase.code, None);
|
||||
let mut copr = copr.unwrap();
|
||||
copr.script = "".into();
|
||||
assert_eq!(copr, *result);
|
||||
}
|
||||
Predicate::ParseIsErr { reason } => {
|
||||
let copr = parse_and_compile_copr(&testcase.code);
|
||||
let copr = parse_and_compile_copr(&testcase.code, None);
|
||||
assert!(copr.is_err(), "Expect to be err, actual {copr:#?}");
|
||||
|
||||
let res = &copr.unwrap_err();
|
||||
@@ -183,7 +183,7 @@ def a(cpu, mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#;
|
||||
let pyast = parser::parse(python_source, parser::Mode::Interactive, "<embedded>").unwrap();
|
||||
let copr = parse_and_compile_copr(python_source);
|
||||
let copr = parse_and_compile_copr(python_source, None);
|
||||
dbg!(copr);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user