mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-18 14:00:39 +00:00
feat: from/to numpy&collect concat (#1339)
* feat: from/to numpy&collect concat * feat: PyRecordBatch * test: try import first,allow w/out numpy/pyarrow * fix: cond compile flag * doc: license * feat: sql() ret PyRecordBatch&repr * fix: after merge * style: fmt * chore: CR advices * docs: update * chore: resolve conflict
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
// limitations under the License.
|
||||
|
||||
// TODO(discord9): spawn new process for executing python script(if hit gil limit) and use shared memory to communicate
|
||||
#![deny(clippy::implicit_clone)]
|
||||
|
||||
pub mod engine;
|
||||
pub mod error;
|
||||
#[cfg(feature = "python")]
|
||||
|
||||
@@ -439,8 +439,8 @@ from greptime import col
|
||||
|
||||
@copr(args=["number"], returns = ["number"], sql = "select * from numbers")
|
||||
def test(number) -> vector[u32]:
|
||||
from greptime import dataframe
|
||||
return dataframe().filter(col("number")==col("number")).collect()[0][0]
|
||||
from greptime import PyDataFrame
|
||||
return PyDataFrame.from_sql("select * from numbers").filter(col("number")==col("number")).collect()[0][0]
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
pub(crate) mod copr;
|
||||
pub(crate) mod py_recordbatch;
|
||||
pub(crate) mod utils;
|
||||
pub(crate) mod vector;
|
||||
pub(crate) use copr::{check_args_anno_real_type, select_from_rb, Coprocessor};
|
||||
|
||||
@@ -36,10 +36,10 @@ use rustpython_vm as vm;
|
||||
use serde::Deserialize;
|
||||
use session::context::QueryContext;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use vm::builtins::{PyList, PyListRef};
|
||||
use vm::convert::ToPyObject;
|
||||
use vm::{pyclass as rspyclass, PyPayload, PyResult, VirtualMachine};
|
||||
use vm::{pyclass as rspyclass, PyObjectRef, PyPayload, PyResult, VirtualMachine};
|
||||
|
||||
use super::py_recordbatch::PyRecordBatch;
|
||||
use crate::python::error::{ensure, ArrowSnafu, OtherSnafu, Result, TypeCastSnafu};
|
||||
use crate::python::ffi_types::PyVector;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
@@ -410,9 +410,11 @@ impl PyQueryEngine {
|
||||
.map_err(|e| format!("Dedicated thread for sql query panic: {e:?}"))?
|
||||
}
|
||||
// TODO(discord9): find a better way to call sql query api, now we don't if we are in async context or not
|
||||
/// return sql query results in List[PyVector], or List[usize] for AffectedRows number if no recordbatches is returned
|
||||
/// - return sql query results in `PyRecordBatch`, or
|
||||
/// - a empty `PyDict` if query results is empty
|
||||
/// - or number of AffectedRows
|
||||
#[pymethod]
|
||||
fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult<PyListRef> {
|
||||
fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
self.query_with_new_thread(s)
|
||||
.map_err(|e| vm.new_system_error(e))
|
||||
.map(|rbs| match rbs {
|
||||
@@ -428,17 +430,11 @@ impl PyQueryEngine {
|
||||
RecordBatch::try_from_df_record_batch(rbs.schema(), rb).map_err(|e| {
|
||||
vm.new_runtime_error(format!("Failed to cast recordbatch: {e:#?}"))
|
||||
})?;
|
||||
let columns_vectors = rb
|
||||
.columns()
|
||||
.iter()
|
||||
.map(|v| PyVector::from(v.clone()).to_pyobject(vm))
|
||||
.collect::<Vec<_>>();
|
||||
Ok(PyList::new_ref(columns_vectors, vm.as_ref()))
|
||||
let rb = PyRecordBatch::new(rb);
|
||||
|
||||
Ok(rb.to_pyobject(vm))
|
||||
}
|
||||
Either::AffectedRows(cnt) => Ok(PyList::new_ref(
|
||||
vec![vm.ctx.new_int(cnt).into()],
|
||||
vm.as_ref(),
|
||||
)),
|
||||
Either::AffectedRows(cnt) => Ok(vm.ctx.new_int(cnt).to_pyobject(vm)),
|
||||
})?
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,16 +414,16 @@ def boolean_array() -> vector[f64]:
|
||||
@copr(returns=["value"], backend="pyo3")
|
||||
def boolean_array() -> vector[f64]:
|
||||
from greptime import vector
|
||||
from greptime import query, dataframe
|
||||
from greptime import query
|
||||
|
||||
print("query()=", query())
|
||||
assert "query_engine object at" in str(query())
|
||||
try:
|
||||
print("dataframe()=", dataframe())
|
||||
except KeyError as e:
|
||||
print("dataframe()=", e)
|
||||
print(str(e), type(str(e)), 'No __dataframe__' in str(e))
|
||||
assert 'No __dataframe__' in str(e)
|
||||
rb = query().sql(
|
||||
"select number from numbers limit 5"
|
||||
)
|
||||
print(rb)
|
||||
assert len(rb) == 5
|
||||
|
||||
|
||||
v = vector([1.0, 2.0, 3.0])
|
||||
# This returns a vector([2.0])
|
||||
@@ -437,24 +437,25 @@ def boolean_array() -> vector[f64]:
|
||||
@copr(returns=["value"], backend="rspy")
|
||||
def boolean_array() -> vector[f64]:
|
||||
from greptime import vector, col
|
||||
from greptime import query, dataframe, PyDataFrame
|
||||
from greptime import query, PyDataFrame
|
||||
|
||||
df = PyDataFrame.from_sql("select number from numbers limit 5")
|
||||
print("df from sql=", df)
|
||||
collected = df.collect()
|
||||
print("df.collect()=", collected)
|
||||
assert len(collected[0][0]) == 5
|
||||
assert len(collected[0]) == 5
|
||||
df = PyDataFrame.from_sql("select number from numbers limit 5").filter(col("number") > 2)
|
||||
collected = df.collect()
|
||||
assert len(collected[0][0]) == 2
|
||||
assert len(collected[0]) == 2
|
||||
assert collected[0] == collected["number"]
|
||||
print("query()=", query())
|
||||
|
||||
assert "query_engine object at" in repr(query())
|
||||
try:
|
||||
print("dataframe()=", dataframe())
|
||||
except KeyError as e:
|
||||
print("dataframe()=", e)
|
||||
assert "__dataframe__" in str(e)
|
||||
rb = query().sql(
|
||||
"select number from numbers limit 5"
|
||||
)
|
||||
print(rb)
|
||||
assert len(rb) == 5
|
||||
|
||||
v = vector([1.0, 2.0, 3.0])
|
||||
# This returns a vector([2.0])
|
||||
@@ -469,16 +470,17 @@ def boolean_array() -> vector[f64]:
|
||||
@copr(returns=["value"], backend="pyo3")
|
||||
def boolean_array() -> vector[f64]:
|
||||
from greptime import vector
|
||||
from greptime import query, dataframe, PyDataFrame, col
|
||||
from greptime import query, PyDataFrame, col
|
||||
df = PyDataFrame.from_sql("select number from numbers limit 5")
|
||||
print("df from sql=", df)
|
||||
ret = df.collect()
|
||||
print("df.collect()=", ret)
|
||||
assert len(ret[0][0]) == 5
|
||||
assert len(ret[0]) == 5
|
||||
df = PyDataFrame.from_sql("select number from numbers limit 5").filter(col("number") > 2)
|
||||
collected = df.collect()
|
||||
assert len(collected[0][0]) == 2
|
||||
return ret[0][0]
|
||||
assert len(collected[0]) == 2
|
||||
assert collected[0] == collected["number"]
|
||||
return ret[0]
|
||||
"#
|
||||
.to_string(),
|
||||
expect: Some(ronish!("value": vector!(UInt32Vector, [0, 1, 2, 3, 4]))),
|
||||
@@ -546,6 +548,12 @@ def answer() -> vector[i64]:
|
||||
@copr(returns=["value"], backend="pyo3")
|
||||
def answer() -> vector[i64]:
|
||||
from greptime import vector
|
||||
try:
|
||||
import pyarrow as pa
|
||||
except ImportError:
|
||||
# Python didn't have pyarrow
|
||||
print("Warning: no pyarrow in current python")
|
||||
return vector([42, 43, 44])
|
||||
return vector.from_pyarrow(vector([42, 43, 44]).to_pyarrow())
|
||||
"#
|
||||
.to_string(),
|
||||
@@ -557,7 +565,12 @@ def answer() -> vector[i64]:
|
||||
@copr(returns=["value"], backend="pyo3")
|
||||
def answer() -> vector[i64]:
|
||||
from greptime import vector
|
||||
import pyarrow as pa
|
||||
try:
|
||||
import pyarrow as pa
|
||||
except ImportError:
|
||||
# Python didn't have pyarrow
|
||||
print("Warning: no pyarrow in current python")
|
||||
return vector([42, 43, 44])
|
||||
return vector.from_pyarrow(pa.array([42, 43, 44]))
|
||||
"#
|
||||
.to_string(),
|
||||
@@ -567,9 +580,9 @@ def answer() -> vector[i64]:
|
||||
script: r#"
|
||||
@copr(args=[], returns = ["number"], sql = "select * from numbers", backend="rspy")
|
||||
def answer() -> vector[i64]:
|
||||
from greptime import vector, col, lit, dataframe
|
||||
from greptime import vector, col, lit, PyDataFrame
|
||||
expr_0 = (col("number")<lit(3)) & (col("number")>0)
|
||||
ret = dataframe().select([col("number")]).filter(expr_0).collect()[0][0]
|
||||
ret = PyDataFrame.from_sql("select * from numbers").select([col("number")]).filter(expr_0).collect()[0]
|
||||
return ret
|
||||
"#
|
||||
.to_string(),
|
||||
@@ -580,10 +593,10 @@ def answer() -> vector[i64]:
|
||||
script: r#"
|
||||
@copr(args=[], returns = ["number"], sql = "select * from numbers", backend="pyo3")
|
||||
def answer() -> vector[i64]:
|
||||
from greptime import vector, col, lit, dataframe
|
||||
from greptime import vector, col, lit, PyDataFrame
|
||||
# Bitwise Operator pred comparison operator
|
||||
expr_0 = (col("number")<lit(3)) & (col("number")>0)
|
||||
ret = dataframe().select([col("number")]).filter(expr_0).collect()[0][0]
|
||||
ret = PyDataFrame.from_sql("select * from numbers").select([col("number")]).filter(expr_0).collect()[0]
|
||||
return ret
|
||||
"#
|
||||
.to_string(),
|
||||
@@ -595,8 +608,13 @@ def answer() -> vector[i64]:
|
||||
@copr(returns=["value"], backend="pyo3")
|
||||
def answer() -> vector[i64]:
|
||||
from greptime import vector
|
||||
import pyarrow as pa
|
||||
a = vector.from_pyarrow(pa.array([42, 43, 44]))
|
||||
try:
|
||||
import pyarrow as pa
|
||||
except ImportError:
|
||||
# Python didn't have pyarrow
|
||||
print("Warning: no pyarrow in current python")
|
||||
return vector([42, 43, 44])
|
||||
a = vector.from_pyarrow(pa.array([42]))
|
||||
return a[0:1]
|
||||
"#
|
||||
.to_string(),
|
||||
@@ -691,6 +709,36 @@ def normalize0(x):
|
||||
def normalize(v) -> vector[i64]:
|
||||
return [normalize0(x) for x in v]
|
||||
|
||||
"#
|
||||
.to_string(),
|
||||
expect: Some(ronish!(
|
||||
"value": vector!(Int64Vector, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9,])
|
||||
)),
|
||||
},
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
CoprTestCase {
|
||||
script: r#"
|
||||
import math
|
||||
|
||||
@coprocessor(args=[], returns=["value"], backend="pyo3")
|
||||
def test_numpy() -> vector[i64]:
|
||||
try:
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
except ImportError as e:
|
||||
# Python didn't have numpy or pyarrow
|
||||
print("Warning: no pyarrow or numpy found in current python", e)
|
||||
return vector([0, 1, 2, 3, 4, 5, 6, 7, 8, 9,])
|
||||
from greptime import vector
|
||||
v = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9,])
|
||||
v = pa.array(v)
|
||||
v = vector.from_pyarrow(v)
|
||||
v = vector.from_numpy(v.numpy())
|
||||
v = v.to_pyarrow()
|
||||
v = v.to_numpy()
|
||||
v = vector.from_numpy(v)
|
||||
return v
|
||||
|
||||
"#
|
||||
.to_string(),
|
||||
expect: Some(ronish!(
|
||||
|
||||
137
src/script/src/python/ffi_types/py_recordbatch.rs
Normal file
137
src/script/src/python/ffi_types/py_recordbatch.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//! PyRecordBatch is a Python class that wraps a RecordBatch,
|
||||
|
||||
//! PyRecordBatch is a Python class that wraps a RecordBatch,
|
||||
//! and provide a PyMapping Protocol to
|
||||
//! access the columns of the RecordBatch.
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::{
|
||||
exceptions::{PyKeyError, PyRuntimeError},
|
||||
pyclass as pyo3class, pymethods, PyObject, PyResult, Python,
|
||||
};
|
||||
use rustpython_vm::builtins::PyStr;
|
||||
use rustpython_vm::protocol::PyMappingMethods;
|
||||
use rustpython_vm::types::AsMapping;
|
||||
use rustpython_vm::{
|
||||
atomic_func, pyclass as rspyclass, PyObject as RsPyObject, PyPayload, PyResult as RsPyResult,
|
||||
VirtualMachine,
|
||||
};
|
||||
|
||||
use crate::python::ffi_types::PyVector;
|
||||
|
||||
/// This is a Wrapper around a RecordBatch, impl PyMapping Protocol so you can do both `a[0]` and `a["number"]` to retrieve column.
|
||||
#[cfg_attr(feature = "pyo3_backend", pyo3class(name = "PyRecordBatch"))]
|
||||
#[rspyclass(module = false, name = "PyRecordBatch")]
|
||||
#[derive(Debug, PyPayload)]
|
||||
pub(crate) struct PyRecordBatch {
|
||||
record_batch: RecordBatch,
|
||||
}
|
||||
|
||||
impl PyRecordBatch {
|
||||
pub fn new(record_batch: RecordBatch) -> Self {
|
||||
Self { record_batch }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RecordBatch> for PyRecordBatch {
|
||||
fn from(record_batch: RecordBatch) -> Self {
|
||||
Self::new(record_batch)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
#[pymethods]
|
||||
impl PyRecordBatch {
|
||||
fn __repr__(&self) -> String {
|
||||
// TODO(discord9): a better pretty print
|
||||
format!("{:#?}", &self.record_batch.df_record_batch())
|
||||
}
|
||||
fn __getitem__(&self, py: Python, key: PyObject) -> PyResult<PyVector> {
|
||||
let column = if let Ok(key) = key.extract::<String>(py) {
|
||||
self.record_batch.column_by_name(&key)
|
||||
} else if let Ok(key) = key.extract::<usize>(py) {
|
||||
Some(self.record_batch.column(key))
|
||||
} else {
|
||||
return Err(PyRuntimeError::new_err(format!(
|
||||
"Expect either str or int, found {key:?}"
|
||||
)));
|
||||
}
|
||||
.ok_or_else(|| PyKeyError::new_err(format!("Column {} not found", key)))?;
|
||||
let v = PyVector::from(column.clone());
|
||||
Ok(v)
|
||||
}
|
||||
fn __iter__(&self) -> PyResult<Vec<PyVector>> {
|
||||
let iter: Vec<_> = self
|
||||
.record_batch
|
||||
.columns()
|
||||
.iter()
|
||||
.map(|i| PyVector::from(i.clone()))
|
||||
.collect();
|
||||
Ok(iter)
|
||||
}
|
||||
fn __len__(&self) -> PyResult<usize> {
|
||||
Ok(self.len())
|
||||
}
|
||||
}
|
||||
|
||||
impl PyRecordBatch {
|
||||
fn len(&self) -> usize {
|
||||
self.record_batch.num_rows()
|
||||
}
|
||||
fn get_item(&self, needle: &RsPyObject, vm: &VirtualMachine) -> RsPyResult {
|
||||
if let Ok(index) = needle.try_to_value::<usize>(vm) {
|
||||
let column = self.record_batch.column(index);
|
||||
let v = PyVector::from(column.clone());
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else if let Ok(index) = needle.try_to_value::<String>(vm) {
|
||||
let key = index.as_str();
|
||||
|
||||
let v = self.record_batch.column_by_name(key).ok_or_else(|| {
|
||||
vm.new_key_error(PyStr::from(format!("Column {} not found", key)).into_pyobject(vm))
|
||||
})?;
|
||||
let v: PyVector = v.clone().into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else {
|
||||
Err(vm.new_key_error(
|
||||
PyStr::from(format!("Expect either str or int, found {needle:?}"))
|
||||
.into_pyobject(vm),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[rspyclass(with(AsMapping))]
|
||||
impl PyRecordBatch {
|
||||
#[pymethod(name = "__repr__")]
|
||||
fn rspy_repr(&self) -> String {
|
||||
format!("{:#?}", &self.record_batch.df_record_batch())
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMapping for PyRecordBatch {
|
||||
fn as_mapping() -> &'static PyMappingMethods {
|
||||
static AS_MAPPING: PyMappingMethods = PyMappingMethods {
|
||||
length: atomic_func!(|mapping, _vm| Ok(PyRecordBatch::mapping_downcast(mapping).len())),
|
||||
subscript: atomic_func!(
|
||||
|mapping, needle, vm| PyRecordBatch::mapping_downcast(mapping).get_item(needle, vm)
|
||||
),
|
||||
ass_subscript: AtomicCell::new(None),
|
||||
};
|
||||
&AS_MAPPING
|
||||
}
|
||||
}
|
||||
@@ -126,6 +126,8 @@ fn get_globals(py: Python) -> PyResult<&PyDict> {
|
||||
Ok(globals)
|
||||
}
|
||||
|
||||
/// In case of not wanting to repeat the same sql statement in sql,
|
||||
/// this function is still useful even we already have PyDataFrame.from_sql()
|
||||
#[pyfunction]
|
||||
fn dataframe(py: Python) -> PyResult<PyDataFrame> {
|
||||
let globals = get_globals(py)?;
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use arrow::compute;
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_telemetry::timer;
|
||||
use datafusion_common::ScalarValue;
|
||||
@@ -21,11 +22,12 @@ use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use pyo3::exceptions::{PyRuntimeError, PyValueError};
|
||||
use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyModule, PyString, PyTuple};
|
||||
use pyo3::{pymethods, PyAny, PyCell, PyObject, PyResult, Python, ToPyObject};
|
||||
use pyo3::{pymethods, IntoPy, PyAny, PyCell, PyObject, PyResult, Python, ToPyObject};
|
||||
use snafu::{ensure, Location, ResultExt};
|
||||
|
||||
use crate::python::error::{self, NewRecordBatchSnafu, OtherSnafu, Result};
|
||||
use crate::python::ffi_types::copr::PyQueryEngine;
|
||||
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
|
||||
use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector};
|
||||
use crate::python::metric;
|
||||
use crate::python::pyo3::dataframe_impl::PyDataFrame;
|
||||
@@ -36,23 +38,22 @@ impl PyQueryEngine {
|
||||
#[pyo3(name = "sql")]
|
||||
pub(crate) fn sql_pyo3(&self, py: Python<'_>, s: String) -> PyResult<PyObject> {
|
||||
let res = self
|
||||
.query_with_new_thread(s)
|
||||
.query_with_new_thread(s.clone())
|
||||
.map_err(PyValueError::new_err)?;
|
||||
match res {
|
||||
crate::python::ffi_types::copr::Either::Rb(rbs) => {
|
||||
let mut top_vec = Vec::with_capacity(rbs.iter().count());
|
||||
for rb in rbs.iter() {
|
||||
let mut vec_of_vec = Vec::with_capacity(rb.columns().len());
|
||||
for v in rb.columns() {
|
||||
let v = PyVector::from(v.clone());
|
||||
let v = PyCell::new(py, v)?;
|
||||
vec_of_vec.push(v.to_object(py));
|
||||
}
|
||||
let vec_of_vec = PyList::new(py, vec_of_vec);
|
||||
top_vec.push(vec_of_vec);
|
||||
}
|
||||
let top_vec = PyList::new(py, top_vec);
|
||||
Ok(top_vec.to_object(py))
|
||||
let rb = compute::concat_batches(
|
||||
rbs.schema().arrow_schema(),
|
||||
rbs.iter().map(|rb| rb.df_record_batch()),
|
||||
)
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("{e:?}")))?;
|
||||
let rb = RecordBatch::try_from_df_record_batch(rbs.schema(), rb).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!(
|
||||
"Convert datafusion record batch to record batch failed for query {s}: {e}"
|
||||
))
|
||||
})?;
|
||||
let rb = PyRecordBatch::new(rb);
|
||||
Ok(rb.into_py(py))
|
||||
}
|
||||
crate::python::ffi_types::copr::Either::AffectedRows(count) => Ok(count.to_object(py)),
|
||||
}
|
||||
|
||||
@@ -12,17 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_recordbatch::DfRecordBatch;
|
||||
use arrow::compute;
|
||||
use common_recordbatch::{DfRecordBatch, RecordBatch};
|
||||
use datafusion::dataframe::DataFrame as DfDataFrame;
|
||||
use datafusion_expr::Expr as DfExpr;
|
||||
use datatypes::schema::Schema;
|
||||
use pyo3::exceptions::{PyRuntimeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
use pyo3::types::{PyList, PyType};
|
||||
use pyo3::types::{PyDict, PyType};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::python::error::DataFusionSnafu;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
|
||||
use crate::python::pyo3::builtins::query_engine;
|
||||
use crate::python::pyo3::utils::pyo3_obj_try_to_typed_scalar_value;
|
||||
use crate::python::utils::block_on_async;
|
||||
@@ -210,32 +212,32 @@ impl PyDataFrame {
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
/// collect `DataFrame` results into `List[List[Vector]]`
|
||||
fn collect<'a>(&self, py: Python<'a>) -> PyResult<&'a PyList> {
|
||||
/// collect `DataFrame` results into `PyRecordBatch` that impl Mapping Protocol
|
||||
fn collect(&self, py: Python) -> PyResult<PyObject> {
|
||||
let inner = self.inner.clone();
|
||||
let res = block_on_async(async { inner.collect().await });
|
||||
let res = res
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?;
|
||||
let outer_list: Vec<PyObject> = res
|
||||
.iter()
|
||||
.map(|elem| -> PyResult<_> {
|
||||
let inner_list: Vec<_> = elem
|
||||
.columns()
|
||||
.iter()
|
||||
.map(|arr| -> PyResult<_> {
|
||||
datatypes::vectors::Helper::try_into_vector(arr)
|
||||
.map(PyVector::from)
|
||||
.map(|v| PyCell::new(py, v))
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
.and_then(|x| x)
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
let inner_list = PyList::new(py, inner_list);
|
||||
Ok(inner_list.into())
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(PyList::new(py, outer_list))
|
||||
if res.is_empty() {
|
||||
return Ok(PyDict::new(py).into());
|
||||
}
|
||||
let concat_rb = compute::concat_batches(&res[0].schema(), res.iter()).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!("Concat batches failed for dataframe {self:?}: {e}"))
|
||||
})?;
|
||||
|
||||
let schema = Schema::try_from(concat_rb.schema()).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!(
|
||||
"Convert to Schema failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
let rb = RecordBatch::try_from_df_record_batch(schema.into(), concat_rb).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!(
|
||||
"Convert to RecordBatch failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
let rb = PyRecordBatch::new(rb);
|
||||
Ok(rb.into_py(py))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ use datatypes::vectors::Helper;
|
||||
use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PySlice, PyString, PyType};
|
||||
use pyo3::types::{PyBool, PyFloat, PyInt, PySequence, PySlice, PyString, PyType};
|
||||
|
||||
use super::utils::val_to_py_any;
|
||||
use crate::python::ffi_types::vector::{arrow_rtruediv, wrap_bool_result, wrap_result, PyVector};
|
||||
@@ -93,12 +93,28 @@ impl PyVector {
|
||||
|
||||
#[pymethods]
|
||||
impl PyVector {
|
||||
/// convert from numpy array to [`PyVector`]
|
||||
#[classmethod]
|
||||
fn from_numpy(cls: &PyType, py: Python<'_>, obj: PyObject) -> PyResult<PyObject> {
|
||||
let pa = py.import("pyarrow")?;
|
||||
let obj = pa.call_method1("array", (obj,))?;
|
||||
let zelf = Self::from_pyarrow(cls, py, obj.into())?;
|
||||
Ok(zelf.into_py(py))
|
||||
}
|
||||
|
||||
fn numpy(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let pa_arrow = self.to_arrow_array().data().to_pyarrow(py)?;
|
||||
let ndarray = pa_arrow.call_method0(py, "to_numpy")?;
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// create a `PyVector` with a `PyList` that contains only elements of same type
|
||||
#[new]
|
||||
pub(crate) fn py_new(iterable: &PyList) -> PyResult<Self> {
|
||||
pub(crate) fn py_new(iterable: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||
let iterable = iterable.downcast::<PySequence>(py)?;
|
||||
let dtype = get_py_type(iterable.get_item(0)?)?;
|
||||
let mut buf = dtype.create_mutable_vector(iterable.len());
|
||||
for i in 0..iterable.len() {
|
||||
let mut buf = dtype.create_mutable_vector(iterable.len()?);
|
||||
for i in 0..iterable.len()? {
|
||||
let element = iterable.get_item(i)?;
|
||||
let val = pyo3_obj_try_to_typed_val(element, Some(dtype.clone()))?;
|
||||
buf.push_value_ref(val.as_value_ref());
|
||||
|
||||
@@ -29,6 +29,7 @@ use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::python::error::{ensure, ret_other_error_with, NewRecordBatchSnafu, OtherSnafu, Result};
|
||||
use crate::python::ffi_types::copr::PyQueryEngine;
|
||||
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
|
||||
use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector};
|
||||
use crate::python::metric;
|
||||
use crate::python::rspython::builtins::init_greptime_builtins;
|
||||
@@ -216,6 +217,7 @@ pub(crate) fn init_interpreter() -> Arc<Interpreter> {
|
||||
// add our own custom datatype and module
|
||||
PyVector::make_class(&vm.ctx);
|
||||
PyQueryEngine::make_class(&vm.ctx);
|
||||
PyRecordBatch::make_class(&vm.ctx);
|
||||
init_greptime_builtins("greptime", vm);
|
||||
init_data_frame("data_frame", vm);
|
||||
}));
|
||||
|
||||
@@ -28,7 +28,6 @@ pub(crate) mod data_frame {
|
||||
use datafusion::dataframe::DataFrame as DfDataFrame;
|
||||
use datafusion::execution::context::SessionContext;
|
||||
use datafusion_expr::Expr as DfExpr;
|
||||
use rustpython_vm::builtins::{PyList, PyListRef};
|
||||
use rustpython_vm::function::PyComparisonValue;
|
||||
use rustpython_vm::types::{Comparable, PyComparisonOp};
|
||||
use rustpython_vm::{
|
||||
@@ -37,7 +36,7 @@ pub(crate) mod data_frame {
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::python::error::DataFusionSnafu;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
|
||||
use crate::python::rspython::builtins::greptime_builtin::{
|
||||
lit, query as get_query_engine, PyDataFrame,
|
||||
};
|
||||
@@ -235,31 +234,38 @@ pub(crate) mod data_frame {
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
/// collect `DataFrame` results into `List[List[Vector]]`
|
||||
fn collect(&self, vm: &VirtualMachine) -> PyResult<PyListRef> {
|
||||
/// collect `DataFrame` results into `PyRecordBatch` that impl Mapping Protocol
|
||||
fn collect(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
let inner = self.inner.clone();
|
||||
let res = block_on_async(async { inner.collect().await });
|
||||
let res = res
|
||||
.map_err(|e| vm.new_runtime_error(format!("{e:?}")))?
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?;
|
||||
let outer_list: Vec<_> = res
|
||||
.iter()
|
||||
.map(|elem| -> PyResult<_> {
|
||||
let inner_list: Vec<_> = elem
|
||||
.columns()
|
||||
.iter()
|
||||
.map(|arr| -> PyResult<_> {
|
||||
datatypes::vectors::Helper::try_into_vector(arr)
|
||||
.map(PyVector::from)
|
||||
.map(|v| vm.new_pyobj(v))
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
let inner_list = PyList::new_ref(inner_list, vm.as_ref());
|
||||
Ok(inner_list.into())
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(PyList::new_ref(outer_list, vm.as_ref()))
|
||||
if res.is_empty() {
|
||||
return Ok(vm.ctx.new_dict().into());
|
||||
}
|
||||
let concat_rb =
|
||||
arrow::compute::concat_batches(&res[0].schema(), res.iter()).map_err(|e| {
|
||||
vm.new_runtime_error(format!(
|
||||
"Concat batches failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
// we are inside a macro, so using full path
|
||||
let schema = datatypes::schema::Schema::try_from(concat_rb.schema()).map_err(|e| {
|
||||
vm.new_runtime_error(format!(
|
||||
"Convert to Schema failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
let rb =
|
||||
RecordBatch::try_from_df_record_batch(schema.into(), concat_rb).map_err(|e| {
|
||||
vm.new_runtime_error(format!(
|
||||
"Convert to RecordBatch failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let rb = PyRecordBatch::new(rb);
|
||||
Ok(rb.into_pyobject(vm))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user