diff --git a/src/script/src/python/ffi_types/copr.rs b/src/script/src/python/ffi_types/copr.rs index 6b95d75e20..1ca9f9b482 100644 --- a/src/script/src/python/ffi_types/copr.rs +++ b/src/script/src/python/ffi_types/copr.rs @@ -261,11 +261,12 @@ pub(crate) fn check_args_anno_real_type( .unwrap_or(true), OtherSnafu { reason: format!( - "column {}'s Type annotation is {:?}, but actual type is {:?}", + "column {}'s Type annotation is {:?}, but actual type is {:?} with nullable=={}", // It's safe to unwrap here, we already ensure the args and types number is the same when parsing copr.deco_args.arg_names.as_ref().unwrap()[idx], anno_ty, - real_ty + real_ty, + is_nullable ) } ) @@ -344,20 +345,35 @@ pub(crate) enum Either { Rb(RecordBatches), AffectedRows(usize), } + +impl PyQueryEngine { + pub(crate) fn sql_to_rb(&self, sql: String) -> StdResult { + let res = self.query_with_new_thread(sql.clone())?; + match res { + Either::Rb(rbs) => { + let rb = compute::concat_batches( + rbs.schema().arrow_schema(), + rbs.iter().map(|r| r.df_record_batch()), + ) + .map_err(|e| format!("Concat batches failed for query {sql}: {e}"))?; + RecordBatch::try_from_df_record_batch(rbs.schema(), rb).map_err(|e| + format!( + "Convert datafusion record batch to record batch failed for query {sql}: {e}" + ) + ) + } + Either::AffectedRows(_) => Err(format!("Expect actual results from query {sql}")), + } + } +} + #[rspyclass] impl PyQueryEngine { pub(crate) fn from_weakref(inner: QueryEngineWeakRef) -> Self { Self { inner } } - #[cfg(feature = "pyo3_backend")] - pub(crate) fn get_ref(&self) -> Option> { - self.inner.0.upgrade() - } - pub(crate) fn query_with_new_thread( - &self, - query: Option>, - s: String, - ) -> StdResult { + pub(crate) fn query_with_new_thread(&self, s: String) -> StdResult { + let query = self.inner.0.upgrade(); let thread_handle = std::thread::spawn(move || -> std::result::Result<_, String> { if let Some(engine) = query { let stmt = QueryLanguageParser::parse_sql(&s).map_err(|e| e.to_string())?; @@ -401,8 +417,7 @@ impl PyQueryEngine { /// return sql query results in List[List[PyVector]], or List[usize] for AffectedRows number if no recordbatches is returned #[pymethod] fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult { - let query = self.inner.0.upgrade(); - self.query_with_new_thread(query, s) + self.query_with_new_thread(s) .map_err(|e| vm.new_system_error(e)) .map(|rbs| match rbs { Either::Rb(rbs) => { diff --git a/src/script/src/python/ffi_types/pair_tests.rs b/src/script/src/python/ffi_types/pair_tests.rs index 7d92d24669..c860c4e2a7 100644 --- a/src/script/src/python/ffi_types/pair_tests.rs +++ b/src/script/src/python/ffi_types/pair_tests.rs @@ -109,10 +109,12 @@ async fn integrated_py_copr_test() { } } +#[allow(clippy::print_stdout)] #[test] fn pyo3_rspy_test_in_pairs() { let testcases = sample_test_case(); for case in testcases { + println!("Testcase: {}", case.script); eval_rspy(case.clone()); #[cfg(feature = "pyo3_backend")] eval_pyo3(case); @@ -122,6 +124,9 @@ fn pyo3_rspy_test_in_pairs() { fn check_equal(v0: VectorRef, v1: VectorRef) -> bool { let v0 = v0.to_arrow_array(); let v1 = v1.to_arrow_array(); + if v0.len() != v1.len() { + return false; + } fn is_float(ty: &ArrowDataType) -> bool { use ArrowDataType::*; matches!(ty, Float16 | Float32 | Float64) diff --git a/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs b/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs index 8805552102..1e6ed16bc0 100644 --- a/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs +++ b/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs @@ -17,10 +17,11 @@ use std::f64::consts; use std::sync::Arc; use datatypes::prelude::ScalarVector; +#[cfg(feature = "pyo3_backend")] +use datatypes::vectors::UInt32Vector; use datatypes::vectors::{BooleanVector, Float64Vector, Int32Vector, Int64Vector, VectorRef}; -use super::CoprTestCase; -use crate::python::ffi_types::pair_tests::CodeBlockTestCase; +use crate::python::ffi_types::pair_tests::{CodeBlockTestCase, CoprTestCase}; macro_rules! vector { ($ty: ident, $slice: expr) => { Arc::new($ty::from_slice($slice)) as VectorRef @@ -56,14 +57,14 @@ def boolean_array() -> vector[f64]: from greptime import vector from greptime import query, dataframe - try: - print("query()=", query()) - except KeyError as e: - print("query()=", e) + print("query()=", query()) + assert "query_engine object at" in str(query()) try: print("dataframe()=", dataframe()) except KeyError as e: print("dataframe()=", e) + print(str(e), type(str(e)), 'No __dataframe__' in str(e)) + assert 'No __dataframe__' in str(e) v = vector([1.0, 2.0, 3.0]) # This returns a vector([2.0]) @@ -72,6 +73,57 @@ def boolean_array() -> vector[f64]: .to_string(), expect: Some(ronish!("value": vector!(Float64Vector, [2.0f64]))), }, + CoprTestCase { + script: r#" +@copr(returns=["value"], backend="rspy") +def boolean_array() -> vector[f64]: + from greptime import vector, col + from greptime import query, dataframe, PyDataFrame + + df = PyDataFrame.from_sql("select number from numbers limit 5") + print("df from sql=", df) + collected = df.collect() + print("df.collect()=", collected) + assert len(collected[0][0]) == 5 + df = PyDataFrame.from_sql("select number from numbers limit 5").filter(col("number") > 2) + collected = df.collect() + assert len(collected[0][0]) == 2 + print("query()=", query()) + + assert "query_engine object at" in repr(query()) + try: + print("dataframe()=", dataframe()) + except KeyError as e: + print("dataframe()=", e) + assert "__dataframe__" in str(e) + + v = vector([1.0, 2.0, 3.0]) + # This returns a vector([2.0]) + return v[(v > 1) & (v < 3)] +"# + .to_string(), + expect: Some(ronish!("value": vector!(Float64Vector, [2.0f64]))), + }, + #[cfg(feature = "pyo3_backend")] + CoprTestCase { + script: r#" +@copr(returns=["value"], backend="pyo3") +def boolean_array() -> vector[f64]: + from greptime import vector + from greptime import query, dataframe, PyDataFrame, col + df = PyDataFrame.from_sql("select number from numbers limit 5") + print("df from sql=", df) + ret = df.collect() + print("df.collect()=", ret) + assert len(ret[0][0]) == 5 + df = PyDataFrame.from_sql("select number from numbers limit 5").filter(col("number") > 2) + collected = df.collect() + assert len(collected[0][0]) == 2 + return ret[0][0] +"# + .to_string(), + expect: Some(ronish!("value": vector!(UInt32Vector, [0, 1, 2, 3, 4]))), + }, #[cfg(feature = "pyo3_backend")] CoprTestCase { script: r#" @@ -178,6 +230,64 @@ def answer() -> vector[i64]: .to_string(), expect: Some(ronish!("number": vector!(Int64Vector, [1, 2]))), }, + #[cfg(feature = "pyo3_backend")] + CoprTestCase { + script: r#" +@copr(returns=["value"], backend="pyo3") +def answer() -> vector[i64]: + from greptime import vector + import pyarrow as pa + a = vector.from_pyarrow(pa.array([42, 43, 44])) + return a[0:1] +"# + .to_string(), + expect: Some(ronish!("value": vector!(Int64Vector, [42]))), + }, + #[cfg(feature = "pyo3_backend")] + CoprTestCase { + script: r#" +@copr(returns=["value"], backend="pyo3") +def answer() -> vector[i64]: + from greptime import vector + a = vector([42, 43, 44]) + # slicing test + assert a[0:2] == a[:-1] + assert len(a[:-1]) == vector([42,44]) + assert a[0:1] == a[:-2] + assert a[0:1] == vector([42]) + assert a[:-2] == vector([42]) + assert a[:-1:2] == vector([42]) + assert a[::2] == vector([42,44]) + # negative step + assert a[-1::-2] == vector([44, 42]) + assert a[-2::-2] == vector([44]) + return a[0:1] +"# + .to_string(), + expect: Some(ronish!("value": vector!(Int64Vector, [42]))), + }, + CoprTestCase { + script: r#" +@copr(returns=["value"], backend="rspy") +def answer() -> vector[i64]: + from greptime import vector + a = vector([42, 43, 44]) + # slicing test + assert a[0:2] == a[:-1] + assert len(a[:-1]) == vector([42,44]) + assert a[0:1] == a[:-2] + assert a[0:1] == vector([42]) + assert a[:-2] == vector([42]) + assert a[:-1:2] == vector([42]) + assert a[::2] == vector([42,44]) + # negative step + assert a[-1::-2] == vector([44, 42]) + assert a[-2::-2] == vector([44]) + return a[-2:-1] +"# + .to_string(), + expect: Some(ronish!("value": vector!(Int64Vector, [43]))), + }, ] } @@ -185,6 +295,7 @@ def answer() -> vector[i64]: /// Using a function to generate testcase instead of `.ron` configure file because it's more flexible and we are in #[cfg(test)] so no binary bloat worrying #[allow(clippy::approx_constant)] pub(super) fn sample_test_case() -> Vec { + // TODO(discord9): detailed tests for slicing vector vec![ CodeBlockTestCase { input: ronish! { @@ -192,13 +303,54 @@ pub(super) fn sample_test_case() -> Vec { }, script: r#" from greptime import * -ret = a+3.0 -ret = ret * 2.0 -ret = ret / 2.0 -ret = ret - 3.0 +ret = a[0:1] ret"# .to_string(), - expect: vector!(Float64Vector, [1.0f64, 2.0, 3.0]), + expect: vector!(Float64Vector, [1.0f64]), + }, + CodeBlockTestCase { + input: ronish! { + "a": vector!(Float64Vector, [1.0f64, 2.0, 3.0]) + }, + script: r#" +from greptime import * +ret = a[0:1:1] +ret"# + .to_string(), + expect: vector!(Float64Vector, [1.0f64]), + }, + CodeBlockTestCase { + input: ronish! { + "a": vector!(Float64Vector, [1.0f64, 2.0, 3.0]) + }, + script: r#" +from greptime import * +ret = a[-2:-1] +ret"# + .to_string(), + expect: vector!(Float64Vector, [2.0f64]), + }, + CodeBlockTestCase { + input: ronish! { + "a": vector!(Float64Vector, [1.0f64, 2.0, 3.0]) + }, + script: r#" +from greptime import * +ret = a[-1:-2:-1] +ret"# + .to_string(), + expect: vector!(Float64Vector, [3.0f64]), + }, + CodeBlockTestCase { + input: ronish! { + "a": vector!(Float64Vector, [1.0f64, 2.0, 3.0]) + }, + script: r#" +from greptime import * +ret = a[-1:-4:-1] +ret"# + .to_string(), + expect: vector!(Float64Vector, [3.0f64, 2.0, 1.0]), }, CodeBlockTestCase { input: ronish! { diff --git a/src/script/src/python/ffi_types/vector.rs b/src/script/src/python/ffi_types/vector.rs index 3eaf8c8016..3660e6c43d 100644 --- a/src/script/src/python/ffi_types/vector.rs +++ b/src/script/src/python/ffi_types/vector.rs @@ -381,7 +381,6 @@ impl PyVector { // adjust_indices so negative number is transform to usize let (mut range, step, slice_len) = slice.adjust_indices(self.len()); let vector = self.as_vector_ref(); - let mut buf = vector.data_type().create_mutable_vector(slice_len); if slice_len == 0 { let v: PyVector = buf.to_vector().into(); @@ -391,6 +390,7 @@ impl PyVector { Ok(v.into_pyobject(vm)) } else if step.is_negative() { // Negative step require special treatment + // range.start > range.stop if slice can found no-empty for i in range.rev().step_by(step.unsigned_abs()) { // Safety: This mutable vector is created from the vector's data type. buf.push_value_ref(vector.get_ref(i)); diff --git a/src/script/src/python/pyo3/builtins.rs b/src/script/src/python/pyo3/builtins.rs index 809fbfe43c..5d1e22a2f0 100644 --- a/src/script/src/python/pyo3/builtins.rs +++ b/src/script/src/python/pyo3/builtins.rs @@ -61,6 +61,7 @@ macro_rules! batch_import { #[pyo3(name = "greptime")] pub(crate) fn greptime_builtins(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; + m.add_class::()?; use self::query_engine; batch_import!( m, @@ -137,7 +138,7 @@ fn dataframe(py: Python) -> PyResult { #[pyfunction] #[pyo3(name = "query")] -fn query_engine(py: Python) -> PyResult { +pub(crate) fn query_engine(py: Python) -> PyResult { let globals = get_globals(py)?; let query = globals .get_item("__query__") diff --git a/src/script/src/python/pyo3/copr_impl.rs b/src/script/src/python/pyo3/copr_impl.rs index 0ed7ccb6e2..e2be48c6ff 100644 --- a/src/script/src/python/pyo3/copr_impl.rs +++ b/src/script/src/python/pyo3/copr_impl.rs @@ -16,7 +16,7 @@ use std::collections::HashMap; use common_recordbatch::RecordBatch; use datatypes::vectors::{Helper, VectorRef}; -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::types::{PyDict, PyList, PyModule, PyTuple}; use pyo3::{pymethods, PyAny, PyCell, PyObject, PyResult, Python, ToPyObject}; use snafu::{ensure, Backtrace, GenerateImplicitData, ResultExt}; @@ -31,9 +31,8 @@ use crate::python::pyo3::utils::{init_cpython_interpreter, pyo3_obj_try_to_typed impl PyQueryEngine { #[pyo3(name = "sql")] pub(crate) fn sql_pyo3(&self, py: Python<'_>, s: String) -> PyResult { - let query = self.get_ref(); let res = self - .query_with_new_thread(query, s) + .query_with_new_thread(s) .map_err(PyValueError::new_err)?; match res { crate::python::ffi_types::copr::Either::Rb(rbs) => { @@ -143,7 +142,7 @@ coprocessor = copr py_any_to_vec(result, col_len) })() .map_err(|err| error::Error::PyRuntime { - msg: err.to_string(), + msg: err.into_value(py).to_string(), backtrace: Backtrace::generate(), })?; ensure!( @@ -165,6 +164,23 @@ coprocessor = copr /// Cast return of py script result to `Vec`, /// constants will be broadcast to length of `col_len` fn py_any_to_vec(obj: &PyAny, col_len: usize) -> PyResult> { + // check if obj is of two types: + // 1. tuples of PyVector + // 2. a single PyVector + let check = if obj.is_instance_of::()? { + let tuple = obj.downcast::()?; + + (0..tuple.len()) + .map(|idx| tuple.get_item(idx).map(|i| i.is_instance_of::())) + .all(|i| matches!(i, Ok(Ok(true)))) + } else { + obj.is_instance_of::()? + }; + if !check { + return Err(PyRuntimeError::new_err(format!( + "Expect a tuple of vectors or one single vector, found {obj}" + ))); + } if let Ok(tuple) = obj.downcast::() { let len = tuple.len(); let v = (0..len) @@ -219,7 +235,7 @@ def a(cpu, mem, **kwargs): for k, v in kwargs.items(): print("%s == %s" % (k, v)) print(dataframe().select([col("cpu")= 0.75) + return (0.5 < cpu) & ~(cpu >= 0.75) "#; let cpu_array = Float32Vector::from_slice([0.9f32, 0.8, 0.7, 0.3]); let mem_array = Float64Vector::from_slice([0.1f64, 0.2, 0.3, 0.4]); diff --git a/src/script/src/python/pyo3/dataframe_impl.rs b/src/script/src/python/pyo3/dataframe_impl.rs index 07bee23df8..e15ebc5f31 100644 --- a/src/script/src/python/pyo3/dataframe_impl.rs +++ b/src/script/src/python/pyo3/dataframe_impl.rs @@ -15,14 +15,15 @@ use common_recordbatch::DfRecordBatch; use datafusion::dataframe::DataFrame as DfDataFrame; use datafusion_expr::Expr as DfExpr; -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; -use pyo3::types::PyList; +use pyo3::types::{PyList, PyType}; use snafu::ResultExt; use crate::python::error::DataFusionSnafu; use crate::python::ffi_types::PyVector; +use crate::python::pyo3::builtins::query_engine; use crate::python::pyo3::utils::pyo3_obj_try_to_typed_scalar_value; use crate::python::utils::block_on_async; type PyExprRef = Py; @@ -49,6 +50,15 @@ impl PyDataFrame { #[pymethods] impl PyDataFrame { + #[classmethod] + fn from_sql(_cls: &PyType, py: Python, sql: String) -> PyResult { + let query = query_engine(py)?; + let rb = query.sql_to_rb(sql).map_err(PyRuntimeError::new_err)?; + let ctx = datafusion::execution::context::SessionContext::new(); + ctx.read_batch(rb.df_record_batch().clone()) + .map_err(|e| PyRuntimeError::new_err(format!("{e:?}"))) + .map(Self::from) + } fn __call__(&self) -> PyResult { Ok(self.clone()) } diff --git a/src/script/src/python/pyo3/vector_impl.rs b/src/script/src/python/pyo3/vector_impl.rs index 7c6a778fd1..d26309aacd 100644 --- a/src/script/src/python/pyo3/vector_impl.rs +++ b/src/script/src/python/pyo3/vector_impl.rs @@ -24,7 +24,7 @@ use datatypes::vectors::Helper; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; -use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyType}; +use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString, PyType}; use super::utils::val_to_py_any; use crate::python::ffi_types::vector::{arrow_rtruediv, wrap_bool_result, wrap_result, PyVector}; @@ -299,6 +299,49 @@ impl PyVector { })?; let ret = Self::from(ret).into_py(py); Ok(ret) + } else if let Ok(slice) = needle.downcast::(py) { + let indices = slice.indices(self.len() as i64)?; + let (start, stop, step, _slicelength) = ( + indices.start, + indices.stop, + indices.step, + indices.slicelength, + ); + if start < 0 { + return Err(PyValueError::new_err(format!( + "Negative start is not supported, found {start} in {indices:?}" + ))); + } // Negative stop is supported, means from "indices.start" to the actual start of the vector + let vector = self.as_vector_ref(); + + let mut buf = vector + .data_type() + .create_mutable_vector(indices.slicelength as usize); + let v = if indices.slicelength == 0 { + buf.to_vector() + } else { + if indices.step > 0 { + let range = if stop == -1 { + start as usize..start as usize + } else { + start as usize..stop as usize + }; + for i in range.step_by(step.unsigned_abs()) { + buf.push_value_ref(vector.get_ref(i)); + } + } else { + // if no-empty, then stop < start + // note: start..stop is empty is start >= stop + // stop>=-1 + let range = { (stop + 1) as usize..=start as usize }; + for i in range.rev().step_by(step.unsigned_abs()) { + buf.push_value_ref(vector.get_ref(i)); + } + } + buf.to_vector() + }; + let v: PyVector = v.into(); + Ok(v.into_py(py)) } else if let Ok(index) = needle.extract::(py) { // deal with negative index let len = self.len() as isize; diff --git a/src/script/src/python/rspython/builtins.rs b/src/script/src/python/rspython/builtins.rs index def55fd5df..82a6d77b2b 100644 --- a/src/script/src/python/rspython/builtins.rs +++ b/src/script/src/python/rspython/builtins.rs @@ -290,6 +290,7 @@ pub(crate) mod greptime_builtin { use common_function::scalars::math::PowFunction; use common_function::scalars::{Function, FunctionRef, FUNCTION_REGISTRY}; use datafusion::arrow::datatypes::DataType as ArrowDataType; + use datafusion::dataframe::DataFrame as DfDataFrame; use datafusion::physical_plan::expressions; use datafusion_expr::{ColumnarValue as DFColValue, Expr as DfExpr}; use datafusion_physical_expr::math_expressions; @@ -300,20 +301,29 @@ pub(crate) mod greptime_builtin { use paste::paste; use rustpython_vm::builtins::{PyFloat, PyFunction, PyInt, PyStr}; use rustpython_vm::function::{FuncArgs, KwArgs, OptionalArg}; - use rustpython_vm::{AsObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine}; - - use super::{ - all_to_f64, eval_aggr_fn, from_df_err, try_into_columnar_value, try_into_py_obj, - type_cast_error, + use rustpython_vm::{ + pyclass, AsObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; + use crate::python::ffi_types::copr::PyQueryEngine; use crate::python::ffi_types::vector::val_to_pyobj; use crate::python::ffi_types::PyVector; - use crate::python::rspython::dataframe_impl::data_frame::{PyDataFrame, PyExpr, PyExprRef}; + use crate::python::rspython::builtins::{ + all_to_f64, eval_aggr_fn, from_df_err, try_into_columnar_value, try_into_py_obj, + type_cast_error, + }; + use crate::python::rspython::dataframe_impl::data_frame::{PyExpr, PyExprRef}; use crate::python::rspython::utils::{ is_instance, py_obj_to_value, py_obj_to_vec, PyVectorRef, }; + #[pyattr] + #[pyclass(module = "greptime_builtin", name = "PyDataFrame")] + #[derive(PyPayload, Debug, Clone)] + pub struct PyDataFrame { + pub inner: DfDataFrame, + } + /// get `__dataframe__` from globals and return it /// TODO(discord9): this is a terrible hack, we should find a better way to get `__dataframe__` #[pyfunction] @@ -327,9 +337,9 @@ pub(crate) mod greptime_builtin { } /// get `__query__` from globals and return it - /// TODO(discord9): this is a terrible hack, we should find a better way to get `__dataframe__` + /// TODO(discord9): this is a terrible hack, we should find a better way to get `__query__` #[pyfunction] - fn query(vm: &VirtualMachine) -> PyResult { + pub(crate) fn query(vm: &VirtualMachine) -> PyResult { let query_engine = vm.current_globals().get_item("__query__", vm)?; let query_engine = query_engine.payload::().ok_or_else(|| { vm.new_type_error(format!("object {:?} is not a QueryEngine", query_engine)) diff --git a/src/script/src/python/rspython/dataframe_impl.rs b/src/script/src/python/rspython/dataframe_impl.rs index 0884a9ad16..381ece78ef 100644 --- a/src/script/src/python/rspython/dataframe_impl.rs +++ b/src/script/src/python/rspython/dataframe_impl.rs @@ -14,8 +14,10 @@ use rustpython_vm::class::PyClassImpl; use rustpython_vm::{pymodule as rspymodule, VirtualMachine}; + +use crate::python::rspython::builtins::greptime_builtin::PyDataFrame; pub(crate) fn init_data_frame(module_name: &str, vm: &mut VirtualMachine) { - data_frame::PyDataFrame::make_class(&vm.ctx); + PyDataFrame::make_class(&vm.ctx); data_frame::PyExpr::make_class(&vm.ctx); vm.add_native_module(module_name.to_owned(), Box::new(data_frame::make_module)); } @@ -24,6 +26,7 @@ pub(crate) fn init_data_frame(module_name: &str, vm: &mut VirtualMachine) { pub(crate) mod data_frame { use common_recordbatch::{DfRecordBatch, RecordBatch}; use datafusion::dataframe::DataFrame as DfDataFrame; + use datafusion::execution::context::SessionContext; use datafusion_expr::Expr as DfExpr; use rustpython_vm::builtins::{PyList, PyListRef}; use rustpython_vm::function::PyComparisonValue; @@ -35,13 +38,10 @@ pub(crate) mod data_frame { use crate::python::error::DataFusionSnafu; use crate::python::ffi_types::PyVector; - use crate::python::rspython::builtins::greptime_builtin::lit; + use crate::python::rspython::builtins::greptime_builtin::{ + lit, query as get_query_engine, PyDataFrame, + }; use crate::python::utils::block_on_async; - #[rspyclass(module = "data_frame", name = "DataFrame")] - #[derive(PyPayload, Debug, Clone)] - pub struct PyDataFrame { - pub inner: DfDataFrame, - } impl From for PyDataFrame { fn from(inner: DfDataFrame) -> Self { @@ -63,9 +63,20 @@ pub(crate) mod data_frame { } #[rspyclass] impl PyDataFrame { + #[pymethod] + fn from_sql(sql: String, vm: &VirtualMachine) -> PyResult { + let query_engine = get_query_engine(vm)?; + let rb = query_engine.sql_to_rb(sql.clone()).map_err(|e| { + vm.new_runtime_error(format!("failed to execute sql: {:?}, error: {:?}", sql, e)) + })?; + let ctx = SessionContext::new(); + ctx.read_batch(rb.df_record_batch().clone()) + .map_err(|e| vm.new_runtime_error(format!("{e:?}"))) + .map(|df| df.into()) + } /// TODO(discord9): error handling fn from_record_batch(rb: &DfRecordBatch) -> crate::python::error::Result { - let ctx = datafusion::execution::context::SessionContext::new(); + let ctx = SessionContext::new(); let inner = ctx.read_batch(rb.clone()).context(DataFusionSnafu)?; Ok(Self { inner }) }