Compare commits

..

3 Commits

Author SHA1 Message Date
discord9
2f8e8be042 test: build a failing example with LogicalPlanBuilder
Signed-off-by: discord9 <discord9@163.com>
2025-08-12 15:19:47 +08:00
discord9
9d30459a58 test: reproduce the panic, still no clue why
Signed-off-by: discord9 <discord9@163.com>
2025-08-11 19:37:48 +08:00
discord9
f1650a78f7 fix?: optimize projection after join
Signed-off-by: discord9 <discord9@163.com>
2025-08-07 19:55:32 +08:00
519 changed files with 9463 additions and 25185 deletions

View File

@@ -35,8 +35,8 @@ HIGHER_VERSION=$(printf "%s\n%s" "$CLEAN_CURRENT" "$CLEAN_LATEST" | sort -V | ta
if [ "$HIGHER_VERSION" = "$CLEAN_CURRENT" ]; then
echo "Current version ($CLEAN_CURRENT) is NEWER than or EQUAL to latest ($CLEAN_LATEST)"
echo "is-current-version-latest=true" >> $GITHUB_OUTPUT
echo "should-push-latest-tag=true" >> $GITHUB_OUTPUT
else
echo "Current version ($CLEAN_CURRENT) is OLDER than latest ($CLEAN_LATEST)"
echo "is-current-version-latest=false" >> $GITHUB_OUTPUT
echo "should-push-latest-tag=false" >> $GITHUB_OUTPUT
fi

View File

@@ -21,7 +21,7 @@ update_dev_builder_version() {
# Commit the changes.
git add Makefile
git commit -s -m "ci: update dev-builder image tag"
git commit -m "ci: update dev-builder image tag"
git push origin $BRANCH_NAME
# Create a Pull Request.

View File

@@ -12,7 +12,6 @@ on:
- 'docker/**'
- '.gitignore'
- 'grafana/**'
- 'Makefile'
workflow_dispatch:
name: CI
@@ -710,7 +709,7 @@ jobs:
- name: Install toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
cache: false
cache: false
- name: Rust Cache
uses: Swatinem/rust-cache@v2
with:

View File

@@ -10,7 +10,6 @@ on:
- 'docker/**'
- '.gitignore'
- 'grafana/**'
- 'Makefile'
push:
branches:
- main
@@ -22,7 +21,6 @@ on:
- 'docker/**'
- '.gitignore'
- 'grafana/**'
- 'Makefile'
workflow_dispatch:
name: CI

View File

@@ -111,8 +111,7 @@ jobs:
# The 'version' use as the global tag name of the release workflow.
version: ${{ steps.create-version.outputs.version }}
# The 'is-current-version-latest' determines whether to update 'latest' Docker tags and downstream repositories.
is-current-version-latest: ${{ steps.check-version.outputs.is-current-version-latest }}
should-push-latest-tag: ${{ steps.check-version.outputs.should-push-latest-tag }}
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -322,7 +321,7 @@ jobs:
image-registry-username: ${{ secrets.DOCKERHUB_USERNAME }}
image-registry-password: ${{ secrets.DOCKERHUB_TOKEN }}
version: ${{ needs.allocate-runners.outputs.version }}
push-latest-tag: ${{ needs.allocate-runners.outputs.is-current-version-latest == 'true' && github.ref_type == 'tag' && !contains(github.ref_name, 'nightly') && github.event_name != 'schedule' }}
push-latest-tag: ${{ needs.allocate-runners.outputs.should-push-latest-tag == 'true' && github.ref_type == 'tag' && !contains(github.ref_name, 'nightly') && github.event_name != 'schedule' }}
- name: Set build image result
id: set-build-image-result
@@ -369,7 +368,7 @@ jobs:
dev-mode: false
upload-to-s3: true
update-version-info: true
push-latest-tag: ${{ needs.allocate-runners.outputs.is-current-version-latest == 'true' && github.ref_type == 'tag' && !contains(github.ref_name, 'nightly') && github.event_name != 'schedule' }}
push-latest-tag: ${{ needs.allocate-runners.outputs.should-push-latest-tag == 'true' && github.ref_type == 'tag' && !contains(github.ref_name, 'nightly') && github.event_name != 'schedule' }}
publish-github-release:
name: Create GitHub release and upload artifacts
@@ -477,7 +476,7 @@ jobs:
bump-helm-charts-version:
name: Bump helm charts version
if: ${{ github.ref_type == 'tag' && !contains(github.ref_name, 'nightly') && github.event_name != 'schedule' && needs.allocate-runners.outputs.is-current-version-latest == 'true' }}
if: ${{ github.ref_type == 'tag' && !contains(github.ref_name, 'nightly') && github.event_name != 'schedule' }}
needs: [allocate-runners, publish-github-release]
runs-on: ubuntu-latest
permissions:
@@ -498,7 +497,7 @@ jobs:
bump-homebrew-greptime-version:
name: Bump homebrew greptime version
if: ${{ github.ref_type == 'tag' && !contains(github.ref_name, 'nightly') && github.event_name != 'schedule' && needs.allocate-runners.outputs.is-current-version-latest == 'true' }}
if: ${{ github.ref_type == 'tag' && !contains(github.ref_name, 'nightly') && github.event_name != 'schedule' }}
needs: [allocate-runners, publish-github-release]
runs-on: ubuntu-latest
permissions:

View File

@@ -55,9 +55,8 @@ GreptimeDB uses the [Apache 2.0 license](https://github.com/GreptimeTeam/greptim
- To ensure that community is free and confident in its ability to use your contributions, please sign the Contributor License Agreement (CLA) which will be incorporated in the pull request process.
- Make sure all files have proper license header (running `docker run --rm -v $(pwd):/github/workspace ghcr.io/korandoru/hawkeye-native:v3 format` from the project root).
- Make sure all your codes are formatted and follow the [coding style](https://pingcap.github.io/style-guide/rust/) and [style guide](docs/style-guide.md).
- Make sure all unit tests are passed using [nextest](https://nexte.st/index.html) `cargo nextest run --workspace --features pg_kvbackend,mysql_kvbackend` or `make test`.
- Make sure all clippy warnings are fixed (you can check it locally by running `cargo clippy --workspace --all-targets -- -D warnings` or `make clippy`).
- When modifying sample configuration files in `config/`, run `make config-docs` (which requires Docker to be installed) to update the configuration documentation and include it in your commit.
- Make sure all unit tests are passed using [nextest](https://nexte.st/index.html) `cargo nextest run`.
- Make sure all clippy warnings are fixed (you can check it locally by running `cargo clippy --workspace --all-targets -- -D warnings`).
#### `pre-commit` Hooks

4745
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -98,12 +98,11 @@ rust.unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tokio_unstable)'] }
# See for more detaiils: https://github.com/rust-lang/cargo/issues/11329
ahash = { version = "0.8", features = ["compile-time-rng"] }
aquamarine = "0.6"
arrow = { version = "56.0", features = ["prettyprint"] }
arrow-array = { version = "56.0", default-features = false, features = ["chrono-tz"] }
arrow-buffer = "56.0"
arrow-flight = "56.0"
arrow-ipc = { version = "56.0", default-features = false, features = ["lz4", "zstd"] }
arrow-schema = { version = "56.0", features = ["serde"] }
arrow = { version = "54.2", features = ["prettyprint"] }
arrow-array = { version = "54.2", default-features = false, features = ["chrono-tz"] }
arrow-flight = "54.2"
arrow-ipc = { version = "54.2", default-features = false, features = ["lz4", "zstd"] }
arrow-schema = { version = "54.2", features = ["serde"] }
async-stream = "0.3"
async-trait = "0.1"
# Remember to update axum-extra, axum-macros when updating axum
@@ -122,27 +121,26 @@ clap = { version = "4.4", features = ["derive"] }
config = "0.13.0"
crossbeam-utils = "0.8"
dashmap = "6.1"
datafusion = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion-expr = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion-functions = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion-functions-aggregate-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion-optimizer = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion-orc = { git = "https://github.com/GreptimeTeam/datafusion-orc", rev = "a0a5f902158f153119316eaeec868cff3fc8a99d" }
datafusion-physical-expr = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion-physical-plan = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion-sql = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion-substrait = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7d5214512740b4dfb742b6b3d91ed9affcc2c9d0" }
datafusion = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
datafusion-common = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
datafusion-expr = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
datafusion-functions = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
datafusion-functions-aggregate-common = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
datafusion-optimizer = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
datafusion-physical-expr = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
datafusion-physical-plan = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
datafusion-sql = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
datafusion-substrait = { git = "https://github.com/waynexia/arrow-datafusion.git", rev = "12c0381babd52c681043957e9d6ee083a03f7646" }
deadpool = "0.12"
deadpool-postgres = "0.14"
derive_builder = "0.20"
dotenv = "0.15"
either = "1.15"
etcd-client = { git = "https://github.com/GreptimeTeam/etcd-client", rev = "f62df834f0cffda355eba96691fe1a9a332b75a7" }
etcd-client = "0.14"
fst = "0.4.7"
futures = "0.3"
futures-util = "0.3"
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "df2bb74b5990c159dfd5b7a344eecf8f4307af64" }
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "ccfd4da48bc0254ed865e479cd981a3581b02d84" }
hex = "0.4"
http = "1"
humantime = "2.1"
@@ -153,7 +151,7 @@ itertools = "0.14"
jsonb = { git = "https://github.com/databendlabs/jsonb.git", rev = "8c8d2fc294a39f3ff08909d60f718639cfba3875", default-features = false }
lazy_static = "1.4"
local-ip-address = "0.6"
loki-proto = { git = "https://github.com/GreptimeTeam/loki-proto.git", rev = "3b7cd33234358b18ece977bf689dc6fb760f29ab" }
loki-proto = { git = "https://github.com/GreptimeTeam/loki-proto.git", rev = "1434ecf23a2654025d86188fb5205e7a74b225d3" }
meter-core = { git = "https://github.com/GreptimeTeam/greptime-meter.git", rev = "5618e779cf2bb4755b499c630fba4c35e91898cb" }
mockall = "0.13"
moka = "0.12"
@@ -161,9 +159,9 @@ nalgebra = "0.33"
nix = { version = "0.30.1", default-features = false, features = ["event", "fs", "process"] }
notify = "8.0"
num_cpus = "1.16"
object_store_opendal = "0.54"
object_store_opendal = "0.50"
once_cell = "1.18"
opentelemetry-proto = { version = "0.30", features = [
opentelemetry-proto = { version = "0.27", features = [
"gen-tonic",
"metrics",
"trace",
@@ -172,14 +170,13 @@ opentelemetry-proto = { version = "0.30", features = [
] }
ordered-float = { version = "4.3", features = ["serde"] }
parking_lot = "0.12"
parquet = { version = "56.0", default-features = false, features = ["arrow", "async", "object_store"] }
parquet = { version = "54.2", default-features = false, features = ["arrow", "async", "object_store"] }
paste = "1.0"
pin-project = "1.0"
pretty_assertions = "1.4.0"
prometheus = { version = "0.13.3", features = ["process"] }
promql-parser = { version = "0.6", features = ["ser"] }
prost = { version = "0.13", features = ["no-recursion-limit"] }
prost-types = "0.13"
raft-engine = { version = "0.4.1", default-features = false }
rand = "0.9"
ratelimit = "0.10"
@@ -191,7 +188,7 @@ reqwest = { version = "0.12", default-features = false, features = [
"stream",
"multipart",
] }
rskafka = { git = "https://github.com/influxdata/rskafka.git", rev = "a62120b6c74d68953464b256f858dc1c41a903b4", features = [
rskafka = { git = "https://github.com/influxdata/rskafka.git", rev = "8dbd01ed809f5a791833a594e85b144e36e45820", features = [
"transport-tls",
] }
rstest = "0.25"
@@ -204,14 +201,15 @@ sea-query = "0.32"
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0", features = ["float_roundtrip"] }
serde_with = "3"
shadow-rs = "1.1"
simd-json = "0.15"
similar-asserts = "1.6.0"
smallvec = { version = "1", features = ["serde"] }
snafu = "0.8"
sqlparser = { git = "https://github.com/GreptimeTeam/sqlparser-rs.git", rev = "39e4fc94c3c741981f77e9d63b5ce8c02e0a27ea", features = [
sqlparser = { git = "https://github.com/GreptimeTeam/sqlparser-rs.git", rev = "df6fcca80ce903f5beef7002cd2c1b062e7024f8", features = [
"visitor",
"serde",
] } # branch = "v0.55.x"
] } # branch = "v0.54.x"
sqlx = { version = "0.8", features = [
"runtime-tokio-rustls",
"mysql",
@@ -221,20 +219,20 @@ sqlx = { version = "0.8", features = [
strum = { version = "0.27", features = ["derive"] }
sysinfo = "0.33"
tempfile = "3"
tokio = { version = "1.47", features = ["full"] }
tokio = { version = "1.40", features = ["full"] }
tokio-postgres = "0.7"
tokio-rustls = { version = "0.26.2", default-features = false }
tokio-stream = "0.1"
tokio-util = { version = "0.7", features = ["io-util", "compat"] }
toml = "0.8.8"
tonic = { version = "0.13", features = ["tls-ring", "gzip", "zstd"] }
tonic = { version = "0.12", features = ["tls", "gzip", "zstd"] }
tower = "0.5"
tower-http = "0.6"
tracing = "0.1"
tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "fmt"] }
typetag = "0.2"
uuid = { version = "1.17", features = ["serde", "v4", "fast-rng"] }
uuid = { version = "1.7", features = ["serde", "v4", "fast-rng"] }
vrl = "0.25"
zstd = "0.13"
# DO_NOT_REMOVE_THIS: END_OF_EXTERNAL_DEPENDENCIES
@@ -293,7 +291,7 @@ mito-codec = { path = "src/mito-codec" }
mito2 = { path = "src/mito2" }
object-store = { path = "src/object-store" }
operator = { path = "src/operator" }
otel-arrow-rust = { git = "https://github.com/GreptimeTeam/otel-arrow", rev = "2d64b7c0fa95642028a8205b36fe9ea0b023ec59", features = [
otel-arrow-rust = { git = "https://github.com/open-telemetry/otel-arrow", rev = "5d551412d2a12e689cde4d84c14ef29e36784e51", features = [
"server",
] }
partition = { path = "src/partition" }

View File

@@ -8,7 +8,7 @@ CARGO_BUILD_OPTS := --locked
IMAGE_REGISTRY ?= docker.io
IMAGE_NAMESPACE ?= greptime
IMAGE_TAG ?= latest
DEV_BUILDER_IMAGE_TAG ?= 2025-05-19-32619816-20250818043248
DEV_BUILDER_IMAGE_TAG ?= 2025-05-19-b2377d4b-20250520045554
BUILDX_MULTI_PLATFORM_BUILD ?= false
BUILDX_BUILDER_NAME ?= gtbuilder
BASE_IMAGE ?= ubuntu
@@ -22,7 +22,7 @@ SQLNESS_OPTS ?=
ETCD_VERSION ?= v3.5.9
ETCD_IMAGE ?= quay.io/coreos/etcd:${ETCD_VERSION}
RETRY_COUNT ?= 3
NEXTEST_OPTS := --retries ${RETRY_COUNT} --features pg_kvbackend,mysql_kvbackend
NEXTEST_OPTS := --retries ${RETRY_COUNT}
BUILD_JOBS ?= $(shell which nproc 1>/dev/null && expr $$(nproc) / 2) # If nproc is not available, we don't set the build jobs.
ifeq ($(BUILD_JOBS), 0) # If the number of cores is less than 2, set the build jobs to 1.
BUILD_JOBS := 1

View File

@@ -41,7 +41,6 @@
| `mysql.addr` | String | `127.0.0.1:4002` | The addr to bind the MySQL server. |
| `mysql.runtime_size` | Integer | `2` | The number of server worker threads. |
| `mysql.keep_alive` | String | `0s` | Server-side keep-alive time.<br/>Set to 0 (default) to disable. |
| `mysql.prepared_stmt_cache_size` | Integer | `10000` | Maximum entries in the MySQL prepared statement cache; default is 10,000. |
| `mysql.tls` | -- | -- | -- |
| `mysql.tls.mode` | String | `disable` | TLS mode, refer to https://www.postgresql.org/docs/current/libpq-ssl.html<br/>- `disable` (default value)<br/>- `prefer`<br/>- `require`<br/>- `verify-ca`<br/>- `verify-full` |
| `mysql.tls.cert_path` | String | Unset | Certificate file path. |
@@ -187,13 +186,12 @@
| `logging.dir` | String | `./greptimedb_data/logs` | The directory to store the log files. If set to empty, logs will not be written to files. |
| `logging.level` | String | Unset | The log level. Can be `info`/`debug`/`warn`/`error`. |
| `logging.enable_otlp_tracing` | Bool | `false` | Enable OTLP tracing. |
| `logging.otlp_endpoint` | String | `http://localhost:4318/v1/traces` | The OTLP tracing endpoint. |
| `logging.otlp_endpoint` | String | `http://localhost:4318` | The OTLP tracing endpoint. |
| `logging.append_stdout` | Bool | `true` | Whether to append logs to stdout. |
| `logging.log_format` | String | `text` | The log format. Can be `text`/`json`. |
| `logging.max_log_files` | Integer | `720` | The maximum amount of log files. |
| `logging.otlp_export_protocol` | String | `http` | The OTLP tracing export protocol. Can be `grpc`/`http`. |
| `logging.otlp_headers` | -- | -- | Additional OTLP headers, only valid when using OTLP http |
| `logging.tracing_sample_ratio` | -- | Unset | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio` | -- | -- | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio.default_ratio` | Float | `1.0` | -- |
| `slow_query` | -- | -- | The slow query log options. |
| `slow_query.enable` | Bool | `false` | Whether to enable slow query log. |
@@ -250,7 +248,6 @@
| `mysql.addr` | String | `127.0.0.1:4002` | The addr to bind the MySQL server. |
| `mysql.runtime_size` | Integer | `2` | The number of server worker threads. |
| `mysql.keep_alive` | String | `0s` | Server-side keep-alive time.<br/>Set to 0 (default) to disable. |
| `mysql.prepared_stmt_cache_size` | Integer | `10000` | Maximum entries in the MySQL prepared statement cache; default is 10,000. |
| `mysql.tls` | -- | -- | -- |
| `mysql.tls.mode` | String | `disable` | TLS mode, refer to https://www.postgresql.org/docs/current/libpq-ssl.html<br/>- `disable` (default value)<br/>- `prefer`<br/>- `require`<br/>- `verify-ca`<br/>- `verify-full` |
| `mysql.tls.cert_path` | String | Unset | Certificate file path. |
@@ -296,20 +293,19 @@
| `logging.dir` | String | `./greptimedb_data/logs` | The directory to store the log files. If set to empty, logs will not be written to files. |
| `logging.level` | String | Unset | The log level. Can be `info`/`debug`/`warn`/`error`. |
| `logging.enable_otlp_tracing` | Bool | `false` | Enable OTLP tracing. |
| `logging.otlp_endpoint` | String | `http://localhost:4318/v1/traces` | The OTLP tracing endpoint. |
| `logging.otlp_endpoint` | String | `http://localhost:4318` | The OTLP tracing endpoint. |
| `logging.append_stdout` | Bool | `true` | Whether to append logs to stdout. |
| `logging.log_format` | String | `text` | The log format. Can be `text`/`json`. |
| `logging.max_log_files` | Integer | `720` | The maximum amount of log files. |
| `logging.otlp_export_protocol` | String | `http` | The OTLP tracing export protocol. Can be `grpc`/`http`. |
| `logging.otlp_headers` | -- | -- | Additional OTLP headers, only valid when using OTLP http |
| `logging.tracing_sample_ratio` | -- | Unset | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio` | -- | -- | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio.default_ratio` | Float | `1.0` | -- |
| `slow_query` | -- | -- | The slow query log options. |
| `slow_query.enable` | Bool | `true` | Whether to enable slow query log. |
| `slow_query.record_type` | String | `system_table` | The record type of slow queries. It can be `system_table` or `log`.<br/>If `system_table` is selected, the slow queries will be recorded in a system table `greptime_private.slow_queries`.<br/>If `log` is selected, the slow queries will be logged in a log file `greptimedb-slow-queries.*`. |
| `slow_query.threshold` | String | `30s` | The threshold of slow query. It can be human readable time string, for example: `10s`, `100ms`, `1s`. |
| `slow_query.sample_ratio` | Float | `1.0` | The sampling ratio of slow query log. The value should be in the range of (0, 1]. For example, `0.1` means 10% of the slow queries will be logged and `1.0` means all slow queries will be logged. |
| `slow_query.ttl` | String | `90d` | The TTL of the `slow_queries` system table. Default is `90d` when `record_type` is `system_table`. |
| `slow_query.ttl` | String | `30d` | The TTL of the `slow_queries` system table. Default is `30d` when `record_type` is `system_table`. |
| `export_metrics` | -- | -- | The frontend can export its metrics and send to Prometheus compatible service (e.g. `greptimedb` itself) from remote-write API.<br/>This is only used for `greptimedb` to export its own metrics internally. It's different from prometheus scrape. |
| `export_metrics.enable` | Bool | `false` | whether enable export metrics. |
| `export_metrics.write_interval` | String | `30s` | The interval of export metrics. |
@@ -320,8 +316,6 @@
| `tracing.tokio_console_addr` | String | Unset | The tokio console address. |
| `memory` | -- | -- | The memory options. |
| `memory.enable_heap_profiling` | Bool | `true` | Whether to enable heap profiling activation during startup.<br/>When enabled, heap profiling will be activated if the `MALLOC_CONF` environment variable<br/>is set to "prof:true,prof_active:false". The official image adds this env variable.<br/>Default is true. |
| `event_recorder` | -- | -- | Configuration options for the event recorder. |
| `event_recorder.ttl` | String | `90d` | TTL for the events table that will be used to store the events. Default is `90d`. |
### Metasrv
@@ -377,32 +371,28 @@
| `datanode.client.tcp_nodelay` | Bool | `true` | `TCP_NODELAY` option for accepted connections. |
| `wal` | -- | -- | -- |
| `wal.provider` | String | `raft_engine` | -- |
| `wal.broker_endpoints` | Array | -- | The broker endpoints of the Kafka cluster.<br/><br/>**It's only used when the provider is `kafka`**. |
| `wal.auto_create_topics` | Bool | `true` | Automatically create topics for WAL.<br/>Set to `true` to automatically create topics for WAL.<br/>Otherwise, use topics named `topic_name_prefix_[0..num_topics)`<br/>**It's only used when the provider is `kafka`**. |
| `wal.auto_prune_interval` | String | `10m` | Interval of automatically WAL pruning.<br/>Set to `0s` to disable automatically WAL pruning which delete unused remote WAL entries periodically.<br/>**It's only used when the provider is `kafka`**. |
| `wal.flush_trigger_size` | String | `512MB` | Estimated size threshold to trigger a flush when using Kafka remote WAL.<br/>Since multiple regions may share a Kafka topic, the estimated size is calculated as:<br/> (latest_entry_id - flushed_entry_id) * avg_record_size<br/>MetaSrv triggers a flush for a region when this estimated size exceeds `flush_trigger_size`.<br/>- `latest_entry_id`: The latest entry ID in the topic.<br/>- `flushed_entry_id`: The last flushed entry ID for the region.<br/>Set to "0" to let the system decide the flush trigger size.<br/>**It's only used when the provider is `kafka`**. |
| `wal.auto_prune_parallelism` | Integer | `10` | Concurrent task limit for automatically WAL pruning.<br/>**It's only used when the provider is `kafka`**. |
| `wal.num_topics` | Integer | `64` | Number of topics used for remote WAL.<br/>**It's only used when the provider is `kafka`**. |
| `wal.selector_type` | String | `round_robin` | Topic selector type.<br/>Available selector types:<br/>- `round_robin` (default)<br/>**It's only used when the provider is `kafka`**. |
| `wal.topic_name_prefix` | String | `greptimedb_wal_topic` | A Kafka topic is constructed by concatenating `topic_name_prefix` and `topic_id`.<br/>Only accepts strings that match the following regular expression pattern:<br/>[a-zA-Z_:-][a-zA-Z0-9_:\-\.@#]*<br/>i.g., greptimedb_wal_topic_0, greptimedb_wal_topic_1.<br/>**It's only used when the provider is `kafka`**. |
| `wal.replication_factor` | Integer | `1` | Expected number of replicas of each partition.<br/>**It's only used when the provider is `kafka`**. |
| `wal.create_topic_timeout` | String | `30s` | The timeout for creating a Kafka topic.<br/>**It's only used when the provider is `kafka`**. |
| `wal.broker_endpoints` | Array | -- | The broker endpoints of the Kafka cluster. |
| `wal.auto_create_topics` | Bool | `true` | Automatically create topics for WAL.<br/>Set to `true` to automatically create topics for WAL.<br/>Otherwise, use topics named `topic_name_prefix_[0..num_topics)` |
| `wal.auto_prune_interval` | String | `0s` | Interval of automatically WAL pruning.<br/>Set to `0s` to disable automatically WAL pruning which delete unused remote WAL entries periodically. |
| `wal.trigger_flush_threshold` | Integer | `0` | The threshold to trigger a flush operation of a region in automatically WAL pruning.<br/>Metasrv will send a flush request to flush the region when:<br/>`trigger_flush_threshold` + `prunable_entry_id` < `max_prunable_entry_id`<br/>where:<br/>- `prunable_entry_id` is the maximum entry id that can be pruned of the region.<br/>- `max_prunable_entry_id` is the maximum prunable entry id among all regions in the same topic.<br/>Set to `0` to disable the flush operation. |
| `wal.auto_prune_parallelism` | Integer | `10` | Concurrent task limit for automatically WAL pruning. |
| `wal.num_topics` | Integer | `64` | Number of topics. |
| `wal.selector_type` | String | `round_robin` | Topic selector type.<br/>Available selector types:<br/>- `round_robin` (default) |
| `wal.topic_name_prefix` | String | `greptimedb_wal_topic` | A Kafka topic is constructed by concatenating `topic_name_prefix` and `topic_id`.<br/>Only accepts strings that match the following regular expression pattern:<br/>[a-zA-Z_:-][a-zA-Z0-9_:\-\.@#]*<br/>i.g., greptimedb_wal_topic_0, greptimedb_wal_topic_1. |
| `wal.replication_factor` | Integer | `1` | Expected number of replicas of each partition. |
| `wal.create_topic_timeout` | String | `30s` | Above which a topic creation operation will be cancelled. |
| `event_recorder` | -- | -- | Configuration options for the event recorder. |
| `event_recorder.ttl` | String | `90d` | TTL for the events table that will be used to store the events. Default is `90d`. |
| `stats_persistence` | -- | -- | Configuration options for the stats persistence. |
| `stats_persistence.ttl` | String | `30d` | TTL for the stats table that will be used to store the stats. Default is `30d`.<br/>Set to `0s` to disable stats persistence. |
| `stats_persistence.interval` | String | `60s` | The interval to persist the stats. Default is `60s`.<br/>The minimum value is `60s`, if the value is less than `60s`, it will be overridden to `60s`. |
| `event_recorder.ttl` | String | `30d` | TTL for the events table that will be used to store the events. |
| `logging` | -- | -- | The logging options. |
| `logging.dir` | String | `./greptimedb_data/logs` | The directory to store the log files. If set to empty, logs will not be written to files. |
| `logging.level` | String | Unset | The log level. Can be `info`/`debug`/`warn`/`error`. |
| `logging.enable_otlp_tracing` | Bool | `false` | Enable OTLP tracing. |
| `logging.otlp_endpoint` | String | `http://localhost:4318/v1/traces` | The OTLP tracing endpoint. |
| `logging.otlp_endpoint` | String | `http://localhost:4318` | The OTLP tracing endpoint. |
| `logging.append_stdout` | Bool | `true` | Whether to append logs to stdout. |
| `logging.log_format` | String | `text` | The log format. Can be `text`/`json`. |
| `logging.max_log_files` | Integer | `720` | The maximum amount of log files. |
| `logging.otlp_export_protocol` | String | `http` | The OTLP tracing export protocol. Can be `grpc`/`http`. |
| `logging.otlp_headers` | -- | -- | Additional OTLP headers, only valid when using OTLP http |
| `logging.tracing_sample_ratio` | -- | Unset | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio` | -- | -- | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio.default_ratio` | Float | `1.0` | -- |
| `export_metrics` | -- | -- | The metasrv can export its metrics and send to Prometheus compatible service (e.g. `greptimedb` itself) from remote-write API.<br/>This is only used for `greptimedb` to export its own metrics internally. It's different from prometheus scrape. |
| `export_metrics.enable` | Bool | `false` | whether enable export metrics. |
@@ -565,13 +555,12 @@
| `logging.dir` | String | `./greptimedb_data/logs` | The directory to store the log files. If set to empty, logs will not be written to files. |
| `logging.level` | String | Unset | The log level. Can be `info`/`debug`/`warn`/`error`. |
| `logging.enable_otlp_tracing` | Bool | `false` | Enable OTLP tracing. |
| `logging.otlp_endpoint` | String | `http://localhost:4318/v1/traces` | The OTLP tracing endpoint. |
| `logging.otlp_endpoint` | String | `http://localhost:4318` | The OTLP tracing endpoint. |
| `logging.append_stdout` | Bool | `true` | Whether to append logs to stdout. |
| `logging.log_format` | String | `text` | The log format. Can be `text`/`json`. |
| `logging.max_log_files` | Integer | `720` | The maximum amount of log files. |
| `logging.otlp_export_protocol` | String | `http` | The OTLP tracing export protocol. Can be `grpc`/`http`. |
| `logging.otlp_headers` | -- | -- | Additional OTLP headers, only valid when using OTLP http |
| `logging.tracing_sample_ratio` | -- | Unset | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio` | -- | -- | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio.default_ratio` | Float | `1.0` | -- |
| `export_metrics` | -- | -- | The datanode can export its metrics and send to Prometheus compatible service (e.g. `greptimedb` itself) from remote-write API.<br/>This is only used for `greptimedb` to export its own metrics internally. It's different from prometheus scrape. |
| `export_metrics.enable` | Bool | `false` | whether enable export metrics. |
@@ -602,12 +591,6 @@
| `flow.batching_mode.experimental_frontend_activity_timeout` | String | `60s` | Frontend activity timeout<br/>if frontend is down(not sending heartbeat) for more than frontend_activity_timeout,<br/>it will be removed from the list that flownode use to connect |
| `flow.batching_mode.experimental_max_filter_num_per_query` | Integer | `20` | Maximum number of filters allowed in a single query |
| `flow.batching_mode.experimental_time_window_merge_threshold` | Integer | `3` | Time window merge distance |
| `flow.batching_mode.read_preference` | String | `Leader` | Read preference of the Frontend client. |
| `flow.batching_mode.frontend_tls` | -- | -- | -- |
| `flow.batching_mode.frontend_tls.enabled` | Bool | `false` | Whether to enable TLS for client. |
| `flow.batching_mode.frontend_tls.server_ca_cert_path` | String | Unset | Server Certificate file path. |
| `flow.batching_mode.frontend_tls.client_cert_path` | String | Unset | Client Certificate file path. |
| `flow.batching_mode.frontend_tls.client_key_path` | String | Unset | Client Private key file path. |
| `grpc` | -- | -- | The gRPC server options. |
| `grpc.bind_addr` | String | `127.0.0.1:6800` | The address to bind the gRPC server. |
| `grpc.server_addr` | String | `127.0.0.1:6800` | The address advertised to the metasrv,<br/>and used for connections from outside the host |
@@ -635,13 +618,12 @@
| `logging.dir` | String | `./greptimedb_data/logs` | The directory to store the log files. If set to empty, logs will not be written to files. |
| `logging.level` | String | Unset | The log level. Can be `info`/`debug`/`warn`/`error`. |
| `logging.enable_otlp_tracing` | Bool | `false` | Enable OTLP tracing. |
| `logging.otlp_endpoint` | String | `http://localhost:4318/v1/traces` | The OTLP tracing endpoint. |
| `logging.otlp_endpoint` | String | `http://localhost:4318` | The OTLP tracing endpoint. |
| `logging.append_stdout` | Bool | `true` | Whether to append logs to stdout. |
| `logging.log_format` | String | `text` | The log format. Can be `text`/`json`. |
| `logging.max_log_files` | Integer | `720` | The maximum amount of log files. |
| `logging.otlp_export_protocol` | String | `http` | The OTLP tracing export protocol. Can be `grpc`/`http`. |
| `logging.otlp_headers` | -- | -- | Additional OTLP headers, only valid when using OTLP http |
| `logging.tracing_sample_ratio` | -- | Unset | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio` | -- | -- | The percentage of tracing will be sampled and exported.<br/>Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.<br/>ratio > 1 are treated as 1. Fractions < 0 are treated as 0 |
| `logging.tracing_sample_ratio.default_ratio` | Float | `1.0` | -- |
| `tracing` | -- | -- | The tracing options. Only effect when compiled with `tokio-console` feature. |
| `tracing.tokio_console_addr` | String | Unset | The tokio console address. |

View File

@@ -632,7 +632,7 @@ level = "info"
enable_otlp_tracing = false
## The OTLP tracing endpoint.
otlp_endpoint = "http://localhost:4318/v1/traces"
otlp_endpoint = "http://localhost:4318"
## Whether to append logs to stdout.
append_stdout = true
@@ -646,13 +646,6 @@ max_log_files = 720
## The OTLP tracing export protocol. Can be `grpc`/`http`.
otlp_export_protocol = "http"
## Additional OTLP headers, only valid when using OTLP http
[logging.otlp_headers]
## @toml2docs:none-default
#Authorization = "Bearer my-token"
## @toml2docs:none-default
#Database = "My database"
## The percentage of tracing will be sampled and exported.
## Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.
## ratio > 1 are treated as 1. Fractions < 0 are treated as 0

View File

@@ -30,20 +30,6 @@ node_id = 14
#+experimental_max_filter_num_per_query=20
## Time window merge distance
#+experimental_time_window_merge_threshold=3
## Read preference of the Frontend client.
#+read_preference="Leader"
[flow.batching_mode.frontend_tls]
## Whether to enable TLS for client.
#+enabled=false
## Server Certificate file path.
## @toml2docs:none-default
#+server_ca_cert_path=""
## Client Certificate file path.
## @toml2docs:none-default
#+client_cert_path=""
## Client Private key file path.
## @toml2docs:none-default
#+client_key_path=""
## The gRPC server options.
[grpc]
@@ -120,7 +106,7 @@ level = "info"
enable_otlp_tracing = false
## The OTLP tracing endpoint.
otlp_endpoint = "http://localhost:4318/v1/traces"
otlp_endpoint = "http://localhost:4318"
## Whether to append logs to stdout.
append_stdout = true
@@ -134,13 +120,6 @@ max_log_files = 720
## The OTLP tracing export protocol. Can be `grpc`/`http`.
otlp_export_protocol = "http"
## Additional OTLP headers, only valid when using OTLP http
[logging.otlp_headers]
## @toml2docs:none-default
#Authorization = "Bearer my-token"
## @toml2docs:none-default
#Database = "My database"
## The percentage of tracing will be sampled and exported.
## Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.
## ratio > 1 are treated as 1. Fractions < 0 are treated as 0

View File

@@ -90,8 +90,6 @@ runtime_size = 2
## Server-side keep-alive time.
## Set to 0 (default) to disable.
keep_alive = "0s"
## Maximum entries in the MySQL prepared statement cache; default is 10,000.
prepared_stmt_cache_size = 10000
# MySQL server TLS options.
[mysql.tls]
@@ -223,7 +221,7 @@ level = "info"
enable_otlp_tracing = false
## The OTLP tracing endpoint.
otlp_endpoint = "http://localhost:4318/v1/traces"
otlp_endpoint = "http://localhost:4318"
## Whether to append logs to stdout.
append_stdout = true
@@ -237,13 +235,6 @@ max_log_files = 720
## The OTLP tracing export protocol. Can be `grpc`/`http`.
otlp_export_protocol = "http"
## Additional OTLP headers, only valid when using OTLP http
[logging.otlp_headers]
## @toml2docs:none-default
#Authorization = "Bearer my-token"
## @toml2docs:none-default
#Database = "My database"
## The percentage of tracing will be sampled and exported.
## Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.
## ratio > 1 are treated as 1. Fractions < 0 are treated as 0
@@ -266,8 +257,8 @@ threshold = "30s"
## The sampling ratio of slow query log. The value should be in the range of (0, 1]. For example, `0.1` means 10% of the slow queries will be logged and `1.0` means all slow queries will be logged.
sample_ratio = 1.0
## The TTL of the `slow_queries` system table. Default is `90d` when `record_type` is `system_table`.
ttl = "90d"
## The TTL of the `slow_queries` system table. Default is `30d` when `record_type` is `system_table`.
ttl = "30d"
## The frontend can export its metrics and send to Prometheus compatible service (e.g. `greptimedb` itself) from remote-write API.
## This is only used for `greptimedb` to export its own metrics internally. It's different from prometheus scrape.
@@ -297,8 +288,3 @@ headers = { }
## is set to "prof:true,prof_active:false". The official image adds this env variable.
## Default is true.
enable_heap_profiling = true
## Configuration options for the event recorder.
[event_recorder]
## TTL for the events table that will be used to store the events. Default is `90d`.
ttl = "90d"

View File

@@ -176,61 +176,50 @@ tcp_nodelay = true
# - `kafka`: metasrv **have to be** configured with kafka wal config when using kafka wal provider in datanode.
provider = "raft_engine"
# Kafka wal config.
## The broker endpoints of the Kafka cluster.
##
## **It's only used when the provider is `kafka`**.
broker_endpoints = ["127.0.0.1:9092"]
## Automatically create topics for WAL.
## Set to `true` to automatically create topics for WAL.
## Otherwise, use topics named `topic_name_prefix_[0..num_topics)`
## **It's only used when the provider is `kafka`**.
auto_create_topics = true
## Interval of automatically WAL pruning.
## Set to `0s` to disable automatically WAL pruning which delete unused remote WAL entries periodically.
## **It's only used when the provider is `kafka`**.
auto_prune_interval = "10m"
auto_prune_interval = "0s"
## Estimated size threshold to trigger a flush when using Kafka remote WAL.
## Since multiple regions may share a Kafka topic, the estimated size is calculated as:
## (latest_entry_id - flushed_entry_id) * avg_record_size
## MetaSrv triggers a flush for a region when this estimated size exceeds `flush_trigger_size`.
## - `latest_entry_id`: The latest entry ID in the topic.
## - `flushed_entry_id`: The last flushed entry ID for the region.
## Set to "0" to let the system decide the flush trigger size.
## **It's only used when the provider is `kafka`**.
flush_trigger_size = "512MB"
## The threshold to trigger a flush operation of a region in automatically WAL pruning.
## Metasrv will send a flush request to flush the region when:
## `trigger_flush_threshold` + `prunable_entry_id` < `max_prunable_entry_id`
## where:
## - `prunable_entry_id` is the maximum entry id that can be pruned of the region.
## - `max_prunable_entry_id` is the maximum prunable entry id among all regions in the same topic.
## Set to `0` to disable the flush operation.
trigger_flush_threshold = 0
## Concurrent task limit for automatically WAL pruning.
## **It's only used when the provider is `kafka`**.
auto_prune_parallelism = 10
## Number of topics used for remote WAL.
## **It's only used when the provider is `kafka`**.
## Number of topics.
num_topics = 64
## Topic selector type.
## Available selector types:
## - `round_robin` (default)
## **It's only used when the provider is `kafka`**.
selector_type = "round_robin"
## A Kafka topic is constructed by concatenating `topic_name_prefix` and `topic_id`.
## Only accepts strings that match the following regular expression pattern:
## [a-zA-Z_:-][a-zA-Z0-9_:\-\.@#]*
## i.g., greptimedb_wal_topic_0, greptimedb_wal_topic_1.
## **It's only used when the provider is `kafka`**.
topic_name_prefix = "greptimedb_wal_topic"
## Expected number of replicas of each partition.
## **It's only used when the provider is `kafka`**.
replication_factor = 1
## The timeout for creating a Kafka topic.
## **It's only used when the provider is `kafka`**.
## Above which a topic creation operation will be cancelled.
create_topic_timeout = "30s"
# The Kafka SASL configuration.
@@ -253,17 +242,8 @@ create_topic_timeout = "30s"
## Configuration options for the event recorder.
[event_recorder]
## TTL for the events table that will be used to store the events. Default is `90d`.
ttl = "90d"
## Configuration options for the stats persistence.
[stats_persistence]
## TTL for the stats table that will be used to store the stats. Default is `30d`.
## Set to `0s` to disable stats persistence.
## TTL for the events table that will be used to store the events.
ttl = "30d"
## The interval to persist the stats. Default is `60s`.
## The minimum value is `60s`, if the value is less than `60s`, it will be overridden to `60s`.
interval = "60s"
## The logging options.
[logging]
@@ -278,7 +258,7 @@ level = "info"
enable_otlp_tracing = false
## The OTLP tracing endpoint.
otlp_endpoint = "http://localhost:4318/v1/traces"
otlp_endpoint = "http://localhost:4318"
## Whether to append logs to stdout.
append_stdout = true
@@ -292,14 +272,6 @@ max_log_files = 720
## The OTLP tracing export protocol. Can be `grpc`/`http`.
otlp_export_protocol = "http"
## Additional OTLP headers, only valid when using OTLP http
[logging.otlp_headers]
## @toml2docs:none-default
#Authorization = "Bearer my-token"
## @toml2docs:none-default
#Database = "My database"
## The percentage of tracing will be sampled and exported.
## Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.
## ratio > 1 are treated as 1. Fractions < 0 are treated as 0

View File

@@ -85,8 +85,7 @@ runtime_size = 2
## Server-side keep-alive time.
## Set to 0 (default) to disable.
keep_alive = "0s"
## Maximum entries in the MySQL prepared statement cache; default is 10,000.
prepared_stmt_cache_size= 10000
# MySQL server TLS options.
[mysql.tls]
@@ -724,7 +723,7 @@ level = "info"
enable_otlp_tracing = false
## The OTLP tracing endpoint.
otlp_endpoint = "http://localhost:4318/v1/traces"
otlp_endpoint = "http://localhost:4318"
## Whether to append logs to stdout.
append_stdout = true
@@ -738,13 +737,6 @@ max_log_files = 720
## The OTLP tracing export protocol. Can be `grpc`/`http`.
otlp_export_protocol = "http"
## Additional OTLP headers, only valid when using OTLP http
[logging.otlp_headers]
## @toml2docs:none-default
#Authorization = "Bearer my-token"
## @toml2docs:none-default
#Database = "My database"
## The percentage of tracing will be sampled and exported.
## Valid range `[0, 1]`, 1 means all traces are sampled, 0 means all traces are not sampled, the default value is 1.
## ratio > 1 are treated as 1. Fractions < 0 are treated as 0

View File

@@ -19,7 +19,7 @@ ARG PROTOBUF_VERSION=29.3
RUN curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOBUF_VERSION}/protoc-${PROTOBUF_VERSION}-linux-x86_64.zip && \
unzip protoc-${PROTOBUF_VERSION}-linux-x86_64.zip -d protoc3;
RUN mv protoc3/bin/* /usr/local/bin/
RUN mv protoc3/include/* /usr/local/include/

View File

@@ -0,0 +1,72 @@
Currently, our query engine is based on DataFusion, so all aggregate function is executed by DataFusion, through its UDAF interface. You can find DataFusion's UDAF example [here](https://github.com/apache/datafusion/tree/main/datafusion-examples/examples/simple_udaf.rs). Basically, we provide the same way as DataFusion to write aggregate functions: both are centered in a struct called "Accumulator" to accumulates states along the way in aggregation.
However, DataFusion's UDAF implementation has a huge restriction, that it requires user to provide a concrete "Accumulator". Take `Median` aggregate function for example, to aggregate a `u32` datatype column, you have to write a `MedianU32`, and use `SELECT MEDIANU32(x)` in SQL. `MedianU32` cannot be used to aggregate a `i32` datatype column. Or, there's another way: you can use a special type that can hold all kinds of data (like our `Value` enum or Arrow's `ScalarValue`), and `match` all the way up to do aggregate calculations. It might work, though rather tedious. (But I think it's DataFusion's preferred way to write UDAF.)
So is there a way we can make an aggregate function that automatically match the input data's type? For example, a `Median` aggregator that can work on both `u32` column and `i32`? The answer is yes until we find a way to bypass DataFusion's restriction, a restriction that DataFusion simply doesn't pass the input data's type when creating an Accumulator.
> There's an example in `my_sum_udaf_example.rs`, take that as quick start.
# 1. Impl `AggregateFunctionCreator` trait for your accumulator creator.
You must first define a struct that will be used to create your accumulator. For example,
```Rust
#[as_aggr_func_creator]
#[derive(Debug, AggrFuncTypeStore)]
struct MySumAccumulatorCreator {}
```
Attribute macro `#[as_aggr_func_creator]` and derive macro `#[derive(Debug, AggrFuncTypeStore)]` must both be annotated on the struct. They work together to provide a storage of aggregate function's input data types, which are needed for creating generic accumulator later.
> Note that the `as_aggr_func_creator` macro will add fields to the struct, so the struct cannot be defined as an empty struct without field like `struct Foo;`, neither as a new type like `struct Foo(bar)`.
Then impl `AggregateFunctionCreator` trait on it. The definition of the trait is:
```Rust
pub trait AggregateFunctionCreator: Send + Sync + Debug {
fn creator(&self) -> AccumulatorCreatorFunction;
fn output_type(&self) -> ConcreteDataType;
fn state_types(&self) -> Vec<ConcreteDataType>;
}
```
You can use input data's type in methods that return output type and state types (just invoke `input_types()`).
The output type is aggregate function's output data's type. For example, `SUM` aggregate function's output type is `u64` for a `u32` datatype column. The state types are accumulator's internal states' types. Take `AVG` aggregate function on a `i32` column as example, its state types are `i64` (for sum) and `u64` (for count).
The `creator` function is where you define how an accumulator (that will be used in DataFusion) is created. You define "how" to create the accumulator (instead of "what" to create), using the input data's type as arguments. With input datatype known, you can create accumulator generically.
# 2. Impl `Accumulator` trait for your accumulator.
The accumulator is where you store the aggregate calculation states and evaluate a result. You must impl `Accumulator` trait for it. The trait's definition is:
```Rust
pub trait Accumulator: Send + Sync + Debug {
fn state(&self) -> Result<Vec<Value>>;
fn update_batch(&mut self, values: &[VectorRef]) -> Result<()>;
fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()>;
fn evaluate(&self) -> Result<Value>;
}
```
The DataFusion basically executes aggregate like this:
1. Partitioning all input data for aggregate. Create an accumulator for each part.
2. Call `update_batch` on each accumulator with partitioned data, to let you update your aggregate calculation.
3. Call `state` to get each accumulator's internal state, the medial calculation result.
4. Call `merge_batch` to merge all accumulator's internal state to one.
5. Execute `evaluate` on the chosen one to get the final calculation result.
Once you know the meaning of each method, you can easily write your accumulator. You can refer to `Median` accumulator or `SUM` accumulator defined in file `my_sum_udaf_example.rs` for more details.
# 3. Register your aggregate function to our query engine.
You can call `register_aggregate_function` method in query engine to register your aggregate function. To do that, you have to new an instance of struct `AggregateFunctionMeta`. The struct has three fields, first is the name of your aggregate function's name. The function name is case-sensitive due to DataFusion's restriction. We strongly recommend using lowercase for your name. If you have to use uppercase name, wrap your aggregate function with quotation marks. For example, if you define an aggregate function named "my_aggr", you can use "`SELECT MY_AGGR(x)`"; if you define "my_AGGR", you have to use "`SELECT "my_AGGR"(x)`".
The second field is arg_counts ,the count of the arguments. Like accumulator `percentile`, calculating the p_number of the column. We need to input the value of column and the value of p to calculate, and so the count of the arguments is two.
The third field is a function about how to create your accumulator creator that you defined in step 1 above. Create creator, that's a bit intertwined, but it is how we make DataFusion use a newly created aggregate function each time it executes a SQL, preventing the stored input types from affecting each other. The key detail can be starting looking at our `DfContextProviderAdapter` struct's `get_aggregate_meta` method.
# (Optional) 4. Make your aggregate function automatically registered.
If you've written a great aggregate function that wants to let everyone use it, you can make it automatically register to our query engine at start time. It's quick and simple, just refer to the `AggregateFunctions::register` function in `common/function/src/scalars/aggregate/mod.rs`.

View File

@@ -15,6 +15,8 @@
let
pkgs = nixpkgs.legacyPackages.${system};
buildInputs = with pkgs; [
libgit2
libz
];
lib = nixpkgs.lib;
rustToolchain = fenix.packages.${system}.fromToolchainName {

View File

@@ -19,3 +19,6 @@ paste.workspace = true
prost.workspace = true
serde_json.workspace = true
snafu.workspace = true
[build-dependencies]
tonic-build = "0.11"

View File

@@ -14,8 +14,6 @@
pub mod column_def;
pub mod helper;
pub mod meta {
pub use greptime_proto::v1::meta::*;
}

View File

@@ -1,65 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use greptime_proto::v1::value::ValueData;
use greptime_proto::v1::{ColumnDataType, ColumnSchema, Row, SemanticType, Value};
/// Create a time index [ColumnSchema] with column's name and datatype.
/// Other fields are left default.
/// Useful when you just want to create a simple [ColumnSchema] without providing much struct fields.
pub fn time_index_column_schema(name: &str, datatype: ColumnDataType) -> ColumnSchema {
ColumnSchema {
column_name: name.to_string(),
datatype: datatype as i32,
semantic_type: SemanticType::Timestamp as i32,
..Default::default()
}
}
/// Create a tag [ColumnSchema] with column's name and datatype.
/// Other fields are left default.
/// Useful when you just want to create a simple [ColumnSchema] without providing much struct fields.
pub fn tag_column_schema(name: &str, datatype: ColumnDataType) -> ColumnSchema {
ColumnSchema {
column_name: name.to_string(),
datatype: datatype as i32,
semantic_type: SemanticType::Tag as i32,
..Default::default()
}
}
/// Create a field [ColumnSchema] with column's name and datatype.
/// Other fields are left default.
/// Useful when you just want to create a simple [ColumnSchema] without providing much struct fields.
pub fn field_column_schema(name: &str, datatype: ColumnDataType) -> ColumnSchema {
ColumnSchema {
column_name: name.to_string(),
datatype: datatype as i32,
semantic_type: SemanticType::Field as i32,
..Default::default()
}
}
/// Create a [Row] from [ValueData]s.
/// Useful when you don't want to write much verbose codes.
pub fn row(values: Vec<ValueData>) -> Row {
Row {
values: values
.into_iter()
.map(|x| Value {
value_data: Some(x),
})
.collect::<Vec<_>>(),
}
}

View File

@@ -21,7 +21,6 @@ bytes.workspace = true
common-base.workspace = true
common-catalog.workspace = true
common-error.workspace = true
common-event-recorder.workspace = true
common-frontend.workspace = true
common-macro.workspace = true
common-meta.workspace = true

View File

@@ -44,7 +44,6 @@ use store_api::metric_engine_consts::METRIC_ENGINE_NAME;
use table::dist_table::DistTable;
use table::metadata::{TableId, TableInfoRef};
use table::table::numbers::{NumbersTable, NUMBERS_TABLE_NAME};
use table::table::PartitionRules;
use table::table_name::TableName;
use table::TableRef;
use tokio::sync::Semaphore;
@@ -133,8 +132,6 @@ impl KvBackendCatalogManager {
{
let mut new_table_info = (*table.table_info()).clone();
let mut phy_part_cols_not_in_logical_table = vec![];
// Remap partition key indices from physical table to logical table
new_table_info.meta.partition_key_indices = physical_table_info_value
.table_info
@@ -151,30 +148,15 @@ impl KvBackendCatalogManager {
.get(physical_index)
.and_then(|physical_column| {
// Find the corresponding index in the logical table schema
let idx = new_table_info
new_table_info
.meta
.schema
.column_index_by_name(physical_column.name.as_str());
if idx.is_none() {
// not all part columns in physical table that are also in logical table
phy_part_cols_not_in_logical_table
.push(physical_column.name.clone());
}
idx
.column_index_by_name(physical_column.name.as_str())
})
})
.collect();
let partition_rules = if !phy_part_cols_not_in_logical_table.is_empty() {
Some(PartitionRules {
extra_phy_cols_not_in_logical_table: phy_part_cols_not_in_logical_table,
})
} else {
None
};
let new_table = DistTable::table_partitioned(Arc::new(new_table_info), partition_rules);
let new_table = DistTable::table(Arc::new(new_table_info));
return Ok(new_table);
}

View File

@@ -38,7 +38,7 @@ use crate::{CatalogManager, DeregisterTableRequest, RegisterSchemaRequest, Regis
type SchemaEntries = HashMap<String, HashMap<String, TableRef>>;
/// Simple in-memory list of catalogs used for tests.
/// Simple in-memory list of catalogs
#[derive(Clone)]
pub struct MemoryCatalogManager {
/// Collection of catalogs containing schemas and ultimately Tables

View File

@@ -21,17 +21,17 @@ use std::time::{Duration, Instant, UNIX_EPOCH};
use api::v1::frontend::{KillProcessRequest, ListProcessRequest, ProcessInfo};
use common_base::cancellation::CancellationHandle;
use common_event_recorder::EventRecorderRef;
use common_frontend::selector::{FrontendSelector, MetaClientSelector};
use common_frontend::slow_query_event::SlowQueryEvent;
use common_telemetry::logging::SlowQueriesRecordType;
use common_telemetry::{debug, info, slow, warn};
use common_telemetry::{debug, error, info, warn};
use common_time::util::current_time_millis;
use meta_client::MetaClientRef;
use promql_parser::parser::EvalStmt;
use rand::random;
use session::context::QueryContextRef;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::statement::Statement;
use tokio::sync::mpsc::Sender;
use crate::error;
use crate::metrics::{PROCESS_KILL_COUNT, PROCESS_LIST_COUNT};
@@ -249,8 +249,6 @@ pub struct Ticket {
pub(crate) manager: ProcessManagerRef,
pub(crate) id: ProcessId,
pub cancellation_handle: Arc<CancellationHandle>,
// Keep the handle of the slow query timer to ensure it will trigger the event recording when dropped.
_slow_query_timer: Option<SlowQueryTimer>,
}
@@ -297,37 +295,38 @@ impl Debug for CancellableProcess {
pub struct SlowQueryTimer {
start: Instant,
stmt: QueryStatement,
threshold: Duration,
sample_ratio: f64,
record_type: SlowQueriesRecordType,
recorder: EventRecorderRef,
query_ctx: QueryContextRef,
threshold: Option<Duration>,
sample_ratio: Option<f64>,
tx: Sender<SlowQueryEvent>,
}
impl SlowQueryTimer {
pub fn new(
stmt: QueryStatement,
threshold: Duration,
sample_ratio: f64,
record_type: SlowQueriesRecordType,
recorder: EventRecorderRef,
query_ctx: QueryContextRef,
threshold: Option<Duration>,
sample_ratio: Option<f64>,
tx: Sender<SlowQueryEvent>,
) -> Self {
Self {
start: Instant::now(),
stmt,
query_ctx,
threshold,
sample_ratio,
record_type,
recorder,
tx,
}
}
}
impl SlowQueryTimer {
fn send_slow_query_event(&self, elapsed: Duration) {
fn send_slow_query_event(&self, elapsed: Duration, threshold: Duration) {
let mut slow_query_event = SlowQueryEvent {
cost: elapsed.as_millis() as u64,
threshold: self.threshold.as_millis() as u64,
threshold: threshold.as_millis() as u64,
query: "".to_string(),
query_ctx: self.query_ctx.clone(),
// The following fields are only used for PromQL queries.
is_promql: false,
@@ -364,37 +363,29 @@ impl SlowQueryTimer {
}
}
match self.record_type {
// Send the slow query event to the event recorder to persist it as the system table.
SlowQueriesRecordType::SystemTable => {
self.recorder.record(Box::new(slow_query_event));
}
// Record the slow query in a specific logs file.
SlowQueriesRecordType::Log => {
slow!(
cost = slow_query_event.cost,
threshold = slow_query_event.threshold,
query = slow_query_event.query,
is_promql = slow_query_event.is_promql,
promql_range = slow_query_event.promql_range,
promql_step = slow_query_event.promql_step,
promql_start = slow_query_event.promql_start,
promql_end = slow_query_event.promql_end,
);
}
// Send SlowQueryEvent to the handler.
if let Err(e) = self.tx.try_send(slow_query_event) {
error!(e; "Failed to send slow query event");
}
}
}
impl Drop for SlowQueryTimer {
fn drop(&mut self) {
// Calculate the elaspsed duration since the timer is created.
let elapsed = self.start.elapsed();
if elapsed > self.threshold {
// Only capture a portion of slow queries based on sample_ratio.
// Generate a random number in [0, 1) and compare it with sample_ratio.
if self.sample_ratio >= 1.0 || random::<f64>() <= self.sample_ratio {
self.send_slow_query_event(elapsed);
if let Some(threshold) = self.threshold {
// Calculate the elaspsed duration since the timer is created.
let elapsed = self.start.elapsed();
if elapsed > threshold {
if let Some(ratio) = self.sample_ratio {
// Only capture a portion of slow queries based on sample_ratio.
// Generate a random number in [0, 1) and compare it with sample_ratio.
if ratio >= 1.0 || random::<f64>() <= ratio {
self.send_slow_query_event(elapsed, threshold);
}
} else {
// Captures all slow queries if sample_ratio is not set.
self.send_slow_query_event(elapsed, threshold);
}
}
}
}

View File

@@ -133,7 +133,7 @@ impl Predicate {
let Expr::Column(c) = *expr else {
unreachable!();
};
let Expr::Literal(ScalarValue::Utf8(Some(pattern)), _) = *pattern else {
let Expr::Literal(ScalarValue::Utf8(Some(pattern))) = *pattern else {
unreachable!();
};
@@ -148,8 +148,8 @@ impl Predicate {
// left OP right
Expr::BinaryExpr(bin) => match (*bin.left, bin.op, *bin.right) {
// left == right
(Expr::Literal(scalar, _), Operator::Eq, Expr::Column(c))
| (Expr::Column(c), Operator::Eq, Expr::Literal(scalar, _)) => {
(Expr::Literal(scalar), Operator::Eq, Expr::Column(c))
| (Expr::Column(c), Operator::Eq, Expr::Literal(scalar)) => {
let Ok(v) = Value::try_from(scalar) else {
return None;
};
@@ -157,8 +157,8 @@ impl Predicate {
Some(Predicate::Eq(c.name, v))
}
// left != right
(Expr::Literal(scalar, _), Operator::NotEq, Expr::Column(c))
| (Expr::Column(c), Operator::NotEq, Expr::Literal(scalar, _)) => {
(Expr::Literal(scalar), Operator::NotEq, Expr::Column(c))
| (Expr::Column(c), Operator::NotEq, Expr::Literal(scalar)) => {
let Ok(v) = Value::try_from(scalar) else {
return None;
};
@@ -189,7 +189,7 @@ impl Predicate {
let mut values = Vec::with_capacity(list.len());
for scalar in list {
// Safety: checked by `is_all_scalars`
let Expr::Literal(scalar, _) = scalar else {
let Expr::Literal(scalar) = scalar else {
unreachable!();
};
@@ -237,7 +237,7 @@ fn like_utf8(s: &str, pattern: &str, case_insensitive: &bool) -> Option<bool> {
}
fn is_string_literal(expr: &Expr) -> bool {
matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(_)), _))
matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(_))))
}
fn is_column(expr: &Expr) -> bool {
@@ -286,14 +286,14 @@ impl Predicates {
/// Returns true when the values are all [`DfExpr::Literal`].
fn is_all_scalars(list: &[Expr]) -> bool {
list.iter().all(|v| matches!(v, Expr::Literal(_, _)))
list.iter().all(|v| matches!(v, Expr::Literal(_)))
}
#[cfg(test)]
mod tests {
use datafusion::common::Column;
use datafusion::common::{Column, ScalarValue};
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{BinaryExpr, Literal};
use datafusion::logical_expr::BinaryExpr;
use super::*;
@@ -378,7 +378,7 @@ mod tests {
let expr = Expr::Like(Like {
negated: false,
expr: Box::new(column("a")),
pattern: Box::new("%abc".lit()),
pattern: Box::new(string_literal("%abc")),
case_insensitive: true,
escape_char: None,
});
@@ -405,7 +405,7 @@ mod tests {
let expr = Expr::Like(Like {
negated: false,
expr: Box::new(column("a")),
pattern: Box::new("%abc".lit()),
pattern: Box::new(string_literal("%abc")),
case_insensitive: false,
escape_char: None,
});
@@ -425,7 +425,7 @@ mod tests {
let expr = Expr::Like(Like {
negated: true,
expr: Box::new(column("a")),
pattern: Box::new("%abc".lit()),
pattern: Box::new(string_literal("%abc")),
case_insensitive: true,
escape_char: None,
});
@@ -440,6 +440,10 @@ mod tests {
Expr::Column(Column::from_name(name))
}
fn string_literal(v: &str) -> Expr {
Expr::Literal(ScalarValue::Utf8(Some(v.to_string())))
}
fn match_string_value(v: &Value, expected: &str) -> bool {
matches!(v, Value::String(bs) if bs.as_utf8() == expected)
}
@@ -459,13 +463,13 @@ mod tests {
let expr1 = Expr::BinaryExpr(BinaryExpr {
left: Box::new(column("a")),
op: Operator::Eq,
right: Box::new("a_value".lit()),
right: Box::new(string_literal("a_value")),
});
let expr2 = Expr::BinaryExpr(BinaryExpr {
left: Box::new(column("b")),
op: Operator::NotEq,
right: Box::new("b_value".lit()),
right: Box::new(string_literal("b_value")),
});
(expr1, expr2)
@@ -504,7 +508,7 @@ mod tests {
let inlist_expr = Expr::InList(InList {
expr: Box::new(column("a")),
list: vec!["a1".lit(), "a2".lit()],
list: vec![string_literal("a1"), string_literal("a2")],
negated: false,
});
@@ -514,7 +518,7 @@ mod tests {
let inlist_expr = Expr::InList(InList {
expr: Box::new(column("a")),
list: vec!["a1".lit(), "a2".lit()],
list: vec![string_literal("a1"), string_literal("a2")],
negated: true,
});
let inlist_p = Predicate::from_expr(inlist_expr).unwrap();

View File

@@ -32,7 +32,7 @@ use dummy_catalog::DummyCatalogList;
use table::TableRef;
use crate::error::{
CastManagerSnafu, DecodePlanSnafu, GetViewCacheSnafu, ProjectViewColumnsSnafu,
CastManagerSnafu, DatafusionSnafu, DecodePlanSnafu, GetViewCacheSnafu, ProjectViewColumnsSnafu,
QueryAccessDeniedSnafu, Result, TableNotExistSnafu, ViewInfoNotFoundSnafu,
ViewPlanColumnsChangedSnafu,
};
@@ -199,10 +199,10 @@ impl DfTableSourceProvider {
logical_plan
};
Ok(Arc::new(ViewTable::new(
logical_plan,
Some(view_info.definition.to_string()),
)))
Ok(Arc::new(
ViewTable::try_new(logical_plan, Some(view_info.definition.to_string()))
.context(DatafusionSnafu)?,
))
}
}

View File

@@ -74,19 +74,11 @@ pub fn make_create_region_request_for_peer(
let catalog = &create_table_expr.catalog_name;
let schema = &create_table_expr.schema_name;
let storage_path = region_storage_path(catalog, schema);
let partition_exprs = region_routes
.iter()
.map(|r| (r.region.id.region_number(), r.region.partition_expr()))
.collect::<HashMap<_, _>>();
for region_number in &regions_on_this_peer {
let region_id = RegionId::new(logical_table_id, *region_number);
let region_request = request_builder.build_one(
region_id,
storage_path.clone(),
&HashMap::new(),
&partition_exprs,
);
let region_request =
request_builder.build_one(region_id, storage_path.clone(), &HashMap::new());
requests.push(region_request);
}

View File

@@ -29,7 +29,6 @@ datatypes.workspace = true
enum_dispatch = "0.3"
futures.workspace = true
futures-util.workspace = true
humantime.workspace = true
lazy_static.workspace = true
moka = { workspace = true, features = ["future"] }
parking_lot.workspace = true
@@ -39,7 +38,6 @@ query.workspace = true
rand.workspace = true
serde_json.workspace = true
snafu.workspace = true
store-api.workspace = true
substrait.workspace = true
tokio.workspace = true
tokio-stream = { workspace = true, features = ["net"] }

View File

@@ -17,7 +17,7 @@ use std::sync::Arc;
use std::time::Duration;
use common_grpc::channel_manager::{ChannelConfig, ChannelManager};
use common_meta::node_manager::{DatanodeManager, DatanodeRef, FlownodeManager, FlownodeRef};
use common_meta::node_manager::{DatanodeRef, FlownodeRef, NodeManager};
use common_meta::peer::Peer;
use moka::future::{Cache, CacheBuilder};
@@ -45,7 +45,7 @@ impl Debug for NodeClients {
}
#[async_trait::async_trait]
impl DatanodeManager for NodeClients {
impl NodeManager for NodeClients {
async fn datanode(&self, datanode: &Peer) -> DatanodeRef {
let client = self.get_client(datanode).await;
@@ -60,10 +60,7 @@ impl DatanodeManager for NodeClients {
*accept_compression,
))
}
}
#[async_trait::async_trait]
impl FlownodeManager for NodeClients {
async fn flownode(&self, flownode: &Peer) -> FlownodeRef {
let client = self.get_client(flownode).await;

View File

@@ -75,24 +75,12 @@ pub struct Database {
}
pub struct DatabaseClient {
pub addr: String,
pub inner: GreptimeDatabaseClient<Channel>,
}
impl DatabaseClient {
/// Returns a closure that logs the error when the request fails.
pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
let addr = &self.addr;
move |status| {
error!("Failed to {context} request, peer: {addr}, status: {status:?}");
}
}
}
fn make_database_client(client: &Client) -> Result<DatabaseClient> {
let (addr, channel) = client.find_channel()?;
let (_, channel) = client.find_channel()?;
Ok(DatabaseClient {
addr,
inner: GreptimeDatabaseClient::new(channel)
.max_decoding_message_size(client.max_grpc_recv_message_size())
.max_encoding_message_size(client.max_grpc_send_message_size()),
@@ -179,19 +167,14 @@ impl Database {
requests: InsertRequests,
hints: &[(&str, &str)],
) -> Result<u32> {
let mut client = make_database_client(&self.client)?;
let mut client = make_database_client(&self.client)?.inner;
let request = self.to_rpc_request(Request::Inserts(requests));
let mut request = tonic::Request::new(request);
let metadata = request.metadata_mut();
Self::put_hints(metadata, hints)?;
let response = client
.inner
.handle(request)
.await
.inspect_err(client.inspect_err("insert_with_hints"))?
.into_inner();
let response = client.handle(request).await?.into_inner();
from_grpc_response(response)
}
@@ -206,19 +189,14 @@ impl Database {
requests: RowInsertRequests,
hints: &[(&str, &str)],
) -> Result<u32> {
let mut client = make_database_client(&self.client)?;
let mut client = make_database_client(&self.client)?.inner;
let request = self.to_rpc_request(Request::RowInserts(requests));
let mut request = tonic::Request::new(request);
let metadata = request.metadata_mut();
Self::put_hints(metadata, hints)?;
let response = client
.inner
.handle(request)
.await
.inspect_err(client.inspect_err("row_inserts_with_hints"))?
.into_inner();
let response = client.handle(request).await?.into_inner();
from_grpc_response(response)
}
@@ -239,14 +217,9 @@ impl Database {
/// Make a request to the database.
pub async fn handle(&self, request: Request) -> Result<u32> {
let mut client = make_database_client(&self.client)?;
let mut client = make_database_client(&self.client)?.inner;
let request = self.to_rpc_request(request);
let response = client
.inner
.handle(request)
.await
.inspect_err(client.inspect_err("handle"))?
.into_inner();
let response = client.handle(request).await?.into_inner();
from_grpc_response(response)
}
@@ -258,7 +231,7 @@ impl Database {
max_retries: u32,
hints: &[(&str, &str)],
) -> Result<u32> {
let mut client = make_database_client(&self.client)?;
let mut client = make_database_client(&self.client)?.inner;
let mut retries = 0;
let request = self.to_rpc_request(request);
@@ -267,11 +240,7 @@ impl Database {
let mut tonic_request = tonic::Request::new(request.clone());
let metadata = tonic_request.metadata_mut();
Self::put_hints(metadata, hints)?;
let raw_response = client
.inner
.handle(tonic_request)
.await
.inspect_err(client.inspect_err("handle"));
let raw_response = client.handle(tonic_request).await;
match (raw_response, retries < max_retries) {
(Ok(resp), _) => return from_grpc_response(resp.into_inner()),
(Err(err), true) => {

View File

@@ -133,13 +133,6 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("External error"))]
External {
#[snafu(implicit)]
location: Location,
source: BoxedError,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -161,7 +154,6 @@ impl ErrorExt for Error {
Error::IllegalGrpcClientState { .. } => StatusCode::Unexpected,
Error::InvalidTonicMetadataValue { .. } => StatusCode::InvalidArguments,
Error::ConvertSchema { source, .. } => source.status_code(),
Error::External { source, .. } => source.status_code(),
}
}

View File

@@ -1,60 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
use api::v1::RowInsertRequests;
use humantime::format_duration;
use store_api::mito_engine_options::{APPEND_MODE_KEY, TTL_KEY};
use crate::error::Result;
/// Context holds the catalog and schema information.
pub struct Context<'a> {
/// The catalog name.
pub catalog: &'a str,
/// The schema name.
pub schema: &'a str,
}
/// Options for insert operations.
#[derive(Debug, Clone, Copy)]
pub struct InsertOptions {
/// Time-to-live for the inserted data.
pub ttl: Duration,
/// Whether to use append mode for the insert.
pub append_mode: bool,
}
impl InsertOptions {
/// Converts the insert options to a list of key-value string hints.
pub fn to_hints(&self) -> Vec<(&'static str, String)> {
vec![
(TTL_KEY, format_duration(self.ttl).to_string()),
(APPEND_MODE_KEY, self.append_mode.to_string()),
]
}
}
/// [`Inserter`] allows different components to share a unified mechanism for inserting data.
///
/// An implementation may perform the insert locally (e.g., via a direct procedure call) or
/// delegate/forward it to another node for processing (e.g., MetaSrv forwarding to an
/// available Frontend).
#[async_trait::async_trait]
pub trait Inserter: Send + Sync {
async fn insert_rows(&self, context: &Context<'_>, requests: RowInsertRequests) -> Result<()>;
fn set_options(&mut self, options: &InsertOptions);
}

View File

@@ -19,7 +19,6 @@ pub mod client_manager;
pub mod database;
pub mod error;
pub mod flow;
pub mod inserter;
pub mod load_balance;
mod metrics;
pub mod region;

View File

@@ -376,8 +376,7 @@ impl StartCommand {
flow_auth_header,
opts.query.clone(),
opts.flow.batching_mode.clone(),
)
.context(StartFlownodeSnafu)?;
);
let frontend_client = Arc::new(frontend_client);
let flownode_builder = FlownodeBuilder::new(
opts.clone(),

View File

@@ -279,7 +279,7 @@ impl StartCommand {
&opts.component.logging,
&opts.component.tracing,
opts.component.node_id.clone(),
Some(&opts.component.slow_query),
opts.component.slow_query.as_ref(),
);
log_versions(verbose_version(), short_version(), APP_NAME);

View File

@@ -157,7 +157,7 @@ pub struct StandaloneOptions {
pub init_regions_in_background: bool,
pub init_regions_parallelism: usize,
pub max_in_flight_write_bytes: Option<ReadableSize>,
pub slow_query: SlowQueryOptions,
pub slow_query: Option<SlowQueryOptions>,
pub query: QueryOptions,
pub memory: MemoryOptions,
}
@@ -191,7 +191,7 @@ impl Default for StandaloneOptions {
init_regions_in_background: false,
init_regions_parallelism: 16,
max_in_flight_write_bytes: None,
slow_query: SlowQueryOptions::default(),
slow_query: Some(SlowQueryOptions::default()),
query: QueryOptions::default(),
memory: MemoryOptions::default(),
}
@@ -486,7 +486,7 @@ impl StartCommand {
&opts.component.logging,
&opts.component.tracing,
None,
Some(&opts.component.slow_query),
opts.component.slow_query.as_ref(),
);
log_versions(verbose_version(), short_version(), APP_NAME);
@@ -834,7 +834,6 @@ impl InformationExtension for StandaloneInformationExtension {
region_manifest: region_stat.manifest.into(),
data_topic_latest_entry_id: region_stat.data_topic_latest_entry_id,
metadata_topic_latest_entry_id: region_stat.metadata_topic_latest_entry_id,
write_bytes: 0,
}
})
.collect::<Vec<_>>();

View File

@@ -20,7 +20,6 @@ pub mod range_read;
#[allow(clippy::all)]
pub mod readable_size;
pub mod secrets;
pub mod serde;
pub type AffectedRows = usize;

View File

@@ -1,31 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Deserializer};
/// Deserialize an empty string as the default value.
pub fn empty_string_as_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: Default + Deserialize<'de>,
{
let s = String::deserialize(deserializer)?;
if s.is_empty() {
Ok(T::default())
} else {
T::deserialize(serde::de::value::StringDeserializer::<D::Error>::new(s))
.map_err(serde::de::Error::custom)
}
}

View File

@@ -25,17 +25,19 @@ common-error.workspace = true
common-macro.workspace = true
common-recordbatch.workspace = true
common-runtime.workspace = true
common-telemetry.workspace = true
datafusion.workspace = true
datafusion-orc.workspace = true
datatypes.workspace = true
derive_builder.workspace = true
futures.workspace = true
lazy_static.workspace = true
object-store.workspace = true
object_store_opendal.workspace = true
orc-rust = { version = "0.6.3", default-features = false, features = ["async"] }
orc-rust = { git = "https://github.com/datafusion-contrib/orc-rust", rev = "3134cab581a8e91b942d6a23aca2916ea965f6bb", default-features = false, features = [
"async",
] }
parquet.workspace = true
paste.workspace = true
rand.workspace = true
regex = "1.7"
serde.workspace = true
snafu.workspace = true
@@ -45,4 +47,6 @@ tokio-util.workspace = true
url = "2.3"
[dev-dependencies]
common-telemetry.workspace = true
common-test-util.workspace = true
uuid.workspace = true

View File

@@ -12,11 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow_schema::Schema;
use std::sync::Arc;
use arrow_schema::{ArrowError, Schema, SchemaRef};
use async_trait::async_trait;
use bytes::Bytes;
use common_recordbatch::adapter::RecordBatchStreamTypeAdapter;
use datafusion::datasource::physical_plan::{FileMeta, FileOpenFuture, FileOpener};
use datafusion::error::{DataFusionError, Result as DfResult};
use futures::future::BoxFuture;
use futures::FutureExt;
use futures::{FutureExt, StreamExt, TryStreamExt};
use object_store::ObjectStore;
use orc_rust::arrow_reader::ArrowReaderBuilder;
use orc_rust::async_arrow_reader::ArrowStreamReader;
@@ -92,6 +97,67 @@ impl FileFormat for OrcFormat {
}
}
#[derive(Debug, Clone)]
pub struct OrcOpener {
object_store: Arc<ObjectStore>,
output_schema: SchemaRef,
projection: Option<Vec<usize>>,
}
impl OrcOpener {
pub fn new(
object_store: ObjectStore,
output_schema: SchemaRef,
projection: Option<Vec<usize>>,
) -> Self {
Self {
object_store: Arc::from(object_store),
output_schema,
projection,
}
}
}
impl FileOpener for OrcOpener {
fn open(&self, meta: FileMeta) -> DfResult<FileOpenFuture> {
let object_store = self.object_store.clone();
let projected_schema = if let Some(projection) = &self.projection {
let projected_schema = self
.output_schema
.project(projection)
.map_err(|e| DataFusionError::External(Box::new(e)))?;
Arc::new(projected_schema)
} else {
self.output_schema.clone()
};
let projection = self.projection.clone();
Ok(Box::pin(async move {
let path = meta.location().to_string();
let meta = object_store
.stat(&path)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let reader = object_store
.reader(&path)
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let stream_reader =
new_orc_stream_reader(ReaderAdapter::new(reader, meta.content_length()))
.await
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let stream =
RecordBatchStreamTypeAdapter::new(projected_schema, stream_reader, projection);
let adopted = stream.map_err(|e| ArrowError::ExternalError(Box::new(e)));
Ok(adopted.boxed())
}))
}
}
#[cfg(test)]
mod tests {
use common_test_util::find_workspace_path;

View File

@@ -31,7 +31,6 @@ use datatypes::schema::SchemaRef;
use futures::future::BoxFuture;
use futures::StreamExt;
use object_store::{FuturesAsyncReader, ObjectStore};
use parquet::arrow::arrow_reader::ArrowReaderOptions;
use parquet::arrow::AsyncArrowWriter;
use parquet::basic::{Compression, Encoding, ZstdLevel};
use parquet::file::properties::{WriterProperties, WriterPropertiesBuilder};
@@ -66,7 +65,7 @@ impl FileFormat for ParquetFormat {
.compat();
let metadata = reader
.get_metadata(None)
.get_metadata()
.await
.context(error::ReadParquetSnafuSnafu)?;
@@ -147,7 +146,7 @@ impl LazyParquetFileReader {
impl AsyncFileReader for LazyParquetFileReader {
fn get_bytes(
&mut self,
range: std::ops::Range<u64>,
range: std::ops::Range<usize>,
) -> BoxFuture<'_, ParquetResult<bytes::Bytes>> {
Box::pin(async move {
self.maybe_initialize()
@@ -158,16 +157,13 @@ impl AsyncFileReader for LazyParquetFileReader {
})
}
fn get_metadata<'a>(
&'a mut self,
options: Option<&'a ArrowReaderOptions>,
) -> BoxFuture<'a, parquet::errors::Result<Arc<ParquetMetaData>>> {
fn get_metadata(&mut self) -> BoxFuture<'_, ParquetResult<Arc<ParquetMetaData>>> {
Box::pin(async move {
self.maybe_initialize()
.await
.map_err(|e| ParquetError::External(Box::new(e)))?;
// Safety: Must initialized
self.reader.as_mut().unwrap().get_metadata(options).await
self.reader.as_mut().unwrap().get_metadata().await
})
}
}

View File

@@ -19,39 +19,35 @@ use std::vec;
use common_test_util::find_workspace_path;
use datafusion::assert_batches_eq;
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
use datafusion::datasource::physical_plan::{
CsvSource, FileScanConfig, FileSource, FileStream, JsonSource, ParquetSource,
CsvConfig, CsvOpener, FileOpener, FileScanConfig, FileStream, JsonOpener, ParquetExec,
};
use datafusion::datasource::source::DataSourceExec;
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_orc::OrcSource;
use futures::StreamExt;
use object_store::ObjectStore;
use super::FORMAT_TYPE;
use crate::file_format::orc::{OrcFormat, OrcOpener};
use crate::file_format::parquet::DefaultParquetFileReaderFactory;
use crate::file_format::{FileFormat, Format, OrcFormat};
use crate::file_format::{FileFormat, Format};
use crate::test_util::{scan_config, test_basic_schema, test_store};
use crate::{error, test_util};
struct Test<'a> {
struct Test<'a, T: FileOpener> {
config: FileScanConfig,
file_source: Arc<dyn FileSource>,
opener: T,
expected: Vec<&'a str>,
}
impl Test<'_> {
async fn run(self, store: &ObjectStore) {
let store = Arc::new(object_store_opendal::OpendalStore::new(store.clone()));
let file_opener = self.file_source.create_file_opener(store, &self.config, 0);
impl<T: FileOpener> Test<'_, T> {
pub async fn run(self) {
let result = FileStream::new(
&self.config,
0,
file_opener,
self.opener,
&ExecutionPlanMetricsSet::new(),
)
.unwrap()
@@ -66,16 +62,26 @@ impl Test<'_> {
#[tokio::test]
async fn test_json_opener() {
let store = test_store("/");
let store = Arc::new(object_store_opendal::OpendalStore::new(store));
let schema = test_basic_schema();
let file_source = Arc::new(JsonSource::new()).with_batch_size(test_util::TEST_BATCH_SIZE);
let json_opener = || {
JsonOpener::new(
test_util::TEST_BATCH_SIZE,
schema.clone(),
FileCompressionType::UNCOMPRESSED,
store.clone(),
)
};
let path = &find_workspace_path("/src/common/datasource/tests/json/basic.json")
.display()
.to_string();
let tests = [
Test {
config: scan_config(schema.clone(), None, path, file_source.clone()),
file_source: file_source.clone(),
config: scan_config(schema.clone(), None, path),
opener: json_opener(),
expected: vec![
"+-----+-------+",
"| num | str |",
@@ -87,8 +93,8 @@ async fn test_json_opener() {
],
},
Test {
config: scan_config(schema, Some(1), path, file_source.clone()),
file_source,
config: scan_config(schema.clone(), Some(1), path),
opener: json_opener(),
expected: vec![
"+-----+------+",
"| num | str |",
@@ -100,26 +106,37 @@ async fn test_json_opener() {
];
for test in tests {
test.run(&store).await;
test.run().await;
}
}
#[tokio::test]
async fn test_csv_opener() {
let store = test_store("/");
let store = Arc::new(object_store_opendal::OpendalStore::new(store));
let schema = test_basic_schema();
let path = &find_workspace_path("/src/common/datasource/tests/csv/basic.csv")
.display()
.to_string();
let csv_config = Arc::new(CsvConfig::new(
test_util::TEST_BATCH_SIZE,
schema.clone(),
None,
true,
b',',
b'"',
None,
store,
None,
));
let file_source = CsvSource::new(true, b',', b'"')
.with_batch_size(test_util::TEST_BATCH_SIZE)
.with_schema(schema.clone());
let csv_opener = || CsvOpener::new(csv_config.clone(), FileCompressionType::UNCOMPRESSED);
let tests = [
Test {
config: scan_config(schema.clone(), None, path, file_source.clone()),
file_source: file_source.clone(),
config: scan_config(schema.clone(), None, path),
opener: csv_opener(),
expected: vec![
"+-----+-------+",
"| num | str |",
@@ -131,8 +148,8 @@ async fn test_csv_opener() {
],
},
Test {
config: scan_config(schema, Some(1), path, file_source.clone()),
file_source,
config: scan_config(schema.clone(), Some(1), path),
opener: csv_opener(),
expected: vec![
"+-----+------+",
"| num | str |",
@@ -144,7 +161,7 @@ async fn test_csv_opener() {
];
for test in tests {
test.run(&store).await;
test.run().await;
}
}
@@ -157,12 +174,12 @@ async fn test_parquet_exec() {
let path = &find_workspace_path("/src/common/datasource/tests/parquet/basic.parquet")
.display()
.to_string();
let base_config = scan_config(schema.clone(), None, path);
let parquet_source = ParquetSource::default()
.with_parquet_file_reader_factory(Arc::new(DefaultParquetFileReaderFactory::new(store)));
let exec = ParquetExec::builder(base_config)
.with_parquet_file_reader_factory(Arc::new(DefaultParquetFileReaderFactory::new(store)))
.build();
let config = scan_config(schema, None, path, Arc::new(parquet_source));
let exec = DataSourceExec::from_data_source(config);
let ctx = SessionContext::new();
let context = Arc::new(TaskContext::from(&ctx));
@@ -191,18 +208,20 @@ async fn test_parquet_exec() {
#[tokio::test]
async fn test_orc_opener() {
let path = &find_workspace_path("/src/common/datasource/tests/orc/test.orc")
let root = find_workspace_path("/src/common/datasource/tests/orc")
.display()
.to_string();
let store = test_store(&root);
let schema = OrcFormat.infer_schema(&store, "test.orc").await.unwrap();
let schema = Arc::new(schema);
let store = test_store("/");
let schema = Arc::new(OrcFormat.infer_schema(&store, path).await.unwrap());
let file_source = Arc::new(OrcSource::default());
let orc_opener = OrcOpener::new(store.clone(), schema.clone(), None);
let path = "test.orc";
let tests = [
Test {
config: scan_config(schema.clone(), None, path, file_source.clone()),
file_source: file_source.clone(),
config: scan_config(schema.clone(), None, path),
opener: orc_opener.clone(),
expected: vec![
"+----------+-----+-------+------------+-----+-----+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+----------------------------+-------------+",
"| double_a | a | b | str_direct | d | e | f | int_short_repeated | int_neg_short_repeated | int_delta | int_neg_delta | int_direct | int_neg_direct | bigint_direct | bigint_neg_direct | bigint_other | utf8_increase | utf8_decrease | timestamp_simple | date_simple |",
@@ -216,8 +235,8 @@ async fn test_orc_opener() {
],
},
Test {
config: scan_config(schema.clone(), Some(1), path, file_source.clone()),
file_source,
config: scan_config(schema.clone(), Some(1), path),
opener: orc_opener.clone(),
expected: vec![
"+----------+-----+------+------------+---+-----+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+-------------------------+-------------+",
"| double_a | a | b | str_direct | d | e | f | int_short_repeated | int_neg_short_repeated | int_delta | int_neg_delta | int_direct | int_neg_direct | bigint_direct | bigint_neg_direct | bigint_other | utf8_increase | utf8_decrease | timestamp_simple | date_simple |",
@@ -229,7 +248,7 @@ async fn test_orc_opener() {
];
for test in tests {
test.run(&store).await;
test.run().await;
}
}

View File

@@ -16,12 +16,12 @@ use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use common_test_util::temp_dir::{create_temp_dir, TempDir};
use datafusion::common::{Constraints, Statistics};
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::object_store::ObjectStoreUrl;
use datafusion::datasource::physical_plan::{
CsvSource, FileGroup, FileScanConfig, FileScanConfigBuilder, FileSource, FileStream,
JsonOpener, JsonSource,
CsvConfig, CsvOpener, FileScanConfig, FileStream, JsonOpener,
};
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use object_store::services::Fs;
@@ -68,20 +68,21 @@ pub fn test_basic_schema() -> SchemaRef {
Arc::new(schema)
}
pub(crate) fn scan_config(
file_schema: SchemaRef,
limit: Option<usize>,
filename: &str,
file_source: Arc<dyn FileSource>,
) -> FileScanConfig {
pub fn scan_config(file_schema: SchemaRef, limit: Option<usize>, filename: &str) -> FileScanConfig {
// object_store only recognize the Unix style path, so make it happy.
let filename = &filename.replace('\\', "/");
let file_group = FileGroup::new(vec![PartitionedFile::new(filename.to_string(), 4096)]);
FileScanConfigBuilder::new(ObjectStoreUrl::local_filesystem(), file_schema, file_source)
.with_file_group(file_group)
.with_limit(limit)
.build()
let statistics = Statistics::new_unknown(file_schema.as_ref());
FileScanConfig {
object_store_url: ObjectStoreUrl::parse("empty://").unwrap(), // won't be used
file_schema,
file_groups: vec![vec![PartitionedFile::new(filename.to_string(), 10)]],
constraints: Constraints::empty(),
statistics,
projection: None,
limit,
table_partition_cols: vec![],
output_ordering: vec![],
}
}
pub async fn setup_stream_to_json_test(origin_path: &str, threshold: impl Fn(usize) -> usize) {
@@ -98,14 +99,9 @@ pub async fn setup_stream_to_json_test(origin_path: &str, threshold: impl Fn(usi
let size = store.read(origin_path).await.unwrap().len();
let config = scan_config(schema, None, origin_path, Arc::new(JsonSource::new()));
let stream = FileStream::new(
&config,
0,
Arc::new(json_opener),
&ExecutionPlanMetricsSet::new(),
)
.unwrap();
let config = scan_config(schema.clone(), None, origin_path);
let stream = FileStream::new(&config, 0, json_opener, &ExecutionPlanMetricsSet::new()).unwrap();
let (tmp_store, dir) = test_tmp_store("test_stream_to_json");
@@ -131,17 +127,24 @@ pub async fn setup_stream_to_csv_test(origin_path: &str, threshold: impl Fn(usiz
let schema = test_basic_schema();
let csv_source = CsvSource::new(true, b',', b'"')
.with_schema(schema.clone())
.with_batch_size(TEST_BATCH_SIZE);
let config = scan_config(schema, None, origin_path, csv_source.clone());
let csv_config = Arc::new(CsvConfig::new(
TEST_BATCH_SIZE,
schema.clone(),
None,
true,
b',',
b'"',
None,
Arc::new(object_store_opendal::OpendalStore::new(store.clone())),
None,
));
let csv_opener = CsvOpener::new(csv_config, FileCompressionType::UNCOMPRESSED);
let size = store.read(origin_path).await.unwrap().len();
let csv_opener = csv_source.create_file_opener(
Arc::new(object_store_opendal::OpendalStore::new(store.clone())),
&config,
0,
);
let config = scan_config(schema.clone(), None, origin_path);
let stream = FileStream::new(&config, 0, csv_opener, &ExecutionPlanMetricsSet::new()).unwrap();
let (tmp_store, dir) = test_tmp_store("test_stream_to_csv");

View File

@@ -12,9 +12,6 @@ common-error.workspace = true
common-macro.workspace = true
common-telemetry.workspace = true
common-time.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
itertools.workspace = true
serde.workspace = true
serde_json.workspace = true
snafu.workspace = true

View File

@@ -22,6 +22,12 @@ use snafu::{Location, Snafu};
#[snafu(visibility(pub))]
#[stack_trace_debug]
pub enum Error {
#[snafu(display("No available frontend"))]
NoAvailableFrontend {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Mismatched schema, expected: {:?}, actual: {:?}", expected, actual))]
MismatchedSchema {
#[snafu(implicit)]
@@ -63,7 +69,9 @@ impl ErrorExt for Error {
Error::MismatchedSchema { .. } | Error::SerializeEvent { .. } => {
StatusCode::InvalidArguments
}
Error::InsertEvents { .. } | Error::KvBackend { .. } => StatusCode::Internal,
Error::NoAvailableFrontend { .. }
| Error::InsertEvents { .. }
| Error::KvBackend { .. } => StatusCode::Internal,
}
}

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![feature(duration_constructors)]
pub mod error;
pub mod recorder;

View File

@@ -15,7 +15,7 @@
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use api::v1::column_data_type_extension::TypeExt;
@@ -28,8 +28,6 @@ use async_trait::async_trait;
use backon::{BackoffBuilder, ExponentialBuilder};
use common_telemetry::{debug, error, info, warn};
use common_time::timestamp::{TimeUnit, Timestamp};
use humantime::format_duration;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use store_api::mito_engine_options::{APPEND_MODE_KEY, TTL_KEY};
use tokio::sync::mpsc::{channel, Receiver, Sender};
@@ -52,10 +50,12 @@ pub const EVENTS_TABLE_TIMESTAMP_COLUMN_NAME: &str = "timestamp";
/// EventRecorderRef is the reference to the event recorder.
pub type EventRecorderRef = Arc<dyn EventRecorder>;
static EVENTS_TABLE_TTL: OnceLock<String> = OnceLock::new();
/// The time interval for flushing batched events to the event handler.
pub const DEFAULT_FLUSH_INTERVAL_SECONDS: Duration = Duration::from_secs(5);
/// The default TTL(90 days) for the events table.
const DEFAULT_EVENTS_TABLE_TTL: Duration = Duration::from_days(90);
// The default TTL for the events table.
const DEFAULT_EVENTS_TABLE_TTL: &str = "30d";
// The capacity of the tokio channel for transmitting events to background processor.
const DEFAULT_CHANNEL_SIZE: usize = 2048;
// The size of the buffer for batching events before flushing to event handler.
@@ -72,11 +72,6 @@ const DEFAULT_MAX_RETRY_TIMES: u64 = 3;
///
/// The event can also add the extra schema and row to the event by overriding the `extra_schema` and `extra_row` methods.
pub trait Event: Send + Sync + Debug {
/// Returns the table name of the event.
fn table_name(&self) -> &str {
DEFAULT_EVENTS_TABLE_NAME
}
/// Returns the type of the event.
fn event_type(&self) -> &str;
@@ -112,68 +107,88 @@ pub trait Eventable: Send + Sync + Debug {
}
}
/// Groups events by its `event_type`.
#[allow(clippy::borrowed_box)]
pub fn group_events_by_type(events: &[Box<dyn Event>]) -> HashMap<&str, Vec<&Box<dyn Event>>> {
events
.iter()
.into_grouping_map_by(|event| event.event_type())
.collect()
/// Returns the hints for the insert operation.
pub fn insert_hints() -> Vec<(&'static str, &'static str)> {
vec![
(
TTL_KEY,
EVENTS_TABLE_TTL
.get()
.map(|s| s.as_str())
.unwrap_or(DEFAULT_EVENTS_TABLE_TTL),
),
(APPEND_MODE_KEY, "true"),
]
}
/// Builds the row inserts request for the events that will be persisted to the events table. The `events` should have the same event type, or it will return an error.
#[allow(clippy::borrowed_box)]
pub fn build_row_inserts_request(events: &[&Box<dyn Event>]) -> Result<RowInsertRequests> {
// Ensure all the events are the same type.
validate_events(events)?;
/// Builds the row inserts request for the events that will be persisted to the events table.
pub fn build_row_inserts_request(events: &[Box<dyn Event>]) -> Result<RowInsertRequests> {
// Aggregate the events by the event type.
let mut event_groups: HashMap<&str, Vec<&Box<dyn Event>>> = HashMap::new();
// We already validated the events, so it's safe to get the first event to build the schema for the RowInsertRequest.
let event = &events[0];
let mut schema: Vec<ColumnSchema> = Vec::with_capacity(3 + event.extra_schema().len());
schema.extend(vec![
ColumnSchema {
column_name: EVENTS_TABLE_TYPE_COLUMN_NAME.to_string(),
datatype: ColumnDataType::String.into(),
semantic_type: SemanticType::Tag.into(),
..Default::default()
},
ColumnSchema {
column_name: EVENTS_TABLE_PAYLOAD_COLUMN_NAME.to_string(),
datatype: ColumnDataType::Binary as i32,
semantic_type: SemanticType::Field as i32,
datatype_extension: Some(ColumnDataTypeExtension {
type_ext: Some(TypeExt::JsonType(JsonTypeExtension::JsonBinary.into())),
}),
..Default::default()
},
ColumnSchema {
column_name: EVENTS_TABLE_TIMESTAMP_COLUMN_NAME.to_string(),
datatype: ColumnDataType::TimestampNanosecond.into(),
semantic_type: SemanticType::Timestamp.into(),
..Default::default()
},
]);
schema.extend(event.extra_schema());
let mut rows: Vec<Row> = Vec::with_capacity(events.len());
for event in events {
let extra_row = event.extra_row()?;
let mut values = Vec::with_capacity(3 + extra_row.values.len());
values.extend([
ValueData::StringValue(event.event_type().to_string()).into(),
ValueData::BinaryValue(event.json_payload()?.into_bytes()).into(),
ValueData::TimestampNanosecondValue(event.timestamp().value()).into(),
]);
values.extend(extra_row.values);
rows.push(Row { values });
event_groups
.entry(event.event_type())
.or_default()
.push(event);
}
Ok(RowInsertRequests {
inserts: vec![RowInsertRequest {
table_name: event.table_name().to_string(),
let mut row_insert_requests = RowInsertRequests {
inserts: Vec::with_capacity(event_groups.len()),
};
for (_, events) in event_groups {
validate_events(&events)?;
// We already validated the events, so it's safe to get the first event to build the schema for the RowInsertRequest.
let event = &events[0];
let mut schema = vec![
ColumnSchema {
column_name: EVENTS_TABLE_TYPE_COLUMN_NAME.to_string(),
datatype: ColumnDataType::String.into(),
semantic_type: SemanticType::Tag.into(),
..Default::default()
},
ColumnSchema {
column_name: EVENTS_TABLE_PAYLOAD_COLUMN_NAME.to_string(),
datatype: ColumnDataType::Binary as i32,
semantic_type: SemanticType::Field as i32,
datatype_extension: Some(ColumnDataTypeExtension {
type_ext: Some(TypeExt::JsonType(JsonTypeExtension::JsonBinary.into())),
}),
..Default::default()
},
ColumnSchema {
column_name: EVENTS_TABLE_TIMESTAMP_COLUMN_NAME.to_string(),
datatype: ColumnDataType::TimestampNanosecond.into(),
semantic_type: SemanticType::Timestamp.into(),
..Default::default()
},
];
schema.extend(event.extra_schema());
let rows = events
.iter()
.map(|event| {
let mut row = Row {
values: vec![
ValueData::StringValue(event.event_type().to_string()).into(),
ValueData::BinaryValue(event.json_payload()?.as_bytes().to_vec()).into(),
ValueData::TimestampNanosecondValue(event.timestamp().value()).into(),
],
};
row.values.extend(event.extra_row()?.values);
Ok(row)
})
.collect::<Result<Vec<_>>>()?;
row_insert_requests.inserts.push(RowInsertRequest {
table_name: DEFAULT_EVENTS_TABLE_NAME.to_string(),
rows: Some(Rows { schema, rows }),
}],
})
});
}
Ok(row_insert_requests)
}
// Ensure the events with the same event type have the same extra schema.
@@ -202,34 +217,6 @@ pub trait EventRecorder: Send + Sync + Debug + 'static {
fn close(&self);
}
/// EventHandlerOptions is the options for the event handler.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct EventHandlerOptions {
/// TTL for the events table that will be used to store the events.
pub ttl: Duration,
/// Append mode for the events table that will be used to store the events.
pub append_mode: bool,
}
impl Default for EventHandlerOptions {
fn default() -> Self {
Self {
ttl: DEFAULT_EVENTS_TABLE_TTL,
append_mode: true,
}
}
}
impl EventHandlerOptions {
/// Converts the options to the hints for the insert operation.
pub fn to_hints(&self) -> Vec<(&str, String)> {
vec![
(TTL_KEY, format_duration(self.ttl).to_string()),
(APPEND_MODE_KEY, self.append_mode.to_string()),
]
}
}
/// EventHandler trait defines the interface for how to handle the event.
#[async_trait]
pub trait EventHandler: Send + Sync + 'static {
@@ -242,14 +229,13 @@ pub trait EventHandler: Send + Sync + 'static {
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct EventRecorderOptions {
/// TTL for the events table that will be used to store the events.
#[serde(with = "humantime_serde")]
pub ttl: Duration,
pub ttl: String,
}
impl Default for EventRecorderOptions {
fn default() -> Self {
Self {
ttl: DEFAULT_EVENTS_TABLE_TTL,
ttl: DEFAULT_EVENTS_TABLE_TTL.to_string(),
}
}
}
@@ -266,7 +252,9 @@ pub struct EventRecorderImpl {
}
impl EventRecorderImpl {
pub fn new(event_handler: Box<dyn EventHandler>) -> Self {
pub fn new(event_handler: Box<dyn EventHandler>, opts: EventRecorderOptions) -> Self {
info!("Creating event recorder with options: {:?}", opts);
let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
let cancel_token = CancellationToken::new();
@@ -291,6 +279,14 @@ impl EventRecorderImpl {
recorder.handle = Some(handle);
// It only sets the ttl once, so it's safe to skip the error.
if EVENTS_TABLE_TTL.set(opts.ttl.clone()).is_err() {
info!(
"Events table ttl already set to {}, skip setting it",
opts.ttl
);
}
recorder
}
}
@@ -475,7 +471,10 @@ mod tests {
#[tokio::test]
async fn test_event_recorder() {
let mut event_recorder = EventRecorderImpl::new(Box::new(TestEventHandlerImpl {}));
let mut event_recorder = EventRecorderImpl::new(
Box::new(TestEventHandlerImpl {}),
EventRecorderOptions::default(),
);
event_recorder.record(Box::new(TestEvent {}));
// Sleep for a while to let the event be sent to the event handler.
@@ -516,8 +515,10 @@ mod tests {
#[tokio::test]
async fn test_event_recorder_should_panic() {
let mut event_recorder =
EventRecorderImpl::new(Box::new(TestEventHandlerImplShouldPanic {}));
let mut event_recorder = EventRecorderImpl::new(
Box::new(TestEventHandlerImplShouldPanic {}),
EventRecorderOptions::default(),
);
event_recorder.record(Box::new(TestEvent {}));
@@ -534,135 +535,4 @@ mod tests {
assert!(handle.await.unwrap_err().is_panic());
}
}
#[derive(Debug)]
struct TestEventA {}
impl Event for TestEventA {
fn event_type(&self) -> &str {
"A"
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug)]
struct TestEventB {}
impl Event for TestEventB {
fn table_name(&self) -> &str {
"table_B"
}
fn event_type(&self) -> &str {
"B"
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug)]
struct TestEventC {}
impl Event for TestEventC {
fn table_name(&self) -> &str {
"table_C"
}
fn event_type(&self) -> &str {
"C"
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[test]
fn test_group_events_by_type() {
let events: Vec<Box<dyn Event>> = vec![
Box::new(TestEventA {}),
Box::new(TestEventB {}),
Box::new(TestEventA {}),
Box::new(TestEventC {}),
Box::new(TestEventB {}),
Box::new(TestEventC {}),
Box::new(TestEventA {}),
];
let event_groups = group_events_by_type(&events);
assert_eq!(event_groups.len(), 3);
assert_eq!(event_groups.get("A").unwrap().len(), 3);
assert_eq!(event_groups.get("B").unwrap().len(), 2);
assert_eq!(event_groups.get("C").unwrap().len(), 2);
}
#[test]
fn test_build_row_inserts_request() {
let events: Vec<Box<dyn Event>> = vec![
Box::new(TestEventA {}),
Box::new(TestEventB {}),
Box::new(TestEventA {}),
Box::new(TestEventC {}),
Box::new(TestEventB {}),
Box::new(TestEventC {}),
Box::new(TestEventA {}),
];
let event_groups = group_events_by_type(&events);
assert_eq!(event_groups.len(), 3);
assert_eq!(event_groups.get("A").unwrap().len(), 3);
assert_eq!(event_groups.get("B").unwrap().len(), 2);
assert_eq!(event_groups.get("C").unwrap().len(), 2);
for (event_type, events) in event_groups {
let row_inserts_request = build_row_inserts_request(&events).unwrap();
if event_type == "A" {
assert_eq!(row_inserts_request.inserts.len(), 1);
assert_eq!(
row_inserts_request.inserts[0].table_name,
DEFAULT_EVENTS_TABLE_NAME
);
assert_eq!(
row_inserts_request.inserts[0]
.rows
.as_ref()
.unwrap()
.rows
.len(),
3
);
} else if event_type == "B" {
assert_eq!(row_inserts_request.inserts.len(), 1);
assert_eq!(row_inserts_request.inserts[0].table_name, "table_B");
assert_eq!(
row_inserts_request.inserts[0]
.rows
.as_ref()
.unwrap()
.rows
.len(),
2
);
} else if event_type == "C" {
assert_eq!(row_inserts_request.inserts.len(), 1);
assert_eq!(row_inserts_request.inserts[0].table_name, "table_C");
assert_eq!(
row_inserts_request.inserts[0]
.rows
.as_ref()
.unwrap()
.rows
.len(),
2
);
} else {
panic!("Unexpected event type: {}", event_type);
}
}
}
}

View File

@@ -5,18 +5,13 @@ edition.workspace = true
license.workspace = true
[dependencies]
api.workspace = true
async-trait.workspace = true
common-error.workspace = true
common-event-recorder.workspace = true
common-grpc.workspace = true
common-macro.workspace = true
common-meta.workspace = true
common-time.workspace = true
greptime-proto.workspace = true
humantime.workspace = true
meta-client.workspace = true
serde.workspace = true
session.workspace = true
snafu.workspace = true
tonic.workspace = true

View File

@@ -12,121 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::any::Any;
use session::context::QueryContextRef;
use api::v1::value::ValueData;
use api::v1::{ColumnDataType, ColumnSchema, Row, SemanticType};
use common_event_recorder::error::Result;
use common_event_recorder::Event;
use serde::Serialize;
pub const SLOW_QUERY_TABLE_NAME: &str = "slow_queries";
pub const SLOW_QUERY_TABLE_COST_COLUMN_NAME: &str = "cost";
pub const SLOW_QUERY_TABLE_THRESHOLD_COLUMN_NAME: &str = "threshold";
pub const SLOW_QUERY_TABLE_QUERY_COLUMN_NAME: &str = "query";
pub const SLOW_QUERY_TABLE_TIMESTAMP_COLUMN_NAME: &str = "timestamp";
pub const SLOW_QUERY_TABLE_IS_PROMQL_COLUMN_NAME: &str = "is_promql";
pub const SLOW_QUERY_TABLE_PROMQL_START_COLUMN_NAME: &str = "promql_start";
pub const SLOW_QUERY_TABLE_PROMQL_END_COLUMN_NAME: &str = "promql_end";
pub const SLOW_QUERY_TABLE_PROMQL_RANGE_COLUMN_NAME: &str = "promql_range";
pub const SLOW_QUERY_TABLE_PROMQL_STEP_COLUMN_NAME: &str = "promql_step";
pub const SLOW_QUERY_EVENT_TYPE: &str = "slow_query";
/// SlowQueryEvent is the event of slow query.
#[derive(Debug, Serialize)]
#[derive(Debug)]
pub struct SlowQueryEvent {
pub cost: u64,
pub threshold: u64,
pub query: String,
pub is_promql: bool,
pub query_ctx: QueryContextRef,
pub promql_range: Option<u64>,
pub promql_step: Option<u64>,
pub promql_start: Option<i64>,
pub promql_end: Option<i64>,
}
impl Event for SlowQueryEvent {
fn table_name(&self) -> &str {
SLOW_QUERY_TABLE_NAME
}
fn event_type(&self) -> &str {
SLOW_QUERY_EVENT_TYPE
}
fn extra_schema(&self) -> Vec<ColumnSchema> {
vec![
ColumnSchema {
column_name: SLOW_QUERY_TABLE_COST_COLUMN_NAME.to_string(),
datatype: ColumnDataType::Uint64.into(),
semantic_type: SemanticType::Field.into(),
..Default::default()
},
ColumnSchema {
column_name: SLOW_QUERY_TABLE_THRESHOLD_COLUMN_NAME.to_string(),
datatype: ColumnDataType::Uint64.into(),
semantic_type: SemanticType::Field.into(),
..Default::default()
},
ColumnSchema {
column_name: SLOW_QUERY_TABLE_QUERY_COLUMN_NAME.to_string(),
datatype: ColumnDataType::String.into(),
semantic_type: SemanticType::Field.into(),
..Default::default()
},
ColumnSchema {
column_name: SLOW_QUERY_TABLE_IS_PROMQL_COLUMN_NAME.to_string(),
datatype: ColumnDataType::Boolean.into(),
semantic_type: SemanticType::Field.into(),
..Default::default()
},
ColumnSchema {
column_name: SLOW_QUERY_TABLE_PROMQL_RANGE_COLUMN_NAME.to_string(),
datatype: ColumnDataType::Uint64.into(),
semantic_type: SemanticType::Field.into(),
..Default::default()
},
ColumnSchema {
column_name: SLOW_QUERY_TABLE_PROMQL_STEP_COLUMN_NAME.to_string(),
datatype: ColumnDataType::Uint64.into(),
semantic_type: SemanticType::Field.into(),
..Default::default()
},
ColumnSchema {
column_name: SLOW_QUERY_TABLE_PROMQL_START_COLUMN_NAME.to_string(),
datatype: ColumnDataType::TimestampMillisecond.into(),
semantic_type: SemanticType::Field.into(),
..Default::default()
},
ColumnSchema {
column_name: SLOW_QUERY_TABLE_PROMQL_END_COLUMN_NAME.to_string(),
datatype: ColumnDataType::TimestampMillisecond.into(),
semantic_type: SemanticType::Field.into(),
..Default::default()
},
]
}
fn extra_row(&self) -> Result<Row> {
Ok(Row {
values: vec![
ValueData::U64Value(self.cost).into(),
ValueData::U64Value(self.threshold).into(),
ValueData::StringValue(self.query.to_string()).into(),
ValueData::BoolValue(self.is_promql).into(),
ValueData::U64Value(self.promql_range.unwrap_or(0)).into(),
ValueData::U64Value(self.promql_step.unwrap_or(0)).into(),
ValueData::TimestampMillisecondValue(self.promql_start.unwrap_or(0)).into(),
ValueData::TimestampMillisecondValue(self.promql_end.unwrap_or(0)).into(),
],
})
}
fn json_payload(&self) -> Result<String> {
Ok("".to_string())
}
fn as_any(&self) -> &dyn Any {
self
}
}

View File

@@ -21,6 +21,8 @@ mod reconcile_database;
mod reconcile_table;
mod remove_region_follower;
use std::sync::Arc;
use add_region_follower::AddRegionFollowerFunction;
use flush_compact_region::{CompactRegionFunction, FlushRegionFunction};
use flush_compact_table::{CompactTableFunction, FlushTableFunction};
@@ -33,22 +35,22 @@ use remove_region_follower::RemoveRegionFollowerFunction;
use crate::flush_flow::FlushFlowFunction;
use crate::function_registry::FunctionRegistry;
/// Administration functions
/// Table functions
pub(crate) struct AdminFunction;
impl AdminFunction {
/// Register all admin functions to [`FunctionRegistry`].
/// Register all table functions to [`FunctionRegistry`].
pub fn register(registry: &FunctionRegistry) {
registry.register(MigrateRegionFunction::factory());
registry.register(AddRegionFollowerFunction::factory());
registry.register(RemoveRegionFollowerFunction::factory());
registry.register(FlushRegionFunction::factory());
registry.register(CompactRegionFunction::factory());
registry.register(FlushTableFunction::factory());
registry.register(CompactTableFunction::factory());
registry.register(FlushFlowFunction::factory());
registry.register(ReconcileCatalogFunction::factory());
registry.register(ReconcileDatabaseFunction::factory());
registry.register(ReconcileTableFunction::factory());
registry.register_async(Arc::new(MigrateRegionFunction));
registry.register_async(Arc::new(AddRegionFollowerFunction));
registry.register_async(Arc::new(RemoveRegionFollowerFunction));
registry.register_async(Arc::new(FlushRegionFunction));
registry.register_async(Arc::new(CompactRegionFunction));
registry.register_async(Arc::new(FlushTableFunction));
registry.register_async(Arc::new(CompactTableFunction));
registry.register_async(Arc::new(FlushFlowFunction));
registry.register_async(Arc::new(ReconcileCatalogFunction));
registry.register_async(Arc::new(ReconcileDatabaseFunction));
registry.register_async(Arc::new(ReconcileTableFunction));
}
}

View File

@@ -18,8 +18,7 @@ use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use common_query::prelude::{Signature, TypeSignature, Volatility};
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
@@ -83,13 +82,7 @@ fn signature() -> Signature {
Signature::one_of(
vec![
// add_region_follower(region_id, peer)
TypeSignature::Uniform(
2,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
TypeSignature::Uniform(2, ConcreteDataType::numerics()),
],
Volatility::Immutable,
)
@@ -99,57 +92,38 @@ fn signature() -> Signature {
mod tests {
use std::sync::Arc;
use arrow::array::UInt64Array;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use common_query::prelude::TypeSignature;
use datatypes::vectors::{UInt64Vector, VectorRef};
use super::*;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
#[test]
fn test_add_region_follower_misc() {
let factory: ScalarFunctionFactory = AddRegionFollowerFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
let f = AddRegionFollowerFunction;
assert_eq!("add_region_follower", f.name());
assert_eq!(DataType::UInt64, f.return_type(&[]).unwrap());
assert_eq!(
ConcreteDataType::uint64_datatype(),
f.return_type(&[]).unwrap()
);
assert!(matches!(f.signature(),
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
volatility: datafusion_expr::Volatility::Immutable
Signature {
type_signature: TypeSignature::OneOf(sigs),
volatility: Volatility::Immutable
} if sigs.len() == 1));
}
#[tokio::test]
async fn test_add_region_follower() {
let factory: ScalarFunctionFactory = AddRegionFollowerFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let f = AddRegionFollowerFunction;
let args = vec![1, 1];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![2]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 0u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(0)));
}
}
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(UInt64Vector::from_slice([0u64]));
assert_eq!(result, expect);
}
}

View File

@@ -16,8 +16,7 @@ use common_macro::admin_fn;
use common_query::error::{
InvalidFuncArgsSnafu, MissingTableMutationHandlerSnafu, Result, UnsupportedInputDataTypeSnafu,
};
use datafusion_expr::{Signature, Volatility};
use datatypes::data_type::DataType;
use common_query::prelude::{Signature, Volatility};
use datatypes::prelude::*;
use session::context::QueryContextRef;
use snafu::ensure;
@@ -67,99 +66,71 @@ define_region_function!(FlushRegionFunction, flush_region, flush_region);
define_region_function!(CompactRegionFunction, compact_region, compact_region);
fn signature() -> Signature {
Signature::uniform(
1,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
Volatility::Immutable,
)
Signature::uniform(1, ConcreteDataType::numerics(), Volatility::Immutable)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::UInt64Array;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use common_query::prelude::TypeSignature;
use datatypes::vectors::UInt64Vector;
use super::*;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
macro_rules! define_region_function_test {
($name: ident, $func: ident) => {
paste::paste! {
#[test]
fn [<test_ $name _misc>]() {
let factory: ScalarFunctionFactory = $func::factory().into();
let f = factory.provide(FunctionContext::mock());
let f = $func;
assert_eq!(stringify!($name), f.name());
assert_eq!(
DataType::UInt64,
ConcreteDataType::uint64_datatype(),
f.return_type(&[]).unwrap()
);
assert!(matches!(f.signature(),
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::Uniform(1, valid_types),
volatility: datafusion_expr::Volatility::Immutable
} if valid_types == &ConcreteDataType::numerics().into_iter().map(|dt| { use datatypes::data_type::DataType; dt.as_arrow_type() }).collect::<Vec<_>>()));
Signature {
type_signature: TypeSignature::Uniform(1, valid_types),
volatility: Volatility::Immutable
} if valid_types == ConcreteDataType::numerics()));
}
#[tokio::test]
async fn [<test_ $name _missing_table_mutation>]() {
let factory: ScalarFunctionFactory = $func::factory().into();
let provider = factory.provide(FunctionContext::default());
let f = provider.as_async().unwrap();
let f = $func;
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![99]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap_err();
let args = vec![99];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
assert_eq!(
"Execution error: Handler error: Missing TableMutationHandler, not expected",
"Missing TableMutationHandler, not expected",
result.to_string()
);
}
#[tokio::test]
async fn [<test_ $name>]() {
let factory: ScalarFunctionFactory = $func::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let f = $func;
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![99]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 42u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(42)));
}
}
let args = vec![99];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(UInt64Vector::from_slice([42]));
assert_eq!(expect, result);
}
}
};

View File

@@ -15,15 +15,14 @@
use std::str::FromStr;
use api::v1::region::{compact_request, StrictWindow};
use arrow::datatypes::DataType as ArrowDataType;
use common_error::ext::BoxedError;
use common_macro::admin_fn;
use common_query::error::{
InvalidFuncArgsSnafu, MissingTableMutationHandlerSnafu, Result, TableMutationSnafu,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, Volatility};
use common_telemetry::info;
use datafusion_expr::{Signature, Volatility};
use datatypes::prelude::*;
use session::context::QueryContextRef;
use session::table_name::table_name_to_full_name;
@@ -106,11 +105,18 @@ pub(crate) async fn compact_table(
}
fn flush_signature() -> Signature {
Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
Signature::uniform(
1,
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
}
fn compact_signature() -> Signature {
Signature::variadic(vec![ArrowDataType::Utf8], Volatility::Immutable)
Signature::variadic(
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
}
/// Parses `compact_table` UDF parameters. This function accepts following combinations:
@@ -198,87 +204,66 @@ mod tests {
use std::sync::Arc;
use api::v1::region::compact_request::Options;
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use datafusion_expr::ColumnarValue;
use common_query::prelude::TypeSignature;
use datatypes::vectors::{StringVector, UInt64Vector};
use session::context::QueryContext;
use super::*;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
macro_rules! define_table_function_test {
($name: ident, $func: ident) => {
paste::paste!{
#[test]
fn [<test_ $name _misc>]() {
let factory: ScalarFunctionFactory = $func::factory().into();
let f = factory.provide(FunctionContext::mock());
let f = $func;
assert_eq!(stringify!($name), f.name());
assert_eq!(
DataType::UInt64,
ConcreteDataType::uint64_datatype(),
f.return_type(&[]).unwrap()
);
assert!(matches!(f.signature(),
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::Uniform(1, valid_types),
volatility: datafusion_expr::Volatility::Immutable
} if valid_types == &vec![ArrowDataType::Utf8]));
Signature {
type_signature: TypeSignature::Uniform(1, valid_types),
volatility: Volatility::Immutable
} if valid_types == vec![ConcreteDataType::string_datatype()]));
}
#[tokio::test]
async fn [<test_ $name _missing_table_mutation>]() {
let factory: ScalarFunctionFactory = $func::factory().into();
let provider = factory.provide(FunctionContext::default());
let f = provider.as_async().unwrap();
let f = $func;
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap_err();
let args = vec!["test"];
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from(vec![arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
assert_eq!(
"Execution error: Handler error: Missing TableMutationHandler, not expected",
"Missing TableMutationHandler, not expected",
result.to_string()
);
}
#[tokio::test]
async fn [<test_ $name>]() {
let factory: ScalarFunctionFactory = $func::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let f = $func;
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<arrow::array::UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 42u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(42)));
}
}
let args = vec!["test"];
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from(vec![arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(UInt64Vector::from_slice([42]));
assert_eq!(expect, result);
}
}
}

View File

@@ -17,8 +17,7 @@ use std::time::Duration;
use common_macro::admin_fn;
use common_meta::rpc::procedure::MigrateRegionRequest;
use common_query::error::{InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use common_query::prelude::{Signature, TypeSignature, Volatility};
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
@@ -104,21 +103,9 @@ fn signature() -> Signature {
Signature::one_of(
vec![
// migrate_region(region_id, from_peer, to_peer)
TypeSignature::Uniform(
3,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
TypeSignature::Uniform(3, ConcreteDataType::numerics()),
// migrate_region(region_id, from_peer, to_peer, timeout(secs))
TypeSignature::Uniform(
4,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
TypeSignature::Uniform(4, ConcreteDataType::numerics()),
],
Volatility::Immutable,
)
@@ -128,89 +115,59 @@ fn signature() -> Signature {
mod tests {
use std::sync::Arc;
use arrow::array::{StringArray, UInt64Array};
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use common_query::prelude::TypeSignature;
use datatypes::vectors::{StringVector, UInt64Vector, VectorRef};
use super::*;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
#[test]
fn test_migrate_region_misc() {
let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
let f = MigrateRegionFunction;
assert_eq!("migrate_region", f.name());
assert_eq!(DataType::Utf8, f.return_type(&[]).unwrap());
assert_eq!(
ConcreteDataType::string_datatype(),
f.return_type(&[]).unwrap()
);
assert!(matches!(f.signature(),
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
volatility: datafusion_expr::Volatility::Immutable
Signature {
type_signature: TypeSignature::OneOf(sigs),
volatility: Volatility::Immutable
} if sigs.len() == 2));
}
#[tokio::test]
async fn test_missing_procedure_service() {
let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
let provider = factory.provide(FunctionContext::default());
let f = provider.as_async().unwrap();
let f = MigrateRegionFunction;
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
Arc::new(Field::new("arg_2", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap_err();
let args = vec![1, 1, 1];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
assert_eq!(
"Execution error: Handler error: Missing ProcedureServiceHandler, not expected",
"Missing ProcedureServiceHandler, not expected",
result.to_string()
);
}
#[tokio::test]
async fn test_migrate_region() {
let factory: ScalarFunctionFactory = MigrateRegionFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let f = MigrateRegionFunction;
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
Arc::new(Field::new("arg_2", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
let args = vec![1, 1, 1];
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
}
}

View File

@@ -14,15 +14,13 @@
use api::v1::meta::reconcile_request::Target;
use api::v1::meta::{ReconcileCatalog, ReconcileRequest};
use arrow::datatypes::DataType as ArrowDataType;
use common_macro::admin_fn;
use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, TypeSignature, Volatility};
use common_telemetry::info;
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::*;
use session::context::QueryContextRef;
@@ -106,15 +104,15 @@ fn signature() -> Signature {
let mut signs = Vec::with_capacity(2 + nums.len());
signs.extend([
// reconcile_catalog()
TypeSignature::Nullary,
TypeSignature::NullAry,
// reconcile_catalog(resolve_strategy)
TypeSignature::Exact(vec![ArrowDataType::Utf8]),
TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
]);
for sign in nums {
// reconcile_catalog(resolve_strategy, parallelism)
signs.push(TypeSignature::Exact(vec![
ArrowDataType::Utf8,
sign.as_arrow_type(),
ConcreteDataType::string_datatype(),
sign,
]));
}
Signature::one_of(signs, Volatility::Immutable)
@@ -122,149 +120,60 @@ fn signature() -> Signature {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::sync::Arc;
use arrow::array::{StringArray, UInt64Array};
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use common_query::error::Error;
use datatypes::vectors::{StringVector, UInt64Vector, VectorRef};
use crate::admin::reconcile_catalog::ReconcileCatalogFunction;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
#[tokio::test]
async fn test_reconcile_catalog() {
common_telemetry::init_default_ut_logging();
// reconcile_catalog()
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![],
arg_fields: vec![],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
let f = ReconcileCatalogFunction;
let args = vec![];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
// reconcile_catalog(resolve_strategy)
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"UseMetasrv",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
let f = ReconcileCatalogFunction;
let args = vec![Arc::new(StringVector::from(vec!["UseMetasrv"])) as _];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
// reconcile_catalog(resolve_strategy, parallelism)
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![10]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
let f = ReconcileCatalogFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt64Vector::from_slice([10])) as _,
];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
// unsupported input data type
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
let f = ReconcileCatalogFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(StringVector::from(vec!["test"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::UnsupportedInputDataType { .. });
// invalid function args
let factory: ScalarFunctionFactory = ReconcileCatalogFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![10]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["10"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
Arc::new(Field::new("arg_2", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
let f = ReconcileCatalogFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt64Vector::from_slice([10])) as _,
Arc::new(StringVector::from(vec!["10"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::InvalidFuncArgs { .. });
}
}

View File

@@ -14,15 +14,13 @@
use api::v1::meta::reconcile_request::Target;
use api::v1::meta::{ReconcileDatabase, ReconcileRequest};
use arrow::datatypes::DataType as ArrowDataType;
use common_macro::admin_fn;
use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, TypeSignature, Volatility};
use common_telemetry::info;
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use datatypes::prelude::*;
use session::context::QueryContextRef;
@@ -115,16 +113,19 @@ fn signature() -> Signature {
let mut signs = Vec::with_capacity(2 + nums.len());
signs.extend([
// reconcile_database(datanode_name)
TypeSignature::Exact(vec![ArrowDataType::Utf8]),
TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
// reconcile_database(database_name, resolve_strategy)
TypeSignature::Exact(vec![ArrowDataType::Utf8, ArrowDataType::Utf8]),
TypeSignature::Exact(vec![
ConcreteDataType::string_datatype(),
ConcreteDataType::string_datatype(),
]),
]);
for sign in nums {
// reconcile_database(database_name, resolve_strategy, parallelism)
signs.push(TypeSignature::Exact(vec![
ArrowDataType::Utf8,
ArrowDataType::Utf8,
sign.as_arrow_type(),
ConcreteDataType::string_datatype(),
ConcreteDataType::string_datatype(),
sign,
]));
}
Signature::one_of(signs, Volatility::Immutable)
@@ -132,160 +133,66 @@ fn signature() -> Signature {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::sync::Arc;
use arrow::array::{StringArray, UInt32Array};
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use common_query::error::Error;
use datatypes::vectors::{StringVector, UInt32Vector, VectorRef};
use crate::admin::reconcile_database::ReconcileDatabaseFunction;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
#[tokio::test]
async fn test_reconcile_catalog() {
common_telemetry::init_default_ut_logging();
// reconcile_database(database_name)
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"test",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
let f = ReconcileDatabaseFunction;
let args = vec![Arc::new(StringVector::from(vec!["test"])) as _];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
// reconcile_database(database_name, resolve_strategy)
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
let f = ReconcileDatabaseFunction;
let args = vec![
Arc::new(StringVector::from(vec!["test"])) as _,
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
// reconcile_database(database_name, resolve_strategy, parallelism)
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt32Array::from(vec![10]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
Arc::new(Field::new("arg_2", DataType::UInt32, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
let f = ReconcileDatabaseFunction;
let args = vec![
Arc::new(StringVector::from(vec!["test"])) as _,
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt32Vector::from_slice([10])) as _,
];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
// invalid function args
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt32Array::from(vec![10]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["v1"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["v2"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::UInt32, false)),
Arc::new(Field::new("arg_2", DataType::Utf8, false)),
Arc::new(Field::new("arg_3", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
let f = ReconcileDatabaseFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt32Vector::from_slice([10])) as _,
Arc::new(StringVector::from(vec!["v1"])) as _,
Arc::new(StringVector::from(vec!["v2"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::InvalidFuncArgs { .. });
// unsupported input data type
let factory: ScalarFunctionFactory = ReconcileDatabaseFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseLatest"]))),
ColumnarValue::Array(Arc::new(UInt32Array::from(vec![10]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["v1"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::UInt32, false)),
Arc::new(Field::new("arg_2", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
let f = ReconcileDatabaseFunction;
let args = vec![
Arc::new(StringVector::from(vec!["UseLatest"])) as _,
Arc::new(UInt32Vector::from_slice([10])) as _,
Arc::new(StringVector::from(vec!["v1"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::UnsupportedInputDataType { .. });
}
}

View File

@@ -14,15 +14,14 @@
use api::v1::meta::reconcile_request::Target;
use api::v1::meta::{ReconcileRequest, ReconcileTable, ResolveStrategy};
use arrow::datatypes::DataType as ArrowDataType;
use common_catalog::format_full_table_name;
use common_error::ext::BoxedError;
use common_macro::admin_fn;
use common_query::error::{
MissingProcedureServiceHandlerSnafu, Result, TableMutationSnafu, UnsupportedInputDataTypeSnafu,
};
use common_query::prelude::{Signature, TypeSignature, Volatility};
use common_telemetry::info;
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::prelude::*;
use session::context::QueryContextRef;
use session::table_name::table_name_to_full_name;
@@ -94,9 +93,12 @@ fn signature() -> Signature {
Signature::one_of(
vec![
// reconcile_table(table_name)
TypeSignature::Exact(vec![ArrowDataType::Utf8]),
TypeSignature::Exact(vec![ConcreteDataType::string_datatype()]),
// reconcile_table(table_name, resolve_strategy)
TypeSignature::Exact(vec![ArrowDataType::Utf8, ArrowDataType::Utf8]),
TypeSignature::Exact(vec![
ConcreteDataType::string_datatype(),
ConcreteDataType::string_datatype(),
]),
],
Volatility::Immutable,
)
@@ -104,101 +106,44 @@ fn signature() -> Signature {
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::sync::Arc;
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use common_query::error::Error;
use datatypes::vectors::{StringVector, VectorRef};
use crate::admin::reconcile_table::ReconcileTableFunction;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
#[tokio::test]
async fn test_reconcile_table() {
common_telemetry::init_default_ut_logging();
// reconcile_table(table_name)
let factory: ScalarFunctionFactory = ReconcileTableFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"test",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
let f = ReconcileTableFunction;
let args = vec![Arc::new(StringVector::from(vec!["test"])) as _];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
// reconcile_table(table_name, resolve_strategy)
let factory: ScalarFunctionFactory = ReconcileTableFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseMetasrv"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(result_array.value(0), "test_pid");
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some("test_pid".to_string()))
);
}
}
let f = ReconcileTableFunction;
let args = vec![
Arc::new(StringVector::from(vec!["test"])) as _,
Arc::new(StringVector::from(vec!["UseMetasrv"])) as _,
];
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec!["test_pid"]));
assert_eq!(expect, result);
// unsupported input data type
let factory: ScalarFunctionFactory = ReconcileTableFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(StringArray::from(vec!["test"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["UseMetasrv"]))),
ColumnarValue::Array(Arc::new(StringArray::from(vec!["10"]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::Utf8, false)),
Arc::new(Field::new("arg_1", DataType::Utf8, false)),
Arc::new(Field::new("arg_2", DataType::Utf8, false)),
],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let _err = f.invoke_async_with_args(func_args).await.unwrap_err();
// Note: Error type is DataFusionError at this level, not common_query::Error
let f = ReconcileTableFunction;
let args = vec![
Arc::new(StringVector::from(vec!["test"])) as _,
Arc::new(StringVector::from(vec!["UseMetasrv"])) as _,
Arc::new(StringVector::from(vec!["10"])) as _,
];
let err = f.eval(FunctionContext::mock(), &args).await.unwrap_err();
assert_matches!(err, Error::UnsupportedInputDataType { .. });
}
}

View File

@@ -18,8 +18,7 @@ use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use datafusion_expr::{Signature, TypeSignature, Volatility};
use datatypes::data_type::DataType;
use common_query::prelude::{Signature, TypeSignature, Volatility};
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
@@ -83,13 +82,7 @@ fn signature() -> Signature {
Signature::one_of(
vec![
// remove_region_follower(region_id, peer_id)
TypeSignature::Uniform(
2,
ConcreteDataType::numerics()
.into_iter()
.map(|dt| dt.as_arrow_type())
.collect(),
),
TypeSignature::Uniform(2, ConcreteDataType::numerics()),
],
Volatility::Immutable,
)
@@ -99,57 +92,38 @@ fn signature() -> Signature {
mod tests {
use std::sync::Arc;
use arrow::array::UInt64Array;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use common_query::prelude::TypeSignature;
use datatypes::vectors::{UInt64Vector, VectorRef};
use super::*;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
#[test]
fn test_remove_region_follower_misc() {
let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
let f = RemoveRegionFollowerFunction;
assert_eq!("remove_region_follower", f.name());
assert_eq!(DataType::UInt64, f.return_type(&[]).unwrap());
assert_eq!(
ConcreteDataType::uint64_datatype(),
f.return_type(&[]).unwrap()
);
assert!(matches!(f.signature(),
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::OneOf(sigs),
volatility: datafusion_expr::Volatility::Immutable
Signature {
type_signature: TypeSignature::OneOf(sigs),
volatility: Volatility::Immutable
} if sigs.len() == 1));
}
#[tokio::test]
async fn test_remove_region_follower() {
let factory: ScalarFunctionFactory = RemoveRegionFollowerFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let f = RemoveRegionFollowerFunction;
let args = vec![1, 1];
let args = args
.into_iter()
.map(|arg| Arc::new(UInt64Vector::from_slice([arg])) as _)
.collect::<Vec<_>>();
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
ColumnarValue::Array(Arc::new(UInt64Array::from(vec![1]))),
],
arg_fields: vec![
Arc::new(Field::new("arg_0", DataType::UInt64, false)),
Arc::new(Field::new("arg_1", DataType::UInt64, false)),
],
return_field: Arc::new(Field::new("result", DataType::UInt64, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
assert_eq!(result_array.value(0), 0u64);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(scalar, datafusion_common::ScalarValue::UInt64(Some(0)));
}
}
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(UInt64Vector::from_slice([0u64]));
assert_eq!(result, expect);
}
}

View File

@@ -25,14 +25,14 @@
use std::sync::Arc;
use arrow::array::StructArray;
use arrow_schema::{FieldRef, Fields};
use arrow_schema::Fields;
use common_telemetry::debug;
use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::optimizer::AnalyzerRule;
use datafusion::physical_planner::create_aggregate_expr_and_maybe_filter;
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
@@ -146,7 +146,6 @@ impl StateMergeHelper {
};
let original_input_types = aggr_func
.params
.args
.iter()
.map(|e| e.get_type(&aggr.input.schema()))
@@ -157,7 +156,11 @@ impl StateMergeHelper {
let expr = AggregateFunction {
func: Arc::new(state_func.into()),
params: aggr_func.params.clone(),
args: aggr_func.args.clone(),
distinct: aggr_func.distinct,
filter: aggr_func.filter.clone(),
order_by: aggr_func.order_by.clone(),
null_treatment: aggr_func.null_treatment,
};
let expr = Expr::AggregateFunction(expr);
let lower_state_output_col_name = expr.schema_name().to_string();
@@ -179,10 +182,11 @@ impl StateMergeHelper {
let arg = Expr::Column(Column::new_unqualified(lower_state_output_col_name));
let expr = AggregateFunction {
func: Arc::new(merge_func.into()),
params: AggregateFunctionParams {
args: vec![arg],
..aggr_func.params.clone()
},
args: vec![arg],
distinct: aggr_func.distinct,
filter: aggr_func.filter.clone(),
order_by: aggr_func.order_by.clone(),
null_treatment: aggr_func.null_treatment,
};
// alias to the original aggregate expr's schema name, so parent plan can refer to it
@@ -243,8 +247,15 @@ impl StateWrapper {
pub fn deduce_aggr_return_type(
&self,
acc_args: &datafusion_expr::function::AccumulatorArgs,
) -> datafusion_common::Result<FieldRef> {
self.inner.return_field(acc_args.schema.fields())
) -> datafusion_common::Result<DataType> {
let input_exprs = acc_args.exprs;
let input_schema = acc_args.schema;
let input_types = input_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>, _>>()?;
let return_type = self.inner.return_type(&input_types)?;
Ok(return_type)
}
}
@@ -254,13 +265,14 @@ impl AggregateUDFImpl for StateWrapper {
acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<Box<dyn Accumulator>> {
// fix and recover proper acc args for the original aggregate function.
let state_type = acc_args.return_type().clone();
let state_type = acc_args.return_type.clone();
let inner = {
let old_return_type = self.deduce_aggr_return_type(&acc_args)?;
let acc_args = datafusion_expr::function::AccumulatorArgs {
return_field: self.deduce_aggr_return_type(&acc_args)?,
return_type: &old_return_type,
schema: acc_args.schema,
ignore_nulls: acc_args.ignore_nulls,
order_bys: acc_args.order_bys,
ordering_req: acc_args.ordering_req,
is_reversed: acc_args.is_reversed,
name: acc_args.name,
is_distinct: acc_args.is_distinct,
@@ -285,15 +297,11 @@ impl AggregateUDFImpl for StateWrapper {
/// Return state_fields as the output struct type.
///
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
let input_fields = &arg_types
.iter()
.map(|x| Arc::new(Field::new("x", x.clone(), false)))
.collect::<Vec<_>>();
let old_return_type = self.inner.return_type(arg_types)?;
let state_fields_args = StateFieldsArgs {
name: self.inner().name(),
input_fields,
return_field: self.inner.return_field(input_fields)?,
input_types: arg_types,
return_type: &old_return_type,
// TODO(discord9): how to get this?, probably ok?
ordering_fields: &[],
is_distinct: false,
@@ -307,11 +315,12 @@ impl AggregateUDFImpl for StateWrapper {
fn state_fields(
&self,
args: datafusion_expr::function::StateFieldsArgs,
) -> datafusion_common::Result<Vec<FieldRef>> {
) -> datafusion_common::Result<Vec<Field>> {
let old_return_type = self.inner.return_type(args.input_types)?;
let state_fields_args = StateFieldsArgs {
name: args.name,
input_fields: args.input_fields,
return_field: self.inner.return_field(args.input_fields)?,
input_types: args.input_types,
return_type: &old_return_type,
ordering_fields: args.ordering_fields,
is_distinct: args.is_distinct,
};
@@ -493,7 +502,7 @@ impl AggregateUDFImpl for MergeWrapper {
fn state_fields(
&self,
_args: datafusion_expr::function::StateFieldsArgs,
) -> datafusion_common::Result<Vec<FieldRef>> {
) -> datafusion_common::Result<Vec<Field>> {
self.original_phy_expr.state_fields()
}
}

View File

@@ -35,7 +35,7 @@ use datafusion::prelude::SessionContext;
use datafusion_common::{Column, TableReference};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::sqlparser::ast::NullTreatment;
use datafusion_expr::{lit, Aggregate, Expr, LogicalPlan, SortExpr, TableScan};
use datafusion_expr::{Aggregate, Expr, LogicalPlan, SortExpr, TableScan};
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datatypes::arrow_array::StringArray;
@@ -234,7 +234,7 @@ async fn test_sum_udaf() {
vec![Expr::Column(Column::new_unqualified("number"))],
false,
None,
vec![],
None,
None,
))],
)
@@ -250,7 +250,7 @@ async fn test_sum_udaf() {
vec![Expr::Column(Column::new_unqualified("number"))],
false,
None,
vec![],
None,
None,
))],
)
@@ -290,7 +290,7 @@ async fn test_sum_udaf() {
vec![Expr::Column(Column::new_unqualified("__sum_state(number)"))],
false,
None,
vec![],
None,
None,
))
.alias("sum(number)")],
@@ -378,7 +378,7 @@ async fn test_avg_udaf() {
vec![Expr::Column(Column::new_unqualified("number"))],
false,
None,
vec![],
None,
None,
))],
)
@@ -395,7 +395,7 @@ async fn test_avg_udaf() {
vec![Expr::Column(Column::new_unqualified("number"))],
false,
None,
vec![],
None,
None,
))],
)
@@ -449,7 +449,7 @@ async fn test_avg_udaf() {
vec![Expr::Column(Column::new_unqualified("__avg_state(number)"))],
false,
None,
vec![],
None,
None,
))
.alias("avg(number)")],
@@ -551,7 +551,7 @@ async fn test_udaf_correct_eval_result() {
expected_fn: Option<ExpectedFn>,
distinct: bool,
filter: Option<Box<Expr>>,
order_by: Vec<SortExpr>,
order_by: Option<Vec<SortExpr>>,
null_treatment: Option<NullTreatment>,
}
type ExpectedFn = fn(ArrayRef) -> bool;
@@ -575,7 +575,7 @@ async fn test_udaf_correct_eval_result() {
expected_fn: None,
distinct: false,
filter: None,
order_by: vec![],
order_by: None,
null_treatment: None,
},
TestCase {
@@ -596,7 +596,7 @@ async fn test_udaf_correct_eval_result() {
expected_fn: None,
distinct: false,
filter: None,
order_by: vec![],
order_by: None,
null_treatment: None,
},
TestCase {
@@ -619,7 +619,7 @@ async fn test_udaf_correct_eval_result() {
expected_fn: None,
distinct: false,
filter: None,
order_by: vec![],
order_by: None,
null_treatment: None,
},
TestCase {
@@ -630,8 +630,8 @@ async fn test_udaf_correct_eval_result() {
true,
)])),
args: vec![
lit(128i64),
lit(0.05f64),
Expr::Literal(ScalarValue::Int64(Some(128))),
Expr::Literal(ScalarValue::Float64(Some(0.05))),
Expr::Column(Column::new_unqualified("number")),
],
input: vec![Arc::new(Float64Array::from(vec![
@@ -659,7 +659,7 @@ async fn test_udaf_correct_eval_result() {
}),
distinct: false,
filter: None,
order_by: vec![],
order_by: None,
null_treatment: None,
},
TestCase {
@@ -690,7 +690,7 @@ async fn test_udaf_correct_eval_result() {
}),
distinct: false,
filter: None,
order_by: vec![],
order_by: None,
null_treatment: None,
},
// TODO(discord9): udd_merge/hll_merge/geo_path/quantile_aggr tests

View File

@@ -41,7 +41,7 @@ use datatypes::arrow::array::{
Array, ArrayRef, AsArray, BooleanArray, Int64Array, ListArray, UInt64Array,
};
use datatypes::arrow::buffer::{OffsetBuffer, ScalarBuffer};
use datatypes::arrow::datatypes::{DataType, Field, FieldRef};
use datatypes::arrow::datatypes::{DataType, Field};
use crate::function_registry::FunctionRegistry;
@@ -94,14 +94,14 @@ impl AggregateUDFImpl for CountHash {
false
}
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Ok(vec![Arc::new(Field::new_list(
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
Ok(vec![Field::new_list(
format_state_name(args.name, "count_hash"),
Field::new_list_field(DataType::UInt64, true),
// For count_hash accumulator, null list item stands for an
// empty value set (i.e., all NULL value so far for that group).
true,
))])
)])
}
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {

View File

@@ -21,7 +21,7 @@ pub(crate) struct GeoFunction;
impl GeoFunction {
pub fn register(registry: &FunctionRegistry) {
registry.register_aggr(encoding::JsonEncodePathAccumulator::uadf_impl());
registry.register_aggr(geo_path::GeoPathAccumulator::uadf_impl());
registry.register_aggr(encoding::JsonPathAccumulator::uadf_impl());
}
}

View File

@@ -14,332 +14,223 @@
use std::sync::Arc;
use arrow::array::AsArray;
use datafusion::arrow::array::{Array, ArrayRef};
use datafusion::common::cast::as_primitive_array;
use datafusion::error::{DataFusionError, Result as DfResult};
use datafusion::logical_expr::{Accumulator as DfAccumulator, AggregateUDF, Volatility};
use datafusion::prelude::create_udaf;
use datafusion_common::cast::{as_list_array, as_struct_array};
use datafusion_common::ScalarValue;
use datatypes::arrow::array::{Float64Array, Int64Array, ListArray, StructArray};
use datatypes::arrow::datatypes::{
DataType, Field, Float64Type, Int64Type, TimeUnit, TimestampNanosecondType,
use common_error::ext::{BoxedError, PlainError};
use common_error::status_code::StatusCode;
use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{self, InvalidInputStateSnafu, Result};
use common_query::logical_plan::accumulator::AggrFuncTypeStore;
use common_query::logical_plan::{
create_aggregate_function, Accumulator, AggregateFunctionCreator,
};
use datatypes::compute::{self, sort_to_indices};
use common_query::prelude::AccumulatorCreatorFunction;
use common_time::Timestamp;
use datafusion_expr::AggregateUDF;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::{ListValue, Value};
use datatypes::vectors::VectorRef;
use snafu::{ensure, ResultExt};
pub const JSON_ENCODE_PATH_NAME: &str = "json_encode_path";
use crate::scalars::geo::helpers::{ensure_columns_len, ensure_columns_n};
const LATITUDE_FIELD: &str = "lat";
const LONGITUDE_FIELD: &str = "lng";
const TIMESTAMP_FIELD: &str = "timestamp";
const DEFAULT_LIST_FIELD_NAME: &str = "item";
#[derive(Debug, Default)]
pub struct JsonEncodePathAccumulator {
/// Accumulator of lat, lng, timestamp tuples
#[derive(Debug)]
pub struct JsonPathAccumulator {
timestamp_type: ConcreteDataType,
lat: Vec<Option<f64>>,
lng: Vec<Option<f64>>,
timestamp: Vec<Option<i64>>,
timestamp: Vec<Option<Timestamp>>,
}
impl JsonEncodePathAccumulator {
pub fn new() -> Self {
Self::default()
}
pub fn uadf_impl() -> AggregateUDF {
create_udaf(
JSON_ENCODE_PATH_NAME,
// Input types: lat, lng, timestamp
vec![
DataType::Float64,
DataType::Float64,
DataType::Timestamp(TimeUnit::Nanosecond, None),
],
// Output type: geojson compatible linestring
Arc::new(DataType::Utf8),
Volatility::Immutable,
// Create the accumulator
Arc::new(|_| Ok(Box::new(Self::new()))),
// Intermediate state types
Arc::new(vec![DataType::Struct(
vec![
Field::new(
LATITUDE_FIELD,
DataType::List(Arc::new(Field::new(
DEFAULT_LIST_FIELD_NAME,
DataType::Float64,
true,
))),
false,
),
Field::new(
LONGITUDE_FIELD,
DataType::List(Arc::new(Field::new(
DEFAULT_LIST_FIELD_NAME,
DataType::Float64,
true,
))),
false,
),
Field::new(
TIMESTAMP_FIELD,
DataType::List(Arc::new(Field::new(
DEFAULT_LIST_FIELD_NAME,
DataType::Int64,
true,
))),
false,
),
]
.into(),
)]),
)
}
}
impl DfAccumulator for JsonEncodePathAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion::error::Result<()> {
if values.len() != 3 {
return Err(DataFusionError::Internal(format!(
"Expected 3 columns for json_encode_path, got {}",
values.len()
)));
impl JsonPathAccumulator {
fn new(timestamp_type: ConcreteDataType) -> Self {
Self {
lat: Vec::default(),
lng: Vec::default(),
timestamp: Vec::default(),
timestamp_type,
}
}
let lat_array = as_primitive_array::<Float64Type>(&values[0])?;
let lng_array = as_primitive_array::<Float64Type>(&values[1])?;
let ts_array = as_primitive_array::<TimestampNanosecondType>(&values[2])?;
/// Create a new `AggregateUDF` for the `json_encode_path` aggregate function.
pub fn uadf_impl() -> AggregateUDF {
create_aggregate_function(
"json_encode_path".to_string(),
3,
Arc::new(JsonPathEncodeFunctionCreator::default()),
)
.into()
}
}
let size = lat_array.len();
self.lat.reserve(size);
self.lng.reserve(size);
impl Accumulator for JsonPathAccumulator {
fn state(&self) -> Result<Vec<Value>> {
Ok(vec![
Value::List(ListValue::new(
self.lat.iter().map(|i| Value::from(*i)).collect(),
ConcreteDataType::float64_datatype(),
)),
Value::List(ListValue::new(
self.lng.iter().map(|i| Value::from(*i)).collect(),
ConcreteDataType::float64_datatype(),
)),
Value::List(ListValue::new(
self.timestamp.iter().map(|i| Value::from(*i)).collect(),
self.timestamp_type.clone(),
)),
])
}
fn update_batch(&mut self, columns: &[VectorRef]) -> Result<()> {
// update batch as in datafusion just provides the accumulator original
// input.
//
// columns is vec of [`lat`, `lng`, `timestamp`]
// where
// - `lat` is a vector of `Value::Float64` or similar type. Each item in
// the vector is a row in given dataset.
// - so on so forth for `lng` and `timestamp`
ensure_columns_n!(columns, 3);
let lat = &columns[0];
let lng = &columns[1];
let ts = &columns[2];
let size = lat.len();
for idx in 0..size {
self.lat.push(if lat_array.is_null(idx) {
None
} else {
Some(lat_array.value(idx))
});
self.lng.push(if lng_array.is_null(idx) {
None
} else {
Some(lng_array.value(idx))
});
self.timestamp.push(if ts_array.is_null(idx) {
None
} else {
Some(ts_array.value(idx))
});
self.lat.push(lat.get(idx).as_f64_lossy());
self.lng.push(lng.get(idx).as_f64_lossy());
self.timestamp.push(ts.get(idx).as_timestamp());
}
Ok(())
}
fn evaluate(&mut self) -> DfResult<ScalarValue> {
let unordered_lng_array = Float64Array::from(self.lng.clone());
let unordered_lat_array = Float64Array::from(self.lat.clone());
let ts_array = Int64Array::from(self.timestamp.clone());
fn merge_batch(&mut self, states: &[VectorRef]) -> Result<()> {
// merge batch as in datafusion gives state accumulated from the data
// returned from child accumulators' state() call
// In our particular implementation, the data structure is like
//
// states is vec of [`lat`, `lng`, `timestamp`]
// where
// - `lat` is a vector of `Value::List`. Each item in the list is all
// coordinates from a child accumulator.
// - so on so forth for `lng` and `timestamp`
let ordered_indices = sort_to_indices(&ts_array, None, None)?;
let lat_array = compute::take(&unordered_lat_array, &ordered_indices, None)?;
let lng_array = compute::take(&unordered_lng_array, &ordered_indices, None)?;
ensure_columns_n!(states, 3);
let len = ts_array.len();
let lat_array = lat_array.as_primitive::<Float64Type>();
let lng_array = lng_array.as_primitive::<Float64Type>();
let lat_lists = &states[0];
let lng_lists = &states[1];
let ts_lists = &states[2];
let mut coords = Vec::with_capacity(len);
for i in 0..len {
let lng = lng_array.value(i);
let lat = lat_array.value(i);
coords.push(vec![lng, lat]);
}
let len = lat_lists.len();
let result = serde_json::to_string(&coords)
.map_err(|e| DataFusionError::Execution(format!("Failed to encode json, {}", e)))?;
for idx in 0..len {
if let Some(lat_list) = lat_lists
.get(idx)
.as_list()
.map_err(BoxedError::new)
.context(error::ExecuteSnafu)?
{
for v in lat_list.items() {
self.lat.push(v.as_f64_lossy());
}
}
Ok(ScalarValue::Utf8(Some(result)))
}
if let Some(lng_list) = lng_lists
.get(idx)
.as_list()
.map_err(BoxedError::new)
.context(error::ExecuteSnafu)?
{
for v in lng_list.items() {
self.lng.push(v.as_f64_lossy());
}
}
fn size(&self) -> usize {
// Base size of JsonEncodePathAccumulator struct fields
let mut total_size = std::mem::size_of::<Self>();
// Size of vectors (approximation)
total_size += self.lat.capacity() * std::mem::size_of::<Option<f64>>();
total_size += self.lng.capacity() * std::mem::size_of::<Option<f64>>();
total_size += self.timestamp.capacity() * std::mem::size_of::<Option<i64>>();
total_size
}
fn state(&mut self) -> datafusion::error::Result<Vec<ScalarValue>> {
let lat_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
Some(self.lat.clone()),
]));
let lng_array = Arc::new(ListArray::from_iter_primitive::<Float64Type, _, _>(vec![
Some(self.lng.clone()),
]));
let ts_array = Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(self.timestamp.clone()),
]));
let state_struct = StructArray::new(
vec![
Field::new(
LATITUDE_FIELD,
DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
false,
),
Field::new(
LONGITUDE_FIELD,
DataType::List(Arc::new(Field::new("item", DataType::Float64, true))),
false,
),
Field::new(
TIMESTAMP_FIELD,
DataType::List(Arc::new(Field::new("item", DataType::Int64, true))),
false,
),
]
.into(),
vec![lat_array, lng_array, ts_array],
None,
);
Ok(vec![ScalarValue::Struct(Arc::new(state_struct))])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion::error::Result<()> {
if states.len() != 1 {
return Err(DataFusionError::Internal(format!(
"Expected 1 states for json_encode_path, got {}",
states.len()
)));
}
for state in states {
let state = as_struct_array(state)?;
let lat_list = as_list_array(state.column(0))?.value(0);
let lat_array = as_primitive_array::<Float64Type>(&lat_list)?;
let lng_list = as_list_array(state.column(1))?.value(0);
let lng_array = as_primitive_array::<Float64Type>(&lng_list)?;
let ts_list = as_list_array(state.column(2))?.value(0);
let ts_array = as_primitive_array::<Int64Type>(&ts_list)?;
self.lat.extend(lat_array);
self.lng.extend(lng_array);
self.timestamp.extend(ts_array);
if let Some(ts_list) = ts_lists
.get(idx)
.as_list()
.map_err(BoxedError::new)
.context(error::ExecuteSnafu)?
{
for v in ts_list.items() {
self.timestamp.push(v.as_timestamp());
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use datafusion::arrow::array::{Float64Array, TimestampNanosecondArray};
use datafusion::scalar::ScalarValue;
fn evaluate(&self) -> Result<Value> {
let mut work_vec: Vec<(&Option<f64>, &Option<f64>, &Option<Timestamp>)> = self
.lat
.iter()
.zip(self.lng.iter())
.zip(self.timestamp.iter())
.map(|((a, b), c)| (a, b, c))
.collect();
use super::*;
// sort by timestamp, we treat null timestamp as 0
work_vec.sort_unstable_by_key(|tuple| tuple.2.unwrap_or_else(|| Timestamp::new_second(0)));
#[test]
fn test_json_encode_path_basic() {
let mut accumulator = JsonEncodePathAccumulator::new();
let result = serde_json::to_string(
&work_vec
.into_iter()
// note that we transform to lng,lat for geojson compatibility
.map(|(lat, lng, _)| vec![lng, lat])
.collect::<Vec<Vec<&Option<f64>>>>(),
)
.map_err(|e| {
BoxedError::new(PlainError::new(
format!("Serialization failure: {}", e),
StatusCode::EngineExecuteQuery,
))
})
.context(error::ExecuteSnafu)?;
// Create test data
let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
let ts_array = Arc::new(TimestampNanosecondArray::from(vec![100, 200, 300]));
// Update batch
accumulator
.update_batch(&[lat_array, lng_array, ts_array])
.unwrap();
// Evaluate
let result = accumulator.evaluate().unwrap();
assert_eq!(
result,
ScalarValue::Utf8(Some("[[4.0,1.0],[5.0,2.0],[6.0,3.0]]".to_string()))
);
}
#[test]
fn test_json_encode_path_sort_by_timestamp() {
let mut accumulator = JsonEncodePathAccumulator::new();
// Create test data with unordered timestamps
let lat_array = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]));
let lng_array = Arc::new(Float64Array::from(vec![4.0, 5.0, 6.0]));
let ts_array = Arc::new(TimestampNanosecondArray::from(vec![300, 100, 200]));
// Update batch
accumulator
.update_batch(&[lat_array, lng_array, ts_array])
.unwrap();
// Evaluate
let result = accumulator.evaluate().unwrap();
assert_eq!(
result,
ScalarValue::Utf8(Some("[[5.0,2.0],[6.0,3.0],[4.0,1.0]]".to_string()))
);
}
#[test]
fn test_json_encode_path_merge() {
let mut accumulator1 = JsonEncodePathAccumulator::new();
let mut accumulator2 = JsonEncodePathAccumulator::new();
// Create test data for first accumulator
let lat_array1 = Arc::new(Float64Array::from(vec![1.0]));
let lng_array1 = Arc::new(Float64Array::from(vec![4.0]));
let ts_array1 = Arc::new(TimestampNanosecondArray::from(vec![100]));
// Create test data for second accumulator
let lat_array2 = Arc::new(Float64Array::from(vec![2.0]));
let lng_array2 = Arc::new(Float64Array::from(vec![5.0]));
let ts_array2 = Arc::new(TimestampNanosecondArray::from(vec![200]));
// Update batches
accumulator1
.update_batch(&[lat_array1, lng_array1, ts_array1])
.unwrap();
accumulator2
.update_batch(&[lat_array2, lng_array2, ts_array2])
.unwrap();
// Get states
let state1 = accumulator1.state().unwrap();
let state2 = accumulator2.state().unwrap();
// Create a merged accumulator
let mut merged = JsonEncodePathAccumulator::new();
// Extract the struct arrays from the states
let state_array1 = match &state1[0] {
ScalarValue::Struct(array) => array.clone(),
_ => panic!("Expected Struct scalar value"),
};
let state_array2 = match &state2[0] {
ScalarValue::Struct(array) => array.clone(),
_ => panic!("Expected Struct scalar value"),
};
// Merge state arrays
merged.merge_batch(&[state_array1]).unwrap();
merged.merge_batch(&[state_array2]).unwrap();
// Evaluate merged result
let result = merged.evaluate().unwrap();
assert_eq!(
result,
ScalarValue::Utf8(Some("[[4.0,1.0],[5.0,2.0]]".to_string()))
);
Ok(Value::String(result.into()))
}
}
/// This function accept rows of lat, lng and timestamp, sort with timestamp and
/// encoding them into a geojson-like path.
///
/// Example:
///
/// ```sql
/// SELECT json_encode_path(lat, lon, timestamp) FROM table [group by ...];
/// ```
///
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct JsonPathEncodeFunctionCreator {}
impl AggregateFunctionCreator for JsonPathEncodeFunctionCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
let ts_type = types[2].clone();
Ok(Box::new(JsonPathAccumulator::new(ts_type)))
});
creator
}
fn output_type(&self) -> Result<ConcreteDataType> {
Ok(ConcreteDataType::string_datatype())
}
fn state_types(&self) -> Result<Vec<ConcreteDataType>> {
let input_types = self.input_types()?;
ensure!(input_types.len() == 3, InvalidInputStateSnafu);
let timestamp_type = input_types[2].clone();
Ok(vec![
ConcreteDataType::list_datatype(ConcreteDataType::float64_datatype()),
ConcreteDataType::list_datatype(ConcreteDataType::float64_datatype()),
ConcreteDataType::list_datatype(timestamp_type),
])
}
}

View File

@@ -12,20 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringArray};
use arrow_schema::{DataType, Field};
use datafusion::logical_expr::{Signature, TypeSignature, Volatility};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{Accumulator, AggregateUDF, SimpleAggregateUDF};
use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
use nalgebra::{Const, DVectorView, Dyn, OVector};
use crate::scalars::vector::impl_conv::{
binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu};
use common_query::logical_plan::{
create_aggregate_function, Accumulator, AggregateFunctionCreator,
};
use common_query::prelude::AccumulatorCreatorFunction;
use datafusion_expr::AggregateUDF;
use datatypes::prelude::{ConcreteDataType, Value, *};
use datatypes::vectors::VectorRef;
use nalgebra::{Const, DVectorView, Dyn, OVector};
use snafu::ensure;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
/// Aggregates by multiplying elements across the same dimension, returns a vector.
#[derive(Debug, Default)]
@@ -34,42 +35,57 @@ pub struct VectorProduct {
has_null: bool,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct VectorProductCreator {}
impl AggregateFunctionCreator for VectorProductCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
ensure!(
types.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
types.len()
)
}
);
let input_type = &types[0];
match input_type {
ConcreteDataType::String(_) | ConcreteDataType::Binary(_) => {
Ok(Box::new(VectorProduct::default()))
}
_ => {
let err_msg = format!(
"\"VEC_PRODUCT\" aggregate function not support data type {:?}",
input_type.logical_type_id(),
);
CreateAccumulatorSnafu { err_msg }.fail()?
}
}
});
creator
}
fn output_type(&self) -> common_query::error::Result<ConcreteDataType> {
Ok(ConcreteDataType::binary_datatype())
}
fn state_types(&self) -> common_query::error::Result<Vec<ConcreteDataType>> {
Ok(vec![self.output_type()?])
}
}
impl VectorProduct {
/// Create a new `AggregateUDF` for the `vec_product` aggregate function.
pub fn uadf_impl() -> AggregateUDF {
let signature = Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Binary]),
],
Volatility::Immutable,
);
let udaf = SimpleAggregateUDF::new_with_signature(
"vec_product",
signature,
DataType::Binary,
Arc::new(Self::accumulator),
vec![Arc::new(Field::new("x", DataType::Binary, true))],
);
AggregateUDF::from(udaf)
}
fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if args.schema.fields().len() != 1 {
return Err(datafusion_common::DataFusionError::Internal(format!(
"expect creating `VEC_PRODUCT` with only one input field, actual {}",
args.schema.fields().len()
)));
}
let t = args.schema.field(0).data_type();
if !matches!(t, DataType::Utf8 | DataType::Binary) {
return Err(datafusion_common::DataFusionError::Internal(format!(
"unexpected input datatype {t} when creating `VEC_PRODUCT`"
)));
}
Ok(Box::new(VectorProduct::default()))
create_aggregate_function(
"vec_product".to_string(),
1,
Arc::new(VectorProductCreator::default()),
)
.into()
}
fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
@@ -78,82 +94,67 @@ impl VectorProduct {
})
}
fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
fn update(&mut self, values: &[VectorRef], is_update: bool) -> Result<(), Error> {
if values.is_empty() || self.has_null {
return Ok(());
};
let column = &values[0];
let len = column.len();
let vectors = match values[0].data_type() {
DataType::Utf8 => {
let arr: &StringArray = values[0].as_string();
arr.iter()
.filter_map(|x| x.map(|s| parse_veclit_from_strlit(s).map_err(Into::into)))
.map(|x| x.map(Cow::Owned))
.collect::<Result<Vec<_>>>()?
match as_veclit_if_const(column)? {
Some(column) => {
let vec_column = DVectorView::from_slice(&column, column.len()).scale(len as f32);
*self.inner(vec_column.len()) =
(*self.inner(vec_column.len())).component_mul(&vec_column);
}
DataType::Binary => {
let arr: &BinaryArray = values[0].as_binary();
arr.iter()
.filter_map(|x| x.map(|b| binlit_as_veclit(b).map_err(Into::into)))
.collect::<Result<Vec<_>>>()?
None => {
for i in 0..len {
let Some(arg0) = as_veclit(column.get_ref(i))? else {
if is_update {
self.has_null = true;
self.product = None;
}
return Ok(());
};
let vec_column = DVectorView::from_slice(&arg0, arg0.len());
*self.inner(vec_column.len()) =
(*self.inner(vec_column.len())).component_mul(&vec_column);
}
}
_ => {
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
"unsupported data type {} for `VEC_PRODUCT`",
values[0].data_type()
)))
}
};
if vectors.len() != values[0].len() {
if is_update {
self.has_null = true;
self.product = None;
}
return Ok(());
}
vectors.iter().for_each(|v| {
let v = DVectorView::from_slice(v, v.len());
let inner = self.inner(v.len());
*inner = inner.component_mul(&v);
});
Ok(())
}
}
impl Accumulator for VectorProduct {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
fn state(&self) -> common_query::error::Result<Vec<Value>> {
self.evaluate().map(|v| vec![v])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
fn update_batch(&mut self, values: &[VectorRef]) -> common_query::error::Result<()> {
self.update(values, true)
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
fn merge_batch(&mut self, states: &[VectorRef]) -> common_query::error::Result<()> {
self.update(states, false)
}
fn evaluate(&mut self) -> Result<ScalarValue> {
fn evaluate(&self) -> common_query::error::Result<Value> {
match &self.product {
None => Ok(ScalarValue::Binary(None)),
Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
vector.as_slice(),
)))),
None => Ok(Value::Null),
Some(vector) => {
let v = vector.as_slice();
Ok(Value::from(veclit_to_binlit(v)))
}
}
}
fn size(&self) -> usize {
size_of_val(self)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::{ConstantVector, StringVector, Vector};
use datatypes::vectors::{ConstantVector, StringVector};
use super::*;
@@ -164,60 +165,59 @@ mod tests {
vec_product.update_batch(&[]).unwrap();
assert!(vec_product.product.is_none());
assert!(!vec_product.has_null);
assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
assert_eq!(Value::Null, vec_product.evaluate().unwrap());
// test update one not-null value
let mut vec_product = VectorProduct::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Some(
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Some(
"[1.0,2.0,3.0]".to_string(),
)]))];
vec_product.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
Value::from(veclit_to_binlit(&[1.0, 2.0, 3.0])),
vec_product.evaluate().unwrap()
);
// test update one null value
let mut vec_product = VectorProduct::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Option::<String>::None]))];
vec_product.update_batch(&v).unwrap();
assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
assert_eq!(Value::Null, vec_product.evaluate().unwrap());
// test update no null-value batch
let mut vec_product = VectorProduct::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
]))];
vec_product.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[28.0, 80.0, 162.0]))),
Value::from(veclit_to_binlit(&[28.0, 80.0, 162.0])),
vec_product.evaluate().unwrap()
);
// test update null-value batch
let mut vec_product = VectorProduct::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
None,
Some("[7.0,8.0,9.0]".to_string()),
]))];
vec_product.update_batch(&v).unwrap();
assert_eq!(ScalarValue::Binary(None), vec_product.evaluate().unwrap());
assert_eq!(Value::Null, vec_product.evaluate().unwrap());
// test update with constant vector
let mut vec_product = VectorProduct::default();
let v: Vec<ArrayRef> = vec![Arc::new(ConstantVector::new(
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
4,
))
.to_arrow_array()];
))];
vec_product.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 16.0, 81.0]))),
Value::from(veclit_to_binlit(&[4.0, 8.0, 12.0])),
vec_product.evaluate().unwrap()
);
}

View File

@@ -14,18 +14,19 @@
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, AsArray, BinaryArray, StringArray};
use arrow_schema::{DataType, Field};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
Accumulator, AggregateUDF, Signature, SimpleAggregateUDF, TypeSignature, Volatility,
use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use common_query::error::{CreateAccumulatorSnafu, Error, InvalidFuncArgsSnafu};
use common_query::logical_plan::{
create_aggregate_function, Accumulator, AggregateFunctionCreator,
};
use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs;
use common_query::prelude::AccumulatorCreatorFunction;
use datafusion_expr::AggregateUDF;
use datatypes::prelude::{ConcreteDataType, Value, *};
use datatypes::vectors::VectorRef;
use nalgebra::{Const, DVectorView, Dyn, OVector};
use snafu::ensure;
use crate::scalars::vector::impl_conv::{
binlit_as_veclit, parse_veclit_from_strlit, veclit_to_binlit,
};
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};
/// The accumulator for the `vec_sum` aggregate function.
#[derive(Debug, Default)]
@@ -34,42 +35,57 @@ pub struct VectorSum {
has_null: bool,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
pub struct VectorSumCreator {}
impl AggregateFunctionCreator for VectorSumCreator {
fn creator(&self) -> AccumulatorCreatorFunction {
let creator: AccumulatorCreatorFunction = Arc::new(move |types: &[ConcreteDataType]| {
ensure!(
types.len() == 1,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly one, have: {}",
types.len()
)
}
);
let input_type = &types[0];
match input_type {
ConcreteDataType::String(_) | ConcreteDataType::Binary(_) => {
Ok(Box::new(VectorSum::default()))
}
_ => {
let err_msg = format!(
"\"VEC_SUM\" aggregate function not support data type {:?}",
input_type.logical_type_id(),
);
CreateAccumulatorSnafu { err_msg }.fail()?
}
}
});
creator
}
fn output_type(&self) -> common_query::error::Result<ConcreteDataType> {
Ok(ConcreteDataType::binary_datatype())
}
fn state_types(&self) -> common_query::error::Result<Vec<ConcreteDataType>> {
Ok(vec![self.output_type()?])
}
}
impl VectorSum {
/// Create a new `AggregateUDF` for the `vec_sum` aggregate function.
pub fn uadf_impl() -> AggregateUDF {
let signature = Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Binary]),
],
Volatility::Immutable,
);
let udaf = SimpleAggregateUDF::new_with_signature(
"vec_sum",
signature,
DataType::Binary,
Arc::new(Self::accumulator),
vec![Arc::new(Field::new("x", DataType::Binary, true))],
);
AggregateUDF::from(udaf)
}
fn accumulator(args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if args.schema.fields().len() != 1 {
return Err(datafusion_common::DataFusionError::Internal(format!(
"expect creating `VEC_SUM` with only one input field, actual {}",
args.schema.fields().len()
)));
}
let t = args.schema.field(0).data_type();
if !matches!(t, DataType::Utf8 | DataType::Binary) {
return Err(datafusion_common::DataFusionError::Internal(format!(
"unexpected input datatype {t} when creating `VEC_SUM`"
)));
}
Ok(Box::new(VectorSum::default()))
create_aggregate_function(
"vec_sum".to_string(),
1,
Arc::new(VectorSumCreator::default()),
)
.into()
}
fn inner(&mut self, len: usize) -> &mut OVector<f32, Dyn> {
@@ -77,87 +93,62 @@ impl VectorSum {
.get_or_insert_with(|| OVector::zeros_generic(Dyn(len), Const::<1>))
}
fn update(&mut self, values: &[ArrayRef], is_update: bool) -> Result<()> {
fn update(&mut self, values: &[VectorRef], is_update: bool) -> Result<(), Error> {
if values.is_empty() || self.has_null {
return Ok(());
};
let column = &values[0];
let len = column.len();
match values[0].data_type() {
DataType::Utf8 => {
let arr: &StringArray = values[0].as_string();
for s in arr.iter() {
let Some(s) = s else {
match as_veclit_if_const(column)? {
Some(column) => {
let vec_column = DVectorView::from_slice(&column, column.len()).scale(len as f32);
*self.inner(vec_column.len()) += vec_column;
}
None => {
for i in 0..len {
let Some(arg0) = as_veclit(column.get_ref(i))? else {
if is_update {
self.has_null = true;
self.sum = None;
}
return Ok(());
};
let values = parse_veclit_from_strlit(s)?;
let vec_column = DVectorView::from_slice(&values, values.len());
let vec_column = DVectorView::from_slice(&arg0, arg0.len());
*self.inner(vec_column.len()) += vec_column;
}
}
DataType::Binary => {
let arr: &BinaryArray = values[0].as_binary();
for b in arr.iter() {
let Some(b) = b else {
if is_update {
self.has_null = true;
self.sum = None;
}
return Ok(());
};
let values = binlit_as_veclit(b)?;
let vec_column = DVectorView::from_slice(&values, values.len());
*self.inner(vec_column.len()) += vec_column;
}
}
_ => {
return Err(datafusion_common::DataFusionError::NotImplemented(format!(
"unsupported data type {} for `VEC_SUM`",
values[0].data_type()
)))
}
}
Ok(())
}
}
impl Accumulator for VectorSum {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
fn state(&self) -> common_query::error::Result<Vec<Value>> {
self.evaluate().map(|v| vec![v])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
fn update_batch(&mut self, values: &[VectorRef]) -> common_query::error::Result<()> {
self.update(values, true)
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
fn merge_batch(&mut self, states: &[VectorRef]) -> common_query::error::Result<()> {
self.update(states, false)
}
fn evaluate(&mut self) -> Result<ScalarValue> {
fn evaluate(&self) -> common_query::error::Result<Value> {
match &self.sum {
None => Ok(ScalarValue::Binary(None)),
Some(vector) => Ok(ScalarValue::Binary(Some(veclit_to_binlit(
vector.as_slice(),
)))),
None => Ok(Value::Null),
Some(vector) => Ok(Value::from(veclit_to_binlit(vector.as_slice()))),
}
}
fn size(&self) -> usize {
size_of_val(self)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::StringArray;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::{ConstantVector, StringVector, Vector};
use datatypes::vectors::{ConstantVector, StringVector};
use super::*;
@@ -168,58 +159,57 @@ mod tests {
vec_sum.update_batch(&[]).unwrap();
assert!(vec_sum.sum.is_none());
assert!(!vec_sum.has_null);
assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
assert_eq!(Value::Null, vec_sum.evaluate().unwrap());
// test update one not-null value
let mut vec_sum = VectorSum::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Some(
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Some(
"[1.0,2.0,3.0]".to_string(),
)]))];
vec_sum.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[1.0, 2.0, 3.0]))),
Value::from(veclit_to_binlit(&[1.0, 2.0, 3.0])),
vec_sum.evaluate().unwrap()
);
// test update one null value
let mut vec_sum = VectorSum::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![Option::<String>::None]))];
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![Option::<String>::None]))];
vec_sum.update_batch(&v).unwrap();
assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
assert_eq!(Value::Null, vec_sum.evaluate().unwrap());
// test update no null-value batch
let mut vec_sum = VectorSum::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
Some("[7.0,8.0,9.0]".to_string()),
]))];
vec_sum.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[12.0, 15.0, 18.0]))),
Value::from(veclit_to_binlit(&[12.0, 15.0, 18.0])),
vec_sum.evaluate().unwrap()
);
// test update null-value batch
let mut vec_sum = VectorSum::default();
let v: Vec<ArrayRef> = vec![Arc::new(StringArray::from(vec![
let v: Vec<VectorRef> = vec![Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
None,
Some("[7.0,8.0,9.0]".to_string()),
]))];
vec_sum.update_batch(&v).unwrap();
assert_eq!(ScalarValue::Binary(None), vec_sum.evaluate().unwrap());
assert_eq!(Value::Null, vec_sum.evaluate().unwrap());
// test update with constant vector
let mut vec_sum = VectorSum::default();
let v: Vec<ArrayRef> = vec![Arc::new(ConstantVector::new(
let v: Vec<VectorRef> = vec![Arc::new(ConstantVector::new(
Arc::new(StringVector::from_vec(vec!["[1.0,2.0,3.0]".to_string()])),
4,
))
.to_arrow_array()];
))];
vec_sum.update_batch(&v).unwrap();
assert_eq!(
ScalarValue::Binary(Some(veclit_to_binlit(&[4.0, 8.0, 12.0]))),
Value::from(veclit_to_binlit(&[4.0, 8.0, 12.0])),
vec_sum.evaluate().unwrap()
);
}

View File

@@ -12,24 +12,28 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::datatypes::DataType as ArrowDataType;
use common_error::ext::BoxedError;
use common_macro::admin_fn;
use common_query::error::{
ExecuteSnafu, InvalidFuncArgsSnafu, MissingFlowServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use datafusion_expr::{Signature, Volatility};
use common_query::prelude::Signature;
use datafusion::logical_expr::Volatility;
use datatypes::value::{Value, ValueRef};
use session::context::QueryContextRef;
use snafu::{ensure, ResultExt};
use sql::ast::ObjectNamePartExt;
use sql::parser::ParserContext;
use store_api::storage::ConcreteDataType;
use crate::handlers::FlowServiceHandlerRef;
fn flush_signature() -> Signature {
Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
Signature::uniform(
1,
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
}
#[admin_fn(
@@ -81,9 +85,9 @@ fn parse_flush_flow(
let (catalog_name, flow_name) = match &obj_name.0[..] {
[flow_name] => (
query_ctx.current_catalog().to_string(),
flow_name.to_string_unquoted(),
flow_name.value.clone(),
),
[catalog, flow_name] => (catalog.to_string_unquoted(), flow_name.to_string_unquoted()),
[catalog, flow_name] => (catalog.value.clone(), flow_name.value.clone()),
_ => {
return InvalidFuncArgsSnafu {
err_msg: format!(
@@ -101,55 +105,44 @@ fn parse_flush_flow(
mod test {
use std::sync::Arc;
use datatypes::scalars::ScalarVector;
use datatypes::vectors::StringVector;
use session::context::QueryContext;
use super::*;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
#[test]
fn test_flush_flow_metadata() {
let factory: ScalarFunctionFactory = FlushFlowFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
let f = FlushFlowFunction;
assert_eq!("flush_flow", f.name());
assert_eq!(ArrowDataType::UInt64, f.return_type(&[]).unwrap());
let expected_signature = datafusion_expr::Signature::uniform(
1,
vec![ArrowDataType::Utf8],
datafusion_expr::Volatility::Immutable,
assert_eq!(
ConcreteDataType::uint64_datatype(),
f.return_type(&[]).unwrap()
);
assert_eq!(
f.signature(),
Signature::uniform(
1,
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
);
assert_eq!(*f.signature(), expected_signature);
}
#[tokio::test]
async fn test_missing_flow_service() {
let factory: ScalarFunctionFactory = FlushFlowFunction::factory().into();
let binding = factory.provide(FunctionContext::default());
let f = binding.as_async().unwrap();
let f = FlushFlowFunction;
let flow_name_array = Arc::new(arrow::array::StringArray::from(vec!["flow_name"]));
let args = vec!["flow_name"];
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from_slice(&[arg])) as _)
.collect::<Vec<_>>();
let columnar_args = vec![datafusion_expr::ColumnarValue::Array(flow_name_array as _)];
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: columnar_args,
arg_fields: vec![Arc::new(arrow::datatypes::Field::new(
"arg_0",
ArrowDataType::Utf8,
false,
))],
return_field: Arc::new(arrow::datatypes::Field::new(
"result",
ArrowDataType::UInt64,
true,
)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap_err();
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
assert_eq!(
"Execution error: Handler error: Missing FlowServiceHandler, not expected",
"Missing FlowServiceHandler, not expected",
result.to_string()
);
}

View File

@@ -41,12 +41,6 @@ impl FunctionContext {
}
}
impl std::fmt::Display for FunctionContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "FunctionContext {{ query_ctx: {} }}", self.query_ctx)
}
}
impl Default for FunctionContext {
fn default() -> Self {
Self {
@@ -73,3 +67,22 @@ pub trait Function: fmt::Display + Sync + Send {
}
pub type FunctionRef = Arc<dyn Function>;
/// Async Scalar function trait
#[async_trait::async_trait]
pub trait AsyncFunction: fmt::Display + Sync + Send {
/// Returns the name of the function, should be unique.
fn name(&self) -> &str;
/// The returned data type of function execution.
fn return_type(&self, input_types: &[ConcreteDataType]) -> Result<ConcreteDataType>;
/// The signature of function.
fn signature(&self) -> Signature;
/// Evaluate the function, e.g. run/execute the function.
/// TODO(dennis): simplify the signature and refactor all the admin functions.
async fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result<VectorRef>;
}
pub type AsyncFunctionRef = Arc<dyn AsyncFunction>;

View File

@@ -22,8 +22,8 @@ use crate::scalars::udf::create_udf;
/// A factory for creating `ScalarUDF` that require a function context.
#[derive(Clone)]
pub struct ScalarFunctionFactory {
pub name: String,
pub factory: Arc<dyn Fn(FunctionContext) -> ScalarUDF + Send + Sync>,
name: String,
factory: Arc<dyn Fn(FunctionContext) -> ScalarUDF + Send + Sync>,
}
impl ScalarFunctionFactory {

View File

@@ -24,7 +24,7 @@ use crate::aggrs::aggr_wrapper::StateMergeHelper;
use crate::aggrs::approximate::ApproximateFunction;
use crate::aggrs::count_hash::CountHash;
use crate::aggrs::vector::VectorFunction as VectorAggrFunction;
use crate::function::{Function, FunctionRef};
use crate::function::{AsyncFunctionRef, Function, FunctionRef};
use crate::function_factory::ScalarFunctionFactory;
use crate::scalars::date::DateFunction;
use crate::scalars::expression::ExpressionFunction;
@@ -42,18 +42,11 @@ use crate::system::SystemFunction;
#[derive(Default)]
pub struct FunctionRegistry {
functions: RwLock<HashMap<String, ScalarFunctionFactory>>,
async_functions: RwLock<HashMap<String, AsyncFunctionRef>>,
aggregate_functions: RwLock<HashMap<String, AggregateUDF>>,
}
impl FunctionRegistry {
/// Register a function in the registry by converting it into a `ScalarFunctionFactory`.
///
/// # Arguments
///
/// * `func` - An object that can be converted into a `ScalarFunctionFactory`.
///
/// The function is inserted into the internal function map, keyed by its name.
/// If a function with the same name already exists, it will be replaced.
pub fn register(&self, func: impl Into<ScalarFunctionFactory>) {
let func = func.into();
let _ = self
@@ -63,12 +56,18 @@ impl FunctionRegistry {
.insert(func.name().to_string(), func);
}
/// Register a scalar function in the registry.
pub fn register_scalar(&self, func: impl Function + 'static) {
self.register(Arc::new(func) as FunctionRef);
}
/// Register an aggregate function in the registry.
pub fn register_async(&self, func: AsyncFunctionRef) {
let _ = self
.async_functions
.write()
.unwrap()
.insert(func.name().to_string(), func);
}
pub fn register_aggr(&self, func: AggregateUDF) {
let _ = self
.aggregate_functions
@@ -77,16 +76,28 @@ impl FunctionRegistry {
.insert(func.name().to_string(), func);
}
pub fn get_async_function(&self, name: &str) -> Option<AsyncFunctionRef> {
self.async_functions.read().unwrap().get(name).cloned()
}
pub fn async_functions(&self) -> Vec<AsyncFunctionRef> {
self.async_functions
.read()
.unwrap()
.values()
.cloned()
.collect()
}
#[cfg(test)]
pub fn get_function(&self, name: &str) -> Option<ScalarFunctionFactory> {
self.functions.read().unwrap().get(name).cloned()
}
/// Returns a list of all scalar functions registered in the registry.
pub fn scalar_functions(&self) -> Vec<ScalarFunctionFactory> {
self.functions.read().unwrap().values().cloned().collect()
}
/// Returns a list of all aggregate functions registered in the registry.
pub fn aggregate_functions(&self) -> Vec<AggregateUDF> {
self.aggregate_functions
.read()
@@ -96,7 +107,6 @@ impl FunctionRegistry {
.collect()
}
/// Returns true if an aggregate function with the given name exists in the registry.
pub fn is_aggr_func_exist(&self, name: &str) -> bool {
self.aggregate_functions.read().unwrap().contains_key(name)
}

View File

@@ -113,8 +113,6 @@ mod tests {
use common_query::prelude::ScalarValue;
use datafusion::arrow::array::BooleanArray;
use datafusion_common::config::ConfigOptions;
use datatypes::arrow::datatypes::Field;
use datatypes::data_type::ConcreteDataType;
use datatypes::prelude::VectorRef;
use datatypes::vectors::{BooleanVector, ConstantVector};
@@ -164,21 +162,10 @@ mod tests {
]))),
];
let arg_fields = vec![
Arc::new(Field::new("a", args[0].data_type(), false)),
Arc::new(Field::new("b", args[1].data_type(), false)),
];
let return_field = Arc::new(Field::new(
"x",
ConcreteDataType::boolean_datatype().as_arrow_type(),
false,
));
let args = ScalarFunctionArgs {
args,
arg_fields,
number_rows: 4,
return_field,
config_options: Arc::new(ConfigOptions::default()),
return_type: &ConcreteDataType::boolean_datatype().as_arrow_type(),
};
match udf.invoke_with_args(args).unwrap() {
datafusion_expr::ColumnarValue::Array(x) => {

View File

@@ -19,6 +19,8 @@ mod procedure_state;
mod timezone;
mod version;
use std::sync::Arc;
use build::BuildFunction;
use database::{
ConnectionIdFunction, CurrentSchemaFunction, DatabaseFunction, PgBackendPidFunction,
@@ -44,7 +46,7 @@ impl SystemFunction {
registry.register_scalar(PgBackendPidFunction);
registry.register_scalar(ConnectionIdFunction);
registry.register_scalar(TimezoneFunction);
registry.register(ProcedureStateFunction::factory());
registry.register_async(Arc::new(ProcedureStateFunction));
PGCatalogFunction::register(registry);
}
}

View File

@@ -13,14 +13,13 @@
// limitations under the License.
use api::v1::meta::ProcedureStatus;
use arrow::datatypes::DataType as ArrowDataType;
use common_macro::admin_fn;
use common_meta::rpc::procedure::ProcedureStateResponse;
use common_query::error::{
InvalidFuncArgsSnafu, MissingProcedureServiceHandlerSnafu, Result,
UnsupportedInputDataTypeSnafu,
};
use datafusion_expr::{Signature, Volatility};
use common_query::prelude::{Signature, Volatility};
use datatypes::prelude::*;
use serde::Serialize;
use session::context::QueryContextRef;
@@ -82,86 +81,73 @@ pub(crate) async fn procedure_state(
}
fn signature() -> Signature {
Signature::uniform(1, vec![ArrowDataType::Utf8], Volatility::Immutable)
Signature::uniform(
1,
vec![ConcreteDataType::string_datatype()],
Volatility::Immutable,
)
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field};
use datafusion_expr::ColumnarValue;
use common_query::prelude::TypeSignature;
use datatypes::vectors::StringVector;
use super::*;
use crate::function::FunctionContext;
use crate::function_factory::ScalarFunctionFactory;
use crate::function::{AsyncFunction, FunctionContext};
#[test]
fn test_procedure_state_misc() {
let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
let f = factory.provide(FunctionContext::mock());
let f = ProcedureStateFunction;
assert_eq!("procedure_state", f.name());
assert_eq!(DataType::Utf8, f.return_type(&[]).unwrap());
assert_eq!(
ConcreteDataType::string_datatype(),
f.return_type(&[]).unwrap()
);
assert!(matches!(f.signature(),
datafusion_expr::Signature {
type_signature: datafusion_expr::TypeSignature::Uniform(1, valid_types),
volatility: datafusion_expr::Volatility::Immutable
} if valid_types == &vec![ArrowDataType::Utf8]));
Signature {
type_signature: TypeSignature::Uniform(1, valid_types),
volatility: Volatility::Immutable
} if valid_types == vec![ConcreteDataType::string_datatype()]
));
}
#[tokio::test]
async fn test_missing_procedure_service() {
let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
let binding = factory.provide(FunctionContext::default());
let f = binding.as_async().unwrap();
let f = ProcedureStateFunction;
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"pid",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await;
assert!(result.is_err());
let args = vec!["pid"];
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from_slice(&[arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::default(), &args).await.unwrap_err();
assert_eq!(
"Missing ProcedureServiceHandler, not expected",
result.to_string()
);
}
#[tokio::test]
async fn test_procedure_state() {
let factory: ScalarFunctionFactory = ProcedureStateFunction::factory().into();
let provider = factory.provide(FunctionContext::mock());
let f = provider.as_async().unwrap();
let f = ProcedureStateFunction;
let func_args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![ColumnarValue::Array(Arc::new(StringArray::from(vec![
"pid",
])))],
arg_fields: vec![Arc::new(Field::new("arg_0", DataType::Utf8, false))],
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
number_rows: 1,
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
};
let result = f.invoke_async_with_args(func_args).await.unwrap();
let args = vec!["pid"];
match result {
ColumnarValue::Array(array) => {
let result_array = array.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(
result_array.value(0),
"{\"status\":\"Done\",\"error\":\"OK\"}"
);
}
ColumnarValue::Scalar(scalar) => {
assert_eq!(
scalar,
datafusion_common::ScalarValue::Utf8(Some(
"{\"status\":\"Done\",\"error\":\"OK\"}".to_string()
))
);
}
}
let args = args
.into_iter()
.map(|arg| Arc::new(StringVector::from_slice(&[arg])) as _)
.collect::<Vec<_>>();
let result = f.eval(FunctionContext::mock(), &args).await.unwrap();
let expect: VectorRef = Arc::new(StringVector::from(vec![
"{\"status\":\"Done\",\"error\":\"OK\"}",
]));
assert_eq!(expect, result);
}
}

View File

@@ -20,7 +20,7 @@ common-telemetry.workspace = true
common-time.workspace = true
dashmap.workspace = true
datatypes.workspace = true
flatbuffers = "25.2"
flatbuffers = "24"
hyper.workspace = true
lazy_static.workspace = true
prost.workspace = true

View File

@@ -21,7 +21,6 @@ use common_telemetry::info;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use snafu::{OptionExt, ResultExt};
use tokio_util::sync::CancellationToken;
use tonic::transport::{
@@ -98,7 +97,6 @@ impl ChannelManager {
}
}
/// Read tls cert and key files and create a ChannelManager with TLS config.
pub fn with_tls_config(config: ChannelConfig) -> Result<Self> {
let mut inner = Inner::with_config(config.clone());
@@ -107,35 +105,20 @@ impl ChannelManager {
msg: "no config input",
})?;
if !path_config.enabled {
// if TLS not enabled, just ignore other tls config
// and not set `client_tls_config` hence not use TLS
return Ok(Self {
inner: Arc::new(inner),
});
}
let server_root_ca_cert = std::fs::read_to_string(path_config.server_ca_cert_path)
.context(InvalidConfigFilePathSnafu)?;
let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
let client_cert = std::fs::read_to_string(path_config.client_cert_path)
.context(InvalidConfigFilePathSnafu)?;
let client_key = std::fs::read_to_string(path_config.client_key_path)
.context(InvalidConfigFilePathSnafu)?;
let client_identity = Identity::from_pem(client_cert, client_key);
let mut tls_config = ClientTlsConfig::new();
if let Some(server_ca) = path_config.server_ca_cert_path {
let server_root_ca_cert =
std::fs::read_to_string(server_ca).context(InvalidConfigFilePathSnafu)?;
let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
tls_config = tls_config.ca_certificate(server_root_ca_cert);
}
if let (Some(client_cert_path), Some(client_key_path)) =
(&path_config.client_cert_path, &path_config.client_key_path)
{
let client_cert =
std::fs::read_to_string(client_cert_path).context(InvalidConfigFilePathSnafu)?;
let client_key =
std::fs::read_to_string(client_key_path).context(InvalidConfigFilePathSnafu)?;
let client_identity = Identity::from_pem(client_cert, client_key);
tls_config = tls_config.identity(client_identity);
}
inner.client_tls_config = Some(tls_config);
inner.client_tls_config = Some(
ClientTlsConfig::new()
.ca_certificate(server_root_ca_cert)
.identity(client_identity),
);
Ok(Self {
inner: Arc::new(inner),
@@ -287,13 +270,11 @@ impl ChannelManager {
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ClientTlsOption {
/// Whether to enable TLS for client.
pub enabled: bool,
pub server_ca_cert_path: Option<String>,
pub client_cert_path: Option<String>,
pub client_key_path: Option<String>,
pub server_ca_cert_path: String,
pub client_cert_path: String,
pub client_key_path: String,
}
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -609,10 +590,9 @@ mod tests {
.tcp_keepalive(Duration::from_secs(2))
.tcp_nodelay(false)
.client_tls_config(ClientTlsOption {
enabled: true,
server_ca_cert_path: Some("some_server_path".to_string()),
client_cert_path: Some("some_cert_path".to_string()),
client_key_path: Some("some_key_path".to_string()),
server_ca_cert_path: "some_server_path".to_string(),
client_cert_path: "some_cert_path".to_string(),
client_key_path: "some_key_path".to_string(),
});
assert_eq!(
@@ -630,10 +610,9 @@ mod tests {
tcp_keepalive: Some(Duration::from_secs(2)),
tcp_nodelay: false,
client_tls: Some(ClientTlsOption {
enabled: true,
server_ca_cert_path: Some("some_server_path".to_string()),
client_cert_path: Some("some_cert_path".to_string()),
client_key_path: Some("some_key_path".to_string()),
server_ca_cert_path: "some_server_path".to_string(),
client_cert_path: "some_cert_path".to_string(),
client_key_path: "some_key_path".to_string(),
}),
max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,

View File

@@ -25,7 +25,7 @@ use common_recordbatch::DfRecordBatch;
use datatypes::arrow;
use datatypes::arrow::array::ArrayRef;
use datatypes::arrow::buffer::Buffer;
use datatypes::arrow::datatypes::{DataType, Schema as ArrowSchema, SchemaRef};
use datatypes::arrow::datatypes::{Schema as ArrowSchema, SchemaRef};
use datatypes::arrow::error::ArrowError;
use datatypes::arrow::ipc::{convert, reader, root_as_message, writer, MessageHeader};
use flatbuffers::FlatBufferBuilder;
@@ -91,15 +91,7 @@ impl FlightEncoder {
/// be encoded to exactly one [FlightData].
pub fn encode(&mut self, flight_message: FlightMessage) -> Vec1<FlightData> {
match flight_message {
FlightMessage::Schema(schema) => {
schema.fields().iter().for_each(|x| {
if matches!(x.data_type(), DataType::Dictionary(_, _)) {
self.dictionary_tracker.next_dict_id();
}
});
vec1![self.encode_schema(schema.as_ref())]
}
FlightMessage::Schema(schema) => vec1![self.encode_schema(schema.as_ref())],
FlightMessage::RecordBatch(record_batch) => {
let (encoded_dictionaries, encoded_batch) = self
.data_gen

View File

@@ -23,10 +23,9 @@ async fn test_mtls_config() {
// test wrong file
let config = ChannelConfig::new().client_tls_config(ClientTlsOption {
enabled: true,
server_ca_cert_path: Some("tests/tls/wrong_ca.pem".to_string()),
client_cert_path: Some("tests/tls/wrong_client.pem".to_string()),
client_key_path: Some("tests/tls/wrong_client.key".to_string()),
server_ca_cert_path: "tests/tls/wrong_ca.pem".to_string(),
client_cert_path: "tests/tls/wrong_client.pem".to_string(),
client_key_path: "tests/tls/wrong_client.key".to_string(),
});
let re = ChannelManager::with_tls_config(config);
@@ -34,10 +33,9 @@ async fn test_mtls_config() {
// test corrupted file content
let config = ChannelConfig::new().client_tls_config(ClientTlsOption {
enabled: true,
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
client_cert_path: Some("tests/tls/client.pem".to_string()),
client_key_path: Some("tests/tls/corrupted".to_string()),
server_ca_cert_path: "tests/tls/ca.pem".to_string(),
client_cert_path: "tests/tls/client.pem".to_string(),
client_key_path: "tests/tls/corrupted".to_string(),
});
let re = ChannelManager::with_tls_config(config).unwrap();
@@ -46,10 +44,9 @@ async fn test_mtls_config() {
// success
let config = ChannelConfig::new().client_tls_config(ClientTlsOption {
enabled: true,
server_ca_cert_path: Some("tests/tls/ca.pem".to_string()),
client_cert_path: Some("tests/tls/client.pem".to_string()),
client_key_path: Some("tests/tls/client.key".to_string()),
server_ca_cert_path: "tests/tls/ca.pem".to_string(),
client_cert_path: "tests/tls/client.pem".to_string(),
client_key_path: "tests/tls/client.key".to_string(),
});
let re = ChannelManager::with_tls_config(config).unwrap();

View File

@@ -11,8 +11,6 @@ proc-macro = true
workspace = true
[dependencies]
greptime-proto.workspace = true
once_cell.workspace = true
proc-macro2 = "1.0.66"
quote = "1.0"
syn = { version = "2.0", features = [

View File

@@ -187,28 +187,8 @@ fn build_struct(
quote! {
#(#attrs)*
#vis struct #name {
signature: datafusion_expr::Signature,
func_ctx: #user_path::function::FunctionContext,
}
impl #name {
/// Creates a new instance of the function with function context.
fn create(signature: datafusion_expr::Signature, func_ctx: #user_path::function::FunctionContext) -> Self {
Self {
signature,
func_ctx,
}
}
/// Returns the [`ScalarFunctionFactory`] of the function.
pub fn factory() -> impl Into< #user_path::function_factory::ScalarFunctionFactory> {
Self {
signature: #sig_fn().into(),
func_ctx: #user_path::function::FunctionContext::default(),
}
}
}
#[derive(Debug)]
#vis struct #name;
impl std::fmt::Display for #name {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
@@ -216,87 +196,24 @@ fn build_struct(
}
}
impl std::fmt::Debug for #name {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}({})", #uppcase_display_name, self.func_ctx)
}
}
// Implement DataFusion's ScalarUDFImpl trait
impl datafusion::logical_expr::ScalarUDFImpl for #name {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
#[async_trait::async_trait]
impl #user_path::function::AsyncFunction for #name {
fn name(&self) -> &'static str {
#display_name
}
fn signature(&self) -> &datafusion_expr::Signature {
&self.signature
fn return_type(&self, _input_types: &[store_api::storage::ConcreteDataType]) -> common_query::error::Result<store_api::storage::ConcreteDataType> {
Ok(store_api::storage::ConcreteDataType::#ret())
}
fn return_type(&self, _arg_types: &[datafusion::arrow::datatypes::DataType]) -> datafusion_common::Result<datafusion::arrow::datatypes::DataType> {
use datatypes::data_type::DataType;
Ok(store_api::storage::ConcreteDataType::#ret().as_arrow_type())
fn signature(&self) -> Signature {
#sig_fn()
}
fn invoke_with_args(
&self,
_args: datafusion::logical_expr::ScalarFunctionArgs,
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
Err(datafusion_common::DataFusionError::NotImplemented(
format!("{} can only be called from async contexts", #display_name)
))
}
}
/// Implement From trait for ScalarFunctionFactory
impl From<#name> for #user_path::function_factory::ScalarFunctionFactory {
fn from(func: #name) -> Self {
use std::sync::Arc;
use datafusion_expr::ScalarUDFImpl;
use datafusion_expr::async_udf::AsyncScalarUDF;
let name = func.name().to_string();
let func = Arc::new(move |ctx: #user_path::function::FunctionContext| {
// create the UDF dynamically with function context
let udf_impl = #name::create(func.signature.clone(), ctx);
let async_udf = AsyncScalarUDF::new(Arc::new(udf_impl));
async_udf.into_scalar_udf()
});
Self {
name,
factory: func,
}
}
}
// Implement DataFusion's AsyncScalarUDFImpl trait
#[async_trait::async_trait]
impl datafusion_expr::async_udf::AsyncScalarUDFImpl for #name {
async fn invoke_async_with_args(
&self,
args: datafusion::logical_expr::ScalarFunctionArgs,
) -> datafusion_common::Result<datafusion_expr::ColumnarValue> {
let columns = args.args
.iter()
.map(|arg| {
common_query::prelude::ColumnarValue::try_from(arg)
.and_then(|cv| match cv {
common_query::prelude::ColumnarValue::Vector(v) => Ok(v),
common_query::prelude::ColumnarValue::Scalar(s) => {
datatypes::vectors::Helper::try_from_scalar_value(s, args.number_rows)
.context(common_query::error::FromScalarValueSnafu)
}
})
})
.collect::<common_query::error::Result<Vec<_>>>()
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Column conversion error: {}", e)))?;
// Safety check: Ensure under the `greptime` catalog for security
#user_path::ensure_greptime!(self.func_ctx);
async fn eval(&self, func_ctx: #user_path::function::FunctionContext, columns: &[datatypes::vectors::VectorRef]) -> common_query::error::Result<datatypes::vectors::VectorRef> {
// Ensure under the `greptime` catalog for security
#user_path::ensure_greptime!(func_ctx);
let columns_num = columns.len();
let rows_num = if columns.is_empty() {
@@ -304,24 +221,23 @@ fn build_struct(
} else {
columns[0].len()
};
let columns = Vec::from(columns);
use snafu::{OptionExt, ResultExt};
use snafu::OptionExt;
use datatypes::data_type::DataType;
let query_ctx = &self.func_ctx.query_ctx;
let handler = self.func_ctx
let query_ctx = &func_ctx.query_ctx;
let handler = func_ctx
.state
.#handler
.as_ref()
.context(#snafu_type)
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Handler error: {}", e)))?;
.context(#snafu_type)?;
let mut builder = store_api::storage::ConcreteDataType::#ret()
.create_mutable_vector(rows_num);
if columns_num == 0 {
let result = #fn_name(handler, query_ctx, &[]).await
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
let result = #fn_name(handler, query_ctx, &[]).await?;
builder.push_value_ref(result.as_value_ref());
} else {
@@ -330,18 +246,15 @@ fn build_struct(
.map(|vector| vector.get_ref(i))
.collect();
let result = #fn_name(handler, query_ctx, &args).await
.map_err(|e| datafusion_common::DataFusionError::Execution(format!("Function execution error: {}", e)))?;
let result = #fn_name(handler, query_ctx, &args).await?;
builder.push_value_ref(result.as_value_ref());
}
}
let result_vector = builder.to_vector();
// Convert result back to DataFusion ColumnarValue
Ok(datafusion_expr::ColumnarValue::Array(result_vector.to_arrow_array()))
Ok(builder.to_vector())
}
}
}
.into()

View File

@@ -16,7 +16,6 @@ mod admin_fn;
mod aggr_func;
mod print_caller;
mod range_fn;
mod row;
mod stack_trace_debug;
mod utils;
@@ -28,9 +27,6 @@ use range_fn::process_range_fn;
use syn::{parse_macro_input, Data, DeriveInput, Fields};
use crate::admin_fn::process_admin_fn;
use crate::row::into_row::derive_into_row_impl;
use crate::row::schema::derive_schema_impl;
use crate::row::to_row::derive_to_row_impl;
/// Make struct implemented trait [AggrFuncTypeStore], which is necessary when writing UDAF.
/// This derive macro is expect to be used along with attribute macro [macro@as_aggr_func_creator].
@@ -190,117 +186,3 @@ pub fn derive_meta_builder(input: TokenStream) -> TokenStream {
gen.into()
}
/// Derive macro to convert a struct to a row.
///
/// # Example
/// ```rust, ignore
/// use api::v1::Row;
/// use api::v1::value::ValueData;
/// use api::v1::Value;
///
/// #[derive(ToRow)]
/// struct ToRowTest {
/// my_value: i32,
/// #[col(name = "string_value", datatype = "string", semantic = "tag")]
/// my_string: String,
/// my_bool: bool,
/// my_float: f32,
/// #[col(
/// name = "timestamp_value",
/// semantic = "Timestamp",
/// datatype = "TimestampMillisecond"
/// )]
/// my_timestamp: i64,
/// #[col(skip)]
/// my_skip: i32,
/// }
///
/// let row = ToRowTest {
/// my_value: 1,
/// my_string: "test".to_string(),
/// my_bool: true,
/// my_float: 1.0,
/// my_timestamp: 1718563200000,
/// my_skip: 1,
/// }.to_row();
/// ```
#[proc_macro_derive(ToRow, attributes(col))]
pub fn derive_to_row(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let output = derive_to_row_impl(input);
output.unwrap_or_else(|e| e.to_compile_error()).into()
}
/// Derive macro to convert a struct to a row with move semantics.
///
/// # Example
/// ```rust, ignore
/// use api::v1::Row;
/// use api::v1::value::ValueData;
/// use api::v1::Value;
///
/// #[derive(IntoRow)]
/// struct IntoRowTest {
/// my_value: i32,
/// #[col(name = "string_value", datatype = "string", semantic = "tag")]
/// my_string: String,
/// my_bool: bool,
/// my_float: f32,
/// #[col(
/// name = "timestamp_value",
/// semantic = "Timestamp",
/// datatype = "TimestampMillisecond"
/// )]
/// my_timestamp: i64,
/// #[col(skip)]
/// my_skip: i32,
/// }
///
/// let row = IntoRowTest {
/// my_value: 1,
/// my_string: "test".to_string(),
/// my_bool: true,
/// my_float: 1.0,
/// my_timestamp: 1718563200000,
/// my_skip: 1,
/// }.into_row();
/// ```
#[proc_macro_derive(IntoRow, attributes(col))]
pub fn derive_into_row(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let output = derive_into_row_impl(input);
output.unwrap_or_else(|e| e.to_compile_error()).into()
}
/// Derive macro to convert a struct to a schema.
///
/// # Example
/// ```rust, ignore
/// use api::v1::ColumnSchema;
///
/// #[derive(Schema)]
/// struct SchemaTest {
/// my_value: i32,
/// #[col(name = "string_value", datatype = "string", semantic = "tag")]
/// my_string: String,
/// my_bool: bool,
/// my_float: f32,
/// #[col(
/// name = "timestamp_value",
/// semantic = "Timestamp",
/// datatype = "TimestampMillisecond"
/// )]
/// my_timestamp: i64,
/// #[col(skip)]
/// my_skip: i32,
/// }
///
/// let schema = SchemaTest::schema();
/// ```
#[proc_macro_derive(Schema, attributes(col))]
pub fn derive_schema(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let output = derive_schema_impl(input);
output.unwrap_or_else(|e| e.to_compile_error()).into()
}

View File

@@ -1,25 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub(crate) mod attribute;
pub(crate) mod into_row;
pub(crate) mod schema;
pub(crate) mod to_row;
pub(crate) mod utils;
pub(crate) const META_KEY_COL: &str = "col";
pub(crate) const META_KEY_NAME: &str = "name";
pub(crate) const META_KEY_DATATYPE: &str = "datatype";
pub(crate) const META_KEY_SEMANTIC: &str = "semantic";
pub(crate) const META_KEY_SKIP: &str = "skip";

View File

@@ -1,128 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use once_cell::sync::Lazy;
use syn::meta::ParseNestedMeta;
use syn::spanned::Spanned;
use syn::{Attribute, LitStr, Meta, Result};
use crate::row::utils::{
column_data_type_from_str, semantic_type_from_str, ColumnDataTypeWithExtension, SemanticType,
};
use crate::row::{
META_KEY_COL, META_KEY_DATATYPE, META_KEY_NAME, META_KEY_SEMANTIC, META_KEY_SKIP,
};
/// Column attribute.
#[derive(Default)]
pub(crate) struct ColumnAttribute {
/// User-defined name of the column.,
pub(crate) name: Option<String>,
/// Data type of the column.
pub(crate) datatype: Option<ColumnDataTypeWithExtension>,
/// Semantic type of the column.
pub(crate) semantic_type: SemanticType,
/// Whether to skip the column.
pub(crate) skip: bool,
}
/// Find the column attribute in the attributes.
pub(crate) fn find_column_attribute(attrs: &[Attribute]) -> Option<&Attribute> {
attrs
.iter()
.find(|attr| matches!(&attr.meta, Meta::List(list) if list.path.is_ident(META_KEY_COL)))
}
/// Parse the column attribute.
pub(crate) fn parse_column_attribute(attr: &Attribute) -> Result<ColumnAttribute> {
match &attr.meta {
Meta::List(list) if list.path.is_ident(META_KEY_COL) => {
let mut attribute = ColumnAttribute::default();
list.parse_nested_meta(|meta| {
parse_column_attribute_field(&meta, &mut attribute)
})?;
Ok(attribute)
}
_ => Err(syn::Error::new(
attr.span(),
format!(
"expected `{META_KEY_COL}({META_KEY_NAME} = \"...\", {META_KEY_DATATYPE} = \"...\", {META_KEY_SEMANTIC} = \"...\")`"
),
)),
}
}
type ParseColumnAttributeField = fn(&ParseNestedMeta, &mut ColumnAttribute) -> Result<()>;
static PARSE_COLUMN_ATTRIBUTE_FIELDS: Lazy<HashMap<&str, ParseColumnAttributeField>> =
Lazy::new(|| {
HashMap::from([
(META_KEY_NAME, parse_name_field as _),
(META_KEY_DATATYPE, parse_datatype_field as _),
(META_KEY_SEMANTIC, parse_semantic_field as _),
(META_KEY_SKIP, parse_skip_field as _),
])
});
fn parse_name_field(meta: &ParseNestedMeta<'_>, attribute: &mut ColumnAttribute) -> Result<()> {
let value = meta.value()?;
let s: LitStr = value.parse()?;
attribute.name = Some(s.value());
Ok(())
}
fn parse_datatype_field(meta: &ParseNestedMeta<'_>, attribute: &mut ColumnAttribute) -> Result<()> {
let value = meta.value()?;
let s: LitStr = value.parse()?;
let ident = s.value();
let Some(value) = column_data_type_from_str(&ident) else {
return Err(meta.error(format!("unexpected {META_KEY_DATATYPE}: {ident}")));
};
attribute.datatype = Some(value);
Ok(())
}
fn parse_semantic_field(meta: &ParseNestedMeta<'_>, attribute: &mut ColumnAttribute) -> Result<()> {
let value = meta.value()?;
let s: LitStr = value.parse()?;
let ident = s.value();
let Some(value) = semantic_type_from_str(&ident) else {
return Err(meta.error(format!("unexpected {META_KEY_SEMANTIC}: {ident}")));
};
attribute.semantic_type = value;
Ok(())
}
fn parse_skip_field(_: &ParseNestedMeta<'_>, attribute: &mut ColumnAttribute) -> Result<()> {
attribute.skip = true;
Ok(())
}
fn parse_column_attribute_field(
meta: &ParseNestedMeta<'_>,
attribute: &mut ColumnAttribute,
) -> Result<()> {
let Some(ident) = meta.path.get_ident() else {
return Err(meta.error(format!("expected `{META_KEY_COL}({META_KEY_NAME} = \"...\", {META_KEY_DATATYPE} = \"...\", {META_KEY_SEMANTIC} = \"...\")`")));
};
let Some(parse_column_attribute) =
PARSE_COLUMN_ATTRIBUTE_FIELDS.get(ident.to_string().as_str())
else {
return Err(meta.error(format!("unexpected attribute: {ident}")));
};
parse_column_attribute(meta, attribute)
}

View File

@@ -1,88 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::spanned::Spanned;
use syn::{DeriveInput, Result};
use crate::row::utils::{
convert_column_data_type_to_value_data_ident, extract_struct_fields, get_column_data_type,
parse_fields_from_fields_named, ParsedField,
};
use crate::row::{META_KEY_COL, META_KEY_DATATYPE};
pub(crate) fn derive_into_row_impl(input: DeriveInput) -> Result<TokenStream2> {
let Some(fields) = extract_struct_fields(&input.data) else {
return Err(syn::Error::new(
input.span(),
"IntoRow can only be derived for structs",
));
};
let ident = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fields = parse_fields_from_fields_named(fields)?;
// Implement `into_row` method.
let impl_to_row_method = impl_into_row_method_combined(&fields)?;
Ok(quote! {
impl #impl_generics #ident #ty_generics #where_clause {
#impl_to_row_method
}
})
}
fn impl_into_row_method_combined(fields: &[ParsedField<'_>]) -> Result<TokenStream2> {
let value_exprs = fields
.iter()
.map(|field| {
let ParsedField {ident, field_type, column_data_type, column_attribute} = field;
let Some(column_data_type) = get_column_data_type(column_data_type, column_attribute)
else {
return Err(syn::Error::new(
ident.span(),
format!(
"expected to set data type explicitly via [({META_KEY_COL}({META_KEY_DATATYPE} = \"...\"))]"
),
));
};
let value_data = convert_column_data_type_to_value_data_ident(&column_data_type.data_type);
let expr = if field_type.is_optional() {
quote! {
match self.#ident {
Some(v) => Value {
value_data: Some(ValueData::#value_data(v.into())),
},
None => Value { value_data: None },
}
}
} else {
quote! {
Value {
value_data: Some(ValueData::#value_data(self.#ident.into())),
}
}
};
Ok(expr)
})
.collect::<Result<Vec<_>>>()?;
Ok(quote! {
pub fn into_row(self) -> Row {
Row {
values: vec![ #( #value_exprs ),* ]
}
}
})
}

View File

@@ -1,118 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use greptime_proto::v1::column_data_type_extension::TypeExt;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::spanned::Spanned;
use syn::{DeriveInput, Result};
use crate::row::utils::{
convert_semantic_type_to_proto_semantic_type, extract_struct_fields, get_column_data_type,
parse_fields_from_fields_named, ColumnDataTypeWithExtension, ParsedField,
};
use crate::row::{META_KEY_COL, META_KEY_DATATYPE};
pub(crate) fn derive_schema_impl(input: DeriveInput) -> Result<TokenStream2> {
let Some(fields) = extract_struct_fields(&input.data) else {
return Err(syn::Error::new(
input.span(),
"Schema can only be derived for structs",
));
};
let ident = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fields = parse_fields_from_fields_named(fields)?;
// Implement `schema` method.
let impl_schema_method = impl_schema_method(&fields)?;
Ok(quote! {
impl #impl_generics #ident #ty_generics #where_clause {
#impl_schema_method
}
})
}
fn impl_schema_method(fields: &[ParsedField<'_>]) -> Result<TokenStream2> {
let schemas: Vec<TokenStream2> = fields
.iter()
.map(|field| {
let ParsedField{ ident, column_data_type, column_attribute, ..} = field;
let Some(ColumnDataTypeWithExtension{data_type, extension}) = get_column_data_type(column_data_type, column_attribute)
else {
return Err(syn::Error::new(
ident.span(),
format!(
"expected to set data type explicitly via [({META_KEY_COL}({META_KEY_DATATYPE} = \"...\"))]"
),
));
};
// Uses user explicit name or field name as column name.
let name = column_attribute
.name
.clone()
.unwrap_or_else(|| ident.to_string());
let name = syn::LitStr::new(&name, ident.span());
let column_data_type =
syn::LitInt::new(&(data_type as i32).to_string(), ident.span());
let semantic_type_val = convert_semantic_type_to_proto_semantic_type(column_attribute.semantic_type) as i32;
let semantic_type = syn::LitInt::new(&semantic_type_val.to_string(), ident.span());
let extension = match extension {
Some(ext) => {
match ext.type_ext {
Some(TypeExt::DecimalType(ext)) => {
let precision = syn::LitInt::new(&ext.precision.to_string(), ident.span());
let scale = syn::LitInt::new(&ext.scale.to_string(), ident.span());
quote! {
Some(ColumnDataTypeExtension { type_ext: Some(TypeExt::DecimalType(DecimalTypeExtension { precision: #precision, scale: #scale })) })
}
}
Some(TypeExt::JsonType(ext)) => {
let json_type = syn::LitInt::new(&ext.to_string(), ident.span());
quote! {
Some(ColumnDataTypeExtension { type_ext: Some(TypeExt::JsonType(#json_type)) })
}
}
Some(TypeExt::VectorType(ext)) => {
let dim = syn::LitInt::new(&ext.dim.to_string(), ident.span());
quote! {
Some(ColumnDataTypeExtension { type_ext: Some(TypeExt::VectorType(VectorTypeExtension { dim: #dim })) })
}
}
None => {
quote! { None }
}
}
}
None => quote! { None },
};
Ok(quote! {
ColumnSchema {
column_name: #name.to_string(),
datatype: #column_data_type,
datatype_extension: #extension,
options: None,
semantic_type: #semantic_type,
}
})
})
.collect::<Result<_>>()?;
Ok(quote! {
pub fn schema() -> Vec<ColumnSchema> {
vec![ #(#schemas),* ]
}
})
}

View File

@@ -1,88 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::spanned::Spanned;
use syn::{DeriveInput, Result};
use crate::row::utils::{
convert_column_data_type_to_value_data_ident, extract_struct_fields, get_column_data_type,
parse_fields_from_fields_named, ParsedField,
};
use crate::row::{META_KEY_COL, META_KEY_DATATYPE};
pub(crate) fn derive_to_row_impl(input: DeriveInput) -> Result<TokenStream2> {
let Some(fields) = extract_struct_fields(&input.data) else {
return Err(syn::Error::new(
input.span(),
"ToRow can only be derived for structs",
));
};
let ident = input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fields = parse_fields_from_fields_named(fields)?;
// Implement `to_row` method.
let impl_to_row_method = impl_to_row_method_combined(&fields)?;
Ok(quote! {
impl #impl_generics #ident #ty_generics #where_clause {
#impl_to_row_method
}
})
}
fn impl_to_row_method_combined(fields: &[ParsedField<'_>]) -> Result<TokenStream2> {
let value_exprs = fields
.iter()
.map(|field| {
let ParsedField {ident, field_type, column_data_type, column_attribute} = field;
let Some(column_data_type) = get_column_data_type(column_data_type, column_attribute)
else {
return Err(syn::Error::new(
ident.span(),
format!(
"expected to set data type explicitly via [({META_KEY_COL}({META_KEY_DATATYPE} = \"...\"))]"
),
));
};
let value_data = convert_column_data_type_to_value_data_ident(&column_data_type.data_type);
let expr = if field_type.is_optional() {
quote! {
match &self.#ident {
Some(v) => Value {
value_data: Some(ValueData::#value_data(v.clone().into())),
},
None => Value { value_data: None },
}
}
} else {
quote! {
Value {
value_data: Some(ValueData::#value_data(self.#ident.clone().into())),
}
}
};
Ok(expr)
})
.collect::<Result<Vec<_>>>()?;
Ok(quote! {
pub fn to_row(&self) -> Row {
Row {
values: vec![ #( #value_exprs ),* ]
}
}
})
}

View File

@@ -1,313 +0,0 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use greptime_proto::v1::column_data_type_extension::TypeExt;
use greptime_proto::v1::{ColumnDataType, ColumnDataTypeExtension, JsonTypeExtension};
use once_cell::sync::Lazy;
use quote::format_ident;
use syn::{
AngleBracketedGenericArguments, Data, DataStruct, Fields, FieldsNamed, GenericArgument, Ident,
Path, PathArguments, PathSegment, Result, Type, TypePath, TypeReference,
};
use crate::row::attribute::{find_column_attribute, parse_column_attribute, ColumnAttribute};
static SEMANTIC_TYPES: Lazy<HashMap<&'static str, SemanticType>> = Lazy::new(|| {
HashMap::from([
("field", SemanticType::Field),
("tag", SemanticType::Tag),
("timestamp", SemanticType::Timestamp),
])
});
static DATATYPE_TO_COLUMN_DATA_TYPE: Lazy<HashMap<&'static str, ColumnDataTypeWithExtension>> =
Lazy::new(|| {
HashMap::from([
// Timestamp
("timestampsecond", ColumnDataType::TimestampSecond.into()),
(
"timestampmillisecond",
ColumnDataType::TimestampMillisecond.into(),
),
(
"timestampmicrosecond",
ColumnDataType::TimestampMicrosecond.into(),
),
(
"timestampnanosecond",
ColumnDataType::TimestampNanosecond.into(),
),
// Date
("date", ColumnDataType::Date.into()),
("datetime", ColumnDataType::Datetime.into()),
// Time
("timesecond", ColumnDataType::TimeSecond.into()),
("timemillisecond", ColumnDataType::TimeMillisecond.into()),
("timemicrosecond", ColumnDataType::TimeMicrosecond.into()),
("timenanosecond", ColumnDataType::TimeNanosecond.into()),
// Others
("string", ColumnDataType::String.into()),
("json", ColumnDataTypeWithExtension::json()),
// TODO(weny): support vector and decimal128.
])
});
static PRIMITIVE_TYPE_TO_COLUMN_DATA_TYPE: Lazy<HashMap<&'static str, ColumnDataType>> =
Lazy::new(|| {
HashMap::from([
("i8", ColumnDataType::Int8),
("i16", ColumnDataType::Int16),
("i32", ColumnDataType::Int32),
("i64", ColumnDataType::Int64),
("u8", ColumnDataType::Uint8),
("u16", ColumnDataType::Uint16),
("u32", ColumnDataType::Uint32),
("u64", ColumnDataType::Uint64),
("f32", ColumnDataType::Float32),
("f64", ColumnDataType::Float64),
("bool", ColumnDataType::Boolean),
])
});
/// Extract the fields of a struct.
pub(crate) fn extract_struct_fields(data: &Data) -> Option<&FieldsNamed> {
let Data::Struct(DataStruct {
fields: Fields::Named(named),
..
}) = &data
else {
return None;
};
Some(named)
}
/// Convert an identifier to a semantic type.
pub(crate) fn semantic_type_from_str(ident: &str) -> Option<SemanticType> {
// Ignores the case of the identifier.
let lowercase = ident.to_lowercase();
let lowercase_str = lowercase.as_str();
SEMANTIC_TYPES.get(lowercase_str).cloned()
}
/// Convert a field type to a column data type.
pub(crate) fn column_data_type_from_str(ident: &str) -> Option<ColumnDataTypeWithExtension> {
// Ignores the case of the identifier.
let lowercase = ident.to_lowercase();
let lowercase_str = lowercase.as_str();
DATATYPE_TO_COLUMN_DATA_TYPE.get(lowercase_str).cloned()
}
#[derive(Default, Clone, Copy)]
pub(crate) enum SemanticType {
#[default]
Field,
Tag,
Timestamp,
}
pub(crate) enum FieldType<'a> {
Required(&'a Type),
Optional(&'a Type),
}
impl FieldType<'_> {
pub(crate) fn is_optional(&self) -> bool {
matches!(self, FieldType::Optional(_))
}
pub(crate) fn extract_ident(&self) -> Option<&Ident> {
match self {
FieldType::Required(ty) => extract_ident_from_type(ty),
FieldType::Optional(ty) => extract_ident_from_type(ty),
}
}
}
fn field_type(ty: &Type) -> FieldType<'_> {
if let Type::Reference(TypeReference { elem, .. }) = ty {
return field_type(elem);
}
if let Type::Path(TypePath {
qself: _,
path: Path {
leading_colon,
segments,
},
}) = ty
{
if leading_colon.is_none() && segments.len() == 1 {
if let Some(PathSegment {
ident,
arguments:
PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }),
}) = segments.first()
{
if let (1, Some(GenericArgument::Type(t))) = (args.len(), args.first()) {
if ident == "Option" {
return FieldType::Optional(t);
}
}
}
}
}
FieldType::Required(ty)
}
fn extract_ident_from_type(ty: &Type) -> Option<&Ident> {
match ty {
Type::Path(TypePath { qself: None, path }) => path.get_ident(),
Type::Reference(type_ref) => extract_ident_from_type(&type_ref.elem),
Type::Group(type_group) => extract_ident_from_type(&type_group.elem),
_ => None,
}
}
/// Convert a semantic type to a proto semantic type.
pub(crate) fn convert_semantic_type_to_proto_semantic_type(
semantic_type: SemanticType,
) -> greptime_proto::v1::SemanticType {
match semantic_type {
SemanticType::Field => greptime_proto::v1::SemanticType::Field,
SemanticType::Tag => greptime_proto::v1::SemanticType::Tag,
SemanticType::Timestamp => greptime_proto::v1::SemanticType::Timestamp,
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct ColumnDataTypeWithExtension {
pub(crate) data_type: ColumnDataType,
pub(crate) extension: Option<ColumnDataTypeExtension>,
}
impl ColumnDataTypeWithExtension {
pub(crate) fn json() -> Self {
Self {
data_type: ColumnDataType::Json,
extension: Some(ColumnDataTypeExtension {
type_ext: Some(TypeExt::JsonType(JsonTypeExtension::JsonBinary.into())),
}),
}
}
}
impl From<ColumnDataType> for ColumnDataTypeWithExtension {
fn from(data_type: ColumnDataType) -> Self {
Self {
data_type,
extension: None,
}
}
}
pub(crate) struct ParsedField<'a> {
pub(crate) ident: &'a Ident,
pub(crate) field_type: FieldType<'a>,
pub(crate) column_data_type: Option<ColumnDataTypeWithExtension>,
pub(crate) column_attribute: ColumnAttribute,
}
/// Parse fields from fields named.
pub(crate) fn parse_fields_from_fields_named(named: &FieldsNamed) -> Result<Vec<ParsedField<'_>>> {
Ok(named
.named
.iter()
.map(|field| {
let ident = field.ident.as_ref().expect("field must have an ident");
let field_type = field_type(&field.ty);
let column_data_type = field_type
.extract_ident()
.and_then(convert_primitive_type_to_column_data_type);
let column_attribute = find_column_attribute(&field.attrs)
.map(parse_column_attribute)
.transpose()?
.unwrap_or_default();
Ok(ParsedField {
ident,
field_type,
column_data_type,
column_attribute,
})
})
.collect::<Result<Vec<ParsedField<'_>>>>()?
.into_iter()
.filter(|field| !field.column_attribute.skip)
.collect::<Vec<_>>())
}
fn convert_primitive_type_to_column_data_type(
ident: &Ident,
) -> Option<ColumnDataTypeWithExtension> {
PRIMITIVE_TYPE_TO_COLUMN_DATA_TYPE
.get(ident.to_string().as_str())
.cloned()
.map(ColumnDataTypeWithExtension::from)
}
/// Get the column data type from the attribute or the inferred column data type.
pub(crate) fn get_column_data_type(
infer_column_data_type: &Option<ColumnDataTypeWithExtension>,
attribute: &ColumnAttribute,
) -> Option<ColumnDataTypeWithExtension> {
attribute.datatype.or(*infer_column_data_type)
}
/// Convert a column data type to a value data ident.
pub(crate) fn convert_column_data_type_to_value_data_ident(
column_data_type: &ColumnDataType,
) -> Ident {
match column_data_type {
ColumnDataType::Boolean => format_ident!("BoolValue"),
ColumnDataType::Int8 => format_ident!("I8Value"),
ColumnDataType::Int16 => format_ident!("I16Value"),
ColumnDataType::Int32 => format_ident!("I32Value"),
ColumnDataType::Int64 => format_ident!("I64Value"),
ColumnDataType::Uint8 => format_ident!("U8Value"),
ColumnDataType::Uint16 => format_ident!("U16Value"),
ColumnDataType::Uint32 => format_ident!("U32Value"),
ColumnDataType::Uint64 => format_ident!("U64Value"),
ColumnDataType::Float32 => format_ident!("F32Value"),
ColumnDataType::Float64 => format_ident!("F64Value"),
ColumnDataType::Binary => format_ident!("BinaryValue"),
ColumnDataType::String => format_ident!("StringValue"),
ColumnDataType::Date => format_ident!("DateValue"),
ColumnDataType::Datetime => format_ident!("DatetimeValue"),
ColumnDataType::TimestampSecond => format_ident!("TimestampSecondValue"),
ColumnDataType::TimestampMillisecond => {
format_ident!("TimestampMillisecondValue")
}
ColumnDataType::TimestampMicrosecond => {
format_ident!("TimestampMicrosecondValue")
}
ColumnDataType::TimestampNanosecond => format_ident!("TimestampNanosecondValue"),
ColumnDataType::TimeSecond => format_ident!("TimeSecondValue"),
ColumnDataType::TimeMillisecond => format_ident!("TimeMillisecondValue"),
ColumnDataType::TimeMicrosecond => format_ident!("TimeMicrosecondValue"),
ColumnDataType::TimeNanosecond => format_ident!("TimeNanosecondValue"),
ColumnDataType::IntervalYearMonth => format_ident!("IntervalYearMonthValue"),
ColumnDataType::IntervalDayTime => format_ident!("IntervalDayTimeValue"),
ColumnDataType::IntervalMonthDayNano => {
format_ident!("IntervalMonthDayNanoValue")
}
ColumnDataType::Decimal128 => format_ident!("Decimal128Value"),
// Json is a special case, it is actually a string column.
ColumnDataType::Json => format_ident!("StringValue"),
ColumnDataType::Vector => format_ident!("VectorValue"),
}
}

View File

@@ -12,183 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use common_macro::{IntoRow, Schema, ToRow};
use greptime_proto::v1::column_data_type_extension::TypeExt;
use greptime_proto::v1::value::ValueData;
use greptime_proto::v1::{
ColumnDataType, ColumnDataTypeExtension, ColumnSchema, JsonTypeExtension, Row, SemanticType,
Value,
};
use common_macro::{as_aggr_func_creator, AggrFuncTypeStore};
use static_assertions::{assert_fields, assert_impl_all};
#[derive(ToRow, Schema, IntoRow)]
struct ToRowOwned {
my_value: i32,
#[col(name = "string_value", datatype = "string", semantic = "tag")]
my_string: String,
my_bool: bool,
my_float: f32,
#[col(
name = "timestamp_value",
semantic = "Timestamp",
datatype = "TimestampMillisecond"
)]
my_timestamp: i64,
#[allow(dead_code)]
#[col(skip)]
my_skip: i32,
#[col(name = "json_value", datatype = "json")]
my_json: String,
}
#[as_aggr_func_creator]
#[derive(Debug, Default, AggrFuncTypeStore)]
struct Foo {}
#[test]
fn test_to_row() {
let test = ToRowOwned {
my_value: 1,
my_string: "test".to_string(),
my_bool: true,
my_float: 1.0,
my_timestamp: 1718563200000,
my_skip: 1,
my_json: r#"{"name":"John", "age":30}"#.to_string(),
};
let row = test.to_row();
assert_row(&row);
let schema = ToRowOwned::schema();
assert_schema(&schema);
let row2 = test.into_row();
assert_row(&row2);
}
#[derive(ToRow, Schema)]
struct ToRowRef<'a> {
my_value: &'a i32,
#[col(name = "string_value", datatype = "string", semantic = "tag")]
my_string: &'a String,
my_bool: &'a bool,
my_float: &'a f32,
#[col(
name = "timestamp_value",
semantic = "Timestamp",
datatype = "TimestampMillisecond"
)]
my_timestamp: &'a i64,
#[col(name = "json_value", datatype = "json")]
my_json: &'a str,
}
#[test]
fn test_to_row_ref() {
let string = "test".to_string();
let test = ToRowRef {
my_value: &1,
my_string: &string,
my_bool: &true,
my_float: &1.0,
my_timestamp: &1718563200000,
my_json: r#"{"name":"John", "age":30}"#,
};
let row = test.to_row();
assert_row(&row);
let schema = ToRowRef::schema();
assert_schema(&schema);
}
#[derive(ToRow, IntoRow)]
struct ToRowOptional {
my_value: Option<i32>,
#[col(name = "string_value", datatype = "string", semantic = "tag")]
my_string: Option<String>,
my_bool: Option<bool>,
my_float: Option<f32>,
#[col(
name = "timestamp_value",
semantic = "Timestamp",
datatype = "TimestampMillisecond"
)]
my_timestamp: i64,
}
fn assert_row_optional(row: &Row) {
assert_eq!(row.values.len(), 5);
assert_eq!(row.values[0].value_data, None);
assert_eq!(row.values[1].value_data, None);
assert_eq!(row.values[2].value_data, None);
assert_eq!(row.values[3].value_data, None);
assert_eq!(
row.values[4].value_data,
Some(greptime_proto::v1::value::ValueData::TimestampMillisecondValue(1718563200000))
);
}
#[test]
fn test_to_row_optional() {
let test = ToRowOptional {
my_value: None,
my_string: None,
my_bool: None,
my_float: None,
my_timestamp: 1718563200000,
};
let row = test.to_row();
assert_row_optional(&row);
let row2 = test.into_row();
assert_row_optional(&row2);
}
fn assert_row(row: &Row) {
assert_eq!(row.values.len(), 6);
assert_eq!(
row.values[0].value_data,
Some(greptime_proto::v1::value::ValueData::I32Value(1))
);
assert_eq!(
row.values[1].value_data,
Some(greptime_proto::v1::value::ValueData::StringValue(
"test".to_string()
))
);
assert_eq!(
row.values[2].value_data,
Some(greptime_proto::v1::value::ValueData::BoolValue(true))
);
assert_eq!(
row.values[3].value_data,
Some(greptime_proto::v1::value::ValueData::F32Value(1.0))
);
assert_eq!(
row.values[4].value_data,
Some(greptime_proto::v1::value::ValueData::TimestampMillisecondValue(1718563200000))
);
}
fn assert_schema(schema: &[ColumnSchema]) {
assert_eq!(schema.len(), 6);
assert_eq!(schema[0].column_name, "my_value");
assert_eq!(schema[0].datatype, ColumnDataType::Int32 as i32);
assert_eq!(schema[0].semantic_type, SemanticType::Field as i32);
assert_eq!(schema[1].column_name, "string_value");
assert_eq!(schema[1].datatype, ColumnDataType::String as i32);
assert_eq!(schema[1].semantic_type, SemanticType::Tag as i32);
assert_eq!(schema[2].column_name, "my_bool");
assert_eq!(schema[2].datatype, ColumnDataType::Boolean as i32);
assert_eq!(schema[2].semantic_type, SemanticType::Field as i32);
assert_eq!(schema[3].column_name, "my_float");
assert_eq!(schema[3].datatype, ColumnDataType::Float32 as i32);
assert_eq!(schema[3].semantic_type, SemanticType::Field as i32);
assert_eq!(schema[4].column_name, "timestamp_value");
assert_eq!(
schema[4].datatype,
ColumnDataType::TimestampMillisecond as i32
);
assert_eq!(schema[4].semantic_type, SemanticType::Timestamp as i32);
assert_eq!(schema[5].column_name, "json_value");
assert_eq!(schema[5].datatype, ColumnDataType::Json as i32);
assert_eq!(schema[5].semantic_type, SemanticType::Field as i32);
assert_eq!(
schema[5].datatype_extension,
Some(ColumnDataTypeExtension {
type_ext: Some(TypeExt::JsonType(JsonTypeExtension::JsonBinary as i32))
})
);
#[allow(clippy::extra_unused_type_parameters)]
fn test_derive() {
let _ = Foo::default();
assert_fields!(Foo: input_types);
assert_impl_all!(Foo: std::fmt::Debug, Default, common_query::logical_plan::accumulator::AggrFuncTypeStore);
}

View File

@@ -17,7 +17,7 @@ pg_kvbackend = [
"dep:rustls",
]
mysql_kvbackend = ["dep:sqlx", "dep:backon"]
enterprise = ["prost-types"]
enterprise = []
[lints]
workspace = true
@@ -56,7 +56,6 @@ etcd-client.workspace = true
flexbuffers = "25.2"
futures.workspace = true
futures-util.workspace = true
greptime-proto.workspace = true
hex.workspace = true
humantime-serde.workspace = true
itertools.workspace = true
@@ -65,7 +64,6 @@ moka.workspace = true
object-store.workspace = true
prometheus.workspace = true
prost.workspace = true
prost-types = { workspace = true, optional = true }
rand.workspace = true
regex.workspace = true
rskafka.workspace = true

View File

@@ -91,12 +91,10 @@ fn init_factory(table_flow_manager: TableFlowManagerRef) -> Initializer<TableId,
.map(Arc::new)
.map(Some)
.inspect(|set| {
if set.as_ref().map(|s| !s.is_empty()).unwrap_or(false) {
info!(
"Initialized table_flownode cache for table_id: {}, set: {:?}",
table_id, set
);
};
info!(
"Initialized table_flownode cache for table_id: {}, set: {:?}",
table_id, set
);
})
})
})

View File

@@ -63,10 +63,7 @@ pub struct Stat {
pub wcus: i64,
/// How many regions on this node
pub region_num: u64,
/// The region stats of the datanode.
pub region_stats: Vec<RegionStat>,
/// The topic stats of the datanode.
pub topic_stats: Vec<TopicStat>,
// The node epoch is used to check whether the node has restarted or redeployed.
pub node_epoch: u64,
/// The datanode workloads.
@@ -102,8 +99,6 @@ pub struct RegionStat {
pub index_size: u64,
/// The manifest infoof the region.
pub region_manifest: RegionManifestInfo,
/// The write bytes.
pub write_bytes: u64,
/// The latest entry id of topic used by data.
/// **Only used by remote WAL prune.**
pub data_topic_latest_entry_id: u64,
@@ -113,24 +108,6 @@ pub struct RegionStat {
pub metadata_topic_latest_entry_id: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopicStat {
/// The topic name.
pub topic: String,
/// The latest entry id of the topic.
pub latest_entry_id: u64,
/// The total size in bytes of records appended to the topic.
pub record_size: u64,
/// The total number of records appended to the topic.
pub record_num: u64,
}
/// Trait for reporting statistics about topics.
pub trait TopicStatsReporter: Send + Sync {
/// Returns a list of topic statistics that can be reported.
fn reportable_topics(&mut self) -> Vec<TopicStat>;
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum RegionManifestInfo {
Mito {
@@ -226,7 +203,6 @@ impl TryFrom<&HeartbeatRequest> for Stat {
region_stats,
node_epoch,
node_workloads,
topic_stats,
..
} = value;
@@ -236,7 +212,6 @@ impl TryFrom<&HeartbeatRequest> for Stat {
.iter()
.map(RegionStat::from)
.collect::<Vec<_>>();
let topic_stats = topic_stats.iter().map(TopicStat::from).collect::<Vec<_>>();
let datanode_workloads = get_datanode_workloads(node_workloads.as_ref());
Ok(Self {
@@ -249,7 +224,6 @@ impl TryFrom<&HeartbeatRequest> for Stat {
wcus: region_stats.iter().map(|s| s.wcus).sum(),
region_num: region_stats.len() as u64,
region_stats,
topic_stats,
node_epoch: *node_epoch,
datanode_workloads,
})
@@ -306,24 +280,12 @@ impl From<&api::v1::meta::RegionStat> for RegionStat {
sst_num: region_stat.sst_num,
index_size: region_stat.index_size,
region_manifest: region_stat.manifest.into(),
write_bytes: region_stat.write_bytes,
data_topic_latest_entry_id: region_stat.data_topic_latest_entry_id,
metadata_topic_latest_entry_id: region_stat.metadata_topic_latest_entry_id,
}
}
}
impl From<&api::v1::meta::TopicStat> for TopicStat {
fn from(value: &api::v1::meta::TopicStat) -> Self {
Self {
topic: value.topic_name.clone(),
latest_entry_id: value.latest_entry_id,
record_size: value.record_size,
record_num: value.record_num,
}
}
}
/// The key of the datanode stat in the memory store.
///
/// The format is `__meta_datanode_stat-0-{node_id}`.

View File

@@ -47,7 +47,7 @@ use crate::key::{DeserializedValueWithBytes, FlowId, FlowPartitionId};
use crate::lock_key::{CatalogLock, FlowNameLock, TableNameLock};
use crate::metrics;
use crate::peer::Peer;
use crate::rpc::ddl::{CreateFlowTask, FlowQueryContext, QueryContext};
use crate::rpc::ddl::{CreateFlowTask, QueryContext};
/// The procedure of flow creation.
pub struct CreateFlowProcedure {
@@ -67,7 +67,7 @@ impl CreateFlowProcedure {
flow_id: None,
peers: vec![],
source_table_ids: vec![],
flow_context: query_context.into(), // Convert to FlowQueryContext
query_context,
state: CreateFlowState::Prepare,
prev_flow_info_value: None,
did_replace: false,
@@ -204,8 +204,7 @@ impl CreateFlowProcedure {
let request = FlowRequest {
header: Some(FlowRequestHeader {
tracing_context: TracingContext::from_current_span().to_w3c(),
// Convert FlowQueryContext to QueryContext
query_context: Some(QueryContext::from(self.data.flow_context.clone()).into()),
query_context: Some(self.data.query_context.clone().into()),
}),
body: Some(PbFlowRequest::Create((&self.data).into())),
};
@@ -416,9 +415,7 @@ pub struct CreateFlowData {
pub(crate) flow_id: Option<FlowId>,
pub(crate) peers: Vec<Peer>,
pub(crate) source_table_ids: Vec<TableId>,
/// Use alias for backward compatibility with QueryContext serialized data
#[serde(alias = "query_context")]
pub(crate) flow_context: FlowQueryContext,
pub(crate) query_context: QueryContext,
/// For verify if prev value is consistent when need to update flow metadata.
/// only set when `or_replace` is true.
pub(crate) prev_flow_info_value: Option<DeserializedValueWithBytes<FlowInfoValue>>,
@@ -498,8 +495,7 @@ impl From<&CreateFlowData> for (FlowInfoValue, Vec<(FlowPartitionId, FlowRouteVa
sink_table_name,
flownode_ids,
catalog_name,
// Convert FlowQueryContext back to QueryContext for storage
query_context: Some(QueryContext::from(value.flow_context.clone())),
query_context: Some(value.query_context.clone()),
flow_name,
raw_sql: sql,
expire_after,

View File

@@ -68,7 +68,6 @@ impl CreateLogicalTablesProcedure {
physical_table_id,
physical_region_numbers: vec![],
physical_columns: vec![],
physical_partition_columns: vec![],
},
}
}
@@ -92,8 +91,6 @@ impl CreateLogicalTablesProcedure {
self.check_input_tasks()?;
// Sets physical region numbers
self.fill_physical_table_info().await?;
// Add partition columns from physical table to logical table schemas
self.merge_partition_columns_into_logical_tables()?;
// Checks if the tables exist
self.check_tables_already_exist().await?;
@@ -260,7 +257,6 @@ pub struct CreateTablesData {
physical_table_id: TableId,
physical_region_numbers: Vec<RegionNumber>,
physical_columns: Vec<ColumnMetadata>,
physical_partition_columns: Vec<String>,
}
impl CreateTablesData {

View File

@@ -12,12 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashSet;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, RawSchema};
use snafu::OptionExt;
use crate::ddl::create_logical_tables::CreateLogicalTablesProcedure;
use crate::error::Result;
use crate::key::table_route::TableRouteValue;
@@ -34,89 +28,6 @@ impl CreateLogicalTablesProcedure {
self.data.physical_region_numbers = physical_region_numbers;
// Extract partition column names from the physical table
let physical_table_info = self
.context
.table_metadata_manager
.table_info_manager()
.get(self.data.physical_table_id)
.await?
.with_context(|| crate::error::TableInfoNotFoundSnafu {
table: format!("physical table {}", self.data.physical_table_id),
})?;
let physical_partition_columns: Vec<String> = physical_table_info
.table_info
.meta
.partition_key_indices
.iter()
.map(|&idx| {
physical_table_info.table_info.meta.schema.column_schemas[idx]
.name
.clone()
})
.collect();
self.data.physical_partition_columns = physical_partition_columns;
Ok(())
}
pub(crate) fn merge_partition_columns_into_logical_tables(&mut self) -> Result<()> {
let partition_columns = &self.data.physical_partition_columns;
// Skip if no partition columns to add
if partition_columns.is_empty() {
return Ok(());
}
for task in &mut self.data.tasks {
// Get existing column names in the logical table
let existing_column_names: HashSet<_> = task
.table_info
.meta
.schema
.column_schemas
.iter()
.map(|c| &c.name)
.collect();
let mut new_columns = Vec::new();
let mut new_primary_key_indices = task.table_info.meta.primary_key_indices.clone();
// Add missing partition columns
for partition_column in partition_columns {
if !existing_column_names.contains(partition_column) {
let new_column_index =
task.table_info.meta.schema.column_schemas.len() + new_columns.len();
// Create new column schema for the partition column
let column_schema = ColumnSchema::new(
partition_column.clone(),
ConcreteDataType::string_datatype(),
true,
);
new_columns.push(column_schema);
// Add to primary key indices (partition columns are part of primary key)
new_primary_key_indices.push(new_column_index);
}
}
// If we added new columns, update the table info
if !new_columns.is_empty() {
let mut updated_columns = task.table_info.meta.schema.column_schemas.clone();
updated_columns.extend(new_columns);
// Create new schema with updated columns
let new_schema = RawSchema::new(updated_columns);
// Update the table info
task.table_info.meta.schema = new_schema;
task.table_info.meta.primary_key_indices = new_primary_key_indices;
}
}
Ok(())
}

View File

@@ -19,12 +19,9 @@ use api::v1::CreateTableExpr;
use common_telemetry::debug;
use common_telemetry::tracing_context::TracingContext;
use store_api::storage::{RegionId, TableId};
use table::metadata::RawTableInfo;
use crate::ddl::create_logical_tables::CreateLogicalTablesProcedure;
use crate::ddl::create_table_template::{
build_template, build_template_from_raw_table_info, CreateRequestBuilder,
};
use crate::ddl::create_table_template::{build_template, CreateRequestBuilder};
use crate::ddl::utils::region_storage_path;
use crate::error::Result;
use crate::peer::Peer;
@@ -40,10 +37,6 @@ impl CreateLogicalTablesProcedure {
let table_ids_already_exists = &self.data.table_ids_already_exists;
let regions_on_this_peer = find_leader_regions(region_routes, peer);
let mut requests = Vec::with_capacity(tasks.len() * regions_on_this_peer.len());
let partition_exprs = region_routes
.iter()
.map(|r| (r.region.id.region_number(), r.region.partition_expr()))
.collect();
for (task, table_id_already_exists) in tasks.iter().zip(table_ids_already_exists) {
if table_id_already_exists.is_some() {
continue;
@@ -54,19 +47,13 @@ impl CreateLogicalTablesProcedure {
let logical_table_id = task.table_info.ident.table_id;
let physical_table_id = self.data.physical_table_id;
let storage_path = region_storage_path(catalog, schema);
let request_builder = create_region_request_builder_from_raw_table_info(
&task.table_info,
physical_table_id,
)?;
let request_builder =
create_region_request_builder(&task.create_table, physical_table_id)?;
for region_number in &regions_on_this_peer {
let region_id = RegionId::new(logical_table_id, *region_number);
let one_region_request = request_builder.build_one(
region_id,
storage_path.clone(),
&HashMap::new(),
&partition_exprs,
);
let one_region_request =
request_builder.build_one(region_id, storage_path.clone(), &HashMap::new());
requests.push(one_region_request);
}
}
@@ -86,7 +73,7 @@ impl CreateLogicalTablesProcedure {
}
}
/// Creates a region request builder
/// Creates a region request builder.
pub fn create_region_request_builder(
create_table_expr: &CreateTableExpr,
physical_table_id: TableId,
@@ -94,14 +81,3 @@ pub fn create_region_request_builder(
let template = build_template(create_table_expr)?;
Ok(CreateRequestBuilder::new(template, Some(physical_table_id)))
}
/// Builds a [CreateRequestBuilder] from a [RawTableInfo].
///
/// Note: **This method is only used for creating logical tables.**
pub fn create_region_request_builder_from_raw_table_info(
raw_table_info: &RawTableInfo,
physical_table_id: TableId,
) -> Result<CreateRequestBuilder> {
let template = build_template_from_raw_table_info(raw_table_info)?;
Ok(CreateRequestBuilder::new(template, Some(physical_table_id)))
}

View File

@@ -214,11 +214,6 @@ impl CreateTableProcedure {
let leaders = find_leaders(region_routes);
let mut create_region_tasks = Vec::with_capacity(leaders.len());
let partition_exprs = region_routes
.iter()
.map(|r| (r.region.id.region_number(), r.region.partition_expr()))
.collect();
for datanode in leaders {
let requester = self.context.node_manager.datanode(&datanode).await;
@@ -226,12 +221,8 @@ impl CreateTableProcedure {
let mut requests = Vec::with_capacity(regions.len());
for region_number in regions {
let region_id = RegionId::new(self.table_id(), region_number);
let create_region_request = request_builder.build_one(
region_id,
storage_path.clone(),
region_wal_options,
&partition_exprs,
);
let create_region_request =
request_builder.build_one(region_id, storage_path.clone(), region_wal_options);
requests.push(PbRegionRequest::Create(create_region_request));
}

View File

@@ -15,10 +15,8 @@
use std::collections::HashMap;
use api::v1::column_def::try_as_column_def;
use api::v1::meta::Partition;
use api::v1::region::{CreateRequest, RegionColumnDef};
use api::v1::{ColumnDef, CreateTableExpr, SemanticType};
use common_telemetry::warn;
use snafu::{OptionExt, ResultExt};
use store_api::metric_engine_consts::{LOGICAL_TABLE_METADATA_KEY, METRIC_ENGINE_NAME};
use store_api::storage::{RegionId, RegionNumber};
@@ -62,7 +60,6 @@ pub(crate) fn build_template_from_raw_table_info(
primary_key: primary_key_indices.iter().map(|i| *i as u32).collect(),
path: String::new(),
options,
partition: None,
};
Ok(template)
@@ -124,7 +121,6 @@ pub(crate) fn build_template(create_table_expr: &CreateTableExpr) -> Result<Crea
primary_key,
path: String::new(),
options: create_table_expr.table_options.clone(),
partition: None,
};
Ok(template)
@@ -154,7 +150,6 @@ impl CreateRequestBuilder {
region_id: RegionId,
storage_path: String,
region_wal_options: &HashMap<RegionNumber, String>,
partition_exprs: &HashMap<RegionNumber, String>,
) -> CreateRequest {
let mut request = self.template.clone();
@@ -162,7 +157,6 @@ impl CreateRequestBuilder {
request.path = storage_path;
// Stores the encoded wal options into the request options.
prepare_wal_options(&mut request.options, region_id, region_wal_options);
request.partition = Some(prepare_partition_expr(region_id, partition_exprs));
if let Some(physical_table_id) = self.physical_table_id {
// Logical table has the same region numbers with physical table, and they have a one-to-one mapping.
@@ -179,55 +173,3 @@ impl CreateRequestBuilder {
request
}
}
fn prepare_partition_expr(
region_id: RegionId,
partition_exprs: &HashMap<RegionNumber, String>,
) -> Partition {
let expr = partition_exprs.get(&region_id.region_number()).cloned();
if expr.is_none() {
warn!("region {} has no partition expr", region_id);
}
Partition {
expression: expr.unwrap_or_default(),
..Default::default()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use store_api::storage::{RegionId, RegionNumber};
use super::*;
#[test]
fn test_build_one_sets_partition_expr_per_region() {
// minimal template
let template = CreateRequest {
region_id: 0,
engine: "mito".to_string(),
column_defs: vec![],
primary_key: vec![],
path: String::new(),
options: Default::default(),
partition: None,
};
let builder = CreateRequestBuilder::new(template, None);
let mut partition_exprs: HashMap<RegionNumber, String> = HashMap::new();
let expr_a =
r#"{"Expr":{"lhs":{"Column":"a"},"op":"Eq","rhs":{"Value":{"UInt32":1}}}}"#.to_string();
partition_exprs.insert(0, expr_a.clone());
let r0 = builder.build_one(
RegionId::new(42, 0),
"/p".to_string(),
&Default::default(),
&partition_exprs,
);
assert_eq!(r0.partition.as_ref().unwrap().expression, expr_a);
}
}

View File

@@ -18,17 +18,17 @@ use std::sync::Arc;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_procedure_test::execute_procedure_until_done;
use session::context::QueryContext as SessionQueryContext;
use session::context::QueryContext;
use table::table_name::TableName;
use crate::ddl::create_flow::{CreateFlowData, CreateFlowProcedure, CreateFlowState, FlowType};
use crate::ddl::create_flow::CreateFlowProcedure;
use crate::ddl::test_util::create_table::test_create_table_task;
use crate::ddl::test_util::flownode_handler::NaiveFlownodeHandler;
use crate::ddl::DdlContext;
use crate::error;
use crate::key::table_route::TableRouteValue;
use crate::key::FlowId;
use crate::rpc::ddl::{CreateFlowTask, FlowQueryContext, QueryContext};
use crate::rpc::ddl::CreateFlowTask;
use crate::test_util::{new_ddl_context, MockFlownodeManager};
pub(crate) fn test_create_flow_task(
@@ -63,7 +63,7 @@ async fn test_create_flow_source_table_not_found() {
let task = test_create_flow_task("my_flow", source_table_names, sink_table_name, false);
let node_manager = Arc::new(MockFlownodeManager::new(NaiveFlownodeHandler));
let ddl_context = new_ddl_context(node_manager);
let query_ctx = SessionQueryContext::arc().into();
let query_ctx = QueryContext::arc().into();
let mut procedure = CreateFlowProcedure::new(task, query_ctx, ddl_context);
let err = procedure.on_prepare().await.unwrap_err();
assert_matches!(err, error::Error::TableNotFound { .. });
@@ -81,7 +81,7 @@ pub(crate) async fn create_test_flow(
sink_table_name.clone(),
false,
);
let query_ctx = SessionQueryContext::arc().into();
let query_ctx = QueryContext::arc().into();
let mut procedure = CreateFlowProcedure::new(task.clone(), query_ctx, ddl_context.clone());
let output = execute_procedure_until_done(&mut procedure).await.unwrap();
let flow_id = output.downcast_ref::<FlowId>().unwrap();
@@ -128,7 +128,7 @@ async fn test_create_flow() {
sink_table_name.clone(),
true,
);
let query_ctx = SessionQueryContext::arc().into();
let query_ctx = QueryContext::arc().into();
let mut procedure = CreateFlowProcedure::new(task.clone(), query_ctx, ddl_context.clone());
let output = execute_procedure_until_done(&mut procedure).await.unwrap();
let flow_id = output.downcast_ref::<FlowId>().unwrap();
@@ -136,7 +136,7 @@ async fn test_create_flow() {
// Creates again
let task = test_create_flow_task("my_flow", source_table_names, sink_table_name, false);
let query_ctx = SessionQueryContext::arc().into();
let query_ctx = QueryContext::arc().into();
let mut procedure = CreateFlowProcedure::new(task.clone(), query_ctx, ddl_context);
let err = procedure.on_prepare().await.unwrap_err();
assert_matches!(err, error::Error::FlowAlreadyExists { .. });
@@ -168,7 +168,7 @@ async fn test_create_flow_same_source_and_sink_table() {
// Try to create a flow with same source and sink table - should fail
let task = test_create_flow_task("my_flow", source_table_names, sink_table_name, false);
let query_ctx = SessionQueryContext::arc().into();
let query_ctx = QueryContext::arc().into();
let mut procedure = CreateFlowProcedure::new(task, query_ctx, ddl_context);
let err = procedure.on_prepare().await.unwrap_err();
assert_matches!(err, error::Error::Unsupported { .. });
@@ -179,165 +179,3 @@ async fn test_create_flow_same_source_and_sink_table() {
assert!(operation.contains("same_table"));
}
}
fn create_test_flow_task_for_serialization() -> CreateFlowTask {
CreateFlowTask {
catalog_name: "test_catalog".to_string(),
flow_name: "test_flow".to_string(),
source_table_names: vec![TableName::new("catalog", "schema", "source_table")],
sink_table_name: TableName::new("catalog", "schema", "sink_table"),
or_replace: false,
create_if_not_exists: false,
expire_after: None,
comment: "test comment".to_string(),
sql: "SELECT * FROM source_table".to_string(),
flow_options: HashMap::new(),
}
}
#[test]
fn test_create_flow_data_serialization_backward_compatibility() {
// Test that old serialized data with query_context can be deserialized
let old_json = r#"{
"state": "Prepare",
"task": {
"catalog_name": "test_catalog",
"flow_name": "test_flow",
"source_table_names": [{"catalog_name": "catalog", "schema_name": "schema", "table_name": "source"}],
"sink_table_name": {"catalog_name": "catalog", "schema_name": "schema", "table_name": "sink"},
"or_replace": false,
"create_if_not_exists": false,
"expire_after": null,
"comment": "test",
"sql": "SELECT * FROM source",
"flow_options": {}
},
"flow_id": null,
"peers": [],
"source_table_ids": [],
"query_context": {
"current_catalog": "old_catalog",
"current_schema": "old_schema",
"timezone": "UTC",
"extensions": {},
"channel": 0
},
"prev_flow_info_value": null,
"did_replace": false,
"flow_type": null
}"#;
let data: CreateFlowData = serde_json::from_str(old_json).unwrap();
assert_eq!(data.flow_context.catalog, "old_catalog");
assert_eq!(data.flow_context.schema, "old_schema");
assert_eq!(data.flow_context.timezone, "UTC");
}
#[test]
fn test_create_flow_data_new_format_serialization() {
// Test new format serialization/deserialization
let flow_context = FlowQueryContext {
catalog: "new_catalog".to_string(),
schema: "new_schema".to_string(),
timezone: "America/New_York".to_string(),
};
let data = CreateFlowData {
state: CreateFlowState::Prepare,
task: create_test_flow_task_for_serialization(),
flow_id: None,
peers: vec![],
source_table_ids: vec![],
flow_context,
prev_flow_info_value: None,
did_replace: false,
flow_type: None,
};
let serialized = serde_json::to_string(&data).unwrap();
let deserialized: CreateFlowData = serde_json::from_str(&serialized).unwrap();
assert_eq!(data.flow_context, deserialized.flow_context);
assert_eq!(deserialized.flow_context.catalog, "new_catalog");
assert_eq!(deserialized.flow_context.schema, "new_schema");
assert_eq!(deserialized.flow_context.timezone, "America/New_York");
}
#[test]
fn test_flow_query_context_conversion_from_query_context() {
let query_context = QueryContext {
current_catalog: "prod_catalog".to_string(),
current_schema: "public".to_string(),
timezone: "America/Los_Angeles".to_string(),
extensions: [
("unused_key".to_string(), "unused_value".to_string()),
("another_key".to_string(), "another_value".to_string()),
]
.into(),
channel: 99,
};
let flow_context: FlowQueryContext = query_context.into();
assert_eq!(flow_context.catalog, "prod_catalog");
assert_eq!(flow_context.schema, "public");
assert_eq!(flow_context.timezone, "America/Los_Angeles");
}
#[test]
fn test_flow_info_conversion_with_flow_context() {
let flow_context = FlowQueryContext {
catalog: "info_catalog".to_string(),
schema: "info_schema".to_string(),
timezone: "Europe/Berlin".to_string(),
};
let data = CreateFlowData {
state: CreateFlowState::CreateMetadata,
task: create_test_flow_task_for_serialization(),
flow_id: Some(123),
peers: vec![],
source_table_ids: vec![456, 789],
flow_context,
prev_flow_info_value: None,
did_replace: false,
flow_type: Some(FlowType::Batching),
};
let (flow_info, _routes) = (&data).into();
assert!(flow_info.query_context.is_some());
let query_context = flow_info.query_context.unwrap();
assert_eq!(query_context.current_catalog(), "info_catalog");
assert_eq!(query_context.current_schema(), "info_schema");
assert_eq!(query_context.timezone(), "Europe/Berlin");
assert_eq!(query_context.channel(), 0);
assert!(query_context.extensions().is_empty());
}
#[test]
fn test_mixed_serialization_format_support() {
// Test that we can deserialize both old and new formats
// Test new FlowQueryContext format
let new_format = r#"{"catalog": "test", "schema": "test", "timezone": "UTC"}"#;
let ctx_from_new: FlowQueryContext = serde_json::from_str(new_format).unwrap();
assert_eq!(ctx_from_new.catalog, "test");
assert_eq!(ctx_from_new.schema, "test");
assert_eq!(ctx_from_new.timezone, "UTC");
// Test old QueryContext format conversion
let old_format = r#"{"current_catalog": "old_test", "current_schema": "old_schema", "timezone": "PST", "extensions": {}, "channel": 0}"#;
let ctx_from_old: FlowQueryContext = serde_json::from_str(old_format).unwrap();
assert_eq!(ctx_from_old.catalog, "old_test");
assert_eq!(ctx_from_old.schema, "old_schema");
assert_eq!(ctx_from_old.timezone, "PST");
// Test that they can be compared
let expected_new = FlowQueryContext {
catalog: "test".to_string(),
schema: "test".to_string(),
timezone: "UTC".to_string(),
};
assert_eq!(ctx_from_new, expected_new);
}

View File

@@ -911,7 +911,7 @@ mod tests {
use crate::key::flow::FlowMetadataManager;
use crate::key::TableMetadataManager;
use crate::kv_backend::memory::MemoryKvBackend;
use crate::node_manager::{DatanodeManager, DatanodeRef, FlownodeManager, FlownodeRef};
use crate::node_manager::{DatanodeRef, FlownodeRef, NodeManager};
use crate::peer::Peer;
use crate::region_keeper::MemoryRegionKeeper;
use crate::region_registry::LeaderRegionRegistry;
@@ -923,14 +923,11 @@ mod tests {
pub struct DummyDatanodeManager;
#[async_trait::async_trait]
impl DatanodeManager for DummyDatanodeManager {
impl NodeManager for DummyDatanodeManager {
async fn datanode(&self, _datanode: &Peer) -> DatanodeRef {
unimplemented!()
}
}
#[async_trait::async_trait]
impl FlownodeManager for DummyDatanodeManager {
async fn flownode(&self, _node: &Peer) -> FlownodeRef {
unimplemented!()
}

View File

@@ -43,9 +43,3 @@ pub const META_KEEP_ALIVE_INTERVAL_SECS: u64 = META_LEASE_SECS / 2;
/// The default mailbox round-trip timeout.
pub const MAILBOX_RTT_SECS: u64 = 1;
/// The interval of reporting topic stats.
pub const TOPIC_STATS_REPORT_INTERVAL_SECS: u64 = 15;
/// The retention seconds of topic stats.
pub const TOPIC_STATS_RETENTION_SECS: u64 = TOPIC_STATS_REPORT_INTERVAL_SECS * 100;

View File

@@ -375,13 +375,6 @@ pub enum Error {
location: Location,
},
#[snafu(display("Region not found: {}", region_id))]
RegionNotFound {
region_id: RegionId,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("View not found: '{}'", view_name))]
ViewNotFound {
view_name: String,
@@ -1027,31 +1020,6 @@ pub enum Error {
actual_column_name: String,
actual_column_id: u32,
},
#[cfg(feature = "enterprise")]
#[snafu(display("Too large duration"))]
TooLargeDuration {
#[snafu(source)]
error: prost_types::DurationError,
#[snafu(implicit)]
location: Location,
},
#[cfg(feature = "enterprise")]
#[snafu(display("Negative duration"))]
NegativeDuration {
#[snafu(source)]
error: prost_types::DurationError,
#[snafu(implicit)]
location: Location,
},
#[cfg(feature = "enterprise")]
#[snafu(display("Missing interval field"))]
MissingInterval {
#[snafu(implicit)]
location: Location,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -1141,21 +1109,14 @@ impl ErrorExt for Error {
| InvalidTimeZone { .. }
| InvalidFileExtension { .. }
| InvalidFileName { .. }
| InvalidFlowRequestBody { .. }
| InvalidFilePath { .. } => StatusCode::InvalidArguments,
#[cfg(feature = "enterprise")]
MissingInterval { .. } | NegativeDuration { .. } | TooLargeDuration { .. } => {
StatusCode::InvalidArguments
}
InvalidFlowRequestBody { .. } => StatusCode::InvalidArguments,
FlowNotFound { .. } => StatusCode::FlowNotFound,
FlowRouteNotFound { .. } => StatusCode::Unexpected,
FlowAlreadyExists { .. } => StatusCode::FlowAlreadyExists,
ViewNotFound { .. } | TableNotFound { .. } | RegionNotFound { .. } => {
StatusCode::TableNotFound
}
ViewNotFound { .. } | TableNotFound { .. } => StatusCode::TableNotFound,
ViewAlreadyExists { .. } | TableAlreadyExists { .. } => StatusCode::TableAlreadyExists,
SubmitProcedure { source, .. }

Some files were not shown because too many files have changed in this diff Show More