mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-22 22:20:02 +00:00
feat: script engine and python impl (#219)
* feat: improve try_into_vector function * Impl python mod and PyVector to execute script * add AsSeq(BUT not IMPL) * add&test pythonic_index, add into_py_obj(UNTEST) * add into_datatypes_value(UNTEST) * inplace setitem_by_index unsupport * still struggle with testing AsSeq * actually pyimpl AsSeq&AsMap * add slice for PyVector * improve visualibility for testing * adjust for clippy * add assert for test_execute_script * add type anno in test * feat: basic support for PyVector's operator with scalar (#64) * feat: memory size of vector (#53) * feat: improve try_into_vector function * feat: impl memory_size function for vectors * fix: forgot memory_size assertion in null vector test * feat: use LargeUtf8 instead of utf8 for string, and rename LargeBianryArray to BinaryArray * feat: memory_size only calculates heap size * feat: impl bytes_allocated for memtable (#55) * add init and constr * rename type cast and add test * fix bug in pyobj_to_val * add default cast when no type specifed * add basic add/sub/mul for array and scalar(value) * cargo clippy * comment out some println * stricter clippy * style: cargo fmt * fix: string&bool support in val2pyobj & back * style: remove println in test * style: rm println in test mod in python.rs * refactor: use wrap_index instead of pythonic_index * refactor: right op in scalar_arith_op * fix: stronger type& better test * style: remove println * fix: scalar sign/unsigned cast * feat: improve try_into_vector function * Impl python mod and PyVector to execute script * add AsSeq(BUT not IMPL) * add&test pythonic_index, add into_py_obj(UNTEST) * add into_datatypes_value(UNTEST) * inplace setitem_by_index unsupport * still struggle with testing AsSeq * actually pyimpl AsSeq&AsMap * add slice for PyVector * improve visualibility for testing * adjust for clippy * add assert for test_execute_script * add type anno in test * add init and constr * rename type cast and add test * fix bug in pyobj_to_val * add default cast when no type specifed * add basic add/sub/mul for array and scalar(value) * cargo clippy * comment out some println * stricter clippy * style: cargo fmt * fix: string&bool support in val2pyobj & back * style: remove println in test * style: rm println in test mod in python.rs * refactor: use wrap_index instead of pythonic_index * refactor: right op in scalar_arith_op * fix: stronger type& better test * style: remove println * fix: scalar sign/unsigned cast * style: remove instead of comment out * style: remove more comment out * feat: support scalar div vector * style: cargo fmt * style: typo * refactor: rename to correct var name * refactor: directly use arrow2::array * refactor: mv rsub&rdiv's op into a function * test: add python expr test * test: add test for PyList * refactor: tweak order of arithmetics in rtruediv * style: remove some `use` * refactor: move `is_instance` to mod * refactor: move fn to mod& move `use` to head * style: cargo fmt * fix: correct signed/unsigned cast * refactor: wrap err msg in another fn * style: cargo fmt * style: remove ok_or_else for readability * feat: add coprocessor fn(not yet impl) * refactor: change back to wrapped_at * fix: update Cargo.lock * fix: update rustc version * Update Rust Toolchain to nightly-2022-07-14 * feat: derive Eq when possible * style: use `from` to avoid `needless_borrow` lint Co-authored-by: dennis zhuang <killme2008@gmail.com> * feat: python coprocessor with type annotation (#96) * feat: add coprocessor fn Signed-off-by: discord9 <zglzy29yzdk@gmail.com> * feat: cast args into PyVector * feat: uncomplete coprocessor * feat: erase decorator in python ast * feat: strip decorator in ast * fix: change parse to `Interactive` * style: format Cargo.toml * feat: make coprocessor actually work * feat: move coprocessor fn out of test mod * feat: add error handling * style: add some comment * feat: rm type annotation * feat: add type annotation support * style: move compile method to vm closure * feat: annotation for nullable * feat: type coercion cast in annotation * feat: actually cast(NOT TESTED) * fix: allow single into(type) * refactor: extract parse_type from parser * style: cargo fmt * feat: change to Expr to preserve location info * feat: add CoprParse to deal parse check error * style: add type anno doc for coprocessor * test: add some test * feat: add underscore as any type in annotation * test: add parse& runtime testcases * style: rm dbg! remnant * style: cargo fmt * feat: add more error prompt info * style: cargo fmt * style: add doc tests' missing `use` * fix: doc test for coprocessor * style: cargo fmt * fix: add missing `use` for `cargo test --doc` * refactor: according to reviews * refactor: more tweaks according to reviews * refactor: merge match arm * refactor: move into different files(UNCOMPLELTE) * refactor: split parse_copr into more function * refactor: split `exec_coprocessor` to more fn * style: cargo fmt * feat: print Py Exceptions in String * feat: error handling conform standards * test: fix test_coprocessor * feat: remove `into` in python * test: remove all `into` in python test * style: update comment * refactor: move strip compile fn to impl Copr * refactor: move `gen_schema` to impl copr * refactor: move `check_cast_type` to impl copr * refactor: if let to match * style: cargo fmt * refactor: better parse of keyword arg list * style: cargo fmt * refactor: some error handling(UNCOMPLETE) * refactor: error handling to general Error type * refactor: rm some Vec::new() * test: modify all tests to ok * style: reorder item * refactor: fetch using iter * style: cargo fmt * style: fmt macro by hand * refactor: rename InnerError to Error * test: use ron to write test * test: add test for exec_copr * refactor: add parse_bin_op * feat: add check_anno * refactor: add some checker function * refactor: exec_copr into smaller func * style: add some comment * refactor: add check for bin_op * refactor: rm useless Result * style: add pretty print for error with location * feat: more info for pretty print * refactor: mv pretty print to error.rs * refactor: rm execute_script * feat: add pretty print * feat: add constant column support * test: add test for constant column * feat: add pretty print exec fn * style: cargo fmt * feat: add macro to chain call `.fail()` * style: update doc for constant columns * style: add lint to allow print in test fn * style: cargo fmt * docs: update some comment * fix: ignore doctest for now * refactor: check_bin_op * refactor: parse_in_op, check ret anno fn * refactor: rm check_decorator * doc: loc add newline explain * style: cargo fmt * refactor: use Helper::try_into_vec in try_into_vec * style: cargo fmt * test: add ret anno test * style: cargo fmt * test: add name for .ron tests for better debug * test: print emoji in test * style: rm some comment out line * style: rename `into` to `try_into` fn * style: cargo fmt * refactor: rm unuse serialize derive * fix: pretty print out of bound fix * fix: rm some space in pretty print * style: cargo fmt * test: not even a python fn def * style: cargo fmt * fix: pretty print off by one space * fix: allow `eprint` in clippy lint * fix: compile error after rebase develop * feat: port 35 functions from DataFusion to Python Coprocessor (#137) * refactor: `cargo clippy` * feat: create a module * style: cargo fmt * feat: bind `pow()` function(UNTEST) * test: add test for udf mod * style: allow part eq not eq for gen code * style: allow print in test lint * feat: use PyObjectRef to handle more types * feat: add cargo feature for udf modules * style: rename feature to udf-builtins * refactor: move away from mod.rs * feat: add all_to_f64 cast fn * feat: add bind_math_fn macro * feat: add all simple math UDF * feat: add `random(len)` math fn * feat: port `avg()` from datafusion * refactor: add `eval_aggr_fn` * feat: add bind_aggr_fn macro * doc: add comment for args of macro * feat: add all UDAF from datafusion * refactor: extract test to separate file * style: cargo fmt * test: add incomplete test * test: add .ron test fn * feat: support scalar::list * doc: add comments * style: rename VagueFloat/Int to LenFloat/IntVec * test: for all fn(expect approx_median) * test: better print * doc: add comment for FloatWithError * refactor: move test.rs out of builtins/ * style: cargo fmt * doc: add comment for .ron file * doc: update some comments * test: EPS=1e-12 for float eq * test: use f64::EPSILON instead * test: change to 2*EPS * test: cache interpreter for fast testing * doc: remove a TODO which is done * test: refacto to_py_obj fn * fix: pow fn * doc: add a TODO for type_.rs * test: use new_int/float in test serde * test: for str case * style: cargo fmt * feat: cast PyList to ScalarValue::List * test: cast scalar to py obj and back * feat: cast to PyList * test: cast from PyList * test: nested PyVector unsupported * doc: remove unrunable doctest * test: replace PartialEq with impl just_as_expect * doc: add name for discord9's TODO * refactor: cahnge to vm.ctx.new_** instead * doc: complete a TODO * refactor: is_instance and other minor problem * refactor: remove type_::is_instance * style: cargo fmt * feat: rename to `greptime_builtin` * fix: error handling for PyList datatype * style: fix clippy warning * test: for PyList * feat: Python Coprocessor MVP (#180) * feat: add get_arrow_op * feat: add comparsion op(UNTESTED) * doc: explain why no rich compare * refactor: py_str2str&parse_keywords * feat: add DecoratorArgs * refactor: parse_keywords ret Deco Args * style: remove unused * doc: add todo * style: remove some unused fn * doc: add comment for copr's field * feat: add copr_engine module * refactor: move to `script` crate * style: clean up cargo.toml * feat: add query engine for copr engine * refactor: deco args into separate struct * test: update corrsponding test * feat: async coprocessor engine * refactor: add `exec_parsed` fn * feat: sync version of coprocessor(UNTEST) * refactor: remove useless lifetime * feat: new type for async stream record batch * merge: from PR#137 add py builtins * toolchain: update rustc to nightly-08-16 * feat: add `exec_with_cached_vm` fn(Can't compile) * toolchain: revert to 07-14 * fix: `exec_with_cached_vm` * fix: allow vector[_] in params * style: cargo fmt * doc: update comment on `_`&`_|None` * fix: allow import&ignore type anno is ok * feat: allow ignore return types * refsctor: remove unused py files in functions/ * style: fmt&clippy * refactor: python modules (#186) * refactor: move common/script to script * fix: clippy warnings and refactor python modules * refactor: remove modules mod rename tests mod * feat: adds Script and ScriptEngine trait, then impl PyScript/PyScriptEngine * refactor: remove pub use some functions in script * refactor: python error mod * refactor: coprocessor and vector * feat: adds engine test and greptime.vector function to create vector from iterable * fix: adds a blank line to cargo file end * fix: compile error after rebase develop * feat: script endpoint for http server (#206) * feat: impl /scripts API for http server * feat: adds http api version * test: add test for scripts handler and endpoint * feat: python side mock module and more builtin functions (#209) * feat: add python side module(for both mock and real upload script) * style: add *.pyc to gitignore * feat: move copr decorator(in .py) to greptime.py * doc: update comment for `datetime`&`mock_tester`&gitignore * feat: `filter()` a array with bool array(UNTESTED) * feat: `prev()`ious elem in array ret as new array(UNTEST) * feat: `datetime()` parse date time string and ret integer(UNTEST) * fix: add missing return&fmt * fix: allow f32 cast to PyFloat * fix: `datetime()`'s last token now parsed * test: `calc_rvs` now can run with builtin module * feat: allow rich compare which ret bool array * feat: logic and(`&`) for bool array * style: cargo fmt * feat: index PyVector by bool array * feat: alias `ln` as `log` in builtin modules * feat: logic or(`|`)¬( `~`) for bool array * feat: add `post` for @copr in py side mod * feat: change datetime return to i64 * feat: py side mod `post` script to given address * fix: add `engine` field in `post` in py side mod * refactor: use `ConstantVector` in `pow()` builtin * fix: prev ret err for zero array * doc: rm comment out code * test: incomplete pyside mod test case * git: ignore all __pycache__ * style: fmt&clippy * refactor: split py side module into exmaple&gptime * feat: init_table in py using `v1/sql` api * feat: calc_rvs now run both locally and remote * doc: add doc for how to run it * fix: comment out start server code in test * fix: clippy warnings * fix: http test url * fix: some CR problems * fix: some CR problems * refactor: script executor for instance * refactor: remove engine param in execute_script * chore: Remove unnecessary allow attributes Co-authored-by: Dennis Zhuang <killme2008@gmail.com> Co-authored-by: Discord9 <discord9@163.com> Co-authored-by: discord9 <zglzy29yzdk@gmail.com> Co-authored-by: discord9 <55937128+discord9@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/coverage.yml
vendored
2
.github/workflows/coverage.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
name: Code coverage
|
||||
|
||||
env:
|
||||
RUST_TOOLCHAIN: nightly-2022-04-03
|
||||
RUST_TOOLCHAIN: nightly-2022-07-14
|
||||
|
||||
jobs:
|
||||
grcov:
|
||||
|
||||
2
.github/workflows/develop.yml
vendored
2
.github/workflows/develop.yml
vendored
@@ -5,7 +5,7 @@ on:
|
||||
name: Continuous integration for developing
|
||||
|
||||
env:
|
||||
RUST_TOOLCHAIN: nightly-2022-04-03
|
||||
RUST_TOOLCHAIN: nightly-2022-07-14
|
||||
|
||||
jobs:
|
||||
check:
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -22,3 +22,9 @@ debug/
|
||||
# Logs
|
||||
**/__unittest_logs
|
||||
logs/
|
||||
|
||||
.DS_store
|
||||
.gitignore
|
||||
|
||||
# cpython's generated python byte code
|
||||
**/__pycache__/
|
||||
|
||||
835
Cargo.lock
generated
835
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -19,6 +19,7 @@ members = [
|
||||
"src/logical-plans",
|
||||
"src/object-store",
|
||||
"src/query",
|
||||
"src/script",
|
||||
"src/servers",
|
||||
"src/sql",
|
||||
"src/storage",
|
||||
|
||||
0
component/script/python/__init__.py
Normal file
0
component/script/python/__init__.py
Normal file
0
component/script/python/example/__init__.py
Normal file
0
component/script/python/example/__init__.py
Normal file
69
component/script/python/example/calc_rv.py
Normal file
69
component/script/python/example/calc_rv.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import sys
|
||||
# for annoying releative import beyond top-level package
|
||||
sys.path.insert(0, "../")
|
||||
from greptime import mock_tester, coprocessor, greptime as gt_builtin
|
||||
from greptime.greptime import interval, vector, log, prev, sqrt, datetime
|
||||
import greptime.greptime as greptime
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
def data_sample(k_lines, symbol, density=5 * 30 * 86400):
|
||||
"""
|
||||
Only return close data for simplicty for now
|
||||
"""
|
||||
k_lines = k_lines["result"] if k_lines["ret_msg"] == "OK" else None
|
||||
if k_lines is None:
|
||||
raise Exception("Expect a `OK`ed message")
|
||||
close = [float(i["close"]) for i in k_lines]
|
||||
|
||||
return interval(close, density, "prev")
|
||||
|
||||
|
||||
def as_table(kline: list):
|
||||
col_len = len(kline)
|
||||
ret = {
|
||||
k: vector([fn(row[k]) for row in kline], str(ty))
|
||||
for k, fn, ty in
|
||||
[
|
||||
("symbol", str, "str"),
|
||||
("period", str, "str"),
|
||||
("open_time", int, "int"),
|
||||
("open", float, "float"),
|
||||
("high", float, "float"),
|
||||
("low", float, "float"),
|
||||
("close", float, "float")
|
||||
]
|
||||
}
|
||||
return ret
|
||||
|
||||
@coprocessor(args=["open_time", "close"], returns=[
|
||||
"rv_7d",
|
||||
"rv_15d",
|
||||
"rv_30d",
|
||||
"rv_60d",
|
||||
"rv_90d",
|
||||
"rv_180d"
|
||||
],
|
||||
sql="select open_time, close from k_line")
|
||||
def calc_rvs(open_time, close):
|
||||
from greptime import vector, log, prev, sqrt, datetime, pow, sum
|
||||
def calc_rv(close, open_time, time, interval):
|
||||
mask = (open_time < time) & (open_time > time - interval)
|
||||
close = close[mask]
|
||||
|
||||
avg_time_interval = (open_time[-1] - open_time[0])/(len(open_time)-1)
|
||||
ref = log(close/prev(close))
|
||||
var = sum(pow(ref, 2)/(len(ref)-1))
|
||||
return sqrt(var/avg_time_interval)
|
||||
|
||||
# how to get env var,
|
||||
# maybe through accessing scope and serde then send to remote?
|
||||
timepoint = open_time[-1]
|
||||
rv_7d = calc_rv(close, open_time, timepoint, datetime("7d"))
|
||||
rv_15d = calc_rv(close, open_time, timepoint, datetime("15d"))
|
||||
rv_30d = calc_rv(close, open_time, timepoint, datetime("30d"))
|
||||
rv_60d = calc_rv(close, open_time, timepoint, datetime("60d"))
|
||||
rv_90d = calc_rv(close, open_time, timepoint, datetime("90d"))
|
||||
rv_180d = calc_rv(close, open_time, timepoint, datetime("180d"))
|
||||
return rv_7d, rv_15d, rv_30d, rv_60d, rv_90d, rv_180d
|
||||
1
component/script/python/example/fetch_kline.sh
Executable file
1
component/script/python/example/fetch_kline.sh
Executable file
@@ -0,0 +1 @@
|
||||
curl "https://api.bybit.com/v2/public/index-price-kline?symbol=BTCUSD&interval=1&limit=$1&from=1581231260" > kline.json
|
||||
108
component/script/python/example/kline.json
Normal file
108
component/script/python/example/kline.json
Normal file
@@ -0,0 +1,108 @@
|
||||
{
|
||||
"ret_code": 0,
|
||||
"ret_msg": "OK",
|
||||
"ext_code": "",
|
||||
"ext_info": "",
|
||||
"result": [
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231300,
|
||||
"open": "10107",
|
||||
"high": "10109.34",
|
||||
"low": "10106.71",
|
||||
"close": "10106.79"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231360,
|
||||
"open": "10106.79",
|
||||
"high": "10109.27",
|
||||
"low": "10105.92",
|
||||
"close": "10106.09"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231420,
|
||||
"open": "10106.09",
|
||||
"high": "10108.75",
|
||||
"low": "10104.66",
|
||||
"close": "10108.73"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231480,
|
||||
"open": "10108.73",
|
||||
"high": "10109.52",
|
||||
"low": "10106.07",
|
||||
"close": "10106.38"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231540,
|
||||
"open": "10106.38",
|
||||
"high": "10109.48",
|
||||
"low": "10104.81",
|
||||
"close": "10106.95"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231600,
|
||||
"open": "10106.95",
|
||||
"high": "10109.48",
|
||||
"low": "10106.6",
|
||||
"close": "10107.55"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231660,
|
||||
"open": "10107.55",
|
||||
"high": "10109.28",
|
||||
"low": "10104.68",
|
||||
"close": "10104.68"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231720,
|
||||
"open": "10104.68",
|
||||
"high": "10109.18",
|
||||
"low": "10104.14",
|
||||
"close": "10108.8"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231780,
|
||||
"open": "10108.8",
|
||||
"high": "10117.36",
|
||||
"low": "10108.8",
|
||||
"close": "10115.96"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231840,
|
||||
"open": "10115.96",
|
||||
"high": "10119.19",
|
||||
"low": "10115.96",
|
||||
"close": "10117.08"
|
||||
},
|
||||
{
|
||||
"symbol": "BTCUSD",
|
||||
"period": "1",
|
||||
"open_time": 1581231900,
|
||||
"open": "10117.08",
|
||||
"high": "10120.73",
|
||||
"low": "10116.96",
|
||||
"close": "10120.43"
|
||||
}
|
||||
],
|
||||
"time_now": "1661225351.158190"
|
||||
}
|
||||
4
component/script/python/greptime/__init__.py
Normal file
4
component/script/python/greptime/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .greptime import coprocessor, copr
|
||||
from .greptime import vector, log, prev, sqrt, pow, datetime, sum
|
||||
from .mock import mock_tester
|
||||
from .cfg import set_conn_addr, get_conn_addr
|
||||
11
component/script/python/greptime/cfg.py
Normal file
11
component/script/python/greptime/cfg.py
Normal file
@@ -0,0 +1,11 @@
|
||||
GREPTIME_DB_CONN_ADDRESS = "localhost:3000"
|
||||
"""The Global Variable for address for conntect to database"""
|
||||
|
||||
def set_conn_addr(addr: str):
|
||||
"""set database address to given `addr`"""
|
||||
global GREPTIME_DB_CONN_ADDRESS
|
||||
GREPTIME_DB_CONN_ADDRESS = addr
|
||||
|
||||
def get_conn_addr()->str:
|
||||
global GREPTIME_DB_CONN_ADDRESS
|
||||
return GREPTIME_DB_CONN_ADDRESS
|
||||
215
component/script/python/greptime/greptime.py
Normal file
215
component/script/python/greptime/greptime.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Be note that this is a mock library, if not connected to database,
|
||||
it can only run on mock data and mock function which is supported by numpy
|
||||
"""
|
||||
import functools
|
||||
import numpy as np
|
||||
import json
|
||||
from urllib import request
|
||||
import inspect
|
||||
import requests
|
||||
|
||||
from .cfg import set_conn_addr, get_conn_addr
|
||||
|
||||
log = np.log
|
||||
sum = np.nansum
|
||||
sqrt = np.sqrt
|
||||
pow = np.power
|
||||
nan = np.nan
|
||||
|
||||
|
||||
class TimeStamp(str):
|
||||
"""
|
||||
TODO: impl date time
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class i32(int):
|
||||
"""
|
||||
For Python Coprocessor Type Annotation ONLY
|
||||
A signed 32-bit integer.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "i32"
|
||||
|
||||
|
||||
class i64(int):
|
||||
"""
|
||||
For Python Coprocessor Type Annotation ONLY
|
||||
A signed 64-bit integer.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "i64"
|
||||
|
||||
|
||||
class f32(float):
|
||||
"""
|
||||
For Python Coprocessor Type Annotation ONLY
|
||||
A 32-bit floating point number.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "f32"
|
||||
|
||||
|
||||
class f64(float):
|
||||
"""
|
||||
For Python Coprocessor Type Annotation ONLY
|
||||
A 64-bit floating point number.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "f64"
|
||||
|
||||
|
||||
class vector(np.ndarray):
|
||||
"""
|
||||
A compact Vector with all elements of same Data type.
|
||||
"""
|
||||
_datatype: str | None = None
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
lst,
|
||||
dtype=None
|
||||
) -> ...:
|
||||
self = np.asarray(lst).view(cls)
|
||||
self._datatype = dtype
|
||||
return self
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "vector({}, \"{}\")".format(super().__str__(), self.datatype())
|
||||
|
||||
def datatype(self):
|
||||
return self._datatype
|
||||
|
||||
def filter(self, lst_bool):
|
||||
return self[lst_bool]
|
||||
|
||||
|
||||
def prev(lst):
|
||||
ret = np.zeros(len(lst))
|
||||
ret[1:] = lst[0:-1]
|
||||
ret[0] = nan
|
||||
return ret
|
||||
|
||||
|
||||
def query(sql: str):
|
||||
pass
|
||||
|
||||
|
||||
def interval(arr: list, duration: int, fill, step: None | int = None, explicitOffset=False):
|
||||
"""
|
||||
Note that this is a mock function with same functionailty to the actual Python Coprocessor
|
||||
`arr` is a vector of integral or temporal type.
|
||||
|
||||
`duration` is the length of sliding window
|
||||
|
||||
`step` being the length when sliding window take a step
|
||||
|
||||
`fill` indicate how to fill missing value:
|
||||
- "prev": use previous
|
||||
- "post": next
|
||||
- "linear": linear interpolation, if not possible to interpolate certain types, fallback to prev
|
||||
- "null": use null
|
||||
- "none": do not interpolate
|
||||
"""
|
||||
if step is None:
|
||||
step = duration
|
||||
|
||||
tot_len = int(np.ceil(len(arr) // step))
|
||||
slices = np.zeros((tot_len, int(duration)))
|
||||
for idx, start in enumerate(range(0, len(arr), step)):
|
||||
slices[idx] = arr[start:(start + duration)]
|
||||
return slices
|
||||
|
||||
|
||||
def factor(unit: str) -> int:
|
||||
if unit == "d":
|
||||
return 24 * 60 * 60
|
||||
elif unit == "h":
|
||||
return 60 * 60
|
||||
elif unit == "m":
|
||||
return 60
|
||||
elif unit == "s":
|
||||
return 1
|
||||
else:
|
||||
raise Exception("Only d,h,m,s, found{}".format(unit))
|
||||
|
||||
|
||||
def datetime(input_time: str) -> int:
|
||||
"""
|
||||
support `d`(day) `h`(hour) `m`(minute) `s`(second)
|
||||
|
||||
support format:
|
||||
`12s` `7d` `12d2h7m`
|
||||
"""
|
||||
|
||||
prev = 0
|
||||
cur = 0
|
||||
state = "Num"
|
||||
parse_res = []
|
||||
for idx, ch in enumerate(input_time):
|
||||
if ch.isdigit():
|
||||
cur = idx
|
||||
|
||||
if state != "Num":
|
||||
parse_res.append((state, input_time[prev:cur], (prev, cur)))
|
||||
prev = idx
|
||||
state = "Num"
|
||||
else:
|
||||
cur = idx
|
||||
if state != "Symbol":
|
||||
parse_res.append((state, input_time[prev:cur], (prev, cur)))
|
||||
prev = idx
|
||||
state = "Symbol"
|
||||
parse_res.append((state, input_time[prev:cur+1], (prev, cur+1)))
|
||||
|
||||
cur_idx = 0
|
||||
res_time = 0
|
||||
while cur_idx < len(parse_res):
|
||||
pair = parse_res[cur_idx]
|
||||
if pair[0] == "Num":
|
||||
val = int(pair[1])
|
||||
nxt = parse_res[cur_idx+1]
|
||||
res_time += val * factor(nxt[1])
|
||||
cur_idx += 2
|
||||
else:
|
||||
raise Exception("Two symbol in a row is impossible")
|
||||
|
||||
return res_time
|
||||
|
||||
|
||||
def coprocessor(args=None, returns=None, sql=None):
|
||||
"""
|
||||
The actual coprocessor, which will connect to database and update
|
||||
whatever function decorated with `@coprocessor(args=[...], returns=[...], sql=...)`
|
||||
"""
|
||||
def decorator_copr(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper_do_actual(*args, **kwargs):
|
||||
if len(args)!=0 or len(kwargs)!=0:
|
||||
raise Exception("Expect call with no arguements(for all args are given by coprocessor itself)")
|
||||
source = inspect.getsource(func)
|
||||
url = "http://{}/v1/scripts".format(get_conn_addr())
|
||||
print("Posting to {}".format(url))
|
||||
data = {
|
||||
"script": source,
|
||||
"engine": None,
|
||||
}
|
||||
|
||||
res = requests.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=data
|
||||
)
|
||||
return res
|
||||
return wrapper_do_actual
|
||||
return decorator_copr
|
||||
|
||||
|
||||
# make a alias for short
|
||||
copr = coprocessor
|
||||
82
component/script/python/greptime/mock.py
Normal file
82
component/script/python/greptime/mock.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Note this is a mock library, if not connected to database,
|
||||
it can only run on mock data and support by numpy
|
||||
"""
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
from .greptime import i32,i64,f32,f64, vector, interval, query, prev, datetime, log, sum, sqrt, pow, nan, copr, coprocessor
|
||||
|
||||
import inspect
|
||||
import functools
|
||||
import ast
|
||||
|
||||
|
||||
|
||||
def mock_tester(
|
||||
func,
|
||||
env:dict,
|
||||
table=None
|
||||
):
|
||||
"""
|
||||
Mock tester helper function,
|
||||
What it does is replace `@coprocessor` with `@mock_cpor` and add a keyword `env=env`
|
||||
like `@mock_copr(args=...,returns=...,env=env)`
|
||||
"""
|
||||
code = inspect.getsource(func)
|
||||
tree = ast.parse(code)
|
||||
tree = HackyReplaceDecorator("env").visit(tree)
|
||||
new_func = tree.body[0]
|
||||
fn_name = new_func.name
|
||||
|
||||
code_obj = compile(tree, "<embedded>", "exec")
|
||||
exec(code_obj)
|
||||
|
||||
ret = eval("{}()".format(fn_name))
|
||||
return ret
|
||||
|
||||
def mock_copr(args, returns, sql=None, env:None|dict=None):
|
||||
"""
|
||||
This should not be used directly by user
|
||||
"""
|
||||
def decorator_copr(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper_do_actual(*fn_args, **fn_kwargs):
|
||||
|
||||
real_args = [env[name] for name in args]
|
||||
ret = func(*real_args)
|
||||
return ret
|
||||
|
||||
return wrapper_do_actual
|
||||
return decorator_copr
|
||||
|
||||
class HackyReplaceDecorator(ast.NodeTransformer):
|
||||
"""
|
||||
This class accept a `env` dict for environment to extract args from,
|
||||
and put `env` dict in the param list of `mock_copr` decorator, i.e:
|
||||
|
||||
a `@copr(args=["a", "b"], returns=["c"])` with call like mock_helper(abc, env={"a":2, "b":3})
|
||||
|
||||
will be transform into `@mock_copr(args=["a", "b"], returns=["c"], env={"a":2, "b":3})`
|
||||
"""
|
||||
def __init__(self, env: str) -> None:
|
||||
# just for add `env` keyword
|
||||
self.env = env
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||
new_node = node
|
||||
decorator_list = new_node.decorator_list
|
||||
if len(decorator_list)!=1:
|
||||
return node
|
||||
|
||||
deco = decorator_list[0]
|
||||
if deco.func.id!="coprocessor" and deco.func.id !="copr":
|
||||
raise Exception("Expect a @copr or @coprocessor, found {}.".format(deco.func.id))
|
||||
deco.func = ast.Name(id="mock_copr", ctx=ast.Load())
|
||||
new_kw = ast.keyword(arg="env", value=ast.Name(id=self.env, ctx=ast.Load()))
|
||||
deco.keywords.append(new_kw)
|
||||
|
||||
# Tie up loose ends in the AST.
|
||||
ast.copy_location(new_node, node)
|
||||
ast.fix_missing_locations(new_node)
|
||||
self.generic_visit(node)
|
||||
return new_node
|
||||
55
component/script/python/test.py
Normal file
55
component/script/python/test.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from example.calc_rv import as_table, calc_rvs
|
||||
from greptime import coprocessor, set_conn_addr, get_conn_addr, mock_tester
|
||||
import sys
|
||||
import json
|
||||
import requests
|
||||
'''
|
||||
To run this script, you need to first start a http server of greptime, and
|
||||
`
|
||||
python3 component/script/python/test.py 地址:端口
|
||||
`
|
||||
|
||||
'''
|
||||
@coprocessor(sql='select number from numbers limit 10', args=['number'], returns=['n'])
|
||||
def test(n):
|
||||
return n+2
|
||||
|
||||
def init_table(close, open_time):
|
||||
req_init = "/v1/sql?sql=create table k_line (close double, open_time bigint, TIME INDEX (open_time))"
|
||||
print(get_db(req_init).text)
|
||||
for c1, c2 in zip(close, open_time):
|
||||
req = "/v1/sql?sql=INSERT INTO k_line(close, open_time) VALUES ({}, {})".format(c1, c2)
|
||||
print(get_db(req).text)
|
||||
print(get_db("/v1/sql?sql=select * from k_line").text)
|
||||
|
||||
def get_db(req:str):
|
||||
return requests.get("http://{}{}".format(get_conn_addr(), req))
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv)!=2:
|
||||
raise Exception("Expect only one address as cmd's args")
|
||||
set_conn_addr(sys.argv[1])
|
||||
res = test()
|
||||
print(res.headers)
|
||||
print(res.text)
|
||||
with open("component/script/python/example/kline.json", "r") as kline_file:
|
||||
kline = json.load(kline_file)
|
||||
# vec = vector([1,2,3], int)
|
||||
# print(vec, vec.datatype())
|
||||
table = as_table(kline["result"])
|
||||
# print(table)
|
||||
close = table["close"]
|
||||
open_time = table["open_time"]
|
||||
init_table(close, open_time)
|
||||
|
||||
# print(repr(close), repr(open_time))
|
||||
# print("calc_rv:", calc_rv(close, open_time, open_time[-1]+datetime("10m"), datetime("7d")))
|
||||
env = {"close":close, "open_time": open_time}
|
||||
# print("env:", env)
|
||||
print("Mock result:", mock_tester(calc_rvs, env=env))
|
||||
real = calc_rvs()
|
||||
print(real)
|
||||
try:
|
||||
print(real.text["error"])
|
||||
except:
|
||||
print(real.text)
|
||||
@@ -1,4 +1,4 @@
|
||||
#![allow(clippy::all)]
|
||||
#![allow(clippy::derive_partial_eq_without_eq)]
|
||||
tonic::include_proto!("greptime.v1");
|
||||
|
||||
pub mod codec {
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "common-function"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[dependencies.arrow]
|
||||
package = "arrow2"
|
||||
version="0.10"
|
||||
features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc", "ahash", "compute", "serde_types"]
|
||||
|
||||
package = "arrow2"
|
||||
version = "0.10"
|
||||
|
||||
[dependencies]
|
||||
arc-swap = "1.0"
|
||||
@@ -17,9 +15,20 @@ common-error = { path = "../error" }
|
||||
common-query = { path = "../query" }
|
||||
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git" , branch = "arrow2" }
|
||||
datatypes = { path = "../../datatypes" }
|
||||
num = "0.4.0"
|
||||
num-traits = "0.2.14"
|
||||
libc = "0.2"
|
||||
num = "0.4"
|
||||
num-traits = "0.2"
|
||||
once_cell = "1.10"
|
||||
paste = "1.0"
|
||||
rustpython-ast = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-bytecode = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-compiler = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-compiler-core = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-parser = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-vm = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
statrs = "0.15"
|
||||
statrs = "0.15"
|
||||
|
||||
[dev-dependencies]
|
||||
ron = "0.7"
|
||||
serde = {version = "1.0", features = ["derive"]}
|
||||
|
||||
@@ -2,7 +2,7 @@ mod pow;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use pow::PowFunction;
|
||||
pub use pow::PowFunction;
|
||||
|
||||
use crate::scalars::function_registry::FunctionRegistry;
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
//! Error of record batch.
|
||||
use std::any::Any;
|
||||
|
||||
use common_error::ext::BoxedError;
|
||||
use common_error::prelude::*;
|
||||
|
||||
common_error::define_opaque_error!(Error);
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -21,6 +21,12 @@ pub enum InnerError {
|
||||
#[snafu(backtrace)]
|
||||
source: datatypes::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("External error, source: {}", source))]
|
||||
External {
|
||||
#[snafu(backtrace)]
|
||||
source: BoxedError,
|
||||
},
|
||||
}
|
||||
|
||||
impl ErrorExt for InnerError {
|
||||
@@ -28,6 +34,7 @@ impl ErrorExt for InnerError {
|
||||
match self {
|
||||
InnerError::NewDfRecordBatch { .. } => StatusCode::InvalidArguments,
|
||||
InnerError::DataTypes { .. } => StatusCode::Internal,
|
||||
InnerError::External { source } => source.status_code(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,12 @@ package = "arrow2"
|
||||
version = "0.10"
|
||||
features = ["io_csv", "io_json", "io_parquet", "io_parquet_compression", "io_ipc", "ahash", "compute", "serde_types"]
|
||||
|
||||
[features]
|
||||
default = ["python"]
|
||||
python = [
|
||||
"dep:script"
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
api = { path = "../api" }
|
||||
async-trait = "0.1"
|
||||
@@ -30,6 +36,7 @@ log-store = { path = "../log-store" }
|
||||
metrics = "0.20"
|
||||
object-store = { path = "../object-store" }
|
||||
query = { path = "../query" }
|
||||
script = { path = "../script", features = ["python"], optional = true }
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
servers = { path = "../servers" }
|
||||
|
||||
@@ -23,6 +23,7 @@ use crate::error::{
|
||||
UnsupportedExprSnafu,
|
||||
};
|
||||
use crate::metric;
|
||||
use crate::script::ScriptExecutor;
|
||||
use crate::server::grpc::handler::{build_err_result, ObjectResultBuilder};
|
||||
use crate::server::grpc::insert::insertion_expr_to_request;
|
||||
use crate::server::grpc::plan::PhysicalPlanner;
|
||||
@@ -39,6 +40,7 @@ pub struct Instance {
|
||||
// Catalog list
|
||||
catalog_manager: CatalogManagerRef,
|
||||
physical_planner: PhysicalPlanner,
|
||||
script_executor: ScriptExecutor,
|
||||
}
|
||||
|
||||
pub type InstanceRef = Arc<Instance>;
|
||||
@@ -64,12 +66,14 @@ impl Instance {
|
||||
);
|
||||
let factory = QueryEngineFactory::new(catalog_manager.clone());
|
||||
let query_engine = factory.query_engine().clone();
|
||||
let script_executor = ScriptExecutor::new(query_engine.clone());
|
||||
|
||||
Ok(Self {
|
||||
query_engine: query_engine.clone(),
|
||||
sql_handler: SqlHandler::new(table_engine, catalog_manager.clone()),
|
||||
catalog_manager,
|
||||
physical_planner: PhysicalPlanner::new(query_engine),
|
||||
script_executor,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -251,6 +255,10 @@ impl SqlQueryHandler for Instance {
|
||||
})
|
||||
.context(servers::error::ExecuteQuerySnafu { query })
|
||||
}
|
||||
|
||||
async fn execute_script(&self, script: &str) -> servers::error::Result<Output> {
|
||||
self.script_executor.execute_script(script).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@@ -4,6 +4,7 @@ pub mod datanode;
|
||||
pub mod error;
|
||||
pub mod instance;
|
||||
mod metric;
|
||||
mod script;
|
||||
pub mod server;
|
||||
mod sql;
|
||||
#[cfg(test)]
|
||||
|
||||
70
src/datanode/src/script.rs
Normal file
70
src/datanode/src/script.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
use query::Output;
|
||||
use query::QueryEngineRef;
|
||||
|
||||
#[cfg(not(feature = "python"))]
|
||||
mod dummy {
|
||||
use super::*;
|
||||
|
||||
pub struct ScriptExecutor;
|
||||
|
||||
impl ScriptExecutor {
|
||||
pub fn new(_query_engine: QueryEngineRef) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
pub async fn execute_script(&self, _script: &str) -> servers::error::Result<Output> {
|
||||
servers::error::NotSupportedSnafu { feat: "script" }.fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python")]
|
||||
mod python {
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_telemetry::logging::error;
|
||||
use script::{
|
||||
engine::{CompileContext, EvalContext, Script, ScriptEngine},
|
||||
python::PyEngine,
|
||||
};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub struct ScriptExecutor {
|
||||
py_engine: PyEngine,
|
||||
}
|
||||
|
||||
impl ScriptExecutor {
|
||||
pub fn new(query_engine: QueryEngineRef) -> Self {
|
||||
Self {
|
||||
py_engine: PyEngine::new(query_engine),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn execute_script(&self, script: &str) -> servers::error::Result<Output> {
|
||||
let py_script = self
|
||||
.py_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(e; "Instance failed to execute script");
|
||||
BoxedError::new(e)
|
||||
})
|
||||
.context(servers::error::ExecuteScriptSnafu { script })?;
|
||||
|
||||
py_script
|
||||
.evaluate(EvalContext::default())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(e; "Instance failed to execute script");
|
||||
BoxedError::new(e)
|
||||
})
|
||||
.context(servers::error::ExecuteScriptSnafu { script })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "python"))]
|
||||
pub use self::dummy::*;
|
||||
#[cfg(feature = "python")]
|
||||
pub use self::python::*;
|
||||
@@ -1,9 +1,12 @@
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::http::StatusCode;
|
||||
use axum::Router;
|
||||
use axum_test_helper::TestClient;
|
||||
use servers::http::handler::ScriptExecution;
|
||||
use servers::http::HttpServer;
|
||||
use servers::server::Server;
|
||||
use test_util::TestGuard;
|
||||
|
||||
use crate::instance::Instance;
|
||||
@@ -23,7 +26,7 @@ async fn test_sql_api() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let (app, _guard) = make_test_app().await;
|
||||
let client = TestClient::new(app);
|
||||
let res = client.get("/sql").send().await;
|
||||
let res = client.get("/v1/sql").send().await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let body = res.text().await;
|
||||
@@ -33,7 +36,7 @@ async fn test_sql_api() {
|
||||
);
|
||||
|
||||
let res = client
|
||||
.get("/sql?sql=select * from numbers limit 10")
|
||||
.get("/v1/sql?sql=select * from numbers limit 10")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
@@ -46,14 +49,14 @@ async fn test_sql_api() {
|
||||
|
||||
// test insert and select
|
||||
let res = client
|
||||
.get("/sql?sql=insert into demo values('host', 66.6, 1024, 0)")
|
||||
.get("/v1/sql?sql=insert into demo values('host', 66.6, 1024, 0)")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
// select *
|
||||
let res = client
|
||||
.get("/sql?sql=select * from demo limit 10")
|
||||
.get("/v1/sql?sql=select * from demo limit 10")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
@@ -66,7 +69,7 @@ async fn test_sql_api() {
|
||||
|
||||
// select with projections
|
||||
let res = client
|
||||
.get("/sql?sql=select cpu, ts from demo limit 10")
|
||||
.get("/v1/sql?sql=select cpu, ts from demo limit 10")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
@@ -87,7 +90,7 @@ async fn test_metrics_api() {
|
||||
|
||||
// Send a sql
|
||||
let res = client
|
||||
.get("/sql?sql=select * from numbers limit 10")
|
||||
.get("/v1/sql?sql=select * from numbers limit 10")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
@@ -98,3 +101,50 @@ async fn test_metrics_api() {
|
||||
let body = res.text().await;
|
||||
assert!(body.contains("datanode_handle_sql_elapsed"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts_api() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let (app, _guard) = make_test_app().await;
|
||||
let client = TestClient::new(app);
|
||||
let res = client
|
||||
.post("/v1/scripts")
|
||||
.json(&ScriptExecution {
|
||||
script: r#"
|
||||
@copr(sql='select number from numbers limit 10', args=['number'], returns=['n'])
|
||||
def test(n):
|
||||
return n;
|
||||
"#
|
||||
.to_string(),
|
||||
})
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let body = res.text().await;
|
||||
assert_eq!(
|
||||
body,
|
||||
r#"{"success":true,"output":{"Rows":[{"schema":{"fields":[{"name":"n","data_type":"UInt32","is_nullable":false,"metadata":{}}],"metadata":{}},"columns":[[0,1,2,3,4,5,6,7,8,9]]}]}}"#
|
||||
);
|
||||
}
|
||||
|
||||
async fn start_test_app(addr: &str) -> (SocketAddr, TestGuard) {
|
||||
let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts();
|
||||
let instance = Arc::new(Instance::new(&opts).await.unwrap());
|
||||
instance.start().await.unwrap();
|
||||
let mut http_server = HttpServer::new(instance);
|
||||
(
|
||||
http_server.start(addr.parse().unwrap()).await.unwrap(),
|
||||
guard,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[tokio::test]
|
||||
async fn test_py_side_scripts_api() {
|
||||
// TODO(discord9): make a working test case, it will require python3 with numpy installed, complex environment setup expected....
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let server = start_test_app("127.0.0.1:21830");
|
||||
// let (app, _guard) = server.await;
|
||||
// dbg!(app);
|
||||
}
|
||||
|
||||
50
src/script/Cargo.toml
Normal file
50
src/script/Cargo.toml
Normal file
@@ -0,0 +1,50 @@
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "script"
|
||||
version = "0.1.0"
|
||||
|
||||
[features]
|
||||
default = ["python"]
|
||||
python = [
|
||||
"dep:datafusion",
|
||||
"dep:datafusion-expr",
|
||||
"dep:datafusion-physical-expr",
|
||||
"dep:rustpython-vm",
|
||||
"dep:rustpython-parser",
|
||||
"dep:rustpython-compiler",
|
||||
"dep:rustpython-compiler-core",
|
||||
"dep:rustpython-bytecode",
|
||||
"dep:rustpython-ast",
|
||||
]
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
common-error = {path = "../common/error"}
|
||||
common-function = { path = "../common/function" }
|
||||
common-query = {path = "../common/query"}
|
||||
common-recordbatch = {path = "../common/recordbatch" }
|
||||
console = "0.15"
|
||||
datafusion = {git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", optional = true}
|
||||
datafusion-common = {git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2"}
|
||||
datafusion-expr = {git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", optional = true}
|
||||
datafusion-physical-expr = {git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", optional = true}
|
||||
datatypes = {path = "../datatypes"}
|
||||
futures-util = "0.3"
|
||||
futures = "0.3"
|
||||
query = { path = "../query" }
|
||||
rustpython-ast = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-bytecode = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-compiler = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-compiler-core = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-parser = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
rustpython-vm = {git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d"}
|
||||
snafu = {version = "0.7", features = ["backtraces"]}
|
||||
sql = { path = "../sql" }
|
||||
|
||||
[dev-dependencies]
|
||||
catalog = { path = "../catalog" }
|
||||
ron = "0.7"
|
||||
serde = {version = "1.0", features = ["derive"]}
|
||||
table = { path = "../table" }
|
||||
tokio = { version = "1.18", features = ["full"] }
|
||||
tokio-test = "0.4"
|
||||
46
src/script/src/engine.rs
Normal file
46
src/script/src/engine.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! Script engine
|
||||
|
||||
use std::any::Any;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_error::ext::ErrorExt;
|
||||
use query::Output;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Script {
|
||||
type Error: ErrorExt + Send + Sync;
|
||||
|
||||
/// Returns the script engine name such as `python` etc.
|
||||
fn engine_name(&self) -> &str;
|
||||
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Evaluate the script and returns the output.
|
||||
async fn evaluate(&self, ctx: EvalContext) -> std::result::Result<Output, Self::Error>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ScriptEngine {
|
||||
type Error: ErrorExt + Send + Sync;
|
||||
type Script: Script<Error = Self::Error>;
|
||||
|
||||
/// Returns the script engine name such as `python` etc.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Compile a script text into a script instance.
|
||||
async fn compile(
|
||||
&self,
|
||||
script: &str,
|
||||
ctx: CompileContext,
|
||||
) -> std::result::Result<Self::Script, Self::Error>;
|
||||
}
|
||||
|
||||
/// Evalute script context
|
||||
#[derive(Debug, Default)]
|
||||
pub struct EvalContext {}
|
||||
|
||||
/// Compile script context
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CompileContext {}
|
||||
3
src/script/src/lib.rs
Normal file
3
src/script/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod engine;
|
||||
#[cfg(feature = "python")]
|
||||
pub mod python;
|
||||
13
src/script/src/python.rs
Normal file
13
src/script/src/python.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
//! Python script coprocessor
|
||||
|
||||
mod builtins;
|
||||
pub(crate) mod coprocessor;
|
||||
mod engine;
|
||||
pub mod error;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
pub(crate) mod utils;
|
||||
mod vector;
|
||||
|
||||
pub use self::engine::{PyEngine, PyScript};
|
||||
pub use self::vector::PyVector;
|
||||
768
src/script/src/python/builtins/mod.rs
Normal file
768
src/script/src/python/builtins/mod.rs
Normal file
@@ -0,0 +1,768 @@
|
||||
//! Builtin module contains GreptimeDB builtin udf/udaf
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::print_stdout)]
|
||||
mod test;
|
||||
|
||||
use datafusion_common::{DataFusionError, ScalarValue};
|
||||
use datafusion_expr::ColumnarValue as DFColValue;
|
||||
use datafusion_physical_expr::AggregateExpr;
|
||||
use datatypes::arrow;
|
||||
use datatypes::arrow::array::ArrayRef;
|
||||
use datatypes::arrow::compute::cast::CastOptions;
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use datatypes::vectors::Helper as HelperVec;
|
||||
use rustpython_vm::builtins::PyList;
|
||||
use rustpython_vm::pymodule;
|
||||
use rustpython_vm::{
|
||||
builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt},
|
||||
AsObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
|
||||
};
|
||||
|
||||
use crate::python::utils::is_instance;
|
||||
use crate::python::PyVector;
|
||||
|
||||
/// "Can't cast operand of type `{name}` into `{ty}`."
|
||||
fn type_cast_error(name: &str, ty: &str, vm: &VirtualMachine) -> PyBaseExceptionRef {
|
||||
vm.new_type_error(format!("Can't cast operand of type `{name}` into `{ty}`."))
|
||||
}
|
||||
|
||||
fn collect_diff_types_string(values: &[ScalarValue], ty: &DataType) -> String {
|
||||
values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, val)| {
|
||||
if val.get_datatype() != *ty {
|
||||
Some((idx, val.get_datatype()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.map(|(idx, ty)| format!(" {:?} at {}th location\n", ty, idx + 1))
|
||||
.reduce(|mut acc, item| {
|
||||
acc.push_str(&item);
|
||||
acc
|
||||
})
|
||||
.unwrap_or_else(|| "Nothing".to_string())
|
||||
}
|
||||
|
||||
/// try to turn a Python Object into a PyVector or a scalar that can be use for calculate
|
||||
///
|
||||
/// supported scalar are(leftside is python data type, right side is rust type):
|
||||
///
|
||||
/// | Python | Rust |
|
||||
/// | ------ | ---- |
|
||||
/// | integer| i64 |
|
||||
/// | float | f64 |
|
||||
/// | bool | bool |
|
||||
/// | vector | array|
|
||||
/// | list | `ScalarValue::List` |
|
||||
fn try_into_columnar_value(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<DFColValue> {
|
||||
if is_instance::<PyVector>(&obj, vm) {
|
||||
let ret = obj
|
||||
.payload::<PyVector>()
|
||||
.ok_or_else(|| type_cast_error(&obj.class().name(), "vector", vm))?;
|
||||
Ok(DFColValue::Array(ret.to_arrow_array()))
|
||||
} else if is_instance::<PyBool>(&obj, vm) {
|
||||
// Note that a `PyBool` is also a `PyInt`, so check if it is a bool first to get a more precise type
|
||||
let ret = obj.try_into_value::<bool>(vm)?;
|
||||
Ok(DFColValue::Scalar(ScalarValue::Boolean(Some(ret))))
|
||||
} else if is_instance::<PyInt>(&obj, vm) {
|
||||
let ret = obj.try_into_value::<i64>(vm)?;
|
||||
Ok(DFColValue::Scalar(ScalarValue::Int64(Some(ret))))
|
||||
} else if is_instance::<PyFloat>(&obj, vm) {
|
||||
let ret = obj.try_into_value::<f64>(vm)?;
|
||||
Ok(DFColValue::Scalar(ScalarValue::Float64(Some(ret))))
|
||||
} else if is_instance::<PyList>(&obj, vm) {
|
||||
let ret = obj
|
||||
.payload::<PyList>()
|
||||
.ok_or_else(|| type_cast_error(&obj.class().name(), "vector", vm))?;
|
||||
let ret: Vec<ScalarValue> = ret
|
||||
.borrow_vec()
|
||||
.iter()
|
||||
.map(|obj| -> PyResult<ScalarValue> {
|
||||
let col = try_into_columnar_value(obj.to_owned(), vm)?;
|
||||
match col {
|
||||
DFColValue::Array(arr) => Err(vm.new_type_error(format!(
|
||||
"Expect only scalar value in a list, found a vector of type {:?} nested in list", arr.data_type()
|
||||
))),
|
||||
DFColValue::Scalar(val) => Ok(val),
|
||||
}
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
if ret.is_empty() {
|
||||
//TODO(dennis): empty list, we set type as f64.
|
||||
return Ok(DFColValue::Scalar(ScalarValue::List(
|
||||
None,
|
||||
Box::new(DataType::Float64),
|
||||
)));
|
||||
}
|
||||
|
||||
let ty = ret[0].get_datatype();
|
||||
if ret.iter().any(|i| i.get_datatype() != ty) {
|
||||
return Err(vm.new_type_error(format!(
|
||||
"All elements in a list should be same type to cast to Datafusion list!\nExpect {ty:?}, found {}",
|
||||
collect_diff_types_string(&ret, &ty)
|
||||
)));
|
||||
}
|
||||
Ok(DFColValue::Scalar(ScalarValue::List(
|
||||
Some(Box::new(ret)),
|
||||
Box::new(ty),
|
||||
)))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!(
|
||||
"Can't cast object of type {} into vector or scalar",
|
||||
obj.class().name()
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// cast a columnar value into python object
|
||||
///
|
||||
/// | Rust | Python |
|
||||
/// | ------ | --------------- |
|
||||
/// | Array | PyVector |
|
||||
/// | Scalar | int/float/bool |
|
||||
fn try_into_py_obj(col: DFColValue, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
match col {
|
||||
DFColValue::Array(arr) => {
|
||||
let ret = PyVector::from(
|
||||
HelperVec::try_into_vector(arr)
|
||||
.map_err(|err| vm.new_type_error(format!("Unsupported type: {:#?}", err)))?,
|
||||
)
|
||||
.into_pyobject(vm);
|
||||
Ok(ret)
|
||||
}
|
||||
DFColValue::Scalar(val) => scalar_val_try_into_py_obj(val, vm),
|
||||
}
|
||||
}
|
||||
|
||||
/// turn a ScalarValue into a Python Object, currently support
|
||||
///
|
||||
/// ScalarValue -> Python Type
|
||||
/// - Float64 -> PyFloat
|
||||
/// - Int64 -> PyInt
|
||||
/// - UInt64 -> PyInt
|
||||
/// - List -> PyList(of inner ScalarValue)
|
||||
fn scalar_val_try_into_py_obj(val: ScalarValue, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
match val {
|
||||
ScalarValue::Float32(Some(v)) => Ok(vm.ctx.new_float(v.into()).into()),
|
||||
ScalarValue::Float64(Some(v)) => Ok(PyFloat::from(v).into_pyobject(vm)),
|
||||
ScalarValue::Int64(Some(v)) => Ok(PyInt::from(v).into_pyobject(vm)),
|
||||
ScalarValue::UInt64(Some(v)) => Ok(PyInt::from(v).into_pyobject(vm)),
|
||||
ScalarValue::List(Some(col), _) => {
|
||||
let list = col
|
||||
.into_iter()
|
||||
.map(|v| scalar_val_try_into_py_obj(v, vm))
|
||||
.collect::<Result<_, _>>()?;
|
||||
let list = vm.ctx.new_list(list);
|
||||
Ok(list.into())
|
||||
}
|
||||
_ => Err(vm.new_type_error(format!(
|
||||
"Can't cast a Scalar Value `{val:#?}` of type {:#?} to a Python Object",
|
||||
val.get_datatype()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Because most of the datafusion's UDF only support f32/64, so cast all to f64 to use datafusion's UDF
|
||||
fn all_to_f64(col: DFColValue, vm: &VirtualMachine) -> PyResult<DFColValue> {
|
||||
match col {
|
||||
DFColValue::Array(arr) => {
|
||||
let res = arrow::compute::cast::cast(
|
||||
arr.as_ref(),
|
||||
&DataType::Float64,
|
||||
CastOptions {
|
||||
wrapped: true,
|
||||
partial: true,
|
||||
},
|
||||
)
|
||||
.map_err(|err| {
|
||||
vm.new_type_error(format!(
|
||||
"Arrow Type Cast Fail(from {:#?} to {:#?}): {err:#?}",
|
||||
arr.data_type(),
|
||||
DataType::Float64
|
||||
))
|
||||
})?;
|
||||
Ok(DFColValue::Array(res.into()))
|
||||
}
|
||||
DFColValue::Scalar(val) => {
|
||||
let val_in_f64 = match val {
|
||||
ScalarValue::Float64(Some(v)) => v,
|
||||
ScalarValue::Int64(Some(v)) => v as f64,
|
||||
ScalarValue::Boolean(Some(v)) => v as i64 as f64,
|
||||
_ => {
|
||||
return Err(vm.new_type_error(format!(
|
||||
"Can't cast type {:#?} to {:#?}",
|
||||
val.get_datatype(),
|
||||
DataType::Float64
|
||||
)))
|
||||
}
|
||||
};
|
||||
Ok(DFColValue::Scalar(ScalarValue::Float64(Some(val_in_f64))))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// use to bind to Data Fusion's UDF function
|
||||
/// P.S: seems due to proc macro issues, can't just use #[pyfunction] in here
|
||||
macro_rules! bind_call_unary_math_function {
|
||||
($DF_FUNC: ident, $vm: ident $(,$ARG: ident)*) => {
|
||||
fn inner_fn($($ARG: PyObjectRef,)* vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
let args = &[$(all_to_f64(try_into_columnar_value($ARG, vm)?, vm)?,)*];
|
||||
let res = math_expressions::$DF_FUNC(args).map_err(|err| from_df_err(err, vm))?;
|
||||
let ret = try_into_py_obj(res, vm)?;
|
||||
Ok(ret)
|
||||
}
|
||||
return inner_fn($($ARG,)* $vm);
|
||||
};
|
||||
}
|
||||
|
||||
/// The macro for binding function in `datafusion_physical_expr::expressions`(most of them are aggregate function)
|
||||
///
|
||||
/// - first arguements is the name of datafusion expression function like `Avg`
|
||||
/// - second is the python virtual machine ident `vm`
|
||||
/// - following is the actual args passing in(as a slice).i.e.`&[values.to_arrow_array()]`
|
||||
/// - the data type of passing in args, i.e: `Datatype::Float64`
|
||||
/// - lastly ARE names given to expr of those function, i.e. `expr0, expr1,`....
|
||||
macro_rules! bind_aggr_fn {
|
||||
($AGGR_FUNC: ident, $VM: ident, $ARGS:expr, $DATA_TYPE: expr $(, $EXPR_ARGS: ident)*) => {
|
||||
// just a place holder, we just want the inner `XXXAccumulator`'s function
|
||||
// so its expr is irrelevant
|
||||
return eval_aggr_fn(
|
||||
expressions::$AGGR_FUNC::new(
|
||||
$(
|
||||
Arc::new(expressions::Column::new(stringify!($EXPR_ARGS), 0)) as _,
|
||||
)*
|
||||
stringify!($AGGR_FUNC), $DATA_TYPE.to_owned()),
|
||||
$ARGS, $VM)
|
||||
};
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_df_err(err: DataFusionError, vm: &VirtualMachine) -> PyBaseExceptionRef {
|
||||
vm.new_runtime_error(format!("Data Fusion Error: {err:#?}"))
|
||||
}
|
||||
|
||||
/// evalute Aggregate Expr using its backing accumulator
|
||||
fn eval_aggr_fn<T: AggregateExpr>(
|
||||
aggr: T,
|
||||
values: &[ArrayRef],
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyObjectRef> {
|
||||
// acquire the accumulator, where the actual implement of aggregate expr layers
|
||||
let mut acc = aggr
|
||||
.create_accumulator()
|
||||
.map_err(|err| from_df_err(err, vm))?;
|
||||
acc.update_batch(values)
|
||||
.map_err(|err| from_df_err(err, vm))?;
|
||||
let res = acc.evaluate().map_err(|err| from_df_err(err, vm))?;
|
||||
scalar_val_try_into_py_obj(res, vm)
|
||||
}
|
||||
|
||||
/// GrepTime User Define Function module
|
||||
///
|
||||
/// allow Python Coprocessor Function to use already implemented udf functions from datafusion and GrepTime DB itself
|
||||
///
|
||||
#[pymodule]
|
||||
pub(crate) mod greptime_builtin {
|
||||
// P.S.: not extract to file because not-inlined proc macro attribute is *unstable*
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_function::scalars::math::PowFunction;
|
||||
use common_function::scalars::{function::FunctionContext, Function};
|
||||
use datafusion::physical_plan::expressions;
|
||||
use datafusion_expr::ColumnarValue as DFColValue;
|
||||
use datafusion_physical_expr::math_expressions;
|
||||
use datatypes::arrow;
|
||||
use datatypes::arrow::array::{ArrayRef, NullArray};
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::vectors::{ConstantVector, Float64Vector, Helper, Int64Vector};
|
||||
use rustpython_vm::builtins::{PyFloat, PyInt, PyStr};
|
||||
use rustpython_vm::function::OptionalArg;
|
||||
use rustpython_vm::{AsObject, PyObjectRef, PyResult, VirtualMachine};
|
||||
|
||||
use crate::python::builtins::{
|
||||
all_to_f64, eval_aggr_fn, from_df_err, try_into_columnar_value, try_into_py_obj,
|
||||
type_cast_error,
|
||||
};
|
||||
use crate::python::utils::is_instance;
|
||||
use crate::python::utils::PyVectorRef;
|
||||
use crate::python::PyVector;
|
||||
|
||||
#[pyfunction]
|
||||
fn vector(args: OptionalArg<PyObjectRef>, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
PyVector::new(args, vm)
|
||||
}
|
||||
|
||||
// the main binding code, due to proc macro things, can't directly use a simpler macro
|
||||
// because pyfunction is not a attr?
|
||||
|
||||
// The math function return a general PyObjectRef
|
||||
// so it can return both PyVector or a scalar PyInt/Float/Bool
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `sqrt` math function
|
||||
#[pyfunction]
|
||||
fn sqrt(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(sqrt, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `sin` math function
|
||||
#[pyfunction]
|
||||
fn sin(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(sin, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `cos` math function
|
||||
#[pyfunction]
|
||||
fn cos(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(cos, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `tan` math function
|
||||
#[pyfunction]
|
||||
fn tan(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(tan, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `asin` math function
|
||||
#[pyfunction]
|
||||
fn asin(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(asin, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `acos` math function
|
||||
#[pyfunction]
|
||||
fn acos(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(acos, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `atan` math function
|
||||
#[pyfunction]
|
||||
fn atan(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(atan, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `floor` math function
|
||||
#[pyfunction]
|
||||
fn floor(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(floor, vm, val);
|
||||
}
|
||||
/// simple math function, the backing implement is datafusion's `ceil` math function
|
||||
#[pyfunction]
|
||||
fn ceil(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(ceil, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `round` math function
|
||||
#[pyfunction]
|
||||
fn round(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(round, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `trunc` math function
|
||||
#[pyfunction]
|
||||
fn trunc(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(trunc, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `abs` math function
|
||||
#[pyfunction]
|
||||
fn abs(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(abs, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `signum` math function
|
||||
#[pyfunction]
|
||||
fn signum(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(signum, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `exp` math function
|
||||
#[pyfunction]
|
||||
fn exp(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(exp, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `ln` math function
|
||||
#[pyfunction(name = "log")]
|
||||
#[pyfunction]
|
||||
fn ln(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(ln, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `log2` math function
|
||||
#[pyfunction]
|
||||
fn log2(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(log2, vm, val);
|
||||
}
|
||||
|
||||
/// simple math function, the backing implement is datafusion's `log10` math function
|
||||
#[pyfunction]
|
||||
fn log10(val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_call_unary_math_function!(log10, vm, val);
|
||||
}
|
||||
|
||||
/// return a random vector range from 0 to 1 and length of len
|
||||
#[pyfunction]
|
||||
fn random(len: usize, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
// This is in a proc macro so using full path to avoid strange things
|
||||
// more info at: https://doc.rust-lang.org/reference/procedural-macros.html#procedural-macro-hygiene
|
||||
let arg = NullArray::new(arrow::datatypes::DataType::Null, len);
|
||||
let args = &[DFColValue::Array(std::sync::Arc::new(arg) as _)];
|
||||
let res = math_expressions::random(args).map_err(|err| from_df_err(err, vm))?;
|
||||
let ret = try_into_py_obj(res, vm)?;
|
||||
Ok(ret)
|
||||
}
|
||||
// UDAF(User Defined Aggregate Function) in datafusion
|
||||
|
||||
#[pyfunction]
|
||||
fn approx_distinct(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
ApproxDistinct,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
/// Not implement in datafusion
|
||||
/// TODO(discord9): use greptime's own impl instead
|
||||
/*
|
||||
#[pyfunction]
|
||||
fn approx_median(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
ApproxMedian,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
*/
|
||||
|
||||
#[pyfunction]
|
||||
fn approx_percentile_cont(
|
||||
values: PyVectorRef,
|
||||
percent: f64,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyObjectRef> {
|
||||
let percent =
|
||||
expressions::Literal::new(datafusion_common::ScalarValue::Float64(Some(percent)));
|
||||
return eval_aggr_fn(
|
||||
expressions::ApproxPercentileCont::new(
|
||||
vec![
|
||||
Arc::new(expressions::Column::new("expr0", 0)) as _,
|
||||
Arc::new(percent) as _,
|
||||
],
|
||||
"ApproxPercentileCont",
|
||||
(values.to_arrow_array().data_type()).to_owned(),
|
||||
)
|
||||
.map_err(|err| from_df_err(err, vm))?,
|
||||
&[values.to_arrow_array()],
|
||||
vm,
|
||||
);
|
||||
}
|
||||
|
||||
/// effectively equals to `list(vector)`
|
||||
#[pyfunction]
|
||||
fn array_agg(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
ArrayAgg,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
/// directly port from datafusion's `avg` function
|
||||
#[pyfunction]
|
||||
fn avg(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
Avg,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn correlation(
|
||||
arg0: PyVectorRef,
|
||||
arg1: PyVectorRef,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
Correlation,
|
||||
vm,
|
||||
&[arg0.to_arrow_array(), arg1.to_arrow_array()],
|
||||
arg0.to_arrow_array().data_type(),
|
||||
expr0,
|
||||
expr1
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn count(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
Count,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn covariance(
|
||||
arg0: PyVectorRef,
|
||||
arg1: PyVectorRef,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
Covariance,
|
||||
vm,
|
||||
&[arg0.to_arrow_array(), arg1.to_arrow_array()],
|
||||
arg0.to_arrow_array().data_type(),
|
||||
expr0,
|
||||
expr1
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn covariance_pop(
|
||||
arg0: PyVectorRef,
|
||||
arg1: PyVectorRef,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
CovariancePop,
|
||||
vm,
|
||||
&[arg0.to_arrow_array(), arg1.to_arrow_array()],
|
||||
arg0.to_arrow_array().data_type(),
|
||||
expr0,
|
||||
expr1
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn max(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
Max,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn min(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
Min,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn stddev(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
Stddev,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn stddev_pop(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
StddevPop,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn sum(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
Sum,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn variance(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
Variance,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn variance_pop(values: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
bind_aggr_fn!(
|
||||
VariancePop,
|
||||
vm,
|
||||
&[values.to_arrow_array()],
|
||||
values.to_arrow_array().data_type(),
|
||||
expr0
|
||||
);
|
||||
}
|
||||
|
||||
/// Pow function, bind from gp's [`PowFunction`]
|
||||
#[pyfunction]
|
||||
fn pow(base: PyObjectRef, pow: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let base = base
|
||||
.payload::<PyVector>()
|
||||
.ok_or_else(|| type_cast_error(&base.class().name(), "vector", vm))?;
|
||||
let len_base = base.as_vector_ref().len();
|
||||
let arg_pow = if is_instance::<PyVector>(&pow, vm) {
|
||||
let pow = pow
|
||||
.payload::<PyVector>()
|
||||
.ok_or_else(|| type_cast_error(&pow.class().name(), "vector", vm))?;
|
||||
pow.as_vector_ref()
|
||||
} else if is_instance::<PyFloat>(&pow, vm) {
|
||||
let pow = pow.try_into_value::<f64>(vm)?;
|
||||
let ret =
|
||||
ConstantVector::new(Arc::new(Float64Vector::from_vec(vec![pow])) as _, len_base);
|
||||
Arc::new(ret) as _
|
||||
} else if is_instance::<PyInt>(&pow, vm) {
|
||||
let pow = pow.try_into_value::<i64>(vm)?;
|
||||
let ret =
|
||||
ConstantVector::new(Arc::new(Int64Vector::from_vec(vec![pow])) as _, len_base);
|
||||
Arc::new(ret) as _
|
||||
} else {
|
||||
return Err(vm.new_type_error(format!("Unsupported type({:#?}) for pow()", pow)));
|
||||
};
|
||||
// pyfunction can return PyResult<...>, args can be like PyObjectRef or anything
|
||||
// impl IntoPyNativeFunc, see rustpython-vm function for more details
|
||||
let args = vec![base.as_vector_ref(), arg_pow];
|
||||
let res = PowFunction::default()
|
||||
.eval(FunctionContext::default(), &args)
|
||||
.unwrap();
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
// TODO: prev, sum, pow, sqrt, datetime, slice, and filter(through boolean array)
|
||||
|
||||
/// TODO: for now prev(arr)[0] == arr[0], need better fill method
|
||||
#[pyfunction]
|
||||
fn prev(cur: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let cur: ArrayRef = cur.to_arrow_array();
|
||||
if cur.len() == 0 {
|
||||
return Err(
|
||||
vm.new_runtime_error("Can't give prev for a zero length array!".to_string())
|
||||
);
|
||||
}
|
||||
let cur = cur.slice(0, cur.len() - 1); // except the last one that is
|
||||
let fill = cur.slice(0, 1);
|
||||
let ret = compute::concatenate::concatenate(&[&*fill, &*cur]).map_err(|err| {
|
||||
vm.new_runtime_error(format!("Can't concat array[0] with array[0:-1]!{err:#?}"))
|
||||
})?;
|
||||
let ret = Helper::try_into_vector(&*ret).map_err(|e| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast result into vector, result: {:?}, err: {:?}",
|
||||
ret, e
|
||||
))
|
||||
})?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn datetime(input: &PyStr, vm: &VirtualMachine) -> PyResult<i64> {
|
||||
let mut parsed = Vec::new();
|
||||
let mut prev = 0;
|
||||
#[derive(Debug)]
|
||||
enum State {
|
||||
Num(i64),
|
||||
Separator(String),
|
||||
}
|
||||
let mut state = State::Num(Default::default());
|
||||
let input = input.as_str();
|
||||
for (idx, ch) in input.chars().enumerate() {
|
||||
match (ch.is_ascii_digit(), &state) {
|
||||
(true, State::Separator(_)) => {
|
||||
let res = &input[prev..idx];
|
||||
let res = State::Separator(res.to_owned());
|
||||
parsed.push(res);
|
||||
prev = idx;
|
||||
state = State::Num(Default::default());
|
||||
}
|
||||
(false, State::Num(_)) => {
|
||||
let res = str::parse(&input[prev..idx]).map_err(|err| {
|
||||
vm.new_runtime_error(format!("Fail to parse num: {err:#?}"))
|
||||
})?;
|
||||
let res = State::Num(res);
|
||||
parsed.push(res);
|
||||
prev = idx;
|
||||
state = State::Separator(Default::default());
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
}
|
||||
let last = match state {
|
||||
State::Num(_) => {
|
||||
let res = str::parse(&input[prev..])
|
||||
.map_err(|err| vm.new_runtime_error(format!("Fail to parse num: {err:#?}")))?;
|
||||
State::Num(res)
|
||||
}
|
||||
State::Separator(_) => {
|
||||
let res = &input[prev..];
|
||||
State::Separator(res.to_owned())
|
||||
}
|
||||
};
|
||||
parsed.push(last);
|
||||
let mut cur_idx = 0;
|
||||
let mut tot_time = 0;
|
||||
fn factor(unit: &str, vm: &VirtualMachine) -> PyResult<i64> {
|
||||
let ret = match unit {
|
||||
"d" => 24 * 60 * 60,
|
||||
"h" => 60 * 60,
|
||||
"m" => 60,
|
||||
"s" => 1,
|
||||
_ => return Err(vm.new_type_error(format!("Unknown time unit: {unit}"))),
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
while cur_idx < parsed.len() {
|
||||
match &parsed[cur_idx] {
|
||||
State::Num(v) => {
|
||||
if cur_idx + 1 > parsed.len() {
|
||||
return Err(vm.new_runtime_error(
|
||||
"Expect a spearator after number, found nothing!".to_string(),
|
||||
));
|
||||
}
|
||||
let nxt = &parsed[cur_idx + 1];
|
||||
if let State::Separator(sep) = nxt {
|
||||
tot_time += v * factor(sep, vm)?;
|
||||
} else {
|
||||
return Err(vm.new_runtime_error(format!(
|
||||
"Expect a spearator after number, found `{nxt:#?}`"
|
||||
)));
|
||||
}
|
||||
cur_idx += 2;
|
||||
}
|
||||
State::Separator(sep) => {
|
||||
return Err(vm.new_runtime_error(format!("Expect a number, found `{sep}`")))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(tot_time)
|
||||
}
|
||||
}
|
||||
77
src/script/src/python/builtins/test.rs
Normal file
77
src/script/src/python/builtins/test.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::PrimitiveArray;
|
||||
use rustpython_vm::class::PyClassImpl;
|
||||
|
||||
use super::*;
|
||||
use crate::python::utils::format_py_error;
|
||||
#[test]
|
||||
fn convert_scalar_to_py_obj_and_back() {
|
||||
rustpython_vm::Interpreter::with_init(Default::default(), |vm| {
|
||||
// this can be in `.enter()` closure, but for clearity, put it in the `with_init()`
|
||||
PyVector::make_class(&vm.ctx);
|
||||
})
|
||||
.enter(|vm| {
|
||||
let col = DFColValue::Scalar(ScalarValue::Float64(Some(1.0)));
|
||||
let to = try_into_py_obj(col, vm).unwrap();
|
||||
let back = try_into_columnar_value(to, vm).unwrap();
|
||||
if let DFColValue::Scalar(ScalarValue::Float64(Some(v))) = back {
|
||||
if (v - 1.0).abs() > 2.0 * f64::EPSILON {
|
||||
panic!("Expect 1.0, found {v}")
|
||||
}
|
||||
} else {
|
||||
panic!("Convert errors, expect 1.0")
|
||||
}
|
||||
let col = DFColValue::Scalar(ScalarValue::Int64(Some(1)));
|
||||
let to = try_into_py_obj(col, vm).unwrap();
|
||||
let back = try_into_columnar_value(to, vm).unwrap();
|
||||
if let DFColValue::Scalar(ScalarValue::Int64(Some(v))) = back {
|
||||
assert_eq!(v, 1);
|
||||
} else {
|
||||
panic!("Convert errors, expect 1")
|
||||
}
|
||||
let col = DFColValue::Scalar(ScalarValue::UInt64(Some(1)));
|
||||
let to = try_into_py_obj(col, vm).unwrap();
|
||||
let back = try_into_columnar_value(to, vm).unwrap();
|
||||
if let DFColValue::Scalar(ScalarValue::Int64(Some(v))) = back {
|
||||
assert_eq!(v, 1);
|
||||
} else {
|
||||
panic!("Convert errors, expect 1")
|
||||
}
|
||||
let col = DFColValue::Scalar(ScalarValue::List(
|
||||
Some(Box::new(vec![
|
||||
ScalarValue::Int64(Some(1)),
|
||||
ScalarValue::Int64(Some(2)),
|
||||
])),
|
||||
Box::new(DataType::Int64),
|
||||
));
|
||||
let to = try_into_py_obj(col, vm).unwrap();
|
||||
let back = try_into_columnar_value(to, vm).unwrap();
|
||||
if let DFColValue::Scalar(ScalarValue::List(Some(list), ty)) = back {
|
||||
assert_eq!(list.len(), 2);
|
||||
assert_eq!(ty.as_ref(), &DataType::Int64);
|
||||
}
|
||||
let list: Vec<PyObjectRef> = vec![vm.ctx.new_int(1).into(), vm.ctx.new_int(2).into()];
|
||||
let nested_list: Vec<PyObjectRef> =
|
||||
vec![vm.ctx.new_list(list).into(), vm.ctx.new_int(3).into()];
|
||||
let list_obj = vm.ctx.new_list(nested_list).into();
|
||||
let col = try_into_columnar_value(list_obj, vm);
|
||||
if let Err(err) = col {
|
||||
let reason = format_py_error(err, vm);
|
||||
assert!(format!("{}", reason).contains(
|
||||
"TypeError: All elements in a list should be same type to cast to Datafusion list!"
|
||||
));
|
||||
}
|
||||
|
||||
let list: PyVector = PyVector::from(
|
||||
HelperVec::try_into_vector(
|
||||
Arc::new(PrimitiveArray::from_slice([0.1f64, 0.2, 0.3, 0.4])) as ArrayRef,
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
let nested_list: Vec<PyObjectRef> = vec![list.into_pyobject(vm), vm.ctx.new_int(3).into()];
|
||||
let list_obj = vm.ctx.new_list(nested_list).into();
|
||||
let expect_err = try_into_columnar_value(list_obj, vm);
|
||||
assert!(expect_err.is_err());
|
||||
})
|
||||
}
|
||||
784
src/script/src/python/builtins/testcases.ron
Normal file
784
src/script/src/python/builtins/testcases.ron
Normal file
@@ -0,0 +1,784 @@
|
||||
// This is the file for UDF&UDAF binding from datafusion,
|
||||
// including most test for those function(except ApproxMedian which datafusion didn't implement)
|
||||
// check src/scalars/py_udf_module/test.rs for more information
|
||||
[
|
||||
// math expressions
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
),
|
||||
"pows": Var(
|
||||
ty: Int8,
|
||||
value: IntVec([0, -1, 3])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
sqrt(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
1.0,
|
||||
1.4142135623730951,
|
||||
1.7320508075688772,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
sin(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.8414709848078965,
|
||||
0.9092974268256817,
|
||||
0.1411200080598672,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
cos(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.5403023058681398,
|
||||
-0.4161468365471424,
|
||||
-0.9899924966004454,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
tan(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
1.557407724654902,
|
||||
-2.185039863261519,
|
||||
-0.1425465430742778,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([0.3, 0.5, 1.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
asin(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.30469265401539747,
|
||||
0.5235987755982988,
|
||||
1.5707963267948966,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([0.3, 0.5, 1.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
acos(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
1.266103672779499,
|
||||
1.0471975511965976,
|
||||
0.0,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([0.3, 0.5, 1.1])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
atan(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.2914567944778671,
|
||||
0.46364760900080615,
|
||||
0.8329812666744317,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([0.3, 0.5, 1.1])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
floor(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.0,
|
||||
0.0,
|
||||
1.0,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([0.3, 0.5, 1.1])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ceil(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
1.0,
|
||||
1.0,
|
||||
2.0,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([0.3, 0.5, 1.1])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
round(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.0,
|
||||
1.0,
|
||||
1.0,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([0.3, 0.5, 1.1])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
trunc(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.0,
|
||||
0.0,
|
||||
1.0,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([-0.3, 0.5, -1.1])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
abs(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.3,
|
||||
0.5,
|
||||
1.1,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([-0.3, 0.5, -1.1])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
signum(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
-1.0,
|
||||
1.0,
|
||||
-1.0,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([0, 1.0, 2.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
exp(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
1.0,
|
||||
2.718281828459045,
|
||||
7.38905609893065,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
ln(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.0,
|
||||
0.6931471805599453,
|
||||
1.0986122886681098,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
log2(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.0,
|
||||
1.0,
|
||||
1.584962500721156,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
log10(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec(
|
||||
[
|
||||
0.0,
|
||||
0.3010299956639812,
|
||||
0.47712125471966244,
|
||||
],
|
||||
),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
random(42)"#,
|
||||
expect: Ok((
|
||||
value: LenFloatVec(42),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
|
||||
// UDAF(Aggerate function)
|
||||
// approx function is indeterministic
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: IntVec([1, 2, 2, 3])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
approx_distinct(values)"#,
|
||||
expect: Ok((
|
||||
value: Int(3),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
// not impl in datafusion
|
||||
/*
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: IntVec([1, 2, 2, 3])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
approx_median(values)"#,
|
||||
expect: Ok((
|
||||
value: Int(2),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
*/
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: IntVec([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
approx_percentile_cont(values, 0.6)"#,
|
||||
expect: Ok((
|
||||
value: Int(6),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
array_agg(values)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec([1.0, 2.0, 3.0]),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
avg(values)"#,
|
||||
expect: Ok((
|
||||
value: Float(2.0),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"a": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
),
|
||||
"b": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 0.0, -1.0])
|
||||
),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
correlation(a, b)"#,
|
||||
expect: Ok((
|
||||
value: Float(-1.0),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: IntVec([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
count(values)"#,
|
||||
expect: Ok((
|
||||
value: Int(10),
|
||||
ty: Int64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"a": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
),
|
||||
"b": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 0.0, -1.0])
|
||||
),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
covariance(a, b)"#,
|
||||
expect: Ok((
|
||||
value: Float(-1.0),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"a": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
),
|
||||
"b": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 0.0, -1.0])
|
||||
),
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
covariance_pop(a, b)"#,
|
||||
expect: Ok((
|
||||
value: Float(-0.6666666666666666),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: IntVec([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
max(values)"#,
|
||||
expect: Ok((
|
||||
value: Int(10),
|
||||
ty: Int64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: IntVec([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
min(values)"#,
|
||||
expect: Ok((
|
||||
value: Int(1),
|
||||
ty: Int64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
stddev(values)"#,
|
||||
expect: Ok((
|
||||
value: Float(3.0276503540974917),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
stddev_pop(values)"#,
|
||||
expect: Ok((
|
||||
value: Float(2.8722813232690143),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
sum(values)"#,
|
||||
expect: Ok((
|
||||
value: Float(55),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
variance(values)"#,
|
||||
expect: Ok((
|
||||
value: Float(9.166666666666666),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
variance_pop(values)"#,
|
||||
expect: Ok((
|
||||
value: Float(8.25),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
|
||||
|
||||
// GrepTime's own UDF
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
),
|
||||
"pows": Var(
|
||||
ty: Int8,
|
||||
value: IntVec([0, -1, 3])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
pow(values, pows)"#,
|
||||
expect: Ok((
|
||||
value: FloatVec([ 1.0, 0.5, 27.0]),
|
||||
ty: Float64
|
||||
))
|
||||
),
|
||||
|
||||
// Error handling test
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
),
|
||||
"pows": Var(
|
||||
ty: Int8,
|
||||
value: IntVec([0, 0, 0])
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
pow(values, 1)"#,
|
||||
expect: Err("TypeError: Can't cast operand of type `int` into `vector`.")
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
),
|
||||
"pows": Var(
|
||||
ty: Int8,
|
||||
value: IntVec([0, 0, 0])
|
||||
),
|
||||
"num": Var(
|
||||
ty: Int64,
|
||||
value: Int(1)
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
pow(num, pows)"#,
|
||||
expect: Err("TypeError: Can't cast operand of type `int` into `vector`")
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"values": Var(
|
||||
ty: Float64,
|
||||
value: FloatVec([1.0, 2.0, 3.0])
|
||||
),
|
||||
"pows": Var(
|
||||
ty: Int8,
|
||||
value: IntVec([0, 0, 0])
|
||||
),
|
||||
"num": Var(
|
||||
ty: Int64,
|
||||
value: Int(1)
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
asin(num, pows)"#,
|
||||
expect: Err("TypeError: Expected at most 1 arguments (2 given)")
|
||||
),
|
||||
// Test Type Cast between float, int and bool
|
||||
TestCase(
|
||||
input: {
|
||||
"num": Var(
|
||||
ty: Int64,
|
||||
value: Int(1)
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
sin(num)"#,
|
||||
expect: Ok((
|
||||
ty: Float64,
|
||||
value: Float(0.8414709848078965)
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"num": Var(
|
||||
ty: Float64,
|
||||
value: Float(1.0)
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
sin(num)"#,
|
||||
expect: Ok((
|
||||
ty: Float64,
|
||||
value: Float(0.8414709848078965)
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
sin(True)"#,
|
||||
expect: Ok((
|
||||
ty: Float64,
|
||||
value: Float(0.8414709848078965)
|
||||
))
|
||||
),
|
||||
TestCase(
|
||||
input: {
|
||||
"num": Var(
|
||||
ty: Boolean,
|
||||
value: Bool(false)
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
sin(num)"#,
|
||||
expect: Ok((
|
||||
ty: Float64,
|
||||
value: Float(0.0)
|
||||
))
|
||||
),
|
||||
// test if string returns error correctly
|
||||
TestCase(
|
||||
input: {
|
||||
"num": Var(
|
||||
ty: Boolean,
|
||||
value: Str("42")
|
||||
)
|
||||
},
|
||||
script: r#"
|
||||
from greptime import *
|
||||
sin(num)"#,
|
||||
expect: Err("Can't cast object of type str into vector or scalar")
|
||||
),
|
||||
]
|
||||
610
src/script/src/python/coprocessor.rs
Normal file
610
src/script/src/python/coprocessor.rs
Normal file
@@ -0,0 +1,610 @@
|
||||
pub mod parse;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
|
||||
use datatypes::arrow;
|
||||
use datatypes::arrow::array::{Array, ArrayRef, BooleanArray, PrimitiveArray};
|
||||
use datatypes::arrow::compute::cast::CastOptions;
|
||||
use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
|
||||
use datatypes::schema::Schema;
|
||||
use datatypes::vectors::Helper;
|
||||
use datatypes::vectors::{BooleanVector, Vector, VectorRef};
|
||||
use rustpython_bytecode::CodeObject;
|
||||
use rustpython_compiler_core::compile;
|
||||
use rustpython_parser::{
|
||||
ast,
|
||||
ast::{Located, Location},
|
||||
parser,
|
||||
};
|
||||
use rustpython_vm as vm;
|
||||
use rustpython_vm::{class::PyClassImpl, AsObject};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use vm::builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt, PyTuple};
|
||||
use vm::scope::Scope;
|
||||
use vm::{Interpreter, PyObjectRef, VirtualMachine};
|
||||
|
||||
use crate::fail_parse_error;
|
||||
use crate::python::builtins::greptime_builtin;
|
||||
use crate::python::coprocessor::parse::{ret_parse_error, DecoratorArgs};
|
||||
use crate::python::error::{
|
||||
ensure, ArrowSnafu, CoprParseSnafu, OtherSnafu, PyCompileSnafu, PyParseSnafu, Result,
|
||||
TypeCastSnafu,
|
||||
};
|
||||
use crate::python::utils::format_py_error;
|
||||
use crate::python::{utils::is_instance, PyVector};
|
||||
|
||||
fn ret_other_error_with(reason: String) -> OtherSnafu<String> {
|
||||
OtherSnafu { reason }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use serde::Deserialize;
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AnnotationInfo {
|
||||
/// if None, use types infered by PyVector
|
||||
pub datatype: Option<DataType>,
|
||||
pub is_nullable: bool,
|
||||
}
|
||||
|
||||
pub type CoprocessorRef = Arc<Coprocessor>;
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct Coprocessor {
|
||||
pub name: String,
|
||||
pub deco_args: DecoratorArgs,
|
||||
/// get from python function args' annotation, first is type, second is is_nullable
|
||||
pub arg_types: Vec<Option<AnnotationInfo>>,
|
||||
/// get from python function returns' annotation, first is type, second is is_nullable
|
||||
pub return_types: Vec<Option<AnnotationInfo>>,
|
||||
/// store its corresponding script, also skip serde when in `cfg(test)` to reduce work in compare
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub script: String,
|
||||
}
|
||||
|
||||
impl Coprocessor {
|
||||
/// generate a call to the coprocessor function
|
||||
/// with arguments given in decorator's `args` list
|
||||
/// also set in location in source code to `loc`
|
||||
fn gen_call(&self, loc: &Location) -> ast::Stmt<()> {
|
||||
let mut loc = loc.to_owned();
|
||||
// adding a line to avoid confusing if any error occurs when calling the function
|
||||
// then the pretty print will point to the last line in code
|
||||
// instead of point to any of existing code written by user.
|
||||
loc.newline();
|
||||
let args: Vec<Located<ast::ExprKind>> = self
|
||||
.deco_args
|
||||
.arg_names
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let node = ast::ExprKind::Name {
|
||||
id: v.to_owned(),
|
||||
ctx: ast::ExprContext::Load,
|
||||
};
|
||||
create_located(node, loc)
|
||||
})
|
||||
.collect();
|
||||
let func = ast::ExprKind::Call {
|
||||
func: Box::new(create_located(
|
||||
ast::ExprKind::Name {
|
||||
id: self.name.to_owned(),
|
||||
ctx: ast::ExprContext::Load,
|
||||
},
|
||||
loc,
|
||||
)),
|
||||
args,
|
||||
keywords: Vec::new(),
|
||||
};
|
||||
let stmt = ast::StmtKind::Expr {
|
||||
value: Box::new(create_located(func, loc)),
|
||||
};
|
||||
create_located(stmt, loc)
|
||||
}
|
||||
|
||||
/// check if `Mod` is of one line of statement
|
||||
fn check_before_compile(top: &ast::Mod) -> Result<()> {
|
||||
if let ast::Mod::Interactive { body: code } = top {
|
||||
ensure!(
|
||||
code.len() == 1,
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"Expect only one statement in script, found {} statement",
|
||||
code.len()
|
||||
),
|
||||
loc: code.first().map(|s| s.location)
|
||||
}
|
||||
);
|
||||
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
name: _,
|
||||
args: _,
|
||||
body: _,
|
||||
decorator_list: _,
|
||||
returns: _,
|
||||
type_comment: __main__,
|
||||
} = &code[0].node
|
||||
{
|
||||
} else {
|
||||
return fail_parse_error!(
|
||||
format!("Expect the one and only statement in script as a function def, but instead found: {:?}", code[0].node),
|
||||
Some(code[0].location)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
return fail_parse_error!(
|
||||
format!("Expect statement in script, found: {:?}", top),
|
||||
None,
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// stripe the decorator(`@xxxx`) and type annotation(for type checker is done in rust function), add one line in the ast for call function with given parameter, and compiler into `CodeObject`
|
||||
///
|
||||
/// The rationale is that rustpython's vm is not very efficient according to [offical benchmark](https://rustpython.github.io/benchmarks),
|
||||
/// So we should avoid running too much Python Bytecode, hence in this function we delete `@` decorator(instead of actually write a decorator in python)
|
||||
/// And add a function call in the end and also
|
||||
/// strip type annotation
|
||||
fn strip_append_and_compile(&self) -> Result<CodeObject> {
|
||||
let script = &self.script;
|
||||
// note that it's important to use `parser::Mode::Interactive` so the ast can be compile to return a result instead of return None in eval mode
|
||||
let mut top = parser::parse(script, parser::Mode::Interactive).context(PyParseSnafu)?;
|
||||
Self::check_before_compile(&top)?;
|
||||
// erase decorator
|
||||
if let ast::Mod::Interactive { body } = &mut top {
|
||||
let code = body;
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
name: _,
|
||||
args,
|
||||
body: _,
|
||||
decorator_list,
|
||||
returns,
|
||||
type_comment: __main__,
|
||||
} = &mut code[0].node
|
||||
{
|
||||
*decorator_list = Vec::new();
|
||||
// strip type annotation
|
||||
// def a(b: int, c:int) -> int
|
||||
// will became
|
||||
// def a(b, c)
|
||||
*returns = None;
|
||||
for arg in &mut args.args {
|
||||
arg.node.annotation = None;
|
||||
}
|
||||
} else {
|
||||
// already done in check function
|
||||
unreachable!()
|
||||
}
|
||||
let loc = code[0].location;
|
||||
|
||||
// This manually construct ast has no corrsponding code
|
||||
// in the script, so just give it a location that don't exist in orginal script
|
||||
// (which doesn't matter because Location usually only used in pretty print errors)
|
||||
code.push(self.gen_call(&loc));
|
||||
} else {
|
||||
// already done in check function
|
||||
unreachable!()
|
||||
}
|
||||
// use `compile::Mode::BlockExpr` so it return the result of statement
|
||||
compile::compile_top(
|
||||
&top,
|
||||
"<embedded>".to_owned(),
|
||||
compile::Mode::BlockExpr,
|
||||
compile::CompileOpts { optimize: 0 },
|
||||
)
|
||||
.context(PyCompileSnafu)
|
||||
}
|
||||
|
||||
/// generate [`Schema`] according to return names, types,
|
||||
/// if no annotation
|
||||
/// the datatypes of the actual columns is used directly
|
||||
fn gen_schema(&self, cols: &[ArrayRef]) -> Result<Arc<ArrowSchema>> {
|
||||
let names = &self.deco_args.ret_names;
|
||||
let anno = &self.return_types;
|
||||
ensure!(
|
||||
cols.len() == names.len() && names.len() == anno.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"Unmatched length for cols({}), names({}) and anno({})",
|
||||
cols.len(),
|
||||
names.len(),
|
||||
anno.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
Ok(Arc::new(ArrowSchema::from(
|
||||
names
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, name)| {
|
||||
let real_ty = cols[idx].data_type().to_owned();
|
||||
let AnnotationInfo {
|
||||
datatype: ty,
|
||||
is_nullable,
|
||||
} = anno[idx].to_owned().unwrap_or_else(||
|
||||
// default to be not nullable and use DataType infered by PyVector itself
|
||||
AnnotationInfo{
|
||||
datatype: Some(real_ty.to_owned()),
|
||||
is_nullable: false
|
||||
});
|
||||
Field::new(
|
||||
name,
|
||||
// if type is like `_` or `_ | None`
|
||||
ty.unwrap_or(real_ty),
|
||||
is_nullable,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<Field>>(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// check if real types and annotation types(if have) is the same, if not try cast columns to annotated type
|
||||
fn check_and_cast_type(&self, cols: &mut [ArrayRef]) -> Result<()> {
|
||||
let return_types = &self.return_types;
|
||||
// allow ignore Return Type Annotation
|
||||
if return_types.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
ensure!(
|
||||
cols.len() == return_types.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"The number of return Vector is wrong, expect {}, found {}",
|
||||
return_types.len(),
|
||||
cols.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
for (col, anno) in cols.iter_mut().zip(return_types) {
|
||||
if let Some(AnnotationInfo {
|
||||
datatype: Some(datatype),
|
||||
is_nullable: _,
|
||||
}) = anno
|
||||
{
|
||||
let real_ty = col.data_type();
|
||||
let anno_ty = datatype;
|
||||
if real_ty != anno_ty {
|
||||
{
|
||||
// This`CastOption` allow for overflowly cast and int to float loosely cast etc..,
|
||||
// check its doc for more information
|
||||
*col = arrow::compute::cast::cast(
|
||||
col.as_ref(),
|
||||
anno_ty,
|
||||
CastOptions {
|
||||
wrapped: true,
|
||||
partial: true,
|
||||
},
|
||||
)
|
||||
.context(ArrowSnafu)?
|
||||
.into();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn create_located<T>(node: T, loc: Location) -> Located<T> {
|
||||
Located::new(loc, node)
|
||||
}
|
||||
|
||||
/// cast a `dyn Array` of type unsigned/int/float into a `dyn Vector`
|
||||
fn try_into_vector<T: datatypes::types::Primitive + datatypes::types::DataTypeBuilder>(
|
||||
arg: Arc<dyn Array>,
|
||||
) -> Result<Arc<dyn Vector>> {
|
||||
// wrap try_into_vector in here to convert `datatypes::error::Error` to `python::error::Error`
|
||||
Helper::try_into_vector(arg).context(TypeCastSnafu)
|
||||
}
|
||||
|
||||
/// convert a `Vec<ArrayRef>` into a `Vec<PyVector>` only when they are of supported types
|
||||
/// PyVector now only support unsigned&int8/16/32/64, float32/64 and bool when doing meanful arithmetics operation
|
||||
fn try_into_py_vector(fetch_args: Vec<ArrayRef>) -> Result<Vec<PyVector>> {
|
||||
let mut args: Vec<PyVector> = Vec::with_capacity(fetch_args.len());
|
||||
for (idx, arg) in fetch_args.into_iter().enumerate() {
|
||||
let v: VectorRef = match arg.data_type() {
|
||||
DataType::Float32 => try_into_vector::<f32>(arg)?,
|
||||
DataType::Float64 => try_into_vector::<f64>(arg)?,
|
||||
DataType::UInt8 => try_into_vector::<u8>(arg)?,
|
||||
DataType::UInt16 => try_into_vector::<u16>(arg)?,
|
||||
DataType::UInt32 => try_into_vector::<u32>(arg)?,
|
||||
DataType::UInt64 => try_into_vector::<u64>(arg)?,
|
||||
DataType::Int8 => try_into_vector::<i8>(arg)?,
|
||||
DataType::Int16 => try_into_vector::<i16>(arg)?,
|
||||
DataType::Int32 => try_into_vector::<i32>(arg)?,
|
||||
DataType::Int64 => try_into_vector::<i64>(arg)?,
|
||||
DataType::Boolean => {
|
||||
let v: VectorRef =
|
||||
Arc::new(BooleanVector::try_from_arrow_array(arg).context(TypeCastSnafu)?);
|
||||
v
|
||||
}
|
||||
_ => {
|
||||
return ret_other_error_with(format!(
|
||||
"Unsupport data type at column {idx}: {:?} for coprocessor",
|
||||
arg.data_type()
|
||||
))
|
||||
.fail()
|
||||
}
|
||||
};
|
||||
args.push(PyVector::from(v));
|
||||
}
|
||||
Ok(args)
|
||||
}
|
||||
|
||||
/// convert a single PyVector or a number(a constant) into a Array(or a constant array)
|
||||
fn py_vec_to_array_ref(obj: &PyObjectRef, vm: &VirtualMachine, col_len: usize) -> Result<ArrayRef> {
|
||||
if is_instance::<PyVector>(obj, vm) {
|
||||
let pyv = obj.payload::<PyVector>().with_context(|| {
|
||||
ret_other_error_with(format!("can't cast obj {:?} to PyVector", obj))
|
||||
})?;
|
||||
Ok(pyv.to_arrow_array())
|
||||
} else if is_instance::<PyInt>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<i64>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
let ret = PrimitiveArray::from_vec(vec![val; col_len]);
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyFloat>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<f64>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
let ret = PrimitiveArray::from_vec(vec![val; col_len]);
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyBool>(obj, vm) {
|
||||
let val = obj
|
||||
.to_owned()
|
||||
.try_into_value::<bool>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
let ret = BooleanArray::from_iter(std::iter::repeat(Some(val)).take(5));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else {
|
||||
ret_other_error_with(format!("Expect a vector or a constant, found {:?}", obj)).fail()
|
||||
}
|
||||
}
|
||||
|
||||
/// convert a tuple of `PyVector` or one `PyVector`(wrapped in a Python Object Ref[`PyObjectRef`])
|
||||
/// to a `Vec<ArrayRef>`
|
||||
fn try_into_columns(
|
||||
obj: &PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
col_len: usize,
|
||||
) -> Result<Vec<ArrayRef>> {
|
||||
if is_instance::<PyTuple>(obj, vm) {
|
||||
let tuple = obj.payload::<PyTuple>().with_context(|| {
|
||||
ret_other_error_with(format!("can't cast obj {:?} to PyTuple)", obj))
|
||||
})?;
|
||||
let cols = tuple
|
||||
.iter()
|
||||
.map(|obj| py_vec_to_array_ref(obj, vm, col_len))
|
||||
.collect::<Result<Vec<ArrayRef>>>()?;
|
||||
Ok(cols)
|
||||
} else {
|
||||
let col = py_vec_to_array_ref(obj, vm, col_len)?;
|
||||
Ok(vec![col])
|
||||
}
|
||||
}
|
||||
|
||||
/// select columns according to `fetch_names` from `rb`
|
||||
/// and cast them into a Vec of PyVector
|
||||
fn select_from_rb(rb: &DfRecordBatch, fetch_names: &[String]) -> Result<Vec<PyVector>> {
|
||||
let field_map: HashMap<&String, usize> = rb
|
||||
.schema()
|
||||
.fields
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, field)| (&field.name, idx))
|
||||
.collect();
|
||||
let fetch_idx: Vec<usize> = fetch_names
|
||||
.iter()
|
||||
.map(|field| {
|
||||
field_map.get(field).copied().context(OtherSnafu {
|
||||
reason: format!("Can't found field name {field}"),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<usize>>>()?;
|
||||
let fetch_args: Vec<Arc<dyn Array>> = fetch_idx
|
||||
.into_iter()
|
||||
.map(|idx| rb.column(idx).clone())
|
||||
.collect();
|
||||
try_into_py_vector(fetch_args)
|
||||
}
|
||||
|
||||
/// match between arguments' real type and annotation types
|
||||
/// if type anno is vector[_] then use real type
|
||||
fn check_args_anno_real_type(
|
||||
args: &[PyVector],
|
||||
copr: &Coprocessor,
|
||||
rb: &DfRecordBatch,
|
||||
) -> Result<()> {
|
||||
for (idx, arg) in args.iter().enumerate() {
|
||||
let anno_ty = copr.arg_types[idx].to_owned();
|
||||
let real_ty = arg.to_arrow_array().data_type().to_owned();
|
||||
let is_nullable: bool = rb.schema().fields[idx].is_nullable;
|
||||
ensure!(
|
||||
anno_ty
|
||||
.to_owned()
|
||||
.map(|v| v.datatype == None // like a vector[_]
|
||||
|| v.datatype == Some(real_ty.to_owned()) && v.is_nullable == is_nullable)
|
||||
.unwrap_or(true),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"column {}'s Type annotation is {:?}, but actual type is {:?}",
|
||||
copr.deco_args.arg_names[idx], anno_ty, real_ty
|
||||
)
|
||||
}
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// set arguments with given name and values in python scopes
|
||||
fn set_items_in_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
arg_names: &[String],
|
||||
args: Vec<PyVector>,
|
||||
) -> Result<()> {
|
||||
let _ = arg_names
|
||||
.iter()
|
||||
.zip(args)
|
||||
.map(|(name, vector)| {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(name, vm.new_pyobj(vector), vm)
|
||||
})
|
||||
.collect::<StdResult<Vec<()>, PyBaseExceptionRef>>()
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The coprocessor function accept a python script and a Record Batch:
|
||||
/// ## What it does
|
||||
/// 1. it take a python script and a [`DfRecordBatch`], extract columns and annotation info according to `args` given in decorator in python script
|
||||
/// 2. execute python code and return a vector or a tuple of vector,
|
||||
/// 3. the returning vector(s) is assembled into a new [`DfRecordBatch`] according to `returns` in python decorator and return to caller
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use std::sync::Arc;
|
||||
/// use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
|
||||
/// use arrow::array::PrimitiveArray;
|
||||
/// use arrow::datatypes::{DataType, Field, Schema};
|
||||
/// use common_function::scalars::python::exec_coprocessor;
|
||||
/// let python_source = r#"
|
||||
/// @copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
/// def a(cpu, mem):
|
||||
/// return cpu + mem, cpu - mem
|
||||
/// "#;
|
||||
/// let cpu_array = PrimitiveArray::from_slice([0.9f32, 0.8, 0.7, 0.6]);
|
||||
/// let mem_array = PrimitiveArray::from_slice([0.1f64, 0.2, 0.3, 0.4]);
|
||||
/// let schema = Arc::new(Schema::from(vec![
|
||||
/// Field::new("cpu", DataType::Float32, false),
|
||||
/// Field::new("mem", DataType::Float64, false),
|
||||
/// ]));
|
||||
/// let rb =
|
||||
/// DfRecordBatch::try_new(schema, vec![Arc::new(cpu_array), Arc::new(mem_array)]).unwrap();
|
||||
/// let ret = exec_coprocessor(python_source, &rb).unwrap();
|
||||
/// assert_eq!(ret.column(0).len(), 4);
|
||||
/// ```
|
||||
///
|
||||
/// # Type Annotation
|
||||
/// you can use type annotations in args and returns to designate types, so coprocessor will check for corrsponding types.
|
||||
///
|
||||
/// Currently support types are `u8`, `u16`, `u32`, `u64`, `i8`, `i16`, `i32`, `i64` and `f16`, `f32`, `f64`
|
||||
///
|
||||
/// use `f64 | None` to mark if returning column is nullable like in [`DfRecordBatch`]'s schema's [`Field`]'s is_nullable
|
||||
///
|
||||
/// you can also use single underscore `_` to let coprocessor infer what type it is, so `_` and `_ | None` are both valid in type annotation.
|
||||
/// Note: using `_` means not nullable column, using `_ | None` means nullable column
|
||||
///
|
||||
/// a example (of python script) given below:
|
||||
/// ```python
|
||||
/// @copr(args=["cpu", "mem"], returns=["perf", "minus", "mul", "div"])
|
||||
/// def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[_ | None]):
|
||||
/// return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
/// ```
|
||||
///
|
||||
/// # Return Constant columns
|
||||
/// You can return constant in python code like `return 1, 1.0, True`
|
||||
/// which create a constant array(with same value)(currently support int, float and bool) as column on return
|
||||
#[cfg(test)]
|
||||
pub fn exec_coprocessor(script: &str, rb: &DfRecordBatch) -> Result<RecordBatch> {
|
||||
// 1. parse the script and check if it's only a function with `@coprocessor` decorator, and get `args` and `returns`,
|
||||
// 2. also check for exist of `args` in `rb`, if not found, return error
|
||||
// TODO(discord9): cache the result of parse_copr
|
||||
let copr = parse::parse_copr(script)?;
|
||||
exec_parsed(&copr, rb)
|
||||
}
|
||||
|
||||
pub(crate) fn exec_with_cached_vm(
|
||||
copr: &Coprocessor,
|
||||
rb: &DfRecordBatch,
|
||||
args: Vec<PyVector>,
|
||||
vm: &Interpreter,
|
||||
) -> Result<RecordBatch> {
|
||||
vm.enter(|vm| -> Result<RecordBatch> {
|
||||
PyVector::make_class(&vm.ctx);
|
||||
// set arguments with given name and values
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
set_items_in_scope(&scope, vm, &copr.deco_args.arg_names, args)?;
|
||||
|
||||
let code_obj = copr.strip_append_and_compile()?;
|
||||
let code_obj = vm.ctx.new_code(code_obj);
|
||||
let ret = vm
|
||||
.run_code_obj(code_obj, scope)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
// 5. get returns as either a PyVector or a PyTuple, and naming schema them according to `returns`
|
||||
let col_len = rb.num_rows();
|
||||
let mut cols: Vec<ArrayRef> = try_into_columns(&ret, vm, col_len)?;
|
||||
ensure!(
|
||||
cols.len() == copr.deco_args.ret_names.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"The number of return Vector is wrong, expect {}, found {}",
|
||||
copr.deco_args.ret_names.len(),
|
||||
cols.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
// if cols and schema's data types is not match, try coerce it to given type(if annotated)(if error occur, return relevant error with question mark)
|
||||
copr.check_and_cast_type(&mut cols)?;
|
||||
// 6. return a assembled DfRecordBatch
|
||||
let schema = copr.gen_schema(&cols)?;
|
||||
let res_rb = DfRecordBatch::try_new(schema.clone(), cols).context(ArrowSnafu)?;
|
||||
Ok(RecordBatch {
|
||||
schema: Arc::new(Schema::try_from(schema).context(TypeCastSnafu)?),
|
||||
df_recordbatch: res_rb,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// init interpreter with type PyVector and Module: greptime
|
||||
pub(crate) fn init_interpreter() -> Interpreter {
|
||||
vm::Interpreter::with_init(Default::default(), |vm| {
|
||||
PyVector::make_class(&vm.ctx);
|
||||
vm.add_native_module("greptime", Box::new(greptime_builtin::make_module));
|
||||
})
|
||||
}
|
||||
|
||||
/// using a parsed `Coprocessor` struct as input to execute python code
|
||||
pub(crate) fn exec_parsed(copr: &Coprocessor, rb: &DfRecordBatch) -> Result<RecordBatch> {
|
||||
// 3. get args from `rb`, and cast them into PyVector
|
||||
let args: Vec<PyVector> = select_from_rb(rb, &copr.deco_args.arg_names)?;
|
||||
check_args_anno_real_type(&args, copr, rb)?;
|
||||
let interpreter = init_interpreter();
|
||||
// 4. then set args in scope and compile then run `CodeObject` which already append a new `Call` node
|
||||
exec_with_cached_vm(copr, rb, args, &interpreter)
|
||||
}
|
||||
|
||||
/// execute script just like [`exec_coprocessor`] do,
|
||||
/// but instead of return a internal [`Error`] type,
|
||||
/// return a friendly String format of error
|
||||
///
|
||||
/// use `ln_offset` and `filename` to offset line number and mark file name in error prompt
|
||||
#[cfg(test)]
|
||||
#[allow(dead_code)]
|
||||
pub fn exec_copr_print(
|
||||
script: &str,
|
||||
rb: &DfRecordBatch,
|
||||
ln_offset: usize,
|
||||
filename: &str,
|
||||
) -> StdResult<RecordBatch, String> {
|
||||
let res = exec_coprocessor(script, rb);
|
||||
res.map_err(|e| {
|
||||
crate::python::error::pretty_print_error_in_src(script, &e, ln_offset, filename)
|
||||
})
|
||||
}
|
||||
521
src/script/src/python/coprocessor/parse.rs
Normal file
521
src/script/src/python/coprocessor/parse.rs
Normal file
@@ -0,0 +1,521 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use rustpython_parser::{
|
||||
ast,
|
||||
ast::{Arguments, Location},
|
||||
parser,
|
||||
};
|
||||
#[cfg(test)]
|
||||
use serde::Deserialize;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::python::coprocessor::AnnotationInfo;
|
||||
use crate::python::coprocessor::Coprocessor;
|
||||
use crate::python::error::{ensure, CoprParseSnafu, PyParseSnafu, Result};
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DecoratorArgs {
|
||||
pub arg_names: Vec<String>,
|
||||
pub ret_names: Vec<String>,
|
||||
pub sql: Option<String>,
|
||||
// maybe add a URL for connecting or what?
|
||||
// also predicate for timed triggered or conditional triggered?
|
||||
}
|
||||
|
||||
/// Return a CoprParseSnafu for you to chain fail() to return correct err Result type
|
||||
pub(crate) fn ret_parse_error(
|
||||
reason: String,
|
||||
loc: Option<Location>,
|
||||
) -> CoprParseSnafu<String, Option<Location>> {
|
||||
CoprParseSnafu { reason, loc }
|
||||
}
|
||||
|
||||
/// append a `.fail()` after `ret_parse_error`, so compiler can return a Err(this error)
|
||||
#[macro_export]
|
||||
macro_rules! fail_parse_error {
|
||||
($reason:expr, $loc:expr $(,)*) => {
|
||||
ret_parse_error($reason, $loc).fail()
|
||||
};
|
||||
}
|
||||
|
||||
fn py_str_to_string(s: &ast::Expr<()>) -> Result<String> {
|
||||
if let ast::ExprKind::Constant {
|
||||
value: ast::Constant::Str(v),
|
||||
kind: _,
|
||||
} = &s.node
|
||||
{
|
||||
Ok(v.to_owned())
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!(
|
||||
"Expect a list of String, found one element to be: \n{:#?}",
|
||||
&s.node
|
||||
),
|
||||
Some(s.location)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// turn a python list of string in ast form(a `ast::Expr`) of string into a `Vec<String>`
|
||||
fn pylist_to_vec(lst: &ast::Expr<()>) -> Result<Vec<String>> {
|
||||
if let ast::ExprKind::List { elts, ctx: _ } = &lst.node {
|
||||
let ret = elts.iter().map(py_str_to_string).collect::<Result<_>>()?;
|
||||
Ok(ret)
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!("Expect a list, found \n{:#?}", &lst.node),
|
||||
Some(lst.location)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn try_into_datatype(ty: &str, loc: &Location) -> Result<Option<DataType>> {
|
||||
match ty {
|
||||
"bool" => Ok(Some(DataType::Boolean)),
|
||||
"u8" => Ok(Some(DataType::UInt8)),
|
||||
"u16" => Ok(Some(DataType::UInt16)),
|
||||
"u32" => Ok(Some(DataType::UInt32)),
|
||||
"u64" => Ok(Some(DataType::UInt64)),
|
||||
"i8" => Ok(Some(DataType::Int8)),
|
||||
"i16" => Ok(Some(DataType::Int16)),
|
||||
"i32" => Ok(Some(DataType::Int32)),
|
||||
"i64" => Ok(Some(DataType::Int64)),
|
||||
"f16" => Ok(Some(DataType::Float16)),
|
||||
"f32" => Ok(Some(DataType::Float32)),
|
||||
"f64" => Ok(Some(DataType::Float64)),
|
||||
// for any datatype
|
||||
"_" => Ok(None),
|
||||
// note the different between "_" and _
|
||||
_ => fail_parse_error!(
|
||||
format!("Unknown datatype: {ty} at {}", loc),
|
||||
Some(loc.to_owned())
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Item => NativeType
|
||||
/// default to be not nullable
|
||||
fn parse_native_type(sub: &ast::Expr<()>) -> Result<AnnotationInfo> {
|
||||
match &sub.node {
|
||||
ast::ExprKind::Name { id, .. } => Ok(AnnotationInfo {
|
||||
datatype: try_into_datatype(id, &sub.location)?,
|
||||
is_nullable: false,
|
||||
}),
|
||||
_ => fail_parse_error!(
|
||||
format!("Expect types' name, found \n{:#?}", &sub.node),
|
||||
Some(sub.location)
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// check if binary op expr is legal(with one typename and one `None`)
|
||||
fn check_bin_op(bin_op: &ast::Expr<()>) -> Result<()> {
|
||||
if let ast::ExprKind::BinOp { left, op: _, right } = &bin_op.node {
|
||||
// 1. first check if this BinOp is legal(Have one typename and(optional) a None)
|
||||
let is_none = |node: &ast::Expr<()>| -> bool {
|
||||
matches!(
|
||||
&node.node,
|
||||
ast::ExprKind::Constant {
|
||||
value: ast::Constant::None,
|
||||
kind: _,
|
||||
}
|
||||
)
|
||||
};
|
||||
let is_type = |node: &ast::Expr<()>| {
|
||||
if let ast::ExprKind::Name { id, ctx: _ } = &node.node {
|
||||
try_into_datatype(id, &node.location).is_ok()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
let left_is_ty = is_type(left);
|
||||
let left_is_none = is_none(left);
|
||||
let right_is_ty = is_type(right);
|
||||
let right_is_none = is_none(right);
|
||||
if left_is_ty && right_is_ty || left_is_none && right_is_none {
|
||||
fail_parse_error!(
|
||||
"Expect one typenames and one `None`".to_string(),
|
||||
Some(bin_op.location)
|
||||
)?;
|
||||
} else if !(left_is_none && right_is_ty || left_is_ty && right_is_none) {
|
||||
fail_parse_error!(
|
||||
format!(
|
||||
"Expect a type name and a `None`, found left: \n{:#?} \nand right: \n{:#?}",
|
||||
&left.node, &right.node
|
||||
),
|
||||
Some(bin_op.location)
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!(
|
||||
"Expect binary ops like `DataType | None`, found \n{:#?}",
|
||||
bin_op.node
|
||||
),
|
||||
Some(bin_op.location)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// parse a `DataType | None` or a single `DataType`
|
||||
fn parse_bin_op(bin_op: &ast::Expr<()>) -> Result<AnnotationInfo> {
|
||||
// 1. first check if this BinOp is legal(Have one typename and(optional) a None)
|
||||
check_bin_op(bin_op)?;
|
||||
if let ast::ExprKind::BinOp { left, op: _, right } = &bin_op.node {
|
||||
// then get types from this BinOp
|
||||
let left_ty = parse_native_type(left).ok();
|
||||
let right_ty = parse_native_type(right).ok();
|
||||
let mut ty_anno = if let Some(left_ty) = left_ty {
|
||||
left_ty
|
||||
} else if let Some(right_ty) = right_ty {
|
||||
right_ty
|
||||
} else {
|
||||
// deal with errors anyway in case code above changed but forget to modify
|
||||
return fail_parse_error!(
|
||||
"Expect a type name, not two `None`".into(),
|
||||
Some(bin_op.location),
|
||||
);
|
||||
};
|
||||
// because check_bin_op assure a `None` exist
|
||||
ty_anno.is_nullable = true;
|
||||
return Ok(ty_anno);
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
/// check for the grammar correctness of annotation, also return the slice of subscript for further parsing
|
||||
fn check_annotation_ret_slice(sub: &ast::Expr<()>) -> Result<&ast::Expr<()>> {
|
||||
// TODO(discord9): allow a single annotation like `vector`
|
||||
if let ast::ExprKind::Subscript {
|
||||
value,
|
||||
slice,
|
||||
ctx: _,
|
||||
} = &sub.node
|
||||
{
|
||||
if let ast::ExprKind::Name { id, ctx: _ } = &value.node {
|
||||
ensure!(
|
||||
id == "vector",
|
||||
ret_parse_error(
|
||||
format!(
|
||||
"Wrong type annotation, expect `vector[...]`, found `{}`",
|
||||
id
|
||||
),
|
||||
Some(value.location)
|
||||
)
|
||||
);
|
||||
} else {
|
||||
return fail_parse_error!(
|
||||
format!("Expect \"vector\", found \n{:#?}", &value.node),
|
||||
Some(value.location)
|
||||
);
|
||||
}
|
||||
Ok(slice)
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!("Expect type annotation, found \n{:#?}", &sub),
|
||||
Some(sub.location)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// where:
|
||||
///
|
||||
/// Start => vector`[`TYPE`]`
|
||||
///
|
||||
/// TYPE => Item | Item `|` None
|
||||
///
|
||||
/// Item => NativeType
|
||||
fn parse_annotation(sub: &ast::Expr<()>) -> Result<AnnotationInfo> {
|
||||
let slice = check_annotation_ret_slice(sub)?;
|
||||
|
||||
{
|
||||
// i.e: vector[f64]
|
||||
match &slice.node {
|
||||
ast::ExprKind::Name { .. } => parse_native_type(slice),
|
||||
ast::ExprKind::BinOp {
|
||||
left: _,
|
||||
op: _,
|
||||
right: _,
|
||||
} => parse_bin_op(slice),
|
||||
_ => {
|
||||
fail_parse_error!(
|
||||
format!("Expect type in `vector[...]`, found \n{:#?}", &slice.node),
|
||||
Some(slice.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// parse a list of keyword and return args and returns list from keywords
|
||||
fn parse_keywords(keywords: &Vec<ast::Keyword<()>>) -> Result<DecoratorArgs> {
|
||||
// more keys maybe add to this list of `avail_key`(like `sql` for querying and maybe config for connecting to database?), for better extension using a `HashSet` in here
|
||||
let avail_key = HashSet::from(["args", "returns", "sql"]);
|
||||
let opt_keys = HashSet::from(["sql"]);
|
||||
let mut visited_key = HashSet::new();
|
||||
let len_min = avail_key.len() - opt_keys.len();
|
||||
let len_max = avail_key.len();
|
||||
ensure!(
|
||||
// "sql" is optional(for now)
|
||||
keywords.len() >= len_min && keywords.len() <= len_max,
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"Expect between {len_min} and {len_max} keyword argument, found {}.",
|
||||
keywords.len()
|
||||
),
|
||||
loc: keywords.get(0).map(|s| s.location)
|
||||
}
|
||||
);
|
||||
let mut ret_args = DecoratorArgs::default();
|
||||
for kw in keywords {
|
||||
match &kw.node.arg {
|
||||
Some(s) => {
|
||||
let s = s.as_str();
|
||||
if visited_key.contains(s) {
|
||||
return fail_parse_error!(
|
||||
format!("`{s}` occur multiple times in decorator's arguements' list."),
|
||||
Some(kw.location),
|
||||
);
|
||||
}
|
||||
if !avail_key.contains(s) {
|
||||
return fail_parse_error!(
|
||||
format!("Expect one of {:?}, found `{}`", &avail_key, s),
|
||||
Some(kw.location),
|
||||
);
|
||||
} else {
|
||||
visited_key.insert(s);
|
||||
}
|
||||
match s {
|
||||
"args" => ret_args.arg_names = pylist_to_vec(&kw.node.value)?,
|
||||
"returns" => ret_args.ret_names = pylist_to_vec(&kw.node.value)?,
|
||||
"sql" => ret_args.sql = Some(py_str_to_string(&kw.node.value)?),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"Expect explictly set both `args` and `returns`, found \n{:#?}",
|
||||
&kw.node
|
||||
),
|
||||
Some(kw.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
let loc = keywords[0].location;
|
||||
for key in avail_key {
|
||||
if !visited_key.contains(key) && !opt_keys.contains(key) {
|
||||
return fail_parse_error!(format!("Expect `{key}` keyword"), Some(loc));
|
||||
}
|
||||
}
|
||||
Ok(ret_args)
|
||||
}
|
||||
|
||||
/// returns args and returns in Vec of String
|
||||
fn parse_decorator(decorator: &ast::Expr<()>) -> Result<DecoratorArgs> {
|
||||
//check_decorator(decorator)?;
|
||||
if let ast::ExprKind::Call {
|
||||
func,
|
||||
args: _,
|
||||
keywords,
|
||||
} = &decorator.node
|
||||
{
|
||||
ensure!(
|
||||
func.node
|
||||
== ast::ExprKind::Name {
|
||||
id: "copr".to_string(),
|
||||
ctx: ast::ExprContext::Load
|
||||
}
|
||||
|| func.node
|
||||
== ast::ExprKind::Name {
|
||||
id: "coprocessor".to_string(),
|
||||
ctx: ast::ExprContext::Load
|
||||
},
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"Expect decorator with name `copr` or `coprocessor`, found \n{:#?}",
|
||||
&func.node
|
||||
),
|
||||
loc: Some(func.location)
|
||||
}
|
||||
);
|
||||
parse_keywords(keywords)
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!(
|
||||
"Expect decorator to be a function call(like `@copr(...)`), found \n{:#?}",
|
||||
decorator.node
|
||||
),
|
||||
Some(decorator.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// get type annotaion in arguments
|
||||
fn get_arg_annotations(args: &Arguments) -> Result<Vec<Option<AnnotationInfo>>> {
|
||||
// get arg types from type annotation>
|
||||
args.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
if let Some(anno) = &arg.node.annotation {
|
||||
// for there is erro handling for parse_annotation
|
||||
parse_annotation(anno).map(Some)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Option<_>>>>()
|
||||
}
|
||||
|
||||
fn get_return_annotations(rets: &ast::Expr<()>) -> Result<Vec<Option<AnnotationInfo>>> {
|
||||
let mut return_types = Vec::with_capacity(match &rets.node {
|
||||
ast::ExprKind::Tuple { elts, ctx: _ } => elts.len(),
|
||||
ast::ExprKind::Subscript {
|
||||
value: _,
|
||||
slice: _,
|
||||
ctx: _,
|
||||
} => 1,
|
||||
_ => {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"Expect `(vector[...], vector[...], ...)` or `vector[...]`, found \n{:#?}",
|
||||
&rets.node
|
||||
),
|
||||
Some(rets.location),
|
||||
)
|
||||
}
|
||||
});
|
||||
match &rets.node {
|
||||
// python: ->(vector[...], vector[...], ...)
|
||||
ast::ExprKind::Tuple { elts, .. } => {
|
||||
for elem in elts {
|
||||
return_types.push(Some(parse_annotation(elem)?))
|
||||
}
|
||||
}
|
||||
// python: -> vector[...]
|
||||
ast::ExprKind::Subscript {
|
||||
value: _,
|
||||
slice: _,
|
||||
ctx: _,
|
||||
} => return_types.push(Some(parse_annotation(rets)?)),
|
||||
_ => {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"Expect one or many type annotation for the return type, found \n{:#?}",
|
||||
&rets.node
|
||||
),
|
||||
Some(rets.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(return_types)
|
||||
}
|
||||
|
||||
/// check if the list of statements contain only one statement and
|
||||
/// that statement is a function call with one decorator
|
||||
fn check_copr(stmts: &Vec<ast::Stmt<()>>) -> Result<()> {
|
||||
ensure!(
|
||||
stmts.len() == 1,
|
||||
CoprParseSnafu {
|
||||
reason:
|
||||
"Expect one and only one python function with `@coprocessor` or `@cpor` decorator"
|
||||
.to_string(),
|
||||
loc: stmts.first().map(|s| s.location)
|
||||
}
|
||||
);
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
name: _,
|
||||
args: _,
|
||||
body: _,
|
||||
decorator_list,
|
||||
returns: _,
|
||||
type_comment: _,
|
||||
} = &stmts[0].node
|
||||
{
|
||||
ensure!(
|
||||
decorator_list.len() == 1,
|
||||
CoprParseSnafu {
|
||||
reason: "Expect one decorator",
|
||||
loc: decorator_list.first().map(|s| s.location)
|
||||
}
|
||||
);
|
||||
} else {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"Expect a function definition, found a \n{:#?}",
|
||||
&stmts[0].node
|
||||
),
|
||||
Some(stmts[0].location),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// parse script and return `Coprocessor` struct with info extract from ast
|
||||
pub fn parse_copr(script: &str) -> Result<Coprocessor> {
|
||||
let python_ast = parser::parse_program(script).context(PyParseSnafu)?;
|
||||
check_copr(&python_ast)?;
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
name,
|
||||
args: fn_args,
|
||||
body: _,
|
||||
decorator_list,
|
||||
returns,
|
||||
type_comment: _,
|
||||
} = &python_ast[0].node
|
||||
{
|
||||
let decorator = &decorator_list[0];
|
||||
let deco_args = parse_decorator(decorator)?;
|
||||
|
||||
// get arg types from type annotation
|
||||
let arg_types = get_arg_annotations(fn_args)?;
|
||||
|
||||
// get return types from type annotation
|
||||
let return_types = if let Some(rets) = returns {
|
||||
get_return_annotations(rets)?
|
||||
} else {
|
||||
// if no anntation at all, set it to all None
|
||||
std::iter::repeat(None)
|
||||
.take(deco_args.ret_names.len())
|
||||
.collect()
|
||||
};
|
||||
|
||||
// make sure both arguments&returns in fucntion
|
||||
// and in decorator have same length
|
||||
ensure!(
|
||||
deco_args.arg_names.len() == arg_types.len(),
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"args number in decorator({}) and function({}) doesn't match",
|
||||
deco_args.arg_names.len(),
|
||||
arg_types.len()
|
||||
),
|
||||
loc: None
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
deco_args.ret_names.len() == return_types.len(),
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"returns number in decorator( {} ) and function annotation( {} ) doesn't match",
|
||||
deco_args.ret_names.len(),
|
||||
return_types.len()
|
||||
),
|
||||
loc: None
|
||||
}
|
||||
);
|
||||
Ok(Coprocessor {
|
||||
name: name.to_string(),
|
||||
deco_args,
|
||||
arg_types,
|
||||
return_types,
|
||||
script: script.to_owned(),
|
||||
})
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
230
src/script/src/python/engine.rs
Normal file
230
src/script/src/python/engine.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
//! Python script engine
|
||||
use std::any::Any;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_recordbatch::{
|
||||
error::ExternalSnafu, error::Result as RecordBatchResult, RecordBatch, RecordBatchStream,
|
||||
SendableRecordBatchStream,
|
||||
};
|
||||
use datatypes::schema::SchemaRef;
|
||||
use futures::Stream;
|
||||
use query::Output;
|
||||
use query::QueryEngineRef;
|
||||
use snafu::{ensure, ResultExt};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use crate::python::coprocessor::{exec_parsed, parse::parse_copr};
|
||||
use crate::python::{
|
||||
coprocessor::CoprocessorRef,
|
||||
error::{self, Result},
|
||||
};
|
||||
|
||||
const PY_ENGINE: &str = "python";
|
||||
|
||||
pub struct PyScript {
|
||||
query_engine: QueryEngineRef,
|
||||
copr: CoprocessorRef,
|
||||
}
|
||||
|
||||
pub struct CoprStream {
|
||||
stream: SendableRecordBatchStream,
|
||||
copr: CoprocessorRef,
|
||||
}
|
||||
|
||||
impl RecordBatchStream for CoprStream {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
self.stream.schema()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for CoprStream {
|
||||
type Item = RecordBatchResult<RecordBatch>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match Pin::new(&mut self.stream).poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Some(Ok(recordbatch))) => {
|
||||
let batch = exec_parsed(&self.copr, &recordbatch.df_recordbatch)
|
||||
.map_err(BoxedError::new)
|
||||
.context(ExternalSnafu)?;
|
||||
|
||||
Poll::Ready(Some(Ok(batch)))
|
||||
}
|
||||
Poll::Ready(other) => Poll::Ready(other),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
self.stream.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Script for PyScript {
|
||||
type Error = error::Error;
|
||||
|
||||
fn engine_name(&self) -> &str {
|
||||
PY_ENGINE
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn evaluate(&self, _ctx: EvalContext) -> Result<Output> {
|
||||
if let Some(sql) = &self.copr.deco_args.sql {
|
||||
let stmt = self.query_engine.sql_to_statement(sql)?;
|
||||
ensure!(
|
||||
matches!(stmt, Statement::Query { .. }),
|
||||
error::UnsupportedSqlSnafu { sql }
|
||||
);
|
||||
let plan = self.query_engine.statement_to_plan(stmt)?;
|
||||
let res = self.query_engine.execute(&plan).await?;
|
||||
let copr = self.copr.clone();
|
||||
match res {
|
||||
query::Output::RecordBatch(stream) => {
|
||||
Ok(Output::RecordBatch(Box::pin(CoprStream { copr, stream })))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
// TODO(boyan): try to retrieve sql from user request
|
||||
error::MissingSqlSnafu {}.fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PyEngine {
|
||||
query_engine: QueryEngineRef,
|
||||
}
|
||||
|
||||
impl PyEngine {
|
||||
pub fn new(query_engine: QueryEngineRef) -> Self {
|
||||
Self { query_engine }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ScriptEngine for PyEngine {
|
||||
type Error = error::Error;
|
||||
type Script = PyScript;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
PY_ENGINE
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn compile(&self, script: &str, _ctx: CompileContext) -> Result<PyScript> {
|
||||
let copr = Arc::new(parse_copr(script)?);
|
||||
|
||||
Ok(PyScript {
|
||||
copr,
|
||||
query_engine: self.query_engine.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use catalog::memory::{MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::{
|
||||
CatalogList, CatalogProvider, SchemaProvider, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME,
|
||||
};
|
||||
use common_recordbatch::util;
|
||||
use datafusion_common::field_util::FieldExt;
|
||||
use datafusion_common::field_util::SchemaExt;
|
||||
use datatypes::arrow::array::Float64Array;
|
||||
use datatypes::arrow::array::Int64Array;
|
||||
use query::QueryEngineFactory;
|
||||
use table::table::numbers::NumbersTable;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compile_evaluate() {
|
||||
let catalog_list = catalog::memory::new_memory_catalog_list().unwrap();
|
||||
|
||||
let default_schema = Arc::new(MemorySchemaProvider::new());
|
||||
default_schema
|
||||
.register_table("numbers".to_string(), Arc::new(NumbersTable::default()))
|
||||
.unwrap();
|
||||
let default_catalog = Arc::new(MemoryCatalogProvider::new());
|
||||
default_catalog.register_schema(DEFAULT_SCHEMA_NAME.to_string(), default_schema);
|
||||
catalog_list.register_catalog(DEFAULT_CATALOG_NAME.to_string(), default_catalog);
|
||||
|
||||
let factory = QueryEngineFactory::new(catalog_list);
|
||||
let query_engine = factory.query_engine();
|
||||
|
||||
let script_engine = PyEngine::new(query_engine.clone());
|
||||
|
||||
let script = r#"
|
||||
@copr(args=["a", "b", "c"], returns = ["r"], sql="select number as a,number as b,number as c from numbers limit 100")
|
||||
def test(a, b, c):
|
||||
import greptime as g
|
||||
return (a + b) / g.sqrt(c)
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script.evaluate(EvalContext::default()).await.unwrap();
|
||||
match output {
|
||||
Output::RecordBatch(stream) => {
|
||||
let numbers = util::collect(stream).await.unwrap();
|
||||
|
||||
assert_eq!(1, numbers.len());
|
||||
let number = &numbers[0];
|
||||
assert_eq!(number.df_recordbatch.num_columns(), 1);
|
||||
assert_eq!("r", number.schema.arrow_schema().field(0).name());
|
||||
|
||||
let columns = number.df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(100, columns[0].len());
|
||||
let rows = columns[0].as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
assert!(rows.value(0).is_nan());
|
||||
assert_eq!((99f64 + 99f64) / 99f64.sqrt(), rows.value(99))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// test list comprehension
|
||||
let script = r#"
|
||||
@copr(args=["number"], returns = ["r"], sql="select number from numbers limit 100")
|
||||
def test(a):
|
||||
import greptime as gt
|
||||
return gt.vector([x for x in a if x % 2 == 0])
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script.evaluate(EvalContext::default()).await.unwrap();
|
||||
match output {
|
||||
Output::RecordBatch(stream) => {
|
||||
let numbers = util::collect(stream).await.unwrap();
|
||||
|
||||
assert_eq!(1, numbers.len());
|
||||
let number = &numbers[0];
|
||||
assert_eq!(number.df_recordbatch.num_columns(), 1);
|
||||
assert_eq!("r", number.schema.arrow_schema().field(0).name());
|
||||
|
||||
let columns = number.df_recordbatch.columns();
|
||||
assert_eq!(1, columns.len());
|
||||
assert_eq!(50, columns[0].len());
|
||||
let rows = columns[0].as_any().downcast_ref::<Int64Array>().unwrap();
|
||||
assert_eq!(0, rows.value(0));
|
||||
assert_eq!(98, rows.value(49))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
189
src/script/src/python/error.rs
Normal file
189
src/script/src/python/error.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use common_error::prelude::{ErrorCompat, ErrorExt, StatusCode};
|
||||
use console::{style, Style};
|
||||
use datatypes::arrow::error::ArrowError;
|
||||
use datatypes::error::Error as DataTypeError;
|
||||
use query::error::Error as QueryError;
|
||||
use rustpython_compiler_core::error::CompileError as CoreCompileError;
|
||||
use rustpython_parser::{ast::Location, error::ParseError};
|
||||
pub use snafu::ensure;
|
||||
use snafu::{prelude::Snafu, Backtrace};
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub(crate)))]
|
||||
pub enum Error {
|
||||
#[snafu(display("Datatype error: {}", source))]
|
||||
TypeCast {
|
||||
#[snafu(backtrace)]
|
||||
source: DataTypeError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to query, source: {}", source))]
|
||||
DatabaseQuery {
|
||||
#[snafu(backtrace)]
|
||||
source: QueryError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to parse script, source: {}", source))]
|
||||
PyParse {
|
||||
backtrace: Backtrace,
|
||||
source: ParseError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to compile script, source: {}", source))]
|
||||
PyCompile {
|
||||
backtrace: Backtrace,
|
||||
source: CoreCompileError,
|
||||
},
|
||||
|
||||
/// rustpython problem, using python virtual machines' backtrace instead
|
||||
#[snafu(display("Python Runtime error, error: {}", msg))]
|
||||
PyRuntime { msg: String, backtrace: Backtrace },
|
||||
|
||||
#[snafu(display("Arrow error: {}", source))]
|
||||
Arrow {
|
||||
backtrace: Backtrace,
|
||||
source: ArrowError,
|
||||
},
|
||||
|
||||
/// errors in coprocessors' parse check for types and etc.
|
||||
#[snafu(display("Coprocessor error: {} {}.", reason,
|
||||
if let Some(loc) = loc{
|
||||
format!("at {loc}")
|
||||
}else{
|
||||
"".into()
|
||||
}))]
|
||||
CoprParse {
|
||||
backtrace: Backtrace,
|
||||
reason: String,
|
||||
// location is option because maybe errors can't give a clear location?
|
||||
loc: Option<Location>,
|
||||
},
|
||||
|
||||
/// Other types of error that isn't any of above
|
||||
#[snafu(display("Coprocessor's Internal error: {}", reason))]
|
||||
Other {
|
||||
backtrace: Backtrace,
|
||||
reason: String,
|
||||
},
|
||||
|
||||
#[snafu(display("Unsupported sql in coprocessor: {}", sql))]
|
||||
UnsupportedSql { sql: String, backtrace: Backtrace },
|
||||
|
||||
#[snafu(display("Missing sql in coprocessor"))]
|
||||
MissingSql { backtrace: Backtrace },
|
||||
|
||||
#[snafu(display("Failed to retrieve record batches, source: {}", source))]
|
||||
RecordBatch {
|
||||
#[snafu(backtrace)]
|
||||
source: common_recordbatch::error::Error,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<QueryError> for Error {
|
||||
fn from(source: QueryError) -> Self {
|
||||
Self::DatabaseQuery { source }
|
||||
}
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::Arrow { .. }
|
||||
| Error::TypeCast { .. }
|
||||
| Error::DatabaseQuery { .. }
|
||||
| Error::PyRuntime { .. }
|
||||
| Error::RecordBatch { .. }
|
||||
| Error::Other { .. } => StatusCode::Internal,
|
||||
|
||||
Error::PyParse { .. }
|
||||
| Error::PyCompile { .. }
|
||||
| Error::CoprParse { .. }
|
||||
| Error::UnsupportedSql { .. }
|
||||
| Error::MissingSql { .. } => StatusCode::InvalidArguments,
|
||||
}
|
||||
}
|
||||
fn backtrace_opt(&self) -> Option<&common_error::snafu::Backtrace> {
|
||||
ErrorCompat::backtrace(self)
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
// impl from for those error so one can use question mark and implictly cast into `CoprError`
|
||||
impl From<DataTypeError> for Error {
|
||||
fn from(e: DataTypeError) -> Self {
|
||||
Self::TypeCast { source: e }
|
||||
}
|
||||
}
|
||||
|
||||
/// pretty print [`Error`] in given script,
|
||||
/// basically print a arrow which point to where error occurs(if possible to get a location)
|
||||
pub fn pretty_print_error_in_src(
|
||||
script: &str,
|
||||
err: &Error,
|
||||
ln_offset: usize,
|
||||
filename: &str,
|
||||
) -> String {
|
||||
let (reason, loc) = get_error_reason_loc(err);
|
||||
if let Some(loc) = loc {
|
||||
visualize_loc(script, &loc, &err.to_string(), &reason, ln_offset, filename)
|
||||
} else {
|
||||
// No location provide
|
||||
format!("\n{}: {}", style("error").red().bold(), err)
|
||||
}
|
||||
}
|
||||
|
||||
/// pretty print a location in script with desc.
|
||||
///
|
||||
/// `ln_offset` is line offset number that added to `loc`'s `row`, `filename` is the file's name display with it's row and columns info.
|
||||
pub fn visualize_loc(
|
||||
script: &str,
|
||||
loc: &Location,
|
||||
err_ty: &str,
|
||||
desc: &str,
|
||||
ln_offset: usize,
|
||||
filename: &str,
|
||||
) -> String {
|
||||
let lines: Vec<&str> = script.split('\n').collect();
|
||||
let (row, col) = (loc.row(), loc.column());
|
||||
let red_bold = Style::new().red().bold();
|
||||
let blue_bold = Style::new().blue().bold();
|
||||
let col_space = (ln_offset + row).to_string().len().max(1);
|
||||
let space: String = " ".repeat(col_space - 1);
|
||||
let indicate = format!(
|
||||
"
|
||||
{error}: {err_ty}
|
||||
{space}{r_arrow}{filename}:{row}:{col}
|
||||
{prow:col_space$}{ln_pad} {line}
|
||||
{space} {ln_pad} {arrow:>pad$} {desc}
|
||||
",
|
||||
error = red_bold.apply_to("error"),
|
||||
err_ty = style(err_ty).bold(),
|
||||
r_arrow = blue_bold.apply_to("-->"),
|
||||
filename = filename,
|
||||
row = ln_offset + row,
|
||||
col = col,
|
||||
line = lines[loc.row() - 1],
|
||||
pad = loc.column(),
|
||||
arrow = red_bold.apply_to("^"),
|
||||
desc = red_bold.apply_to(desc),
|
||||
ln_pad = blue_bold.apply_to("|"),
|
||||
prow = blue_bold.apply_to(ln_offset + row),
|
||||
space = space
|
||||
);
|
||||
indicate
|
||||
}
|
||||
|
||||
/// extract a reason for [`Error`] in string format, also return a location if possible
|
||||
pub fn get_error_reason_loc(err: &Error) -> (String, Option<Location>) {
|
||||
match err {
|
||||
Error::CoprParse { reason, loc, .. } => (reason.clone(), loc.to_owned()),
|
||||
Error::Other { reason, .. } => (reason.clone(), None),
|
||||
Error::PyRuntime { msg, .. } => (msg.clone(), None),
|
||||
Error::PyParse { source, .. } => (source.error.to_string(), Some(source.location)),
|
||||
Error::PyCompile { source, .. } => (source.error.to_string(), Some(source.location)),
|
||||
_ => (format!("Unknown error: {:?}", err), None),
|
||||
}
|
||||
}
|
||||
321
src/script/src/python/test.rs
Normal file
321
src/script/src/python/test.rs
Normal file
@@ -0,0 +1,321 @@
|
||||
#![allow(clippy::print_stdout, clippy::print_stderr)]
|
||||
// for debug purpose, also this is already a
|
||||
// test module so allow print_stdout shouldn't be a problem?
|
||||
use std::fs::File;
|
||||
use std::io::prelude::*;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use console::style;
|
||||
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
|
||||
use datatypes::arrow::array::PrimitiveArray;
|
||||
use datatypes::arrow::datatypes::{DataType, Field, Schema};
|
||||
use ron::from_str as from_ron_string;
|
||||
use rustpython_parser::parser;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::error::{get_error_reason_loc, visualize_loc};
|
||||
use crate::python::coprocessor::AnnotationInfo;
|
||||
use crate::python::error::pretty_print_error_in_src;
|
||||
use crate::python::{
|
||||
coprocessor, coprocessor::parse::parse_copr, coprocessor::Coprocessor, error::Error,
|
||||
};
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct TestCase {
|
||||
name: String,
|
||||
code: String,
|
||||
predicate: Predicate,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
enum Predicate {
|
||||
ParseIsOk {
|
||||
result: Coprocessor,
|
||||
},
|
||||
ParseIsErr {
|
||||
/// used to check if after serialize [`Error`] into a String, that string contains `reason`
|
||||
reason: String,
|
||||
},
|
||||
ExecIsOk {
|
||||
fields: Vec<AnnotationInfo>,
|
||||
columns: Vec<ColumnInfo>,
|
||||
},
|
||||
ExecIsErr {
|
||||
/// used to check if after serialize [`Error`] into a String, that string contains `reason`
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct ColumnInfo {
|
||||
pub ty: DataType,
|
||||
pub len: usize,
|
||||
}
|
||||
|
||||
fn create_sample_recordbatch() -> DfRecordBatch {
|
||||
let cpu_array = PrimitiveArray::from_slice([0.9f32, 0.8, 0.7, 0.6]);
|
||||
let mem_array = PrimitiveArray::from_slice([0.1f64, 0.2, 0.3, 0.4]);
|
||||
let schema = Arc::new(Schema::from(vec![
|
||||
Field::new("cpu", DataType::Float32, false),
|
||||
Field::new("mem", DataType::Float64, false),
|
||||
]));
|
||||
|
||||
DfRecordBatch::try_new(schema, vec![Arc::new(cpu_array), Arc::new(mem_array)]).unwrap()
|
||||
}
|
||||
|
||||
/// test cases which read from a .ron file, deser,
|
||||
///
|
||||
/// and exec/parse (depending on the type of predicate) then decide if result is as expected
|
||||
#[test]
|
||||
fn run_ron_testcases() {
|
||||
let loc = Path::new("src/python/testcases.ron");
|
||||
let loc = loc.to_str().expect("Fail to parse path");
|
||||
let mut file = File::open(loc).expect("Fail to open file");
|
||||
let mut buf = String::new();
|
||||
file.read_to_string(&mut buf)
|
||||
.expect("Fail to read to string");
|
||||
let testcases: Vec<TestCase> = from_ron_string(&buf).expect("Fail to convert to testcases");
|
||||
println!("Read {} testcases from {}", testcases.len(), loc);
|
||||
for testcase in testcases {
|
||||
print!(".ron test {}", testcase.name);
|
||||
match testcase.predicate {
|
||||
Predicate::ParseIsOk { result } => {
|
||||
let copr = parse_copr(&testcase.code);
|
||||
let mut copr = copr.unwrap();
|
||||
copr.script = "".into();
|
||||
assert_eq!(copr, result);
|
||||
}
|
||||
Predicate::ParseIsErr { reason } => {
|
||||
let copr = parse_copr(&testcase.code);
|
||||
if copr.is_ok() {
|
||||
eprintln!("Expect to be err, found{copr:#?}");
|
||||
panic!()
|
||||
}
|
||||
let res = &copr.unwrap_err();
|
||||
println!(
|
||||
"{}",
|
||||
pretty_print_error_in_src(&testcase.code, res, 0, "<embedded>")
|
||||
);
|
||||
let (res, _) = get_error_reason_loc(res);
|
||||
if !res.contains(&reason) {
|
||||
eprintln!("{}", testcase.code);
|
||||
eprintln!("Parse Error, expect \"{reason}\" in \"{res}\", but not found.");
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
Predicate::ExecIsOk { fields, columns } => {
|
||||
let rb = create_sample_recordbatch();
|
||||
let res = coprocessor::exec_coprocessor(&testcase.code, &rb);
|
||||
if res.is_err() {
|
||||
dbg!(&res);
|
||||
}
|
||||
assert!(res.is_ok());
|
||||
let res = res.unwrap();
|
||||
fields
|
||||
.iter()
|
||||
.zip(&res.schema.arrow_schema().fields)
|
||||
.map(|(anno, real)| {
|
||||
if !(anno.datatype.clone().unwrap() == real.data_type
|
||||
&& anno.is_nullable == real.is_nullable)
|
||||
{
|
||||
eprintln!("fields expect to be {anno:#?}, found to be {real:#?}.");
|
||||
panic!()
|
||||
}
|
||||
})
|
||||
.count();
|
||||
columns
|
||||
.iter()
|
||||
.zip(res.df_recordbatch.columns())
|
||||
.map(|(anno, real)| {
|
||||
if !(&anno.ty == real.data_type() && anno.len == real.len()) {
|
||||
eprintln!(
|
||||
"Unmatch type or length!Expect [{:#?}; {}], found [{:#?}; {}]",
|
||||
anno.ty,
|
||||
anno.len,
|
||||
real.data_type(),
|
||||
real.len()
|
||||
);
|
||||
panic!()
|
||||
}
|
||||
})
|
||||
.count();
|
||||
}
|
||||
Predicate::ExecIsErr {
|
||||
reason: part_reason,
|
||||
} => {
|
||||
let rb = create_sample_recordbatch();
|
||||
let res = coprocessor::exec_coprocessor(&testcase.code, &rb);
|
||||
if let Err(res) = res {
|
||||
println!(
|
||||
"{}",
|
||||
pretty_print_error_in_src(&testcase.code, &res, 1120, "<embedded>")
|
||||
);
|
||||
let (reason, _) = get_error_reason_loc(&res);
|
||||
if !reason.contains(&part_reason) {
|
||||
eprintln!(
|
||||
"{}\nExecute error, expect \"{reason}\" in \"{res}\", but not found.",
|
||||
testcase.code,
|
||||
reason = style(reason).green(),
|
||||
res = style(res).red()
|
||||
);
|
||||
panic!()
|
||||
}
|
||||
} else {
|
||||
eprintln!("{:#?}\nExpect Err(...), found Ok(...)", res);
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
}
|
||||
println!(" ... {}", style("ok✅").green());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(unused)]
|
||||
fn test_type_anno() {
|
||||
let python_source = r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu, mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[ _ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#;
|
||||
let pyast = parser::parse(python_source, parser::Mode::Interactive).unwrap();
|
||||
let copr = parse_copr(python_source);
|
||||
dbg!(copr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::print_stdout, unused_must_use)]
|
||||
// allow print in test function for debug purpose(like for quick testing a syntax&ideas)
|
||||
fn test_calc_rvs() {
|
||||
let python_source = r#"
|
||||
@coprocessor(args=["open_time", "close"], returns=[
|
||||
"rv_7d",
|
||||
"rv_15d",
|
||||
"rv_30d",
|
||||
"rv_60d",
|
||||
"rv_90d",
|
||||
"rv_180d"
|
||||
])
|
||||
def calc_rvs(open_time, close):
|
||||
from greptime import vector, log, prev, sqrt, datetime, pow, sum
|
||||
def calc_rv(close, open_time, time, interval):
|
||||
mask = (open_time < time) & (open_time > time - interval)
|
||||
close = close[mask]
|
||||
|
||||
avg_time_interval = (open_time[-1] - open_time[0])/(len(open_time)-1)
|
||||
ref = log(close/prev(close))
|
||||
var = sum(pow(ref, 2)/(len(ref)-1))
|
||||
return sqrt(var/avg_time_interval)
|
||||
|
||||
# how to get env var,
|
||||
# maybe through accessing scope and serde then send to remote?
|
||||
timepoint = open_time[-1]
|
||||
rv_7d = calc_rv(close, open_time, timepoint, datetime("7d"))
|
||||
rv_15d = calc_rv(close, open_time, timepoint, datetime("15d"))
|
||||
rv_30d = calc_rv(close, open_time, timepoint, datetime("30d"))
|
||||
rv_60d = calc_rv(close, open_time, timepoint, datetime("60d"))
|
||||
rv_90d = calc_rv(close, open_time, timepoint, datetime("90d"))
|
||||
rv_180d = calc_rv(close, open_time, timepoint, datetime("180d"))
|
||||
return rv_7d, rv_15d, rv_30d, rv_60d, rv_90d, rv_180d
|
||||
"#;
|
||||
let close_array = PrimitiveArray::from_slice([
|
||||
10106.79f32,
|
||||
10106.09,
|
||||
10108.73,
|
||||
10106.38,
|
||||
10106.95,
|
||||
10107.55,
|
||||
10104.68,
|
||||
10108.8,
|
||||
10115.96,
|
||||
10117.08,
|
||||
10120.43,
|
||||
]);
|
||||
let open_time_array = PrimitiveArray::from_slice([
|
||||
1581231300i64,
|
||||
1581231360,
|
||||
1581231420,
|
||||
1581231480,
|
||||
1581231540,
|
||||
1581231600,
|
||||
1581231660,
|
||||
1581231720,
|
||||
1581231780,
|
||||
1581231840,
|
||||
1581231900,
|
||||
]);
|
||||
let schema = Arc::new(Schema::from(vec![
|
||||
Field::new("close", DataType::Float32, false),
|
||||
Field::new("open_time", DataType::Int64, false),
|
||||
]));
|
||||
let rb = DfRecordBatch::try_new(
|
||||
schema,
|
||||
vec![Arc::new(close_array), Arc::new(open_time_array)],
|
||||
)
|
||||
.unwrap();
|
||||
let ret = coprocessor::exec_coprocessor(python_source, &rb);
|
||||
if let Err(Error::PyParse {
|
||||
backtrace: _,
|
||||
source,
|
||||
}) = ret
|
||||
{
|
||||
let res = visualize_loc(
|
||||
python_source,
|
||||
&source.location,
|
||||
"unknown tokens",
|
||||
source.error.to_string().as_str(),
|
||||
0,
|
||||
"copr.py",
|
||||
);
|
||||
println!("{res}");
|
||||
} else if let Ok(res) = ret {
|
||||
dbg!(&res);
|
||||
} else {
|
||||
dbg!(ret);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::print_stdout, unused_must_use)]
|
||||
// allow print in test function for debug purpose(like for quick testing a syntax&ideas)
|
||||
fn test_coprocessor() {
|
||||
let python_source = r#"
|
||||
@copr(args=["cpu", "mem"], returns=["ref"])
|
||||
def a(cpu, mem):
|
||||
import greptime as gt
|
||||
from greptime import vector, log2, prev, sum, pow, sqrt, datetime
|
||||
abc = vector([v[0] > v[1] for v in zip(cpu, mem)])
|
||||
fed = cpu.filter(abc)
|
||||
ref = log2(fed/prev(fed))
|
||||
return (0.5 < cpu) & ~( cpu >= 0.75)
|
||||
"#;
|
||||
let cpu_array = PrimitiveArray::from_slice([0.9f32, 0.8, 0.7, 0.3]);
|
||||
let mem_array = PrimitiveArray::from_slice([0.1f64, 0.2, 0.3, 0.4]);
|
||||
let schema = Arc::new(Schema::from(vec![
|
||||
Field::new("cpu", DataType::Float32, false),
|
||||
Field::new("mem", DataType::Float64, false),
|
||||
]));
|
||||
let rb =
|
||||
DfRecordBatch::try_new(schema, vec![Arc::new(cpu_array), Arc::new(mem_array)]).unwrap();
|
||||
let ret = coprocessor::exec_coprocessor(python_source, &rb);
|
||||
if let Err(Error::PyParse {
|
||||
backtrace: _,
|
||||
source,
|
||||
}) = ret
|
||||
{
|
||||
let res = visualize_loc(
|
||||
python_source,
|
||||
&source.location,
|
||||
"unknown tokens",
|
||||
source.error.to_string().as_str(),
|
||||
0,
|
||||
"copr.py",
|
||||
);
|
||||
println!("{res}");
|
||||
} else if let Ok(res) = ret {
|
||||
dbg!(&res);
|
||||
} else {
|
||||
dbg!(ret);
|
||||
}
|
||||
}
|
||||
413
src/script/src/python/testcases.ron
Normal file
413
src/script/src/python/testcases.ron
Normal file
@@ -0,0 +1,413 @@
|
||||
// This is the file for python coprocessor's testcases,
|
||||
// including coprocessor parsing test and execute test
|
||||
// check src/scalars/python/test.rs for more information
|
||||
[
|
||||
(
|
||||
name: "correct_parse",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsOk(
|
||||
result: (
|
||||
name: "a",
|
||||
deco_args: (
|
||||
arg_names: ["cpu", "mem"],
|
||||
ret_names: ["perf", "what", "how", "why"],
|
||||
),
|
||||
arg_types: [
|
||||
Some((
|
||||
datatype: Some(Float32),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64),
|
||||
is_nullable: false
|
||||
)),
|
||||
],
|
||||
return_types: [
|
||||
Some((
|
||||
datatype: Some(Float64),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64),
|
||||
is_nullable: true
|
||||
)),
|
||||
Some((
|
||||
datatype: None,
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: None,
|
||||
is_nullable: true
|
||||
)),
|
||||
]
|
||||
)
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "missing_decorator",
|
||||
code: r#"
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect one decorator"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "not_a_list_of_string",
|
||||
code: r#"
|
||||
@copr(args=["cpu", 3], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect a list of String, found"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "not_even_a_list",
|
||||
code: r#"
|
||||
@copr(args=42, returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect a list, found"
|
||||
)
|
||||
),
|
||||
(
|
||||
// unknown type names
|
||||
name: "unknown_type_names",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[g32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Unknown datatype:"
|
||||
)
|
||||
),
|
||||
(
|
||||
// two type name
|
||||
name: "two_type_names",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32 | f64], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect one typenames and one `None`"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "two_none",
|
||||
// two `None`
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[None | None], mem: vector[f64])->(vector[f64], vector[None|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect one typenames and one `None`"
|
||||
)
|
||||
),
|
||||
(
|
||||
// Expect a Types name
|
||||
name: "unknown_type_names_in_ret",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f64|None], mem: vector[f64])->(vector[g64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Unknown datatype:"
|
||||
)
|
||||
),
|
||||
(
|
||||
// no more `into`
|
||||
name: "call_deprecated_for_cast_into",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[cast(f64)], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect type in `vector[...]`, found "
|
||||
)
|
||||
),
|
||||
(
|
||||
// Expect `vector` not `vec`
|
||||
name: "vector_not_vec",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vec[f64], mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Wrong type annotation, expect `vector[...]`, found"
|
||||
)
|
||||
),
|
||||
(
|
||||
// Expect `None`
|
||||
name: "expect_none",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f64|1], mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect a type name and a `None`, found left: "
|
||||
)
|
||||
),
|
||||
(
|
||||
// more than one statement
|
||||
name: "two_stmt",
|
||||
code: r#"
|
||||
print("hello world")
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f64], mem: vector[f64])->(vector[None|None], vector[into(f64)], vector[f64], vector[f64 | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
"Expect one and only one python function with `@coprocessor` or `@cpor` decorator"
|
||||
)
|
||||
),
|
||||
(
|
||||
// wrong decorator name
|
||||
name: "typo_copr",
|
||||
code: r#"
|
||||
@corp(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f64], mem: vector[f64])->(vector[None|None], vector[into(f64)], vector[f64], vector[f64 | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
"Expect decorator with name `copr` or `coprocessor`, found"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "extra_keywords",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], sql=3,psql = 4,rets=5)
|
||||
def a(cpu: vector[f64], mem: vector[f64])->(vector[f64|None], vector[into(f64)], vector[f64], vector[f64 | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
" keyword argument, found "
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "missing_keywords",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"])
|
||||
def a(cpu: vector[f64], mem: vector[f64])->(vector[f64|None], vector[into(f64)], vector[f64], vector[f64 | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
" keyword argument, found "
|
||||
)
|
||||
),
|
||||
(
|
||||
// exec_coprocessor
|
||||
name: "correct_exec",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, cpu - mem
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(float)
|
||||
name: "constant_float_col",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, 1.0
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(int)
|
||||
name: "constant_int_col",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, 1
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(bool)
|
||||
name: "constant_bool_col",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, True
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// expect 4 vector ,found 5
|
||||
name: "ret_nums_wrong",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why", "whatever", "nihilism"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], vector[f64], vector[f64], vector[f64 | None], vector[bool], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem, cpu
|
||||
"#,
|
||||
predicate: ExecIsErr(
|
||||
reason: "The number of return Vector is wrong, expect"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "div_by_zero",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, cpu - mem*(1/0)
|
||||
"#,
|
||||
predicate: ExecIsErr(
|
||||
reason: "ZeroDivisionError: division by zero"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "unexpected_token",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, cpu - mem***
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "invalid syntax. Got unexpected token "
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "wrong_return_anno",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->f32:
|
||||
return cpu + mem, cpu - mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect `(vector[...], vector[...], ...)` or `vector[...]`, found "
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "break_outside_loop",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64]):
|
||||
break
|
||||
return cpu + mem, cpu - mem
|
||||
"#,
|
||||
predicate: ExecIsErr(
|
||||
reason: "'break' outside loop"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "not_even_wrong",
|
||||
code: r#"
|
||||
42
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect a function definition, found a"
|
||||
)
|
||||
)
|
||||
]
|
||||
27
src/script/src/python/utils.rs
Normal file
27
src/script/src/python/utils.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use rustpython_vm::{builtins::PyBaseExceptionRef, PyObjectRef, PyPayload, PyRef, VirtualMachine};
|
||||
use snafu::{Backtrace, GenerateImplicitData};
|
||||
|
||||
use crate::python::error;
|
||||
use crate::python::PyVector;
|
||||
pub(crate) type PyVectorRef = PyRef<PyVector>;
|
||||
|
||||
/// use `rustpython`'s `is_instance` method to check if a PyObject is a instance of class.
|
||||
/// if `PyResult` is Err, then this function return `false`
|
||||
pub fn is_instance<T: PyPayload>(obj: &PyObjectRef, vm: &VirtualMachine) -> bool {
|
||||
obj.is_instance(T::class(vm).into(), vm).unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn format_py_error(excep: PyBaseExceptionRef, vm: &VirtualMachine) -> error::Error {
|
||||
let mut msg = String::new();
|
||||
if let Err(e) = vm.write_exception(&mut msg, &excep) {
|
||||
return error::Error::PyRuntime {
|
||||
msg: format!("Failed to write exception msg, err: {}", e),
|
||||
backtrace: Backtrace::generate(),
|
||||
};
|
||||
}
|
||||
|
||||
error::Error::PyRuntime {
|
||||
msg,
|
||||
backtrace: Backtrace::generate(),
|
||||
}
|
||||
}
|
||||
1197
src/script/src/python/vector.rs
Normal file
1197
src/script/src/python/vector.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -34,3 +34,4 @@ catalog = { path = "../catalog" }
|
||||
mysql_async = "0.30"
|
||||
rand = "0.8"
|
||||
test-util = { path = "../../test-util" }
|
||||
script = { path = "../script", features = ["python"] }
|
||||
|
||||
@@ -49,6 +49,13 @@ pub enum Error {
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to execute script: {}, source: {}", script, source))]
|
||||
ExecuteScript {
|
||||
script: String,
|
||||
#[snafu(backtrace)]
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("Not supported: {}", feat))]
|
||||
NotSupported { feat: String },
|
||||
}
|
||||
@@ -66,7 +73,11 @@ impl ErrorExt for Error {
|
||||
| Error::StartHttp { .. }
|
||||
| Error::StartGrpc { .. }
|
||||
| Error::TcpBind { .. } => StatusCode::Internal,
|
||||
Error::ExecuteQuery { source, .. } => source.status_code(),
|
||||
|
||||
Error::ExecuteScript { source, .. } | Error::ExecuteQuery { source, .. } => {
|
||||
source.status_code()
|
||||
}
|
||||
|
||||
Error::NotSupported { .. } => StatusCode::InvalidArguments,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,8 +8,7 @@ use axum::{
|
||||
error_handling::HandleErrorLayer,
|
||||
response::IntoResponse,
|
||||
response::{Json, Response},
|
||||
routing::get,
|
||||
BoxError, Extension, Router,
|
||||
routing, BoxError, Extension, Router,
|
||||
};
|
||||
use common_recordbatch::{util, RecordBatch};
|
||||
use common_telemetry::logging::info;
|
||||
@@ -23,6 +22,8 @@ use crate::error::{Result, StartHttpSnafu};
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
use crate::server::Server;
|
||||
|
||||
const HTTP_API_VERSION: &str = "v1";
|
||||
|
||||
pub struct HttpServer {
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
}
|
||||
@@ -116,9 +117,14 @@ impl HttpServer {
|
||||
|
||||
pub fn make_app(&self) -> Router {
|
||||
Router::new()
|
||||
// handlers
|
||||
.route("/sql", get(handler::sql))
|
||||
.route("/metrics", get(handler::metrics))
|
||||
.nest(
|
||||
&format!("/{}", HTTP_API_VERSION),
|
||||
Router::new()
|
||||
// handlers
|
||||
.route("/sql", routing::get(handler::sql))
|
||||
.route("/scripts", routing::post(handler::scripts)),
|
||||
)
|
||||
.route("/metrics", routing::get(handler::metrics))
|
||||
// middlewares
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use axum::extract::{Extension, Query};
|
||||
use axum::extract::{Extension, Json, Query};
|
||||
use common_telemetry::metric;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::http::{HttpResponse, JsonResponse};
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
@@ -33,3 +34,23 @@ pub async fn metrics(
|
||||
HttpResponse::Text("Prometheus handle not initialized.".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct ScriptExecution {
|
||||
pub script: String,
|
||||
}
|
||||
|
||||
/// Handler to execute scripts
|
||||
#[axum_macros::debug_handler]
|
||||
pub async fn scripts(
|
||||
Extension(query_handler): Extension<SqlQueryHandlerRef>,
|
||||
Json(payload): Json<ScriptExecution>,
|
||||
) -> HttpResponse {
|
||||
if payload.script.is_empty() {
|
||||
return HttpResponse::Json(JsonResponse::with_error(Some("Invalid script".to_string())));
|
||||
}
|
||||
|
||||
HttpResponse::Json(
|
||||
JsonResponse::from_output(query_handler.execute_script(&payload.script).await).await,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ pub type GrpcQueryHandlerRef = Arc<dyn GrpcQueryHandler + Send + Sync>;
|
||||
#[async_trait]
|
||||
pub trait SqlQueryHandler {
|
||||
async fn do_query(&self, query: &str) -> Result<Output>;
|
||||
async fn execute_script(&self, script: &str) -> Result<Output>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use axum::extract::Query;
|
||||
use axum::extract::{Json, Query};
|
||||
use axum::Extension;
|
||||
use common_telemetry::metric;
|
||||
use metrics::counter;
|
||||
use servers::http::handler as http_handler;
|
||||
use servers::http::handler::ScriptExecution;
|
||||
use servers::http::{HttpResponse, JsonOutput};
|
||||
use test_util::MemTable;
|
||||
|
||||
@@ -70,6 +71,41 @@ async fn test_metrics() {
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let exec = create_script_payload();
|
||||
let query_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let extension = Extension(query_handler);
|
||||
|
||||
let json = http_handler::scripts(extension, exec).await;
|
||||
match json {
|
||||
HttpResponse::Json(json) => {
|
||||
assert!(json.success(), "{:?}", json);
|
||||
assert!(json.error().is_none());
|
||||
match json.output().expect("assertion failed") {
|
||||
JsonOutput::Rows(rows) => {
|
||||
assert_eq!(1, rows.len());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_script_payload() -> Json<ScriptExecution> {
|
||||
Json(ScriptExecution {
|
||||
script: r#"
|
||||
@copr(sql='select uint32s as number from numbers', args=['number'], returns=['n'])
|
||||
def test(n):
|
||||
return n;
|
||||
"#
|
||||
.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn create_query() -> Query<HashMap<String, String>> {
|
||||
Query(HashMap::from([(
|
||||
"sql".to_string(),
|
||||
|
||||
@@ -12,9 +12,23 @@ use test_util::MemTable;
|
||||
|
||||
mod http;
|
||||
mod mysql;
|
||||
use script::{
|
||||
engine::{CompileContext, EvalContext, Script, ScriptEngine},
|
||||
python::PyEngine,
|
||||
};
|
||||
|
||||
struct DummyInstance {
|
||||
query_engine: QueryEngineRef,
|
||||
py_engine: Arc<PyEngine>,
|
||||
}
|
||||
|
||||
impl DummyInstance {
|
||||
fn new(query_engine: QueryEngineRef) -> Self {
|
||||
Self {
|
||||
py_engine: Arc::new(PyEngine::new(query_engine.clone())),
|
||||
query_engine,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -23,6 +37,15 @@ impl SqlQueryHandler for DummyInstance {
|
||||
let plan = self.query_engine.sql_to_plan(query).unwrap();
|
||||
Ok(self.query_engine.execute(&plan).await.unwrap())
|
||||
}
|
||||
async fn execute_script(&self, script: &str) -> Result<Output> {
|
||||
let py_script = self
|
||||
.py_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Ok(py_script.evaluate(EvalContext::default()).await.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
fn create_testing_sql_query_handler(table: MemTable) -> SqlQueryHandlerRef {
|
||||
@@ -38,5 +61,5 @@ fn create_testing_sql_query_handler(table: MemTable) -> SqlQueryHandlerRef {
|
||||
|
||||
let factory = QueryEngineFactory::new(catalog_list);
|
||||
let query_engine = factory.query_engine().clone();
|
||||
Arc::new(DummyInstance { query_engine })
|
||||
Arc::new(DummyInstance::new(query_engine))
|
||||
}
|
||||
|
||||
@@ -322,7 +322,7 @@ struct RegionInner<S: LogStore> {
|
||||
impl<S: LogStore> RegionInner<S> {
|
||||
#[inline]
|
||||
fn version_control(&self) -> &VersionControl {
|
||||
self.shared.version_control.as_ref()
|
||||
&self.shared.version_control
|
||||
}
|
||||
|
||||
fn in_memory_metadata(&self) -> RegionMetaImpl {
|
||||
|
||||
Reference in New Issue
Block a user