feat: from/to numpy&collect concat (#1339)

* feat: from/to numpy&collect concat

* feat: PyRecordBatch

* test: try import first,allow w/out numpy/pyarrow

* fix: cond compile flag

* doc: license

* feat: sql() ret PyRecordBatch&repr

* fix: after merge

* style: fmt

* chore: CR advices

* docs: update

* chore: resolve conflict
This commit is contained in:
discord9
2023-04-13 10:46:25 +08:00
committed by GitHub
parent 33dbf7264f
commit c20dbda598
12 changed files with 320 additions and 107 deletions

View File

@@ -13,6 +13,8 @@
// limitations under the License.
// TODO(discord9): spawn new process for executing python script(if hit gil limit) and use shared memory to communicate
#![deny(clippy::implicit_clone)]
pub mod engine;
pub mod error;
#[cfg(feature = "python")]

View File

@@ -439,8 +439,8 @@ from greptime import col
@copr(args=["number"], returns = ["number"], sql = "select * from numbers")
def test(number) -> vector[u32]:
from greptime import dataframe
return dataframe().filter(col("number")==col("number")).collect()[0][0]
from greptime import PyDataFrame
return PyDataFrame.from_sql("select * from numbers").filter(col("number")==col("number")).collect()[0][0]
"#;
let script = script_engine
.compile(script, CompileContext::default())

View File

@@ -13,6 +13,7 @@
// limitations under the License.
pub(crate) mod copr;
pub(crate) mod py_recordbatch;
pub(crate) mod utils;
pub(crate) mod vector;
pub(crate) use copr::{check_args_anno_real_type, select_from_rb, Coprocessor};

View File

@@ -36,10 +36,10 @@ use rustpython_vm as vm;
use serde::Deserialize;
use session::context::QueryContext;
use snafu::{OptionExt, ResultExt};
use vm::builtins::{PyList, PyListRef};
use vm::convert::ToPyObject;
use vm::{pyclass as rspyclass, PyPayload, PyResult, VirtualMachine};
use vm::{pyclass as rspyclass, PyObjectRef, PyPayload, PyResult, VirtualMachine};
use super::py_recordbatch::PyRecordBatch;
use crate::python::error::{ensure, ArrowSnafu, OtherSnafu, Result, TypeCastSnafu};
use crate::python::ffi_types::PyVector;
#[cfg(feature = "pyo3_backend")]
@@ -410,9 +410,11 @@ impl PyQueryEngine {
.map_err(|e| format!("Dedicated thread for sql query panic: {e:?}"))?
}
// TODO(discord9): find a better way to call sql query api, now we don't if we are in async context or not
/// return sql query results in List[PyVector], or List[usize] for AffectedRows number if no recordbatches is returned
/// - return sql query results in `PyRecordBatch`, or
/// - a empty `PyDict` if query results is empty
/// - or number of AffectedRows
#[pymethod]
fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult<PyListRef> {
fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
self.query_with_new_thread(s)
.map_err(|e| vm.new_system_error(e))
.map(|rbs| match rbs {
@@ -428,17 +430,11 @@ impl PyQueryEngine {
RecordBatch::try_from_df_record_batch(rbs.schema(), rb).map_err(|e| {
vm.new_runtime_error(format!("Failed to cast recordbatch: {e:#?}"))
})?;
let columns_vectors = rb
.columns()
.iter()
.map(|v| PyVector::from(v.clone()).to_pyobject(vm))
.collect::<Vec<_>>();
Ok(PyList::new_ref(columns_vectors, vm.as_ref()))
let rb = PyRecordBatch::new(rb);
Ok(rb.to_pyobject(vm))
}
Either::AffectedRows(cnt) => Ok(PyList::new_ref(
vec![vm.ctx.new_int(cnt).into()],
vm.as_ref(),
)),
Either::AffectedRows(cnt) => Ok(vm.ctx.new_int(cnt).to_pyobject(vm)),
})?
}
}

View File

@@ -414,16 +414,16 @@ def boolean_array() -> vector[f64]:
@copr(returns=["value"], backend="pyo3")
def boolean_array() -> vector[f64]:
from greptime import vector
from greptime import query, dataframe
from greptime import query
print("query()=", query())
assert "query_engine object at" in str(query())
try:
print("dataframe()=", dataframe())
except KeyError as e:
print("dataframe()=", e)
print(str(e), type(str(e)), 'No __dataframe__' in str(e))
assert 'No __dataframe__' in str(e)
rb = query().sql(
"select number from numbers limit 5"
)
print(rb)
assert len(rb) == 5
v = vector([1.0, 2.0, 3.0])
# This returns a vector([2.0])
@@ -437,24 +437,25 @@ def boolean_array() -> vector[f64]:
@copr(returns=["value"], backend="rspy")
def boolean_array() -> vector[f64]:
from greptime import vector, col
from greptime import query, dataframe, PyDataFrame
from greptime import query, PyDataFrame
df = PyDataFrame.from_sql("select number from numbers limit 5")
print("df from sql=", df)
collected = df.collect()
print("df.collect()=", collected)
assert len(collected[0][0]) == 5
assert len(collected[0]) == 5
df = PyDataFrame.from_sql("select number from numbers limit 5").filter(col("number") > 2)
collected = df.collect()
assert len(collected[0][0]) == 2
assert len(collected[0]) == 2
assert collected[0] == collected["number"]
print("query()=", query())
assert "query_engine object at" in repr(query())
try:
print("dataframe()=", dataframe())
except KeyError as e:
print("dataframe()=", e)
assert "__dataframe__" in str(e)
rb = query().sql(
"select number from numbers limit 5"
)
print(rb)
assert len(rb) == 5
v = vector([1.0, 2.0, 3.0])
# This returns a vector([2.0])
@@ -469,16 +470,17 @@ def boolean_array() -> vector[f64]:
@copr(returns=["value"], backend="pyo3")
def boolean_array() -> vector[f64]:
from greptime import vector
from greptime import query, dataframe, PyDataFrame, col
from greptime import query, PyDataFrame, col
df = PyDataFrame.from_sql("select number from numbers limit 5")
print("df from sql=", df)
ret = df.collect()
print("df.collect()=", ret)
assert len(ret[0][0]) == 5
assert len(ret[0]) == 5
df = PyDataFrame.from_sql("select number from numbers limit 5").filter(col("number") > 2)
collected = df.collect()
assert len(collected[0][0]) == 2
return ret[0][0]
assert len(collected[0]) == 2
assert collected[0] == collected["number"]
return ret[0]
"#
.to_string(),
expect: Some(ronish!("value": vector!(UInt32Vector, [0, 1, 2, 3, 4]))),
@@ -546,6 +548,12 @@ def answer() -> vector[i64]:
@copr(returns=["value"], backend="pyo3")
def answer() -> vector[i64]:
from greptime import vector
try:
import pyarrow as pa
except ImportError:
# Python didn't have pyarrow
print("Warning: no pyarrow in current python")
return vector([42, 43, 44])
return vector.from_pyarrow(vector([42, 43, 44]).to_pyarrow())
"#
.to_string(),
@@ -557,7 +565,12 @@ def answer() -> vector[i64]:
@copr(returns=["value"], backend="pyo3")
def answer() -> vector[i64]:
from greptime import vector
import pyarrow as pa
try:
import pyarrow as pa
except ImportError:
# Python didn't have pyarrow
print("Warning: no pyarrow in current python")
return vector([42, 43, 44])
return vector.from_pyarrow(pa.array([42, 43, 44]))
"#
.to_string(),
@@ -567,9 +580,9 @@ def answer() -> vector[i64]:
script: r#"
@copr(args=[], returns = ["number"], sql = "select * from numbers", backend="rspy")
def answer() -> vector[i64]:
from greptime import vector, col, lit, dataframe
from greptime import vector, col, lit, PyDataFrame
expr_0 = (col("number")<lit(3)) & (col("number")>0)
ret = dataframe().select([col("number")]).filter(expr_0).collect()[0][0]
ret = PyDataFrame.from_sql("select * from numbers").select([col("number")]).filter(expr_0).collect()[0]
return ret
"#
.to_string(),
@@ -580,10 +593,10 @@ def answer() -> vector[i64]:
script: r#"
@copr(args=[], returns = ["number"], sql = "select * from numbers", backend="pyo3")
def answer() -> vector[i64]:
from greptime import vector, col, lit, dataframe
from greptime import vector, col, lit, PyDataFrame
# Bitwise Operator pred comparison operator
expr_0 = (col("number")<lit(3)) & (col("number")>0)
ret = dataframe().select([col("number")]).filter(expr_0).collect()[0][0]
ret = PyDataFrame.from_sql("select * from numbers").select([col("number")]).filter(expr_0).collect()[0]
return ret
"#
.to_string(),
@@ -595,8 +608,13 @@ def answer() -> vector[i64]:
@copr(returns=["value"], backend="pyo3")
def answer() -> vector[i64]:
from greptime import vector
import pyarrow as pa
a = vector.from_pyarrow(pa.array([42, 43, 44]))
try:
import pyarrow as pa
except ImportError:
# Python didn't have pyarrow
print("Warning: no pyarrow in current python")
return vector([42, 43, 44])
a = vector.from_pyarrow(pa.array([42]))
return a[0:1]
"#
.to_string(),
@@ -691,6 +709,36 @@ def normalize0(x):
def normalize(v) -> vector[i64]:
return [normalize0(x) for x in v]
"#
.to_string(),
expect: Some(ronish!(
"value": vector!(Int64Vector, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,])
)),
},
#[cfg(feature = "pyo3_backend")]
CoprTestCase {
script: r#"
import math
@coprocessor(args=[], returns=["value"], backend="pyo3")
def test_numpy() -> vector[i64]:
try:
import numpy as np
import pyarrow as pa
except ImportError as e:
# Python didn't have numpy or pyarrow
print("Warning: no pyarrow or numpy found in current python", e)
return vector([0, 1, 2, 3, 4, 5, 6, 7, 8, 9,])
from greptime import vector
v = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9,])
v = pa.array(v)
v = vector.from_pyarrow(v)
v = vector.from_numpy(v.numpy())
v = v.to_pyarrow()
v = v.to_numpy()
v = vector.from_numpy(v)
return v
"#
.to_string(),
expect: Some(ronish!(

View File

@@ -0,0 +1,137 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! PyRecordBatch is a Python class that wraps a RecordBatch,
//! PyRecordBatch is a Python class that wraps a RecordBatch,
//! and provide a PyMapping Protocol to
//! access the columns of the RecordBatch.
use common_recordbatch::RecordBatch;
use crossbeam_utils::atomic::AtomicCell;
#[cfg(feature = "pyo3_backend")]
use pyo3::{
exceptions::{PyKeyError, PyRuntimeError},
pyclass as pyo3class, pymethods, PyObject, PyResult, Python,
};
use rustpython_vm::builtins::PyStr;
use rustpython_vm::protocol::PyMappingMethods;
use rustpython_vm::types::AsMapping;
use rustpython_vm::{
atomic_func, pyclass as rspyclass, PyObject as RsPyObject, PyPayload, PyResult as RsPyResult,
VirtualMachine,
};
use crate::python::ffi_types::PyVector;
/// This is a Wrapper around a RecordBatch, impl PyMapping Protocol so you can do both `a[0]` and `a["number"]` to retrieve column.
#[cfg_attr(feature = "pyo3_backend", pyo3class(name = "PyRecordBatch"))]
#[rspyclass(module = false, name = "PyRecordBatch")]
#[derive(Debug, PyPayload)]
pub(crate) struct PyRecordBatch {
record_batch: RecordBatch,
}
impl PyRecordBatch {
pub fn new(record_batch: RecordBatch) -> Self {
Self { record_batch }
}
}
impl From<RecordBatch> for PyRecordBatch {
fn from(record_batch: RecordBatch) -> Self {
Self::new(record_batch)
}
}
#[cfg(feature = "pyo3_backend")]
#[pymethods]
impl PyRecordBatch {
fn __repr__(&self) -> String {
// TODO(discord9): a better pretty print
format!("{:#?}", &self.record_batch.df_record_batch())
}
fn __getitem__(&self, py: Python, key: PyObject) -> PyResult<PyVector> {
let column = if let Ok(key) = key.extract::<String>(py) {
self.record_batch.column_by_name(&key)
} else if let Ok(key) = key.extract::<usize>(py) {
Some(self.record_batch.column(key))
} else {
return Err(PyRuntimeError::new_err(format!(
"Expect either str or int, found {key:?}"
)));
}
.ok_or_else(|| PyKeyError::new_err(format!("Column {} not found", key)))?;
let v = PyVector::from(column.clone());
Ok(v)
}
fn __iter__(&self) -> PyResult<Vec<PyVector>> {
let iter: Vec<_> = self
.record_batch
.columns()
.iter()
.map(|i| PyVector::from(i.clone()))
.collect();
Ok(iter)
}
fn __len__(&self) -> PyResult<usize> {
Ok(self.len())
}
}
impl PyRecordBatch {
fn len(&self) -> usize {
self.record_batch.num_rows()
}
fn get_item(&self, needle: &RsPyObject, vm: &VirtualMachine) -> RsPyResult {
if let Ok(index) = needle.try_to_value::<usize>(vm) {
let column = self.record_batch.column(index);
let v = PyVector::from(column.clone());
Ok(v.into_pyobject(vm))
} else if let Ok(index) = needle.try_to_value::<String>(vm) {
let key = index.as_str();
let v = self.record_batch.column_by_name(key).ok_or_else(|| {
vm.new_key_error(PyStr::from(format!("Column {} not found", key)).into_pyobject(vm))
})?;
let v: PyVector = v.clone().into();
Ok(v.into_pyobject(vm))
} else {
Err(vm.new_key_error(
PyStr::from(format!("Expect either str or int, found {needle:?}"))
.into_pyobject(vm),
))
}
}
}
#[rspyclass(with(AsMapping))]
impl PyRecordBatch {
#[pymethod(name = "__repr__")]
fn rspy_repr(&self) -> String {
format!("{:#?}", &self.record_batch.df_record_batch())
}
}
impl AsMapping for PyRecordBatch {
fn as_mapping() -> &'static PyMappingMethods {
static AS_MAPPING: PyMappingMethods = PyMappingMethods {
length: atomic_func!(|mapping, _vm| Ok(PyRecordBatch::mapping_downcast(mapping).len())),
subscript: atomic_func!(
|mapping, needle, vm| PyRecordBatch::mapping_downcast(mapping).get_item(needle, vm)
),
ass_subscript: AtomicCell::new(None),
};
&AS_MAPPING
}
}

View File

@@ -126,6 +126,8 @@ fn get_globals(py: Python) -> PyResult<&PyDict> {
Ok(globals)
}
/// In case of not wanting to repeat the same sql statement in sql,
/// this function is still useful even we already have PyDataFrame.from_sql()
#[pyfunction]
fn dataframe(py: Python) -> PyResult<PyDataFrame> {
let globals = get_globals(py)?;

View File

@@ -14,6 +14,7 @@
use std::collections::HashMap;
use arrow::compute;
use common_recordbatch::RecordBatch;
use common_telemetry::timer;
use datafusion_common::ScalarValue;
@@ -21,11 +22,12 @@ use datatypes::prelude::ConcreteDataType;
use datatypes::vectors::{Helper, VectorRef};
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyModule, PyString, PyTuple};
use pyo3::{pymethods, PyAny, PyCell, PyObject, PyResult, Python, ToPyObject};
use pyo3::{pymethods, IntoPy, PyAny, PyCell, PyObject, PyResult, Python, ToPyObject};
use snafu::{ensure, Location, ResultExt};
use crate::python::error::{self, NewRecordBatchSnafu, OtherSnafu, Result};
use crate::python::ffi_types::copr::PyQueryEngine;
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector};
use crate::python::metric;
use crate::python::pyo3::dataframe_impl::PyDataFrame;
@@ -36,23 +38,22 @@ impl PyQueryEngine {
#[pyo3(name = "sql")]
pub(crate) fn sql_pyo3(&self, py: Python<'_>, s: String) -> PyResult<PyObject> {
let res = self
.query_with_new_thread(s)
.query_with_new_thread(s.clone())
.map_err(PyValueError::new_err)?;
match res {
crate::python::ffi_types::copr::Either::Rb(rbs) => {
let mut top_vec = Vec::with_capacity(rbs.iter().count());
for rb in rbs.iter() {
let mut vec_of_vec = Vec::with_capacity(rb.columns().len());
for v in rb.columns() {
let v = PyVector::from(v.clone());
let v = PyCell::new(py, v)?;
vec_of_vec.push(v.to_object(py));
}
let vec_of_vec = PyList::new(py, vec_of_vec);
top_vec.push(vec_of_vec);
}
let top_vec = PyList::new(py, top_vec);
Ok(top_vec.to_object(py))
let rb = compute::concat_batches(
rbs.schema().arrow_schema(),
rbs.iter().map(|rb| rb.df_record_batch()),
)
.map_err(|e| PyRuntimeError::new_err(format!("{e:?}")))?;
let rb = RecordBatch::try_from_df_record_batch(rbs.schema(), rb).map_err(|e| {
PyRuntimeError::new_err(format!(
"Convert datafusion record batch to record batch failed for query {s}: {e}"
))
})?;
let rb = PyRecordBatch::new(rb);
Ok(rb.into_py(py))
}
crate::python::ffi_types::copr::Either::AffectedRows(count) => Ok(count.to_object(py)),
}

View File

@@ -12,17 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use common_recordbatch::DfRecordBatch;
use arrow::compute;
use common_recordbatch::{DfRecordBatch, RecordBatch};
use datafusion::dataframe::DataFrame as DfDataFrame;
use datafusion_expr::Expr as DfExpr;
use datatypes::schema::Schema;
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::{PyList, PyType};
use pyo3::types::{PyDict, PyType};
use snafu::ResultExt;
use crate::python::error::DataFusionSnafu;
use crate::python::ffi_types::PyVector;
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
use crate::python::pyo3::builtins::query_engine;
use crate::python::pyo3::utils::pyo3_obj_try_to_typed_scalar_value;
use crate::python::utils::block_on_async;
@@ -210,32 +212,32 @@ impl PyDataFrame {
.map_err(|e| PyValueError::new_err(e.to_string()))?
.into())
}
/// collect `DataFrame` results into `List[List[Vector]]`
fn collect<'a>(&self, py: Python<'a>) -> PyResult<&'a PyList> {
/// collect `DataFrame` results into `PyRecordBatch` that impl Mapping Protocol
fn collect(&self, py: Python) -> PyResult<PyObject> {
let inner = self.inner.clone();
let res = block_on_async(async { inner.collect().await });
let res = res
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?
.map_err(|e| PyValueError::new_err(e.to_string()))?;
let outer_list: Vec<PyObject> = res
.iter()
.map(|elem| -> PyResult<_> {
let inner_list: Vec<_> = elem
.columns()
.iter()
.map(|arr| -> PyResult<_> {
datatypes::vectors::Helper::try_into_vector(arr)
.map(PyVector::from)
.map(|v| PyCell::new(py, v))
.map_err(|e| PyValueError::new_err(e.to_string()))
.and_then(|x| x)
})
.collect::<Result<_, _>>()?;
let inner_list = PyList::new(py, inner_list);
Ok(inner_list.into())
})
.collect::<Result<_, _>>()?;
Ok(PyList::new(py, outer_list))
if res.is_empty() {
return Ok(PyDict::new(py).into());
}
let concat_rb = compute::concat_batches(&res[0].schema(), res.iter()).map_err(|e| {
PyRuntimeError::new_err(format!("Concat batches failed for dataframe {self:?}: {e}"))
})?;
let schema = Schema::try_from(concat_rb.schema()).map_err(|e| {
PyRuntimeError::new_err(format!(
"Convert to Schema failed for dataframe {self:?}: {e}"
))
})?;
let rb = RecordBatch::try_from_df_record_batch(schema.into(), concat_rb).map_err(|e| {
PyRuntimeError::new_err(format!(
"Convert to RecordBatch failed for dataframe {self:?}: {e}"
))
})?;
let rb = PyRecordBatch::new(rb);
Ok(rb.into_py(py))
}
}

View File

@@ -24,7 +24,7 @@ use datatypes::vectors::Helper;
use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString, PyType};
use pyo3::types::{PyBool, PyFloat, PyInt, PySequence, PySlice, PyString, PyType};
use super::utils::val_to_py_any;
use crate::python::ffi_types::vector::{arrow_rtruediv, wrap_bool_result, wrap_result, PyVector};
@@ -93,12 +93,28 @@ impl PyVector {
#[pymethods]
impl PyVector {
/// convert from numpy array to [`PyVector`]
#[classmethod]
fn from_numpy(cls: &PyType, py: Python<'_>, obj: PyObject) -> PyResult<PyObject> {
let pa = py.import("pyarrow")?;
let obj = pa.call_method1("array", (obj,))?;
let zelf = Self::from_pyarrow(cls, py, obj.into())?;
Ok(zelf.into_py(py))
}
fn numpy(&self, py: Python<'_>) -> PyResult<PyObject> {
let pa_arrow = self.to_arrow_array().data().to_pyarrow(py)?;
let ndarray = pa_arrow.call_method0(py, "to_numpy")?;
Ok(ndarray)
}
/// create a `PyVector` with a `PyList` that contains only elements of same type
#[new]
pub(crate) fn py_new(iterable: &PyList) -> PyResult<Self> {
pub(crate) fn py_new(iterable: PyObject, py: Python<'_>) -> PyResult<Self> {
let iterable = iterable.downcast::<PySequence>(py)?;
let dtype = get_py_type(iterable.get_item(0)?)?;
let mut buf = dtype.create_mutable_vector(iterable.len());
for i in 0..iterable.len() {
let mut buf = dtype.create_mutable_vector(iterable.len()?);
for i in 0..iterable.len()? {
let element = iterable.get_item(i)?;
let val = pyo3_obj_try_to_typed_val(element, Some(dtype.clone()))?;
buf.push_value_ref(val.as_value_ref());

View File

@@ -29,6 +29,7 @@ use snafu::{OptionExt, ResultExt};
use crate::python::error::{ensure, ret_other_error_with, NewRecordBatchSnafu, OtherSnafu, Result};
use crate::python::ffi_types::copr::PyQueryEngine;
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector};
use crate::python::metric;
use crate::python::rspython::builtins::init_greptime_builtins;
@@ -216,6 +217,7 @@ pub(crate) fn init_interpreter() -> Arc<Interpreter> {
// add our own custom datatype and module
PyVector::make_class(&vm.ctx);
PyQueryEngine::make_class(&vm.ctx);
PyRecordBatch::make_class(&vm.ctx);
init_greptime_builtins("greptime", vm);
init_data_frame("data_frame", vm);
}));

View File

@@ -28,7 +28,6 @@ pub(crate) mod data_frame {
use datafusion::dataframe::DataFrame as DfDataFrame;
use datafusion::execution::context::SessionContext;
use datafusion_expr::Expr as DfExpr;
use rustpython_vm::builtins::{PyList, PyListRef};
use rustpython_vm::function::PyComparisonValue;
use rustpython_vm::types::{Comparable, PyComparisonOp};
use rustpython_vm::{
@@ -37,7 +36,7 @@ pub(crate) mod data_frame {
use snafu::ResultExt;
use crate::python::error::DataFusionSnafu;
use crate::python::ffi_types::PyVector;
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
use crate::python::rspython::builtins::greptime_builtin::{
lit, query as get_query_engine, PyDataFrame,
};
@@ -235,31 +234,38 @@ pub(crate) mod data_frame {
}
#[pymethod]
/// collect `DataFrame` results into `List[List[Vector]]`
fn collect(&self, vm: &VirtualMachine) -> PyResult<PyListRef> {
/// collect `DataFrame` results into `PyRecordBatch` that impl Mapping Protocol
fn collect(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
let inner = self.inner.clone();
let res = block_on_async(async { inner.collect().await });
let res = res
.map_err(|e| vm.new_runtime_error(format!("{e:?}")))?
.map_err(|e| vm.new_runtime_error(e.to_string()))?;
let outer_list: Vec<_> = res
.iter()
.map(|elem| -> PyResult<_> {
let inner_list: Vec<_> = elem
.columns()
.iter()
.map(|arr| -> PyResult<_> {
datatypes::vectors::Helper::try_into_vector(arr)
.map(PyVector::from)
.map(|v| vm.new_pyobj(v))
.map_err(|e| vm.new_runtime_error(e.to_string()))
})
.collect::<Result<_, _>>()?;
let inner_list = PyList::new_ref(inner_list, vm.as_ref());
Ok(inner_list.into())
})
.collect::<Result<_, _>>()?;
Ok(PyList::new_ref(outer_list, vm.as_ref()))
if res.is_empty() {
return Ok(vm.ctx.new_dict().into());
}
let concat_rb =
arrow::compute::concat_batches(&res[0].schema(), res.iter()).map_err(|e| {
vm.new_runtime_error(format!(
"Concat batches failed for dataframe {self:?}: {e}"
))
})?;
// we are inside a macro, so using full path
let schema = datatypes::schema::Schema::try_from(concat_rb.schema()).map_err(|e| {
vm.new_runtime_error(format!(
"Convert to Schema failed for dataframe {self:?}: {e}"
))
})?;
let rb =
RecordBatch::try_from_df_record_batch(schema.into(), concat_rb).map_err(|e| {
vm.new_runtime_error(format!(
"Convert to RecordBatch failed for dataframe {self:?}: {e}"
))
})?;
let rb = PyRecordBatch::new(rb);
Ok(rb.into_pyobject(vm))
}
}