Compare commits

...

35 Commits

Author SHA1 Message Date
discord9
03954e8b3b chore: update proto (#6992)
* chore: update proto

Signed-off-by: discord9 <discord9@163.com>

* update lockfile

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
2025-09-19 09:01:46 +00:00
Yingwen
bd8f5d2b71 fix: print the output message of the error in admin fn macro (#6994)
Signed-off-by: evenyag <realevenyag@gmail.com>
2025-09-19 08:11:19 +00:00
Weny Xu
74721a06ba chore: improve error logging in WAL prune manager (#6993)
Signed-off-by: WenyXu <wenymedia@gmail.com>
2025-09-19 07:08:28 +00:00
discord9
18e4839a17 feat: datanode side local gc worker (#6940)
* wip

Signed-off-by: discord9 <discord9@163.com>

* docs for behavior

Signed-off-by: discord9 <discord9@163.com>

* wip: handle outdated version

Signed-off-by: discord9 <discord9@163.com>

* feat: just retry

Signed-off-by: discord9 <discord9@163.com>

* feat: smaller lingering time

Signed-off-by: discord9 <discord9@163.com>

* refactor: partial per review

Signed-off-by: discord9 <discord9@163.com>

* refactor: rm tmp file cnt

Signed-off-by: discord9 <discord9@163.com>

* chore: per review

Signed-off-by: discord9 <discord9@163.com>

* chore: opt partial

Signed-off-by: discord9 <discord9@163.com>

* chore: rebase fix

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
2025-09-19 03:20:13 +00:00
LFC
cbe0cf4a74 refactor: rewrite some UDFs to DataFusion style (part 2) (#6967)
* refactor: rewrite some UDFs to DataFusion style (part 2)

Signed-off-by: luofucong <luofc@foxmail.com>

* deal with vector UDFs `(scalar, scalar)` situation, and try getting the scalar value reference everytime

Signed-off-by: luofucong <luofc@foxmail.com>

* reduce some vector literal parsing

Signed-off-by: luofucong <luofc@foxmail.com>

* fix ci

Signed-off-by: luofucong <luofc@foxmail.com>

---------

Signed-off-by: luofucong <luofc@foxmail.com>
2025-09-18 06:37:27 +00:00
discord9
e26b98f452 refactor: put FileId to store-api (#6988)
* refactor: put FileId to store-api

Signed-off-by: discord9 <discord9@163.com>

* per review

Signed-off-by: discord9 <discord9@163.com>

* chore: lock file

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
2025-09-18 03:20:42 +00:00
localhost
d8b967408e chore: modify LogExpr AggrFunc (#6948)
* chore: modify  LogExpr AggrFunc

* chore: change AggrFunc range field

* chore: remove range from aggrfunc
2025-09-17 12:19:48 +00:00
Weny Xu
c35407fdce refactor: region follower management with unified interface (#6986)
Signed-off-by: WenyXu <wenymedia@gmail.com>
2025-09-17 10:01:03 +00:00
Lei, HUANG
edf4b3f7f8 chore: unset tz env in test (#6984)
chore/unset-tz-env-in-test:
 ### Commit Message

 Add environment variable cleanup in timezone tests

 - Updated `timezone.rs` to include removal of the `TZ` environment variable in the `test_from_tz_string` function to ensure a clean test environment.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>
2025-09-17 08:48:38 +00:00
Yingwen
14550429e9 chore: reduce SeriesScan sender timeout (#6983)
Signed-off-by: evenyag <realevenyag@gmail.com>
2025-09-17 07:02:47 +00:00
shuiyisong
ff2da4903e fix: OTel metrics naming wiht Prometheus style (#6982)
* fix: otel metrics naming

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* fix: otel metrics naming & add some tests

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

---------

Signed-off-by: shuiyisong <xixing.sys@gmail.com>
2025-09-17 06:11:38 +00:00
Lei, HUANG
c92ab4217f fix: avoid truncating SST statistics during flush (#6977)
fix/disable-parquet-stats-truncate:
 - **Update `memcomparable` Dependency**: Switched from crates.io to a Git repository for `memcomparable` in `Cargo.lock`, `mito-codec/Cargo.toml`, and removed it from `mito2/Cargo.toml`.
 - **Enhance Parquet Writer Properties**: Added `set_statistics_truncate_length` and `set_column_index_truncate_length` to `WriterProperties` in `parquet.rs`, `bulk/part.rs`, `partition_tree/data.rs`, and `writer.rs`.
 - **Add Test for Corrupt Scan**: Introduced a new test module `scan_corrupt.rs` in `mito2/src/engine` to verify handling of corrupt data.
 - **Update Test Data**: Modified test data in `flush.rs` to reflect changes in file sizes and sequences.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>
2025-09-17 03:02:52 +00:00
Zhenchi
77981a7de5 fix: clean intm ignore notfound (#6971)
* fix: clean intm ignore notfound

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>

* address comments

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>

---------

Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>
2025-09-17 02:58:03 +00:00
Lei, HUANG
9096c5ebbf chore: bump sequence on region edit (#6947)
* chore/update-sequence-on-region-edit:
 ### Commit Message

 Refactor `get_last_seq_num` Method Across Engines

 - **Change Return Type**: Updated the `get_last_seq_num` method to return `Result<SequenceNumber, BoxedError>` instead of `Result<Option<SequenceNumber>, BoxedError>` in the following files:
   - `src/datanode/src/tests.rs`
   - `src/file-engine/src/engine.rs`
   - `src/metric-engine/src/engine.rs`
   - `src/metric-engine/src/engine/read.rs`
   - `src/mito2/src/engine.rs`
   - `src/query/src/optimizer/test_util.rs`
   - `src/store-api/src/region_engine.rs`

 - **Enhance Region Edit Handling**: Modified `RegionWorkerLoop` in `src/mito2/src/worker/handle_manifest.rs` to update file sequences during region edits.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* add committed_sequence to RegionEdit

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* chore/update-sequence-on-region-edit:
 ### Commit Message

 Refactor sequence retrieval method

 - **Renamed Method**: Changed `get_last_seq_num` to `get_committed_sequence` across multiple files to better reflect its purpose of retrieving the latest committed sequence.
   - Affected files: `tests.rs`, `engine.rs` in `file-engine`, `metric-engine`, `mito2`, `test_util.rs`, and `region_engine.rs`.
 - **Removed Unused Struct**: Deleted `RegionSequencesRequest` struct from `region_request.rs` as it is no longer needed.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* chore/update-sequence-on-region-edit:
 **Add Committed Sequence Handling in Region Engine**

 - **`engine.rs`**: Introduced a new test module `bump_committed_sequence_test` to verify committed sequence handling.
 - **`bump_committed_sequence_test.rs`**: Added a test to ensure the committed sequence is correctly updated and persisted across region reopenings.
 - **`action.rs`**: Updated `RegionManifest` and `RegionManifestBuilder` to include `committed_sequence` for tracking.
 - **`manager.rs`**: Adjusted manifest size assertion to accommodate new committed sequence data.
 - **`opener.rs`**: Implemented logic to override committed sequence during region opening.
 - **`version.rs`**: Added `set_committed_sequence` method to update the committed sequence in `VersionControl`.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* chore/update-sequence-on-region-edit:
 **Enhance `test_bump_committed_sequence` in `bump_committed_sequence_test.rs`**

 - Updated the test to include row operations using `build_rows`, `put_rows`, and `rows_schema` to verify the committed sequence behavior.
 - Adjusted assertions to reflect changes in committed sequence after row operations and region edits.
 - Added comments to clarify the expected behavior of committed sequence after reopening the region and replaying the WAL.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* chore/update-sequence-on-region-edit:
 **Enhance Region Sequence Management**

 - **`bump_committed_sequence_test.rs`**: Updated test to handle region reopening and sequence management, ensuring committed sequences are correctly set and verified after edits.
 - **`opener.rs`**: Improved committed sequence handling by overriding it only if the manifest's sequence is greater than the replayed sequence. Added logging for mutation sequence replay.
 - **`region_write_ctx.rs`**: Modified `push_mutation` and `push_bulk` methods to adopt sequence numbers from parameters, enhancing sequence management during write operations.
 - **`handle_write.rs`**: Updated `RegionWorkerLoop` to pass sequence numbers in `push_bulk` and `push_mutation` methods, ensuring consistent sequence handling.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

* chore/update-sequence-on-region-edit:
 ### Remove Debug Logging from `opener.rs`

 - Removed debug logging for mutation sequences in `opener.rs` to clean up the output and improve performance.

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>

---------

Signed-off-by: Lei, HUANG <mrsatangel@gmail.com>
2025-09-16 16:22:25 +00:00
Weny Xu
0a959f9920 feat: add TLS support for mysql backend (#6979)
* refactor: move etcd tls code to `common-meta`

Signed-off-by: WenyXu <wenymedia@gmail.com>

* refactor: move postgre pool logic to `utils::postgre`

Signed-off-by: WenyXu <wenymedia@gmail.com>

* feat: setup mysql ssl options

Signed-off-by: WenyXu <wenymedia@gmail.com>

* feat: add test for mysql backend with tls

Signed-off-by: WenyXu <wenymedia@gmail.com>

* refactor: simplify certs generation

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: apply suggestions

Signed-off-by: WenyXu <wenymedia@gmail.com>

---------

Signed-off-by: WenyXu <wenymedia@gmail.com>
2025-09-16 13:46:37 +00:00
discord9
85c1a91bae feat: support SubqueryAlias pushdown (#6963)
* wip enforce dist requirement rewriter

Signed-off-by: discord9 <discord9@163.com>

* feat: enforce dist req

Signed-off-by: discord9 <discord9@163.com>

* test: sqlness result

Signed-off-by: discord9 <discord9@163.com>

* fix: double projection

Signed-off-by: discord9 <discord9@163.com>

* test: fix sqlness

Signed-off-by: discord9 <discord9@163.com>

* refactor: per review

Signed-off-by: discord9 <discord9@163.com>

* docs: use btree map

Signed-off-by: discord9 <discord9@163.com>

* test: sqlness explain&comment

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
2025-09-16 13:27:35 +00:00
Weny Xu
7aba9a18fd chore: add tests for postgre backend with tls (#6973)
* chore: add tests for postgre backend with tls

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: minor

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: apply suggestions

Signed-off-by: WenyXu <wenymedia@gmail.com>

---------

Signed-off-by: WenyXu <wenymedia@gmail.com>
2025-09-16 11:03:11 +00:00
shuiyisong
4c18d140b4 fix: deadlock in dashmap (#6978)
* fix: deadlock in dashmap

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

* Update src/frontend/src/instance.rs

Co-authored-by: Yingwen <realevenyag@gmail.com>

* chore: extract fast cache check and add test

Signed-off-by: shuiyisong <xixing.sys@gmail.com>

---------

Signed-off-by: shuiyisong <xixing.sys@gmail.com>
Co-authored-by: Yingwen <realevenyag@gmail.com>
2025-09-16 10:49:28 +00:00
Yingwen
b8e0c49cb4 feat: add an flag to enable the experimental flat format (#6976)
* feat: add enable_experimental_flat_format flag to enable flat format

Signed-off-by: evenyag <realevenyag@gmail.com>

* refactor: extract build_scan_input for CompactionSstReaderBuilder

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: add compact memtable cost to flush metrics

Signed-off-by: evenyag <realevenyag@gmail.com>

* feat: Sets compact dispatcher for bulk memtable

Signed-off-by: evenyag <realevenyag@gmail.com>

* feat: Cast dictionary to target type in FlatProjectionMapper

Signed-off-by: evenyag <realevenyag@gmail.com>

* fix: add time index to FlatProjectionMapper::batch_schema

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: update config toml

Signed-off-by: evenyag <realevenyag@gmail.com>

* fix: pass flat_format to ProjectionMapper in CompactionSstReaderBuilder

Signed-off-by: evenyag <realevenyag@gmail.com>

---------

Signed-off-by: evenyag <realevenyag@gmail.com>
2025-09-16 09:33:12 +00:00
Zhenchi
db42ad42dc feat: add visible to sst entry for staging mode (#6964)
Signed-off-by: Zhenchi <zhongzc_arch@outlook.com>
2025-09-15 09:05:54 +00:00
shuiyisong
8ce963f63e fix: shorten lock time (#6968) 2025-09-15 03:37:36 +00:00
Yingwen
b3aabb6706 feat: support flush and compact flat format files (#6949)
* feat: basic functions for flush/compact flat format

Signed-off-by: evenyag <realevenyag@gmail.com>

* feat: bridge flush and compaction for flat format

Signed-off-by: evenyag <realevenyag@gmail.com>

* feat: add write cache support

Signed-off-by: evenyag <realevenyag@gmail.com>

* style: fix clippy

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: change log level to debug

Signed-off-by: evenyag <realevenyag@gmail.com>

* refactor: wrap duplicated code to merge and dedup iter

Signed-off-by: evenyag <realevenyag@gmail.com>

* refactor: wrap some code into flush_flat_mem_ranges

Signed-off-by: evenyag <realevenyag@gmail.com>

* refactor: extract logic into do_flush_memtables

Signed-off-by: evenyag <realevenyag@gmail.com>

---------

Signed-off-by: evenyag <realevenyag@gmail.com>
2025-09-14 13:36:24 +00:00
Ning Sun
028effe952 docs: update memory profiling description doc (#6960)
doc: update memory profiling description doc
2025-09-12 08:30:22 +00:00
Ruihang Xia
d86f489a74 fix: staging mode with proper region edit operations (#6962)
Signed-off-by: Ruihang Xia <waynestxia@gmail.com>
2025-09-12 04:39:42 +00:00
dennis zhuang
6c066c1a4a test: migrate join tests from duckdb, part3 (#6881)
* test: migrate join tests

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* chore: update test results after rebasing main branch

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: unstable query sort results and natural_join test

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: count(*) with joining

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

* fix: unstable query sort results and style

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>

---------

Signed-off-by: Dennis Zhuang <killme2008@gmail.com>
2025-09-12 04:20:00 +00:00
LFC
9ab87e11a4 refactor: rewrite h3 functions to DataFusion style (#6942)
* refactor: rewrite h3 functions to DataFusion style

Signed-off-by: luofucong <luofc@foxmail.com>

* resolve PR comments

Signed-off-by: luofucong <luofc@foxmail.com>

---------

Signed-off-by: luofucong <luofc@foxmail.com>
2025-09-12 02:27:24 +00:00
Weny Xu
9fe7069146 feat: add postgres tls support for CLI (#6941)
* feat: add postgres tls support for cli

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: apply suggestions

Signed-off-by: WenyXu <wenymedia@gmail.com>

---------

Signed-off-by: WenyXu <wenymedia@gmail.com>
2025-09-11 12:18:13 +00:00
fys
733a1afcd1 fix: correct jemalloc metrics (#6959)
The allocated and resident metrics were swapped in the set calls. This commit
fixes the issue by ensuring each metric receives its corresponding value.
2025-09-11 06:37:19 +00:00
Yingwen
5e65581f94 feat: support flat format for SeriesScan (#6938)
* feat: Support flat format for SeriesScan

Signed-off-by: evenyag <realevenyag@gmail.com>

* test: simplify tests

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: update comment

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: only accumulate fetch time to scan_cost in SeriesDistributor of
the SeriesScan

Signed-off-by: evenyag <realevenyag@gmail.com>

* chore: update comment

Signed-off-by: evenyag <realevenyag@gmail.com>

---------

Signed-off-by: evenyag <realevenyag@gmail.com>
2025-09-11 06:12:53 +00:00
ZonaHe
e75e5baa63 feat: update dashboard to v0.11.4 (#6956)
Co-authored-by: sunchanglong <sunchanglong@users.noreply.github.com>
2025-09-11 04:34:25 +00:00
zyy17
c4b89df523 fix: use pull_request_target to fix add labels 403 error (#6958)
Signed-off-by: zyy17 <zyylsxm@gmail.com>
2025-09-11 03:53:14 +00:00
Weny Xu
6a15e62719 feat: expose workload filter to selector options (#6951)
* feat: add workload filtering support to selector options

Signed-off-by: WenyXu <wenymedia@gmail.com>

* chore: apply suggestions

Signed-off-by: WenyXu <wenymedia@gmail.com>

---------

Signed-off-by: WenyXu <wenymedia@gmail.com>
2025-09-11 03:11:13 +00:00
discord9
2bddbe8c47 feat(query): better alias tracker (#6909)
* better resolve

Signed-off-by: discord9 <discord9@163.com>

feat: layered alias tracker

Signed-off-by: discord9 <discord9@163.com>

refactor

Signed-off-by: discord9 <discord9@163.com>

docs: expalin for no offset by one

Signed-off-by: discord9 <discord9@163.com>

test: more

Signed-off-by: discord9 <discord9@163.com>

simpify api

Signed-off-by: discord9 <discord9@163.com>

wip

Signed-off-by: discord9 <discord9@163.com>

fix: filter non-exist columns

Signed-off-by: discord9 <discord9@163.com>

feat: stuff

Signed-off-by: discord9 <discord9@163.com>

feat: cache partition columns

Signed-off-by: discord9 <discord9@163.com>

refactor: rm unused fn

Signed-off-by: discord9 <discord9@163.com>

no need res

Signed-off-by: discord9 <discord9@163.com>

chore: rm unwrap&docs update

Signed-off-by: discord9 <discord9@163.com>

* chore: after rebase fix

Signed-off-by: discord9 <discord9@163.com>

* refactor: per review

Signed-off-by: discord9 <discord9@163.com>

* fix: unsupport part

Signed-off-by: discord9 <discord9@163.com>

* err msg

Signed-off-by: discord9 <discord9@163.com>

* fix: pass correct partition cols

Signed-off-by: discord9 <discord9@163.com>

* fix? use column name only

Signed-off-by: discord9 <discord9@163.com>

* fix: merge scan has partition columns no alias/no partition diff

Signed-off-by: discord9 <discord9@163.com>

* refactor: loop instead of recursive

Signed-off-by: discord9 <discord9@163.com>

* refactor: per review

Signed-off-by: discord9 <discord9@163.com>

* feat: overlaps

Signed-off-by: discord9 <discord9@163.com>

---------

Signed-off-by: discord9 <discord9@163.com>
2025-09-11 02:30:51 +00:00
discord9
ea8125aafb fix: count(1) instead of count(ts) when >1 inputs (#6952)
Signed-off-by: discord9 <discord9@163.com>
2025-09-10 21:30:43 +00:00
dennis zhuang
49722951c6 fix: unstable query sort results (#6944)
Signed-off-by: Dennis Zhuang <killme2008@gmail.com>
2025-09-10 20:41:10 +00:00
271 changed files with 13774 additions and 4079 deletions

View File

@@ -1,7 +1,7 @@
name: "Semantic Pull Request"
on:
pull_request:
pull_request_target:
types:
- opened
- reopened
@@ -12,9 +12,9 @@ concurrency:
cancel-in-progress: true
permissions:
issues: write
contents: write
contents: read
pull-requests: write
issues: write
jobs:
check:

7
Cargo.lock generated
View File

@@ -5302,7 +5302,7 @@ dependencies = [
[[package]]
name = "greptime-proto"
version = "0.1.0"
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=f9836cf8aab30e672f640c6ef4c1cfd2cf9fbc36#f9836cf8aab30e672f640c6ef4c1cfd2cf9fbc36"
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=3e821d0d405e6733690a4e4352812ba2ff780a3e#3e821d0d405e6733690a4e4352812ba2ff780a3e"
dependencies = [
"prost 0.13.5",
"prost-types 0.13.5",
@@ -7287,8 +7287,7 @@ checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0"
[[package]]
name = "memcomparable"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "376101dbd964fc502d5902216e180f92b3d003b5cc3d2e40e044eb5470fca677"
source = "git+https://github.com/v0y4g3r/memcomparable.git?rev=a07122dc03556bbd88ad66234cbea7efd3b23efb#a07122dc03556bbd88ad66234cbea7efd3b23efb"
dependencies = [
"bytes",
"serde",
@@ -7607,7 +7606,6 @@ dependencies = [
"itertools 0.14.0",
"lazy_static",
"log-store",
"memcomparable",
"mito-codec",
"moka",
"object-store",
@@ -12360,6 +12358,7 @@ dependencies = [
"sqlparser 0.55.0-greptime",
"strum 0.27.1",
"tokio",
"uuid",
]
[[package]]

View File

@@ -145,7 +145,7 @@ etcd-client = { git = "https://github.com/GreptimeTeam/etcd-client", rev = "f62d
fst = "0.4.7"
futures = "0.3"
futures-util = "0.3"
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "f9836cf8aab30e672f640c6ef4c1cfd2cf9fbc36" }
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "3e821d0d405e6733690a4e4352812ba2ff780a3e" }
hex = "0.4"
http = "1"
humantime = "2.1"

View File

@@ -151,6 +151,7 @@
| `region_engine.mito.max_concurrent_scan_files` | Integer | `384` | Maximum number of SST files to scan concurrently. |
| `region_engine.mito.allow_stale_entries` | Bool | `false` | Whether to allow stale WAL entries read during replay. |
| `region_engine.mito.min_compaction_interval` | String | `0m` | Minimum time interval between two compactions.<br/>To align with the old behavior, the default value is 0 (no restrictions). |
| `region_engine.mito.enable_experimental_flat_format` | Bool | `false` | Whether to enable experimental flat format. |
| `region_engine.mito.index` | -- | -- | The options for index in Mito engine. |
| `region_engine.mito.index.aux_path` | String | `""` | Auxiliary directory path for the index in filesystem, used to store intermediate files for<br/>creating the index and staging files for searching the index, defaults to `{data_home}/index_intermediate`.<br/>The default name for this directory is `index_intermediate` for backward compatibility.<br/><br/>This path contains two subdirectories:<br/>- `__intm`: for storing intermediate files used during creating index.<br/>- `staging`: for storing staging files used during searching index. |
| `region_engine.mito.index.staging_size` | String | `2GB` | The max capacity of the staging directory. |
@@ -543,6 +544,7 @@
| `region_engine.mito.max_concurrent_scan_files` | Integer | `384` | Maximum number of SST files to scan concurrently. |
| `region_engine.mito.allow_stale_entries` | Bool | `false` | Whether to allow stale WAL entries read during replay. |
| `region_engine.mito.min_compaction_interval` | String | `0m` | Minimum time interval between two compactions.<br/>To align with the old behavior, the default value is 0 (no restrictions). |
| `region_engine.mito.enable_experimental_flat_format` | Bool | `false` | Whether to enable experimental flat format. |
| `region_engine.mito.index` | -- | -- | The options for index in Mito engine. |
| `region_engine.mito.index.aux_path` | String | `""` | Auxiliary directory path for the index in filesystem, used to store intermediate files for<br/>creating the index and staging files for searching the index, defaults to `{data_home}/index_intermediate`.<br/>The default name for this directory is `index_intermediate` for backward compatibility.<br/><br/>This path contains two subdirectories:<br/>- `__intm`: for storing intermediate files used during creating index.<br/>- `staging`: for storing staging files used during searching index. |
| `region_engine.mito.index.staging_size` | String | `2GB` | The max capacity of the staging directory. |

View File

@@ -497,6 +497,9 @@ allow_stale_entries = false
## To align with the old behavior, the default value is 0 (no restrictions).
min_compaction_interval = "0m"
## Whether to enable experimental flat format.
enable_experimental_flat_format = false
## The options for index in Mito engine.
[region_engine.mito.index]

View File

@@ -576,6 +576,9 @@ allow_stale_entries = false
## To align with the old behavior, the default value is 0 (no restrictions).
min_compaction_interval = "0m"
## Whether to enable experimental flat format.
enable_experimental_flat_format = false
## The options for index in Mito engine.
[region_engine.mito.index]

View File

@@ -30,22 +30,7 @@ curl https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph
## Profiling
### Configuration
You can control heap profiling activation through configuration. Add the following to your configuration file:
```toml
[memory]
# Whether to enable heap profiling activation during startup.
# When enabled, heap profiling will be activated if the `MALLOC_CONF` environment variable
# is set to "prof:true,prof_active:false". The official image adds this env variable.
# Default is true.
enable_heap_profiling = true
```
By default, if you set `MALLOC_CONF=prof:true,prof_active:false`, the database will enable profiling during startup. You can disable this behavior by setting `enable_heap_profiling = false` in the configuration.
### Starting with environment variables
### Enable memory profiling for greptimedb binary
Start GreptimeDB instance with environment variables:
@@ -57,6 +42,22 @@ MALLOC_CONF=prof:true ./target/debug/greptime standalone start
_RJEM_MALLOC_CONF=prof:true ./target/debug/greptime standalone start
```
### Memory profiling for greptimedb docker image
We have memory profiling enabled and activated by default in our official docker
image.
This behavior is controlled by configuration `enable_heap_profiling`:
```toml
[memory]
# Whether to enable heap profiling activation during startup.
# Default is true.
enable_heap_profiling = true
```
To disable memory profiling, set `enable_heap_profiling` to `false`.
### Memory profiling control
You can control heap profiling activation using the new HTTP APIs:

41
scripts/generate_certs.sh Executable file
View File

@@ -0,0 +1,41 @@
#!/usr/bin/env bash
set -euo pipefail
CERT_DIR="${1:-$(dirname "$0")/../tests-integration/fixtures/certs}"
DAYS="${2:-365}"
mkdir -p "${CERT_DIR}"
cd "${CERT_DIR}"
echo "Generating CA certificate..."
openssl req -new -x509 -days "${DAYS}" -nodes -text \
-out root.crt -keyout root.key \
-subj "/CN=GreptimeDBRootCA"
echo "Generating server certificate..."
openssl req -new -nodes -text \
-out server.csr -keyout server.key \
-subj "/CN=greptime"
openssl x509 -req -in server.csr -text -days "${DAYS}" \
-CA root.crt -CAkey root.key -CAcreateserial \
-out server.crt \
-extensions v3_req -extfile <(printf "[v3_req]\nsubjectAltName=DNS:localhost,IP:127.0.0.1")
echo "Generating client certificate..."
# Make sure the client certificate is for the greptimedb user
openssl req -new -nodes -text \
-out client.csr -keyout client.key \
-subj "/CN=greptimedb"
openssl x509 -req -in client.csr -CA root.crt -CAkey root.key -CAcreateserial \
-out client.crt -days 365 -extensions v3_req -extfile <(printf "[v3_req]\nsubjectAltName=DNS:localhost")
rm -f *.csr
echo "TLS certificates generated successfully in ${CERT_DIR}"
chmod 644 root.key
chmod 644 client.key
chmod 644 server.key

View File

@@ -19,8 +19,8 @@ use common_error::ext::BoxedError;
use common_meta::kv_backend::KvBackendRef;
use common_meta::kv_backend::chroot::ChrootKvBackend;
use common_meta::kv_backend::etcd::EtcdStore;
use meta_srv::bootstrap::create_etcd_client_with_tls;
use meta_srv::metasrv::BackendImpl;
use meta_srv::utils::etcd::create_etcd_client_with_tls;
use servers::tls::{TlsMode, TlsOption};
use crate::error::{EmptyStoreAddrsSnafu, UnsupportedMemoryBackendSnafu};
@@ -83,6 +83,20 @@ pub(crate) struct StoreConfig {
}
impl StoreConfig {
pub fn tls_config(&self) -> Option<TlsOption> {
if self.backend_tls_mode != TlsMode::Disable {
Some(TlsOption {
mode: self.backend_tls_mode.clone(),
cert_path: self.backend_tls_cert_path.clone(),
key_path: self.backend_tls_key_path.clone(),
ca_cert_path: self.backend_tls_ca_cert_path.clone(),
watch: self.backend_tls_watch,
})
} else {
None
}
}
/// Builds a [`KvBackendRef`] from the store configuration.
pub async fn build(&self) -> Result<KvBackendRef, BoxedError> {
let max_txn_ops = self.max_txn_ops;
@@ -92,17 +106,7 @@ impl StoreConfig {
} else {
let kvbackend = match self.backend {
BackendImpl::EtcdStore => {
let tls_config = if self.backend_tls_mode != TlsMode::Disable {
Some(TlsOption {
mode: self.backend_tls_mode.clone(),
cert_path: self.backend_tls_cert_path.clone(),
key_path: self.backend_tls_key_path.clone(),
ca_cert_path: self.backend_tls_ca_cert_path.clone(),
watch: self.backend_tls_watch,
})
} else {
None
};
let tls_config = self.tls_config();
let etcd_client = create_etcd_client_with_tls(store_addrs, tls_config.as_ref())
.await
.map_err(BoxedError::new)?;
@@ -111,9 +115,14 @@ impl StoreConfig {
#[cfg(feature = "pg_kvbackend")]
BackendImpl::PostgresStore => {
let table_name = &self.meta_table_name;
let pool = meta_srv::bootstrap::create_postgres_pool(store_addrs, None)
.await
.map_err(BoxedError::new)?;
let tls_config = self.tls_config();
let pool = meta_srv::utils::postgres::create_postgres_pool(
store_addrs,
None,
tls_config,
)
.await
.map_err(BoxedError::new)?;
let schema_name = self.meta_schema_name.as_deref();
Ok(common_meta::kv_backend::rds::PgStore::with_pg_pool(
pool,
@@ -127,9 +136,11 @@ impl StoreConfig {
#[cfg(feature = "mysql_kvbackend")]
BackendImpl::MysqlStore => {
let table_name = &self.meta_table_name;
let pool = meta_srv::bootstrap::create_mysql_pool(store_addrs)
.await
.map_err(BoxedError::new)?;
let tls_config = self.tls_config();
let pool =
meta_srv::utils::mysql::create_mysql_pool(store_addrs, tls_config.as_ref())
.await
.map_err(BoxedError::new)?;
Ok(common_meta::kv_backend::rds::MySqlStore::with_mysql_pool(
pool,
table_name,

View File

@@ -196,7 +196,10 @@ pub async fn stream_to_parquet(
concurrency: usize,
) -> Result<usize> {
let write_props = column_wise_config(
WriterProperties::builder().set_compression(Compression::ZSTD(ZstdLevel::default())),
WriterProperties::builder()
.set_compression(Compression::ZSTD(ZstdLevel::default()))
.set_statistics_truncate_length(None)
.set_column_index_truncate_length(None),
schema,
)
.build();

View File

@@ -12,23 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod add_region_follower;
mod flush_compact_region;
mod flush_compact_table;
mod migrate_region;
mod reconcile_catalog;
mod reconcile_database;
mod reconcile_table;
mod remove_region_follower;
use add_region_follower::AddRegionFollowerFunction;
use flush_compact_region::{CompactRegionFunction, FlushRegionFunction};
use flush_compact_table::{CompactTableFunction, FlushTableFunction};
use migrate_region::MigrateRegionFunction;
use reconcile_catalog::ReconcileCatalogFunction;
use reconcile_database::ReconcileDatabaseFunction;
use reconcile_table::ReconcileTableFunction;
use remove_region_follower::RemoveRegionFollowerFunction;
use crate::flush_flow::FlushFlowFunction;
use crate::function_registry::FunctionRegistry;
@@ -40,8 +36,6 @@ impl AdminFunction {
/// Register all admin functions to [`FunctionRegistry`].
pub fn register(registry: &FunctionRegistry) {
registry.register(MigrateRegionFunction::factory());
registry.register(AddRegionFollowerFunction::factory());
registry.register(RemoveRegionFollowerFunction::factory());
registry.register(FlushRegionFunction::factory());
registry.register(CompactRegionFunction::factory());
registry.register(FlushTableFunction::factory());

View File

@@ -1,155 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_macro::admin_fn;
use common_meta::rpc::procedure::AddRegionFollowerRequest;
use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
use snafu::ensure;
use crate::handlers::ProcedureServiceHandlerRef;
use crate::helper::cast_u64;
/// A function to add a follower to a region.
/// Only available in cluster mode.
///
/// - `add_region_follower(region_id, peer_id)`.
///
/// The parameters:
/// - `region_id`: the region id
/// - `peer_id`: the peer id
#[admin_fn(
name = AddRegionFollowerFunction,
display_name = add_region_follower,
sig_fn = signature,
ret = uint64
)]
pub(crate) async fn add_region_follower(
procedure_service_handler: &ProcedureServiceHandlerRef,
_ctx: &QueryContextRef,
params: &[ValueRef<'_>],
) -> Result<Value> {
ensure!(
params.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
params.len()
),
}
);
let Some(region_id) = cast_u64(&params[0])? else {
return UnsupportedInputDataTypeSnafu {
function: "add_region_follower",
datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
}
.fail();
};
let Some(peer_id) = cast_u64(&params[1])? else {
return UnsupportedInputDataTypeSnafu {
function: "add_region_follower",
datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
}
.fail();
};
procedure_service_handler
.add_region_follower(AddRegionFollowerRequest { region_id, peer_id })
.await?;
Ok(Value::from(0u64))
}
fn signature() -> Signature {
Signature::one_of(
vec![
// add_region_follower(region_id, peer)
TypeSignature::Uniform(
2,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
],
Volatility::Immutable,
)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::UInt64Array;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use super::*;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[test]
fn test_add_region_follower_misc() {
let factory: ScalarFunctionFactory = AddRegionFollowerFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
assert_eq!("add_region_follower", f.name());
assert_eq!(DataType::UInt64, f.return_type(&[]).unwrap());
assert!(matches!(f.signature(),
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
volatility: datafusion_expr::Volatility::Immutable
} if sigs.len() == 1));
}
#[tokio::test]
async fn test_add_region_follower() {
let factory: ScalarFunctionFactory = AddRegionFollowerFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![2]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 0u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(0)));
}
}
}
}

View File

@@ -1,155 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_macro::admin_fn;
use common_meta::rpc::procedure::RemoveRegionFollowerRequest;
use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
use snafu::ensure;
use crate::handlers::ProcedureServiceHandlerRef;
use crate::helper::cast_u64;
/// A function to remove a follower from a region.
//// Only available in cluster mode.
///
/// - `remove_region_follower(region_id, peer_id)`.
///
/// The parameters:
/// - `region_id`: the region id
/// - `peer_id`: the peer id
#[admin_fn(
name = RemoveRegionFollowerFunction,
display_name = remove_region_follower,
sig_fn = signature,
ret = uint64
)]
pub(crate) async fn remove_region_follower(
procedure_service_handler: &ProcedureServiceHandlerRef,
_ctx: &QueryContextRef,
params: &[ValueRef<'_>],
) -> Result<Value> {
ensure!(
params.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
params.len()
),
}
);
let Some(region_id) = cast_u64(&params[0])? else {
return UnsupportedInputDataTypeSnafu {
function: "add_region_follower",
datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
}
.fail();
};
let Some(peer_id) = cast_u64(&params[1])? else {
return UnsupportedInputDataTypeSnafu {
function: "add_region_follower",
datatypes: params.iter().map(|v| v.data_type()).collect::<Vec<_>>(),
}
.fail();
};
procedure_service_handler
.remove_region_follower(RemoveRegionFollowerRequest { region_id, peer_id })
.await?;
Ok(Value::from(0u64))
}
fn signature() -> Signature {
Signature::one_of(
vec![
// remove_region_follower(region_id, peer_id)
TypeSignature::Uniform(
2,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
],
Volatility::Immutable,
)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::UInt64Array;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use super::*;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
#[test]
fn test_remove_region_follower_misc() {
let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
assert_eq!("remove_region_follower", f.name());
assert_eq!(DataType::UInt64, f.return_type(&[]).unwrap());
assert!(matches!(f.signature(),
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
volatility: datafusion_expr::Volatility::Immutable
} if sigs.len() == 1));
}
#[tokio::test]
async fn test_remove_region_follower() {
let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 0u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(0)));
}
}
}
}

View File

@@ -15,11 +15,15 @@
use std::fmt;
use std::sync::Arc;
use common_query::error::Result;
use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_query::error::{ExecuteSnafu, Result};
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::Signature;
use datafusion::logical_expr::ColumnarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::vectors::VectorRef;
use session::context::{QueryContextBuilder, QueryContextRef};
use snafu::ResultExt;
use crate::state::FunctionState;
@@ -68,8 +72,26 @@ pub trait Function: fmt::Display + Sync + Send {
/// The signature of function.
fn signature(&self) -> Signature;
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
// TODO(LFC): Remove default implementation once all UDFs have implemented this function.
let _ = args;
Err(datafusion_common::DataFusionError::NotImplemented(
"invoke_with_args".to_string(),
))
}
/// Evaluate the function, e.g. run/execute the function.
fn eval(&self, ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef>;
/// TODO(LFC): Remove `eval` when all UDFs are rewritten to `invoke_with_args`
fn eval(&self, _: &FunctionContext, _: &[VectorRef]) -> Result<VectorRef> {
Err(BoxedError::new(PlainError::new(
"unsupported".to_string(),
StatusCode::Unsupported,
)))
.context(ExecuteSnafu)
}
fn aliases(&self) -> &[String] {
&[]

View File

@@ -19,8 +19,7 @@ use async_trait::async_trait;
use catalog::CatalogManagerRef;
use common_base::AffectedRows;
use common_meta::rpc::procedure::{
AddRegionFollowerRequest, MigrateRegionRequest, ProcedureStateResponse,
RemoveRegionFollowerRequest,
ManageRegionFollowerRequest, MigrateRegionRequest, ProcedureStateResponse,
};
use common_query::Output;
use common_query::error::Result;
@@ -72,11 +71,8 @@ pub trait ProcedureServiceHandler: Send + Sync {
/// Query the procedure' state by its id
async fn query_procedure_state(&self, pid: &str) -> Result<ProcedureStateResponse>;
/// Add a region follower to a region.
async fn add_region_follower(&self, request: AddRegionFollowerRequest) -> Result<()>;
/// Remove a region follower from a region.
async fn remove_region_follower(&self, request: RemoveRegionFollowerRequest) -> Result<()>;
/// Manage a region follower to a region.
async fn manage_region_follower(&self, request: ManageRegionFollowerRequest) -> Result<()>;
/// Get the catalog manager
fn catalog_manager(&self) -> &CatalogManagerRef;

View File

@@ -14,14 +14,15 @@
use std::fmt;
use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use common_query::error::{ArrowComputeSnafu, Result};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::utils;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::compute::kernels::numeric;
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datatypes::vectors::{Helper, VectorRef};
use snafu::{ResultExt, ensure};
use snafu::ResultExt;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
/// A function adds an interval value to Timestamp, Date, and return the result.
@@ -58,25 +59,15 @@ impl Function for DateAddFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 2, have: {}",
columns.len()
),
}
);
let left = columns[0].to_arrow_array();
let right = columns[1].to_arrow_array();
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [left, right] = utils::take_function_args(self.name(), args)?;
let result = numeric::add(&left, &right).context(ArrowComputeSnafu)?;
let arrow_type = result.data_type().clone();
Helper::try_into_vector(result).context(IntoVectorSnafu {
data_type: arrow_type,
})
Ok(ColumnarValue::Array(result))
}
}
@@ -90,12 +81,14 @@ impl fmt::Display for DateAddFunction {
mod tests {
use std::sync::Arc;
use datafusion_expr::{TypeSignature, Volatility};
use datatypes::arrow::datatypes::IntervalDayTime;
use datatypes::value::Value;
use datatypes::vectors::{
DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector,
use arrow_schema::Field;
use datafusion::arrow::array::{
Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
TimestampSecondArray,
};
use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{TypeSignature, Volatility};
use super::{DateAddFunction, *};
@@ -142,25 +135,37 @@ mod tests {
];
let results = [Some(124), None, Some(45), None];
let time_vector = TimestampSecondVector::from(times.clone());
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
let args = vec![
ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
];
let vector = f
.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new(
"x",
DataType::Timestamp(TimeUnit::Second, None),
true,
)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let vector = vector.as_primitive::<TimestampSecondType>();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);
let result = results.get(i).unwrap();
if result.is_none() {
assert_eq!(Value::Null, v);
continue;
}
match v {
Value::Timestamp(ts) => {
assert_eq!(ts.value(), result.unwrap());
}
_ => unreachable!(),
if let Some(x) = result {
assert!(vector.is_valid(i));
assert_eq!(vector.value(i), *x);
} else {
assert!(vector.is_null(i));
}
}
}
@@ -174,25 +179,37 @@ mod tests {
let intervals = vec![1, 2, 3, 1];
let results = [Some(154), None, Some(131), None];
let date_vector = DateVector::from(dates.clone());
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
let args = vec![
ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
];
let vector = f
.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new(
"x",
DataType::Timestamp(TimeUnit::Second, None),
true,
)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let vector = vector.as_primitive::<Date32Type>();
assert_eq!(4, vector.len());
for (i, _t) in dates.iter().enumerate() {
let v = vector.get(i);
let result = results.get(i).unwrap();
if result.is_none() {
assert_eq!(Value::Null, v);
continue;
}
match v {
Value::Date(date) => {
assert_eq!(date.val(), result.unwrap());
}
_ => unreachable!(),
if let Some(x) = result {
assert!(vector.is_valid(i));
assert_eq!(vector.value(i), *x);
} else {
assert!(vector.is_null(i));
}
}
}

View File

@@ -14,14 +14,15 @@
use std::fmt;
use common_query::error::{ArrowComputeSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use common_query::error::{ArrowComputeSnafu, Result};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::utils;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::compute::kernels::numeric;
use datatypes::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
use datatypes::vectors::{Helper, VectorRef};
use snafu::{ResultExt, ensure};
use snafu::ResultExt;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
/// A function subtracts an interval value to Timestamp, Date, and return the result.
@@ -58,25 +59,15 @@ impl Function for DateSubFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 2, have: {}",
columns.len()
),
}
);
let left = columns[0].to_arrow_array();
let right = columns[1].to_arrow_array();
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [left, right] = utils::take_function_args(self.name(), args)?;
let result = numeric::sub(&left, &right).context(ArrowComputeSnafu)?;
let arrow_type = result.data_type().clone();
Helper::try_into_vector(result).context(IntoVectorSnafu {
data_type: arrow_type,
})
Ok(ColumnarValue::Array(result))
}
}
@@ -90,12 +81,14 @@ impl fmt::Display for DateSubFunction {
mod tests {
use std::sync::Arc;
use datafusion_expr::{TypeSignature, Volatility};
use datatypes::arrow::datatypes::IntervalDayTime;
use datatypes::value::Value;
use datatypes::vectors::{
DateVector, IntervalDayTimeVector, IntervalYearMonthVector, TimestampSecondVector,
use arrow_schema::Field;
use datafusion::arrow::array::{
Array, AsArray, Date32Array, IntervalDayTimeArray, IntervalYearMonthArray,
TimestampSecondArray,
};
use datafusion::arrow::datatypes::{Date32Type, IntervalDayTime, TimestampSecondType};
use datafusion_common::config::ConfigOptions;
use datafusion_expr::{TypeSignature, Volatility};
use super::{DateSubFunction, *};
@@ -142,25 +135,37 @@ mod tests {
];
let results = [Some(122), None, Some(39), None];
let time_vector = TimestampSecondVector::from(times.clone());
let interval_vector = IntervalDayTimeVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(time_vector), Arc::new(interval_vector)];
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
let args = vec![
ColumnarValue::Array(Arc::new(TimestampSecondArray::from(times.clone()))),
ColumnarValue::Array(Arc::new(IntervalDayTimeArray::from(intervals))),
];
let vector = f
.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new(
"x",
DataType::Timestamp(TimeUnit::Second, None),
true,
)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let vector = vector.as_primitive::<TimestampSecondType>();
assert_eq!(4, vector.len());
for (i, _t) in times.iter().enumerate() {
let v = vector.get(i);
let result = results.get(i).unwrap();
if result.is_none() {
assert_eq!(Value::Null, v);
continue;
}
match v {
Value::Timestamp(ts) => {
assert_eq!(ts.value(), result.unwrap());
}
_ => unreachable!(),
if let Some(x) = result {
assert!(vector.is_valid(i));
assert_eq!(vector.value(i), *x);
} else {
assert!(vector.is_null(i));
}
}
}
@@ -180,25 +185,37 @@ mod tests {
let intervals = vec![1, 2, 3, 1];
let results = [Some(3659), None, Some(1168), None];
let date_vector = DateVector::from(dates.clone());
let interval_vector = IntervalYearMonthVector::from_vec(intervals);
let args: Vec<VectorRef> = vec![Arc::new(date_vector), Arc::new(interval_vector)];
let vector = f.eval(&FunctionContext::default(), &args).unwrap();
let args = vec![
ColumnarValue::Array(Arc::new(Date32Array::from(dates.clone()))),
ColumnarValue::Array(Arc::new(IntervalYearMonthArray::from(intervals))),
];
let vector = f
.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new(
"x",
DataType::Timestamp(TimeUnit::Second, None),
true,
)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let vector = vector.as_primitive::<Date32Type>();
assert_eq!(4, vector.len());
for (i, _t) in dates.iter().enumerate() {
let v = vector.get(i);
let result = results.get(i).unwrap();
if result.is_none() {
assert_eq!(Value::Null, v);
continue;
}
match v {
Value::Date(date) => {
assert_eq!(date.val(), result.unwrap());
}
_ => unreachable!(),
if let Some(x) = result {
assert!(vector.is_valid(i));
assert_eq!(vector.value(i), *x);
} else {
assert!(vector.is_null(i));
}
}
}

View File

@@ -17,62 +17,26 @@ use std::sync::Arc;
use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_query::error::{self, InvalidFuncArgsSnafu, Result};
use datafusion::arrow::datatypes::Field;
use common_query::error::{self, Result};
use datafusion::arrow::array::{Array, AsArray, ListBuilder, StringViewBuilder};
use datafusion::arrow::datatypes::{DataType, Field, Float64Type, UInt8Type};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, utils};
use datafusion_expr::type_coercion::aggregates::INTEGERS;
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::arrow::datatypes::DataType;
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::{Scalar, ScalarVectorBuilder};
use datatypes::value::{ListValue, Value};
use datatypes::vectors::{ListVectorBuilder, MutableVector, StringVectorBuilder, VectorRef};
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use geohash::Coord;
use snafu::{ResultExt, ensure};
use snafu::ResultExt;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::scalars::geo::helpers;
macro_rules! ensure_resolution_usize {
($v: ident) => {
if !($v > 0 && $v <= 12) {
Err(BoxedError::new(PlainError::new(
format!("Invalid geohash resolution {}, expect value: [1, 12]", $v),
StatusCode::EngineExecuteQuery,
)))
.context(error::ExecuteSnafu)
} else {
Ok($v as usize)
}
};
}
fn try_into_resolution(v: Value) -> Result<usize> {
match v {
Value::Int8(v) => {
ensure_resolution_usize!(v)
}
Value::Int16(v) => {
ensure_resolution_usize!(v)
}
Value::Int32(v) => {
ensure_resolution_usize!(v)
}
Value::Int64(v) => {
ensure_resolution_usize!(v)
}
Value::UInt8(v) => {
ensure_resolution_usize!(v)
}
Value::UInt16(v) => {
ensure_resolution_usize!(v)
}
Value::UInt32(v) => {
ensure_resolution_usize!(v)
}
Value::UInt64(v) => {
ensure_resolution_usize!(v)
}
_ => unreachable!(),
fn ensure_resolution_usize(v: u8) -> datafusion_common::Result<usize> {
if v == 0 || v > 12 {
return Err(DataFusionError::Execution(format!(
"Invalid geohash resolution {v}, valid value range: [1, 12]"
)));
}
Ok(v as usize)
}
/// Function that return geohash string for a given geospatial coordinate.
@@ -109,31 +73,33 @@ impl Function for GeohashFunction {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 3, provided : {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
let lat_vec = &columns[0];
let lon_vec = &columns[1];
let resolution_vec = &columns[2];
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
let lat_vec = lat_vec.as_primitive::<Float64Type>();
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
let lon_vec = lon_vec.as_primitive::<Float64Type>();
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
let resolutions = resolutions.as_primitive::<UInt8Type>();
let size = lat_vec.len();
let mut results = StringVectorBuilder::with_capacity(size);
let mut builder = StringViewBuilder::with_capacity(size);
for i in 0..size {
let lat = lat_vec.get(i).as_f64_lossy();
let lon = lon_vec.get(i).as_f64_lossy();
let r = try_into_resolution(resolution_vec.get(i))?;
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
let r = resolutions
.is_valid(i)
.then(|| ensure_resolution_usize(resolutions.value(i)))
.transpose()?;
let result = match (lat, lon) {
(Some(lat), Some(lon)) => {
let result = match (lat, lon, r) {
(Some(lat), Some(lon), Some(r)) => {
let coord = Coord { x: lon, y: lat };
let encoded = geohash::encode(coord, r)
.map_err(|e| {
@@ -148,10 +114,10 @@ impl Function for GeohashFunction {
_ => None,
};
results.push(result.as_deref());
builder.append_option(result);
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -176,8 +142,8 @@ impl Function for GeohashNeighboursFunction {
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::List(Arc::new(Field::new(
"x",
DataType::Utf8,
"item",
DataType::Utf8View,
false,
))))
}
@@ -199,32 +165,33 @@ impl Function for GeohashNeighboursFunction {
Signature::one_of(signatures, Volatility::Stable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect 3, provided : {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [lat_vec, lon_vec, resolutions] = utils::take_function_args(self.name(), args)?;
let lat_vec = &columns[0];
let lon_vec = &columns[1];
let resolution_vec = &columns[2];
let lat_vec = helpers::cast::<Float64Type>(&lat_vec)?;
let lat_vec = lat_vec.as_primitive::<Float64Type>();
let lon_vec = helpers::cast::<Float64Type>(&lon_vec)?;
let lon_vec = lon_vec.as_primitive::<Float64Type>();
let resolutions = helpers::cast::<UInt8Type>(&resolutions)?;
let resolutions = resolutions.as_primitive::<UInt8Type>();
let size = lat_vec.len();
let mut results =
ListVectorBuilder::with_type_capacity(ConcreteDataType::string_datatype(), size);
let mut builder = ListBuilder::new(StringViewBuilder::new());
for i in 0..size {
let lat = lat_vec.get(i).as_f64_lossy();
let lon = lon_vec.get(i).as_f64_lossy();
let r = try_into_resolution(resolution_vec.get(i))?;
let lat = lat_vec.is_valid(i).then(|| lat_vec.value(i));
let lon = lon_vec.is_valid(i).then(|| lon_vec.value(i));
let r = resolutions
.is_valid(i)
.then(|| ensure_resolution_usize(resolutions.value(i)))
.transpose()?;
let result = match (lat, lon) {
(Some(lat), Some(lon)) => {
match (lat, lon, r) {
(Some(lat), Some(lon), Some(r)) => {
let coord = Coord { x: lon, y: lat };
let encoded = geohash::encode(coord, r)
.map_err(|e| {
@@ -242,8 +209,8 @@ impl Function for GeohashNeighboursFunction {
))
})
.context(error::ExecuteSnafu)?;
Some(ListValue::new(
vec![
builder.append_value(
[
neighbours.n,
neighbours.nw,
neighbours.w,
@@ -254,22 +221,14 @@ impl Function for GeohashNeighboursFunction {
neighbours.ne,
]
.into_iter()
.map(Value::from)
.collect(),
ConcreteDataType::string_datatype(),
))
.map(Some),
);
}
_ => None,
_ => builder.append_null(),
};
if let Some(list_value) = result {
results.push(Some(list_value.as_scalar_ref()));
} else {
results.push(None);
}
}
Ok(results.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use datafusion::arrow::array::{ArrayRef, ArrowPrimitiveType};
use datafusion::arrow::compute;
macro_rules! ensure_columns_len {
($columns:ident) => {
snafu::ensure!(
@@ -73,3 +76,15 @@ macro_rules! ensure_and_coerce {
}
pub(crate) use ensure_and_coerce;
pub(crate) fn cast<T: ArrowPrimitiveType>(array: &ArrayRef) -> datafusion_common::Result<ArrayRef> {
let x = compute::cast_with_options(
array.as_ref(),
&T::DATA_TYPE,
&compute::CastOptions {
safe: false,
..Default::default()
},
)?;
Ok(x)
}

View File

@@ -16,23 +16,20 @@ use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use common_query::error::{
GeneralDataFusionSnafu, IntoVectorSnafu, InvalidFuncArgsSnafu, InvalidInputTypeSnafu, Result,
};
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion};
use datafusion::common::{DFSchema, Result as DfResult};
use datafusion::execution::SessionStateBuilder;
use datafusion::logical_expr::{self, Expr, Volatility};
use datafusion::logical_expr::{self, ColumnarValue, Expr, Volatility};
use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
use datafusion_expr::Signature;
use datafusion_common::{DataFusionError, utils};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::array::RecordBatch;
use datatypes::arrow::datatypes::{DataType, Field};
use datatypes::prelude::VectorRef;
use datatypes::vectors::BooleanVector;
use snafu::{OptionExt, ResultExt, ensure};
use store_api::storage::ConcreteDataType;
use snafu::{OptionExt, ensure};
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
/// `matches` for full text search.
@@ -67,38 +64,36 @@ impl Function for MatchesFunction {
}
// TODO: read case-sensitive config
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
columns.len()
),
}
);
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DfResult<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [data_column, patterns] = utils::take_function_args(self.name(), args)?;
let data_column = &columns[0];
if data_column.is_empty() {
return Ok(Arc::new(BooleanVector::from(Vec::<bool>::with_capacity(0))));
return Ok(ColumnarValue::Array(Arc::new(BooleanArray::from(
Vec::<bool>::with_capacity(0),
))));
}
let pattern_vector = &columns[1]
.cast(&ConcreteDataType::string_datatype())
.context(InvalidInputTypeSnafu {
err_msg: "cannot cast `pattern` to string",
})?;
// Safety: both length and type are checked before
let pattern = pattern_vector.get(0).as_string().unwrap();
let pattern = match patterns.data_type() {
DataType::Utf8View => patterns.as_string_view().value(0),
DataType::Utf8 => patterns.as_string::<i32>().value(0),
DataType::LargeUtf8 => patterns.as_string::<i64>().value(0),
t => {
return Err(DataFusionError::Execution(format!(
"unsupported datatype {t}"
)));
}
};
self.eval(data_column, pattern)
}
}
impl MatchesFunction {
fn eval(&self, data: &VectorRef, pattern: String) -> Result<VectorRef> {
fn eval(&self, data_array: ArrayRef, pattern: &str) -> DfResult<ColumnarValue> {
let col_name = "data";
let parser_context = ParserContext::default();
let raw_ast = parser_context.parse_pattern(&pattern)?;
let raw_ast = parser_context.parse_pattern(pattern)?;
let ast = raw_ast.transform_ast()?;
let like_expr = ast.into_like_expr(col_name);
@@ -106,27 +101,17 @@ impl MatchesFunction {
let input_schema = Self::input_schema();
let session_state = SessionStateBuilder::new().with_default_features().build();
let planner = DefaultPhysicalPlanner::default();
let physical_expr = planner
.create_physical_expr(&like_expr, &input_schema, &session_state)
.context(GeneralDataFusionSnafu)?;
let physical_expr =
planner.create_physical_expr(&like_expr, &input_schema, &session_state)?;
let data_array = data.to_arrow_array();
let arrow_schema = Arc::new(input_schema.as_arrow().clone());
let input_record_batch = RecordBatch::try_new(arrow_schema, vec![data_array]).unwrap();
let num_rows = input_record_batch.num_rows();
let result = physical_expr
.evaluate(&input_record_batch)
.context(GeneralDataFusionSnafu)?;
let result_array = result
.into_array(num_rows)
.context(GeneralDataFusionSnafu)?;
let result_vector =
BooleanVector::try_from_arrow_array(result_array).context(IntoVectorSnafu {
data_type: DataType::Boolean,
})?;
let result = physical_expr.evaluate(&input_record_batch)?;
let result_array = result.into_array(num_rows)?;
Ok(Arc::new(result_vector))
Ok(ColumnarValue::Array(Arc::new(result_array)))
}
fn input_schema() -> DFSchema {
@@ -210,14 +195,12 @@ impl PatternAst {
/// Transform this AST with preset rules to make it correct.
fn transform_ast(self) -> Result<Self> {
self.transform_up(Self::collapse_binary_branch_fn)
.context(GeneralDataFusionSnafu)
.map(|data| data.data)?
.transform_up(Self::eliminate_optional_fn)
.context(GeneralDataFusionSnafu)
.map(|data| data.data)?
.transform_down(Self::eliminate_single_child_fn)
.context(GeneralDataFusionSnafu)
.map(|data| data.data)
.map_err(Into::into)
}
/// Collapse binary branch with the same operator. I.e., this transformer
@@ -842,7 +825,9 @@ impl Tokenizer {
#[cfg(test)]
mod test {
use datatypes::vectors::StringVector;
use datafusion::arrow::array::StringArray;
use datafusion_common::ScalarValue;
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -1309,7 +1294,7 @@ mod test {
"The quick brown fox jumps over dog",
"The quick brown fox jumps over the dog",
];
let input_vector: VectorRef = Arc::new(StringVector::from(input_data));
let col: ArrayRef = Arc::new(StringArray::from(input_data));
let cases = [
// basic cases
("quick", vec![true, false, true, true, true, true, true]),
@@ -1400,9 +1385,22 @@ mod test {
let f = MatchesFunction;
for (pattern, expected) in cases {
let actual: VectorRef = f.eval(&input_vector, pattern.to_string()).unwrap();
let expected: VectorRef = Arc::new(BooleanVector::from(expected)) as _;
assert_eq!(expected, actual, "{pattern}");
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(col.clone()),
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(pattern.to_string()))),
],
arg_fields: vec![],
number_rows: col.len(),
return_field: Arc::new(Field::new("x", col.data_type().clone(), true)),
config_options: Arc::new(ConfigOptions::new()),
};
let actual = f
.invoke_with_args(args)
.and_then(|x| x.to_array(col.len()))
.unwrap();
let expected: ArrayRef = Arc::new(BooleanArray::from(expected));
assert_eq!(expected.as_ref(), actual.as_ref(), "{pattern}");
}
}
}

View File

@@ -19,15 +19,13 @@ mod rate;
use std::fmt;
pub use clamp::{ClampFunction, ClampMaxFunction, ClampMinFunction};
use common_query::error::{GeneralDataFusionSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::error::DataFusionError;
use datafusion_expr::{Signature, Volatility};
use datatypes::vectors::VectorRef;
pub use rate::RateFunction;
use snafu::ResultExt;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::function_registry::FunctionRegistry;
use crate::scalars::math::modulo::ModuloFunction;
@@ -68,7 +66,7 @@ impl Function for RangeFunction {
.ok_or(DataFusionError::Internal(
"No expr found in range_fn".into(),
))
.context(GeneralDataFusionSnafu)
.map_err(Into::into)
}
/// `range_fn` will never been used. As long as a legal signature is returned, the specific content of the signature does not matter.
@@ -76,11 +74,4 @@ impl Function for RangeFunction {
fn signature(&self) -> Signature {
Signature::variadic_any(Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
Err(DataFusionError::Internal(
"range_fn just a empty function used in range select, It should not be eval!".into(),
))
.context(GeneralDataFusionSnafu)
}
}

View File

@@ -15,54 +15,21 @@
use std::fmt::{self, Display};
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion::arrow::array::{ArrayIter, PrimitiveArray};
use common_query::error::Result;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, PrimitiveArray};
use datafusion::arrow::datatypes::DataType as ArrowDataType;
use datafusion::logical_expr::Volatility;
use datafusion_expr::Signature;
use datafusion::logical_expr::{ColumnarValue, Volatility};
use datafusion_common::{DataFusionError, ScalarValue, utils};
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datatypes::data_type::DataType;
use datatypes::prelude::VectorRef;
use datatypes::types::LogicalPrimitiveType;
use datatypes::value::TryAsPrimitive;
use datatypes::vectors::PrimitiveVector;
use datatypes::with_match_primitive_type_id;
use snafu::{OptionExt, ensure};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use crate::function::{Function, FunctionContext};
use crate::function::Function;
#[derive(Clone, Debug, Default)]
pub struct ClampFunction;
const CLAMP_NAME: &str = "clamp";
/// Ensure the vector is constant and not empty (i.e., all values are identical)
fn ensure_constant_vector(vector: &VectorRef) -> Result<()> {
ensure!(
!vector.is_empty(),
InvalidFuncArgsSnafu {
err_msg: "Expect at least one value",
}
);
if vector.is_const() {
return Ok(());
}
let first = vector.get_ref(0);
for i in 1..vector.len() {
let v = vector.get_ref(i);
if first != v {
return InvalidFuncArgsSnafu {
err_msg: "All values in min/max argument must be identical",
}
.fail();
}
}
Ok(())
}
impl Function for ClampFunction {
fn name(&self) -> &str {
CLAMP_NAME
@@ -78,76 +45,12 @@ impl Function for ClampFunction {
Signature::uniform(3, NUMERICS.to_vec(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 3, have: {}",
columns.len()
),
}
);
ensure!(
columns[0].data_type().is_numeric(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The first arg's type is not numeric, have: {}",
columns[0].data_type()
),
}
);
ensure!(
columns[0].data_type() == columns[1].data_type()
&& columns[1].data_type() == columns[2].data_type(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Arguments don't have identical types: {}, {}, {}",
columns[0].data_type(),
columns[1].data_type(),
columns[2].data_type()
),
}
);
ensure_constant_vector(&columns[1])?;
ensure_constant_vector(&columns[2])?;
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
let input_array = columns[0].to_arrow_array();
let input = input_array
.as_any()
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
.unwrap();
let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The second arg should not be none",
}
})?;
let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[2].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The third arg should not be none",
}
})?;
// ensure min <= max
ensure!(
min <= max,
InvalidFuncArgsSnafu {
err_msg: format!(
"The second arg should be less than or equal to the third arg, have: {:?}, {:?}",
columns[1], columns[2]
),
}
);
clamp_impl::<$S, true, true>(input, min, max)
},{
unreachable!()
})
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [col, min, max] = utils::take_function_args(self.name(), args.args)?;
clamp_impl(col, min, max)
}
}
@@ -157,25 +60,155 @@ impl Display for ClampFunction {
}
}
fn clamp_impl<T: LogicalPrimitiveType, const CLAMP_MIN: bool, const CLAMP_MAX: bool>(
input: &PrimitiveArray<T::ArrowPrimitive>,
min: T::Native,
max: T::Native,
) -> Result<VectorRef> {
let iter = ArrayIter::new(input);
let result = iter.map(|x| {
x.map(|x| {
if CLAMP_MIN && x < min {
min
} else if CLAMP_MAX && x > max {
max
} else {
x
fn clamp_impl(
col: ColumnarValue,
min: ColumnarValue,
max: ColumnarValue,
) -> datafusion_common::Result<ColumnarValue> {
if col.data_type() != min.data_type() || min.data_type() != max.data_type() {
return Err(DataFusionError::Execution(format!(
"argument data types mismatch: {}, {}, {}",
col.data_type(),
min.data_type(),
max.data_type(),
)));
}
macro_rules! with_match_numerics_types {
($data_type:expr, | $_:tt $T:ident | $body:tt) => {{
macro_rules! __with_ty__ {
( $_ $T:ident ) => {
$body
};
}
})
});
let result = PrimitiveArray::<T::ArrowPrimitive>::from_iter(result);
Ok(Arc::new(PrimitiveVector::<T>::from(result)))
use datafusion::arrow::datatypes::{
Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, UInt8Type,
UInt16Type, UInt32Type, UInt64Type,
};
match $data_type {
ArrowDataType::Int8 => Ok(__with_ty__! { Int8Type }),
ArrowDataType::Int16 => Ok(__with_ty__! { Int16Type }),
ArrowDataType::Int32 => Ok(__with_ty__! { Int32Type }),
ArrowDataType::Int64 => Ok(__with_ty__! { Int64Type }),
ArrowDataType::UInt8 => Ok(__with_ty__! { UInt8Type }),
ArrowDataType::UInt16 => Ok(__with_ty__! { UInt16Type }),
ArrowDataType::UInt32 => Ok(__with_ty__! { UInt32Type }),
ArrowDataType::UInt64 => Ok(__with_ty__! { UInt64Type }),
ArrowDataType::Float32 => Ok(__with_ty__! { Float32Type }),
ArrowDataType::Float64 => Ok(__with_ty__! { Float64Type }),
_ => Err(DataFusionError::Execution(format!(
"unsupported numeric data type: '{}'",
$data_type
))),
}
}};
}
macro_rules! clamp {
($v: ident, $min: ident, $max: ident) => {
if $v < $min {
$min
} else if $v > $max {
$max
} else {
$v
}
};
}
match (col, min, max) {
(ColumnarValue::Scalar(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
if min > max {
return Err(DataFusionError::Execution(format!(
"min '{}' > max '{}'",
min, max
)));
}
Ok(ColumnarValue::Scalar(clamp!(col, min, max)))
}
(ColumnarValue::Array(col), ColumnarValue::Array(min), ColumnarValue::Array(max)) => {
if col.len() != min.len() || col.len() != max.len() {
return Err(DataFusionError::Internal(
"arguments not of same length".to_string(),
));
}
let result = with_match_numerics_types!(
col.data_type(),
|$S| {
let col = col.as_primitive::<$S>();
let min = min.as_primitive::<$S>();
let max = max.as_primitive::<$S>();
Arc::new(PrimitiveArray::<$S>::from(
(0..col.len())
.map(|i| {
let v = col.is_valid(i).then(|| col.value(i));
// Index safety: checked above, all have same length.
let min = min.is_valid(i).then(|| min.value(i));
let max = max.is_valid(i).then(|| max.value(i));
Ok(match (v, min, max) {
(Some(v), Some(min), Some(max)) => {
if min > max {
return Err(DataFusionError::Execution(format!(
"min '{}' > max '{}'",
min, max
)));
}
Some(clamp!(v, min, max))
},
_ => None,
})
})
.collect::<datafusion_common::Result<Vec<_>>>()?,
)
) as ArrayRef
}
)?;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Array(col), ColumnarValue::Scalar(min), ColumnarValue::Scalar(max)) => {
if min.is_null() || max.is_null() {
return Err(DataFusionError::Execution(
"argument 'min' or 'max' is null".to_string(),
));
}
let min = min.to_array()?;
let max = max.to_array()?;
let result = with_match_numerics_types!(
col.data_type(),
|$S| {
let col = col.as_primitive::<$S>();
// Index safety: checked above, both are not nulls.
let min = min.as_primitive::<$S>().value(0);
let max = max.as_primitive::<$S>().value(0);
if min > max {
return Err(DataFusionError::Execution(format!(
"min '{}' > max '{}'",
min, max
)));
}
Arc::new(PrimitiveArray::<$S>::from(
(0..col.len())
.map(|x| {
col.is_valid(x).then(|| {
let v = col.value(x);
clamp!(v, min, max)
})
})
.collect::<Vec<_>>(),
)
) as ArrayRef
}
)?;
Ok(ColumnarValue::Array(result))
}
_ => Err(DataFusionError::Internal(
"argument column types mismatch".to_string(),
)),
}
}
#[derive(Clone, Debug, Default)]
@@ -197,59 +230,19 @@ impl Function for ClampMinFunction {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
columns.len()
),
}
);
ensure!(
columns[0].data_type().is_numeric(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The first arg's type is not numeric, have: {}",
columns[0].data_type()
),
}
);
ensure!(
columns[0].data_type() == columns[1].data_type(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Arguments don't have identical types: {}, {}",
columns[0].data_type(),
columns[1].data_type()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [col, min] = utils::take_function_args(self.name(), args.args)?;
ensure_constant_vector(&columns[1])?;
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
let input_array = columns[0].to_arrow_array();
let input = input_array
.as_any()
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
.unwrap();
let min = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The second arg (min) should not be none",
}
})?;
// For clamp_min, max is effectively infinity, so we don't use it in the clamp_impl logic.
// We pass a default/dummy value for max.
let max_dummy = <$S as LogicalPrimitiveType>::Native::default();
clamp_impl::<$S, true, false>(input, min, max_dummy)
},{
unreachable!()
})
let Some(max) = ScalarValue::max(&min.data_type()) else {
return Err(DataFusionError::Internal(format!(
"cannot find a max value for numeric data type {}",
min.data_type()
)));
};
clamp_impl(col, min, ColumnarValue::Scalar(max))
}
}
@@ -278,59 +271,19 @@ impl Function for ClampMaxFunction {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly 2, have: {}",
columns.len()
),
}
);
ensure!(
columns[0].data_type().is_numeric(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The first arg's type is not numeric, have: {}",
columns[0].data_type()
),
}
);
ensure!(
columns[0].data_type() == columns[1].data_type(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Arguments don't have identical types: {}, {}",
columns[0].data_type(),
columns[1].data_type()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let [col, max] = utils::take_function_args(self.name(), args.args)?;
ensure_constant_vector(&columns[1])?;
with_match_primitive_type_id!(columns[0].data_type().logical_type_id(), |$S| {
let input_array = columns[0].to_arrow_array();
let input = input_array
.as_any()
.downcast_ref::<PrimitiveArray<<$S as LogicalPrimitiveType>::ArrowPrimitive>>()
.unwrap();
let max = TryAsPrimitive::<$S>::try_as_primitive(&columns[1].get(0))
.with_context(|| {
InvalidFuncArgsSnafu {
err_msg: "The second arg (max) should not be none",
}
})?;
// For clamp_max, min is effectively -infinity, so we don't use it in the clamp_impl logic.
// We pass a default/dummy value for min.
let min_dummy = <$S as LogicalPrimitiveType>::Native::default();
clamp_impl::<$S, false, true>(input, min_dummy, max)
},{
unreachable!()
})
let Some(min) = ScalarValue::min(&max.data_type()) else {
return Err(DataFusionError::Internal(format!(
"cannot find a min value for numeric data type {}",
max.data_type()
)));
};
clamp_impl(col, ColumnarValue::Scalar(min), max)
}
}
@@ -345,55 +298,80 @@ mod test {
use std::sync::Arc;
use datatypes::prelude::ScalarVector;
use datatypes::vectors::{
ConstantVector, Float64Vector, Int64Vector, StringVector, UInt64Vector,
};
use arrow_schema::Field;
use datafusion_common::config::ConfigOptions;
use datatypes::arrow::array::{ArrayRef, Float64Array, Int64Array, UInt64Array};
use datatypes::arrow_array::StringArray;
use super::*;
use crate::function::FunctionContext;
macro_rules! impl_test_eval {
($func: ty) => {
impl $func {
fn test_eval(
&self,
args: Vec<ColumnarValue>,
number_rows: usize,
) -> datafusion_common::Result<ArrayRef> {
let input_type = args[0].data_type();
self.invoke_with_args(ScalarFunctionArgs {
args,
arg_fields: vec![],
number_rows,
return_field: Arc::new(Field::new("x", input_type, false)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]).map_err(Into::into))
.map(|mut a| a.remove(0))
}
}
};
}
impl_test_eval!(ClampFunction);
impl_test_eval!(ClampMinFunction);
impl_test_eval!(ClampMaxFunction);
#[test]
fn clamp_i64() {
let inputs = [
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
-1,
10,
-1i64,
10i64,
vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
),
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
0,
0,
0i64,
0i64,
vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
),
(
vec![Some(-3), None, Some(-1), None, None, Some(2)],
-2,
1,
-2i64,
1i64,
vec![Some(-2), None, Some(-1), None, None, Some(1)],
),
(
vec![None, None, None, None, None],
0,
1,
0i64,
1i64,
vec![None, None, None, None, None],
),
];
let func = ClampFunction;
for (in_data, min, max, expected) in inputs {
let args = [
Arc::new(Int64Vector::from(in_data)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -402,42 +380,41 @@ mod test {
let inputs = [
(
vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
1,
3,
1u64,
3u64,
vec![Some(1), Some(1), Some(2), Some(3), Some(3), Some(3)],
),
(
vec![Some(0), Some(1), Some(2), Some(3), Some(4), Some(5)],
0,
0,
0u64,
0u64,
vec![Some(0), Some(0), Some(0), Some(0), Some(0), Some(0)],
),
(
vec![Some(0), None, Some(2), None, None, Some(5)],
1,
3,
1u64,
3u64,
vec![Some(1), None, Some(2), None, None, Some(3)],
),
(
vec![None, None, None, None, None],
0,
1,
0u64,
1u64,
vec![None, None, None, None, None],
),
];
let func = ClampFunction;
for (in_data, min, max, expected) in inputs {
let args = [
Arc::new(UInt64Vector::from(in_data)) as _,
Arc::new(UInt64Vector::from_vec(vec![min])) as _,
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(UInt64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(UInt64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -472,38 +449,18 @@ mod test {
let func = ClampFunction;
for (in_data, min, max, expected) in inputs {
let args = [
Arc::new(Float64Vector::from(in_data)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
Arc::new(Float64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
#[test]
fn clamp_const_i32() {
let input = vec![Some(5)];
let min = 2;
let max = 4;
let func = ClampFunction;
let args = [
Arc::new(ConstantVector::new(Arc::new(Int64Vector::from(input)), 1)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(vec![Some(4)]));
assert_eq!(expected, result);
}
#[test]
fn clamp_invalid_min_max() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
@@ -511,28 +468,30 @@ mod test {
let max = -1.0;
let func = ClampFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
Arc::new(Float64Vector::from_vec(vec![max])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
#[test]
fn clamp_type_not_match() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
let min = -1;
let max = 10;
let min = -1i64;
let max = 10u64;
let func = ClampFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
Arc::new(UInt64Vector::from_vec(vec![max])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(min.into()),
ColumnarValue::Scalar(max.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
@@ -543,12 +502,13 @@ mod test {
let max = 1.0;
let func = ClampFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Float64Vector::from_vec(vec![min, max])) as _,
Arc::new(Float64Vector::from_vec(vec![max, min])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Array(Arc::new(Float64Array::from(vec![min, max]))),
ColumnarValue::Array(Arc::new(Float64Array::from(vec![max, min]))),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
@@ -558,11 +518,12 @@ mod test {
let min = -10.0;
let func = ClampFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(min.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
@@ -571,12 +532,13 @@ mod test {
let input = vec![Some("foo"), Some("foo"), Some("foo"), Some("foo")];
let func = ClampFunction;
let args = [
Arc::new(StringVector::from(input)) as _,
Arc::new(StringVector::from_vec(vec!["bar"])) as _,
Arc::new(StringVector::from_vec(vec!["baz"])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(StringArray::from(input))),
ColumnarValue::Scalar("bar".into()),
ColumnarValue::Scalar("baz".into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
@@ -585,27 +547,26 @@ mod test {
let inputs = [
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
-1,
-1i64,
vec![Some(-1), Some(-1), Some(-1), Some(0), Some(1), Some(2)],
),
(
vec![Some(-3), None, Some(-1), None, None, Some(2)],
-2,
-2i64,
vec![Some(-2), None, Some(-1), None, None, Some(2)],
),
];
let func = ClampMinFunction;
for (in_data, min, expected) in inputs {
let args = [
Arc::new(Int64Vector::from(in_data)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -614,27 +575,26 @@ mod test {
let inputs = [
(
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(2)],
1,
1i64,
vec![Some(-3), Some(-2), Some(-1), Some(0), Some(1), Some(1)],
),
(
vec![Some(-3), None, Some(-1), None, None, Some(2)],
0,
0i64,
vec![Some(-3), None, Some(-1), None, None, Some(0)],
),
];
let func = ClampMaxFunction;
for (in_data, max, expected) in inputs {
let args = [
Arc::new(Int64Vector::from(in_data)) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Int64Array::from(in_data))),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Int64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Int64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -648,15 +608,14 @@ mod test {
let func = ClampMinFunction;
for (in_data, min, expected) in inputs {
let args = [
Arc::new(Float64Vector::from(in_data)) as _,
Arc::new(Float64Vector::from_vec(vec![min])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
ColumnarValue::Scalar(min.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
@@ -670,43 +629,44 @@ mod test {
let func = ClampMaxFunction;
for (in_data, max, expected) in inputs {
let args = [
Arc::new(Float64Vector::from(in_data)) as _,
Arc::new(Float64Vector::from_vec(vec![max])) as _,
let number_rows = in_data.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(in_data))),
ColumnarValue::Scalar(max.into()),
];
let result = func
.eval(&FunctionContext::default(), args.as_slice())
.unwrap();
let expected: VectorRef = Arc::new(Float64Vector::from(expected));
assert_eq!(expected, result);
let result = func.test_eval(args, number_rows).unwrap();
let expected: ArrayRef = Arc::new(Float64Array::from(expected));
assert_eq!(expected.as_ref(), result.as_ref());
}
}
#[test]
fn clamp_min_type_not_match() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
let min = -1;
let min = -1i64;
let func = ClampMinFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Int64Vector::from_vec(vec![min])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(min.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
#[test]
fn clamp_max_type_not_match() {
let input = vec![Some(-3.0), Some(-2.0), Some(-1.0), Some(0.0), Some(1.0)];
let max = 1;
let max = 1i64;
let func = ClampMaxFunction;
let args = [
Arc::new(Float64Vector::from(input)) as _,
Arc::new(Int64Vector::from_vec(vec![max])) as _,
let number_rows = input.len();
let args = vec![
ColumnarValue::Array(Arc::new(Float64Array::from(input))),
ColumnarValue::Scalar(max.into()),
];
let result = func.eval(&FunctionContext::default(), args.as_slice());
let result = func.test_eval(args, number_rows);
assert!(result.is_err());
}
}

View File

@@ -65,6 +65,14 @@ impl ScalarUDFImpl for ScalarUdf {
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
let result = self.function.invoke_with_args(args.clone());
if !matches!(
result,
Err(datafusion_common::DataFusionError::NotImplemented(_))
) {
return result;
}
let columns = args
.args
.iter()

View File

@@ -28,7 +28,14 @@ mod vector_norm;
mod vector_sub;
mod vector_subvector;
use std::borrow::Cow;
use datafusion_common::{DataFusionError, Result, ScalarValue, utils};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
use crate::function_registry::FunctionRegistry;
use crate::scalars::vector::impl_conv::as_veclit;
pub(crate) struct VectorFunction;
impl VectorFunction {
@@ -59,3 +66,155 @@ impl VectorFunction {
registry.register_scalar(elem_product::ElemProductFunction);
}
}
// Use macro instead of function to "return" the reference to `ScalarValue` in the
// `ColumnarValue::Array` match arm.
macro_rules! try_get_scalar_value {
($col: ident, $i: ident) => {
match $col {
datafusion::logical_expr::ColumnarValue::Array(a) => {
&datafusion_common::ScalarValue::try_from_array(a.as_ref(), $i)?
}
datafusion::logical_expr::ColumnarValue::Scalar(v) => v,
}
};
}
pub(crate) fn ensure_same_length(values: &[&ColumnarValue]) -> Result<usize> {
if values.is_empty() {
return Ok(0);
}
let mut array_len = None;
for v in values {
array_len = match (v, array_len) {
(ColumnarValue::Array(a), None) => Some(a.len()),
(ColumnarValue::Array(a), Some(array_len)) => {
if array_len == a.len() {
Some(array_len)
} else {
return Err(DataFusionError::Internal(format!(
"Arguments has mixed length. Expected length: {array_len}, found length: {}",
a.len()
)));
}
}
(ColumnarValue::Scalar(_), array_len) => array_len,
}
}
// If array_len is none, it means there are only scalars, treat them each as 1 element array.
let array_len = array_len.unwrap_or(1);
Ok(array_len)
}
struct VectorCalculator<'a, F> {
name: &'a str,
func: F,
}
impl<F> VectorCalculator<'_, F>
where
F: Fn(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
{
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
let result = (self.func)(v0, v1)?;
return Ok(ColumnarValue::Scalar(result));
}
let len = ensure_same_length(&[arg0, arg1])?;
let mut results = Vec::with_capacity(len);
for i in 0..len {
let v0 = try_get_scalar_value!(arg0, i);
let v1 = try_get_scalar_value!(arg1, i);
results.push((self.func)(v0, v1)?);
}
let results = ScalarValue::iter_to_array(results.into_iter())?;
Ok(ColumnarValue::Array(results))
}
}
impl<F> VectorCalculator<'_, F>
where
F: Fn(&Option<Cow<[f32]>>, &Option<Cow<[f32]>>) -> Result<ScalarValue>,
{
fn invoke_with_vectors(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [arg0, arg1] = utils::take_function_args(self.name, &args.args)?;
if let (ColumnarValue::Scalar(v0), ColumnarValue::Scalar(v1)) = (arg0, arg1) {
let v0 = as_veclit(v0)?;
let v1 = as_veclit(v1)?;
let result = (self.func)(&v0, &v1)?;
return Ok(ColumnarValue::Scalar(result));
}
let len = ensure_same_length(&[arg0, arg1])?;
let mut results = Vec::with_capacity(len);
match (arg0, arg1) {
(ColumnarValue::Scalar(v0), ColumnarValue::Array(a1)) => {
let v0 = as_veclit(v0)?;
for i in 0..len {
let v1 = ScalarValue::try_from_array(a1, i)?;
let v1 = as_veclit(&v1)?;
results.push((self.func)(&v0, &v1)?);
}
}
(ColumnarValue::Array(a0), ColumnarValue::Scalar(v1)) => {
let v1 = as_veclit(v1)?;
for i in 0..len {
let v0 = ScalarValue::try_from_array(a0, i)?;
let v0 = as_veclit(&v0)?;
results.push((self.func)(&v0, &v1)?);
}
}
(ColumnarValue::Array(a0), ColumnarValue::Array(a1)) => {
for i in 0..len {
let v0 = ScalarValue::try_from_array(a0, i)?;
let v0 = as_veclit(&v0)?;
let v1 = ScalarValue::try_from_array(a1, i)?;
let v1 = as_veclit(&v1)?;
results.push((self.func)(&v0, &v1)?);
}
}
(ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => {
// unreachable because this arm has been separately dealt with above
unreachable!()
}
}
let results = ScalarValue::iter_to_array(results.into_iter())?;
Ok(ColumnarValue::Array(results))
}
}
impl<F> VectorCalculator<'_, F>
where
F: Fn(&ScalarValue) -> Result<ScalarValue>,
{
fn invoke_with_single_argument(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [arg0] = utils::take_function_args(self.name, &args.args)?;
let arg0 = match arg0 {
ColumnarValue::Scalar(v) => {
let result = (self.func)(v)?;
return Ok(ColumnarValue::Scalar(result));
}
ColumnarValue::Array(a) => a,
};
let len = arg0.len();
let mut results = Vec::with_capacity(len);
for i in 0..len {
let v = ScalarValue::try_from_array(arg0, i)?;
results.push((self.func)(&v)?);
}
let results = ScalarValue::iter_to_array(results.into_iter())?;
Ok(ColumnarValue::Array(results))
}
}

View File

@@ -17,7 +17,7 @@ use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::type_coercion::aggregates::BINARYS;
use datafusion_expr::{Signature, Volatility};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::types::vector_type_value_to_string;
use datatypes::value::Value;
@@ -41,7 +41,13 @@ impl Function for VectorToStringFunction {
}
fn signature(&self) -> Signature {
Signature::uniform(1, BINARYS.to_vec(), Volatility::Immutable)
Signature::one_of(
vec![
TypeSignature::Uniform(1, vec![DataType::BinaryView]),
TypeSignature::Uniform(1, BINARYS.to_vec()),
],
Volatility::Immutable,
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {

View File

@@ -19,20 +19,17 @@ mod l2sq;
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use common_query::error::Result;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
macro_rules! define_distance_function {
($StructName:ident, $display_name:expr, $similarity_method:path) => {
/// A function calculates the distance between two vectors.
#[derive(Debug, Clone, Default)]
@@ -54,59 +51,34 @@ macro_rules! define_distance_function {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
if v0.len() != v1.len() {
return Err(datafusion_common::DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
let size = arg0.len();
let mut result = Float32VectorBuilder::with_capacity(size);
if size == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..size {
let vec0 = match arg0_const.as_ref() {
Some(a) => Some(Cow::Borrowed(a.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let vec1 = match arg1_const.as_ref() {
Some(b) => Some(Cow::Borrowed(b.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
if let (Some(vec0), Some(vec1)) = (vec0, vec1) {
ensure!(
vec0.len() == vec1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the vectors must match to calculate distance, have: {} vs {}",
vec0.len(),
vec1.len()
),
}
);
// Checked if the length of the vectors match
let d = $similarity_method(vec0.as_ref(), vec1.as_ref());
result.push(Some(d));
let d = $similarity_method(v0, v1);
Some(d)
} else {
result.push_null();
}
}
None
};
Ok(ScalarValue::Float32(result))
};
return Ok(result.to_vector());
let calculator = $crate::scalars::vector::VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -115,7 +87,7 @@ macro_rules! define_distance_function {
write!(f, "{}", $display_name.to_ascii_uppercase())
}
}
}
};
}
define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos);
@@ -126,10 +98,29 @@ define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot);
mod tests {
use std::sync::Arc;
use datatypes::vectors::{BinaryVector, ConstantVector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringViewArray};
use datafusion::arrow::datatypes::Float32Type;
use datafusion_common::config::ConfigOptions;
use super::*;
fn test_invoke(func: &dyn Function, args: &[ArrayRef]) -> datafusion_common::Result<ArrayRef> {
let number_rows = args[0].len();
let args = ScalarFunctionArgs {
args: args
.iter()
.map(|x| ColumnarValue::Array(x.clone()))
.collect::<Vec<_>>(),
arg_fields: vec![],
number_rows,
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
config_options: Arc::new(ConfigOptions::new()),
};
func.invoke_with_args(args)
.and_then(|x| x.to_array(number_rows))
}
#[test]
fn test_distance_string_string() {
let funcs = [
@@ -139,36 +130,34 @@ mod tests {
];
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec![
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[0.0, 1.0]"),
Some("[1.0, 0.0]"),
None,
Some("[1.0, 0.0]"),
])) as VectorRef;
let vec2 = Arc::new(StringVector::from(vec![
]));
let vec2: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[0.0, 1.0]"),
Some("[0.0, 1.0]"),
Some("[0.0, 1.0]"),
None,
])) as VectorRef;
]));
let result = func
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
let result = func
.eval(&FunctionContext::default(), &[vec2, vec1])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}
@@ -181,37 +170,35 @@ mod tests {
];
for func in funcs {
let vec1 = Arc::new(BinaryVector::from(vec![
let vec1: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
None,
Some(vec![0, 0, 128, 63, 0, 0, 0, 0]),
])) as VectorRef;
let vec2 = Arc::new(BinaryVector::from(vec![
]));
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
// [0.0, 1.0]
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
None,
])) as VectorRef;
]));
let result = func
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
let result = func
.eval(&FunctionContext::default(), &[vec2, vec1])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}
@@ -224,115 +211,35 @@ mod tests {
];
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec![
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[0.0, 1.0]"),
Some("[1.0, 0.0]"),
None,
Some("[1.0, 0.0]"),
])) as VectorRef;
let vec2 = Arc::new(BinaryVector::from(vec![
]));
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter(vec![
// [0.0, 1.0]
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
None,
])) as VectorRef;
]));
let result = func
.eval(&FunctionContext::default(), &[vec1.clone(), vec2.clone()])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec1.clone(), vec2.clone()]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
let result = func
.eval(&FunctionContext::default(), &[vec2, vec1])
.unwrap();
let result = test_invoke(func.as_ref(), &[vec2, vec1]).unwrap();
let result = result.as_primitive::<Float32Type>();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(result.get(3).is_null());
}
}
#[test]
fn test_distance_const_string() {
let funcs = [
Box::new(CosDistanceFunction {}) as Box<dyn Function>,
Box::new(L2SqDistanceFunction {}) as Box<dyn Function>,
Box::new(DotProductFunction {}) as Box<dyn Function>,
];
for func in funcs {
let const_str = Arc::new(ConstantVector::new(
Arc::new(StringVector::from(vec!["[0.0, 1.0]"])),
4,
));
let vec1 = Arc::new(StringVector::from(vec![
Some("[0.0, 1.0]"),
Some("[1.0, 0.0]"),
None,
Some("[1.0, 0.0]"),
])) as VectorRef;
let vec2 = Arc::new(BinaryVector::from(vec![
// [0.0, 1.0]
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
Some(vec![0, 0, 0, 0, 0, 0, 128, 63]),
None,
])) as VectorRef;
let result = func
.eval(
&FunctionContext::default(),
&[const_str.clone(), vec1.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(!result.get(3).is_null());
let result = func
.eval(
&FunctionContext::default(),
&[vec1.clone(), const_str.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(result.get(2).is_null());
assert!(!result.get(3).is_null());
let result = func
.eval(
&FunctionContext::default(),
&[const_str.clone(), vec2.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(!result.get(2).is_null());
assert!(result.get(3).is_null());
let result = func
.eval(
&FunctionContext::default(),
&[vec2.clone(), const_str.clone()],
)
.unwrap();
assert!(!result.get(0).is_null());
assert!(!result.get(1).is_null());
assert!(!result.get(2).is_null());
assert!(result.get(3).is_null());
assert!(!result.is_null(0));
assert!(!result.is_null(1));
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}
@@ -345,15 +252,16 @@ mod tests {
];
for func in funcs {
let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef;
let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef;
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
let vec1: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0]"]));
let vec2: ArrayRef = Arc::new(StringViewArray::from(vec!["[1.0, 1.0]"]));
let result = test_invoke(func.as_ref(), &[vec1, vec2]);
assert!(result.is_err());
let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef;
let vec2 =
Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef;
let result = func.eval(&FunctionContext::default(), &[vec1, vec2]);
let vec1: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![0, 0, 128, 63]]));
let vec2: ArrayRef = Arc::new(BinaryArray::from_iter_values(vec![vec![
0, 0, 128, 63, 0, 0, 0, 64,
]]));
let result = test_invoke(func.as_ref(), &[vec1, vec2]);
assert!(result.is_err());
}
}

View File

@@ -12,20 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
use crate::function::Function;
use crate::scalars::vector::{VectorCalculator, impl_conv};
const NAME: &str = "vec_elem_product";
@@ -64,43 +62,21 @@ impl Function for ElemProductFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v0 = impl_conv::as_veclit(v0)?
.map(|v0| DVectorView::from_slice(&v0, v0.len()).product());
Ok(ScalarValue::Float32(v0))
};
let len = arg0.len();
let mut result = Float32VectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
};
result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).product()));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_single_argument(args)
}
}
@@ -114,27 +90,39 @@ impl Display for ElemProductFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringArray};
use datafusion::arrow::datatypes::Float32Type;
use datafusion_common::config::ConfigOptions;
use super::*;
use crate::function::FunctionContext;
#[test]
fn test_elem_product() {
let func = ElemProductFunction;
let input0 = Arc::new(StringVector::from(vec![
let input = Arc::new(StringArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
]));
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let result = func
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input.clone())],
arg_fields: vec![],
number_rows: input.len(),
return_field: Arc::new(Field::new("x", DataType::Float32, true)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let result = result.as_primitive::<Float32Type>();
let result = result.as_ref();
assert_eq!(result.len(), 3);
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(120.0));
assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
assert_eq!(result.value(0), 6.0);
assert_eq!(result.value(1), 120.0);
assert!(result.is_null(2));
}
}

View File

@@ -12,20 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;
use datafusion_expr::type_coercion::aggregates::{BINARYS, STRINGS};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
use crate::function::Function;
use crate::scalars::vector::{VectorCalculator, impl_conv};
const NAME: &str = "vec_elem_sum";
@@ -51,43 +49,21 @@ impl Function for ElemSumFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v0 =
impl_conv::as_veclit(v0)?.map(|v0| DVectorView::from_slice(&v0, v0.len()).sum());
Ok(ScalarValue::Float32(v0))
};
let len = arg0.len();
let mut result = Float32VectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
};
result.push(Some(DVectorView::from_slice(&arg0, arg0.len()).sum()));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_single_argument(args)
}
}
@@ -101,27 +77,40 @@ impl Display for ElemSumFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::StringVector;
use arrow::array::StringViewArray;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray};
use datafusion::arrow::datatypes::Float32Type;
use datafusion_common::config::ConfigOptions;
use super::*;
use crate::function::FunctionContext;
#[test]
fn test_elem_sum() {
let func = ElemSumFunction;
let input0 = Arc::new(StringVector::from(vec![
let input = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
]));
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let result = func
.invoke_with_args(ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input.clone())],
arg_fields: vec![],
number_rows: input.len(),
return_field: Arc::new(Field::new("x", DataType::Float32, true)),
config_options: Arc::new(ConfigOptions::new()),
})
.and_then(|v| ColumnarValue::values_to_arrays(&[v]))
.map(|mut a| a.remove(0))
.unwrap();
let result = result.as_primitive::<Float32Type>();
let result = result.as_ref();
assert_eq!(result.len(), 3);
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(6.0));
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(15.0));
assert_eq!(result.get_ref(2).as_f32().unwrap(), None);
assert_eq!(result.value(0), 6.0);
assert_eq!(result.value(1), 15.0);
assert!(result.is_null(2));
}
}

View File

@@ -13,40 +13,18 @@
// limitations under the License.
use std::borrow::Cow;
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datatypes::prelude::ConcreteDataType;
use datatypes::value::ValueRef;
use datatypes::vectors::Vector;
/// Convert a constant string or binary literal to a vector literal.
pub fn as_veclit_if_const(arg: &Arc<dyn Vector>) -> Result<Option<Cow<'_, [f32]>>> {
if !arg.is_const() {
return Ok(None);
}
if arg.data_type() != ConcreteDataType::string_datatype()
&& arg.data_type() != ConcreteDataType::binary_datatype()
{
return Ok(None);
}
as_veclit(arg.get_ref(0))
}
use datafusion_common::ScalarValue;
/// Convert a string or binary literal to a vector literal.
pub fn as_veclit(arg: ValueRef<'_>) -> Result<Option<Cow<'_, [f32]>>> {
match arg.data_type() {
ConcreteDataType::Binary(_) => arg
.as_binary()
.unwrap() // Safe: checked if it is a binary
.map(binlit_as_veclit)
pub fn as_veclit(arg: &ScalarValue) -> Result<Option<Cow<'_, [f32]>>> {
match arg {
ScalarValue::Binary(b) => b.as_ref().map(|x| binlit_as_veclit(x)).transpose(),
ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) => s
.as_ref()
.map(|x| parse_veclit_from_strlit(x).map(Cow::Owned))
.transpose(),
ConcreteDataType::String(_) => arg
.as_string()
.unwrap() // Safe: checked if it is a string
.map(|s| Ok(Cow::Owned(parse_veclit_from_strlit(s)?)))
.transpose(),
ConcreteDataType::Null(_) => Ok(None),
_ => InvalidFuncArgsSnafu {
err_msg: format!("Unsupported data type: {:?}", arg.data_type()),
}

View File

@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
const NAME: &str = "vec_scalar_add";
@@ -60,7 +59,7 @@ impl Function for ScalarAddFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -70,52 +69,26 @@ impl Function for ScalarAddFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = arg0.get(i).as_f64_lossy();
let Some(arg0) = arg0 else {
result.push_null();
continue;
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let ScalarValue::Float64(Some(v0)) = v0 else {
return Ok(ScalarValue::BinaryView(None));
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
let Some(arg1) = arg1 else {
result.push_null();
continue;
};
let v1 = as_veclit(v1)?
.map(|v1| DVectorView::from_slice(&v1, v1.len()).add_scalar(*v0 as f32));
let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
Ok(ScalarValue::BinaryView(result))
};
let vec = DVectorView::from_slice(&arg1, arg1.len());
let vec_res = vec.add_scalar(arg0 as _);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_args(args)
}
}
@@ -129,7 +102,9 @@ impl Display for ScalarAddFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::{Float32Vector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -137,34 +112,42 @@ mod tests {
fn test_scalar_add() {
let func = ScalarAddFunction;
let input0 = Arc::new(Float32Vector::from(vec![
let input0 = Arc::new(Float64Array::from(vec![
Some(1.0),
Some(-1.0),
None,
Some(3.0),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
result.value(0),
veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice())
result.value(1),
veclit_to_binlit(&[3.0, 4.0, 5.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}

View File

@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
const NAME: &str = "vec_scalar_mul";
@@ -60,7 +59,7 @@ impl Function for ScalarMulFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -70,52 +69,26 @@ impl Function for ScalarMulFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = arg0.get(i).as_f64_lossy();
let Some(arg0) = arg0 else {
result.push_null();
continue;
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let ScalarValue::Float64(Some(v0)) = v0 else {
return Ok(ScalarValue::BinaryView(None));
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
let Some(arg1) = arg1 else {
result.push_null();
continue;
};
let v1 =
as_veclit(v1)?.map(|v1| DVectorView::from_slice(&v1, v1.len()).scale(*v0 as f32));
let result = v1.map(|v1| veclit_to_binlit(v1.as_slice()));
Ok(ScalarValue::BinaryView(result))
};
let vec = DVectorView::from_slice(&arg1, arg1.len());
let vec_res = vec.scale(arg0 as _);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_args(args)
}
}
@@ -129,7 +102,9 @@ impl Display for ScalarMulFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::{Float32Vector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, Float64Array, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -137,34 +112,42 @@ mod tests {
fn test_scalar_mul() {
let func = ScalarMulFunction;
let input0 = Arc::new(Float32Vector::from(vec![
let input0 = Arc::new(Float64Array::from(vec![
Some(2.0),
Some(-0.5),
None,
Some(3.0),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[8.0,10.0,12.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice())
result.value(0),
veclit_to_binlit(&[2.0, 4.0, 6.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice())
result.value(1),
veclit_to_binlit(&[-4.0, -5.0, -6.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}

View File

@@ -15,17 +15,17 @@
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::veclit_to_binlit;
const NAME: &str = "vec_add";
@@ -51,7 +51,7 @@ impl Function for VectorAddFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -61,66 +61,36 @@ impl Function for VectorAddFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
let v0 = DVectorView::from_slice(v0, v0.len());
let v1 = DVectorView::from_slice(v1, v1.len());
if v0.len() != v1.len() {
return Err(DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
ensure!(
arg0.len() == arg1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The lengths of the vector are not aligned, args 0: {}, args 1: {}",
arg0.len(),
arg1.len(),
)
}
);
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
let result = veclit_to_binlit((v0 + v1).as_slice());
Some(result)
} else {
None
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
result.push_null();
continue;
};
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
Ok(ScalarValue::BinaryView(result))
};
let vec_res = vec0 + vec1;
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -134,8 +104,9 @@ impl Display for VectorAddFunction {
mod tests {
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -143,63 +114,71 @@ mod tests {
fn test_sub() {
let func = VectorAddFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
None,
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice())
result.value(0),
veclit_to_binlit(&[2.0, 3.0, 4.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice())
result.value(1),
veclit_to_binlit(&[10.0, 10.0, 10.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
#[test]
fn test_sub_error() {
let func = VectorAddFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
]));
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"The lengths of the vector are not aligned, args 0: 4, args 1: 3"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with(
"Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
));
}
}

View File

@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{MutableVector, UInt64VectorBuilder, VectorRef};
use snafu::ensure;
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
use crate::function::Function;
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::as_veclit;
const NAME: &str = "vec_dim";
@@ -63,43 +62,20 @@ impl Function for VectorDimFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v = as_veclit(v0)?.map(|v0| v0.len() as u64);
Ok(ScalarValue::UInt64(v))
};
let len = arg0.len();
let mut result = UInt64VectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
};
result.push(Some(arg0.len() as u64));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_single_argument(args)
}
}
@@ -113,8 +89,10 @@ impl Display for VectorDimFunction {
mod tests {
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion::arrow::datatypes::UInt64Type;
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -122,49 +100,60 @@ mod tests {
fn test_vec_dim() {
let func = VectorDimFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[0.0,2.0,3.0]".to_string()),
Some("[1.0,2.0,3.0,4.0]".to_string()),
None,
Some("[5.0]".to_string()),
]));
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_primitive::<UInt64Type>();
assert_eq!(result.len(), 4);
assert_eq!(result.get_ref(0).as_u64().unwrap(), Some(3));
assert_eq!(result.get_ref(1).as_u64().unwrap(), Some(4));
assert!(result.get_ref(2).is_null());
assert_eq!(result.get_ref(3).as_u64().unwrap(), Some(1));
assert_eq!(result.value(0), 3);
assert_eq!(result.value(1), 4);
assert!(result.is_null(2));
assert_eq!(result.value(3), 1);
}
#[test]
fn test_dim_error() {
let func = VectorDimFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
]));
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"The length of the args is not correct, expect exactly one, have: 2"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::UInt64, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(
e.to_string()
.starts_with("Execution error: vec_dim function requires 1 argument, got 2")
)
}
}

View File

@@ -15,17 +15,17 @@
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::veclit_to_binlit;
const NAME: &str = "vec_div";
@@ -52,7 +52,7 @@ impl Function for VectorDivFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -62,64 +62,36 @@ impl Function for VectorDivFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
let v0 = DVectorView::from_slice(v0, v0.len());
let v1 = DVectorView::from_slice(v1, v1.len());
if v0.len() != v1.len() {
return Err(DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
let arg0 = &columns[0];
let arg1 = &columns[1];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
ensure!(
arg0.len() == arg1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the vectors must match for division, have: {} vs {}",
arg0.len(),
arg1.len()
),
}
);
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
let vec_res = vec0.component_div(&vec1);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
let result = veclit_to_binlit((v0.component_div(&v1)).as_slice());
Some(result)
} else {
result.push_null();
}
}
None
};
Ok(ScalarValue::BinaryView(result))
};
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -133,8 +105,9 @@ impl Display for VectorDivFunction {
mod tests {
use std::sync::Arc;
use common_query::error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -144,69 +117,80 @@ mod tests {
let vec0 = vec![1.0, 2.0, 3.0];
let vec1 = vec![1.0, 1.0];
let (len0, len1) = (vec0.len(), vec1.len());
let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
let err = func
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert_eq!(
e.to_string(),
"Execution error: vectors length not match: vec_div"
);
match err {
error::Error::InvalidFuncArgs { err_msg, .. } => {
assert_eq!(
err_msg,
format!(
"The length of the vectors must match for division, have: {} vs {}",
len0, len1
)
)
}
_ => unreachable!(),
}
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[8.0,10.0,12.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[2.0,2.0,2.0]".to_string()),
None,
Some("[3.0,3.0,3.0]".to_string()),
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
result.value(0),
veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice())
result.value(1),
veclit_to_binlit(&[4.0, 5.0, 6.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,-2.0]".to_string())]));
let input1 = Arc::new(StringVector::from(vec![Some("[0.0,0.0]".to_string())]));
let input0 = Arc::new(StringViewArray::from(vec![Some("[1.0,-2.0]".to_string())]));
let input1 = Arc::new(StringViewArray::from(vec![Some("[0.0,0.0]".to_string())]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 2,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(2))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice())
result.value(0),
veclit_to_binlit(&[f64::INFINITY as f32, f64::NEG_INFINITY as f32]).as_slice()
);
}
}

View File

@@ -12,19 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use common_query::error::Result;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef};
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::as_veclit;
const NAME: &str = "vec_kth_elem";
@@ -63,72 +62,44 @@ impl Function for VectorKthElemFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue, v1: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v0 = as_veclit(v0)?;
let arg0 = &columns[0];
let arg1 = &columns[1];
let v1 = match v1 {
ScalarValue::Int64(None) => return Ok(ScalarValue::Float32(None)),
ScalarValue::Int64(Some(v1)) if *v1 >= 0 => *v1 as usize,
_ => {
return Err(DataFusionError::Execution(format!(
"2nd argument not a valid index or expected datatype: {}",
self.name()
)));
}
};
let len = arg0.len();
let mut result = Float32VectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
let result = v0
.map(|v0| {
if v1 >= v0.len() {
Err(DataFusionError::Execution(format!(
"index out of bound: {}",
self.name()
)))
} else {
Ok(v0[v1])
}
})
.transpose()?;
Ok(ScalarValue::Float32(result))
};
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
};
let arg1 = arg1.get(i).as_f64_lossy();
let Some(arg1) = arg1 else {
result.push_null();
continue;
};
ensure!(
arg1 >= 0.0 && arg1.fract() == 0.0,
InvalidFuncArgsSnafu {
err_msg: format!(
"Invalid argument: k must be a non-negative integer, but got k = {}.",
arg1
),
}
);
let k = arg1 as usize;
ensure!(
k < arg0.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"Out of range: k must be in the range [0, {}], but got k = {}.",
arg0.len() - 1,
k
),
}
);
let value = arg0[k];
result.push(Some(value));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_args(args)
}
}
@@ -142,8 +113,10 @@ impl Display for VectorKthElemFunction {
mod tests {
use std::sync::Arc;
use common_query::error;
use datatypes::vectors::{Int64Vector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, Int64Array, StringViewArray};
use datafusion::arrow::datatypes::Float32Type;
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -151,55 +124,66 @@ mod tests {
fn test_vec_kth_elem() {
let func = VectorKthElemFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let input1 = Arc::new(Int64Vector::from(vec![Some(0), Some(2), None, Some(1)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(0), Some(2), None, Some(1)]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_primitive::<Float32Type>();
assert_eq!(result.len(), 4);
assert_eq!(result.get_ref(0).as_f32().unwrap(), Some(1.0));
assert_eq!(result.get_ref(1).as_f32().unwrap(), Some(6.0));
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert_eq!(result.value(0), 1.0);
assert_eq!(result.value(1), 6.0);
assert!(result.is_null(2));
assert!(result.is_null(3));
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
let input1 = Arc::new(Int64Vector::from(vec![Some(3)]));
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
"[1.0,2.0,3.0]".to_string(),
)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
let err = func
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
match err {
error::Error::InvalidFuncArgs { err_msg, .. } => {
assert_eq!(
err_msg,
format!("Out of range: k must be in the range [0, 2], but got k = 3.")
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(
e.to_string()
.starts_with("Execution error: index out of bound: vec_kth_elem")
);
let input0 = Arc::new(StringVector::from(vec![Some("[1.0,2.0,3.0]".to_string())]));
let input1 = Arc::new(Int64Vector::from(vec![Some(-1)]));
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![Some(
"[1.0,2.0,3.0]".to_string(),
)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(-1)]));
let err = func
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
match err {
error::Error::InvalidFuncArgs { err_msg, .. } => {
assert_eq!(
err_msg,
format!("Invalid argument: k must be a non-negative integer, but got k = -1.")
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::Float32, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with(
"Execution error: 2nd argument not a valid index or expected datatype: vec_kth_elem"
));
}
}

View File

@@ -15,17 +15,17 @@
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::veclit_to_binlit;
const NAME: &str = "vec_mul";
@@ -52,7 +52,7 @@ impl Function for VectorMulFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -62,64 +62,36 @@ impl Function for VectorMulFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
),
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
let v0 = DVectorView::from_slice(v0, v0.len());
let v1 = DVectorView::from_slice(v1, v1.len());
if v0.len() != v1.len() {
return Err(DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
let arg0 = &columns[0];
let arg1 = &columns[1];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
if let (Some(arg0), Some(arg1)) = (arg0, arg1) {
ensure!(
arg0.len() == arg1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the vectors must match for multiplying, have: {} vs {}",
arg0.len(),
arg1.len()
),
}
);
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
let vec_res = vec1.component_mul(&vec0);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
let result = veclit_to_binlit((v0.component_mul(&v1)).as_slice());
Some(result)
} else {
result.push_null();
}
}
None
};
Ok(ScalarValue::BinaryView(result))
};
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -133,8 +105,9 @@ impl Display for VectorMulFunction {
mod tests {
use std::sync::Arc;
use common_query::error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -144,56 +117,59 @@ mod tests {
let vec0 = vec![1.0, 2.0, 3.0];
let vec1 = vec![1.0, 1.0];
let (len0, len1) = (vec0.len(), vec1.len());
let input0 = Arc::new(StringVector::from(vec![Some(format!("{vec0:?}"))]));
let input1 = Arc::new(StringVector::from(vec![Some(format!("{vec1:?}"))]));
let input0 = Arc::new(StringViewArray::from(vec![Some(format!("{vec0:?}"))]));
let input1 = Arc::new(StringViewArray::from(vec![Some(format!("{vec1:?}"))]));
let err = func
.eval(&FunctionContext::default(), &[input0, input1])
.unwrap_err();
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(
e.to_string()
.starts_with("Execution error: vectors length not match: vec_mul")
);
match err {
error::Error::InvalidFuncArgs { err_msg, .. } => {
assert_eq!(
err_msg,
format!(
"The length of the vectors must match for multiplying, have: {} vs {}",
len0, len1
)
)
}
_ => unreachable!(),
}
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[8.0,10.0,12.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
None,
]));
let input1 = Arc::new(StringVector::from(vec![
let input1 = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[2.0,2.0,2.0]".to_string()),
None,
Some("[3.0,3.0,3.0]".to_string()),
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice())
result.value(0),
veclit_to_binlit(&[1.0, 2.0, 3.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice())
result.value(1),
veclit_to_binlit(&[16.0, 20.0, 24.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
}

View File

@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion::logical_expr_common::type_coercion::aggregates::{BINARYS, STRINGS};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use datafusion_common::ScalarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::function::Function;
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
const NAME: &str = "vec_norm";
@@ -53,7 +52,7 @@ impl Function for VectorNormFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -66,55 +65,27 @@ impl Function for VectorNormFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let Some(arg0) = arg0 else {
result.push_null();
continue;
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &ScalarValue| -> datafusion_common::Result<ScalarValue> {
let v0 = as_veclit(v0)?;
let Some(v0) = v0 else {
return Ok(ScalarValue::BinaryView(None));
};
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg0, arg0.len());
let vec2scalar = vec1.component_mul(&vec0);
let scalar_var = vec2scalar.sum().sqrt();
let v0 = DVectorView::from_slice(&v0, v0.len());
let result =
veclit_to_binlit(v0.unscale(v0.component_mul(&v0).sum().sqrt()).as_slice());
Ok(ScalarValue::BinaryView(Some(result)))
};
let vec = DVectorView::from_slice(&arg0, arg0.len());
// Use unscale to avoid division by zero and keep more precision as possible
let vec_res = vec.unscale(scalar_var);
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_single_argument(args)
}
}
@@ -128,7 +99,9 @@ impl Display for VectorNormFunction {
mod tests {
use std::sync::Arc;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -136,7 +109,7 @@ mod tests {
fn test_vec_norm() {
let func = VectorNormFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[0.0,2.0,3.0]".to_string()),
Some("[1.0,2.0,3.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
@@ -144,26 +117,36 @@ mod tests {
None,
]));
let result = func.eval(&FunctionContext::default(), &[input0]).unwrap();
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0)],
arg_fields: vec![],
number_rows: 5,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.invoke_with_args(args)
.and_then(|x| x.to_array(5))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 5);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice())
result.value(0),
veclit_to_binlit(&[0.0, 0.5547002, 0.8320503]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice())
result.value(1),
veclit_to_binlit(&[0.26726124, 0.5345225, 0.8017837]).as_slice()
);
assert_eq!(
result.get_ref(2).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice())
result.value(2),
veclit_to_binlit(&[0.5025707, 0.5743665, 0.64616233]).as_slice()
);
assert_eq!(
result.get_ref(3).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice())
result.value(3),
veclit_to_binlit(&[0.5025707, -0.5743665, 0.64616233]).as_slice()
);
assert!(result.get_ref(4).is_null());
assert!(result.is_null(4));
}
}

View File

@@ -15,17 +15,17 @@
use std::borrow::Cow;
use std::fmt::Display;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::Signature;
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use common_query::error::Result;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_expr::{ScalarFunctionArgs, Signature};
use nalgebra::DVectorView;
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::function::Function;
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::scalars::vector::VectorCalculator;
use crate::scalars::vector::impl_conv::veclit_to_binlit;
const NAME: &str = "vec_sub";
@@ -51,7 +51,7 @@ impl Function for VectorSubFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -61,66 +61,36 @@ impl Function for VectorSubFunction {
)
}
fn eval(
fn invoke_with_args(
&self,
_func_ctx: &FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let body = |v0: &Option<Cow<[f32]>>,
v1: &Option<Cow<[f32]>>|
-> datafusion_common::Result<ScalarValue> {
let result = if let (Some(v0), Some(v1)) = (v0, v1) {
let v0 = DVectorView::from_slice(v0, v0.len());
let v1 = DVectorView::from_slice(v1, v1.len());
if v0.len() != v1.len() {
return Err(DataFusionError::Execution(format!(
"vectors length not match: {}",
self.name()
)));
}
ensure!(
arg0.len() == arg1.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The lengths of the vector are not aligned, args 0: {}, args 1: {}",
arg0.len(),
arg1.len(),
)
}
);
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}
let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
let result = veclit_to_binlit((v0 - v1).as_slice());
Some(result)
} else {
None
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
result.push_null();
continue;
};
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg1, arg1.len());
Ok(ScalarValue::BinaryView(result))
};
let vec_res = vec0 - vec1;
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}
Ok(result.to_vector())
let calculator = VectorCalculator {
name: self.name(),
func: body,
};
calculator.invoke_with_vectors(args)
}
}
@@ -134,8 +104,9 @@ impl Display for VectorSubFunction {
mod tests {
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::StringVector;
use arrow_schema::Field;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
@@ -143,63 +114,71 @@ mod tests {
fn test_sub() {
let func = VectorSubFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
None,
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1])
.invoke_with_args(args)
.and_then(|x| x.to_array(4))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice())
result.value(0),
veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice()
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice())
result.value(1),
veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice()
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
assert!(result.is_null(2));
assert!(result.is_null(3));
}
#[test]
fn test_sub_error() {
let func = VectorSubFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
let input1: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
]));
let result = func.eval(&FunctionContext::default(), &[input0, input1]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"The lengths of the vector are not aligned, args 0: 4, args 1: 3"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(input0), ColumnarValue::Array(input1)],
arg_fields: vec![],
number_rows: 4,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with(
"Internal error: Arguments has mixed length. Expected length: 4, found length: 3."
));
}
}

View File

@@ -12,18 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::fmt::Display;
use std::sync::Arc;
use common_query::error::{InvalidFuncArgsSnafu, Result};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datafusion::arrow::array::{Array, AsArray, BinaryViewBuilder};
use datafusion::arrow::datatypes::Int64Type;
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{ScalarValue, utils};
use datafusion_expr::{ScalarFunctionArgs, Signature, TypeSignature, Volatility};
use datatypes::arrow::datatypes::DataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use snafu::ensure;
use crate::function::{Function, FunctionContext};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
use crate::function::Function;
use crate::scalars::vector::impl_conv::{as_veclit, veclit_to_binlit};
const NAME: &str = "vec_subvector";
@@ -52,7 +54,7 @@ impl Function for VectorSubvectorFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Binary)
Ok(DataType::BinaryView)
}
fn signature(&self) -> Signature {
@@ -65,50 +67,28 @@ impl Function for VectorSubvectorFunction {
)
}
fn eval(&self, _func_ctx: &FunctionContext, columns: &[VectorRef]) -> Result<VectorRef> {
ensure!(
columns.len() == 3,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly three, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];
let arg2 = &columns[2];
ensure!(
arg0.len() == arg1.len() && arg1.len() == arg2.len(),
InvalidFuncArgsSnafu {
err_msg: format!(
"The lengths of the vector are not aligned, args 0: {}, args 1: {}, args 2: {}",
arg0.len(),
arg1.len(),
arg2.len()
)
}
);
fn invoke_with_args(
&self,
args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;
let [arg0, arg1, arg2] = utils::take_function_args(self.name(), args)?;
let arg1 = arg1.as_primitive::<Int64Type>();
let arg2 = arg2.as_primitive::<Int64Type>();
let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
let mut builder = BinaryViewBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
}
let arg0_const = as_veclit_if_const(arg0)?;
for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let arg1 = arg1.get(i).as_i64();
let arg2 = arg2.get(i).as_i64();
let v = ScalarValue::try_from_array(&arg0, i)?;
let arg0 = as_veclit(&v)?;
let arg1 = arg1.is_valid(i).then(|| arg1.value(i));
let arg2 = arg2.is_valid(i).then(|| arg2.value(i));
let (Some(arg0), Some(arg1), Some(arg2)) = (arg0, arg1, arg2) else {
result.push_null();
builder.append_null();
continue;
};
@@ -126,10 +106,10 @@ impl Function for VectorSubvectorFunction {
let subvector = &arg0[arg1 as usize..arg2 as usize];
let binlit = veclit_to_binlit(subvector);
result.push(Some(&binlit));
builder.append_value(&binlit);
}
Ok(result.to_vector())
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
}
}
@@ -143,89 +123,102 @@ impl Display for VectorSubvectorFunction {
mod tests {
use std::sync::Arc;
use common_query::error::Error;
use datatypes::vectors::{Int64Vector, StringVector};
use arrow_schema::Field;
use datafusion::arrow::array::{ArrayRef, Int64Array, StringViewArray};
use datafusion_common::config::ConfigOptions;
use super::*;
use crate::function::FunctionContext;
#[test]
fn test_subvector() {
let func = VectorSubvectorFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0, 2.0, 3.0, 4.0, 5.0]".to_string()),
Some("[6.0, 7.0, 8.0, 9.0, 10.0]".to_string()),
None,
Some("[11.0, 12.0, 13.0]".to_string()),
]));
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(0), Some(0), Some(1)]));
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(5), Some(2), Some(3)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(0), Some(0), Some(1)]));
let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(5), Some(2), Some(3)]));
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(input0),
ColumnarValue::Array(input1),
ColumnarValue::Array(input2),
],
arg_fields: vec![],
number_rows: 5,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let result = func
.eval(&FunctionContext::default(), &[input0, input1, input2])
.invoke_with_args(args)
.and_then(|x| x.to_array(5))
.unwrap();
let result = result.as_ref();
let result = result.as_binary_view();
assert_eq!(result.len(), 4);
assert_eq!(result.value(0), veclit_to_binlit(&[2.0, 3.0]).as_slice());
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[2.0, 3.0]).as_slice())
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice())
);
assert!(result.get_ref(2).is_null());
assert_eq!(
result.get_ref(3).as_binary().unwrap(),
Some(veclit_to_binlit(&[12.0, 13.0]).as_slice())
result.value(1),
veclit_to_binlit(&[6.0, 7.0, 8.0, 9.0, 10.0]).as_slice()
);
assert!(result.is_null(2));
assert_eq!(result.value(3), veclit_to_binlit(&[12.0, 13.0]).as_slice());
}
#[test]
fn test_subvector_error() {
let func = VectorSubvectorFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0: ArrayRef = Arc::new(StringViewArray::from(vec![
Some("[1.0, 2.0, 3.0]".to_string()),
Some("[4.0, 5.0, 6.0]".to_string()),
]));
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(2)]));
let input2 = Arc::new(Int64Vector::from(vec![Some(3)]));
let input1: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2)]));
let input2: ArrayRef = Arc::new(Int64Array::from(vec![Some(3)]));
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"The lengths of the vector are not aligned, args 0: 2, args 1: 2, args 2: 1"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(input0),
ColumnarValue::Array(input1),
ColumnarValue::Array(input2),
],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with(
"Internal error: Arguments has mixed length. Expected length: 2, found length: 1."
));
}
#[test]
fn test_subvector_invalid_indices() {
let func = VectorSubvectorFunction;
let input0 = Arc::new(StringVector::from(vec![
let input0 = Arc::new(StringViewArray::from(vec![
Some("[1.0, 2.0, 3.0]".to_string()),
Some("[4.0, 5.0, 6.0]".to_string()),
]));
let input1 = Arc::new(Int64Vector::from(vec![Some(1), Some(3)]));
let input2 = Arc::new(Int64Vector::from(vec![Some(3), Some(4)]));
let input1 = Arc::new(Int64Array::from(vec![Some(1), Some(3)]));
let input2 = Arc::new(Int64Array::from(vec![Some(3), Some(4)]));
let result = func.eval(&FunctionContext::default(), &[input0, input1, input2]);
match result {
Err(Error::InvalidFuncArgs { err_msg, .. }) => {
assert_eq!(
err_msg,
"Invalid start and end indices: start=3, end=4, vec_len=3"
)
}
_ => unreachable!(),
}
let args = ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(input0),
ColumnarValue::Array(input1),
ColumnarValue::Array(input2),
],
arg_fields: vec![],
number_rows: 3,
return_field: Arc::new(Field::new("x", DataType::BinaryView, false)),
config_options: Arc::new(ConfigOptions::new()),
};
let e = func.invoke_with_args(args).unwrap_err();
assert!(e.to_string().starts_with("External error: Invalid function args: Invalid start and end indices: start=3, end=4, vec_len=3"));
}
}

View File

@@ -37,8 +37,7 @@ impl FunctionState {
use catalog::CatalogManagerRef;
use common_base::AffectedRows;
use common_meta::rpc::procedure::{
AddRegionFollowerRequest, MigrateRegionRequest, ProcedureStateResponse,
RemoveRegionFollowerRequest,
ManageRegionFollowerRequest, MigrateRegionRequest, ProcedureStateResponse,
};
use common_query::Output;
use common_query::error::Result;
@@ -75,13 +74,9 @@ impl FunctionState {
})
}
async fn add_region_follower(&self, _request: AddRegionFollowerRequest) -> Result<()> {
Ok(())
}
async fn remove_region_follower(
async fn manage_region_follower(
&self,
_request: RemoveRegionFollowerRequest,
_request: ManageRegionFollowerRequest,
) -> Result<()> {
Ok(())
}

View File

@@ -16,11 +16,12 @@ use std::fmt;
use std::sync::Arc;
use common_query::error::Result;
use datafusion::arrow::array::StringViewArray;
use datafusion::arrow::datatypes::DataType;
use datafusion_expr::{Signature, Volatility};
use datatypes::vectors::{StringVector, VectorRef};
use datafusion::logical_expr::ColumnarValue;
use datafusion_expr::{ScalarFunctionArgs, Signature, Volatility};
use crate::function::{Function, FunctionContext};
use crate::function::Function;
/// Generates build information
#[derive(Clone, Debug, Default)]
@@ -38,17 +39,18 @@ impl Function for BuildFunction {
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
Ok(DataType::Utf8View)
}
fn signature(&self) -> Signature {
Signature::nullary(Volatility::Immutable)
}
fn eval(&self, _func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef> {
fn invoke_with_args(&self, _: ScalarFunctionArgs) -> datafusion_common::Result<ColumnarValue> {
let build_info = common_version::build_info().to_string();
let v = Arc::new(StringVector::from(vec![build_info]));
Ok(v)
Ok(ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
build_info,
]))))
}
}
@@ -56,16 +58,29 @@ impl Function for BuildFunction {
mod tests {
use std::sync::Arc;
use arrow_schema::Field;
use datafusion::arrow::array::ArrayRef;
use datafusion_common::config::ConfigOptions;
use super::*;
#[test]
fn test_build_function() {
let build = BuildFunction;
assert_eq!("build", build.name());
assert_eq!(DataType::Utf8, build.return_type(&[]).unwrap());
assert_eq!(DataType::Utf8View, build.return_type(&[]).unwrap());
assert_eq!(build.signature(), Signature::nullary(Volatility::Immutable));
let build_info = common_version::build_info().to_string();
let vector = build.eval(&FunctionContext::default(), &[]).unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec![build_info]));
assert_eq!(expect, vector);
let actual = build
.invoke_with_args(ScalarFunctionArgs {
args: vec![],
arg_fields: vec![],
number_rows: 0,
return_field: Arc::new(Field::new("x", DataType::Utf8View, false)),
config_options: Arc::new(ConfigOptions::new()),
})
.unwrap();
let actual = ColumnarValue::values_to_arrays(&[actual]).unwrap();
let expect = vec![Arc::new(StringViewArray::from(vec![build_info])) as ArrayRef];
assert_eq!(actual, expect);
}
}

View File

@@ -280,6 +280,8 @@ fn build_struct(
&self,
args: datafusion::logical_expr::ScalarFunctionArgs,
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
use common_error::ext::ErrorExt;
let columns = args.args
.iter()
.map(|arg| {
@@ -293,7 +295,7 @@ fn build_struct(
})
})
.collect::<common_query::error::Result<Vec<_>>>()
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Column conversion error: {}", e)))?;
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Column conversion error: {}", e.output_msg())))?;
// Safety check: Ensure under the `greptime` catalog for security
#user_path::ensure_greptime!(self.func_ctx);
@@ -314,14 +316,14 @@ fn build_struct(
.#handler
.as_ref()
.context(#snafu_type)
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Handler error: {}", e)))?;
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Handler error: {}", e.output_msg())))?;
let mut builder = store_api::storage::ConcreteDataType::#ret()
.create_mutable_vector(rows_num);
if columns_num == 0 {
let result = #fn_name(handler, query_ctx, &[]).await
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e.output_msg())))?;
builder.push_value_ref(result.as_value_ref());
} else {
@@ -331,7 +333,7 @@ fn build_struct(
.collect();
let result = #fn_name(handler, query_ctx, &args).await
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e.output_msg())))?;
builder.push_value_ref(result.as_value_ref());
}

View File

@@ -752,7 +752,6 @@ pub enum Error {
location: Location,
},
#[cfg(feature = "pg_kvbackend")]
#[snafu(display("Failed to load TLS certificate from path: {}", path))]
LoadTlsCertificate {
path: String,
@@ -1181,13 +1180,14 @@ impl ErrorExt for Error {
| InvalidRole { .. }
| EmptyDdlTasks { .. } => StatusCode::InvalidArguments,
LoadTlsCertificate { .. } => StatusCode::Internal,
#[cfg(feature = "pg_kvbackend")]
PostgresExecution { .. }
| CreatePostgresPool { .. }
| GetPostgresConnection { .. }
| PostgresTransaction { .. }
| PostgresTlsConfig { .. }
| LoadTlsCertificate { .. }
| InvalidTlsConfig { .. } => StatusCode::Internal,
#[cfg(feature = "mysql_kvbackend")]
MySqlExecution { .. } | CreateMySqlPool { .. } | MySqlTransaction { .. } => {

View File

@@ -13,15 +13,17 @@
// limitations under the License.
use std::any::Any;
use std::fs;
use std::sync::Arc;
use common_telemetry::info;
use common_telemetry::{debug, info};
use etcd_client::{
Client, DeleteOptions, GetOptions, PutOptions, Txn, TxnOp, TxnOpResponse, TxnResponse,
Certificate, Client, DeleteOptions, GetOptions, Identity, PutOptions, TlsOptions, Txn, TxnOp,
TxnOpResponse, TxnResponse,
};
use snafu::{ResultExt, ensure};
use crate::error::{self, Error, Result};
use crate::error::{self, Error, LoadTlsCertificateSnafu, Result};
use crate::kv_backend::txn::{Txn as KvTxn, TxnResponse as KvTxnResponse};
use crate::kv_backend::{KvBackend, KvBackendRef, TxnService};
use crate::metrics::METRIC_META_TXN_REQUEST;
@@ -451,8 +453,76 @@ impl TryFrom<DeleteRangeRequest> for Delete {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum TlsMode {
#[default]
Disable,
Require,
}
/// TLS configuration for Etcd connections.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TlsOption {
pub mode: TlsMode,
pub cert_path: String,
pub key_path: String,
pub ca_cert_path: String,
}
/// Creates a Etcd [`TlsOptions`] from a [`TlsOption`].
///
/// This function builds the TLS options for etcd client connections based on the provided
/// [`TlsOption`]. It supports disabling TLS, setting a custom CA certificate, and configuring
/// client identity for mutual TLS authentication.
///
/// Note: All TlsMode variants except [`TlsMode::Disable`] will be treated as enabling TLS.
pub fn create_etcd_tls_options(tls_config: &TlsOption) -> Result<Option<TlsOptions>> {
// If TLS mode is disabled, return None to indicate no TLS configuration.
if matches!(tls_config.mode, TlsMode::Disable) {
return Ok(None);
}
info!("Creating etcd TLS with mode: {:?}", tls_config.mode);
// Start with default TLS options.
let mut etcd_tls_opts = TlsOptions::new();
// If a CA certificate path is provided, load the CA certificate and add it to the options.
if !tls_config.ca_cert_path.is_empty() {
debug!("Using CA certificate from {}", tls_config.ca_cert_path);
let ca_cert_pem = fs::read(&tls_config.ca_cert_path).context(LoadTlsCertificateSnafu {
path: &tls_config.ca_cert_path,
})?;
let ca_cert = Certificate::from_pem(ca_cert_pem);
etcd_tls_opts = etcd_tls_opts.ca_certificate(ca_cert);
}
// If both client certificate and key paths are provided, load them and set the client identity.
if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
info!("Loading client certificate for mutual TLS");
debug!(
"Using client certificate from {} and key from {}",
tls_config.cert_path, tls_config.key_path
);
let cert_pem = fs::read(&tls_config.cert_path).context(LoadTlsCertificateSnafu {
path: &tls_config.cert_path,
})?;
let key_pem = fs::read(&tls_config.key_path).context(LoadTlsCertificateSnafu {
path: &tls_config.key_path,
})?;
let identity = Identity::from_pem(cert_pem, key_pem);
etcd_tls_opts = etcd_tls_opts.identity(identity);
}
// Always enable native TLS roots for additional trust anchors.
etcd_tls_opts = etcd_tls_opts.with_native_roots();
Ok(Some(etcd_tls_opts))
}
#[cfg(test)]
mod tests {
use etcd_client::ConnectOptions;
use super::*;
#[test]
@@ -555,6 +625,8 @@ mod tests {
test_txn_compare_not_equal, test_txn_one_compare_op, text_txn_multi_compare_op,
unprepare_kv,
};
use crate::maybe_skip_etcd_tls_integration_test;
use crate::test_util::etcd_certs_dir;
async fn build_kv_backend() -> Option<EtcdStore> {
let endpoints = std::env::var("GT_ETCD_ENDPOINTS").unwrap_or_default();
@@ -654,4 +726,41 @@ mod tests {
test_txn_compare_not_equal(&kv_backend).await;
}
}
async fn create_etcd_client_with_tls(endpoints: &[String], tls_config: &TlsOption) -> Client {
let endpoints = endpoints
.iter()
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect::<Vec<_>>();
let connect_options =
ConnectOptions::new().with_tls(create_etcd_tls_options(tls_config).unwrap().unwrap());
Client::connect(&endpoints, Some(connect_options))
.await
.unwrap()
}
#[tokio::test]
async fn test_create_etcd_client_with_mtls_and_ca() {
maybe_skip_etcd_tls_integration_test!();
let endpoints = std::env::var("GT_ETCD_TLS_ENDPOINTS")
.unwrap()
.split(',')
.map(|s| s.to_string())
.collect::<Vec<_>>();
let cert_dir = etcd_certs_dir();
let tls_config = TlsOption {
mode: TlsMode::Require,
ca_cert_path: cert_dir.join("ca.crt").to_string_lossy().to_string(),
cert_path: cert_dir.join("client.crt").to_string_lossy().to_string(),
key_path: cert_dir
.join("client-key.pem")
.to_string_lossy()
.to_string(),
};
let mut client = create_etcd_client_with_tls(&endpoints, &tls_config).await;
let _ = client.get(b"hello", None).await.unwrap();
}
}

View File

@@ -573,6 +573,7 @@ impl MySqlStore {
#[cfg(test)]
mod tests {
use common_telemetry::init_default_ut_logging;
use sqlx::mysql::{MySqlConnectOptions, MySqlSslMode};
use super::*;
use crate::kv_backend::test::{
@@ -584,6 +585,7 @@ mod tests {
text_txn_multi_compare_op, unprepare_kv,
};
use crate::maybe_skip_mysql_integration_test;
use crate::test_util::test_certs_dir;
async fn build_mysql_kv_backend(table_name: &str) -> Option<MySqlStore> {
init_default_ut_logging();
@@ -711,4 +713,71 @@ mod tests {
test_txn_compare_less(&kv_backend).await;
test_txn_compare_not_equal(&kv_backend).await;
}
#[tokio::test]
async fn test_mysql_with_tls() {
common_telemetry::init_default_ut_logging();
maybe_skip_mysql_integration_test!();
let endpoint = std::env::var("GT_MYSQL_ENDPOINTS").unwrap();
let opts = endpoint
.parse::<MySqlConnectOptions>()
.unwrap()
.ssl_mode(MySqlSslMode::Required);
let pool = MySqlPool::connect_with(opts).await.unwrap();
sqlx::query("SELECT 1").execute(&pool).await.unwrap();
}
#[tokio::test]
async fn test_mysql_with_mtls() {
common_telemetry::init_default_ut_logging();
maybe_skip_mysql_integration_test!();
let endpoint = std::env::var("GT_MYSQL_ENDPOINTS").unwrap();
let certs_dir = test_certs_dir();
let opts = endpoint
.parse::<MySqlConnectOptions>()
.unwrap()
.ssl_mode(MySqlSslMode::Required)
.ssl_client_cert(certs_dir.join("client.crt").to_string_lossy().to_string())
.ssl_client_key(certs_dir.join("client.key").to_string_lossy().to_string());
let pool = MySqlPool::connect_with(opts).await.unwrap();
sqlx::query("SELECT 1").execute(&pool).await.unwrap();
}
#[tokio::test]
async fn test_mysql_with_tls_verify_ca() {
common_telemetry::init_default_ut_logging();
maybe_skip_mysql_integration_test!();
let endpoint = std::env::var("GT_MYSQL_ENDPOINTS").unwrap();
let certs_dir = test_certs_dir();
let opts = endpoint
.parse::<MySqlConnectOptions>()
.unwrap()
.ssl_mode(MySqlSslMode::VerifyCa)
.ssl_ca(certs_dir.join("root.crt").to_string_lossy().to_string())
.ssl_client_cert(certs_dir.join("client.crt").to_string_lossy().to_string())
.ssl_client_key(certs_dir.join("client.key").to_string_lossy().to_string());
let pool = MySqlPool::connect_with(opts).await.unwrap();
sqlx::query("SELECT 1").execute(&pool).await.unwrap();
}
#[tokio::test]
async fn test_mysql_with_tls_verify_ident() {
common_telemetry::init_default_ut_logging();
maybe_skip_mysql_integration_test!();
let endpoint = std::env::var("GT_MYSQL_ENDPOINTS").unwrap();
let certs_dir = test_certs_dir();
let opts = endpoint
.parse::<MySqlConnectOptions>()
.unwrap()
.ssl_mode(MySqlSslMode::VerifyIdentity)
.ssl_ca(certs_dir.join("root.crt").to_string_lossy().to_string())
.ssl_client_cert(certs_dir.join("client.crt").to_string_lossy().to_string())
.ssl_client_key(certs_dir.join("client.key").to_string_lossy().to_string());
let pool = MySqlPool::connect_with(opts).await.unwrap();
sqlx::query("SELECT 1").execute(&pool).await.unwrap();
}
}

View File

@@ -903,6 +903,7 @@ mod tests {
test_txn_compare_less, test_txn_compare_not_equal, test_txn_one_compare_op,
text_txn_multi_compare_op, unprepare_kv,
};
use crate::test_util::test_certs_dir;
use crate::{maybe_skip_postgres_integration_test, maybe_skip_postgres15_integration_test};
async fn build_pg_kv_backend(table_name: &str) -> Option<PgStore> {
@@ -993,6 +994,97 @@ mod tests {
unprepare_kv(&kv, prefix).await;
}
#[tokio::test]
async fn test_pg_with_tls() {
common_telemetry::init_default_ut_logging();
maybe_skip_postgres_integration_test!();
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
let tls_connector = create_postgres_tls_connector(&TlsOption {
mode: TlsMode::Require,
cert_path: String::new(),
key_path: String::new(),
ca_cert_path: String::new(),
watch: false,
})
.unwrap();
let mut cfg = Config::new();
cfg.url = Some(endpoints);
let pool = cfg
.create_pool(Some(Runtime::Tokio1), tls_connector)
.unwrap();
let client = pool.get().await.unwrap();
client.execute("SELECT 1", &[]).await.unwrap();
}
#[tokio::test]
async fn test_pg_with_mtls() {
common_telemetry::init_default_ut_logging();
maybe_skip_postgres_integration_test!();
let certs_dir = test_certs_dir();
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
let tls_connector = create_postgres_tls_connector(&TlsOption {
mode: TlsMode::Require,
cert_path: certs_dir.join("client.crt").display().to_string(),
key_path: certs_dir.join("client.key").display().to_string(),
ca_cert_path: String::new(),
watch: false,
})
.unwrap();
let mut cfg = Config::new();
cfg.url = Some(endpoints);
let pool = cfg
.create_pool(Some(Runtime::Tokio1), tls_connector)
.unwrap();
let client = pool.get().await.unwrap();
client.execute("SELECT 1", &[]).await.unwrap();
}
#[tokio::test]
async fn test_pg_verify_ca() {
common_telemetry::init_default_ut_logging();
maybe_skip_postgres_integration_test!();
let certs_dir = test_certs_dir();
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
let tls_connector = create_postgres_tls_connector(&TlsOption {
mode: TlsMode::VerifyCa,
cert_path: certs_dir.join("client.crt").display().to_string(),
key_path: certs_dir.join("client.key").display().to_string(),
ca_cert_path: certs_dir.join("root.crt").display().to_string(),
watch: false,
})
.unwrap();
let mut cfg = Config::new();
cfg.url = Some(endpoints);
let pool = cfg
.create_pool(Some(Runtime::Tokio1), tls_connector)
.unwrap();
let client = pool.get().await.unwrap();
client.execute("SELECT 1", &[]).await.unwrap();
}
#[tokio::test]
async fn test_pg_verify_full() {
common_telemetry::init_default_ut_logging();
maybe_skip_postgres_integration_test!();
let certs_dir = test_certs_dir();
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
let tls_connector = create_postgres_tls_connector(&TlsOption {
mode: TlsMode::VerifyFull,
cert_path: certs_dir.join("client.crt").display().to_string(),
key_path: certs_dir.join("client.key").display().to_string(),
ca_cert_path: certs_dir.join("root.crt").display().to_string(),
watch: false,
})
.unwrap();
let mut cfg = Config::new();
cfg.url = Some(endpoints);
let pool = cfg
.create_pool(Some(Runtime::Tokio1), tls_connector)
.unwrap();
let client = pool.get().await.unwrap();
client.execute("SELECT 1", &[]).await.unwrap();
}
#[tokio::test]
async fn test_pg_put() {
maybe_skip_postgres_integration_test!();

View File

@@ -25,8 +25,8 @@ use crate::error::{
};
use crate::rpc::ddl::{SubmitDdlTaskRequest, SubmitDdlTaskResponse};
use crate::rpc::procedure::{
self, AddRegionFollowerRequest, MigrateRegionRequest, MigrateRegionResponse,
ProcedureStateResponse, RemoveRegionFollowerRequest,
self, ManageRegionFollowerRequest, MigrateRegionRequest, MigrateRegionResponse,
ProcedureStateResponse,
};
/// The context of procedure executor.
@@ -45,26 +45,14 @@ pub trait ProcedureExecutor: Send + Sync {
request: SubmitDdlTaskRequest,
) -> Result<SubmitDdlTaskResponse>;
/// Add a region follower
async fn add_region_follower(
/// Submit ad manage region follower task
async fn manage_region_follower(
&self,
_ctx: &ExecutorContext,
_request: AddRegionFollowerRequest,
_request: ManageRegionFollowerRequest,
) -> Result<()> {
UnsupportedSnafu {
operation: "add_region_follower",
}
.fail()
}
/// Remove a region follower
async fn remove_region_follower(
&self,
_ctx: &ExecutorContext,
_request: RemoveRegionFollowerRequest,
) -> Result<()> {
UnsupportedSnafu {
operation: "remove_region_follower",
operation: "manage_region_follower",
}
.fail()
}

View File

@@ -23,6 +23,7 @@ use api::v1::meta::{
use common_error::ext::ErrorExt;
use common_procedure::{ProcedureId, ProcedureInfo, ProcedureState};
use snafu::ResultExt;
use table::metadata::TableId;
use crate::error::{ParseProcedureIdSnafu, Result};
@@ -44,6 +45,30 @@ pub struct AddRegionFollowerRequest {
pub peer_id: u64,
}
#[derive(Debug, Clone)]
pub struct AddTableFollowerRequest {
pub catalog_name: String,
pub schema_name: String,
pub table_name: String,
pub table_id: TableId,
}
#[derive(Debug, Clone)]
pub struct RemoveTableFollowerRequest {
pub catalog_name: String,
pub schema_name: String,
pub table_name: String,
pub table_id: TableId,
}
#[derive(Debug, Clone)]
pub enum ManageRegionFollowerRequest {
AddRegionFollower(AddRegionFollowerRequest),
RemoveRegionFollower(RemoveRegionFollowerRequest),
AddTableFollower(AddTableFollowerRequest),
RemoveTableFollower(RemoveTableFollowerRequest),
}
/// A request to remove region follower.
#[derive(Debug, Clone)]
pub struct RemoveRegionFollowerRequest {

View File

@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::path::PathBuf;
use std::sync::Arc;
use api::region::RegionResponse;
@@ -299,3 +300,39 @@ macro_rules! maybe_skip_postgres15_integration_test {
}
};
}
#[macro_export]
/// Skip the test if the environment variable `GT_ETCD_TLS_ENDPOINTS` is not set.
///
/// The format of the environment variable is:
/// ```text
/// GT_ETCD_TLS_ENDPOINTS=localhost:9092,localhost:9093
/// ```
macro_rules! maybe_skip_etcd_tls_integration_test {
() => {
if std::env::var("GT_ETCD_TLS_ENDPOINTS").is_err() {
common_telemetry::warn!("The etcd with tls endpoints is empty, skipping the test");
return;
}
};
}
/// Returns the directory of the etcd TLS certs.
pub fn etcd_certs_dir() -> PathBuf {
let project_path = env!("CARGO_MANIFEST_DIR");
let project_path = PathBuf::from(project_path);
let base = project_path.ancestors().nth(3).unwrap();
base.join("tests-integration")
.join("fixtures")
.join("etcd-tls-certs")
}
/// Returns the directory of the test certs.
pub fn test_certs_dir() -> PathBuf {
let project_path = env!("CARGO_MANIFEST_DIR");
let project_path = PathBuf::from(project_path);
let base = project_path.ancestors().nth(3).unwrap();
base.join("tests-integration")
.join("fixtures")
.join("certs")
}

View File

@@ -17,7 +17,7 @@ use datatypes::prelude::ConcreteDataType;
use datatypes::vectors::{Helper, VectorRef};
use snafu::ResultExt;
use crate::error::{self, GeneralDataFusionSnafu, IntoVectorSnafu, Result};
use crate::error::{self, IntoVectorSnafu, Result};
use crate::prelude::ScalarValue;
/// Represents the result from an expression
@@ -43,9 +43,7 @@ impl ColumnarValue {
Ok(match self {
ColumnarValue::Vector(v) => v,
ColumnarValue::Scalar(s) => {
let v = s
.to_array_of_size(num_rows)
.context(GeneralDataFusionSnafu)?;
let v = s.to_array_of_size(num_rows)?;
let data_type = v.data_type().clone();
Helper::try_into_vector(v).context(IntoVectorSnafu { data_type })?
}

View File

@@ -78,7 +78,7 @@ pub enum Error {
location: Location,
},
#[snafu(display("General DataFusion error"))]
#[snafu(transparent)]
GeneralDataFusion {
#[snafu(source)]
error: DataFusionError,

View File

@@ -24,9 +24,8 @@ use datafusion_common::{Column, TableReference};
use datafusion_expr::dml::InsertOp;
use datafusion_expr::{DmlStatement, TableSource, WriteOp, col};
pub use expr::{build_filter_from_timestamp, build_same_type_ts_filter};
use snafu::ResultExt;
use crate::error::{GeneralDataFusionSnafu, Result};
use crate::error::Result;
/// Rename columns by applying a new projection. Returns an error if the column to be
/// renamed does not exist. The `renames` parameter is a `Vector` with elements
@@ -122,7 +121,7 @@ pub fn add_insert_to_logical_plan(
WriteOp::Insert(InsertOp::Append),
Arc::new(input),
));
let plan = plan.recompute_schema().context(GeneralDataFusionSnafu)?;
let plan = plan.recompute_schema()?;
Ok(plan)
}

View File

@@ -173,6 +173,9 @@ mod tests {
#[test]
fn test_from_tz_string() {
unsafe {
std::env::remove_var("TZ");
}
assert_eq!(
Timezone::Named(Tz::UTC),
Timezone::from_tz_string("SYSTEM").unwrap()

View File

@@ -72,7 +72,7 @@ impl RegionServer {
})?
};
let entries = mito.all_ssts_from_manifest().collect::<Vec<_>>();
let entries = mito.all_ssts_from_manifest().await;
let schema = ManifestSstEntry::schema().arrow_schema().clone();
let batch = ManifestSstEntry::to_record_batch(&entries)
.map_err(DataFusionError::from)

View File

@@ -252,7 +252,7 @@ impl RegionEngine for MockRegionEngine {
unimplemented!()
}
async fn get_last_seq_num(&self, _: RegionId) -> Result<Option<SequenceNumber>, BoxedError> {
async fn get_committed_sequence(&self, _: RegionId) -> Result<SequenceNumber, BoxedError> {
unimplemented!()
}

View File

@@ -115,8 +115,8 @@ impl RegionEngine for FileRegionEngine {
None
}
async fn get_last_seq_num(&self, _: RegionId) -> Result<Option<SequenceNumber>, BoxedError> {
Ok(None)
async fn get_committed_sequence(&self, _: RegionId) -> Result<SequenceNumber, BoxedError> {
Ok(Default::default())
}
fn set_region_role(&self, region_id: RegionId, role: RegionRole) -> Result<(), BoxedError> {

View File

@@ -376,34 +376,16 @@ impl Instance {
ctx: QueryContextRef,
) -> server_error::Result<bool> {
let db_string = ctx.get_db_string();
// fast cache check
let cache = self
.otlp_metrics_table_legacy_cache
.entry(db_string)
.entry(db_string.clone())
.or_default();
// check cache
let hit_cache = names
.iter()
.filter_map(|name| cache.get(*name))
.collect::<Vec<_>>();
if !hit_cache.is_empty() {
let hit_legacy = hit_cache.iter().any(|en| *en.value());
let hit_prom = hit_cache.iter().any(|en| !*en.value());
// hit but have true and false, means both legacy and new mode are used
// we cannot handle this case, so return error
// add doc links in err msg later
ensure!(!(hit_legacy && hit_prom), OtlpMetricModeIncompatibleSnafu);
let flag = hit_legacy;
// set cache for all names
names.iter().for_each(|name| {
if !cache.contains_key(*name) {
cache.insert(name.to_string(), flag);
}
});
if let Some(flag) = fast_legacy_check(&cache, names)? {
return Ok(flag);
}
// release cache reference to avoid lock contention
drop(cache);
let catalog = ctx.current_catalog();
let schema = ctx.current_schema();
@@ -430,7 +412,10 @@ impl Instance {
// means no existing table is found, use new mode
if table_ids.is_empty() {
// set cache
let cache = self
.otlp_metrics_table_legacy_cache
.entry(db_string)
.or_default();
names.iter().for_each(|name| {
cache.insert(name.to_string(), false);
});
@@ -455,6 +440,10 @@ impl Instance {
.unwrap_or(&OTLP_LEGACY_DEFAULT_VALUE)
})
.collect::<Vec<_>>();
let cache = self
.otlp_metrics_table_legacy_cache
.entry(db_string)
.or_default();
if !options.is_empty() {
// check value consistency
let has_prom = options.iter().any(|opt| *opt == OTLP_METRIC_COMPAT_PROM);
@@ -477,6 +466,39 @@ impl Instance {
}
}
fn fast_legacy_check(
cache: &DashMap<String, bool>,
names: &[&String],
) -> server_error::Result<Option<bool>> {
let hit_cache = names
.iter()
.filter_map(|name| cache.get(*name))
.collect::<Vec<_>>();
if !hit_cache.is_empty() {
let hit_legacy = hit_cache.iter().any(|en| *en.value());
let hit_prom = hit_cache.iter().any(|en| !*en.value());
// hit but have true and false, means both legacy and new mode are used
// we cannot handle this case, so return error
// add doc links in err msg later
ensure!(!(hit_legacy && hit_prom), OtlpMetricModeIncompatibleSnafu);
let flag = hit_legacy;
// drop hit_cache to release references before inserting to avoid deadlock
drop(hit_cache);
// set cache for all names
names.iter().for_each(|name| {
if !cache.contains_key(*name) {
cache.insert(name.to_string(), flag);
}
});
Ok(Some(flag))
} else {
Ok(None)
}
}
/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements.
/// For MySQL, it applies only to read-only statements.
fn derive_timeout(stmt: &Statement, query_ctx: &QueryContextRef) -> Option<Duration> {
@@ -1041,6 +1063,10 @@ fn should_capture_statement(stmt: Option<&Statement>) -> bool {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::{Duration, Instant};
use common_base::Plugins;
use query::query_engine::options::QueryOptions;
@@ -1050,6 +1076,122 @@ mod tests {
use super::*;
#[test]
fn test_fast_legacy_check_deadlock_prevention() {
// Create a DashMap to simulate the cache
let cache = DashMap::new();
// Pre-populate cache with some entries
cache.insert("metric1".to_string(), true); // legacy mode
cache.insert("metric2".to_string(), false); // prom mode
cache.insert("metric3".to_string(), true); // legacy mode
// Test case 1: Normal operation with cache hits
let metric1 = "metric1".to_string();
let metric4 = "metric4".to_string();
let names1 = vec![&metric1, &metric4];
let result = fast_legacy_check(&cache, &names1);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Some(true)); // should return legacy mode
// Verify that metric4 was added to cache
assert!(cache.contains_key("metric4"));
assert!(*cache.get("metric4").unwrap().value());
// Test case 2: No cache hits
let metric5 = "metric5".to_string();
let metric6 = "metric6".to_string();
let names2 = vec![&metric5, &metric6];
let result = fast_legacy_check(&cache, &names2);
assert!(result.is_ok());
assert_eq!(result.unwrap(), None); // should return None as no cache hits
// Test case 3: Incompatible modes should return error
let cache_incompatible = DashMap::new();
cache_incompatible.insert("metric1".to_string(), true); // legacy
cache_incompatible.insert("metric2".to_string(), false); // prom
let metric1_test = "metric1".to_string();
let metric2_test = "metric2".to_string();
let names3 = vec![&metric1_test, &metric2_test];
let result = fast_legacy_check(&cache_incompatible, &names3);
assert!(result.is_err()); // should error due to incompatible modes
// Test case 4: Intensive concurrent access to test deadlock prevention
// This test specifically targets the scenario where multiple threads
// access the same cache entries simultaneously
let cache_concurrent = Arc::new(DashMap::new());
cache_concurrent.insert("shared_metric".to_string(), true);
let num_threads = 8;
let operations_per_thread = 100;
let barrier = Arc::new(Barrier::new(num_threads));
let success_flag = Arc::new(AtomicBool::new(true));
let handles: Vec<_> = (0..num_threads)
.map(|thread_id| {
let cache_clone = Arc::clone(&cache_concurrent);
let barrier_clone = Arc::clone(&barrier);
let success_flag_clone = Arc::clone(&success_flag);
thread::spawn(move || {
// Wait for all threads to be ready
barrier_clone.wait();
let start_time = Instant::now();
for i in 0..operations_per_thread {
// Each operation references existing cache entry and adds new ones
let shared_metric = "shared_metric".to_string();
let new_metric = format!("thread_{}_metric_{}", thread_id, i);
let names = vec![&shared_metric, &new_metric];
match fast_legacy_check(&cache_clone, &names) {
Ok(_) => {}
Err(_) => {
success_flag_clone.store(false, Ordering::Relaxed);
return;
}
}
// If the test takes too long, it likely means deadlock
if start_time.elapsed() > Duration::from_secs(10) {
success_flag_clone.store(false, Ordering::Relaxed);
return;
}
}
})
})
.collect();
// Join all threads with timeout
let start_time = Instant::now();
for (i, handle) in handles.into_iter().enumerate() {
let join_result = handle.join();
// Check if we're taking too long (potential deadlock)
if start_time.elapsed() > Duration::from_secs(30) {
panic!("Test timed out - possible deadlock detected!");
}
if join_result.is_err() {
panic!("Thread {} panicked during execution", i);
}
}
// Verify all operations completed successfully
assert!(
success_flag.load(Ordering::Relaxed),
"Some operations failed"
);
// Verify that many new entries were added (proving operations completed)
let final_count = cache_concurrent.len();
assert!(
final_count > 1 + num_threads * operations_per_thread / 2,
"Expected more cache entries, got {}",
final_count
);
}
#[test]
fn test_exec_validation() {
let query_ctx = QueryContext::arc();

View File

@@ -43,9 +43,9 @@ use table::table::adapter::DfTableProviderAdapter;
use table::table_name::TableName;
use crate::error::{
CatalogSnafu, Error, ExternalSnafu, IncompleteGrpcRequestSnafu, NotSupportedSnafu,
PermissionSnafu, PlanStatementSnafu, Result, SubstraitDecodeLogicalPlanSnafu,
TableNotFoundSnafu, TableOperationSnafu,
CatalogSnafu, DataFusionSnafu, Error, ExternalSnafu, IncompleteGrpcRequestSnafu,
NotSupportedSnafu, PermissionSnafu, PlanStatementSnafu, Result,
SubstraitDecodeLogicalPlanSnafu, TableNotFoundSnafu, TableOperationSnafu,
};
use crate::instance::{Instance, attach_timer};
use crate::metrics::{
@@ -395,14 +395,10 @@ impl Instance {
let analyzed_plan = state
.analyzer()
.execute_and_check(insert_into, state.config_options(), |_, _| {})
.context(common_query::error::GeneralDataFusionSnafu)
.context(SubstraitDecodeLogicalPlanSnafu)?;
.context(DataFusionSnafu)?;
// Optimize the plan
let optimized_plan = state
.optimize(&analyzed_plan)
.context(common_query::error::GeneralDataFusionSnafu)
.context(SubstraitDecodeLogicalPlanSnafu)?;
let optimized_plan = state.optimize(&analyzed_plan).context(DataFusionSnafu)?;
let output = SqlQueryHandler::do_exec_plan(self, None, optimized_plan, ctx.clone()).await?;

View File

@@ -91,6 +91,21 @@ impl Filters {
Filters::Single(filter)
}
}
/// Aggregation function with optional range and alias.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggFunc {
/// Function name, e.g., "count", "sum", etc.
pub name: String,
/// Arguments to the function. e.g., column references or literals. LogExpr::NamedIdent("column1".to_string())
pub args: Vec<LogExpr>,
pub alias: Option<String>,
}
impl AggFunc {
pub fn new(name: String, args: Vec<LogExpr>, alias: Option<String>) -> Self {
Self { name, args, alias }
}
}
/// Expression to calculate on log after filtering.
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -103,13 +118,11 @@ pub enum LogExpr {
args: Vec<LogExpr>,
alias: Option<String>,
},
/// Aggregation function with optional grouping.
AggrFunc {
name: String,
args: Vec<LogExpr>,
/// Optional range function parameter. Stands for the time range for both step and align.
range: Option<String>,
/// Function name, arguments, and optional alias.
expr: Vec<AggFunc>,
by: Vec<LogExpr>,
alias: Option<String>,
},
Decompose {
expr: Box<LogExpr>,

View File

@@ -44,8 +44,9 @@ use common_meta::range_stream::PaginationStream;
use common_meta::rpc::KeyValue;
use common_meta::rpc::ddl::{SubmitDdlTaskRequest, SubmitDdlTaskResponse};
use common_meta::rpc::procedure::{
AddRegionFollowerRequest, MigrateRegionRequest, MigrateRegionResponse, ProcedureStateResponse,
RemoveRegionFollowerRequest,
AddRegionFollowerRequest, AddTableFollowerRequest, ManageRegionFollowerRequest,
MigrateRegionRequest, MigrateRegionResponse, ProcedureStateResponse,
RemoveRegionFollowerRequest, RemoveTableFollowerRequest,
};
use common_meta::rpc::store::{
BatchDeleteRequest, BatchDeleteResponse, BatchGetRequest, BatchGetResponse, BatchPutRequest,
@@ -246,6 +247,10 @@ pub trait RegionFollowerClient: Sync + Send + Debug {
async fn remove_region_follower(&self, request: RemoveRegionFollowerRequest) -> Result<()>;
async fn add_table_follower(&self, request: AddTableFollowerRequest) -> Result<()>;
async fn remove_table_follower(&self, request: RemoveTableFollowerRequest) -> Result<()>;
async fn start(&self, urls: &[&str]) -> Result<()>;
async fn start_with(&self, leader_provider: LeaderProviderRef) -> Result<()>;
@@ -286,39 +291,41 @@ impl ProcedureExecutor for MetaClient {
.context(meta_error::ExternalSnafu)
}
async fn add_region_follower(
async fn manage_region_follower(
&self,
_ctx: &ExecutorContext,
request: AddRegionFollowerRequest,
request: ManageRegionFollowerRequest,
) -> MetaResult<()> {
if let Some(region_follower) = &self.region_follower {
region_follower
.add_region_follower(request)
.await
.map_err(BoxedError::new)
.context(meta_error::ExternalSnafu)
} else {
UnsupportedSnafu {
operation: "add_region_follower",
match request {
ManageRegionFollowerRequest::AddRegionFollower(add_region_follower_request) => {
region_follower
.add_region_follower(add_region_follower_request)
.await
}
ManageRegionFollowerRequest::RemoveRegionFollower(
remove_region_follower_request,
) => {
region_follower
.remove_region_follower(remove_region_follower_request)
.await
}
ManageRegionFollowerRequest::AddTableFollower(add_table_follower_request) => {
region_follower
.add_table_follower(add_table_follower_request)
.await
}
ManageRegionFollowerRequest::RemoveTableFollower(remove_table_follower_request) => {
region_follower
.remove_table_follower(remove_table_follower_request)
.await
}
}
.fail()
}
}
async fn remove_region_follower(
&self,
_ctx: &ExecutorContext,
request: RemoveRegionFollowerRequest,
) -> MetaResult<()> {
if let Some(region_follower) = &self.region_follower {
region_follower
.remove_region_follower(request)
.await
.map_err(BoxedError::new)
.context(meta_error::ExternalSnafu)
.map_err(BoxedError::new)
.context(meta_error::ExternalSnafu)
} else {
UnsupportedSnafu {
operation: "remove_region_follower",
operation: "manage_region_follower",
}
.fail()
}

View File

@@ -21,45 +21,23 @@ use api::v1::meta::procedure_service_server::ProcedureServiceServer;
use api::v1::meta::store_server::StoreServer;
use common_base::Plugins;
use common_config::Configurable;
#[cfg(feature = "pg_kvbackend")]
use common_error::ext::BoxedError;
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
use common_meta::distributed_time_constants::META_LEASE_SECS;
use common_meta::kv_backend::chroot::ChrootKvBackend;
use common_meta::kv_backend::etcd::EtcdStore;
use common_meta::kv_backend::memory::MemoryKvBackend;
#[cfg(feature = "mysql_kvbackend")]
use common_meta::kv_backend::rds::MySqlStore;
#[cfg(feature = "pg_kvbackend")]
use common_meta::kv_backend::rds::PgStore;
#[cfg(feature = "pg_kvbackend")]
use common_meta::kv_backend::rds::postgres::create_postgres_tls_connector;
#[cfg(feature = "pg_kvbackend")]
use common_meta::kv_backend::rds::postgres::{TlsMode as PgTlsMode, TlsOption as PgTlsOption};
use common_meta::kv_backend::{KvBackendRef, ResettableKvBackendRef};
use common_telemetry::info;
#[cfg(feature = "pg_kvbackend")]
use deadpool_postgres::{Config, Runtime};
use either::Either;
use etcd_client::{Client, ConnectOptions};
use servers::configurator::ConfiguratorRef;
use servers::export_metrics::ExportMetricsTask;
use servers::http::{HttpServer, HttpServerBuilder};
use servers::metrics_handler::MetricsHandler;
use servers::server::Server;
use servers::tls::TlsOption;
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
use snafu::OptionExt;
use snafu::ResultExt;
#[cfg(feature = "mysql_kvbackend")]
use sqlx::mysql::MySqlConnectOptions;
#[cfg(feature = "mysql_kvbackend")]
use sqlx::mysql::MySqlPool;
use tokio::net::TcpListener;
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::{Mutex, oneshot};
#[cfg(feature = "pg_kvbackend")]
use tokio_postgres::NoTls;
use tonic::codec::CompressionEncoding;
use tonic::transport::server::{Router, TcpIncoming};
@@ -67,10 +45,6 @@ use crate::cluster::{MetaPeerClientBuilder, MetaPeerClientRef};
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
use crate::election::CANDIDATE_LEASE_SECS;
use crate::election::etcd::EtcdElection;
#[cfg(feature = "mysql_kvbackend")]
use crate::election::rds::mysql::MySqlElection;
#[cfg(feature = "pg_kvbackend")]
use crate::election::rds::postgres::PgElection;
use crate::metasrv::builder::MetasrvBuilder;
use crate::metasrv::{
BackendImpl, ElectionRef, Metasrv, MetasrvOptions, SelectTarget, SelectorRef,
@@ -82,6 +56,7 @@ use crate::selector::round_robin::RoundRobinSelector;
use crate::selector::weight_compute::RegionNumsBasedWeightCompute;
use crate::service::admin;
use crate::service::admin::admin_axum_router;
use crate::utils::etcd::create_etcd_client_with_tls;
use crate::{Result, error};
pub struct MetasrvInstance {
@@ -306,8 +281,11 @@ pub async fn metasrv_builder(
use std::time::Duration;
use common_meta::distributed_time_constants::POSTGRES_KEEP_ALIVE_SECS;
use common_meta::kv_backend::rds::PgStore;
use deadpool_postgres::Config;
use crate::election::rds::postgres::ElectionPgClient;
use crate::election::rds::postgres::{ElectionPgClient, PgElection};
use crate::utils::postgres::create_postgres_pool;
let candidate_lease_ttl = Duration::from_secs(CANDIDATE_LEASE_SECS);
let execution_timeout = Duration::from_secs(META_LEASE_SECS);
@@ -319,8 +297,8 @@ pub async fn metasrv_builder(
cfg.keepalives = Some(true);
cfg.keepalives_idle = Some(Duration::from_secs(POSTGRES_KEEP_ALIVE_SECS));
// We use a separate pool for election since we need a different session keep-alive idle time.
let pool =
create_postgres_pool_with(&opts.store_addrs, cfg, opts.backend_tls.clone()).await?;
let pool = create_postgres_pool(&opts.store_addrs, Some(cfg), opts.backend_tls.clone())
.await?;
let election_client = ElectionPgClient::new(
pool,
@@ -340,7 +318,8 @@ pub async fn metasrv_builder(
)
.await?;
let pool = create_postgres_pool(&opts.store_addrs, opts.backend_tls.clone()).await?;
let pool =
create_postgres_pool(&opts.store_addrs, None, opts.backend_tls.clone()).await?;
let kv_backend = PgStore::with_pg_pool(
pool,
opts.meta_schema_name.as_deref(),
@@ -356,9 +335,12 @@ pub async fn metasrv_builder(
(None, BackendImpl::MysqlStore) => {
use std::time::Duration;
use crate::election::rds::mysql::ElectionMysqlClient;
use common_meta::kv_backend::rds::MySqlStore;
let pool = create_mysql_pool(&opts.store_addrs).await?;
use crate::election::rds::mysql::{ElectionMysqlClient, MySqlElection};
use crate::utils::mysql::create_mysql_pool;
let pool = create_mysql_pool(&opts.store_addrs, opts.backend_tls.as_ref()).await?;
let kv_backend =
MySqlStore::with_mysql_pool(pool, &opts.meta_table_name, opts.max_txn_ops)
.await
@@ -366,7 +348,7 @@ pub async fn metasrv_builder(
// Since election will acquire a lock of the table, we need a separate table for election.
let election_table_name = opts.meta_table_name.clone() + "_election";
// We use a separate pool for election since we need a different session keep-alive idle time.
let pool = create_mysql_pool(&opts.store_addrs).await?;
let pool = create_mysql_pool(&opts.store_addrs, opts.backend_tls.as_ref()).await?;
let execution_timeout = Duration::from_secs(META_LEASE_SECS);
let statement_timeout = Duration::from_secs(META_LEASE_SECS);
let idle_session_timeout = Duration::from_secs(META_LEASE_SECS);
@@ -452,259 +434,3 @@ pub(crate) fn build_default_meta_peer_client(
// Safety: all required fields set at initialization
.unwrap()
}
pub async fn create_etcd_client(store_addrs: &[String]) -> Result<Client> {
create_etcd_client_with_tls(store_addrs, None).await
}
fn build_connection_options(tls_config: Option<&TlsOption>) -> Result<Option<ConnectOptions>> {
use std::fs;
use common_telemetry::debug;
use etcd_client::{Certificate, ConnectOptions, Identity, TlsOptions};
use servers::tls::TlsMode;
// If TLS options are not provided, return None
let Some(tls_config) = tls_config else {
return Ok(None);
};
// If TLS is disabled, return None
if matches!(tls_config.mode, TlsMode::Disable) {
return Ok(None);
}
let mut etcd_tls_opts = TlsOptions::new();
// Set CA certificate if provided
if !tls_config.ca_cert_path.is_empty() {
debug!("Using CA certificate from {}", tls_config.ca_cert_path);
let ca_cert_pem = fs::read(&tls_config.ca_cert_path).context(error::FileIoSnafu {
path: &tls_config.ca_cert_path,
})?;
let ca_cert = Certificate::from_pem(ca_cert_pem);
etcd_tls_opts = etcd_tls_opts.ca_certificate(ca_cert);
}
// Set client identity (cert + key) if both are provided
if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
debug!(
"Using client certificate from {} and key from {}",
tls_config.cert_path, tls_config.key_path
);
let cert_pem = fs::read(&tls_config.cert_path).context(error::FileIoSnafu {
path: &tls_config.cert_path,
})?;
let key_pem = fs::read(&tls_config.key_path).context(error::FileIoSnafu {
path: &tls_config.key_path,
})?;
let identity = Identity::from_pem(cert_pem, key_pem);
etcd_tls_opts = etcd_tls_opts.identity(identity);
}
// Enable native TLS roots for additional trust anchors
etcd_tls_opts = etcd_tls_opts.with_native_roots();
Ok(Some(ConnectOptions::new().with_tls(etcd_tls_opts)))
}
pub async fn create_etcd_client_with_tls(
store_addrs: &[String],
tls_config: Option<&TlsOption>,
) -> Result<Client> {
let etcd_endpoints = store_addrs
.iter()
.map(|x| x.trim())
.filter(|x| !x.is_empty())
.collect::<Vec<_>>();
let connect_options = build_connection_options(tls_config)?;
Client::connect(&etcd_endpoints, connect_options)
.await
.context(error::ConnectEtcdSnafu)
}
#[cfg(feature = "pg_kvbackend")]
/// Converts servers::tls::TlsOption to postgres::TlsOption to avoid circular dependencies
fn convert_tls_option(tls_option: &TlsOption) -> PgTlsOption {
let mode = match tls_option.mode {
servers::tls::TlsMode::Disable => PgTlsMode::Disable,
servers::tls::TlsMode::Prefer => PgTlsMode::Prefer,
servers::tls::TlsMode::Require => PgTlsMode::Require,
servers::tls::TlsMode::VerifyCa => PgTlsMode::VerifyCa,
servers::tls::TlsMode::VerifyFull => PgTlsMode::VerifyFull,
};
PgTlsOption {
mode,
cert_path: tls_option.cert_path.clone(),
key_path: tls_option.key_path.clone(),
ca_cert_path: tls_option.ca_cert_path.clone(),
watch: tls_option.watch,
}
}
#[cfg(feature = "pg_kvbackend")]
/// Creates a pool for the Postgres backend with optional TLS.
///
/// It only use first store addr to create a pool.
pub async fn create_postgres_pool(
store_addrs: &[String],
tls_config: Option<TlsOption>,
) -> Result<deadpool_postgres::Pool> {
create_postgres_pool_with(store_addrs, Config::new(), tls_config).await
}
#[cfg(feature = "pg_kvbackend")]
/// Creates a pool for the Postgres backend with config and optional TLS.
///
/// It only use first store addr to create a pool, and use the given config to create a pool.
pub async fn create_postgres_pool_with(
store_addrs: &[String],
mut cfg: Config,
tls_config: Option<TlsOption>,
) -> Result<deadpool_postgres::Pool> {
let postgres_url = store_addrs.first().context(error::InvalidArgumentsSnafu {
err_msg: "empty store addrs",
})?;
cfg.url = Some(postgres_url.to_string());
let pool = if let Some(tls_config) = tls_config {
let pg_tls_config = convert_tls_option(&tls_config);
let tls_connector =
create_postgres_tls_connector(&pg_tls_config).map_err(|e| error::Error::Other {
source: BoxedError::new(e),
location: snafu::Location::new(file!(), line!(), 0),
})?;
cfg.create_pool(Some(Runtime::Tokio1), tls_connector)
.context(error::CreatePostgresPoolSnafu)?
} else {
cfg.create_pool(Some(Runtime::Tokio1), NoTls)
.context(error::CreatePostgresPoolSnafu)?
};
Ok(pool)
}
#[cfg(feature = "mysql_kvbackend")]
async fn setup_mysql_options(store_addrs: &[String]) -> Result<MySqlConnectOptions> {
let mysql_url = store_addrs.first().context(error::InvalidArgumentsSnafu {
err_msg: "empty store addrs",
})?;
// Avoid `SET` commands in sqlx
let opts: MySqlConnectOptions = mysql_url
.parse()
.context(error::ParseMySqlUrlSnafu { mysql_url })?;
let opts = opts
.no_engine_substitution(false)
.pipes_as_concat(false)
.timezone(None)
.set_names(false);
Ok(opts)
}
#[cfg(feature = "mysql_kvbackend")]
pub async fn create_mysql_pool(store_addrs: &[String]) -> Result<MySqlPool> {
let opts = setup_mysql_options(store_addrs).await?;
let pool = MySqlPool::connect_with(opts)
.await
.context(error::CreateMySqlPoolSnafu)?;
Ok(pool)
}
#[cfg(test)]
mod tests {
use servers::tls::TlsMode;
use super::*;
#[tokio::test]
async fn test_create_etcd_client_tls_without_certs() {
let endpoints: Vec<String> = match std::env::var("GT_ETCD_TLS_ENDPOINTS") {
Ok(endpoints_str) => endpoints_str
.split(',')
.map(|s| s.trim().to_string())
.collect(),
Err(_) => return,
};
let tls_config = TlsOption {
mode: TlsMode::Require,
ca_cert_path: String::new(),
cert_path: String::new(),
key_path: String::new(),
watch: false,
};
let _client = create_etcd_client_with_tls(&endpoints, Some(&tls_config))
.await
.unwrap();
}
#[tokio::test]
async fn test_create_etcd_client_tls_with_client_certs() {
let endpoints: Vec<String> = match std::env::var("GT_ETCD_TLS_ENDPOINTS") {
Ok(endpoints_str) => endpoints_str
.split(',')
.map(|s| s.trim().to_string())
.collect(),
Err(_) => return,
};
let cert_dir = std::env::current_dir()
.unwrap()
.join("tests-integration")
.join("fixtures")
.join("etcd-tls-certs");
if cert_dir.join("client.crt").exists() && cert_dir.join("client-key.pem").exists() {
let tls_config = TlsOption {
mode: TlsMode::Require,
ca_cert_path: String::new(),
cert_path: cert_dir.join("client.crt").to_string_lossy().to_string(),
key_path: cert_dir
.join("client-key.pem")
.to_string_lossy()
.to_string(),
watch: false,
};
let _client = create_etcd_client_with_tls(&endpoints, Some(&tls_config))
.await
.unwrap();
}
}
#[tokio::test]
async fn test_create_etcd_client_tls_with_full_certs() {
let endpoints: Vec<String> = match std::env::var("GT_ETCD_TLS_ENDPOINTS") {
Ok(endpoints_str) => endpoints_str
.split(',')
.map(|s| s.trim().to_string())
.collect(),
Err(_) => return,
};
let cert_dir = std::env::current_dir()
.unwrap()
.join("tests-integration")
.join("fixtures")
.join("etcd-tls-certs");
if cert_dir.join("ca.crt").exists()
&& cert_dir.join("client.crt").exists()
&& cert_dir.join("client-key.pem").exists()
{
let tls_config = TlsOption {
mode: TlsMode::Require,
ca_cert_path: cert_dir.join("ca.crt").to_string_lossy().to_string(),
cert_path: cert_dir.join("client.crt").to_string_lossy().to_string(),
key_path: cert_dir
.join("client-key.pem")
.to_string_lossy()
.to_string(),
watch: false,
};
let _client = create_etcd_client_with_tls(&endpoints, Some(&tls_config))
.await
.unwrap();
}
}
}

View File

@@ -107,7 +107,7 @@ mod tests {
use common_time::util::current_time_millis;
use common_workload::DatanodeWorkloadType;
use crate::discovery::utils::{self, is_datanode_accept_ingest_workload};
use crate::discovery::utils::{self, accept_ingest_workload};
use crate::key::{DatanodeLeaseKey, LeaseValue};
use crate::test_util::create_meta_peer_client;
@@ -219,7 +219,7 @@ mod tests {
let peers = utils::alive_datanodes(
client.as_ref(),
Duration::from_secs(lease_secs),
Some(is_datanode_accept_ingest_workload),
Some(accept_ingest_workload),
)
.await
.unwrap();

View File

@@ -144,19 +144,22 @@ pub async fn alive_datanode(
Ok(v)
}
/// Returns true if the datanode can accept ingest workload based on its workload types.
/// Determines if a datanode is capable of accepting ingest workloads.
/// Returns `true` if the datanode's workload types include ingest capability,
/// or if the node is not of type [NodeWorkloads::Datanode].
///
/// A datanode is considered to accept ingest workload if it supports either:
/// - Hybrid workload (both ingest and query workloads)
/// - Ingest workload (only ingest workload)
pub fn is_datanode_accept_ingest_workload(datanode_workloads: &NodeWorkloads) -> bool {
pub fn accept_ingest_workload(datanode_workloads: &NodeWorkloads) -> bool {
match &datanode_workloads {
NodeWorkloads::Datanode(workloads) => workloads
.types
.iter()
.filter_map(|w| DatanodeWorkloadType::from_i32(*w))
.any(|w| w.accept_ingest()),
_ => false,
// If the [NodeWorkloads] type is not [NodeWorkloads::Datanode], returns true.
_ => true,
}
}

View File

@@ -984,8 +984,8 @@ mod tests {
use common_telemetry::init_default_ut_logging;
use super::*;
use crate::bootstrap::create_mysql_pool;
use crate::error;
use crate::utils::mysql::create_mysql_pool;
async fn create_mysql_client(
table_name: Option<&str>,
@@ -1000,7 +1000,7 @@ mod tests {
}
.fail();
}
let pool = create_mysql_pool(&[endpoint]).await.unwrap();
let pool = create_mysql_pool(&[endpoint], None).await.unwrap();
let mut client = ElectionMysqlClient::new(
pool,
execution_timeout,

View File

@@ -826,8 +826,8 @@ mod tests {
use common_meta::maybe_skip_postgres_integration_test;
use super::*;
use crate::bootstrap::create_postgres_pool;
use crate::error;
use crate::utils::postgres::create_postgres_pool;
async fn create_postgres_client(
table_name: Option<&str>,
@@ -842,7 +842,7 @@ mod tests {
}
.fail();
}
let pool = create_postgres_pool(&[endpoint], None).await.unwrap();
let pool = create_postgres_pool(&[endpoint], None, None).await.unwrap();
let mut pg_client = ElectionPgClient::new(
pool,
execution_timeout,

View File

@@ -981,6 +981,14 @@ pub enum Error {
#[snafu(source)]
source: common_meta::error::Error,
},
#[snafu(display("Failed to build tls options"))]
BuildTlsOptions {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
source: common_meta::error::Error,
},
}
impl Error {
@@ -1116,6 +1124,7 @@ impl ErrorExt for Error {
| Error::InitDdlManager { source, .. }
| Error::InitReconciliationManager { source, .. } => source.status_code(),
Error::BuildTlsOptions { source, .. } => source.status_code(),
Error::Other { source, .. } => source.status_code(),
Error::NoEnoughAvailableNode { .. } => StatusCode::RuntimeResourcesExhausted,

View File

@@ -20,6 +20,7 @@ use common_meta::error::{ExternalSnafu, Result as MetaResult};
use common_meta::peer::{Peer, PeerAllocator};
use snafu::{ResultExt, ensure};
use crate::discovery::utils::accept_ingest_workload;
use crate::error::{Result, TooManyPartitionsSnafu};
use crate::metasrv::{SelectorContext, SelectorRef};
use crate::selector::SelectorOptions;
@@ -69,6 +70,7 @@ impl MetasrvPeerAllocator {
min_required_items,
allow_duplication: true,
exclude_peer_ids: HashSet::new(),
workload_filter: Some(accept_ingest_workload),
},
)
.await

View File

@@ -261,12 +261,8 @@ impl WalPruneManager {
Err(error::Error::PruneTaskAlreadyRunning { topic, .. }) => {
warn!("Prune task for topic {} is already running", topic);
}
Err(e) => {
error!(
"Failed to submit prune task for topic {}: {}",
topic_name.clone(),
e
);
Err(err) => {
error!(err; "Failed to prune remote WAL for topic {}", topic_name.as_str());
}
}
});

View File

@@ -40,6 +40,7 @@ use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::oneshot;
use tokio::time::{MissedTickBehavior, interval, interval_at};
use crate::discovery::utils::accept_ingest_workload;
use crate::error::{self, Result};
use crate::failure_detector::PhiAccrualFailureDetectorOptions;
use crate::metasrv::{RegionStatAwareSelectorRef, SelectTarget, SelectorContext, SelectorRef};
@@ -584,6 +585,7 @@ impl RegionSupervisor {
min_required_items: regions.len(),
allow_duplication: true,
exclude_peer_ids,
workload_filter: Some(accept_ingest_workload),
};
let peers = selector.select(&self.selector_context, opt).await?;
ensure!(

View File

@@ -22,6 +22,7 @@ pub mod weight_compute;
pub mod weighted_choose;
use std::collections::HashSet;
use api::v1::meta::heartbeat_request::NodeWorkloads;
use serde::{Deserialize, Serialize};
use store_api::storage::RegionId;
use strum::AsRefStr;
@@ -63,6 +64,8 @@ pub struct SelectorOptions {
pub allow_duplication: bool,
/// The peers to exclude from the selection.
pub exclude_peer_ids: HashSet<u64>,
/// The filter to select the peers based on their workloads.
pub workload_filter: Option<fn(&NodeWorkloads) -> bool>,
}
impl Default for SelectorOptions {
@@ -71,6 +74,7 @@ impl Default for SelectorOptions {
min_required_items: 1,
allow_duplication: false,
exclude_peer_ids: HashSet::new(),
workload_filter: None,
}
}
}

View File

@@ -139,6 +139,7 @@ mod tests {
min_required_items: i,
allow_duplication: false,
exclude_peer_ids: HashSet::new(),
workload_filter: None,
};
let selected_peers: HashSet<_> =
@@ -154,6 +155,7 @@ mod tests {
min_required_items: 6,
allow_duplication: false,
exclude_peer_ids: HashSet::new(),
workload_filter: None,
};
let selected_result =
@@ -165,6 +167,7 @@ mod tests {
min_required_items: i,
allow_duplication: true,
exclude_peer_ids: HashSet::new(),
workload_filter: None,
};
let selected_peers =

View File

@@ -15,7 +15,6 @@
use common_meta::peer::Peer;
use snafu::ResultExt;
use crate::discovery::utils::is_datanode_accept_ingest_workload;
use crate::error::{ListActiveDatanodesSnafu, Result};
use crate::metasrv::SelectorContext;
use crate::selector::common::{choose_items, filter_out_excluded_peers};
@@ -35,7 +34,7 @@ impl Selector for LeaseBasedSelector {
// 1. get alive datanodes.
let alive_datanodes = ctx
.peer_discovery
.active_datanodes(Some(is_datanode_accept_ingest_workload))
.active_datanodes(opts.workload_filter)
.await
.context(ListActiveDatanodesSnafu)?;

View File

@@ -20,7 +20,6 @@ use common_telemetry::debug;
use snafu::ResultExt;
use crate::cluster::MetaPeerClientRef;
use crate::discovery::utils::is_datanode_accept_ingest_workload;
use crate::error::{ListActiveDatanodesSnafu, Result};
use crate::metasrv::SelectorContext;
use crate::selector::common::{choose_items, filter_out_excluded_peers};
@@ -54,7 +53,7 @@ where
// 1. get alive datanodes.
let alive_datanodes = ctx
.peer_discovery
.active_datanodes(Some(is_datanode_accept_ingest_workload))
.active_datanodes(opts.workload_filter)
.await
.context(ListActiveDatanodesSnafu)?;

View File

@@ -17,7 +17,6 @@ use std::sync::atomic::AtomicUsize;
use common_meta::peer::Peer;
use snafu::{ResultExt, ensure};
use crate::discovery::utils::is_datanode_accept_ingest_workload;
use crate::error::{
ListActiveDatanodesSnafu, ListActiveFlownodesSnafu, NoEnoughAvailableNodeSnafu, Result,
};
@@ -59,7 +58,7 @@ impl RoundRobinSelector {
// 1. get alive datanodes.
let alive_datanodes = ctx
.peer_discovery
.active_datanodes(Some(is_datanode_accept_ingest_workload))
.active_datanodes(opts.workload_filter)
.await
.context(ListActiveDatanodesSnafu)?;
@@ -71,7 +70,7 @@ impl RoundRobinSelector {
}
SelectTarget::Flownode => ctx
.peer_discovery
.active_flownodes(None)
.active_flownodes(opts.workload_filter)
.await
.context(ListActiveFlownodesSnafu)?,
};
@@ -150,6 +149,7 @@ mod test {
min_required_items: 4,
allow_duplication: true,
exclude_peer_ids: HashSet::new(),
workload_filter: None,
},
)
.await
@@ -167,6 +167,7 @@ mod test {
min_required_items: 2,
allow_duplication: true,
exclude_peer_ids: HashSet::new(),
workload_filter: None,
},
)
.await
@@ -208,6 +209,7 @@ mod test {
min_required_items: 1,
allow_duplication: true,
exclude_peer_ids: HashSet::from([2, 5]),
workload_filter: None,
},
)
.await

View File

@@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod etcd;
pub mod insert_forwarder;
#[cfg(feature = "mysql_kvbackend")]
pub mod mysql;
#[cfg(feature = "pg_kvbackend")]
pub mod postgres;
#[macro_export]
macro_rules! define_ticker {

View File

@@ -0,0 +1,56 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_meta::kv_backend::etcd::create_etcd_tls_options;
use etcd_client::{Client, ConnectOptions};
use servers::tls::{TlsMode, TlsOption};
use snafu::ResultExt;
use crate::error::{self, BuildTlsOptionsSnafu, Result};
/// Creates an etcd client with TLS configuration.
pub async fn create_etcd_client_with_tls(
store_addrs: &[String],
tls_config: Option<&TlsOption>,
) -> Result<Client> {
let etcd_endpoints = store_addrs
.iter()
.map(|x| x.trim())
.filter(|x| !x.is_empty())
.collect::<Vec<_>>();
let connect_options = tls_config
.map(|c| create_etcd_tls_options(&convert_tls_option(c)))
.transpose()
.context(BuildTlsOptionsSnafu)?
.flatten()
.map(|tls_options| ConnectOptions::new().with_tls(tls_options));
Client::connect(&etcd_endpoints, connect_options)
.await
.context(error::ConnectEtcdSnafu)
}
fn convert_tls_option(tls_option: &TlsOption) -> common_meta::kv_backend::etcd::TlsOption {
let mode = match tls_option.mode {
TlsMode::Disable => common_meta::kv_backend::etcd::TlsMode::Disable,
_ => common_meta::kv_backend::etcd::TlsMode::Require,
};
common_meta::kv_backend::etcd::TlsOption {
mode,
cert_path: tls_option.cert_path.clone(),
key_path: tls_option.key_path.clone(),
ca_cert_path: tls_option.ca_cert_path.clone(),
}
}

View File

@@ -0,0 +1,85 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_telemetry::info;
use servers::tls::{TlsMode, TlsOption};
use snafu::{OptionExt, ResultExt};
use sqlx::mysql::{MySqlConnectOptions, MySqlPool, MySqlSslMode};
use crate::error::{self, Result};
async fn setup_mysql_options(
store_addrs: &[String],
tls_config: Option<&TlsOption>,
) -> Result<MySqlConnectOptions> {
let mysql_url = store_addrs.first().context(error::InvalidArgumentsSnafu {
err_msg: "empty store addrs",
})?;
// Avoid `SET` commands in sqlx
let opts: MySqlConnectOptions = mysql_url
.parse()
.context(error::ParseMySqlUrlSnafu { mysql_url })?;
let mut opts = opts
.no_engine_substitution(false)
.pipes_as_concat(false)
.timezone(None)
.set_names(false);
let Some(tls_config) = tls_config else {
return Ok(opts);
};
match tls_config.mode {
TlsMode::Disable => return Ok(opts),
TlsMode::Prefer => {
opts = opts.ssl_mode(MySqlSslMode::Preferred);
}
TlsMode::Require => {
opts = opts.ssl_mode(MySqlSslMode::Required);
}
TlsMode::VerifyCa => {
opts = opts.ssl_mode(MySqlSslMode::VerifyCa);
opts = opts.ssl_ca(&tls_config.ca_cert_path);
}
TlsMode::VerifyFull => {
opts = opts.ssl_mode(MySqlSslMode::VerifyIdentity);
opts = opts.ssl_ca(&tls_config.ca_cert_path);
}
}
info!(
"Setting up MySQL options with TLS mode: {:?}",
tls_config.mode
);
if !tls_config.cert_path.is_empty() && !tls_config.key_path.is_empty() {
info!("Loading client certificate for mutual TLS");
opts = opts.ssl_client_cert(&tls_config.cert_path);
opts = opts.ssl_client_key(&tls_config.key_path);
}
Ok(opts)
}
/// Creates a MySQL pool.
pub async fn create_mysql_pool(
store_addrs: &[String],
tls_config: Option<&TlsOption>,
) -> Result<MySqlPool> {
let opts = setup_mysql_options(store_addrs, tls_config).await?;
let pool = MySqlPool::connect_with(opts)
.await
.context(error::CreateMySqlPoolSnafu)?;
Ok(pool)
}

View File

@@ -0,0 +1,74 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use common_error::ext::BoxedError;
use common_meta::kv_backend::rds::postgres::{
TlsMode as PgTlsMode, TlsOption as PgTlsOption, create_postgres_tls_connector,
};
use deadpool_postgres::{Config, Runtime};
use servers::tls::TlsOption;
use snafu::{OptionExt, ResultExt};
use tokio_postgres::NoTls;
use crate::error::{self, Result};
/// Converts [`TlsOption`] to [`PgTlsOption`] to avoid circular dependencies
fn convert_tls_option(tls_option: &TlsOption) -> PgTlsOption {
let mode = match tls_option.mode {
servers::tls::TlsMode::Disable => PgTlsMode::Disable,
servers::tls::TlsMode::Prefer => PgTlsMode::Prefer,
servers::tls::TlsMode::Require => PgTlsMode::Require,
servers::tls::TlsMode::VerifyCa => PgTlsMode::VerifyCa,
servers::tls::TlsMode::VerifyFull => PgTlsMode::VerifyFull,
};
PgTlsOption {
mode,
cert_path: tls_option.cert_path.clone(),
key_path: tls_option.key_path.clone(),
ca_cert_path: tls_option.ca_cert_path.clone(),
watch: tls_option.watch,
}
}
/// Creates a pool for the Postgres backend with config and optional TLS.
///
/// It only use first store addr to create a pool, and use the given config to create a pool.
pub async fn create_postgres_pool(
store_addrs: &[String],
cfg: Option<Config>,
tls_config: Option<TlsOption>,
) -> Result<deadpool_postgres::Pool> {
let mut cfg = cfg.unwrap_or_default();
let postgres_url = store_addrs.first().context(error::InvalidArgumentsSnafu {
err_msg: "empty store addrs",
})?;
cfg.url = Some(postgres_url.to_string());
let pool = if let Some(tls_config) = tls_config {
let pg_tls_config = convert_tls_option(&tls_config);
let tls_connector =
create_postgres_tls_connector(&pg_tls_config).map_err(|e| error::Error::Other {
source: BoxedError::new(e),
location: snafu::Location::new(file!(), line!(), 0),
})?;
cfg.create_pool(Some(Runtime::Tokio1), tls_connector)
.context(error::CreatePostgresPoolSnafu)?
} else {
cfg.create_pool(Some(Runtime::Tokio1), NoTls)
.context(error::CreatePostgresPoolSnafu)?
};
Ok(pool)
}

View File

@@ -257,10 +257,10 @@ impl RegionEngine for MetricEngine {
self.handle_query(region_id, request).await
}
async fn get_last_seq_num(
async fn get_committed_sequence(
&self,
region_id: RegionId,
) -> Result<Option<SequenceNumber>, BoxedError> {
) -> Result<SequenceNumber, BoxedError> {
self.inner
.get_last_seq_num(region_id)
.await

View File

@@ -111,6 +111,8 @@ mod tests {
let mito = env.mito();
let debug_format = mito
.all_ssts_from_manifest()
.await
.into_iter()
.map(|mut e| {
e.file_path = e.file_path.replace(&e.file_id, "<file_id>");
e.index_file_path = e
@@ -125,12 +127,12 @@ mod tests {
assert_eq!(
debug_format,
r#"
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47244640257(11, 1), table_id: 11, region_number: 1, region_group: 0, region_sequence: 1, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000001/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/11_0000000001/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(20), origin_region_id: 47244640257(11, 1), node_id: None }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47244640258(11, 2), table_id: 11, region_number: 2, region_group: 0, region_sequence: 2, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000002/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/11_0000000002/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(10), origin_region_id: 47244640258(11, 2), node_id: None }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47261417473(11, 16777217), table_id: 11, region_number: 16777217, region_group: 1, region_sequence: 1, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000001/metadata/<file_id>.parquet", file_size: 3201, index_file_path: None, index_file_size: None, num_rows: 8, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(8), origin_region_id: 47261417473(11, 16777217), node_id: None }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47261417474(11, 16777218), table_id: 11, region_number: 16777218, region_group: 1, region_sequence: 2, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000002/metadata/<file_id>.parquet", file_size: 3185, index_file_path: None, index_file_size: None, num_rows: 4, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(4), origin_region_id: 47261417474(11, 16777218), node_id: None }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 94489280554(22, 42), table_id: 22, region_number: 42, region_group: 0, region_sequence: 42, file_id: "<file_id>", level: 0, file_path: "test_metric_region/22_0000000042/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/22_0000000042/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(10), origin_region_id: 94489280554(22, 42), node_id: None }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 94506057770(22, 16777258), table_id: 22, region_number: 16777258, region_group: 1, region_sequence: 42, file_id: "<file_id>", level: 0, file_path: "test_metric_region/22_0000000042/metadata/<file_id>.parquet", file_size: 3185, index_file_path: None, index_file_size: None, num_rows: 4, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(4), origin_region_id: 94506057770(22, 16777258), node_id: None }"#
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47244640257(11, 1), table_id: 11, region_number: 1, region_group: 0, region_sequence: 1, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000001/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/11_0000000001/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(20), origin_region_id: 47244640257(11, 1), node_id: None, visible: true }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47244640258(11, 2), table_id: 11, region_number: 2, region_group: 0, region_sequence: 2, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000002/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/11_0000000002/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(10), origin_region_id: 47244640258(11, 2), node_id: None, visible: true }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47261417473(11, 16777217), table_id: 11, region_number: 16777217, region_group: 1, region_sequence: 1, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000001/metadata/<file_id>.parquet", file_size: 3429, index_file_path: None, index_file_size: None, num_rows: 8, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(8), origin_region_id: 47261417473(11, 16777217), node_id: None, visible: true }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 47261417474(11, 16777218), table_id: 11, region_number: 16777218, region_group: 1, region_sequence: 2, file_id: "<file_id>", level: 0, file_path: "test_metric_region/11_0000000002/metadata/<file_id>.parquet", file_size: 3413, index_file_path: None, index_file_size: None, num_rows: 4, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(4), origin_region_id: 47261417474(11, 16777218), node_id: None, visible: true }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 94489280554(22, 42), table_id: 22, region_number: 42, region_group: 0, region_sequence: 42, file_id: "<file_id>", level: 0, file_path: "test_metric_region/22_0000000042/data/<file_id>.parquet", file_size: 3157, index_file_path: Some("test_metric_region/22_0000000042/data/index/<file_id>.puffin"), index_file_size: Some(235), num_rows: 10, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 9::Millisecond, sequence: Some(10), origin_region_id: 94489280554(22, 42), node_id: None, visible: true }
ManifestSstEntry { table_dir: "test_metric_region/", region_id: 94506057770(22, 16777258), table_id: 22, region_number: 16777258, region_group: 1, region_sequence: 42, file_id: "<file_id>", level: 0, file_path: "test_metric_region/22_0000000042/metadata/<file_id>.parquet", file_size: 3413, index_file_path: None, index_file_size: None, num_rows: 4, num_row_groups: 1, min_ts: 0::Millisecond, max_ts: 0::Millisecond, sequence: Some(4), origin_region_id: 94506057770(22, 16777258), node_id: None, visible: true }"#
);
// list from storage

View File

@@ -89,7 +89,7 @@ impl MetricEngineInner {
Ok(scanner)
}
pub async fn get_last_seq_num(&self, region_id: RegionId) -> Result<Option<SequenceNumber>> {
pub async fn get_last_seq_num(&self, region_id: RegionId) -> Result<SequenceNumber> {
let region_id = if self.is_physical_region(region_id) {
region_id
} else {
@@ -97,7 +97,7 @@ impl MetricEngineInner {
utils::to_data_region_id(physical_region_id)
};
self.mito
.get_last_seq_num(region_id)
.get_committed_sequence(region_id)
.await
.context(MitoReadOperationSnafu)
}

View File

@@ -19,7 +19,7 @@ common-recordbatch.workspace = true
common-telemetry.workspace = true
common-time.workspace = true
datatypes.workspace = true
memcomparable = "0.2"
memcomparable = { git = "https://github.com/v0y4g3r/memcomparable.git", rev = "a07122dc03556bbd88ad66234cbea7efd3b23efb" }
paste.workspace = true
serde.workspace = true
snafu.workspace = true

View File

@@ -51,7 +51,6 @@ index.workspace = true
itertools.workspace = true
lazy_static = "1.4"
log-store = { workspace = true }
memcomparable = "0.2"
mito-codec.workspace = true
moka = { workspace = true, features = ["sync", "future"] }
object-store.workspace = true

View File

@@ -13,10 +13,11 @@
// limitations under the License.
use std::sync::Arc;
use std::time::Duration;
use std::time::{Duration, Instant};
use async_stream::try_stream;
use common_time::Timestamp;
use either::Either;
use futures::{Stream, TryStreamExt};
use object_store::services::Fs;
use object_store::util::{join_dir, with_instrument_layers};
@@ -26,7 +27,7 @@ use snafu::ResultExt;
use store_api::metadata::RegionMetadataRef;
use store_api::region_request::PathType;
use store_api::sst_entry::StorageSstEntry;
use store_api::storage::{RegionId, SequenceNumber};
use store_api::storage::{FileId, RegionId, SequenceNumber};
use crate::cache::CacheManagerRef;
use crate::cache::file_cache::{FileCacheRef, FileType, IndexKey};
@@ -34,9 +35,9 @@ use crate::cache::write_cache::SstUploadRequest;
use crate::config::{BloomFilterConfig, FulltextIndexConfig, InvertedIndexConfig};
use crate::error::{CleanDirSnafu, DeleteIndexSnafu, DeleteSstSnafu, OpenDalSnafu, Result};
use crate::metrics::{COMPACTION_STAGE_ELAPSED, FLUSH_ELAPSED};
use crate::read::Source;
use crate::read::{FlatSource, Source};
use crate::region::options::IndexOptions;
use crate::sst::file::{FileHandle, FileId, RegionFileId};
use crate::sst::file::{FileHandle, RegionFileId};
use crate::sst::index::IndexerBuilderImpl;
use crate::sst::index::intermediate::IntermediateManager;
use crate::sst::index::puffin_manager::PuffinManagerFactory;
@@ -44,6 +45,7 @@ use crate::sst::location::{self, region_dir_from_table_dir};
use crate::sst::parquet::reader::ParquetReaderBuilder;
use crate::sst::parquet::writer::ParquetWriter;
use crate::sst::parquet::{SstInfo, WriteOptions};
use crate::sst::{DEFAULT_WRITE_BUFFER_SIZE, DEFAULT_WRITE_CONCURRENCY};
pub type AccessLayerRef = Arc<AccessLayer>;
/// SST write results.
@@ -66,6 +68,7 @@ pub struct Metrics {
pub(crate) update_index: Duration,
pub(crate) upload_parquet: Duration,
pub(crate) upload_puffin: Duration,
pub(crate) compact_memtable: Duration,
}
impl Metrics {
@@ -77,6 +80,7 @@ impl Metrics {
update_index: Default::default(),
upload_parquet: Default::default(),
upload_puffin: Default::default(),
compact_memtable: Default::default(),
}
}
@@ -87,6 +91,7 @@ impl Metrics {
self.update_index += other.update_index;
self.upload_parquet += other.upload_parquet;
self.upload_puffin += other.upload_puffin;
self.compact_memtable += other.compact_memtable;
self
}
@@ -108,6 +113,11 @@ impl Metrics {
FLUSH_ELAPSED
.with_label_values(&["upload_puffin"])
.observe(self.upload_puffin.as_secs_f64());
if !self.compact_memtable.is_zero() {
FLUSH_ELAPSED
.with_label_values(&["compact_memtable"])
.observe(self.upload_puffin.as_secs_f64());
}
}
WriteType::Compaction => {
COMPACTION_STAGE_ELAPSED
@@ -288,9 +298,16 @@ impl AccessLayer {
)
.await
.with_file_cleaner(cleaner);
let ssts = writer
.write_all(request.source, request.max_sequence, write_opts)
.await?;
let ssts = match request.source {
Either::Left(source) => {
writer
.write_all(source, request.max_sequence, write_opts)
.await?
}
Either::Right(flat_source) => {
writer.write_all_flat(flat_source, write_opts).await?
}
};
let metrics = writer.into_metrics();
(ssts, metrics)
};
@@ -310,6 +327,53 @@ impl AccessLayer {
Ok((sst_info, metrics))
}
/// Puts encoded SST bytes to the write cache (if enabled) and uploads it to the object store.
pub(crate) async fn put_sst(
&self,
data: &bytes::Bytes,
region_id: RegionId,
sst_info: &SstInfo,
cache_manager: &CacheManagerRef,
) -> Result<Metrics> {
if let Some(write_cache) = cache_manager.write_cache() {
// Write to cache and upload to remote store
let upload_request = SstUploadRequest {
dest_path_provider: RegionFilePathFactory::new(
self.table_dir.clone(),
self.path_type,
),
remote_store: self.object_store.clone(),
};
write_cache
.put_and_upload_sst(data, region_id, sst_info, upload_request)
.await
} else {
let start = Instant::now();
let cleaner = TempFileCleaner::new(region_id, self.object_store.clone());
let path_provider = RegionFilePathFactory::new(self.table_dir.clone(), self.path_type);
let sst_file_path =
path_provider.build_sst_file_path(RegionFileId::new(region_id, sst_info.file_id));
let mut writer = self
.object_store
.writer_with(&sst_file_path)
.chunk(DEFAULT_WRITE_BUFFER_SIZE.as_bytes() as usize)
.concurrent(DEFAULT_WRITE_CONCURRENCY)
.await
.context(OpenDalSnafu)?;
if let Err(err) = writer.write(data.clone()).await.context(OpenDalSnafu) {
cleaner.clean_by_file_id(sst_info.file_id).await;
return Err(err);
}
if let Err(err) = writer.close().await.context(OpenDalSnafu) {
cleaner.clean_by_file_id(sst_info.file_id).await;
return Err(err);
}
let mut metrics = Metrics::new(WriteType::Flush);
metrics.write_batch = start.elapsed();
Ok(metrics)
}
}
/// Lists the SST entries from the storage layer in the table directory.
pub fn storage_sst_entries(&self) -> impl Stream<Item = Result<StorageSstEntry>> + use<> {
let object_store = self.object_store.clone();
@@ -363,7 +427,7 @@ pub enum OperationType {
pub struct SstWriteRequest {
pub op_type: OperationType,
pub metadata: RegionMetadataRef,
pub source: Source,
pub source: Either<Source, FlatSource>,
pub cache_manager: CacheManagerRef,
#[allow(dead_code)]
pub storage: Option<String>,

View File

@@ -35,7 +35,7 @@ use moka::notification::RemovalCause;
use moka::sync::Cache;
use parquet::file::metadata::ParquetMetaData;
use puffin::puffin_manager::cache::{PuffinMetadataCache, PuffinMetadataCacheRef};
use store_api::storage::{ConcreteDataType, RegionId, TimeSeriesRowSelector};
use store_api::storage::{ConcreteDataType, FileId, RegionId, TimeSeriesRowSelector};
use crate::cache::cache_size::parquet_meta_size;
use crate::cache::file_cache::{FileType, IndexKey};
@@ -43,7 +43,7 @@ use crate::cache::index::inverted_index::{InvertedIndexCache, InvertedIndexCache
use crate::cache::write_cache::WriteCacheRef;
use crate::metrics::{CACHE_BYTES, CACHE_EVICTION, CACHE_HIT, CACHE_MISS};
use crate::read::Batch;
use crate::sst::file::{FileId, RegionFileId};
use crate::sst::file::RegionFileId;
/// Metrics type key for sst meta.
const SST_META_TYPE: &str = "sst_meta";

View File

@@ -30,12 +30,11 @@ use object_store::util::join_path;
use object_store::{ErrorKind, ObjectStore, Reader};
use parquet::file::metadata::ParquetMetaData;
use snafu::ResultExt;
use store_api::storage::RegionId;
use store_api::storage::{FileId, RegionId};
use crate::cache::FILE_TYPE;
use crate::error::{OpenDalSnafu, Result};
use crate::metrics::{CACHE_BYTES, CACHE_HIT, CACHE_MISS};
use crate::sst::file::FileId;
use crate::sst::parquet::helper::fetch_byte_ranges;
use crate::sst::parquet::metadata::MetadataLoader;

View File

@@ -20,11 +20,10 @@ use async_trait::async_trait;
use bytes::Bytes;
use index::bloom_filter::error::Result;
use index::bloom_filter::reader::BloomFilterReader;
use store_api::storage::ColumnId;
use store_api::storage::{ColumnId, FileId};
use crate::cache::index::{INDEX_METADATA_TYPE, IndexCache, PageKey};
use crate::metrics::{CACHE_HIT, CACHE_MISS};
use crate::sst::file::FileId;
const INDEX_TYPE_BLOOM_FILTER_INDEX: &str = "bloom_filter_index";

View File

@@ -21,10 +21,10 @@ use bytes::Bytes;
use index::inverted_index::error::Result;
use index::inverted_index::format::reader::InvertedIndexReader;
use prost::Message;
use store_api::storage::FileId;
use crate::cache::index::{INDEX_METADATA_TYPE, IndexCache, PageKey};
use crate::metrics::{CACHE_HIT, CACHE_MISS};
use crate::sst::file::FileId;
const INDEX_TYPE_INVERTED_INDEX: &str = "inverted_index";

View File

@@ -19,10 +19,9 @@ use index::bloom_filter::applier::InListPredicate;
use index::inverted_index::search::predicate::{Predicate, RangePredicate};
use moka::notification::RemovalCause;
use moka::sync::Cache;
use store_api::storage::ColumnId;
use store_api::storage::{ColumnId, FileId};
use crate::metrics::{CACHE_BYTES, CACHE_EVICTION, CACHE_HIT, CACHE_MISS};
use crate::sst::file::FileId;
use crate::sst::index::fulltext_index::applier::builder::{
FulltextQuery, FulltextRequest, FulltextTerm,
};

View File

@@ -37,8 +37,8 @@ use crate::sst::file::RegionFileId;
use crate::sst::index::IndexerBuilderImpl;
use crate::sst::index::intermediate::IntermediateManager;
use crate::sst::index::puffin_manager::PuffinManagerFactory;
use crate::sst::parquet::WriteOptions;
use crate::sst::parquet::writer::ParquetWriter;
use crate::sst::parquet::{SstInfo, WriteOptions};
use crate::sst::{DEFAULT_WRITE_BUFFER_SIZE, DEFAULT_WRITE_CONCURRENCY};
/// A cache for uploading files to remote object stores.
@@ -101,6 +101,66 @@ impl WriteCache {
self.file_cache.clone()
}
/// Put encoded SST data to the cache and upload to the remote object store.
pub(crate) async fn put_and_upload_sst(
&self,
data: &bytes::Bytes,
region_id: RegionId,
sst_info: &SstInfo,
upload_request: SstUploadRequest,
) -> Result<Metrics> {
let file_id = sst_info.file_id;
let mut metrics = Metrics::new(WriteType::Flush);
// Create index key for the SST file
let parquet_key = IndexKey::new(region_id, file_id, FileType::Parquet);
// Write to cache first
let cache_start = Instant::now();
let cache_path = self.file_cache.cache_file_path(parquet_key);
let mut cache_writer = self
.file_cache
.local_store()
.writer(&cache_path)
.await
.context(crate::error::OpenDalSnafu)?;
cache_writer
.write(data.clone())
.await
.context(crate::error::OpenDalSnafu)?;
cache_writer
.close()
.await
.context(crate::error::OpenDalSnafu)?;
// Register in file cache
let index_value = IndexValue {
file_size: data.len() as u32,
};
self.file_cache.put(parquet_key, index_value).await;
metrics.write_batch = cache_start.elapsed();
// Upload to remote store
let upload_start = Instant::now();
let region_file_id = RegionFileId::new(region_id, file_id);
let remote_path = upload_request
.dest_path_provider
.build_sst_file_path(region_file_id);
if let Err(e) = self
.upload(parquet_key, &remote_path, &upload_request.remote_store)
.await
{
// Clean up cache on failure
self.remove(parquet_key).await;
return Err(e);
}
metrics.upload_parquet = upload_start.elapsed();
Ok(metrics)
}
/// Writes SST to the cache and then uploads it to the remote object store.
pub(crate) async fn write_and_upload_sst(
&self,
@@ -139,9 +199,14 @@ impl WriteCache {
.await
.with_file_cleaner(cleaner);
let sst_info = writer
.write_all(write_request.source, write_request.max_sequence, write_opts)
.await?;
let sst_info = match write_request.source {
either::Left(source) => {
writer
.write_all(source, write_request.max_sequence, write_opts)
.await?
}
either::Right(flat_source) => writer.write_all_flat(flat_source, write_opts).await?,
};
let mut metrics = writer.into_metrics();
// Upload sst file to remote object store.
@@ -469,7 +534,7 @@ mod tests {
let write_request = SstWriteRequest {
op_type: OperationType::Flush,
metadata,
source,
source: either::Left(source),
storage: None,
max_sequence: None,
cache_manager: Default::default(),
@@ -567,7 +632,7 @@ mod tests {
let write_request = SstWriteRequest {
op_type: OperationType::Flush,
metadata,
source,
source: either::Left(source),
storage: None,
max_sequence: None,
cache_manager: cache_manager.clone(),
@@ -646,7 +711,7 @@ mod tests {
let write_request = SstWriteRequest {
op_type: OperationType::Flush,
metadata,
source,
source: either::Left(source),
storage: None,
max_sequence: None,
cache_manager: cache_manager.clone(),

View File

@@ -55,10 +55,10 @@ use crate::error::{
TimeRangePredicateOverflowSnafu, TimeoutSnafu,
};
use crate::metrics::{COMPACTION_STAGE_ELAPSED, INFLIGHT_COMPACTION_COUNT};
use crate::read::BoxedBatchReader;
use crate::read::projection::ProjectionMapper;
use crate::read::scan_region::{PredicateGroup, ScanInput};
use crate::read::seq_scan::SeqScan;
use crate::read::{BoxedBatchReader, BoxedRecordBatchStream};
use crate::region::options::MergeMode;
use crate::region::version::VersionControlRef;
use crate::region::{ManifestContextRef, RegionLeaderState, RegionRoleState};
@@ -638,9 +638,26 @@ struct CompactionSstReaderBuilder<'a> {
impl CompactionSstReaderBuilder<'_> {
/// Builds [BoxedBatchReader] that reads all SST files and yields batches in primary key order.
async fn build_sst_reader(self) -> Result<BoxedBatchReader> {
let scan_input = self.build_scan_input(false)?;
SeqScan::new(scan_input, true)
.build_reader_for_compaction()
.await
}
/// Builds [BoxedRecordBatchStream] that reads all SST files and yields batches in flat format for compaction.
async fn build_flat_sst_reader(self) -> Result<BoxedRecordBatchStream> {
let scan_input = self.build_scan_input(true)?;
SeqScan::new(scan_input, true)
.build_flat_reader_for_compaction()
.await
}
fn build_scan_input(self, flat_format: bool) -> Result<ScanInput> {
let mut scan_input = ScanInput::new(
self.sst_layer,
ProjectionMapper::all(&self.metadata, false)?,
ProjectionMapper::all(&self.metadata, flat_format)?,
)
.with_files(self.inputs.to_vec())
.with_append_mode(self.append_mode)
@@ -649,7 +666,8 @@ impl CompactionSstReaderBuilder<'_> {
.with_filter_deleted(self.filter_deleted)
// We ignore file not found error during compaction.
.with_ignore_file_not_found(true)
.with_merge_mode(self.merge_mode);
.with_merge_mode(self.merge_mode)
.with_flat_format(flat_format);
// This serves as a workaround of https://github.com/GreptimeTeam/greptimedb/issues/3944
// by converting time ranges into predicate.
@@ -658,9 +676,7 @@ impl CompactionSstReaderBuilder<'_> {
scan_input.with_predicate(time_range_to_predicate(time_range, &self.metadata)?);
}
SeqScan::new(scan_input, true)
.build_reader_for_compaction()
.await
Ok(scan_input)
}
}

View File

@@ -80,9 +80,10 @@ pub(crate) const TIME_BUCKETS: TimeBuckets = TimeBuckets([
#[cfg(test)]
mod tests {
use store_api::storage::FileId;
use super::*;
use crate::compaction::test_util::new_file_handle;
use crate::sst::file::FileId;
#[test]
fn test_time_bucket() {

View File

@@ -42,7 +42,7 @@ use crate::manifest::action::{RegionEdit, RegionMetaAction, RegionMetaActionList
use crate::manifest::manager::{RegionManifestManager, RegionManifestOptions, RemoveFileOptions};
use crate::manifest::storage::manifest_compress_type;
use crate::metrics;
use crate::read::Source;
use crate::read::{FlatSource, Source};
use crate::region::opener::new_manifest_dir;
use crate::region::options::RegionOptions;
use crate::region::version::VersionRef;
@@ -342,6 +342,9 @@ impl Compactor for DefaultCompactor {
.clone();
let append_mode = compaction_region.current_version.options.append_mode;
let merge_mode = compaction_region.current_version.options.merge_mode();
let flat_format = compaction_region
.engine_config
.enable_experimental_flat_format;
let inverted_index_config = compaction_region.engine_config.inverted_index.clone();
let fulltext_index_config = compaction_region.engine_config.fulltext_index.clone();
let bloom_filter_index_config =
@@ -359,7 +362,7 @@ impl Compactor for DefaultCompactor {
.iter()
.map(|f| f.file_id().to_string())
.join(",");
let reader = CompactionSstReaderBuilder {
let builder = CompactionSstReaderBuilder {
metadata: region_metadata.clone(),
sst_layer: sst_layer.clone(),
cache: cache_manager.clone(),
@@ -368,15 +371,20 @@ impl Compactor for DefaultCompactor {
filter_deleted: output.filter_deleted,
time_range: output.output_time_range,
merge_mode,
}
.build_sst_reader()
.await?;
};
let source = if flat_format {
let reader = builder.build_flat_sst_reader().await?;
either::Right(FlatSource::Stream(reader))
} else {
let reader = builder.build_sst_reader().await?;
either::Left(Source::Reader(reader))
};
let (sst_infos, metrics) = sst_layer
.write_sst(
SstWriteRequest {
op_type: OperationType::Compact,
metadata: region_metadata,
source: Source::Reader(reader),
source,
cache_manager,
storage,
max_sequence: max_sequence.map(NonZero::get),
@@ -475,6 +483,7 @@ impl Compactor for DefaultCompactor {
.map(|seconds| Duration::from_secs(seconds as u64)),
flushed_entry_id: None,
flushed_sequence: None,
committed_sequence: None,
};
let action_list = RegionMetaActionList::with_action(RegionMetaAction::Edit(edit.clone()));

View File

@@ -147,9 +147,10 @@ pub fn new_picker(
#[cfg(test)]
mod tests {
use store_api::storage::FileId;
use super::*;
use crate::compaction::test_util::new_file_handle;
use crate::sst::file::FileId;
use crate::test_util::new_noop_file_purger;
#[test]

View File

@@ -15,8 +15,9 @@
use std::num::NonZeroU64;
use common_time::Timestamp;
use store_api::storage::FileId;
use crate::sst::file::{FileHandle, FileId, FileMeta, Level};
use crate::sst::file::{FileHandle, FileMeta, Level};
use crate::test_util::new_noop_file_purger;
/// Test util to create file handles.

View File

@@ -350,11 +350,13 @@ fn find_latest_window_in_seconds<'a>(
mod tests {
use std::collections::HashSet;
use store_api::storage::FileId;
use super::*;
use crate::compaction::test_util::{
new_file_handle, new_file_handle_with_sequence, new_file_handle_with_size_and_sequence,
};
use crate::sst::file::{FileId, Level};
use crate::sst::file::Level;
#[test]
fn test_get_latest_window_in_seconds() {

View File

@@ -206,12 +206,12 @@ mod tests {
use common_time::Timestamp;
use common_time::range::TimestampRange;
use store_api::storage::RegionId;
use store_api::storage::{FileId, RegionId};
use crate::compaction::compactor::CompactionVersion;
use crate::compaction::window::{WindowedCompactionPicker, file_time_bucket_span};
use crate::region::options::RegionOptions;
use crate::sst::file::{FileId, FileMeta, Level};
use crate::sst::file::{FileMeta, Level};
use crate::sst::file_purger::NoopFilePurger;
use crate::sst::version::SstVersion;
use crate::test_util::memtable_util::metadata_for_test;

View File

@@ -141,6 +141,10 @@ pub struct MitoConfig {
/// To align with the old behavior, the default value is 0 (no restrictions).
#[serde(with = "humantime_serde")]
pub min_compaction_interval: Duration,
/// Whether to enable experimental flat format.
/// When enabled, forces using BulkMemtable and BulkMemtableBuilder.
pub enable_experimental_flat_format: bool,
}
impl Default for MitoConfig {
@@ -177,6 +181,7 @@ impl Default for MitoConfig {
bloom_filter_index: BloomFilterConfig::default(),
memtable: MemtableConfig::default(),
min_compaction_interval: Duration::from_secs(0),
enable_experimental_flat_format: false,
};
// Adjust buffer and cache size according to system memory if we can.

View File

@@ -23,6 +23,8 @@ mod basic_test;
#[cfg(test)]
mod batch_open_test;
#[cfg(test)]
mod bump_committed_sequence_test;
#[cfg(test)]
mod catchup_test;
#[cfg(test)]
mod close_test;
@@ -53,6 +55,8 @@ mod prune_test;
#[cfg(test)]
mod row_selector_test;
#[cfg(test)]
mod scan_corrupt;
#[cfg(test)]
mod scan_test;
#[cfg(test)]
mod set_role_state_test;
@@ -414,16 +418,20 @@ impl MitoEngine {
}
/// Lists all SSTs from the manifest of all regions in the engine.
pub fn all_ssts_from_manifest(&self) -> impl Iterator<Item = ManifestSstEntry> + use<'_> {
pub async fn all_ssts_from_manifest(&self) -> Vec<ManifestSstEntry> {
let node_id = self.inner.workers.file_ref_manager().node_id();
self.inner
.workers
.all_regions()
.flat_map(|region| region.manifest_sst_entries())
.map(move |mut entry| {
entry.node_id = node_id;
entry
})
let regions = self.inner.workers.all_regions();
let mut results = Vec::new();
for region in regions {
let mut entries = region.manifest_sst_entries().await;
for e in &mut entries {
e.node_id = node_id;
}
results.extend(entries);
}
results
}
/// Lists all SSTs from the storage layer of all regions in the engine.
@@ -465,6 +473,7 @@ fn is_valid_region_edit(edit: &RegionEdit) -> bool {
compaction_time_window: None,
flushed_entry_id: None,
flushed_sequence: None,
..
}
)
}
@@ -658,10 +667,11 @@ impl EngineInner {
receiver.await.context(RecvSnafu)?
}
fn get_last_seq_num(&self, region_id: RegionId) -> Result<Option<SequenceNumber>> {
/// Returns the sequence of latest committed data.
fn get_committed_sequence(&self, region_id: RegionId) -> Result<SequenceNumber> {
// Reading a region doesn't need to go through the region worker thread.
let region = self.find_region(region_id)?;
Ok(Some(region.find_committed_sequence()))
self.find_region(region_id)
.map(|r| r.find_committed_sequence())
}
/// Handles the scan `request` and returns a [ScanRegion].
@@ -685,8 +695,7 @@ impl EngineInner {
.with_ignore_fulltext_index(self.config.fulltext_index.apply_on_query.disabled())
.with_ignore_bloom_filter(self.config.bloom_filter_index.apply_on_query.disabled())
.with_start_time(query_start)
// TODO(yingwen): Enable it after flat format is supported.
.with_flat_format(false);
.with_flat_format(self.config.enable_experimental_flat_format);
#[cfg(feature = "enterprise")]
let scan_region = self.maybe_fill_extension_range_provider(scan_region, region);
@@ -829,12 +838,12 @@ impl RegionEngine for MitoEngine {
.map_err(BoxedError::new)
}
async fn get_last_seq_num(
async fn get_committed_sequence(
&self,
region_id: RegionId,
) -> Result<Option<SequenceNumber>, BoxedError> {
) -> Result<SequenceNumber, BoxedError> {
self.inner
.get_last_seq_num(region_id)
.get_committed_sequence(region_id)
.map_err(BoxedError::new)
}
@@ -1018,6 +1027,7 @@ mod tests {
compaction_time_window: None,
flushed_entry_id: None,
flushed_sequence: None,
committed_sequence: None,
};
assert!(is_valid_region_edit(&edit));
@@ -1029,6 +1039,7 @@ mod tests {
compaction_time_window: None,
flushed_entry_id: None,
flushed_sequence: None,
committed_sequence: None,
};
assert!(!is_valid_region_edit(&edit));
@@ -1040,6 +1051,7 @@ mod tests {
compaction_time_window: None,
flushed_entry_id: None,
flushed_sequence: None,
committed_sequence: None,
};
assert!(!is_valid_region_edit(&edit));
@@ -1051,6 +1063,7 @@ mod tests {
compaction_time_window: Some(Duration::from_secs(1)),
flushed_entry_id: None,
flushed_sequence: None,
committed_sequence: None,
};
assert!(!is_valid_region_edit(&edit));
let edit = RegionEdit {
@@ -1060,6 +1073,7 @@ mod tests {
compaction_time_window: None,
flushed_entry_id: Some(1),
flushed_sequence: None,
committed_sequence: None,
};
assert!(!is_valid_region_edit(&edit));
let edit = RegionEdit {
@@ -1069,6 +1083,7 @@ mod tests {
compaction_time_window: None,
flushed_entry_id: None,
flushed_sequence: Some(1),
committed_sequence: None,
};
assert!(!is_valid_region_edit(&edit));
}

Some files were not shown because too many files have changed in this diff Show More