Compare commits

..

71 Commits

Author SHA1 Message Date
Lei, HUANG
a8630cdb38 fix: clippy errors 2022-12-15 18:12:05 +08:00
Ruihang Xia
0f3dcc1b38 fix: Fix All The Tests! (#752)
* fix: Fix several tests compile errors

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix: some compile errors in tests

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix: compile errors in frontend tests

* fix: compile errors in frontend tests

* test: Fix tests in api and common-query

* test: Fix test in sql crate

* fix: resolve substrait error

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* chore: add more test

* test: Fix tests in servers

* fix instance_test

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* test: Fix tests in tests-integration

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
Co-authored-by: Lei, HUANG <mrsatangel@gmail.com>
Co-authored-by: evenyag <realevenyag@gmail.com>
2022-12-15 17:47:14 +08:00
evenyag
7c696dae08 Merge branch 'develop' into replace-arrow2 2022-12-15 15:29:35 +08:00
LFC
61d8bc2ea1 refactor(frontend): minor changes around FrontendInstance constructor (#748)
* refactor: minor changes in some testing codes

Co-authored-by: luofucong <luofucong@greptime.com>
2022-12-15 14:34:40 +08:00
Yingwen
142dee41d6 fix: Fix compiler errors in script crate (#749)
* fix: Fix compiler errors in state.rs

* fix: fix compiler errors in state

* feat: upgrade sqlparser to 0.26

* fix: fix datafusion engine compiler errors

* fix: Fix some tests in query crate

* fix: Fix all warnings in tests

* feat: Remove `Type` from timestamp's type name

* fix: fix query tests

Now datafusion already supports median, so this commit also remove the
median function

* style: Fix clippy

* feat: Remove RecordBatch::pretty_print

* chore: Address CR comments

* feat: Add column_by_name to RecordBatch

* feat: modify select_from_rb

* feat: Fix some compiler errors in vector.rs

* feat: Fix more compiler errors in vector.rs

* fix: fix table.rs

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix: Fix compiler errors in coprocessor

* fix: Fix some compiler errors

* fix: Fix compiler errors in script

* chore: Remove unused imports and format code

* test: disable interval tests

* test: Fix test_compile_execute test

* style: Fix clippy

* feat: Support interval

* feat: Add RecordBatch::columns and fix clippy

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
Co-authored-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-15 14:20:35 +08:00
Ruihang Xia
e3785fca70 docs: change logo in readme automatically based on github theme (#743)
* docs: adaptive logo on theme

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* switch logos

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* aligh center

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* adjust stylet

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* use new logo image

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-14 19:32:51 +08:00
Lei, HUANG
ce6d1cb7d1 fix: frontend compile errors (#747)
fix: fix compile errors in frontend
2022-12-14 18:30:16 +08:00
Yingwen
dbb3034ecb fix: Fix compiler errors in query crate (#746)
* fix: Fix compiler errors in state.rs

* fix: fix compiler errors in state

* feat: upgrade sqlparser to 0.26

* fix: fix datafusion engine compiler errors

* fix: Fix some tests in query crate

* fix: Fix all warnings in tests

* feat: Remove `Type` from timestamp's type name

* fix: fix query tests

Now datafusion already supports median, so this commit also remove the
median function

* style: Fix clippy

* feat: Remove RecordBatch::pretty_print

* chore: Address CR comments

* Update src/query/src/query_engine/state.rs

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-14 17:42:07 +08:00
shuiyisong
fda9e80cbf feat: impl static_user_provider (#739)
* feat: add MemUserProvider and impl auth

* feat: impl user_provider option in fe and standalone mode

* chore: add file impl for mem provider

* chore: remove mem opts

* chore: minor change

* chore: refac pg server to use user_provider as indicator for using pwd auth

* chore: fix test

* chore: extract common code

* chore: add unit test

* chore: rebase develop

* chore: add user provider to http server

* chore: minor rename

* chore: change to ref when convert to anymap

* chore: fix according to clippy

* chore: remove clone on startcommand

* chore: fix cr issue

* chore: update tempdir use

* chore: change TryFrom to normal func while parsing anymap

* chore: minor change

* chore: remove to_lowercase
2022-12-14 16:38:29 +08:00
Lei, HUANG
756c068166 feat: logstore compaction (#740)
* feat: add benchmark for wal

* add bin

* feat: impl wal compaction

* chore: This reverts commit ef9f2326

* chore: This reverts commit 9142ec0e

* fix: remove empty files

* fix: failing tests

* fix: CR comments

* fix: Mark log as stable after writer applies manifest

* fix: some cr comments and namings

* chore: rename all stable_xxx to obsolete_xxx

* chore: error message
2022-12-14 16:15:29 +08:00
Lei, HUANG
652d59a643 fix: remove unwrap 2022-12-13 17:51:14 +08:00
Lei, HUANG
fa971c6513 fix: errors in optimzer 2022-12-13 17:44:37 +08:00
evenyag
36c929e1a7 fix: Fix imports in optimizer.rs 2022-12-13 17:27:44 +08:00
dennis zhuang
6a4e2e5975 feat: promql create and skeleton (#720)
* feat: adds promql crate

* feat: adds promql-parser dependency and rfc doc

* fix: dependencies order in servers crate

* fix: forgot error.rs

* fix: comment

* fix: license header

* fix: remove docs/rfc/20221207_promql.md
2022-12-13 17:08:22 +08:00
Ruihang Xia
a712382fba Merge pull request #745
* fix nyc-taxi and util

* Merge branch 'replace-arrow2' into fix-others

* fix substrait

* fix warnings and error in test
2022-12-13 16:59:28 +08:00
Yingwen
4b644aa482 fix: Fix compiler errors in catalog and mito crates (#742)
* fix: Fix compiler errors in mito

* fix: Fix compiler errors in catalog crate

* style: Fix clippy

* chore: Fix use
2022-12-13 15:53:55 +08:00
Lei, HUANG
9ad6ddb26e fix: remove useless metaclient field from datanode Instance (#744) 2022-12-13 14:26:26 +08:00
Lei, HUANG
4defde055c feat: upgrade storage crate to arrow and parquet offcial impl (#738)
* fix: compile erros

* fix: parquet reader and writer

* fix: parquet reader and writer

* fix: WriteBatch IPC encode/decode

* fix: clippy errors in storage subcrate

* chore: remove suspicious unwrap

* fix: some cr comments

* fix: CR comments

* fix: CR comments
2022-12-13 11:58:50 +08:00
fys
c5661ee362 feat: support http basic authentication (#733)
* feat: support http auth

* add some unit test and log

* fix

* cr

* remove unused #[derive(Clone)]
2022-12-13 10:44:33 +08:00
zyy17
9b093463cc feat: add Makefile to aggregate the commands that developers always use (#736)
* feat: add Makefile to aggregate the commands that developers always use

* refactor: add 'clean' and 'unit-test' target

* refactor: add sqlness-test target and modify some decriptions format

Signed-off-by: zyy17 <zyylsxm@gmail.com>
2022-12-12 13:03:49 +08:00
zyy17
61e0f1a11c refactor: add tls option in frontend cli options (#735)
* refactor: add tls option in frontend cli options

* fix: add 'Eq' trait for fixing clippy error

* fix: remove redundant clone

Signed-off-by: zyy17 <zyylsxm@gmail.com>
2022-12-12 10:02:17 +08:00
Ruihang Xia
95b2d8654f fix: pre-cast to avoid tremendous match arms (#734)
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-09 17:20:03 +08:00
Ning Sun
249ebc6937 feat: update pgwire and refactor pg auth handler (#732) 2022-12-09 17:01:55 +08:00
Ruihang Xia
42fdc7251a fix: Fix common grpc expr (#730)
* fix compile errors

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* rename fn names

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix styles

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix wranings in common-time

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-09 14:24:04 +08:00
elijah
c1b8981f61 refactor(mito): change the table path to schema/table_id (#728)
refactor: change the table path to `schema/table_id`
2022-12-09 12:59:16 +08:00
Jiachun Feng
949cd3e3af feat: move_value & delete_route (#707)
* feat: move_value & delete_route

* chore: minor refactor

* chore: refactor unit test of metaclient

* chore: map to kv

* Update src/meta-srv/src/service/router.rs

Co-authored-by: Yingwen <realevenyag@gmail.com>

* Update src/meta-srv/src/service/router.rs

Co-authored-by: Yingwen <realevenyag@gmail.com>

* chore: by code review

Co-authored-by: Yingwen <realevenyag@gmail.com>
2022-12-09 11:07:48 +08:00
SSebo
b26982c5d7 feat: support timestamp new syntax (#697)
* feat: support timestamp new syntax

* fix: not null at end of new time stamp index syntax

* chore: simplify code
2022-12-09 10:52:14 +08:00
Ruihang Xia
d0892bf0b7 fix: Fix compile error in server subcrate (#727)
* fix: Fix compile error in server subcrate

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* remove unused type alias

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* explicitly panic

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* Update src/storage/src/sst/parquet.rs

Co-authored-by: Yingwen <realevenyag@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
Co-authored-by: Yingwen <realevenyag@gmail.com>
2022-12-08 20:27:53 +08:00
Ruihang Xia
fff530cb50 fix common record batch
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-08 17:58:53 +08:00
Yingwen
b936d8b18a fix: Fix common::grpc compiler errors (#722)
* fix: Fix common::grpc compiler errors

This commit refactors RecordBatch and holds vectors in the RecordBatch
struct, so we don't need to cast the array to vector when doing
serialization or iterating the batch.

Now we use the vector API instead of the arrow API in grpc crate.

* chore: Address CR comments
2022-12-08 17:51:20 +08:00
Lei, HUANG
1bde1ba399 fix: row group pruning (#725)
* fix: row group pruning

* chore: use macro to simplify stats implemetation

* fxi: CR comments

* fix: row group metadata length mismatch

* fix: simplify code
2022-12-08 17:44:04 +08:00
Ruihang Xia
3687bc7346 fix: Fix tests and clippy for common-function subcrate (#726)
* further fixing

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix all compile errors in common function

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix tests

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix clippy

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* revert test changes

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-08 17:01:54 +08:00
fys
4fdf26810c feat: support auth in frontend (#688)
* feat: add UserProvider trait

* chore: minor fix

* support pg mysql

* refactor and add some logs

* chore: add license

Co-authored-by: shuiyisong <xixing.sys@gmail.com>
2022-12-08 11:51:52 +08:00
Ruihang Xia
587bdc9800 fix: fix other compile error in common-function (#719)
* further fixing

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix all compile errors in common function

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-08 11:38:07 +08:00
dennis zhuang
7f59758e69 feat: bump opendal version to 0.22 (#721)
* feat: bump opendal version to 0.22

* fix: LoggingLayer
2022-12-08 11:19:21 +08:00
Yingwen
58c26def6b fix: fix argmin/percentile/clip/interp/scipy_stats_norm_pdf errors (#718)
fix: fix argmin/percentile/clip/interp/scipy_stats_norm_pdf compiler errors
2022-12-07 19:55:07 +08:00
Ruihang Xia
6f3baf96b0 fix: fix compile error for mean/polyval/pow/interp ops (#717)
* fix: fix compile error for mean/polyval/pow/interp ops

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* simplify type bounds

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-07 16:38:43 +08:00
Yingwen
a898f846d1 fix: Fix compiler errors in argmax/rate/median/norm_cdf (#716)
* fix: Fix compiler errors in argmax/rate/median/norm_cdf

* chore: Address CR comments
2022-12-07 15:28:27 +08:00
Ruihang Xia
a562199455 Revert "fix: fix compile error for mean/polyval/pow/interp ops"
This reverts commit fb0b4eb826.
2022-12-07 15:13:58 +08:00
Ruihang Xia
fb0b4eb826 fix: fix compile error for mean/polyval/pow/interp ops
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-07 15:12:28 +08:00
Zheming Li
a521ab5041 fix: set default value when fail to get git info instead of panic (#696)
fix: set default value when fail to git info instead of panic
2022-12-07 13:16:27 +08:00
LFC
833216d317 refactor: directly invoke Datanode methods in standalone mode (part 1) (#694)
* refactor: directly invoke Datanode methods in standalone mode

* test: add more unit tests

* fix: get rid of `println` in testing codes

* fix: resolve PR comments

* fix: resolve PR comments

Co-authored-by: luofucong <luofucong@greptime.com>
2022-12-07 11:37:59 +08:00
Yingwen
2ba99259e1 feat: Implements diff accumulator using WrapperType (#715)
* feat: Remove usage of opaque error from common::recordbatch

* feat: Remove opaque error from common::query

* feat: Fix diff compiler errors

Now common_function just use common_query's Error and Result. Adds
a LargestType associated type to LogicalPrimitiveType to get the largest
type a logical primitive type can cast to.

* feat: Remove LargestType from NativeType trait

* chore: Update comments

* feat: Restrict Scalar::RefType of WrapperType to itself

Add trait bound `for<'a> Scalar<RefType<'a> = Self>` to WrapperType

* chore: Address CR comments

* chore: Format codes
2022-12-07 11:13:24 +08:00
Ruihang Xia
551cde23b1 Merge branch 'dev' into replace-arrow2 2022-12-07 10:50:27 +08:00
Ruihang Xia
90c832b33d refactor: drop support of physical plan query interface (#714)
* refactor: drop support of physical plan query interface

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* refactor: collapse server/grpc sub-module

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* refactor: remove unused errors

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-06 19:23:32 +08:00
LFC
8959dbcef8 feat: Substrait logical plan (#704)
* feat: use Substrait logical plan to query data from Datanode in Frontend in distributed mode

* fix: resolve PR comments

* fix: resolve PR comments

* fix: resolve PR comments

Co-authored-by: luofucong <luofucong@greptime.com>
2022-12-06 19:21:57 +08:00
Yingwen
653906d4fa fix: Fix common::query compiler errors (#713)
* feat: Move conversion to ScalarValue to value.rs

* fix: Fix common::query compiler errors

This commit also make InnerError pub(crate)
2022-12-06 16:45:54 +08:00
Ruihang Xia
829ff491c4 fix: common-query subcrate (#712)
* fix: record batch adapter

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix error enum

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-06 16:32:52 +08:00
Yingwen
b32438e78c feat: Fix some compiler errors in common::query (#710)
* feat: Fix some compiler errors in common::query

* feat: test_collect use vectors api
2022-12-06 15:32:12 +08:00
Lei, HUANG
0ccb8b4302 chore: delete datatypes based on arrow2 2022-12-06 15:01:57 +08:00
Lei, HUANG
b48ae21b71 fix: api crate (#708)
* fix: rename ConcreteDataType::timestamp_millis_type to ConcreteDataType::timestamp_millisecond_type. fix other warnings regarding timestamp

* fix: revert changes in datatypes2

* fix: helper
2022-12-06 14:56:59 +08:00
discord9
2034b40f33 chore: update RustPython dependence(With a tweaked fork) (#655)
* refactor: update RsPy

* depend: add `rustpython-pylib`

* feat: add_frozen stdlib for every vm init

* feat: limit stdlib to a selected few

* chore: use `rev` instead of branch` im depend

* refactor: rename to allow_list

* feat: use opt level one

* doc: add username for TODO&change optimize to 0

* style: fmt .toml
2022-12-06 14:15:00 +08:00
evenyag
3c0adb00f3 feat: Fix recordbatch test compiling issue 2022-12-06 12:03:06 +08:00
evenyag
8c66b7d000 feat: Fix common::recordbatch compiler errors 2022-12-06 11:55:19 +08:00
evenyag
99371fd31b chore: sort Cargo.toml 2022-12-06 11:39:15 +08:00
evenyag
fe505fecfd feat: Make recordbatch compile 2022-12-06 11:38:59 +08:00
SSebo
55e6be7af1 fix: test_server_require_secure_client_secure (#701) 2022-12-06 10:38:54 +08:00
discord9
f9bfb121db feat: add rate() udf (#508)
* feat: rewrite `rate` UDF

* feat: rename to `prom_rate`

* refactor: solve conflict&add license

* refactor: import arrow
2022-12-06 10:30:13 +08:00
evenyag
cc1ec26416 feat: Switch to datatypes2 2022-12-05 20:30:47 +08:00
Ruihang Xia
504059a699 chore: fix wrong merge commit
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-05 20:11:22 +08:00
Ruihang Xia
7151deb4ed Merge branch 'dev' into replace-arrow2 2022-12-05 20:10:37 +08:00
Ruihang Xia
6fb413ae50 ci: add toml format linter (#706)
* chore: run taplo format

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* ci: add workflow to check toml

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* rerun formatter with ident to 4 spaces

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* update check command

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-12-05 20:03:10 +08:00
Ruihang Xia
beb07fc895 feat: new datatypes subcrate based on the official arrow (#705)
* feat: Init datatypes2 crate

* chore: Remove some unimplemented types

* feat: Implements PrimitiveType and PrimitiveVector for datatypes2 (#633)

* feat: Implement primitive types and vectors

* feat: Implement a wrapper type

* feat: Remove VectorType from ScalarRef

* feat: Move some trait bound from NativeType to WrapperType

* feat: pub use  primitive vectors and builders

* feat: Returns error in try_from when type mismatch

* feat: Impl PartialEq for some vectors

* test: Pass vector tests

* chore: Add license header

* test: Pass more vector tests

* feat: Implement some methods of vector Helper

* test: Pass more tests

* style: Fix clippy

* chore: Add license header

* feat: Remove IntoValueRef trait

* feat: Add NativeType trait bound to WrapperType::Native

* docs: Explain what is wrapper type

* chore: Fix typos

* refactor: LogicalPrimitiveType::type_name returns str

* feat: Implements DateType and DateVector (#651)

* feat: Implement DateType and DateVector

* test: Pass more value and data type tests

* chore: Address CR comments

* test: Skip list value test

* feat: datatypes2 datetime (#661)

* feat: impl DateTime type and vector

* fix: add license header

* fix: CR comments and add more tests

* fix: customized serialization for wrapper type

* feat: Implements NullType and NullVector (#658)

* feat: Implements NullType and NullVector

* chore: Address CR comment

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>

* chore: Address CR comment

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>

* feat: Implements StringType and StringVector (#659)

* feat: implement string vector

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* add more test and from

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix clippy

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* cover NUL

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* feat: impl datatypes2/timestamp (#686)

* feat: add timestamp datatype and vectors

* fix: cr comments and reformat code

* chore: add some tests

* feat: Implements ListType and ListVector (#681)

* feat: Implement ListType and ListVector

* test: Pass more tests

* style: Fix clippy

* chore: Fix comment

* chore: Address CR comments

* feat: impl constant vector (#680)

* feat: impl constant vector

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* fix tests

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* Apply suggestions from code review

Co-authored-by: Yingwen <realevenyag@gmail.com>

* rename fn names

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

* remove println

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
Co-authored-by: Yingwen <realevenyag@gmail.com>

* feat: Implements Validity (#684)

* feat: Implements Validity

* chore: remove pub from sub mod in vectors

* feat: Implements schema for datatypes2 (#695)

* feat: Add is_timestamp_compatible to DataType

* feat: Implement ColumnSchema and Schema

* feat: Impl RawSchema

* chore: Remove useless codes and run more tests

* chore: Fix clippy

* feat: Impl from_arrow_time_unit and pass schema tests

* chore: add more tests for timestamp (#702)

* chore: add more tests for timestamp

* chore: add replicate test for timestamps

* feat: Implements helper methods for vectors/values (#703)

* feat: Implement helper methods for vectors/values

* chore: Address CR comments

* chore: add more test for timestamp

Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
Co-authored-by: evenyag <realevenyag@gmail.com>
Co-authored-by: Lei, HUANG <6406592+v0y4g3r@users.noreply.github.com>
Co-authored-by: Lei, HUANG <mrsatangel@gmail.com>
2022-12-05 19:59:23 +08:00
Ning Sun
4275e47bdb refactor: use updated mysql_async client (#698) 2022-12-05 11:18:32 +08:00
dennis zhuang
6720bc5f7c fix: validate create table request in mito engine (#690)
* fix: validate create table request in mito engine

* fix: comment

* chore: remove TIMESTAMP_INDEX in system.rs
2022-12-05 11:01:43 +08:00
Ruihang Xia
d0686f9c19 Merge branch 'replace-arrow2' of github.com:GreptimeTeam/greptimedb into replace-arrow2 2022-11-21 17:43:40 +08:00
Ruihang Xia
221f3e9d2e Merge branch 'dev' into replace-arrow2 2022-11-21 17:42:15 +08:00
evenyag
61c4a3691a chore: update dep of binary vector 2022-11-21 15:55:07 +08:00
evenyag
d7626fd6af feat: arrow_array switch to arrow 2022-11-21 15:39:41 +08:00
Ruihang Xia
e3201a4705 chore: replace one last datafusion dep
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-11-21 14:29:59 +08:00
Ruihang Xia
571a84d91b chore: kick off. change datafusion/arrow/parquet to target version
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2022-11-21 14:19:39 +08:00
332 changed files with 13293 additions and 10806 deletions

View File

@@ -49,6 +49,23 @@ jobs:
- name: Run cargo check - name: Run cargo check
run: cargo check --workspace --all-targets run: cargo check --workspace --all-targets
toml:
name: Toml Check
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
timeout-minutes: 60
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@master
with:
toolchain: ${{ env.RUST_TOOLCHAIN }}
- name: Rust Cache
uses: Swatinem/rust-cache@v2
- name: Install taplo
run: cargo install taplo-cli --version ^0.8 --locked
- name: Run taplo
run: taplo format --check --option "indent_string= "
# Use coverage to run test. # Use coverage to run test.
# test: # test:
# name: Test Suite # name: Test Suite

2484
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -26,6 +26,7 @@ members = [
"src/meta-srv", "src/meta-srv",
"src/mito", "src/mito",
"src/object-store", "src/object-store",
"src/promql",
"src/query", "src/query",
"src/script", "src/script",
"src/servers", "src/servers",
@@ -34,9 +35,9 @@ members = [
"src/storage", "src/storage",
"src/store-api", "src/store-api",
"src/table", "src/table",
"tests-integration" "tests-integration",
, "tests/runner",
"tests/runner"] ]
[profile.release] [profile.release]
debug = true debug = true

67
Makefile Normal file
View File

@@ -0,0 +1,67 @@
IMAGE_REGISTRY ?= greptimedb
IMAGE_TAG ?= latest
##@ Build
.PHONY: build
build: ## Build debug version greptime.
cargo build
.PHONY: release
release: ## Build release version greptime.
cargo build --release
.PHONY: clean
clean: ## Clean the project.
cargo clean
.PHONY: fmt
fmt: ## Format all the Rust code.
cargo fmt --all
.PHONY: docker-image
docker-image: ## Build docker image.
docker build --network host -f docker/Dockerfile -t ${IMAGE_REGISTRY}:${IMAGE_TAG} .
##@ Test
.PHONY: unit-test
unit-test: ## Run unit test.
cargo test --workspace
.PHONY: integration-test
integration-test: ## Run integation test.
cargo test integration
.PHONY: sqlness-test
sqlness-test: ## Run sqlness test.
cargo run --bin sqlness-runner
.PHONY: check
check: ## Cargo check all the targets.
cargo check --workspace --all-targets
.PHONY: clippy
clippy: ## Check clippy rules.
cargo clippy --workspace --all-targets -- -D warnings -D clippy::print_stdout -D clippy::print_stderr
.PHONY: fmt-check
fmt-check: ## Check code format.
cargo fmt --all -- --check
##@ General
# The help target prints out all targets with their descriptions organized
# beneath their categories. The categories are represented by '##@' and the
# target descriptions by '##'. The awk commands is responsible for reading the
# entire set of makefiles included in this invocation, looking for lines of the
# file as xyz: ## something, and then pretty-format the target and help. Then,
# if there's a line with ##@ something, that gets pretty-printed as a category.
# More info on the usage of ANSI control characters for terminal formatting:
# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters
# More info on the awk command:
# https://linuxcommand.org/lc3_adv_awk.php
.PHONY: help
help: ## Display help messages.
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m<target>\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-20s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)

View File

@@ -1,7 +1,12 @@
<p align="center"> <p align="center">
<img src="/docs/logo-text-padding.png" alt="GreptimeDB Logo" width="400px"></img> <picture>
<source media="(prefers-color-scheme: light)" srcset="/docs/logo-text-padding.png">
<source media="(prefers-color-scheme: dark)" srcset="/docs/logo-text-padding-dark.png">
<img alt="GreptimeDB Logo" src="/docs/logo-text-padding.png" width="400px">
</picture>
</p> </p>
<h3 align="center"> <h3 align="center">
The next-generation hybrid timeseries/analytics processing database in the cloud The next-generation hybrid timeseries/analytics processing database in the cloud
</h3> </h3>

View File

@@ -5,10 +5,10 @@ edition = "2021"
license = "Apache-2.0" license = "Apache-2.0"
[dependencies] [dependencies]
arrow = "10" arrow = "26.0.0"
clap = { version = "4.0", features = ["derive"] } clap = { version = "4.0", features = ["derive"] }
client = { path = "../src/client" } client = { path = "../src/client" }
indicatif = "0.17.1" indicatif = "0.17.1"
itertools = "0.10.5" itertools = "0.10.5"
parquet = { version = "*" } parquet = "26.0.0"
tokio = { version = "1.21", features = ["full"] } tokio = { version = "1.21", features = ["full"] }

View File

@@ -20,7 +20,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use arrow::array::{ArrayRef, PrimitiveArray, StringArray, TimestampNanosecondArray}; use arrow::array::{ArrayRef, PrimitiveArray, StringArray, TimestampNanosecondArray};
@@ -32,9 +31,7 @@ use client::api::v1::column::Values;
use client::api::v1::{Column, ColumnDataType, ColumnDef, CreateExpr, InsertExpr}; use client::api::v1::{Column, ColumnDataType, ColumnDef, CreateExpr, InsertExpr};
use client::{Client, Database, Select}; use client::{Client, Database, Select};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use parquet::arrow::{ArrowReader, ParquetFileArrowReader}; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use parquet::file::reader::FileReader;
use parquet::file::serialized_reader::SerializedFileReader;
use tokio::task::JoinSet; use tokio::task::JoinSet;
const DATABASE_NAME: &str = "greptime"; const DATABASE_NAME: &str = "greptime";
@@ -86,10 +83,14 @@ async fn write_data(
pb_style: ProgressStyle, pb_style: ProgressStyle,
) -> u128 { ) -> u128 {
let file = std::fs::File::open(&path).unwrap(); let file = std::fs::File::open(&path).unwrap();
let file_reader = Arc::new(SerializedFileReader::new(file).unwrap()); let record_batch_reader_builder = ParquetRecordBatchReaderBuilder::try_new(file).unwrap();
let row_num = file_reader.metadata().file_metadata().num_rows(); let row_num = record_batch_reader_builder
let record_batch_reader = ParquetFileArrowReader::new(file_reader) .metadata()
.get_record_reader(batch_size) .file_metadata()
.num_rows();
let record_batch_reader = record_batch_reader_builder
.with_batch_size(batch_size)
.build()
.unwrap(); .unwrap();
let progress_bar = mpb.add(ProgressBar::new(row_num as _)); let progress_bar = mpb.add(ProgressBar::new(row_num as _));
progress_bar.set_style(pb_style); progress_bar.set_style(pb_style);
@@ -210,9 +211,10 @@ fn build_values(column: &ArrayRef) -> Values {
| DataType::FixedSizeList(_, _) | DataType::FixedSizeList(_, _)
| DataType::LargeList(_) | DataType::LargeList(_)
| DataType::Struct(_) | DataType::Struct(_)
| DataType::Union(_, _) | DataType::Union(_, _, _)
| DataType::Dictionary(_, _) | DataType::Dictionary(_, _)
| DataType::Decimal(_, _) | DataType::Decimal128(_, _)
| DataType::Decimal256(_, _)
| DataType::Map(_, _) => todo!(), | DataType::Map(_, _) => todo!(),
} }
} }

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

View File

@@ -21,7 +21,6 @@ fn main() {
.compile( .compile(
&[ &[
"greptime/v1/select.proto", "greptime/v1/select.proto",
"greptime/v1/physical_plan.proto",
"greptime/v1/greptime.proto", "greptime/v1/greptime.proto",
"greptime/v1/meta/common.proto", "greptime/v1/meta/common.proto",
"greptime/v1/meta/heartbeat.proto", "greptime/v1/meta/heartbeat.proto",

View File

@@ -32,7 +32,10 @@ message Column {
repeated int32 date_values = 14; repeated int32 date_values = 14;
repeated int64 datetime_values = 15; repeated int64 datetime_values = 15;
repeated int64 ts_millis_values = 16; repeated int64 ts_second_values = 16;
repeated int64 ts_millisecond_values = 17;
repeated int64 ts_microsecond_values = 18;
repeated int64 ts_nanosecond_values = 19;
} }
// The array of non-null values in this column. // The array of non-null values in this column.
// //
@@ -75,5 +78,8 @@ enum ColumnDataType {
STRING = 12; STRING = 12;
DATE = 13; DATE = 13;
DATETIME = 14; DATETIME = 14;
TIMESTAMP = 15; TIMESTAMP_SECOND = 15;
TIMESTAMP_MILLISECOND = 16;
TIMESTAMP_MICROSECOND = 17;
TIMESTAMP_NANOSECOND = 18;
} }

View File

@@ -29,15 +29,9 @@ message SelectExpr {
oneof expr { oneof expr {
string sql = 1; string sql = 1;
bytes logical_plan = 2; bytes logical_plan = 2;
PhysicalPlan physical_plan = 15;
} }
} }
message PhysicalPlan {
bytes original_ql = 1;
bytes plan = 2;
}
message InsertExpr { message InsertExpr {
string schema_name = 1; string schema_name = 1;
string table_name = 2; string table_name = 2;

View File

@@ -5,6 +5,8 @@ package greptime.v1.meta;
import "greptime/v1/meta/common.proto"; import "greptime/v1/meta/common.proto";
service Router { service Router {
rpc Create(CreateRequest) returns (RouteResponse) {}
// Fetch routing information for tables. The smallest unit is the complete // Fetch routing information for tables. The smallest unit is the complete
// routing information(all regions) of a table. // routing information(all regions) of a table.
// //
@@ -26,7 +28,14 @@ service Router {
// //
rpc Route(RouteRequest) returns (RouteResponse) {} rpc Route(RouteRequest) returns (RouteResponse) {}
rpc Create(CreateRequest) returns (RouteResponse) {} rpc Delete(DeleteRequest) returns (RouteResponse) {}
}
message CreateRequest {
RequestHeader header = 1;
TableName table_name = 2;
repeated Partition partitions = 3;
} }
message RouteRequest { message RouteRequest {
@@ -35,6 +44,12 @@ message RouteRequest {
repeated TableName table_names = 2; repeated TableName table_names = 2;
} }
message DeleteRequest {
RequestHeader header = 1;
TableName table_name = 2;
}
message RouteResponse { message RouteResponse {
ResponseHeader header = 1; ResponseHeader header = 1;
@@ -42,13 +57,6 @@ message RouteResponse {
repeated TableRoute table_routes = 3; repeated TableRoute table_routes = 3;
} }
message CreateRequest {
RequestHeader header = 1;
TableName table_name = 2;
repeated Partition partitions = 3;
}
message TableRoute { message TableRoute {
Table table = 1; Table table = 1;
repeated RegionRoute region_routes = 2; repeated RegionRoute region_routes = 2;

View File

@@ -20,6 +20,9 @@ service Store {
// DeleteRange deletes the given range from the key-value store. // DeleteRange deletes the given range from the key-value store.
rpc DeleteRange(DeleteRangeRequest) returns (DeleteRangeResponse); rpc DeleteRange(DeleteRangeRequest) returns (DeleteRangeResponse);
// MoveValue atomically renames the key to the given updated key.
rpc MoveValue(MoveValueRequest) returns (MoveValueResponse);
} }
message RangeRequest { message RangeRequest {
@@ -136,3 +139,21 @@ message DeleteRangeResponse {
// returned. // returned.
repeated KeyValue prev_kvs = 3; repeated KeyValue prev_kvs = 3;
} }
message MoveValueRequest {
RequestHeader header = 1;
// If from_key dose not exist, return the value of to_key (if it exists).
// If from_key exists, move the value of from_key to to_key (i.e. rename),
// and return the value.
bytes from_key = 2;
bytes to_key = 3;
}
message MoveValueResponse {
ResponseHeader header = 1;
// If from_key dose not exist, return the value of to_key (if it exists).
// If from_key exists, return the value of from_key.
KeyValue kv = 2;
}

View File

@@ -1,33 +0,0 @@
syntax = "proto3";
package greptime.v1.codec;
message PhysicalPlanNode {
oneof PhysicalPlanType {
ProjectionExecNode projection = 1;
MockInputExecNode mock = 99;
// TODO(fys): impl other physical plan node
}
}
message ProjectionExecNode {
PhysicalPlanNode input = 1;
repeated PhysicalExprNode expr = 2;
repeated string expr_name = 3;
}
message PhysicalExprNode {
oneof ExprType {
PhysicalColumn column = 1;
// TODO(fys): impl other physical expr node
}
}
message PhysicalColumn {
string name = 1;
uint64 index = 2;
}
message MockInputExecNode {
string name = 1;
}

View File

@@ -15,6 +15,7 @@
use common_base::BitVec; use common_base::BitVec;
use common_time::timestamp::TimeUnit; use common_time::timestamp::TimeUnit;
use datatypes::prelude::ConcreteDataType; use datatypes::prelude::ConcreteDataType;
use datatypes::types::TimestampType;
use datatypes::value::Value; use datatypes::value::Value;
use datatypes::vectors::VectorRef; use datatypes::vectors::VectorRef;
use snafu::prelude::*; use snafu::prelude::*;
@@ -56,7 +57,16 @@ impl From<ColumnDataTypeWrapper> for ConcreteDataType {
ColumnDataType::String => ConcreteDataType::string_datatype(), ColumnDataType::String => ConcreteDataType::string_datatype(),
ColumnDataType::Date => ConcreteDataType::date_datatype(), ColumnDataType::Date => ConcreteDataType::date_datatype(),
ColumnDataType::Datetime => ConcreteDataType::datetime_datatype(), ColumnDataType::Datetime => ConcreteDataType::datetime_datatype(),
ColumnDataType::Timestamp => ConcreteDataType::timestamp_millis_datatype(), ColumnDataType::TimestampSecond => ConcreteDataType::timestamp_second_datatype(),
ColumnDataType::TimestampMillisecond => {
ConcreteDataType::timestamp_millisecond_datatype()
}
ColumnDataType::TimestampMicrosecond => {
ConcreteDataType::timestamp_microsecond_datatype()
}
ColumnDataType::TimestampNanosecond => {
ConcreteDataType::timestamp_nanosecond_datatype()
}
} }
} }
} }
@@ -81,7 +91,12 @@ impl TryFrom<ConcreteDataType> for ColumnDataTypeWrapper {
ConcreteDataType::String(_) => ColumnDataType::String, ConcreteDataType::String(_) => ColumnDataType::String,
ConcreteDataType::Date(_) => ColumnDataType::Date, ConcreteDataType::Date(_) => ColumnDataType::Date,
ConcreteDataType::DateTime(_) => ColumnDataType::Datetime, ConcreteDataType::DateTime(_) => ColumnDataType::Datetime,
ConcreteDataType::Timestamp(_) => ColumnDataType::Timestamp, ConcreteDataType::Timestamp(unit) => match unit {
TimestampType::Second(_) => ColumnDataType::TimestampSecond,
TimestampType::Millisecond(_) => ColumnDataType::TimestampMillisecond,
TimestampType::Microsecond(_) => ColumnDataType::TimestampMicrosecond,
TimestampType::Nanosecond(_) => ColumnDataType::TimestampNanosecond,
},
ConcreteDataType::Null(_) | ConcreteDataType::List(_) => { ConcreteDataType::Null(_) | ConcreteDataType::List(_) => {
return error::IntoColumnDataTypeSnafu { from: datatype }.fail() return error::IntoColumnDataTypeSnafu { from: datatype }.fail()
} }
@@ -153,8 +168,20 @@ impl Values {
datetime_values: Vec::with_capacity(capacity), datetime_values: Vec::with_capacity(capacity),
..Default::default() ..Default::default()
}, },
ColumnDataType::Timestamp => Values { ColumnDataType::TimestampSecond => Values {
ts_millis_values: Vec::with_capacity(capacity), ts_second_values: Vec::with_capacity(capacity),
..Default::default()
},
ColumnDataType::TimestampMillisecond => Values {
ts_millisecond_values: Vec::with_capacity(capacity),
..Default::default()
},
ColumnDataType::TimestampMicrosecond => Values {
ts_microsecond_values: Vec::with_capacity(capacity),
..Default::default()
},
ColumnDataType::TimestampNanosecond => Values {
ts_nanosecond_values: Vec::with_capacity(capacity),
..Default::default() ..Default::default()
}, },
} }
@@ -187,9 +214,12 @@ impl Column {
Value::Binary(val) => values.binary_values.push(val.to_vec()), Value::Binary(val) => values.binary_values.push(val.to_vec()),
Value::Date(val) => values.date_values.push(val.val()), Value::Date(val) => values.date_values.push(val.val()),
Value::DateTime(val) => values.datetime_values.push(val.val()), Value::DateTime(val) => values.datetime_values.push(val.val()),
Value::Timestamp(val) => values Value::Timestamp(val) => match val.unit() {
.ts_millis_values TimeUnit::Second => values.ts_second_values.push(val.value()),
.push(val.convert_to(TimeUnit::Millisecond)), TimeUnit::Millisecond => values.ts_millisecond_values.push(val.value()),
TimeUnit::Microsecond => values.ts_microsecond_values.push(val.value()),
TimeUnit::Nanosecond => values.ts_nanosecond_values.push(val.value()),
},
Value::List(_) => unreachable!(), Value::List(_) => unreachable!(),
}); });
self.null_mask = null_mask.into_vec(); self.null_mask = null_mask.into_vec();
@@ -200,7 +230,10 @@ impl Column {
mod tests { mod tests {
use std::sync::Arc; use std::sync::Arc;
use datatypes::vectors::BooleanVector; use datatypes::vectors::{
BooleanVector, TimestampMicrosecondVector, TimestampMillisecondVector,
TimestampNanosecondVector, TimestampSecondVector,
};
use super::*; use super::*;
@@ -258,8 +291,8 @@ mod tests {
let values = values.datetime_values; let values = values.datetime_values;
assert_eq!(2, values.capacity()); assert_eq!(2, values.capacity());
let values = Values::with_capacity(ColumnDataType::Timestamp, 2); let values = Values::with_capacity(ColumnDataType::TimestampMillisecond, 2);
let values = values.ts_millis_values; let values = values.ts_millisecond_values;
assert_eq!(2, values.capacity()); assert_eq!(2, values.capacity());
} }
@@ -326,8 +359,8 @@ mod tests {
ColumnDataTypeWrapper(ColumnDataType::Datetime).into() ColumnDataTypeWrapper(ColumnDataType::Datetime).into()
); );
assert_eq!( assert_eq!(
ConcreteDataType::timestamp_millis_datatype(), ConcreteDataType::timestamp_millisecond_datatype(),
ColumnDataTypeWrapper(ColumnDataType::Timestamp).into() ColumnDataTypeWrapper(ColumnDataType::TimestampMillisecond).into()
); );
} }
@@ -394,8 +427,8 @@ mod tests {
ConcreteDataType::datetime_datatype().try_into().unwrap() ConcreteDataType::datetime_datatype().try_into().unwrap()
); );
assert_eq!( assert_eq!(
ColumnDataTypeWrapper(ColumnDataType::Timestamp), ColumnDataTypeWrapper(ColumnDataType::TimestampMillisecond),
ConcreteDataType::timestamp_millis_datatype() ConcreteDataType::timestamp_millisecond_datatype()
.try_into() .try_into()
.unwrap() .unwrap()
); );
@@ -412,7 +445,48 @@ mod tests {
assert!(result.is_err()); assert!(result.is_err());
assert_eq!( assert_eq!(
result.unwrap_err().to_string(), result.unwrap_err().to_string(),
"Failed to create column datatype from List(ListType { inner: Boolean(BooleanType) })" "Failed to create column datatype from List(ListType { item_type: Boolean(BooleanType) })"
);
}
#[test]
fn test_column_put_timestamp_values() {
let mut column = Column {
column_name: "test".to_string(),
semantic_type: 0,
values: Some(Values {
..Default::default()
}),
null_mask: vec![],
datatype: 0,
};
let vector = Arc::new(TimestampNanosecondVector::from_vec(vec![1, 2, 3]));
column.push_vals(3, vector);
assert_eq!(
vec![1, 2, 3],
column.values.as_ref().unwrap().ts_nanosecond_values
);
let vector = Arc::new(TimestampMillisecondVector::from_vec(vec![4, 5, 6]));
column.push_vals(3, vector);
assert_eq!(
vec![4, 5, 6],
column.values.as_ref().unwrap().ts_millisecond_values
);
let vector = Arc::new(TimestampMicrosecondVector::from_vec(vec![7, 8, 9]));
column.push_vals(3, vector);
assert_eq!(
vec![7, 8, 9],
column.values.as_ref().unwrap().ts_microsecond_values
);
let vector = Arc::new(TimestampSecondVector::from_vec(vec![10, 11, 12]));
column.push_vals(3, vector);
assert_eq!(
vec![10, 11, 12],
column.values.as_ref().unwrap().ts_second_values
); );
} }

View File

@@ -15,7 +15,7 @@
pub use prost::DecodeError; pub use prost::DecodeError;
use prost::Message; use prost::Message;
use crate::v1::codec::{PhysicalPlanNode, SelectResult}; use crate::v1::codec::SelectResult;
use crate::v1::meta::TableRouteValue; use crate::v1::meta::TableRouteValue;
macro_rules! impl_convert_with_bytes { macro_rules! impl_convert_with_bytes {
@@ -37,7 +37,6 @@ macro_rules! impl_convert_with_bytes {
} }
impl_convert_with_bytes!(SelectResult); impl_convert_with_bytes!(SelectResult);
impl_convert_with_bytes!(PhysicalPlanNode);
impl_convert_with_bytes!(TableRouteValue); impl_convert_with_bytes!(TableRouteValue);
#[cfg(test)] #[cfg(test)]

View File

@@ -145,10 +145,12 @@ gen_set_header!(HeartbeatRequest);
gen_set_header!(RouteRequest); gen_set_header!(RouteRequest);
gen_set_header!(CreateRequest); gen_set_header!(CreateRequest);
gen_set_header!(RangeRequest); gen_set_header!(RangeRequest);
gen_set_header!(DeleteRequest);
gen_set_header!(PutRequest); gen_set_header!(PutRequest);
gen_set_header!(BatchPutRequest); gen_set_header!(BatchPutRequest);
gen_set_header!(CompareAndPutRequest); gen_set_header!(CompareAndPutRequest);
gen_set_header!(DeleteRangeRequest); gen_set_header!(DeleteRangeRequest);
gen_set_header!(MoveValueRequest);
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {

View File

@@ -19,9 +19,7 @@ common-recordbatch = { path = "../common/recordbatch" }
common-runtime = { path = "../common/runtime" } common-runtime = { path = "../common/runtime" }
common-telemetry = { path = "../common/telemetry" } common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" } common-time = { path = "../common/time" }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [ datafusion = "14.0.0"
"simd",
] }
datatypes = { path = "../datatypes" } datatypes = { path = "../datatypes" }
futures = "0.3" futures = "0.3"
futures-util = "0.3" futures-util = "0.3"

View File

@@ -17,7 +17,7 @@ use std::any::Any;
use common_error::ext::{BoxedError, ErrorExt}; use common_error::ext::{BoxedError, ErrorExt};
use common_error::prelude::{Snafu, StatusCode}; use common_error::prelude::{Snafu, StatusCode};
use datafusion::error::DataFusionError; use datafusion::error::DataFusionError;
use datatypes::arrow; use datatypes::prelude::ConcreteDataType;
use datatypes::schema::RawSchema; use datatypes::schema::RawSchema;
use snafu::{Backtrace, ErrorCompat}; use snafu::{Backtrace, ErrorCompat};
@@ -51,14 +51,12 @@ pub enum Error {
SystemCatalog { msg: String, backtrace: Backtrace }, SystemCatalog { msg: String, backtrace: Backtrace },
#[snafu(display( #[snafu(display(
"System catalog table type mismatch, expected: binary, found: {:?} source: {}", "System catalog table type mismatch, expected: binary, found: {:?}",
data_type, data_type,
source
))] ))]
SystemCatalogTypeMismatch { SystemCatalogTypeMismatch {
data_type: arrow::datatypes::DataType, data_type: ConcreteDataType,
#[snafu(backtrace)] backtrace: Backtrace,
source: datatypes::error::Error,
}, },
#[snafu(display("Invalid system catalog entry type: {:?}", entry_type))] #[snafu(display("Invalid system catalog entry type: {:?}", entry_type))]
@@ -222,10 +220,11 @@ impl ErrorExt for Error {
| Error::ValueDeserialize { .. } | Error::ValueDeserialize { .. }
| Error::Io { .. } => StatusCode::StorageUnavailable, | Error::Io { .. } => StatusCode::StorageUnavailable,
Error::RegisterTable { .. } => StatusCode::Internal, Error::RegisterTable { .. } | Error::SystemCatalogTypeMismatch { .. } => {
StatusCode::Internal
}
Error::ReadSystemCatalog { source, .. } => source.status_code(), Error::ReadSystemCatalog { source, .. } => source.status_code(),
Error::SystemCatalogTypeMismatch { source, .. } => source.status_code(),
Error::InvalidCatalogValue { source, .. } => source.status_code(), Error::InvalidCatalogValue { source, .. } => source.status_code(),
Error::TableExists { .. } => StatusCode::TableAlreadyExists, Error::TableExists { .. } => StatusCode::TableAlreadyExists,
@@ -265,7 +264,6 @@ impl From<Error> for DataFusionError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use common_error::mock::MockError; use common_error::mock::MockError;
use datatypes::arrow::datatypes::DataType;
use snafu::GenerateImplicitData; use snafu::GenerateImplicitData;
use super::*; use super::*;
@@ -314,11 +312,8 @@ mod tests {
assert_eq!( assert_eq!(
StatusCode::Internal, StatusCode::Internal,
Error::SystemCatalogTypeMismatch { Error::SystemCatalogTypeMismatch {
data_type: DataType::Boolean, data_type: ConcreteDataType::binary_datatype(),
source: datatypes::error::Error::UnsupportedArrowType { backtrace: Backtrace::generate(),
arrow_type: DataType::Boolean,
backtrace: Backtrace::generate()
}
} }
.status_code() .status_code()
); );

View File

@@ -15,18 +15,19 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{Display, Formatter}; use std::fmt::{Display, Formatter};
use common_catalog::error::{
DeserializeCatalogEntryValueSnafu, Error, InvalidCatalogSnafu, SerializeCatalogEntryValueSnafu,
};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Serialize, Serializer}; use serde::{Deserialize, Serialize, Serializer};
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
use table::metadata::{RawTableInfo, TableId, TableVersion}; use table::metadata::{RawTableInfo, TableId, TableVersion};
use crate::consts::{ const CATALOG_KEY_PREFIX: &str = "__c";
CATALOG_KEY_PREFIX, SCHEMA_KEY_PREFIX, TABLE_GLOBAL_KEY_PREFIX, TABLE_REGIONAL_KEY_PREFIX, const SCHEMA_KEY_PREFIX: &str = "__s";
}; const TABLE_GLOBAL_KEY_PREFIX: &str = "__tg";
use crate::error::{ const TABLE_REGIONAL_KEY_PREFIX: &str = "__tr";
DeserializeCatalogEntryValueSnafu, Error, InvalidCatalogSnafu, SerializeCatalogEntryValueSnafu,
};
const ALPHANUMERICS_NAME_PATTERN: &str = "[a-zA-Z_][a-zA-Z0-9_]*"; const ALPHANUMERICS_NAME_PATTERN: &str = "[a-zA-Z_][a-zA-Z0-9_]*";
@@ -137,7 +138,7 @@ impl TableGlobalKey {
/// Table global info contains necessary info for a datanode to create table regions, including /// Table global info contains necessary info for a datanode to create table regions, including
/// table id, table meta(schema...), region id allocation across datanodes. /// table id, table meta(schema...), region id allocation across datanodes.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct TableGlobalValue { pub struct TableGlobalValue {
/// Id of datanode that created the global table info kv. only for debugging. /// Id of datanode that created the global table info kv. only for debugging.
pub node_id: u64, pub node_id: u64,

View File

@@ -29,6 +29,7 @@ use crate::error::{CreateTableSnafu, Result};
pub use crate::schema::{SchemaProvider, SchemaProviderRef}; pub use crate::schema::{SchemaProvider, SchemaProviderRef};
pub mod error; pub mod error;
pub mod helper;
pub mod local; pub mod local;
pub mod remote; pub mod remote;
pub mod schema; pub mod schema;

View File

@@ -145,27 +145,34 @@ impl LocalCatalogManager {
/// Convert `RecordBatch` to a vector of `Entry`. /// Convert `RecordBatch` to a vector of `Entry`.
fn record_batch_to_entry(rb: RecordBatch) -> Result<Vec<Entry>> { fn record_batch_to_entry(rb: RecordBatch) -> Result<Vec<Entry>> {
ensure!( ensure!(
rb.df_recordbatch.columns().len() >= 6, rb.num_columns() >= 6,
SystemCatalogSnafu { SystemCatalogSnafu {
msg: format!("Length mismatch: {}", rb.df_recordbatch.columns().len()) msg: format!("Length mismatch: {}", rb.num_columns())
} }
); );
let entry_type = UInt8Vector::try_from_arrow_array(&rb.df_recordbatch.columns()[0]) let entry_type = rb
.with_context(|_| SystemCatalogTypeMismatchSnafu { .column(ENTRY_TYPE_INDEX)
data_type: rb.df_recordbatch.columns()[ENTRY_TYPE_INDEX] .as_any()
.data_type() .downcast_ref::<UInt8Vector>()
.clone(), .with_context(|| SystemCatalogTypeMismatchSnafu {
data_type: rb.column(ENTRY_TYPE_INDEX).data_type(),
})?; })?;
let key = BinaryVector::try_from_arrow_array(&rb.df_recordbatch.columns()[1]) let key = rb
.with_context(|_| SystemCatalogTypeMismatchSnafu { .column(KEY_INDEX)
data_type: rb.df_recordbatch.columns()[KEY_INDEX].data_type().clone(), .as_any()
.downcast_ref::<BinaryVector>()
.with_context(|| SystemCatalogTypeMismatchSnafu {
data_type: rb.column(KEY_INDEX).data_type(),
})?; })?;
let value = BinaryVector::try_from_arrow_array(&rb.df_recordbatch.columns()[3]) let value = rb
.with_context(|_| SystemCatalogTypeMismatchSnafu { .column(VALUE_INDEX)
data_type: rb.df_recordbatch.columns()[VALUE_INDEX].data_type().clone(), .as_any()
.downcast_ref::<BinaryVector>()
.with_context(|| SystemCatalogTypeMismatchSnafu {
data_type: rb.column(VALUE_INDEX).data_type(),
})?; })?;
let mut res = Vec::with_capacity(rb.num_rows()); let mut res = Vec::with_capacity(rb.num_rows());

View File

@@ -20,10 +20,6 @@ use std::sync::Arc;
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use async_stream::stream; use async_stream::stream;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, MIN_USER_TABLE_ID}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, MIN_USER_TABLE_ID};
use common_catalog::{
build_catalog_prefix, build_schema_prefix, build_table_global_prefix, CatalogKey, CatalogValue,
SchemaKey, SchemaValue, TableGlobalKey, TableGlobalValue, TableRegionalKey, TableRegionalValue,
};
use common_telemetry::{debug, info}; use common_telemetry::{debug, info};
use futures::Stream; use futures::Stream;
use futures_util::StreamExt; use futures_util::StreamExt;
@@ -39,6 +35,10 @@ use crate::error::{
CatalogNotFoundSnafu, CreateTableSnafu, InvalidCatalogValueSnafu, InvalidTableSchemaSnafu, CatalogNotFoundSnafu, CreateTableSnafu, InvalidCatalogValueSnafu, InvalidTableSchemaSnafu,
OpenTableSnafu, Result, SchemaNotFoundSnafu, TableExistsSnafu, UnimplementedSnafu, OpenTableSnafu, Result, SchemaNotFoundSnafu, TableExistsSnafu, UnimplementedSnafu,
}; };
use crate::helper::{
build_catalog_prefix, build_schema_prefix, build_table_global_prefix, CatalogKey, CatalogValue,
SchemaKey, SchemaValue, TableGlobalKey, TableGlobalValue, TableRegionalKey, TableRegionalValue,
};
use crate::remote::{Kv, KvBackendRef}; use crate::remote::{Kv, KvBackendRef};
use crate::{ use crate::{
handle_system_table_request, CatalogList, CatalogManager, CatalogProvider, CatalogProviderRef, handle_system_table_request, CatalogList, CatalogManager, CatalogProvider, CatalogProviderRef,

View File

@@ -21,14 +21,13 @@ use common_catalog::consts::{
SYSTEM_CATALOG_TABLE_ID, SYSTEM_CATALOG_TABLE_NAME, SYSTEM_CATALOG_TABLE_ID, SYSTEM_CATALOG_TABLE_NAME,
}; };
use common_query::logical_plan::Expr; use common_query::logical_plan::Expr;
use common_query::physical_plan::{PhysicalPlanRef, RuntimeEnv}; use common_query::physical_plan::{PhysicalPlanRef, SessionContext};
use common_recordbatch::SendableRecordBatchStream; use common_recordbatch::SendableRecordBatchStream;
use common_telemetry::debug; use common_telemetry::debug;
use common_time::timestamp::Timestamp;
use common_time::util; use common_time::util;
use datatypes::prelude::{ConcreteDataType, ScalarVector}; use datatypes::prelude::{ConcreteDataType, ScalarVector};
use datatypes::schema::{ColumnSchema, Schema, SchemaBuilder, SchemaRef}; use datatypes::schema::{ColumnSchema, Schema, SchemaBuilder, SchemaRef};
use datatypes::vectors::{BinaryVector, TimestampVector, UInt8Vector}; use datatypes::vectors::{BinaryVector, TimestampMillisecondVector, UInt8Vector};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
use table::engine::{EngineContext, TableEngineRef}; use table::engine::{EngineContext, TableEngineRef};
@@ -43,7 +42,6 @@ use crate::error::{
pub const ENTRY_TYPE_INDEX: usize = 0; pub const ENTRY_TYPE_INDEX: usize = 0;
pub const KEY_INDEX: usize = 1; pub const KEY_INDEX: usize = 1;
pub const TIMESTAMP_INDEX: usize = 2;
pub const VALUE_INDEX: usize = 3; pub const VALUE_INDEX: usize = 3;
pub struct SystemCatalogTable { pub struct SystemCatalogTable {
@@ -111,7 +109,7 @@ impl SystemCatalogTable {
desc: Some("System catalog table".to_string()), desc: Some("System catalog table".to_string()),
schema: schema.clone(), schema: schema.clone(),
region_numbers: vec![0], region_numbers: vec![0],
primary_key_indices: vec![ENTRY_TYPE_INDEX, KEY_INDEX, TIMESTAMP_INDEX], primary_key_indices: vec![ENTRY_TYPE_INDEX, KEY_INDEX],
create_if_not_exists: true, create_if_not_exists: true,
table_options: HashMap::new(), table_options: HashMap::new(),
}; };
@@ -128,13 +126,14 @@ impl SystemCatalogTable {
/// Create a stream of all entries inside system catalog table /// Create a stream of all entries inside system catalog table
pub async fn records(&self) -> Result<SendableRecordBatchStream> { pub async fn records(&self) -> Result<SendableRecordBatchStream> {
let full_projection = None; let full_projection = None;
let ctx = SessionContext::new();
let scan = self let scan = self
.table .table
.scan(&full_projection, &[], None) .scan(&full_projection, &[], None)
.await .await
.context(error::SystemCatalogTableScanSnafu)?; .context(error::SystemCatalogTableScanSnafu)?;
let stream = scan let stream = scan
.execute(0, Arc::new(RuntimeEnv::default())) .execute(0, ctx.task_ctx())
.context(error::SystemCatalogTableScanExecSnafu)?; .context(error::SystemCatalogTableScanExecSnafu)?;
Ok(stream) Ok(stream)
} }
@@ -162,7 +161,7 @@ fn build_system_catalog_schema() -> Schema {
), ),
ColumnSchema::new( ColumnSchema::new(
"timestamp".to_string(), "timestamp".to_string(),
ConcreteDataType::timestamp_millis_datatype(), ConcreteDataType::timestamp_millisecond_datatype(),
false, false,
) )
.with_time_index(true), .with_time_index(true),
@@ -173,12 +172,12 @@ fn build_system_catalog_schema() -> Schema {
), ),
ColumnSchema::new( ColumnSchema::new(
"gmt_created".to_string(), "gmt_created".to_string(),
ConcreteDataType::timestamp_millis_datatype(), ConcreteDataType::timestamp_millisecond_datatype(),
false, false,
), ),
ColumnSchema::new( ColumnSchema::new(
"gmt_modified".to_string(), "gmt_modified".to_string(),
ConcreteDataType::timestamp_millis_datatype(), ConcreteDataType::timestamp_millisecond_datatype(),
false, false,
), ),
]; ];
@@ -223,7 +222,7 @@ pub fn build_insert_request(entry_type: EntryType, key: &[u8], value: &[u8]) ->
// Timestamp in key part is intentionally left to 0 // Timestamp in key part is intentionally left to 0
columns_values.insert( columns_values.insert(
"timestamp".to_string(), "timestamp".to_string(),
Arc::new(TimestampVector::from_slice(&[Timestamp::from_millis(0)])) as _, Arc::new(TimestampMillisecondVector::from_slice(&[0])) as _,
); );
columns_values.insert( columns_values.insert(
@@ -231,18 +230,15 @@ pub fn build_insert_request(entry_type: EntryType, key: &[u8], value: &[u8]) ->
Arc::new(BinaryVector::from_slice(&[value])) as _, Arc::new(BinaryVector::from_slice(&[value])) as _,
); );
let now = util::current_time_millis();
columns_values.insert( columns_values.insert(
"gmt_created".to_string(), "gmt_created".to_string(),
Arc::new(TimestampVector::from_slice(&[Timestamp::from_millis( Arc::new(TimestampMillisecondVector::from_slice(&[now])) as _,
util::current_time_millis(),
)])) as _,
); );
columns_values.insert( columns_values.insert(
"gmt_modified".to_string(), "gmt_modified".to_string(),
Arc::new(TimestampVector::from_slice(&[Timestamp::from_millis( Arc::new(TimestampMillisecondVector::from_slice(&[now])) as _,
util::current_time_millis(),
)])) as _,
); );
InsertRequest { InsertRequest {

View File

@@ -26,9 +26,9 @@ use common_query::logical_plan::Expr;
use common_query::physical_plan::PhysicalPlanRef; use common_query::physical_plan::PhysicalPlanRef;
use common_recordbatch::error::Result as RecordBatchResult; use common_recordbatch::error::Result as RecordBatchResult;
use common_recordbatch::{RecordBatch, RecordBatchStream}; use common_recordbatch::{RecordBatch, RecordBatchStream};
use datatypes::prelude::{ConcreteDataType, VectorBuilder}; use datatypes::prelude::{ConcreteDataType, DataType};
use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
use datatypes::value::Value; use datatypes::value::ValueRef;
use datatypes::vectors::VectorRef; use datatypes::vectors::VectorRef;
use futures::Stream; use futures::Stream;
use snafu::ResultExt; use snafu::ResultExt;
@@ -149,26 +149,33 @@ fn tables_to_record_batch(
engine: &str, engine: &str,
) -> Vec<VectorRef> { ) -> Vec<VectorRef> {
let mut catalog_vec = let mut catalog_vec =
VectorBuilder::with_capacity(ConcreteDataType::string_datatype(), table_names.len()); ConcreteDataType::string_datatype().create_mutable_vector(table_names.len());
let mut schema_vec = let mut schema_vec =
VectorBuilder::with_capacity(ConcreteDataType::string_datatype(), table_names.len()); ConcreteDataType::string_datatype().create_mutable_vector(table_names.len());
let mut table_name_vec = let mut table_name_vec =
VectorBuilder::with_capacity(ConcreteDataType::string_datatype(), table_names.len()); ConcreteDataType::string_datatype().create_mutable_vector(table_names.len());
let mut engine_vec = let mut engine_vec =
VectorBuilder::with_capacity(ConcreteDataType::string_datatype(), table_names.len()); ConcreteDataType::string_datatype().create_mutable_vector(table_names.len());
for table_name in table_names { for table_name in table_names {
catalog_vec.push(&Value::String(catalog_name.into())); // Safety: All these vectors are string type.
schema_vec.push(&Value::String(schema_name.into())); catalog_vec
table_name_vec.push(&Value::String(table_name.into())); .push_value_ref(ValueRef::String(catalog_name))
engine_vec.push(&Value::String(engine.into())); .unwrap();
schema_vec
.push_value_ref(ValueRef::String(schema_name))
.unwrap();
table_name_vec
.push_value_ref(ValueRef::String(&table_name))
.unwrap();
engine_vec.push_value_ref(ValueRef::String(engine)).unwrap();
} }
vec![ vec![
catalog_vec.finish(), catalog_vec.to_vector(),
schema_vec.finish(), schema_vec.to_vector(),
table_name_vec.finish(), table_name_vec.to_vector(),
engine_vec.finish(), engine_vec.to_vector(),
] ]
} }
@@ -340,9 +347,7 @@ fn build_schema_for_tables() -> Schema {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::physical_plan::RuntimeEnv; use common_query::physical_plan::SessionContext;
use datatypes::arrow::array::Utf8Array;
use datatypes::arrow::datatypes::DataType;
use futures_util::StreamExt; use futures_util::StreamExt;
use table::table::numbers::NumbersTable; use table::table::numbers::NumbersTable;
@@ -366,56 +371,47 @@ mod tests {
let tables = Tables::new(catalog_list, "test_engine".to_string()); let tables = Tables::new(catalog_list, "test_engine".to_string());
let tables_stream = tables.scan(&None, &[], None).await.unwrap(); let tables_stream = tables.scan(&None, &[], None).await.unwrap();
let mut tables_stream = tables_stream let session_ctx = SessionContext::new();
.execute(0, Arc::new(RuntimeEnv::default())) let mut tables_stream = tables_stream.execute(0, session_ctx.task_ctx()).unwrap();
.unwrap();
if let Some(t) = tables_stream.next().await { if let Some(t) = tables_stream.next().await {
let batch = t.unwrap().df_recordbatch; let batch = t.unwrap();
assert_eq!(1, batch.num_rows()); assert_eq!(1, batch.num_rows());
assert_eq!(4, batch.num_columns()); assert_eq!(4, batch.num_columns());
assert_eq!(&DataType::Utf8, batch.column(0).data_type()); assert_eq!(
assert_eq!(&DataType::Utf8, batch.column(1).data_type()); ConcreteDataType::string_datatype(),
assert_eq!(&DataType::Utf8, batch.column(2).data_type()); batch.column(0).data_type()
assert_eq!(&DataType::Utf8, batch.column(3).data_type()); );
assert_eq!(
ConcreteDataType::string_datatype(),
batch.column(1).data_type()
);
assert_eq!(
ConcreteDataType::string_datatype(),
batch.column(2).data_type()
);
assert_eq!(
ConcreteDataType::string_datatype(),
batch.column(3).data_type()
);
assert_eq!( assert_eq!(
"greptime", "greptime",
batch batch.column(0).get_ref(0).as_string().unwrap().unwrap()
.column(0)
.as_any()
.downcast_ref::<Utf8Array<i32>>()
.unwrap()
.value(0)
); );
assert_eq!( assert_eq!(
"public", "public",
batch batch.column(1).get_ref(0).as_string().unwrap().unwrap()
.column(1)
.as_any()
.downcast_ref::<Utf8Array<i32>>()
.unwrap()
.value(0)
); );
assert_eq!( assert_eq!(
"test_table", "test_table",
batch batch.column(2).get_ref(0).as_string().unwrap().unwrap()
.column(2)
.as_any()
.downcast_ref::<Utf8Array<i32>>()
.unwrap()
.value(0)
); );
assert_eq!( assert_eq!(
"test_engine", "test_engine",
batch batch.column(3).get_ref(0).as_string().unwrap().unwrap()
.column(3)
.as_any()
.downcast_ref::<Utf8Array<i32>>()
.unwrap()
.value(0)
); );
} else { } else {
panic!("Record batch should not be empty!") panic!("Record batch should not be empty!")

View File

@@ -22,12 +22,12 @@ mod tests {
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use catalog::helper::{CatalogKey, CatalogValue, SchemaKey, SchemaValue};
use catalog::remote::{ use catalog::remote::{
KvBackend, KvBackendRef, RemoteCatalogManager, RemoteCatalogProvider, RemoteSchemaProvider, KvBackend, KvBackendRef, RemoteCatalogManager, RemoteCatalogProvider, RemoteSchemaProvider,
}; };
use catalog::{CatalogList, CatalogManager, RegisterTableRequest}; use catalog::{CatalogList, CatalogManager, RegisterTableRequest};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::{CatalogKey, CatalogValue, SchemaKey, SchemaValue};
use datatypes::schema::Schema; use datatypes::schema::Schema;
use futures_util::StreamExt; use futures_util::StreamExt;
use table::engine::{EngineContext, TableEngineRef}; use table::engine::{EngineContext, TableEngineRef};

View File

@@ -15,9 +15,7 @@ common-grpc-expr = { path = "../common/grpc-expr" }
common-query = { path = "../common/query" } common-query = { path = "../common/query" }
common-recordbatch = { path = "../common/recordbatch" } common-recordbatch = { path = "../common/recordbatch" }
common-time = { path = "../common/time" } common-time = { path = "../common/time" }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [ datafusion = "14.0.0"
"simd",
] }
datatypes = { path = "../datatypes" } datatypes = { path = "../datatypes" }
enum_dispatch = "0.3" enum_dispatch = "0.3"
parking_lot = "0.12" parking_lot = "0.12"

View File

@@ -41,7 +41,7 @@ async fn run() {
column_defs: vec![ column_defs: vec![
ColumnDef { ColumnDef {
name: "timestamp".to_string(), name: "timestamp".to_string(),
datatype: ColumnDataType::Timestamp as i32, datatype: ColumnDataType::TimestampMillisecond as i32,
is_nullable: false, is_nullable: false,
default_constraint: None, default_constraint: None,
}, },

View File

@@ -1,51 +0,0 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use client::{Client, Database};
use common_grpc::MockExecution;
use datafusion::physical_plan::expressions::Column;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::{ExecutionPlan, PhysicalExpr};
use tracing::{event, Level};
fn main() {
tracing::subscriber::set_global_default(tracing_subscriber::FmtSubscriber::builder().finish())
.unwrap();
run();
}
#[tokio::main]
async fn run() {
let client = Client::with_urls(vec!["127.0.0.1:3001"]);
let db = Database::new("greptime", client);
let physical = mock_physical_plan();
let result = db.physical_plan(physical, None).await;
event!(Level::INFO, "result: {:#?}", result);
}
fn mock_physical_plan() -> Arc<dyn ExecutionPlan> {
let id_expr = Arc::new(Column::new("id", 0)) as Arc<dyn PhysicalExpr>;
let age_expr = Arc::new(Column::new("age", 2)) as Arc<dyn PhysicalExpr>;
let expr = vec![(id_expr, "id".to_string()), (age_expr, "age".to_string())];
let input =
Arc::new(MockExecution::new("mock_input_exec".to_string())) as Arc<dyn ExecutionPlan>;
let projection = ProjectionExec::try_new(expr, input).unwrap();
Arc::new(projection)
}

View File

@@ -18,22 +18,17 @@ use api::v1::codec::SelectResult as GrpcSelectResult;
use api::v1::column::SemanticType; use api::v1::column::SemanticType;
use api::v1::{ use api::v1::{
object_expr, object_result, select_expr, DatabaseRequest, ExprHeader, InsertExpr, object_expr, object_result, select_expr, DatabaseRequest, ExprHeader, InsertExpr,
MutateResult as GrpcMutateResult, ObjectExpr, ObjectResult as GrpcObjectResult, PhysicalPlan, MutateResult as GrpcMutateResult, ObjectExpr, ObjectResult as GrpcObjectResult, SelectExpr,
SelectExpr,
}; };
use common_error::status_code::StatusCode; use common_error::status_code::StatusCode;
use common_grpc::{AsExecutionPlan, DefaultAsPlanImpl};
use common_grpc_expr::column_to_vector; use common_grpc_expr::column_to_vector;
use common_query::Output; use common_query::Output;
use common_recordbatch::{RecordBatch, RecordBatches}; use common_recordbatch::{RecordBatch, RecordBatches};
use datafusion::physical_plan::ExecutionPlan;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema}; use datatypes::schema::{ColumnSchema, Schema};
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
use crate::error::{ use crate::error::{ColumnToVectorSnafu, ConvertSchemaSnafu, DatanodeSnafu, DecodeSelectSnafu};
ColumnToVectorSnafu, ConvertSchemaSnafu, DatanodeSnafu, DecodeSelectSnafu, EncodePhysicalSnafu,
};
use crate::{error, Client, Result}; use crate::{error, Client, Result};
pub const PROTOCOL_VERSION: u32 = 1; pub const PROTOCOL_VERSION: u32 = 1;
@@ -94,24 +89,6 @@ impl Database {
self.do_select(select_expr).await self.do_select(select_expr).await
} }
pub async fn physical_plan(
&self,
physical: Arc<dyn ExecutionPlan>,
original_ql: Option<String>,
) -> Result<ObjectResult> {
let plan = DefaultAsPlanImpl::try_from_physical_plan(physical.clone())
.context(EncodePhysicalSnafu { physical })?
.bytes;
let original_ql = original_ql.unwrap_or_default();
let select_expr = SelectExpr {
expr: Some(select_expr::Expr::PhysicalPlan(PhysicalPlan {
original_ql: original_ql.into_bytes(),
plan,
})),
};
self.do_select(select_expr).await
}
pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<ObjectResult> { pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<ObjectResult> {
let select_expr = SelectExpr { let select_expr = SelectExpr {
expr: Some(select_expr::Expr::LogicalPlan(logical_plan)), expr: Some(select_expr::Expr::LogicalPlan(logical_plan)),
@@ -341,12 +318,11 @@ mod tests {
fn create_test_column(vector: VectorRef) -> Column { fn create_test_column(vector: VectorRef) -> Column {
let wrapper: ColumnDataTypeWrapper = vector.data_type().try_into().unwrap(); let wrapper: ColumnDataTypeWrapper = vector.data_type().try_into().unwrap();
let array = vector.to_arrow_array();
Column { Column {
column_name: "test".to_string(), column_name: "test".to_string(),
semantic_type: 1, semantic_type: 1,
values: Some(values(&[array.clone()]).unwrap()), values: Some(values(&[vector.clone()]).unwrap()),
null_mask: null_mask(&vec![array], vector.len()), null_mask: null_mask(&[vector.clone()], vector.len()),
datatype: wrapper.datatype() as i32, datatype: wrapper.datatype() as i32,
} }
} }

View File

@@ -10,6 +10,7 @@ name = "greptime"
path = "src/bin/greptime.rs" path = "src/bin/greptime.rs"
[dependencies] [dependencies]
anymap = "1.0.0-beta.2"
clap = { version = "3.1", features = ["derive"] } clap = { version = "3.1", features = ["derive"] }
common-error = { path = "../common/error" } common-error = { path = "../common/error" }
common-telemetry = { path = "../common/telemetry", features = [ common-telemetry = { path = "../common/telemetry", features = [

View File

@@ -12,8 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
const DEFAULT_VALUE: &str = "unknown";
fn main() { fn main() {
build_data::set_GIT_BRANCH(); println!(
build_data::set_GIT_COMMIT(); "cargo:rustc-env=GIT_COMMIT={}",
build_data::set_GIT_DIRTY(); build_data::get_git_commit().unwrap_or_else(|_| DEFAULT_VALUE.to_string())
);
println!(
"cargo:rustc-env=GIT_BRANCH={}",
build_data::get_git_branch().unwrap_or_else(|_| DEFAULT_VALUE.to_string())
);
println!(
"cargo:rustc-env=GIT_DIRTY={}",
build_data::get_git_dirty().map_or(DEFAULT_VALUE.to_string(), |v| v.to_string())
);
} }

View File

@@ -77,7 +77,9 @@ fn print_version() -> &'static str {
"\ncommit: ", "\ncommit: ",
env!("GIT_COMMIT"), env!("GIT_COMMIT"),
"\ndirty: ", "\ndirty: ",
env!("GIT_DIRTY") env!("GIT_DIRTY"),
"\nversion: ",
env!("CARGO_PKG_VERSION")
) )
} }

View File

@@ -25,12 +25,6 @@ pub enum Error {
source: datanode::error::Error, source: datanode::error::Error,
}, },
#[snafu(display("Failed to build frontend, source: {}", source))]
BuildFrontend {
#[snafu(backtrace)]
source: frontend::error::Error,
},
#[snafu(display("Failed to start frontend, source: {}", source))] #[snafu(display("Failed to start frontend, source: {}", source))]
StartFrontend { StartFrontend {
#[snafu(backtrace)] #[snafu(backtrace)]
@@ -61,6 +55,12 @@ pub enum Error {
#[snafu(display("Illegal config: {}", msg))] #[snafu(display("Illegal config: {}", msg))]
IllegalConfig { msg: String, backtrace: Backtrace }, IllegalConfig { msg: String, backtrace: Backtrace },
#[snafu(display("Illegal auth config: {}", source))]
IllegalAuthConfig {
#[snafu(backtrace)]
source: servers::auth::Error,
},
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@@ -75,7 +75,7 @@ impl ErrorExt for Error {
StatusCode::InvalidArguments StatusCode::InvalidArguments
} }
Error::IllegalConfig { .. } => StatusCode::InvalidArguments, Error::IllegalConfig { .. } => StatusCode::InvalidArguments,
Error::BuildFrontend { source, .. } => source.status_code(), Error::IllegalAuthConfig { .. } => StatusCode::InvalidArguments,
} }
} }

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use anymap::AnyMap;
use clap::Parser; use clap::Parser;
use frontend::frontend::{Frontend, FrontendOptions}; use frontend::frontend::{Frontend, FrontendOptions};
use frontend::grpc::GrpcOptions; use frontend::grpc::GrpcOptions;
@@ -21,11 +22,13 @@ use frontend::mysql::MysqlOptions;
use frontend::opentsdb::OpentsdbOptions; use frontend::opentsdb::OpentsdbOptions;
use frontend::postgres::PostgresOptions; use frontend::postgres::PostgresOptions;
use meta_client::MetaClientOpts; use meta_client::MetaClientOpts;
use servers::auth::UserProviderRef;
use servers::http::HttpOptions; use servers::http::HttpOptions;
use servers::Mode; use servers::tls::{TlsMode, TlsOption};
use servers::{auth, Mode};
use snafu::ResultExt; use snafu::ResultExt;
use crate::error::{self, Result}; use crate::error::{self, IllegalAuthConfigSnafu, Result};
use crate::toml_loader; use crate::toml_loader;
#[derive(Parser)] #[derive(Parser)]
@@ -71,21 +74,41 @@ pub struct StartCommand {
influxdb_enable: Option<bool>, influxdb_enable: Option<bool>,
#[clap(long)] #[clap(long)]
metasrv_addr: Option<String>, metasrv_addr: Option<String>,
#[clap(long)]
tls_mode: Option<TlsMode>,
#[clap(long)]
tls_cert_path: Option<String>,
#[clap(long)]
tls_key_path: Option<String>,
#[clap(long)]
user_provider: Option<String>,
} }
impl StartCommand { impl StartCommand {
async fn run(self) -> Result<()> { async fn run(self) -> Result<()> {
let plugins = load_frontend_plugins(&self.user_provider)?;
let opts: FrontendOptions = self.try_into()?; let opts: FrontendOptions = self.try_into()?;
let mut frontend = Frontend::new( let mut frontend = Frontend::new(
opts.clone(), opts.clone(),
Instance::try_new(&opts) Instance::try_new_distributed(&opts)
.await .await
.context(error::StartFrontendSnafu)?, .context(error::StartFrontendSnafu)?,
plugins,
); );
frontend.start().await.context(error::StartFrontendSnafu) frontend.start().await.context(error::StartFrontendSnafu)
} }
} }
pub fn load_frontend_plugins(user_provider: &Option<String>) -> Result<AnyMap> {
let mut plugins = AnyMap::new();
if let Some(provider) = user_provider {
let provider = auth::user_provider_from_option(provider).context(IllegalAuthConfigSnafu)?;
plugins.insert::<UserProviderRef>(provider);
}
Ok(plugins)
}
impl TryFrom<StartCommand> for FrontendOptions { impl TryFrom<StartCommand> for FrontendOptions {
type Error = error::Error; type Error = error::Error;
@@ -96,6 +119,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
FrontendOptions::default() FrontendOptions::default()
}; };
let tls_option = TlsOption::new(cmd.tls_mode, cmd.tls_cert_path, cmd.tls_key_path);
if let Some(addr) = cmd.http_addr { if let Some(addr) = cmd.http_addr {
opts.http_options = Some(HttpOptions { opts.http_options = Some(HttpOptions {
addr, addr,
@@ -111,12 +136,14 @@ impl TryFrom<StartCommand> for FrontendOptions {
if let Some(addr) = cmd.mysql_addr { if let Some(addr) = cmd.mysql_addr {
opts.mysql_options = Some(MysqlOptions { opts.mysql_options = Some(MysqlOptions {
addr, addr,
tls: tls_option.clone(),
..Default::default() ..Default::default()
}); });
} }
if let Some(addr) = cmd.postgres_addr { if let Some(addr) = cmd.postgres_addr {
opts.postgres_options = Some(PostgresOptions { opts.postgres_options = Some(PostgresOptions {
addr, addr,
tls: tls_option,
..Default::default() ..Default::default()
}); });
} }
@@ -147,6 +174,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
mod tests { mod tests {
use std::time::Duration; use std::time::Duration;
use servers::auth::{Identity, Password, UserProviderRef};
use super::*; use super::*;
#[test] #[test]
@@ -160,6 +189,10 @@ mod tests {
influxdb_enable: Some(false), influxdb_enable: Some(false),
config_file: None, config_file: None,
metasrv_addr: None, metasrv_addr: None,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: None,
}; };
let opts: FrontendOptions = command.try_into().unwrap(); let opts: FrontendOptions = command.try_into().unwrap();
@@ -209,11 +242,14 @@ mod tests {
std::env::current_dir().unwrap().as_path().to_str().unwrap() std::env::current_dir().unwrap().as_path().to_str().unwrap()
)), )),
metasrv_addr: None, metasrv_addr: None,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: None,
}; };
let fe_opts = FrontendOptions::try_from(command).unwrap(); let fe_opts = FrontendOptions::try_from(command).unwrap();
assert_eq!(Mode::Distributed, fe_opts.mode); assert_eq!(Mode::Distributed, fe_opts.mode);
assert_eq!("127.0.0.1:3001".to_string(), fe_opts.datanode_rpc_addr);
assert_eq!( assert_eq!(
"127.0.0.1:4000".to_string(), "127.0.0.1:4000".to_string(),
fe_opts.http_options.as_ref().unwrap().addr fe_opts.http_options.as_ref().unwrap().addr
@@ -223,4 +259,34 @@ mod tests {
fe_opts.http_options.as_ref().unwrap().timeout fe_opts.http_options.as_ref().unwrap().timeout
); );
} }
#[tokio::test]
async fn test_try_from_start_command_to_anymap() {
let command = StartCommand {
http_addr: None,
grpc_addr: None,
mysql_addr: None,
postgres_addr: None,
opentsdb_addr: None,
influxdb_enable: None,
config_file: None,
metasrv_addr: None,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: Some("static_user_provider:cmd:test=test".to_string()),
};
let plugins = load_frontend_plugins(&command.user_provider);
assert!(plugins.is_ok());
let plugins = plugins.unwrap();
let provider = plugins.get::<UserProviderRef>();
assert!(provider.is_some());
let provider = provider.unwrap();
let result = provider
.auth(Identity::UserId("test", None), Password::PlainText("test"))
.await;
assert!(result.is_ok());
}
} }

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use anymap::AnyMap;
use clap::Parser; use clap::Parser;
use common_telemetry::info; use common_telemetry::info;
use datanode::datanode::{Datanode, DatanodeOptions, ObjectStoreConfig}; use datanode::datanode::{Datanode, DatanodeOptions, ObjectStoreConfig};
@@ -26,13 +27,12 @@ use frontend::postgres::PostgresOptions;
use frontend::prometheus::PrometheusOptions; use frontend::prometheus::PrometheusOptions;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use servers::http::HttpOptions; use servers::http::HttpOptions;
use servers::tls::{TlsMode, TlsOption};
use servers::Mode; use servers::Mode;
use snafu::ResultExt; use snafu::ResultExt;
use tokio::try_join;
use crate::error::{ use crate::error::{Error, IllegalConfigSnafu, Result, StartDatanodeSnafu, StartFrontendSnafu};
BuildFrontendSnafu, Error, IllegalConfigSnafu, Result, StartDatanodeSnafu, StartFrontendSnafu, use crate::frontend::load_frontend_plugins;
};
use crate::toml_loader; use crate::toml_loader;
#[derive(Parser)] #[derive(Parser)]
@@ -104,7 +104,6 @@ impl StandaloneOptions {
influxdb_options: self.influxdb_options, influxdb_options: self.influxdb_options,
prometheus_options: self.prometheus_options, prometheus_options: self.prometheus_options,
mode: self.mode, mode: self.mode,
datanode_rpc_addr: "127.0.0.1:3001".to_string(),
meta_client_opts: None, meta_client_opts: None,
} }
} }
@@ -137,12 +136,21 @@ struct StartCommand {
config_file: Option<String>, config_file: Option<String>,
#[clap(short = 'm', long = "memory-catalog")] #[clap(short = 'm', long = "memory-catalog")]
enable_memory_catalog: bool, enable_memory_catalog: bool,
#[clap(long)]
tls_mode: Option<TlsMode>,
#[clap(long)]
tls_cert_path: Option<String>,
#[clap(long)]
tls_key_path: Option<String>,
#[clap(long)]
user_provider: Option<String>,
} }
impl StartCommand { impl StartCommand {
async fn run(self) -> Result<()> { async fn run(self) -> Result<()> {
let enable_memory_catalog = self.enable_memory_catalog; let enable_memory_catalog = self.enable_memory_catalog;
let config_file = self.config_file.clone(); let config_file = self.config_file.clone();
let plugins = load_frontend_plugins(&self.user_provider)?;
let fe_opts = FrontendOptions::try_from(self)?; let fe_opts = FrontendOptions::try_from(self)?;
let dn_opts: DatanodeOptions = { let dn_opts: DatanodeOptions = {
let mut opts: StandaloneOptions = if let Some(path) = config_file { let mut opts: StandaloneOptions = if let Some(path) = config_file {
@@ -162,7 +170,7 @@ impl StartCommand {
let mut datanode = Datanode::new(dn_opts.clone()) let mut datanode = Datanode::new(dn_opts.clone())
.await .await
.context(StartDatanodeSnafu)?; .context(StartDatanodeSnafu)?;
let mut frontend = build_frontend(fe_opts, &dn_opts, datanode.get_instance()).await?; let mut frontend = build_frontend(fe_opts, plugins, datanode.get_instance()).await?;
// Start datanode instance before starting services, to avoid requests come in before internal components are started. // Start datanode instance before starting services, to avoid requests come in before internal components are started.
datanode datanode
@@ -171,11 +179,7 @@ impl StartCommand {
.context(StartDatanodeSnafu)?; .context(StartDatanodeSnafu)?;
info!("Datanode instance started"); info!("Datanode instance started");
try_join!( frontend.start().await.context(StartFrontendSnafu)?;
async { datanode.start_services().await.context(StartDatanodeSnafu) },
async { frontend.start().await.context(StartFrontendSnafu) }
)?;
Ok(()) Ok(())
} }
} }
@@ -183,20 +187,12 @@ impl StartCommand {
/// Build frontend instance in standalone mode /// Build frontend instance in standalone mode
async fn build_frontend( async fn build_frontend(
fe_opts: FrontendOptions, fe_opts: FrontendOptions,
dn_opts: &DatanodeOptions, plugins: AnyMap,
datanode_instance: InstanceRef, datanode_instance: InstanceRef,
) -> Result<Frontend<FeInstance>> { ) -> Result<Frontend<FeInstance>> {
let grpc_server_addr = &dn_opts.rpc_addr; let mut frontend_instance = FeInstance::new_standalone(datanode_instance.clone());
info!(
"Build frontend with datanode gRPC addr: {}",
grpc_server_addr
);
let mut frontend_instance = FeInstance::try_new(&fe_opts)
.await
.context(BuildFrontendSnafu)?;
frontend_instance.set_catalog_manager(datanode_instance.catalog_manager().clone());
frontend_instance.set_script_handler(datanode_instance); frontend_instance.set_script_handler(datanode_instance);
Ok(Frontend::new(fe_opts, frontend_instance)) Ok(Frontend::new(fe_opts, frontend_instance, plugins))
} }
impl TryFrom<StartCommand> for FrontendOptions { impl TryFrom<StartCommand> for FrontendOptions {
@@ -261,6 +257,18 @@ impl TryFrom<StartCommand> for FrontendOptions {
opts.influxdb_options = Some(InfluxdbOptions { enable: true }); opts.influxdb_options = Some(InfluxdbOptions { enable: true });
} }
let tls_option = TlsOption::new(cmd.tls_mode, cmd.tls_cert_path, cmd.tls_key_path);
if let Some(mut mysql_options) = opts.mysql_options {
mysql_options.tls = tls_option.clone();
opts.mysql_options = Some(mysql_options);
}
if let Some(mut postgres_options) = opts.postgres_options {
postgres_options.tls = tls_option;
opts.postgres_options = Some(postgres_options);
}
Ok(opts) Ok(opts)
} }
} }
@@ -269,6 +277,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
mod tests { mod tests {
use std::time::Duration; use std::time::Duration;
use servers::auth::{Identity, Password, UserProviderRef};
use super::*; use super::*;
#[test] #[test]
@@ -285,11 +295,14 @@ mod tests {
)), )),
influxdb_enable: false, influxdb_enable: false,
enable_memory_catalog: false, enable_memory_catalog: false,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: None,
}; };
let fe_opts = FrontendOptions::try_from(cmd).unwrap(); let fe_opts = FrontendOptions::try_from(cmd).unwrap();
assert_eq!(Mode::Standalone, fe_opts.mode); assert_eq!(Mode::Standalone, fe_opts.mode);
assert_eq!("127.0.0.1:3001".to_string(), fe_opts.datanode_rpc_addr);
assert_eq!( assert_eq!(
"127.0.0.1:4000".to_string(), "127.0.0.1:4000".to_string(),
fe_opts.http_options.as_ref().unwrap().addr fe_opts.http_options.as_ref().unwrap().addr
@@ -309,4 +322,33 @@ mod tests {
assert_eq!(2, fe_opts.mysql_options.as_ref().unwrap().runtime_size); assert_eq!(2, fe_opts.mysql_options.as_ref().unwrap().runtime_size);
assert!(fe_opts.influxdb_options.as_ref().unwrap().enable); assert!(fe_opts.influxdb_options.as_ref().unwrap().enable);
} }
#[tokio::test]
async fn test_try_from_start_command_to_anymap() {
let command = StartCommand {
http_addr: None,
rpc_addr: None,
mysql_addr: None,
postgres_addr: None,
opentsdb_addr: None,
config_file: None,
influxdb_enable: false,
enable_memory_catalog: false,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: Some("static_user_provider:cmd:test=test".to_string()),
};
let plugins = load_frontend_plugins(&command.user_provider);
assert!(plugins.is_ok());
let plugins = plugins.unwrap();
let provider = plugins.get::<UserProviderRef>();
assert!(provider.is_some());
let provider = provider.unwrap();
let result = provider
.auth(Identity::UserId("test", None), Password::PlainText("test"))
.await;
assert!(result.is_ok());
}
} }

View File

@@ -11,4 +11,3 @@ common-error = { path = "../error" }
paste = "1.0" paste = "1.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
snafu = { version = "0.7", features = ["backtraces"] } snafu = { version = "0.7", features = ["backtraces"] }

View File

@@ -14,7 +14,6 @@ regex = "1.6"
serde = "1.0" serde = "1.0"
serde_json = "1.0" serde_json = "1.0"
snafu = { version = "0.7", features = ["backtraces"] } snafu = { version = "0.7", features = ["backtraces"] }
table = { path = "../../table" }
[dev-dependencies] [dev-dependencies]
chrono = "0.4" chrono = "0.4"

View File

@@ -25,9 +25,3 @@ pub const MIN_USER_TABLE_ID: u32 = 1024;
pub const SYSTEM_CATALOG_TABLE_ID: u32 = 0; pub const SYSTEM_CATALOG_TABLE_ID: u32 = 0;
/// scripts table id /// scripts table id
pub const SCRIPTS_TABLE_ID: u32 = 1; pub const SCRIPTS_TABLE_ID: u32 = 1;
pub(crate) const CATALOG_KEY_PREFIX: &str = "__c";
pub(crate) const SCHEMA_KEY_PREFIX: &str = "__s";
pub(crate) const TABLE_GLOBAL_KEY_PREFIX: &str = "__tg";
pub(crate) const TABLE_REGIONAL_KEY_PREFIX: &str = "__tr";
pub const TABLE_ID_KEY_PREFIX: &str = "__tid";

View File

@@ -14,10 +14,3 @@
pub mod consts; pub mod consts;
pub mod error; pub mod error;
mod helper;
pub use helper::{
build_catalog_prefix, build_schema_prefix, build_table_global_prefix,
build_table_regional_prefix, CatalogKey, CatalogValue, SchemaKey, SchemaValue, TableGlobalKey,
TableGlobalValue, TableRegionalKey, TableRegionalValue,
};

View File

@@ -62,6 +62,19 @@ pub enum StatusCode {
/// Runtime resources exhausted, like creating threads failed. /// Runtime resources exhausted, like creating threads failed.
RuntimeResourcesExhausted = 6000, RuntimeResourcesExhausted = 6000,
// ====== End of server related status code ======= // ====== End of server related status code =======
// ====== Begin of auth related status code =====
/// User not exist
UserNotFound = 7000,
/// Unsupported password type
UnsupportedPasswordType = 7001,
/// Username and password does not match
UserPasswordMismatch = 7002,
/// Not found http authorization header
AuthHeaderNotFound = 7003,
/// Invalid http authorization header
InvalidAuthHeader = 7004,
// ====== End of auth related status code =====
} }
impl StatusCode { impl StatusCode {

View File

@@ -11,7 +11,7 @@ common-error = { path = "../error" }
common-function-macro = { path = "../function-macro" } common-function-macro = { path = "../function-macro" }
common-query = { path = "../query" } common-query = { path = "../query" }
common-time = { path = "../time" } common-time = { path = "../time" }
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" } datafusion-common = "14.0.0"
datatypes = { path = "../../datatypes" } datatypes = { path = "../../datatypes" }
libc = "0.2" libc = "0.2"
num = "0.4" num = "0.4"

View File

@@ -12,5 +12,4 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
pub mod error;
pub mod scalars; pub mod scalars;

View File

@@ -23,6 +23,5 @@ pub(crate) mod test;
mod timestamp; mod timestamp;
pub mod udf; pub mod udf;
pub use aggregate::MedianAccumulatorCreator;
pub use function::{Function, FunctionRef}; pub use function::{Function, FunctionRef};
pub use function_registry::{FunctionRegistry, FUNCTION_REGISTRY}; pub use function_registry::{FunctionRegistry, FUNCTION_REGISTRY};

View File

@@ -16,7 +16,6 @@ mod argmax;
mod argmin; mod argmin;
mod diff; mod diff;
mod mean; mod mean;
mod median;
mod percentile; mod percentile;
mod polyval; mod polyval;
mod scipy_stats_norm_cdf; mod scipy_stats_norm_cdf;
@@ -29,7 +28,6 @@ pub use argmin::ArgminAccumulatorCreator;
use common_query::logical_plan::AggregateFunctionCreatorRef; use common_query::logical_plan::AggregateFunctionCreatorRef;
pub use diff::DiffAccumulatorCreator; pub use diff::DiffAccumulatorCreator;
pub use mean::MeanAccumulatorCreator; pub use mean::MeanAccumulatorCreator;
pub use median::MedianAccumulatorCreator;
pub use percentile::PercentileAccumulatorCreator; pub use percentile::PercentileAccumulatorCreator;
pub use polyval::PolyvalAccumulatorCreator; pub use polyval::PolyvalAccumulatorCreator;
pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator; pub use scipy_stats_norm_cdf::ScipyStatsNormCdfAccumulatorCreator;
@@ -88,7 +86,6 @@ impl AggregateFunctions {
}; };
} }
register_aggr_func!("median", 1, MedianAccumulatorCreator);
register_aggr_func!("diff", 1, DiffAccumulatorCreator); register_aggr_func!("diff", 1, DiffAccumulatorCreator);
register_aggr_func!("mean", 1, MeanAccumulatorCreator); register_aggr_func!("mean", 1, MeanAccumulatorCreator);
register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator); register_aggr_func!("polyval", 2, PolyvalAccumulatorCreator);

View File

@@ -20,24 +20,22 @@ use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Resul
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*; use common_query::prelude::*;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::vectors::ConstantVector; use datatypes::types::{LogicalPrimitiveType, WrapperType};
use datatypes::vectors::{ConstantVector, Helper};
use datatypes::with_match_primitive_type_id; use datatypes::with_match_primitive_type_id;
use snafu::ensure; use snafu::ensure;
// https://numpy.org/doc/stable/reference/generated/numpy.argmax.html // https://numpy.org/doc/stable/reference/generated/numpy.argmax.html
// return the index of the max value // return the index of the max value
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Argmax<T> pub struct Argmax<T> {
where
T: Primitive + PartialOrd,
{
max: Option<T>, max: Option<T>,
n: u64, n: u64,
} }
impl<T> Argmax<T> impl<T> Argmax<T>
where where
T: Primitive + PartialOrd, T: PartialOrd + Copy,
{ {
fn update(&mut self, value: T, index: u64) { fn update(&mut self, value: T, index: u64) {
if let Some(Ordering::Less) = self.max.partial_cmp(&Some(value)) { if let Some(Ordering::Less) = self.max.partial_cmp(&Some(value)) {
@@ -49,8 +47,7 @@ where
impl<T> Accumulator for Argmax<T> impl<T> Accumulator for Argmax<T>
where where
T: Primitive + PartialOrd, T: WrapperType + PartialOrd,
for<'a> T: Scalar<RefType<'a> = T>,
{ {
fn state(&self) -> Result<Vec<Value>> { fn state(&self) -> Result<Vec<Value>> {
match self.max { match self.max {
@@ -66,10 +63,10 @@ where
let column = &values[0]; let column = &values[0];
let column: &<T as Scalar>::VectorType = if column.is_const() { let column: &<T as Scalar>::VectorType = if column.is_const() {
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) } unsafe { Helper::static_cast(column.inner()) }
} else { } else {
unsafe { VectorHelper::static_cast(column) } unsafe { Helper::static_cast(column) }
}; };
for (i, v) in column.iter_data().enumerate() { for (i, v) in column.iter_data().enumerate() {
if let Some(value) = v { if let Some(value) = v {
@@ -93,8 +90,8 @@ where
let max = &states[0]; let max = &states[0];
let index = &states[1]; let index = &states[1];
let max: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(max) }; let max: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(max) };
let index: &<u64 as Scalar>::VectorType = unsafe { VectorHelper::static_cast(index) }; let index: &<u64 as Scalar>::VectorType = unsafe { Helper::static_cast(index) };
index index
.iter_data() .iter_data()
.flatten() .flatten()
@@ -122,7 +119,7 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_type.logical_type_id(), input_type.logical_type_id(),
|$S| { |$S| {
Ok(Box::new(Argmax::<$S>::default())) Ok(Box::new(Argmax::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
}, },
{ {
let err_msg = format!( let err_msg = format!(
@@ -154,7 +151,7 @@ impl AggregateFunctionCreator for ArgmaxAccumulatorCreator {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use datatypes::vectors::PrimitiveVector; use datatypes::vectors::Int32Vector;
use super::*; use super::*;
#[test] #[test]
@@ -166,21 +163,19 @@ mod test {
// test update one not-null value // test update one not-null value
let mut argmax = Argmax::<i32>::default(); let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![Some(42)]))]; let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Some(42)]))];
assert!(argmax.update_batch(&v).is_ok()); assert!(argmax.update_batch(&v).is_ok());
assert_eq!(Value::from(0_u64), argmax.evaluate().unwrap()); assert_eq!(Value::from(0_u64), argmax.evaluate().unwrap());
// test update one null value // test update one null value
let mut argmax = Argmax::<i32>::default(); let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Option::<i32>::None]))];
Option::<i32>::None,
]))];
assert!(argmax.update_batch(&v).is_ok()); assert!(argmax.update_batch(&v).is_ok());
assert_eq!(Value::Null, argmax.evaluate().unwrap()); assert_eq!(Value::Null, argmax.evaluate().unwrap());
// test update no null-value batch // test update no null-value batch
let mut argmax = Argmax::<i32>::default(); let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-1i32), Some(-1i32),
Some(1), Some(1),
Some(3), Some(3),
@@ -190,7 +185,7 @@ mod test {
// test update null-value batch // test update null-value batch
let mut argmax = Argmax::<i32>::default(); let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-2i32), Some(-2i32),
None, None,
Some(4), Some(4),
@@ -201,7 +196,7 @@ mod test {
// test update with constant vector // test update with constant vector
let mut argmax = Argmax::<i32>::default(); let mut argmax = Argmax::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new( let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])), Arc::new(Int32Vector::from_vec(vec![4])),
10, 10,
))]; ))];
assert!(argmax.update_batch(&v).is_ok()); assert!(argmax.update_batch(&v).is_ok());

View File

@@ -20,23 +20,20 @@ use common_query::error::{BadAccumulatorImplSnafu, CreateAccumulatorSnafu, Resul
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*; use common_query::prelude::*;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::vectors::ConstantVector; use datatypes::vectors::{ConstantVector, Helper};
use datatypes::with_match_primitive_type_id; use datatypes::with_match_primitive_type_id;
use snafu::ensure; use snafu::ensure;
// // https://numpy.org/doc/stable/reference/generated/numpy.argmin.html // // https://numpy.org/doc/stable/reference/generated/numpy.argmin.html
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Argmin<T> pub struct Argmin<T> {
where
T: Primitive + PartialOrd,
{
min: Option<T>, min: Option<T>,
n: u32, n: u32,
} }
impl<T> Argmin<T> impl<T> Argmin<T>
where where
T: Primitive + PartialOrd, T: Copy + PartialOrd,
{ {
fn update(&mut self, value: T, index: u32) { fn update(&mut self, value: T, index: u32) {
match self.min { match self.min {
@@ -56,8 +53,7 @@ where
impl<T> Accumulator for Argmin<T> impl<T> Accumulator for Argmin<T>
where where
T: Primitive + PartialOrd, T: WrapperType + PartialOrd,
for<'a> T: Scalar<RefType<'a> = T>,
{ {
fn state(&self) -> Result<Vec<Value>> { fn state(&self) -> Result<Vec<Value>> {
match self.min { match self.min {
@@ -75,10 +71,10 @@ where
let column = &values[0]; let column = &values[0];
let column: &<T as Scalar>::VectorType = if column.is_const() { let column: &<T as Scalar>::VectorType = if column.is_const() {
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) } unsafe { Helper::static_cast(column.inner()) }
} else { } else {
unsafe { VectorHelper::static_cast(column) } unsafe { Helper::static_cast(column) }
}; };
for (i, v) in column.iter_data().enumerate() { for (i, v) in column.iter_data().enumerate() {
if let Some(value) = v { if let Some(value) = v {
@@ -102,8 +98,8 @@ where
let min = &states[0]; let min = &states[0];
let index = &states[1]; let index = &states[1];
let min: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(min) }; let min: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(min) };
let index: &<u32 as Scalar>::VectorType = unsafe { VectorHelper::static_cast(index) }; let index: &<u32 as Scalar>::VectorType = unsafe { Helper::static_cast(index) };
index index
.iter_data() .iter_data()
.flatten() .flatten()
@@ -131,7 +127,7 @@ impl AggregateFunctionCreator for ArgminAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_type.logical_type_id(), input_type.logical_type_id(),
|$S| { |$S| {
Ok(Box::new(Argmin::<$S>::default())) Ok(Box::new(Argmin::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
}, },
{ {
let err_msg = format!( let err_msg = format!(
@@ -163,7 +159,7 @@ impl AggregateFunctionCreator for ArgminAccumulatorCreator {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use datatypes::vectors::PrimitiveVector; use datatypes::vectors::Int32Vector;
use super::*; use super::*;
#[test] #[test]
@@ -175,21 +171,19 @@ mod test {
// test update one not-null value // test update one not-null value
let mut argmin = Argmin::<i32>::default(); let mut argmin = Argmin::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![Some(42)]))]; let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Some(42)]))];
assert!(argmin.update_batch(&v).is_ok()); assert!(argmin.update_batch(&v).is_ok());
assert_eq!(Value::from(0_u32), argmin.evaluate().unwrap()); assert_eq!(Value::from(0_u32), argmin.evaluate().unwrap());
// test update one null value // test update one null value
let mut argmin = Argmin::<i32>::default(); let mut argmin = Argmin::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Option::<i32>::None]))];
Option::<i32>::None,
]))];
assert!(argmin.update_batch(&v).is_ok()); assert!(argmin.update_batch(&v).is_ok());
assert_eq!(Value::Null, argmin.evaluate().unwrap()); assert_eq!(Value::Null, argmin.evaluate().unwrap());
// test update no null-value batch // test update no null-value batch
let mut argmin = Argmin::<i32>::default(); let mut argmin = Argmin::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-1i32), Some(-1i32),
Some(1), Some(1),
Some(3), Some(3),
@@ -199,7 +193,7 @@ mod test {
// test update null-value batch // test update null-value batch
let mut argmin = Argmin::<i32>::default(); let mut argmin = Argmin::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-2i32), Some(-2i32),
None, None,
Some(4), Some(4),
@@ -210,7 +204,7 @@ mod test {
// test update with constant vector // test update with constant vector
let mut argmin = Argmin::<i32>::default(); let mut argmin = Argmin::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new( let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])), Arc::new(Int32Vector::from_vec(vec![4])),
10, 10,
))]; ))];
assert!(argmin.update_batch(&v).is_ok()); assert!(argmin.update_batch(&v).is_ok());

View File

@@ -22,40 +22,32 @@ use common_query::error::{
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*; use common_query::prelude::*;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::types::PrimitiveType;
use datatypes::value::ListValue; use datatypes::value::ListValue;
use datatypes::vectors::{ConstantVector, ListVector}; use datatypes::vectors::{ConstantVector, Helper, ListVector};
use datatypes::with_match_primitive_type_id; use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive; use num_traits::AsPrimitive;
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
// https://numpy.org/doc/stable/reference/generated/numpy.diff.html // https://numpy.org/doc/stable/reference/generated/numpy.diff.html
// I is the input type, O is the output type.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Diff<T, SubT> pub struct Diff<I, O> {
where values: Vec<I>,
T: Primitive + AsPrimitive<SubT>, _phantom: PhantomData<O>,
SubT: Primitive + std::ops::Sub<Output = SubT>,
{
values: Vec<T>,
_phantom: PhantomData<SubT>,
} }
impl<T, SubT> Diff<T, SubT> impl<I, O> Diff<I, O> {
where fn push(&mut self, value: I) {
T: Primitive + AsPrimitive<SubT>,
SubT: Primitive + std::ops::Sub<Output = SubT>,
{
fn push(&mut self, value: T) {
self.values.push(value); self.values.push(value);
} }
} }
impl<T, SubT> Accumulator for Diff<T, SubT> impl<I, O> Accumulator for Diff<I, O>
where where
T: Primitive + AsPrimitive<SubT>, I: WrapperType,
for<'a> T: Scalar<RefType<'a> = T>, O: WrapperType,
SubT: Primitive + std::ops::Sub<Output = SubT>, I::Native: AsPrimitive<O::Native>,
for<'a> SubT: Scalar<RefType<'a> = SubT>, O::Native: std::ops::Sub<Output = O::Native>,
{ {
fn state(&self) -> Result<Vec<Value>> { fn state(&self) -> Result<Vec<Value>> {
let nums = self let nums = self
@@ -65,7 +57,7 @@ where
.collect::<Vec<Value>>(); .collect::<Vec<Value>>();
Ok(vec![Value::List(ListValue::new( Ok(vec![Value::List(ListValue::new(
Some(Box::new(nums)), Some(Box::new(nums)),
T::default().into().data_type(), I::LogicalType::build_data_type(),
))]) ))])
} }
@@ -78,12 +70,12 @@ where
let column = &values[0]; let column = &values[0];
let mut len = 1; let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() { let column: &<I as Scalar>::VectorType = if column.is_const() {
len = column.len(); len = column.len();
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) } unsafe { Helper::static_cast(column.inner()) }
} else { } else {
unsafe { VectorHelper::static_cast(column) } unsafe { Helper::static_cast(column) }
}; };
(0..len).for_each(|_| { (0..len).for_each(|_| {
for v in column.iter_data().flatten() { for v in column.iter_data().flatten() {
@@ -109,8 +101,9 @@ where
), ),
})?; })?;
for state in states.values_iter() { for state in states.values_iter() {
let state = state.context(FromScalarValueSnafu)?; if let Some(state) = state.context(FromScalarValueSnafu)? {
self.update_batch(&[state])? self.update_batch(&[state])?;
}
} }
Ok(()) Ok(())
} }
@@ -122,11 +115,14 @@ where
let diff = self let diff = self
.values .values
.windows(2) .windows(2)
.map(|x| (x[1].as_() - x[0].as_()).into()) .map(|x| {
let native = x[1].into_native().as_() - x[0].into_native().as_();
O::from_native(native).into()
})
.collect::<Vec<Value>>(); .collect::<Vec<Value>>();
let diff = Value::List(ListValue::new( let diff = Value::List(ListValue::new(
Some(Box::new(diff)), Some(Box::new(diff)),
SubT::default().into().data_type(), O::LogicalType::build_data_type(),
)); ));
Ok(diff) Ok(diff)
} }
@@ -143,7 +139,7 @@ impl AggregateFunctionCreator for DiffAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_type.logical_type_id(), input_type.logical_type_id(),
|$S| { |$S| {
Ok(Box::new(Diff::<$S,<$S as Primitive>::LargestType>::default())) Ok(Box::new(Diff::<<$S as LogicalPrimitiveType>::Wrapper, <<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default()))
}, },
{ {
let err_msg = format!( let err_msg = format!(
@@ -163,7 +159,7 @@ impl AggregateFunctionCreator for DiffAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_types[0].logical_type_id(), input_types[0].logical_type_id(),
|$S| { |$S| {
Ok(ConcreteDataType::list_datatype(PrimitiveType::<<$S as Primitive>::LargestType>::default().into())) Ok(ConcreteDataType::list_datatype($S::default().into()))
}, },
{ {
unreachable!() unreachable!()
@@ -177,7 +173,7 @@ impl AggregateFunctionCreator for DiffAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_types[0].logical_type_id(), input_types[0].logical_type_id(),
|$S| { |$S| {
Ok(vec![ConcreteDataType::list_datatype(PrimitiveType::<$S>::default().into())]) Ok(vec![ConcreteDataType::list_datatype($S::default().into())])
}, },
{ {
unreachable!() unreachable!()
@@ -188,9 +184,10 @@ impl AggregateFunctionCreator for DiffAccumulatorCreator {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use datatypes::vectors::PrimitiveVector; use datatypes::vectors::Int32Vector;
use super::*; use super::*;
#[test] #[test]
fn test_update_batch() { fn test_update_batch() {
// test update empty batch, expect not updating anything // test update empty batch, expect not updating anything
@@ -201,21 +198,19 @@ mod test {
// test update one not-null value // test update one not-null value
let mut diff = Diff::<i32, i64>::default(); let mut diff = Diff::<i32, i64>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![Some(42)]))]; let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Some(42)]))];
assert!(diff.update_batch(&v).is_ok()); assert!(diff.update_batch(&v).is_ok());
assert_eq!(Value::Null, diff.evaluate().unwrap()); assert_eq!(Value::Null, diff.evaluate().unwrap());
// test update one null value // test update one null value
let mut diff = Diff::<i32, i64>::default(); let mut diff = Diff::<i32, i64>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Option::<i32>::None]))];
Option::<i32>::None,
]))];
assert!(diff.update_batch(&v).is_ok()); assert!(diff.update_batch(&v).is_ok());
assert_eq!(Value::Null, diff.evaluate().unwrap()); assert_eq!(Value::Null, diff.evaluate().unwrap());
// test update no null-value batch // test update no null-value batch
let mut diff = Diff::<i32, i64>::default(); let mut diff = Diff::<i32, i64>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-1i32), Some(-1i32),
Some(1), Some(1),
Some(2), Some(2),
@@ -232,7 +227,7 @@ mod test {
// test update null-value batch // test update null-value batch
let mut diff = Diff::<i32, i64>::default(); let mut diff = Diff::<i32, i64>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-2i32), Some(-2i32),
None, None,
Some(3), Some(3),
@@ -251,7 +246,7 @@ mod test {
// test update with constant vector // test update with constant vector
let mut diff = Diff::<i32, i64>::default(); let mut diff = Diff::<i32, i64>::default();
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new( let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])), Arc::new(Int32Vector::from_vec(vec![4])),
4, 4,
))]; ))];
let values = vec![Value::from(0_i64), Value::from(0_i64), Value::from(0_i64)]; let values = vec![Value::from(0_i64), Value::from(0_i64), Value::from(0_i64)];

View File

@@ -22,16 +22,14 @@ use common_query::error::{
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*; use common_query::prelude::*;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::vectors::{ConstantVector, Float64Vector, UInt64Vector}; use datatypes::types::WrapperType;
use datatypes::vectors::{ConstantVector, Float64Vector, Helper, UInt64Vector};
use datatypes::with_match_primitive_type_id; use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive; use num_traits::AsPrimitive;
use snafu::{ensure, OptionExt}; use snafu::{ensure, OptionExt};
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Mean<T> pub struct Mean<T> {
where
T: Primitive + AsPrimitive<f64>,
{
sum: f64, sum: f64,
n: u64, n: u64,
_phantom: PhantomData<T>, _phantom: PhantomData<T>,
@@ -39,11 +37,12 @@ where
impl<T> Mean<T> impl<T> Mean<T>
where where
T: Primitive + AsPrimitive<f64>, T: WrapperType,
T::Native: AsPrimitive<f64>,
{ {
#[inline(always)] #[inline(always)]
fn push(&mut self, value: T) { fn push(&mut self, value: T) {
self.sum += value.as_(); self.sum += value.into_native().as_();
self.n += 1; self.n += 1;
} }
@@ -56,8 +55,8 @@ where
impl<T> Accumulator for Mean<T> impl<T> Accumulator for Mean<T>
where where
T: Primitive + AsPrimitive<f64>, T: WrapperType,
for<'a> T: Scalar<RefType<'a> = T>, T::Native: AsPrimitive<f64>,
{ {
fn state(&self) -> Result<Vec<Value>> { fn state(&self) -> Result<Vec<Value>> {
Ok(vec![self.sum.into(), self.n.into()]) Ok(vec![self.sum.into(), self.n.into()])
@@ -73,10 +72,10 @@ where
let mut len = 1; let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() { let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len(); len = column.len();
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) } unsafe { Helper::static_cast(column.inner()) }
} else { } else {
unsafe { VectorHelper::static_cast(column) } unsafe { Helper::static_cast(column) }
}; };
(0..len).for_each(|_| { (0..len).for_each(|_| {
for v in column.iter_data().flatten() { for v in column.iter_data().flatten() {
@@ -150,7 +149,7 @@ impl AggregateFunctionCreator for MeanAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_type.logical_type_id(), input_type.logical_type_id(),
|$S| { |$S| {
Ok(Box::new(Mean::<$S>::default())) Ok(Box::new(Mean::<<$S as LogicalPrimitiveType>::Native>::default()))
}, },
{ {
let err_msg = format!( let err_msg = format!(
@@ -182,7 +181,7 @@ impl AggregateFunctionCreator for MeanAccumulatorCreator {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use datatypes::vectors::PrimitiveVector; use datatypes::vectors::Int32Vector;
use super::*; use super::*;
#[test] #[test]
@@ -194,21 +193,19 @@ mod test {
// test update one not-null value // test update one not-null value
let mut mean = Mean::<i32>::default(); let mut mean = Mean::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![Some(42)]))]; let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Some(42)]))];
assert!(mean.update_batch(&v).is_ok()); assert!(mean.update_batch(&v).is_ok());
assert_eq!(Value::from(42.0_f64), mean.evaluate().unwrap()); assert_eq!(Value::from(42.0_f64), mean.evaluate().unwrap());
// test update one null value // test update one null value
let mut mean = Mean::<i32>::default(); let mut mean = Mean::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![Option::<i32>::None]))];
Option::<i32>::None,
]))];
assert!(mean.update_batch(&v).is_ok()); assert!(mean.update_batch(&v).is_ok());
assert_eq!(Value::Null, mean.evaluate().unwrap()); assert_eq!(Value::Null, mean.evaluate().unwrap());
// test update no null-value batch // test update no null-value batch
let mut mean = Mean::<i32>::default(); let mut mean = Mean::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-1i32), Some(-1i32),
Some(1), Some(1),
Some(2), Some(2),
@@ -218,7 +215,7 @@ mod test {
// test update null-value batch // test update null-value batch
let mut mean = Mean::<i32>::default(); let mut mean = Mean::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![ let v: Vec<VectorRef> = vec![Arc::new(Int32Vector::from(vec![
Some(-2i32), Some(-2i32),
None, None,
Some(3), Some(3),
@@ -230,7 +227,7 @@ mod test {
// test update with constant vector // test update with constant vector
let mut mean = Mean::<i32>::default(); let mut mean = Mean::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new( let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])), Arc::new(Int32Vector::from_vec(vec![4])),
10, 10,
))]; ))];
assert!(mean.update_batch(&v).is_ok()); assert!(mean.update_batch(&v).is_ok());

View File

@@ -1,289 +0,0 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use common_function_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{
CreateAccumulatorSnafu, DowncastVectorSnafu, FromScalarValueSnafu, Result,
};
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*;
use datatypes::prelude::*;
use datatypes::types::OrdPrimitive;
use datatypes::value::ListValue;
use datatypes::vectors::{ConstantVector, ListVector};
use datatypes::with_match_primitive_type_id;
use num::NumCast;
use snafu::{ensure, OptionExt, ResultExt};
// This median calculation algorithm's details can be found at
// https://leetcode.cn/problems/find-median-from-data-stream/
//
// Basically, it uses two heaps, a maximum heap and a minimum. The maximum heap stores numbers that
// are not greater than the median, and the minimum heap stores the greater. In a streaming of
// numbers, when a number is arrived, we adjust the heaps' tops, so that either one top is the
// median or both tops can be averaged to get the median.
//
// The time complexity to update the median is O(logn), O(1) to get the median; and the space
// complexity is O(n). (Ignore the costs for heap expansion.)
//
// From the point of algorithm, [quick select](https://en.wikipedia.org/wiki/Quickselect) might be
// better. But to use quick select here, we need a mutable self in the final calculation(`evaluate`)
// to swap stored numbers in the states vector. Though we can make our `evaluate` received
// `&mut self`, DataFusion calls our accumulator with `&self` (see `DfAccumulatorAdaptor`). That
// means we have to introduce some kinds of interior mutability, and the overhead is not neglectable.
//
// TODO(LFC): Use quick select to get median when we can modify DataFusion's code, and benchmark with two-heap algorithm.
#[derive(Debug, Default)]
pub struct Median<T>
where
T: Primitive,
{
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
not_greater: BinaryHeap<OrdPrimitive<T>>,
}
impl<T> Median<T>
where
T: Primitive,
{
fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value);
if self.not_greater.is_empty() {
self.not_greater.push(value);
return;
}
// The `unwrap`s below are safe because there are `push`s before them.
if value <= *self.not_greater.peek().unwrap() {
self.not_greater.push(value);
if self.not_greater.len() > self.greater.len() + 1 {
self.greater.push(Reverse(self.not_greater.pop().unwrap()));
}
} else {
self.greater.push(Reverse(value));
if self.greater.len() > self.not_greater.len() {
self.not_greater.push(self.greater.pop().unwrap().0);
}
}
}
}
// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions
// to use them.
impl<T> Accumulator for Median<T>
where
T: Primitive,
for<'a> T: Scalar<RefType<'a> = T>,
{
// This function serializes our state to `ScalarValue`, which DataFusion uses to pass this
// state between execution stages. Note that this can be arbitrary data.
//
// The `ScalarValue`s returned here will be passed in as argument `states: &[VectorRef]` to
// `merge_batch` function.
fn state(&self) -> Result<Vec<Value>> {
let nums = self
.greater
.iter()
.map(|x| &x.0)
.chain(self.not_greater.iter())
.map(|&n| n.into())
.collect::<Vec<Value>>();
Ok(vec![Value::List(ListValue::new(
Some(Box::new(nums)),
T::default().into().data_type(),
))])
}
// DataFusion calls this function to update the accumulator's state for a batch of inputs rows.
// It is expected this function to update the accumulator's state.
fn update_batch(&mut self, values: &[VectorRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
ensure!(values.len() == 1, InvalidInputStateSnafu);
// This is a unary accumulator, so only one column is provided.
let column = &values[0];
let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len();
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) }
} else {
unsafe { VectorHelper::static_cast(column) }
};
(0..len).for_each(|_| {
for v in column.iter_data().flatten() {
self.push(v);
}
});
Ok(())
}
// DataFusion executes accumulators in partitions. In some execution stage, DataFusion will
// merge states from other accumulators (returned by `state()` method).
fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
// The states here are returned by the `state` method. Since we only returned a vector
// with one value in that method, `states[0]` is fine.
let states = &states[0];
let states = states
.as_any()
.downcast_ref::<ListVector>()
.with_context(|| DowncastVectorSnafu {
err_msg: format!(
"expect ListVector, got vector type {}",
states.vector_type_name()
),
})?;
for state in states.values_iter() {
let state = state.context(FromScalarValueSnafu)?;
// merging state is simply accumulate stored numbers from others', so just call update
self.update_batch(&[state])?
}
Ok(())
}
// DataFusion expects this function to return the final value of this aggregator.
fn evaluate(&self) -> Result<Value> {
if self.not_greater.is_empty() {
assert!(
self.greater.is_empty(),
"not expected in two-heap median algorithm, there must be a bug when implementing it"
);
return Ok(Value::Null);
}
// unwrap is safe because we checked not_greater heap's len above
let not_greater = *self.not_greater.peek().unwrap();
let median = if self.not_greater.len() > self.greater.len() {
not_greater.into()
} else {
// unwrap is safe because greater heap len >= not_greater heap len, which is > 0
let greater = self.greater.peek().unwrap();
// the following three NumCast's `unwrap`s are safe because T is primitive
let not_greater_v: f64 = NumCast::from(not_greater.as_primitive()).unwrap();
let greater_v: f64 = NumCast::from(greater.0.as_primitive()).unwrap();
let median: T = NumCast::from((not_greater_v + greater_v) / 2.0).unwrap();
median.into()
};
Ok(median)
}
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct MedianAccumulatorCreator {}
impl AggregateFunctionCreator for MedianAccumulatorCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let input_type = &types[0];
with_match_primitive_type_id!(
input_type.logical_type_id(),
|$S| {
Ok(Box::new(Median::<$S>::default()))
},
{
let err_msg = format!(
"\"MEDIAN\" aggregate function not support data type {:?}",
input_type.logical_type_id(),
);
CreateAccumulatorSnafu { err_msg }.fail()?
}
)
});
creator
}
fn output_type(&self) -> Result<ConcreteDataType> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 1, InvalidInputStateSnafu);
// unwrap is safe because we have checked input_types len must equals 1
Ok(input_types.into_iter().next().unwrap())
}
fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
Ok(vec![ConcreteDataType::list_datatype(self.output_type()?)])
}
}
#[cfg(test)]
mod test {
use datatypes::vectors::PrimitiveVector;
use super::*;
#[test]
fn test_update_batch() {
// test update empty batch, expect not updating anything
let mut median = Median::<i32>::default();
assert!(median.update_batch(&[]).is_ok());
assert!(median.not_greater.is_empty());
assert!(median.greater.is_empty());
assert_eq!(Value::Null, median.evaluate().unwrap());
// test update one not-null value
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![Some(42)]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(42), median.evaluate().unwrap());
// test update one null value
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
Option::<i32>::None,
]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Null, median.evaluate().unwrap());
// test update no null-value batch
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
Some(-1i32),
Some(1),
Some(2),
]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(1), median.evaluate().unwrap());
// test update null-value batch
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(PrimitiveVector::<i32>::from(vec![
Some(-2i32),
None,
Some(3),
Some(4),
]))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(3), median.evaluate().unwrap());
// test update with constant vector
let mut median = Median::<i32>::default();
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])),
10,
))];
assert!(median.update_batch(&v).is_ok());
assert_eq!(Value::Int32(4), median.evaluate().unwrap());
}
}

View File

@@ -26,7 +26,7 @@ use common_query::prelude::*;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::types::OrdPrimitive; use datatypes::types::OrdPrimitive;
use datatypes::value::{ListValue, OrderedFloat}; use datatypes::value::{ListValue, OrderedFloat};
use datatypes::vectors::{ConstantVector, Float64Vector, ListVector}; use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector};
use datatypes::with_match_primitive_type_id; use datatypes::with_match_primitive_type_id;
use num::NumCast; use num::NumCast;
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
@@ -44,15 +44,15 @@ use snafu::{ensure, OptionExt, ResultExt};
// This optional method parameter specifies the method to use when the desired quantile lies between two data points i < j. // This optional method parameter specifies the method to use when the desired quantile lies between two data points i < j.
// If g is the fractional part of the index surrounded by i and alpha and beta are correction constants modifying i and j. // If g is the fractional part of the index surrounded by i and alpha and beta are correction constants modifying i and j.
// i+g = (q-alpha)/(n-alpha-beta+1) // i+g = (q-alpha)/(n-alpha-beta+1)
// Below, q is the quantile value, n is the sample size and alpha and beta are constants. The following formula gives an interpolation i + g of where the quantile would be in the sorted sample. // Below, 'q' is the quantile value, 'n' is the sample size and alpha and beta are constants. The following formula gives an interpolation "i + g" of where the quantile would be in the sorted sample.
// With i being the floor and g the fractional part of the result. // With 'i' being the floor and 'g' the fractional part of the result.
// the default method is linear where // the default method is linear where
// alpha = 1 // alpha = 1
// beta = 1 // beta = 1
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Percentile<T> pub struct Percentile<T>
where where
T: Primitive, T: WrapperType,
{ {
greater: BinaryHeap<Reverse<OrdPrimitive<T>>>, greater: BinaryHeap<Reverse<OrdPrimitive<T>>>,
not_greater: BinaryHeap<OrdPrimitive<T>>, not_greater: BinaryHeap<OrdPrimitive<T>>,
@@ -62,7 +62,7 @@ where
impl<T> Percentile<T> impl<T> Percentile<T>
where where
T: Primitive, T: WrapperType,
{ {
fn push(&mut self, value: T) { fn push(&mut self, value: T) {
let value = OrdPrimitive::<T>(value); let value = OrdPrimitive::<T>(value);
@@ -93,8 +93,7 @@ where
impl<T> Accumulator for Percentile<T> impl<T> Accumulator for Percentile<T>
where where
T: Primitive, T: WrapperType,
for<'a> T: Scalar<RefType<'a> = T>,
{ {
fn state(&self) -> Result<Vec<Value>> { fn state(&self) -> Result<Vec<Value>> {
let nums = self let nums = self
@@ -107,7 +106,7 @@ where
Ok(vec![ Ok(vec![
Value::List(ListValue::new( Value::List(ListValue::new(
Some(Box::new(nums)), Some(Box::new(nums)),
T::default().into().data_type(), T::LogicalType::build_data_type(),
)), )),
self.p.into(), self.p.into(),
]) ])
@@ -129,14 +128,14 @@ where
let mut len = 1; let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() { let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len(); len = column.len();
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) } unsafe { Helper::static_cast(column.inner()) }
} else { } else {
unsafe { VectorHelper::static_cast(column) } unsafe { Helper::static_cast(column) }
}; };
let x = &values[1]; let x = &values[1];
let x = VectorHelper::check_get_scalar::<f64>(x).context(error::InvalidInputsSnafu { let x = Helper::check_get_scalar::<f64>(x).context(error::InvalidInputTypeSnafu {
err_msg: "expecting \"POLYVAL\" function's second argument to be float64", err_msg: "expecting \"POLYVAL\" function's second argument to be float64",
})?; })?;
// `get(0)` is safe because we have checked `values[1].len() == values[0].len() != 0` // `get(0)` is safe because we have checked `values[1].len() == values[0].len() != 0`
@@ -209,10 +208,11 @@ where
), ),
})?; })?;
for value in values.values_iter() { for value in values.values_iter() {
let value = value.context(FromScalarValueSnafu)?; if let Some(value) = value.context(FromScalarValueSnafu)? {
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&value) }; let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(&value) };
for v in column.iter_data().flatten() { for v in column.iter_data().flatten() {
self.push(v); self.push(v);
}
} }
} }
Ok(()) Ok(())
@@ -259,7 +259,7 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_type.logical_type_id(), input_type.logical_type_id(),
|$S| { |$S| {
Ok(Box::new(Percentile::<$S>::default())) Ok(Box::new(Percentile::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
}, },
{ {
let err_msg = format!( let err_msg = format!(
@@ -292,7 +292,7 @@ impl AggregateFunctionCreator for PercentileAccumulatorCreator {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use datatypes::vectors::PrimitiveVector; use datatypes::vectors::{Float64Vector, Int32Vector};
use super::*; use super::*;
#[test] #[test]
@@ -307,8 +307,8 @@ mod test {
// test update one not-null value // test update one not-null value
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![Some(42)])), Arc::new(Int32Vector::from(vec![Some(42)])),
Arc::new(PrimitiveVector::<f64>::from(vec![Some(100.0_f64)])), Arc::new(Float64Vector::from(vec![Some(100.0_f64)])),
]; ];
assert!(percentile.update_batch(&v).is_ok()); assert!(percentile.update_batch(&v).is_ok());
assert_eq!(Value::from(42.0_f64), percentile.evaluate().unwrap()); assert_eq!(Value::from(42.0_f64), percentile.evaluate().unwrap());
@@ -316,8 +316,8 @@ mod test {
// test update one null value // test update one null value
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![Option::<i32>::None])), Arc::new(Int32Vector::from(vec![Option::<i32>::None])),
Arc::new(PrimitiveVector::<f64>::from(vec![Some(100.0_f64)])), Arc::new(Float64Vector::from(vec![Some(100.0_f64)])),
]; ];
assert!(percentile.update_batch(&v).is_ok()); assert!(percentile.update_batch(&v).is_ok());
assert_eq!(Value::Null, percentile.evaluate().unwrap()); assert_eq!(Value::Null, percentile.evaluate().unwrap());
@@ -325,12 +325,8 @@ mod test {
// test update no null-value batch // test update no null-value batch
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Some(-1i32), Arc::new(Float64Vector::from(vec![
Some(1),
Some(2),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(100.0_f64), Some(100.0_f64),
Some(100.0_f64), Some(100.0_f64),
Some(100.0_f64), Some(100.0_f64),
@@ -342,13 +338,8 @@ mod test {
// test update null-value batch // test update null-value batch
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])),
Some(-2i32), Arc::new(Float64Vector::from(vec![
None,
Some(3),
Some(4),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(100.0_f64), Some(100.0_f64),
Some(100.0_f64), Some(100.0_f64),
Some(100.0_f64), Some(100.0_f64),
@@ -362,13 +353,10 @@ mod test {
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(ConstantVector::new( Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])), Arc::new(Int32Vector::from_vec(vec![4])),
2, 2,
)), )),
Arc::new(PrimitiveVector::<f64>::from(vec![ Arc::new(Float64Vector::from(vec![Some(100.0_f64), Some(100.0_f64)])),
Some(100.0_f64),
Some(100.0_f64),
])),
]; ];
assert!(percentile.update_batch(&v).is_ok()); assert!(percentile.update_batch(&v).is_ok());
assert_eq!(Value::from(4_f64), percentile.evaluate().unwrap()); assert_eq!(Value::from(4_f64), percentile.evaluate().unwrap());
@@ -376,12 +364,8 @@ mod test {
// test left border // test left border
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Some(-1i32), Arc::new(Float64Vector::from(vec![
Some(1),
Some(2),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(0.0_f64), Some(0.0_f64),
Some(0.0_f64), Some(0.0_f64),
Some(0.0_f64), Some(0.0_f64),
@@ -393,12 +377,8 @@ mod test {
// test medium // test medium
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Some(-1i32), Arc::new(Float64Vector::from(vec![
Some(1),
Some(2),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(50.0_f64), Some(50.0_f64),
Some(50.0_f64), Some(50.0_f64),
Some(50.0_f64), Some(50.0_f64),
@@ -410,12 +390,8 @@ mod test {
// test right border // test right border
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Some(-1i32), Arc::new(Float64Vector::from(vec![
Some(1),
Some(2),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(100.0_f64), Some(100.0_f64),
Some(100.0_f64), Some(100.0_f64),
Some(100.0_f64), Some(100.0_f64),
@@ -431,12 +407,8 @@ mod test {
// >> 6.400000000000 // >> 6.400000000000
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(10i32), Some(7), Some(4)])),
Some(10i32), Arc::new(Float64Vector::from(vec![
Some(7),
Some(4),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(40.0_f64), Some(40.0_f64),
Some(40.0_f64), Some(40.0_f64),
Some(40.0_f64), Some(40.0_f64),
@@ -451,12 +423,8 @@ mod test {
// >> 9.7000000000000011 // >> 9.7000000000000011
let mut percentile = Percentile::<i32>::default(); let mut percentile = Percentile::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(10i32), Some(7), Some(4)])),
Some(10i32), Arc::new(Float64Vector::from(vec![
Some(7),
Some(4),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(95.0_f64), Some(95.0_f64),
Some(95.0_f64), Some(95.0_f64),
Some(95.0_f64), Some(95.0_f64),

View File

@@ -23,9 +23,9 @@ use common_query::error::{
use common_query::logical_plan::{Accumulator, AggregateFunctionCreator}; use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*; use common_query::prelude::*;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::types::PrimitiveType; use datatypes::types::{LogicalPrimitiveType, WrapperType};
use datatypes::value::ListValue; use datatypes::value::ListValue;
use datatypes::vectors::{ConstantVector, Int64Vector, ListVector}; use datatypes::vectors::{ConstantVector, Helper, Int64Vector, ListVector};
use datatypes::with_match_primitive_type_id; use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive; use num_traits::AsPrimitive;
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
@@ -34,8 +34,10 @@ use snafu::{ensure, OptionExt, ResultExt};
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct Polyval<T, PolyT> pub struct Polyval<T, PolyT>
where where
T: Primitive + AsPrimitive<PolyT>, T: WrapperType,
PolyT: Primitive + std::ops::Mul<Output = PolyT>, T::Native: AsPrimitive<PolyT::Native>,
PolyT: WrapperType,
PolyT::Native: std::ops::Mul<Output = PolyT::Native>,
{ {
values: Vec<T>, values: Vec<T>,
// DataFusion casts constant in into i64 type. // DataFusion casts constant in into i64 type.
@@ -45,8 +47,10 @@ where
impl<T, PolyT> Polyval<T, PolyT> impl<T, PolyT> Polyval<T, PolyT>
where where
T: Primitive + AsPrimitive<PolyT>, T: WrapperType,
PolyT: Primitive + std::ops::Mul<Output = PolyT>, T::Native: AsPrimitive<PolyT::Native>,
PolyT: WrapperType,
PolyT::Native: std::ops::Mul<Output = PolyT::Native>,
{ {
fn push(&mut self, value: T) { fn push(&mut self, value: T) {
self.values.push(value); self.values.push(value);
@@ -55,11 +59,11 @@ where
impl<T, PolyT> Accumulator for Polyval<T, PolyT> impl<T, PolyT> Accumulator for Polyval<T, PolyT>
where where
T: Primitive + AsPrimitive<PolyT>, T: WrapperType,
PolyT: Primitive + std::ops::Mul<Output = PolyT> + std::iter::Sum<PolyT>, T::Native: AsPrimitive<PolyT::Native>,
for<'a> T: Scalar<RefType<'a> = T>, PolyT: WrapperType + std::iter::Sum<<PolyT as WrapperType>::Native>,
for<'a> PolyT: Scalar<RefType<'a> = PolyT>, PolyT::Native: std::ops::Mul<Output = PolyT::Native> + std::iter::Sum<PolyT::Native>,
i64: AsPrimitive<PolyT>, i64: AsPrimitive<<PolyT as WrapperType>::Native>,
{ {
fn state(&self) -> Result<Vec<Value>> { fn state(&self) -> Result<Vec<Value>> {
let nums = self let nums = self
@@ -70,7 +74,7 @@ where
Ok(vec![ Ok(vec![
Value::List(ListValue::new( Value::List(ListValue::new(
Some(Box::new(nums)), Some(Box::new(nums)),
T::default().into().data_type(), T::LogicalType::build_data_type(),
)), )),
self.x.into(), self.x.into(),
]) ])
@@ -91,10 +95,10 @@ where
let mut len = 1; let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() { let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len(); len = column.len();
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) } unsafe { Helper::static_cast(column.inner()) }
} else { } else {
unsafe { VectorHelper::static_cast(column) } unsafe { Helper::static_cast(column) }
}; };
(0..len).for_each(|_| { (0..len).for_each(|_| {
for v in column.iter_data().flatten() { for v in column.iter_data().flatten() {
@@ -103,7 +107,7 @@ where
}); });
let x = &values[1]; let x = &values[1];
let x = VectorHelper::check_get_scalar::<i64>(x).context(error::InvalidInputsSnafu { let x = Helper::check_get_scalar::<i64>(x).context(error::InvalidInputTypeSnafu {
err_msg: "expecting \"POLYVAL\" function's second argument to be a positive integer", err_msg: "expecting \"POLYVAL\" function's second argument to be a positive integer",
})?; })?;
// `get(0)` is safe because we have checked `values[1].len() == values[0].len() != 0` // `get(0)` is safe because we have checked `values[1].len() == values[0].len() != 0`
@@ -172,12 +176,14 @@ where
), ),
})?; })?;
for value in values.values_iter() { for value in values.values_iter() {
let value = value.context(FromScalarValueSnafu)?; if let Some(value) = value.context(FromScalarValueSnafu)? {
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&value) }; let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(&value) };
for v in column.iter_data().flatten() { for v in column.iter_data().flatten() {
self.push(v); self.push(v);
}
} }
} }
Ok(()) Ok(())
} }
@@ -196,7 +202,7 @@ where
.values .values
.iter() .iter()
.enumerate() .enumerate()
.map(|(i, &value)| value.as_() * (x.pow((len - 1 - i) as u32)).as_()) .map(|(i, &value)| value.into_native().as_() * x.pow((len - 1 - i) as u32).as_())
.sum(); .sum();
Ok(polyval.into()) Ok(polyval.into())
} }
@@ -213,7 +219,7 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_type.logical_type_id(), input_type.logical_type_id(),
|$S| { |$S| {
Ok(Box::new(Polyval::<$S,<$S as Primitive>::LargestType>::default())) Ok(Box::new(Polyval::<<$S as LogicalPrimitiveType>::Wrapper, <<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::Wrapper>::default()))
}, },
{ {
let err_msg = format!( let err_msg = format!(
@@ -234,7 +240,7 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_type, input_type,
|$S| { |$S| {
Ok(PrimitiveType::<<$S as Primitive>::LargestType>::default().into()) Ok(<<$S as LogicalPrimitiveType>::LargestType as LogicalPrimitiveType>::build_data_type())
}, },
{ {
unreachable!() unreachable!()
@@ -254,7 +260,7 @@ impl AggregateFunctionCreator for PolyvalAccumulatorCreator {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use datatypes::vectors::PrimitiveVector; use datatypes::vectors::Int32Vector;
use super::*; use super::*;
#[test] #[test]
@@ -268,8 +274,8 @@ mod test {
// test update one not-null value // test update one not-null value
let mut polyval = Polyval::<i32, i64>::default(); let mut polyval = Polyval::<i32, i64>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![Some(3)])), Arc::new(Int32Vector::from(vec![Some(3)])),
Arc::new(PrimitiveVector::<i64>::from(vec![Some(2_i64)])), Arc::new(Int64Vector::from(vec![Some(2_i64)])),
]; ];
assert!(polyval.update_batch(&v).is_ok()); assert!(polyval.update_batch(&v).is_ok());
assert_eq!(Value::Int64(3), polyval.evaluate().unwrap()); assert_eq!(Value::Int64(3), polyval.evaluate().unwrap());
@@ -277,8 +283,8 @@ mod test {
// test update one null value // test update one null value
let mut polyval = Polyval::<i32, i64>::default(); let mut polyval = Polyval::<i32, i64>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![Option::<i32>::None])), Arc::new(Int32Vector::from(vec![Option::<i32>::None])),
Arc::new(PrimitiveVector::<i64>::from(vec![Some(2_i64)])), Arc::new(Int64Vector::from(vec![Some(2_i64)])),
]; ];
assert!(polyval.update_batch(&v).is_ok()); assert!(polyval.update_batch(&v).is_ok());
assert_eq!(Value::Null, polyval.evaluate().unwrap()); assert_eq!(Value::Null, polyval.evaluate().unwrap());
@@ -286,12 +292,8 @@ mod test {
// test update no null-value batch // test update no null-value batch
let mut polyval = Polyval::<i32, i64>::default(); let mut polyval = Polyval::<i32, i64>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(3), Some(0), Some(1)])),
Some(3), Arc::new(Int64Vector::from(vec![
Some(0),
Some(1),
])),
Arc::new(PrimitiveVector::<i64>::from(vec![
Some(2_i64), Some(2_i64),
Some(2_i64), Some(2_i64),
Some(2_i64), Some(2_i64),
@@ -303,13 +305,8 @@ mod test {
// test update null-value batch // test update null-value batch
let mut polyval = Polyval::<i32, i64>::default(); let mut polyval = Polyval::<i32, i64>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(3), Some(0), None, Some(1)])),
Some(3), Arc::new(Int64Vector::from(vec![
Some(0),
None,
Some(1),
])),
Arc::new(PrimitiveVector::<i64>::from(vec![
Some(2_i64), Some(2_i64),
Some(2_i64), Some(2_i64),
Some(2_i64), Some(2_i64),
@@ -323,10 +320,10 @@ mod test {
let mut polyval = Polyval::<i32, i64>::default(); let mut polyval = Polyval::<i32, i64>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(ConstantVector::new( Arc::new(ConstantVector::new(
Arc::new(PrimitiveVector::<i32>::from_vec(vec![4])), Arc::new(Int32Vector::from_vec(vec![4])),
2, 2,
)), )),
Arc::new(PrimitiveVector::<i64>::from(vec![Some(5_i64), Some(5_i64)])), Arc::new(Int64Vector::from(vec![Some(5_i64), Some(5_i64)])),
]; ];
assert!(polyval.update_batch(&v).is_ok()); assert!(polyval.update_batch(&v).is_ok());
assert_eq!(Value::Int64(24), polyval.evaluate().unwrap()); assert_eq!(Value::Int64(24), polyval.evaluate().unwrap());

View File

@@ -23,7 +23,7 @@ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*; use common_query::prelude::*;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::value::{ListValue, OrderedFloat}; use datatypes::value::{ListValue, OrderedFloat};
use datatypes::vectors::{ConstantVector, Float64Vector, ListVector}; use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector};
use datatypes::with_match_primitive_type_id; use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive; use num_traits::AsPrimitive;
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
@@ -33,18 +33,12 @@ use statrs::statistics::Statistics;
// https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct ScipyStatsNormCdf<T> pub struct ScipyStatsNormCdf<T> {
where
T: Primitive + AsPrimitive<f64> + std::iter::Sum<T>,
{
values: Vec<T>, values: Vec<T>,
x: Option<f64>, x: Option<f64>,
} }
impl<T> ScipyStatsNormCdf<T> impl<T> ScipyStatsNormCdf<T> {
where
T: Primitive + AsPrimitive<f64> + std::iter::Sum<T>,
{
fn push(&mut self, value: T) { fn push(&mut self, value: T) {
self.values.push(value); self.values.push(value);
} }
@@ -52,8 +46,8 @@ where
impl<T> Accumulator for ScipyStatsNormCdf<T> impl<T> Accumulator for ScipyStatsNormCdf<T>
where where
T: Primitive + AsPrimitive<f64> + std::iter::Sum<T>, T: WrapperType + std::iter::Sum<T>,
for<'a> T: Scalar<RefType<'a> = T>, T::Native: AsPrimitive<f64>,
{ {
fn state(&self) -> Result<Vec<Value>> { fn state(&self) -> Result<Vec<Value>> {
let nums = self let nums = self
@@ -64,7 +58,7 @@ where
Ok(vec![ Ok(vec![
Value::List(ListValue::new( Value::List(ListValue::new(
Some(Box::new(nums)), Some(Box::new(nums)),
T::default().into().data_type(), T::LogicalType::build_data_type(),
)), )),
self.x.into(), self.x.into(),
]) ])
@@ -86,14 +80,14 @@ where
let mut len = 1; let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() { let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len(); len = column.len();
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) } unsafe { Helper::static_cast(column.inner()) }
} else { } else {
unsafe { VectorHelper::static_cast(column) } unsafe { Helper::static_cast(column) }
}; };
let x = &values[1]; let x = &values[1];
let x = VectorHelper::check_get_scalar::<f64>(x).context(error::InvalidInputsSnafu { let x = Helper::check_get_scalar::<f64>(x).context(error::InvalidInputTypeSnafu {
err_msg: "expecting \"SCIPYSTATSNORMCDF\" function's second argument to be a positive integer", err_msg: "expecting \"SCIPYSTATSNORMCDF\" function's second argument to be a positive integer",
})?; })?;
let first = x.get(0); let first = x.get(0);
@@ -160,19 +154,19 @@ where
), ),
})?; })?;
for value in values.values_iter() { for value in values.values_iter() {
let value = value.context(FromScalarValueSnafu)?; if let Some(value) = value.context(FromScalarValueSnafu)? {
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&value) }; let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(&value) };
for v in column.iter_data().flatten() { for v in column.iter_data().flatten() {
self.push(v); self.push(v);
}
} }
} }
Ok(()) Ok(())
} }
fn evaluate(&self) -> Result<Value> { fn evaluate(&self) -> Result<Value> {
let values = self.values.iter().map(|&v| v.as_()).collect::<Vec<_>>(); let mean = self.values.iter().map(|v| v.into_native().as_()).mean();
let mean = values.clone().mean(); let std_dev = self.values.iter().map(|v| v.into_native().as_()).std_dev();
let std_dev = values.std_dev();
if mean.is_nan() || std_dev.is_nan() { if mean.is_nan() || std_dev.is_nan() {
Ok(Value::Null) Ok(Value::Null)
} else { } else {
@@ -198,7 +192,7 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_type.logical_type_id(), input_type.logical_type_id(),
|$S| { |$S| {
Ok(Box::new(ScipyStatsNormCdf::<$S>::default())) Ok(Box::new(ScipyStatsNormCdf::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
}, },
{ {
let err_msg = format!( let err_msg = format!(
@@ -230,7 +224,7 @@ impl AggregateFunctionCreator for ScipyStatsNormCdfAccumulatorCreator {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use datatypes::vectors::PrimitiveVector; use datatypes::vectors::{Float64Vector, Int32Vector};
use super::*; use super::*;
#[test] #[test]
@@ -244,12 +238,8 @@ mod test {
// test update no null-value batch // test update no null-value batch
let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::<i32>::default(); let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Some(-1i32), Arc::new(Float64Vector::from(vec![
Some(1),
Some(2),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(2.0_f64), Some(2.0_f64),
Some(2.0_f64), Some(2.0_f64),
Some(2.0_f64), Some(2.0_f64),
@@ -264,13 +254,8 @@ mod test {
// test update null-value batch // test update null-value batch
let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::<i32>::default(); let mut scipy_stats_norm_cdf = ScipyStatsNormCdf::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])),
Some(-2i32), Arc::new(Float64Vector::from(vec![
None,
Some(3),
Some(4),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(2.0_f64), Some(2.0_f64),
None, None,
Some(2.0_f64), Some(2.0_f64),

View File

@@ -23,7 +23,7 @@ use common_query::logical_plan::{Accumulator, AggregateFunctionCreator};
use common_query::prelude::*; use common_query::prelude::*;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::value::{ListValue, OrderedFloat}; use datatypes::value::{ListValue, OrderedFloat};
use datatypes::vectors::{ConstantVector, Float64Vector, ListVector}; use datatypes::vectors::{ConstantVector, Float64Vector, Helper, ListVector};
use datatypes::with_match_primitive_type_id; use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive; use num_traits::AsPrimitive;
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
@@ -33,18 +33,12 @@ use statrs::statistics::Statistics;
// https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html // https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.norm.html
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct ScipyStatsNormPdf<T> pub struct ScipyStatsNormPdf<T> {
where
T: Primitive + AsPrimitive<f64> + std::iter::Sum<T>,
{
values: Vec<T>, values: Vec<T>,
x: Option<f64>, x: Option<f64>,
} }
impl<T> ScipyStatsNormPdf<T> impl<T> ScipyStatsNormPdf<T> {
where
T: Primitive + AsPrimitive<f64> + std::iter::Sum<T>,
{
fn push(&mut self, value: T) { fn push(&mut self, value: T) {
self.values.push(value); self.values.push(value);
} }
@@ -52,8 +46,8 @@ where
impl<T> Accumulator for ScipyStatsNormPdf<T> impl<T> Accumulator for ScipyStatsNormPdf<T>
where where
T: Primitive + AsPrimitive<f64> + std::iter::Sum<T>, T: WrapperType,
for<'a> T: Scalar<RefType<'a> = T>, T::Native: AsPrimitive<f64> + std::iter::Sum<T>,
{ {
fn state(&self) -> Result<Vec<Value>> { fn state(&self) -> Result<Vec<Value>> {
let nums = self let nums = self
@@ -64,7 +58,7 @@ where
Ok(vec![ Ok(vec![
Value::List(ListValue::new( Value::List(ListValue::new(
Some(Box::new(nums)), Some(Box::new(nums)),
T::default().into().data_type(), T::LogicalType::build_data_type(),
)), )),
self.x.into(), self.x.into(),
]) ])
@@ -86,14 +80,14 @@ where
let mut len = 1; let mut len = 1;
let column: &<T as Scalar>::VectorType = if column.is_const() { let column: &<T as Scalar>::VectorType = if column.is_const() {
len = column.len(); len = column.len();
let column: &ConstantVector = unsafe { VectorHelper::static_cast(column) }; let column: &ConstantVector = unsafe { Helper::static_cast(column) };
unsafe { VectorHelper::static_cast(column.inner()) } unsafe { Helper::static_cast(column.inner()) }
} else { } else {
unsafe { VectorHelper::static_cast(column) } unsafe { Helper::static_cast(column) }
}; };
let x = &values[1]; let x = &values[1];
let x = VectorHelper::check_get_scalar::<f64>(x).context(error::InvalidInputsSnafu { let x = Helper::check_get_scalar::<f64>(x).context(error::InvalidInputTypeSnafu {
err_msg: "expecting \"SCIPYSTATSNORMPDF\" function's second argument to be a positive integer", err_msg: "expecting \"SCIPYSTATSNORMPDF\" function's second argument to be a positive integer",
})?; })?;
let first = x.get(0); let first = x.get(0);
@@ -160,19 +154,20 @@ where
), ),
})?; })?;
for value in values.values_iter() { for value in values.values_iter() {
let value = value.context(FromScalarValueSnafu)?; if let Some(value) = value.context(FromScalarValueSnafu)? {
let column: &<T as Scalar>::VectorType = unsafe { VectorHelper::static_cast(&value) }; let column: &<T as Scalar>::VectorType = unsafe { Helper::static_cast(&value) };
for v in column.iter_data().flatten() { for v in column.iter_data().flatten() {
self.push(v); self.push(v);
}
} }
} }
Ok(()) Ok(())
} }
fn evaluate(&self) -> Result<Value> { fn evaluate(&self) -> Result<Value> {
let values = self.values.iter().map(|&v| v.as_()).collect::<Vec<_>>(); let mean = self.values.iter().map(|v| v.into_native().as_()).mean();
let mean = values.clone().mean(); let std_dev = self.values.iter().map(|v| v.into_native().as_()).std_dev();
let std_dev = values.std_dev();
if mean.is_nan() || std_dev.is_nan() { if mean.is_nan() || std_dev.is_nan() {
Ok(Value::Null) Ok(Value::Null)
} else { } else {
@@ -198,7 +193,7 @@ impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator {
with_match_primitive_type_id!( with_match_primitive_type_id!(
input_type.logical_type_id(), input_type.logical_type_id(),
|$S| { |$S| {
Ok(Box::new(ScipyStatsNormPdf::<$S>::default())) Ok(Box::new(ScipyStatsNormPdf::<<$S as LogicalPrimitiveType>::Wrapper>::default()))
}, },
{ {
let err_msg = format!( let err_msg = format!(
@@ -230,7 +225,7 @@ impl AggregateFunctionCreator for ScipyStatsNormPdfAccumulatorCreator {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use datatypes::vectors::PrimitiveVector; use datatypes::vectors::{Float64Vector, Int32Vector};
use super::*; use super::*;
#[test] #[test]
@@ -244,12 +239,8 @@ mod test {
// test update no null-value batch // test update no null-value batch
let mut scipy_stats_norm_pdf = ScipyStatsNormPdf::<i32>::default(); let mut scipy_stats_norm_pdf = ScipyStatsNormPdf::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(-1i32), Some(1), Some(2)])),
Some(-1i32), Arc::new(Float64Vector::from(vec![
Some(1),
Some(2),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(2.0_f64), Some(2.0_f64),
Some(2.0_f64), Some(2.0_f64),
Some(2.0_f64), Some(2.0_f64),
@@ -264,13 +255,8 @@ mod test {
// test update null-value batch // test update null-value batch
let mut scipy_stats_norm_pdf = ScipyStatsNormPdf::<i32>::default(); let mut scipy_stats_norm_pdf = ScipyStatsNormPdf::<i32>::default();
let v: Vec<VectorRef> = vec![ let v: Vec<VectorRef> = vec![
Arc::new(PrimitiveVector::<i32>::from(vec![ Arc::new(Int32Vector::from(vec![Some(-2i32), None, Some(3), Some(4)])),
Some(-2i32), Arc::new(Float64Vector::from(vec![
None,
Some(3),
Some(4),
])),
Arc::new(PrimitiveVector::<f64>::from(vec![
Some(2.0_f64), Some(2.0_f64),
None, None,
Some(2.0_f64), Some(2.0_f64),

View File

@@ -14,10 +14,10 @@
use std::iter; use std::iter;
use common_query::error::Result;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::vectors::ConstantVector; use datatypes::vectors::{ConstantVector, Helper};
use crate::error::Result;
use crate::scalars::expression::ctx::EvalContext; use crate::scalars::expression::ctx::EvalContext;
pub fn scalar_binary_op<L: Scalar, R: Scalar, O: Scalar, F>( pub fn scalar_binary_op<L: Scalar, R: Scalar, O: Scalar, F>(
@@ -36,10 +36,9 @@ where
let result = match (l.is_const(), r.is_const()) { let result = match (l.is_const(), r.is_const()) {
(false, true) => { (false, true) => {
let left: &<L as Scalar>::VectorType = unsafe { VectorHelper::static_cast(l) }; let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(l) };
let right: &ConstantVector = unsafe { VectorHelper::static_cast(r) }; let right: &ConstantVector = unsafe { Helper::static_cast(r) };
let right: &<R as Scalar>::VectorType = let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(right.inner()) };
unsafe { VectorHelper::static_cast(right.inner()) };
let b = right.get_data(0); let b = right.get_data(0);
let it = left.iter_data().map(|a| f(a, b, ctx)); let it = left.iter_data().map(|a| f(a, b, ctx));
@@ -47,8 +46,8 @@ where
} }
(false, false) => { (false, false) => {
let left: &<L as Scalar>::VectorType = unsafe { VectorHelper::static_cast(l) }; let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(l) };
let right: &<R as Scalar>::VectorType = unsafe { VectorHelper::static_cast(r) }; let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(r) };
let it = left let it = left
.iter_data() .iter_data()
@@ -58,25 +57,22 @@ where
} }
(true, false) => { (true, false) => {
let left: &ConstantVector = unsafe { VectorHelper::static_cast(l) }; let left: &ConstantVector = unsafe { Helper::static_cast(l) };
let left: &<L as Scalar>::VectorType = let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(left.inner()) };
unsafe { VectorHelper::static_cast(left.inner()) };
let a = left.get_data(0); let a = left.get_data(0);
let right: &<R as Scalar>::VectorType = unsafe { VectorHelper::static_cast(r) }; let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(r) };
let it = right.iter_data().map(|b| f(a, b, ctx)); let it = right.iter_data().map(|b| f(a, b, ctx));
<O as Scalar>::VectorType::from_owned_iterator(it) <O as Scalar>::VectorType::from_owned_iterator(it)
} }
(true, true) => { (true, true) => {
let left: &ConstantVector = unsafe { VectorHelper::static_cast(l) }; let left: &ConstantVector = unsafe { Helper::static_cast(l) };
let left: &<L as Scalar>::VectorType = let left: &<L as Scalar>::VectorType = unsafe { Helper::static_cast(left.inner()) };
unsafe { VectorHelper::static_cast(left.inner()) };
let a = left.get_data(0); let a = left.get_data(0);
let right: &ConstantVector = unsafe { VectorHelper::static_cast(r) }; let right: &ConstantVector = unsafe { Helper::static_cast(r) };
let right: &<R as Scalar>::VectorType = let right: &<R as Scalar>::VectorType = unsafe { Helper::static_cast(right.inner()) };
unsafe { VectorHelper::static_cast(right.inner()) };
let b = right.get_data(0); let b = right.get_data(0);
let it = iter::repeat(a) let it = iter::repeat(a)

View File

@@ -13,8 +13,7 @@
// limitations under the License. // limitations under the License.
use chrono_tz::Tz; use chrono_tz::Tz;
use common_query::error::Error;
use crate::error::Error;
pub struct EvalContext { pub struct EvalContext {
_tz: Tz, _tz: Tz,

View File

@@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use common_query::error::{self, Result};
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::vectors::Helper;
use snafu::ResultExt; use snafu::ResultExt;
use crate::error::{GetScalarVectorSnafu, Result};
use crate::scalars::expression::ctx::EvalContext; use crate::scalars::expression::ctx::EvalContext;
/// TODO: remove the allow_unused when it's used. /// TODO: remove the allow_unused when it's used.
@@ -28,7 +29,7 @@ pub fn scalar_unary_op<L: Scalar, O: Scalar, F>(
where where
F: Fn(Option<L::RefType<'_>>, &mut EvalContext) -> Option<O>, F: Fn(Option<L::RefType<'_>>, &mut EvalContext) -> Option<O>,
{ {
let left = VectorHelper::check_get_scalar::<L>(l).context(GetScalarVectorSnafu)?; let left = Helper::check_get_scalar::<L>(l).context(error::GetScalarVectorSnafu)?;
let it = left.iter_data().map(|a| f(a, ctx)); let it = left.iter_data().map(|a| f(a, ctx));
let result = <O as Scalar>::VectorType::from_owned_iterator(it); let result = <O as Scalar>::VectorType::from_owned_iterator(it);

View File

@@ -16,12 +16,11 @@ use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use chrono_tz::Tz; use chrono_tz::Tz;
use common_query::error::Result;
use common_query::prelude::Signature; use common_query::prelude::Signature;
use datatypes::data_type::ConcreteDataType; use datatypes::data_type::ConcreteDataType;
use datatypes::vectors::VectorRef; use datatypes::vectors::VectorRef;
use crate::error::Result;
#[derive(Clone)] #[derive(Clone)]
pub struct FunctionContext { pub struct FunctionContext {
pub tz: Tz, pub tz: Tz,

View File

@@ -13,10 +13,12 @@
// limitations under the License. // limitations under the License.
mod pow; mod pow;
mod rate;
use std::sync::Arc; use std::sync::Arc;
pub use pow::PowFunction; pub use pow::PowFunction;
pub use rate::RateFunction;
use crate::scalars::function_registry::FunctionRegistry; use crate::scalars::function_registry::FunctionRegistry;
@@ -25,5 +27,6 @@ pub(crate) struct MathFunction;
impl MathFunction { impl MathFunction {
pub fn register(registry: &FunctionRegistry) { pub fn register(registry: &FunctionRegistry) {
registry.register(Arc::new(PowFunction::default())); registry.register(Arc::new(PowFunction::default()));
registry.register(Arc::new(RateFunction::default()))
} }
} }

View File

@@ -15,15 +15,16 @@
use std::fmt; use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use common_query::error::Result;
use common_query::prelude::{Signature, Volatility}; use common_query::prelude::{Signature, Volatility};
use datatypes::data_type::DataType; use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType; use datatypes::prelude::ConcreteDataType;
use datatypes::types::LogicalPrimitiveType;
use datatypes::vectors::VectorRef; use datatypes::vectors::VectorRef;
use datatypes::with_match_primitive_type_id; use datatypes::with_match_primitive_type_id;
use num::traits::Pow; use num::traits::Pow;
use num_traits::AsPrimitive; use num_traits::AsPrimitive;
use crate::error::Result;
use crate::scalars::expression::{scalar_binary_op, EvalContext}; use crate::scalars::expression::{scalar_binary_op, EvalContext};
use crate::scalars::function::{Function, FunctionContext}; use crate::scalars::function::{Function, FunctionContext};
@@ -46,7 +47,7 @@ impl Function for PowFunction {
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> { fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| { with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| {
let col = scalar_binary_op::<$S, $T, f64, _>(&columns[0], &columns[1], scalar_pow, &mut EvalContext::default())?; let col = scalar_binary_op::<<$S as LogicalPrimitiveType>::Native, <$T as LogicalPrimitiveType>::Native, f64, _>(&columns[0], &columns[1], scalar_pow, &mut EvalContext::default())?;
Ok(Arc::new(col)) Ok(Arc::new(col))
},{ },{
unreachable!() unreachable!()

View File

@@ -0,0 +1,106 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::fmt;
use common_query::error::{self, Result};
use common_query::prelude::{Signature, Volatility};
use datatypes::arrow::compute::kernels::{arithmetic, cast};
use datatypes::arrow::datatypes::DataType;
use datatypes::prelude::*;
use datatypes::vectors::{Helper, VectorRef};
use snafu::ResultExt;
use crate::scalars::function::{Function, FunctionContext};
/// generates rates from a sequence of adjacent data points.
#[derive(Clone, Debug, Default)]
pub struct RateFunction;
impl fmt::Display for RateFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "RATE")
}
}
impl Function for RateFunction {
fn name(&self) -> &str {
"prom_rate"
}
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::float64_datatype())
}
fn signature(&self) -> Signature {
Signature::uniform(2, ConcreteDataType::numerics(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
let val = &columns[0].to_arrow_array();
let val_0 = val.slice(0, val.len() - 1);
let val_1 = val.slice(1, val.len() - 1);
let dv = arithmetic::subtract_dyn(&val_1, &val_0).context(error::ArrowComputeSnafu)?;
let ts = &columns[1].to_arrow_array();
let ts_0 = ts.slice(0, ts.len() - 1);
let ts_1 = ts.slice(1, ts.len() - 1);
let dt = arithmetic::subtract_dyn(&ts_1, &ts_0).context(error::ArrowComputeSnafu)?;
let dv = cast::cast(&dv, &DataType::Float64).context(error::TypeCastSnafu {
typ: DataType::Float64,
})?;
let dt = cast::cast(&dt, &DataType::Float64).context(error::TypeCastSnafu {
typ: DataType::Float64,
})?;
let rate = arithmetic::divide_dyn(&dv, &dt).context(error::ArrowComputeSnafu)?;
let v = Helper::try_into_vector(&rate).context(error::FromArrowArraySnafu)?;
Ok(v)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use common_query::prelude::TypeSignature;
use datatypes::vectors::{Float32Vector, Float64Vector, Int64Vector};
use super::*;
#[test]
fn test_rate_function() {
let rate = RateFunction::default();
assert_eq!("prom_rate", rate.name());
assert_eq!(
ConcreteDataType::float64_datatype(),
rate.return_type(&[]).unwrap()
);
assert!(matches!(rate.signature(),
Signature {
type_signature: TypeSignature::Uniform(2, valid_types),
volatility: Volatility::Immutable
} if valid_types == ConcreteDataType::numerics()
));
let values = vec![1.0, 3.0, 6.0];
let ts = vec![0, 1, 2];
let args: Vec<VectorRef> = vec![
Arc::new(Float32Vector::from_vec(values)),
Arc::new(Int64Vector::from_vec(ts)),
];
let vector = rate.eval(FunctionContext::default(), &args).unwrap();
let expect: VectorRef = Arc::new(Float64Vector::from_vec(vec![2.0, 3.0]));
assert_eq!(expect, vector);
}
}

View File

@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
mod clip; mod clip;
#[allow(unused)]
mod interp; mod interp;
use std::sync::Arc; use std::sync::Arc;

View File

@@ -15,14 +15,15 @@
use std::fmt; use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use common_query::error::Result;
use common_query::prelude::{Signature, Volatility}; use common_query::prelude::{Signature, Volatility};
use datatypes::data_type::{ConcreteDataType, DataType}; use datatypes::arrow::compute;
use datatypes::prelude::{Scalar, VectorRef}; use datatypes::arrow::datatypes::ArrowPrimitiveType;
use datatypes::with_match_primitive_type_id; use datatypes::data_type::ConcreteDataType;
use num_traits::AsPrimitive; use datatypes::prelude::*;
use datatypes::vectors::PrimitiveVector;
use paste::paste; use paste::paste;
use crate::error::Result;
use crate::scalars::expression::{scalar_binary_op, EvalContext}; use crate::scalars::expression::{scalar_binary_op, EvalContext};
use crate::scalars::function::{Function, FunctionContext}; use crate::scalars::function::{Function, FunctionContext};
@@ -34,25 +35,32 @@ macro_rules! define_eval {
($O: ident) => { ($O: ident) => {
paste! { paste! {
fn [<eval_ $O>](columns: &[VectorRef]) -> Result<VectorRef> { fn [<eval_ $O>](columns: &[VectorRef]) -> Result<VectorRef> {
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| { fn cast_vector(input: &VectorRef) -> VectorRef {
with_match_primitive_type_id!(columns[1].data_type().logical_type_id(), |$T| { Arc::new(PrimitiveVector::<<$O as WrapperType>::LogicalType>::try_from_arrow_array(
with_match_primitive_type_id!(columns[2].data_type().logical_type_id(), |$R| { compute::cast(&input.to_arrow_array(), &<<<$O as WrapperType>::LogicalType as LogicalPrimitiveType>::ArrowPrimitive as ArrowPrimitiveType>::DATA_TYPE).unwrap()
// clip(a, min, max) is equals to min(max(a, min), max) ).unwrap()) as _
let col: VectorRef = Arc::new(scalar_binary_op::<$S, $T, $O, _>(&columns[0], &columns[1], scalar_max, &mut EvalContext::default())?); }
let col = scalar_binary_op::<$O, $R, $O, _>(&col, &columns[2], scalar_min, &mut EvalContext::default())?; let operator_1 = cast_vector(&columns[0]);
Ok(Arc::new(col)) let operator_2 = cast_vector(&columns[1]);
}, { let operator_3 = cast_vector(&columns[2]);
unreachable!()
}) // clip(a, min, max) is equals to min(max(a, min), max)
}, { let col: VectorRef = Arc::new(scalar_binary_op::<$O, $O, $O, _>(
unreachable!() &operator_1,
}) &operator_2,
}, { scalar_max,
unreachable!() &mut EvalContext::default(),
}) )?);
let col = scalar_binary_op::<$O, $O, $O, _>(
&col,
&operator_3,
scalar_min,
&mut EvalContext::default(),
)?;
Ok(Arc::new(col))
} }
} }
} };
} }
define_eval!(i64); define_eval!(i64);
@@ -108,27 +116,23 @@ pub fn max<T: PartialOrd>(input: T, max: T) -> T {
} }
#[inline] #[inline]
fn scalar_min<S, T, O>(left: Option<S>, right: Option<T>, _ctx: &mut EvalContext) -> Option<O> fn scalar_min<O>(left: Option<O>, right: Option<O>, _ctx: &mut EvalContext) -> Option<O>
where where
S: AsPrimitive<O>,
T: AsPrimitive<O>,
O: Scalar + Copy + PartialOrd, O: Scalar + Copy + PartialOrd,
{ {
match (left, right) { match (left, right) {
(Some(left), Some(right)) => Some(min(left.as_(), right.as_())), (Some(left), Some(right)) => Some(min(left, right)),
_ => None, _ => None,
} }
} }
#[inline] #[inline]
fn scalar_max<S, T, O>(left: Option<S>, right: Option<T>, _ctx: &mut EvalContext) -> Option<O> fn scalar_max<O>(left: Option<O>, right: Option<O>, _ctx: &mut EvalContext) -> Option<O>
where where
S: AsPrimitive<O>,
T: AsPrimitive<O>,
O: Scalar + Copy + PartialOrd, O: Scalar + Copy + PartialOrd,
{ {
match (left, right) { match (left, right) {
(Some(left), Some(right)) => Some(max(left.as_(), right.as_())), (Some(left), Some(right)) => Some(max(left, right)),
_ => None, _ => None,
} }
} }
@@ -143,11 +147,15 @@ impl fmt::Display for ClipFunction {
mod tests { mod tests {
use common_query::prelude::TypeSignature; use common_query::prelude::TypeSignature;
use datatypes::value::Value; use datatypes::value::Value;
use datatypes::vectors::{ConstantVector, Float32Vector, Int32Vector, UInt32Vector}; use datatypes::vectors::{
ConstantVector, Float32Vector, Int16Vector, Int32Vector, Int8Vector, UInt16Vector,
UInt32Vector, UInt8Vector,
};
use super::*; use super::*;
#[test] #[test]
fn test_clip_function() { fn test_clip_signature() {
let clip = ClipFunction::default(); let clip = ClipFunction::default();
assert_eq!("clip", clip.name()); assert_eq!("clip", clip.name());
@@ -190,16 +198,21 @@ mod tests {
volatility: Volatility::Immutable volatility: Volatility::Immutable
} if valid_types == ConcreteDataType::numerics() } if valid_types == ConcreteDataType::numerics()
)); ));
}
#[test]
fn test_clip_fn_signed() {
let clip = ClipFunction::default();
// eval with signed integers // eval with signed integers
let args: Vec<VectorRef> = vec![ let args: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from_values(0..10)), Arc::new(Int32Vector::from_values(0..10)),
Arc::new(ConstantVector::new( Arc::new(ConstantVector::new(
Arc::new(Int32Vector::from_vec(vec![3])), Arc::new(Int8Vector::from_vec(vec![3])),
10, 10,
)), )),
Arc::new(ConstantVector::new( Arc::new(ConstantVector::new(
Arc::new(Int32Vector::from_vec(vec![6])), Arc::new(Int16Vector::from_vec(vec![6])),
10, 10,
)), )),
]; ];
@@ -217,16 +230,21 @@ mod tests {
assert!(matches!(vector.get(i), Value::Int64(v) if v == 6)); assert!(matches!(vector.get(i), Value::Int64(v) if v == 6));
} }
} }
}
#[test]
fn test_clip_fn_unsigned() {
let clip = ClipFunction::default();
// eval with unsigned integers // eval with unsigned integers
let args: Vec<VectorRef> = vec![ let args: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_values(0..10)), Arc::new(UInt8Vector::from_values(0..10)),
Arc::new(ConstantVector::new( Arc::new(ConstantVector::new(
Arc::new(UInt32Vector::from_vec(vec![3])), Arc::new(UInt32Vector::from_vec(vec![3])),
10, 10,
)), )),
Arc::new(ConstantVector::new( Arc::new(ConstantVector::new(
Arc::new(UInt32Vector::from_vec(vec![6])), Arc::new(UInt16Vector::from_vec(vec![6])),
10, 10,
)), )),
]; ];
@@ -244,12 +262,17 @@ mod tests {
assert!(matches!(vector.get(i), Value::UInt64(v) if v == 6)); assert!(matches!(vector.get(i), Value::UInt64(v) if v == 6));
} }
} }
}
#[test]
fn test_clip_fn_float() {
let clip = ClipFunction::default();
// eval with floats // eval with floats
let args: Vec<VectorRef> = vec![ let args: Vec<VectorRef> = vec![
Arc::new(Int32Vector::from_values(0..10)), Arc::new(Int8Vector::from_values(0..10)),
Arc::new(ConstantVector::new( Arc::new(ConstantVector::new(
Arc::new(Int32Vector::from_vec(vec![3])), Arc::new(UInt32Vector::from_vec(vec![3])),
10, 10,
)), )),
Arc::new(ConstantVector::new( Arc::new(ConstantVector::new(

View File

@@ -14,41 +14,18 @@
use std::sync::Arc; use std::sync::Arc;
use datatypes::arrow::array::PrimitiveArray; use common_query::error::{self, Result};
use datatypes::arrow::compute::cast::primitive_to_primitive; use datatypes::arrow::compute::cast;
use datatypes::arrow::datatypes::DataType::Float64; use datatypes::arrow::datatypes::DataType as ArrowDataType;
use datatypes::data_type::DataType; use datatypes::data_type::DataType;
use datatypes::prelude::ScalarVector; use datatypes::prelude::ScalarVector;
use datatypes::type_id::LogicalTypeId;
use datatypes::value::Value; use datatypes::value::Value;
use datatypes::vectors::{Float64Vector, PrimitiveVector, Vector, VectorRef}; use datatypes::vectors::{Float64Vector, Vector, VectorRef};
use datatypes::{arrow, with_match_primitive_type_id}; use datatypes::with_match_primitive_type_id;
use snafu::{ensure, Snafu}; use snafu::{ensure, ResultExt};
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display(
"The length of the args is not enough, expect at least: {}, have: {}",
expect,
actual,
))]
ArgsLenNotEnough { expect: usize, actual: usize },
#[snafu(display("The sample {} is empty", name))]
SampleEmpty { name: String },
#[snafu(display(
"The length of the len1: {} don't match the length of the len2: {}",
len1,
len2,
))]
LenNotEquals { len1: usize, len2: usize },
}
pub type Result<T> = std::result::Result<T, Error>;
/* search the biggest number that smaller than x in xp */ /* search the biggest number that smaller than x in xp */
fn linear_search_ascending_vector(x: Value, xp: &PrimitiveVector<f64>) -> usize { fn linear_search_ascending_vector(x: Value, xp: &Float64Vector) -> usize {
for i in 0..xp.len() { for i in 0..xp.len() {
if x < xp.get(i) { if x < xp.get(i) {
return i - 1; return i - 1;
@@ -58,7 +35,7 @@ fn linear_search_ascending_vector(x: Value, xp: &PrimitiveVector<f64>) -> usize
} }
/* search the biggest number that smaller than x in xp */ /* search the biggest number that smaller than x in xp */
fn binary_search_ascending_vector(key: Value, xp: &PrimitiveVector<f64>) -> usize { fn binary_search_ascending_vector(key: Value, xp: &Float64Vector) -> usize {
let mut left = 0; let mut left = 0;
let mut right = xp.len(); let mut right = xp.len();
/* If len <= 4 use linear search. */ /* If len <= 4 use linear search. */
@@ -77,27 +54,33 @@ fn binary_search_ascending_vector(key: Value, xp: &PrimitiveVector<f64>) -> usiz
left - 1 left - 1
} }
fn concrete_type_to_primitive_vector(arg: &VectorRef) -> Result<PrimitiveVector<f64>> { fn concrete_type_to_primitive_vector(arg: &VectorRef) -> Result<Float64Vector> {
with_match_primitive_type_id!(arg.data_type().logical_type_id(), |$S| { with_match_primitive_type_id!(arg.data_type().logical_type_id(), |$S| {
let tmp = arg.to_arrow_array(); let tmp = arg.to_arrow_array();
let from = tmp.as_any().downcast_ref::<PrimitiveArray<$S>>().expect("cast failed"); let array = cast(&tmp, &ArrowDataType::Float64).context(error::TypeCastSnafu {
let array = primitive_to_primitive(from, &Float64); typ: ArrowDataType::Float64,
Ok(PrimitiveVector::new(array)) })?;
// Safety: array has been cast to Float64Array.
Ok(Float64Vector::try_from_arrow_array(array).unwrap())
},{ },{
unreachable!() unreachable!()
}) })
} }
/// https://github.com/numpy/numpy/blob/b101756ac02e390d605b2febcded30a1da50cc2c/numpy/core/src/multiarray/compiled_base.c#L491 /// https://github.com/numpy/numpy/blob/b101756ac02e390d605b2febcded30a1da50cc2c/numpy/core/src/multiarray/compiled_base.c#L491
#[allow(unused)]
pub fn interp(args: &[VectorRef]) -> Result<VectorRef> { pub fn interp(args: &[VectorRef]) -> Result<VectorRef> {
let mut left = None; let mut left = None;
let mut right = None; let mut right = None;
ensure!( ensure!(
args.len() >= 3, args.len() >= 3,
ArgsLenNotEnoughSnafu { error::InvalidFuncArgsSnafu {
expect: 3_usize, err_msg: format!(
actual: args.len() "The length of the args is not enough, expect at least: {}, have: {}",
3,
args.len()
),
} }
); );
@@ -109,9 +92,12 @@ pub fn interp(args: &[VectorRef]) -> Result<VectorRef> {
if args.len() > 3 { if args.len() > 3 {
ensure!( ensure!(
args.len() == 5, args.len() == 5,
ArgsLenNotEnoughSnafu { error::InvalidFuncArgsSnafu {
expect: 5_usize, err_msg: format!(
actual: args.len() "The length of the args is not enough, expect at least: {}, have: {}",
5,
args.len()
),
} }
); );
@@ -123,14 +109,32 @@ pub fn interp(args: &[VectorRef]) -> Result<VectorRef> {
.get_data(0); .get_data(0);
} }
ensure!(x.len() != 0, SampleEmptySnafu { name: "x" }); ensure!(
ensure!(xp.len() != 0, SampleEmptySnafu { name: "xp" }); x.len() != 0,
ensure!(fp.len() != 0, SampleEmptySnafu { name: "fp" }); error::InvalidFuncArgsSnafu {
err_msg: "The sample x is empty",
}
);
ensure!(
xp.len() != 0,
error::InvalidFuncArgsSnafu {
err_msg: "The sample xp is empty",
}
);
ensure!(
fp.len() != 0,
error::InvalidFuncArgsSnafu {
err_msg: "The sample fp is empty",
}
);
ensure!( ensure!(
xp.len() == fp.len(), xp.len() == fp.len(),
LenNotEqualsSnafu { error::InvalidFuncArgsSnafu {
len1: xp.len(), err_msg: format!(
len2: fp.len(), "The length of the len1: {} don't match the length of the len2: {}",
xp.len(),
fp.len()
),
} }
); );
@@ -147,7 +151,7 @@ pub fn interp(args: &[VectorRef]) -> Result<VectorRef> {
let res; let res;
if xp.len() == 1 { if xp.len() == 1 {
res = x let datas = x
.iter_data() .iter_data()
.map(|x| { .map(|x| {
if Value::from(x) < xp.get(0) { if Value::from(x) < xp.get(0) {
@@ -158,7 +162,8 @@ pub fn interp(args: &[VectorRef]) -> Result<VectorRef> {
fp.get_data(0) fp.get_data(0)
} }
}) })
.collect::<Float64Vector>(); .collect::<Vec<_>>();
res = Float64Vector::from(datas);
} else { } else {
let mut j = 0; let mut j = 0;
/* only pre-calculate slopes if there are relatively few of them. */ /* only pre-calculate slopes if there are relatively few of them. */
@@ -185,7 +190,7 @@ pub fn interp(args: &[VectorRef]) -> Result<VectorRef> {
} }
slopes = Some(slopes_tmp); slopes = Some(slopes_tmp);
} }
res = x let datas = x
.iter_data() .iter_data()
.map(|x| match x { .map(|x| match x {
Some(xi) => { Some(xi) => {
@@ -248,7 +253,8 @@ pub fn interp(args: &[VectorRef]) -> Result<VectorRef> {
} }
_ => None, _ => None,
}) })
.collect::<Float64Vector>(); .collect::<Vec<_>>();
res = Float64Vector::from(datas);
} }
Ok(Arc::new(res) as _) Ok(Arc::new(res) as _)
} }
@@ -257,8 +263,7 @@ pub fn interp(args: &[VectorRef]) -> Result<VectorRef> {
mod tests { mod tests {
use std::sync::Arc; use std::sync::Arc;
use datatypes::prelude::ScalarVectorBuilder; use datatypes::vectors::{Int32Vector, Int64Vector};
use datatypes::vectors::{Int32Vector, Int64Vector, PrimitiveVectorBuilder};
use super::*; use super::*;
#[test] #[test]
@@ -341,12 +346,8 @@ mod tests {
assert!(matches!(vector.get(0), Value::Float64(v) if v==x[0] as f64)); assert!(matches!(vector.get(0), Value::Float64(v) if v==x[0] as f64));
// x=None output:Null // x=None output:Null
let input = [None, Some(0.0), Some(0.3)]; let input = vec![None, Some(0.0), Some(0.3)];
let mut builder = PrimitiveVectorBuilder::with_capacity(input.len()); let x = Float64Vector::from(input);
for v in input {
builder.push(v);
}
let x = builder.finish();
let args: Vec<VectorRef> = vec![ let args: Vec<VectorRef> = vec![
Arc::new(x), Arc::new(x),
Arc::new(Int64Vector::from_vec(xp)), Arc::new(Int64Vector::from_vec(xp)),

View File

@@ -15,11 +15,11 @@
use std::fmt; use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use common_query::error::Result;
use common_query::prelude::{Signature, Volatility}; use common_query::prelude::{Signature, Volatility};
use datatypes::data_type::ConcreteDataType; use datatypes::data_type::ConcreteDataType;
use datatypes::prelude::VectorRef; use datatypes::prelude::VectorRef;
use crate::error::Result;
use crate::scalars::expression::{scalar_binary_op, EvalContext}; use crate::scalars::expression::{scalar_binary_op, EvalContext};
use crate::scalars::function::{Function, FunctionContext}; use crate::scalars::function::{Function, FunctionContext};

View File

@@ -17,16 +17,17 @@
use std::fmt; use std::fmt;
use std::sync::Arc; use std::sync::Arc;
use common_query::error::{IntoVectorSnafu, UnsupportedInputDataTypeSnafu}; use common_query::error::{
ArrowComputeSnafu, IntoVectorSnafu, Result, TypeCastSnafu, UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, Volatility}; use common_query::prelude::{Signature, Volatility};
use datatypes::arrow::compute::arithmetics; use datatypes::arrow::compute;
use datatypes::arrow::datatypes::DataType as ArrowDatatype; use datatypes::arrow::datatypes::{DataType as ArrowDatatype, Int64Type};
use datatypes::arrow::scalar::PrimitiveScalar; use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType; use datatypes::prelude::ConcreteDataType;
use datatypes::vectors::{TimestampVector, VectorRef}; use datatypes::vectors::{TimestampMillisecondVector, VectorRef};
use snafu::ResultExt; use snafu::ResultExt;
use crate::error::Result;
use crate::scalars::function::{Function, FunctionContext}; use crate::scalars::function::{Function, FunctionContext};
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
@@ -40,7 +41,7 @@ impl Function for FromUnixtimeFunction {
} }
fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> { fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::timestamp_millis_datatype()) Ok(ConcreteDataType::timestamp_millisecond_datatype())
} }
fn signature(&self) -> Signature { fn signature(&self) -> Signature {
@@ -56,14 +57,18 @@ impl Function for FromUnixtimeFunction {
ConcreteDataType::Int64(_) => { ConcreteDataType::Int64(_) => {
let array = columns[0].to_arrow_array(); let array = columns[0].to_arrow_array();
// Our timestamp vector's time unit is millisecond // Our timestamp vector's time unit is millisecond
let array = arithmetics::mul_scalar( let array = compute::multiply_scalar_dyn::<Int64Type>(&array, 1000i64)
&*array, .context(ArrowComputeSnafu)?;
&PrimitiveScalar::new(ArrowDatatype::Int64, Some(1000i64)),
);
let arrow_datatype = &self.return_type(&[]).unwrap().as_arrow_type();
Ok(Arc::new( Ok(Arc::new(
TimestampVector::try_from_arrow_array(array).context(IntoVectorSnafu { TimestampMillisecondVector::try_from_arrow_array(
data_type: ArrowDatatype::Int64, compute::cast(&array, arrow_datatype).context(TypeCastSnafu {
typ: ArrowDatatype::Int64,
})?,
)
.context(IntoVectorSnafu {
data_type: arrow_datatype.clone(),
})?, })?,
)) ))
} }
@@ -71,8 +76,7 @@ impl Function for FromUnixtimeFunction {
function: NAME, function: NAME,
datatypes: columns.iter().map(|c| c.data_type()).collect::<Vec<_>>(), datatypes: columns.iter().map(|c| c.data_type()).collect::<Vec<_>>(),
} }
.fail() .fail(),
.map_err(|e| e.into()),
} }
} }
} }
@@ -96,7 +100,7 @@ mod tests {
let f = FromUnixtimeFunction::default(); let f = FromUnixtimeFunction::default();
assert_eq!("from_unixtime", f.name()); assert_eq!("from_unixtime", f.name());
assert_eq!( assert_eq!(
ConcreteDataType::timestamp_millis_datatype(), ConcreteDataType::timestamp_millisecond_datatype(),
f.return_type(&[]).unwrap() f.return_type(&[]).unwrap()
); );

View File

@@ -19,7 +19,8 @@ use common_query::prelude::{
ColumnarValue, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUdf, ScalarValue, ColumnarValue, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUdf, ScalarValue,
}; };
use datatypes::error::Error as DataTypeError; use datatypes::error::Error as DataTypeError;
use datatypes::prelude::{ConcreteDataType, VectorHelper}; use datatypes::prelude::*;
use datatypes::vectors::Helper;
use snafu::ResultExt; use snafu::ResultExt;
use crate::scalars::function::{FunctionContext, FunctionRef}; use crate::scalars::function::{FunctionContext, FunctionRef};
@@ -47,7 +48,7 @@ pub fn create_udf(func: FunctionRef) -> ScalarUdf {
let args: Result<Vec<_>, DataTypeError> = args let args: Result<Vec<_>, DataTypeError> = args
.iter() .iter()
.map(|arg| match arg { .map(|arg| match arg {
ColumnarValue::Scalar(v) => VectorHelper::try_from_scalar_value(v.clone(), rows), ColumnarValue::Scalar(v) => Helper::try_from_scalar_value(v.clone(), rows),
ColumnarValue::Vector(v) => Ok(v.clone()), ColumnarValue::Vector(v) => Ok(v.clone()),
}) })
.collect(); .collect();

View File

@@ -22,11 +22,11 @@ use api::v1::{AddColumn, AddColumns, Column, ColumnDataType, ColumnDef, CreateEx
use common_base::BitVec; use common_base::BitVec;
use common_time::timestamp::Timestamp; use common_time::timestamp::Timestamp;
use common_time::{Date, DateTime}; use common_time::{Date, DateTime};
use datatypes::data_type::ConcreteDataType; use datatypes::data_type::{ConcreteDataType, DataType};
use datatypes::prelude::{ValueRef, VectorRef}; use datatypes::prelude::{ValueRef, VectorRef};
use datatypes::schema::SchemaRef; use datatypes::schema::SchemaRef;
use datatypes::value::Value; use datatypes::value::Value;
use datatypes::vectors::VectorBuilder; use datatypes::vectors::MutableVector;
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
use table::metadata::TableId; use table::metadata::TableId;
use table::requests::{AddColumnRequest, AlterKind, AlterTableRequest, InsertRequest}; use table::requests::{AddColumnRequest, AlterKind, AlterTableRequest, InsertRequest};
@@ -99,7 +99,7 @@ pub fn column_to_vector(column: &Column, rows: u32) -> Result<VectorRef> {
let column_datatype = wrapper.datatype(); let column_datatype = wrapper.datatype();
let rows = rows as usize; let rows = rows as usize;
let mut vector = VectorBuilder::with_capacity(wrapper.into(), rows); let mut vector = ConcreteDataType::from(wrapper).create_mutable_vector(rows);
if let Some(values) = &column.values { if let Some(values) = &column.values {
let values = collect_column_values(column_datatype, values); let values = collect_column_values(column_datatype, values);
@@ -110,21 +110,31 @@ pub fn column_to_vector(column: &Column, rows: u32) -> Result<VectorRef> {
for i in 0..rows { for i in 0..rows {
if let Some(true) = nulls_iter.next() { if let Some(true) = nulls_iter.next() {
vector.push_null(); vector
.push_value_ref(ValueRef::Null)
.context(CreateVectorSnafu)?;
} else { } else {
let value_ref = values_iter.next().context(InvalidColumnProtoSnafu { let value_ref = values_iter
err_msg: format!( .next()
"value not found at position {} of column {}", .with_context(|| InvalidColumnProtoSnafu {
i, &column.column_name err_msg: format!(
), "value not found at position {} of column {}",
})?; i, &column.column_name
vector.try_push_ref(value_ref).context(CreateVectorSnafu)?; ),
})?;
vector
.push_value_ref(value_ref)
.context(CreateVectorSnafu)?;
} }
} }
} else { } else {
(0..rows).for_each(|_| vector.push_null()); (0..rows).try_for_each(|_| {
vector
.push_value_ref(ValueRef::Null)
.context(CreateVectorSnafu)
})?;
} }
Ok(vector.finish()) Ok(vector.to_vector())
} }
fn collect_column_values(column_datatype: ColumnDataType, values: &Values) -> Vec<ValueRef> { fn collect_column_values(column_datatype: ColumnDataType, values: &Values) -> Vec<ValueRef> {
@@ -174,9 +184,24 @@ fn collect_column_values(column_datatype: ColumnDataType, values: &Values) -> Ve
DateTime::new(*v) DateTime::new(*v)
)) ))
} }
ColumnDataType::Timestamp => { ColumnDataType::TimestampSecond => {
collect_values!(values.ts_millis_values, |v| ValueRef::Timestamp( collect_values!(values.ts_second_values, |v| ValueRef::Timestamp(
Timestamp::from_millis(*v) Timestamp::new_second(*v)
))
}
ColumnDataType::TimestampMillisecond => {
collect_values!(values.ts_millisecond_values, |v| ValueRef::Timestamp(
Timestamp::new_millisecond(*v)
))
}
ColumnDataType::TimestampMicrosecond => {
collect_values!(values.ts_millisecond_values, |v| ValueRef::Timestamp(
Timestamp::new_microsecond(*v)
))
}
ColumnDataType::TimestampNanosecond => {
collect_values!(values.ts_millisecond_values, |v| ValueRef::Timestamp(
Timestamp::new_nanosecond(*v)
)) ))
} }
} }
@@ -289,10 +314,7 @@ pub fn insertion_expr_to_request(
}, },
)?; )?;
let data_type = &column_schema.data_type; let data_type = &column_schema.data_type;
entry.insert(VectorBuilder::with_capacity( entry.insert(data_type.create_mutable_vector(row_count as usize))
data_type.clone(),
row_count as usize,
))
} }
}; };
add_values_to_builder(vector_builder, values, row_count as usize, null_mask)?; add_values_to_builder(vector_builder, values, row_count as usize, null_mask)?;
@@ -300,7 +322,7 @@ pub fn insertion_expr_to_request(
} }
let columns_values = columns_builders let columns_values = columns_builders
.into_iter() .into_iter()
.map(|(column_name, mut vector_builder)| (column_name, vector_builder.finish())) .map(|(column_name, mut vector_builder)| (column_name, vector_builder.to_vector()))
.collect(); .collect();
Ok(InsertRequest { Ok(InsertRequest {
@@ -312,7 +334,7 @@ pub fn insertion_expr_to_request(
} }
fn add_values_to_builder( fn add_values_to_builder(
builder: &mut VectorBuilder, builder: &mut Box<dyn MutableVector>,
values: Values, values: Values,
row_count: usize, row_count: usize,
null_mask: Vec<u8>, null_mask: Vec<u8>,
@@ -323,9 +345,11 @@ fn add_values_to_builder(
if null_mask.is_empty() { if null_mask.is_empty() {
ensure!(values.len() == row_count, IllegalInsertDataSnafu); ensure!(values.len() == row_count, IllegalInsertDataSnafu);
values.iter().for_each(|value| { values.iter().try_for_each(|value| {
builder.push(value); builder
}); .push_value_ref(value.as_value_ref())
.context(CreateVectorSnafu)
})?;
} else { } else {
let null_mask = BitVec::from_vec(null_mask); let null_mask = BitVec::from_vec(null_mask);
ensure!( ensure!(
@@ -336,9 +360,13 @@ fn add_values_to_builder(
let mut idx_of_values = 0; let mut idx_of_values = 0;
for idx in 0..row_count { for idx in 0..row_count {
match is_null(&null_mask, idx) { match is_null(&null_mask, idx) {
Some(true) => builder.push(&Value::Null), Some(true) => builder
.push_value_ref(ValueRef::Null)
.context(CreateVectorSnafu)?,
_ => { _ => {
builder.push(&values[idx_of_values]); builder
.push_value_ref(values[idx_of_values].as_value_ref())
.context(CreateVectorSnafu)?;
idx_of_values += 1 idx_of_values += 1
} }
} }
@@ -418,9 +446,9 @@ fn convert_values(data_type: &ConcreteDataType, values: Values) -> Vec<Value> {
.map(|v| Value::Date(v.into())) .map(|v| Value::Date(v.into()))
.collect(), .collect(),
ConcreteDataType::Timestamp(_) => values ConcreteDataType::Timestamp(_) => values
.ts_millis_values .ts_millisecond_values
.into_iter() .into_iter()
.map(|v| Value::Timestamp(Timestamp::from_millis(v))) .map(|v| Value::Timestamp(Timestamp::new_millisecond(v)))
.collect(), .collect(),
ConcreteDataType::Null(_) => unreachable!(), ConcreteDataType::Null(_) => unreachable!(),
ConcreteDataType::List(_) => unreachable!(), ConcreteDataType::List(_) => unreachable!(),
@@ -543,7 +571,7 @@ mod tests {
); );
assert_eq!( assert_eq!(
ConcreteDataType::timestamp_millis_datatype(), ConcreteDataType::timestamp_millisecond_datatype(),
ConcreteDataType::from( ConcreteDataType::from(
ColumnDataTypeWrapper::try_new( ColumnDataTypeWrapper::try_new(
column_defs column_defs
@@ -624,8 +652,8 @@ mod tests {
assert_eq!(Value::Float64(0.1.into()), memory.get(1)); assert_eq!(Value::Float64(0.1.into()), memory.get(1));
let ts = insert_req.columns_values.get("ts").unwrap(); let ts = insert_req.columns_values.get("ts").unwrap();
assert_eq!(Value::Timestamp(Timestamp::from_millis(100)), ts.get(0)); assert_eq!(Value::Timestamp(Timestamp::new_millisecond(100)), ts.get(0));
assert_eq!(Value::Timestamp(Timestamp::from_millis(101)), ts.get(1)); assert_eq!(Value::Timestamp(Timestamp::new_millisecond(101)), ts.get(1));
} }
#[test] #[test]
@@ -675,8 +703,12 @@ mod tests {
ColumnSchema::new("host", ConcreteDataType::string_datatype(), false), ColumnSchema::new("host", ConcreteDataType::string_datatype(), false),
ColumnSchema::new("cpu", ConcreteDataType::float64_datatype(), true), ColumnSchema::new("cpu", ConcreteDataType::float64_datatype(), true),
ColumnSchema::new("memory", ConcreteDataType::float64_datatype(), true), ColumnSchema::new("memory", ConcreteDataType::float64_datatype(), true),
ColumnSchema::new("ts", ConcreteDataType::timestamp_millis_datatype(), true) ColumnSchema::new(
.with_time_index(true), "ts",
ConcreteDataType::timestamp_millisecond_datatype(),
true,
)
.with_time_index(true),
]; ];
Arc::new( Arc::new(
@@ -741,7 +773,7 @@ mod tests {
}; };
let ts_vals = column::Values { let ts_vals = column::Values {
ts_millis_values: vec![100, 101], ts_millisecond_values: vec![100, 101],
..Default::default() ..Default::default()
}; };
let ts_column = Column { let ts_column = Column {
@@ -749,7 +781,7 @@ mod tests {
semantic_type: TIMESTAMP_SEMANTIC_TYPE, semantic_type: TIMESTAMP_SEMANTIC_TYPE,
values: Some(ts_vals), values: Some(ts_vals),
null_mask: vec![0], null_mask: vec![0],
datatype: ColumnDataType::Timestamp as i32, datatype: ColumnDataType::TimestampMillisecond as i32,
}; };
( (

View File

@@ -13,9 +13,7 @@ common-query = { path = "../query" }
common-recordbatch = { path = "../recordbatch" } common-recordbatch = { path = "../recordbatch" }
common-runtime = { path = "../runtime" } common-runtime = { path = "../runtime" }
dashmap = "5.4" dashmap = "5.4"
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [ datafusion = "14.0.0"
"simd",
] }
datatypes = { path = "../../datatypes" } datatypes = { path = "../../datatypes" }
snafu = { version = "0.7", features = ["backtraces"] } snafu = { version = "0.7", features = ["backtraces"] }
tokio = { version = "1.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] }

View File

@@ -14,9 +14,7 @@
use std::any::Any; use std::any::Any;
use api::DecodeError;
use common_error::prelude::{ErrorExt, StatusCode}; use common_error::prelude::{ErrorExt, StatusCode};
use datafusion::error::DataFusionError;
use snafu::{Backtrace, ErrorCompat, Snafu}; use snafu::{Backtrace, ErrorCompat, Snafu};
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@@ -24,33 +22,9 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Snafu)] #[derive(Debug, Snafu)]
#[snafu(visibility(pub))] #[snafu(visibility(pub))]
pub enum Error { pub enum Error {
#[snafu(display("Unexpected empty physical plan type: {}", name))]
EmptyPhysicalPlan { name: String, backtrace: Backtrace },
#[snafu(display("Unexpected empty physical expr: {}", name))]
EmptyPhysicalExpr { name: String, backtrace: Backtrace },
#[snafu(display("Unsupported datafusion execution plan: {}", name))]
UnsupportedDfPlan { name: String, backtrace: Backtrace },
#[snafu(display("Unsupported datafusion physical expr: {}", name))]
UnsupportedDfExpr { name: String, backtrace: Backtrace },
#[snafu(display("Missing required field in protobuf, field: {}", field))] #[snafu(display("Missing required field in protobuf, field: {}", field))]
MissingField { field: String, backtrace: Backtrace }, MissingField { field: String, backtrace: Backtrace },
#[snafu(display("Failed to new datafusion projection exec, source: {}", source))]
NewProjection {
source: DataFusionError,
backtrace: Backtrace,
},
#[snafu(display("Failed to decode physical plan node, source: {}", source))]
DecodePhysicalPlanNode {
source: DecodeError,
backtrace: Backtrace,
},
#[snafu(display( #[snafu(display(
"Write type mismatch, column name: {}, expected: {}, actual: {}", "Write type mismatch, column name: {}, expected: {}, actual: {}",
column_name, column_name,
@@ -89,17 +63,8 @@ pub enum Error {
impl ErrorExt for Error { impl ErrorExt for Error {
fn status_code(&self) -> StatusCode { fn status_code(&self) -> StatusCode {
match self { match self {
Error::EmptyPhysicalPlan { .. } Error::MissingField { .. } | Error::TypeMismatch { .. } => StatusCode::InvalidArguments,
| Error::EmptyPhysicalExpr { .. } Error::CreateChannel { .. } | Error::Conversion { .. } => StatusCode::Internal,
| Error::MissingField { .. }
| Error::TypeMismatch { .. } => StatusCode::InvalidArguments,
Error::UnsupportedDfPlan { .. } | Error::UnsupportedDfExpr { .. } => {
StatusCode::Unsupported
}
Error::NewProjection { .. }
| Error::DecodePhysicalPlanNode { .. }
| Error::CreateChannel { .. }
| Error::Conversion { .. } => StatusCode::Internal,
Error::CollectRecordBatches { source } => source.status_code(), Error::CollectRecordBatches { source } => source.status_code(),
Error::ColumnDataType { source } => source.status_code(), Error::ColumnDataType { source } => source.status_code(),
} }
@@ -126,50 +91,6 @@ mod tests {
None None
} }
#[test]
fn test_empty_physical_plan_error() {
let e = throw_none_option()
.context(EmptyPhysicalPlanSnafu { name: "test" })
.err()
.unwrap();
assert!(e.backtrace_opt().is_some());
assert_eq!(e.status_code(), StatusCode::InvalidArguments);
}
#[test]
fn test_empty_physical_expr_error() {
let e = throw_none_option()
.context(EmptyPhysicalExprSnafu { name: "test" })
.err()
.unwrap();
assert!(e.backtrace_opt().is_some());
assert_eq!(e.status_code(), StatusCode::InvalidArguments);
}
#[test]
fn test_unsupported_df_plan_error() {
let e = throw_none_option()
.context(UnsupportedDfPlanSnafu { name: "test" })
.err()
.unwrap();
assert!(e.backtrace_opt().is_some());
assert_eq!(e.status_code(), StatusCode::Unsupported);
}
#[test]
fn test_unsupported_df_expr_error() {
let e = throw_none_option()
.context(UnsupportedDfExprSnafu { name: "test" })
.err()
.unwrap();
assert!(e.backtrace_opt().is_some());
assert_eq!(e.status_code(), StatusCode::Unsupported);
}
#[test] #[test]
fn test_missing_field_error() { fn test_missing_field_error() {
let e = throw_none_option() let e = throw_none_option()
@@ -181,33 +102,6 @@ mod tests {
assert_eq!(e.status_code(), StatusCode::InvalidArguments); assert_eq!(e.status_code(), StatusCode::InvalidArguments);
} }
#[test]
fn test_new_projection_error() {
fn throw_df_error() -> StdResult<DataFusionError> {
Err(DataFusionError::NotImplemented("".to_string()))
}
let e = throw_df_error().context(NewProjectionSnafu).err().unwrap();
assert!(e.backtrace_opt().is_some());
assert_eq!(e.status_code(), StatusCode::Internal);
}
#[test]
fn test_decode_physical_plan_node_error() {
fn throw_decode_error() -> StdResult<DecodeError> {
Err(DecodeError::new("test"))
}
let e = throw_decode_error()
.context(DecodePhysicalPlanNodeSnafu)
.err()
.unwrap();
assert!(e.backtrace_opt().is_some());
assert_eq!(e.status_code(), StatusCode::Internal);
}
#[test] #[test]
fn test_type_mismatch_error() { fn test_type_mismatch_error() {
let e = throw_none_option() let e = throw_none_option()

View File

@@ -14,10 +14,7 @@
pub mod channel_manager; pub mod channel_manager;
pub mod error; pub mod error;
pub mod physical;
pub mod select; pub mod select;
pub mod writer; pub mod writer;
pub use error::Error; pub use error::Error;
pub use physical::plan::{DefaultAsPlanImpl, MockExecution};
pub use physical::AsExecutionPlan;

View File

@@ -1,100 +0,0 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::result::Result;
use std::sync::Arc;
use api::v1::codec;
use datafusion::physical_plan::expressions::Column as DfColumn;
use datafusion::physical_plan::PhysicalExpr as DfPhysicalExpr;
use snafu::OptionExt;
use crate::error::{EmptyPhysicalExprSnafu, Error, UnsupportedDfExprSnafu};
// grpc -> datafusion (physical expr)
pub(crate) fn parse_grpc_physical_expr(
proto: &codec::PhysicalExprNode,
) -> Result<Arc<dyn DfPhysicalExpr>, Error> {
let expr_type = proto.expr_type.as_ref().context(EmptyPhysicalExprSnafu {
name: format!("{:?}", proto),
})?;
// TODO(fys): impl other physical expr
let pexpr: Arc<dyn DfPhysicalExpr> = match expr_type {
codec::physical_expr_node::ExprType::Column(c) => {
let pcol = DfColumn::new(&c.name, c.index as usize);
Arc::new(pcol)
}
};
Ok(pexpr)
}
// datafusion -> grpc (physical expr)
pub(crate) fn parse_df_physical_expr(
df_expr: Arc<dyn DfPhysicalExpr>,
) -> Result<codec::PhysicalExprNode, Error> {
let expr = df_expr.as_any();
// TODO(fys): impl other physical expr
if let Some(expr) = expr.downcast_ref::<DfColumn>() {
Ok(codec::PhysicalExprNode {
expr_type: Some(codec::physical_expr_node::ExprType::Column(
codec::PhysicalColumn {
name: expr.name().to_string(),
index: expr.index() as u64,
},
)),
})
} else {
UnsupportedDfExprSnafu {
name: df_expr.to_string(),
}
.fail()?
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use api::v1::codec::physical_expr_node::ExprType::Column;
use api::v1::codec::{PhysicalColumn, PhysicalExprNode};
use datafusion::physical_plan::expressions::Column as DfColumn;
use datafusion::physical_plan::PhysicalExpr;
use crate::physical::expr::{parse_df_physical_expr, parse_grpc_physical_expr};
#[test]
fn test_column_convert() {
// mock df_column_expr
let df_column = DfColumn::new("name", 11);
let df_column_clone = df_column.clone();
let df_expr = Arc::new(df_column) as Arc<dyn PhysicalExpr>;
// mock grpc_column_expr
let grpc_expr = PhysicalExprNode {
expr_type: Some(Column(PhysicalColumn {
name: "name".to_owned(),
index: 11,
})),
};
let result = parse_df_physical_expr(df_expr).unwrap();
assert_eq!(grpc_expr, result);
let result = parse_grpc_physical_expr(&grpc_expr).unwrap();
let df_column = result.as_any().downcast_ref::<DfColumn>().unwrap();
assert_eq!(df_column_clone, df_column.to_owned());
}
}

View File

@@ -1,280 +0,0 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::ops::Deref;
use std::result::Result;
use std::sync::Arc;
use api::v1::codec::physical_plan_node::PhysicalPlanType;
use api::v1::codec::{MockInputExecNode, PhysicalPlanNode, ProjectionExecNode};
use async_trait::async_trait;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::field_util::SchemaExt;
use datafusion::physical_plan::memory::MemoryStream;
use datafusion::physical_plan::projection::ProjectionExec;
use datafusion::physical_plan::{
ExecutionPlan, PhysicalExpr, SendableRecordBatchStream, Statistics,
};
use datafusion::record_batch::RecordBatch;
use datatypes::arrow::array::{PrimitiveArray, Utf8Array};
use datatypes::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use snafu::{OptionExt, ResultExt};
use crate::error::{
DecodePhysicalPlanNodeSnafu, EmptyPhysicalPlanSnafu, Error, MissingFieldSnafu,
NewProjectionSnafu, UnsupportedDfPlanSnafu,
};
use crate::physical::{expr, AsExecutionPlan, ExecutionPlanRef};
pub struct DefaultAsPlanImpl {
pub bytes: Vec<u8>,
}
impl AsExecutionPlan for DefaultAsPlanImpl {
type Error = Error;
// Vec<u8> -> PhysicalPlanNode -> ExecutionPlanRef
fn try_into_physical_plan(&self) -> Result<ExecutionPlanRef, Self::Error> {
let physicalplan_node: PhysicalPlanNode = self
.bytes
.deref()
.try_into()
.context(DecodePhysicalPlanNodeSnafu)?;
physicalplan_node.try_into_physical_plan()
}
// ExecutionPlanRef -> PhysicalPlanNode -> Vec<u8>
fn try_from_physical_plan(plan: ExecutionPlanRef) -> Result<Self, Self::Error>
where
Self: Sized,
{
let bytes: Vec<u8> = PhysicalPlanNode::try_from_physical_plan(plan)?.into();
Ok(DefaultAsPlanImpl { bytes })
}
}
impl AsExecutionPlan for PhysicalPlanNode {
type Error = Error;
fn try_into_physical_plan(&self) -> Result<ExecutionPlanRef, Self::Error> {
let plan = self
.physical_plan_type
.as_ref()
.context(EmptyPhysicalPlanSnafu {
name: format!("{:?}", self),
})?;
// TODO(fys): impl other physical plan type
match plan {
PhysicalPlanType::Projection(projection) => {
let input = if let Some(input) = &projection.input {
input.as_ref().try_into_physical_plan()?
} else {
MissingFieldSnafu { field: "input" }.fail()?
};
let exprs = projection
.expr
.iter()
.zip(projection.expr_name.iter())
.map(|(expr, name)| {
Ok((expr::parse_grpc_physical_expr(expr)?, name.to_string()))
})
.collect::<Result<Vec<(Arc<dyn PhysicalExpr>, String)>, Error>>()?;
let projection =
ProjectionExec::try_new(exprs, input).context(NewProjectionSnafu)?;
Ok(Arc::new(projection))
}
PhysicalPlanType::Mock(mock) => Ok(Arc::new(MockExecution {
name: mock.name.to_string(),
})),
}
}
fn try_from_physical_plan(plan: ExecutionPlanRef) -> Result<Self, Self::Error>
where
Self: Sized,
{
let plan = plan.as_any();
if let Some(exec) = plan.downcast_ref::<ProjectionExec>() {
let input = PhysicalPlanNode::try_from_physical_plan(exec.input().to_owned())?;
let expr = exec
.expr()
.iter()
.map(|expr| expr::parse_df_physical_expr(expr.0.clone()))
.collect::<Result<Vec<_>, Error>>()?;
let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect();
Ok(PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::Projection(Box::new(
ProjectionExecNode {
input: Some(Box::new(input)),
expr,
expr_name,
},
))),
})
} else if let Some(exec) = plan.downcast_ref::<MockExecution>() {
Ok(PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::Mock(MockInputExecNode {
name: exec.name.clone(),
})),
})
} else {
UnsupportedDfPlanSnafu {
name: format!("{:?}", plan),
}
.fail()?
}
}
}
// TODO(fys): use "test" feature to enable it
#[derive(Debug)]
pub struct MockExecution {
name: String,
}
impl MockExecution {
pub fn new(name: String) -> Self {
Self { name }
}
}
#[async_trait]
impl ExecutionPlan for MockExecution {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn schema(&self) -> SchemaRef {
let field1 = Field::new("id", DataType::UInt32, false);
let field2 = Field::new("name", DataType::Utf8, false);
let field3 = Field::new("age", DataType::UInt32, false);
Arc::new(Schema::new(vec![field1, field2, field3]))
}
fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning {
unimplemented!()
}
fn output_ordering(
&self,
) -> Option<&[datafusion::physical_plan::expressions::PhysicalSortExpr]> {
unimplemented!()
}
fn children(&self) -> Vec<ExecutionPlanRef> {
unimplemented!()
}
fn with_new_children(
&self,
_children: Vec<ExecutionPlanRef>,
) -> datafusion::error::Result<ExecutionPlanRef> {
unimplemented!()
}
async fn execute(
&self,
_partition: usize,
_runtime: Arc<RuntimeEnv>,
) -> datafusion::error::Result<SendableRecordBatchStream> {
let id_array = Arc::new(PrimitiveArray::from_slice([1u32, 2, 3, 4, 5]));
let name_array = Arc::new(Utf8Array::<i32>::from_slice([
"zhangsan", "lisi", "wangwu", "Tony", "Mike",
]));
let age_array = Arc::new(PrimitiveArray::from_slice([25u32, 28, 27, 35, 25]));
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("name", DataType::Utf8, false),
Field::new("age", DataType::UInt32, false),
]));
let record_batch =
RecordBatch::try_new(schema, vec![id_array, name_array, age_array]).unwrap();
let data: Vec<RecordBatch> = vec![record_batch];
let projection = Some(vec![0, 1, 2]);
let stream = MemoryStream::try_new(data, self.schema(), projection).unwrap();
Ok(Box::pin(stream))
}
fn statistics(&self) -> Statistics {
todo!()
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use api::v1::codec::PhysicalPlanNode;
use datafusion::physical_plan::expressions::Column;
use datafusion::physical_plan::projection::ProjectionExec;
use crate::physical::plan::{DefaultAsPlanImpl, MockExecution};
use crate::physical::{AsExecutionPlan, ExecutionPlanRef};
#[test]
fn test_convert_df_projection_with_bytes() {
let projection_exec = mock_df_projection();
let bytes = DefaultAsPlanImpl::try_from_physical_plan(projection_exec).unwrap();
let exec = bytes.try_into_physical_plan().unwrap();
verify_df_projection(exec);
}
#[test]
fn test_convert_df_with_grpc_projection() {
let projection_exec = mock_df_projection();
let projection_node = PhysicalPlanNode::try_from_physical_plan(projection_exec).unwrap();
let exec = projection_node.try_into_physical_plan().unwrap();
verify_df_projection(exec);
}
fn mock_df_projection() -> Arc<ProjectionExec> {
let mock_input = Arc::new(MockExecution {
name: "mock_input".to_string(),
});
let column1 = Arc::new(Column::new("id", 0));
let column2 = Arc::new(Column::new("name", 1));
Arc::new(
ProjectionExec::try_new(
vec![(column1, "id".to_string()), (column2, "name".to_string())],
mock_input,
)
.unwrap(),
)
}
fn verify_df_projection(exec: ExecutionPlanRef) {
let projection_exec = exec.as_any().downcast_ref::<ProjectionExec>().unwrap();
let mock_input = projection_exec
.input()
.as_any()
.downcast_ref::<MockExecution>()
.unwrap();
assert_eq!("mock_input", mock_input.name);
assert_eq!(2, projection_exec.expr().len());
assert_eq!("id", projection_exec.expr()[0].1);
assert_eq!("name", projection_exec.expr()[1].1);
}
}

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::sync::Arc;
use api::helper::ColumnDataTypeWrapper; use api::helper::ColumnDataTypeWrapper;
use api::result::{build_err_result, ObjectResultBuilder}; use api::result::{build_err_result, ObjectResultBuilder};
use api::v1::codec::SelectResult; use api::v1::codec::SelectResult;
@@ -23,10 +21,15 @@ use common_base::BitVec;
use common_error::prelude::ErrorExt; use common_error::prelude::ErrorExt;
use common_error::status_code::StatusCode; use common_error::status_code::StatusCode;
use common_query::Output; use common_query::Output;
use common_recordbatch::{util, RecordBatches, SendableRecordBatchStream}; use common_recordbatch::{RecordBatches, SendableRecordBatchStream};
use datatypes::arrow::array::{Array, BooleanArray, PrimitiveArray};
use datatypes::arrow_array::{BinaryArray, StringArray};
use datatypes::schema::SchemaRef; use datatypes::schema::SchemaRef;
use datatypes::types::{TimestampType, WrapperType};
use datatypes::vectors::{
BinaryVector, BooleanVector, DateTimeVector, DateVector, Float32Vector, Float64Vector,
Int16Vector, Int32Vector, Int64Vector, Int8Vector, StringVector, TimestampMicrosecondVector,
TimestampMillisecondVector, TimestampNanosecondVector, TimestampSecondVector, UInt16Vector,
UInt32Vector, UInt64Vector, UInt8Vector, VectorRef,
};
use snafu::{OptionExt, ResultExt}; use snafu::{OptionExt, ResultExt};
use crate::error::{self, ConversionSnafu, Result}; use crate::error::{self, ConversionSnafu, Result};
@@ -46,14 +49,11 @@ pub async fn to_object_result(output: std::result::Result<Output, impl ErrorExt>
Err(e) => build_err_result(&e), Err(e) => build_err_result(&e),
} }
} }
async fn collect(stream: SendableRecordBatchStream) -> Result<ObjectResult> { async fn collect(stream: SendableRecordBatchStream) -> Result<ObjectResult> {
let schema = stream.schema(); let recordbatches = RecordBatches::try_collect(stream)
let recordbatches = util::collect(stream)
.await .await
.and_then(|batches| RecordBatches::try_new(schema, batches))
.context(error::CollectRecordBatchesSnafu)?; .context(error::CollectRecordBatchesSnafu)?;
let object_result = build_result(recordbatches)?; let object_result = build_result(recordbatches)?;
Ok(object_result) Ok(object_result)
} }
@@ -82,10 +82,7 @@ fn try_convert(record_batches: RecordBatches) -> Result<SelectResult> {
let schema = record_batches.schema(); let schema = record_batches.schema();
let record_batches = record_batches.take(); let record_batches = record_batches.take();
let row_count: usize = record_batches let row_count: usize = record_batches.iter().map(|r| r.num_rows()).sum();
.iter()
.map(|r| r.df_recordbatch.num_rows())
.sum();
let schemas = schema.column_schemas(); let schemas = schema.column_schemas();
let mut columns = Vec::with_capacity(schemas.len()); let mut columns = Vec::with_capacity(schemas.len());
@@ -93,9 +90,9 @@ fn try_convert(record_batches: RecordBatches) -> Result<SelectResult> {
for (idx, column_schema) in schemas.iter().enumerate() { for (idx, column_schema) in schemas.iter().enumerate() {
let column_name = column_schema.name.clone(); let column_name = column_schema.name.clone();
let arrays: Vec<Arc<dyn Array>> = record_batches let arrays: Vec<_> = record_batches
.iter() .iter()
.map(|r| r.df_recordbatch.columns()[idx].clone()) .map(|r| r.column(idx).clone())
.collect(); .collect();
let column = Column { let column = Column {
@@ -116,7 +113,7 @@ fn try_convert(record_batches: RecordBatches) -> Result<SelectResult> {
}) })
} }
pub fn null_mask(arrays: &Vec<Arc<dyn Array>>, row_count: usize) -> Vec<u8> { pub fn null_mask(arrays: &[VectorRef], row_count: usize) -> Vec<u8> {
let null_count: usize = arrays.iter().map(|a| a.null_count()).sum(); let null_count: usize = arrays.iter().map(|a| a.null_count()).sum();
if null_count == 0 { if null_count == 0 {
@@ -126,10 +123,12 @@ pub fn null_mask(arrays: &Vec<Arc<dyn Array>>, row_count: usize) -> Vec<u8> {
let mut null_mask = BitVec::with_capacity(row_count); let mut null_mask = BitVec::with_capacity(row_count);
for array in arrays { for array in arrays {
let validity = array.validity(); let validity = array.validity();
if let Some(v) = validity { if validity.is_all_valid() {
v.iter().for_each(|x| null_mask.push(!x));
} else {
null_mask.extend_from_bitslice(&BitVec::repeat(false, array.len())); null_mask.extend_from_bitslice(&BitVec::repeat(false, array.len()));
} else {
for i in 0..array.len() {
null_mask.push(!validity.is_set(i));
}
} }
} }
null_mask.into_vec() null_mask.into_vec()
@@ -137,7 +136,9 @@ pub fn null_mask(arrays: &Vec<Arc<dyn Array>>, row_count: usize) -> Vec<u8> {
macro_rules! convert_arrow_array_to_grpc_vals { macro_rules! convert_arrow_array_to_grpc_vals {
($data_type: expr, $arrays: ident, $(($Type: pat, $CastType: ty, $field: ident, $MapFunction: expr)), +) => {{ ($data_type: expr, $arrays: ident, $(($Type: pat, $CastType: ty, $field: ident, $MapFunction: expr)), +) => {{
use datatypes::arrow::datatypes::{DataType, TimeUnit}; use datatypes::data_type::{ConcreteDataType};
use datatypes::prelude::ScalarVector;
match $data_type { match $data_type {
$( $(
$Type => { $Type => {
@@ -147,52 +148,114 @@ macro_rules! convert_arrow_array_to_grpc_vals {
from: format!("{:?}", $data_type), from: format!("{:?}", $data_type),
})?; })?;
vals.$field.extend(array vals.$field.extend(array
.iter() .iter_data()
.filter_map(|i| i.map($MapFunction)) .filter_map(|i| i.map($MapFunction))
.collect::<Vec<_>>()); .collect::<Vec<_>>());
} }
return Ok(vals); return Ok(vals);
}, },
)+ )+
_ => unimplemented!(), ConcreteDataType::Null(_) | ConcreteDataType::List(_) => unreachable!("Should not send {:?} in gRPC", $data_type),
} }
}}; }};
} }
pub fn values(arrays: &[Arc<dyn Array>]) -> Result<Values> { pub fn values(arrays: &[VectorRef]) -> Result<Values> {
if arrays.is_empty() { if arrays.is_empty() {
return Ok(Values::default()); return Ok(Values::default());
} }
let data_type = arrays[0].data_type(); let data_type = arrays[0].data_type();
convert_arrow_array_to_grpc_vals!( convert_arrow_array_to_grpc_vals!(
data_type, arrays, data_type,
arrays,
(DataType::Boolean, BooleanArray, bool_values, |x| {x}), (
ConcreteDataType::Boolean(_),
(DataType::Int8, PrimitiveArray<i8>, i8_values, |x| {*x as i32}), BooleanVector,
(DataType::Int16, PrimitiveArray<i16>, i16_values, |x| {*x as i32}), bool_values,
(DataType::Int32, PrimitiveArray<i32>, i32_values, |x| {*x}), |x| { x }
(DataType::Int64, PrimitiveArray<i64>, i64_values, |x| {*x}), ),
(ConcreteDataType::Int8(_), Int8Vector, i8_values, |x| {
(DataType::UInt8, PrimitiveArray<u8>, u8_values, |x| {*x as u32}), i32::from(x)
(DataType::UInt16, PrimitiveArray<u16>, u16_values, |x| {*x as u32}), }),
(DataType::UInt32, PrimitiveArray<u32>, u32_values, |x| {*x}), (ConcreteDataType::Int16(_), Int16Vector, i16_values, |x| {
(DataType::UInt64, PrimitiveArray<u64>, u64_values, |x| {*x}), i32::from(x)
}),
(DataType::Float32, PrimitiveArray<f32>, f32_values, |x| {*x}), (ConcreteDataType::Int32(_), Int32Vector, i32_values, |x| {
(DataType::Float64, PrimitiveArray<f64>, f64_values, |x| {*x}), x
}),
(DataType::Binary, BinaryArray, binary_values, |x| {x.into()}), (ConcreteDataType::Int64(_), Int64Vector, i64_values, |x| {
(DataType::LargeBinary, BinaryArray, binary_values, |x| {x.into()}), x
}),
(DataType::Utf8, StringArray, string_values, |x| {x.into()}), (ConcreteDataType::UInt8(_), UInt8Vector, u8_values, |x| {
(DataType::LargeUtf8, StringArray, string_values, |x| {x.into()}), u32::from(x)
}),
(DataType::Date32, PrimitiveArray<i32>, date_values, |x| {*x as i32}), (ConcreteDataType::UInt16(_), UInt16Vector, u16_values, |x| {
(DataType::Date64, PrimitiveArray<i64>, datetime_values,|x| {*x as i64}), u32::from(x)
}),
(DataType::Timestamp(TimeUnit::Millisecond, _), PrimitiveArray<i64>, ts_millis_values, |x| {*x}) (ConcreteDataType::UInt32(_), UInt32Vector, u32_values, |x| {
x
}),
(ConcreteDataType::UInt64(_), UInt64Vector, u64_values, |x| {
x
}),
(
ConcreteDataType::Float32(_),
Float32Vector,
f32_values,
|x| { x }
),
(
ConcreteDataType::Float64(_),
Float64Vector,
f64_values,
|x| { x }
),
(
ConcreteDataType::Binary(_),
BinaryVector,
binary_values,
|x| { x.into() }
),
(
ConcreteDataType::String(_),
StringVector,
string_values,
|x| { x.into() }
),
(ConcreteDataType::Date(_), DateVector, date_values, |x| {
x.val()
}),
(
ConcreteDataType::DateTime(_),
DateTimeVector,
datetime_values,
|x| { x.val() }
),
(
ConcreteDataType::Timestamp(TimestampType::Second(_)),
TimestampSecondVector,
ts_second_values,
|x| { x.into_native() }
),
(
ConcreteDataType::Timestamp(TimestampType::Millisecond(_)),
TimestampMillisecondVector,
ts_millisecond_values,
|x| { x.into_native() }
),
(
ConcreteDataType::Timestamp(TimestampType::Microsecond(_)),
TimestampMicrosecondVector,
ts_microsecond_values,
|x| { x.into_native() }
),
(
ConcreteDataType::Timestamp(TimestampType::Nanosecond(_)),
TimestampNanosecondVector,
ts_nanosecond_values,
|x| { x.into_native() }
)
) )
} }
@@ -201,14 +264,10 @@ mod tests {
use std::sync::Arc; use std::sync::Arc;
use common_recordbatch::{RecordBatch, RecordBatches}; use common_recordbatch::{RecordBatch, RecordBatches};
use datafusion::field_util::SchemaExt; use datatypes::data_type::ConcreteDataType;
use datatypes::arrow::array::{Array, BooleanArray, PrimitiveArray}; use datatypes::schema::{ColumnSchema, Schema};
use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
use datatypes::arrow_array::StringArray;
use datatypes::schema::Schema;
use datatypes::vectors::{UInt32Vector, VectorRef};
use crate::select::{null_mask, try_convert, values}; use super::*;
#[test] #[test]
fn test_convert_record_batches_to_select_result() { fn test_convert_record_batches_to_select_result() {
@@ -234,9 +293,8 @@ mod tests {
#[test] #[test]
fn test_convert_arrow_arrays_i32() { fn test_convert_arrow_arrays_i32() {
let array: PrimitiveArray<i32> = let array = Int32Vector::from(vec![Some(1), Some(2), None, Some(3)]);
PrimitiveArray::from(vec![Some(1), Some(2), None, Some(3)]); let array: VectorRef = Arc::new(array);
let array: Arc<dyn Array> = Arc::new(array);
let values = values(&[array]).unwrap(); let values = values(&[array]).unwrap();
@@ -245,14 +303,14 @@ mod tests {
#[test] #[test]
fn test_convert_arrow_arrays_string() { fn test_convert_arrow_arrays_string() {
let array = StringArray::from(vec![ let array = StringVector::from(vec![
Some("1".to_string()), Some("1".to_string()),
Some("2".to_string()), Some("2".to_string()),
None, None,
Some("3".to_string()), Some("3".to_string()),
None, None,
]); ]);
let array: Arc<dyn Array> = Arc::new(array); let array: VectorRef = Arc::new(array);
let values = values(&[array]).unwrap(); let values = values(&[array]).unwrap();
@@ -261,8 +319,8 @@ mod tests {
#[test] #[test]
fn test_convert_arrow_arrays_bool() { fn test_convert_arrow_arrays_bool() {
let array = BooleanArray::from(vec![Some(true), Some(false), None, Some(false), None]); let array = BooleanVector::from(vec![Some(true), Some(false), None, Some(false), None]);
let array: Arc<dyn Array> = Arc::new(array); let array: VectorRef = Arc::new(array);
let values = values(&[array]).unwrap(); let values = values(&[array]).unwrap();
@@ -271,43 +329,42 @@ mod tests {
#[test] #[test]
fn test_convert_arrow_arrays_empty() { fn test_convert_arrow_arrays_empty() {
let array = BooleanArray::from(vec![None, None, None, None, None]); let array = BooleanVector::from(vec![None, None, None, None, None]);
let array: Arc<dyn Array> = Arc::new(array); let array: VectorRef = Arc::new(array);
let values = values(&[array]).unwrap(); let values = values(&[array]).unwrap();
assert_eq!(Vec::<bool>::default(), values.bool_values); assert!(values.bool_values.is_empty());
} }
#[test] #[test]
fn test_null_mask() { fn test_null_mask() {
let a1: Arc<dyn Array> = Arc::new(PrimitiveArray::from(vec![None, Some(2), None])); let a1: VectorRef = Arc::new(Int32Vector::from(vec![None, Some(2), None]));
let a2: Arc<dyn Array> = let a2: VectorRef = Arc::new(Int32Vector::from(vec![Some(1), Some(2), None, Some(4)]));
Arc::new(PrimitiveArray::from(vec![Some(1), Some(2), None, Some(4)])); let mask = null_mask(&[a1, a2], 3 + 4);
let mask = null_mask(&vec![a1, a2], 3 + 4);
assert_eq!(vec![0b0010_0101], mask); assert_eq!(vec![0b0010_0101], mask);
let empty: Arc<dyn Array> = Arc::new(PrimitiveArray::<i32>::from(vec![None, None, None])); let empty: VectorRef = Arc::new(Int32Vector::from(vec![None, None, None]));
let mask = null_mask(&vec![empty.clone(), empty.clone(), empty], 9); let mask = null_mask(&[empty.clone(), empty.clone(), empty], 9);
assert_eq!(vec![0b1111_1111, 0b0000_0001], mask); assert_eq!(vec![0b1111_1111, 0b0000_0001], mask);
let a1: Arc<dyn Array> = Arc::new(PrimitiveArray::from(vec![Some(1), Some(2), Some(3)])); let a1: VectorRef = Arc::new(Int32Vector::from(vec![Some(1), Some(2), Some(3)]));
let a2: Arc<dyn Array> = Arc::new(PrimitiveArray::from(vec![Some(4), Some(5), Some(6)])); let a2: VectorRef = Arc::new(Int32Vector::from(vec![Some(4), Some(5), Some(6)]));
let mask = null_mask(&vec![a1, a2], 3 + 3); let mask = null_mask(&[a1, a2], 3 + 3);
assert_eq!(Vec::<u8>::default(), mask); assert_eq!(Vec::<u8>::default(), mask);
let a1: Arc<dyn Array> = Arc::new(PrimitiveArray::from(vec![Some(1), Some(2), Some(3)])); let a1: VectorRef = Arc::new(Int32Vector::from(vec![Some(1), Some(2), Some(3)]));
let a2: Arc<dyn Array> = Arc::new(PrimitiveArray::from(vec![Some(4), Some(5), None])); let a2: VectorRef = Arc::new(Int32Vector::from(vec![Some(4), Some(5), None]));
let mask = null_mask(&vec![a1, a2], 3 + 3); let mask = null_mask(&[a1, a2], 3 + 3);
assert_eq!(vec![0b0010_0000], mask); assert_eq!(vec![0b0010_0000], mask);
} }
fn mock_record_batch() -> RecordBatch { fn mock_record_batch() -> RecordBatch {
let arrow_schema = Arc::new(ArrowSchema::new(vec![ let column_schemas = vec![
Field::new("c1", DataType::UInt32, false), ColumnSchema::new("c1", ConcreteDataType::uint32_datatype(), true),
Field::new("c2", DataType::UInt32, false), ColumnSchema::new("c2", ConcreteDataType::uint32_datatype(), true),
])); ];
let schema = Arc::new(Schema::try_from(arrow_schema).unwrap()); let schema = Arc::new(Schema::try_new(column_schemas).unwrap());
let v1 = Arc::new(UInt32Vector::from(vec![Some(1), Some(2), None])); let v1 = Arc::new(UInt32Vector::from(vec![Some(1), Some(2), None]));
let v2 = Arc::new(UInt32Vector::from(vec![Some(1), None, None])); let v2 = Arc::new(UInt32Vector::from(vec![Some(1), None, None]));

View File

@@ -45,11 +45,11 @@ impl LinesWriter {
pub fn write_ts(&mut self, column_name: &str, value: (i64, Precision)) -> Result<()> { pub fn write_ts(&mut self, column_name: &str, value: (i64, Precision)) -> Result<()> {
let (idx, column) = self.mut_column( let (idx, column) = self.mut_column(
column_name, column_name,
ColumnDataType::Timestamp, ColumnDataType::TimestampMillisecond,
SemanticType::Timestamp, SemanticType::Timestamp,
); );
ensure!( ensure!(
column.datatype == ColumnDataType::Timestamp as i32, column.datatype == ColumnDataType::TimestampMillisecond as i32,
TypeMismatchSnafu { TypeMismatchSnafu {
column_name, column_name,
expected: "timestamp", expected: "timestamp",
@@ -58,7 +58,9 @@ impl LinesWriter {
); );
// It is safe to use unwrap here, because values has been initialized in mut_column() // It is safe to use unwrap here, because values has been initialized in mut_column()
let values = column.values.as_mut().unwrap(); let values = column.values.as_mut().unwrap();
values.ts_millis_values.push(to_ms_ts(value.1, value.0)); values
.ts_millisecond_values
.push(to_ms_ts(value.1, value.0));
self.null_masks[idx].push(false); self.null_masks[idx].push(false);
Ok(()) Ok(())
} }
@@ -224,23 +226,23 @@ impl LinesWriter {
pub fn to_ms_ts(p: Precision, ts: i64) -> i64 { pub fn to_ms_ts(p: Precision, ts: i64) -> i64 {
match p { match p {
Precision::NANOSECOND => ts / 1_000_000, Precision::Nanosecond => ts / 1_000_000,
Precision::MICROSECOND => ts / 1000, Precision::Microsecond => ts / 1000,
Precision::MILLISECOND => ts, Precision::Millisecond => ts,
Precision::SECOND => ts * 1000, Precision::Second => ts * 1000,
Precision::MINUTE => ts * 1000 * 60, Precision::Minute => ts * 1000 * 60,
Precision::HOUR => ts * 1000 * 60 * 60, Precision::Hour => ts * 1000 * 60 * 60,
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Precision { pub enum Precision {
NANOSECOND, Nanosecond,
MICROSECOND, Microsecond,
MILLISECOND, Millisecond,
SECOND, Second,
MINUTE, Minute,
HOUR, Hour,
} }
#[cfg(test)] #[cfg(test)]
@@ -261,13 +263,13 @@ mod tests {
writer.write_f64("memory", 0.4).unwrap(); writer.write_f64("memory", 0.4).unwrap();
writer.write_string("name", "name1").unwrap(); writer.write_string("name", "name1").unwrap();
writer writer
.write_ts("ts", (101011000, Precision::MILLISECOND)) .write_ts("ts", (101011000, Precision::Millisecond))
.unwrap(); .unwrap();
writer.commit(); writer.commit();
writer.write_tag("host", "host2").unwrap(); writer.write_tag("host", "host2").unwrap();
writer writer
.write_ts("ts", (102011001, Precision::MILLISECOND)) .write_ts("ts", (102011001, Precision::Millisecond))
.unwrap(); .unwrap();
writer.write_bool("enable_reboot", true).unwrap(); writer.write_bool("enable_reboot", true).unwrap();
writer.write_u64("year_of_service", 2).unwrap(); writer.write_u64("year_of_service", 2).unwrap();
@@ -278,7 +280,7 @@ mod tests {
writer.write_f64("cpu", 0.4).unwrap(); writer.write_f64("cpu", 0.4).unwrap();
writer.write_u64("cpu_core_num", 16).unwrap(); writer.write_u64("cpu_core_num", 16).unwrap();
writer writer
.write_ts("ts", (103011002, Precision::MILLISECOND)) .write_ts("ts", (103011002, Precision::Millisecond))
.unwrap(); .unwrap();
writer.commit(); writer.commit();
@@ -321,11 +323,11 @@ mod tests {
let column = &columns[4]; let column = &columns[4];
assert_eq!("ts", column.column_name); assert_eq!("ts", column.column_name);
assert_eq!(ColumnDataType::Timestamp as i32, column.datatype); assert_eq!(ColumnDataType::TimestampMillisecond as i32, column.datatype);
assert_eq!(SemanticType::Timestamp as i32, column.semantic_type); assert_eq!(SemanticType::Timestamp as i32, column.semantic_type);
assert_eq!( assert_eq!(
vec![101011000, 102011001, 103011002], vec![101011000, 102011001, 103011002],
column.values.as_ref().unwrap().ts_millis_values column.values.as_ref().unwrap().ts_millisecond_values
); );
verify_null_mask(&column.null_mask, vec![false, false, false]); verify_null_mask(&column.null_mask, vec![false, false, false]);
@@ -367,16 +369,16 @@ mod tests {
#[test] #[test]
fn test_to_ms() { fn test_to_ms() {
assert_eq!(100, to_ms_ts(Precision::NANOSECOND, 100110000)); assert_eq!(100, to_ms_ts(Precision::Nanosecond, 100110000));
assert_eq!(100110, to_ms_ts(Precision::MICROSECOND, 100110000)); assert_eq!(100110, to_ms_ts(Precision::Microsecond, 100110000));
assert_eq!(100110000, to_ms_ts(Precision::MILLISECOND, 100110000)); assert_eq!(100110000, to_ms_ts(Precision::Millisecond, 100110000));
assert_eq!( assert_eq!(
100110000 * 1000 * 60, 100110000 * 1000 * 60,
to_ms_ts(Precision::MINUTE, 100110000) to_ms_ts(Precision::Minute, 100110000)
); );
assert_eq!( assert_eq!(
100110000 * 1000 * 60 * 60, 100110000 * 1000 * 60 * 60,
to_ms_ts(Precision::HOUR, 100110000) to_ms_ts(Precision::Hour, 100110000)
); );
} }
} }

View File

@@ -9,11 +9,9 @@ async-trait = "0.1"
common-error = { path = "../error" } common-error = { path = "../error" }
common-recordbatch = { path = "../recordbatch" } common-recordbatch = { path = "../recordbatch" }
common-time = { path = "../time" } common-time = { path = "../time" }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [ datafusion = "14.0.0"
"simd", datafusion-common = "14.0.0"
] } datafusion-expr = "14.0.0"
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" }
datafusion-expr = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" }
datatypes = { path = "../../datatypes" } datatypes = { path = "../../datatypes" }
snafu = { version = "0.7", features = ["backtraces"] } snafu = { version = "0.7", features = ["backtraces"] }
statrs = "0.15" statrs = "0.15"

View File

@@ -14,18 +14,18 @@
use std::any::Any; use std::any::Any;
use arrow::error::ArrowError;
use common_error::prelude::*; use common_error::prelude::*;
use datafusion_common::DataFusionError; use datafusion_common::DataFusionError;
use datatypes::arrow;
use datatypes::arrow::datatypes::DataType as ArrowDatatype; use datatypes::arrow::datatypes::DataType as ArrowDatatype;
use datatypes::error::Error as DataTypeError; use datatypes::error::Error as DataTypeError;
use datatypes::prelude::ConcreteDataType; use datatypes::prelude::ConcreteDataType;
use statrs::StatsError; use statrs::StatsError;
common_error::define_opaque_error!(Error);
#[derive(Debug, Snafu)] #[derive(Debug, Snafu)]
#[snafu(visibility(pub))] #[snafu(visibility(pub))]
pub enum InnerError { pub enum Error {
#[snafu(display("Fail to execute function, source: {}", source))] #[snafu(display("Fail to execute function, source: {}", source))]
ExecuteFunction { ExecuteFunction {
source: DataFusionError, source: DataFusionError,
@@ -51,6 +51,12 @@ pub enum InnerError {
source: DataTypeError, source: DataTypeError,
}, },
#[snafu(display("Fail to cast arrow array into vector: {}", source))]
FromArrowArray {
#[snafu(backtrace)]
source: DataTypeError,
},
#[snafu(display("Fail to cast arrow array into vector: {:?}, {}", data_type, source))] #[snafu(display("Fail to cast arrow array into vector: {:?}, {}", data_type, source))]
IntoVector { IntoVector {
#[snafu(backtrace)] #[snafu(backtrace)]
@@ -70,8 +76,8 @@ pub enum InnerError {
backtrace: Backtrace, backtrace: Backtrace,
}, },
#[snafu(display("Invalid inputs: {}", err_msg))] #[snafu(display("Invalid input type: {}", err_msg))]
InvalidInputs { InvalidInputType {
#[snafu(backtrace)] #[snafu(backtrace)]
source: DataTypeError, source: DataTypeError,
err_msg: String, err_msg: String,
@@ -120,34 +126,74 @@ pub enum InnerError {
#[snafu(backtrace)] #[snafu(backtrace)]
source: BoxedError, source: BoxedError,
}, },
#[snafu(display("Failed to cast array to {:?}, source: {}", typ, source))]
TypeCast {
source: ArrowError,
typ: arrow::datatypes::DataType,
backtrace: Backtrace,
},
#[snafu(display(
"Failed to perform compute operation on arrow arrays, source: {}",
source
))]
ArrowCompute {
source: ArrowError,
backtrace: Backtrace,
},
#[snafu(display("Query engine fail to cast value: {}", source))]
ToScalarValue {
#[snafu(backtrace)]
source: DataTypeError,
},
#[snafu(display("Failed to get scalar vector, {}", source))]
GetScalarVector {
#[snafu(backtrace)]
source: DataTypeError,
},
#[snafu(display("Invalid function args: {}", err_msg))]
InvalidFuncArgs {
err_msg: String,
backtrace: Backtrace,
},
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
impl ErrorExt for InnerError { impl ErrorExt for Error {
fn status_code(&self) -> StatusCode { fn status_code(&self) -> StatusCode {
match self { match self {
InnerError::ExecuteFunction { .. } Error::ExecuteFunction { .. }
| InnerError::GenerateFunction { .. } | Error::GenerateFunction { .. }
| InnerError::CreateAccumulator { .. } | Error::CreateAccumulator { .. }
| InnerError::DowncastVector { .. } | Error::DowncastVector { .. }
| InnerError::InvalidInputState { .. } | Error::InvalidInputState { .. }
| InnerError::InvalidInputCol { .. } | Error::InvalidInputCol { .. }
| InnerError::BadAccumulatorImpl { .. } => StatusCode::EngineExecuteQuery, | Error::BadAccumulatorImpl { .. }
| Error::ToScalarValue { .. }
| Error::GetScalarVector { .. }
| Error::ArrowCompute { .. } => StatusCode::EngineExecuteQuery,
InnerError::InvalidInputs { source, .. } Error::InvalidInputType { source, .. }
| InnerError::IntoVector { source, .. } | Error::IntoVector { source, .. }
| InnerError::FromScalarValue { source } | Error::FromScalarValue { source }
| InnerError::ConvertArrowSchema { source } => source.status_code(), | Error::ConvertArrowSchema { source }
| Error::FromArrowArray { source } => source.status_code(),
InnerError::ExecuteRepeatedly { .. } Error::ExecuteRepeatedly { .. }
| InnerError::GeneralDataFusion { .. } | Error::GeneralDataFusion { .. }
| InnerError::DataFusionExecutionPlan { .. } => StatusCode::Unexpected, | Error::DataFusionExecutionPlan { .. } => StatusCode::Unexpected,
InnerError::UnsupportedInputDataType { .. } => StatusCode::InvalidArguments, Error::UnsupportedInputDataType { .. }
| Error::TypeCast { .. }
| Error::InvalidFuncArgs { .. } => StatusCode::InvalidArguments,
InnerError::ConvertDfRecordBatchStream { source, .. } => source.status_code(), Error::ConvertDfRecordBatchStream { source, .. } => source.status_code(),
InnerError::ExecutePhysicalPlan { source } => source.status_code(), Error::ExecutePhysicalPlan { source } => source.status_code(),
} }
} }
@@ -160,12 +206,6 @@ impl ErrorExt for InnerError {
} }
} }
impl From<InnerError> for Error {
fn from(e: InnerError) -> Error {
Error::new(e)
}
}
impl From<Error> for DataFusionError { impl From<Error> for DataFusionError {
fn from(e: Error) -> DataFusionError { fn from(e: Error) -> DataFusionError {
DataFusionError::External(Box::new(e)) DataFusionError::External(Box::new(e))
@@ -174,7 +214,7 @@ impl From<Error> for DataFusionError {
impl From<BoxedError> for Error { impl From<BoxedError> for Error {
fn from(source: BoxedError) -> Self { fn from(source: BoxedError) -> Self {
InnerError::ExecutePhysicalPlan { source }.into() Error::ExecutePhysicalPlan { source }
} }
} }
@@ -190,60 +230,51 @@ mod tests {
} }
fn assert_error(err: &Error, code: StatusCode) { fn assert_error(err: &Error, code: StatusCode) {
let inner_err = err.as_any().downcast_ref::<InnerError>().unwrap(); let inner_err = err.as_any().downcast_ref::<Error>().unwrap();
assert_eq!(code, inner_err.status_code()); assert_eq!(code, inner_err.status_code());
assert!(inner_err.backtrace_opt().is_some()); assert!(inner_err.backtrace_opt().is_some());
} }
#[test] #[test]
fn test_datafusion_as_source() { fn test_datafusion_as_source() {
let err: Error = throw_df_error() let err = throw_df_error()
.context(ExecuteFunctionSnafu) .context(ExecuteFunctionSnafu)
.err() .err()
.unwrap() .unwrap();
.into();
assert_error(&err, StatusCode::EngineExecuteQuery); assert_error(&err, StatusCode::EngineExecuteQuery);
let err: Error = throw_df_error() let err: Error = throw_df_error()
.context(GeneralDataFusionSnafu) .context(GeneralDataFusionSnafu)
.err() .err()
.unwrap() .unwrap();
.into();
assert_error(&err, StatusCode::Unexpected); assert_error(&err, StatusCode::Unexpected);
let err: Error = throw_df_error() let err = throw_df_error()
.context(DataFusionExecutionPlanSnafu) .context(DataFusionExecutionPlanSnafu)
.err() .err()
.unwrap() .unwrap();
.into();
assert_error(&err, StatusCode::Unexpected); assert_error(&err, StatusCode::Unexpected);
} }
#[test] #[test]
fn test_execute_repeatedly_error() { fn test_execute_repeatedly_error() {
let error: Error = None::<i32> let error = None::<i32>.context(ExecuteRepeatedlySnafu).err().unwrap();
.context(ExecuteRepeatedlySnafu) assert_eq!(error.status_code(), StatusCode::Unexpected);
.err()
.unwrap()
.into();
assert_eq!(error.inner.status_code(), StatusCode::Unexpected);
assert!(error.backtrace_opt().is_some()); assert!(error.backtrace_opt().is_some());
} }
#[test] #[test]
fn test_convert_df_recordbatch_stream_error() { fn test_convert_df_recordbatch_stream_error() {
let result: std::result::Result<i32, common_recordbatch::error::Error> = let result: std::result::Result<i32, common_recordbatch::error::Error> =
Err(common_recordbatch::error::InnerError::PollStream { Err(common_recordbatch::error::Error::PollStream {
source: ArrowError::Overflow, source: ArrowError::DivideByZero,
backtrace: Backtrace::generate(), backtrace: Backtrace::generate(),
} });
.into()); let error = result
let error: Error = result
.context(ConvertDfRecordBatchStreamSnafu) .context(ConvertDfRecordBatchStreamSnafu)
.err() .err()
.unwrap() .unwrap();
.into(); assert_eq!(error.status_code(), StatusCode::Internal);
assert_eq!(error.inner.status_code(), StatusCode::Internal);
assert!(error.backtrace_opt().is_some()); assert!(error.backtrace_opt().is_some());
} }
@@ -256,13 +287,12 @@ mod tests {
#[test] #[test]
fn test_into_vector_error() { fn test_into_vector_error() {
let err: Error = raise_datatype_error() let err = raise_datatype_error()
.context(IntoVectorSnafu { .context(IntoVectorSnafu {
data_type: ArrowDatatype::Int32, data_type: ArrowDatatype::Int32,
}) })
.err() .err()
.unwrap() .unwrap();
.into();
assert!(err.backtrace_opt().is_some()); assert!(err.backtrace_opt().is_some());
let datatype_err = raise_datatype_error().err().unwrap(); let datatype_err = raise_datatype_error().err().unwrap();
assert_eq!(datatype_err.status_code(), err.status_code()); assert_eq!(datatype_err.status_code(), err.status_code());

View File

@@ -22,7 +22,7 @@ use std::sync::Arc;
use datatypes::prelude::ConcreteDataType; use datatypes::prelude::ConcreteDataType;
pub use self::accumulator::{Accumulator, AggregateFunctionCreator, AggregateFunctionCreatorRef}; pub use self::accumulator::{Accumulator, AggregateFunctionCreator, AggregateFunctionCreatorRef};
pub use self::expr::Expr; pub use self::expr::{DfExpr, Expr};
pub use self::udaf::AggregateFunction; pub use self::udaf::AggregateFunction;
pub use self::udf::ScalarUdf; pub use self::udf::ScalarUdf;
use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation}; use crate::function::{ReturnTypeFunction, ScalarFunctionImplementation};
@@ -148,9 +148,7 @@ mod tests {
let args = vec![ let args = vec![
DfColumnarValue::Scalar(ScalarValue::Boolean(Some(true))), DfColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
DfColumnarValue::Array(Arc::new(BooleanArray::from_slice(vec![ DfColumnarValue::Array(Arc::new(BooleanArray::from(vec![true, false, false, true]))),
true, false, false, true,
]))),
]; ];
// call the function // call the function

View File

@@ -17,12 +17,10 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::Arc; use std::sync::Arc;
use common_time::timestamp::TimeUnit;
use datafusion_common::Result as DfResult; use datafusion_common::Result as DfResult;
use datafusion_expr::Accumulator as DfAccumulator; use datafusion_expr::{Accumulator as DfAccumulator, AggregateState};
use datatypes::arrow::array::ArrayRef; use datatypes::arrow::array::ArrayRef;
use datatypes::prelude::*; use datatypes::prelude::*;
use datatypes::value::ListValue;
use datatypes::vectors::{Helper as VectorHelper, VectorRef}; use datatypes::vectors::{Helper as VectorHelper, VectorRef};
use snafu::ResultExt; use snafu::ResultExt;
@@ -128,356 +126,53 @@ impl DfAccumulatorAdaptor {
} }
impl DfAccumulator for DfAccumulatorAdaptor { impl DfAccumulator for DfAccumulatorAdaptor {
fn state(&self) -> DfResult<Vec<ScalarValue>> { fn state(&self) -> DfResult<Vec<AggregateState>> {
let state_values = self.accumulator.state()?; let state_values = self.accumulator.state()?;
let state_types = self.creator.state_types()?; let state_types = self.creator.state_types()?;
if state_values.len() != state_types.len() { if state_values.len() != state_types.len() {
return error::BadAccumulatorImplSnafu { return error::BadAccumulatorImplSnafu {
err_msg: format!("Accumulator {:?} returned state values size do not match its state types size.", self), err_msg: format!("Accumulator {:?} returned state values size do not match its state types size.", self),
} }
.fail() .fail()?;
.map_err(Error::from)?;
} }
Ok(state_values Ok(state_values
.into_iter() .into_iter()
.zip(state_types.iter()) .zip(state_types.iter())
.map(|(v, t)| try_into_scalar_value(v, t)) .map(|(v, t)| {
.collect::<Result<Vec<_>>>() let scalar = v
.map_err(Error::from)?) .try_to_scalar_value(t)
.context(error::ToScalarValueSnafu)?;
Ok(AggregateState::Scalar(scalar))
})
.collect::<Result<Vec<_>>>()?)
} }
fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> { fn update_batch(&mut self, values: &[ArrayRef]) -> DfResult<()> {
let vectors = VectorHelper::try_into_vectors(values) let vectors = VectorHelper::try_into_vectors(values).context(FromScalarValueSnafu)?;
.context(FromScalarValueSnafu) self.accumulator.update_batch(&vectors)?;
.map_err(Error::from)?; Ok(())
self.accumulator
.update_batch(&vectors)
.map_err(|e| e.into())
} }
fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> { fn merge_batch(&mut self, states: &[ArrayRef]) -> DfResult<()> {
let mut vectors = Vec::with_capacity(states.len()); let mut vectors = Vec::with_capacity(states.len());
for array in states.iter() { for array in states.iter() {
vectors.push( vectors.push(
VectorHelper::try_into_vector(array) VectorHelper::try_into_vector(array).context(IntoVectorSnafu {
.context(IntoVectorSnafu { data_type: array.data_type().clone(),
data_type: array.data_type().clone(), })?,
})
.map_err(Error::from)?,
); );
} }
self.accumulator.merge_batch(&vectors).map_err(|e| e.into()) self.accumulator.merge_batch(&vectors)?;
Ok(())
} }
fn evaluate(&self) -> DfResult<ScalarValue> { fn evaluate(&self) -> DfResult<ScalarValue> {
let value = self.accumulator.evaluate()?; let value = self.accumulator.evaluate()?;
let output_type = self.creator.output_type()?; let output_type = self.creator.output_type()?;
Ok(try_into_scalar_value(value, &output_type)?) let scalar_value = value
} .try_to_scalar_value(&output_type)
} .context(error::ToScalarValueSnafu)
.map_err(Error::from)?;
fn try_into_scalar_value(value: Value, datatype: &ConcreteDataType) -> Result<ScalarValue> { Ok(scalar_value)
if !matches!(value, Value::Null) && datatype != &value.data_type() {
return error::BadAccumulatorImplSnafu {
err_msg: format!(
"expect value to return datatype {:?}, actual: {:?}",
datatype,
value.data_type()
),
}
.fail()?;
}
Ok(match value {
Value::Boolean(v) => ScalarValue::Boolean(Some(v)),
Value::UInt8(v) => ScalarValue::UInt8(Some(v)),
Value::UInt16(v) => ScalarValue::UInt16(Some(v)),
Value::UInt32(v) => ScalarValue::UInt32(Some(v)),
Value::UInt64(v) => ScalarValue::UInt64(Some(v)),
Value::Int8(v) => ScalarValue::Int8(Some(v)),
Value::Int16(v) => ScalarValue::Int16(Some(v)),
Value::Int32(v) => ScalarValue::Int32(Some(v)),
Value::Int64(v) => ScalarValue::Int64(Some(v)),
Value::Float32(v) => ScalarValue::Float32(Some(v.0)),
Value::Float64(v) => ScalarValue::Float64(Some(v.0)),
Value::String(v) => ScalarValue::Utf8(Some(v.as_utf8().to_string())),
Value::Binary(v) => ScalarValue::LargeBinary(Some(v.to_vec())),
Value::Date(v) => ScalarValue::Date32(Some(v.val())),
Value::DateTime(v) => ScalarValue::Date64(Some(v.val())),
Value::Null => try_convert_null_value(datatype)?,
Value::List(list) => try_convert_list_value(list)?,
Value::Timestamp(t) => timestamp_to_scalar_value(t.unit(), Some(t.value())),
})
}
fn timestamp_to_scalar_value(unit: TimeUnit, val: Option<i64>) -> ScalarValue {
match unit {
TimeUnit::Second => ScalarValue::TimestampSecond(val, None),
TimeUnit::Millisecond => ScalarValue::TimestampMillisecond(val, None),
TimeUnit::Microsecond => ScalarValue::TimestampMicrosecond(val, None),
TimeUnit::Nanosecond => ScalarValue::TimestampNanosecond(val, None),
}
}
fn try_convert_null_value(datatype: &ConcreteDataType) -> Result<ScalarValue> {
Ok(match datatype {
ConcreteDataType::Boolean(_) => ScalarValue::Boolean(None),
ConcreteDataType::Int8(_) => ScalarValue::Int8(None),
ConcreteDataType::Int16(_) => ScalarValue::Int16(None),
ConcreteDataType::Int32(_) => ScalarValue::Int32(None),
ConcreteDataType::Int64(_) => ScalarValue::Int64(None),
ConcreteDataType::UInt8(_) => ScalarValue::UInt8(None),
ConcreteDataType::UInt16(_) => ScalarValue::UInt16(None),
ConcreteDataType::UInt32(_) => ScalarValue::UInt32(None),
ConcreteDataType::UInt64(_) => ScalarValue::UInt64(None),
ConcreteDataType::Float32(_) => ScalarValue::Float32(None),
ConcreteDataType::Float64(_) => ScalarValue::Float64(None),
ConcreteDataType::Binary(_) => ScalarValue::LargeBinary(None),
ConcreteDataType::String(_) => ScalarValue::Utf8(None),
ConcreteDataType::Timestamp(t) => timestamp_to_scalar_value(t.unit, None),
_ => {
return error::BadAccumulatorImplSnafu {
err_msg: format!(
"undefined transition from null value to datatype {:?}",
datatype
),
}
.fail()?
}
})
}
fn try_convert_list_value(list: ListValue) -> Result<ScalarValue> {
let vs = if let Some(items) = list.items() {
Some(Box::new(
items
.iter()
.map(|v| try_into_scalar_value(v.clone(), list.datatype()))
.collect::<Result<Vec<_>>>()?,
))
} else {
None
};
Ok(ScalarValue::List(
vs,
Box::new(list.datatype().as_arrow_type()),
))
}
#[cfg(test)]
mod tests {
use common_base::bytes::{Bytes, StringBytes};
use datafusion_common::ScalarValue;
use datatypes::arrow::datatypes::DataType;
use datatypes::value::{ListValue, OrderedFloat};
use super::*;
#[test]
fn test_not_null_value_to_scalar_value() {
assert_eq!(
ScalarValue::Boolean(Some(true)),
try_into_scalar_value(Value::Boolean(true), &ConcreteDataType::boolean_datatype())
.unwrap()
);
assert_eq!(
ScalarValue::Boolean(Some(false)),
try_into_scalar_value(Value::Boolean(false), &ConcreteDataType::boolean_datatype())
.unwrap()
);
assert_eq!(
ScalarValue::UInt8(Some(u8::MIN + 1)),
try_into_scalar_value(
Value::UInt8(u8::MIN + 1),
&ConcreteDataType::uint8_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::UInt16(Some(u16::MIN + 2)),
try_into_scalar_value(
Value::UInt16(u16::MIN + 2),
&ConcreteDataType::uint16_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::UInt32(Some(u32::MIN + 3)),
try_into_scalar_value(
Value::UInt32(u32::MIN + 3),
&ConcreteDataType::uint32_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::UInt64(Some(u64::MIN + 4)),
try_into_scalar_value(
Value::UInt64(u64::MIN + 4),
&ConcreteDataType::uint64_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Int8(Some(i8::MIN + 4)),
try_into_scalar_value(Value::Int8(i8::MIN + 4), &ConcreteDataType::int8_datatype())
.unwrap()
);
assert_eq!(
ScalarValue::Int16(Some(i16::MIN + 5)),
try_into_scalar_value(
Value::Int16(i16::MIN + 5),
&ConcreteDataType::int16_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Int32(Some(i32::MIN + 6)),
try_into_scalar_value(
Value::Int32(i32::MIN + 6),
&ConcreteDataType::int32_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Int64(Some(i64::MIN + 7)),
try_into_scalar_value(
Value::Int64(i64::MIN + 7),
&ConcreteDataType::int64_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Float32(Some(8.0f32)),
try_into_scalar_value(
Value::Float32(OrderedFloat(8.0f32)),
&ConcreteDataType::float32_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Float64(Some(9.0f64)),
try_into_scalar_value(
Value::Float64(OrderedFloat(9.0f64)),
&ConcreteDataType::float64_datatype()
)
.unwrap()
);
assert_eq!(
ScalarValue::Utf8(Some("hello".to_string())),
try_into_scalar_value(
Value::String(StringBytes::from("hello")),
&ConcreteDataType::string_datatype(),
)
.unwrap()
);
assert_eq!(
ScalarValue::LargeBinary(Some("world".as_bytes().to_vec())),
try_into_scalar_value(
Value::Binary(Bytes::from("world".as_bytes())),
&ConcreteDataType::binary_datatype()
)
.unwrap()
);
}
#[test]
fn test_null_value_to_scalar_value() {
assert_eq!(
ScalarValue::Boolean(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::boolean_datatype()).unwrap()
);
assert_eq!(
ScalarValue::UInt8(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::uint8_datatype()).unwrap()
);
assert_eq!(
ScalarValue::UInt16(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::uint16_datatype()).unwrap()
);
assert_eq!(
ScalarValue::UInt32(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::uint32_datatype()).unwrap()
);
assert_eq!(
ScalarValue::UInt64(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::uint64_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Int8(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::int8_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Int16(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::int16_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Int32(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::int32_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Int64(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::int64_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Float32(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::float32_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Float64(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::float64_datatype()).unwrap()
);
assert_eq!(
ScalarValue::Utf8(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::string_datatype()).unwrap()
);
assert_eq!(
ScalarValue::LargeBinary(None),
try_into_scalar_value(Value::Null, &ConcreteDataType::binary_datatype()).unwrap()
);
}
#[test]
fn test_list_value_to_scalar_value() {
let items = Some(Box::new(vec![Value::Int32(-1), Value::Null]));
let list = Value::List(ListValue::new(items, ConcreteDataType::int32_datatype()));
let df_list = try_into_scalar_value(
list,
&ConcreteDataType::list_datatype(ConcreteDataType::int32_datatype()),
)
.unwrap();
assert!(matches!(df_list, ScalarValue::List(_, _)));
match df_list {
ScalarValue::List(vs, datatype) => {
assert_eq!(*datatype, DataType::Int32);
assert!(vs.is_some());
let vs = *vs.unwrap();
assert_eq!(
vs,
vec![ScalarValue::Int32(Some(-1)), ScalarValue::Int32(None)]
);
}
_ => unreachable!(),
}
}
#[test]
pub fn test_timestamp_to_scalar_value() {
assert_eq!(
ScalarValue::TimestampSecond(Some(1), None),
timestamp_to_scalar_value(TimeUnit::Second, Some(1))
);
assert_eq!(
ScalarValue::TimestampMillisecond(Some(1), None),
timestamp_to_scalar_value(TimeUnit::Millisecond, Some(1))
);
assert_eq!(
ScalarValue::TimestampMicrosecond(Some(1), None),
timestamp_to_scalar_value(TimeUnit::Microsecond, Some(1))
);
assert_eq!(
ScalarValue::TimestampNanosecond(Some(1), None),
timestamp_to_scalar_value(TimeUnit::Nanosecond, Some(1))
);
} }
} }

View File

@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use datafusion::logical_plan::Expr as DfExpr; pub use datafusion_expr::expr::Expr as DfExpr;
/// Central struct of query API. /// Central struct of query API.
/// Represent logical expressions such as `A + 1`, or `CAST(c1 AS int)`. /// Represent logical expressions such as `A + 1`, or `CAST(c1 AS int)`.
#[derive(Clone, PartialEq, Hash, Debug)] #[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Expr { pub struct Expr {
df_expr: DfExpr, df_expr: DfExpr,
} }

View File

@@ -104,7 +104,7 @@ fn to_df_accumulator_func(
accumulator: AccumulatorFunctionImpl, accumulator: AccumulatorFunctionImpl,
creator: AggregateFunctionCreatorRef, creator: AggregateFunctionCreatorRef,
) -> DfAccumulatorFunctionImplementation { ) -> DfAccumulatorFunctionImplementation {
Arc::new(move || { Arc::new(move |_| {
let accumulator = accumulator()?; let accumulator = accumulator()?;
let creator = creator.clone(); let creator = creator.clone();
Ok(Box::new(DfAccumulatorAdaptor::new(accumulator, creator))) Ok(Box::new(DfAccumulatorAdaptor::new(accumulator, creator)))

View File

@@ -16,12 +16,11 @@ use std::any::Any;
use std::fmt::Debug; use std::fmt::Debug;
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use common_recordbatch::adapter::{DfRecordBatchStreamAdapter, RecordBatchStreamAdapter};
use common_recordbatch::adapter::{AsyncRecordBatchStreamAdapter, DfRecordBatchStreamAdapter};
use common_recordbatch::{DfSendableRecordBatchStream, SendableRecordBatchStream}; use common_recordbatch::{DfSendableRecordBatchStream, SendableRecordBatchStream};
use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef; use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef;
use datafusion::error::Result as DfResult; use datafusion::error::Result as DfResult;
pub use datafusion::execution::runtime_env::RuntimeEnv; pub use datafusion::execution::context::{SessionContext, TaskContext};
use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::expressions::PhysicalSortExpr;
pub use datafusion::physical_plan::Partitioning; pub use datafusion::physical_plan::Partitioning;
use datafusion::physical_plan::Statistics; use datafusion::physical_plan::Statistics;
@@ -63,7 +62,7 @@ pub trait PhysicalPlan: Debug + Send + Sync {
fn execute( fn execute(
&self, &self,
partition: usize, partition: usize,
runtime: Arc<RuntimeEnv>, context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream>; ) -> Result<SendableRecordBatchStream>;
} }
@@ -111,6 +110,7 @@ impl PhysicalPlan for PhysicalPlanAdapter {
.collect(); .collect();
let plan = self let plan = self
.df_plan .df_plan
.clone()
.with_new_children(children) .with_new_children(children)
.context(error::GeneralDataFusionSnafu)?; .context(error::GeneralDataFusionSnafu)?;
Ok(Arc::new(PhysicalPlanAdapter::new(self.schema(), plan))) Ok(Arc::new(PhysicalPlanAdapter::new(self.schema(), plan)))
@@ -119,20 +119,22 @@ impl PhysicalPlan for PhysicalPlanAdapter {
fn execute( fn execute(
&self, &self,
partition: usize, partition: usize,
runtime: Arc<RuntimeEnv>, context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> { ) -> Result<SendableRecordBatchStream> {
let df_plan = self.df_plan.clone(); let df_plan = self.df_plan.clone();
let stream = Box::pin(async move { df_plan.execute(partition, runtime).await }); let stream = df_plan
let stream = AsyncRecordBatchStreamAdapter::new(self.schema(), stream); .execute(partition, context)
.context(error::GeneralDataFusionSnafu)?;
let adapter = RecordBatchStreamAdapter::try_new(stream)
.context(error::ConvertDfRecordBatchStreamSnafu)?;
Ok(Box::pin(stream)) Ok(Box::pin(adapter))
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub struct DfPhysicalPlanAdapter(pub PhysicalPlanRef); pub struct DfPhysicalPlanAdapter(pub PhysicalPlanRef);
#[async_trait]
impl DfPhysicalPlan for DfPhysicalPlanAdapter { impl DfPhysicalPlan for DfPhysicalPlanAdapter {
fn as_any(&self) -> &dyn Any { fn as_any(&self) -> &dyn Any {
self self
@@ -159,15 +161,14 @@ impl DfPhysicalPlan for DfPhysicalPlanAdapter {
} }
fn with_new_children( fn with_new_children(
&self, self: Arc<Self>,
children: Vec<Arc<dyn DfPhysicalPlan>>, children: Vec<Arc<dyn DfPhysicalPlan>>,
) -> DfResult<Arc<dyn DfPhysicalPlan>> { ) -> DfResult<Arc<dyn DfPhysicalPlan>> {
let df_schema = self.schema(); let df_schema = self.schema();
let schema: SchemaRef = Arc::new( let schema: SchemaRef = Arc::new(
df_schema df_schema
.try_into() .try_into()
.context(error::ConvertArrowSchemaSnafu) .context(error::ConvertArrowSchemaSnafu)?,
.map_err(error::Error::from)?,
); );
let children = children let children = children
.into_iter() .into_iter()
@@ -177,12 +178,12 @@ impl DfPhysicalPlan for DfPhysicalPlanAdapter {
Ok(Arc::new(DfPhysicalPlanAdapter(plan))) Ok(Arc::new(DfPhysicalPlanAdapter(plan)))
} }
async fn execute( fn execute(
&self, &self,
partition: usize, partition: usize,
runtime: Arc<RuntimeEnv>, context: Arc<TaskContext>,
) -> DfResult<DfSendableRecordBatchStream> { ) -> DfResult<DfSendableRecordBatchStream> {
let stream = self.0.execute(partition, runtime)?; let stream = self.0.execute(partition, context)?;
Ok(Box::pin(DfRecordBatchStreamAdapter::new(stream))) Ok(Box::pin(DfRecordBatchStreamAdapter::new(stream)))
} }
@@ -194,16 +195,16 @@ impl DfPhysicalPlan for DfPhysicalPlanAdapter {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use async_trait::async_trait;
use common_recordbatch::{RecordBatch, RecordBatches}; use common_recordbatch::{RecordBatch, RecordBatches};
use datafusion::arrow_print; use datafusion::datasource::{DefaultTableSource, TableProvider as DfTableProvider, TableType};
use datafusion::datasource::TableProvider as DfTableProvider; use datafusion::execution::context::{SessionContext, SessionState};
use datafusion::logical_plan::LogicalPlanBuilder;
use datafusion::physical_plan::collect; use datafusion::physical_plan::collect;
use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::empty::EmptyExec;
use datafusion::prelude::ExecutionContext; use datafusion_expr::logical_plan::builder::LogicalPlanBuilder;
use datafusion_common::field_util::SchemaExt; use datafusion_expr::{Expr, TableSource};
use datafusion_expr::Expr;
use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
use datatypes::arrow::util::pretty;
use datatypes::schema::Schema; use datatypes::schema::Schema;
use datatypes::vectors::Int32Vector; use datatypes::vectors::Int32Vector;
@@ -225,8 +226,13 @@ mod test {
)])) )]))
} }
fn table_type(&self) -> TableType {
TableType::Base
}
async fn scan( async fn scan(
&self, &self,
_ctx: &SessionState,
_projection: &Option<Vec<usize>>, _projection: &Option<Vec<usize>>,
_filters: &[Expr], _filters: &[Expr],
_limit: Option<usize>, _limit: Option<usize>,
@@ -240,6 +246,14 @@ mod test {
} }
} }
impl MyDfTableProvider {
fn table_source() -> Arc<dyn TableSource> {
Arc::new(DefaultTableSource {
table_provider: Arc::new(Self),
})
}
}
#[derive(Debug)] #[derive(Debug)]
struct MyExecutionPlan { struct MyExecutionPlan {
schema: SchemaRef, schema: SchemaRef,
@@ -269,7 +283,7 @@ mod test {
fn execute( fn execute(
&self, &self,
_partition: usize, _partition: usize,
_runtime: Arc<RuntimeEnv>, _context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> { ) -> Result<SendableRecordBatchStream> {
let schema = self.schema(); let schema = self.schema();
let recordbatches = RecordBatches::try_new( let recordbatches = RecordBatches::try_new(
@@ -295,20 +309,26 @@ mod test {
// Test our physical plan can be executed by DataFusion, through adapters. // Test our physical plan can be executed by DataFusion, through adapters.
#[tokio::test] #[tokio::test]
async fn test_execute_physical_plan() { async fn test_execute_physical_plan() {
let ctx = ExecutionContext::new(); let ctx = SessionContext::new();
let logical_plan = LogicalPlanBuilder::scan("test", Arc::new(MyDfTableProvider), None) let logical_plan =
.unwrap() LogicalPlanBuilder::scan("test", MyDfTableProvider::table_source(), None)
.build() .unwrap()
.unwrap(); .build()
.unwrap();
let physical_plan = ctx.create_physical_plan(&logical_plan).await.unwrap(); let physical_plan = ctx.create_physical_plan(&logical_plan).await.unwrap();
let df_recordbatches = collect(physical_plan, Arc::new(RuntimeEnv::default())) let df_recordbatches = collect(physical_plan, Arc::new(TaskContext::from(&ctx)))
.await .await
.unwrap(); .unwrap();
let pretty_print = arrow_print::write(&df_recordbatches); let pretty_print = pretty::pretty_format_batches(&df_recordbatches).unwrap();
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
assert_eq!( assert_eq!(
pretty_print, pretty_print.to_string(),
vec!["+---+", "| a |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+",] r#"+---+
| a |
+---+
| 1 |
| 2 |
| 3 |
+---+"#
); );
} }

View File

@@ -15,7 +15,7 @@
//! Signature module contains foundational types that are used to represent signatures, types, //! Signature module contains foundational types that are used to represent signatures, types,
//! and return types of functions. //! and return types of functions.
//! Copied and modified from datafusion. //! Copied and modified from datafusion.
pub use datafusion::physical_plan::functions::Volatility; pub use datafusion_expr::Volatility;
use datafusion_expr::{Signature as DfSignature, TypeSignature as DfTypeSignature}; use datafusion_expr::{Signature as DfSignature, TypeSignature as DfTypeSignature};
use datatypes::arrow::datatypes::DataType as ArrowDataType; use datatypes::arrow::datatypes::DataType as ArrowDataType;
use datatypes::data_type::DataType; use datatypes::data_type::DataType;

View File

@@ -6,10 +6,8 @@ license = "Apache-2.0"
[dependencies] [dependencies]
common-error = { path = "../error" } common-error = { path = "../error" }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [ datafusion = "14.0.0"
"simd", datafusion-common = "14.0.0"
] }
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" }
datatypes = { path = "../../datatypes" } datatypes = { path = "../../datatypes" }
futures = "0.3" futures = "0.3"
paste = "1.0" paste = "1.0"

View File

@@ -19,7 +19,6 @@ use std::task::{Context, Poll};
use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef; use datafusion::arrow::datatypes::SchemaRef as DfSchemaRef;
use datafusion::physical_plan::RecordBatchStream as DfRecordBatchStream; use datafusion::physical_plan::RecordBatchStream as DfRecordBatchStream;
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
use datafusion_common::DataFusionError; use datafusion_common::DataFusionError;
use datatypes::arrow::error::{ArrowError, Result as ArrowResult}; use datatypes::arrow::error::{ArrowError, Result as ArrowResult};
use datatypes::schema::{Schema, SchemaRef}; use datatypes::schema::{Schema, SchemaRef};
@@ -28,7 +27,8 @@ use snafu::ResultExt;
use crate::error::{self, Result}; use crate::error::{self, Result};
use crate::{ use crate::{
DfSendableRecordBatchStream, RecordBatch, RecordBatchStream, SendableRecordBatchStream, Stream, DfRecordBatch, DfSendableRecordBatchStream, RecordBatch, RecordBatchStream,
SendableRecordBatchStream, Stream,
}; };
type FutureStream = Pin< type FutureStream = Pin<
@@ -63,8 +63,8 @@ impl Stream for DfRecordBatchStreamAdapter {
match Pin::new(&mut self.stream).poll_next(cx) { match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(Some(recordbatch)) => match recordbatch { Poll::Ready(Some(recordbatch)) => match recordbatch {
Ok(recordbatch) => Poll::Ready(Some(Ok(recordbatch.df_recordbatch))), Ok(recordbatch) => Poll::Ready(Some(Ok(recordbatch.into_df_record_batch()))),
Err(e) => Poll::Ready(Some(Err(ArrowError::External("".to_owned(), Box::new(e))))), Err(e) => Poll::Ready(Some(Err(ArrowError::ExternalError(Box::new(e))))),
}, },
Poll::Ready(None) => Poll::Ready(None), Poll::Ready(None) => Poll::Ready(None),
} }
@@ -102,10 +102,13 @@ impl Stream for RecordBatchStreamAdapter {
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.stream).poll_next(cx) { match Pin::new(&mut self.stream).poll_next(cx) {
Poll::Pending => Poll::Pending, Poll::Pending => Poll::Pending,
Poll::Ready(Some(df_recordbatch)) => Poll::Ready(Some(Ok(RecordBatch { Poll::Ready(Some(df_record_batch)) => {
schema: self.schema(), let df_record_batch = df_record_batch.context(error::PollStreamSnafu)?;
df_recordbatch: df_recordbatch.context(error::PollStreamSnafu)?, Poll::Ready(Some(RecordBatch::try_from_df_record_batch(
}))), self.schema(),
df_record_batch,
)))
}
Poll::Ready(None) => Poll::Ready(None), Poll::Ready(None) => Poll::Ready(None),
} }
} }
@@ -157,10 +160,8 @@ impl Stream for AsyncRecordBatchStreamAdapter {
AsyncRecordBatchStreamAdapterState::Inited(stream) => match stream { AsyncRecordBatchStreamAdapterState::Inited(stream) => match stream {
Ok(stream) => { Ok(stream) => {
return Poll::Ready(ready!(Pin::new(stream).poll_next(cx)).map(|df| { return Poll::Ready(ready!(Pin::new(stream).poll_next(cx)).map(|df| {
Ok(RecordBatch { let df_record_batch = df.context(error::PollStreamSnafu)?;
schema: self.schema(), RecordBatch::try_from_df_record_batch(self.schema(), df_record_batch)
df_recordbatch: df.context(error::PollStreamSnafu)?,
})
})); }));
} }
Err(e) => { Err(e) => {
@@ -168,8 +169,7 @@ impl Stream for AsyncRecordBatchStreamAdapter {
error::CreateRecordBatchesSnafu { error::CreateRecordBatchesSnafu {
reason: format!("Read error {:?} from stream", e), reason: format!("Read error {:?} from stream", e),
} }
.fail() .fail(),
.map_err(|e| e.into()),
)) ))
} }
}, },

View File

@@ -17,13 +17,12 @@ use std::any::Any;
use common_error::ext::BoxedError; use common_error::ext::BoxedError;
use common_error::prelude::*; use common_error::prelude::*;
common_error::define_opaque_error!(Error);
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Snafu)] #[derive(Debug, Snafu)]
#[snafu(visibility(pub))] #[snafu(visibility(pub))]
pub enum InnerError { pub enum Error {
#[snafu(display("Fail to create datafusion record batch, source: {}", source))] #[snafu(display("Fail to create datafusion record batch, source: {}", source))]
NewDfRecordBatch { NewDfRecordBatch {
source: datatypes::arrow::error::ArrowError, source: datatypes::arrow::error::ArrowError,
@@ -59,20 +58,27 @@ pub enum InnerError {
source: datatypes::arrow::error::ArrowError, source: datatypes::arrow::error::ArrowError,
backtrace: Backtrace, backtrace: Backtrace,
}, },
#[snafu(display("Fail to format record batch, source: {}", source))]
Format {
source: datatypes::arrow::error::ArrowError,
backtrace: Backtrace,
},
} }
impl ErrorExt for InnerError { impl ErrorExt for Error {
fn status_code(&self) -> StatusCode { fn status_code(&self) -> StatusCode {
match self { match self {
InnerError::NewDfRecordBatch { .. } => StatusCode::InvalidArguments, Error::NewDfRecordBatch { .. } => StatusCode::InvalidArguments,
InnerError::DataTypes { .. } Error::DataTypes { .. }
| InnerError::CreateRecordBatches { .. } | Error::CreateRecordBatches { .. }
| InnerError::PollStream { .. } => StatusCode::Internal, | Error::PollStream { .. }
| Error::Format { .. } => StatusCode::Internal,
InnerError::External { source } => source.status_code(), Error::External { source } => source.status_code(),
InnerError::SchemaConversion { source, .. } => source.status_code(), Error::SchemaConversion { source, .. } => source.status_code(),
} }
} }
@@ -84,9 +90,3 @@ impl ErrorExt for InnerError {
self self
} }
} }
impl From<InnerError> for Error {
fn from(e: InnerError) -> Error {
Error::new(e)
}
}

View File

@@ -20,16 +20,17 @@ pub mod util;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use datafusion::arrow_print;
use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::memory::MemoryStream;
pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream; pub use datafusion::physical_plan::SendableRecordBatchStream as DfSendableRecordBatchStream;
pub use datatypes::arrow::record_batch::RecordBatch as DfRecordBatch;
use datatypes::arrow::util::pretty;
use datatypes::prelude::VectorRef; use datatypes::prelude::VectorRef;
use datatypes::schema::{Schema, SchemaRef}; use datatypes::schema::{Schema, SchemaRef};
use error::Result; use error::Result;
use futures::task::{Context, Poll}; use futures::task::{Context, Poll};
use futures::Stream; use futures::{Stream, TryStreamExt};
pub use recordbatch::RecordBatch; pub use recordbatch::RecordBatch;
use snafu::ensure; use snafu::{ensure, ResultExt};
pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> { pub trait RecordBatchStream: Stream<Item = Result<RecordBatch>> {
fn schema(&self) -> SchemaRef; fn schema(&self) -> SchemaRef;
@@ -65,7 +66,7 @@ impl Stream for EmptyRecordBatchStream {
} }
} }
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub struct RecordBatches { pub struct RecordBatches {
schema: SchemaRef, schema: SchemaRef,
batches: Vec<RecordBatch>, batches: Vec<RecordBatch>,
@@ -80,6 +81,12 @@ impl RecordBatches {
Ok(Self { schema, batches }) Ok(Self { schema, batches })
} }
pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
let schema = stream.schema();
let batches = stream.try_collect::<Vec<_>>().await?;
Ok(Self { schema, batches })
}
#[inline] #[inline]
pub fn empty() -> Self { pub fn empty() -> Self {
Self { Self {
@@ -92,17 +99,18 @@ impl RecordBatches {
self.batches.iter() self.batches.iter()
} }
pub fn pretty_print(&self) -> String { pub fn pretty_print(&self) -> Result<String> {
arrow_print::write( let df_batches = &self
&self .iter()
.iter() .map(|x| x.df_record_batch().clone())
.map(|x| x.df_recordbatch.clone()) .collect::<Vec<_>>();
.collect::<Vec<_>>(), let result = pretty::pretty_format_batches(df_batches).context(error::FormatSnafu)?;
)
Ok(result.to_string())
} }
pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> { pub fn try_new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Result<Self> {
for batch in batches.iter() { for batch in &batches {
ensure!( ensure!(
batch.schema == schema, batch.schema == schema,
error::CreateRecordBatchesSnafu { error::CreateRecordBatchesSnafu {
@@ -138,7 +146,7 @@ impl RecordBatches {
let df_record_batches = self let df_record_batches = self
.batches .batches
.into_iter() .into_iter()
.map(|batch| batch.df_recordbatch) .map(|batch| batch.into_df_record_batch())
.collect(); .collect();
// unwrap safety: `MemoryStream::try_new` won't fail // unwrap safety: `MemoryStream::try_new` won't fail
Box::pin( Box::pin(
@@ -236,7 +244,7 @@ mod tests {
| 1 | hello | | 1 | hello |
| 2 | world | | 2 | world |
+---+-------+"; +---+-------+";
assert_eq!(batches.pretty_print(), expected); assert_eq!(batches.pretty_print().unwrap(), expected);
assert_eq!(schema1, batches.schema()); assert_eq!(schema1, batches.schema());
assert_eq!(vec![batch1], batches.take()); assert_eq!(vec![batch1], batches.take());

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
use datatypes::arrow_array::arrow_array_get;
use datatypes::schema::SchemaRef; use datatypes::schema::SchemaRef;
use datatypes::value::Value; use datatypes::value::Value;
use datatypes::vectors::{Helper, VectorRef}; use datatypes::vectors::{Helper, VectorRef};
@@ -22,31 +20,88 @@ use serde::{Serialize, Serializer};
use snafu::ResultExt; use snafu::ResultExt;
use crate::error::{self, Result}; use crate::error::{self, Result};
use crate::DfRecordBatch;
/// A two-dimensional batch of column-oriented data with a defined schema.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct RecordBatch { pub struct RecordBatch {
pub schema: SchemaRef, pub schema: SchemaRef,
pub df_recordbatch: DfRecordBatch, columns: Vec<VectorRef>,
df_record_batch: DfRecordBatch,
} }
impl RecordBatch { impl RecordBatch {
/// Create a new [`RecordBatch`] from `schema` and `columns`.
pub fn new<I: IntoIterator<Item = VectorRef>>( pub fn new<I: IntoIterator<Item = VectorRef>>(
schema: SchemaRef, schema: SchemaRef,
columns: I, columns: I,
) -> Result<RecordBatch> { ) -> Result<RecordBatch> {
let arrow_arrays = columns.into_iter().map(|v| v.to_arrow_array()).collect(); let columns: Vec<_> = columns.into_iter().collect();
let arrow_arrays = columns.iter().map(|v| v.to_arrow_array()).collect();
let df_recordbatch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays) let df_record_batch = DfRecordBatch::try_new(schema.arrow_schema().clone(), arrow_arrays)
.context(error::NewDfRecordBatchSnafu)?; .context(error::NewDfRecordBatchSnafu)?;
Ok(RecordBatch { Ok(RecordBatch {
schema, schema,
df_recordbatch, columns,
df_record_batch,
}) })
} }
/// Create a new [`RecordBatch`] from `schema` and `df_record_batch`.
///
/// This method doesn't check the schema.
pub fn try_from_df_record_batch(
schema: SchemaRef,
df_record_batch: DfRecordBatch,
) -> Result<RecordBatch> {
let columns = df_record_batch
.columns()
.iter()
.map(|c| Helper::try_into_vector(c.clone()).context(error::DataTypesSnafu))
.collect::<Result<Vec<_>>>()?;
Ok(RecordBatch {
schema,
columns,
df_record_batch,
})
}
#[inline]
pub fn df_record_batch(&self) -> &DfRecordBatch {
&self.df_record_batch
}
#[inline]
pub fn into_df_record_batch(self) -> DfRecordBatch {
self.df_record_batch
}
#[inline]
pub fn columns(&self) -> &[VectorRef] {
&self.columns
}
#[inline]
pub fn column(&self, idx: usize) -> &VectorRef {
&self.columns[idx]
}
pub fn column_by_name(&self, name: &str) -> Option<&VectorRef> {
let idx = self.schema.column_index_by_name(name)?;
Some(&self.columns[idx])
}
#[inline]
pub fn num_columns(&self) -> usize {
self.columns.len()
}
#[inline]
pub fn num_rows(&self) -> usize { pub fn num_rows(&self) -> usize {
self.df_recordbatch.num_rows() self.df_record_batch.num_rows()
} }
/// Create an iterator to traverse the data by row /// Create an iterator to traverse the data by row
@@ -60,14 +115,15 @@ impl Serialize for RecordBatch {
where where
S: Serializer, S: Serializer,
{ {
// TODO(yingwen): arrow and arrow2's schemas have different fields, so
// it might be better to use our `RawSchema` as serialized field.
let mut s = serializer.serialize_struct("record", 2)?; let mut s = serializer.serialize_struct("record", 2)?;
s.serialize_field("schema", &self.schema.arrow_schema())?; s.serialize_field("schema", &**self.schema.arrow_schema())?;
let df_columns = self.df_recordbatch.columns(); let vec = self
.columns
let vec = df_columns
.iter() .iter()
.map(|c| Helper::try_into_vector(c.clone())?.serialize_to_json()) .map(|c| c.serialize_to_json())
.collect::<std::result::Result<Vec<_>, _>>() .collect::<std::result::Result<Vec<_>, _>>()
.map_err(S::Error::custom)?; .map_err(S::Error::custom)?;
@@ -87,8 +143,8 @@ impl<'a> RecordBatchRowIterator<'a> {
fn new(record_batch: &'a RecordBatch) -> RecordBatchRowIterator { fn new(record_batch: &'a RecordBatch) -> RecordBatchRowIterator {
RecordBatchRowIterator { RecordBatchRowIterator {
record_batch, record_batch,
rows: record_batch.df_recordbatch.num_rows(), rows: record_batch.df_record_batch.num_rows(),
columns: record_batch.df_recordbatch.num_columns(), columns: record_batch.df_record_batch.num_columns(),
row_cursor: 0, row_cursor: 0,
} }
} }
@@ -104,13 +160,8 @@ impl<'a> Iterator for RecordBatchRowIterator<'a> {
let mut row = Vec::with_capacity(self.columns); let mut row = Vec::with_capacity(self.columns);
for col in 0..self.columns { for col in 0..self.columns {
let column_array = self.record_batch.df_recordbatch.column(col); let column = self.record_batch.column(col);
match arrow_array_get(column_array.as_ref(), self.row_cursor) row.push(column.get(self.row_cursor));
.context(error::DataTypesSnafu)
{
Ok(field) => row.push(field),
Err(e) => return Some(Err(e.into())),
}
} }
self.row_cursor += 1; self.row_cursor += 1;
@@ -123,63 +174,60 @@ impl<'a> Iterator for RecordBatchRowIterator<'a> {
mod tests { mod tests {
use std::sync::Arc; use std::sync::Arc;
use datafusion_common::field_util::SchemaExt;
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
use datatypes::arrow::array::UInt32Array;
use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
use datatypes::prelude::*; use datatypes::data_type::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema}; use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{StringVector, UInt32Vector, Vector}; use datatypes::vectors::{StringVector, UInt32Vector};
use super::*; use super::*;
#[test] #[test]
fn test_new_record_batch() { fn test_record_batch() {
let arrow_schema = Arc::new(ArrowSchema::new(vec![ let arrow_schema = Arc::new(ArrowSchema::new(vec![
Field::new("c1", DataType::UInt32, false), Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt32, false), Field::new("c2", DataType::UInt32, false),
])); ]));
let schema = Arc::new(Schema::try_from(arrow_schema).unwrap()); let schema = Arc::new(Schema::try_from(arrow_schema).unwrap());
let v = Arc::new(UInt32Vector::from_slice(&[1, 2, 3])); let c1 = Arc::new(UInt32Vector::from_slice(&[1, 2, 3]));
let columns: Vec<VectorRef> = vec![v.clone(), v.clone()]; let c2 = Arc::new(UInt32Vector::from_slice(&[4, 5, 6]));
let columns: Vec<VectorRef> = vec![c1, c2];
let batch = RecordBatch::new(schema.clone(), columns).unwrap(); let batch = RecordBatch::new(schema.clone(), columns.clone()).unwrap();
let expect = v.to_arrow_array(); assert_eq!(3, batch.num_rows());
for column in batch.df_recordbatch.columns() { assert_eq!(&columns, batch.columns());
let array = column.as_any().downcast_ref::<UInt32Array>().unwrap(); for (i, expect) in columns.iter().enumerate().take(batch.num_columns()) {
assert_eq!( let column = batch.column(i);
expect.as_any().downcast_ref::<UInt32Array>().unwrap(), assert_eq!(expect, column);
array
);
} }
assert_eq!(schema, batch.schema); assert_eq!(schema, batch.schema);
assert_eq!(columns[0], *batch.column_by_name("c1").unwrap());
assert_eq!(columns[1], *batch.column_by_name("c2").unwrap());
assert!(batch.column_by_name("c3").is_none());
let converted =
RecordBatch::try_from_df_record_batch(schema, batch.df_record_batch().clone()).unwrap();
assert_eq!(batch, converted);
assert_eq!(*batch.df_record_batch(), converted.into_df_record_batch());
} }
#[test] #[test]
pub fn test_serialize_recordbatch() { pub fn test_serialize_recordbatch() {
let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new( let column_schemas = vec![ColumnSchema::new(
"number", "number",
DataType::UInt32, ConcreteDataType::uint32_datatype(),
false, false,
)])); )];
let schema = Arc::new(Schema::try_from(arrow_schema.clone()).unwrap()); let schema = Arc::new(Schema::try_new(column_schemas).unwrap());
let numbers: Vec<u32> = (0..10).collect(); let numbers: Vec<u32> = (0..10).collect();
let df_batch = DfRecordBatch::try_new( let columns = vec![Arc::new(UInt32Vector::from_slice(&numbers)) as VectorRef];
arrow_schema, let batch = RecordBatch::new(schema, columns).unwrap();
vec![Arc::new(UInt32Array::from_slice(&numbers))],
)
.unwrap();
let batch = RecordBatch {
schema,
df_recordbatch: df_batch,
};
let output = serde_json::to_string(&batch).unwrap(); let output = serde_json::to_string(&batch).unwrap();
assert_eq!( assert_eq!(
r#"{"schema":{"fields":[{"name":"number","data_type":"UInt32","is_nullable":false,"metadata":{}}],"metadata":{}},"columns":[[0,1,2,3,4,5,6,7,8,9]]}"#, r#"{"schema":{"fields":[{"name":"number","data_type":"UInt32","nullable":false,"dict_id":0,"dict_is_ordered":false}],"metadata":{"greptime:version":"0"}},"columns":[[0,1,2,3,4,5,6,7,8,9]]}"#,
output output
); );
} }

View File

@@ -15,23 +15,29 @@
use futures::TryStreamExt; use futures::TryStreamExt;
use crate::error::Result; use crate::error::Result;
use crate::{RecordBatch, SendableRecordBatchStream}; use crate::{RecordBatch, RecordBatches, SendableRecordBatchStream};
/// Collect all the items from the stream into a vector of [`RecordBatch`].
pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatch>> { pub async fn collect(stream: SendableRecordBatchStream) -> Result<Vec<RecordBatch>> {
stream.try_collect::<Vec<_>>().await stream.try_collect::<Vec<_>>().await
} }
/// Collect all the items from the stream into [RecordBatches].
pub async fn collect_batches(stream: SendableRecordBatchStream) -> Result<RecordBatches> {
let schema = stream.schema();
let batches = stream.try_collect::<Vec<_>>().await?;
RecordBatches::try_new(schema, batches)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::mem; use std::mem;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use datafusion_common::field_util::SchemaExt; use datatypes::prelude::*;
use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
use datatypes::arrow::array::UInt32Array; use datatypes::vectors::UInt32Vector;
use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
use datatypes::schema::{Schema, SchemaRef};
use futures::task::{Context, Poll}; use futures::task::{Context, Poll};
use futures::Stream; use futures::Stream;
@@ -65,12 +71,13 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_collect() { async fn test_collect() {
let arrow_schema = Arc::new(ArrowSchema::new(vec![Field::new( let column_schemas = vec![ColumnSchema::new(
"number", "number",
DataType::UInt32, ConcreteDataType::uint32_datatype(),
false, false,
)])); )];
let schema = Arc::new(Schema::try_from(arrow_schema.clone()).unwrap());
let schema = Arc::new(Schema::try_new(column_schemas).unwrap());
let stream = MockRecordBatchStream { let stream = MockRecordBatchStream {
schema: schema.clone(), schema: schema.clone(),
@@ -81,24 +88,23 @@ mod tests {
assert_eq!(0, batches.len()); assert_eq!(0, batches.len());
let numbers: Vec<u32> = (0..10).collect(); let numbers: Vec<u32> = (0..10).collect();
let df_batch = DfRecordBatch::try_new( let columns = [Arc::new(UInt32Vector::from_vec(numbers)) as _];
arrow_schema.clone(), let batch = RecordBatch::new(schema.clone(), columns).unwrap();
vec![Arc::new(UInt32Array::from_slice(&numbers))],
)
.unwrap();
let batch = RecordBatch {
schema: schema.clone(),
df_recordbatch: df_batch,
};
let stream = MockRecordBatchStream { let stream = MockRecordBatchStream {
schema: Arc::new(Schema::try_from(arrow_schema).unwrap()), schema: schema.clone(),
batch: Some(batch.clone()), batch: Some(batch.clone()),
}; };
let batches = collect(Box::pin(stream)).await.unwrap(); let batches = collect(Box::pin(stream)).await.unwrap();
assert_eq!(1, batches.len()); assert_eq!(1, batches.len());
assert_eq!(batch, batches[0]); assert_eq!(batch, batches[0]);
let stream = MockRecordBatchStream {
schema: schema.clone(),
batch: Some(batch.clone()),
};
let batches = collect_batches(Box::pin(stream)).await.unwrap();
let expect_batches = RecordBatches::try_new(schema.clone(), vec![batch]).unwrap();
assert_eq!(expect_batches, batches);
} }
} }

View File

@@ -10,10 +10,8 @@ catalog = { path = "../../catalog" }
common-catalog = { path = "../catalog" } common-catalog = { path = "../catalog" }
common-error = { path = "../error" } common-error = { path = "../error" }
common-telemetry = { path = "../telemetry" } common-telemetry = { path = "../telemetry" }
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [ datafusion = "14.0.0"
"simd", datafusion-expr = "14.0.0"
] }
datafusion-expr = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" }
datatypes = { path = "../../datatypes" } datatypes = { path = "../../datatypes" }
futures = "0.3" futures = "0.3"
prost = "0.9" prost = "0.9"

View File

@@ -14,6 +14,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use datafusion::common::DFSchemaRef;
use substrait_proto::protobuf::extensions::simple_extension_declaration::{ use substrait_proto::protobuf::extensions::simple_extension_declaration::{
ExtensionFunction, MappingType, ExtensionFunction, MappingType,
}; };
@@ -23,6 +24,7 @@ use substrait_proto::protobuf::extensions::SimpleExtensionDeclaration;
pub struct ConvertorContext { pub struct ConvertorContext {
scalar_fn_names: HashMap<String, u32>, scalar_fn_names: HashMap<String, u32>,
scalar_fn_map: HashMap<u32, String>, scalar_fn_map: HashMap<u32, String>,
df_schema: Option<DFSchemaRef>,
} }
impl ConvertorContext { impl ConvertorContext {
@@ -63,4 +65,13 @@ impl ConvertorContext {
} }
result result
} }
pub(crate) fn set_df_schema(&mut self, schema: DFSchemaRef) {
debug_assert!(self.df_schema.is_none());
self.df_schema.get_or_insert(schema);
}
pub(crate) fn df_schema(&self) -> Option<&DFSchemaRef> {
self.df_schema.as_ref()
}
} }

View File

@@ -15,8 +15,8 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::str::FromStr; use std::str::FromStr;
use datafusion::logical_plan::{Column, Expr}; use datafusion::common::Column;
use datafusion_expr::{expr_fn, BuiltinScalarFunction, Operator}; use datafusion_expr::{expr_fn, lit, Between, BinaryExpr, BuiltinScalarFunction, Expr, Operator};
use datatypes::schema::Schema; use datatypes::schema::Schema;
use snafu::{ensure, OptionExt}; use snafu::{ensure, OptionExt};
use substrait_proto::protobuf::expression::field_reference::ReferenceType as FieldReferenceType; use substrait_proto::protobuf::expression::field_reference::ReferenceType as FieldReferenceType;
@@ -24,7 +24,7 @@ use substrait_proto::protobuf::expression::reference_segment::{
ReferenceType as SegReferenceType, StructField, ReferenceType as SegReferenceType, StructField,
}; };
use substrait_proto::protobuf::expression::{ use substrait_proto::protobuf::expression::{
FieldReference, ReferenceSegment, RexType, ScalarFunction, FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction,
}; };
use substrait_proto::protobuf::function_argument::ArgType; use substrait_proto::protobuf::function_argument::ArgType;
use substrait_proto::protobuf::Expression; use substrait_proto::protobuf::Expression;
@@ -33,15 +33,24 @@ use crate::context::ConvertorContext;
use crate::error::{ use crate::error::{
EmptyExprSnafu, InvalidParametersSnafu, MissingFieldSnafu, Result, UnsupportedExprSnafu, EmptyExprSnafu, InvalidParametersSnafu, MissingFieldSnafu, Result, UnsupportedExprSnafu,
}; };
use crate::types::{literal_type_to_scalar_value, scalar_value_as_literal_type};
/// Convert substrait's `Expression` to DataFusion's `Expr`. /// Convert substrait's `Expression` to DataFusion's `Expr`.
pub fn to_df_expr(ctx: &ConvertorContext, expression: Expression, schema: &Schema) -> Result<Expr> { pub(crate) fn to_df_expr(
ctx: &ConvertorContext,
expression: Expression,
schema: &Schema,
) -> Result<Expr> {
let expr_rex_type = expression.rex_type.context(EmptyExprSnafu)?; let expr_rex_type = expression.rex_type.context(EmptyExprSnafu)?;
match expr_rex_type { match expr_rex_type {
RexType::Literal(_) => UnsupportedExprSnafu { RexType::Literal(l) => {
name: "substrait Literal expression", let t = l.literal_type.context(MissingFieldSnafu {
field: "LiteralType",
plan: "Literal",
})?;
let v = literal_type_to_scalar_value(t)?;
Ok(lit(v))
} }
.fail()?,
RexType::Selection(selection) => convert_selection_rex(*selection, schema), RexType::Selection(selection) => convert_selection_rex(*selection, schema),
RexType::ScalarFunction(scalar_fn) => convert_scalar_function(ctx, scalar_fn, schema), RexType::ScalarFunction(scalar_fn) => convert_scalar_function(ctx, scalar_fn, schema),
RexType::WindowFunction(_) RexType::WindowFunction(_)
@@ -302,21 +311,21 @@ pub fn convert_scalar_function(
// skip GetIndexedField, unimplemented. // skip GetIndexedField, unimplemented.
"between" => { "between" => {
ensure_arg_len(3)?; ensure_arg_len(3)?;
Expr::Between { Expr::Between(Between {
expr: Box::new(inputs.pop_front().unwrap()), expr: Box::new(inputs.pop_front().unwrap()),
negated: false, negated: false,
low: Box::new(inputs.pop_front().unwrap()), low: Box::new(inputs.pop_front().unwrap()),
high: Box::new(inputs.pop_front().unwrap()), high: Box::new(inputs.pop_front().unwrap()),
} })
} }
"not_between" => { "not_between" => {
ensure_arg_len(3)?; ensure_arg_len(3)?;
Expr::Between { Expr::Between(Between {
expr: Box::new(inputs.pop_front().unwrap()), expr: Box::new(inputs.pop_front().unwrap()),
negated: true, negated: true,
low: Box::new(inputs.pop_front().unwrap()), low: Box::new(inputs.pop_front().unwrap()),
high: Box::new(inputs.pop_front().unwrap()), high: Box::new(inputs.pop_front().unwrap()),
} })
} }
// skip Case, is covered in substrait::SwitchExpression. // skip Case, is covered in substrait::SwitchExpression.
// skip Cast and TryCast, is covered in substrait::Cast. // skip Cast and TryCast, is covered in substrait::Cast.
@@ -453,11 +462,22 @@ pub fn expression_from_df_expr(
} }
} }
// Don't merge them with other unsupported expr arms to preserve the ordering. // Don't merge them with other unsupported expr arms to preserve the ordering.
Expr::ScalarVariable(..) | Expr::Literal(..) => UnsupportedExprSnafu { Expr::ScalarVariable(..) => UnsupportedExprSnafu {
name: expr.to_string(), name: expr.to_string(),
} }
.fail()?, .fail()?,
Expr::BinaryExpr { left, op, right } => { Expr::Literal(v) => {
let t = scalar_value_as_literal_type(v)?;
let l = Literal {
nullable: true,
type_variation_reference: 0,
literal_type: Some(t),
};
Expression {
rex_type: Some(RexType::Literal(l)),
}
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let left = expression_from_df_expr(ctx, left, schema)?; let left = expression_from_df_expr(ctx, left, schema)?;
let right = expression_from_df_expr(ctx, right, schema)?; let right = expression_from_df_expr(ctx, right, schema)?;
let arguments = utils::expression_to_argument(vec![left, right]); let arguments = utils::expression_to_argument(vec![left, right]);
@@ -498,12 +518,12 @@ pub fn expression_from_df_expr(
name: expr.to_string(), name: expr.to_string(),
} }
.fail()?, .fail()?,
Expr::Between { Expr::Between(Between {
expr, expr,
negated, negated,
low, low,
high, high,
} => { }) => {
let expr = expression_from_df_expr(ctx, expr, schema)?; let expr = expression_from_df_expr(ctx, expr, schema)?;
let low = expression_from_df_expr(ctx, low, schema)?; let low = expression_from_df_expr(ctx, low, schema)?;
let high = expression_from_df_expr(ctx, high, schema)?; let high = expression_from_df_expr(ctx, high, schema)?;
@@ -544,7 +564,21 @@ pub fn expression_from_df_expr(
| Expr::WindowFunction { .. } | Expr::WindowFunction { .. }
| Expr::AggregateUDF { .. } | Expr::AggregateUDF { .. }
| Expr::InList { .. } | Expr::InList { .. }
| Expr::Wildcard => UnsupportedExprSnafu { | Expr::Wildcard
| Expr::Like(_)
| Expr::ILike(_)
| Expr::SimilarTo(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::ScalarSubquery(..)
| Expr::QualifiedWildcard { .. } => todo!(),
Expr::GroupingSet(_) => UnsupportedExprSnafu {
name: expr.to_string(), name: expr.to_string(),
} }
.fail()?, .fail()?,
@@ -608,6 +642,10 @@ mod utils {
Operator::RegexNotIMatch => "regex_not_i_match", Operator::RegexNotIMatch => "regex_not_i_match",
Operator::BitwiseAnd => "bitwise_and", Operator::BitwiseAnd => "bitwise_and",
Operator::BitwiseOr => "bitwise_or", Operator::BitwiseOr => "bitwise_or",
Operator::BitwiseXor => "bitwise_xor",
Operator::BitwiseShiftRight => "bitwise_shift_right",
Operator::BitwiseShiftLeft => "bitwise_shift_left",
Operator::StringConcat => "string_concat",
} }
} }
@@ -659,7 +697,6 @@ mod utils {
BuiltinScalarFunction::Sqrt => "sqrt", BuiltinScalarFunction::Sqrt => "sqrt",
BuiltinScalarFunction::Tan => "tan", BuiltinScalarFunction::Tan => "tan",
BuiltinScalarFunction::Trunc => "trunc", BuiltinScalarFunction::Trunc => "trunc",
BuiltinScalarFunction::Array => "make_array",
BuiltinScalarFunction::Ascii => "ascii", BuiltinScalarFunction::Ascii => "ascii",
BuiltinScalarFunction::BitLength => "bit_length", BuiltinScalarFunction::BitLength => "bit_length",
BuiltinScalarFunction::Btrim => "btrim", BuiltinScalarFunction::Btrim => "btrim",
@@ -703,6 +740,17 @@ mod utils {
BuiltinScalarFunction::Trim => "trim", BuiltinScalarFunction::Trim => "trim",
BuiltinScalarFunction::Upper => "upper", BuiltinScalarFunction::Upper => "upper",
BuiltinScalarFunction::RegexpMatch => "regexp_match", BuiltinScalarFunction::RegexpMatch => "regexp_match",
BuiltinScalarFunction::Atan2 => "atan2",
BuiltinScalarFunction::Coalesce => "coalesce",
BuiltinScalarFunction::Power => "power",
BuiltinScalarFunction::MakeArray => "make_array",
BuiltinScalarFunction::DateBin => "date_bin",
BuiltinScalarFunction::FromUnixtime => "from_unixtime",
BuiltinScalarFunction::CurrentDate => "current_date",
BuiltinScalarFunction::CurrentTime => "current_time",
BuiltinScalarFunction::Uuid => "uuid",
BuiltinScalarFunction::Struct => "struct",
BuiltinScalarFunction::ArrowTypeof => "arrow_type_of",
} }
} }
} }

View File

@@ -18,9 +18,11 @@ use bytes::{Buf, Bytes, BytesMut};
use catalog::CatalogManagerRef; use catalog::CatalogManagerRef;
use common_error::prelude::BoxedError; use common_error::prelude::BoxedError;
use common_telemetry::debug; use common_telemetry::debug;
use datafusion::datasource::TableProvider; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef;
use datafusion::logical_plan::{LogicalPlan, TableScan, ToDFSchema}; use datafusion::common::ToDFSchema;
use datafusion::datasource::DefaultTableSource;
use datafusion::physical_plan::project_schema; use datafusion::physical_plan::project_schema;
use datafusion_expr::{Filter, LogicalPlan, TableScan, TableSource};
use prost::Message; use prost::Message;
use snafu::{ensure, OptionExt, ResultExt}; use snafu::{ensure, OptionExt, ResultExt};
use substrait_proto::protobuf::expression::mask_expression::{StructItem, StructSelect}; use substrait_proto::protobuf::expression::mask_expression::{StructItem, StructSelect};
@@ -29,31 +31,33 @@ use substrait_proto::protobuf::extensions::simple_extension_declaration::Mapping
use substrait_proto::protobuf::plan_rel::RelType as PlanRelType; use substrait_proto::protobuf::plan_rel::RelType as PlanRelType;
use substrait_proto::protobuf::read_rel::{NamedTable, ReadType}; use substrait_proto::protobuf::read_rel::{NamedTable, ReadType};
use substrait_proto::protobuf::rel::RelType; use substrait_proto::protobuf::rel::RelType;
use substrait_proto::protobuf::{Plan, PlanRel, ReadRel, Rel}; use substrait_proto::protobuf::{FilterRel, Plan, PlanRel, ReadRel, Rel};
use table::table::adapter::DfTableProviderAdapter; use table::table::adapter::DfTableProviderAdapter;
use crate::context::ConvertorContext; use crate::context::ConvertorContext;
use crate::df_expr::{expression_from_df_expr, to_df_expr}; use crate::df_expr::{expression_from_df_expr, to_df_expr};
use crate::error::{ use crate::error::{
DFInternalSnafu, DecodeRelSnafu, EmptyPlanSnafu, EncodeRelSnafu, Error, InternalSnafu, self, DFInternalSnafu, DecodeRelSnafu, EmptyPlanSnafu, EncodeRelSnafu, Error, InternalSnafu,
InvalidParametersSnafu, MissingFieldSnafu, SchemaNotMatchSnafu, TableNotFoundSnafu, InvalidParametersSnafu, MissingFieldSnafu, SchemaNotMatchSnafu, TableNotFoundSnafu,
UnknownPlanSnafu, UnsupportedExprSnafu, UnsupportedPlanSnafu, UnknownPlanSnafu, UnsupportedExprSnafu, UnsupportedPlanSnafu,
}; };
use crate::schema::{from_schema, to_schema}; use crate::schema::{from_schema, to_schema};
use crate::SubstraitPlan; use crate::SubstraitPlan;
pub struct DFLogicalSubstraitConvertor { pub struct DFLogicalSubstraitConvertor;
catalog_manager: CatalogManagerRef,
}
impl SubstraitPlan for DFLogicalSubstraitConvertor { impl SubstraitPlan for DFLogicalSubstraitConvertor {
type Error = Error; type Error = Error;
type Plan = LogicalPlan; type Plan = LogicalPlan;
fn decode<B: Buf + Send>(&self, message: B) -> Result<Self::Plan, Self::Error> { fn decode<B: Buf + Send>(
&self,
message: B,
catalog_manager: CatalogManagerRef,
) -> Result<Self::Plan, Self::Error> {
let plan = Plan::decode(message).context(DecodeRelSnafu)?; let plan = Plan::decode(message).context(DecodeRelSnafu)?;
self.convert_plan(plan) self.convert_plan(plan, catalog_manager)
} }
fn encode(&self, plan: Self::Plan) -> Result<Bytes, Self::Error> { fn encode(&self, plan: Self::Plan) -> Result<Bytes, Self::Error> {
@@ -67,13 +71,11 @@ impl SubstraitPlan for DFLogicalSubstraitConvertor {
} }
impl DFLogicalSubstraitConvertor { impl DFLogicalSubstraitConvertor {
pub fn new(catalog_manager: CatalogManagerRef) -> Self { fn convert_plan(
Self { catalog_manager } &self,
} mut plan: Plan,
} catalog_manager: CatalogManagerRef,
) -> Result<LogicalPlan, Error> {
impl DFLogicalSubstraitConvertor {
pub fn convert_plan(&self, mut plan: Plan) -> Result<LogicalPlan, Error> {
// prepare convertor context // prepare convertor context
let mut ctx = ConvertorContext::default(); let mut ctx = ConvertorContext::default();
for simple_ext in plan.extensions { for simple_ext in plan.extensions {
@@ -99,15 +101,51 @@ impl DFLogicalSubstraitConvertor {
} }
.fail()? .fail()?
}; };
self.rel_to_logical_plan(&mut ctx, Box::new(rel), catalog_manager)
}
fn rel_to_logical_plan(
&self,
ctx: &mut ConvertorContext,
rel: Box<Rel>,
catalog_manager: CatalogManagerRef,
) -> Result<LogicalPlan, Error> {
let rel_type = rel.rel_type.context(EmptyPlanSnafu)?; let rel_type = rel.rel_type.context(EmptyPlanSnafu)?;
// build logical plan // build logical plan
let logical_plan = match rel_type { let logical_plan = match rel_type {
RelType::Read(read_rel) => self.convert_read_rel(&mut ctx, read_rel), RelType::Read(read_rel) => self.convert_read_rel(ctx, read_rel, catalog_manager)?,
RelType::Filter(_filter_rel) => UnsupportedPlanSnafu { RelType::Filter(filter) => {
name: "Filter Relation", let FilterRel {
common: _,
input,
condition,
advanced_extension: _,
} = *filter;
let input = input.context(MissingFieldSnafu {
field: "input",
plan: "Filter",
})?;
let input = Arc::new(self.rel_to_logical_plan(ctx, input, catalog_manager)?);
let condition = condition.context(MissingFieldSnafu {
field: "condition",
plan: "Filter",
})?;
let schema = ctx.df_schema().context(InvalidParametersSnafu {
reason: "the underlying TableScan plan should have included a table schema",
})?;
let schema = schema
.clone()
.try_into()
.context(error::ConvertDfSchemaSnafu)?;
let predicate = to_df_expr(ctx, *condition, &schema)?;
LogicalPlan::Filter(Filter::try_new(predicate, input).context(DFInternalSnafu)?)
} }
.fail()?,
RelType::Fetch(_fetch_rel) => UnsupportedPlanSnafu { RelType::Fetch(_fetch_rel) => UnsupportedPlanSnafu {
name: "Fetch Relation", name: "Fetch Relation",
} }
@@ -148,7 +186,7 @@ impl DFLogicalSubstraitConvertor {
name: "Cross Relation", name: "Cross Relation",
} }
.fail()?, .fail()?,
}?; };
Ok(logical_plan) Ok(logical_plan)
} }
@@ -157,6 +195,7 @@ impl DFLogicalSubstraitConvertor {
&self, &self,
ctx: &mut ConvertorContext, ctx: &mut ConvertorContext,
read_rel: Box<ReadRel>, read_rel: Box<ReadRel>,
catalog_manager: CatalogManagerRef,
) -> Result<LogicalPlan, Error> { ) -> Result<LogicalPlan, Error> {
// Extract the catalog, schema and table name from NamedTable. Assume the first three are those names. // Extract the catalog, schema and table name from NamedTable. Assume the first three are those names.
let read_type = read_rel.read_type.context(MissingFieldSnafu { let read_type = read_rel.read_type.context(MissingFieldSnafu {
@@ -192,22 +231,23 @@ impl DFLogicalSubstraitConvertor {
.map(|mask_expr| self.convert_mask_expression(mask_expr)); .map(|mask_expr| self.convert_mask_expression(mask_expr));
// Get table handle from catalog manager // Get table handle from catalog manager
let table_ref = self let table_ref = catalog_manager
.catalog_manager
.table(&catalog_name, &schema_name, &table_name) .table(&catalog_name, &schema_name, &table_name)
.map_err(BoxedError::new) .map_err(BoxedError::new)
.context(InternalSnafu)? .context(InternalSnafu)?
.context(TableNotFoundSnafu { .context(TableNotFoundSnafu {
name: format!("{}.{}.{}", catalog_name, schema_name, table_name), name: format!("{}.{}.{}", catalog_name, schema_name, table_name),
})?; })?;
let adapter = Arc::new(DfTableProviderAdapter::new(table_ref)); let adapter = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(table_ref),
)));
// Get schema directly from the table, and compare it with the schema retrieved from substrait proto. // Get schema directly from the table, and compare it with the schema retrieved from substrait proto.
let stored_schema = adapter.schema(); let stored_schema = adapter.schema();
let retrieved_schema = to_schema(read_rel.base_schema.unwrap_or_default())?; let retrieved_schema = to_schema(read_rel.base_schema.unwrap_or_default())?;
let retrieved_arrow_schema = retrieved_schema.arrow_schema(); let retrieved_arrow_schema = retrieved_schema.arrow_schema();
ensure!( ensure!(
stored_schema.fields == retrieved_arrow_schema.fields, same_schema_without_metadata(&stored_schema, retrieved_arrow_schema),
SchemaNotMatchSnafu { SchemaNotMatchSnafu {
substrait_schema: retrieved_arrow_schema.clone(), substrait_schema: retrieved_arrow_schema.clone(),
storage_schema: stored_schema storage_schema: stored_schema
@@ -227,14 +267,16 @@ impl DFLogicalSubstraitConvertor {
.to_dfschema_ref() .to_dfschema_ref()
.context(DFInternalSnafu)?; .context(DFInternalSnafu)?;
// TODO(ruihang): Support filters and limit ctx.set_df_schema(projected_schema.clone());
// TODO(ruihang): Support limit(fetch)
Ok(LogicalPlan::TableScan(TableScan { Ok(LogicalPlan::TableScan(TableScan {
table_name, table_name: format!("{}.{}.{}", catalog_name, schema_name, table_name),
source: adapter, source: adapter,
projection, projection,
projected_schema, projected_schema,
filters, filters,
limit: None, fetch: None,
})) }))
} }
@@ -250,20 +292,42 @@ impl DFLogicalSubstraitConvertor {
} }
impl DFLogicalSubstraitConvertor { impl DFLogicalSubstraitConvertor {
pub fn convert_df_plan(&self, plan: LogicalPlan) -> Result<Plan, Error> { fn logical_plan_to_rel(
let mut ctx = ConvertorContext::default(); &self,
ctx: &mut ConvertorContext,
// TODO(ruihang): extract this translation logic into a separated function plan: Arc<LogicalPlan>,
// convert PlanRel ) -> Result<Rel, Error> {
let rel = match plan { Ok(match &*plan {
LogicalPlan::Projection(_) => UnsupportedPlanSnafu { LogicalPlan::Projection(_) => UnsupportedPlanSnafu {
name: "DataFusion Logical Projection", name: "DataFusion Logical Projection",
} }
.fail()?, .fail()?,
LogicalPlan::Filter(_) => UnsupportedPlanSnafu { LogicalPlan::Filter(filter) => {
name: "DataFusion Logical Filter", let input = Some(Box::new(
self.logical_plan_to_rel(ctx, filter.input().clone())?,
));
let schema = plan
.schema()
.clone()
.try_into()
.context(error::ConvertDfSchemaSnafu)?;
let condition = Some(Box::new(expression_from_df_expr(
ctx,
filter.predicate(),
&schema,
)?));
let rel = FilterRel {
common: None,
input,
condition,
advanced_extension: None,
};
Rel {
rel_type: Some(RelType::Filter(Box::new(rel))),
}
} }
.fail()?,
LogicalPlan::Window(_) => UnsupportedPlanSnafu { LogicalPlan::Window(_) => UnsupportedPlanSnafu {
name: "DataFusion Logical Window", name: "DataFusion Logical Window",
} }
@@ -293,7 +357,7 @@ impl DFLogicalSubstraitConvertor {
} }
.fail()?, .fail()?,
LogicalPlan::TableScan(table_scan) => { LogicalPlan::TableScan(table_scan) => {
let read_rel = self.convert_table_scan_plan(&mut ctx, table_scan)?; let read_rel = self.convert_table_scan_plan(ctx, table_scan)?;
Rel { Rel {
rel_type: Some(RelType::Read(Box::new(read_rel))), rel_type: Some(RelType::Read(Box::new(read_rel))),
} }
@@ -306,7 +370,16 @@ impl DFLogicalSubstraitConvertor {
name: "DataFusion Logical Limit", name: "DataFusion Logical Limit",
} }
.fail()?, .fail()?,
LogicalPlan::CreateExternalTable(_)
LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::CreateView(_)
| LogicalPlan::CreateCatalogSchema(_)
| LogicalPlan::CreateCatalog(_)
| LogicalPlan::DropView(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::SetVariable(_)
| LogicalPlan::CreateExternalTable(_)
| LogicalPlan::CreateMemoryTable(_) | LogicalPlan::CreateMemoryTable(_)
| LogicalPlan::DropTable(_) | LogicalPlan::DropTable(_)
| LogicalPlan::Values(_) | LogicalPlan::Values(_)
@@ -319,7 +392,13 @@ impl DFLogicalSubstraitConvertor {
), ),
} }
.fail()?, .fail()?,
}; })
}
fn convert_df_plan(&self, plan: LogicalPlan) -> Result<Plan, Error> {
let mut ctx = ConvertorContext::default();
let rel = self.logical_plan_to_rel(&mut ctx, Arc::new(plan))?;
// convert extension // convert extension
let extensions = ctx.generate_function_extension(); let extensions = ctx.generate_function_extension();
@@ -341,11 +420,15 @@ impl DFLogicalSubstraitConvertor {
pub fn convert_table_scan_plan( pub fn convert_table_scan_plan(
&self, &self,
ctx: &mut ConvertorContext, ctx: &mut ConvertorContext,
table_scan: TableScan, table_scan: &TableScan,
) -> Result<ReadRel, Error> { ) -> Result<ReadRel, Error> {
let provider = table_scan let provider = table_scan
.source .source
.as_any() .as_any()
.downcast_ref::<DefaultTableSource>()
.context(UnknownPlanSnafu)?
.table_provider
.as_any()
.downcast_ref::<DfTableProviderAdapter>() .downcast_ref::<DfTableProviderAdapter>()
.context(UnknownPlanSnafu)?; .context(UnknownPlanSnafu)?;
let table_info = provider.table().table_info(); let table_info = provider.table().table_info();
@@ -363,7 +446,8 @@ impl DFLogicalSubstraitConvertor {
// assemble projection // assemble projection
let projection = table_scan let projection = table_scan
.projection .projection
.map(|proj| self.convert_schema_projection(&proj)); .as_ref()
.map(|x| self.convert_schema_projection(x));
// assemble base (unprojected) schema using Table's schema. // assemble base (unprojected) schema using Table's schema.
let base_schema = from_schema(&provider.table().schema())?; let base_schema = from_schema(&provider.table().schema())?;
@@ -371,7 +455,8 @@ impl DFLogicalSubstraitConvertor {
// make conjunction over a list of filters and convert the result to substrait // make conjunction over a list of filters and convert the result to substrait
let filter = if let Some(conjunction) = table_scan let filter = if let Some(conjunction) = table_scan
.filters .filters
.into_iter() .iter()
.cloned()
.reduce(|accum, expr| accum.and(expr)) .reduce(|accum, expr| accum.and(expr))
{ {
Some(Box::new(expression_from_df_expr( Some(Box::new(expression_from_df_expr(
@@ -412,12 +497,21 @@ impl DFLogicalSubstraitConvertor {
} }
} }
fn same_schema_without_metadata(lhs: &ArrowSchemaRef, rhs: &ArrowSchemaRef) -> bool {
lhs.fields.len() == rhs.fields.len()
&& lhs.fields.iter().zip(rhs.fields.iter()).all(|(x, y)| {
x.name() == y.name()
&& x.data_type() == y.data_type()
&& x.is_nullable() == y.is_nullable()
})
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use catalog::local::{LocalCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; use catalog::local::{LocalCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
use catalog::{CatalogList, CatalogProvider, RegisterTableRequest}; use catalog::{CatalogList, CatalogProvider, RegisterTableRequest};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use datafusion::logical_plan::DFSchema; use datafusion::common::{DFSchema, ToDFSchema};
use datatypes::schema::Schema; use datatypes::schema::Schema;
use table::requests::CreateTableRequest; use table::requests::CreateTableRequest;
use table::test_util::{EmptyTable, MockTableEngine}; use table::test_util::{EmptyTable, MockTableEngine};
@@ -463,10 +557,10 @@ mod test {
} }
async fn logical_plan_round_trip(plan: LogicalPlan, catalog: CatalogManagerRef) { async fn logical_plan_round_trip(plan: LogicalPlan, catalog: CatalogManagerRef) {
let convertor = DFLogicalSubstraitConvertor::new(catalog); let convertor = DFLogicalSubstraitConvertor;
let proto = convertor.encode(plan.clone()).unwrap(); let proto = convertor.encode(plan.clone()).unwrap();
let tripped_plan = convertor.decode(proto).unwrap(); let tripped_plan = convertor.decode(proto, catalog).unwrap();
assert_eq!(format!("{:?}", plan), format!("{:?}", tripped_plan)); assert_eq!(format!("{:?}", plan), format!("{:?}", tripped_plan));
} }
@@ -487,7 +581,10 @@ mod test {
}) })
.await .await
.unwrap(); .unwrap();
let adapter = Arc::new(DfTableProviderAdapter::new(table_ref)); let adapter = Arc::new(DefaultTableSource::new(Arc::new(
DfTableProviderAdapter::new(table_ref),
)));
let projection = vec![1, 3, 5]; let projection = vec![1, 3, 5];
let df_schema = adapter.schema().to_dfschema().unwrap(); let df_schema = adapter.schema().to_dfschema().unwrap();
let projected_fields = projection let projected_fields = projection
@@ -498,12 +595,15 @@ mod test {
Arc::new(DFSchema::new_with_metadata(projected_fields, Default::default()).unwrap()); Arc::new(DFSchema::new_with_metadata(projected_fields, Default::default()).unwrap());
let table_scan_plan = LogicalPlan::TableScan(TableScan { let table_scan_plan = LogicalPlan::TableScan(TableScan {
table_name: DEFAULT_TABLE_NAME.to_string(), table_name: format!(
"{}.{}.{}",
DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, DEFAULT_TABLE_NAME
),
source: adapter, source: adapter,
projection: Some(projection), projection: Some(projection),
projected_schema, projected_schema,
filters: vec![], filters: vec![],
limit: None, fetch: None,
}); });
logical_plan_round_trip(table_scan_plan, catalog_manager).await; logical_plan_round_trip(table_scan_plan, catalog_manager).await;

View File

@@ -99,6 +99,12 @@ pub enum Error {
storage_schema: datafusion::arrow::datatypes::SchemaRef, storage_schema: datafusion::arrow::datatypes::SchemaRef,
backtrace: Backtrace, backtrace: Backtrace,
}, },
#[snafu(display("Failed to convert DataFusion schema, source: {}", source))]
ConvertDfSchema {
#[snafu(backtrace)]
source: datatypes::error::Error,
},
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;
@@ -120,6 +126,7 @@ impl ErrorExt for Error {
| Error::TableNotFound { .. } | Error::TableNotFound { .. }
| Error::SchemaNotMatch { .. } => StatusCode::InvalidArguments, | Error::SchemaNotMatch { .. } => StatusCode::InvalidArguments,
Error::DFInternal { .. } | Error::Internal { .. } => StatusCode::Internal, Error::DFInternal { .. } | Error::Internal { .. } => StatusCode::Internal,
Error::ConvertDfSchema { source } => source.status_code(),
} }
} }

View File

@@ -22,6 +22,7 @@ mod schema;
mod types; mod types;
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use catalog::CatalogManagerRef;
pub use crate::df_logical::DFLogicalSubstraitConvertor; pub use crate::df_logical::DFLogicalSubstraitConvertor;
@@ -30,7 +31,11 @@ pub trait SubstraitPlan {
type Plan; type Plan;
fn decode<B: Buf + Send>(&self, message: B) -> Result<Self::Plan, Self::Error>; fn decode<B: Buf + Send>(
&self,
message: B,
catalog_manager: CatalogManagerRef,
) -> Result<Self::Plan, Self::Error>;
fn encode(&self, plan: Self::Plan) -> Result<Bytes, Self::Error>; fn encode(&self, plan: Self::Plan) -> Result<Bytes, Self::Error>;
} }

View File

@@ -18,11 +18,13 @@
//! Current we only have variations on integer types. Variation 0 (system preferred) are the same with base types, which //! Current we only have variations on integer types. Variation 0 (system preferred) are the same with base types, which
//! are signed integer (i.e. I8 -> [i8]), and Variation 1 stands for unsigned integer (i.e. I8 -> [u8]). //! are signed integer (i.e. I8 -> [i8]), and Variation 1 stands for unsigned integer (i.e. I8 -> [u8]).
use datafusion::scalar::ScalarValue;
use datatypes::prelude::ConcreteDataType; use datatypes::prelude::ConcreteDataType;
use substrait_proto::protobuf::expression::literal::LiteralType;
use substrait_proto::protobuf::r#type::{self as s_type, Kind, Nullability}; use substrait_proto::protobuf::r#type::{self as s_type, Kind, Nullability};
use substrait_proto::protobuf::Type as SType; use substrait_proto::protobuf::{Type as SType, Type};
use crate::error::{Result, UnsupportedConcreteTypeSnafu, UnsupportedSubstraitTypeSnafu}; use crate::error::{self, Result, UnsupportedConcreteTypeSnafu, UnsupportedSubstraitTypeSnafu};
macro_rules! substrait_kind { macro_rules! substrait_kind {
($desc:ident, $concrete_ty:ident) => {{ ($desc:ident, $concrete_ty:ident) => {{
@@ -134,3 +136,67 @@ pub fn from_concrete_type(ty: ConcreteDataType, nullability: Option<bool>) -> Re
Ok(SType { kind }) Ok(SType { kind })
} }
pub(crate) fn scalar_value_as_literal_type(v: &ScalarValue) -> Result<LiteralType> {
Ok(if v.is_null() {
LiteralType::Null(Type { kind: None })
} else {
match v {
ScalarValue::Boolean(Some(v)) => LiteralType::Boolean(*v),
ScalarValue::Float32(Some(v)) => LiteralType::Fp32(*v),
ScalarValue::Float64(Some(v)) => LiteralType::Fp64(*v),
ScalarValue::Int8(Some(v)) => LiteralType::I8(*v as i32),
ScalarValue::Int16(Some(v)) => LiteralType::I16(*v as i32),
ScalarValue::Int32(Some(v)) => LiteralType::I32(*v),
ScalarValue::Int64(Some(v)) => LiteralType::I64(*v),
ScalarValue::LargeUtf8(Some(v)) => LiteralType::String(v.clone()),
ScalarValue::LargeBinary(Some(v)) => LiteralType::Binary(v.clone()),
// TODO(LFC): Implement other conversions: ScalarValue => LiteralType
_ => {
return error::UnsupportedExprSnafu {
name: format!("{:?}", v),
}
.fail()
}
}
})
}
pub(crate) fn literal_type_to_scalar_value(t: LiteralType) -> Result<ScalarValue> {
Ok(match t {
LiteralType::Null(Type { kind: Some(kind) }) => match kind {
Kind::Bool(_) => ScalarValue::Boolean(None),
Kind::I8(_) => ScalarValue::Int8(None),
Kind::I16(_) => ScalarValue::Int16(None),
Kind::I32(_) => ScalarValue::Int32(None),
Kind::I64(_) => ScalarValue::Int64(None),
Kind::Fp32(_) => ScalarValue::Float32(None),
Kind::Fp64(_) => ScalarValue::Float64(None),
Kind::String(_) => ScalarValue::LargeUtf8(None),
Kind::Binary(_) => ScalarValue::LargeBinary(None),
// TODO(LFC): Implement other conversions: Kind => ScalarValue
_ => {
return error::UnsupportedSubstraitTypeSnafu {
ty: format!("{:?}", kind),
}
.fail()
}
},
LiteralType::Boolean(v) => ScalarValue::Boolean(Some(v)),
LiteralType::I8(v) => ScalarValue::Int8(Some(v as i8)),
LiteralType::I16(v) => ScalarValue::Int16(Some(v as i16)),
LiteralType::I32(v) => ScalarValue::Int32(Some(v)),
LiteralType::I64(v) => ScalarValue::Int64(Some(v)),
LiteralType::Fp32(v) => ScalarValue::Float32(Some(v)),
LiteralType::Fp64(v) => ScalarValue::Float64(Some(v)),
LiteralType::String(v) => ScalarValue::LargeUtf8(Some(v)),
LiteralType::Binary(v) => ScalarValue::LargeBinary(Some(v)),
// TODO(LFC): Implement other conversions: LiteralType => ScalarValue
_ => {
return error::UnsupportedSubstraitTypeSnafu {
ty: format!("{:?}", t),
}
.fail()
}
})
}

View File

@@ -55,8 +55,11 @@ impl From<i32> for Date {
impl Display for Date { impl Display for Date {
/// [Date] is formatted according to ISO-8601 standard. /// [Date] is formatted according to ISO-8601 standard.
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let abs_date = NaiveDate::from_num_days_from_ce(UNIX_EPOCH_FROM_CE + self.0); if let Some(abs_date) = NaiveDate::from_num_days_from_ce_opt(UNIX_EPOCH_FROM_CE + self.0) {
f.write_str(&abs_date.format("%F").to_string()) write!(f, "{}", abs_date.format("%F"))
} else {
write!(f, "Date({})", self.0)
}
} }
} }
@@ -95,7 +98,7 @@ mod tests {
Date::from_str("1969-01-01").unwrap().to_string() Date::from_str("1969-01-01").unwrap().to_string()
); );
let now = Utc::now().date().format("%F").to_string(); let now = Utc::now().date_naive().format("%F").to_string();
assert_eq!(now, Date::from_str(&now).unwrap().to_string()); assert_eq!(now, Date::from_str(&now).unwrap().to_string());
} }

View File

@@ -31,8 +31,11 @@ pub struct DateTime(i64);
impl Display for DateTime { impl Display for DateTime {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let abs_time = NaiveDateTime::from_timestamp(self.0, 0); if let Some(abs_time) = NaiveDateTime::from_timestamp_opt(self.0, 0) {
write!(f, "{}", abs_time.format(DATETIME_FORMAT)) write!(f, "{}", abs_time.format(DATETIME_FORMAT))
} else {
write!(f, "DateTime({})", self.0)
}
} }
} }

Some files were not shown because too many files have changed in this diff Show More