diff --git a/src/script/src/lib.rs b/src/script/src/lib.rs index f53b365e1d..9f88abcc5d 100644 --- a/src/script/src/lib.rs +++ b/src/script/src/lib.rs @@ -13,6 +13,8 @@ // limitations under the License. // TODO(discord9): spawn new process for executing python script(if hit gil limit) and use shared memory to communicate +#![deny(clippy::implicit_clone)] + pub mod engine; pub mod error; #[cfg(feature = "python")] diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index 1c497e9c55..1e51c6141b 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -439,8 +439,8 @@ from greptime import col @copr(args=["number"], returns = ["number"], sql = "select * from numbers") def test(number) -> vector[u32]: - from greptime import dataframe - return dataframe().filter(col("number")==col("number")).collect()[0][0] + from greptime import PyDataFrame + return PyDataFrame.from_sql("select * from numbers").filter(col("number")==col("number")).collect()[0][0] "#; let script = script_engine .compile(script, CompileContext::default()) diff --git a/src/script/src/python/ffi_types.rs b/src/script/src/python/ffi_types.rs index 6506585e23..5240e4e403 100644 --- a/src/script/src/python/ffi_types.rs +++ b/src/script/src/python/ffi_types.rs @@ -13,6 +13,7 @@ // limitations under the License. pub(crate) mod copr; +pub(crate) mod py_recordbatch; pub(crate) mod utils; pub(crate) mod vector; pub(crate) use copr::{check_args_anno_real_type, select_from_rb, Coprocessor}; diff --git a/src/script/src/python/ffi_types/copr.rs b/src/script/src/python/ffi_types/copr.rs index ce9fc85ad9..fe970ec307 100644 --- a/src/script/src/python/ffi_types/copr.rs +++ b/src/script/src/python/ffi_types/copr.rs @@ -36,10 +36,10 @@ use rustpython_vm as vm; use serde::Deserialize; use session::context::QueryContext; use snafu::{OptionExt, ResultExt}; -use vm::builtins::{PyList, PyListRef}; use vm::convert::ToPyObject; -use vm::{pyclass as rspyclass, PyPayload, PyResult, VirtualMachine}; +use vm::{pyclass as rspyclass, PyObjectRef, PyPayload, PyResult, VirtualMachine}; +use super::py_recordbatch::PyRecordBatch; use crate::python::error::{ensure, ArrowSnafu, OtherSnafu, Result, TypeCastSnafu}; use crate::python::ffi_types::PyVector; #[cfg(feature = "pyo3_backend")] @@ -410,9 +410,11 @@ impl PyQueryEngine { .map_err(|e| format!("Dedicated thread for sql query panic: {e:?}"))? } // TODO(discord9): find a better way to call sql query api, now we don't if we are in async context or not - /// return sql query results in List[PyVector], or List[usize] for AffectedRows number if no recordbatches is returned + /// - return sql query results in `PyRecordBatch`, or + /// - a empty `PyDict` if query results is empty + /// - or number of AffectedRows #[pymethod] - fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult { + fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult { self.query_with_new_thread(s) .map_err(|e| vm.new_system_error(e)) .map(|rbs| match rbs { @@ -428,17 +430,11 @@ impl PyQueryEngine { RecordBatch::try_from_df_record_batch(rbs.schema(), rb).map_err(|e| { vm.new_runtime_error(format!("Failed to cast recordbatch: {e:#?}")) })?; - let columns_vectors = rb - .columns() - .iter() - .map(|v| PyVector::from(v.clone()).to_pyobject(vm)) - .collect::>(); - Ok(PyList::new_ref(columns_vectors, vm.as_ref())) + let rb = PyRecordBatch::new(rb); + + Ok(rb.to_pyobject(vm)) } - Either::AffectedRows(cnt) => Ok(PyList::new_ref( - vec![vm.ctx.new_int(cnt).into()], - vm.as_ref(), - )), + Either::AffectedRows(cnt) => Ok(vm.ctx.new_int(cnt).to_pyobject(vm)), })? } } diff --git a/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs b/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs index fbcf261074..8902fcd7c9 100644 --- a/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs +++ b/src/script/src/python/ffi_types/pair_tests/sample_testcases.rs @@ -414,16 +414,16 @@ def boolean_array() -> vector[f64]: @copr(returns=["value"], backend="pyo3") def boolean_array() -> vector[f64]: from greptime import vector - from greptime import query, dataframe + from greptime import query print("query()=", query()) assert "query_engine object at" in str(query()) - try: - print("dataframe()=", dataframe()) - except KeyError as e: - print("dataframe()=", e) - print(str(e), type(str(e)), 'No __dataframe__' in str(e)) - assert 'No __dataframe__' in str(e) + rb = query().sql( + "select number from numbers limit 5" + ) + print(rb) + assert len(rb) == 5 + v = vector([1.0, 2.0, 3.0]) # This returns a vector([2.0]) @@ -437,24 +437,25 @@ def boolean_array() -> vector[f64]: @copr(returns=["value"], backend="rspy") def boolean_array() -> vector[f64]: from greptime import vector, col - from greptime import query, dataframe, PyDataFrame + from greptime import query, PyDataFrame df = PyDataFrame.from_sql("select number from numbers limit 5") print("df from sql=", df) collected = df.collect() print("df.collect()=", collected) - assert len(collected[0][0]) == 5 + assert len(collected[0]) == 5 df = PyDataFrame.from_sql("select number from numbers limit 5").filter(col("number") > 2) collected = df.collect() - assert len(collected[0][0]) == 2 + assert len(collected[0]) == 2 + assert collected[0] == collected["number"] print("query()=", query()) assert "query_engine object at" in repr(query()) - try: - print("dataframe()=", dataframe()) - except KeyError as e: - print("dataframe()=", e) - assert "__dataframe__" in str(e) + rb = query().sql( + "select number from numbers limit 5" + ) + print(rb) + assert len(rb) == 5 v = vector([1.0, 2.0, 3.0]) # This returns a vector([2.0]) @@ -469,16 +470,17 @@ def boolean_array() -> vector[f64]: @copr(returns=["value"], backend="pyo3") def boolean_array() -> vector[f64]: from greptime import vector - from greptime import query, dataframe, PyDataFrame, col + from greptime import query, PyDataFrame, col df = PyDataFrame.from_sql("select number from numbers limit 5") print("df from sql=", df) ret = df.collect() print("df.collect()=", ret) - assert len(ret[0][0]) == 5 + assert len(ret[0]) == 5 df = PyDataFrame.from_sql("select number from numbers limit 5").filter(col("number") > 2) collected = df.collect() - assert len(collected[0][0]) == 2 - return ret[0][0] + assert len(collected[0]) == 2 + assert collected[0] == collected["number"] + return ret[0] "# .to_string(), expect: Some(ronish!("value": vector!(UInt32Vector, [0, 1, 2, 3, 4]))), @@ -546,6 +548,12 @@ def answer() -> vector[i64]: @copr(returns=["value"], backend="pyo3") def answer() -> vector[i64]: from greptime import vector + try: + import pyarrow as pa + except ImportError: + # Python didn't have pyarrow + print("Warning: no pyarrow in current python") + return vector([42, 43, 44]) return vector.from_pyarrow(vector([42, 43, 44]).to_pyarrow()) "# .to_string(), @@ -557,7 +565,12 @@ def answer() -> vector[i64]: @copr(returns=["value"], backend="pyo3") def answer() -> vector[i64]: from greptime import vector - import pyarrow as pa + try: + import pyarrow as pa + except ImportError: + # Python didn't have pyarrow + print("Warning: no pyarrow in current python") + return vector([42, 43, 44]) return vector.from_pyarrow(pa.array([42, 43, 44])) "# .to_string(), @@ -567,9 +580,9 @@ def answer() -> vector[i64]: script: r#" @copr(args=[], returns = ["number"], sql = "select * from numbers", backend="rspy") def answer() -> vector[i64]: - from greptime import vector, col, lit, dataframe + from greptime import vector, col, lit, PyDataFrame expr_0 = (col("number")0) - ret = dataframe().select([col("number")]).filter(expr_0).collect()[0][0] + ret = PyDataFrame.from_sql("select * from numbers").select([col("number")]).filter(expr_0).collect()[0] return ret "# .to_string(), @@ -580,10 +593,10 @@ def answer() -> vector[i64]: script: r#" @copr(args=[], returns = ["number"], sql = "select * from numbers", backend="pyo3") def answer() -> vector[i64]: - from greptime import vector, col, lit, dataframe + from greptime import vector, col, lit, PyDataFrame # Bitwise Operator pred comparison operator expr_0 = (col("number")0) - ret = dataframe().select([col("number")]).filter(expr_0).collect()[0][0] + ret = PyDataFrame.from_sql("select * from numbers").select([col("number")]).filter(expr_0).collect()[0] return ret "# .to_string(), @@ -595,8 +608,13 @@ def answer() -> vector[i64]: @copr(returns=["value"], backend="pyo3") def answer() -> vector[i64]: from greptime import vector - import pyarrow as pa - a = vector.from_pyarrow(pa.array([42, 43, 44])) + try: + import pyarrow as pa + except ImportError: + # Python didn't have pyarrow + print("Warning: no pyarrow in current python") + return vector([42, 43, 44]) + a = vector.from_pyarrow(pa.array([42])) return a[0:1] "# .to_string(), @@ -691,6 +709,36 @@ def normalize0(x): def normalize(v) -> vector[i64]: return [normalize0(x) for x in v] +"# + .to_string(), + expect: Some(ronish!( + "value": vector!(Int64Vector, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,]) + )), + }, + #[cfg(feature = "pyo3_backend")] + CoprTestCase { + script: r#" +import math + +@coprocessor(args=[], returns=["value"], backend="pyo3") +def test_numpy() -> vector[i64]: + try: + import numpy as np + import pyarrow as pa + except ImportError as e: + # Python didn't have numpy or pyarrow + print("Warning: no pyarrow or numpy found in current python", e) + return vector([0, 1, 2, 3, 4, 5, 6, 7, 8, 9,]) + from greptime import vector + v = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9,]) + v = pa.array(v) + v = vector.from_pyarrow(v) + v = vector.from_numpy(v.numpy()) + v = v.to_pyarrow() + v = v.to_numpy() + v = vector.from_numpy(v) + return v + "# .to_string(), expect: Some(ronish!( diff --git a/src/script/src/python/ffi_types/py_recordbatch.rs b/src/script/src/python/ffi_types/py_recordbatch.rs new file mode 100644 index 0000000000..74d9208dbe --- /dev/null +++ b/src/script/src/python/ffi_types/py_recordbatch.rs @@ -0,0 +1,137 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//! PyRecordBatch is a Python class that wraps a RecordBatch, + +//! PyRecordBatch is a Python class that wraps a RecordBatch, +//! and provide a PyMapping Protocol to +//! access the columns of the RecordBatch. + +use common_recordbatch::RecordBatch; +use crossbeam_utils::atomic::AtomicCell; +#[cfg(feature = "pyo3_backend")] +use pyo3::{ + exceptions::{PyKeyError, PyRuntimeError}, + pyclass as pyo3class, pymethods, PyObject, PyResult, Python, +}; +use rustpython_vm::builtins::PyStr; +use rustpython_vm::protocol::PyMappingMethods; +use rustpython_vm::types::AsMapping; +use rustpython_vm::{ + atomic_func, pyclass as rspyclass, PyObject as RsPyObject, PyPayload, PyResult as RsPyResult, + VirtualMachine, +}; + +use crate::python::ffi_types::PyVector; + +/// This is a Wrapper around a RecordBatch, impl PyMapping Protocol so you can do both `a[0]` and `a["number"]` to retrieve column. +#[cfg_attr(feature = "pyo3_backend", pyo3class(name = "PyRecordBatch"))] +#[rspyclass(module = false, name = "PyRecordBatch")] +#[derive(Debug, PyPayload)] +pub(crate) struct PyRecordBatch { + record_batch: RecordBatch, +} + +impl PyRecordBatch { + pub fn new(record_batch: RecordBatch) -> Self { + Self { record_batch } + } +} + +impl From for PyRecordBatch { + fn from(record_batch: RecordBatch) -> Self { + Self::new(record_batch) + } +} + +#[cfg(feature = "pyo3_backend")] +#[pymethods] +impl PyRecordBatch { + fn __repr__(&self) -> String { + // TODO(discord9): a better pretty print + format!("{:#?}", &self.record_batch.df_record_batch()) + } + fn __getitem__(&self, py: Python, key: PyObject) -> PyResult { + let column = if let Ok(key) = key.extract::(py) { + self.record_batch.column_by_name(&key) + } else if let Ok(key) = key.extract::(py) { + Some(self.record_batch.column(key)) + } else { + return Err(PyRuntimeError::new_err(format!( + "Expect either str or int, found {key:?}" + ))); + } + .ok_or_else(|| PyKeyError::new_err(format!("Column {} not found", key)))?; + let v = PyVector::from(column.clone()); + Ok(v) + } + fn __iter__(&self) -> PyResult> { + let iter: Vec<_> = self + .record_batch + .columns() + .iter() + .map(|i| PyVector::from(i.clone())) + .collect(); + Ok(iter) + } + fn __len__(&self) -> PyResult { + Ok(self.len()) + } +} + +impl PyRecordBatch { + fn len(&self) -> usize { + self.record_batch.num_rows() + } + fn get_item(&self, needle: &RsPyObject, vm: &VirtualMachine) -> RsPyResult { + if let Ok(index) = needle.try_to_value::(vm) { + let column = self.record_batch.column(index); + let v = PyVector::from(column.clone()); + Ok(v.into_pyobject(vm)) + } else if let Ok(index) = needle.try_to_value::(vm) { + let key = index.as_str(); + + let v = self.record_batch.column_by_name(key).ok_or_else(|| { + vm.new_key_error(PyStr::from(format!("Column {} not found", key)).into_pyobject(vm)) + })?; + let v: PyVector = v.clone().into(); + Ok(v.into_pyobject(vm)) + } else { + Err(vm.new_key_error( + PyStr::from(format!("Expect either str or int, found {needle:?}")) + .into_pyobject(vm), + )) + } + } +} + +#[rspyclass(with(AsMapping))] +impl PyRecordBatch { + #[pymethod(name = "__repr__")] + fn rspy_repr(&self) -> String { + format!("{:#?}", &self.record_batch.df_record_batch()) + } +} + +impl AsMapping for PyRecordBatch { + fn as_mapping() -> &'static PyMappingMethods { + static AS_MAPPING: PyMappingMethods = PyMappingMethods { + length: atomic_func!(|mapping, _vm| Ok(PyRecordBatch::mapping_downcast(mapping).len())), + subscript: atomic_func!( + |mapping, needle, vm| PyRecordBatch::mapping_downcast(mapping).get_item(needle, vm) + ), + ass_subscript: AtomicCell::new(None), + }; + &AS_MAPPING + } +} diff --git a/src/script/src/python/pyo3/builtins.rs b/src/script/src/python/pyo3/builtins.rs index 8d9609421c..b678d86008 100644 --- a/src/script/src/python/pyo3/builtins.rs +++ b/src/script/src/python/pyo3/builtins.rs @@ -126,6 +126,8 @@ fn get_globals(py: Python) -> PyResult<&PyDict> { Ok(globals) } +/// In case of not wanting to repeat the same sql statement in sql, +/// this function is still useful even we already have PyDataFrame.from_sql() #[pyfunction] fn dataframe(py: Python) -> PyResult { let globals = get_globals(py)?; diff --git a/src/script/src/python/pyo3/copr_impl.rs b/src/script/src/python/pyo3/copr_impl.rs index 66e9f7fb05..6f7dd8a46e 100644 --- a/src/script/src/python/pyo3/copr_impl.rs +++ b/src/script/src/python/pyo3/copr_impl.rs @@ -14,6 +14,7 @@ use std::collections::HashMap; +use arrow::compute; use common_recordbatch::RecordBatch; use common_telemetry::timer; use datafusion_common::ScalarValue; @@ -21,11 +22,12 @@ use datatypes::prelude::ConcreteDataType; use datatypes::vectors::{Helper, VectorRef}; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyModule, PyString, PyTuple}; -use pyo3::{pymethods, PyAny, PyCell, PyObject, PyResult, Python, ToPyObject}; +use pyo3::{pymethods, IntoPy, PyAny, PyCell, PyObject, PyResult, Python, ToPyObject}; use snafu::{ensure, Location, ResultExt}; use crate::python::error::{self, NewRecordBatchSnafu, OtherSnafu, Result}; use crate::python::ffi_types::copr::PyQueryEngine; +use crate::python::ffi_types::py_recordbatch::PyRecordBatch; use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector}; use crate::python::metric; use crate::python::pyo3::dataframe_impl::PyDataFrame; @@ -36,23 +38,22 @@ impl PyQueryEngine { #[pyo3(name = "sql")] pub(crate) fn sql_pyo3(&self, py: Python<'_>, s: String) -> PyResult { let res = self - .query_with_new_thread(s) + .query_with_new_thread(s.clone()) .map_err(PyValueError::new_err)?; match res { crate::python::ffi_types::copr::Either::Rb(rbs) => { - let mut top_vec = Vec::with_capacity(rbs.iter().count()); - for rb in rbs.iter() { - let mut vec_of_vec = Vec::with_capacity(rb.columns().len()); - for v in rb.columns() { - let v = PyVector::from(v.clone()); - let v = PyCell::new(py, v)?; - vec_of_vec.push(v.to_object(py)); - } - let vec_of_vec = PyList::new(py, vec_of_vec); - top_vec.push(vec_of_vec); - } - let top_vec = PyList::new(py, top_vec); - Ok(top_vec.to_object(py)) + let rb = compute::concat_batches( + rbs.schema().arrow_schema(), + rbs.iter().map(|rb| rb.df_record_batch()), + ) + .map_err(|e| PyRuntimeError::new_err(format!("{e:?}")))?; + let rb = RecordBatch::try_from_df_record_batch(rbs.schema(), rb).map_err(|e| { + PyRuntimeError::new_err(format!( + "Convert datafusion record batch to record batch failed for query {s}: {e}" + )) + })?; + let rb = PyRecordBatch::new(rb); + Ok(rb.into_py(py)) } crate::python::ffi_types::copr::Either::AffectedRows(count) => Ok(count.to_object(py)), } diff --git a/src/script/src/python/pyo3/dataframe_impl.rs b/src/script/src/python/pyo3/dataframe_impl.rs index e15ebc5f31..7b1adeb239 100644 --- a/src/script/src/python/pyo3/dataframe_impl.rs +++ b/src/script/src/python/pyo3/dataframe_impl.rs @@ -12,17 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -use common_recordbatch::DfRecordBatch; +use arrow::compute; +use common_recordbatch::{DfRecordBatch, RecordBatch}; use datafusion::dataframe::DataFrame as DfDataFrame; use datafusion_expr::Expr as DfExpr; +use datatypes::schema::Schema; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; -use pyo3::types::{PyList, PyType}; +use pyo3::types::{PyDict, PyType}; use snafu::ResultExt; use crate::python::error::DataFusionSnafu; -use crate::python::ffi_types::PyVector; +use crate::python::ffi_types::py_recordbatch::PyRecordBatch; use crate::python::pyo3::builtins::query_engine; use crate::python::pyo3::utils::pyo3_obj_try_to_typed_scalar_value; use crate::python::utils::block_on_async; @@ -210,32 +212,32 @@ impl PyDataFrame { .map_err(|e| PyValueError::new_err(e.to_string()))? .into()) } - /// collect `DataFrame` results into `List[List[Vector]]` - fn collect<'a>(&self, py: Python<'a>) -> PyResult<&'a PyList> { + /// collect `DataFrame` results into `PyRecordBatch` that impl Mapping Protocol + fn collect(&self, py: Python) -> PyResult { let inner = self.inner.clone(); let res = block_on_async(async { inner.collect().await }); let res = res .map_err(|e| PyValueError::new_err(format!("{e:?}")))? .map_err(|e| PyValueError::new_err(e.to_string()))?; - let outer_list: Vec = res - .iter() - .map(|elem| -> PyResult<_> { - let inner_list: Vec<_> = elem - .columns() - .iter() - .map(|arr| -> PyResult<_> { - datatypes::vectors::Helper::try_into_vector(arr) - .map(PyVector::from) - .map(|v| PyCell::new(py, v)) - .map_err(|e| PyValueError::new_err(e.to_string())) - .and_then(|x| x) - }) - .collect::>()?; - let inner_list = PyList::new(py, inner_list); - Ok(inner_list.into()) - }) - .collect::>()?; - Ok(PyList::new(py, outer_list)) + if res.is_empty() { + return Ok(PyDict::new(py).into()); + } + let concat_rb = compute::concat_batches(&res[0].schema(), res.iter()).map_err(|e| { + PyRuntimeError::new_err(format!("Concat batches failed for dataframe {self:?}: {e}")) + })?; + + let schema = Schema::try_from(concat_rb.schema()).map_err(|e| { + PyRuntimeError::new_err(format!( + "Convert to Schema failed for dataframe {self:?}: {e}" + )) + })?; + let rb = RecordBatch::try_from_df_record_batch(schema.into(), concat_rb).map_err(|e| { + PyRuntimeError::new_err(format!( + "Convert to RecordBatch failed for dataframe {self:?}: {e}" + )) + })?; + let rb = PyRecordBatch::new(rb); + Ok(rb.into_py(py)) } } diff --git a/src/script/src/python/pyo3/vector_impl.rs b/src/script/src/python/pyo3/vector_impl.rs index ffbb570d60..c6b1119e39 100644 --- a/src/script/src/python/pyo3/vector_impl.rs +++ b/src/script/src/python/pyo3/vector_impl.rs @@ -24,7 +24,7 @@ use datatypes::vectors::Helper; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::pyclass::CompareOp; -use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString, PyType}; +use pyo3::types::{PyBool, PyFloat, PyInt, PySequence, PySlice, PyString, PyType}; use super::utils::val_to_py_any; use crate::python::ffi_types::vector::{arrow_rtruediv, wrap_bool_result, wrap_result, PyVector}; @@ -93,12 +93,28 @@ impl PyVector { #[pymethods] impl PyVector { + /// convert from numpy array to [`PyVector`] + #[classmethod] + fn from_numpy(cls: &PyType, py: Python<'_>, obj: PyObject) -> PyResult { + let pa = py.import("pyarrow")?; + let obj = pa.call_method1("array", (obj,))?; + let zelf = Self::from_pyarrow(cls, py, obj.into())?; + Ok(zelf.into_py(py)) + } + + fn numpy(&self, py: Python<'_>) -> PyResult { + let pa_arrow = self.to_arrow_array().data().to_pyarrow(py)?; + let ndarray = pa_arrow.call_method0(py, "to_numpy")?; + Ok(ndarray) + } + /// create a `PyVector` with a `PyList` that contains only elements of same type #[new] - pub(crate) fn py_new(iterable: &PyList) -> PyResult { + pub(crate) fn py_new(iterable: PyObject, py: Python<'_>) -> PyResult { + let iterable = iterable.downcast::(py)?; let dtype = get_py_type(iterable.get_item(0)?)?; - let mut buf = dtype.create_mutable_vector(iterable.len()); - for i in 0..iterable.len() { + let mut buf = dtype.create_mutable_vector(iterable.len()?); + for i in 0..iterable.len()? { let element = iterable.get_item(i)?; let val = pyo3_obj_try_to_typed_val(element, Some(dtype.clone()))?; buf.push_value_ref(val.as_value_ref()); diff --git a/src/script/src/python/rspython/copr_impl.rs b/src/script/src/python/rspython/copr_impl.rs index b042445592..576f0fa112 100644 --- a/src/script/src/python/rspython/copr_impl.rs +++ b/src/script/src/python/rspython/copr_impl.rs @@ -29,6 +29,7 @@ use snafu::{OptionExt, ResultExt}; use crate::python::error::{ensure, ret_other_error_with, NewRecordBatchSnafu, OtherSnafu, Result}; use crate::python::ffi_types::copr::PyQueryEngine; +use crate::python::ffi_types::py_recordbatch::PyRecordBatch; use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector}; use crate::python::metric; use crate::python::rspython::builtins::init_greptime_builtins; @@ -216,6 +217,7 @@ pub(crate) fn init_interpreter() -> Arc { // add our own custom datatype and module PyVector::make_class(&vm.ctx); PyQueryEngine::make_class(&vm.ctx); + PyRecordBatch::make_class(&vm.ctx); init_greptime_builtins("greptime", vm); init_data_frame("data_frame", vm); })); diff --git a/src/script/src/python/rspython/dataframe_impl.rs b/src/script/src/python/rspython/dataframe_impl.rs index 381ece78ef..e269db409a 100644 --- a/src/script/src/python/rspython/dataframe_impl.rs +++ b/src/script/src/python/rspython/dataframe_impl.rs @@ -28,7 +28,6 @@ pub(crate) mod data_frame { use datafusion::dataframe::DataFrame as DfDataFrame; use datafusion::execution::context::SessionContext; use datafusion_expr::Expr as DfExpr; - use rustpython_vm::builtins::{PyList, PyListRef}; use rustpython_vm::function::PyComparisonValue; use rustpython_vm::types::{Comparable, PyComparisonOp}; use rustpython_vm::{ @@ -37,7 +36,7 @@ pub(crate) mod data_frame { use snafu::ResultExt; use crate::python::error::DataFusionSnafu; - use crate::python::ffi_types::PyVector; + use crate::python::ffi_types::py_recordbatch::PyRecordBatch; use crate::python::rspython::builtins::greptime_builtin::{ lit, query as get_query_engine, PyDataFrame, }; @@ -235,31 +234,38 @@ pub(crate) mod data_frame { } #[pymethod] - /// collect `DataFrame` results into `List[List[Vector]]` - fn collect(&self, vm: &VirtualMachine) -> PyResult { + /// collect `DataFrame` results into `PyRecordBatch` that impl Mapping Protocol + fn collect(&self, vm: &VirtualMachine) -> PyResult { let inner = self.inner.clone(); let res = block_on_async(async { inner.collect().await }); let res = res .map_err(|e| vm.new_runtime_error(format!("{e:?}")))? .map_err(|e| vm.new_runtime_error(e.to_string()))?; - let outer_list: Vec<_> = res - .iter() - .map(|elem| -> PyResult<_> { - let inner_list: Vec<_> = elem - .columns() - .iter() - .map(|arr| -> PyResult<_> { - datatypes::vectors::Helper::try_into_vector(arr) - .map(PyVector::from) - .map(|v| vm.new_pyobj(v)) - .map_err(|e| vm.new_runtime_error(e.to_string())) - }) - .collect::>()?; - let inner_list = PyList::new_ref(inner_list, vm.as_ref()); - Ok(inner_list.into()) - }) - .collect::>()?; - Ok(PyList::new_ref(outer_list, vm.as_ref())) + if res.is_empty() { + return Ok(vm.ctx.new_dict().into()); + } + let concat_rb = + arrow::compute::concat_batches(&res[0].schema(), res.iter()).map_err(|e| { + vm.new_runtime_error(format!( + "Concat batches failed for dataframe {self:?}: {e}" + )) + })?; + + // we are inside a macro, so using full path + let schema = datatypes::schema::Schema::try_from(concat_rb.schema()).map_err(|e| { + vm.new_runtime_error(format!( + "Convert to Schema failed for dataframe {self:?}: {e}" + )) + })?; + let rb = + RecordBatch::try_from_df_record_batch(schema.into(), concat_rb).map_err(|e| { + vm.new_runtime_error(format!( + "Convert to RecordBatch failed for dataframe {self:?}: {e}" + )) + })?; + + let rb = PyRecordBatch::new(rb); + Ok(rb.into_pyobject(vm)) } }