mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-26 18:00:41 +00:00
feat: add PyO3(Hence CPython as a Optional Backend (#976)
* refactor: ffi_types * style: fmt * refactor: use `String` for return when possible * todo: vector_impl * feat: pyobj_try_typed_val * refactor: more backend indep function * feat: +-*/ magic methods * refactor: copr * style: fmt * feat: add paired tests * refactor: more * refactor: move inside `python` folder * refactor: all but test code * feat: builtins for PyO3 * chore: add licenses * chore: remove unused&add todos * refactor: remove old files * chore: mark unused * chore: fmt * chore: license * feat: query in PyO3 * test: paired testcases for rspy&pyo3 * feat: PyDataFrame(Untested) * feat: some allow_threads * style: fmt * style: add license * feat: rebase manually of #962 * feat: more `allow_threads` * chore: typo * chore: remove some `TODO` * test: allow margin of epsilon * chore: code review advices * chore: more CR adjust * chore: more adjust * feat: kwargs&its test * chore: remove some `dbg!` * chore: allow params * fix: put `dataframe` into scope * chore: newline * fix: adjust after rebase * fix: test serde skip attr * style: taplo * feat: add `pyo3_backend` feature * doc: update CI&readme
This commit is contained in:
2
.github/workflows/develop.yml
vendored
2
.github/workflows/develop.yml
vendored
@@ -208,7 +208,7 @@ jobs:
|
||||
- name: Install cargo-llvm-cov
|
||||
uses: taiki-e/install-action@cargo-llvm-cov
|
||||
- name: Collect coverage data
|
||||
run: cargo llvm-cov nextest --workspace --lcov --output-path lcov.info
|
||||
run: cargo llvm-cov nextest --workspace --lcov --output-path lcov.info -F pyo3_backend
|
||||
env:
|
||||
CARGO_BUILD_RUSTFLAGS: "-C link-arg=-fuse-ld=lld"
|
||||
RUST_BACKTRACE: 1
|
||||
|
||||
88
Cargo.lock
generated
88
Cargo.lock
generated
@@ -3357,6 +3357,12 @@ dependencies = [
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indoc"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa799dd5ed20a7e349f3b4639aa80d74549c81716d9ec4f994c9b5815598306"
|
||||
|
||||
[[package]]
|
||||
name = "influxdb_line_protocol"
|
||||
version = "0.1.0"
|
||||
@@ -3929,6 +3935,15 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memoffset"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "meta-client"
|
||||
version = "0.1.0"
|
||||
@@ -5500,6 +5515,66 @@ version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b7e158a385023d209d6d5f2585c4b468f6dcb3dd5aca9b75c4f1678c05bb375"
|
||||
|
||||
[[package]]
|
||||
name = "pyo3"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06a3d8e8a46ab2738109347433cb7b96dffda2e4a218b03ef27090238886b147"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"indoc",
|
||||
"libc",
|
||||
"memoffset 0.8.0",
|
||||
"parking_lot",
|
||||
"pyo3-build-config",
|
||||
"pyo3-ffi",
|
||||
"pyo3-macros",
|
||||
"unindent",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-build-config"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75439f995d07ddfad42b192dfcf3bc66a7ecfd8b4a1f5f6f046aa5c2c5d7677d"
|
||||
dependencies = [
|
||||
"once_cell",
|
||||
"target-lexicon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-ffi"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "839526a5c07a17ff44823679b68add4a58004de00512a95b6c1c98a6dcac0ee5"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"pyo3-build-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd44cf207476c6a9760c4653559be4f206efafb924d3e4cbf2721475fc0d6cc5"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyo3-macros-backend"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc1f43d8e30460f36350d18631ccf85ded64c059829208fe680904c65bcd0a4c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.10.1"
|
||||
@@ -6601,6 +6676,7 @@ dependencies = [
|
||||
"mito",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"pyo3",
|
||||
"query",
|
||||
"ron",
|
||||
"rustpython-ast",
|
||||
@@ -7565,6 +7641,12 @@ version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
||||
|
||||
[[package]]
|
||||
name = "target-lexicon"
|
||||
version = "0.12.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ae9980cab1db3fceee2f6c6f643d5d8de2997c58ee8d25fb0cc8a9e9e7348e5"
|
||||
|
||||
[[package]]
|
||||
name = "tempdir"
|
||||
version = "0.3.7"
|
||||
@@ -8505,6 +8587,12 @@ version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "029df4cc8238cefc911704ff8fa210853a0f3bce2694d8f51181dd41ee0f3301"
|
||||
|
||||
[[package]]
|
||||
name = "unindent"
|
||||
version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1766d682d402817b5ac4490b3c3002d91dfa0d22812f341609f97b08757359c"
|
||||
|
||||
[[package]]
|
||||
name = "unsafe-libyaml"
|
||||
version = "0.2.5"
|
||||
|
||||
@@ -61,6 +61,11 @@ To compile GreptimeDB from source, you'll need:
|
||||
find an installation instructions [here](https://grpc.io/docs/protoc-installation/).
|
||||
**Note that `protoc` version needs to be >= 3.15** because we have used the `optional`
|
||||
keyword. You can check it with `protoc --version`.
|
||||
- python3-dev or python3-devel(Optional, only needed if you want to run scripts in cpython): this install a Python shared library required for running python scripting engine(In CPython Mode).
|
||||
This is available as `python3-dev` on ubuntu, you can install it with `sudo apt install python3-dev`, or `python3-devel` on RPM based distributions (e.g. Fedora, Red Hat, SuSE), Mac 's Python3 package should have this shared library by default.
|
||||
Then, you can build GreptimeDB from source code:
|
||||
|
||||
```
|
||||
|
||||
#### Build with Docker
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ RUN apt-get update && apt-get install -y \
|
||||
protobuf-compiler \
|
||||
curl \
|
||||
build-essential \
|
||||
pkg-config
|
||||
pkg-config \
|
||||
python3-dev
|
||||
|
||||
# Install Rust.
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
@@ -6,6 +6,7 @@ license.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["python"]
|
||||
pyo3_backend = ["pyo3"]
|
||||
python = [
|
||||
"dep:datafusion",
|
||||
"dep:datafusion-common",
|
||||
@@ -58,6 +59,7 @@ rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = tru
|
||||
"default",
|
||||
"codegen",
|
||||
] }
|
||||
pyo3 = { version = "0.18", optional = true }
|
||||
session = { path = "../session" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
sql = { path = "../sql" }
|
||||
|
||||
@@ -14,15 +14,14 @@
|
||||
|
||||
//! Python script coprocessor
|
||||
|
||||
mod builtins;
|
||||
pub(crate) mod coprocessor;
|
||||
mod dataframe;
|
||||
mod engine;
|
||||
pub mod error;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
pub(crate) mod utils;
|
||||
mod vector;
|
||||
|
||||
pub use self::engine::{PyEngine, PyScript};
|
||||
pub use self::vector::PyVector;
|
||||
|
||||
mod ffi_types;
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
mod pyo3;
|
||||
mod rspython;
|
||||
|
||||
@@ -40,8 +40,8 @@ use snafu::{ensure, ResultExt};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use crate::python::coprocessor::{exec_parsed, parse, AnnotationInfo, CoprocessorRef};
|
||||
use crate::python::error::{self, Result};
|
||||
use crate::python::ffi_types::copr::{exec_parsed, parse, AnnotationInfo, CoprocessorRef};
|
||||
|
||||
const PY_ENGINE: &str = "python";
|
||||
|
||||
|
||||
21
src/script/src/python/ffi_types.rs
Normal file
21
src/script/src/python/ffi_types.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub(crate) mod copr;
|
||||
pub(crate) mod utils;
|
||||
pub(crate) mod vector;
|
||||
pub(crate) use copr::{check_args_anno_real_type, select_from_rb, Coprocessor};
|
||||
pub(crate) use vector::PyVector;
|
||||
#[cfg(test)]
|
||||
mod pair_tests;
|
||||
@@ -15,43 +15,36 @@
|
||||
pub mod compile;
|
||||
pub mod parse;
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashMap;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
use common_recordbatch::{RecordBatch, RecordBatches};
|
||||
use common_telemetry::info;
|
||||
use datatypes::arrow::array::Array;
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::data_type::{ConcreteDataType, DataType};
|
||||
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
// use crate::python::builtins::greptime_builtin;
|
||||
use parse::DecoratorArgs;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::pyclass as pyo3class;
|
||||
use query::parser::QueryLanguageParser;
|
||||
use query::QueryEngine;
|
||||
use rustpython_compiler_core::CodeObject;
|
||||
use rustpython_vm as vm;
|
||||
use rustpython_vm::class::PyClassImpl;
|
||||
use rustpython_vm::AsObject;
|
||||
#[cfg(test)]
|
||||
use serde::Deserialize;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use vm::builtins::{PyBaseExceptionRef, PyDict, PyList, PyListRef, PyStr, PyTuple};
|
||||
use vm::builtins::{PyList, PyListRef};
|
||||
use vm::convert::ToPyObject;
|
||||
use vm::scope::Scope;
|
||||
use vm::{pyclass, Interpreter, PyObjectRef, PyPayload, PyResult, VirtualMachine};
|
||||
use vm::{pyclass as rspyclass, PyPayload, PyResult, VirtualMachine};
|
||||
|
||||
use crate::python::builtins::greptime_builtin;
|
||||
use crate::python::coprocessor::parse::DecoratorArgs;
|
||||
use crate::python::dataframe::data_frame::{self, set_dataframe_in_scope};
|
||||
use crate::python::error::{
|
||||
ensure, ret_other_error_with, ArrowSnafu, NewRecordBatchSnafu, OtherSnafu, Result,
|
||||
TypeCastSnafu,
|
||||
};
|
||||
use crate::python::utils::{format_py_error, is_instance, py_vec_obj_to_array};
|
||||
use crate::python::PyVector;
|
||||
|
||||
thread_local!(static INTERPRETER: RefCell<Option<Arc<Interpreter>>> = RefCell::new(None));
|
||||
use crate::python::error::{ensure, ArrowSnafu, OtherSnafu, Result, TypeCastSnafu};
|
||||
use crate::python::ffi_types::PyVector;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use crate::python::pyo3::pyo3_exec_parsed;
|
||||
use crate::python::rspython::rspy_exec_parsed;
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
@@ -62,6 +55,16 @@ pub struct AnnotationInfo {
|
||||
pub is_nullable: bool,
|
||||
}
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Debug, Default, Clone, Eq, PartialEq)]
|
||||
pub enum BackendType {
|
||||
#[default]
|
||||
RustPython,
|
||||
// TODO(discord9): intergral test
|
||||
#[allow(unused)]
|
||||
CPython,
|
||||
}
|
||||
|
||||
pub type CoprocessorRef = Arc<Coprocessor>;
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
@@ -85,6 +88,10 @@ pub struct Coprocessor {
|
||||
pub code_obj: Option<CodeObject>,
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub query_engine: Option<QueryEngineWeakRef>,
|
||||
/// Use which backend to run this script
|
||||
/// Ideally in test both backend should be tested, so skip this
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub backend: BackendType,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -126,7 +133,7 @@ impl Coprocessor {
|
||||
/// generate [`Schema`] according to return names, types,
|
||||
/// if no annotation
|
||||
/// the datatypes of the actual columns is used directly
|
||||
fn gen_schema(&self, cols: &[VectorRef]) -> Result<SchemaRef> {
|
||||
pub(crate) fn gen_schema(&self, cols: &[VectorRef]) -> Result<SchemaRef> {
|
||||
let names = &self.deco_args.ret_names;
|
||||
let anno = &self.return_types;
|
||||
ensure!(
|
||||
@@ -169,7 +176,7 @@ impl Coprocessor {
|
||||
}
|
||||
|
||||
/// check if real types and annotation types(if have) is the same, if not try cast columns to annotated type
|
||||
fn check_and_cast_type(&self, cols: &mut [VectorRef]) -> Result<()> {
|
||||
pub(crate) fn check_and_cast_type(&self, cols: &mut [VectorRef]) -> Result<()> {
|
||||
let return_types = &self.return_types;
|
||||
// allow ignore Return Type Annotation
|
||||
if return_types.is_empty() {
|
||||
@@ -205,32 +212,9 @@ impl Coprocessor {
|
||||
}
|
||||
}
|
||||
|
||||
/// convert a tuple of `PyVector` or one `PyVector`(wrapped in a Python Object Ref[`PyObjectRef`])
|
||||
/// to a `Vec<ArrayRef>`
|
||||
/// by default, a constant(int/float/bool) gives the a constant array of same length with input args
|
||||
fn try_into_columns(
|
||||
obj: &PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
col_len: usize,
|
||||
) -> Result<Vec<VectorRef>> {
|
||||
if is_instance::<PyTuple>(obj, vm) {
|
||||
let tuple = obj
|
||||
.payload::<PyTuple>()
|
||||
.with_context(|| ret_other_error_with(format!("can't cast obj {obj:?} to PyTuple)")))?;
|
||||
let cols = tuple
|
||||
.iter()
|
||||
.map(|obj| py_vec_obj_to_array(obj, vm, col_len))
|
||||
.collect::<Result<Vec<VectorRef>>>()?;
|
||||
Ok(cols)
|
||||
} else {
|
||||
let col = py_vec_obj_to_array(obj, vm, col_len)?;
|
||||
Ok(vec![col])
|
||||
}
|
||||
}
|
||||
|
||||
/// select columns according to `fetch_names` from `rb`
|
||||
/// and cast them into a Vec of PyVector
|
||||
fn select_from_rb(rb: &RecordBatch, fetch_names: &[String]) -> Result<Vec<PyVector>> {
|
||||
pub(crate) fn select_from_rb(rb: &RecordBatch, fetch_names: &[String]) -> Result<Vec<PyVector>> {
|
||||
fetch_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
@@ -243,8 +227,8 @@ fn select_from_rb(rb: &RecordBatch, fetch_names: &[String]) -> Result<Vec<PyVect
|
||||
}
|
||||
|
||||
/// match between arguments' real type and annotation types
|
||||
/// if type anno is vector[_] then use real type
|
||||
fn check_args_anno_real_type(
|
||||
/// if type anno is `vector[_]` then use real type(from RecordBatch's schema)
|
||||
pub(crate) fn check_args_anno_real_type(
|
||||
args: &[PyVector],
|
||||
copr: &Coprocessor,
|
||||
rb: &RecordBatch,
|
||||
@@ -274,27 +258,6 @@ fn check_args_anno_real_type(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// set arguments with given name and values in python scopes
|
||||
fn set_items_in_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
arg_names: &[String],
|
||||
args: Vec<PyVector>,
|
||||
) -> Result<()> {
|
||||
let _ = arg_names
|
||||
.iter()
|
||||
.zip(args)
|
||||
.map(|(name, vector)| {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(name, vm.new_pyobj(vector), vm)
|
||||
})
|
||||
.collect::<StdResult<Vec<()>, PyBaseExceptionRef>>()
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The coprocessor function accept a python script and a Record Batch:
|
||||
/// ## What it does
|
||||
/// 1. it take a python script and a [`RecordBatch`], extract columns and annotation info according to `args` given in decorator in python script
|
||||
@@ -351,31 +314,39 @@ fn set_items_in_scope(
|
||||
pub fn exec_coprocessor(script: &str, rb: &Option<RecordBatch>) -> Result<RecordBatch> {
|
||||
// 1. parse the script and check if it's only a function with `@coprocessor` decorator, and get `args` and `returns`,
|
||||
// 2. also check for exist of `args` in `rb`, if not found, return error
|
||||
// TODO(discord9): cache the result of parse_copr
|
||||
// cache the result of parse_copr
|
||||
let copr = parse::parse_and_compile_copr(script, None)?;
|
||||
exec_parsed(&copr, rb, &HashMap::new())
|
||||
}
|
||||
|
||||
#[pyclass(module = false, name = "query_engine")]
|
||||
#[cfg_attr(feature = "pyo3_backend", pyo3class(name = "query_engine"))]
|
||||
#[rspyclass(module = false, name = "query_engine")]
|
||||
#[derive(Debug, PyPayload)]
|
||||
pub struct PyQueryEngine {
|
||||
inner: QueryEngineWeakRef,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub(crate) enum Either {
|
||||
Rb(RecordBatches),
|
||||
AffectedRows(usize),
|
||||
}
|
||||
#[rspyclass]
|
||||
impl PyQueryEngine {
|
||||
// TODO(discord9): find a better way to call sql query api, now we don't if we are in async context or not
|
||||
/// return sql query results in List[List[PyVector]], or List[usize] for AffectedRows number if no recordbatches is returned
|
||||
#[pymethod]
|
||||
fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult<PyListRef> {
|
||||
enum Either {
|
||||
Rb(RecordBatches),
|
||||
AffectedRows(usize),
|
||||
}
|
||||
let query = self.inner.0.upgrade();
|
||||
pub(crate) fn from_weakref(inner: QueryEngineWeakRef) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
pub(crate) fn get_ref(&self) -> Option<Arc<dyn QueryEngine>> {
|
||||
self.inner.0.upgrade()
|
||||
}
|
||||
pub(crate) fn query_with_new_thread(
|
||||
&self,
|
||||
query: Option<Arc<dyn QueryEngine>>,
|
||||
s: String,
|
||||
) -> StdResult<Either, String> {
|
||||
let thread_handle = std::thread::spawn(move || -> std::result::Result<_, String> {
|
||||
if let Some(engine) = query {
|
||||
let stmt = QueryLanguageParser::parse_sql(s.as_str()).map_err(|e| e.to_string())?;
|
||||
let stmt = QueryLanguageParser::parse_sql(&s).map_err(|e| e.to_string())?;
|
||||
|
||||
// To prevent the error of nested creating Runtime, if is nested, use the parent runtime instead
|
||||
|
||||
let rt = tokio::runtime::Runtime::new().map_err(|e| e.to_string())?;
|
||||
@@ -385,7 +356,6 @@ impl PyQueryEngine {
|
||||
.statement_to_plan(stmt, Default::default())
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let res = engine
|
||||
.clone()
|
||||
.execute(&plan)
|
||||
@@ -409,9 +379,14 @@ impl PyQueryEngine {
|
||||
});
|
||||
thread_handle
|
||||
.join()
|
||||
.map_err(|e| {
|
||||
vm.new_system_error(format!("Dedicated thread for sql query panic: {e:?}"))
|
||||
})?
|
||||
.map_err(|e| format!("Dedicated thread for sql query panic: {e:?}"))?
|
||||
}
|
||||
// TODO(discord9): find a better way to call sql query api, now we don't if we are in async context or not
|
||||
/// return sql query results in List[List[PyVector]], or List[usize] for AffectedRows number if no recordbatches is returned
|
||||
#[pymethod]
|
||||
fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult<PyListRef> {
|
||||
let query = self.inner.0.upgrade();
|
||||
self.query_with_new_thread(query, s)
|
||||
.map_err(|e| vm.new_system_error(e))
|
||||
.map(|rbs| match rbs {
|
||||
Either::Rb(rbs) => {
|
||||
@@ -435,149 +410,27 @@ impl PyQueryEngine {
|
||||
}
|
||||
}
|
||||
|
||||
fn set_query_engine_in_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
query_engine: PyQueryEngine,
|
||||
) -> Result<()> {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item("query", query_engine.to_pyobject(vm), vm)
|
||||
.map_err(|e| format_py_error(e, vm))
|
||||
}
|
||||
|
||||
fn exec_with_cached_vm(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
args: Vec<PyVector>,
|
||||
params: &HashMap<String, String>,
|
||||
vm: &Arc<Interpreter>,
|
||||
) -> Result<RecordBatch> {
|
||||
vm.enter(|vm| -> Result<RecordBatch> {
|
||||
PyVector::make_class(&vm.ctx);
|
||||
// set arguments with given name and values
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
if let Some(rb) = rb {
|
||||
set_dataframe_in_scope(&scope, vm, "dataframe", rb)?;
|
||||
}
|
||||
|
||||
if let Some(arg_names) = &copr.deco_args.arg_names {
|
||||
assert_eq!(arg_names.len(), args.len());
|
||||
set_items_in_scope(&scope, vm, arg_names, args)?;
|
||||
}
|
||||
|
||||
if let Some(engine) = &copr.query_engine {
|
||||
let query_engine = PyQueryEngine {
|
||||
inner: engine.clone(),
|
||||
};
|
||||
|
||||
// put a object named with query of class PyQueryEngine in scope
|
||||
PyQueryEngine::make_class(&vm.ctx);
|
||||
set_query_engine_in_scope(&scope, vm, query_engine)?;
|
||||
}
|
||||
|
||||
if let Some(kwarg) = &copr.kwarg {
|
||||
let dict = PyDict::new_ref(&vm.ctx);
|
||||
for (k, v) in params {
|
||||
dict.set_item(k, PyStr::from(v.clone()).into_pyobject(vm), vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
}
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(kwarg, vm.new_pyobj(dict), vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
}
|
||||
|
||||
// It's safe to unwrap code_object, it's already compiled before.
|
||||
let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap());
|
||||
let ret = vm
|
||||
.run_code_obj(code_obj, scope)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
// 5. get returns as either a PyVector or a PyTuple, and naming schema them according to `returns`
|
||||
let col_len = rb.as_ref().map(|rb| rb.num_rows()).unwrap_or(1);
|
||||
let mut cols = try_into_columns(&ret, vm, col_len)?;
|
||||
ensure!(
|
||||
cols.len() == copr.deco_args.ret_names.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"The number of return Vector is wrong, expect {}, found {}",
|
||||
copr.deco_args.ret_names.len(),
|
||||
cols.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
// if cols and schema's data types is not match, try coerce it to given type(if annotated)(if error occur, return relevant error with question mark)
|
||||
copr.check_and_cast_type(&mut cols)?;
|
||||
|
||||
// 6. return a assembled DfRecordBatch
|
||||
let schema = copr.gen_schema(&cols)?;
|
||||
RecordBatch::new(schema, cols).context(NewRecordBatchSnafu)
|
||||
})
|
||||
}
|
||||
|
||||
/// init interpreter with type PyVector and Module: greptime
|
||||
pub(crate) fn init_interpreter() -> Arc<Interpreter> {
|
||||
INTERPRETER.with(|i| {
|
||||
i.borrow_mut()
|
||||
.get_or_insert_with(|| {
|
||||
// we limit stdlib imports for safety reason, i.e `fcntl` is not allowed here
|
||||
let native_module_allow_list = HashSet::from([
|
||||
"array", "cmath", "gc", "hashlib", "_json", "_random", "math",
|
||||
]);
|
||||
// TODO(discord9): edge cases, can't use "..Default::default" because Settings is `#[non_exhaustive]`
|
||||
// so more in here: https://internals.rust-lang.org/t/allow-constructing-non-exhaustive-structs-using-default-default/13868
|
||||
let mut settings = vm::Settings::default();
|
||||
// disable SIG_INT handler so our own binary can take ctrl_c handler
|
||||
settings.no_sig_int = true;
|
||||
let interpreter = Arc::new(vm::Interpreter::with_init(settings, |vm| {
|
||||
// not using full stdlib to prevent security issue, instead filter out a few simple util module
|
||||
vm.add_native_modules(
|
||||
rustpython_stdlib::get_module_inits()
|
||||
.filter(|(k, _)| native_module_allow_list.contains(k.as_ref())),
|
||||
);
|
||||
|
||||
// We are freezing the stdlib to include the standard library inside the binary.
|
||||
// so according to this issue:
|
||||
// https://github.com/RustPython/RustPython/issues/4292
|
||||
// add this line for stdlib, so rustpython can found stdlib's python part in bytecode format
|
||||
vm.add_frozen(rustpython_pylib::frozen_stdlib());
|
||||
// add our own custom datatype and module
|
||||
PyVector::make_class(&vm.ctx);
|
||||
vm.add_native_module("greptime", Box::new(greptime_builtin::make_module));
|
||||
|
||||
data_frame::PyDataFrame::make_class(&vm.ctx);
|
||||
data_frame::PyExpr::make_class(&vm.ctx);
|
||||
vm.add_native_module("data_frame", Box::new(data_frame::make_module));
|
||||
}));
|
||||
info!("Initialized Python interpreter.");
|
||||
interpreter
|
||||
})
|
||||
.clone()
|
||||
})
|
||||
}
|
||||
|
||||
/// using a parsed `Coprocessor` struct as input to execute python code
|
||||
pub(crate) fn exec_parsed(
|
||||
pub fn exec_parsed(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
params: &HashMap<String, String>,
|
||||
) -> Result<RecordBatch> {
|
||||
// 3. get args from `rb`, and cast them into PyVector
|
||||
let args: Vec<PyVector> = if let Some(rb) = rb {
|
||||
let args = select_from_rb(rb, copr.deco_args.arg_names.as_ref().unwrap_or(&vec![]))?;
|
||||
check_args_anno_real_type(&args, copr, rb)?;
|
||||
args
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let interpreter = init_interpreter();
|
||||
// 4. then set args in scope and compile then run `CodeObject` which already append a new `Call` node
|
||||
exec_with_cached_vm(copr, rb, args, params, &interpreter)
|
||||
match copr.backend {
|
||||
BackendType::RustPython => rspy_exec_parsed(copr, rb, params),
|
||||
BackendType::CPython => {
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
{
|
||||
pyo3_exec_parsed(copr, rb, params)
|
||||
}
|
||||
#[cfg(not(feature = "pyo3_backend"))]
|
||||
OtherSnafu {
|
||||
reason: "`pyo3` feature is disabled, therefore can't run scripts in cpython"
|
||||
.to_string(),
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// execute script just like [`exec_coprocessor`] do,
|
||||
@@ -601,7 +454,7 @@ pub fn exec_copr_print(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::python::coprocessor::parse::parse_and_compile_copr;
|
||||
use crate::python::ffi_types::copr::parse::parse_and_compile_copr;
|
||||
|
||||
#[test]
|
||||
fn test_parse_copr() {
|
||||
@@ -612,7 +465,7 @@ def add(a, b):
|
||||
@copr(args=["a", "b", "c"], returns = ["r"], sql="select number as a,number as b,number as c from numbers limit 100")
|
||||
def test(a, b, c, **params):
|
||||
import greptime as g
|
||||
return add(a, b) / g.sqrt(c)
|
||||
return ( a + b ) / g.sqrt(c)
|
||||
"#;
|
||||
|
||||
let copr = parse_and_compile_copr(script, None).unwrap();
|
||||
@@ -21,8 +21,8 @@ use rustpython_parser::{ast, parser};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::fail_parse_error;
|
||||
use crate::python::coprocessor::parse::{ret_parse_error, DecoratorArgs};
|
||||
use crate::python::error::{PyCompileSnafu, PyParseSnafu, Result};
|
||||
use crate::python::ffi_types::copr::parse::{ret_parse_error, DecoratorArgs};
|
||||
|
||||
fn create_located<T>(node: T, loc: Location) -> Located<T> {
|
||||
Located::new(loc, loc, node)
|
||||
@@ -23,17 +23,17 @@ use rustpython_parser::{ast, parser};
|
||||
use serde::Deserialize;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::python::coprocessor::{compile, AnnotationInfo, Coprocessor};
|
||||
use crate::python::error::{ensure, CoprParseSnafu, PyParseSnafu, Result};
|
||||
|
||||
use crate::python::ffi_types::copr::{compile, AnnotationInfo, BackendType, Coprocessor};
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DecoratorArgs {
|
||||
pub arg_names: Option<Vec<String>>,
|
||||
pub ret_names: Vec<String>,
|
||||
pub sql: Option<String>,
|
||||
// maybe add a URL for connecting or what?
|
||||
// also predicate for timed triggered or conditional triggered?
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub backend: BackendType, // maybe add a URL for connecting or what?
|
||||
// also predicate for timed triggered or conditional triggered?
|
||||
}
|
||||
|
||||
/// Return a CoprParseSnafu for you to chain fail() to return correct err Result type
|
||||
@@ -259,8 +259,8 @@ fn parse_annotation(sub: &ast::Expr<()>) -> Result<AnnotationInfo> {
|
||||
/// parse a list of keyword and return args and returns list from keywords
|
||||
fn parse_keywords(keywords: &Vec<ast::Keyword<()>>) -> Result<DecoratorArgs> {
|
||||
// more keys maybe add to this list of `avail_key`(like `sql` for querying and maybe config for connecting to database?), for better extension using a `HashSet` in here
|
||||
let avail_key = HashSet::from(["args", "returns", "sql"]);
|
||||
let opt_keys = HashSet::from(["sql", "args"]);
|
||||
let avail_key = HashSet::from(["args", "returns", "sql", "backend"]);
|
||||
let opt_keys = HashSet::from(["sql", "args", "backend"]);
|
||||
let mut visited_key = HashSet::new();
|
||||
let len_min = avail_key.len() - opt_keys.len();
|
||||
let len_max = avail_key.len();
|
||||
@@ -298,6 +298,23 @@ fn parse_keywords(keywords: &Vec<ast::Keyword<()>>) -> Result<DecoratorArgs> {
|
||||
"args" => ret_args.arg_names = Some(pylist_to_vec(&kw.node.value)?),
|
||||
"returns" => ret_args.ret_names = pylist_to_vec(&kw.node.value)?,
|
||||
"sql" => ret_args.sql = Some(py_str_to_string(&kw.node.value)?),
|
||||
"backend" => {
|
||||
let value = py_str_to_string(&kw.node.value)?;
|
||||
match value.as_str() {
|
||||
// although this is default option to use RustPython for interpreter
|
||||
// but that could change in the future
|
||||
"rspy" => ret_args.backend = BackendType::RustPython,
|
||||
"pyo3" => ret_args.backend = BackendType::CPython,
|
||||
_ => {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"backend type can only be of `rspy` and `pyo3`, found {value}"
|
||||
),
|
||||
Some(kw.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@@ -497,6 +514,8 @@ pub fn parse_and_compile_copr(
|
||||
loc: None
|
||||
}
|
||||
);
|
||||
|
||||
let backend = deco_args.backend.clone();
|
||||
let kwarg = fn_args.kwarg.as_ref().map(|arg| arg.node.arg.clone());
|
||||
coprocessor = Some(Coprocessor {
|
||||
code_obj: Some(compile::compile_script(name, &deco_args, &kwarg, script)?),
|
||||
@@ -507,6 +526,7 @@ pub fn parse_and_compile_copr(
|
||||
kwarg,
|
||||
script: script.to_string(),
|
||||
query_engine: query_engine.as_ref().map(|e| Arc::downgrade(e).into()),
|
||||
backend,
|
||||
});
|
||||
}
|
||||
} else if matches!(
|
||||
145
src/script/src/python/ffi_types/pair_tests.rs
Normal file
145
src/script/src/python/ffi_types/pair_tests.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod sample_testcases;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use datafusion::arrow::array::Float64Array;
|
||||
use datafusion::arrow::compute;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::vectors::VectorRef;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::{types::PyDict, Python};
|
||||
use rustpython_compiler::Mode;
|
||||
|
||||
use crate::python::ffi_types::pair_tests::sample_testcases::sample_test_case;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use crate::python::pyo3::{init_cpython_interpreter, vector_impl::into_pyo3_cell};
|
||||
use crate::python::rspython::init_interpreter;
|
||||
|
||||
/// generate testcases that should be tested in paired both in RustPython and CPython
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestCase {
|
||||
input: HashMap<String, VectorRef>,
|
||||
script: String,
|
||||
expect: VectorRef,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pyo3_rspy_test_in_pairs() {
|
||||
let testcases = sample_test_case();
|
||||
for case in testcases {
|
||||
eval_rspy(case.clone());
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
eval_pyo3(case);
|
||||
}
|
||||
}
|
||||
|
||||
fn check_equal(v0: VectorRef, v1: VectorRef) -> bool {
|
||||
let v0 = v0.to_arrow_array();
|
||||
let v1 = v1.to_arrow_array();
|
||||
fn is_float(ty: &ArrowDataType) -> bool {
|
||||
use ArrowDataType::*;
|
||||
matches!(ty, Float16 | Float32 | Float64)
|
||||
}
|
||||
if is_float(v0.data_type()) || is_float(v1.data_type()) {
|
||||
let v0 = compute::cast(&v0, &ArrowDataType::Float64).unwrap();
|
||||
let v0 = v0.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
|
||||
let v1 = compute::cast(&v1, &ArrowDataType::Float64).unwrap();
|
||||
let v1 = v1.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
|
||||
let res = compute::subtract(v0, v1).unwrap();
|
||||
res.iter().all(|v| {
|
||||
if let Some(v) = v {
|
||||
v.abs() <= 2.0 * f32::EPSILON as f64
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
} else {
|
||||
*v0 == *v1
|
||||
}
|
||||
}
|
||||
|
||||
/// will panic if something is wrong, used in tests only
|
||||
fn eval_rspy(case: TestCase) {
|
||||
let interpreter = init_interpreter();
|
||||
interpreter.enter(|vm| {
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
for (k, v) in case.input {
|
||||
let v = PyVector::from(v);
|
||||
scope.locals.set_item(&k, vm.new_pyobj(v), vm).unwrap();
|
||||
}
|
||||
let code_obj = vm
|
||||
.compile(&case.script, Mode::BlockExpr, "<embedded>".to_owned())
|
||||
.map_err(|err| {
|
||||
dbg!(&err);
|
||||
vm.new_syntax_error(&err)
|
||||
})
|
||||
.unwrap();
|
||||
let result_vector = vm
|
||||
.run_code_obj(code_obj, scope)
|
||||
.map_err(|e| {
|
||||
dbg!(&e);
|
||||
dbg!(&case.script);
|
||||
e
|
||||
})
|
||||
.unwrap()
|
||||
.downcast::<PyVector>()
|
||||
.unwrap();
|
||||
|
||||
if !check_equal(result_vector.as_vector_ref(), case.expect.clone()) {
|
||||
panic!(
|
||||
"(RsPy)code:{}\nReal: {:?}!=Expected: {:?}",
|
||||
case.script, result_vector, case.expect
|
||||
)
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
fn eval_pyo3(case: TestCase) {
|
||||
init_cpython_interpreter();
|
||||
Python::with_gil(|py| {
|
||||
let locals = {
|
||||
let locals_dict = PyDict::new(py);
|
||||
for (k, v) in case.input {
|
||||
let v = PyVector::from(v);
|
||||
locals_dict
|
||||
.set_item(k, into_pyo3_cell(py, v).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
locals_dict
|
||||
};
|
||||
py.run(&case.script, None, Some(locals)).unwrap();
|
||||
let res_vec = locals
|
||||
.get_item("ret")
|
||||
.unwrap()
|
||||
.extract::<PyVector>()
|
||||
.map_err(|e| {
|
||||
dbg!(&case.script);
|
||||
e
|
||||
})
|
||||
.unwrap();
|
||||
if !check_equal(res_vec.as_vector_ref(), case.expect.clone()) {
|
||||
panic!(
|
||||
"(PyO3)code:{}\nReal: {:?}!=Expected: {:?}",
|
||||
case.script, res_vec, case.expect
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
491
src/script/src/python/ffi_types/pair_tests/sample_testcases.rs
Normal file
491
src/script/src/python/ffi_types/pair_tests/sample_testcases.rs
Normal file
@@ -0,0 +1,491 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::f64::consts;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::vectors::{BooleanVector, Float64Vector, Int64Vector, VectorRef};
|
||||
|
||||
use crate::python::ffi_types::pair_tests::TestCase;
|
||||
|
||||
macro_rules! vector {
|
||||
($ty: ident, $slice: expr) => {
|
||||
Arc::new($ty::from_slice($slice)) as VectorRef
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! ronish {
|
||||
($($key: literal : $expr: expr),*$(,)?) => {
|
||||
HashMap::from([
|
||||
$(($key.to_string(), $expr)),*
|
||||
])
|
||||
};
|
||||
}
|
||||
|
||||
/// Using a function to generate testcase instead of `.ron` configure file because it's more flexible and we are in #[cfg(test)] so no binary bloat worrying
|
||||
#[allow(clippy::approx_constant)]
|
||||
pub(super) fn sample_test_case() -> Vec<TestCase> {
|
||||
vec![
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0f64, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = a+3.0
|
||||
ret = ret * 2.0
|
||||
ret = ret / 2.0
|
||||
ret = ret - 3.0
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [1.0f64, 2.0, 3.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0f64, 2.0, 3.0]),
|
||||
"b": vector!(Float64Vector, [3.0f64, 2.0, 1.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = a+b
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [4.0f64, 4.0, 4.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0f64, 2.0, 3.0]),
|
||||
"b": vector!(Float64Vector, [3.0f64, 2.0, 1.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = a-b
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [-2.0f64, 0.0, 2.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0f64, 2.0, 3.0]),
|
||||
"b": vector!(Float64Vector, [3.0f64, 2.0, 1.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = a*b
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [3.0f64, 4.0, 3.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0f64, 2.0, 3.0]),
|
||||
"b": vector!(Float64Vector, [3.0f64, 2.0, 1.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = a/b
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [1. / 3., 1.0, 3.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1.0f64, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = sqrt(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(
|
||||
Float64Vector,
|
||||
[1.0f64, std::f64::consts::SQRT_2, 1.7320508075688772,]
|
||||
),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1.0, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = sin(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(
|
||||
Float64Vector,
|
||||
[0.8414709848078965, 0.9092974268256817, 0.1411200080598672,]
|
||||
),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1.0, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = cos(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(
|
||||
Float64Vector,
|
||||
[0.5403023058681398, -0.4161468365471424, -0.9899924966004454,]
|
||||
),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1.0, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = tan(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(
|
||||
Float64Vector,
|
||||
[1.5574077246549023, -2.185039863261519, -0.1425465430742778,]
|
||||
),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [0.3, 0.5, 1.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = asin(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(
|
||||
Float64Vector,
|
||||
[0.3046926540153975, 0.5235987755982989, 1.5707963267948966,]
|
||||
),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [0.3, 0.5, 1.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = acos(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(
|
||||
Float64Vector,
|
||||
[1.2661036727794992, 1.0471975511965979, 0.0,]
|
||||
),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [0.3, 0.5, 1.1])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = atan(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(
|
||||
Float64Vector,
|
||||
[0.2914567944778671, 0.4636476090008061, 0.8329812666744317,]
|
||||
),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [0.3, 0.5, 1.1])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = floor(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [0.0, 0.0, 1.0,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [0.3, 0.5, 1.1])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = ceil(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [1.0, 1.0, 2.0,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [0.3, 0.5, 1.1])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = round(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [0.0, 1.0, 1.0,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [0.3, 0.5, 1.1])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = trunc(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [0.0, 0.0, 1.0,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [-0.3, 0.5, -1.1])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = abs(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [0.3, 0.5, 1.1,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [-0.3, 0.5, -1.1])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = signum(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [-1.0, 1.0, -1.0,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [0., 1.0, 2.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = exp(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [1.0, consts::E, 7.38905609893065,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1.0, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = ln(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [0.0, consts::LN_2, 1.0986122886681098,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1.0, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = log2(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [0.0, 1.0, 1.584962500721156,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1.0, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = log10(values)
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [0.0, consts::LOG10_2, 0.47712125471966244,]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = 0.0<=random(3)<=1.0
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(BooleanVector, &[true, true, true]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Int64Vector, [1, 2, 2, 3])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([approx_distinct(values)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Int64Vector, [3]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Int64Vector, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([approx_percentile_cont(values, 0.6)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Int64Vector, [6]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1.0, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector(array_agg(values))
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [1.0, 2.0, 3.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1.0, 2.0, 3.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([avg(values)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [2.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0, 2.0, 3.0]),
|
||||
"b": vector!(Float64Vector, [1.0, 0.0, -1.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([correlation(a, b)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [-1.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Int64Vector, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([count(values)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Int64Vector, [10]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0, 2.0, 3.0]),
|
||||
"b": vector!(Float64Vector, [1.0, 0.0, -1.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([covariance(a, b)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [-1.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0, 2.0, 3.0]),
|
||||
"b": vector!(Float64Vector, [1.0, 0.0, -1.0])
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([covariance_pop(a, b)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [-0.6666666666666666]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0, 2.0, 3.0]),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([max(a)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [3.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"a": vector!(Float64Vector, [1.0, 2.0, 3.0]),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([min(a)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [1.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([stddev(values)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [3.0276503540974917]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([stddev_pop(values)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [2.8722813232690143]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([sum(values)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [55.0]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([variance(values)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [9.166666666666666]),
|
||||
},
|
||||
TestCase {
|
||||
input: ronish! {
|
||||
"values": vector!(Float64Vector, [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ret = vector([variance_pop(values)])
|
||||
ret"#
|
||||
.to_string(),
|
||||
expect: vector!(Float64Vector, [8.25]),
|
||||
},
|
||||
// TODO(discord9): GrepTime's Own UDF
|
||||
]
|
||||
}
|
||||
81
src/script/src/python/ffi_types/utils.rs
Normal file
81
src/script/src/python/ffi_types/utils.rs
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// to avoid put too many #cfg for pyo3 feature flag
|
||||
#![allow(unused)]
|
||||
use datafusion::arrow::compute;
|
||||
use datafusion::arrow::datatypes::Field;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::ColumnarValue;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
|
||||
pub fn new_item_field(data_type: ArrowDataType) -> Field {
|
||||
Field::new("item", data_type, false)
|
||||
}
|
||||
|
||||
/// Generate friendly error message when the type of the input `values` is different than `ty`
|
||||
/// # Example
|
||||
/// `values` is [Int64(1), Float64(1.0), Int64(2)] and `ty` is Int64
|
||||
/// then the error message will be: " Float64 at 2th location\n"
|
||||
pub(crate) fn collect_diff_types_string(values: &[ScalarValue], ty: &ArrowDataType) -> String {
|
||||
values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, val)| {
|
||||
if val.get_datatype() != *ty {
|
||||
Some((idx, val.get_datatype()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.map(|(idx, ty)| format!(" {:?} at {}th location\n", ty, idx + 1))
|
||||
.reduce(|mut acc, item| {
|
||||
acc.push_str(&item);
|
||||
acc
|
||||
})
|
||||
.unwrap_or_else(|| "Nothing".to_string())
|
||||
}
|
||||
|
||||
/// Because most of the datafusion's UDF only support f32/64, so cast all to f64 to use datafusion's UDF
|
||||
pub fn all_to_f64(col: ColumnarValue) -> Result<ColumnarValue, String> {
|
||||
match col {
|
||||
ColumnarValue::Array(arr) => {
|
||||
let res = compute::cast(&arr, &ArrowDataType::Float64).map_err(|err| {
|
||||
format!(
|
||||
"Arrow Type Cast Fail(from {:#?} to {:#?}): {err:#?}",
|
||||
arr.data_type(),
|
||||
ArrowDataType::Float64
|
||||
)
|
||||
})?;
|
||||
Ok(ColumnarValue::Array(res))
|
||||
}
|
||||
ColumnarValue::Scalar(val) => {
|
||||
let val_in_f64 = match val {
|
||||
ScalarValue::Float64(Some(v)) => v,
|
||||
ScalarValue::Int64(Some(v)) => v as f64,
|
||||
ScalarValue::Boolean(Some(v)) => v as i64 as f64,
|
||||
_ => {
|
||||
return Err(format!(
|
||||
"Can't cast type {:#?} to {:#?}",
|
||||
val.get_datatype(),
|
||||
ArrowDataType::Float64
|
||||
))
|
||||
}
|
||||
};
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(
|
||||
val_in_f64,
|
||||
))))
|
||||
}
|
||||
}
|
||||
}
|
||||
539
src/script/src/python/ffi_types/vector.rs
Normal file
539
src/script/src/python/ffi_types/vector.rs
Normal file
@@ -0,0 +1,539 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::arrow::array::{
|
||||
Array, ArrayRef, BooleanArray, Float64Array, Int64Array, UInt64Array,
|
||||
};
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::arrow::compute::kernels::{arithmetic, comparison};
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::arrow::error::Result as ArrowResult;
|
||||
use datatypes::data_type::DataType;
|
||||
use datatypes::prelude::Value;
|
||||
use datatypes::value::{self, OrderedFloat};
|
||||
use datatypes::vectors::{Helper, NullVector, VectorRef};
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::pyclass as pyo3class;
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt, PyNone, PyStr};
|
||||
use rustpython_vm::sliceable::{SaturatedSlice, SequenceIndex, SequenceIndexOp};
|
||||
use rustpython_vm::types::PyComparisonOp;
|
||||
use rustpython_vm::{
|
||||
pyclass as rspyclass, AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
|
||||
VirtualMachine,
|
||||
};
|
||||
|
||||
use crate::python::utils::is_instance;
|
||||
|
||||
/// The Main FFI type `PyVector` that is used both in RustPython and PyO3
|
||||
#[cfg_attr(feature = "pyo3_backend", pyo3class(name = "vector"))]
|
||||
#[rspyclass(module = false, name = "vector")]
|
||||
#[repr(transparent)]
|
||||
#[derive(PyPayload, Debug, Clone)]
|
||||
pub struct PyVector {
|
||||
pub(crate) vector: VectorRef,
|
||||
}
|
||||
|
||||
impl From<VectorRef> for PyVector {
|
||||
fn from(vector: VectorRef) -> Self {
|
||||
Self { vector }
|
||||
}
|
||||
}
|
||||
|
||||
fn to_type_error(vm: &'_ VirtualMachine) -> impl FnOnce(String) -> PyBaseExceptionRef + '_ {
|
||||
|msg: String| vm.new_type_error(msg)
|
||||
}
|
||||
|
||||
/// Performs `val - arr`.
|
||||
pub(crate) fn arrow_rsub(arr: &dyn Array, val: &dyn Array) -> Result<ArrayRef, String> {
|
||||
arithmetic::subtract_dyn(val, arr).map_err(|e| format!("rsub error: {e}"))
|
||||
}
|
||||
|
||||
/// Performs `val / arr`
|
||||
pub(crate) fn arrow_rtruediv(arr: &dyn Array, val: &dyn Array) -> Result<ArrayRef, String> {
|
||||
arithmetic::divide_dyn(val, arr).map_err(|e| format!("rtruediv error: {e}"))
|
||||
}
|
||||
|
||||
/// Performs `val / arr`, but cast to i64.
|
||||
pub(crate) fn arrow_rfloordiv(arr: &dyn Array, val: &dyn Array) -> Result<ArrayRef, String> {
|
||||
let array =
|
||||
arithmetic::divide_dyn(val, arr).map_err(|e| format!("rfloordiv divide error: {e}"))?;
|
||||
compute::cast(&array, &ArrowDataType::Int64).map_err(|e| format!("rfloordiv cast error: {e}"))
|
||||
}
|
||||
|
||||
pub(crate) fn wrap_result<F>(f: F) -> impl Fn(&dyn Array, &dyn Array) -> Result<ArrayRef, String>
|
||||
where
|
||||
F: Fn(&dyn Array, &dyn Array) -> ArrowResult<ArrayRef>,
|
||||
{
|
||||
move |left, right| f(left, right).map_err(|e| format!("arithmetic error {e}"))
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
pub(crate) fn wrap_bool_result<F>(
|
||||
op_bool_arr: F,
|
||||
) -> impl Fn(&dyn Array, &dyn Array) -> Result<ArrayRef, String>
|
||||
where
|
||||
F: Fn(&dyn Array, &dyn Array) -> ArrowResult<BooleanArray>,
|
||||
{
|
||||
move |a: &dyn Array, b: &dyn Array| -> Result<ArrayRef, String> {
|
||||
let array = op_bool_arr(a, b).map_err(|e| format!("logical op error: {e}"))?;
|
||||
Ok(Arc::new(array))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_float(datatype: &ArrowDataType) -> bool {
|
||||
matches!(
|
||||
datatype,
|
||||
ArrowDataType::Float16 | ArrowDataType::Float32 | ArrowDataType::Float64
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_signed(datatype: &ArrowDataType) -> bool {
|
||||
matches!(
|
||||
datatype,
|
||||
ArrowDataType::Int8 | ArrowDataType::Int16 | ArrowDataType::Int32 | ArrowDataType::Int64
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_unsigned(datatype: &ArrowDataType) -> bool {
|
||||
matches!(
|
||||
datatype,
|
||||
ArrowDataType::UInt8
|
||||
| ArrowDataType::UInt16
|
||||
| ArrowDataType::UInt32
|
||||
| ArrowDataType::UInt64
|
||||
)
|
||||
}
|
||||
|
||||
fn cast(array: ArrayRef, target_type: &ArrowDataType) -> Result<ArrayRef, String> {
|
||||
compute::cast(&array, target_type).map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
impl AsRef<PyVector> for PyVector {
|
||||
fn as_ref(&self) -> &PyVector {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl PyVector {
|
||||
pub(crate) fn vector_and(left: &Self, right: &Self) -> Result<Self, String> {
|
||||
let left = left.to_arrow_array();
|
||||
let right = right.to_arrow_array();
|
||||
let left = left
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {left:#?} as a Boolean Array"))?;
|
||||
let right = right
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {right:#?} as a Boolean Array"))?;
|
||||
let res =
|
||||
Arc::new(compute::kernels::boolean::and(left, right).map_err(|err| err.to_string())?)
|
||||
as ArrayRef;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|err| err.to_string())?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
pub(crate) fn vector_or(left: &Self, right: &Self) -> Result<Self, String> {
|
||||
let left = left.to_arrow_array();
|
||||
let right = right.to_arrow_array();
|
||||
let left = left
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {left:#?} as a Boolean Array"))?;
|
||||
let right = right
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {right:#?} as a Boolean Array"))?;
|
||||
let res =
|
||||
Arc::new(compute::kernels::boolean::or(left, right).map_err(|err| err.to_string())?)
|
||||
as ArrayRef;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|err| err.to_string())?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
pub(crate) fn vector_invert(left: &Self) -> Result<Self, String> {
|
||||
let zelf = left.to_arrow_array();
|
||||
let zelf = zelf
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {left:#?} as a Boolean Array"))?;
|
||||
let res = Arc::new(compute::kernels::boolean::not(zelf).map_err(|err| err.to_string())?)
|
||||
as ArrayRef;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|err| err.to_string())?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
/// create a ref to inner vector
|
||||
#[inline]
|
||||
pub fn as_vector_ref(&self) -> VectorRef {
|
||||
self.vector.clone()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn to_arrow_array(&self) -> ArrayRef {
|
||||
self.vector.to_arrow_array()
|
||||
}
|
||||
|
||||
pub(crate) fn scalar_arith_op<F>(
|
||||
&self,
|
||||
right: value::Value,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
F: Fn(&dyn Array, &dyn Array) -> Result<ArrayRef, String>,
|
||||
{
|
||||
let right_type = right.data_type().as_arrow_type();
|
||||
// assuming they are all 64 bit type if possible
|
||||
let left = self.to_arrow_array();
|
||||
|
||||
let left_type = left.data_type();
|
||||
let right_type = &right_type;
|
||||
let target_type = Self::coerce_types(left_type, right_type, &target_type);
|
||||
let left = cast(left, &target_type)?;
|
||||
let left_len = left.len();
|
||||
|
||||
// Convert `right` to an array of `target_type`.
|
||||
let right: Box<dyn Array> = if is_float(&target_type) {
|
||||
match right {
|
||||
value::Value::Int64(v) => Box::new(Float64Array::from_value(v as f64, left_len)),
|
||||
value::Value::UInt64(v) => Box::new(Float64Array::from_value(v as f64, left_len)),
|
||||
value::Value::Float64(v) => {
|
||||
Box::new(Float64Array::from_value(f64::from(v), left_len))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else if is_signed(&target_type) {
|
||||
match right {
|
||||
value::Value::Int64(v) => Box::new(Int64Array::from_value(v, left_len)),
|
||||
value::Value::UInt64(v) => Box::new(Int64Array::from_value(v as i64, left_len)),
|
||||
value::Value::Float64(v) => Box::new(Int64Array::from_value(v.0 as i64, left_len)),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else if is_unsigned(&target_type) {
|
||||
match right {
|
||||
value::Value::Int64(v) => Box::new(UInt64Array::from_value(v as u64, left_len)),
|
||||
value::Value::UInt64(v) => Box::new(UInt64Array::from_value(v, left_len)),
|
||||
value::Value::Float64(v) => Box::new(UInt64Array::from_value(v.0 as u64, left_len)),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
return Err(format!(
|
||||
"Can't cast source operand of type {:?} into target type of {:?}",
|
||||
right_type, &target_type
|
||||
));
|
||||
};
|
||||
|
||||
let result = op(left.as_ref(), right.as_ref())?;
|
||||
|
||||
Ok(Helper::try_into_vector(result.clone())
|
||||
.map_err(|e| format!("Can't cast result into vector, result: {result:?}, err: {e:?}",))?
|
||||
.into())
|
||||
}
|
||||
|
||||
pub(crate) fn rspy_scalar_arith_op<F>(
|
||||
&self,
|
||||
other: PyObjectRef,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyVector>
|
||||
where
|
||||
F: Fn(&dyn Array, &dyn Array) -> Result<ArrayRef, String>,
|
||||
{
|
||||
// the right operand only support PyInt or PyFloat,
|
||||
let right = {
|
||||
if is_instance::<PyInt>(&other, vm) {
|
||||
other.try_into_value::<i64>(vm).map(value::Value::Int64)?
|
||||
} else if is_instance::<PyFloat>(&other, vm) {
|
||||
other
|
||||
.try_into_value::<f64>(vm)
|
||||
.map(|v| (value::Value::Float64(OrderedFloat(v))))?
|
||||
} else {
|
||||
return Err(vm.new_type_error(format!(
|
||||
"Can't cast right operand into Scalar of Int or Float, actual: {}",
|
||||
other.class().name()
|
||||
)));
|
||||
}
|
||||
};
|
||||
self.scalar_arith_op(right, target_type, op)
|
||||
.map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
/// Returns the type that should be used for the result of an arithmetic operation
|
||||
fn coerce_types(
|
||||
left_type: &ArrowDataType,
|
||||
right_type: &ArrowDataType,
|
||||
target_type: &Option<ArrowDataType>,
|
||||
) -> ArrowDataType {
|
||||
// TODO(discord9): found better way to cast between signed and unsigned types
|
||||
target_type.clone().unwrap_or_else(|| {
|
||||
if is_signed(left_type) && is_signed(right_type) {
|
||||
ArrowDataType::Int64
|
||||
} else if is_unsigned(left_type) && is_unsigned(right_type) {
|
||||
ArrowDataType::UInt64
|
||||
} else {
|
||||
ArrowDataType::Float64
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn vector_arith_op<F>(
|
||||
&self,
|
||||
right: &Self,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
) -> Result<PyVector, String>
|
||||
where
|
||||
F: Fn(&dyn Array, &dyn Array) -> Result<ArrayRef, String>,
|
||||
{
|
||||
let left = self.to_arrow_array();
|
||||
let right = right.to_arrow_array();
|
||||
|
||||
let left_type = &left.data_type();
|
||||
let right_type = &right.data_type();
|
||||
|
||||
let target_type = Self::coerce_types(left_type, right_type, &target_type);
|
||||
|
||||
let left = cast(left, &target_type)?;
|
||||
let right = cast(right, &target_type)?;
|
||||
|
||||
let result = op(left.as_ref(), right.as_ref())?;
|
||||
|
||||
Ok(Helper::try_into_vector(result.clone())
|
||||
.map_err(|e| format!("Can't cast result into vector, result: {result:?}, err: {e:?}",))?
|
||||
.into())
|
||||
}
|
||||
|
||||
pub(crate) fn rspy_vector_arith_op<F>(
|
||||
&self,
|
||||
other: PyObjectRef,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyVector>
|
||||
where
|
||||
F: Fn(&dyn Array, &dyn Array) -> Result<ArrayRef, String>,
|
||||
{
|
||||
let right = other.downcast_ref::<PyVector>().ok_or_else(|| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast right operand into PyVector, actual type: {}",
|
||||
other.class().name()
|
||||
))
|
||||
})?;
|
||||
self.vector_arith_op(right, target_type, op)
|
||||
.map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
pub(crate) fn _getitem(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
if let Some(seq) = needle.payload::<PyVector>() {
|
||||
let mask = seq.to_arrow_array();
|
||||
let mask = mask
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| {
|
||||
vm.new_type_error(format!("Can't cast {seq:#?} as a Boolean Array"))
|
||||
})?;
|
||||
let res = compute::filter(self.to_arrow_array().as_ref(), mask)
|
||||
.map_err(|err| vm.new_runtime_error(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
|
||||
vm.new_type_error(format!("Can't cast result into vector, err: {e:?}"))
|
||||
})?;
|
||||
Ok(Self::from(ret).into_pyobject(vm))
|
||||
} else {
|
||||
match SequenceIndex::try_from_borrowed_object(vm, needle, "vector")? {
|
||||
SequenceIndex::Int(i) => self.getitem_by_index(i, vm),
|
||||
SequenceIndex::Slice(slice) => self.getitem_by_slice(&slice, vm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn getitem_by_index(&self, i: isize, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
// in the newest version of rustpython_vm, wrapped_at for isize is replace by wrap_index(i, len)
|
||||
let i = i.wrapped_at(self.len()).ok_or_else(|| {
|
||||
vm.new_index_error(format!("PyVector index {i} out of range {}", self.len()))
|
||||
})?;
|
||||
val_to_pyobj(self.as_vector_ref().get(i), vm)
|
||||
}
|
||||
|
||||
/// Return a `PyVector` in `PyObjectRef`
|
||||
fn getitem_by_slice(
|
||||
&self,
|
||||
slice: &SaturatedSlice,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyObjectRef> {
|
||||
// adjust_indices so negative number is transform to usize
|
||||
let (mut range, step, slice_len) = slice.adjust_indices(self.len());
|
||||
let vector = self.as_vector_ref();
|
||||
|
||||
let mut buf = vector.data_type().create_mutable_vector(slice_len);
|
||||
if slice_len == 0 {
|
||||
let v: PyVector = buf.to_vector().into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else if step == 1 {
|
||||
let v: PyVector = vector.slice(range.next().unwrap_or(0), slice_len).into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else if step.is_negative() {
|
||||
// Negative step require special treatment
|
||||
for i in range.rev().step_by(step.unsigned_abs()) {
|
||||
// Safety: This mutable vector is created from the vector's data type.
|
||||
buf.push_value_ref(vector.get_ref(i));
|
||||
}
|
||||
let v: PyVector = buf.to_vector().into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else {
|
||||
for i in range.step_by(step.unsigned_abs()) {
|
||||
// Safety: This mutable vector is created from the vector's data type.
|
||||
buf.push_value_ref(vector.get_ref(i));
|
||||
}
|
||||
let v: PyVector = buf.to_vector().into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
}
|
||||
}
|
||||
|
||||
/// Unsupported
|
||||
/// TODO(discord9): make it work
|
||||
#[allow(unused)]
|
||||
fn setitem_by_index(
|
||||
zelf: PyRef<Self>,
|
||||
i: isize,
|
||||
value: PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<()> {
|
||||
Err(vm.new_not_implemented_error("setitem_by_index unimplemented".to_string()))
|
||||
}
|
||||
|
||||
/// rich compare, return a boolean array, accept type are vec and vec and vec and number
|
||||
pub(crate) fn richcompare(
|
||||
&self,
|
||||
other: PyObjectRef,
|
||||
op: PyComparisonOp,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
let scalar_op = get_arrow_scalar_op(op);
|
||||
self.rspy_scalar_arith_op(other, None, scalar_op, vm)
|
||||
} else {
|
||||
let arr_op = get_arrow_op(op);
|
||||
self.rspy_vector_arith_op(other, None, wrap_result(arr_op), vm)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.as_vector_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
/// get corresponding arrow op function according to given PyComaprsionOp
|
||||
fn get_arrow_op(op: PyComparisonOp) -> impl Fn(&dyn Array, &dyn Array) -> ArrowResult<ArrayRef> {
|
||||
let op_bool_arr = match op {
|
||||
PyComparisonOp::Eq => comparison::eq_dyn,
|
||||
PyComparisonOp::Ne => comparison::neq_dyn,
|
||||
PyComparisonOp::Gt => comparison::gt_dyn,
|
||||
PyComparisonOp::Lt => comparison::lt_dyn,
|
||||
PyComparisonOp::Ge => comparison::gt_eq_dyn,
|
||||
PyComparisonOp::Le => comparison::lt_eq_dyn,
|
||||
};
|
||||
|
||||
move |a: &dyn Array, b: &dyn Array| -> ArrowResult<ArrayRef> {
|
||||
let array = op_bool_arr(a, b)?;
|
||||
Ok(Arc::new(array))
|
||||
}
|
||||
}
|
||||
|
||||
/// get corresponding arrow scalar op function according to given PyComaprsionOp
|
||||
fn get_arrow_scalar_op(
|
||||
op: PyComparisonOp,
|
||||
) -> impl Fn(&dyn Array, &dyn Array) -> Result<ArrayRef, String> {
|
||||
let op_bool_arr = match op {
|
||||
PyComparisonOp::Eq => comparison::eq_dyn,
|
||||
PyComparisonOp::Ne => comparison::neq_dyn,
|
||||
PyComparisonOp::Gt => comparison::gt_dyn,
|
||||
PyComparisonOp::Lt => comparison::lt_dyn,
|
||||
PyComparisonOp::Ge => comparison::gt_eq_dyn,
|
||||
PyComparisonOp::Le => comparison::lt_eq_dyn,
|
||||
};
|
||||
|
||||
move |a: &dyn Array, b: &dyn Array| -> Result<ArrayRef, String> {
|
||||
let array = op_bool_arr(a, b).map_err(|e| format!("scalar op error: {e}"))?;
|
||||
Ok(Arc::new(array))
|
||||
}
|
||||
}
|
||||
|
||||
/// if this pyobj can be cast to a scalar value(i.e Null/Int/Float/Bool)
|
||||
#[inline]
|
||||
pub(crate) fn rspy_is_pyobj_scalar(obj: &PyObjectRef, vm: &VirtualMachine) -> bool {
|
||||
is_instance::<PyNone>(obj, vm)
|
||||
|| is_instance::<PyInt>(obj, vm)
|
||||
|| is_instance::<PyFloat>(obj, vm)
|
||||
|| is_instance::<PyBool>(obj, vm)
|
||||
|| is_instance::<PyStr>(obj, vm)
|
||||
}
|
||||
|
||||
/// convert a DataType `Value` into a `PyObjectRef`
|
||||
pub fn val_to_pyobj(val: value::Value, vm: &VirtualMachine) -> PyResult {
|
||||
Ok(match val {
|
||||
// This comes from:https://github.com/RustPython/RustPython/blob/8ab4e770351d451cfdff5dc2bf8cce8df76a60ab/vm/src/builtins/singletons.rs#L37
|
||||
// None in Python is universally singleton so
|
||||
// use `vm.ctx.new_int` and `new_***` is more idiomatic for there are certain optimize can be used in this way(small int pool etc.)
|
||||
value::Value::Null => vm.ctx.none(),
|
||||
value::Value::Boolean(v) => vm.ctx.new_bool(v).into(),
|
||||
value::Value::UInt8(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::UInt16(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::UInt32(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::UInt64(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Int8(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Int16(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Int32(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Int64(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Float32(v) => vm.ctx.new_float(v.0 as f64).into(),
|
||||
value::Value::Float64(v) => vm.ctx.new_float(v.0).into(),
|
||||
value::Value::String(s) => vm.ctx.new_str(s.as_utf8()).into(),
|
||||
// is this copy necessary?
|
||||
value::Value::Binary(b) => vm.ctx.new_bytes(b.deref().to_vec()).into(),
|
||||
// TODO(dennis):is `Date` and `DateTime` supported yet? For now just ad hoc into PyInt, but it's better to be cast into python Date, DateTime objects etc..
|
||||
value::Value::Date(v) => vm.ctx.new_int(v.val()).into(),
|
||||
value::Value::DateTime(v) => vm.ctx.new_int(v.val()).into(),
|
||||
// FIXME(dennis): lose the timestamp unit here
|
||||
Value::Timestamp(v) => vm.ctx.new_int(v.value()).into(),
|
||||
value::Value::List(list) => {
|
||||
let list = list.items().as_ref();
|
||||
match list {
|
||||
Some(list) => {
|
||||
let list: Vec<_> = list
|
||||
.iter()
|
||||
.map(|v| val_to_pyobj(v.clone(), vm))
|
||||
.collect::<Result<_, _>>()?;
|
||||
vm.ctx.new_list(list).into()
|
||||
}
|
||||
None => vm.ctx.new_list(Vec::new()).into(),
|
||||
}
|
||||
}
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => return Err(vm.new_type_error(format!("Convert from {val:?} is not supported yet"))),
|
||||
})
|
||||
}
|
||||
|
||||
impl Default for PyVector {
|
||||
fn default() -> PyVector {
|
||||
PyVector {
|
||||
vector: Arc::new(NullVector::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
161
src/script/src/python/ffi_types/vector/tests.rs
Normal file
161
src/script/src/python/ffi_types/vector/tests.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Here are pair-tests for vector types in both rustpython and cpython
|
||||
//!
|
||||
|
||||
// TODO: sample record batch
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::scalars::ScalarVector;
|
||||
use datatypes::vectors::{BooleanVector, Float64Vector, VectorRef};
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::{types::PyDict, Python};
|
||||
use rustpython_compiler::Mode;
|
||||
use rustpython_vm::class::PyClassImpl;
|
||||
use rustpython_vm::{vm, AsObject};
|
||||
|
||||
use crate::python::ffi_types::PyVector;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use crate::python::pyo3::{init_cpython_interpreter, vector_impl::into_pyo3_cell};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestCase {
|
||||
eval: String,
|
||||
result: VectorRef,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eval_py_vector_in_pairs() {
|
||||
let locals: HashMap<_, _> = sample_py_vector()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, PyVector::from(v)))
|
||||
.collect();
|
||||
|
||||
let testcases = get_test_cases();
|
||||
|
||||
for testcase in testcases {
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
eval_pyo3(testcase.clone(), locals.clone());
|
||||
eval_rspy(testcase, locals.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_py_vector() -> HashMap<String, VectorRef> {
|
||||
let b1 = Arc::new(BooleanVector::from_slice(&[false, false, true, true])) as VectorRef;
|
||||
let b2 = Arc::new(BooleanVector::from_slice(&[false, true, false, true])) as VectorRef;
|
||||
let f1 = Arc::new(Float64Vector::from_slice(&[0.0f64, 2.0, 10.0, 42.0])) as VectorRef;
|
||||
let f2 = Arc::new(Float64Vector::from_slice(&[-0.1f64, -42.0, 2., 7.0])) as VectorRef;
|
||||
HashMap::from([
|
||||
("b1".to_owned(), b1),
|
||||
("b2".to_owned(), b2),
|
||||
("f1".to_owned(), f1),
|
||||
("f2".to_owned(), f2),
|
||||
])
|
||||
}
|
||||
|
||||
/// testcases for test basic operations
|
||||
/// this is more powerful&flexible than standalone testcases configure file
|
||||
fn get_test_cases() -> Vec<TestCase> {
|
||||
let testcases = [
|
||||
TestCase {
|
||||
eval: "b1 & b2".to_string(),
|
||||
result: Arc::new(BooleanVector::from_slice(&[false, false, false, true])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "b1 | b2".to_string(),
|
||||
result: Arc::new(BooleanVector::from_slice(&[false, true, true, true])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "~b1".to_string(),
|
||||
result: Arc::new(BooleanVector::from_slice(&[true, true, false, false])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f1+f2".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice(&[-0.1f64, -40.0, 12., 49.0])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f1-f2".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice(&[0.1f64, 44.0, 8., 35.0])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f1*f2".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice(&[
|
||||
-0.0f64,
|
||||
-84.0,
|
||||
20.,
|
||||
42.0 * 7.0,
|
||||
])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f1/f2".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice(&[
|
||||
0.0 / -0.1f64,
|
||||
2. / -42.,
|
||||
10. / 2.,
|
||||
42. / 7.,
|
||||
])) as VectorRef,
|
||||
},
|
||||
];
|
||||
Vec::from(testcases)
|
||||
}
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
fn eval_pyo3(testcase: TestCase, locals: HashMap<String, PyVector>) {
|
||||
init_cpython_interpreter();
|
||||
Python::with_gil(|py| {
|
||||
let locals = {
|
||||
let locals_dict = PyDict::new(py);
|
||||
for (k, v) in locals {
|
||||
locals_dict
|
||||
.set_item(k, into_pyo3_cell(py, v).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
locals_dict
|
||||
};
|
||||
let res = py.eval(&testcase.eval, None, Some(locals)).unwrap();
|
||||
let res_vec = res.extract::<PyVector>().unwrap();
|
||||
let raw_arr = res_vec.as_vector_ref().to_arrow_array();
|
||||
let expect_arr = testcase.result.to_arrow_array();
|
||||
if *raw_arr != *expect_arr {
|
||||
panic!("{raw_arr:?}!={expect_arr:?}")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn eval_rspy(testcase: TestCase, locals: HashMap<String, PyVector>) {
|
||||
vm::Interpreter::without_stdlib(Default::default()).enter(|vm| {
|
||||
PyVector::make_class(&vm.ctx);
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
locals.into_iter().for_each(|(k, v)| {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(&k, vm.new_pyobj(v), vm)
|
||||
.unwrap();
|
||||
});
|
||||
let code_obj = vm
|
||||
.compile(&testcase.eval, Mode::Eval, "<embedded>".to_owned())
|
||||
.map_err(|err| vm.new_syntax_error(&err))
|
||||
.unwrap();
|
||||
let obj = vm.run_code_obj(code_obj, scope).unwrap();
|
||||
let v = obj.downcast::<PyVector>().unwrap();
|
||||
let result_arr = v.to_arrow_array();
|
||||
let expect_arr = testcase.result.to_arrow_array();
|
||||
if *result_arr != *expect_arr {
|
||||
panic!("{result_arr:?}!={expect_arr:?}")
|
||||
}
|
||||
});
|
||||
}
|
||||
24
src/script/src/python/pyo3.rs
Normal file
24
src/script/src/python/pyo3.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod builtins;
|
||||
pub(crate) mod copr_impl;
|
||||
mod dataframe_impl;
|
||||
mod utils;
|
||||
pub(crate) mod vector_impl;
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
pub(crate) use copr_impl::pyo3_exec_parsed;
|
||||
#[cfg(test)]
|
||||
pub(crate) use utils::init_cpython_interpreter;
|
||||
335
src/script/src/python/pyo3/builtins.rs
Normal file
335
src/script/src/python/pyo3/builtins.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY};
|
||||
use datafusion::arrow::array::{ArrayRef, NullArray};
|
||||
use datafusion::physical_plan::expressions;
|
||||
use datafusion_expr::ColumnarValue;
|
||||
use datafusion_physical_expr::{math_expressions, AggregateExpr};
|
||||
use datatypes::vectors::VectorRef;
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyList;
|
||||
|
||||
use super::utils::scalar_value_to_py_any;
|
||||
use crate::python::ffi_types::utils::all_to_f64;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::pyo3::dataframe_impl::col;
|
||||
use crate::python::pyo3::utils::{
|
||||
columnar_value_to_py_any, try_into_columnar_value, val_to_py_any,
|
||||
};
|
||||
|
||||
macro_rules! batch_import {
|
||||
($m: ident, [$($fn_name: ident),*]) => {
|
||||
$($m.add_function(wrap_pyfunction!($fn_name, $m)?)?;)*
|
||||
};
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
#[pyo3(name = "greptime")]
|
||||
pub(crate) fn greptime_builtins(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
batch_import!(
|
||||
m,
|
||||
[
|
||||
col,
|
||||
vector,
|
||||
pow,
|
||||
clip,
|
||||
diff,
|
||||
mean,
|
||||
polyval,
|
||||
argmax,
|
||||
argmin,
|
||||
percentile,
|
||||
scipy_stats_norm_cdf,
|
||||
scipy_stats_norm_pdf,
|
||||
sqrt,
|
||||
sin,
|
||||
cos,
|
||||
tan,
|
||||
asin,
|
||||
acos,
|
||||
atan,
|
||||
floor,
|
||||
ceil,
|
||||
round,
|
||||
trunc,
|
||||
abs,
|
||||
signum,
|
||||
exp,
|
||||
ln,
|
||||
log2,
|
||||
log10,
|
||||
random,
|
||||
approx_distinct,
|
||||
median,
|
||||
approx_percentile_cont,
|
||||
array_agg,
|
||||
avg,
|
||||
correlation,
|
||||
count,
|
||||
covariance,
|
||||
covariance_pop,
|
||||
max,
|
||||
min,
|
||||
stddev,
|
||||
stddev_pop,
|
||||
sum,
|
||||
variance,
|
||||
variance_pop
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn eval_func(py: Python<'_>, name: &str, v: &[&PyVector]) -> PyResult<PyVector> {
|
||||
py.allow_threads(|| {
|
||||
let v: Vec<VectorRef> = v.iter().map(|v| v.as_vector_ref()).collect();
|
||||
let func: Option<FunctionRef> = FUNCTION_REGISTRY.get_function(name);
|
||||
let res = match func {
|
||||
Some(f) => f.eval(Default::default(), &v),
|
||||
None => return Err(PyValueError::new_err(format!("Can't find function {name}"))),
|
||||
};
|
||||
match res {
|
||||
Ok(v) => Ok(v.into()),
|
||||
Err(err) => Err(PyValueError::new_err(format!(
|
||||
"Fail to evaluate the function,: {err}"
|
||||
))),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn eval_aggr_func(py: Python<'_>, name: &str, args: &[&PyVector]) -> PyResult<PyObject> {
|
||||
let res = py.allow_threads(|| {
|
||||
let v: Vec<VectorRef> = args.iter().map(|v| v.as_vector_ref()).collect();
|
||||
let func = FUNCTION_REGISTRY.get_aggr_function(name);
|
||||
let f = match func {
|
||||
Some(f) => f.create().creator(),
|
||||
None => return Err(PyValueError::new_err(format!("Can't find function {name}"))),
|
||||
};
|
||||
let types: Vec<_> = v.iter().map(|v| v.data_type()).collect();
|
||||
let acc = f(&types);
|
||||
let mut acc = match acc {
|
||||
Ok(acc) => acc,
|
||||
Err(err) => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Failed to create accumulator: {err}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
match acc.update_batch(&v) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Failed to update batch: {err}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
let res = match acc.evaluate() {
|
||||
Ok(r) => r,
|
||||
Err(err) => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Failed to evaluate accumulator: {err}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
})?;
|
||||
val_to_py_any(py, res)
|
||||
}
|
||||
|
||||
/// evaluate Aggregate Expr using its backing accumulator
|
||||
/// TODO(discord9): cast to f64 before use/Provide cast to f64 function?
|
||||
fn eval_aggr_expr<T: AggregateExpr>(
|
||||
py: Python<'_>,
|
||||
aggr: T,
|
||||
values: &[ArrayRef],
|
||||
) -> PyResult<PyObject> {
|
||||
let res = py.allow_threads(|| -> PyResult<_> {
|
||||
// acquire the accumulator, where the actual implement of aggregate expr layers
|
||||
let mut acc = aggr
|
||||
.create_accumulator()
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
acc.update_batch(values)
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
let res = acc
|
||||
.evaluate()
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
Ok(res)
|
||||
})?;
|
||||
scalar_value_to_py_any(py, res)
|
||||
}
|
||||
|
||||
/// use to bind to Data Fusion's UDF function
|
||||
macro_rules! bind_call_unary_math_function {
|
||||
($($DF_FUNC: ident),*) => {
|
||||
$(
|
||||
#[pyfunction]
|
||||
fn $DF_FUNC(py: Python<'_>, val: PyObject) -> PyResult<PyObject> {
|
||||
let args =
|
||||
&[all_to_f64(try_into_columnar_value(py, val)?).map_err(PyValueError::new_err)?];
|
||||
let res = math_expressions::$DF_FUNC(args).map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
columnar_value_to_py_any(py, res)
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! simple_vector_fn {
|
||||
($name: ident, $name_str: tt, [$($arg:ident),*]) => {
|
||||
#[pyfunction]
|
||||
fn $name(py: Python<'_>, $($arg: &PyVector),*) -> PyResult<PyVector> {
|
||||
eval_func(py, $name_str, &[$($arg),*])
|
||||
}
|
||||
};
|
||||
($name: ident, $name_str: tt, AGG[$($arg:ident),*]) => {
|
||||
#[pyfunction]
|
||||
fn $name(py: Python<'_>, $($arg: &PyVector),*) -> PyResult<PyObject> {
|
||||
eval_aggr_func(py, $name_str, &[$($arg),*])
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn vector(iterable: &PyList) -> PyResult<PyVector> {
|
||||
PyVector::py_new(iterable)
|
||||
}
|
||||
|
||||
// TODO(discord9): More Aggr functions& allow threads
|
||||
simple_vector_fn!(pow, "pow", [v0, v1]);
|
||||
simple_vector_fn!(clip, "clip", [v0, v1, v2]);
|
||||
simple_vector_fn!(diff, "diff", AGG[v0]);
|
||||
simple_vector_fn!(mean, "mean", AGG[v0]);
|
||||
simple_vector_fn!(polyval, "polyval", AGG[v0, v1]);
|
||||
simple_vector_fn!(argmax, "argmax", AGG[v0]);
|
||||
simple_vector_fn!(argmin, "argmin", AGG[v0]);
|
||||
simple_vector_fn!(percentile, "percentile", AGG[v0, v1]);
|
||||
simple_vector_fn!(scipy_stats_norm_cdf, "scipystatsnormcdf", AGG[v0, v1]);
|
||||
simple_vector_fn!(scipy_stats_norm_pdf, "scipystatsnormpdf", AGG[v0, v1]);
|
||||
|
||||
/*
|
||||
This macro basically expand to this code below:
|
||||
```rust
|
||||
fn sqrt(py: Python<'_>, val: PyObject) -> PyResult<PyObject> {
|
||||
let args = &[all_to_f64(try_into_columnar_value(py, val)?).map_err(PyValueError::new_err)?];
|
||||
let res = math_expressions::sqrt(args).map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
columnar_value_to_py_any(py, res)
|
||||
}
|
||||
```
|
||||
*/
|
||||
bind_call_unary_math_function!(
|
||||
sqrt, sin, cos, tan, asin, acos, atan, floor, ceil, round, trunc, abs, signum, exp, ln, log2,
|
||||
log10
|
||||
);
|
||||
|
||||
/// return a random vector range from 0 to 1 and length of len
|
||||
#[pyfunction]
|
||||
fn random(py: Python<'_>, len: usize) -> PyResult<PyObject> {
|
||||
// This is in a proc macro so using full path to avoid strange things
|
||||
// more info at: https://doc.rust-lang.org/reference/procedural-macros.html#procedural-macro-hygiene
|
||||
let arg = NullArray::new(len);
|
||||
let args = &[ColumnarValue::Array(std::sync::Arc::new(arg) as _)];
|
||||
let res =
|
||||
math_expressions::random(args).map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
|
||||
columnar_value_to_py_any(py, res)
|
||||
}
|
||||
|
||||
/// The macro for binding function in `datafusion_physical_expr::expressions`(most of them are aggregate function)
|
||||
macro_rules! bind_aggr_expr {
|
||||
($FUNC_NAME:ident, $AGGR_FUNC: ident, [$($ARG: ident),*], $ARG_TY: ident, $($EXPR:ident => $idx: literal),*) => {
|
||||
#[pyfunction]
|
||||
fn $FUNC_NAME(py: Python<'_>, $($ARG: &PyVector),*)->PyResult<PyObject>{
|
||||
// just a place holder, we just want the inner `XXXAccumulator`'s function
|
||||
// so its expr is irrelevant
|
||||
return eval_aggr_expr(
|
||||
py,
|
||||
expressions::$AGGR_FUNC::new(
|
||||
$(
|
||||
Arc::new(expressions::Column::new(stringify!($EXPR), $idx)) as _,
|
||||
)*
|
||||
stringify!($AGGR_FUNC),
|
||||
$ARG_TY.to_arrow_array().data_type().to_owned()),
|
||||
&[$($ARG.to_arrow_array()),*]
|
||||
)
|
||||
}
|
||||
};
|
||||
}
|
||||
/*
|
||||
`bind_aggr_expr!(approx_distinct, ApproxDistinct,[v0], v0, expr0=>0);`
|
||||
expand into:
|
||||
```
|
||||
fn approx_distinct(py: Python<'_>, v0: &PyVector) -> PyResult<PyObject> {
|
||||
return eval_aggr_expr(
|
||||
py,
|
||||
expressions::ApproxDistinct::new(
|
||||
Arc::new(expressions::Column::new("expr0", 0)) as _,
|
||||
"ApproxDistinct",
|
||||
v0.to_arrow_array().data_type().to_owned(),
|
||||
),
|
||||
&[v0.to_arrow_array()],
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
*/
|
||||
bind_aggr_expr!(approx_distinct, ApproxDistinct,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(median, Median,[v0], v0, expr0=>0);
|
||||
|
||||
#[pyfunction]
|
||||
fn approx_percentile_cont(py: Python<'_>, values: &PyVector, percent: f64) -> PyResult<PyObject> {
|
||||
let percent = expressions::Literal::new(datafusion_common::ScalarValue::Float64(Some(percent)));
|
||||
return eval_aggr_expr(
|
||||
py,
|
||||
expressions::ApproxPercentileCont::new(
|
||||
vec![
|
||||
Arc::new(expressions::Column::new("expr0", 0)) as _,
|
||||
Arc::new(percent) as _,
|
||||
],
|
||||
"ApproxPercentileCont",
|
||||
(values.to_arrow_array().data_type()).to_owned(),
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?,
|
||||
&[values.to_arrow_array()],
|
||||
);
|
||||
}
|
||||
|
||||
bind_aggr_expr!(array_agg, ArrayAgg,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(avg, Avg,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(correlation, Correlation,[v0, v1], v0, expr0=>0, expr1=>1);
|
||||
|
||||
bind_aggr_expr!(count, Count,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(covariance, Covariance,[v0, v1], v0, expr0=>0, expr1=>1);
|
||||
|
||||
bind_aggr_expr!(covariance_pop, CovariancePop,[v0, v1], v0, expr0=>0, expr1=>1);
|
||||
|
||||
bind_aggr_expr!(max, Max,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(min, Min,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(stddev, Stddev,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(stddev_pop, StddevPop,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(sum, Sum,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(variance, Variance,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(variance_pop, VariancePop,[v0], v0, expr0=>0);
|
||||
254
src/script/src/python/pyo3/copr_impl.rs
Normal file
254
src/script/src/python/pyo3/copr_impl.rs
Normal file
@@ -0,0 +1,254 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::types::{PyDict, PyList, PyModule, PyTuple};
|
||||
use pyo3::{pymethods, PyAny, PyCell, PyObject, PyResult, Python, ToPyObject};
|
||||
use snafu::{ensure, Backtrace, GenerateImplicitData, ResultExt};
|
||||
|
||||
use crate::python::error::{self, NewRecordBatchSnafu, OtherSnafu, Result};
|
||||
use crate::python::ffi_types::copr::PyQueryEngine;
|
||||
use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector};
|
||||
use crate::python::pyo3::builtins::greptime_builtins;
|
||||
use crate::python::pyo3::dataframe_impl::PyDataFrame;
|
||||
use crate::python::pyo3::utils::{init_cpython_interpreter, pyo3_obj_try_to_typed_val};
|
||||
|
||||
#[pymethods]
|
||||
impl PyQueryEngine {
|
||||
#[pyo3(name = "sql")]
|
||||
pub(crate) fn sql_pyo3(&self, py: Python<'_>, s: String) -> PyResult<PyObject> {
|
||||
let query = self.get_ref();
|
||||
let res = self
|
||||
.query_with_new_thread(query, s)
|
||||
.map_err(PyValueError::new_err)?;
|
||||
match res {
|
||||
crate::python::ffi_types::copr::Either::Rb(rbs) => {
|
||||
let mut top_vec = Vec::with_capacity(rbs.iter().count());
|
||||
for rb in rbs.iter() {
|
||||
let mut vec_of_vec = Vec::with_capacity(rb.columns().len());
|
||||
for v in rb.columns() {
|
||||
let v = PyVector::from(v.to_owned());
|
||||
let v = PyCell::new(py, v)?;
|
||||
vec_of_vec.push(v.to_object(py));
|
||||
}
|
||||
let vec_of_vec = PyList::new(py, vec_of_vec);
|
||||
top_vec.push(vec_of_vec);
|
||||
}
|
||||
let top_vec = PyList::new(py, top_vec);
|
||||
Ok(top_vec.to_object(py))
|
||||
}
|
||||
crate::python::ffi_types::copr::Either::AffectedRows(count) => Ok(count.to_object(py)),
|
||||
}
|
||||
}
|
||||
// TODO: put this into greptime module
|
||||
}
|
||||
/// Execute a `Coprocessor` with given `RecordBatch`
|
||||
pub(crate) fn pyo3_exec_parsed(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
params: &HashMap<String, String>,
|
||||
) -> Result<RecordBatch> {
|
||||
let arg_names = if let Some(names) = &copr.deco_args.arg_names {
|
||||
names
|
||||
} else {
|
||||
return OtherSnafu {
|
||||
reason: "PyO3 Backend doesn't support params yet".to_string(),
|
||||
}
|
||||
.fail();
|
||||
};
|
||||
let args: Vec<PyVector> = if let Some(rb) = rb {
|
||||
let args = select_from_rb(rb, arg_names)?;
|
||||
check_args_anno_real_type(&args, copr, rb)?;
|
||||
args
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
// Just in case cpython is not inited
|
||||
init_cpython_interpreter();
|
||||
Python::with_gil(|py| -> Result<_> {
|
||||
let mut cols = (|| -> PyResult<_> {
|
||||
let dummy_decorator = "
|
||||
# A dummy decorator, actual implementation is in Rust code
|
||||
def copr(*dummy, **kwdummy):
|
||||
def inner(func):
|
||||
return func
|
||||
return inner
|
||||
coprocessor = copr
|
||||
";
|
||||
let gen_call = format!("\n_return_from_coprocessor = {}(*_args_for_coprocessor, **_kwargs_for_coprocessor)", copr.name);
|
||||
let script = format!("{}{}{}", dummy_decorator, copr.script, gen_call);
|
||||
|
||||
let args = args
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|v| PyCell::new(py, v))
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
let args = PyTuple::new(py, args);
|
||||
|
||||
let kwargs = PyDict::new(py);
|
||||
if let Some(_copr_kwargs) = &copr.kwarg {
|
||||
for (k, v) in params {
|
||||
kwargs.set_item(k, v)?;
|
||||
}
|
||||
}
|
||||
|
||||
let py_main = PyModule::import(py, "__main__")?;
|
||||
let globals = py_main.dict();
|
||||
|
||||
let locals = PyDict::new(py);
|
||||
let greptime = PyModule::new(py, "greptime")?;
|
||||
greptime_builtins(py, greptime)?;
|
||||
locals.set_item("greptime", greptime)?;
|
||||
|
||||
if let Some(engine) = &copr.query_engine {
|
||||
let query_engine = PyQueryEngine::from_weakref(engine.clone());
|
||||
let query_engine = PyCell::new(py, query_engine)?;
|
||||
globals.set_item("query", query_engine)?;
|
||||
}
|
||||
|
||||
// TODO(discord9): find out why `dataframe` is not in scope
|
||||
if let Some(rb) = rb {
|
||||
let dataframe = PyDataFrame::from_record_batch(rb.df_record_batch())
|
||||
.map_err(|err|
|
||||
PyValueError::new_err(
|
||||
format!("Can't create dataframe from record batch: {}", err
|
||||
)
|
||||
)
|
||||
)?;
|
||||
let dataframe = PyCell::new(py, dataframe)?;
|
||||
globals.set_item("dataframe", dataframe)?;
|
||||
}
|
||||
|
||||
|
||||
locals.set_item("_args_for_coprocessor", args)?;
|
||||
locals.set_item("_kwargs_for_coprocessor", kwargs)?;
|
||||
|
||||
// TODO(discord9): find a better way to set `dataframe` and `query` in scope/ or set it into module(latter might be impossible and not idomatic even in python)
|
||||
// set `dataframe` and `query` in scope/ or set it into module
|
||||
// could generate a call in python code and use Python::run to run it, just like in RustPython
|
||||
// Expect either: a PyVector Or a List/Tuple of PyVector
|
||||
py.run(&script, Some(globals), Some(locals))?;
|
||||
let result = locals.get_item("_return_from_coprocessor").ok_or(PyValueError::new_err("Can't find return value of coprocessor function"))?;
|
||||
|
||||
let col_len = rb.as_ref().map(|rb| rb.num_rows()).unwrap_or(1);
|
||||
py_any_to_vec(result, col_len)
|
||||
})()
|
||||
.map_err(|err| error::Error::PyRuntime {
|
||||
msg: err.to_string(),
|
||||
backtrace: Backtrace::generate(),
|
||||
})?;
|
||||
ensure!(
|
||||
cols.len() == copr.deco_args.ret_names.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"The number of return Vector is wrong, expect {}, found {}",
|
||||
copr.deco_args.ret_names.len(),
|
||||
cols.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
copr.check_and_cast_type(&mut cols)?;
|
||||
let schema = copr.gen_schema(&cols)?;
|
||||
RecordBatch::new(schema, cols).context(NewRecordBatchSnafu)
|
||||
})
|
||||
}
|
||||
|
||||
/// Cast return of py script result to `Vec<VectorRef>`,
|
||||
/// constants will be broadcast to length of `col_len`
|
||||
fn py_any_to_vec(obj: &PyAny, col_len: usize) -> PyResult<Vec<VectorRef>> {
|
||||
if let Ok(tuple) = obj.downcast::<PyTuple>() {
|
||||
let len = tuple.len();
|
||||
let v = (0..len)
|
||||
.map(|idx| tuple.get_item(idx))
|
||||
.map(|elem| {
|
||||
elem.map(|any| py_obj_broadcast_to_vec(any, col_len))
|
||||
.and_then(|v| v)
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
Ok(v)
|
||||
} else {
|
||||
let ret = py_obj_broadcast_to_vec(obj, col_len)?;
|
||||
Ok(vec![ret])
|
||||
}
|
||||
}
|
||||
|
||||
fn py_obj_broadcast_to_vec(obj: &PyAny, col_len: usize) -> PyResult<VectorRef> {
|
||||
if let Ok(v) = obj.extract::<PyVector>() {
|
||||
Ok(v.as_vector_ref())
|
||||
} else {
|
||||
let val = pyo3_obj_try_to_typed_val(obj, None)?;
|
||||
let handler = |e: datatypes::Error| PyValueError::new_err(e.to_string());
|
||||
let v = Helper::try_from_scalar_value(
|
||||
val.try_to_scalar_value(&val.data_type()).map_err(handler)?,
|
||||
col_len,
|
||||
)
|
||||
.map_err(handler)?;
|
||||
Ok(v)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod copr_test {
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use datatypes::data_type::ConcreteDataType;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::{Float32Vector, Float64Vector, VectorRef};
|
||||
|
||||
use crate::python::ffi_types::copr::{exec_parsed, parse, BackendType};
|
||||
|
||||
#[test]
|
||||
#[allow(unused_must_use)]
|
||||
fn simple_test_pyo3_copr() {
|
||||
let python_source = r#"
|
||||
@copr(args=["cpu", "mem"], returns=["ref"], backend="pyo3")
|
||||
def a(cpu, mem, **kwargs):
|
||||
import greptime as gt
|
||||
from greptime import vector, log2, sum, pow, col
|
||||
for k, v in kwargs.items():
|
||||
print("%s == %s" % (k, v))
|
||||
print(dataframe.select([col("cpu")]).collect())
|
||||
return (0.5 < cpu) & ~( cpu >= 0.75)
|
||||
"#;
|
||||
let cpu_array = Float32Vector::from_slice([0.9f32, 0.8, 0.7, 0.3]);
|
||||
let mem_array = Float64Vector::from_slice([0.1f64, 0.2, 0.3, 0.4]);
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
ColumnSchema::new("cpu", ConcreteDataType::float32_datatype(), false),
|
||||
ColumnSchema::new("mem", ConcreteDataType::float64_datatype(), false),
|
||||
]));
|
||||
let rb = RecordBatch::new(
|
||||
schema,
|
||||
[
|
||||
Arc::new(cpu_array) as VectorRef,
|
||||
Arc::new(mem_array) as VectorRef,
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let copr = parse::parse_and_compile_copr(python_source, None).unwrap();
|
||||
assert_eq!(copr.backend, BackendType::CPython);
|
||||
let ret = exec_parsed(
|
||||
&copr,
|
||||
&Some(rb),
|
||||
&HashMap::from([("a".to_string(), "1".to_string())]),
|
||||
);
|
||||
dbg!(&ret);
|
||||
assert!(ret.is_ok());
|
||||
}
|
||||
}
|
||||
275
src/script/src/python/pyo3/dataframe_impl.rs
Normal file
275
src/script/src/python/pyo3/dataframe_impl.rs
Normal file
@@ -0,0 +1,275 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_recordbatch::DfRecordBatch;
|
||||
use datafusion::dataframe::DataFrame as DfDataFrame;
|
||||
use datafusion_expr::Expr as DfExpr;
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
use pyo3::types::PyList;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::python::error::DataFusionSnafu;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::utils::block_on_async;
|
||||
type PyExprRef = Py<PyExpr>;
|
||||
#[pyclass]
|
||||
pub(crate) struct PyDataFrame {
|
||||
inner: DfDataFrame,
|
||||
}
|
||||
|
||||
impl From<DfDataFrame> for PyDataFrame {
|
||||
fn from(inner: DfDataFrame) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl PyDataFrame {
|
||||
pub(crate) fn from_record_batch(rb: &DfRecordBatch) -> crate::python::error::Result<Self> {
|
||||
let ctx = datafusion::execution::context::SessionContext::new();
|
||||
let inner = ctx.read_batch(rb.clone()).context(DataFusionSnafu)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyDataFrame {
|
||||
fn select_columns(&self, columns: Vec<String>) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.select_columns(&columns.iter().map(AsRef::as_ref).collect::<Vec<&str>>())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn select(&self, py: Python<'_>, expr_list: Vec<PyExprRef>) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.select(
|
||||
expr_list
|
||||
.iter()
|
||||
.map(|e| e.borrow(py).inner.clone())
|
||||
.collect(),
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn filter(&self, predicate: &PyExpr) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.filter(predicate.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn aggregate(
|
||||
&self,
|
||||
py: Python<'_>,
|
||||
group_expr: Vec<PyExprRef>,
|
||||
aggr_expr: Vec<PyExprRef>,
|
||||
) -> PyResult<Self> {
|
||||
let ret = self.inner.clone().aggregate(
|
||||
group_expr
|
||||
.iter()
|
||||
.map(|i| i.borrow(py).inner.clone())
|
||||
.collect(),
|
||||
aggr_expr
|
||||
.iter()
|
||||
.map(|i| i.borrow(py).inner.clone())
|
||||
.collect(),
|
||||
);
|
||||
Ok(ret
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn limit(&self, skip: usize, fetch: Option<usize>) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.limit(skip, fetch)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn union(&self, df: &PyDataFrame) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.union(df.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn union_distinct(&self, df: &PyDataFrame) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.union_distinct(df.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn distinct(&self) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.distinct()
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn sort(&self, py: Python<'_>, expr: Vec<PyExprRef>) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.sort(expr.iter().map(|e| e.borrow(py).inner.clone()).collect())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn join(
|
||||
&self,
|
||||
py: Python<'_>,
|
||||
right: &PyDataFrame,
|
||||
join_type: String,
|
||||
left_cols: Vec<String>,
|
||||
right_cols: Vec<String>,
|
||||
filter: Option<PyExprRef>,
|
||||
) -> PyResult<Self> {
|
||||
use datafusion::prelude::JoinType;
|
||||
let join_type = match join_type.as_str() {
|
||||
"inner" | "Inner" => JoinType::Inner,
|
||||
"left" | "Left" => JoinType::Left,
|
||||
"right" | "Right" => JoinType::Right,
|
||||
"full" | "Full" => JoinType::Full,
|
||||
"leftSemi" | "LeftSemi" => JoinType::LeftSemi,
|
||||
"rightSemi" | "RightSemi" => JoinType::RightSemi,
|
||||
"leftAnti" | "LeftAnti" => JoinType::LeftAnti,
|
||||
"rightAnti" | "RightAnti" => JoinType::RightAnti,
|
||||
_ => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Unknown join type: {join_type}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
let left_cols: Vec<&str> = left_cols.iter().map(AsRef::as_ref).collect();
|
||||
let right_cols: Vec<&str> = right_cols.iter().map(AsRef::as_ref).collect();
|
||||
let filter = filter.map(|f| f.borrow(py).inner.clone());
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.join(
|
||||
right.inner.clone(),
|
||||
join_type,
|
||||
&left_cols,
|
||||
&right_cols,
|
||||
filter,
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn intersect(&self, py: Python<'_>, df: &PyDataFrame) -> PyResult<Self> {
|
||||
py.allow_threads(|| {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.intersect(df.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
})
|
||||
}
|
||||
fn except(&self, df: &PyDataFrame) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.except(df.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
/// collect `DataFrame` results into `List[List[Vector]]`
|
||||
fn collect<'a>(&self, py: Python<'a>) -> PyResult<&'a PyList> {
|
||||
let inner = self.inner.clone();
|
||||
let res = block_on_async(async { inner.collect().await });
|
||||
let res = res
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?;
|
||||
let outer_list: Vec<PyObject> = res
|
||||
.iter()
|
||||
.map(|elem| -> PyResult<_> {
|
||||
let inner_list: Vec<_> = elem
|
||||
.columns()
|
||||
.iter()
|
||||
.map(|arr| -> PyResult<_> {
|
||||
datatypes::vectors::Helper::try_into_vector(arr)
|
||||
.map(PyVector::from)
|
||||
.map(|v| PyCell::new(py, v))
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
.and_then(|x| x)
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
let inner_list = PyList::new(py, inner_list);
|
||||
Ok(inner_list.into())
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(PyList::new(py, outer_list))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub(crate) struct PyExpr {
|
||||
inner: DfExpr,
|
||||
}
|
||||
|
||||
impl From<datafusion_expr::Expr> for PyExpr {
|
||||
fn from(value: DfExpr) -> Self {
|
||||
Self { inner: value }
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub(crate) fn col(name: String) -> PyExpr {
|
||||
let expr: PyExpr = DfExpr::Column(datafusion_common::Column::from_name(name)).into();
|
||||
expr
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyExpr {
|
||||
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<Self> {
|
||||
let op = match op {
|
||||
CompareOp::Lt => DfExpr::lt,
|
||||
CompareOp::Le => DfExpr::lt_eq,
|
||||
CompareOp::Eq => DfExpr::eq,
|
||||
CompareOp::Ne => DfExpr::not_eq,
|
||||
CompareOp::Gt => DfExpr::gt,
|
||||
CompareOp::Ge => DfExpr::gt_eq,
|
||||
};
|
||||
Ok(op(self.inner.clone(), other.inner.clone()).into())
|
||||
}
|
||||
fn alias(&self, name: String) -> PyResult<PyExpr> {
|
||||
Ok(self.inner.clone().alias(name).into())
|
||||
}
|
||||
fn __and__(&self, py: Python<'_>, other: PyExprRef) -> PyResult<PyExpr> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.and(other.borrow(py).inner.clone())
|
||||
.into())
|
||||
}
|
||||
fn __or__(&self, py: Python<'_>, other: PyExprRef) -> PyResult<PyExpr> {
|
||||
Ok(self.inner.clone().or(other.borrow(py).inner.clone()).into())
|
||||
}
|
||||
fn __invert__(&self) -> PyResult<PyExpr> {
|
||||
Ok(self.inner.clone().not().into())
|
||||
}
|
||||
fn sort(&self, asc: bool, nulls_first: bool) -> PyExpr {
|
||||
self.inner.clone().sort(asc, nulls_first).into()
|
||||
}
|
||||
}
|
||||
287
src/script/src/python/pyo3/utils.rs
Normal file
287
src/script/src/python/pyo3/utils.rs
Normal file
@@ -0,0 +1,287 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use common_telemetry::info;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::ColumnarValue;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::{OrderedFloat, Value};
|
||||
use datatypes::vectors::Helper;
|
||||
use once_cell::sync::Lazy;
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyTuple};
|
||||
|
||||
use crate::python::ffi_types::utils::{collect_diff_types_string, new_item_field};
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::pyo3::builtins::greptime_builtins;
|
||||
|
||||
/// prevent race condition of init cpython
|
||||
static START_PYO3: Lazy<Mutex<bool>> = Lazy::new(|| Mutex::new(false));
|
||||
|
||||
pub(crate) fn init_cpython_interpreter() {
|
||||
let mut start = START_PYO3.lock().unwrap();
|
||||
if !*start {
|
||||
pyo3::append_to_inittab!(greptime_builtins);
|
||||
pyo3::prepare_freethreaded_python();
|
||||
*start = true;
|
||||
info!("Started CPython Interpreter");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn val_to_py_any(py: Python<'_>, val: Value) -> PyResult<PyObject> {
|
||||
Ok(match val {
|
||||
Value::Null => py.None(),
|
||||
Value::Boolean(val) => val.to_object(py),
|
||||
Value::UInt8(val) => val.to_object(py),
|
||||
Value::UInt16(val) => val.to_object(py),
|
||||
Value::UInt32(val) => val.to_object(py),
|
||||
Value::UInt64(val) => val.to_object(py),
|
||||
Value::Int8(val) => val.to_object(py),
|
||||
Value::Int16(val) => val.to_object(py),
|
||||
Value::Int32(val) => val.to_object(py),
|
||||
Value::Int64(val) => val.to_object(py),
|
||||
Value::Float32(val) => val.0.to_object(py),
|
||||
Value::Float64(val) => val.0.to_object(py),
|
||||
Value::String(val) => val.as_utf8().to_object(py),
|
||||
Value::Binary(val) => val.to_object(py),
|
||||
Value::Date(val) => val.val().to_object(py),
|
||||
Value::DateTime(val) => val.val().to_object(py),
|
||||
Value::Timestamp(val) => val.value().to_object(py),
|
||||
Value::List(val) => {
|
||||
let list = val.items().clone().unwrap_or(Default::default());
|
||||
let list = list
|
||||
.into_iter()
|
||||
.map(|v| val_to_py_any(py, v))
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
list.to_object(py)
|
||||
}
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Convert from {val:?} is not supported yet"
|
||||
)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
macro_rules! to_con_type {
|
||||
($dtype:ident,$obj:ident, $($cty:ident => $rty:ty),*$(,)?) => {
|
||||
match $dtype {
|
||||
$(
|
||||
ConcreteDataType::$cty(_) => $obj.extract::<$rty>().map(Value::$cty),
|
||||
)*
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
($dtype:ident,$obj:ident, $($cty:ident =ord=> $rty:ty),*$(,)?) => {
|
||||
match $dtype {
|
||||
$(
|
||||
ConcreteDataType::$cty(_) => $obj.extract::<$rty>()
|
||||
.map(OrderedFloat)
|
||||
.map(Value::$cty),
|
||||
)*
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// to int/float/boolean, if dtype is None, then convert to highest prec type
|
||||
pub(crate) fn pyo3_obj_try_to_typed_val(
|
||||
obj: &PyAny,
|
||||
dtype: Option<ConcreteDataType>,
|
||||
) -> PyResult<Value> {
|
||||
if let Ok(b) = obj.downcast::<PyBool>() {
|
||||
if let Some(ConcreteDataType::Boolean(_)) = dtype {
|
||||
let dtype = ConcreteDataType::boolean_datatype();
|
||||
let ret = to_con_type!(dtype, b,
|
||||
Boolean => bool
|
||||
)?;
|
||||
Ok(ret)
|
||||
} else {
|
||||
Err(PyValueError::new_err(format!(
|
||||
"Can't cast num to {dtype:?}"
|
||||
)))
|
||||
}
|
||||
} else if let Ok(num) = obj.downcast::<PyInt>() {
|
||||
if let Some(dtype) = dtype {
|
||||
if dtype.is_signed() || dtype.is_unsigned() {
|
||||
let ret = to_con_type!(dtype, num,
|
||||
Int8 => i8,
|
||||
Int16 => i16,
|
||||
Int32 => i32,
|
||||
Int64 => i64,
|
||||
UInt8 => u8,
|
||||
UInt16 => u16,
|
||||
UInt32 => u32,
|
||||
UInt64 => u64,
|
||||
)?;
|
||||
Ok(ret)
|
||||
} else {
|
||||
Err(PyValueError::new_err(format!(
|
||||
"Can't cast num to {dtype:?}"
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
num.extract::<i64>().map(Value::Int64)
|
||||
}
|
||||
} else if let Ok(num) = obj.downcast::<PyFloat>() {
|
||||
if let Some(dtype) = dtype {
|
||||
if dtype.is_float() {
|
||||
let ret = to_con_type!(dtype, num,
|
||||
Float32 =ord=> f32,
|
||||
Float64 =ord=> f64,
|
||||
)?;
|
||||
Ok(ret)
|
||||
} else {
|
||||
Err(PyValueError::new_err(format!(
|
||||
"Can't cast num to {dtype:?}"
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
num.extract::<f64>()
|
||||
.map(|v| Value::Float64(OrderedFloat(v)))
|
||||
}
|
||||
} else if let Ok(s) = obj.extract::<String>() {
|
||||
Ok(Value::String(s.into()))
|
||||
} else {
|
||||
Err(PyValueError::new_err(format!(
|
||||
"Can't cast {obj} to {dtype:?}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// cast a columnar value into python object
|
||||
///
|
||||
/// | Rust | Python |
|
||||
/// | ------ | --------------- |
|
||||
/// | Array | PyVector |
|
||||
/// | Scalar | int/float/bool/str |
|
||||
pub fn columnar_value_to_py_any(py: Python<'_>, val: ColumnarValue) -> PyResult<PyObject> {
|
||||
match val {
|
||||
ColumnarValue::Array(arr) => {
|
||||
let v = PyVector::from(
|
||||
Helper::try_into_vector(arr).map_err(|e| PyValueError::new_err(format!("{e}")))?,
|
||||
);
|
||||
Ok(PyCell::new(py, v)?.into())
|
||||
}
|
||||
ColumnarValue::Scalar(scalar) => scalar_value_to_py_any(py, scalar),
|
||||
}
|
||||
}
|
||||
|
||||
/// turn a ScalarValue into a Python Object, currently support
|
||||
pub fn scalar_value_to_py_any(py: Python<'_>, val: ScalarValue) -> PyResult<PyObject> {
|
||||
macro_rules! to_py_any {
|
||||
($val:ident, [$($scalar_ty:ident),*]) => {
|
||||
match val{
|
||||
ScalarValue::Null => Ok(py.None()),
|
||||
$(ScalarValue::$scalar_ty(Some(v)) => Ok(v.to_object(py)),)*
|
||||
ScalarValue::List(Some(col), _) => {
|
||||
let list:Vec<PyObject> = col
|
||||
.into_iter()
|
||||
.map(|v| scalar_value_to_py_any(py, v))
|
||||
.collect::<PyResult<_>>()?;
|
||||
let list = PyList::new(py, list);
|
||||
Ok(list.into())
|
||||
}
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Can't cast a Scalar Value `{:#?}` of type {:#?} to a Python Object",
|
||||
$val, $val.get_datatype()
|
||||
)))
|
||||
}
|
||||
};
|
||||
}
|
||||
to_py_any!(
|
||||
val,
|
||||
[
|
||||
Boolean, Float32, Float64, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
|
||||
Utf8, LargeUtf8
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
pub fn try_into_columnar_value(py: Python<'_>, obj: PyObject) -> PyResult<ColumnarValue> {
|
||||
macro_rules! to_rust_types {
|
||||
($obj: ident, $($ty: ty => $scalar_ty: ident),*) => {
|
||||
$(
|
||||
if let Ok(val) = $obj.extract::<$ty>(py) {
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::$scalar_ty(Some(val))))
|
||||
}
|
||||
)else*
|
||||
else{
|
||||
Err(PyValueError::new_err(format!("Can't cast {} into Columnar Value", $obj)))
|
||||
}
|
||||
};
|
||||
}
|
||||
if let Ok(v) = obj.extract::<PyVector>(py) {
|
||||
Ok(ColumnarValue::Array(v.to_arrow_array()))
|
||||
} else if obj.as_ref(py).is_instance_of::<PyList>()?
|
||||
|| obj.as_ref(py).is_instance_of::<PyTuple>()?
|
||||
{
|
||||
let ret: Vec<ScalarValue> = {
|
||||
if let Ok(val) = obj.downcast::<PyList>(py) {
|
||||
val.iter().map(|v|->PyResult<ScalarValue>{
|
||||
let val = try_into_columnar_value(py, v.into())?;
|
||||
match val{
|
||||
ColumnarValue::Array(arr) => Err(PyValueError::new_err(format!(
|
||||
"Expect only scalar value in a list, found a vector of type {:?} nested in list", arr.data_type()
|
||||
))),
|
||||
ColumnarValue::Scalar(val) => Ok(val),
|
||||
}
|
||||
}).collect::<PyResult<_>>()?
|
||||
} else if let Ok(val) = obj.downcast::<PyTuple>(py) {
|
||||
val.iter().map(|v|->PyResult<ScalarValue>{
|
||||
let val = try_into_columnar_value(py, v.into())?;
|
||||
match val{
|
||||
ColumnarValue::Array(arr) => Err(PyValueError::new_err(format!(
|
||||
"Expect only scalar value in a tuple, found a vector of type {:?} nested in tuple", arr.data_type()
|
||||
))),
|
||||
ColumnarValue::Scalar(val) => Ok(val),
|
||||
}
|
||||
}).collect::<PyResult<_>>()?
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
|
||||
if ret.is_empty() {
|
||||
return Ok(ColumnarValue::Scalar(ScalarValue::List(
|
||||
None,
|
||||
Box::new(new_item_field(ArrowDataType::Null)),
|
||||
)));
|
||||
}
|
||||
let ty = ret[0].get_datatype();
|
||||
|
||||
if ret.iter().any(|i| i.get_datatype() != ty) {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"All elements in a list should be same type to cast to Datafusion list!\nExpect {ty:?}, found {}",
|
||||
collect_diff_types_string(&ret, &ty)
|
||||
)));
|
||||
}
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::List(
|
||||
Some(ret),
|
||||
Box::new(new_item_field(ty)),
|
||||
)))
|
||||
} else {
|
||||
to_rust_types!(obj,
|
||||
bool => Boolean,
|
||||
i64 => Int64,
|
||||
f64 => Float64,
|
||||
String => Utf8
|
||||
)
|
||||
}
|
||||
}
|
||||
277
src/script/src/python/pyo3/vector_impl.rs
Normal file
277
src/script/src/python/pyo3/vector_impl.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use datafusion::arrow::compute::kernels::{arithmetic, comparison};
|
||||
use datatypes::arrow::array::{Array, ArrayRef};
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::prelude::{ConcreteDataType, DataType};
|
||||
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString};
|
||||
|
||||
use crate::python::ffi_types::vector::{wrap_bool_result, wrap_result, PyVector};
|
||||
use crate::python::pyo3::utils::pyo3_obj_try_to_typed_val;
|
||||
|
||||
macro_rules! get_con_type {
|
||||
($obj:ident, $($pyty:ident => $con_ty:ident),*$(,)?) => {
|
||||
$(
|
||||
if $obj.is_instance_of::<$pyty>()?{
|
||||
Ok(ConcreteDataType::$con_ty())
|
||||
}
|
||||
)else* else{
|
||||
Err(PyValueError::new_err("Unsupported pyobject type: {obj:?}"))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn get_py_type(obj: &PyAny) -> PyResult<ConcreteDataType> {
|
||||
// Bool need to precede Int because `PyBool` is also a instance of `PyInt`
|
||||
get_con_type!(obj,
|
||||
PyBool => boolean_datatype,
|
||||
PyInt => int64_datatype,
|
||||
PyFloat => float64_datatype,
|
||||
PyString => string_datatype
|
||||
)
|
||||
}
|
||||
|
||||
fn pyo3_is_obj_scalar(obj: &PyAny) -> bool {
|
||||
get_py_type(obj).is_ok()
|
||||
}
|
||||
|
||||
impl PyVector {
|
||||
fn pyo3_scalar_arith_op<F>(
|
||||
&self,
|
||||
py: Python<'_>,
|
||||
right: PyObject,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
) -> PyResult<Self>
|
||||
where
|
||||
F: Fn(&dyn Array, &dyn Array) -> Result<ArrayRef, String> + Send,
|
||||
{
|
||||
let right = pyo3_obj_try_to_typed_val(right.as_ref(py), None)?;
|
||||
py.allow_threads(|| {
|
||||
self.scalar_arith_op(right, target_type, op)
|
||||
.map_err(PyValueError::new_err)
|
||||
})
|
||||
}
|
||||
fn pyo3_vector_arith_op<F>(
|
||||
&self,
|
||||
py: Python<'_>,
|
||||
right: PyObject,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
) -> PyResult<Self>
|
||||
where
|
||||
F: Fn(&dyn Array, &dyn Array) -> Result<ArrayRef, String> + Send,
|
||||
{
|
||||
let right = right.extract::<PyVector>(py)?;
|
||||
py.allow_threads(|| {
|
||||
self.vector_arith_op(&right, target_type, op)
|
||||
.map_err(PyValueError::new_err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyVector {
|
||||
/// create a `PyVector` with a `PyList` that contains only elements of same type
|
||||
#[new]
|
||||
pub(crate) fn py_new(iterable: &PyList) -> PyResult<Self> {
|
||||
let dtype = get_py_type(iterable.get_item(0)?)?;
|
||||
let mut buf = dtype.create_mutable_vector(iterable.len());
|
||||
for i in 0..iterable.len() {
|
||||
let element = iterable.get_item(i)?;
|
||||
let val = pyo3_obj_try_to_typed_val(element, Some(dtype.clone()))?;
|
||||
buf.push_value_ref(val.as_value_ref());
|
||||
}
|
||||
Ok(buf.to_vector().into())
|
||||
}
|
||||
fn __richcmp__(&self, py: Python<'_>, other: PyObject, op: CompareOp) -> PyResult<Self> {
|
||||
let op_fn = match op {
|
||||
CompareOp::Lt => comparison::lt_dyn,
|
||||
CompareOp::Le => comparison::lt_eq_dyn,
|
||||
CompareOp::Eq => comparison::eq_dyn,
|
||||
CompareOp::Ne => comparison::neq_dyn,
|
||||
CompareOp::Gt => comparison::gt_dyn,
|
||||
CompareOp::Ge => comparison::gt_eq_dyn,
|
||||
};
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, None, wrap_bool_result(op_fn))
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(py, other, None, wrap_bool_result(op_fn))
|
||||
}
|
||||
}
|
||||
|
||||
fn __add__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, None, wrap_result(arithmetic::add_dyn))
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(py, other, None, wrap_result(arithmetic::add_dyn))
|
||||
}
|
||||
}
|
||||
fn __radd__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
self.__add__(py, other)
|
||||
}
|
||||
|
||||
fn __sub__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, None, wrap_result(arithmetic::subtract_dyn))
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(py, other, None, wrap_result(arithmetic::subtract_dyn))
|
||||
}
|
||||
}
|
||||
fn __rsub__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(
|
||||
py,
|
||||
other,
|
||||
None,
|
||||
wrap_result(|a, b| arithmetic::subtract_dyn(b, a)),
|
||||
)
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(
|
||||
py,
|
||||
other,
|
||||
None,
|
||||
wrap_result(|a, b| arithmetic::subtract_dyn(b, a)),
|
||||
)
|
||||
}
|
||||
}
|
||||
fn __mul__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, None, wrap_result(arithmetic::multiply_dyn))
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(py, other, None, wrap_result(arithmetic::multiply_dyn))
|
||||
}
|
||||
}
|
||||
fn __rmul__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
self.__mul__(py, other)
|
||||
}
|
||||
fn __truediv__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(
|
||||
py,
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(arithmetic::divide_dyn),
|
||||
)
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(
|
||||
py,
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(arithmetic::divide_dyn),
|
||||
)
|
||||
}
|
||||
}
|
||||
#[allow(unused)]
|
||||
fn __rtruediv__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
Err(PyNotImplementedError::new_err(()))
|
||||
}
|
||||
#[allow(unused)]
|
||||
fn __floordiv__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
Err(PyNotImplementedError::new_err(()))
|
||||
}
|
||||
#[allow(unused)]
|
||||
fn __rfloordiv__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
Err(PyNotImplementedError::new_err(()))
|
||||
}
|
||||
fn __and__(&self, other: &Self) -> PyResult<Self> {
|
||||
Self::vector_and(self, other).map_err(PyValueError::new_err)
|
||||
}
|
||||
fn __or__(&self, other: &Self) -> PyResult<Self> {
|
||||
Self::vector_or(self, other).map_err(PyValueError::new_err)
|
||||
}
|
||||
fn __invert__(&self) -> PyResult<Self> {
|
||||
Self::vector_invert(self).map_err(PyValueError::new_err)
|
||||
}
|
||||
fn __len__(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
fn __doc__(&self) -> PyResult<String> {
|
||||
Ok("PyVector is like a Python array, a compact array of elem of same datatype, but Readonly for now".to_string())
|
||||
}
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(format!("{self:#?}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn into_pyo3_cell(py: Python, val: PyVector) -> PyResult<&PyCell<PyVector>> {
|
||||
PyCell::new(py, val)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::scalars::ScalarVector;
|
||||
use datatypes::vectors::{BooleanVector, Float64Vector, VectorRef};
|
||||
use pyo3::types::{PyDict, PyModule};
|
||||
use pyo3::{PyCell, Python};
|
||||
|
||||
use crate::python::ffi_types::vector::PyVector;
|
||||
use crate::python::pyo3::init_cpython_interpreter;
|
||||
fn sample_vector() -> HashMap<String, PyVector> {
|
||||
let mut locals = HashMap::new();
|
||||
let b = BooleanVector::from_slice(&[true, false, true, true]);
|
||||
let b: PyVector = (Arc::new(b) as VectorRef).into();
|
||||
locals.insert("bv1".to_string(), b);
|
||||
let b = BooleanVector::from_slice(&[false, false, false, true]);
|
||||
let b: PyVector = (Arc::new(b) as VectorRef).into();
|
||||
locals.insert("bv2".to_string(), b);
|
||||
|
||||
let f = Float64Vector::from_slice(&[0.0f64, 1.0, 42.0, 3.0]);
|
||||
let f: PyVector = (Arc::new(f) as VectorRef).into();
|
||||
locals.insert("fv1".to_string(), f);
|
||||
let f = Float64Vector::from_slice(&[1919.810f64, 0.114, 51.4, 3.0]);
|
||||
let f: PyVector = (Arc::new(f) as VectorRef).into();
|
||||
locals.insert("fv2".to_string(), f);
|
||||
locals
|
||||
}
|
||||
#[test]
|
||||
fn test_py_vector_api() {
|
||||
init_cpython_interpreter();
|
||||
Python::with_gil(|py| {
|
||||
let module = PyModule::new(py, "gt").unwrap();
|
||||
module.add_class::<PyVector>().unwrap();
|
||||
// Import and get sys.modules
|
||||
let sys = PyModule::import(py, "sys").unwrap();
|
||||
let py_modules: &PyDict = sys.getattr("modules").unwrap().downcast().unwrap();
|
||||
|
||||
// Insert foo into sys.modules
|
||||
py_modules.set_item("gt", module).unwrap();
|
||||
|
||||
let locals = PyDict::new(py);
|
||||
for (k, v) in sample_vector() {
|
||||
locals.set_item(k, PyCell::new(py, v).unwrap()).unwrap();
|
||||
}
|
||||
// ~bool_v1&bool_v2
|
||||
py.run(
|
||||
r#"
|
||||
from gt import vector
|
||||
print(vector([1,2]))
|
||||
print(fv1+fv2)
|
||||
"#,
|
||||
None,
|
||||
Some(locals),
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
}
|
||||
26
src/script/src/python/rspython.rs
Normal file
26
src/script/src/python/rspython.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod copr_impl;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
pub(crate) mod vector_impl;
|
||||
|
||||
pub(crate) mod builtins;
|
||||
mod dataframe_impl;
|
||||
mod utils;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use copr_impl::init_interpreter;
|
||||
pub(crate) use copr_impl::rspy_exec_parsed;
|
||||
@@ -27,8 +27,15 @@ use datatypes::vectors::Helper as HelperVec;
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt, PyList, PyStr};
|
||||
use rustpython_vm::{pymodule, AsObject, PyObjectRef, PyPayload, PyResult, VirtualMachine};
|
||||
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::utils::is_instance;
|
||||
use crate::python::PyVector;
|
||||
|
||||
pub fn init_greptime_builtins(module_name: &str, vm: &mut VirtualMachine) {
|
||||
vm.add_native_module(
|
||||
module_name.to_string(),
|
||||
Box::new(greptime_builtin::make_module),
|
||||
);
|
||||
}
|
||||
|
||||
/// "Can't cast operand of type `{name}` into `{ty}`."
|
||||
fn type_cast_error(name: &str, ty: &str, vm: &VirtualMachine) -> PyBaseExceptionRef {
|
||||
@@ -295,13 +302,13 @@ pub(crate) mod greptime_builtin {
|
||||
use rustpython_vm::function::{FuncArgs, KwArgs, OptionalArg};
|
||||
use rustpython_vm::{AsObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine};
|
||||
|
||||
use crate::python::builtins::{
|
||||
use super::{
|
||||
all_to_f64, eval_aggr_fn, from_df_err, try_into_columnar_value, try_into_py_obj,
|
||||
type_cast_error,
|
||||
};
|
||||
use crate::python::utils::{is_instance, py_vec_obj_to_array, PyVectorRef};
|
||||
use crate::python::vector::val_to_pyobj;
|
||||
use crate::python::PyVector;
|
||||
use crate::python::ffi_types::vector::val_to_pyobj;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::rspython::utils::{is_instance, py_vec_obj_to_array, PyVectorRef};
|
||||
|
||||
#[pyfunction]
|
||||
fn vector(args: OptionalArg<PyObjectRef>, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
@@ -356,7 +363,7 @@ pub(crate) mod greptime_builtin {
|
||||
return Err(vm.new_runtime_error(format!("Failed to evaluate accumulator: {err}")))
|
||||
}
|
||||
};
|
||||
let res = val_to_pyobj(res, vm);
|
||||
let res = val_to_pyobj(res, vm)?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
@@ -962,7 +969,7 @@ pub(crate) mod greptime_builtin {
|
||||
Ok(obj) => match py_vec_obj_to_array(&obj, vm, 1){
|
||||
Ok(v) => if v.len()==1{
|
||||
Ok(v)
|
||||
} else {
|
||||
}else{
|
||||
Err(vm.new_runtime_error(format!("Expect return's length to be at most one, found to be length of {}.", v.len())))
|
||||
},
|
||||
Err(err) => Err(vm
|
||||
@@ -32,8 +32,8 @@ use rustpython_vm::{AsObject, PyObjectRef, VirtualMachine};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::utils::{format_py_error, is_instance};
|
||||
use crate::python::PyVector;
|
||||
|
||||
#[test]
|
||||
fn convert_scalar_to_py_obj_and_back() {
|
||||
@@ -307,7 +307,7 @@ impl PyValue {
|
||||
fn run_builtin_fn_testcases() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let loc = Path::new("src/python/builtins/testcases.ron");
|
||||
let loc = Path::new("src/python/rspython/builtins/testcases.ron");
|
||||
let loc = loc.to_str().expect("Fail to parse path");
|
||||
let mut file = File::open(loc).expect("Fail to open file");
|
||||
let mut buf = String::new();
|
||||
220
src/script/src/python/rspython/copr_impl.rs
Normal file
220
src/script/src/python/rspython/copr_impl.rs
Normal file
@@ -0,0 +1,220 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_telemetry::info;
|
||||
use datatypes::vectors::VectorRef;
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyDict, PyStr, PyTuple};
|
||||
use rustpython_vm::class::PyClassImpl;
|
||||
use rustpython_vm::convert::ToPyObject;
|
||||
use rustpython_vm::scope::Scope;
|
||||
use rustpython_vm::{vm, AsObject, Interpreter, PyObjectRef, PyPayload, VirtualMachine};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::python::error::{ensure, ret_other_error_with, NewRecordBatchSnafu, OtherSnafu, Result};
|
||||
use crate::python::ffi_types::copr::PyQueryEngine;
|
||||
use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector};
|
||||
use crate::python::rspython::builtins::init_greptime_builtins;
|
||||
use crate::python::rspython::dataframe_impl::data_frame::set_dataframe_in_scope;
|
||||
use crate::python::rspython::dataframe_impl::init_data_frame;
|
||||
use crate::python::rspython::utils::{format_py_error, is_instance, py_vec_obj_to_array};
|
||||
|
||||
thread_local!(static INTERPRETER: RefCell<Option<Arc<Interpreter>>> = RefCell::new(None));
|
||||
|
||||
/// Using `RustPython` to run a parsed `Coprocessor` struct as input to execute python code
|
||||
pub(crate) fn rspy_exec_parsed(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
params: &HashMap<String, String>,
|
||||
) -> Result<RecordBatch> {
|
||||
// 3. get args from `rb`, and cast them into PyVector
|
||||
let args: Vec<PyVector> = if let Some(rb) = rb {
|
||||
let args = select_from_rb(rb, copr.deco_args.arg_names.as_ref().unwrap_or(&vec![]))?;
|
||||
check_args_anno_real_type(&args, copr, rb)?;
|
||||
args
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let interpreter = init_interpreter();
|
||||
// 4. then set args in scope and compile then run `CodeObject` which already append a new `Call` node
|
||||
exec_with_cached_vm(copr, rb, args, params, &interpreter)
|
||||
}
|
||||
|
||||
/// set arguments with given name and values in python scopes
|
||||
fn set_items_in_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
arg_names: &[String],
|
||||
args: Vec<PyVector>,
|
||||
) -> Result<()> {
|
||||
let _ = arg_names
|
||||
.iter()
|
||||
.zip(args)
|
||||
.map(|(name, vector)| {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(name, vm.new_pyobj(vector), vm)
|
||||
})
|
||||
.collect::<StdResult<Vec<()>, PyBaseExceptionRef>>()
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_query_engine_in_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
query_engine: PyQueryEngine,
|
||||
) -> Result<()> {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item("query", query_engine.to_pyobject(vm), vm)
|
||||
.map_err(|e| format_py_error(e, vm))
|
||||
}
|
||||
|
||||
pub(crate) fn exec_with_cached_vm(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
args: Vec<PyVector>,
|
||||
params: &HashMap<String, String>,
|
||||
vm: &Arc<Interpreter>,
|
||||
) -> Result<RecordBatch> {
|
||||
vm.enter(|vm| -> Result<RecordBatch> {
|
||||
// set arguments with given name and values
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
if let Some(rb) = rb {
|
||||
set_dataframe_in_scope(&scope, vm, "dataframe", rb)?;
|
||||
}
|
||||
|
||||
if let Some(arg_names) = &copr.deco_args.arg_names {
|
||||
assert_eq!(arg_names.len(), args.len());
|
||||
set_items_in_scope(&scope, vm, arg_names, args)?;
|
||||
}
|
||||
|
||||
if let Some(engine) = &copr.query_engine {
|
||||
let query_engine = PyQueryEngine::from_weakref(engine.clone());
|
||||
|
||||
// put a object named with query of class PyQueryEngine in scope
|
||||
set_query_engine_in_scope(&scope, vm, query_engine)?;
|
||||
}
|
||||
|
||||
if let Some(kwarg) = &copr.kwarg {
|
||||
let dict = PyDict::new_ref(&vm.ctx);
|
||||
for (k, v) in params {
|
||||
dict.set_item(k, PyStr::from(v.clone()).into_pyobject(vm), vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
}
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(kwarg, vm.new_pyobj(dict), vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
}
|
||||
|
||||
// It's safe to unwrap code_object, it's already compiled before.
|
||||
let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap());
|
||||
let ret = vm
|
||||
.run_code_obj(code_obj, scope)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
// 5. get returns as either a PyVector or a PyTuple, and naming schema them according to `returns`
|
||||
let col_len = rb.as_ref().map(|rb| rb.num_rows()).unwrap_or(1);
|
||||
let mut cols = try_into_columns(&ret, vm, col_len)?;
|
||||
ensure!(
|
||||
cols.len() == copr.deco_args.ret_names.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"The number of return Vector is wrong, expect {}, found {}",
|
||||
copr.deco_args.ret_names.len(),
|
||||
cols.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
// if cols and schema's data types is not match, try coerce it to given type(if annotated)(if error occur, return relevant error with question mark)
|
||||
copr.check_and_cast_type(&mut cols)?;
|
||||
|
||||
// 6. return a assembled DfRecordBatch
|
||||
let schema = copr.gen_schema(&cols)?;
|
||||
RecordBatch::new(schema, cols).context(NewRecordBatchSnafu)
|
||||
})
|
||||
}
|
||||
|
||||
/// convert a tuple of `PyVector` or one `PyVector`(wrapped in a Python Object Ref[`PyObjectRef`])
|
||||
/// to a `Vec<ArrayRef>`
|
||||
/// by default, a constant(int/float/bool) gives the a constant array of same length with input args
|
||||
fn try_into_columns(
|
||||
obj: &PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
col_len: usize,
|
||||
) -> Result<Vec<VectorRef>> {
|
||||
if is_instance::<PyTuple>(obj, vm) {
|
||||
let tuple = obj
|
||||
.payload::<PyTuple>()
|
||||
.with_context(|| ret_other_error_with(format!("can't cast obj {obj:?} to PyTuple)")))?;
|
||||
let cols = tuple
|
||||
.iter()
|
||||
.map(|obj| py_vec_obj_to_array(obj, vm, col_len))
|
||||
.collect::<Result<Vec<VectorRef>>>()?;
|
||||
Ok(cols)
|
||||
} else {
|
||||
let col = py_vec_obj_to_array(obj, vm, col_len)?;
|
||||
Ok(vec![col])
|
||||
}
|
||||
}
|
||||
|
||||
/// init interpreter with type PyVector and Module: greptime
|
||||
pub(crate) fn init_interpreter() -> Arc<Interpreter> {
|
||||
INTERPRETER.with(|i| {
|
||||
i.borrow_mut()
|
||||
.get_or_insert_with(|| {
|
||||
// we limit stdlib imports for safety reason, i.e `fcntl` is not allowed here
|
||||
let native_module_allow_list = HashSet::from([
|
||||
"array", "cmath", "gc", "hashlib", "_json", "_random", "math",
|
||||
]);
|
||||
// edge cases, can't use "..Default::default" because Settings is `#[non_exhaustive]`
|
||||
// so more in here: https://internals.rust-lang.org/t/allow-constructing-non-exhaustive-structs-using-default-default/13868
|
||||
let mut settings = vm::Settings::default();
|
||||
// disable SIG_INT handler so our own binary can take ctrl_c handler
|
||||
settings.no_sig_int = true;
|
||||
let interpreter = Arc::new(vm::Interpreter::with_init(settings, |vm| {
|
||||
// not using full stdlib to prevent security issue, instead filter out a few simple util module
|
||||
vm.add_native_modules(
|
||||
rustpython_stdlib::get_module_inits()
|
||||
.filter(|(k, _)| native_module_allow_list.contains(k.as_ref())),
|
||||
);
|
||||
|
||||
// We are freezing the stdlib to include the standard library inside the binary.
|
||||
// so according to this issue:
|
||||
// https://github.com/RustPython/RustPython/issues/4292
|
||||
// add this line for stdlib, so rustpython can found stdlib's python part in bytecode format
|
||||
vm.add_frozen(rustpython_pylib::frozen_stdlib());
|
||||
// add our own custom datatype and module
|
||||
PyVector::make_class(&vm.ctx);
|
||||
PyQueryEngine::make_class(&vm.ctx);
|
||||
init_greptime_builtins("greptime", vm);
|
||||
init_data_frame("data_frame", vm);
|
||||
}));
|
||||
info!("Initialized Python interpreter.");
|
||||
interpreter
|
||||
})
|
||||
.clone()
|
||||
})
|
||||
}
|
||||
@@ -12,8 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use rustpython_vm::pymodule as rspymodule;
|
||||
|
||||
use rustpython_vm::class::PyClassImpl;
|
||||
use rustpython_vm::{pymodule as rspymodule, VirtualMachine};
|
||||
pub(crate) fn init_data_frame(module_name: &str, vm: &mut VirtualMachine) {
|
||||
data_frame::PyDataFrame::make_class(&vm.ctx);
|
||||
data_frame::PyExpr::make_class(&vm.ctx);
|
||||
vm.add_native_module(module_name.to_owned(), Box::new(data_frame::make_module));
|
||||
}
|
||||
/// with `register_batch`, and then wrap DataFrame API in it
|
||||
#[rspymodule]
|
||||
pub(crate) mod data_frame {
|
||||
@@ -29,6 +34,7 @@ pub(crate) mod data_frame {
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::python::error::DataFusionSnafu;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::utils::block_on_async;
|
||||
#[rspyclass(module = "data_frame", name = "DataFrame")]
|
||||
#[derive(PyPayload, Debug)]
|
||||
@@ -232,7 +238,7 @@ pub(crate) mod data_frame {
|
||||
.iter()
|
||||
.map(|arr| -> PyResult<_> {
|
||||
datatypes::vectors::Helper::try_into_vector(arr)
|
||||
.map(crate::python::PyVector::from)
|
||||
.map(PyVector::from)
|
||||
.map(|v| vm.new_pyobj(v))
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))
|
||||
})
|
||||
@@ -28,11 +28,10 @@ use ron::from_str as from_ron_string;
|
||||
use rustpython_parser::parser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::error::{get_error_reason_loc, visualize_loc};
|
||||
use crate::python::coprocessor;
|
||||
use crate::python::coprocessor::parse::parse_and_compile_copr;
|
||||
use crate::python::coprocessor::{AnnotationInfo, Coprocessor};
|
||||
use crate::python::error::{pretty_print_error_in_src, Error};
|
||||
use crate::python::error::{get_error_reason_loc, pretty_print_error_in_src, visualize_loc, Error};
|
||||
use crate::python::ffi_types::copr::parse::parse_and_compile_copr;
|
||||
use crate::python::ffi_types::copr::{exec_coprocessor, AnnotationInfo};
|
||||
use crate::python::ffi_types::Coprocessor;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct TestCase {
|
||||
@@ -91,7 +90,7 @@ fn create_sample_recordbatch() -> RecordBatch {
|
||||
fn run_ron_testcases() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let loc = Path::new("src/python/testcases.ron");
|
||||
let loc = Path::new("src/python/rspython/testcases.ron");
|
||||
let loc = loc.to_str().expect("Fail to parse path");
|
||||
let mut file = File::open(loc).expect("Fail to open file");
|
||||
let mut buf = String::new();
|
||||
@@ -126,7 +125,7 @@ fn run_ron_testcases() {
|
||||
}
|
||||
Predicate::ExecIsOk { fields, columns } => {
|
||||
let rb = create_sample_recordbatch();
|
||||
let res = coprocessor::exec_coprocessor(&testcase.code, &Some(rb)).unwrap();
|
||||
let res = exec_coprocessor(&testcase.code, &Some(rb)).unwrap();
|
||||
fields
|
||||
.iter()
|
||||
.zip(res.schema.column_schemas())
|
||||
@@ -152,7 +151,7 @@ fn run_ron_testcases() {
|
||||
reason: part_reason,
|
||||
} => {
|
||||
let rb = create_sample_recordbatch();
|
||||
let res = coprocessor::exec_coprocessor(&testcase.code, &Some(rb));
|
||||
let res = exec_coprocessor(&testcase.code, &Some(rb));
|
||||
assert!(res.is_err(), "{res:#?}\nExpect Err(...), actual Ok(...)");
|
||||
if let Err(res) = res {
|
||||
error!(
|
||||
@@ -254,7 +253,7 @@ def calc_rvs(open_time, close):
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let ret = coprocessor::exec_coprocessor(python_source, &Some(rb));
|
||||
let ret = exec_coprocessor(python_source, &Some(rb));
|
||||
if let Err(Error::PyParse {
|
||||
backtrace: _,
|
||||
source,
|
||||
@@ -304,7 +303,7 @@ def a(cpu, mem):
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let ret = coprocessor::exec_coprocessor(python_source, &Some(rb));
|
||||
let ret = exec_coprocessor(python_source, &Some(rb));
|
||||
if let Err(Error::PyParse {
|
||||
backtrace: _,
|
||||
source,
|
||||
@@ -274,7 +274,7 @@ def a(cpu: vector[f64], mem: vector[f64])->(vector[f64|None], vector[into(f64)],
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
" keyword argument, found "
|
||||
"Expect a list of String, found one element to be"
|
||||
)
|
||||
),
|
||||
(
|
||||
131
src/script/src/python/rspython/utils.rs
Normal file
131
src/script/src/python/rspython/utils.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::ColumnarValue as DFColValue;
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::vectors::{
|
||||
BooleanVector, Float64Vector, Helper, Int64Vector, NullVector, StringVector, VectorRef,
|
||||
};
|
||||
use futures::Future;
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt, PyList, PyStr};
|
||||
use rustpython_vm::{PyObjectRef, PyPayload, PyRef, VirtualMachine};
|
||||
use snafu::{Backtrace, GenerateImplicitData, OptionExt, ResultExt};
|
||||
|
||||
use crate::python::error;
|
||||
use crate::python::error::ret_other_error_with;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::rspython::builtins::try_into_columnar_value;
|
||||
|
||||
pub(crate) type PyVectorRef = PyRef<PyVector>;
|
||||
|
||||
/// use `rustpython`'s `is_instance` method to check if a PyObject is a instance of class.
|
||||
/// if `PyResult` is Err, then this function return `false`
|
||||
pub fn is_instance<T: PyPayload>(obj: &PyObjectRef, vm: &VirtualMachine) -> bool {
|
||||
obj.is_instance(T::class(vm).into(), vm).unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn format_py_error(excep: PyBaseExceptionRef, vm: &VirtualMachine) -> error::Error {
|
||||
let mut msg = String::new();
|
||||
if let Err(e) = vm.write_exception(&mut msg, &excep) {
|
||||
return error::Error::PyRuntime {
|
||||
msg: format!("Failed to write exception msg, err: {e}"),
|
||||
backtrace: Backtrace::generate(),
|
||||
};
|
||||
}
|
||||
|
||||
error::Error::PyRuntime {
|
||||
msg,
|
||||
backtrace: Backtrace::generate(),
|
||||
}
|
||||
}
|
||||
|
||||
/// convert a single PyVector or a number(a constant)(wrapping in PyObjectRef) into a Array(or a constant array)
|
||||
pub fn py_vec_obj_to_array(
|
||||
obj: &PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
col_len: usize,
|
||||
) -> Result<VectorRef, error::Error> {
|
||||
// It's ugly, but we can't find a better way right now.
|
||||
if is_instance::<PyVector>(obj, vm) {
|
||||
let pyv = obj
|
||||
.payload::<PyVector>()
|
||||
.with_context(|| ret_other_error_with(format!("can't cast obj {obj:?} to PyVector")))?;
|
||||
Ok(pyv.as_vector_ref())
|
||||
} else if is_instance::<PyInt>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<i64>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
let ret = Int64Vector::from_iterator(std::iter::repeat(val).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyFloat>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<f64>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
let ret = Float64Vector::from_iterator(std::iter::repeat(val).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyBool>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<bool>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
let ret = BooleanVector::from_iterator(std::iter::repeat(val).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyStr>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<String>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
let ret = StringVector::from_iterator(std::iter::repeat(val.as_str()).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyList>(obj, vm) {
|
||||
let columnar_value =
|
||||
try_into_columnar_value(obj.clone(), vm).map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
match columnar_value {
|
||||
DFColValue::Scalar(ScalarValue::List(scalars, _datatype)) => match scalars {
|
||||
Some(scalars) => {
|
||||
let array = ScalarValue::iter_to_array(scalars.into_iter())
|
||||
.context(error::DataFusionSnafu)?;
|
||||
|
||||
Helper::try_into_vector(array).context(error::TypeCastSnafu)
|
||||
}
|
||||
None => Ok(Arc::new(NullVector::new(0))),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
ret_other_error_with(format!("Expect a vector or a constant, found {obj:?}")).fail()
|
||||
}
|
||||
}
|
||||
|
||||
/// a terrible hack to call async from sync by:
|
||||
/// TODO(discord9): find a better way
|
||||
/// 1. spawn a new thread
|
||||
/// 2. create a new runtime in new thread and call `block_on` on it
|
||||
#[allow(unused)]
|
||||
pub fn block_on_async<T, F>(f: F) -> std::thread::Result<T>
|
||||
where
|
||||
F: Future<Output = T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let rt = tokio::runtime::Runtime::new().map_err(|e| Box::new(e) as _)?;
|
||||
std::thread::spawn(move || rt.block_on(f)).join()
|
||||
}
|
||||
514
src/script/src/python/rspython/vector_impl.rs
Normal file
514
src/script/src/python/rspython/vector_impl.rs
Normal file
@@ -0,0 +1,514 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_time::date::Date;
|
||||
use common_time::datetime::DateTime;
|
||||
use common_time::timestamp::Timestamp;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use datatypes::arrow::array::{Array, BooleanArray};
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::arrow::compute::kernels::arithmetic;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::data_type::{ConcreteDataType, DataType};
|
||||
use datatypes::value::{self, OrderedFloat};
|
||||
use datatypes::vectors::Helper;
|
||||
use once_cell::sync::Lazy;
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyBool, PyBytes, PyFloat, PyInt, PyNone, PyStr};
|
||||
use rustpython_vm::function::{Either, OptionalArg, PyComparisonValue};
|
||||
use rustpython_vm::protocol::{PyMappingMethods, PySequenceMethods};
|
||||
use rustpython_vm::types::{AsMapping, AsSequence, Comparable, PyComparisonOp};
|
||||
use rustpython_vm::{
|
||||
atomic_func, pyclass as rspyclass, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
|
||||
VirtualMachine,
|
||||
};
|
||||
|
||||
use crate::python::ffi_types::vector::{
|
||||
arrow_rfloordiv, arrow_rsub, arrow_rtruediv, rspy_is_pyobj_scalar, wrap_result, PyVector,
|
||||
};
|
||||
use crate::python::utils::is_instance;
|
||||
/// PyVectors' rustpython specify methods
|
||||
|
||||
fn to_type_error(vm: &'_ VirtualMachine) -> impl FnOnce(String) -> PyBaseExceptionRef + '_ {
|
||||
|msg: String| vm.new_type_error(msg)
|
||||
}
|
||||
|
||||
pub(crate) type PyVectorRef = PyRef<PyVector>;
|
||||
/// PyVector type wraps a greptime vector, impl multiply/div/add/sub opeerators etc.
|
||||
#[rspyclass(with(AsMapping, AsSequence, Comparable))]
|
||||
impl PyVector {
|
||||
#[pymethod]
|
||||
pub(crate) fn new(
|
||||
iterable: OptionalArg<PyObjectRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyVector> {
|
||||
if let OptionalArg::Present(iterable) = iterable {
|
||||
let mut elements: Vec<PyObjectRef> = iterable.try_to_value(vm)?;
|
||||
|
||||
if elements.is_empty() {
|
||||
return Ok(PyVector::default());
|
||||
}
|
||||
|
||||
let datatype = get_concrete_type(&elements[0], vm)?;
|
||||
let mut buf = datatype.create_mutable_vector(elements.len());
|
||||
|
||||
for obj in elements.drain(..) {
|
||||
let val = if let Some(v) =
|
||||
pyobj_try_to_typed_val(obj.clone(), vm, Some(datatype.clone()))
|
||||
{
|
||||
v
|
||||
} else {
|
||||
return Err(vm.new_type_error(format!(
|
||||
"Can't cast pyobject {obj:?} into concrete type {datatype:?}",
|
||||
)));
|
||||
};
|
||||
// Safety: `pyobj_try_to_typed_val()` has checked the data type.
|
||||
buf.push_value_ref(val.as_value_ref());
|
||||
}
|
||||
|
||||
Ok(PyVector {
|
||||
vector: buf.to_vector(),
|
||||
})
|
||||
} else {
|
||||
Ok(PyVector::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(name = "__radd__")]
|
||||
#[pymethod(magic)]
|
||||
fn add(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
self.rspy_scalar_arith_op(other, None, wrap_result(arithmetic::add_dyn), vm)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(other, None, wrap_result(arithmetic::add_dyn), vm)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
self.rspy_scalar_arith_op(other, None, wrap_result(arithmetic::subtract_dyn), vm)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(other, None, wrap_result(arithmetic::subtract_dyn), vm)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn rsub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
self.rspy_scalar_arith_op(other, None, arrow_rsub, vm)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(
|
||||
other,
|
||||
None,
|
||||
wrap_result(|a, b| arithmetic::subtract_dyn(b, a)),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(name = "__rmul__")]
|
||||
#[pymethod(magic)]
|
||||
fn mul(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
self.rspy_scalar_arith_op(other, None, wrap_result(arithmetic::multiply_dyn), vm)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(other, None, wrap_result(arithmetic::multiply_dyn), vm)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn truediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
self.rspy_scalar_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(arithmetic::divide_dyn),
|
||||
vm,
|
||||
)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(arithmetic::divide_dyn),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn rtruediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
self.rspy_scalar_arith_op(other, Some(ArrowDataType::Float64), arrow_rtruediv, vm)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(|a, b| arithmetic::divide_dyn(b, a)),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn floordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
self.rspy_scalar_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Int64),
|
||||
wrap_result(arithmetic::divide_dyn),
|
||||
vm,
|
||||
)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Int64),
|
||||
wrap_result(arithmetic::divide_dyn),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn rfloordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
// FIXME: DataType convert problem, target_type should be inferred?
|
||||
self.rspy_scalar_arith_op(other, Some(ArrowDataType::Int64), arrow_rfloordiv, vm)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Int64),
|
||||
wrap_result(|a, b| arithmetic::divide_dyn(b, a)),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn and(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
Self::vector_and(self, &other).map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn or(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
Self::vector_or(self, &other).map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn invert(&self, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
Self::vector_invert(self).map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
#[pymethod(name = "__len__")]
|
||||
fn len_rspy(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
#[pymethod(name = "concat")]
|
||||
fn concat(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let left = self.to_arrow_array();
|
||||
let right = other.to_arrow_array();
|
||||
|
||||
let res = compute::concat(&[left.as_ref(), right.as_ref()]);
|
||||
let res = res.map_err(|err| vm.new_runtime_error(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast result into vector, result: {res:?}, err: {e:?}",
|
||||
))
|
||||
})?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
|
||||
/// take a boolean array and filters the Array, returning elements matching the filter (i.e. where the values are true).
|
||||
#[pymethod(name = "filter")]
|
||||
fn filter(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let left = self.to_arrow_array();
|
||||
let right = other.to_arrow_array();
|
||||
let filter = right.as_any().downcast_ref::<BooleanArray>();
|
||||
match filter {
|
||||
Some(filter) => {
|
||||
let res = compute::filter(left.as_ref(), filter);
|
||||
|
||||
let res =
|
||||
res.map_err(|err| vm.new_runtime_error(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast result into vector, result: {res:?}, err: {e:?}",
|
||||
))
|
||||
})?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
None => Err(vm.new_runtime_error(format!(
|
||||
"Can't cast operand into a Boolean Array, which is {right:#?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn doc(&self) -> PyResult<PyStr> {
|
||||
Ok(PyStr::from(
|
||||
"PyVector is like a Python array, a compact array of elem of same datatype, but Readonly for now",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMapping for PyVector {
|
||||
fn as_mapping() -> &'static PyMappingMethods {
|
||||
static AS_MAPPING: PyMappingMethods = PyMappingMethods {
|
||||
length: atomic_func!(|mapping, _vm| Ok(PyVector::mapping_downcast(mapping).len())),
|
||||
subscript: atomic_func!(
|
||||
|mapping, needle, vm| PyVector::mapping_downcast(mapping)._getitem(needle, vm)
|
||||
),
|
||||
ass_subscript: AtomicCell::new(None),
|
||||
};
|
||||
&AS_MAPPING
|
||||
}
|
||||
}
|
||||
|
||||
impl AsSequence for PyVector {
|
||||
fn as_sequence() -> &'static PySequenceMethods {
|
||||
static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
|
||||
length: atomic_func!(|seq, _vm| Ok(PyVector::sequence_downcast(seq).len())),
|
||||
item: atomic_func!(|seq, i, vm| {
|
||||
let zelf = PyVector::sequence_downcast(seq);
|
||||
zelf.getitem_by_index(i, vm)
|
||||
}),
|
||||
ass_item: atomic_func!(|_seq, _i, _value, vm| {
|
||||
Err(vm.new_type_error("PyVector object doesn't support item assigns".to_string()))
|
||||
}),
|
||||
..PySequenceMethods::NOT_IMPLEMENTED
|
||||
});
|
||||
&AS_SEQUENCE
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparable for PyVector {
|
||||
fn slot_richcompare(
|
||||
zelf: &PyObject,
|
||||
other: &PyObject,
|
||||
op: PyComparisonOp,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<Either<PyObjectRef, PyComparisonValue>> {
|
||||
if let Some(zelf) = zelf.downcast_ref::<Self>() {
|
||||
let ret: PyVector = zelf.richcompare(other.to_owned(), op, vm)?;
|
||||
let ret = ret.into_pyobject(vm);
|
||||
Ok(Either::A(ret))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!(
|
||||
"unexpected payload {:?} for {}",
|
||||
zelf,
|
||||
op.method_name(&vm.ctx).as_str()
|
||||
)))
|
||||
}
|
||||
}
|
||||
fn cmp(
|
||||
_zelf: &rustpython_vm::Py<Self>,
|
||||
_other: &PyObject,
|
||||
_op: PyComparisonOp,
|
||||
_vm: &VirtualMachine,
|
||||
) -> PyResult<PyComparisonValue> {
|
||||
Ok(PyComparisonValue::NotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_concrete_type(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<ConcreteDataType> {
|
||||
if is_instance::<PyNone>(obj, vm) {
|
||||
Ok(ConcreteDataType::null_datatype())
|
||||
} else if is_instance::<PyBool>(obj, vm) {
|
||||
Ok(ConcreteDataType::boolean_datatype())
|
||||
} else if is_instance::<PyInt>(obj, vm) {
|
||||
Ok(ConcreteDataType::int64_datatype())
|
||||
} else if is_instance::<PyFloat>(obj, vm) {
|
||||
Ok(ConcreteDataType::float64_datatype())
|
||||
} else if is_instance::<PyStr>(obj, vm) {
|
||||
Ok(ConcreteDataType::string_datatype())
|
||||
} else {
|
||||
Err(vm.new_type_error(format!("Unsupported pyobject type: {obj:?}")))
|
||||
}
|
||||
}
|
||||
|
||||
/// convert a `PyObjectRef` into a `datatypes::Value`(is that ok?)
|
||||
/// if `obj` can be convert to given ConcreteDataType then return inner `Value` else return None
|
||||
/// if dtype is None, return types with highest precision
|
||||
/// Not used for now but may be use in future
|
||||
pub(crate) fn pyobj_try_to_typed_val(
|
||||
obj: PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
dtype: Option<ConcreteDataType>,
|
||||
) -> Option<value::Value> {
|
||||
// TODO(discord9): use `PyResult` instead of `Option` for better error handling
|
||||
if let Some(dtype) = dtype {
|
||||
match dtype {
|
||||
ConcreteDataType::Null(_) => {
|
||||
if is_instance::<PyNone>(&obj, vm) {
|
||||
Some(value::Value::Null)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Boolean(_) => {
|
||||
if is_instance::<PyBool>(&obj, vm) || is_instance::<PyInt>(&obj, vm) {
|
||||
Some(value::Value::Boolean(
|
||||
obj.try_into_value::<bool>(vm).unwrap_or(false),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Int8(_)
|
||||
| ConcreteDataType::Int16(_)
|
||||
| ConcreteDataType::Int32(_)
|
||||
| ConcreteDataType::Int64(_) => {
|
||||
if is_instance::<PyInt>(&obj, vm) {
|
||||
match dtype {
|
||||
ConcreteDataType::Int8(_) => {
|
||||
obj.try_into_value::<i8>(vm).ok().map(value::Value::Int8)
|
||||
}
|
||||
ConcreteDataType::Int16(_) => {
|
||||
obj.try_into_value::<i16>(vm).ok().map(value::Value::Int16)
|
||||
}
|
||||
ConcreteDataType::Int32(_) => {
|
||||
obj.try_into_value::<i32>(vm).ok().map(value::Value::Int32)
|
||||
}
|
||||
ConcreteDataType::Int64(_) => {
|
||||
obj.try_into_value::<i64>(vm).ok().map(value::Value::Int64)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::UInt8(_)
|
||||
| ConcreteDataType::UInt16(_)
|
||||
| ConcreteDataType::UInt32(_)
|
||||
| ConcreteDataType::UInt64(_) => {
|
||||
if is_instance::<PyInt>(&obj, vm)
|
||||
&& obj.clone().try_into_value::<i64>(vm).unwrap_or(-1) >= 0
|
||||
{
|
||||
match dtype {
|
||||
ConcreteDataType::UInt8(_) => {
|
||||
obj.try_into_value::<u8>(vm).ok().map(value::Value::UInt8)
|
||||
}
|
||||
ConcreteDataType::UInt16(_) => {
|
||||
obj.try_into_value::<u16>(vm).ok().map(value::Value::UInt16)
|
||||
}
|
||||
ConcreteDataType::UInt32(_) => {
|
||||
obj.try_into_value::<u32>(vm).ok().map(value::Value::UInt32)
|
||||
}
|
||||
ConcreteDataType::UInt64(_) => {
|
||||
obj.try_into_value::<u64>(vm).ok().map(value::Value::UInt64)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
|
||||
if is_instance::<PyFloat>(&obj, vm) {
|
||||
match dtype {
|
||||
ConcreteDataType::Float32(_) => obj
|
||||
.try_into_value::<f32>(vm)
|
||||
.ok()
|
||||
.map(|v| value::Value::Float32(OrderedFloat(v))),
|
||||
ConcreteDataType::Float64(_) => obj
|
||||
.try_into_value::<f64>(vm)
|
||||
.ok()
|
||||
.map(|v| value::Value::Float64(OrderedFloat(v))),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
ConcreteDataType::String(_) => {
|
||||
if is_instance::<PyStr>(&obj, vm) {
|
||||
obj.try_into_value::<String>(vm)
|
||||
.ok()
|
||||
.map(|v| value::Value::String(v.into()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Binary(_) => {
|
||||
if is_instance::<PyBytes>(&obj, vm) {
|
||||
obj.try_into_value::<Vec<u8>>(vm).ok().and_then(|v| {
|
||||
String::from_utf8(v)
|
||||
.ok()
|
||||
.map(|v| value::Value::String(v.into()))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Date(_)
|
||||
| ConcreteDataType::DateTime(_)
|
||||
| ConcreteDataType::Timestamp(_) => {
|
||||
if is_instance::<PyInt>(&obj, vm) {
|
||||
match dtype {
|
||||
ConcreteDataType::Date(_) => obj
|
||||
.try_into_value::<i32>(vm)
|
||||
.ok()
|
||||
.map(Date::new)
|
||||
.map(value::Value::Date),
|
||||
ConcreteDataType::DateTime(_) => obj
|
||||
.try_into_value::<i64>(vm)
|
||||
.ok()
|
||||
.map(DateTime::new)
|
||||
.map(value::Value::DateTime),
|
||||
ConcreteDataType::Timestamp(_) => {
|
||||
// FIXME(dennis): we always consider the timestamp unit is millis, it's not correct if user define timestamp column with other units.
|
||||
obj.try_into_value::<i64>(vm)
|
||||
.ok()
|
||||
.map(Timestamp::new_millisecond)
|
||||
.map(value::Value::Timestamp)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
} else if is_instance::<PyNone>(&obj, vm) {
|
||||
// if Untyped then by default return types with highest precision
|
||||
Some(value::Value::Null)
|
||||
} else if is_instance::<PyBool>(&obj, vm) {
|
||||
Some(value::Value::Boolean(
|
||||
obj.try_into_value::<bool>(vm).unwrap_or(false),
|
||||
))
|
||||
} else if is_instance::<PyInt>(&obj, vm) {
|
||||
obj.try_into_value::<i64>(vm).ok().map(value::Value::Int64)
|
||||
} else if is_instance::<PyFloat>(&obj, vm) {
|
||||
obj.try_into_value::<f64>(vm)
|
||||
.ok()
|
||||
.map(|v| value::Value::Float64(OrderedFloat(v)))
|
||||
} else if is_instance::<PyStr>(&obj, vm) {
|
||||
obj.try_into_value::<Vec<u8>>(vm).ok().and_then(|v| {
|
||||
String::from_utf8(v)
|
||||
.ok()
|
||||
.map(|v| value::Value::String(v.into()))
|
||||
})
|
||||
} else if is_instance::<PyBytes>(&obj, vm) {
|
||||
obj.try_into_value::<Vec<u8>>(vm).ok().and_then(|v| {
|
||||
String::from_utf8(v)
|
||||
.ok()
|
||||
.map(|v| value::Value::String(v.into()))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -12,24 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::ColumnarValue as DFColValue;
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::vectors::{
|
||||
BooleanVector, Float64Vector, Helper, Int64Vector, NullVector, StringVector, VectorRef,
|
||||
};
|
||||
use futures::Future;
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt, PyList, PyStr};
|
||||
use rustpython_vm::{PyObjectRef, PyPayload, PyRef, VirtualMachine};
|
||||
use snafu::{Backtrace, GenerateImplicitData, OptionExt, ResultExt};
|
||||
use rustpython_vm::builtins::PyBaseExceptionRef;
|
||||
use rustpython_vm::{PyObjectRef, PyPayload, VirtualMachine};
|
||||
use snafu::{Backtrace, GenerateImplicitData};
|
||||
|
||||
use crate::python::builtins::try_into_columnar_value;
|
||||
use crate::python::error::ret_other_error_with;
|
||||
use crate::python::{error, PyVector};
|
||||
|
||||
pub(crate) type PyVectorRef = PyRef<PyVector>;
|
||||
use crate::python::error;
|
||||
|
||||
/// use `rustpython`'s `is_instance` method to check if a PyObject is a instance of class.
|
||||
/// if `PyResult` is Err, then this function return `false`
|
||||
@@ -52,69 +40,6 @@ pub fn format_py_error(excep: PyBaseExceptionRef, vm: &VirtualMachine) -> error:
|
||||
}
|
||||
}
|
||||
|
||||
/// convert a single PyVector or a number(a constant)(wrapping in PyObjectRef) into a Array(or a constant array)
|
||||
pub fn py_vec_obj_to_array(
|
||||
obj: &PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
col_len: usize,
|
||||
) -> Result<VectorRef, error::Error> {
|
||||
// It's ugly, but we can't find a better way right now.
|
||||
if is_instance::<PyVector>(obj, vm) {
|
||||
let pyv = obj
|
||||
.payload::<PyVector>()
|
||||
.with_context(|| ret_other_error_with(format!("can't cast obj {obj:?} to PyVector")))?;
|
||||
Ok(pyv.as_vector_ref())
|
||||
} else if is_instance::<PyInt>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<i64>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
let ret = Int64Vector::from_iterator(std::iter::repeat(val).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyFloat>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<f64>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
let ret = Float64Vector::from_iterator(std::iter::repeat(val).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyBool>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<bool>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
let ret = BooleanVector::from_iterator(std::iter::repeat(val).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyStr>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<String>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
let ret = StringVector::from_iterator(std::iter::repeat(val.as_str()).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyList>(obj, vm) {
|
||||
let columnar_value =
|
||||
try_into_columnar_value(obj.clone(), vm).map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
match columnar_value {
|
||||
DFColValue::Scalar(ScalarValue::List(scalars, _datatype)) => match scalars {
|
||||
Some(scalars) => {
|
||||
let array = ScalarValue::iter_to_array(scalars.into_iter())
|
||||
.context(error::DataFusionSnafu)?;
|
||||
|
||||
Helper::try_into_vector(array).context(error::TypeCastSnafu)
|
||||
}
|
||||
None => Ok(Arc::new(NullVector::new(0))),
|
||||
},
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
ret_other_error_with(format!("Expect a vector or a constant, found {obj:?}")).fail()
|
||||
}
|
||||
}
|
||||
|
||||
/// a terrible hack to call async from sync by:
|
||||
/// TODO(discord9): find a better way
|
||||
/// 1. spawn a new thread
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user