From 7c34b009ec7f25321a6e29ffadbefdee07cbb70a Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Fri, 10 Feb 2023 11:21:57 +0800 Subject: [PATCH] feat: bind DataFrame API into python script (#945) * chore: remove unused magic fn * feat: dataframe * feat: add data_frame crate * feat: more api binded * fix: `Comparable` for overload op * fix: license&more test * chore: PR advices * chore: more PR advices --- src/script/src/python.rs | 1 + src/script/src/python/coprocessor.rs | 8 + src/script/src/python/dataframe.rs | 344 +++++++++++++++++++++++++++ src/script/src/python/engine.rs | 29 ++- src/script/src/python/testcases.ron | 68 ++++++ src/script/src/python/utils.rs | 14 ++ src/script/src/python/vector.rs | 43 ---- 7 files changed, 463 insertions(+), 44 deletions(-) create mode 100644 src/script/src/python/dataframe.rs diff --git a/src/script/src/python.rs b/src/script/src/python.rs index 30e402ef29..b340d04cf7 100644 --- a/src/script/src/python.rs +++ b/src/script/src/python.rs @@ -16,6 +16,7 @@ mod builtins; pub(crate) mod coprocessor; +mod dataframe; mod engine; pub mod error; #[cfg(test)] diff --git a/src/script/src/python/coprocessor.rs b/src/script/src/python/coprocessor.rs index c10e5dc513..f96d56a5e6 100644 --- a/src/script/src/python/coprocessor.rs +++ b/src/script/src/python/coprocessor.rs @@ -43,6 +43,7 @@ use vm::{pyclass, Interpreter, PyObjectRef, PyPayload, PyResult, VirtualMachine} use crate::python::builtins::greptime_builtin; use crate::python::coprocessor::parse::DecoratorArgs; +use crate::python::dataframe::data_frame::{self, set_dataframe_in_scope}; use crate::python::error::{ ensure, ret_other_error_with, ArrowSnafu, NewRecordBatchSnafu, OtherSnafu, Result, TypeCastSnafu, @@ -450,6 +451,8 @@ pub(crate) fn exec_with_cached_vm( // set arguments with given name and values let scope = vm.new_scope_with_builtins(); set_items_in_scope(&scope, vm, &copr.deco_args.arg_names, args)?; + set_dataframe_in_scope(&scope, vm, "dataframe", rb)?; + if let Some(engine) = &copr.query_engine { let query_engine = PyQueryEngine { inner: engine.clone(), @@ -500,6 +503,7 @@ pub(crate) fn init_interpreter() -> Arc { // TODO(discord9): edge cases, can't use "..Default::default" because Settings is `#[non_exhaustive]` // so more in here: https://internals.rust-lang.org/t/allow-constructing-non-exhaustive-structs-using-default-default/13868 let mut settings = vm::Settings::default(); + // disable SIG_INT handler so our own binary can take ctrl_c handler settings.no_sig_int = true; let interpreter = Arc::new(vm::Interpreter::with_init(settings, |vm| { // not using full stdlib to prevent security issue, instead filter out a few simple util module @@ -517,6 +521,10 @@ pub(crate) fn init_interpreter() -> Arc { // add our own custom datatype and module PyVector::make_class(&vm.ctx); vm.add_native_module("greptime", Box::new(greptime_builtin::make_module)); + + data_frame::PyDataFrame::make_class(&vm.ctx); + data_frame::PyExpr::make_class(&vm.ctx); + vm.add_native_module("data_frame", Box::new(data_frame::make_module)); })); info!("Initialized Python interpreter."); interpreter diff --git a/src/script/src/python/dataframe.rs b/src/script/src/python/dataframe.rs new file mode 100644 index 0000000000..37a8551cd4 --- /dev/null +++ b/src/script/src/python/dataframe.rs @@ -0,0 +1,344 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use rustpython_vm::pymodule as rspymodule; + +/// with `register_batch`, and then wrap DataFrame API in it +#[rspymodule] +pub(crate) mod data_frame { + use common_recordbatch::{DfRecordBatch, RecordBatch}; + use datafusion::dataframe::DataFrame as DfDataFrame; + use datafusion_expr::Expr as DfExpr; + use rustpython_vm::builtins::{PyList, PyListRef}; + use rustpython_vm::function::PyComparisonValue; + use rustpython_vm::types::{Comparable, PyComparisonOp}; + use rustpython_vm::{ + pyclass as rspyclass, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, + }; + use snafu::ResultExt; + + use crate::python::error::DataFusionSnafu; + use crate::python::utils::block_on_async; + #[rspyclass(module = "data_frame", name = "DataFrame")] + #[derive(PyPayload, Debug)] + pub struct PyDataFrame { + pub inner: DfDataFrame, + } + + impl From for PyDataFrame { + fn from(inner: DfDataFrame) -> Self { + Self { inner } + } + } + /// set DataFrame instance into current scope with given name + pub fn set_dataframe_in_scope( + scope: &rustpython_vm::scope::Scope, + vm: &VirtualMachine, + name: &str, + rb: &RecordBatch, + ) -> crate::python::error::Result<()> { + let df = PyDataFrame::from_record_batch(rb.df_record_batch())?; + scope + .locals + .set_item(name, vm.new_pyobj(df), vm) + .map_err(|e| crate::python::utils::format_py_error(e, vm)) + } + #[rspyclass] + impl PyDataFrame { + /// TODO(discord9): error handling + fn from_record_batch(rb: &DfRecordBatch) -> crate::python::error::Result { + let ctx = datafusion::execution::context::SessionContext::new(); + let inner = ctx.read_batch(rb.clone()).context(DataFusionSnafu)?; + Ok(Self { inner }) + } + + #[pymethod] + fn select_columns(&self, columns: Vec, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .select_columns(&columns.iter().map(AsRef::as_ref).collect::>()) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn select(&self, expr_list: Vec, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .select(expr_list.iter().map(|e| e.inner.clone()).collect()) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn filter(&self, predicate: PyExprRef, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .filter(predicate.inner.clone()) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn aggregate( + &self, + group_expr: Vec, + aggr_expr: Vec, + vm: &VirtualMachine, + ) -> PyResult { + let ret = self.inner.clone().aggregate( + group_expr.iter().map(|i| i.inner.clone()).collect(), + aggr_expr.iter().map(|i| i.inner.clone()).collect(), + ); + Ok(ret.map_err(|e| vm.new_runtime_error(e.to_string()))?.into()) + } + + #[pymethod] + fn limit(&self, skip: usize, fetch: Option, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .limit(skip, fetch) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn union(&self, df: PyRef, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .union(df.inner.clone()) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn union_distinct(&self, df: PyRef, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .union_distinct(df.inner.clone()) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn distinct(&self, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .distinct() + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn sort(&self, expr: Vec, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .sort(expr.iter().map(|e| e.inner.clone()).collect()) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn join( + &self, + right: PyRef, + join_type: String, + left_cols: Vec, + right_cols: Vec, + filter: Option, + vm: &VirtualMachine, + ) -> PyResult { + use datafusion::prelude::JoinType; + let join_type = match join_type.as_str() { + "inner" | "Inner" => JoinType::Inner, + "left" | "Left" => JoinType::Left, + "right" | "Right" => JoinType::Right, + "full" | "Full" => JoinType::Full, + "leftSemi" | "LeftSemi" => JoinType::LeftSemi, + "rightSemi" | "RightSemi" => JoinType::RightSemi, + "leftAnti" | "LeftAnti" => JoinType::LeftAnti, + "rightAnti" | "RightAnti" => JoinType::RightAnti, + _ => return Err(vm.new_runtime_error(format!("Unknown join type: {join_type}"))), + }; + let left_cols: Vec<&str> = left_cols.iter().map(AsRef::as_ref).collect(); + let right_cols: Vec<&str> = right_cols.iter().map(AsRef::as_ref).collect(); + let filter = filter.map(|f| f.inner.clone()); + Ok(self + .inner + .clone() + .join( + right.inner.clone(), + join_type, + &left_cols, + &right_cols, + filter, + ) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn intersect(&self, df: PyRef, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .intersect(df.inner.clone()) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + fn except(&self, df: PyRef, vm: &VirtualMachine) -> PyResult { + Ok(self + .inner + .clone() + .except(df.inner.clone()) + .map_err(|e| vm.new_runtime_error(e.to_string()))? + .into()) + } + + #[pymethod] + /// collect `DataFrame` results into `List[List[Vector]]` + fn collect(&self, vm: &VirtualMachine) -> PyResult { + let inner = self.inner.clone(); + let res = block_on_async(async { inner.collect().await }); + let res = res + .map_err(|e| vm.new_runtime_error(format!("{e:?}")))? + .map_err(|e| vm.new_runtime_error(e.to_string()))?; + let outer_list: Vec<_> = res + .iter() + .map(|elem| -> PyResult<_> { + let inner_list: Vec<_> = elem + .columns() + .iter() + .map(|arr| -> PyResult<_> { + datatypes::vectors::Helper::try_into_vector(arr) + .map(crate::python::PyVector::from) + .map(|v| vm.new_pyobj(v)) + .map_err(|e| vm.new_runtime_error(e.to_string())) + }) + .collect::>()?; + let inner_list = PyList::new_ref(inner_list, vm.as_ref()); + Ok(inner_list.into()) + }) + .collect::>()?; + Ok(PyList::new_ref(outer_list, vm.as_ref())) + } + } + + #[rspyclass(module = "data_frame", name = "Expr")] + #[derive(PyPayload, Debug, Clone)] + pub struct PyExpr { + pub inner: DfExpr, + } + + #[pyfunction] + fn col(name: String, vm: &VirtualMachine) -> PyExprRef { + let expr: PyExpr = DfExpr::Column(datafusion_common::Column::from_name(name)).into(); + expr.into_ref(vm) + } + + // TODO(discord9): lit function that take PyObject and turn it into ScalarValue + + type PyExprRef = PyRef; + + impl From for PyExpr { + fn from(value: DfExpr) -> Self { + Self { inner: value } + } + } + + impl Comparable for PyExpr { + fn slot_richcompare( + zelf: &PyObject, + other: &PyObject, + op: PyComparisonOp, + vm: &VirtualMachine, + ) -> PyResult> { + if let (Some(zelf), Some(other)) = + (zelf.downcast_ref::(), other.downcast_ref::()) + { + let ret = zelf.richcompare((**other).clone(), op, vm)?; + let ret = ret.into_pyobject(vm); + Ok(rustpython_vm::function::Either::A(ret)) + } else { + Err(vm.new_type_error(format!( + "unexpected payload {zelf:?} and {other:?} for op {}", + op.method_name(&vm.ctx).as_str() + ))) + } + } + fn cmp( + _zelf: &rustpython_vm::Py, + _other: &PyObject, + _op: PyComparisonOp, + _vm: &VirtualMachine, + ) -> PyResult { + Ok(PyComparisonValue::NotImplemented) + } + } + + #[rspyclass(with(Comparable))] + impl PyExpr { + fn richcompare( + &self, + other: Self, + op: PyComparisonOp, + _vm: &VirtualMachine, + ) -> PyResult { + let f = match op { + PyComparisonOp::Eq => DfExpr::eq, + PyComparisonOp::Ne => DfExpr::not_eq, + PyComparisonOp::Gt => DfExpr::gt, + PyComparisonOp::Lt => DfExpr::lt, + PyComparisonOp::Ge => DfExpr::gt_eq, + PyComparisonOp::Le => DfExpr::lt_eq, + }; + Ok(f(self.inner.clone(), other.inner).into()) + } + #[pymethod] + fn alias(&self, name: String) -> PyResult { + Ok(self.inner.clone().alias(name).into()) + } + + #[pymethod(magic)] + fn and(&self, other: PyExprRef) -> PyResult { + Ok(self.inner.clone().and(other.inner.clone()).into()) + } + #[pymethod(magic)] + fn or(&self, other: PyExprRef) -> PyResult { + Ok(self.inner.clone().or(other.inner.clone()).into()) + } + + /// `~` operator, return `!self` + #[pymethod(magic)] + fn invert(&self) -> PyResult { + Ok(self.inner.clone().not().into()) + } + + /// sort ascending&nulls_first + #[pymethod] + fn sort(&self) -> PyExpr { + self.inner.clone().sort(true, true).into() + } + } +} diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index 7a68a63e6b..fd89ee74be 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -320,7 +320,34 @@ import greptime as gt @copr(args=["number"], returns = ["number"], sql = "select * from numbers") def test(number)->vector[u32]: - return query.sql("select * from numbers")[0][0][1] + return query.sql("select * from numbers")[0][0] +"#; + let script = script_engine + .compile(script, CompileContext::default()) + .await + .unwrap(); + let output = script.execute(EvalContext::default()).await.unwrap(); + let res = common_recordbatch::util::collect_batches(match output { + Output::Stream(s) => s, + _ => unreachable!(), + }) + .await + .unwrap(); + let rb = res.iter().next().expect("One and only one recordbatch"); + assert_eq!(rb.column(0).len(), 100); + } + + #[tokio::test] + async fn test_data_frame_in_py() { + let script_engine = sample_script_engine(); + + let script = r#" +import greptime as gt +from data_frame import col + +@copr(args=["number"], returns = ["number"], sql = "select * from numbers") +def test(number)->vector[u32]: + return dataframe.filter(col("number")==col("number")).collect()[0][0] "#; let script = script_engine .compile(script, CompileContext::default()) diff --git a/src/script/src/python/testcases.ron b/src/script/src/python/testcases.ron index ab70ec0662..3ebd2d5e4c 100644 --- a/src/script/src/python/testcases.ron +++ b/src/script/src/python/testcases.ron @@ -476,6 +476,74 @@ def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], math.ceil(0.2) import string return cpu + mem, 1 +"#, + predicate: ExecIsOk( + fields: [ + ( + datatype: Some(Float64(())), + is_nullable: true + ), + ( + datatype: Some(Float32(())), + is_nullable: false + ), + ], + columns: [ + ( + ty: Float64, + len: 4 + ), + ( + ty: Float32, + len: 4 + ) + ] + ) + ), + ( + // constant column(int) + name: "test_data_frame", + code: r#" +from data_frame import col +@copr(args=["cpu", "mem"], returns=["perf", "what"]) +def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], + vector[f32]): + ret = dataframe.select([col("cpu"), col("mem")]).collect()[0] + return ret[0], ret[1] +"#, + predicate: ExecIsOk( + fields: [ + ( + datatype: Some(Float64(())), + is_nullable: true + ), + ( + datatype: Some(Float32(())), + is_nullable: false + ), + ], + columns: [ + ( + ty: Float64, + len: 4 + ), + ( + ty: Float32, + len: 4 + ) + ] + ) + ), + ( + // constant column(int) + name: "test_data_frame", + code: r#" +from data_frame import col +@copr(args=["cpu", "mem"], returns=["perf", "what"]) +def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], + vector[f32]): + ret = dataframe.filter(col("cpu")>col("mem")).collect()[0] + return ret[0], ret[1] "#, predicate: ExecIsOk( fields: [ diff --git a/src/script/src/python/utils.rs b/src/script/src/python/utils.rs index a0f1df2a8b..e9c6fb6775 100644 --- a/src/script/src/python/utils.rs +++ b/src/script/src/python/utils.rs @@ -20,6 +20,7 @@ use datatypes::prelude::ScalarVector; use datatypes::vectors::{ BooleanVector, Float64Vector, Helper, Int64Vector, NullVector, StringVector, VectorRef, }; +use futures::Future; use rustpython_vm::builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt, PyList, PyStr}; use rustpython_vm::{PyObjectRef, PyPayload, PyRef, VirtualMachine}; use snafu::{Backtrace, GenerateImplicitData, OptionExt, ResultExt}; @@ -113,3 +114,16 @@ pub fn py_vec_obj_to_array( ret_other_error_with(format!("Expect a vector or a constant, found {obj:?}")).fail() } } + +/// a terrible hack to call async from sync by: +/// TODO(discord9): find a better way +/// 1. spawn a new thread +/// 2. create a new runtime in new thread and call `block_on` on it +pub fn block_on_async(f: F) -> std::thread::Result +where + F: Future + Send + 'static, + T: Send + 'static, +{ + let rt = tokio::runtime::Runtime::new().map_err(|e| Box::new(e) as _)?; + std::thread::spawn(move || rt.block_on(f)).join() +} diff --git a/src/script/src/python/vector.rs b/src/script/src/python/vector.rs index 0d8294772c..c776f140b3 100644 --- a/src/script/src/python/vector.rs +++ b/src/script/src/python/vector.rs @@ -436,48 +436,6 @@ impl PyVector { } } - // it seems rustpython's richcompare support is not good - // The Comparable Trait only support normal cmp - // (yes there is a slot_richcompare function, but it is not used in anywhere) - // so use our own function - // TODO(discord9): test those function - - #[pymethod(name = "eq")] - #[pymethod(magic)] - fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.richcompare(other, PyComparisonOp::Eq, vm) - } - - #[pymethod(name = "ne")] - #[pymethod(magic)] - fn ne(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.richcompare(other, PyComparisonOp::Ne, vm) - } - - #[pymethod(name = "gt")] - #[pymethod(magic)] - fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.richcompare(other, PyComparisonOp::Gt, vm) - } - - #[pymethod(name = "lt")] - #[pymethod(magic)] - fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.richcompare(other, PyComparisonOp::Lt, vm) - } - - #[pymethod(name = "ge")] - #[pymethod(magic)] - fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.richcompare(other, PyComparisonOp::Ge, vm) - } - - #[pymethod(name = "le")] - #[pymethod(magic)] - fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - self.richcompare(other, PyComparisonOp::Le, vm) - } - #[pymethod(magic)] fn and(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult { let left = self.to_arrow_array(); @@ -516,7 +474,6 @@ impl PyVector { #[pymethod(magic)] fn invert(&self, vm: &VirtualMachine) -> PyResult { - dbg!(); let left = self.to_arrow_array(); let left = left .as_any()