mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-25 15:40:02 +00:00
Compare commits
106 Commits
feat/inges
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
294f19fa1d | ||
|
|
be530ac1de | ||
|
|
434b4d8183 | ||
|
|
3ad0b60c4b | ||
|
|
19ae845225 | ||
|
|
3866512cf6 | ||
|
|
d4870ee2af | ||
|
|
aea4e9fa55 | ||
|
|
cea578244c | ||
|
|
e1b18614ee | ||
|
|
4bae75ccdb | ||
|
|
dc9f3a702e | ||
|
|
2d9967b981 | ||
|
|
dec0d522f8 | ||
|
|
17e2b98132 | ||
|
|
ee86987912 | ||
|
|
0cea58c642 | ||
|
|
fdedbb8261 | ||
|
|
8d9afc83e3 | ||
|
|
625fdd09ea | ||
|
|
b3bc3c76f1 | ||
|
|
342eb47e19 | ||
|
|
6a6b34c709 | ||
|
|
a8b512dded | ||
|
|
bd8ffd3db9 | ||
|
|
c0652f6dd5 | ||
|
|
fed6cb0806 | ||
|
|
69659211f6 | ||
|
|
6332d91884 | ||
|
|
4d66bd96b8 | ||
|
|
2f4a15ec40 | ||
|
|
658332fe68 | ||
|
|
c088d361a4 | ||
|
|
a85864067e | ||
|
|
0df69c95aa | ||
|
|
72eede8b38 | ||
|
|
95eccd6cde | ||
|
|
0bc5a305be | ||
|
|
1afcddd5a9 | ||
|
|
62808b887b | ||
|
|
04ddd40e00 | ||
|
|
b4f028be5f | ||
|
|
da964880f5 | ||
|
|
a35a39f726 | ||
|
|
e0c1566e92 | ||
|
|
f6afb10e33 | ||
|
|
2dfcf35fee | ||
|
|
f7d5c87ac0 | ||
|
|
9cd57e9342 | ||
|
|
32f9cc5286 | ||
|
|
5232a12a8c | ||
|
|
913ac325e5 | ||
|
|
0c52d5bb34 | ||
|
|
e0697790e6 | ||
|
|
64e74916b9 | ||
|
|
b601781604 | ||
|
|
bd3ad60910 | ||
|
|
cbfdeca64c | ||
|
|
baffed8c6a | ||
|
|
11a5e1618d | ||
|
|
f5e0e94e3a | ||
|
|
ba4eda40e5 | ||
|
|
f06a64ff90 | ||
|
|
84b4777925 | ||
|
|
a26dee0ca1 | ||
|
|
276f6bf026 | ||
|
|
1d5291b06d | ||
|
|
564cc0c750 | ||
|
|
f1abe5d215 | ||
|
|
ab426cbf89 | ||
|
|
cb0f1afb01 | ||
|
|
a22d08f1b1 | ||
|
|
6817a376b5 | ||
|
|
4d1a587079 | ||
|
|
9f1aefe98f | ||
|
|
2f9130a2de | ||
|
|
fa2b4e5e63 | ||
|
|
9197e818ec | ||
|
|
36d89c3baf | ||
|
|
0ebfd161d8 | ||
|
|
8b26a98c3b | ||
|
|
7199823be9 | ||
|
|
60f752d306 | ||
|
|
edb1f6086f | ||
|
|
1ebcef4794 | ||
|
|
2147545c90 | ||
|
|
84e4e42ee7 | ||
|
|
d5c616a9ff | ||
|
|
f02bdf5428 | ||
|
|
f2288a86b0 | ||
|
|
9d35b8cad4 | ||
|
|
cc99f9d65b | ||
|
|
11ecb7a28a | ||
|
|
2a760f010f | ||
|
|
63dd37dca3 | ||
|
|
68fff3b1aa | ||
|
|
0177f244e9 | ||
|
|
931556dbd3 | ||
|
|
69f0249039 | ||
|
|
1f91422bae | ||
|
|
377373b8fd | ||
|
|
e107030d85 | ||
|
|
18875eed4d | ||
|
|
ee76d50569 | ||
|
|
5d634aeba0 | ||
|
|
8346acb900 |
@@ -51,7 +51,7 @@ runs:
|
||||
run: |
|
||||
helm upgrade \
|
||||
--install my-greptimedb \
|
||||
--set meta.backendStorage.etcd.endpoints=${{ inputs.etcd-endpoints }} \
|
||||
--set 'meta.backendStorage.etcd.endpoints[0]=${{ inputs.etcd-endpoints }}' \
|
||||
--set meta.enableRegionFailover=${{ inputs.enable-region-failover }} \
|
||||
--set image.registry=${{ inputs.image-registry }} \
|
||||
--set image.repository=${{ inputs.image-repository }} \
|
||||
|
||||
11
.github/scripts/create-version.sh
vendored
11
.github/scripts/create-version.sh
vendored
@@ -49,6 +49,17 @@ function create_version() {
|
||||
echo "GITHUB_REF_NAME is empty in push event" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# For tag releases, ensure GITHUB_REF_NAME matches the version in Cargo.toml
|
||||
CARGO_VERSION=$(grep '^version = ' Cargo.toml | cut -d '"' -f 2 | head -n 1)
|
||||
EXPECTED_REF_NAME="v${CARGO_VERSION}"
|
||||
|
||||
if [ "$GITHUB_REF_NAME" != "$EXPECTED_REF_NAME" ]; then
|
||||
echo "Error: GITHUB_REF_NAME '$GITHUB_REF_NAME' does not match Cargo.toml version 'v${CARGO_VERSION}'" >&2
|
||||
echo "Expected tag name: '$EXPECTED_REF_NAME'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "$GITHUB_REF_NAME"
|
||||
elif [ "$GITHUB_EVENT_NAME" = workflow_dispatch ]; then
|
||||
echo "$NEXT_RELEASE_VERSION-$(git rev-parse --short HEAD)-$(date "+%Y%m%d-%s")"
|
||||
|
||||
4
.github/scripts/deploy-greptimedb.sh
vendored
4
.github/scripts/deploy-greptimedb.sh
vendored
@@ -81,7 +81,7 @@ function deploy_greptimedb_cluster() {
|
||||
--create-namespace \
|
||||
--set image.tag="$GREPTIMEDB_IMAGE_TAG" \
|
||||
--set initializer.tag="$GREPTIMEDB_INITIALIZER_IMAGE_TAG" \
|
||||
--set meta.backendStorage.etcd.endpoints="etcd.$install_namespace:2379" \
|
||||
--set "meta.backendStorage.etcd.endpoints[0]=etcd.$install_namespace.svc.cluster.local:2379" \
|
||||
--set meta.backendStorage.etcd.storeKeyPrefix="$cluster_name" \
|
||||
-n "$install_namespace"
|
||||
|
||||
@@ -119,7 +119,7 @@ function deploy_greptimedb_cluster_with_s3_storage() {
|
||||
--create-namespace \
|
||||
--set image.tag="$GREPTIMEDB_IMAGE_TAG" \
|
||||
--set initializer.tag="$GREPTIMEDB_INITIALIZER_IMAGE_TAG" \
|
||||
--set meta.backendStorage.etcd.endpoints="etcd.$install_namespace:2379" \
|
||||
--set "meta.backendStorage.etcd.endpoints[0]=etcd.$install_namespace.svc.cluster.local:2379" \
|
||||
--set meta.backendStorage.etcd.storeKeyPrefix="$cluster_name" \
|
||||
--set objectStorage.s3.bucket="$AWS_CI_TEST_BUCKET" \
|
||||
--set objectStorage.s3.region="$AWS_REGION" \
|
||||
|
||||
154
.github/workflows/check-git-deps.yml
vendored
Normal file
154
.github/workflows/check-git-deps.yml
vendored
Normal file
@@ -0,0 +1,154 @@
|
||||
name: Check Git Dependencies on Main Branch
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'Cargo.toml'
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'Cargo.toml'
|
||||
|
||||
jobs:
|
||||
check-git-deps:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Check git dependencies
|
||||
env:
|
||||
WHITELIST_DEPS: "greptime-proto,meter-core,meter-macros"
|
||||
run: |
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "Checking whitelisted git dependencies..."
|
||||
|
||||
# Function to check if a commit is on main branch
|
||||
check_commit_on_main() {
|
||||
local repo_url="$1"
|
||||
local commit="$2"
|
||||
local repo_name=$(basename "$repo_url" .git)
|
||||
|
||||
echo "Checking $repo_name"
|
||||
echo "Repo: $repo_url"
|
||||
echo "Commit: $commit"
|
||||
|
||||
# Create a temporary directory for cloning
|
||||
local temp_dir=$(mktemp -d)
|
||||
|
||||
# Clone the repository
|
||||
if git clone "$repo_url" "$temp_dir" 2>/dev/null; then
|
||||
cd "$temp_dir"
|
||||
|
||||
# Try to determine the main branch name
|
||||
local main_branch="main"
|
||||
if ! git rev-parse --verify origin/main >/dev/null 2>&1; then
|
||||
if git rev-parse --verify origin/master >/dev/null 2>&1; then
|
||||
main_branch="master"
|
||||
else
|
||||
# Try to get the default branch
|
||||
main_branch=$(git symbolic-ref refs/remotes/origin/HEAD | sed 's@^refs/remotes/origin/@@')
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Main branch: $main_branch"
|
||||
|
||||
# Check if commit exists
|
||||
if git cat-file -e "$commit" 2>/dev/null; then
|
||||
# Check if commit is on main branch
|
||||
if git merge-base --is-ancestor "$commit" "origin/$main_branch" 2>/dev/null; then
|
||||
echo "PASS: Commit $commit is on $main_branch branch"
|
||||
cd - >/dev/null
|
||||
rm -rf "$temp_dir"
|
||||
return 0
|
||||
else
|
||||
echo "FAIL: Commit $commit is NOT on $main_branch branch"
|
||||
|
||||
# Try to find which branch contains this commit
|
||||
local branch_name=$(git branch -r --contains "$commit" 2>/dev/null | head -1 | sed 's/^[[:space:]]*origin\///' | sed 's/[[:space:]]*$//')
|
||||
if [[ -n "$branch_name" ]]; then
|
||||
echo "Found on branch: $branch_name"
|
||||
fi
|
||||
cd - >/dev/null
|
||||
rm -rf "$temp_dir"
|
||||
return 1
|
||||
fi
|
||||
else
|
||||
echo "FAIL: Commit $commit not found in repository"
|
||||
cd - >/dev/null
|
||||
rm -rf "$temp_dir"
|
||||
return 1
|
||||
fi
|
||||
else
|
||||
echo "FAIL: Failed to clone $repo_url"
|
||||
rm -rf "$temp_dir"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Extract whitelisted git dependencies from Cargo.toml
|
||||
echo "Extracting git dependencies from Cargo.toml..."
|
||||
|
||||
# Create temporary array to store dependencies
|
||||
declare -a deps=()
|
||||
|
||||
# Build awk pattern from whitelist
|
||||
IFS=',' read -ra WHITELIST <<< "$WHITELIST_DEPS"
|
||||
awk_pattern=""
|
||||
for dep in "${WHITELIST[@]}"; do
|
||||
if [[ -n "$awk_pattern" ]]; then
|
||||
awk_pattern="$awk_pattern|"
|
||||
fi
|
||||
awk_pattern="$awk_pattern$dep"
|
||||
done
|
||||
|
||||
# Extract whitelisted dependencies
|
||||
while IFS= read -r line; do
|
||||
if [[ -n "$line" ]]; then
|
||||
deps+=("$line")
|
||||
fi
|
||||
done < <(awk -v pattern="$awk_pattern" '
|
||||
$0 ~ pattern ".*git = \"https:/" {
|
||||
match($0, /git = "([^"]+)"/, arr)
|
||||
git_url = arr[1]
|
||||
if (match($0, /rev = "([^"]+)"/, rev_arr)) {
|
||||
rev = rev_arr[1]
|
||||
print git_url " " rev
|
||||
} else {
|
||||
# Check next line for rev
|
||||
getline
|
||||
if (match($0, /rev = "([^"]+)"/, rev_arr)) {
|
||||
rev = rev_arr[1]
|
||||
print git_url " " rev
|
||||
}
|
||||
}
|
||||
}
|
||||
' Cargo.toml)
|
||||
|
||||
echo "Found ${#deps[@]} dependencies to check:"
|
||||
for dep in "${deps[@]}"; do
|
||||
echo " $dep"
|
||||
done
|
||||
|
||||
failed=0
|
||||
|
||||
for dep in "${deps[@]}"; do
|
||||
read -r repo_url commit <<< "$dep"
|
||||
if ! check_commit_on_main "$repo_url" "$commit"; then
|
||||
failed=1
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Check completed."
|
||||
|
||||
if [[ $failed -eq 1 ]]; then
|
||||
echo "ERROR: Some git dependencies are not on their main branches!"
|
||||
echo "Please update the commits to point to main branch commits."
|
||||
exit 1
|
||||
else
|
||||
echo "SUCCESS: All git dependencies are on their main branches!"
|
||||
fi
|
||||
32
.github/workflows/release.yml
vendored
32
.github/workflows/release.yml
vendored
@@ -49,14 +49,9 @@ on:
|
||||
description: Do not run integration tests during the build
|
||||
type: boolean
|
||||
default: true
|
||||
build_linux_amd64_artifacts:
|
||||
build_linux_artifacts:
|
||||
type: boolean
|
||||
description: Build linux-amd64 artifacts
|
||||
required: false
|
||||
default: false
|
||||
build_linux_arm64_artifacts:
|
||||
type: boolean
|
||||
description: Build linux-arm64 artifacts
|
||||
description: Build linux artifacts (both amd64 and arm64)
|
||||
required: false
|
||||
default: false
|
||||
build_macos_artifacts:
|
||||
@@ -144,7 +139,7 @@ jobs:
|
||||
./.github/scripts/check-version.sh "${{ steps.create-version.outputs.version }}"
|
||||
|
||||
- name: Allocate linux-amd64 runner
|
||||
if: ${{ inputs.build_linux_amd64_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
if: ${{ inputs.build_linux_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
uses: ./.github/actions/start-runner
|
||||
id: start-linux-amd64-runner
|
||||
with:
|
||||
@@ -158,7 +153,7 @@ jobs:
|
||||
subnet-id: ${{ vars.EC2_RUNNER_SUBNET_ID }}
|
||||
|
||||
- name: Allocate linux-arm64 runner
|
||||
if: ${{ inputs.build_linux_arm64_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
if: ${{ inputs.build_linux_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
uses: ./.github/actions/start-runner
|
||||
id: start-linux-arm64-runner
|
||||
with:
|
||||
@@ -173,7 +168,7 @@ jobs:
|
||||
|
||||
build-linux-amd64-artifacts:
|
||||
name: Build linux-amd64 artifacts
|
||||
if: ${{ inputs.build_linux_amd64_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
if: ${{ inputs.build_linux_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
needs: [
|
||||
allocate-runners,
|
||||
]
|
||||
@@ -195,7 +190,7 @@ jobs:
|
||||
|
||||
build-linux-arm64-artifacts:
|
||||
name: Build linux-arm64 artifacts
|
||||
if: ${{ inputs.build_linux_arm64_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
if: ${{ inputs.build_linux_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
needs: [
|
||||
allocate-runners,
|
||||
]
|
||||
@@ -217,7 +212,7 @@ jobs:
|
||||
|
||||
run-multi-lang-tests:
|
||||
name: Run Multi-language SDK Tests
|
||||
if: ${{ inputs.build_linux_amd64_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
if: ${{ inputs.build_linux_artifacts || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
needs: [
|
||||
allocate-runners,
|
||||
build-linux-amd64-artifacts,
|
||||
@@ -386,7 +381,18 @@ jobs:
|
||||
|
||||
publish-github-release:
|
||||
name: Create GitHub release and upload artifacts
|
||||
if: ${{ inputs.publish_github_release || github.event_name == 'push' || github.event_name == 'schedule' }}
|
||||
# Use always() to run even when optional jobs (macos, windows) are skipped.
|
||||
# Then check that required jobs succeeded and optional jobs didn't fail.
|
||||
if: |
|
||||
always() &&
|
||||
(inputs.publish_github_release || github.event_name == 'push' || github.event_name == 'schedule') &&
|
||||
needs.allocate-runners.result == 'success' &&
|
||||
(needs.build-linux-amd64-artifacts.result == 'success' || needs.build-linux-amd64-artifacts.result == 'skipped') &&
|
||||
(needs.build-linux-arm64-artifacts.result == 'success' || needs.build-linux-arm64-artifacts.result == 'skipped') &&
|
||||
(needs.build-macos-artifacts.result == 'success' || needs.build-macos-artifacts.result == 'skipped') &&
|
||||
(needs.build-windows-artifacts.result == 'success' || needs.build-windows-artifacts.result == 'skipped') &&
|
||||
(needs.release-images-to-dockerhub.result == 'success' || needs.release-images-to-dockerhub.result == 'skipped') &&
|
||||
(needs.run-multi-lang-tests.result == 'success' || needs.run-multi-lang-tests.result == 'skipped')
|
||||
needs: [ # The job have to wait for all the artifacts are built.
|
||||
allocate-runners,
|
||||
build-linux-amd64-artifacts,
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -67,3 +67,6 @@ greptimedb_data
|
||||
|
||||
# Claude code
|
||||
CLAUDE.md
|
||||
|
||||
# AGENTS.md
|
||||
AGENTS.md
|
||||
|
||||
64
AUTHOR.md
64
AUTHOR.md
@@ -2,41 +2,41 @@
|
||||
|
||||
## Individual Committers (in alphabetical order)
|
||||
|
||||
* [CookiePieWw](https://github.com/CookiePieWw)
|
||||
* [etolbakov](https://github.com/etolbakov)
|
||||
* [irenjj](https://github.com/irenjj)
|
||||
* [KKould](https://github.com/KKould)
|
||||
* [Lanqing Yang](https://github.com/lyang24)
|
||||
* [NiwakaDev](https://github.com/NiwakaDev)
|
||||
* [tisonkun](https://github.com/tisonkun)
|
||||
- [apdong2022](https://github.com/apdong2022)
|
||||
- [beryl678](https://github.com/beryl678)
|
||||
- [CookiePieWw](https://github.com/CookiePieWw)
|
||||
- [etolbakov](https://github.com/etolbakov)
|
||||
- [irenjj](https://github.com/irenjj)
|
||||
- [KKould](https://github.com/KKould)
|
||||
- [Lanqing Yang](https://github.com/lyang24)
|
||||
- [nicecui](https://github.com/nicecui)
|
||||
- [NiwakaDev](https://github.com/NiwakaDev)
|
||||
- [paomian](https://github.com/paomian)
|
||||
- [tisonkun](https://github.com/tisonkun)
|
||||
- [Wenjie0329](https://github.com/Wenjie0329)
|
||||
- [zhaoyingnan01](https://github.com/zhaoyingnan01)
|
||||
- [zhongzc](https://github.com/zhongzc)
|
||||
- [ZonaHex](https://github.com/ZonaHex)
|
||||
- [zyy17](https://github.com/zyy17)
|
||||
|
||||
## Team Members (in alphabetical order)
|
||||
|
||||
* [apdong2022](https://github.com/apdong2022)
|
||||
* [beryl678](https://github.com/beryl678)
|
||||
* [daviderli614](https://github.com/daviderli614)
|
||||
* [discord9](https://github.com/discord9)
|
||||
* [evenyag](https://github.com/evenyag)
|
||||
* [fengjiachun](https://github.com/fengjiachun)
|
||||
* [fengys1996](https://github.com/fengys1996)
|
||||
* [GrepTime](https://github.com/GrepTime)
|
||||
* [holalengyu](https://github.com/holalengyu)
|
||||
* [killme2008](https://github.com/killme2008)
|
||||
* [MichaelScofield](https://github.com/MichaelScofield)
|
||||
* [nicecui](https://github.com/nicecui)
|
||||
* [paomian](https://github.com/paomian)
|
||||
* [shuiyisong](https://github.com/shuiyisong)
|
||||
* [sunchanglong](https://github.com/sunchanglong)
|
||||
* [sunng87](https://github.com/sunng87)
|
||||
* [v0y4g3r](https://github.com/v0y4g3r)
|
||||
* [waynexia](https://github.com/waynexia)
|
||||
* [Wenjie0329](https://github.com/Wenjie0329)
|
||||
* [WenyXu](https://github.com/WenyXu)
|
||||
* [xtang](https://github.com/xtang)
|
||||
* [zhaoyingnan01](https://github.com/zhaoyingnan01)
|
||||
* [zhongzc](https://github.com/zhongzc)
|
||||
* [ZonaHex](https://github.com/ZonaHex)
|
||||
* [zyy17](https://github.com/zyy17)
|
||||
- [daviderli614](https://github.com/daviderli614)
|
||||
- [discord9](https://github.com/discord9)
|
||||
- [evenyag](https://github.com/evenyag)
|
||||
- [fengjiachun](https://github.com/fengjiachun)
|
||||
- [fengys1996](https://github.com/fengys1996)
|
||||
- [GrepTime](https://github.com/GrepTime)
|
||||
- [holalengyu](https://github.com/holalengyu)
|
||||
- [killme2008](https://github.com/killme2008)
|
||||
- [MichaelScofield](https://github.com/MichaelScofield)
|
||||
- [shuiyisong](https://github.com/shuiyisong)
|
||||
- [sunchanglong](https://github.com/sunchanglong)
|
||||
- [sunng87](https://github.com/sunng87)
|
||||
- [v0y4g3r](https://github.com/v0y4g3r)
|
||||
- [waynexia](https://github.com/waynexia)
|
||||
- [WenyXu](https://github.com/WenyXu)
|
||||
- [xtang](https://github.com/xtang)
|
||||
|
||||
## All Contributors
|
||||
|
||||
|
||||
@@ -102,6 +102,30 @@ like `feat`/`fix`/`docs`, with a concise summary of code change following. AVOID
|
||||
|
||||
All commit messages SHOULD adhere to the [Conventional Commits specification](https://conventionalcommits.org/).
|
||||
|
||||
## AI-Assisted contributions
|
||||
|
||||
We have the following policy for AI-assisted PRs:
|
||||
|
||||
- The PR author should **understand the core ideas** behind the implementation **end-to-end**, and be able to justify the design and code during review.
|
||||
- **Calls out unknowns and assumptions**. It's okay to not fully understand some bits of AI generated code. You should comment on these cases and point them out to reviewers so that they can use their knowledge of the codebase to clear up any concerns. For example, you might comment "calling this function here seems to work but I'm not familiar with how it works internally, I wonder if there's a race condition if it is called concurrently".
|
||||
|
||||
### Why fully AI-generated PRs without understanding are not helpful
|
||||
|
||||
Today, AI tools cannot reliably make complex changes to GreptimeDB on their own, which is why we rely on pull requests and code review.
|
||||
|
||||
The purposes of code review are:
|
||||
|
||||
1. Finish the intended task.
|
||||
2. Share knowledge between authors and reviewers, as a long-term investment in the project. For this reason, even if someone familiar with the codebase can finish a task quickly, we're still happy to help a new contributor work on it even if it takes longer.
|
||||
|
||||
An AI dump for an issue doesn’t meet these purposes. Maintainers could finish the task faster by using AI directly, and the submitters gain little knowledge if they act only as a pass through AI proxy without understanding.
|
||||
|
||||
Please understand the reviewing capacity is **very limited** for the project, so large PRs which appear to not have the requisite understanding might not get reviewed, and eventually closed or redirected.
|
||||
|
||||
### Better ways to contribute than an “AI dump”
|
||||
|
||||
It's recommended to write a high-quality issue with a clear problem statement and a minimal, reproducible example. This can make it easier for others to contribute.
|
||||
|
||||
## Getting Help
|
||||
|
||||
There are many ways to get help when you're stuck. It is recommended to ask for help by opening an issue, with a detailed description
|
||||
|
||||
379
Cargo.lock
generated
379
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
16
Cargo.toml
16
Cargo.toml
@@ -21,6 +21,7 @@ members = [
|
||||
"src/common/grpc-expr",
|
||||
"src/common/macro",
|
||||
"src/common/mem-prof",
|
||||
"src/common/memory-manager",
|
||||
"src/common/meta",
|
||||
"src/common/options",
|
||||
"src/common/plugins",
|
||||
@@ -74,7 +75,7 @@ members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "1.0.0-beta.2"
|
||||
version = "1.0.0-beta.3"
|
||||
edition = "2024"
|
||||
license = "Apache-2.0"
|
||||
|
||||
@@ -131,7 +132,7 @@ datafusion-functions = "50"
|
||||
datafusion-functions-aggregate-common = "50"
|
||||
datafusion-optimizer = "50"
|
||||
datafusion-orc = "0.5"
|
||||
datafusion-pg-catalog = "0.12.1"
|
||||
datafusion-pg-catalog = "0.12.3"
|
||||
datafusion-physical-expr = "50"
|
||||
datafusion-physical-plan = "50"
|
||||
datafusion-sql = "50"
|
||||
@@ -139,16 +140,17 @@ datafusion-substrait = "50"
|
||||
deadpool = "0.12"
|
||||
deadpool-postgres = "0.14"
|
||||
derive_builder = "0.20"
|
||||
derive_more = { version = "2.1", features = ["full"] }
|
||||
dotenv = "0.15"
|
||||
either = "1.15"
|
||||
etcd-client = { git = "https://github.com/GreptimeTeam/etcd-client", rev = "f62df834f0cffda355eba96691fe1a9a332b75a7", features = [
|
||||
etcd-client = { version = "0.16.1", features = [
|
||||
"tls",
|
||||
"tls-roots",
|
||||
] }
|
||||
fst = "0.4.7"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "0df99f09f1d6785055b2d9da96fc4ecc2bdf6803" }
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "520fa524f9d590752ea327683e82ffd65721b27c" }
|
||||
hex = "0.4"
|
||||
http = "1"
|
||||
humantime = "2.1"
|
||||
@@ -200,7 +202,8 @@ reqwest = { version = "0.12", default-features = false, features = [
|
||||
"stream",
|
||||
"multipart",
|
||||
] }
|
||||
rskafka = { git = "https://github.com/WenyXu/rskafka.git", rev = "7b0f31ed39db049b4ee2e5f1e95b5a30be9baf76", features = [
|
||||
# Branch: feat/request-timeout
|
||||
rskafka = { git = "https://github.com/GreptimeTeam/rskafka.git", rev = "f5688f83e7da591cda3f2674c2408b4c0ed4ed50", features = [
|
||||
"transport-tls",
|
||||
] }
|
||||
rstest = "0.25"
|
||||
@@ -264,6 +267,7 @@ common-grpc = { path = "src/common/grpc" }
|
||||
common-grpc-expr = { path = "src/common/grpc-expr" }
|
||||
common-macro = { path = "src/common/macro" }
|
||||
common-mem-prof = { path = "src/common/mem-prof" }
|
||||
common-memory-manager = { path = "src/common/memory-manager" }
|
||||
common-meta = { path = "src/common/meta" }
|
||||
common-options = { path = "src/common/options" }
|
||||
common-plugins = { path = "src/common/plugins" }
|
||||
@@ -328,7 +332,7 @@ datafusion-physical-plan = { git = "https://github.com/GreptimeTeam/datafusion.g
|
||||
datafusion-datasource = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "fd4b2abcf3c3e43e94951bda452c9fd35243aab0" }
|
||||
datafusion-sql = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "fd4b2abcf3c3e43e94951bda452c9fd35243aab0" }
|
||||
datafusion-substrait = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "fd4b2abcf3c3e43e94951bda452c9fd35243aab0" }
|
||||
sqlparser = { git = "https://github.com/GreptimeTeam/sqlparser-rs.git", rev = "4b519a5caa95472cc3988f5556813a583dd35af1" } # branch = "v0.58.x"
|
||||
sqlparser = { git = "https://github.com/GreptimeTeam/sqlparser-rs.git", rev = "a0ce2bc6eb3e804532932f39833c32432f5c9a39" } # branch = "v0.58.x"
|
||||
|
||||
[profile.release]
|
||||
debug = 1
|
||||
|
||||
@@ -14,11 +14,12 @@
|
||||
| --- | -----| ------- | ----------- |
|
||||
| `default_timezone` | String | Unset | The default timezone of the server. |
|
||||
| `default_column_prefix` | String | Unset | The default column prefix for auto-created time index and value columns. |
|
||||
| `max_in_flight_write_bytes` | String | Unset | Maximum total memory for all concurrent write request bodies and messages (HTTP, gRPC, Flight).<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
|
||||
| `write_bytes_exhausted_policy` | String | Unset | Policy when write bytes quota is exhausted.<br/>Options: "wait" (default, 10s timeout), "wait(<duration>)" (e.g., "wait(30s)"), "fail" |
|
||||
| `init_regions_in_background` | Bool | `false` | Initialize all regions in the background during the startup.<br/>By default, it provides services after all regions have been initialized. |
|
||||
| `init_regions_parallelism` | Integer | `16` | Parallelism of initializing regions. |
|
||||
| `max_concurrent_queries` | Integer | `0` | The maximum current queries allowed to be executed. Zero means unlimited.<br/>NOTE: This setting affects scan_memory_limit's privileged tier allocation.<br/>When set, 70% of queries get privileged memory access (full scan_memory_limit).<br/>The remaining 30% get standard tier access (70% of scan_memory_limit). |
|
||||
| `enable_telemetry` | Bool | `true` | Enable telemetry to collect anonymous usage data. Enabled by default. |
|
||||
| `max_in_flight_write_bytes` | String | Unset | The maximum in-flight write bytes. |
|
||||
| `runtime` | -- | -- | The runtime options. |
|
||||
| `runtime.global_rt_size` | Integer | `8` | The number of threads to execute the runtime for global read operations. |
|
||||
| `runtime.compact_rt_size` | Integer | `4` | The number of threads to execute the runtime for global write operations. |
|
||||
@@ -26,14 +27,12 @@
|
||||
| `http.addr` | String | `127.0.0.1:4000` | The address to bind the HTTP server. |
|
||||
| `http.timeout` | String | `0s` | HTTP request timeout. Set to 0 to disable timeout. |
|
||||
| `http.body_limit` | String | `64MB` | HTTP request body limit.<br/>The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.<br/>Set to 0 to disable limit. |
|
||||
| `http.max_total_body_memory` | String | Unset | Maximum total memory for all concurrent HTTP request bodies.<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
|
||||
| `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default<br/>This allows browser to access http APIs without CORS restrictions |
|
||||
| `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. |
|
||||
| `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.<br/>Available options:<br/>- strict: deny invalid UTF-8 strings (default).<br/>- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).<br/>- unchecked: do not valid strings. |
|
||||
| `grpc` | -- | -- | The gRPC server options. |
|
||||
| `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. |
|
||||
| `grpc.runtime_size` | Integer | `8` | The number of server worker threads. |
|
||||
| `grpc.max_total_message_memory` | String | Unset | Maximum total memory for all concurrent gRPC request messages.<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
|
||||
| `grpc.max_connection_age` | String | Unset | The maximum connection age for gRPC connection.<br/>The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.<br/>Refer to https://grpc.io/docs/guides/keepalive/ for more details. |
|
||||
| `grpc.tls` | -- | -- | gRPC server TLS options, see `mysql.tls` section. |
|
||||
| `grpc.tls.mode` | String | `disable` | TLS mode. |
|
||||
@@ -83,6 +82,8 @@
|
||||
| `wal.sync_period` | String | `10s` | Duration for fsyncing log files.<br/>**It's only used when the provider is `raft_engine`**. |
|
||||
| `wal.recovery_parallelism` | Integer | `2` | Parallelism during WAL recovery. |
|
||||
| `wal.broker_endpoints` | Array | -- | The Kafka broker endpoints.<br/>**It's only used when the provider is `kafka`**. |
|
||||
| `wal.connect_timeout` | String | `3s` | The connect timeout for kafka client.<br/>**It's only used when the provider is `kafka`**. |
|
||||
| `wal.timeout` | String | `3s` | The timeout for kafka client.<br/>**It's only used when the provider is `kafka`**. |
|
||||
| `wal.auto_create_topics` | Bool | `true` | Automatically create topics for WAL.<br/>Set to `true` to automatically create topics for WAL.<br/>Otherwise, use topics named `topic_name_prefix_[0..num_topics)` |
|
||||
| `wal.num_topics` | Integer | `64` | Number of topics.<br/>**It's only used when the provider is `kafka`**. |
|
||||
| `wal.selector_type` | String | `round_robin` | Topic selector type.<br/>Available selector types:<br/>- `round_robin` (default)<br/>**It's only used when the provider is `kafka`**. |
|
||||
@@ -108,9 +109,6 @@
|
||||
| `storage` | -- | -- | The data storage options. |
|
||||
| `storage.data_home` | String | `./greptimedb_data` | The working home directory. |
|
||||
| `storage.type` | String | `File` | The storage type used to store the data.<br/>- `File`: the data is stored in the local file system.<br/>- `S3`: the data is stored in the S3 object storage.<br/>- `Gcs`: the data is stored in the Google Cloud Storage.<br/>- `Azblob`: the data is stored in the Azure Blob Storage.<br/>- `Oss`: the data is stored in the Aliyun OSS. |
|
||||
| `storage.enable_read_cache` | Bool | `true` | Whether to enable read cache. If not set, the read cache will be enabled by default when using object storage. |
|
||||
| `storage.cache_path` | String | Unset | Read cache configuration for object storage such as 'S3' etc, it's configured by default when using object storage. It is recommended to configure it when using object storage for better performance.<br/>A local file directory, defaults to `{data_home}`. An empty string means disabling. |
|
||||
| `storage.cache_capacity` | String | Unset | The local file cache capacity in bytes. If your disk space is sufficient, it is recommended to set it larger. |
|
||||
| `storage.bucket` | String | Unset | The S3 bucket name.<br/>**It's only used when the storage type is `S3`, `Oss` and `Gcs`**. |
|
||||
| `storage.root` | String | Unset | The S3 data will be stored in the specified prefix, for example, `s3://${bucket}/${root}`.<br/>**It's only used when the storage type is `S3`, `Oss` and `Azblob`**. |
|
||||
| `storage.access_key_id` | String | Unset | The access key id of the aws account.<br/>It's **highly recommended** to use AWS IAM roles instead of hardcoding the access key id and secret key.<br/>**It's only used when the storage type is `S3` and `Oss`**. |
|
||||
@@ -141,6 +139,8 @@
|
||||
| `region_engine.mito.max_background_flushes` | Integer | Auto | Max number of running background flush jobs (default: 1/2 of cpu cores). |
|
||||
| `region_engine.mito.max_background_compactions` | Integer | Auto | Max number of running background compaction jobs (default: 1/4 of cpu cores). |
|
||||
| `region_engine.mito.max_background_purges` | Integer | Auto | Max number of running background purge jobs (default: number of cpu cores). |
|
||||
| `region_engine.mito.experimental_compaction_memory_limit` | String | 0 | Memory budget for compaction tasks. Setting it to 0 or "unlimited" disables the limit. |
|
||||
| `region_engine.mito.experimental_compaction_on_exhausted` | String | wait | Behavior when compaction cannot acquire memory from the budget.<br/>Options: "wait" (default, 10s), "wait(<duration>)", "fail" |
|
||||
| `region_engine.mito.auto_flush_interval` | String | `1h` | Interval to auto flush a region if it has not flushed yet. |
|
||||
| `region_engine.mito.global_write_buffer_size` | String | Auto | Global write buffer size for all regions. If not set, it's default to 1/8 of OS memory with a max limitation of 1GB. |
|
||||
| `region_engine.mito.global_write_buffer_reject_size` | String | Auto | Global write buffer size threshold to reject write requests. If not set, it's default to 2 times of `global_write_buffer_size`. |
|
||||
@@ -154,6 +154,8 @@
|
||||
| `region_engine.mito.write_cache_ttl` | String | Unset | TTL for write cache. |
|
||||
| `region_engine.mito.preload_index_cache` | Bool | `true` | Preload index (puffin) files into cache on region open (default: true).<br/>When enabled, index files are loaded into the write cache during region initialization,<br/>which can improve query performance at the cost of longer startup times. |
|
||||
| `region_engine.mito.index_cache_percent` | Integer | `20` | Percentage of write cache capacity allocated for index (puffin) files (default: 20).<br/>The remaining capacity is used for data (parquet) files.<br/>Must be between 0 and 100 (exclusive). For example, with a 5GiB write cache and 20% allocation,<br/>1GiB is reserved for index files and 4GiB for data files. |
|
||||
| `region_engine.mito.enable_refill_cache_on_read` | Bool | `true` | Enable refilling cache on read operations (default: true).<br/>When disabled, cache refilling on read won't happen. |
|
||||
| `region_engine.mito.manifest_cache_size` | String | `256MB` | Capacity for manifest cache (default: 256MB). |
|
||||
| `region_engine.mito.sst_write_buffer_size` | String | `8MB` | Buffer size for SST writing. |
|
||||
| `region_engine.mito.parallel_scan_channel_size` | Integer | `32` | Capacity of the channel to send data from parallel scan tasks to the main task. |
|
||||
| `region_engine.mito.max_concurrent_scan_files` | Integer | `384` | Maximum number of SST files to scan concurrently. |
|
||||
@@ -224,7 +226,8 @@
|
||||
| --- | -----| ------- | ----------- |
|
||||
| `default_timezone` | String | Unset | The default timezone of the server. |
|
||||
| `default_column_prefix` | String | Unset | The default column prefix for auto-created time index and value columns. |
|
||||
| `max_in_flight_write_bytes` | String | Unset | The maximum in-flight write bytes. |
|
||||
| `max_in_flight_write_bytes` | String | Unset | Maximum total memory for all concurrent write request bodies and messages (HTTP, gRPC, Flight).<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
|
||||
| `write_bytes_exhausted_policy` | String | Unset | Policy when write bytes quota is exhausted.<br/>Options: "wait" (default, 10s timeout), "wait(<duration>)" (e.g., "wait(30s)"), "fail" |
|
||||
| `runtime` | -- | -- | The runtime options. |
|
||||
| `runtime.global_rt_size` | Integer | `8` | The number of threads to execute the runtime for global read operations. |
|
||||
| `runtime.compact_rt_size` | Integer | `4` | The number of threads to execute the runtime for global write operations. |
|
||||
@@ -235,7 +238,6 @@
|
||||
| `http.addr` | String | `127.0.0.1:4000` | The address to bind the HTTP server. |
|
||||
| `http.timeout` | String | `0s` | HTTP request timeout. Set to 0 to disable timeout. |
|
||||
| `http.body_limit` | String | `64MB` | HTTP request body limit.<br/>The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.<br/>Set to 0 to disable limit. |
|
||||
| `http.max_total_body_memory` | String | Unset | Maximum total memory for all concurrent HTTP request bodies.<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
|
||||
| `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default<br/>This allows browser to access http APIs without CORS restrictions |
|
||||
| `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. |
|
||||
| `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.<br/>Available options:<br/>- strict: deny invalid UTF-8 strings (default).<br/>- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).<br/>- unchecked: do not valid strings. |
|
||||
@@ -243,7 +245,6 @@
|
||||
| `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. |
|
||||
| `grpc.server_addr` | String | `127.0.0.1:4001` | The address advertised to the metasrv, and used for connections from outside the host.<br/>If left empty or unset, the server will automatically use the IP address of the first network interface<br/>on the host, with the same port number as the one specified in `grpc.bind_addr`. |
|
||||
| `grpc.runtime_size` | Integer | `8` | The number of server worker threads. |
|
||||
| `grpc.max_total_message_memory` | String | Unset | Maximum total memory for all concurrent gRPC request messages.<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
|
||||
| `grpc.flight_compression` | String | `arrow_ipc` | Compression mode for frontend side Arrow IPC service. Available options:<br/>- `none`: disable all compression<br/>- `transport`: only enable gRPC transport compression (zstd)<br/>- `arrow_ipc`: only enable Arrow IPC compression (lz4)<br/>- `all`: enable all compression.<br/>Default to `none` |
|
||||
| `grpc.max_connection_age` | String | Unset | The maximum connection age for gRPC connection.<br/>The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.<br/>Refer to https://grpc.io/docs/guides/keepalive/ for more details. |
|
||||
| `grpc.tls` | -- | -- | gRPC server TLS options, see `mysql.tls` section. |
|
||||
@@ -294,7 +295,6 @@
|
||||
| `meta_client` | -- | -- | The metasrv client options. |
|
||||
| `meta_client.metasrv_addrs` | Array | -- | The addresses of the metasrv. |
|
||||
| `meta_client.timeout` | String | `3s` | Operation timeout. |
|
||||
| `meta_client.heartbeat_timeout` | String | `500ms` | Heartbeat timeout. |
|
||||
| `meta_client.ddl_timeout` | String | `10s` | DDL timeout. |
|
||||
| `meta_client.connect_timeout` | String | `1s` | Connect server timeout. |
|
||||
| `meta_client.tcp_nodelay` | Bool | `true` | `TCP_NODELAY` option for accepted connections. |
|
||||
@@ -344,14 +344,15 @@
|
||||
| `store_key_prefix` | String | `""` | If it's not empty, the metasrv will store all data with this key prefix. |
|
||||
| `backend` | String | `etcd_store` | The datastore for meta server.<br/>Available values:<br/>- `etcd_store` (default value)<br/>- `memory_store`<br/>- `postgres_store`<br/>- `mysql_store` |
|
||||
| `meta_table_name` | String | `greptime_metakv` | Table name in RDS to store metadata. Effect when using a RDS kvbackend.<br/>**Only used when backend is `postgres_store`.** |
|
||||
| `meta_schema_name` | String | `greptime_schema` | Optional PostgreSQL schema for metadata table and election table name qualification.<br/>When PostgreSQL public schema is not writable (e.g., PostgreSQL 15+ with restricted public),<br/>set this to a writable schema. GreptimeDB will use `meta_schema_name`.`meta_table_name`.<br/>GreptimeDB will NOT create the schema automatically; please ensure it exists or the user has permission.<br/>**Only used when backend is `postgres_store`.** |
|
||||
| `meta_schema_name` | String | `greptime_schema` | Optional PostgreSQL schema for metadata table and election table name qualification.<br/>When PostgreSQL public schema is not writable (e.g., PostgreSQL 15+ with restricted public),<br/>set this to a writable schema. GreptimeDB will use `meta_schema_name`.`meta_table_name`.<br/>**Only used when backend is `postgres_store`.** |
|
||||
| `auto_create_schema` | Bool | `true` | Automatically create PostgreSQL schema if it doesn't exist.<br/>When enabled, the system will execute `CREATE SCHEMA IF NOT EXISTS <schema_name>`<br/>before creating metadata tables. This is useful in production environments where<br/>manual schema creation may be restricted.<br/>Default is true.<br/>Note: The PostgreSQL user must have CREATE SCHEMA permission for this to work.<br/>**Only used when backend is `postgres_store`.** |
|
||||
| `meta_election_lock_id` | Integer | `1` | Advisory lock id in PostgreSQL for election. Effect when using PostgreSQL as kvbackend<br/>Only used when backend is `postgres_store`. |
|
||||
| `selector` | String | `round_robin` | Datanode selector type.<br/>- `round_robin` (default value)<br/>- `lease_based`<br/>- `load_based`<br/>For details, please see "https://docs.greptime.com/developer-guide/metasrv/selector". |
|
||||
| `use_memory_store` | Bool | `false` | Store data in memory. |
|
||||
| `enable_region_failover` | Bool | `false` | Whether to enable region failover.<br/>This feature is only available on GreptimeDB running on cluster mode and<br/>- Using Remote WAL<br/>- Using shared storage (e.g., s3). |
|
||||
| `region_failure_detector_initialization_delay` | String | `10m` | The delay before starting region failure detection.<br/>This delay helps prevent Metasrv from triggering unnecessary region failovers before all Datanodes are fully started.<br/>Especially useful when the cluster is not deployed with GreptimeDB Operator and maintenance mode is not enabled. |
|
||||
| `allow_region_failover_on_local_wal` | Bool | `false` | Whether to allow region failover on local WAL.<br/>**This option is not recommended to be set to true, because it may lead to data loss during failover.** |
|
||||
| `node_max_idle_time` | String | `24hours` | Max allowed idle time before removing node info from metasrv memory. |
|
||||
| `heartbeat_interval` | String | `3s` | Base heartbeat interval for calculating distributed time constants.<br/>The frontend heartbeat interval is 6 times of the base heartbeat interval.<br/>The flownode/datanode heartbeat interval is 1 times of the base heartbeat interval.<br/>e.g., If the base heartbeat interval is 3s, the frontend heartbeat interval is 18s, the flownode/datanode heartbeat interval is 3s.<br/>If you change this value, you need to change the heartbeat interval of the flownode/frontend/datanode accordingly. |
|
||||
| `enable_telemetry` | Bool | `true` | Whether to enable greptimedb telemetry. Enabled by default. |
|
||||
| `runtime` | -- | -- | The runtime options. |
|
||||
| `runtime.global_rt_size` | Integer | `8` | The number of threads to execute the runtime for global read operations. |
|
||||
@@ -361,12 +362,18 @@
|
||||
| `backend_tls.cert_path` | String | `""` | Path to client certificate file (for client authentication)<br/>Like "/path/to/client.crt" |
|
||||
| `backend_tls.key_path` | String | `""` | Path to client private key file (for client authentication)<br/>Like "/path/to/client.key" |
|
||||
| `backend_tls.ca_cert_path` | String | `""` | Path to CA certificate file (for server certificate verification)<br/>Required when using custom CAs or self-signed certificates<br/>Leave empty to use system root certificates only<br/>Like "/path/to/ca.crt" |
|
||||
| `backend_client` | -- | -- | The backend client options.<br/>Currently, only applicable when using etcd as the metadata store. |
|
||||
| `backend_client.keep_alive_timeout` | String | `3s` | The keep alive timeout for backend client. |
|
||||
| `backend_client.keep_alive_interval` | String | `10s` | The keep alive interval for backend client. |
|
||||
| `backend_client.connect_timeout` | String | `3s` | The connect timeout for backend client. |
|
||||
| `grpc` | -- | -- | The gRPC server options. |
|
||||
| `grpc.bind_addr` | String | `127.0.0.1:3002` | The address to bind the gRPC server. |
|
||||
| `grpc.server_addr` | String | `127.0.0.1:3002` | The communication server address for the frontend and datanode to connect to metasrv.<br/>If left empty or unset, the server will automatically use the IP address of the first network interface<br/>on the host, with the same port number as the one specified in `bind_addr`. |
|
||||
| `grpc.runtime_size` | Integer | `8` | The number of server worker threads. |
|
||||
| `grpc.max_recv_message_size` | String | `512MB` | The maximum receive message size for gRPC server. |
|
||||
| `grpc.max_send_message_size` | String | `512MB` | The maximum send message size for gRPC server. |
|
||||
| `grpc.http2_keep_alive_interval` | String | `10s` | The server side HTTP/2 keep-alive interval |
|
||||
| `grpc.http2_keep_alive_timeout` | String | `3s` | The server side HTTP/2 keep-alive timeout. |
|
||||
| `http` | -- | -- | The HTTP server options. |
|
||||
| `http.addr` | String | `127.0.0.1:4000` | The address to bind the HTTP server. |
|
||||
| `http.timeout` | String | `0s` | HTTP request timeout. Set to 0 to disable timeout. |
|
||||
@@ -457,7 +464,6 @@
|
||||
| `meta_client` | -- | -- | The metasrv client options. |
|
||||
| `meta_client.metasrv_addrs` | Array | -- | The addresses of the metasrv. |
|
||||
| `meta_client.timeout` | String | `3s` | Operation timeout. |
|
||||
| `meta_client.heartbeat_timeout` | String | `500ms` | Heartbeat timeout. |
|
||||
| `meta_client.ddl_timeout` | String | `10s` | DDL timeout. |
|
||||
| `meta_client.connect_timeout` | String | `1s` | Connect server timeout. |
|
||||
| `meta_client.tcp_nodelay` | Bool | `true` | `TCP_NODELAY` option for accepted connections. |
|
||||
@@ -477,6 +483,8 @@
|
||||
| `wal.sync_period` | String | `10s` | Duration for fsyncing log files.<br/>**It's only used when the provider is `raft_engine`**. |
|
||||
| `wal.recovery_parallelism` | Integer | `2` | Parallelism during WAL recovery. |
|
||||
| `wal.broker_endpoints` | Array | -- | The Kafka broker endpoints.<br/>**It's only used when the provider is `kafka`**. |
|
||||
| `wal.connect_timeout` | String | `3s` | The connect timeout for kafka client.<br/>**It's only used when the provider is `kafka`**. |
|
||||
| `wal.timeout` | String | `3s` | The timeout for kafka client.<br/>**It's only used when the provider is `kafka`**. |
|
||||
| `wal.max_batch_bytes` | String | `1MB` | The max size of a single producer batch.<br/>Warning: Kafka has a default limit of 1MB per message in a topic.<br/>**It's only used when the provider is `kafka`**. |
|
||||
| `wal.consumer_wait_timeout` | String | `100ms` | The consumer wait timeout.<br/>**It's only used when the provider is `kafka`**. |
|
||||
| `wal.create_index` | Bool | `true` | Whether to enable WAL index creation.<br/>**It's only used when the provider is `kafka`**. |
|
||||
@@ -488,9 +496,6 @@
|
||||
| `storage` | -- | -- | The data storage options. |
|
||||
| `storage.data_home` | String | `./greptimedb_data` | The working home directory. |
|
||||
| `storage.type` | String | `File` | The storage type used to store the data.<br/>- `File`: the data is stored in the local file system.<br/>- `S3`: the data is stored in the S3 object storage.<br/>- `Gcs`: the data is stored in the Google Cloud Storage.<br/>- `Azblob`: the data is stored in the Azure Blob Storage.<br/>- `Oss`: the data is stored in the Aliyun OSS. |
|
||||
| `storage.cache_path` | String | Unset | Read cache configuration for object storage such as 'S3' etc, it's configured by default when using object storage. It is recommended to configure it when using object storage for better performance.<br/>A local file directory, defaults to `{data_home}`. An empty string means disabling. |
|
||||
| `storage.enable_read_cache` | Bool | `true` | Whether to enable read cache. If not set, the read cache will be enabled by default when using object storage. |
|
||||
| `storage.cache_capacity` | String | Unset | The local file cache capacity in bytes. If your disk space is sufficient, it is recommended to set it larger. |
|
||||
| `storage.bucket` | String | Unset | The S3 bucket name.<br/>**It's only used when the storage type is `S3`, `Oss` and `Gcs`**. |
|
||||
| `storage.root` | String | Unset | The S3 data will be stored in the specified prefix, for example, `s3://${bucket}/${root}`.<br/>**It's only used when the storage type is `S3`, `Oss` and `Azblob`**. |
|
||||
| `storage.access_key_id` | String | Unset | The access key id of the aws account.<br/>It's **highly recommended** to use AWS IAM roles instead of hardcoding the access key id and secret key.<br/>**It's only used when the storage type is `S3` and `Oss`**. |
|
||||
@@ -523,6 +528,8 @@
|
||||
| `region_engine.mito.max_background_flushes` | Integer | Auto | Max number of running background flush jobs (default: 1/2 of cpu cores). |
|
||||
| `region_engine.mito.max_background_compactions` | Integer | Auto | Max number of running background compaction jobs (default: 1/4 of cpu cores). |
|
||||
| `region_engine.mito.max_background_purges` | Integer | Auto | Max number of running background purge jobs (default: number of cpu cores). |
|
||||
| `region_engine.mito.experimental_compaction_memory_limit` | String | 0 | Memory budget for compaction tasks. Setting it to 0 or "unlimited" disables the limit. |
|
||||
| `region_engine.mito.experimental_compaction_on_exhausted` | String | wait | Behavior when compaction cannot acquire memory from the budget.<br/>Options: "wait" (default, 10s), "wait(<duration>)", "fail" |
|
||||
| `region_engine.mito.auto_flush_interval` | String | `1h` | Interval to auto flush a region if it has not flushed yet. |
|
||||
| `region_engine.mito.global_write_buffer_size` | String | Auto | Global write buffer size for all regions. If not set, it's default to 1/8 of OS memory with a max limitation of 1GB. |
|
||||
| `region_engine.mito.global_write_buffer_reject_size` | String | Auto | Global write buffer size threshold to reject write requests. If not set, it's default to 2 times of `global_write_buffer_size` |
|
||||
@@ -536,6 +543,8 @@
|
||||
| `region_engine.mito.write_cache_ttl` | String | Unset | TTL for write cache. |
|
||||
| `region_engine.mito.preload_index_cache` | Bool | `true` | Preload index (puffin) files into cache on region open (default: true).<br/>When enabled, index files are loaded into the write cache during region initialization,<br/>which can improve query performance at the cost of longer startup times. |
|
||||
| `region_engine.mito.index_cache_percent` | Integer | `20` | Percentage of write cache capacity allocated for index (puffin) files (default: 20).<br/>The remaining capacity is used for data (parquet) files.<br/>Must be between 0 and 100 (exclusive). For example, with a 5GiB write cache and 20% allocation,<br/>1GiB is reserved for index files and 4GiB for data files. |
|
||||
| `region_engine.mito.enable_refill_cache_on_read` | Bool | `true` | Enable refilling cache on read operations (default: true).<br/>When disabled, cache refilling on read won't happen. |
|
||||
| `region_engine.mito.manifest_cache_size` | String | `256MB` | Capacity for manifest cache (default: 256MB). |
|
||||
| `region_engine.mito.sst_write_buffer_size` | String | `8MB` | Buffer size for SST writing. |
|
||||
| `region_engine.mito.parallel_scan_channel_size` | Integer | `32` | Capacity of the channel to send data from parallel scan tasks to the main task. |
|
||||
| `region_engine.mito.max_concurrent_scan_files` | Integer | `384` | Maximum number of SST files to scan concurrently. |
|
||||
@@ -629,7 +638,6 @@
|
||||
| `meta_client` | -- | -- | The metasrv client options. |
|
||||
| `meta_client.metasrv_addrs` | Array | -- | The addresses of the metasrv. |
|
||||
| `meta_client.timeout` | String | `3s` | Operation timeout. |
|
||||
| `meta_client.heartbeat_timeout` | String | `500ms` | Heartbeat timeout. |
|
||||
| `meta_client.ddl_timeout` | String | `10s` | DDL timeout. |
|
||||
| `meta_client.connect_timeout` | String | `1s` | Connect server timeout. |
|
||||
| `meta_client.tcp_nodelay` | Bool | `true` | `TCP_NODELAY` option for accepted connections. |
|
||||
|
||||
@@ -99,9 +99,6 @@ metasrv_addrs = ["127.0.0.1:3002"]
|
||||
## Operation timeout.
|
||||
timeout = "3s"
|
||||
|
||||
## Heartbeat timeout.
|
||||
heartbeat_timeout = "500ms"
|
||||
|
||||
## DDL timeout.
|
||||
ddl_timeout = "10s"
|
||||
|
||||
@@ -172,6 +169,14 @@ recovery_parallelism = 2
|
||||
## **It's only used when the provider is `kafka`**.
|
||||
broker_endpoints = ["127.0.0.1:9092"]
|
||||
|
||||
## The connect timeout for kafka client.
|
||||
## **It's only used when the provider is `kafka`**.
|
||||
#+ connect_timeout = "3s"
|
||||
|
||||
## The timeout for kafka client.
|
||||
## **It's only used when the provider is `kafka`**.
|
||||
#+ timeout = "3s"
|
||||
|
||||
## The max size of a single producer batch.
|
||||
## Warning: Kafka has a default limit of 1MB per message in a topic.
|
||||
## **It's only used when the provider is `kafka`**.
|
||||
@@ -228,6 +233,7 @@ overwrite_entry_start_id = false
|
||||
# endpoint = "https://s3.amazonaws.com"
|
||||
# region = "us-west-2"
|
||||
# enable_virtual_host_style = false
|
||||
# disable_ec2_metadata = false
|
||||
|
||||
# Example of using Oss as the storage.
|
||||
# [storage]
|
||||
@@ -284,18 +290,6 @@ data_home = "./greptimedb_data"
|
||||
## - `Oss`: the data is stored in the Aliyun OSS.
|
||||
type = "File"
|
||||
|
||||
## Read cache configuration for object storage such as 'S3' etc, it's configured by default when using object storage. It is recommended to configure it when using object storage for better performance.
|
||||
## A local file directory, defaults to `{data_home}`. An empty string means disabling.
|
||||
## @toml2docs:none-default
|
||||
#+ cache_path = ""
|
||||
|
||||
## Whether to enable read cache. If not set, the read cache will be enabled by default when using object storage.
|
||||
#+ enable_read_cache = true
|
||||
|
||||
## The local file cache capacity in bytes. If your disk space is sufficient, it is recommended to set it larger.
|
||||
## @toml2docs:none-default
|
||||
cache_capacity = "5GiB"
|
||||
|
||||
## The S3 bucket name.
|
||||
## **It's only used when the storage type is `S3`, `Oss` and `Gcs`**.
|
||||
## @toml2docs:none-default
|
||||
@@ -455,6 +449,15 @@ compress_manifest = false
|
||||
## @toml2docs:none-default="Auto"
|
||||
#+ max_background_purges = 8
|
||||
|
||||
## Memory budget for compaction tasks. Setting it to 0 or "unlimited" disables the limit.
|
||||
## @toml2docs:none-default="0"
|
||||
#+ experimental_compaction_memory_limit = "0"
|
||||
|
||||
## Behavior when compaction cannot acquire memory from the budget.
|
||||
## Options: "wait" (default, 10s), "wait(<duration>)", "fail"
|
||||
## @toml2docs:none-default="wait"
|
||||
#+ experimental_compaction_on_exhausted = "wait"
|
||||
|
||||
## Interval to auto flush a region if it has not flushed yet.
|
||||
auto_flush_interval = "1h"
|
||||
|
||||
@@ -510,6 +513,13 @@ preload_index_cache = true
|
||||
## 1GiB is reserved for index files and 4GiB for data files.
|
||||
index_cache_percent = 20
|
||||
|
||||
## Enable refilling cache on read operations (default: true).
|
||||
## When disabled, cache refilling on read won't happen.
|
||||
enable_refill_cache_on_read = true
|
||||
|
||||
## Capacity for manifest cache (default: 256MB).
|
||||
manifest_cache_size = "256MB"
|
||||
|
||||
## Buffer size for SST writing.
|
||||
sst_write_buffer_size = "8MB"
|
||||
|
||||
|
||||
@@ -78,9 +78,6 @@ metasrv_addrs = ["127.0.0.1:3002"]
|
||||
## Operation timeout.
|
||||
timeout = "3s"
|
||||
|
||||
## Heartbeat timeout.
|
||||
heartbeat_timeout = "500ms"
|
||||
|
||||
## DDL timeout.
|
||||
ddl_timeout = "10s"
|
||||
|
||||
|
||||
@@ -6,9 +6,15 @@ default_timezone = "UTC"
|
||||
## @toml2docs:none-default
|
||||
default_column_prefix = "greptime"
|
||||
|
||||
## The maximum in-flight write bytes.
|
||||
## Maximum total memory for all concurrent write request bodies and messages (HTTP, gRPC, Flight).
|
||||
## Set to 0 to disable the limit. Default: "0" (unlimited)
|
||||
## @toml2docs:none-default
|
||||
#+ max_in_flight_write_bytes = "500MB"
|
||||
#+ max_in_flight_write_bytes = "1GB"
|
||||
|
||||
## Policy when write bytes quota is exhausted.
|
||||
## Options: "wait" (default, 10s timeout), "wait(<duration>)" (e.g., "wait(30s)"), "fail"
|
||||
## @toml2docs:none-default
|
||||
#+ write_bytes_exhausted_policy = "wait"
|
||||
|
||||
## The runtime options.
|
||||
#+ [runtime]
|
||||
@@ -35,10 +41,6 @@ timeout = "0s"
|
||||
## The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.
|
||||
## Set to 0 to disable limit.
|
||||
body_limit = "64MB"
|
||||
## Maximum total memory for all concurrent HTTP request bodies.
|
||||
## Set to 0 to disable the limit. Default: "0" (unlimited)
|
||||
## @toml2docs:none-default
|
||||
#+ max_total_body_memory = "1GB"
|
||||
## HTTP CORS support, it's turned on by default
|
||||
## This allows browser to access http APIs without CORS restrictions
|
||||
enable_cors = true
|
||||
@@ -62,10 +64,6 @@ bind_addr = "127.0.0.1:4001"
|
||||
server_addr = "127.0.0.1:4001"
|
||||
## The number of server worker threads.
|
||||
runtime_size = 8
|
||||
## Maximum total memory for all concurrent gRPC request messages.
|
||||
## Set to 0 to disable the limit. Default: "0" (unlimited)
|
||||
## @toml2docs:none-default
|
||||
#+ max_total_message_memory = "1GB"
|
||||
## Compression mode for frontend side Arrow IPC service. Available options:
|
||||
## - `none`: disable all compression
|
||||
## - `transport`: only enable gRPC transport compression (zstd)
|
||||
@@ -131,7 +129,6 @@ key_path = ""
|
||||
## For now, gRPC tls config does not support auto reload.
|
||||
watch = false
|
||||
|
||||
|
||||
## MySQL server options.
|
||||
[mysql]
|
||||
## Whether to enable.
|
||||
@@ -226,9 +223,6 @@ metasrv_addrs = ["127.0.0.1:3002"]
|
||||
## Operation timeout.
|
||||
timeout = "3s"
|
||||
|
||||
## Heartbeat timeout.
|
||||
heartbeat_timeout = "500ms"
|
||||
|
||||
## DDL timeout.
|
||||
ddl_timeout = "10s"
|
||||
|
||||
|
||||
@@ -34,11 +34,18 @@ meta_table_name = "greptime_metakv"
|
||||
## Optional PostgreSQL schema for metadata table and election table name qualification.
|
||||
## When PostgreSQL public schema is not writable (e.g., PostgreSQL 15+ with restricted public),
|
||||
## set this to a writable schema. GreptimeDB will use `meta_schema_name`.`meta_table_name`.
|
||||
## GreptimeDB will NOT create the schema automatically; please ensure it exists or the user has permission.
|
||||
## **Only used when backend is `postgres_store`.**
|
||||
|
||||
meta_schema_name = "greptime_schema"
|
||||
|
||||
## Automatically create PostgreSQL schema if it doesn't exist.
|
||||
## When enabled, the system will execute `CREATE SCHEMA IF NOT EXISTS <schema_name>`
|
||||
## before creating metadata tables. This is useful in production environments where
|
||||
## manual schema creation may be restricted.
|
||||
## Default is true.
|
||||
## Note: The PostgreSQL user must have CREATE SCHEMA permission for this to work.
|
||||
## **Only used when backend is `postgres_store`.**
|
||||
auto_create_schema = true
|
||||
|
||||
## Advisory lock id in PostgreSQL for election. Effect when using PostgreSQL as kvbackend
|
||||
## Only used when backend is `postgres_store`.
|
||||
meta_election_lock_id = 1
|
||||
@@ -50,9 +57,6 @@ meta_election_lock_id = 1
|
||||
## For details, please see "https://docs.greptime.com/developer-guide/metasrv/selector".
|
||||
selector = "round_robin"
|
||||
|
||||
## Store data in memory.
|
||||
use_memory_store = false
|
||||
|
||||
## Whether to enable region failover.
|
||||
## This feature is only available on GreptimeDB running on cluster mode and
|
||||
## - Using Remote WAL
|
||||
@@ -71,6 +75,13 @@ allow_region_failover_on_local_wal = false
|
||||
## Max allowed idle time before removing node info from metasrv memory.
|
||||
node_max_idle_time = "24hours"
|
||||
|
||||
## Base heartbeat interval for calculating distributed time constants.
|
||||
## The frontend heartbeat interval is 6 times of the base heartbeat interval.
|
||||
## The flownode/datanode heartbeat interval is 1 times of the base heartbeat interval.
|
||||
## e.g., If the base heartbeat interval is 3s, the frontend heartbeat interval is 18s, the flownode/datanode heartbeat interval is 3s.
|
||||
## If you change this value, you need to change the heartbeat interval of the flownode/frontend/datanode accordingly.
|
||||
#+ heartbeat_interval = "3s"
|
||||
|
||||
## Whether to enable greptimedb telemetry. Enabled by default.
|
||||
#+ enable_telemetry = true
|
||||
|
||||
@@ -109,6 +120,16 @@ key_path = ""
|
||||
## Like "/path/to/ca.crt"
|
||||
ca_cert_path = ""
|
||||
|
||||
## The backend client options.
|
||||
## Currently, only applicable when using etcd as the metadata store.
|
||||
#+ [backend_client]
|
||||
## The keep alive timeout for backend client.
|
||||
#+ keep_alive_timeout = "3s"
|
||||
## The keep alive interval for backend client.
|
||||
#+ keep_alive_interval = "10s"
|
||||
## The connect timeout for backend client.
|
||||
#+ connect_timeout = "3s"
|
||||
|
||||
## The gRPC server options.
|
||||
[grpc]
|
||||
## The address to bind the gRPC server.
|
||||
@@ -123,6 +144,10 @@ runtime_size = 8
|
||||
max_recv_message_size = "512MB"
|
||||
## The maximum send message size for gRPC server.
|
||||
max_send_message_size = "512MB"
|
||||
## The server side HTTP/2 keep-alive interval
|
||||
#+ http2_keep_alive_interval = "10s"
|
||||
## The server side HTTP/2 keep-alive timeout.
|
||||
#+ http2_keep_alive_timeout = "3s"
|
||||
|
||||
## The HTTP server options.
|
||||
[http]
|
||||
|
||||
@@ -6,6 +6,16 @@ default_timezone = "UTC"
|
||||
## @toml2docs:none-default
|
||||
default_column_prefix = "greptime"
|
||||
|
||||
## Maximum total memory for all concurrent write request bodies and messages (HTTP, gRPC, Flight).
|
||||
## Set to 0 to disable the limit. Default: "0" (unlimited)
|
||||
## @toml2docs:none-default
|
||||
#+ max_in_flight_write_bytes = "1GB"
|
||||
|
||||
## Policy when write bytes quota is exhausted.
|
||||
## Options: "wait" (default, 10s timeout), "wait(<duration>)" (e.g., "wait(30s)"), "fail"
|
||||
## @toml2docs:none-default
|
||||
#+ write_bytes_exhausted_policy = "wait"
|
||||
|
||||
## Initialize all regions in the background during the startup.
|
||||
## By default, it provides services after all regions have been initialized.
|
||||
init_regions_in_background = false
|
||||
@@ -22,10 +32,6 @@ max_concurrent_queries = 0
|
||||
## Enable telemetry to collect anonymous usage data. Enabled by default.
|
||||
#+ enable_telemetry = true
|
||||
|
||||
## The maximum in-flight write bytes.
|
||||
## @toml2docs:none-default
|
||||
#+ max_in_flight_write_bytes = "500MB"
|
||||
|
||||
## The runtime options.
|
||||
#+ [runtime]
|
||||
## The number of threads to execute the runtime for global read operations.
|
||||
@@ -43,10 +49,6 @@ timeout = "0s"
|
||||
## The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.
|
||||
## Set to 0 to disable limit.
|
||||
body_limit = "64MB"
|
||||
## Maximum total memory for all concurrent HTTP request bodies.
|
||||
## Set to 0 to disable the limit. Default: "0" (unlimited)
|
||||
## @toml2docs:none-default
|
||||
#+ max_total_body_memory = "1GB"
|
||||
## HTTP CORS support, it's turned on by default
|
||||
## This allows browser to access http APIs without CORS restrictions
|
||||
enable_cors = true
|
||||
@@ -67,10 +69,6 @@ prom_validation_mode = "strict"
|
||||
bind_addr = "127.0.0.1:4001"
|
||||
## The number of server worker threads.
|
||||
runtime_size = 8
|
||||
## Maximum total memory for all concurrent gRPC request messages.
|
||||
## Set to 0 to disable the limit. Default: "0" (unlimited)
|
||||
## @toml2docs:none-default
|
||||
#+ max_total_message_memory = "1GB"
|
||||
## The maximum connection age for gRPC connection.
|
||||
## The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.
|
||||
## Refer to https://grpc.io/docs/guides/keepalive/ for more details.
|
||||
@@ -230,6 +228,14 @@ recovery_parallelism = 2
|
||||
## **It's only used when the provider is `kafka`**.
|
||||
broker_endpoints = ["127.0.0.1:9092"]
|
||||
|
||||
## The connect timeout for kafka client.
|
||||
## **It's only used when the provider is `kafka`**.
|
||||
#+ connect_timeout = "3s"
|
||||
|
||||
## The timeout for kafka client.
|
||||
## **It's only used when the provider is `kafka`**.
|
||||
#+ timeout = "3s"
|
||||
|
||||
## Automatically create topics for WAL.
|
||||
## Set to `true` to automatically create topics for WAL.
|
||||
## Otherwise, use topics named `topic_name_prefix_[0..num_topics)`
|
||||
@@ -332,6 +338,7 @@ max_running_procedures = 128
|
||||
# endpoint = "https://s3.amazonaws.com"
|
||||
# region = "us-west-2"
|
||||
# enable_virtual_host_style = false
|
||||
# disable_ec2_metadata = false
|
||||
|
||||
# Example of using Oss as the storage.
|
||||
# [storage]
|
||||
@@ -388,18 +395,6 @@ data_home = "./greptimedb_data"
|
||||
## - `Oss`: the data is stored in the Aliyun OSS.
|
||||
type = "File"
|
||||
|
||||
## Whether to enable read cache. If not set, the read cache will be enabled by default when using object storage.
|
||||
#+ enable_read_cache = true
|
||||
|
||||
## Read cache configuration for object storage such as 'S3' etc, it's configured by default when using object storage. It is recommended to configure it when using object storage for better performance.
|
||||
## A local file directory, defaults to `{data_home}`. An empty string means disabling.
|
||||
## @toml2docs:none-default
|
||||
#+ cache_path = ""
|
||||
|
||||
## The local file cache capacity in bytes. If your disk space is sufficient, it is recommended to set it larger.
|
||||
## @toml2docs:none-default
|
||||
cache_capacity = "5GiB"
|
||||
|
||||
## The S3 bucket name.
|
||||
## **It's only used when the storage type is `S3`, `Oss` and `Gcs`**.
|
||||
## @toml2docs:none-default
|
||||
@@ -546,6 +541,15 @@ compress_manifest = false
|
||||
## @toml2docs:none-default="Auto"
|
||||
#+ max_background_purges = 8
|
||||
|
||||
## Memory budget for compaction tasks. Setting it to 0 or "unlimited" disables the limit.
|
||||
## @toml2docs:none-default="0"
|
||||
#+ experimental_compaction_memory_limit = "0"
|
||||
|
||||
## Behavior when compaction cannot acquire memory from the budget.
|
||||
## Options: "wait" (default, 10s), "wait(<duration>)", "fail"
|
||||
## @toml2docs:none-default="wait"
|
||||
#+ experimental_compaction_on_exhausted = "wait"
|
||||
|
||||
## Interval to auto flush a region if it has not flushed yet.
|
||||
auto_flush_interval = "1h"
|
||||
|
||||
@@ -601,6 +605,13 @@ preload_index_cache = true
|
||||
## 1GiB is reserved for index files and 4GiB for data files.
|
||||
index_cache_percent = 20
|
||||
|
||||
## Enable refilling cache on read operations (default: true).
|
||||
## When disabled, cache refilling on read won't happen.
|
||||
enable_refill_cache_on_read = true
|
||||
|
||||
## Capacity for manifest cache (default: 256MB).
|
||||
manifest_cache_size = "256MB"
|
||||
|
||||
## Buffer size for SST writing.
|
||||
sst_write_buffer_size = "8MB"
|
||||
|
||||
|
||||
94
docs/rfcs/2025-12-05-vector-index.md
Normal file
94
docs/rfcs/2025-12-05-vector-index.md
Normal file
@@ -0,0 +1,94 @@
|
||||
---
|
||||
Feature Name: Vector Index
|
||||
Tracking Issue: TBD
|
||||
Date: 2025-12-04
|
||||
Author: "TBD"
|
||||
---
|
||||
|
||||
# Summary
|
||||
Introduce a per-SST approximate nearest neighbor (ANN) index for `VECTOR(dim)` columns with a pluggable engine. USearch HNSW is the initial engine, while the design keeps VSAG (default when linked) and future engines selectable at DDL or alter time and encoded in the index metadata. The index is built alongside SST creation and accelerates `ORDER BY vec_*_distance(column, <literal vector>) LIMIT k` queries, falling back to the existing brute-force path when an index is unavailable or ineligible.
|
||||
|
||||
# Motivation
|
||||
Vector distances are currently computed with nalgebra across all rows (O(N)) before sorting, which does not scale to millions of vectors. An on-disk ANN index with sub-linear search reduces latency and compute cost for common RAG, semantic search, and recommendation workloads without changing SQL.
|
||||
|
||||
# Details
|
||||
|
||||
## Current Behavior
|
||||
`VECTOR(dim)` values are stored as binary blobs. Queries call `vec_cos_distance`/`vec_l2sq_distance`/`vec_dot_product` via nalgebra for every row and then sort; there is no indexing or caching.
|
||||
|
||||
## Index Eligibility and Configuration
|
||||
Only `VECTOR(dim)` columns can be indexed. A column metadata flag follows the existing column-option pattern with an intentionally small surface area:
|
||||
- `engine`: `vsag` (default when the binding is built) or `usearch`. If a configured engine is unavailable at runtime, the builder logs and falls back to `usearch` while leaving the option intact for future rebuilds.
|
||||
- `metric`: `cosine` (default), `l2sq`, or `dot`; mismatches with query functions force brute-force execution.
|
||||
- `m`: HNSW graph connectivity (higher = denser graph, more memory, better recall), default `16`.
|
||||
- `ef_construct`: build-time expansion, default `128`.
|
||||
- `ef_search`: query-time expansion, default `64`; engines may clamp values.
|
||||
|
||||
Option semantics mirror HNSW defaults so both USearch and VSAG can honor them; engine-specific tunables stay in reserved key-value pairs inside the blob header for forward compatibility.
|
||||
|
||||
DDL reuses column extensions similar to inverted/fulltext indexes:
|
||||
|
||||
```sql
|
||||
CREATE TABLE embeddings (
|
||||
ts TIMESTAMP TIME INDEX,
|
||||
id STRING PRIMARY KEY,
|
||||
vec VECTOR(384) VECTOR INDEX WITH (engine = 'vsag', metric = 'cosine', ef_search = 64)
|
||||
);
|
||||
```
|
||||
|
||||
Altering column options toggles the flag, can switch engines (for example `usearch` -> `vsag`), and triggers rebuilds through the existing alter/compaction flow. Engine choice stays in table metadata and each blob header; new SSTs use the configured engine while older SSTs remain readable under their recorded engine until compaction or a manual rebuild rewrites them.
|
||||
|
||||
## Storage and Format
|
||||
- One vector index per indexed column per SST, stored as a Puffin blob with type `greptime-vector-index-v1`.
|
||||
- Each blob records the engine (`usearch`, `vsag`, future values) and engine parameters in the header so readers can select the matching decoder. Mixed-engine SSTs remain readable because the engine id travels with the blob.
|
||||
- USearch uses `f32` vectors and SST row offsets (`u64`) as keys; nulls and `OpType::Delete` rows are skipped. Row ids are the absolute SST ordinal so readers can derive `RowSelection` directly from parquet row group lengths without extra side tables.
|
||||
- Blob layout:
|
||||
- Header: version, column id, dimension, engine id, metric, `m`, `ef_construct`, `ef_search`, and reserved engine-specific key-value pairs.
|
||||
- Counts: total rows written and indexed rows.
|
||||
- Payload: USearch binary produced by `save_to_buffer`.
|
||||
- An empty index (no eligible vectors) results in no available index entry for that column.
|
||||
- `puffin_manager` registers the blob type so caches and readers discover it alongside inverted/fulltext/bloom blobs in the same index file.
|
||||
|
||||
## Row Visibility and Duplicates
|
||||
- The indexer increments `row_offset` for every incoming row (including skipped/null/delete rows) so offsets stay aligned with parquet ordering across row groups.
|
||||
- Only `OpType::Put` rows with the expected dimension are inserted; `OpType::Delete` and malformed rows are skipped but still advance `row_offset`, matching the data plane’s visibility rules.
|
||||
- Multiple versions of the same primary key remain in the graph; the read path intersects search hits with the standard mito2 deduplication/visibility pipeline (sequence-aware dedup, delete filtering, projection) before returning results.
|
||||
- Searches overfetch beyond `k` to compensate for rows discarded by visibility checks and to avoid reissuing index reads.
|
||||
|
||||
## Build Path (mito2 write)
|
||||
Extend `sst::index::Indexer` to optionally create a `VectorIndexer` when region metadata marks a column as vector-indexed, mirroring how inverted/fulltext/bloom filters attach to `IndexerBuilderImpl` in `mito2`.
|
||||
|
||||
The indexer consumes `Batch`/`RecordBatch` data and shares memory tracking and abort semantics with existing indexers:
|
||||
- Maintain a running `row_offset` that follows SST write order and spans row groups so the search result can be turned into `RowSelection`.
|
||||
- For each `OpType::Put`, if the vector is non-null and matches the declared dimension, insert into USearch with `row_offset` as the key; otherwise skip.
|
||||
- Track memory with existing index build metrics; on failure, abort only the vector index while keeping SST writing unaffected.
|
||||
|
||||
Engine selection is table-driven: the builder picks the configured engine (default `vsag`, fallback `usearch` if `vsag` is not compiled in) and dispatches to the matching implementation. Unknown engines skip index build with a warning.
|
||||
|
||||
On `finish`, serialize the engine-tagged index into the Puffin writer and record `IndexType::Vector` metadata for the column. `IndexOutput` and `FileMeta::indexes/available_indexes` gain a vector entry so manifest updates and `RegionVersion` surface per-column presence, following patterns used by inverted/fulltext/bloom indexes. Planner/metadata validation ensures that mismatched dimensions only reduce the indexed-row count and do not break reads.
|
||||
|
||||
## Read Path (mito2 query)
|
||||
A planner rule in `query` identifies eligible plans on mito2 tables: a single `ORDER BY vec_cos_distance|vec_l2sq_distance|vec_dot_product(<vector column>, <literal vector>)` in ascending order plus a `LIMIT`/`TopK`. The rule rejects plans with multiple sort keys, non-literal query vectors, or additional projections that would change the distance expression and falls back to brute-force in those cases.
|
||||
|
||||
For eligible scans, build a `VectorIndexScan` execution node that:
|
||||
- Consults SST metadata for `IndexType::Vector`, loads the index via Puffin using the existing `mito2::cache::index` infrastructure, and dispatches to the engine declared in the blob header (USearch/VSAG/etc.).
|
||||
- Runs the engine’s `search` with an overfetch (for example 2×k) to tolerate rows filtered by deletes, dimension mismatches, or late-stage dedup; keys already match SST row offsets produced by the writer.
|
||||
- Converts hits to `RowSelection` using parquet row group lengths and reuses the parquet reader so visibility, projection, and deduplication logic stay unchanged; distances are recomputed with `vec_*_distance` before the final trim to k to guarantee ordering and to merge distributed partial results deterministically.
|
||||
|
||||
Any unsupported shape, load error, or cache miss falls back to the current brute-force execution path.
|
||||
|
||||
## Lifecycle and Maintenance
|
||||
Lifecycle piggybacks on the existing SST/index flow: rebuilds run where other secondary indexes do, graphs are always rebuilt from source rows (no HNSW merge), and cleanup/versioning/caching reuse the existing Puffin and index cache paths.
|
||||
|
||||
# Implementation Plan
|
||||
1. Add the `usearch` dependency (wrapper module in `index` or `mito2`) and map minimal HNSW options; keep an engine trait that allows plugging VSAG without changing the rest of the pipeline.
|
||||
2. Introduce `IndexType::Vector` and a column metadata key for vector index options (including `engine`); add SQL parser and `SHOW CREATE TABLE` support for `VECTOR INDEX WITH (...)`.
|
||||
3. Implement `vector_index` build/read modules under `mito2` (and `index` if shared), including Puffin serialization that records engine id, blob-type registration with `puffin_manager`, and integration with the `Indexer` builder, `IndexOutput`, manifest updates, and compaction rebuild.
|
||||
4. Extend the query planner/execution to detect eligible plans and drive a `RowSelection`-based ANN scan with a fallback path, dispatching by engine at read time and using existing Puffin and index caches.
|
||||
5. Add unit tests for serialization/search correctness and an end-to-end test covering plan rewrite, cache usage, engine selection, and fallback; add a mixed-engine test to confirm old USearch blobs still serve after a VSAG switch.
|
||||
6. Follow up with an optional VSAG engine binding (feature flag), validate parity with USearch on dense vectors, exercise alternative algorithms (for example PQ), and flip the default `engine` to `vsag` when the binding is present.
|
||||
|
||||
# Alternatives
|
||||
- **VSAG (follow-up engine):** C++ library with HNSW and additional algorithms (for example SINDI for sparse vectors and PQ) targeting in-memory and disk-friendly search. Provides parameter generators and a roadmap for GPU-assisted build and graph compression. Compared to FAISS it is newer with fewer integrations but bundles sparse/dense coverage and out-of-core focus in one engine. Fits the pluggable-engine design and would become the default `engine = 'vsag'` when linked; USearch remains available for lighter dependencies.
|
||||
- **FAISS:** Broad feature set (IVF/IVFPQ/PQ/HNSW, GPU acceleration, scalar filtering, pre/post filters) and battle-tested performance across datasets, but it requires a heavier C++/GPU toolchain, has no official Rust binding, and is less disk-centric than VSAG; integrating it would add more build/distribution burden than USearch/VSAG.
|
||||
- **Do nothing:** Keep brute-force evaluation, which remains O(N) and unacceptable at scale.
|
||||
20
flake.lock
generated
20
flake.lock
generated
@@ -8,11 +8,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1760078406,
|
||||
"narHash": "sha256-JeJK0ZA845PtkCHkfo4KjeI1mYrsr2s3cxBYKhF4BoE=",
|
||||
"lastModified": 1765252472,
|
||||
"narHash": "sha256-byMt/uMi7DJ8tRniFopDFZMO3leSjGp6GS4zWOFT+uQ=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "351277c60d104944122ee389cdf581c5ce2c6732",
|
||||
"rev": "8456b985f6652e3eef0632ee9992b439735c5544",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -41,16 +41,16 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1759994382,
|
||||
"narHash": "sha256-wSK+3UkalDZRVHGCRikZ//CyZUJWDJkBDTQX1+G77Ow=",
|
||||
"lastModified": 1764983851,
|
||||
"narHash": "sha256-y7RPKl/jJ/KAP/VKLMghMgXTlvNIJMHKskl8/Uuar7o=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "5da4a26309e796daa7ffca72df93dbe53b8164c7",
|
||||
"rev": "d9bc5c7dceb30d8d6fafa10aeb6aa8a48c218454",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-25.05",
|
||||
"ref": "nixos-25.11",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
@@ -65,11 +65,11 @@
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1760014945,
|
||||
"narHash": "sha256-ySdl7F9+oeWNHVrg3QL/brazqmJvYFEdpGnF3pyoDH8=",
|
||||
"lastModified": 1765120009,
|
||||
"narHash": "sha256-nG76b87rkaDzibWbnB5bYDm6a52b78A+fpm+03pqYIw=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "90d2e1ce4dfe7dc49250a8b88a0f08ffdb9cb23f",
|
||||
"rev": "5e3e9c4e61bba8a5e72134b9ffefbef8f531d008",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
description = "Development environment flake";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05";
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.11";
|
||||
fenix = {
|
||||
url = "github:nix-community/fenix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
@@ -48,7 +48,7 @@
|
||||
gnuplot ## for cargo bench
|
||||
];
|
||||
|
||||
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs;
|
||||
buildInputs = buildInputs;
|
||||
NIX_HARDENING_ENABLE = "";
|
||||
};
|
||||
});
|
||||
|
||||
@@ -708,6 +708,7 @@ fn ddl_request_type(request: &DdlRequest) -> &'static str {
|
||||
Some(Expr::CreateView(_)) => "ddl.create_view",
|
||||
Some(Expr::DropView(_)) => "ddl.drop_view",
|
||||
Some(Expr::AlterDatabase(_)) => "ddl.alter_database",
|
||||
Some(Expr::CommentOn(_)) => "ddl.comment_on",
|
||||
None => "ddl.empty",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ workspace = true
|
||||
api.workspace = true
|
||||
async-trait.workspace = true
|
||||
common-base.workspace = true
|
||||
common-config.workspace = true
|
||||
common-error.workspace = true
|
||||
common-macro.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
digest = "0.10"
|
||||
notify.workspace = true
|
||||
sha1 = "0.10"
|
||||
snafu.workspace = true
|
||||
sql.workspace = true
|
||||
|
||||
@@ -75,11 +75,12 @@ pub enum Error {
|
||||
username: String,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to initialize a watcher for file {}", path))]
|
||||
#[snafu(display("Failed to initialize a file watcher"))]
|
||||
FileWatch {
|
||||
path: String,
|
||||
#[snafu(source)]
|
||||
error: notify::Error,
|
||||
source: common_config::error::Error,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("User is not authorized to perform this action"))]
|
||||
|
||||
@@ -12,16 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::Path;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_config::file_watcher::{FileWatcherBuilder, FileWatcherConfig};
|
||||
use common_telemetry::{info, warn};
|
||||
use notify::{EventKind, RecursiveMode, Watcher};
|
||||
use snafu::{ResultExt, ensure};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{FileWatchSnafu, InvalidConfigSnafu, Result};
|
||||
use crate::error::{FileWatchSnafu, Result};
|
||||
use crate::user_provider::{UserInfoMap, authenticate_with_credential, load_credential_from_file};
|
||||
use crate::{Identity, Password, UserInfoRef, UserProvider};
|
||||
|
||||
@@ -41,61 +39,36 @@ impl WatchFileUserProvider {
|
||||
pub fn new(filepath: &str) -> Result<Self> {
|
||||
let credential = load_credential_from_file(filepath)?;
|
||||
let users = Arc::new(Mutex::new(credential));
|
||||
let this = WatchFileUserProvider {
|
||||
users: users.clone(),
|
||||
};
|
||||
|
||||
let (tx, rx) = channel::<notify::Result<notify::Event>>();
|
||||
let mut debouncer =
|
||||
notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
|
||||
let mut dir = Path::new(filepath).to_path_buf();
|
||||
ensure!(
|
||||
dir.pop(),
|
||||
InvalidConfigSnafu {
|
||||
value: filepath,
|
||||
msg: "UserProvider path must be a file path",
|
||||
}
|
||||
);
|
||||
debouncer
|
||||
.watch(&dir, RecursiveMode::NonRecursive)
|
||||
.context(FileWatchSnafu { path: filepath })?;
|
||||
let users_clone = users.clone();
|
||||
let filepath_owned = filepath.to_string();
|
||||
|
||||
let filepath = filepath.to_string();
|
||||
std::thread::spawn(move || {
|
||||
let filename = Path::new(&filepath).file_name();
|
||||
let _hold = debouncer;
|
||||
while let Ok(res) = rx.recv() {
|
||||
if let Ok(event) = res {
|
||||
let is_this_file = event.paths.iter().any(|p| p.file_name() == filename);
|
||||
let is_relevant_event = matches!(
|
||||
event.kind,
|
||||
EventKind::Modify(_) | EventKind::Create(_) | EventKind::Remove(_)
|
||||
FileWatcherBuilder::new()
|
||||
.watch_path(filepath)
|
||||
.context(FileWatchSnafu)?
|
||||
.config(FileWatcherConfig::new())
|
||||
.spawn(move || match load_credential_from_file(&filepath_owned) {
|
||||
Ok(credential) => {
|
||||
let mut users = users_clone.lock().expect("users credential must be valid");
|
||||
#[cfg(not(test))]
|
||||
info!("User provider file {} reloaded", &filepath_owned);
|
||||
#[cfg(test)]
|
||||
info!(
|
||||
"User provider file {} reloaded: {:?}",
|
||||
&filepath_owned, credential
|
||||
);
|
||||
if is_this_file && is_relevant_event {
|
||||
info!(?event.kind, "User provider file {} changed", &filepath);
|
||||
match load_credential_from_file(&filepath) {
|
||||
Ok(credential) => {
|
||||
let mut users =
|
||||
users.lock().expect("users credential must be valid");
|
||||
#[cfg(not(test))]
|
||||
info!("User provider file {filepath} reloaded");
|
||||
#[cfg(test)]
|
||||
info!("User provider file {filepath} reloaded: {credential:?}");
|
||||
*users = credential;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
?err,
|
||||
"Fail to load credential from file {filepath}; keep the old one",
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
*users = credential;
|
||||
}
|
||||
}
|
||||
});
|
||||
Err(err) => {
|
||||
warn!(
|
||||
?err,
|
||||
"Fail to load credential from file {}; keep the old one", &filepath_owned
|
||||
)
|
||||
}
|
||||
})
|
||||
.context(FileWatchSnafu)?;
|
||||
|
||||
Ok(this)
|
||||
Ok(WatchFileUserProvider { users })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -428,7 +428,7 @@ pub trait InformationExtension {
|
||||
}
|
||||
|
||||
/// The request to inspect the datanode.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct DatanodeInspectRequest {
|
||||
/// Kind to fetch from datanode.
|
||||
pub kind: DatanodeInspectKind,
|
||||
|
||||
@@ -211,6 +211,7 @@ struct InformationSchemaPartitionsBuilder {
|
||||
partition_names: StringVectorBuilder,
|
||||
partition_ordinal_positions: Int64VectorBuilder,
|
||||
partition_expressions: StringVectorBuilder,
|
||||
partition_descriptions: StringVectorBuilder,
|
||||
create_times: TimestampSecondVectorBuilder,
|
||||
partition_ids: UInt64VectorBuilder,
|
||||
}
|
||||
@@ -231,6 +232,7 @@ impl InformationSchemaPartitionsBuilder {
|
||||
partition_names: StringVectorBuilder::with_capacity(INIT_CAPACITY),
|
||||
partition_ordinal_positions: Int64VectorBuilder::with_capacity(INIT_CAPACITY),
|
||||
partition_expressions: StringVectorBuilder::with_capacity(INIT_CAPACITY),
|
||||
partition_descriptions: StringVectorBuilder::with_capacity(INIT_CAPACITY),
|
||||
create_times: TimestampSecondVectorBuilder::with_capacity(INIT_CAPACITY),
|
||||
partition_ids: UInt64VectorBuilder::with_capacity(INIT_CAPACITY),
|
||||
}
|
||||
@@ -319,6 +321,21 @@ impl InformationSchemaPartitionsBuilder {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get partition column names (shared by all partitions)
|
||||
// In MySQL, PARTITION_EXPRESSION is the partitioning function expression (e.g., column name)
|
||||
let partition_columns: String = table_info
|
||||
.meta
|
||||
.partition_column_names()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
let partition_expr_str = if partition_columns.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(partition_columns)
|
||||
};
|
||||
|
||||
for (index, partition) in partitions.iter().enumerate() {
|
||||
let partition_name = format!("p{index}");
|
||||
|
||||
@@ -328,8 +345,12 @@ impl InformationSchemaPartitionsBuilder {
|
||||
self.partition_names.push(Some(&partition_name));
|
||||
self.partition_ordinal_positions
|
||||
.push(Some((index + 1) as i64));
|
||||
let expression = partition.partition_expr.as_ref().map(|e| e.to_string());
|
||||
self.partition_expressions.push(expression.as_deref());
|
||||
// PARTITION_EXPRESSION: partition column names (same for all partitions)
|
||||
self.partition_expressions
|
||||
.push(partition_expr_str.as_deref());
|
||||
// PARTITION_DESCRIPTION: partition boundary expression (different for each partition)
|
||||
let description = partition.partition_expr.as_ref().map(|e| e.to_string());
|
||||
self.partition_descriptions.push(description.as_deref());
|
||||
self.create_times.push(Some(TimestampSecond::from(
|
||||
table_info.meta.created_on.timestamp(),
|
||||
)));
|
||||
@@ -369,7 +390,7 @@ impl InformationSchemaPartitionsBuilder {
|
||||
null_string_vector.clone(),
|
||||
Arc::new(self.partition_expressions.finish()),
|
||||
null_string_vector.clone(),
|
||||
null_string_vector.clone(),
|
||||
Arc::new(self.partition_descriptions.finish()),
|
||||
// TODO(dennis): rows and index statistics info
|
||||
null_i64_vector.clone(),
|
||||
null_i64_vector.clone(),
|
||||
|
||||
@@ -67,6 +67,7 @@ tracing-appender.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
common-meta = { workspace = true, features = ["testing"] }
|
||||
common-test-util.workspace = true
|
||||
common-version.workspace = true
|
||||
serde.workspace = true
|
||||
tempfile.workspace = true
|
||||
|
||||
@@ -15,5 +15,8 @@
|
||||
mod object_store;
|
||||
mod store;
|
||||
|
||||
pub use object_store::{ObjectStoreConfig, new_fs_object_store};
|
||||
pub use object_store::{
|
||||
ObjectStoreConfig, PrefixedAzblobConnection, PrefixedGcsConnection, PrefixedOssConnection,
|
||||
PrefixedS3Connection, new_fs_object_store,
|
||||
};
|
||||
pub use store::StoreConfig;
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_base::secrets::SecretString;
|
||||
use common_base::secrets::{ExposeSecret, SecretString};
|
||||
use common_error::ext::BoxedError;
|
||||
use object_store::services::{Azblob, Fs, Gcs, Oss, S3};
|
||||
use object_store::util::{with_instrument_layers, with_retry_layers};
|
||||
@@ -22,9 +22,69 @@ use snafu::ResultExt;
|
||||
|
||||
use crate::error::{self};
|
||||
|
||||
/// Trait to convert CLI field types to target struct field types.
|
||||
/// This enables `Option<SecretString>` (CLI) -> `SecretString` (target) conversions,
|
||||
/// allowing us to distinguish "not provided" from "provided but empty".
|
||||
trait IntoField<T> {
|
||||
fn into_field(self) -> T;
|
||||
}
|
||||
|
||||
/// Identity conversion for types that are the same.
|
||||
impl<T> IntoField<T> for T {
|
||||
fn into_field(self) -> T {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert `Option<SecretString>` to `SecretString`, using default for None.
|
||||
impl IntoField<SecretString> for Option<SecretString> {
|
||||
fn into_field(self) -> SecretString {
|
||||
self.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for checking if a field is effectively empty.
|
||||
///
|
||||
/// **`is_empty()`**: Checks if the field has no meaningful value
|
||||
/// - Used when backend is enabled to validate required fields
|
||||
/// - `None`, `Some("")`, `false`, or `""` are considered empty
|
||||
trait FieldValidator {
|
||||
/// Check if the field is empty (has no meaningful value).
|
||||
fn is_empty(&self) -> bool;
|
||||
}
|
||||
|
||||
/// String fields: empty if the string is empty
|
||||
impl FieldValidator for String {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Bool fields: false is considered "empty", true is "provided"
|
||||
impl FieldValidator for bool {
|
||||
fn is_empty(&self) -> bool {
|
||||
!self
|
||||
}
|
||||
}
|
||||
|
||||
/// Option<String> fields: None or empty content is empty
|
||||
impl FieldValidator for Option<String> {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.as_ref().is_none_or(|s| s.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
/// Option<SecretString> fields: None or empty secret is empty
|
||||
/// For secrets, Some("") is treated as "not provided" for both checks
|
||||
impl FieldValidator for Option<SecretString> {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.as_ref().is_none_or(|s| s.expose_secret().is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! wrap_with_clap_prefix {
|
||||
(
|
||||
$new_name:ident, $prefix:literal, $base:ty, {
|
||||
$new_name:ident, $prefix:literal, $enable_flag:literal, $base:ty, {
|
||||
$( $( #[doc = $doc:expr] )? $( #[alias = $alias:literal] )? $field:ident : $type:ty $( = $default:expr )? ),* $(,)?
|
||||
}
|
||||
) => {
|
||||
@@ -34,15 +94,16 @@ macro_rules! wrap_with_clap_prefix {
|
||||
$(
|
||||
$( #[doc = $doc] )?
|
||||
$( #[clap(alias = $alias)] )?
|
||||
#[clap(long $(, default_value_t = $default )? )]
|
||||
[<$prefix $field>]: $type,
|
||||
#[clap(long, requires = $enable_flag $(, default_value_t = $default )? )]
|
||||
pub [<$prefix $field>]: $type,
|
||||
)*
|
||||
}
|
||||
|
||||
impl From<$new_name> for $base {
|
||||
fn from(w: $new_name) -> Self {
|
||||
Self {
|
||||
$( $field: w.[<$prefix $field>] ),*
|
||||
// Use into_field() to handle Option<SecretString> -> SecretString conversion
|
||||
$( $field: w.[<$prefix $field>].into_field() ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -50,9 +111,90 @@ macro_rules! wrap_with_clap_prefix {
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro for declarative backend validation.
|
||||
///
|
||||
/// # Validation Rules
|
||||
///
|
||||
/// For each storage backend (S3, OSS, GCS, Azblob), this function validates:
|
||||
/// **When backend is enabled** (e.g., `--s3`): All required fields must be non-empty
|
||||
///
|
||||
/// Note: When backend is disabled, clap's `requires` attribute ensures no configuration
|
||||
/// fields can be provided at parse time.
|
||||
///
|
||||
/// # Syntax
|
||||
///
|
||||
/// ```ignore
|
||||
/// validate_backend!(
|
||||
/// enable: self.enable_s3,
|
||||
/// name: "S3",
|
||||
/// required: [(field1, "name1"), (field2, "name2"), ...],
|
||||
/// custom_validator: |missing| { ... } // optional
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// - `enable`: Boolean expression indicating if backend is enabled
|
||||
/// - `name`: Human-readable backend name for error messages
|
||||
/// - `required`: Array of (field_ref, field_name) tuples for required fields
|
||||
/// - `custom_validator`: Optional closure for complex validation logic
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// validate_backend!(
|
||||
/// enable: self.enable_s3,
|
||||
/// name: "S3",
|
||||
/// required: [
|
||||
/// (&self.s3.s3_bucket, "bucket"),
|
||||
/// (&self.s3.s3_access_key_id, "access key ID"),
|
||||
/// ]
|
||||
/// )
|
||||
/// ```
|
||||
macro_rules! validate_backend {
|
||||
(
|
||||
enable: $enable:expr,
|
||||
name: $backend_name:expr,
|
||||
required: [ $( ($field:expr, $field_name:expr) ),* $(,)? ]
|
||||
$(, custom_validator: $custom_validator:expr)?
|
||||
) => {{
|
||||
if $enable {
|
||||
// Check required fields when backend is enabled
|
||||
let mut missing = Vec::new();
|
||||
$(
|
||||
if FieldValidator::is_empty($field) {
|
||||
missing.push($field_name);
|
||||
}
|
||||
)*
|
||||
|
||||
// Run custom validation if provided
|
||||
$(
|
||||
$custom_validator(&mut missing);
|
||||
)?
|
||||
|
||||
if !missing.is_empty() {
|
||||
return Err(BoxedError::new(
|
||||
error::MissingConfigSnafu {
|
||||
msg: format!(
|
||||
"{} {} must be set when --{} is enabled.",
|
||||
$backend_name,
|
||||
missing.join(", "),
|
||||
$backend_name.to_lowercase()
|
||||
),
|
||||
}
|
||||
.build(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}};
|
||||
}
|
||||
|
||||
wrap_with_clap_prefix! {
|
||||
PrefixedAzblobConnection,
|
||||
"azblob-",
|
||||
"enable_azblob",
|
||||
AzblobConnection,
|
||||
{
|
||||
#[doc = "The container of the object store."]
|
||||
@@ -60,9 +202,9 @@ wrap_with_clap_prefix! {
|
||||
#[doc = "The root of the object store."]
|
||||
root: String = Default::default(),
|
||||
#[doc = "The account name of the object store."]
|
||||
account_name: SecretString = Default::default(),
|
||||
account_name: Option<SecretString>,
|
||||
#[doc = "The account key of the object store."]
|
||||
account_key: SecretString = Default::default(),
|
||||
account_key: Option<SecretString>,
|
||||
#[doc = "The endpoint of the object store."]
|
||||
endpoint: String = Default::default(),
|
||||
#[doc = "The SAS token of the object store."]
|
||||
@@ -70,9 +212,33 @@ wrap_with_clap_prefix! {
|
||||
}
|
||||
}
|
||||
|
||||
impl PrefixedAzblobConnection {
|
||||
pub fn validate(&self) -> Result<(), BoxedError> {
|
||||
validate_backend!(
|
||||
enable: true,
|
||||
name: "AzBlob",
|
||||
required: [
|
||||
(&self.azblob_container, "container"),
|
||||
(&self.azblob_root, "root"),
|
||||
(&self.azblob_account_name, "account name"),
|
||||
(&self.azblob_endpoint, "endpoint"),
|
||||
],
|
||||
custom_validator: |missing: &mut Vec<&str>| {
|
||||
// account_key is only required if sas_token is not provided
|
||||
if self.azblob_sas_token.is_none()
|
||||
&& self.azblob_account_key.is_empty()
|
||||
{
|
||||
missing.push("account key (when sas_token is not provided)");
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
wrap_with_clap_prefix! {
|
||||
PrefixedS3Connection,
|
||||
"s3-",
|
||||
"enable_s3",
|
||||
S3Connection,
|
||||
{
|
||||
#[doc = "The bucket of the object store."]
|
||||
@@ -80,21 +246,39 @@ wrap_with_clap_prefix! {
|
||||
#[doc = "The root of the object store."]
|
||||
root: String = Default::default(),
|
||||
#[doc = "The access key ID of the object store."]
|
||||
access_key_id: SecretString = Default::default(),
|
||||
access_key_id: Option<SecretString>,
|
||||
#[doc = "The secret access key of the object store."]
|
||||
secret_access_key: SecretString = Default::default(),
|
||||
secret_access_key: Option<SecretString>,
|
||||
#[doc = "The endpoint of the object store."]
|
||||
endpoint: Option<String>,
|
||||
#[doc = "The region of the object store."]
|
||||
region: Option<String>,
|
||||
#[doc = "Enable virtual host style for the object store."]
|
||||
enable_virtual_host_style: bool = Default::default(),
|
||||
#[doc = "Disable EC2 metadata service for the object store."]
|
||||
disable_ec2_metadata: bool = Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
impl PrefixedS3Connection {
|
||||
pub fn validate(&self) -> Result<(), BoxedError> {
|
||||
validate_backend!(
|
||||
enable: true,
|
||||
name: "S3",
|
||||
required: [
|
||||
(&self.s3_bucket, "bucket"),
|
||||
(&self.s3_access_key_id, "access key ID"),
|
||||
(&self.s3_secret_access_key, "secret access key"),
|
||||
(&self.s3_region, "region"),
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
wrap_with_clap_prefix! {
|
||||
PrefixedOssConnection,
|
||||
"oss-",
|
||||
"enable_oss",
|
||||
OssConnection,
|
||||
{
|
||||
#[doc = "The bucket of the object store."]
|
||||
@@ -102,17 +286,33 @@ wrap_with_clap_prefix! {
|
||||
#[doc = "The root of the object store."]
|
||||
root: String = Default::default(),
|
||||
#[doc = "The access key ID of the object store."]
|
||||
access_key_id: SecretString = Default::default(),
|
||||
access_key_id: Option<SecretString>,
|
||||
#[doc = "The access key secret of the object store."]
|
||||
access_key_secret: SecretString = Default::default(),
|
||||
access_key_secret: Option<SecretString>,
|
||||
#[doc = "The endpoint of the object store."]
|
||||
endpoint: String = Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
impl PrefixedOssConnection {
|
||||
pub fn validate(&self) -> Result<(), BoxedError> {
|
||||
validate_backend!(
|
||||
enable: true,
|
||||
name: "OSS",
|
||||
required: [
|
||||
(&self.oss_bucket, "bucket"),
|
||||
(&self.oss_access_key_id, "access key ID"),
|
||||
(&self.oss_access_key_secret, "access key secret"),
|
||||
(&self.oss_endpoint, "endpoint"),
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
wrap_with_clap_prefix! {
|
||||
PrefixedGcsConnection,
|
||||
"gcs-",
|
||||
"enable_gcs",
|
||||
GcsConnection,
|
||||
{
|
||||
#[doc = "The root of the object store."]
|
||||
@@ -122,40 +322,72 @@ wrap_with_clap_prefix! {
|
||||
#[doc = "The scope of the object store."]
|
||||
scope: String = Default::default(),
|
||||
#[doc = "The credential path of the object store."]
|
||||
credential_path: SecretString = Default::default(),
|
||||
credential_path: Option<SecretString>,
|
||||
#[doc = "The credential of the object store."]
|
||||
credential: SecretString = Default::default(),
|
||||
credential: Option<SecretString>,
|
||||
#[doc = "The endpoint of the object store."]
|
||||
endpoint: String = Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// common config for object store.
|
||||
impl PrefixedGcsConnection {
|
||||
pub fn validate(&self) -> Result<(), BoxedError> {
|
||||
validate_backend!(
|
||||
enable: true,
|
||||
name: "GCS",
|
||||
required: [
|
||||
(&self.gcs_bucket, "bucket"),
|
||||
(&self.gcs_root, "root"),
|
||||
(&self.gcs_scope, "scope"),
|
||||
]
|
||||
// No custom_validator needed: GCS supports Application Default Credentials (ADC)
|
||||
// where neither credential_path nor credential is required.
|
||||
// Endpoint is also optional (defaults to https://storage.googleapis.com).
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Common config for object store.
|
||||
///
|
||||
/// # Dependency Enforcement
|
||||
///
|
||||
/// Each backend's configuration fields (e.g., `--s3-bucket`) requires its corresponding
|
||||
/// enable flag (e.g., `--s3`) to be present. This is enforced by `clap` at parse time
|
||||
/// using the `requires` attribute.
|
||||
///
|
||||
/// For example, attempting to use `--s3-bucket my-bucket` without `--s3` will result in:
|
||||
/// ```text
|
||||
/// error: The argument '--s3-bucket <BUCKET>' requires '--s3'
|
||||
/// ```
|
||||
///
|
||||
/// This ensures that users cannot accidentally provide backend-specific configuration
|
||||
/// without explicitly enabling that backend.
|
||||
#[derive(clap::Parser, Debug, Clone, PartialEq, Default)]
|
||||
#[clap(group(clap::ArgGroup::new("storage_backend").required(false).multiple(false)))]
|
||||
pub struct ObjectStoreConfig {
|
||||
/// Whether to use S3 object store.
|
||||
#[clap(long, alias = "s3")]
|
||||
#[clap(long = "s3", group = "storage_backend")]
|
||||
pub enable_s3: bool,
|
||||
|
||||
#[clap(flatten)]
|
||||
pub s3: PrefixedS3Connection,
|
||||
|
||||
/// Whether to use OSS.
|
||||
#[clap(long, alias = "oss")]
|
||||
#[clap(long = "oss", group = "storage_backend")]
|
||||
pub enable_oss: bool,
|
||||
|
||||
#[clap(flatten)]
|
||||
pub oss: PrefixedOssConnection,
|
||||
|
||||
/// Whether to use GCS.
|
||||
#[clap(long, alias = "gcs")]
|
||||
#[clap(long = "gcs", group = "storage_backend")]
|
||||
pub enable_gcs: bool,
|
||||
|
||||
#[clap(flatten)]
|
||||
pub gcs: PrefixedGcsConnection,
|
||||
|
||||
/// Whether to use Azure Blob.
|
||||
#[clap(long, alias = "azblob")]
|
||||
#[clap(long = "azblob", group = "storage_backend")]
|
||||
pub enable_azblob: bool,
|
||||
|
||||
#[clap(flatten)]
|
||||
@@ -173,52 +405,66 @@ pub fn new_fs_object_store(root: &str) -> std::result::Result<ObjectStore, Boxed
|
||||
Ok(with_instrument_layers(object_store, false))
|
||||
}
|
||||
|
||||
macro_rules! gen_object_store_builder {
|
||||
($method:ident, $field:ident, $conn_type:ty, $service_type:ty) => {
|
||||
pub fn $method(&self) -> Result<ObjectStore, BoxedError> {
|
||||
let config = <$conn_type>::from(self.$field.clone());
|
||||
common_telemetry::info!(
|
||||
"Building object store with {}: {:?}",
|
||||
stringify!($field),
|
||||
config
|
||||
);
|
||||
let object_store = ObjectStore::new(<$service_type>::from(&config))
|
||||
.context(error::InitBackendSnafu)
|
||||
.map_err(BoxedError::new)?
|
||||
.finish();
|
||||
Ok(with_instrument_layers(
|
||||
with_retry_layers(object_store),
|
||||
false,
|
||||
))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl ObjectStoreConfig {
|
||||
gen_object_store_builder!(build_s3, s3, S3Connection, S3);
|
||||
|
||||
gen_object_store_builder!(build_oss, oss, OssConnection, Oss);
|
||||
|
||||
gen_object_store_builder!(build_gcs, gcs, GcsConnection, Gcs);
|
||||
|
||||
gen_object_store_builder!(build_azblob, azblob, AzblobConnection, Azblob);
|
||||
|
||||
pub fn validate(&self) -> Result<(), BoxedError> {
|
||||
if self.enable_s3 {
|
||||
self.s3.validate()?;
|
||||
}
|
||||
if self.enable_oss {
|
||||
self.oss.validate()?;
|
||||
}
|
||||
if self.enable_gcs {
|
||||
self.gcs.validate()?;
|
||||
}
|
||||
if self.enable_azblob {
|
||||
self.azblob.validate()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Builds the object store from the config.
|
||||
pub fn build(&self) -> Result<Option<ObjectStore>, BoxedError> {
|
||||
let object_store = if self.enable_s3 {
|
||||
let s3 = S3Connection::from(self.s3.clone());
|
||||
common_telemetry::info!("Building object store with s3: {:?}", s3);
|
||||
Some(
|
||||
ObjectStore::new(S3::from(&s3))
|
||||
.context(error::InitBackendSnafu)
|
||||
.map_err(BoxedError::new)?
|
||||
.finish(),
|
||||
)
|
||||
self.validate()?;
|
||||
|
||||
if self.enable_s3 {
|
||||
self.build_s3().map(Some)
|
||||
} else if self.enable_oss {
|
||||
let oss = OssConnection::from(self.oss.clone());
|
||||
common_telemetry::info!("Building object store with oss: {:?}", oss);
|
||||
Some(
|
||||
ObjectStore::new(Oss::from(&oss))
|
||||
.context(error::InitBackendSnafu)
|
||||
.map_err(BoxedError::new)?
|
||||
.finish(),
|
||||
)
|
||||
self.build_oss().map(Some)
|
||||
} else if self.enable_gcs {
|
||||
let gcs = GcsConnection::from(self.gcs.clone());
|
||||
common_telemetry::info!("Building object store with gcs: {:?}", gcs);
|
||||
Some(
|
||||
ObjectStore::new(Gcs::from(&gcs))
|
||||
.context(error::InitBackendSnafu)
|
||||
.map_err(BoxedError::new)?
|
||||
.finish(),
|
||||
)
|
||||
self.build_gcs().map(Some)
|
||||
} else if self.enable_azblob {
|
||||
let azblob = AzblobConnection::from(self.azblob.clone());
|
||||
common_telemetry::info!("Building object store with azblob: {:?}", azblob);
|
||||
Some(
|
||||
ObjectStore::new(Azblob::from(&azblob))
|
||||
.context(error::InitBackendSnafu)
|
||||
.map_err(BoxedError::new)?
|
||||
.finish(),
|
||||
)
|
||||
self.build_azblob().map(Some)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let object_store = object_store
|
||||
.map(|object_store| with_instrument_layers(with_retry_layers(object_store), false));
|
||||
|
||||
Ok(object_store)
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ use common_error::ext::BoxedError;
|
||||
use common_meta::kv_backend::KvBackendRef;
|
||||
use common_meta::kv_backend::chroot::ChrootKvBackend;
|
||||
use common_meta::kv_backend::etcd::EtcdStore;
|
||||
use meta_srv::metasrv::BackendImpl;
|
||||
use meta_srv::metasrv::{BackendClientOptions, BackendImpl};
|
||||
use meta_srv::utils::etcd::create_etcd_client_with_tls;
|
||||
use servers::tls::{TlsMode, TlsOption};
|
||||
|
||||
@@ -61,6 +61,12 @@ pub struct StoreConfig {
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
#[clap(long)]
|
||||
pub meta_schema_name: Option<String>,
|
||||
|
||||
/// Automatically create PostgreSQL schema if it doesn't exist (default: true).
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
#[clap(long, default_value_t = true)]
|
||||
pub auto_create_schema: bool,
|
||||
|
||||
/// TLS mode for backend store connections (etcd, PostgreSQL, MySQL)
|
||||
#[clap(long = "backend-tls-mode", value_enum, default_value = "disable")]
|
||||
pub backend_tls_mode: TlsMode,
|
||||
@@ -112,9 +118,13 @@ impl StoreConfig {
|
||||
let kvbackend = match self.backend {
|
||||
BackendImpl::EtcdStore => {
|
||||
let tls_config = self.tls_config();
|
||||
let etcd_client = create_etcd_client_with_tls(store_addrs, tls_config.as_ref())
|
||||
.await
|
||||
.map_err(BoxedError::new)?;
|
||||
let etcd_client = create_etcd_client_with_tls(
|
||||
store_addrs,
|
||||
&BackendClientOptions::default(),
|
||||
tls_config.as_ref(),
|
||||
)
|
||||
.await
|
||||
.map_err(BoxedError::new)?;
|
||||
Ok(EtcdStore::with_etcd_client(etcd_client, max_txn_ops))
|
||||
}
|
||||
#[cfg(feature = "pg_kvbackend")]
|
||||
@@ -134,6 +144,7 @@ impl StoreConfig {
|
||||
schema_name,
|
||||
table_name,
|
||||
max_txn_ops,
|
||||
self.auto_create_schema,
|
||||
)
|
||||
.await
|
||||
.map_err(BoxedError::new)?)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
mod export;
|
||||
mod import;
|
||||
mod storage_export;
|
||||
|
||||
use clap::Subcommand;
|
||||
use client::DEFAULT_CATALOG_NAME;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
373
src/cli/src/data/storage_export.rs
Normal file
373
src/cli/src/data/storage_export.rs
Normal file
@@ -0,0 +1,373 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use common_base::secrets::{ExposeSecret, SecretString};
|
||||
use common_error::ext::BoxedError;
|
||||
|
||||
use crate::common::{
|
||||
PrefixedAzblobConnection, PrefixedGcsConnection, PrefixedOssConnection, PrefixedS3Connection,
|
||||
};
|
||||
|
||||
/// Helper function to extract secret string from Option<SecretString>.
|
||||
/// Returns empty string if None.
|
||||
fn expose_optional_secret(secret: &Option<SecretString>) -> &str {
|
||||
secret
|
||||
.as_ref()
|
||||
.map(|s| s.expose_secret().as_str())
|
||||
.unwrap_or("")
|
||||
}
|
||||
|
||||
/// Helper function to format root path with leading slash if non-empty.
|
||||
fn format_root_path(root: &str) -> String {
|
||||
if root.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("/{}", root)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function to mask multiple secrets in a string.
|
||||
fn mask_secrets(mut sql: String, secrets: &[&str]) -> String {
|
||||
for secret in secrets {
|
||||
if !secret.is_empty() {
|
||||
sql = sql.replace(secret, "[REDACTED]");
|
||||
}
|
||||
}
|
||||
sql
|
||||
}
|
||||
|
||||
/// Helper function to format storage URI.
|
||||
fn format_uri(scheme: &str, bucket: &str, root: &str, path: &str) -> String {
|
||||
let root = format_root_path(root);
|
||||
format!("{}://{}{}/{}", scheme, bucket, root, path)
|
||||
}
|
||||
|
||||
/// Trait for storage backends that can be used for data export.
|
||||
pub trait StorageExport: Send + Sync {
|
||||
/// Generate the storage path for COPY DATABASE command.
|
||||
/// Returns (path, connection_string) where connection_string includes CONNECTION clause.
|
||||
fn get_storage_path(&self, catalog: &str, schema: &str) -> (String, String);
|
||||
|
||||
/// Format the output path for logging purposes.
|
||||
fn format_output_path(&self, file_path: &str) -> String;
|
||||
|
||||
/// Mask sensitive information in SQL commands for safe logging.
|
||||
fn mask_sensitive_info(&self, sql: &str) -> String;
|
||||
}
|
||||
|
||||
macro_rules! define_backend {
|
||||
($name:ident, $config:ty) => {
|
||||
#[derive(Clone)]
|
||||
pub struct $name {
|
||||
config: $config,
|
||||
}
|
||||
|
||||
impl $name {
|
||||
pub fn new(config: $config) -> Result<Self, BoxedError> {
|
||||
config.validate()?;
|
||||
Ok(Self { config })
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Local file system storage backend.
|
||||
#[derive(Clone)]
|
||||
pub struct FsBackend {
|
||||
output_dir: String,
|
||||
}
|
||||
|
||||
impl FsBackend {
|
||||
pub fn new(output_dir: String) -> Self {
|
||||
Self { output_dir }
|
||||
}
|
||||
}
|
||||
|
||||
impl StorageExport for FsBackend {
|
||||
fn get_storage_path(&self, catalog: &str, schema: &str) -> (String, String) {
|
||||
if self.output_dir.is_empty() {
|
||||
unreachable!("output_dir must be set when not using remote storage")
|
||||
}
|
||||
let path = PathBuf::from(&self.output_dir)
|
||||
.join(catalog)
|
||||
.join(format!("{schema}/"))
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
(path, String::new())
|
||||
}
|
||||
|
||||
fn format_output_path(&self, file_path: &str) -> String {
|
||||
format!("{}/{}", self.output_dir, file_path)
|
||||
}
|
||||
|
||||
fn mask_sensitive_info(&self, sql: &str) -> String {
|
||||
sql.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
define_backend!(S3Backend, PrefixedS3Connection);
|
||||
|
||||
impl StorageExport for S3Backend {
|
||||
fn get_storage_path(&self, catalog: &str, schema: &str) -> (String, String) {
|
||||
let s3_path = format_uri(
|
||||
"s3",
|
||||
&self.config.s3_bucket,
|
||||
&self.config.s3_root,
|
||||
&format!("{}/{}/", catalog, schema),
|
||||
);
|
||||
|
||||
let mut connection_options = vec![
|
||||
format!(
|
||||
"ACCESS_KEY_ID='{}'",
|
||||
expose_optional_secret(&self.config.s3_access_key_id)
|
||||
),
|
||||
format!(
|
||||
"SECRET_ACCESS_KEY='{}'",
|
||||
expose_optional_secret(&self.config.s3_secret_access_key)
|
||||
),
|
||||
];
|
||||
|
||||
if let Some(region) = &self.config.s3_region {
|
||||
connection_options.push(format!("REGION='{}'", region));
|
||||
}
|
||||
|
||||
if let Some(endpoint) = &self.config.s3_endpoint {
|
||||
connection_options.push(format!("ENDPOINT='{}'", endpoint));
|
||||
}
|
||||
|
||||
let connection_str = format!(" CONNECTION ({})", connection_options.join(", "));
|
||||
(s3_path, connection_str)
|
||||
}
|
||||
|
||||
fn format_output_path(&self, file_path: &str) -> String {
|
||||
format_uri(
|
||||
"s3",
|
||||
&self.config.s3_bucket,
|
||||
&self.config.s3_root,
|
||||
file_path,
|
||||
)
|
||||
}
|
||||
|
||||
fn mask_sensitive_info(&self, sql: &str) -> String {
|
||||
mask_secrets(
|
||||
sql.to_string(),
|
||||
&[
|
||||
expose_optional_secret(&self.config.s3_access_key_id),
|
||||
expose_optional_secret(&self.config.s3_secret_access_key),
|
||||
],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
define_backend!(OssBackend, PrefixedOssConnection);
|
||||
|
||||
impl StorageExport for OssBackend {
|
||||
fn get_storage_path(&self, catalog: &str, schema: &str) -> (String, String) {
|
||||
let oss_path = format_uri(
|
||||
"oss",
|
||||
&self.config.oss_bucket,
|
||||
&self.config.oss_root,
|
||||
&format!("{}/{}/", catalog, schema),
|
||||
);
|
||||
|
||||
let connection_options = [
|
||||
format!(
|
||||
"ACCESS_KEY_ID='{}'",
|
||||
expose_optional_secret(&self.config.oss_access_key_id)
|
||||
),
|
||||
format!(
|
||||
"ACCESS_KEY_SECRET='{}'",
|
||||
expose_optional_secret(&self.config.oss_access_key_secret)
|
||||
),
|
||||
];
|
||||
|
||||
let connection_str = format!(" CONNECTION ({})", connection_options.join(", "));
|
||||
(oss_path, connection_str)
|
||||
}
|
||||
|
||||
fn format_output_path(&self, file_path: &str) -> String {
|
||||
format_uri(
|
||||
"oss",
|
||||
&self.config.oss_bucket,
|
||||
&self.config.oss_root,
|
||||
file_path,
|
||||
)
|
||||
}
|
||||
|
||||
fn mask_sensitive_info(&self, sql: &str) -> String {
|
||||
mask_secrets(
|
||||
sql.to_string(),
|
||||
&[
|
||||
expose_optional_secret(&self.config.oss_access_key_id),
|
||||
expose_optional_secret(&self.config.oss_access_key_secret),
|
||||
],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
define_backend!(GcsBackend, PrefixedGcsConnection);
|
||||
|
||||
impl StorageExport for GcsBackend {
|
||||
fn get_storage_path(&self, catalog: &str, schema: &str) -> (String, String) {
|
||||
let gcs_path = format_uri(
|
||||
"gcs",
|
||||
&self.config.gcs_bucket,
|
||||
&self.config.gcs_root,
|
||||
&format!("{}/{}/", catalog, schema),
|
||||
);
|
||||
|
||||
let mut connection_options = Vec::new();
|
||||
|
||||
let credential_path = expose_optional_secret(&self.config.gcs_credential_path);
|
||||
if !credential_path.is_empty() {
|
||||
connection_options.push(format!("CREDENTIAL_PATH='{}'", credential_path));
|
||||
}
|
||||
|
||||
let credential = expose_optional_secret(&self.config.gcs_credential);
|
||||
if !credential.is_empty() {
|
||||
connection_options.push(format!("CREDENTIAL='{}'", credential));
|
||||
}
|
||||
|
||||
if !self.config.gcs_endpoint.is_empty() {
|
||||
connection_options.push(format!("ENDPOINT='{}'", self.config.gcs_endpoint));
|
||||
}
|
||||
|
||||
let connection_str = if connection_options.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" CONNECTION ({})", connection_options.join(", "))
|
||||
};
|
||||
|
||||
(gcs_path, connection_str)
|
||||
}
|
||||
|
||||
fn format_output_path(&self, file_path: &str) -> String {
|
||||
format_uri(
|
||||
"gcs",
|
||||
&self.config.gcs_bucket,
|
||||
&self.config.gcs_root,
|
||||
file_path,
|
||||
)
|
||||
}
|
||||
|
||||
fn mask_sensitive_info(&self, sql: &str) -> String {
|
||||
mask_secrets(
|
||||
sql.to_string(),
|
||||
&[
|
||||
expose_optional_secret(&self.config.gcs_credential_path),
|
||||
expose_optional_secret(&self.config.gcs_credential),
|
||||
],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
define_backend!(AzblobBackend, PrefixedAzblobConnection);
|
||||
|
||||
impl StorageExport for AzblobBackend {
|
||||
fn get_storage_path(&self, catalog: &str, schema: &str) -> (String, String) {
|
||||
let azblob_path = format_uri(
|
||||
"azblob",
|
||||
&self.config.azblob_container,
|
||||
&self.config.azblob_root,
|
||||
&format!("{}/{}/", catalog, schema),
|
||||
);
|
||||
|
||||
let mut connection_options = vec![
|
||||
format!(
|
||||
"ACCOUNT_NAME='{}'",
|
||||
expose_optional_secret(&self.config.azblob_account_name)
|
||||
),
|
||||
format!(
|
||||
"ACCOUNT_KEY='{}'",
|
||||
expose_optional_secret(&self.config.azblob_account_key)
|
||||
),
|
||||
];
|
||||
|
||||
if let Some(sas_token) = &self.config.azblob_sas_token {
|
||||
connection_options.push(format!("SAS_TOKEN='{}'", sas_token));
|
||||
}
|
||||
|
||||
let connection_str = format!(" CONNECTION ({})", connection_options.join(", "));
|
||||
(azblob_path, connection_str)
|
||||
}
|
||||
|
||||
fn format_output_path(&self, file_path: &str) -> String {
|
||||
format_uri(
|
||||
"azblob",
|
||||
&self.config.azblob_container,
|
||||
&self.config.azblob_root,
|
||||
file_path,
|
||||
)
|
||||
}
|
||||
|
||||
fn mask_sensitive_info(&self, sql: &str) -> String {
|
||||
mask_secrets(
|
||||
sql.to_string(),
|
||||
&[
|
||||
expose_optional_secret(&self.config.azblob_account_name),
|
||||
expose_optional_secret(&self.config.azblob_account_key),
|
||||
],
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum StorageType {
|
||||
Fs(FsBackend),
|
||||
S3(S3Backend),
|
||||
Oss(OssBackend),
|
||||
Gcs(GcsBackend),
|
||||
Azblob(AzblobBackend),
|
||||
}
|
||||
|
||||
impl StorageExport for StorageType {
|
||||
fn get_storage_path(&self, catalog: &str, schema: &str) -> (String, String) {
|
||||
match self {
|
||||
StorageType::Fs(backend) => backend.get_storage_path(catalog, schema),
|
||||
StorageType::S3(backend) => backend.get_storage_path(catalog, schema),
|
||||
StorageType::Oss(backend) => backend.get_storage_path(catalog, schema),
|
||||
StorageType::Gcs(backend) => backend.get_storage_path(catalog, schema),
|
||||
StorageType::Azblob(backend) => backend.get_storage_path(catalog, schema),
|
||||
}
|
||||
}
|
||||
|
||||
fn format_output_path(&self, file_path: &str) -> String {
|
||||
match self {
|
||||
StorageType::Fs(backend) => backend.format_output_path(file_path),
|
||||
StorageType::S3(backend) => backend.format_output_path(file_path),
|
||||
StorageType::Oss(backend) => backend.format_output_path(file_path),
|
||||
StorageType::Gcs(backend) => backend.format_output_path(file_path),
|
||||
StorageType::Azblob(backend) => backend.format_output_path(file_path),
|
||||
}
|
||||
}
|
||||
|
||||
fn mask_sensitive_info(&self, sql: &str) -> String {
|
||||
match self {
|
||||
StorageType::Fs(backend) => backend.mask_sensitive_info(sql),
|
||||
StorageType::S3(backend) => backend.mask_sensitive_info(sql),
|
||||
StorageType::Oss(backend) => backend.mask_sensitive_info(sql),
|
||||
StorageType::Gcs(backend) => backend.mask_sensitive_info(sql),
|
||||
StorageType::Azblob(backend) => backend.mask_sensitive_info(sql),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StorageType {
|
||||
/// Returns true if the storage backend is remote (not local filesystem).
|
||||
pub fn is_remote_storage(&self) -> bool {
|
||||
!matches!(self, StorageType::Fs(_))
|
||||
}
|
||||
}
|
||||
@@ -253,12 +253,6 @@ pub enum Error {
|
||||
error: ObjectStoreError,
|
||||
},
|
||||
|
||||
#[snafu(display("S3 config need be set"))]
|
||||
S3ConfigNotSet {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Output directory not set"))]
|
||||
OutputDirNotSet {
|
||||
#[snafu(implicit)]
|
||||
@@ -364,9 +358,9 @@ impl ErrorExt for Error {
|
||||
|
||||
Error::Other { source, .. } => source.status_code(),
|
||||
Error::OpenDal { .. } | Error::InitBackend { .. } => StatusCode::Internal,
|
||||
Error::S3ConfigNotSet { .. }
|
||||
| Error::OutputDirNotSet { .. }
|
||||
| Error::EmptyStoreAddrs { .. } => StatusCode::InvalidArguments,
|
||||
Error::OutputDirNotSet { .. } | Error::EmptyStoreAddrs { .. } => {
|
||||
StatusCode::InvalidArguments
|
||||
}
|
||||
|
||||
Error::BuildRuntime { source, .. } => source.status_code(),
|
||||
|
||||
|
||||
@@ -145,6 +145,17 @@ impl ObjbenchCommand {
|
||||
let region_meta = extract_region_metadata(&self.source, &parquet_meta)?;
|
||||
let num_rows = parquet_meta.file_metadata().num_rows() as u64;
|
||||
let num_row_groups = parquet_meta.num_row_groups() as u64;
|
||||
let max_row_group_uncompressed_size: u64 = parquet_meta
|
||||
.row_groups()
|
||||
.iter()
|
||||
.map(|rg| {
|
||||
rg.columns()
|
||||
.iter()
|
||||
.map(|c| c.uncompressed_size() as u64)
|
||||
.sum::<u64>()
|
||||
})
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
println!(
|
||||
"{} Metadata loaded - rows: {}, size: {} bytes",
|
||||
@@ -160,10 +171,11 @@ impl ObjbenchCommand {
|
||||
time_range: Default::default(),
|
||||
level: 0,
|
||||
file_size,
|
||||
max_row_group_uncompressed_size,
|
||||
available_indexes: Default::default(),
|
||||
indexes: Default::default(),
|
||||
index_file_size: 0,
|
||||
index_file_id: None,
|
||||
index_version: 0,
|
||||
num_rows,
|
||||
num_row_groups,
|
||||
sequence: None,
|
||||
@@ -564,7 +576,7 @@ fn new_noop_file_purger() -> FilePurgerRef {
|
||||
#[derive(Debug)]
|
||||
struct Noop;
|
||||
impl FilePurger for Noop {
|
||||
fn remove_file(&self, _file_meta: FileMeta, _is_delete: bool) {}
|
||||
fn remove_file(&self, _file_meta: FileMeta, _is_delete: bool, _index_outdated: bool) {}
|
||||
}
|
||||
Arc::new(Noop)
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use cache::{build_fundamental_cache_registry, with_default_composite_cache_registry};
|
||||
use catalog::CatalogManagerRef;
|
||||
use catalog::information_extension::DistributedInformationExtension;
|
||||
use catalog::kvbackend::{CachedKvBackendBuilder, KvBackendCatalogManagerBuilder, MetaKvBackend};
|
||||
use clap::Parser;
|
||||
@@ -26,14 +25,12 @@ use client::client_manager::NodeClients;
|
||||
use common_base::Plugins;
|
||||
use common_config::{Configurable, DEFAULT_DATA_HOME};
|
||||
use common_grpc::channel_manager::ChannelConfig;
|
||||
use common_meta::FlownodeId;
|
||||
use common_meta::cache::{CacheRegistryBuilder, LayeredCacheRegistryBuilder};
|
||||
use common_meta::heartbeat::handler::HandlerGroupExecutor;
|
||||
use common_meta::heartbeat::handler::invalidate_table_cache::InvalidateCacheHandler;
|
||||
use common_meta::heartbeat::handler::parse_mailbox_message::ParseMailboxMessageHandler;
|
||||
use common_meta::key::TableMetadataManager;
|
||||
use common_meta::key::flow::FlowMetadataManager;
|
||||
use common_meta::kv_backend::KvBackendRef;
|
||||
use common_stat::ResourceStatImpl;
|
||||
use common_telemetry::info;
|
||||
use common_telemetry::logging::{DEFAULT_LOGGING_DIR, TracingOptions};
|
||||
@@ -43,6 +40,7 @@ use flow::{
|
||||
get_flow_auth_options,
|
||||
};
|
||||
use meta_client::{MetaClientOptions, MetaClientType};
|
||||
use plugins::flownode::context::GrpcConfigureContext;
|
||||
use servers::configurator::GrpcBuilderConfiguratorRef;
|
||||
use snafu::{OptionExt, ResultExt, ensure};
|
||||
use tracing_appender::non_blocking::WorkerGuard;
|
||||
@@ -435,11 +433,3 @@ impl StartCommand {
|
||||
Ok(Instance::new(flownode, guard))
|
||||
}
|
||||
}
|
||||
|
||||
/// The context for [`GrpcBuilderConfiguratorRef`] in flownode.
|
||||
pub struct GrpcConfigureContext {
|
||||
pub kv_backend: KvBackendRef,
|
||||
pub fe_client: Arc<FrontendClient>,
|
||||
pub flownode_id: FlownodeId,
|
||||
pub catalog_manager: CatalogManagerRef,
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ use common_meta::cache::{CacheRegistryBuilder, LayeredCacheRegistryBuilder};
|
||||
use common_meta::heartbeat::handler::HandlerGroupExecutor;
|
||||
use common_meta::heartbeat::handler::invalidate_table_cache::InvalidateCacheHandler;
|
||||
use common_meta::heartbeat::handler::parse_mailbox_message::ParseMailboxMessageHandler;
|
||||
use common_meta::heartbeat::handler::suspend::SuspendHandler;
|
||||
use common_query::prelude::set_default_prefix;
|
||||
use common_stat::ResourceStatImpl;
|
||||
use common_telemetry::info;
|
||||
@@ -46,9 +47,12 @@ use frontend::heartbeat::HeartbeatTask;
|
||||
use frontend::instance::builder::FrontendBuilder;
|
||||
use frontend::server::Services;
|
||||
use meta_client::{MetaClientOptions, MetaClientRef, MetaClientType};
|
||||
use plugins::frontend::context::{
|
||||
CatalogManagerConfigureContext, DistributedCatalogManagerConfigureContext,
|
||||
};
|
||||
use servers::addrs;
|
||||
use servers::grpc::GrpcOptions;
|
||||
use servers::tls::{TlsMode, TlsOption};
|
||||
use servers::tls::{TlsMode, TlsOption, merge_tls_option};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use tracing_appender::non_blocking::WorkerGuard;
|
||||
|
||||
@@ -252,7 +256,7 @@ impl StartCommand {
|
||||
|
||||
if let Some(addr) = &self.rpc_bind_addr {
|
||||
opts.grpc.bind_addr.clone_from(addr);
|
||||
opts.grpc.tls = tls_opts.clone();
|
||||
opts.grpc.tls = merge_tls_option(&opts.grpc.tls, tls_opts.clone());
|
||||
}
|
||||
|
||||
if let Some(addr) = &self.rpc_server_addr {
|
||||
@@ -287,13 +291,13 @@ impl StartCommand {
|
||||
if let Some(addr) = &self.mysql_addr {
|
||||
opts.mysql.enable = true;
|
||||
opts.mysql.addr.clone_from(addr);
|
||||
opts.mysql.tls = tls_opts.clone();
|
||||
opts.mysql.tls = merge_tls_option(&opts.mysql.tls, tls_opts.clone());
|
||||
}
|
||||
|
||||
if let Some(addr) = &self.postgres_addr {
|
||||
opts.postgres.enable = true;
|
||||
opts.postgres.addr.clone_from(addr);
|
||||
opts.postgres.tls = tls_opts;
|
||||
opts.postgres.tls = merge_tls_option(&opts.postgres.tls, tls_opts.clone());
|
||||
}
|
||||
|
||||
if let Some(enable) = self.influxdb_enable {
|
||||
@@ -423,9 +427,11 @@ impl StartCommand {
|
||||
let builder = if let Some(configurator) =
|
||||
plugins.get::<CatalogManagerConfiguratorRef<CatalogManagerConfigureContext>>()
|
||||
{
|
||||
let ctx = CatalogManagerConfigureContext {
|
||||
let ctx = DistributedCatalogManagerConfigureContext {
|
||||
meta_client: meta_client.clone(),
|
||||
};
|
||||
let ctx = CatalogManagerConfigureContext::Distributed(ctx);
|
||||
|
||||
configurator
|
||||
.configure(builder, ctx)
|
||||
.await
|
||||
@@ -435,30 +441,13 @@ impl StartCommand {
|
||||
};
|
||||
let catalog_manager = builder.build();
|
||||
|
||||
let executor = HandlerGroupExecutor::new(vec![
|
||||
Arc::new(ParseMailboxMessageHandler),
|
||||
Arc::new(InvalidateCacheHandler::new(layered_cache_registry.clone())),
|
||||
]);
|
||||
|
||||
let mut resource_stat = ResourceStatImpl::default();
|
||||
resource_stat.start_collect_cpu_usage();
|
||||
|
||||
let heartbeat_task = HeartbeatTask::new(
|
||||
&opts,
|
||||
meta_client.clone(),
|
||||
opts.heartbeat.clone(),
|
||||
Arc::new(executor),
|
||||
Arc::new(resource_stat),
|
||||
);
|
||||
let heartbeat_task = Some(heartbeat_task);
|
||||
|
||||
let instance = FrontendBuilder::new(
|
||||
opts.clone(),
|
||||
cached_meta_backend.clone(),
|
||||
layered_cache_registry.clone(),
|
||||
catalog_manager,
|
||||
client,
|
||||
meta_client,
|
||||
meta_client.clone(),
|
||||
process_manager,
|
||||
)
|
||||
.with_plugin(plugins.clone())
|
||||
@@ -466,6 +455,9 @@ impl StartCommand {
|
||||
.try_build()
|
||||
.await
|
||||
.context(error::StartFrontendSnafu)?;
|
||||
|
||||
let heartbeat_task = Some(create_heartbeat_task(&opts, meta_client, &instance));
|
||||
|
||||
let instance = Arc::new(instance);
|
||||
|
||||
let servers = Services::new(opts, instance.clone(), plugins)
|
||||
@@ -482,9 +474,26 @@ impl StartCommand {
|
||||
}
|
||||
}
|
||||
|
||||
/// The context for [`CatalogManagerConfigratorRef`] in frontend.
|
||||
pub struct CatalogManagerConfigureContext {
|
||||
pub meta_client: MetaClientRef,
|
||||
pub fn create_heartbeat_task(
|
||||
options: &frontend::frontend::FrontendOptions,
|
||||
meta_client: MetaClientRef,
|
||||
instance: &frontend::instance::Instance,
|
||||
) -> HeartbeatTask {
|
||||
let executor = Arc::new(HandlerGroupExecutor::new(vec![
|
||||
Arc::new(ParseMailboxMessageHandler),
|
||||
Arc::new(SuspendHandler::new(instance.suspend_state())),
|
||||
Arc::new(InvalidateCacheHandler::new(
|
||||
instance.cache_invalidator().clone(),
|
||||
)),
|
||||
]));
|
||||
|
||||
let stat = {
|
||||
let mut stat = ResourceStatImpl::default();
|
||||
stat.start_collect_cpu_usage();
|
||||
Arc::new(stat)
|
||||
};
|
||||
|
||||
HeartbeatTask::new(options, meta_client, executor, stat)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -20,6 +20,7 @@ use async_trait::async_trait;
|
||||
use clap::Parser;
|
||||
use common_base::Plugins;
|
||||
use common_config::Configurable;
|
||||
use common_meta::distributed_time_constants::init_distributed_time_constants;
|
||||
use common_telemetry::info;
|
||||
use common_telemetry::logging::{DEFAULT_LOGGING_DIR, TracingOptions};
|
||||
use common_version::{short_version, verbose_version};
|
||||
@@ -154,8 +155,6 @@ pub struct StartCommand {
|
||||
#[clap(short, long)]
|
||||
selector: Option<String>,
|
||||
#[clap(long)]
|
||||
use_memory_store: Option<bool>,
|
||||
#[clap(long)]
|
||||
enable_region_failover: Option<bool>,
|
||||
#[clap(long)]
|
||||
http_addr: Option<String>,
|
||||
@@ -185,7 +184,6 @@ impl Debug for StartCommand {
|
||||
.field("store_addrs", &self.sanitize_store_addrs())
|
||||
.field("config_file", &self.config_file)
|
||||
.field("selector", &self.selector)
|
||||
.field("use_memory_store", &self.use_memory_store)
|
||||
.field("enable_region_failover", &self.enable_region_failover)
|
||||
.field("http_addr", &self.http_addr)
|
||||
.field("http_timeout", &self.http_timeout)
|
||||
@@ -267,10 +265,6 @@ impl StartCommand {
|
||||
.context(error::UnsupportedSelectorTypeSnafu { selector_type })?;
|
||||
}
|
||||
|
||||
if let Some(use_memory_store) = self.use_memory_store {
|
||||
opts.use_memory_store = use_memory_store;
|
||||
}
|
||||
|
||||
if let Some(enable_region_failover) = self.enable_region_failover {
|
||||
opts.enable_region_failover = enable_region_failover;
|
||||
}
|
||||
@@ -327,6 +321,7 @@ impl StartCommand {
|
||||
log_versions(verbose_version(), short_version(), APP_NAME);
|
||||
maybe_activate_heap_profile(&opts.component.memory);
|
||||
create_resource_limit_metrics(APP_NAME);
|
||||
init_distributed_time_constants(opts.component.heartbeat_interval);
|
||||
|
||||
info!("Metasrv start command: {:#?}", self);
|
||||
|
||||
@@ -389,7 +384,6 @@ mod tests {
|
||||
server_addr = "127.0.0.1:3002"
|
||||
store_addr = "127.0.0.1:2379"
|
||||
selector = "LeaseBased"
|
||||
use_memory_store = false
|
||||
|
||||
[logging]
|
||||
level = "debug"
|
||||
@@ -468,7 +462,6 @@ mod tests {
|
||||
server_addr = "127.0.0.1:3002"
|
||||
datanode_lease_secs = 15
|
||||
selector = "LeaseBased"
|
||||
use_memory_store = false
|
||||
|
||||
[http]
|
||||
addr = "127.0.0.1:4000"
|
||||
|
||||
@@ -32,7 +32,7 @@ use common_meta::cache::LayeredCacheRegistryBuilder;
|
||||
use common_meta::ddl::flow_meta::FlowMetadataAllocator;
|
||||
use common_meta::ddl::table_meta::TableMetadataAllocator;
|
||||
use common_meta::ddl::{DdlContext, NoopRegionFailureDetectorControl};
|
||||
use common_meta::ddl_manager::{DdlManager, DdlManagerConfiguratorRef, DdlManagerConfigureContext};
|
||||
use common_meta::ddl_manager::{DdlManager, DdlManagerConfiguratorRef};
|
||||
use common_meta::key::flow::FlowMetadataManager;
|
||||
use common_meta::key::{TableMetadataManager, TableMetadataManagerRef};
|
||||
use common_meta::kv_backend::KvBackendRef;
|
||||
@@ -58,7 +58,11 @@ use frontend::instance::StandaloneDatanodeManager;
|
||||
use frontend::instance::builder::FrontendBuilder;
|
||||
use frontend::server::Services;
|
||||
use meta_srv::metasrv::{FLOW_ID_SEQ, TABLE_ID_SEQ};
|
||||
use servers::tls::{TlsMode, TlsOption};
|
||||
use plugins::frontend::context::{
|
||||
CatalogManagerConfigureContext, StandaloneCatalogManagerConfigureContext,
|
||||
};
|
||||
use plugins::standalone::context::DdlManagerConfigureContext;
|
||||
use servers::tls::{TlsMode, TlsOption, merge_tls_option};
|
||||
use snafu::ResultExt;
|
||||
use standalone::StandaloneInformationExtension;
|
||||
use standalone::options::StandaloneOptions;
|
||||
@@ -289,19 +293,20 @@ impl StartCommand {
|
||||
),
|
||||
}.fail();
|
||||
}
|
||||
opts.grpc.bind_addr.clone_from(addr)
|
||||
opts.grpc.bind_addr.clone_from(addr);
|
||||
opts.grpc.tls = merge_tls_option(&opts.grpc.tls, tls_opts.clone());
|
||||
}
|
||||
|
||||
if let Some(addr) = &self.mysql_addr {
|
||||
opts.mysql.enable = true;
|
||||
opts.mysql.addr.clone_from(addr);
|
||||
opts.mysql.tls = tls_opts.clone();
|
||||
opts.mysql.tls = merge_tls_option(&opts.mysql.tls, tls_opts.clone());
|
||||
}
|
||||
|
||||
if let Some(addr) = &self.postgres_addr {
|
||||
opts.postgres.enable = true;
|
||||
opts.postgres.addr.clone_from(addr);
|
||||
opts.postgres.tls = tls_opts;
|
||||
opts.postgres.tls = merge_tls_option(&opts.postgres.tls, tls_opts.clone());
|
||||
}
|
||||
|
||||
if self.influxdb_enable {
|
||||
@@ -414,9 +419,10 @@ impl StartCommand {
|
||||
let builder = if let Some(configurator) =
|
||||
plugins.get::<CatalogManagerConfiguratorRef<CatalogManagerConfigureContext>>()
|
||||
{
|
||||
let ctx = CatalogManagerConfigureContext {
|
||||
let ctx = StandaloneCatalogManagerConfigureContext {
|
||||
fe_client: frontend_client.clone(),
|
||||
};
|
||||
let ctx = CatalogManagerConfigureContext::Standalone(ctx);
|
||||
configurator
|
||||
.configure(builder, ctx)
|
||||
.await
|
||||
@@ -506,9 +512,13 @@ impl StartCommand {
|
||||
let ddl_manager = DdlManager::try_new(ddl_context, procedure_manager.clone(), true)
|
||||
.context(error::InitDdlManagerSnafu)?;
|
||||
|
||||
let ddl_manager = if let Some(configurator) = plugins.get::<DdlManagerConfiguratorRef>() {
|
||||
let ddl_manager = if let Some(configurator) =
|
||||
plugins.get::<DdlManagerConfiguratorRef<DdlManagerConfigureContext>>()
|
||||
{
|
||||
let ctx = DdlManagerConfigureContext {
|
||||
kv_backend: kv_backend.clone(),
|
||||
fe_client: frontend_client.clone(),
|
||||
catalog_manager: catalog_manager.clone(),
|
||||
};
|
||||
configurator
|
||||
.configure(ddl_manager, ctx)
|
||||
@@ -542,9 +552,8 @@ impl StartCommand {
|
||||
let grpc_handler = fe_instance.clone() as Arc<dyn GrpcQueryHandlerWithBoxedError>;
|
||||
let weak_grpc_handler = Arc::downgrade(&grpc_handler);
|
||||
frontend_instance_handler
|
||||
.lock()
|
||||
.unwrap()
|
||||
.replace(weak_grpc_handler);
|
||||
.set_handler(weak_grpc_handler)
|
||||
.await;
|
||||
|
||||
// set the frontend invoker for flownode
|
||||
let flow_streaming_engine = flownode.flow_engine().streaming_engine();
|
||||
@@ -595,11 +604,6 @@ impl StartCommand {
|
||||
}
|
||||
}
|
||||
|
||||
/// The context for [`CatalogManagerConfigratorRef`] in standalone.
|
||||
pub struct CatalogManagerConfigureContext {
|
||||
pub fe_client: Arc<FrontendClient>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::default::Default;
|
||||
@@ -761,7 +765,6 @@ mod tests {
|
||||
user_provider: Some("static_user_provider:cmd:test=test".to_string()),
|
||||
mysql_addr: Some("127.0.0.1:4002".to_string()),
|
||||
postgres_addr: Some("127.0.0.1:4003".to_string()),
|
||||
tls_watch: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -778,8 +781,6 @@ mod tests {
|
||||
|
||||
assert_eq!("./greptimedb_data/test/logs", opts.logging.dir);
|
||||
assert_eq!("debug", opts.logging.level.unwrap());
|
||||
assert!(opts.mysql.tls.watch);
|
||||
assert!(opts.postgres.tls.watch);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -52,7 +52,6 @@ fn test_load_datanode_example_config() {
|
||||
meta_client: Some(MetaClientOptions {
|
||||
metasrv_addrs: vec!["127.0.0.1:3002".to_string()],
|
||||
timeout: Duration::from_secs(3),
|
||||
heartbeat_timeout: Duration::from_millis(500),
|
||||
ddl_timeout: Duration::from_secs(10),
|
||||
connect_timeout: Duration::from_secs(1),
|
||||
tcp_nodelay: true,
|
||||
@@ -118,7 +117,6 @@ fn test_load_frontend_example_config() {
|
||||
meta_client: Some(MetaClientOptions {
|
||||
metasrv_addrs: vec!["127.0.0.1:3002".to_string()],
|
||||
timeout: Duration::from_secs(3),
|
||||
heartbeat_timeout: Duration::from_millis(500),
|
||||
ddl_timeout: Duration::from_secs(10),
|
||||
connect_timeout: Duration::from_secs(1),
|
||||
tcp_nodelay: true,
|
||||
@@ -241,7 +239,6 @@ fn test_load_flownode_example_config() {
|
||||
meta_client: Some(MetaClientOptions {
|
||||
metasrv_addrs: vec!["127.0.0.1:3002".to_string()],
|
||||
timeout: Duration::from_secs(3),
|
||||
heartbeat_timeout: Duration::from_millis(500),
|
||||
ddl_timeout: Duration::from_secs(10),
|
||||
connect_timeout: Duration::from_secs(1),
|
||||
tcp_nodelay: true,
|
||||
|
||||
@@ -32,7 +32,12 @@ impl Plugins {
|
||||
|
||||
pub fn insert<T: 'static + Send + Sync>(&self, value: T) {
|
||||
let last = self.write().insert(value);
|
||||
assert!(last.is_none(), "each type of plugins must be one and only");
|
||||
if last.is_some() {
|
||||
panic!(
|
||||
"Plugin of type {} already exists",
|
||||
std::any::type_name::<T>()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get<T: 'static + Send + Sync + Clone>(&self) -> Option<T> {
|
||||
@@ -140,7 +145,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "each type of plugins must be one and only")]
|
||||
#[should_panic(expected = "Plugin of type i32 already exists")]
|
||||
fn test_plugin_uniqueness() {
|
||||
let plugins = Plugins::new();
|
||||
plugins.insert(1i32);
|
||||
|
||||
@@ -11,8 +11,10 @@ workspace = true
|
||||
common-base.workspace = true
|
||||
common-error.workspace = true
|
||||
common-macro.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
config.workspace = true
|
||||
humantime-serde.workspace = true
|
||||
notify.workspace = true
|
||||
object-store.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -49,14 +49,31 @@ pub enum Error {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to watch file: {}", path))]
|
||||
FileWatch {
|
||||
path: String,
|
||||
#[snafu(source)]
|
||||
error: notify::Error,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Invalid path '{}': expected a file, not a directory", path))]
|
||||
InvalidPath {
|
||||
path: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::TomlFormat { .. } | Error::LoadLayeredConfig { .. } => {
|
||||
StatusCode::InvalidArguments
|
||||
}
|
||||
Error::TomlFormat { .. }
|
||||
| Error::LoadLayeredConfig { .. }
|
||||
| Error::FileWatch { .. }
|
||||
| Error::InvalidPath { .. } => StatusCode::InvalidArguments,
|
||||
Error::SerdeJson { .. } => StatusCode::Unexpected,
|
||||
}
|
||||
}
|
||||
|
||||
277
src/common/config/src/file_watcher.rs
Normal file
277
src/common/config/src/file_watcher.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Common file watching utilities for configuration hot-reloading.
|
||||
//!
|
||||
//! This module provides a generic file watcher that can be used to watch
|
||||
//! files for changes and trigger callbacks when changes occur.
|
||||
//!
|
||||
//! The watcher monitors the parent directory of each file rather than the
|
||||
//! file itself. This ensures that file deletions and recreations are properly
|
||||
//! tracked, which is common with editors that use atomic saves or when
|
||||
//! configuration files are replaced.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::mpsc::channel;
|
||||
|
||||
use common_telemetry::{error, info, warn};
|
||||
use notify::{EventKind, RecursiveMode, Watcher};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{FileWatchSnafu, InvalidPathSnafu, Result};
|
||||
|
||||
/// Configuration for the file watcher behavior.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FileWatcherConfig {
|
||||
/// Whether to include Remove events in addition to Modify and Create.
|
||||
pub include_remove_events: bool,
|
||||
}
|
||||
|
||||
impl FileWatcherConfig {
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
pub fn include_remove_events(mut self) -> Self {
|
||||
self.include_remove_events = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A builder for creating file watchers with flexible configuration.
|
||||
///
|
||||
/// The watcher monitors the parent directory of each file to handle file
|
||||
/// deletion and recreation properly. Events are filtered to only trigger
|
||||
/// callbacks for the specific files being watched.
|
||||
pub struct FileWatcherBuilder {
|
||||
config: FileWatcherConfig,
|
||||
/// Canonicalized paths of files to watch.
|
||||
file_paths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl FileWatcherBuilder {
|
||||
/// Create a new builder with default configuration.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: FileWatcherConfig::default(),
|
||||
file_paths: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the watcher configuration.
|
||||
pub fn config(mut self, config: FileWatcherConfig) -> Self {
|
||||
self.config = config;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a file path to watch.
|
||||
///
|
||||
/// Returns an error if the path is a directory.
|
||||
/// The path is canonicalized for reliable comparison with events.
|
||||
pub fn watch_path<P: AsRef<Path>>(mut self, path: P) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
snafu::ensure!(
|
||||
path.is_file(),
|
||||
InvalidPathSnafu {
|
||||
path: path.display().to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
self.file_paths.push(path.to_path_buf());
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Add multiple file paths to watch.
|
||||
///
|
||||
/// Returns an error if any path is a directory.
|
||||
pub fn watch_paths<P: AsRef<Path>, I: IntoIterator<Item = P>>(
|
||||
mut self,
|
||||
paths: I,
|
||||
) -> Result<Self> {
|
||||
for path in paths {
|
||||
self = self.watch_path(path)?;
|
||||
}
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Build and spawn the file watcher with the given callback.
|
||||
///
|
||||
/// The callback is invoked when relevant file events are detected for
|
||||
/// the watched files. The watcher monitors the parent directories to
|
||||
/// handle file deletion and recreation properly.
|
||||
///
|
||||
/// The spawned watcher thread runs for the lifetime of the process.
|
||||
pub fn spawn<F>(self, callback: F) -> Result<()>
|
||||
where
|
||||
F: Fn() + Send + 'static,
|
||||
{
|
||||
let (tx, rx) = channel::<notify::Result<notify::Event>>();
|
||||
let mut watcher =
|
||||
notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
|
||||
|
||||
// Collect unique parent directories to watch
|
||||
let mut watched_dirs: HashSet<PathBuf> = HashSet::new();
|
||||
for file_path in &self.file_paths {
|
||||
if let Some(parent) = file_path.parent()
|
||||
&& watched_dirs.insert(parent.to_path_buf())
|
||||
{
|
||||
watcher
|
||||
.watch(parent, RecursiveMode::NonRecursive)
|
||||
.context(FileWatchSnafu {
|
||||
path: parent.display().to_string(),
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let config = self.config;
|
||||
|
||||
info!(
|
||||
"Spawning file watcher for paths: {:?} (watching parent directories)",
|
||||
self.file_paths
|
||||
.iter()
|
||||
.map(|p| p.display().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
std::thread::spawn(move || {
|
||||
// Keep watcher alive in the thread
|
||||
let _watcher = watcher;
|
||||
|
||||
while let Ok(res) = rx.recv() {
|
||||
match res {
|
||||
Ok(event) => {
|
||||
if !is_relevant_event(&event.kind, &config) {
|
||||
continue;
|
||||
}
|
||||
|
||||
info!(?event.kind, ?event.paths, "Detected folder change");
|
||||
callback();
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("File watcher error: {}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
error!("File watcher channel closed unexpectedly");
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FileWatcherBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an event kind is relevant based on the configuration.
|
||||
fn is_relevant_event(kind: &EventKind, config: &FileWatcherConfig) -> bool {
|
||||
match kind {
|
||||
EventKind::Modify(_) | EventKind::Create(_) => true,
|
||||
EventKind::Remove(_) => config.include_remove_events,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
use common_test_util::temp_dir::create_temp_dir;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_file_watcher_detects_changes() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let dir = create_temp_dir("test_file_watcher");
|
||||
let file_path = dir.path().join("test_file.txt");
|
||||
|
||||
// Create initial file
|
||||
std::fs::write(&file_path, "initial content").unwrap();
|
||||
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
let counter_clone = counter.clone();
|
||||
|
||||
FileWatcherBuilder::new()
|
||||
.watch_path(&file_path)
|
||||
.unwrap()
|
||||
.config(FileWatcherConfig::new())
|
||||
.spawn(move || {
|
||||
counter_clone.fetch_add(1, Ordering::SeqCst);
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Give watcher time to start
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
|
||||
// Modify the file
|
||||
std::fs::write(&file_path, "modified content").unwrap();
|
||||
|
||||
// Wait for the event to be processed
|
||||
std::thread::sleep(Duration::from_millis(500));
|
||||
|
||||
assert!(
|
||||
counter.load(Ordering::SeqCst) >= 1,
|
||||
"Watcher should have detected at least one change"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_watcher_detects_delete_and_recreate() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let dir = create_temp_dir("test_file_watcher_recreate");
|
||||
let file_path = dir.path().join("test_file.txt");
|
||||
|
||||
// Create initial file
|
||||
std::fs::write(&file_path, "initial content").unwrap();
|
||||
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
let counter_clone = counter.clone();
|
||||
|
||||
FileWatcherBuilder::new()
|
||||
.watch_path(&file_path)
|
||||
.unwrap()
|
||||
.config(FileWatcherConfig::new())
|
||||
.spawn(move || {
|
||||
counter_clone.fetch_add(1, Ordering::SeqCst);
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Give watcher time to start
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
|
||||
// Delete the file
|
||||
std::fs::remove_file(&file_path).unwrap();
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
|
||||
// Recreate the file - this should still be detected because we watch the directory
|
||||
std::fs::write(&file_path, "recreated content").unwrap();
|
||||
|
||||
// Wait for the event to be processed
|
||||
std::thread::sleep(Duration::from_millis(500));
|
||||
|
||||
assert!(
|
||||
counter.load(Ordering::SeqCst) >= 1,
|
||||
"Watcher should have detected file recreation"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod file_watcher;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ const SECRET_ACCESS_KEY: &str = "secret_access_key";
|
||||
const SESSION_TOKEN: &str = "session_token";
|
||||
const REGION: &str = "region";
|
||||
const ENABLE_VIRTUAL_HOST_STYLE: &str = "enable_virtual_host_style";
|
||||
const DISABLE_EC2_METADATA: &str = "disable_ec2_metadata";
|
||||
|
||||
pub fn is_supported_in_s3(key: &str) -> bool {
|
||||
[
|
||||
@@ -36,6 +37,7 @@ pub fn is_supported_in_s3(key: &str) -> bool {
|
||||
SESSION_TOKEN,
|
||||
REGION,
|
||||
ENABLE_VIRTUAL_HOST_STYLE,
|
||||
DISABLE_EC2_METADATA,
|
||||
]
|
||||
.contains(&key)
|
||||
}
|
||||
@@ -82,6 +84,21 @@ pub fn build_s3_backend(
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(disable_str) = connection.get(DISABLE_EC2_METADATA) {
|
||||
let disable = disable_str.as_str().parse::<bool>().map_err(|e| {
|
||||
error::InvalidConnectionSnafu {
|
||||
msg: format!(
|
||||
"failed to parse the option {}={}, {}",
|
||||
DISABLE_EC2_METADATA, disable_str, e
|
||||
),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
if disable {
|
||||
builder = builder.disable_ec2_metadata();
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(weny): Consider finding a better way to eliminate duplicate code.
|
||||
Ok(ObjectStore::new(builder)
|
||||
.context(error::BuildBackendSnafu)?
|
||||
@@ -109,6 +126,7 @@ mod tests {
|
||||
assert!(is_supported_in_s3(SESSION_TOKEN));
|
||||
assert!(is_supported_in_s3(REGION));
|
||||
assert!(is_supported_in_s3(ENABLE_VIRTUAL_HOST_STYLE));
|
||||
assert!(is_supported_in_s3(DISABLE_EC2_METADATA));
|
||||
assert!(!is_supported_in_s3("foo"))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,8 @@ pub mod status_code;
|
||||
use http::{HeaderMap, HeaderValue};
|
||||
pub use snafu;
|
||||
|
||||
use crate::status_code::StatusCode;
|
||||
|
||||
// HACK - these headers are here for shared in gRPC services. For common HTTP headers,
|
||||
// please define in `src/servers/src/http/header.rs`.
|
||||
pub const GREPTIME_DB_HEADER_ERROR_CODE: &str = "x-greptime-err-code";
|
||||
@@ -46,6 +48,29 @@ pub fn from_err_code_msg_to_header(code: u32, msg: &str) -> HeaderMap {
|
||||
header
|
||||
}
|
||||
|
||||
/// Extract [StatusCode] and error message from [HeaderMap], if any.
|
||||
///
|
||||
/// Note that if the [StatusCode] is illegal, for example, a random number that is not pre-defined
|
||||
/// as a [StatusCode], the result is still `None`.
|
||||
pub fn from_header_to_err_code_msg(headers: &HeaderMap) -> Option<(StatusCode, &str)> {
|
||||
let code = headers
|
||||
.get(GREPTIME_DB_HEADER_ERROR_CODE)
|
||||
.and_then(|value| {
|
||||
value
|
||||
.to_str()
|
||||
.ok()
|
||||
.and_then(|x| x.parse::<u32>().ok())
|
||||
.and_then(StatusCode::from_u32)
|
||||
});
|
||||
let msg = headers
|
||||
.get(GREPTIME_DB_HEADER_ERROR_MSG)
|
||||
.and_then(|x| x.to_str().ok());
|
||||
match (code, msg) {
|
||||
(Some(code), Some(msg)) => Some((code, msg)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the external root cause of the source error (exclude the current error).
|
||||
pub fn root_source(err: &dyn std::error::Error) -> Option<&dyn std::error::Error> {
|
||||
// There are some divergence about the behavior of the `sources()` API
|
||||
|
||||
@@ -42,6 +42,8 @@ pub enum StatusCode {
|
||||
External = 1007,
|
||||
/// The request is deadline exceeded (typically server-side).
|
||||
DeadlineExceeded = 1008,
|
||||
/// Service got suspended for various reason. For example, resources exceed limit.
|
||||
Suspended = 1009,
|
||||
// ====== End of common status code ================
|
||||
|
||||
// ====== Begin of SQL related status code =========
|
||||
@@ -175,7 +177,8 @@ impl StatusCode {
|
||||
| StatusCode::AccessDenied
|
||||
| StatusCode::PermissionDenied
|
||||
| StatusCode::RequestOutdated
|
||||
| StatusCode::External => false,
|
||||
| StatusCode::External
|
||||
| StatusCode::Suspended => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,7 +226,8 @@ impl StatusCode {
|
||||
| StatusCode::InvalidAuthHeader
|
||||
| StatusCode::AccessDenied
|
||||
| StatusCode::PermissionDenied
|
||||
| StatusCode::RequestOutdated => false,
|
||||
| StatusCode::RequestOutdated
|
||||
| StatusCode::Suspended => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,7 +351,8 @@ pub fn status_to_tonic_code(status_code: StatusCode) -> Code {
|
||||
| StatusCode::RegionNotReady => Code::Unavailable,
|
||||
StatusCode::RuntimeResourcesExhausted
|
||||
| StatusCode::RateLimited
|
||||
| StatusCode::RegionBusy => Code::ResourceExhausted,
|
||||
| StatusCode::RegionBusy
|
||||
| StatusCode::Suspended => Code::ResourceExhausted,
|
||||
StatusCode::UnsupportedPasswordType
|
||||
| StatusCode::UserPasswordMismatch
|
||||
| StatusCode::AuthHeaderNotFound
|
||||
|
||||
@@ -19,7 +19,7 @@ arc-swap = "1.0"
|
||||
arrow.workspace = true
|
||||
arrow-schema.workspace = true
|
||||
async-trait.workspace = true
|
||||
bincode = "1.3"
|
||||
bincode = "=1.3.3"
|
||||
catalog.workspace = true
|
||||
chrono.workspace = true
|
||||
common-base.workspace = true
|
||||
@@ -39,7 +39,7 @@ datafusion-functions-aggregate-common.workspace = true
|
||||
datafusion-pg-catalog.workspace = true
|
||||
datafusion-physical-expr.workspace = true
|
||||
datatypes.workspace = true
|
||||
derive_more = { version = "1", default-features = false, features = ["display"] }
|
||||
derive_more.workspace = true
|
||||
geo = { version = "0.29", optional = true }
|
||||
geo-types = { version = "0.7", optional = true }
|
||||
geohash = { version = "0.13", optional = true }
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
mod binary;
|
||||
mod ctx;
|
||||
mod if_func;
|
||||
mod is_null;
|
||||
mod unary;
|
||||
|
||||
@@ -22,6 +23,7 @@ pub use ctx::EvalContext;
|
||||
pub use unary::scalar_unary_op;
|
||||
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
use crate::scalars::expression::if_func::IfFunction;
|
||||
use crate::scalars::expression::is_null::IsNullFunction;
|
||||
|
||||
pub(crate) struct ExpressionFunction;
|
||||
@@ -29,5 +31,6 @@ pub(crate) struct ExpressionFunction;
|
||||
impl ExpressionFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(IsNullFunction::default());
|
||||
registry.register_scalar(IfFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
404
src/common/function/src/scalars/expression/if_func.rs
Normal file
404
src/common/function/src/scalars/expression/if_func.rs
Normal file
@@ -0,0 +1,404 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt;
|
||||
use std::fmt::Display;
|
||||
|
||||
use arrow::array::ArrowNativeTypeOp;
|
||||
use arrow::datatypes::ArrowPrimitiveType;
|
||||
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray, PrimitiveArray};
|
||||
use datafusion::arrow::compute::kernels::zip::zip;
|
||||
use datafusion::arrow::datatypes::DataType;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_expr::type_coercion::binary::comparison_coercion;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
|
||||
const NAME: &str = "if";
|
||||
|
||||
/// MySQL-compatible IF function: IF(condition, true_value, false_value)
|
||||
///
|
||||
/// Returns true_value if condition is TRUE (not NULL and not 0),
|
||||
/// otherwise returns false_value.
|
||||
///
|
||||
/// MySQL truthy rules:
|
||||
/// - NULL -> false
|
||||
/// - 0 (numeric zero) -> false
|
||||
/// - Any non-zero numeric -> true
|
||||
/// - Boolean true/false -> use directly
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct IfFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl Default for IfFunction {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
signature: Signature::any(3, Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for IfFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for IfFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, input_types: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
// Return the common type of true_value and false_value (args[1] and args[2])
|
||||
if input_types.len() < 3 {
|
||||
return Err(DataFusionError::Plan(format!(
|
||||
"{} requires 3 arguments, got {}",
|
||||
NAME,
|
||||
input_types.len()
|
||||
)));
|
||||
}
|
||||
let true_type = &input_types[1];
|
||||
let false_type = &input_types[2];
|
||||
|
||||
// Use comparison_coercion to find common type
|
||||
comparison_coercion(true_type, false_type).ok_or_else(|| {
|
||||
DataFusionError::Plan(format!(
|
||||
"Cannot find common type for IF function between {:?} and {:?}",
|
||||
true_type, false_type
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() != 3 {
|
||||
return Err(DataFusionError::Plan(format!(
|
||||
"{} requires exactly 3 arguments, got {}",
|
||||
NAME,
|
||||
args.args.len()
|
||||
)));
|
||||
}
|
||||
|
||||
let condition = &args.args[0];
|
||||
let true_value = &args.args[1];
|
||||
let false_value = &args.args[2];
|
||||
|
||||
// Convert condition to boolean array using MySQL truthy rules
|
||||
let bool_array = to_boolean_array(condition, args.number_rows)?;
|
||||
|
||||
// Convert true and false values to arrays
|
||||
let true_array = true_value.to_array(args.number_rows)?;
|
||||
let false_array = false_value.to_array(args.number_rows)?;
|
||||
|
||||
// Use zip to select values based on condition
|
||||
// zip expects &dyn Datum, and ArrayRef (Arc<dyn Array>) implements Datum
|
||||
let result = zip(&bool_array, &true_array, &false_array)?;
|
||||
Ok(ColumnarValue::Array(result))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a ColumnarValue to a BooleanArray using MySQL truthy rules:
|
||||
/// - NULL -> false
|
||||
/// - 0 (any numeric zero) -> false
|
||||
/// - Non-zero numeric -> true
|
||||
/// - Boolean -> use directly
|
||||
fn to_boolean_array(
|
||||
value: &ColumnarValue,
|
||||
num_rows: usize,
|
||||
) -> datafusion_common::Result<BooleanArray> {
|
||||
let array = value.to_array(num_rows)?;
|
||||
array_to_bool(array)
|
||||
}
|
||||
|
||||
/// Convert an integer PrimitiveArray to BooleanArray using MySQL truthy rules:
|
||||
/// NULL -> false, 0 -> false, non-zero -> true
|
||||
fn int_array_to_bool<T>(array: &PrimitiveArray<T>) -> BooleanArray
|
||||
where
|
||||
T: ArrowPrimitiveType,
|
||||
T::Native: ArrowNativeTypeOp,
|
||||
{
|
||||
BooleanArray::from_iter(
|
||||
array
|
||||
.iter()
|
||||
.map(|opt| Some(opt.is_some_and(|v| !v.is_zero()))),
|
||||
)
|
||||
}
|
||||
|
||||
/// Convert a float PrimitiveArray to BooleanArray using MySQL truthy rules:
|
||||
/// NULL -> false, 0 (including -0.0) -> false, NaN -> true, other non-zero -> true
|
||||
fn float_array_to_bool<T>(array: &PrimitiveArray<T>) -> BooleanArray
|
||||
where
|
||||
T: ArrowPrimitiveType,
|
||||
T::Native: ArrowNativeTypeOp + num_traits::Float,
|
||||
{
|
||||
use num_traits::Float;
|
||||
BooleanArray::from_iter(
|
||||
array
|
||||
.iter()
|
||||
.map(|opt| Some(opt.is_some_and(|v| v.is_nan() || !v.is_zero()))),
|
||||
)
|
||||
}
|
||||
|
||||
/// Convert an Array to BooleanArray using MySQL truthy rules
|
||||
fn array_to_bool(array: ArrayRef) -> datafusion_common::Result<BooleanArray> {
|
||||
use arrow::datatypes::*;
|
||||
|
||||
match array.data_type() {
|
||||
DataType::Boolean => {
|
||||
let bool_array = array.as_boolean();
|
||||
Ok(BooleanArray::from_iter(
|
||||
bool_array.iter().map(|opt| Some(opt.unwrap_or(false))),
|
||||
))
|
||||
}
|
||||
DataType::Int8 => Ok(int_array_to_bool(array.as_primitive::<Int8Type>())),
|
||||
DataType::Int16 => Ok(int_array_to_bool(array.as_primitive::<Int16Type>())),
|
||||
DataType::Int32 => Ok(int_array_to_bool(array.as_primitive::<Int32Type>())),
|
||||
DataType::Int64 => Ok(int_array_to_bool(array.as_primitive::<Int64Type>())),
|
||||
DataType::UInt8 => Ok(int_array_to_bool(array.as_primitive::<UInt8Type>())),
|
||||
DataType::UInt16 => Ok(int_array_to_bool(array.as_primitive::<UInt16Type>())),
|
||||
DataType::UInt32 => Ok(int_array_to_bool(array.as_primitive::<UInt32Type>())),
|
||||
DataType::UInt64 => Ok(int_array_to_bool(array.as_primitive::<UInt64Type>())),
|
||||
// Float16 needs special handling since half::f16 doesn't implement num_traits::Float
|
||||
DataType::Float16 => {
|
||||
let typed_array = array.as_primitive::<Float16Type>();
|
||||
Ok(BooleanArray::from_iter(typed_array.iter().map(|opt| {
|
||||
Some(opt.is_some_and(|v| {
|
||||
let f = v.to_f32();
|
||||
f.is_nan() || !f.is_zero()
|
||||
}))
|
||||
})))
|
||||
}
|
||||
DataType::Float32 => Ok(float_array_to_bool(array.as_primitive::<Float32Type>())),
|
||||
DataType::Float64 => Ok(float_array_to_bool(array.as_primitive::<Float64Type>())),
|
||||
// Null type is always false.
|
||||
// Note: NullArray::is_null() returns false (physical null), so we must handle it explicitly.
|
||||
// See: https://github.com/apache/arrow-rs/issues/4840
|
||||
DataType::Null => Ok(BooleanArray::from(vec![false; array.len()])),
|
||||
// For other types, treat non-null as true
|
||||
_ => {
|
||||
let len = array.len();
|
||||
Ok(BooleanArray::from_iter(
|
||||
(0..len).map(|i| Some(!array.is_null(i))),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Field;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_common::arrow::array::{AsArray, Int32Array, StringArray};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_if_function_basic() {
|
||||
let if_func = IfFunction::default();
|
||||
assert_eq!("if", if_func.name());
|
||||
|
||||
// Test IF(true, 'yes', 'no') -> 'yes'
|
||||
let result = if_func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("", DataType::Utf8, true)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
let str_arr = arr.as_string::<i32>();
|
||||
assert_eq!(str_arr.value(0), "yes");
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_if_function_false() {
|
||||
let if_func = IfFunction::default();
|
||||
|
||||
// Test IF(false, 'yes', 'no') -> 'no'
|
||||
let result = if_func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("", DataType::Utf8, true)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
let str_arr = arr.as_string::<i32>();
|
||||
assert_eq!(str_arr.value(0), "no");
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_if_function_null_is_false() {
|
||||
let if_func = IfFunction::default();
|
||||
|
||||
// Test IF(NULL, 'yes', 'no') -> 'no' (NULL is treated as false)
|
||||
// Using Boolean(None) - typed null
|
||||
let result = if_func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Scalar(ScalarValue::Boolean(None)),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("", DataType::Utf8, true)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
let str_arr = arr.as_string::<i32>();
|
||||
assert_eq!(str_arr.value(0), "no");
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
|
||||
// Test IF(NULL, 'yes', 'no') -> 'no' using ScalarValue::Null (untyped null from SQL NULL literal)
|
||||
let result = if_func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Scalar(ScalarValue::Null),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("", DataType::Utf8, true)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
let str_arr = arr.as_string::<i32>();
|
||||
assert_eq!(str_arr.value(0), "no");
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_if_function_numeric_truthy() {
|
||||
let if_func = IfFunction::default();
|
||||
|
||||
// Test IF(1, 'yes', 'no') -> 'yes' (non-zero is true)
|
||||
let result = if_func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("", DataType::Utf8, true)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
let str_arr = arr.as_string::<i32>();
|
||||
assert_eq!(str_arr.value(0), "yes");
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
|
||||
// Test IF(0, 'yes', 'no') -> 'no' (zero is false)
|
||||
let result = if_func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Scalar(ScalarValue::Int32(Some(0))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("yes".to_string()))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("no".to_string()))),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 1,
|
||||
return_field: Arc::new(Field::new("", DataType::Utf8, true)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
let str_arr = arr.as_string::<i32>();
|
||||
assert_eq!(str_arr.value(0), "no");
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_if_function_with_arrays() {
|
||||
let if_func = IfFunction::default();
|
||||
|
||||
// Test with array condition
|
||||
let condition = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
|
||||
let true_val = StringArray::from(vec!["yes", "yes", "yes", "yes"]);
|
||||
let false_val = StringArray::from(vec!["no", "no", "no", "no"]);
|
||||
|
||||
let result = if_func
|
||||
.invoke_with_args(ScalarFunctionArgs {
|
||||
args: vec![
|
||||
ColumnarValue::Array(Arc::new(condition)),
|
||||
ColumnarValue::Array(Arc::new(true_val)),
|
||||
ColumnarValue::Array(Arc::new(false_val)),
|
||||
],
|
||||
arg_fields: vec![],
|
||||
number_rows: 4,
|
||||
return_field: Arc::new(Field::new("", DataType::Utf8, true)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
let str_arr = arr.as_string::<i32>();
|
||||
assert_eq!(str_arr.value(0), "yes"); // 1 is true
|
||||
assert_eq!(str_arr.value(1), "no"); // 0 is false
|
||||
assert_eq!(str_arr.value(2), "no"); // NULL is false
|
||||
assert_eq!(str_arr.value(3), "yes"); // 5 is true
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::Display;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::{Array, AsArray, BooleanBuilder};
|
||||
|
||||
@@ -14,13 +14,31 @@
|
||||
|
||||
//! String scalar functions
|
||||
|
||||
mod elt;
|
||||
mod field;
|
||||
mod format;
|
||||
mod insert;
|
||||
mod locate;
|
||||
mod regexp_extract;
|
||||
mod space;
|
||||
|
||||
pub(crate) use elt::EltFunction;
|
||||
pub(crate) use field::FieldFunction;
|
||||
pub(crate) use format::FormatFunction;
|
||||
pub(crate) use insert::InsertFunction;
|
||||
pub(crate) use locate::LocateFunction;
|
||||
pub(crate) use regexp_extract::RegexpExtractFunction;
|
||||
pub(crate) use space::SpaceFunction;
|
||||
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
/// Register all string functions
|
||||
pub fn register_string_functions(registry: &FunctionRegistry) {
|
||||
EltFunction::register(registry);
|
||||
FieldFunction::register(registry);
|
||||
FormatFunction::register(registry);
|
||||
InsertFunction::register(registry);
|
||||
LocateFunction::register(registry);
|
||||
RegexpExtractFunction::register(registry);
|
||||
SpaceFunction::register(registry);
|
||||
}
|
||||
|
||||
252
src/common/function/src/scalars/string/elt.rs
Normal file
252
src/common/function/src/scalars/string/elt.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible ELT function implementation.
|
||||
//!
|
||||
//! ELT(N, str1, str2, str3, ...) - Returns the Nth string from the list.
|
||||
//! Returns NULL if N < 1 or N > number of strings.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
|
||||
use datafusion_common::arrow::compute::cast;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "elt";
|
||||
|
||||
/// MySQL-compatible ELT function.
|
||||
///
|
||||
/// Syntax: ELT(N, str1, str2, str3, ...)
|
||||
/// Returns the Nth string argument. N is 1-based.
|
||||
/// Returns NULL if N is NULL, N < 1, or N > number of string arguments.
|
||||
#[derive(Debug)]
|
||||
pub struct EltFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl EltFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(EltFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EltFunction {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// ELT takes a variable number of arguments: (Int64, String, String, ...)
|
||||
signature: Signature::variadic_any(Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for EltFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for EltFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() < 2 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"ELT requires at least 2 arguments: ELT(N, str1, ...)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
let num_strings = arrays.len() - 1;
|
||||
|
||||
// First argument is the index (N) - try to cast to Int64
|
||||
let index_array = if arrays[0].data_type() == &DataType::Null {
|
||||
// All NULLs - return all NULLs
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, 0);
|
||||
for _ in 0..len {
|
||||
builder.append_null();
|
||||
}
|
||||
return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
|
||||
} else {
|
||||
cast(arrays[0].as_ref(), &DataType::Int64).map_err(|e| {
|
||||
DataFusionError::Execution(format!("ELT: index argument cast failed: {}", e))
|
||||
})?
|
||||
};
|
||||
|
||||
// Cast string arguments to LargeUtf8
|
||||
let string_arrays: Vec<ArrayRef> = arrays[1..]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
|
||||
DataFusionError::Execution(format!(
|
||||
"ELT: string argument {} cast failed: {}",
|
||||
i + 1,
|
||||
e
|
||||
))
|
||||
})
|
||||
})
|
||||
.collect::<datafusion_common::Result<Vec<_>>>()?;
|
||||
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
|
||||
|
||||
for i in 0..len {
|
||||
if index_array.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let n = index_array
|
||||
.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
|
||||
.value(i);
|
||||
|
||||
// N is 1-based, check bounds
|
||||
if n < 1 || n as usize > num_strings {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let str_idx = (n - 1) as usize;
|
||||
let str_array = string_arrays[str_idx].as_string::<i64>();
|
||||
|
||||
if str_array.is_null(i) {
|
||||
builder.append_null();
|
||||
} else {
|
||||
builder.append_value(str_array.value(i));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::{Int64Array, StringArray};
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_elt_basic() {
|
||||
let function = EltFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![1, 2, 3]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
|
||||
let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
|
||||
|
||||
let args = create_args(vec![n, s1, s2, s3]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "a");
|
||||
assert_eq!(str_array.value(1), "b");
|
||||
assert_eq!(str_array.value(2), "c");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_elt_out_of_bounds() {
|
||||
let function = EltFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![0, 4, -1]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
|
||||
let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
|
||||
|
||||
let args = create_args(vec![n, s1, s2, s3]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert!(str_array.is_null(0)); // 0 is out of bounds
|
||||
assert!(str_array.is_null(1)); // 4 is out of bounds
|
||||
assert!(str_array.is_null(2)); // -1 is out of bounds
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_elt_with_nulls() {
|
||||
let function = EltFunction::default();
|
||||
|
||||
// Row 0: n=1, select s1="a" -> "a"
|
||||
// Row 1: n=NULL -> NULL
|
||||
// Row 2: n=1, select s1=NULL -> NULL
|
||||
let n = Arc::new(Int64Array::from(vec![Some(1), None, Some(1)]));
|
||||
let s1 = Arc::new(StringArray::from(vec![Some("a"), Some("a"), None]));
|
||||
let s2 = Arc::new(StringArray::from(vec![Some("b"), Some("b"), Some("b")]));
|
||||
|
||||
let args = create_args(vec![n, s1, s2]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "a");
|
||||
assert!(str_array.is_null(1)); // N is NULL
|
||||
assert!(str_array.is_null(2)); // Selected string is NULL
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
224
src/common/function/src/scalars/string/field.rs
Normal file
224
src/common/function/src/scalars/string/field.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible FIELD function implementation.
|
||||
//!
|
||||
//! FIELD(str, str1, str2, str3, ...) - Returns the 1-based index of str in the list.
|
||||
//! Returns 0 if str is not found or is NULL.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
|
||||
use datafusion_common::arrow::compute::cast;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "field";
|
||||
|
||||
/// MySQL-compatible FIELD function.
|
||||
///
|
||||
/// Syntax: FIELD(str, str1, str2, str3, ...)
|
||||
/// Returns the 1-based index of str in the argument list (str1, str2, str3, ...).
|
||||
/// Returns 0 if str is not found or is NULL.
|
||||
#[derive(Debug)]
|
||||
pub struct FieldFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl FieldFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(FieldFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FieldFunction {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
// FIELD takes a variable number of arguments: (String, String, String, ...)
|
||||
signature: Signature::variadic_any(Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FieldFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for FieldFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::Int64)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() < 2 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"FIELD requires at least 2 arguments: FIELD(str, str1, ...)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
|
||||
// Cast all arguments to LargeUtf8
|
||||
let string_arrays: Vec<ArrayRef> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
cast(arr.as_ref(), &DataType::LargeUtf8).map_err(|e| {
|
||||
DataFusionError::Execution(format!("FIELD: argument {} cast failed: {}", i, e))
|
||||
})
|
||||
})
|
||||
.collect::<datafusion_common::Result<Vec<_>>>()?;
|
||||
|
||||
let search_str = string_arrays[0].as_string::<i64>();
|
||||
let mut builder = Int64Builder::with_capacity(len);
|
||||
|
||||
for i in 0..len {
|
||||
// If search string is NULL, return 0
|
||||
if search_str.is_null(i) {
|
||||
builder.append_value(0);
|
||||
continue;
|
||||
}
|
||||
|
||||
let needle = search_str.value(i);
|
||||
let mut found_idx = 0i64;
|
||||
|
||||
// Search through the list (starting from index 1 in string_arrays)
|
||||
for (j, str_arr) in string_arrays[1..].iter().enumerate() {
|
||||
let str_array = str_arr.as_string::<i64>();
|
||||
if !str_array.is_null(i) && str_array.value(i) == needle {
|
||||
found_idx = (j + 1) as i64; // 1-based index
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
builder.append_value(found_idx);
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::StringArray;
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::Int64, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_basic() {
|
||||
let function = FieldFunction::default();
|
||||
|
||||
let search = Arc::new(StringArray::from(vec!["b", "d", "a"]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["b", "b", "b"]));
|
||||
let s3 = Arc::new(StringArray::from(vec!["c", "c", "c"]));
|
||||
|
||||
let args = create_args(vec![search, s1, s2, s3]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 2); // "b" is at index 2
|
||||
assert_eq!(int_array.value(1), 0); // "d" not found
|
||||
assert_eq!(int_array.value(2), 1); // "a" is at index 1
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_with_null_search() {
|
||||
let function = FieldFunction::default();
|
||||
|
||||
let search = Arc::new(StringArray::from(vec![Some("a"), None]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["b", "b"]));
|
||||
|
||||
let args = create_args(vec![search, s1, s2]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 1); // "a" found at index 1
|
||||
assert_eq!(int_array.value(1), 0); // NULL search returns 0
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_field_case_sensitive() {
|
||||
let function = FieldFunction::default();
|
||||
|
||||
let search = Arc::new(StringArray::from(vec!["A", "a"]));
|
||||
let s1 = Arc::new(StringArray::from(vec!["a", "a"]));
|
||||
let s2 = Arc::new(StringArray::from(vec!["A", "A"]));
|
||||
|
||||
let args = create_args(vec![search, s1, s2]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 2); // "A" matches at index 2
|
||||
assert_eq!(int_array.value(1), 1); // "a" matches at index 1
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
512
src/common/function/src/scalars/string/format.rs
Normal file
512
src/common/function/src/scalars/string/format.rs
Normal file
@@ -0,0 +1,512 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible FORMAT function implementation.
|
||||
//!
|
||||
//! FORMAT(X, D) - Formats the number X with D decimal places using thousand separators.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
|
||||
use datafusion_common::arrow::datatypes as arrow_types;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "format";
|
||||
|
||||
/// MySQL-compatible FORMAT function.
|
||||
///
|
||||
/// Syntax: FORMAT(X, D)
|
||||
/// Formats the number X to a format like '#,###,###.##', rounded to D decimal places.
|
||||
/// D can be 0 to 30.
|
||||
///
|
||||
/// Note: This implementation uses the en_US locale (comma as thousand separator,
|
||||
/// period as decimal separator).
|
||||
#[derive(Debug)]
|
||||
pub struct FormatFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl FormatFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(FormatFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FormatFunction {
|
||||
fn default() -> Self {
|
||||
let mut signatures = Vec::new();
|
||||
|
||||
// Support various numeric types for X
|
||||
let numeric_types = [
|
||||
DataType::Float64,
|
||||
DataType::Float32,
|
||||
DataType::Int64,
|
||||
DataType::Int32,
|
||||
DataType::Int16,
|
||||
DataType::Int8,
|
||||
DataType::UInt64,
|
||||
DataType::UInt32,
|
||||
DataType::UInt16,
|
||||
DataType::UInt8,
|
||||
];
|
||||
|
||||
// D can be various integer types
|
||||
let int_types = [
|
||||
DataType::Int64,
|
||||
DataType::Int32,
|
||||
DataType::Int16,
|
||||
DataType::Int8,
|
||||
DataType::UInt64,
|
||||
DataType::UInt32,
|
||||
DataType::UInt16,
|
||||
DataType::UInt8,
|
||||
];
|
||||
|
||||
for x_type in &numeric_types {
|
||||
for d_type in &int_types {
|
||||
signatures.push(TypeSignature::Exact(vec![x_type.clone(), d_type.clone()]));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
signature: Signature::one_of(signatures, Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FormatFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for FormatFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() != 2 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"FORMAT requires exactly 2 arguments: FORMAT(X, D)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
|
||||
let x_array = &arrays[0];
|
||||
let d_array = &arrays[1];
|
||||
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, len * 20);
|
||||
|
||||
for i in 0..len {
|
||||
if x_array.is_null(i) || d_array.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let decimal_places = get_decimal_places(d_array, i)?.clamp(0, 30) as usize;
|
||||
|
||||
let formatted = match x_array.data_type() {
|
||||
DataType::Float64 | DataType::Float32 => {
|
||||
format_number_float(get_float_value(x_array, i)?, decimal_places)
|
||||
}
|
||||
DataType::Int64
|
||||
| DataType::Int32
|
||||
| DataType::Int16
|
||||
| DataType::Int8
|
||||
| DataType::UInt64
|
||||
| DataType::UInt32
|
||||
| DataType::UInt16
|
||||
| DataType::UInt8 => format_number_integer(x_array, i, decimal_places)?,
|
||||
_ => {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"FORMAT: unsupported type {:?}",
|
||||
x_array.data_type()
|
||||
)));
|
||||
}
|
||||
};
|
||||
builder.append_value(&formatted);
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get float value from various numeric types.
|
||||
fn get_float_value(
|
||||
array: &datafusion_common::arrow::array::ArrayRef,
|
||||
index: usize,
|
||||
) -> datafusion_common::Result<f64> {
|
||||
match array.data_type() {
|
||||
DataType::Float64 => Ok(array
|
||||
.as_primitive::<arrow_types::Float64Type>()
|
||||
.value(index)),
|
||||
DataType::Float32 => Ok(array
|
||||
.as_primitive::<arrow_types::Float32Type>()
|
||||
.value(index) as f64),
|
||||
_ => Err(DataFusionError::Execution(format!(
|
||||
"FORMAT: unsupported type {:?}",
|
||||
array.data_type()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get decimal places from various integer types.
|
||||
///
|
||||
/// MySQL clamps decimal places to `0..=30`. This function returns an `i64` so the caller can clamp.
|
||||
fn get_decimal_places(
|
||||
array: &datafusion_common::arrow::array::ArrayRef,
|
||||
index: usize,
|
||||
) -> datafusion_common::Result<i64> {
|
||||
match array.data_type() {
|
||||
DataType::Int64 => Ok(array.as_primitive::<arrow_types::Int64Type>().value(index)),
|
||||
DataType::Int32 => Ok(array.as_primitive::<arrow_types::Int32Type>().value(index) as i64),
|
||||
DataType::Int16 => Ok(array.as_primitive::<arrow_types::Int16Type>().value(index) as i64),
|
||||
DataType::Int8 => Ok(array.as_primitive::<arrow_types::Int8Type>().value(index) as i64),
|
||||
DataType::UInt64 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt64Type>().value(index);
|
||||
Ok(if v > i64::MAX as u64 {
|
||||
i64::MAX
|
||||
} else {
|
||||
v as i64
|
||||
})
|
||||
}
|
||||
DataType::UInt32 => Ok(array.as_primitive::<arrow_types::UInt32Type>().value(index) as i64),
|
||||
DataType::UInt16 => Ok(array.as_primitive::<arrow_types::UInt16Type>().value(index) as i64),
|
||||
DataType::UInt8 => Ok(array.as_primitive::<arrow_types::UInt8Type>().value(index) as i64),
|
||||
_ => Err(DataFusionError::Execution(format!(
|
||||
"FORMAT: unsupported type {:?}",
|
||||
array.data_type()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn format_number_integer(
|
||||
array: &datafusion_common::arrow::array::ArrayRef,
|
||||
index: usize,
|
||||
decimal_places: usize,
|
||||
) -> datafusion_common::Result<String> {
|
||||
let (is_negative, abs_digits) = match array.data_type() {
|
||||
DataType::Int64 => {
|
||||
let v = array.as_primitive::<arrow_types::Int64Type>().value(index) as i128;
|
||||
(v.is_negative(), v.unsigned_abs().to_string())
|
||||
}
|
||||
DataType::Int32 => {
|
||||
let v = array.as_primitive::<arrow_types::Int32Type>().value(index) as i128;
|
||||
(v.is_negative(), v.unsigned_abs().to_string())
|
||||
}
|
||||
DataType::Int16 => {
|
||||
let v = array.as_primitive::<arrow_types::Int16Type>().value(index) as i128;
|
||||
(v.is_negative(), v.unsigned_abs().to_string())
|
||||
}
|
||||
DataType::Int8 => {
|
||||
let v = array.as_primitive::<arrow_types::Int8Type>().value(index) as i128;
|
||||
(v.is_negative(), v.unsigned_abs().to_string())
|
||||
}
|
||||
DataType::UInt64 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt64Type>().value(index) as u128;
|
||||
(false, v.to_string())
|
||||
}
|
||||
DataType::UInt32 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt32Type>().value(index) as u128;
|
||||
(false, v.to_string())
|
||||
}
|
||||
DataType::UInt16 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt16Type>().value(index) as u128;
|
||||
(false, v.to_string())
|
||||
}
|
||||
DataType::UInt8 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt8Type>().value(index) as u128;
|
||||
(false, v.to_string())
|
||||
}
|
||||
_ => {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"FORMAT: unsupported type {:?}",
|
||||
array.data_type()
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
let mut result = String::new();
|
||||
if is_negative {
|
||||
result.push('-');
|
||||
}
|
||||
result.push_str(&add_thousand_separators(&abs_digits));
|
||||
|
||||
if decimal_places > 0 {
|
||||
result.push('.');
|
||||
result.push_str(&"0".repeat(decimal_places));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Format a float with thousand separators and `decimal_places` digits after decimal point.
|
||||
fn format_number_float(x: f64, decimal_places: usize) -> String {
|
||||
// Handle special cases
|
||||
if x.is_nan() {
|
||||
return "NaN".to_string();
|
||||
}
|
||||
if x.is_infinite() {
|
||||
return if x.is_sign_positive() {
|
||||
"Infinity".to_string()
|
||||
} else {
|
||||
"-Infinity".to_string()
|
||||
};
|
||||
}
|
||||
|
||||
// Round to decimal_places
|
||||
let multiplier = 10f64.powi(decimal_places as i32);
|
||||
let rounded = (x * multiplier).round() / multiplier;
|
||||
|
||||
// Split into integer and fractional parts
|
||||
let is_negative = rounded < 0.0;
|
||||
let abs_value = rounded.abs();
|
||||
|
||||
// Format with the specified decimal places
|
||||
let formatted = if decimal_places == 0 {
|
||||
format!("{:.0}", abs_value)
|
||||
} else {
|
||||
format!("{:.prec$}", abs_value, prec = decimal_places)
|
||||
};
|
||||
|
||||
// Split at decimal point
|
||||
let parts: Vec<&str> = formatted.split('.').collect();
|
||||
let int_part = parts[0];
|
||||
let dec_part = parts.get(1).copied();
|
||||
|
||||
// Add thousand separators to integer part
|
||||
let int_with_sep = add_thousand_separators(int_part);
|
||||
|
||||
// Build result
|
||||
let mut result = String::new();
|
||||
if is_negative {
|
||||
result.push('-');
|
||||
}
|
||||
result.push_str(&int_with_sep);
|
||||
if let Some(dec) = dec_part {
|
||||
result.push('.');
|
||||
result.push_str(dec);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Add thousand separators (commas) to an integer string.
|
||||
fn add_thousand_separators(s: &str) -> String {
|
||||
let chars: Vec<char> = s.chars().collect();
|
||||
let len = chars.len();
|
||||
|
||||
if len <= 3 {
|
||||
return s.to_string();
|
||||
}
|
||||
|
||||
let mut result = String::with_capacity(len + len / 3);
|
||||
let first_group_len = len % 3;
|
||||
let first_group_len = if first_group_len == 0 {
|
||||
3
|
||||
} else {
|
||||
first_group_len
|
||||
};
|
||||
|
||||
for (i, ch) in chars.iter().enumerate() {
|
||||
if i > 0 && i >= first_group_len && (i - first_group_len) % 3 == 0 {
|
||||
result.push(',');
|
||||
}
|
||||
result.push(*ch);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::{Float64Array, Int64Array};
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<datafusion_common::arrow::array::ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_basic() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Float64Array::from(vec![1234567.891, 1234.5, 1234567.0]));
|
||||
let d = Arc::new(Int64Array::from(vec![2, 0, 3]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "1,234,567.89");
|
||||
assert_eq!(str_array.value(1), "1,235"); // rounded
|
||||
assert_eq!(str_array.value(2), "1,234,567.000");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_negative() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Float64Array::from(vec![-1234567.891]));
|
||||
let d = Arc::new(Int64Array::from(vec![2]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "-1,234,567.89");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_small_numbers() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Float64Array::from(vec![0.5, 12.345, 123.0]));
|
||||
let d = Arc::new(Int64Array::from(vec![2, 2, 0]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "0.50");
|
||||
assert_eq!(str_array.value(1), "12.35"); // rounded
|
||||
assert_eq!(str_array.value(2), "123");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_with_nulls() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Float64Array::from(vec![Some(1234.5), None]));
|
||||
let d = Arc::new(Int64Array::from(vec![2, 2]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "1,234.50");
|
||||
assert!(str_array.is_null(1));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_thousand_separators() {
|
||||
assert_eq!(add_thousand_separators("1"), "1");
|
||||
assert_eq!(add_thousand_separators("12"), "12");
|
||||
assert_eq!(add_thousand_separators("123"), "123");
|
||||
assert_eq!(add_thousand_separators("1234"), "1,234");
|
||||
assert_eq!(add_thousand_separators("12345"), "12,345");
|
||||
assert_eq!(add_thousand_separators("123456"), "123,456");
|
||||
assert_eq!(add_thousand_separators("1234567"), "1,234,567");
|
||||
assert_eq!(add_thousand_separators("12345678"), "12,345,678");
|
||||
assert_eq!(add_thousand_separators("123456789"), "123,456,789");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_large_int_no_float_precision_loss() {
|
||||
let function = FormatFunction::default();
|
||||
|
||||
// 2^53 + 1 cannot be represented exactly as f64.
|
||||
let x = Arc::new(Int64Array::from(vec![9_007_199_254_740_993i64]));
|
||||
let d = Arc::new(Int64Array::from(vec![0]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "9,007,199,254,740,993");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_decimal_places_u64_overflow_clamps() {
|
||||
use datafusion_common::arrow::array::UInt64Array;
|
||||
|
||||
let function = FormatFunction::default();
|
||||
|
||||
let x = Arc::new(Int64Array::from(vec![1]));
|
||||
let d = Arc::new(UInt64Array::from(vec![u64::MAX]));
|
||||
|
||||
let args = create_args(vec![x, d]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), format!("1.{}", "0".repeat(30)));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
345
src/common/function/src/scalars/string/insert.rs
Normal file
345
src/common/function/src/scalars/string/insert.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible INSERT function implementation.
|
||||
//!
|
||||
//! INSERT(str, pos, len, newstr) - Inserts newstr into str at position pos,
|
||||
//! replacing len characters.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, LargeStringBuilder};
|
||||
use datafusion_common::arrow::compute::cast;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "insert";
|
||||
|
||||
/// MySQL-compatible INSERT function.
|
||||
///
|
||||
/// Syntax: INSERT(str, pos, len, newstr)
|
||||
/// Returns str with the substring beginning at position pos and len characters long
|
||||
/// replaced by newstr.
|
||||
///
|
||||
/// - pos is 1-based
|
||||
/// - If pos is out of range, returns the original string
|
||||
/// - If len is out of range, replaces from pos to end of string
|
||||
#[derive(Debug)]
|
||||
pub struct InsertFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl InsertFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(InsertFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for InsertFunction {
|
||||
fn default() -> Self {
|
||||
let mut signatures = Vec::new();
|
||||
let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
|
||||
let int_types = [
|
||||
DataType::Int64,
|
||||
DataType::Int32,
|
||||
DataType::Int16,
|
||||
DataType::Int8,
|
||||
DataType::UInt64,
|
||||
DataType::UInt32,
|
||||
DataType::UInt16,
|
||||
DataType::UInt8,
|
||||
];
|
||||
|
||||
for str_type in &string_types {
|
||||
for newstr_type in &string_types {
|
||||
for pos_type in &int_types {
|
||||
for len_type in &int_types {
|
||||
signatures.push(TypeSignature::Exact(vec![
|
||||
str_type.clone(),
|
||||
pos_type.clone(),
|
||||
len_type.clone(),
|
||||
newstr_type.clone(),
|
||||
]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
signature: Signature::one_of(signatures, Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for InsertFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for InsertFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() != 4 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"INSERT requires exactly 4 arguments: INSERT(str, pos, len, newstr)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
|
||||
// Cast string arguments to LargeUtf8
|
||||
let str_array = cast_to_large_utf8(&arrays[0], "str")?;
|
||||
let newstr_array = cast_to_large_utf8(&arrays[3], "newstr")?;
|
||||
let pos_array = cast_to_int64(&arrays[1], "pos")?;
|
||||
let replace_len_array = cast_to_int64(&arrays[2], "len")?;
|
||||
|
||||
let str_arr = str_array.as_string::<i64>();
|
||||
let pos_arr = pos_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
let len_arr =
|
||||
replace_len_array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
let newstr_arr = newstr_array.as_string::<i64>();
|
||||
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, len * 32);
|
||||
|
||||
for i in 0..len {
|
||||
// Check for NULLs
|
||||
if str_arr.is_null(i)
|
||||
|| pos_array.is_null(i)
|
||||
|| replace_len_array.is_null(i)
|
||||
|| newstr_arr.is_null(i)
|
||||
{
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let original = str_arr.value(i);
|
||||
let pos = pos_arr.value(i);
|
||||
let replace_len = len_arr.value(i);
|
||||
let new_str = newstr_arr.value(i);
|
||||
|
||||
let result = insert_string(original, pos, replace_len, new_str);
|
||||
builder.append_value(&result);
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast array to LargeUtf8 for uniform string access.
|
||||
fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
|
||||
cast(array.as_ref(), &DataType::LargeUtf8)
|
||||
.map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
|
||||
}
|
||||
|
||||
fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
|
||||
cast(array.as_ref(), &DataType::Int64)
|
||||
.map_err(|e| DataFusionError::Execution(format!("INSERT: {} cast failed: {}", name, e)))
|
||||
}
|
||||
|
||||
/// Perform the INSERT string operation.
|
||||
/// pos is 1-based. If pos < 1 or pos > len(str) + 1, returns original string.
|
||||
fn insert_string(original: &str, pos: i64, replace_len: i64, new_str: &str) -> String {
|
||||
let char_count = original.chars().count();
|
||||
|
||||
// MySQL behavior: if pos < 1 or pos > string length + 1, return original
|
||||
if pos < 1 || pos as usize > char_count + 1 {
|
||||
return original.to_string();
|
||||
}
|
||||
|
||||
let start_idx = (pos - 1) as usize; // Convert to 0-based
|
||||
|
||||
// Calculate end index for replacement
|
||||
let replace_len = if replace_len < 0 {
|
||||
0
|
||||
} else {
|
||||
replace_len as usize
|
||||
};
|
||||
let end_idx = (start_idx + replace_len).min(char_count);
|
||||
|
||||
let start_byte = char_to_byte_idx(original, start_idx);
|
||||
let end_byte = char_to_byte_idx(original, end_idx);
|
||||
|
||||
let mut result = String::with_capacity(original.len() + new_str.len());
|
||||
result.push_str(&original[..start_byte]);
|
||||
result.push_str(new_str);
|
||||
result.push_str(&original[end_byte..]);
|
||||
result
|
||||
}
|
||||
|
||||
fn char_to_byte_idx(s: &str, char_idx: usize) -> usize {
|
||||
s.char_indices()
|
||||
.nth(char_idx)
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(s.len())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::{Int64Array, StringArray};
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_basic() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
// INSERT('Quadratic', 3, 4, 'What') => 'QuWhattic'
|
||||
let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
|
||||
let pos = Arc::new(Int64Array::from(vec![3]));
|
||||
let len = Arc::new(Int64Array::from(vec![4]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["What"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "QuWhattic");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_out_of_range_pos() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
// INSERT('Quadratic', 0, 4, 'What') => 'Quadratic' (pos < 1)
|
||||
let str_arr = Arc::new(StringArray::from(vec!["Quadratic", "Quadratic"]));
|
||||
let pos = Arc::new(Int64Array::from(vec![0, 100]));
|
||||
let len = Arc::new(Int64Array::from(vec![4, 4]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["What", "What"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "Quadratic"); // pos < 1
|
||||
assert_eq!(str_array.value(1), "Quadratic"); // pos > length
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_replace_to_end() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
// INSERT('Quadratic', 3, 100, 'What') => 'QuWhat' (len exceeds remaining)
|
||||
let str_arr = Arc::new(StringArray::from(vec!["Quadratic"]));
|
||||
let pos = Arc::new(Int64Array::from(vec![3]));
|
||||
let len = Arc::new(Int64Array::from(vec![100]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["What"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "QuWhat");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_unicode() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
// INSERT('hello世界', 6, 1, 'の') => 'helloの界'
|
||||
let str_arr = Arc::new(StringArray::from(vec!["hello世界"]));
|
||||
let pos = Arc::new(Int64Array::from(vec![6]));
|
||||
let len = Arc::new(Int64Array::from(vec![1]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["の"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "helloの界");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_with_nulls() {
|
||||
let function = InsertFunction::default();
|
||||
|
||||
let str_arr = Arc::new(StringArray::from(vec![Some("hello"), None]));
|
||||
let pos = Arc::new(Int64Array::from(vec![1, 1]));
|
||||
let len = Arc::new(Int64Array::from(vec![1, 1]));
|
||||
let newstr = Arc::new(StringArray::from(vec!["X", "X"]));
|
||||
|
||||
let args = create_args(vec![str_arr, pos, len, newstr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "Xello");
|
||||
assert!(str_array.is_null(1));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
373
src/common/function/src/scalars/string/locate.rs
Normal file
373
src/common/function/src/scalars/string/locate.rs
Normal file
@@ -0,0 +1,373 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible LOCATE function implementation.
|
||||
//!
|
||||
//! LOCATE(substr, str) - Returns the position of the first occurrence of substr in str (1-based).
|
||||
//! LOCATE(substr, str, pos) - Returns the position of the first occurrence of substr in str,
|
||||
//! starting from position pos.
|
||||
//! Returns 0 if substr is not found.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, ArrayRef, AsArray, Int64Builder};
|
||||
use datafusion_common::arrow::compute::cast;
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "locate";
|
||||
|
||||
/// MySQL-compatible LOCATE function.
|
||||
///
|
||||
/// Syntax:
|
||||
/// - LOCATE(substr, str) - Returns 1-based position of substr in str, or 0 if not found.
|
||||
/// - LOCATE(substr, str, pos) - Same, but starts searching from position pos.
|
||||
#[derive(Debug)]
|
||||
pub struct LocateFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl LocateFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(LocateFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LocateFunction {
|
||||
fn default() -> Self {
|
||||
// Support 2 or 3 arguments with various string types
|
||||
let mut signatures = Vec::new();
|
||||
let string_types = [DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
|
||||
let int_types = [
|
||||
DataType::Int64,
|
||||
DataType::Int32,
|
||||
DataType::Int16,
|
||||
DataType::Int8,
|
||||
DataType::UInt64,
|
||||
DataType::UInt32,
|
||||
DataType::UInt16,
|
||||
DataType::UInt8,
|
||||
];
|
||||
|
||||
// 2-argument form: LOCATE(substr, str)
|
||||
for substr_type in &string_types {
|
||||
for str_type in &string_types {
|
||||
signatures.push(TypeSignature::Exact(vec![
|
||||
substr_type.clone(),
|
||||
str_type.clone(),
|
||||
]));
|
||||
}
|
||||
}
|
||||
|
||||
// 3-argument form: LOCATE(substr, str, pos)
|
||||
for substr_type in &string_types {
|
||||
for str_type in &string_types {
|
||||
for pos_type in &int_types {
|
||||
signatures.push(TypeSignature::Exact(vec![
|
||||
substr_type.clone(),
|
||||
str_type.clone(),
|
||||
pos_type.clone(),
|
||||
]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
signature: Signature::one_of(signatures, Volatility::Immutable),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for LocateFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for LocateFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::Int64)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let arg_count = args.args.len();
|
||||
if !(2..=3).contains(&arg_count) {
|
||||
return Err(DataFusionError::Execution(
|
||||
"LOCATE requires 2 or 3 arguments: LOCATE(substr, str) or LOCATE(substr, str, pos)"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
|
||||
// Cast string arguments to LargeUtf8 for uniform access
|
||||
let substr_array = cast_to_large_utf8(&arrays[0], "substr")?;
|
||||
let str_array = cast_to_large_utf8(&arrays[1], "str")?;
|
||||
|
||||
let substr = substr_array.as_string::<i64>();
|
||||
let str_arr = str_array.as_string::<i64>();
|
||||
let len = substr.len();
|
||||
|
||||
// Handle optional pos argument
|
||||
let pos_array: Option<ArrayRef> = if arg_count == 3 {
|
||||
Some(cast_to_int64(&arrays[2], "pos")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut builder = Int64Builder::with_capacity(len);
|
||||
|
||||
for i in 0..len {
|
||||
if substr.is_null(i) || str_arr.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let needle = substr.value(i);
|
||||
let haystack = str_arr.value(i);
|
||||
|
||||
// Get starting position (1-based in MySQL, convert to 0-based)
|
||||
let start_pos = if let Some(ref pos_arr) = pos_array {
|
||||
if pos_arr.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
let pos = pos_arr
|
||||
.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>()
|
||||
.value(i);
|
||||
if pos < 1 {
|
||||
// MySQL returns 0 for pos < 1
|
||||
builder.append_value(0);
|
||||
continue;
|
||||
}
|
||||
(pos - 1) as usize
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Find position using character-based indexing (for Unicode support)
|
||||
let result = locate_substr(haystack, needle, start_pos);
|
||||
builder.append_value(result);
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast array to LargeUtf8 for uniform string access.
|
||||
fn cast_to_large_utf8(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
|
||||
cast(array.as_ref(), &DataType::LargeUtf8)
|
||||
.map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
|
||||
}
|
||||
|
||||
fn cast_to_int64(array: &ArrayRef, name: &str) -> datafusion_common::Result<ArrayRef> {
|
||||
cast(array.as_ref(), &DataType::Int64)
|
||||
.map_err(|e| DataFusionError::Execution(format!("LOCATE: {} cast failed: {}", name, e)))
|
||||
}
|
||||
|
||||
/// Find the 1-based position of needle in haystack, starting from start_pos (0-based character index).
|
||||
/// Returns 0 if not found.
|
||||
fn locate_substr(haystack: &str, needle: &str, start_pos: usize) -> i64 {
|
||||
// Handle empty needle - MySQL returns start_pos + 1
|
||||
if needle.is_empty() {
|
||||
let char_count = haystack.chars().count();
|
||||
return if start_pos <= char_count {
|
||||
(start_pos + 1) as i64
|
||||
} else {
|
||||
0
|
||||
};
|
||||
}
|
||||
|
||||
// Convert start_pos (character index) to byte index
|
||||
let byte_start = haystack
|
||||
.char_indices()
|
||||
.nth(start_pos)
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(haystack.len());
|
||||
|
||||
if byte_start >= haystack.len() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Search in the substring
|
||||
let search_str = &haystack[byte_start..];
|
||||
if let Some(byte_pos) = search_str.find(needle) {
|
||||
// Convert byte position back to character position
|
||||
let char_pos = search_str[..byte_pos].chars().count();
|
||||
// Return 1-based position relative to original string
|
||||
(start_pos + char_pos + 1) as i64
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::StringArray;
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::Int64, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_basic() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec!["world", "xyz", "hello"]));
|
||||
let str_arr = Arc::new(StringArray::from(vec![
|
||||
"hello world",
|
||||
"hello world",
|
||||
"hello world",
|
||||
]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 7); // "world" at position 7
|
||||
assert_eq!(int_array.value(1), 0); // "xyz" not found
|
||||
assert_eq!(int_array.value(2), 1); // "hello" at position 1
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_with_position() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec!["o", "o", "o"]));
|
||||
let str_arr = Arc::new(StringArray::from(vec![
|
||||
"hello world",
|
||||
"hello world",
|
||||
"hello world",
|
||||
]));
|
||||
let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
|
||||
1, 5, 8,
|
||||
]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr, pos]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 5); // first 'o' at position 5
|
||||
assert_eq!(int_array.value(1), 5); // 'o' at position 5 (start from 5)
|
||||
assert_eq!(int_array.value(2), 8); // 'o' in "world" at position 8
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_unicode() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec!["世", "界"]));
|
||||
let str_arr = Arc::new(StringArray::from(vec!["hello世界", "hello世界"]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 6); // "世" at position 6
|
||||
assert_eq!(int_array.value(1), 7); // "界" at position 7
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_empty_needle() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec!["", ""]));
|
||||
let str_arr = Arc::new(StringArray::from(vec!["hello", "hello"]));
|
||||
let pos = Arc::new(datafusion_common::arrow::array::Int64Array::from(vec![
|
||||
1, 3,
|
||||
]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr, pos]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 1); // empty string at pos 1
|
||||
assert_eq!(int_array.value(1), 3); // empty string at pos 3
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_locate_with_nulls() {
|
||||
let function = LocateFunction::default();
|
||||
|
||||
let substr = Arc::new(StringArray::from(vec![Some("o"), None]));
|
||||
let str_arr = Arc::new(StringArray::from(vec![Some("hello"), Some("hello")]));
|
||||
|
||||
let args = create_args(vec![substr, str_arr]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let int_array = array.as_primitive::<datafusion_common::arrow::datatypes::Int64Type>();
|
||||
assert_eq!(int_array.value(0), 5);
|
||||
assert!(int_array.is_null(1));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
252
src/common/function/src/scalars/string/space.rs
Normal file
252
src/common/function/src/scalars/string/space.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! MySQL-compatible SPACE function implementation.
|
||||
//!
|
||||
//! SPACE(N) - Returns a string consisting of N space characters.
|
||||
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::DataFusionError;
|
||||
use datafusion_common::arrow::array::{Array, AsArray, LargeStringBuilder};
|
||||
use datafusion_common::arrow::datatypes::DataType;
|
||||
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, TypeSignature, Volatility};
|
||||
|
||||
use crate::function::Function;
|
||||
use crate::function_registry::FunctionRegistry;
|
||||
|
||||
const NAME: &str = "space";
|
||||
|
||||
// Safety limit for maximum number of spaces
|
||||
const MAX_SPACE_COUNT: i64 = 1024 * 1024; // 1MB of spaces
|
||||
|
||||
/// MySQL-compatible SPACE function.
|
||||
///
|
||||
/// Syntax: SPACE(N)
|
||||
/// Returns a string consisting of N space characters.
|
||||
/// Returns NULL if N is NULL.
|
||||
/// Returns empty string if N < 0.
|
||||
#[derive(Debug)]
|
||||
pub struct SpaceFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl SpaceFunction {
|
||||
pub fn register(registry: &FunctionRegistry) {
|
||||
registry.register_scalar(SpaceFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SpaceFunction {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
signature: Signature::one_of(
|
||||
vec![
|
||||
TypeSignature::Exact(vec![DataType::Int64]),
|
||||
TypeSignature::Exact(vec![DataType::Int32]),
|
||||
TypeSignature::Exact(vec![DataType::Int16]),
|
||||
TypeSignature::Exact(vec![DataType::Int8]),
|
||||
TypeSignature::Exact(vec![DataType::UInt64]),
|
||||
TypeSignature::Exact(vec![DataType::UInt32]),
|
||||
TypeSignature::Exact(vec![DataType::UInt16]),
|
||||
TypeSignature::Exact(vec![DataType::UInt8]),
|
||||
],
|
||||
Volatility::Immutable,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SpaceFunction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", NAME.to_ascii_uppercase())
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for SpaceFunction {
|
||||
fn name(&self) -> &str {
|
||||
NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::LargeUtf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
if args.args.len() != 1 {
|
||||
return Err(DataFusionError::Execution(
|
||||
"SPACE requires exactly 1 argument: SPACE(N)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
|
||||
let len = arrays[0].len();
|
||||
let n_array = &arrays[0];
|
||||
|
||||
let mut builder = LargeStringBuilder::with_capacity(len, len * 10);
|
||||
|
||||
for i in 0..len {
|
||||
if n_array.is_null(i) {
|
||||
builder.append_null();
|
||||
continue;
|
||||
}
|
||||
|
||||
let n = get_int_value(n_array, i)?;
|
||||
|
||||
if n < 0 {
|
||||
// MySQL returns empty string for negative values
|
||||
builder.append_value("");
|
||||
} else if n > MAX_SPACE_COUNT {
|
||||
return Err(DataFusionError::Execution(format!(
|
||||
"SPACE: requested {} spaces exceeds maximum allowed ({})",
|
||||
n, MAX_SPACE_COUNT
|
||||
)));
|
||||
} else {
|
||||
let spaces = " ".repeat(n as usize);
|
||||
builder.append_value(&spaces);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract integer value from various integer types.
|
||||
fn get_int_value(
|
||||
array: &datafusion_common::arrow::array::ArrayRef,
|
||||
index: usize,
|
||||
) -> datafusion_common::Result<i64> {
|
||||
use datafusion_common::arrow::datatypes as arrow_types;
|
||||
|
||||
match array.data_type() {
|
||||
DataType::Int64 => Ok(array.as_primitive::<arrow_types::Int64Type>().value(index)),
|
||||
DataType::Int32 => Ok(array.as_primitive::<arrow_types::Int32Type>().value(index) as i64),
|
||||
DataType::Int16 => Ok(array.as_primitive::<arrow_types::Int16Type>().value(index) as i64),
|
||||
DataType::Int8 => Ok(array.as_primitive::<arrow_types::Int8Type>().value(index) as i64),
|
||||
DataType::UInt64 => {
|
||||
let v = array.as_primitive::<arrow_types::UInt64Type>().value(index);
|
||||
if v > i64::MAX as u64 {
|
||||
Err(DataFusionError::Execution(format!(
|
||||
"SPACE: value {} exceeds maximum",
|
||||
v
|
||||
)))
|
||||
} else {
|
||||
Ok(v as i64)
|
||||
}
|
||||
}
|
||||
DataType::UInt32 => Ok(array.as_primitive::<arrow_types::UInt32Type>().value(index) as i64),
|
||||
DataType::UInt16 => Ok(array.as_primitive::<arrow_types::UInt16Type>().value(index) as i64),
|
||||
DataType::UInt8 => Ok(array.as_primitive::<arrow_types::UInt8Type>().value(index) as i64),
|
||||
_ => Err(DataFusionError::Execution(format!(
|
||||
"SPACE: unsupported type {:?}",
|
||||
array.data_type()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use datafusion_common::arrow::array::Int64Array;
|
||||
use datafusion_common::arrow::datatypes::Field;
|
||||
use datafusion_expr::ScalarFunctionArgs;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_args(arrays: Vec<datafusion_common::arrow::array::ArrayRef>) -> ScalarFunctionArgs {
|
||||
let arg_fields: Vec<_> = arrays
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, arr)| {
|
||||
Arc::new(Field::new(
|
||||
format!("arg_{}", i),
|
||||
arr.data_type().clone(),
|
||||
true,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
ScalarFunctionArgs {
|
||||
args: arrays.iter().cloned().map(ColumnarValue::Array).collect(),
|
||||
arg_fields,
|
||||
return_field: Arc::new(Field::new("result", DataType::LargeUtf8, true)),
|
||||
number_rows: arrays[0].len(),
|
||||
config_options: Arc::new(datafusion_common::config::ConfigOptions::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_space_basic() {
|
||||
let function = SpaceFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![0, 1, 5]));
|
||||
|
||||
let args = create_args(vec![n]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "");
|
||||
assert_eq!(str_array.value(1), " ");
|
||||
assert_eq!(str_array.value(2), " ");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_space_negative() {
|
||||
let function = SpaceFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![-1, -100]));
|
||||
|
||||
let args = create_args(vec![n]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), "");
|
||||
assert_eq!(str_array.value(1), "");
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_space_with_nulls() {
|
||||
let function = SpaceFunction::default();
|
||||
|
||||
let n = Arc::new(Int64Array::from(vec![Some(3), None]));
|
||||
|
||||
let args = create_args(vec![n]);
|
||||
let result = function.invoke_with_args(args).unwrap();
|
||||
|
||||
if let ColumnarValue::Array(array) = result {
|
||||
let str_array = array.as_string::<i64>();
|
||||
assert_eq!(str_array.value(0), " ");
|
||||
assert!(str_array.is_null(1));
|
||||
} else {
|
||||
panic!("Expected array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,7 @@ use std::sync::Arc;
|
||||
use common_catalog::consts::{
|
||||
DEFAULT_PRIVATE_SCHEMA_NAME, INFORMATION_SCHEMA_NAME, PG_CATALOG_NAME,
|
||||
};
|
||||
use datafusion::arrow::array::{ArrayRef, StringArray, as_boolean_array};
|
||||
use datafusion::arrow::array::{ArrayRef, StringArray, StringBuilder, as_boolean_array};
|
||||
use datafusion::catalog::TableFunction;
|
||||
use datafusion::common::ScalarValue;
|
||||
use datafusion::common::utils::SingleRowListArrayBuilder;
|
||||
@@ -34,10 +34,15 @@ const CURRENT_SCHEMA_FUNCTION_NAME: &str = "current_schema";
|
||||
const CURRENT_SCHEMAS_FUNCTION_NAME: &str = "current_schemas";
|
||||
const SESSION_USER_FUNCTION_NAME: &str = "session_user";
|
||||
const CURRENT_DATABASE_FUNCTION_NAME: &str = "current_database";
|
||||
const OBJ_DESCRIPTION_FUNCTION_NAME: &str = "obj_description";
|
||||
const COL_DESCRIPTION_FUNCTION_NAME: &str = "col_description";
|
||||
const SHOBJ_DESCRIPTION_FUNCTION_NAME: &str = "shobj_description";
|
||||
const PG_MY_TEMP_SCHEMA_FUNCTION_NAME: &str = "pg_my_temp_schema";
|
||||
|
||||
define_nullary_udf!(CurrentSchemaFunction);
|
||||
define_nullary_udf!(SessionUserFunction);
|
||||
define_nullary_udf!(CurrentDatabaseFunction);
|
||||
define_nullary_udf!(PgMyTempSchemaFunction);
|
||||
|
||||
impl Function for CurrentDatabaseFunction {
|
||||
fn name(&self) -> &str {
|
||||
@@ -173,6 +178,175 @@ impl Function for CurrentSchemasFunction {
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgreSQL obj_description - returns NULL for compatibility
|
||||
#[derive(Display, Debug, Clone)]
|
||||
#[display("{}", self.name())]
|
||||
pub(super) struct ObjDescriptionFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl ObjDescriptionFunction {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
signature: Signature::one_of(
|
||||
vec![
|
||||
TypeSignature::Exact(vec![DataType::Int64, DataType::Utf8]),
|
||||
TypeSignature::Exact(vec![DataType::UInt32, DataType::Utf8]),
|
||||
TypeSignature::Exact(vec![DataType::Int64]),
|
||||
TypeSignature::Exact(vec![DataType::UInt32]),
|
||||
],
|
||||
Volatility::Stable,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for ObjDescriptionFunction {
|
||||
fn name(&self) -> &str {
|
||||
OBJ_DESCRIPTION_FUNCTION_NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let num_rows = args.number_rows;
|
||||
let mut builder = StringBuilder::with_capacity(num_rows, 0);
|
||||
for _ in 0..num_rows {
|
||||
builder.append_null();
|
||||
}
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgreSQL col_description - returns NULL for compatibility
|
||||
#[derive(Display, Debug, Clone)]
|
||||
#[display("{}", self.name())]
|
||||
pub(super) struct ColDescriptionFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl ColDescriptionFunction {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
signature: Signature::one_of(
|
||||
vec![
|
||||
TypeSignature::Exact(vec![DataType::Int64, DataType::Int32]),
|
||||
TypeSignature::Exact(vec![DataType::UInt32, DataType::Int32]),
|
||||
TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]),
|
||||
TypeSignature::Exact(vec![DataType::UInt32, DataType::Int64]),
|
||||
],
|
||||
Volatility::Stable,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for ColDescriptionFunction {
|
||||
fn name(&self) -> &str {
|
||||
COL_DESCRIPTION_FUNCTION_NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let num_rows = args.number_rows;
|
||||
let mut builder = StringBuilder::with_capacity(num_rows, 0);
|
||||
for _ in 0..num_rows {
|
||||
builder.append_null();
|
||||
}
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgreSQL shobj_description - returns NULL for compatibility
|
||||
#[derive(Display, Debug, Clone)]
|
||||
#[display("{}", self.name())]
|
||||
pub(super) struct ShobjDescriptionFunction {
|
||||
signature: Signature,
|
||||
}
|
||||
|
||||
impl ShobjDescriptionFunction {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
signature: Signature::one_of(
|
||||
vec![
|
||||
TypeSignature::Exact(vec![DataType::Int64, DataType::Utf8]),
|
||||
TypeSignature::Exact(vec![DataType::UInt64, DataType::Utf8]),
|
||||
TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]),
|
||||
TypeSignature::Exact(vec![DataType::UInt32, DataType::Utf8]),
|
||||
],
|
||||
Volatility::Stable,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for ShobjDescriptionFunction {
|
||||
fn name(&self) -> &str {
|
||||
SHOBJ_DESCRIPTION_FUNCTION_NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::Utf8)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
let num_rows = args.number_rows;
|
||||
let mut builder = StringBuilder::with_capacity(num_rows, 0);
|
||||
for _ in 0..num_rows {
|
||||
builder.append_null();
|
||||
}
|
||||
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgreSQL pg_my_temp_schema - returns 0 (no temp schema) for compatibility
|
||||
impl Function for PgMyTempSchemaFunction {
|
||||
fn name(&self) -> &str {
|
||||
PG_MY_TEMP_SCHEMA_FUNCTION_NAME
|
||||
}
|
||||
|
||||
fn return_type(&self, _: &[DataType]) -> datafusion_common::Result<DataType> {
|
||||
Ok(DataType::UInt32)
|
||||
}
|
||||
|
||||
fn signature(&self) -> &Signature {
|
||||
&self.signature
|
||||
}
|
||||
|
||||
fn invoke_with_args(
|
||||
&self,
|
||||
_args: ScalarFunctionArgs,
|
||||
) -> datafusion_common::Result<ColumnarValue> {
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::UInt32(Some(0))))
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct PGCatalogFunction;
|
||||
|
||||
impl PGCatalogFunction {
|
||||
@@ -212,5 +386,100 @@ impl PGCatalogFunction {
|
||||
registry.register(pg_catalog::create_pg_total_relation_size_udf());
|
||||
registry.register(pg_catalog::create_pg_stat_get_numscans());
|
||||
registry.register(pg_catalog::create_pg_get_constraintdef());
|
||||
registry.register(pg_catalog::create_pg_get_partition_ancestors_udf());
|
||||
registry.register(pg_catalog::quote_ident_udf::create_quote_ident_udf());
|
||||
registry.register(pg_catalog::quote_ident_udf::create_parse_ident_udf());
|
||||
registry.register_scalar(ObjDescriptionFunction::new());
|
||||
registry.register_scalar(ColDescriptionFunction::new());
|
||||
registry.register_scalar(ShobjDescriptionFunction::new());
|
||||
registry.register_scalar(PgMyTempSchemaFunction::default());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Field;
|
||||
use datafusion::arrow::array::Array;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::ColumnarValue;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn create_test_args(args: Vec<ColumnarValue>, number_rows: usize) -> ScalarFunctionArgs {
|
||||
ScalarFunctionArgs {
|
||||
args,
|
||||
arg_fields: vec![],
|
||||
number_rows,
|
||||
return_field: Arc::new(Field::new("result", DataType::Utf8, true)),
|
||||
config_options: Arc::new(Default::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_obj_description_function() {
|
||||
let func = ObjDescriptionFunction::new();
|
||||
assert_eq!("obj_description", func.name());
|
||||
assert_eq!(DataType::Utf8, func.return_type(&[]).unwrap());
|
||||
|
||||
let args = create_test_args(
|
||||
vec![
|
||||
ColumnarValue::Scalar(ScalarValue::Int64(Some(1234))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("pg_class".to_string()))),
|
||||
],
|
||||
1,
|
||||
);
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
assert_eq!(1, arr.len());
|
||||
assert!(arr.is_null(0));
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_col_description_function() {
|
||||
let func = ColDescriptionFunction::new();
|
||||
assert_eq!("col_description", func.name());
|
||||
assert_eq!(DataType::Utf8, func.return_type(&[]).unwrap());
|
||||
|
||||
let args = create_test_args(
|
||||
vec![
|
||||
ColumnarValue::Scalar(ScalarValue::Int64(Some(1234))),
|
||||
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
|
||||
],
|
||||
1,
|
||||
);
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
assert_eq!(1, arr.len());
|
||||
assert!(arr.is_null(0));
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shobj_description_function() {
|
||||
let func = ShobjDescriptionFunction::new();
|
||||
assert_eq!("shobj_description", func.name());
|
||||
assert_eq!(DataType::Utf8, func.return_type(&[]).unwrap());
|
||||
|
||||
let args = create_test_args(
|
||||
vec![
|
||||
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
|
||||
ColumnarValue::Scalar(ScalarValue::Utf8(Some("pg_database".to_string()))),
|
||||
],
|
||||
1,
|
||||
);
|
||||
let result = func.invoke_with_args(args).unwrap();
|
||||
if let ColumnarValue::Array(arr) = result {
|
||||
assert_eq!(1, arr.len());
|
||||
assert!(arr.is_null(0));
|
||||
} else {
|
||||
panic!("Expected Array result");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ api.workspace = true
|
||||
arrow-flight.workspace = true
|
||||
bytes.workspace = true
|
||||
common-base.workspace = true
|
||||
common-config.workspace = true
|
||||
common-error.workspace = true
|
||||
common-macro.workspace = true
|
||||
common-recordbatch.workspace = true
|
||||
@@ -23,7 +24,6 @@ datatypes.workspace = true
|
||||
flatbuffers = "25.2"
|
||||
hyper.workspace = true
|
||||
lazy_static.workspace = true
|
||||
notify.workspace = true
|
||||
prost.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -38,11 +38,10 @@ pub enum Error {
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to watch config file path: {}", path))]
|
||||
#[snafu(display("Failed to watch config file"))]
|
||||
FileWatch {
|
||||
path: String,
|
||||
#[snafu(source)]
|
||||
error: notify::Error,
|
||||
source: common_config::error::Error,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
@@ -46,13 +46,16 @@ pub struct DoPutResponse {
|
||||
request_id: i64,
|
||||
/// The successfully ingested rows number.
|
||||
affected_rows: AffectedRows,
|
||||
/// The elapsed time in seconds for handling the bulk insert.
|
||||
elapsed_secs: f64,
|
||||
}
|
||||
|
||||
impl DoPutResponse {
|
||||
pub fn new(request_id: i64, affected_rows: AffectedRows) -> Self {
|
||||
pub fn new(request_id: i64, affected_rows: AffectedRows, elapsed_secs: f64) -> Self {
|
||||
Self {
|
||||
request_id,
|
||||
affected_rows,
|
||||
elapsed_secs,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,6 +66,10 @@ impl DoPutResponse {
|
||||
pub fn affected_rows(&self) -> AffectedRows {
|
||||
self.affected_rows
|
||||
}
|
||||
|
||||
pub fn elapsed_secs(&self) -> f64 {
|
||||
self.elapsed_secs
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<PutResult> for DoPutResponse {
|
||||
@@ -86,8 +93,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_serde_do_put_response() {
|
||||
let x = DoPutResponse::new(42, 88);
|
||||
let x = DoPutResponse::new(42, 88, 0.123);
|
||||
let serialized = serde_json::to_string(&x).unwrap();
|
||||
assert_eq!(serialized, r#"{"request_id":42,"affected_rows":88}"#);
|
||||
assert_eq!(
|
||||
serialized,
|
||||
r#"{"request_id":42,"affected_rows":88,"elapsed_secs":0.123}"#
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,11 +15,10 @@
|
||||
use std::path::Path;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::mpsc::channel;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use common_config::file_watcher::{FileWatcherBuilder, FileWatcherConfig};
|
||||
use common_telemetry::{error, info};
|
||||
use notify::{EventKind, RecursiveMode, Watcher};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{FileWatchSnafu, Result};
|
||||
@@ -119,45 +118,28 @@ where
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let watch_paths: Vec<_> = tls_config
|
||||
.get_tls_option()
|
||||
.watch_paths()
|
||||
.iter()
|
||||
.map(|p| p.to_path_buf())
|
||||
.collect();
|
||||
|
||||
let tls_config_for_watcher = tls_config.clone();
|
||||
|
||||
let (tx, rx) = channel::<notify::Result<notify::Event>>();
|
||||
let mut watcher = notify::recommended_watcher(tx).context(FileWatchSnafu { path: "<none>" })?;
|
||||
|
||||
// Watch all paths returned by the TlsConfigLoader
|
||||
for path in tls_config.get_tls_option().watch_paths() {
|
||||
watcher
|
||||
.watch(path, RecursiveMode::NonRecursive)
|
||||
.with_context(|_| FileWatchSnafu {
|
||||
path: path.display().to_string(),
|
||||
})?;
|
||||
}
|
||||
|
||||
info!("Spawning background task for watching TLS cert/key file changes");
|
||||
std::thread::spawn(move || {
|
||||
let _watcher = watcher;
|
||||
loop {
|
||||
match rx.recv() {
|
||||
Ok(Ok(event)) => {
|
||||
if let EventKind::Modify(_) | EventKind::Create(_) = event.kind {
|
||||
info!("Detected TLS cert/key file change: {:?}", event);
|
||||
if let Err(err) = tls_config_for_watcher.reload() {
|
||||
error!("Failed to reload TLS config: {}", err);
|
||||
} else {
|
||||
info!("Reloaded TLS cert/key file successfully.");
|
||||
on_reload();
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
error!("Failed to watch TLS cert/key file: {}", err);
|
||||
}
|
||||
Err(err) => {
|
||||
error!("TLS cert/key file watcher channel closed: {}", err);
|
||||
}
|
||||
FileWatcherBuilder::new()
|
||||
.watch_paths(&watch_paths)
|
||||
.context(FileWatchSnafu)?
|
||||
.config(FileWatcherConfig::new())
|
||||
.spawn(move || {
|
||||
if let Err(err) = tls_config_for_watcher.reload() {
|
||||
error!("Failed to reload TLS config: {}", err);
|
||||
} else {
|
||||
info!("Reloaded TLS cert/key file successfully.");
|
||||
on_reload();
|
||||
}
|
||||
}
|
||||
});
|
||||
})
|
||||
.context(FileWatchSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
20
src/common/memory-manager/Cargo.toml
Normal file
20
src/common/memory-manager/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "common-memory-manager"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
common-error = { workspace = true }
|
||||
common-macro = { workspace = true }
|
||||
common-telemetry = { workspace = true }
|
||||
humantime = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["rt", "macros"] }
|
||||
63
src/common/memory-manager/src/error.rs
Normal file
63
src/common/memory-manager/src/error.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
use std::time::Duration;
|
||||
|
||||
use common_error::ext::ErrorExt;
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_macro::stack_trace_debug;
|
||||
use snafu::Snafu;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
#[stack_trace_debug]
|
||||
pub enum Error {
|
||||
#[snafu(display(
|
||||
"Memory limit exceeded: requested {requested_bytes} bytes, limit {limit_bytes} bytes"
|
||||
))]
|
||||
MemoryLimitExceeded {
|
||||
requested_bytes: u64,
|
||||
limit_bytes: u64,
|
||||
},
|
||||
|
||||
#[snafu(display("Memory semaphore unexpectedly closed"))]
|
||||
MemorySemaphoreClosed,
|
||||
|
||||
#[snafu(display(
|
||||
"Timeout waiting for memory quota: requested {requested_bytes} bytes, waited {waited:?}"
|
||||
))]
|
||||
MemoryAcquireTimeout {
|
||||
requested_bytes: u64,
|
||||
waited: Duration,
|
||||
},
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
use Error::*;
|
||||
|
||||
match self {
|
||||
MemoryLimitExceeded { .. } => StatusCode::RuntimeResourcesExhausted,
|
||||
MemorySemaphoreClosed => StatusCode::Unexpected,
|
||||
MemoryAcquireTimeout { .. } => StatusCode::RuntimeResourcesExhausted,
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
168
src/common/memory-manager/src/granularity.rs
Normal file
168
src/common/memory-manager/src/granularity.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Memory permit granularity for different use cases.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum PermitGranularity {
|
||||
/// 1 KB per permit
|
||||
///
|
||||
/// Use for:
|
||||
/// - HTTP/gRPC request limiting (small, high-concurrency operations)
|
||||
/// - Small batch operations
|
||||
/// - Scenarios requiring fine-grained fairness
|
||||
Kilobyte,
|
||||
|
||||
/// 1 MB per permit (default)
|
||||
///
|
||||
/// Use for:
|
||||
/// - Query execution memory management
|
||||
/// - Compaction memory control
|
||||
/// - Large, long-running operations
|
||||
#[default]
|
||||
Megabyte,
|
||||
}
|
||||
|
||||
impl PermitGranularity {
|
||||
/// Returns the number of bytes per permit.
|
||||
#[inline]
|
||||
pub const fn bytes(self) -> u64 {
|
||||
match self {
|
||||
Self::Kilobyte => 1024,
|
||||
Self::Megabyte => 1024 * 1024,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a human-readable string representation.
|
||||
pub const fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Kilobyte => "1KB",
|
||||
Self::Megabyte => "1MB",
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts bytes to permits based on this granularity.
|
||||
///
|
||||
/// Rounds up to ensure the requested bytes are fully covered.
|
||||
/// Clamped to Semaphore::MAX_PERMITS.
|
||||
#[inline]
|
||||
pub fn bytes_to_permits(self, bytes: u64) -> u32 {
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
let granularity_bytes = self.bytes();
|
||||
bytes
|
||||
.saturating_add(granularity_bytes - 1)
|
||||
.saturating_div(granularity_bytes)
|
||||
.min(Semaphore::MAX_PERMITS as u64)
|
||||
.min(u32::MAX as u64) as u32
|
||||
}
|
||||
|
||||
/// Converts permits to bytes based on this granularity.
|
||||
#[inline]
|
||||
pub fn permits_to_bytes(self, permits: u32) -> u64 {
|
||||
(permits as u64).saturating_mul(self.bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PermitGranularity {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_bytes_to_permits_kilobyte() {
|
||||
let granularity = PermitGranularity::Kilobyte;
|
||||
|
||||
// Exact multiples
|
||||
assert_eq!(granularity.bytes_to_permits(1024), 1);
|
||||
assert_eq!(granularity.bytes_to_permits(2048), 2);
|
||||
assert_eq!(granularity.bytes_to_permits(10 * 1024), 10);
|
||||
|
||||
// Rounds up
|
||||
assert_eq!(granularity.bytes_to_permits(1), 1);
|
||||
assert_eq!(granularity.bytes_to_permits(1025), 2);
|
||||
assert_eq!(granularity.bytes_to_permits(2047), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bytes_to_permits_megabyte() {
|
||||
let granularity = PermitGranularity::Megabyte;
|
||||
|
||||
// Exact multiples
|
||||
assert_eq!(granularity.bytes_to_permits(1024 * 1024), 1);
|
||||
assert_eq!(granularity.bytes_to_permits(2 * 1024 * 1024), 2);
|
||||
|
||||
// Rounds up
|
||||
assert_eq!(granularity.bytes_to_permits(1), 1);
|
||||
assert_eq!(granularity.bytes_to_permits(1024), 1);
|
||||
assert_eq!(granularity.bytes_to_permits(1024 * 1024 + 1), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bytes_to_permits_zero_bytes() {
|
||||
assert_eq!(PermitGranularity::Kilobyte.bytes_to_permits(0), 0);
|
||||
assert_eq!(PermitGranularity::Megabyte.bytes_to_permits(0), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bytes_to_permits_clamps_to_maximum() {
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
let max_permits = (Semaphore::MAX_PERMITS as u64).min(u32::MAX as u64) as u32;
|
||||
|
||||
assert_eq!(
|
||||
PermitGranularity::Kilobyte.bytes_to_permits(u64::MAX),
|
||||
max_permits
|
||||
);
|
||||
assert_eq!(
|
||||
PermitGranularity::Megabyte.bytes_to_permits(u64::MAX),
|
||||
max_permits
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permits_to_bytes() {
|
||||
assert_eq!(PermitGranularity::Kilobyte.permits_to_bytes(1), 1024);
|
||||
assert_eq!(PermitGranularity::Kilobyte.permits_to_bytes(10), 10 * 1024);
|
||||
|
||||
assert_eq!(PermitGranularity::Megabyte.permits_to_bytes(1), 1024 * 1024);
|
||||
assert_eq!(
|
||||
PermitGranularity::Megabyte.permits_to_bytes(10),
|
||||
10 * 1024 * 1024
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_trip_conversion() {
|
||||
// Kilobyte: bytes -> permits -> bytes (should round up)
|
||||
let kb = PermitGranularity::Kilobyte;
|
||||
let permits = kb.bytes_to_permits(1500);
|
||||
let bytes = kb.permits_to_bytes(permits);
|
||||
assert!(bytes >= 1500); // Must cover original request
|
||||
assert_eq!(bytes, 2048); // 2KB
|
||||
|
||||
// Megabyte: bytes -> permits -> bytes (should round up)
|
||||
let mb = PermitGranularity::Megabyte;
|
||||
let permits = mb.bytes_to_permits(1500);
|
||||
let bytes = mb.permits_to_bytes(permits);
|
||||
assert!(bytes >= 1500);
|
||||
assert_eq!(bytes, 1024 * 1024); // 1MB
|
||||
}
|
||||
}
|
||||
231
src/common/memory-manager/src/guard.rs
Normal file
231
src/common/memory-manager/src/guard.rs
Normal file
@@ -0,0 +1,231 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{fmt, mem};
|
||||
|
||||
use common_telemetry::debug;
|
||||
use snafu::ensure;
|
||||
use tokio::sync::{OwnedSemaphorePermit, TryAcquireError};
|
||||
|
||||
use crate::error::{
|
||||
MemoryAcquireTimeoutSnafu, MemoryLimitExceededSnafu, MemorySemaphoreClosedSnafu, Result,
|
||||
};
|
||||
use crate::manager::{MemoryMetrics, MemoryQuota};
|
||||
use crate::policy::OnExhaustedPolicy;
|
||||
|
||||
/// Guard representing a slice of reserved memory.
|
||||
pub struct MemoryGuard<M: MemoryMetrics> {
|
||||
pub(crate) state: GuardState<M>,
|
||||
}
|
||||
|
||||
pub(crate) enum GuardState<M: MemoryMetrics> {
|
||||
Unlimited,
|
||||
Limited {
|
||||
permit: OwnedSemaphorePermit,
|
||||
quota: MemoryQuota<M>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<M: MemoryMetrics> MemoryGuard<M> {
|
||||
pub(crate) fn unlimited() -> Self {
|
||||
Self {
|
||||
state: GuardState::Unlimited,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn limited(permit: OwnedSemaphorePermit, quota: MemoryQuota<M>) -> Self {
|
||||
Self {
|
||||
state: GuardState::Limited { permit, quota },
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns granted quota in bytes.
|
||||
pub fn granted_bytes(&self) -> u64 {
|
||||
match &self.state {
|
||||
GuardState::Unlimited => 0,
|
||||
GuardState::Limited { permit, quota } => {
|
||||
quota.permits_to_bytes(permit.num_permits() as u32)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquires additional memory, waiting if necessary until enough is available.
|
||||
///
|
||||
/// On success, merges the new memory into this guard.
|
||||
///
|
||||
/// # Errors
|
||||
/// - Returns error if requested bytes would exceed the manager's total limit
|
||||
/// - Returns error if the semaphore is unexpectedly closed
|
||||
pub async fn acquire_additional(&mut self, bytes: u64) -> Result<()> {
|
||||
match &mut self.state {
|
||||
GuardState::Unlimited => Ok(()),
|
||||
GuardState::Limited { permit, quota } => {
|
||||
if bytes == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let additional_permits = quota.bytes_to_permits(bytes);
|
||||
let current_permits = permit.num_permits() as u32;
|
||||
|
||||
ensure!(
|
||||
current_permits.saturating_add(additional_permits) <= quota.limit_permits,
|
||||
MemoryLimitExceededSnafu {
|
||||
requested_bytes: bytes,
|
||||
limit_bytes: quota.permits_to_bytes(quota.limit_permits)
|
||||
}
|
||||
);
|
||||
|
||||
let additional_permit = quota
|
||||
.semaphore
|
||||
.clone()
|
||||
.acquire_many_owned(additional_permits)
|
||||
.await
|
||||
.map_err(|_| MemorySemaphoreClosedSnafu.build())?;
|
||||
|
||||
permit.merge(additional_permit);
|
||||
quota.update_in_use_metric();
|
||||
debug!("Acquired additional {} bytes", bytes);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tries to acquire additional memory without waiting.
|
||||
///
|
||||
/// On success, merges the new memory into this guard and returns true.
|
||||
/// On failure, returns false and leaves this guard unchanged.
|
||||
pub fn try_acquire_additional(&mut self, bytes: u64) -> bool {
|
||||
match &mut self.state {
|
||||
GuardState::Unlimited => true,
|
||||
GuardState::Limited { permit, quota } => {
|
||||
if bytes == 0 {
|
||||
return true;
|
||||
}
|
||||
|
||||
let additional_permits = quota.bytes_to_permits(bytes);
|
||||
|
||||
match quota
|
||||
.semaphore
|
||||
.clone()
|
||||
.try_acquire_many_owned(additional_permits)
|
||||
{
|
||||
Ok(additional_permit) => {
|
||||
permit.merge(additional_permit);
|
||||
quota.update_in_use_metric();
|
||||
debug!("Acquired additional {} bytes", bytes);
|
||||
true
|
||||
}
|
||||
Err(TryAcquireError::NoPermits) | Err(TryAcquireError::Closed) => {
|
||||
quota.metrics.inc_rejected("try_acquire_additional");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquires additional memory based on the given policy.
|
||||
///
|
||||
/// - For `OnExhaustedPolicy::Wait`: Waits up to the timeout duration for memory to become available
|
||||
/// - For `OnExhaustedPolicy::Fail`: Returns immediately if memory is not available
|
||||
///
|
||||
/// # Errors
|
||||
/// - `MemoryLimitExceeded`: Requested bytes would exceed the total limit (both policies), or memory is currently exhausted (Fail policy only)
|
||||
/// - `MemoryAcquireTimeout`: Timeout elapsed while waiting for memory (Wait policy only)
|
||||
/// - `MemorySemaphoreClosed`: The internal semaphore is unexpectedly closed (rare, indicates system issue)
|
||||
pub async fn acquire_additional_with_policy(
|
||||
&mut self,
|
||||
bytes: u64,
|
||||
policy: OnExhaustedPolicy,
|
||||
) -> Result<()> {
|
||||
match policy {
|
||||
OnExhaustedPolicy::Wait { timeout } => {
|
||||
match tokio::time::timeout(timeout, self.acquire_additional(bytes)).await {
|
||||
Ok(Ok(())) => Ok(()),
|
||||
Ok(Err(e)) => Err(e),
|
||||
Err(_elapsed) => MemoryAcquireTimeoutSnafu {
|
||||
requested_bytes: bytes,
|
||||
waited: timeout,
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
OnExhaustedPolicy::Fail => {
|
||||
if self.try_acquire_additional(bytes) {
|
||||
Ok(())
|
||||
} else {
|
||||
MemoryLimitExceededSnafu {
|
||||
requested_bytes: bytes,
|
||||
limit_bytes: match &self.state {
|
||||
GuardState::Unlimited => 0, // unreachable: unlimited mode always succeeds
|
||||
GuardState::Limited { quota, .. } => {
|
||||
quota.permits_to_bytes(quota.limit_permits)
|
||||
}
|
||||
},
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Releases a portion of granted memory back to the pool before the guard is dropped.
|
||||
///
|
||||
/// Returns true if the release succeeds or is a no-op; false if the request exceeds granted.
|
||||
pub fn release_partial(&mut self, bytes: u64) -> bool {
|
||||
match &mut self.state {
|
||||
GuardState::Unlimited => true,
|
||||
GuardState::Limited { permit, quota } => {
|
||||
if bytes == 0 {
|
||||
return true;
|
||||
}
|
||||
|
||||
let release_permits = quota.bytes_to_permits(bytes);
|
||||
|
||||
match permit.split(release_permits as usize) {
|
||||
Some(released_permit) => {
|
||||
let released_bytes =
|
||||
quota.permits_to_bytes(released_permit.num_permits() as u32);
|
||||
drop(released_permit);
|
||||
quota.update_in_use_metric();
|
||||
debug!("Released {} bytes from memory guard", released_bytes);
|
||||
true
|
||||
}
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: MemoryMetrics> Drop for MemoryGuard<M> {
|
||||
fn drop(&mut self) {
|
||||
if let GuardState::Limited { permit, quota } =
|
||||
mem::replace(&mut self.state, GuardState::Unlimited)
|
||||
{
|
||||
let bytes = quota.permits_to_bytes(permit.num_permits() as u32);
|
||||
drop(permit);
|
||||
quota.update_in_use_metric();
|
||||
debug!("Released memory: {} bytes", bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: MemoryMetrics> fmt::Debug for MemoryGuard<M> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("MemoryGuard")
|
||||
.field("granted_bytes", &self.granted_bytes())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
49
src/common/memory-manager/src/lib.rs
Normal file
49
src/common/memory-manager/src/lib.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Generic memory management for resource-constrained operations.
|
||||
//!
|
||||
//! This crate provides a reusable memory quota system based on semaphores,
|
||||
//! allowing different subsystems (compaction, flush, index build, etc.) to
|
||||
//! share the same allocation logic while using their own metrics.
|
||||
|
||||
mod error;
|
||||
mod granularity;
|
||||
mod guard;
|
||||
mod manager;
|
||||
mod policy;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use error::{Error, Result};
|
||||
pub use granularity::PermitGranularity;
|
||||
pub use guard::MemoryGuard;
|
||||
pub use manager::{MemoryManager, MemoryMetrics};
|
||||
pub use policy::{DEFAULT_MEMORY_WAIT_TIMEOUT, OnExhaustedPolicy};
|
||||
|
||||
/// No-op metrics implementation for testing.
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct NoOpMetrics;
|
||||
|
||||
impl MemoryMetrics for NoOpMetrics {
|
||||
#[inline(always)]
|
||||
fn set_limit(&self, _: i64) {}
|
||||
|
||||
#[inline(always)]
|
||||
fn set_in_use(&self, _: i64) {}
|
||||
|
||||
#[inline(always)]
|
||||
fn inc_rejected(&self, _: &str) {}
|
||||
}
|
||||
222
src/common/memory-manager/src/manager.rs
Normal file
222
src/common/memory-manager/src/manager.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use snafu::ensure;
|
||||
use tokio::sync::{Semaphore, TryAcquireError};
|
||||
|
||||
use crate::error::{
|
||||
MemoryAcquireTimeoutSnafu, MemoryLimitExceededSnafu, MemorySemaphoreClosedSnafu, Result,
|
||||
};
|
||||
use crate::granularity::PermitGranularity;
|
||||
use crate::guard::MemoryGuard;
|
||||
use crate::policy::OnExhaustedPolicy;
|
||||
|
||||
/// Trait for recording memory usage metrics.
|
||||
pub trait MemoryMetrics: Clone + Send + Sync + 'static {
|
||||
fn set_limit(&self, bytes: i64);
|
||||
fn set_in_use(&self, bytes: i64);
|
||||
fn inc_rejected(&self, reason: &str);
|
||||
}
|
||||
|
||||
/// Generic memory manager for quota-controlled operations.
|
||||
#[derive(Clone)]
|
||||
pub struct MemoryManager<M: MemoryMetrics> {
|
||||
quota: Option<MemoryQuota<M>>,
|
||||
}
|
||||
|
||||
impl<M: MemoryMetrics + Default> Default for MemoryManager<M> {
|
||||
fn default() -> Self {
|
||||
Self::new(0, M::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MemoryQuota<M: MemoryMetrics> {
|
||||
pub(crate) semaphore: Arc<Semaphore>,
|
||||
pub(crate) limit_permits: u32,
|
||||
pub(crate) granularity: PermitGranularity,
|
||||
pub(crate) metrics: M,
|
||||
}
|
||||
|
||||
impl<M: MemoryMetrics> MemoryManager<M> {
|
||||
/// Creates a new memory manager with the given limit in bytes.
|
||||
/// `limit_bytes = 0` disables the limit.
|
||||
pub fn new(limit_bytes: u64, metrics: M) -> Self {
|
||||
Self::with_granularity(limit_bytes, PermitGranularity::default(), metrics)
|
||||
}
|
||||
|
||||
/// Creates a new memory manager with specified granularity.
|
||||
pub fn with_granularity(limit_bytes: u64, granularity: PermitGranularity, metrics: M) -> Self {
|
||||
if limit_bytes == 0 {
|
||||
metrics.set_limit(0);
|
||||
return Self { quota: None };
|
||||
}
|
||||
|
||||
let limit_permits = granularity.bytes_to_permits(limit_bytes);
|
||||
let limit_aligned_bytes = granularity.permits_to_bytes(limit_permits);
|
||||
metrics.set_limit(limit_aligned_bytes as i64);
|
||||
|
||||
Self {
|
||||
quota: Some(MemoryQuota {
|
||||
semaphore: Arc::new(Semaphore::new(limit_permits as usize)),
|
||||
limit_permits,
|
||||
granularity,
|
||||
metrics,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the configured limit in bytes (0 if unlimited).
|
||||
pub fn limit_bytes(&self) -> u64 {
|
||||
self.quota
|
||||
.as_ref()
|
||||
.map(|quota| quota.permits_to_bytes(quota.limit_permits))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Returns currently used bytes.
|
||||
pub fn used_bytes(&self) -> u64 {
|
||||
self.quota
|
||||
.as_ref()
|
||||
.map(|quota| quota.permits_to_bytes(quota.used_permits()))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Returns available bytes.
|
||||
pub fn available_bytes(&self) -> u64 {
|
||||
self.quota
|
||||
.as_ref()
|
||||
.map(|quota| quota.permits_to_bytes(quota.available_permits_clamped()))
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Acquires memory, waiting if necessary until enough is available.
|
||||
///
|
||||
/// # Errors
|
||||
/// - Returns error if requested bytes exceed the total limit
|
||||
/// - Returns error if the semaphore is unexpectedly closed
|
||||
pub async fn acquire(&self, bytes: u64) -> Result<MemoryGuard<M>> {
|
||||
match &self.quota {
|
||||
None => Ok(MemoryGuard::unlimited()),
|
||||
Some(quota) => {
|
||||
let permits = quota.bytes_to_permits(bytes);
|
||||
|
||||
ensure!(
|
||||
permits <= quota.limit_permits,
|
||||
MemoryLimitExceededSnafu {
|
||||
requested_bytes: bytes,
|
||||
limit_bytes: self.limit_bytes()
|
||||
}
|
||||
);
|
||||
|
||||
let permit = quota
|
||||
.semaphore
|
||||
.clone()
|
||||
.acquire_many_owned(permits)
|
||||
.await
|
||||
.map_err(|_| MemorySemaphoreClosedSnafu.build())?;
|
||||
quota.update_in_use_metric();
|
||||
Ok(MemoryGuard::limited(permit, quota.clone()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tries to acquire memory. Returns Some(guard) on success, None if insufficient.
|
||||
pub fn try_acquire(&self, bytes: u64) -> Option<MemoryGuard<M>> {
|
||||
match &self.quota {
|
||||
None => Some(MemoryGuard::unlimited()),
|
||||
Some(quota) => {
|
||||
let permits = quota.bytes_to_permits(bytes);
|
||||
|
||||
match quota.semaphore.clone().try_acquire_many_owned(permits) {
|
||||
Ok(permit) => {
|
||||
quota.update_in_use_metric();
|
||||
Some(MemoryGuard::limited(permit, quota.clone()))
|
||||
}
|
||||
Err(TryAcquireError::NoPermits) | Err(TryAcquireError::Closed) => {
|
||||
quota.metrics.inc_rejected("try_acquire");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquires memory based on the given policy.
|
||||
///
|
||||
/// - For `OnExhaustedPolicy::Wait`: Waits up to the timeout duration for memory to become available
|
||||
/// - For `OnExhaustedPolicy::Fail`: Returns immediately if memory is not available
|
||||
///
|
||||
/// # Errors
|
||||
/// - `MemoryLimitExceeded`: Requested bytes exceed the total limit (both policies), or memory is currently exhausted (Fail policy only)
|
||||
/// - `MemoryAcquireTimeout`: Timeout elapsed while waiting for memory (Wait policy only)
|
||||
/// - `MemorySemaphoreClosed`: The internal semaphore is unexpectedly closed (rare, indicates system issue)
|
||||
pub async fn acquire_with_policy(
|
||||
&self,
|
||||
bytes: u64,
|
||||
policy: OnExhaustedPolicy,
|
||||
) -> Result<MemoryGuard<M>> {
|
||||
match policy {
|
||||
OnExhaustedPolicy::Wait { timeout } => {
|
||||
match tokio::time::timeout(timeout, self.acquire(bytes)).await {
|
||||
Ok(Ok(guard)) => Ok(guard),
|
||||
Ok(Err(e)) => Err(e),
|
||||
Err(_elapsed) => {
|
||||
// Timeout elapsed while waiting
|
||||
MemoryAcquireTimeoutSnafu {
|
||||
requested_bytes: bytes,
|
||||
waited: timeout,
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
OnExhaustedPolicy::Fail => self.try_acquire(bytes).ok_or_else(|| {
|
||||
MemoryLimitExceededSnafu {
|
||||
requested_bytes: bytes,
|
||||
limit_bytes: self.limit_bytes(),
|
||||
}
|
||||
.build()
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: MemoryMetrics> MemoryQuota<M> {
|
||||
pub(crate) fn bytes_to_permits(&self, bytes: u64) -> u32 {
|
||||
self.granularity.bytes_to_permits(bytes)
|
||||
}
|
||||
|
||||
pub(crate) fn permits_to_bytes(&self, permits: u32) -> u64 {
|
||||
self.granularity.permits_to_bytes(permits)
|
||||
}
|
||||
|
||||
pub(crate) fn used_permits(&self) -> u32 {
|
||||
self.limit_permits
|
||||
.saturating_sub(self.available_permits_clamped())
|
||||
}
|
||||
|
||||
pub(crate) fn available_permits_clamped(&self) -> u32 {
|
||||
self.semaphore
|
||||
.available_permits()
|
||||
.min(self.limit_permits as usize) as u32
|
||||
}
|
||||
|
||||
pub(crate) fn update_in_use_metric(&self) {
|
||||
let bytes = self.permits_to_bytes(self.used_permits());
|
||||
self.metrics.set_in_use(bytes as i64);
|
||||
}
|
||||
}
|
||||
83
src/common/memory-manager/src/policy.rs
Normal file
83
src/common/memory-manager/src/policy.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use humantime::{format_duration, parse_duration};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Default wait timeout for memory acquisition.
|
||||
pub const DEFAULT_MEMORY_WAIT_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Defines how to react when memory cannot be acquired immediately.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum OnExhaustedPolicy {
|
||||
/// Wait until enough memory is released, bounded by timeout.
|
||||
Wait { timeout: Duration },
|
||||
|
||||
/// Fail immediately if memory is not available.
|
||||
Fail,
|
||||
}
|
||||
|
||||
impl Default for OnExhaustedPolicy {
|
||||
fn default() -> Self {
|
||||
OnExhaustedPolicy::Wait {
|
||||
timeout: DEFAULT_MEMORY_WAIT_TIMEOUT,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for OnExhaustedPolicy {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let text = match self {
|
||||
OnExhaustedPolicy::Fail => "fail".to_string(),
|
||||
OnExhaustedPolicy::Wait { timeout } if *timeout == DEFAULT_MEMORY_WAIT_TIMEOUT => {
|
||||
"wait".to_string()
|
||||
}
|
||||
OnExhaustedPolicy::Wait { timeout } => format!("wait({})", format_duration(*timeout)),
|
||||
};
|
||||
serializer.serialize_str(&text)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for OnExhaustedPolicy {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let raw = String::deserialize(deserializer)?;
|
||||
let lower = raw.to_ascii_lowercase();
|
||||
|
||||
// Accept both "skip" (legacy) and "fail".
|
||||
if lower == "skip" || lower == "fail" {
|
||||
return Ok(OnExhaustedPolicy::Fail);
|
||||
}
|
||||
if lower == "wait" {
|
||||
return Ok(OnExhaustedPolicy::default());
|
||||
}
|
||||
if lower.starts_with("wait(") && lower.ends_with(')') {
|
||||
let inner = &raw[5..raw.len() - 1];
|
||||
let timeout = parse_duration(inner).map_err(serde::de::Error::custom)?;
|
||||
return Ok(OnExhaustedPolicy::Wait { timeout });
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom(format!(
|
||||
"invalid memory policy: {}, expected wait, wait(<duration>), fail",
|
||||
raw
|
||||
)))
|
||||
}
|
||||
}
|
||||
411
src/common/memory-manager/src/tests.rs
Normal file
411
src/common/memory-manager/src/tests.rs
Normal file
@@ -0,0 +1,411 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use tokio::time::{Duration, sleep};
|
||||
|
||||
use crate::{MemoryManager, NoOpMetrics, PermitGranularity};
|
||||
|
||||
// Helper constant for tests - use default Megabyte granularity
|
||||
const PERMIT_GRANULARITY_BYTES: u64 = PermitGranularity::Megabyte.bytes();
|
||||
|
||||
#[test]
|
||||
fn test_try_acquire_unlimited() {
|
||||
let manager = MemoryManager::new(0, NoOpMetrics);
|
||||
let guard = manager.try_acquire(10 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
assert_eq!(manager.limit_bytes(), 0);
|
||||
assert_eq!(guard.granted_bytes(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_acquire_limited_success_and_release() {
|
||||
let bytes = 2 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(bytes, NoOpMetrics);
|
||||
{
|
||||
let guard = manager.try_acquire(PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
assert_eq!(guard.granted_bytes(), PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), PERMIT_GRANULARITY_BYTES);
|
||||
drop(guard);
|
||||
}
|
||||
assert_eq!(manager.used_bytes(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_acquire_exceeds_limit() {
|
||||
let limit = PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
let result = manager.try_acquire(limit + PERMIT_GRANULARITY_BYTES);
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_acquire_blocks_and_unblocks() {
|
||||
let bytes = 2 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(bytes, NoOpMetrics);
|
||||
let guard = manager.try_acquire(bytes).unwrap();
|
||||
|
||||
// Spawn a task that will block on acquire()
|
||||
let waiter = {
|
||||
let manager = manager.clone();
|
||||
tokio::spawn(async move {
|
||||
// This will block until memory is available
|
||||
let _guard = manager.acquire(bytes).await.unwrap();
|
||||
})
|
||||
};
|
||||
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
// Release memory - this should unblock the waiter
|
||||
drop(guard);
|
||||
|
||||
// Waiter should complete now
|
||||
waiter.await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_additional_success() {
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES; // 10MB limit
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
// Acquire base quota (5MB)
|
||||
let base = 5 * PERMIT_GRANULARITY_BYTES;
|
||||
let mut guard = manager.try_acquire(base).unwrap();
|
||||
assert_eq!(guard.granted_bytes(), base);
|
||||
assert_eq!(manager.used_bytes(), base);
|
||||
|
||||
// Request additional memory (3MB) - should succeed and merge
|
||||
assert!(guard.try_acquire_additional(3 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(guard.granted_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_additional_exceeds_limit() {
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES; // 10MB limit
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
// Acquire base quota (5MB)
|
||||
let base = 5 * PERMIT_GRANULARITY_BYTES;
|
||||
let mut guard = manager.try_acquire(base).unwrap();
|
||||
|
||||
// Request additional memory (3MB) - should succeed
|
||||
assert!(guard.try_acquire_additional(3 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(manager.used_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// Request more (3MB) - should fail (would exceed 10MB limit)
|
||||
let result = guard.try_acquire_additional(3 * PERMIT_GRANULARITY_BYTES);
|
||||
assert!(!result);
|
||||
|
||||
// Still at 8MB
|
||||
assert_eq!(manager.used_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(guard.granted_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_additional_auto_release_on_guard_drop() {
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
{
|
||||
let mut guard = manager.try_acquire(5 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
// Request additional - memory is merged into guard
|
||||
assert!(guard.try_acquire_additional(3 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(manager.used_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// When guard drops, all memory (base + additional) is released together
|
||||
}
|
||||
|
||||
// After scope, all memory should be released
|
||||
assert_eq!(manager.used_bytes(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_additional_unlimited() {
|
||||
let manager = MemoryManager::new(0, NoOpMetrics); // Unlimited
|
||||
let mut guard = manager.try_acquire(5 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
// Should always succeed with unlimited manager
|
||||
assert!(guard.try_acquire_additional(100 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(guard.granted_bytes(), 0);
|
||||
assert_eq!(manager.used_bytes(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_additional_zero_bytes() {
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
let mut guard = manager.try_acquire(5 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
// Request 0 bytes should succeed without affecting anything
|
||||
assert!(guard.try_acquire_additional(0));
|
||||
assert_eq!(guard.granted_bytes(), 5 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 5 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_early_release_partial_success() {
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
let mut guard = manager.try_acquire(8 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
assert_eq!(manager.used_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// Release half
|
||||
assert!(guard.release_partial(4 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(guard.granted_bytes(), 4 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 4 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// Released memory should be available to others
|
||||
let _guard2 = manager.try_acquire(4 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
assert_eq!(manager.used_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_early_release_partial_exceeds_granted() {
|
||||
let manager = MemoryManager::new(10 * PERMIT_GRANULARITY_BYTES, NoOpMetrics);
|
||||
let mut guard = manager.try_acquire(5 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
// Try to release more than granted - should fail
|
||||
assert!(!guard.release_partial(10 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(guard.granted_bytes(), 5 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 5 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_early_release_partial_unlimited() {
|
||||
let manager = MemoryManager::new(0, NoOpMetrics);
|
||||
let mut guard = manager.try_acquire(100 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
// Unlimited guard - release should succeed (no-op)
|
||||
assert!(guard.release_partial(50 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(guard.granted_bytes(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_and_early_release_symmetry() {
|
||||
let limit = 20 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
let mut guard = manager.try_acquire(5 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
// Request additional
|
||||
assert!(guard.try_acquire_additional(5 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(guard.granted_bytes(), 10 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 10 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// Early release some
|
||||
assert!(guard.release_partial(3 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(guard.granted_bytes(), 7 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 7 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// Request again
|
||||
assert!(guard.try_acquire_additional(2 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(guard.granted_bytes(), 9 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 9 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// Early release again
|
||||
assert!(guard.release_partial(4 * PERMIT_GRANULARITY_BYTES));
|
||||
assert_eq!(guard.granted_bytes(), 5 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 5 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
drop(guard);
|
||||
assert_eq!(manager.used_bytes(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_small_allocation_rounds_up() {
|
||||
// Test that allocations smaller than PERMIT_GRANULARITY_BYTES
|
||||
// round up to 1 permit and can use try_acquire_additional()
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
let mut guard = manager.try_acquire(512 * 1024).unwrap(); // 512KB
|
||||
assert_eq!(guard.granted_bytes(), PERMIT_GRANULARITY_BYTES); // Rounds up to 1MB
|
||||
assert!(guard.try_acquire_additional(2 * PERMIT_GRANULARITY_BYTES)); // Can request more
|
||||
assert_eq!(guard.granted_bytes(), 3 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_acquire_zero_bytes_lazy_allocation() {
|
||||
// Test that acquire(0) returns 0 permits but can try_acquire_additional() later
|
||||
let manager = MemoryManager::new(10 * PERMIT_GRANULARITY_BYTES, NoOpMetrics);
|
||||
|
||||
let mut guard = manager.try_acquire(0).unwrap();
|
||||
assert_eq!(guard.granted_bytes(), 0); // No permits consumed
|
||||
assert_eq!(manager.used_bytes(), 0);
|
||||
|
||||
assert!(guard.try_acquire_additional(3 * PERMIT_GRANULARITY_BYTES)); // Lazy allocation
|
||||
assert_eq!(guard.granted_bytes(), 3 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_acquire_additional_blocks_and_unblocks() {
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
// First guard takes 9MB, leaving only 1MB available
|
||||
let mut guard1 = manager.try_acquire(9 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
assert_eq!(manager.used_bytes(), 9 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// Spawn a task that will block trying to acquire additional 5MB (needs total 10MB available)
|
||||
let manager_clone = manager.clone();
|
||||
let waiter = tokio::spawn(async move {
|
||||
let mut guard2 = manager_clone.try_acquire(0).unwrap();
|
||||
// This will block until enough memory is available
|
||||
guard2
|
||||
.acquire_additional(5 * PERMIT_GRANULARITY_BYTES)
|
||||
.await
|
||||
.unwrap();
|
||||
guard2
|
||||
});
|
||||
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
|
||||
// Release 5MB from guard1 - this should unblock the waiter
|
||||
assert!(guard1.release_partial(5 * PERMIT_GRANULARITY_BYTES));
|
||||
|
||||
// Waiter should complete now
|
||||
let guard2 = waiter.await.unwrap();
|
||||
assert_eq!(guard2.granted_bytes(), 5 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// Total: guard1 has 4MB, guard2 has 5MB = 9MB
|
||||
assert_eq!(manager.used_bytes(), 9 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_acquire_additional_exceeds_total_limit() {
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
let mut guard = manager.try_acquire(8 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
// Try to acquire additional 5MB - would exceed total limit of 10MB
|
||||
let result = guard.acquire_additional(5 * PERMIT_GRANULARITY_BYTES).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
// Guard should remain unchanged
|
||||
assert_eq!(guard.granted_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 8 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_acquire_additional_success() {
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
let mut guard = manager.try_acquire(3 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
assert_eq!(manager.used_bytes(), 3 * PERMIT_GRANULARITY_BYTES);
|
||||
|
||||
// Acquire additional 4MB - should succeed
|
||||
guard
|
||||
.acquire_additional(4 * PERMIT_GRANULARITY_BYTES)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(guard.granted_bytes(), 7 * PERMIT_GRANULARITY_BYTES);
|
||||
assert_eq!(manager.used_bytes(), 7 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_acquire_additional_with_policy_wait_success() {
|
||||
use crate::policy::OnExhaustedPolicy;
|
||||
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
let mut guard1 = manager.try_acquire(8 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
let manager_clone = manager.clone();
|
||||
let waiter = tokio::spawn(async move {
|
||||
let mut guard2 = manager_clone.try_acquire(0).unwrap();
|
||||
// Wait policy with 1 second timeout
|
||||
guard2
|
||||
.acquire_additional_with_policy(
|
||||
5 * PERMIT_GRANULARITY_BYTES,
|
||||
OnExhaustedPolicy::Wait {
|
||||
timeout: Duration::from_secs(1),
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
guard2
|
||||
});
|
||||
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
|
||||
// Release memory to unblock waiter
|
||||
assert!(guard1.release_partial(5 * PERMIT_GRANULARITY_BYTES));
|
||||
|
||||
let guard2 = waiter.await.unwrap();
|
||||
assert_eq!(guard2.granted_bytes(), 5 * PERMIT_GRANULARITY_BYTES);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_acquire_additional_with_policy_wait_timeout() {
|
||||
use crate::policy::OnExhaustedPolicy;
|
||||
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
// Take all memory
|
||||
let _guard1 = manager.try_acquire(10 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
let mut guard2 = manager.try_acquire(0).unwrap();
|
||||
|
||||
// Try to acquire with short timeout - should timeout
|
||||
let result = guard2
|
||||
.acquire_additional_with_policy(
|
||||
5 * PERMIT_GRANULARITY_BYTES,
|
||||
OnExhaustedPolicy::Wait {
|
||||
timeout: Duration::from_millis(50),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(guard2.granted_bytes(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_acquire_additional_with_policy_fail() {
|
||||
use crate::policy::OnExhaustedPolicy;
|
||||
|
||||
let limit = 10 * PERMIT_GRANULARITY_BYTES;
|
||||
let manager = MemoryManager::new(limit, NoOpMetrics);
|
||||
|
||||
let _guard1 = manager.try_acquire(8 * PERMIT_GRANULARITY_BYTES).unwrap();
|
||||
|
||||
let mut guard2 = manager.try_acquire(0).unwrap();
|
||||
|
||||
// Fail policy - should return error immediately
|
||||
let result = guard2
|
||||
.acquire_additional_with_policy(5 * PERMIT_GRANULARITY_BYTES, OnExhaustedPolicy::Fail)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(guard2.granted_bytes(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_acquire_additional_unlimited() {
|
||||
let manager = MemoryManager::new(0, NoOpMetrics); // Unlimited
|
||||
let mut guard = manager.try_acquire(0).unwrap();
|
||||
|
||||
// Should always succeed with unlimited manager
|
||||
guard
|
||||
.acquire_additional(1000 * PERMIT_GRANULARITY_BYTES)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(guard.granted_bytes(), 0);
|
||||
assert_eq!(manager.used_bytes(), 0);
|
||||
}
|
||||
@@ -12,6 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::hash::{DefaultHasher, Hash, Hasher};
|
||||
use std::str::FromStr;
|
||||
|
||||
@@ -60,7 +61,7 @@ pub trait ClusterInfo {
|
||||
}
|
||||
|
||||
/// The key of [NodeInfo] in the storage. The format is `__meta_cluster_node_info-0-{role}-{node_id}`.
|
||||
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub struct NodeInfoKey {
|
||||
/// The role of the node. It can be `[Role::Datanode]` or `[Role::Frontend]`.
|
||||
pub role: Role,
|
||||
@@ -135,7 +136,7 @@ pub struct NodeInfo {
|
||||
pub hostname: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
pub enum Role {
|
||||
Datanode,
|
||||
Frontend,
|
||||
@@ -241,6 +242,12 @@ impl From<&NodeInfoKey> for Vec<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for NodeInfoKey {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{:?}-{}", self.role, self.node_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for NodeInfo {
|
||||
type Err = Error;
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ use crate::region_registry::LeaderRegionRegistryRef;
|
||||
pub mod alter_database;
|
||||
pub mod alter_logical_tables;
|
||||
pub mod alter_table;
|
||||
pub mod comment_on;
|
||||
pub mod create_database;
|
||||
pub mod create_flow;
|
||||
pub mod create_logical_tables;
|
||||
|
||||
@@ -301,8 +301,8 @@ fn build_new_table_info(
|
||||
| AlterKind::UnsetTableOptions { .. }
|
||||
| AlterKind::SetIndexes { .. }
|
||||
| AlterKind::UnsetIndexes { .. }
|
||||
| AlterKind::DropDefaults { .. } => {}
|
||||
AlterKind::SetDefaults { .. } => {}
|
||||
| AlterKind::DropDefaults { .. }
|
||||
| AlterKind::SetDefaults { .. } => {}
|
||||
}
|
||||
|
||||
info!(
|
||||
|
||||
509
src/common/meta/src/ddl/comment_on.rs
Normal file
509
src/common/meta/src/ddl/comment_on.rs
Normal file
@@ -0,0 +1,509 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use common_catalog::format_full_table_name;
|
||||
use common_procedure::error::{FromJsonSnafu, Result as ProcedureResult, ToJsonSnafu};
|
||||
use common_procedure::{Context as ProcedureContext, LockKey, Procedure, Status};
|
||||
use common_telemetry::tracing::info;
|
||||
use datatypes::schema::COMMENT_KEY as COLUMN_COMMENT_KEY;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::{OptionExt, ResultExt, ensure};
|
||||
use store_api::storage::TableId;
|
||||
use strum::AsRefStr;
|
||||
use table::metadata::RawTableInfo;
|
||||
use table::requests::COMMENT_KEY as TABLE_COMMENT_KEY;
|
||||
use table::table_name::TableName;
|
||||
|
||||
use crate::cache_invalidator::Context;
|
||||
use crate::ddl::DdlContext;
|
||||
use crate::ddl::utils::map_to_procedure_error;
|
||||
use crate::error::{ColumnNotFoundSnafu, FlowNotFoundSnafu, Result, TableNotFoundSnafu};
|
||||
use crate::instruction::CacheIdent;
|
||||
use crate::key::flow::flow_info::{FlowInfoKey, FlowInfoValue};
|
||||
use crate::key::table_info::{TableInfoKey, TableInfoValue};
|
||||
use crate::key::table_name::TableNameKey;
|
||||
use crate::key::{DeserializedValueWithBytes, FlowId, MetadataKey, MetadataValue};
|
||||
use crate::lock_key::{CatalogLock, FlowNameLock, SchemaLock, TableNameLock};
|
||||
use crate::rpc::ddl::{CommentObjectType, CommentOnTask};
|
||||
use crate::rpc::store::PutRequest;
|
||||
|
||||
pub struct CommentOnProcedure {
|
||||
pub context: DdlContext,
|
||||
pub data: CommentOnData,
|
||||
}
|
||||
|
||||
impl CommentOnProcedure {
|
||||
pub const TYPE_NAME: &'static str = "metasrv-procedure::CommentOn";
|
||||
|
||||
pub fn new(task: CommentOnTask, context: DdlContext) -> Self {
|
||||
Self {
|
||||
context,
|
||||
data: CommentOnData::new(task),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_json(json: &str, context: DdlContext) -> ProcedureResult<Self> {
|
||||
let data = serde_json::from_str(json).context(FromJsonSnafu)?;
|
||||
|
||||
Ok(Self { context, data })
|
||||
}
|
||||
|
||||
pub async fn on_prepare(&mut self) -> Result<Status> {
|
||||
match self.data.object_type {
|
||||
CommentObjectType::Table | CommentObjectType::Column => {
|
||||
self.prepare_table_or_column().await?;
|
||||
}
|
||||
CommentObjectType::Flow => {
|
||||
self.prepare_flow().await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Fast path: if comment is unchanged, skip update
|
||||
if self.data.is_unchanged {
|
||||
let object_desc = match self.data.object_type {
|
||||
CommentObjectType::Table => format!(
|
||||
"table {}",
|
||||
format_full_table_name(
|
||||
&self.data.catalog_name,
|
||||
&self.data.schema_name,
|
||||
&self.data.object_name,
|
||||
)
|
||||
),
|
||||
CommentObjectType::Column => format!(
|
||||
"column {}.{}",
|
||||
format_full_table_name(
|
||||
&self.data.catalog_name,
|
||||
&self.data.schema_name,
|
||||
&self.data.object_name,
|
||||
),
|
||||
self.data.column_name.as_ref().unwrap()
|
||||
),
|
||||
CommentObjectType::Flow => {
|
||||
format!("flow {}.{}", self.data.catalog_name, self.data.object_name)
|
||||
}
|
||||
};
|
||||
info!("Comment unchanged for {}, skipping update", object_desc);
|
||||
return Ok(Status::done());
|
||||
}
|
||||
|
||||
self.data.state = CommentOnState::UpdateMetadata;
|
||||
Ok(Status::executing(true))
|
||||
}
|
||||
|
||||
async fn prepare_table_or_column(&mut self) -> Result<()> {
|
||||
let table_name_key = TableNameKey::new(
|
||||
&self.data.catalog_name,
|
||||
&self.data.schema_name,
|
||||
&self.data.object_name,
|
||||
);
|
||||
|
||||
let table_id = self
|
||||
.context
|
||||
.table_metadata_manager
|
||||
.table_name_manager()
|
||||
.get(table_name_key)
|
||||
.await?
|
||||
.with_context(|| TableNotFoundSnafu {
|
||||
table_name: format_full_table_name(
|
||||
&self.data.catalog_name,
|
||||
&self.data.schema_name,
|
||||
&self.data.object_name,
|
||||
),
|
||||
})?
|
||||
.table_id();
|
||||
|
||||
let table_info = self
|
||||
.context
|
||||
.table_metadata_manager
|
||||
.table_info_manager()
|
||||
.get(table_id)
|
||||
.await?
|
||||
.with_context(|| TableNotFoundSnafu {
|
||||
table_name: format_full_table_name(
|
||||
&self.data.catalog_name,
|
||||
&self.data.schema_name,
|
||||
&self.data.object_name,
|
||||
),
|
||||
})?;
|
||||
|
||||
// For column comments, validate the column exists
|
||||
if self.data.object_type == CommentObjectType::Column {
|
||||
let column_name = self.data.column_name.as_ref().unwrap();
|
||||
let column_exists = table_info
|
||||
.table_info
|
||||
.meta
|
||||
.schema
|
||||
.column_schemas
|
||||
.iter()
|
||||
.any(|col| &col.name == column_name);
|
||||
|
||||
ensure!(
|
||||
column_exists,
|
||||
ColumnNotFoundSnafu {
|
||||
column_name,
|
||||
column_id: 0u32, // column_id is not known here
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
self.data.table_id = Some(table_id);
|
||||
|
||||
// Check if comment is unchanged for early exit optimization
|
||||
match self.data.object_type {
|
||||
CommentObjectType::Table => {
|
||||
let current_comment = &table_info.table_info.desc;
|
||||
if &self.data.comment == current_comment {
|
||||
self.data.is_unchanged = true;
|
||||
}
|
||||
}
|
||||
CommentObjectType::Column => {
|
||||
let column_name = self.data.column_name.as_ref().unwrap();
|
||||
let column_schema = table_info
|
||||
.table_info
|
||||
.meta
|
||||
.schema
|
||||
.column_schemas
|
||||
.iter()
|
||||
.find(|col| &col.name == column_name)
|
||||
.unwrap(); // Safe: validated above
|
||||
|
||||
let current_comment = column_schema.metadata().get(COLUMN_COMMENT_KEY);
|
||||
if self.data.comment.as_deref() == current_comment.map(String::as_str) {
|
||||
self.data.is_unchanged = true;
|
||||
}
|
||||
}
|
||||
CommentObjectType::Flow => {
|
||||
// this branch is handled in `prepare_flow`
|
||||
}
|
||||
}
|
||||
|
||||
self.data.table_info = Some(table_info);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn prepare_flow(&mut self) -> Result<()> {
|
||||
let flow_name_value = self
|
||||
.context
|
||||
.flow_metadata_manager
|
||||
.flow_name_manager()
|
||||
.get(&self.data.catalog_name, &self.data.object_name)
|
||||
.await?
|
||||
.with_context(|| FlowNotFoundSnafu {
|
||||
flow_name: &self.data.object_name,
|
||||
})?;
|
||||
|
||||
let flow_id = flow_name_value.flow_id();
|
||||
let flow_info = self
|
||||
.context
|
||||
.flow_metadata_manager
|
||||
.flow_info_manager()
|
||||
.get_raw(flow_id)
|
||||
.await?
|
||||
.with_context(|| FlowNotFoundSnafu {
|
||||
flow_name: &self.data.object_name,
|
||||
})?;
|
||||
|
||||
self.data.flow_id = Some(flow_id);
|
||||
|
||||
// Check if comment is unchanged for early exit optimization
|
||||
let current_comment = &flow_info.get_inner_ref().comment;
|
||||
let new_comment = self.data.comment.as_deref().unwrap_or("");
|
||||
if new_comment == current_comment.as_str() {
|
||||
self.data.is_unchanged = true;
|
||||
}
|
||||
|
||||
self.data.flow_info = Some(flow_info);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn on_update_metadata(&mut self) -> Result<Status> {
|
||||
match self.data.object_type {
|
||||
CommentObjectType::Table => {
|
||||
self.update_table_comment().await?;
|
||||
}
|
||||
CommentObjectType::Column => {
|
||||
self.update_column_comment().await?;
|
||||
}
|
||||
CommentObjectType::Flow => {
|
||||
self.update_flow_comment().await?;
|
||||
}
|
||||
}
|
||||
|
||||
self.data.state = CommentOnState::InvalidateCache;
|
||||
Ok(Status::executing(true))
|
||||
}
|
||||
|
||||
async fn update_table_comment(&mut self) -> Result<()> {
|
||||
let table_info_value = self.data.table_info.as_ref().unwrap();
|
||||
let mut new_table_info = table_info_value.table_info.clone();
|
||||
|
||||
new_table_info.desc = self.data.comment.clone();
|
||||
|
||||
// Sync comment to table options
|
||||
sync_table_comment_option(
|
||||
&mut new_table_info.meta.options,
|
||||
new_table_info.desc.as_deref(),
|
||||
);
|
||||
|
||||
self.update_table_info(table_info_value, new_table_info)
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
"Updated comment for table {}.{}.{}",
|
||||
self.data.catalog_name, self.data.schema_name, self.data.object_name
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_column_comment(&mut self) -> Result<()> {
|
||||
let table_info_value = self.data.table_info.as_ref().unwrap();
|
||||
let mut new_table_info = table_info_value.table_info.clone();
|
||||
|
||||
let column_name = self.data.column_name.as_ref().unwrap();
|
||||
let column_schema = new_table_info
|
||||
.meta
|
||||
.schema
|
||||
.column_schemas
|
||||
.iter_mut()
|
||||
.find(|col| &col.name == column_name)
|
||||
.unwrap(); // Safe: validated in prepare
|
||||
|
||||
update_column_comment_metadata(column_schema, self.data.comment.clone());
|
||||
|
||||
self.update_table_info(table_info_value, new_table_info)
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
"Updated comment for column {}.{}.{}.{}",
|
||||
self.data.catalog_name, self.data.schema_name, self.data.object_name, column_name
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_flow_comment(&mut self) -> Result<()> {
|
||||
let flow_id = self.data.flow_id.unwrap();
|
||||
let flow_info_value = self.data.flow_info.as_ref().unwrap();
|
||||
|
||||
let mut new_flow_info = flow_info_value.get_inner_ref().clone();
|
||||
new_flow_info.comment = self.data.comment.clone().unwrap_or_default();
|
||||
new_flow_info.updated_time = Utc::now();
|
||||
|
||||
let raw_value = new_flow_info.try_as_raw_value()?;
|
||||
|
||||
self.context
|
||||
.table_metadata_manager
|
||||
.kv_backend()
|
||||
.put(
|
||||
PutRequest::new()
|
||||
.with_key(FlowInfoKey::new(flow_id).to_bytes())
|
||||
.with_value(raw_value),
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
"Updated comment for flow {}.{}",
|
||||
self.data.catalog_name, self.data.object_name
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_table_info(
|
||||
&self,
|
||||
current_table_info: &DeserializedValueWithBytes<TableInfoValue>,
|
||||
new_table_info: RawTableInfo,
|
||||
) -> Result<()> {
|
||||
let table_id = current_table_info.table_info.ident.table_id;
|
||||
let new_table_info_value = current_table_info.update(new_table_info);
|
||||
let raw_value = new_table_info_value.try_as_raw_value()?;
|
||||
|
||||
self.context
|
||||
.table_metadata_manager
|
||||
.kv_backend()
|
||||
.put(
|
||||
PutRequest::new()
|
||||
.with_key(TableInfoKey::new(table_id).to_bytes())
|
||||
.with_value(raw_value),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn on_invalidate_cache(&mut self) -> Result<Status> {
|
||||
let cache_invalidator = &self.context.cache_invalidator;
|
||||
|
||||
match self.data.object_type {
|
||||
CommentObjectType::Table | CommentObjectType::Column => {
|
||||
let table_id = self.data.table_id.unwrap();
|
||||
let table_name = TableName::new(
|
||||
self.data.catalog_name.clone(),
|
||||
self.data.schema_name.clone(),
|
||||
self.data.object_name.clone(),
|
||||
);
|
||||
|
||||
let cache_ident = vec![
|
||||
CacheIdent::TableId(table_id),
|
||||
CacheIdent::TableName(table_name),
|
||||
];
|
||||
|
||||
cache_invalidator
|
||||
.invalidate(&Context::default(), &cache_ident)
|
||||
.await?;
|
||||
}
|
||||
CommentObjectType::Flow => {
|
||||
let flow_id = self.data.flow_id.unwrap();
|
||||
let cache_ident = vec![CacheIdent::FlowId(flow_id)];
|
||||
|
||||
cache_invalidator
|
||||
.invalidate(&Context::default(), &cache_ident)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Status::done())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Procedure for CommentOnProcedure {
|
||||
fn type_name(&self) -> &str {
|
||||
Self::TYPE_NAME
|
||||
}
|
||||
|
||||
async fn execute(&mut self, _ctx: &ProcedureContext) -> ProcedureResult<Status> {
|
||||
match self.data.state {
|
||||
CommentOnState::Prepare => self.on_prepare().await,
|
||||
CommentOnState::UpdateMetadata => self.on_update_metadata().await,
|
||||
CommentOnState::InvalidateCache => self.on_invalidate_cache().await,
|
||||
}
|
||||
.map_err(map_to_procedure_error)
|
||||
}
|
||||
|
||||
fn dump(&self) -> ProcedureResult<String> {
|
||||
serde_json::to_string(&self.data).context(ToJsonSnafu)
|
||||
}
|
||||
|
||||
fn lock_key(&self) -> LockKey {
|
||||
let catalog = &self.data.catalog_name;
|
||||
let schema = &self.data.schema_name;
|
||||
|
||||
let lock_key = match self.data.object_type {
|
||||
CommentObjectType::Table | CommentObjectType::Column => {
|
||||
vec![
|
||||
CatalogLock::Read(catalog).into(),
|
||||
SchemaLock::read(catalog, schema).into(),
|
||||
TableNameLock::new(catalog, schema, &self.data.object_name).into(),
|
||||
]
|
||||
}
|
||||
CommentObjectType::Flow => {
|
||||
vec![
|
||||
CatalogLock::Read(catalog).into(),
|
||||
FlowNameLock::new(catalog, &self.data.object_name).into(),
|
||||
]
|
||||
}
|
||||
};
|
||||
|
||||
LockKey::new(lock_key)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, AsRefStr)]
|
||||
enum CommentOnState {
|
||||
Prepare,
|
||||
UpdateMetadata,
|
||||
InvalidateCache,
|
||||
}
|
||||
|
||||
/// The data of comment on procedure.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CommentOnData {
|
||||
state: CommentOnState,
|
||||
catalog_name: String,
|
||||
schema_name: String,
|
||||
object_type: CommentObjectType,
|
||||
object_name: String,
|
||||
/// Column name (only for Column comments)
|
||||
column_name: Option<String>,
|
||||
comment: Option<String>,
|
||||
/// Cached table ID (for Table/Column)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
table_id: Option<TableId>,
|
||||
/// Cached table info (for Table/Column)
|
||||
#[serde(skip)]
|
||||
table_info: Option<DeserializedValueWithBytes<TableInfoValue>>,
|
||||
/// Cached flow ID (for Flow)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
flow_id: Option<FlowId>,
|
||||
/// Cached flow info (for Flow)
|
||||
#[serde(skip)]
|
||||
flow_info: Option<DeserializedValueWithBytes<FlowInfoValue>>,
|
||||
/// Whether the comment is unchanged (optimization for early exit)
|
||||
#[serde(skip)]
|
||||
is_unchanged: bool,
|
||||
}
|
||||
|
||||
impl CommentOnData {
|
||||
pub fn new(task: CommentOnTask) -> Self {
|
||||
Self {
|
||||
state: CommentOnState::Prepare,
|
||||
catalog_name: task.catalog_name,
|
||||
schema_name: task.schema_name,
|
||||
object_type: task.object_type,
|
||||
object_name: task.object_name,
|
||||
column_name: task.column_name,
|
||||
comment: task.comment,
|
||||
table_id: None,
|
||||
table_info: None,
|
||||
flow_id: None,
|
||||
flow_info: None,
|
||||
is_unchanged: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_column_comment_metadata(
|
||||
column_schema: &mut datatypes::schema::ColumnSchema,
|
||||
comment: Option<String>,
|
||||
) {
|
||||
match comment {
|
||||
Some(value) => {
|
||||
column_schema
|
||||
.mut_metadata()
|
||||
.insert(COLUMN_COMMENT_KEY.to_string(), value);
|
||||
}
|
||||
None => {
|
||||
column_schema.mut_metadata().remove(COLUMN_COMMENT_KEY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sync_table_comment_option(options: &mut table::requests::TableOptions, comment: Option<&str>) {
|
||||
match comment {
|
||||
Some(value) => {
|
||||
options
|
||||
.extra_options
|
||||
.insert(TABLE_COMMENT_KEY.to_string(), value.to_string());
|
||||
}
|
||||
None => {
|
||||
options.extra_options.remove(TABLE_COMMENT_KEY);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,7 @@ use store_api::storage::TableId;
|
||||
use crate::ddl::alter_database::AlterDatabaseProcedure;
|
||||
use crate::ddl::alter_logical_tables::AlterLogicalTablesProcedure;
|
||||
use crate::ddl::alter_table::AlterTableProcedure;
|
||||
use crate::ddl::comment_on::CommentOnProcedure;
|
||||
use crate::ddl::create_database::CreateDatabaseProcedure;
|
||||
use crate::ddl::create_flow::CreateFlowProcedure;
|
||||
use crate::ddl::create_logical_tables::CreateLogicalTablesProcedure;
|
||||
@@ -46,44 +47,39 @@ use crate::error::{
|
||||
use crate::key::table_info::TableInfoValue;
|
||||
use crate::key::table_name::TableNameKey;
|
||||
use crate::key::{DeserializedValueWithBytes, TableMetadataManagerRef};
|
||||
use crate::kv_backend::KvBackendRef;
|
||||
use crate::procedure_executor::ExecutorContext;
|
||||
#[cfg(feature = "enterprise")]
|
||||
use crate::rpc::ddl::DdlTask::CreateTrigger;
|
||||
#[cfg(feature = "enterprise")]
|
||||
use crate::rpc::ddl::DdlTask::DropTrigger;
|
||||
use crate::rpc::ddl::DdlTask::{
|
||||
AlterDatabase, AlterLogicalTables, AlterTable, CreateDatabase, CreateFlow, CreateLogicalTables,
|
||||
CreateTable, CreateView, DropDatabase, DropFlow, DropLogicalTables, DropTable, DropView,
|
||||
TruncateTable,
|
||||
AlterDatabase, AlterLogicalTables, AlterTable, CommentOn, CreateDatabase, CreateFlow,
|
||||
CreateLogicalTables, CreateTable, CreateView, DropDatabase, DropFlow, DropLogicalTables,
|
||||
DropTable, DropView, TruncateTable,
|
||||
};
|
||||
#[cfg(feature = "enterprise")]
|
||||
use crate::rpc::ddl::trigger::CreateTriggerTask;
|
||||
#[cfg(feature = "enterprise")]
|
||||
use crate::rpc::ddl::trigger::DropTriggerTask;
|
||||
use crate::rpc::ddl::{
|
||||
AlterDatabaseTask, AlterTableTask, CreateDatabaseTask, CreateFlowTask, CreateTableTask,
|
||||
CreateViewTask, DropDatabaseTask, DropFlowTask, DropTableTask, DropViewTask, QueryContext,
|
||||
SubmitDdlTaskRequest, SubmitDdlTaskResponse, TruncateTableTask,
|
||||
AlterDatabaseTask, AlterTableTask, CommentOnTask, CreateDatabaseTask, CreateFlowTask,
|
||||
CreateTableTask, CreateViewTask, DropDatabaseTask, DropFlowTask, DropTableTask, DropViewTask,
|
||||
QueryContext, SubmitDdlTaskRequest, SubmitDdlTaskResponse, TruncateTableTask,
|
||||
};
|
||||
use crate::rpc::router::RegionRoute;
|
||||
|
||||
/// A configurator that customizes or enhances a [`DdlManager`].
|
||||
#[async_trait::async_trait]
|
||||
pub trait DdlManagerConfigurator: Send + Sync {
|
||||
pub trait DdlManagerConfigurator<C>: Send + Sync {
|
||||
/// Configures the given [`DdlManager`] using the provided [`DdlManagerConfigureContext`].
|
||||
async fn configure(
|
||||
&self,
|
||||
ddl_manager: DdlManager,
|
||||
ctx: DdlManagerConfigureContext,
|
||||
ctx: C,
|
||||
) -> std::result::Result<DdlManager, BoxedError>;
|
||||
}
|
||||
|
||||
pub type DdlManagerConfiguratorRef = Arc<dyn DdlManagerConfigurator>;
|
||||
|
||||
pub struct DdlManagerConfigureContext {
|
||||
pub kv_backend: KvBackendRef,
|
||||
}
|
||||
pub type DdlManagerConfiguratorRef<C> = Arc<dyn DdlManagerConfigurator<C>>;
|
||||
|
||||
pub type DdlManagerRef = Arc<DdlManager>;
|
||||
|
||||
@@ -197,7 +193,8 @@ impl DdlManager {
|
||||
TruncateTableProcedure,
|
||||
CreateDatabaseProcedure,
|
||||
DropDatabaseProcedure,
|
||||
DropViewProcedure
|
||||
DropViewProcedure,
|
||||
CommentOnProcedure
|
||||
);
|
||||
|
||||
for (type_name, loader_factory) in loaders {
|
||||
@@ -413,6 +410,19 @@ impl DdlManager {
|
||||
self.submit_procedure(procedure_with_id).await
|
||||
}
|
||||
|
||||
/// Submits and executes a comment on task.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn submit_comment_on_task(
|
||||
&self,
|
||||
comment_on_task: CommentOnTask,
|
||||
) -> Result<(ProcedureId, Option<Output>)> {
|
||||
let context = self.create_context();
|
||||
let procedure = CommentOnProcedure::new(comment_on_task, context);
|
||||
let procedure_with_id = ProcedureWithId::with_random_id(Box::new(procedure));
|
||||
|
||||
self.submit_procedure(procedure_with_id).await
|
||||
}
|
||||
|
||||
async fn submit_procedure(
|
||||
&self,
|
||||
procedure_with_id: ProcedureWithId,
|
||||
@@ -481,6 +491,7 @@ impl DdlManager {
|
||||
handle_create_view_task(self, create_view_task).await
|
||||
}
|
||||
DropView(drop_view_task) => handle_drop_view_task(self, drop_view_task).await,
|
||||
CommentOn(comment_on_task) => handle_comment_on_task(self, comment_on_task).await,
|
||||
#[cfg(feature = "enterprise")]
|
||||
CreateTrigger(create_trigger_task) => {
|
||||
handle_create_trigger_task(
|
||||
@@ -912,6 +923,26 @@ async fn handle_create_view_task(
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_comment_on_task(
|
||||
ddl_manager: &DdlManager,
|
||||
comment_on_task: CommentOnTask,
|
||||
) -> Result<SubmitDdlTaskResponse> {
|
||||
let (id, _) = ddl_manager
|
||||
.submit_comment_on_task(comment_on_task.clone())
|
||||
.await?;
|
||||
|
||||
let procedure_id = id.to_string();
|
||||
info!(
|
||||
"Comment on {}.{}.{} is updated via procedure_id {id:?}",
|
||||
comment_on_task.catalog_name, comment_on_task.schema_name, comment_on_task.object_name
|
||||
);
|
||||
|
||||
Ok(SubmitDdlTaskResponse {
|
||||
key: procedure_id.into(),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -12,25 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Heartbeat interval time (is the basic unit of various time).
|
||||
pub const HEARTBEAT_INTERVAL_MILLIS: u64 = 3000;
|
||||
|
||||
/// The frontend will also send heartbeats to Metasrv, sending an empty
|
||||
/// heartbeat every HEARTBEAT_INTERVAL_MILLIS * 6 seconds.
|
||||
pub const FRONTEND_HEARTBEAT_INTERVAL_MILLIS: u64 = HEARTBEAT_INTERVAL_MILLIS * 6;
|
||||
|
||||
/// The lease seconds of a region. It's set by 3 heartbeat intervals
|
||||
/// (HEARTBEAT_INTERVAL_MILLIS × 3), plus some extra buffer (1 second).
|
||||
pub const REGION_LEASE_SECS: u64 =
|
||||
Duration::from_millis(HEARTBEAT_INTERVAL_MILLIS * 3).as_secs() + 1;
|
||||
|
||||
/// When creating table or region failover, a target node needs to be selected.
|
||||
/// If the node's lease has expired, the `Selector` will not select it.
|
||||
pub const DATANODE_LEASE_SECS: u64 = REGION_LEASE_SECS;
|
||||
|
||||
pub const FLOWNODE_LEASE_SECS: u64 = DATANODE_LEASE_SECS;
|
||||
pub const BASE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(3);
|
||||
|
||||
/// The lease seconds of metasrv leader.
|
||||
pub const META_LEASE_SECS: u64 = 5;
|
||||
@@ -41,6 +26,15 @@ pub const POSTGRES_KEEP_ALIVE_SECS: u64 = 30;
|
||||
/// In a lease, there are two opportunities for renewal.
|
||||
pub const META_KEEP_ALIVE_INTERVAL_SECS: u64 = META_LEASE_SECS / 2;
|
||||
|
||||
/// The timeout of the heartbeat request.
|
||||
pub const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(META_KEEP_ALIVE_INTERVAL_SECS + 1);
|
||||
|
||||
/// The keep-alive interval of the heartbeat channel.
|
||||
pub const HEARTBEAT_CHANNEL_KEEP_ALIVE_INTERVAL_SECS: Duration = Duration::from_secs(15);
|
||||
|
||||
/// The keep-alive timeout of the heartbeat channel.
|
||||
pub const HEARTBEAT_CHANNEL_KEEP_ALIVE_TIMEOUT_SECS: Duration = Duration::from_secs(5);
|
||||
|
||||
/// The default mailbox round-trip timeout.
|
||||
pub const MAILBOX_RTT_SECS: u64 = 1;
|
||||
|
||||
@@ -49,3 +43,60 @@ pub const TOPIC_STATS_REPORT_INTERVAL_SECS: u64 = 15;
|
||||
|
||||
/// The retention seconds of topic stats.
|
||||
pub const TOPIC_STATS_RETENTION_SECS: u64 = TOPIC_STATS_REPORT_INTERVAL_SECS * 100;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
/// The distributed time constants.
|
||||
pub struct DistributedTimeConstants {
|
||||
pub heartbeat_interval: Duration,
|
||||
pub frontend_heartbeat_interval: Duration,
|
||||
pub region_lease: Duration,
|
||||
pub datanode_lease: Duration,
|
||||
pub flownode_lease: Duration,
|
||||
}
|
||||
|
||||
/// The frontend heartbeat interval is 6 times of the base heartbeat interval.
|
||||
pub fn frontend_heartbeat_interval(base_heartbeat_interval: Duration) -> Duration {
|
||||
base_heartbeat_interval * 6
|
||||
}
|
||||
|
||||
impl DistributedTimeConstants {
|
||||
/// Create a new DistributedTimeConstants from the heartbeat interval.
|
||||
pub fn from_heartbeat_interval(heartbeat_interval: Duration) -> Self {
|
||||
let region_lease = heartbeat_interval * 3 + Duration::from_secs(1);
|
||||
let datanode_lease = region_lease;
|
||||
let flownode_lease = datanode_lease;
|
||||
Self {
|
||||
heartbeat_interval,
|
||||
frontend_heartbeat_interval: frontend_heartbeat_interval(heartbeat_interval),
|
||||
region_lease,
|
||||
datanode_lease,
|
||||
flownode_lease,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DistributedTimeConstants {
|
||||
fn default() -> Self {
|
||||
Self::from_heartbeat_interval(BASE_HEARTBEAT_INTERVAL)
|
||||
}
|
||||
}
|
||||
|
||||
static DEFAULT_DISTRIBUTED_TIME_CONSTANTS: OnceLock<DistributedTimeConstants> = OnceLock::new();
|
||||
|
||||
/// Get the default distributed time constants.
|
||||
pub fn default_distributed_time_constants() -> &'static DistributedTimeConstants {
|
||||
DEFAULT_DISTRIBUTED_TIME_CONSTANTS.get_or_init(Default::default)
|
||||
}
|
||||
|
||||
/// Initialize the default distributed time constants.
|
||||
pub fn init_distributed_time_constants(base_heartbeat_interval: Duration) {
|
||||
let distributed_time_constants =
|
||||
DistributedTimeConstants::from_heartbeat_interval(base_heartbeat_interval);
|
||||
DEFAULT_DISTRIBUTED_TIME_CONSTANTS
|
||||
.set(distributed_time_constants)
|
||||
.expect("Failed to set default distributed time constants");
|
||||
common_telemetry::info!(
|
||||
"Initialized default distributed time constants: {:#?}",
|
||||
distributed_time_constants
|
||||
);
|
||||
}
|
||||
|
||||
@@ -272,13 +272,6 @@ pub enum Error {
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to send message: {err_msg}"))]
|
||||
SendMessage {
|
||||
err_msg: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to serde json"))]
|
||||
SerdeJson {
|
||||
#[snafu(source)]
|
||||
@@ -1118,7 +1111,7 @@ impl ErrorExt for Error {
|
||||
| DeserializeFlexbuffers { .. }
|
||||
| ConvertTimeRanges { .. } => StatusCode::Unexpected,
|
||||
|
||||
SendMessage { .. } | GetKvCache { .. } | CacheNotGet { .. } => StatusCode::Internal,
|
||||
GetKvCache { .. } | CacheNotGet { .. } => StatusCode::Internal,
|
||||
|
||||
SchemaAlreadyExists { .. } => StatusCode::DatabaseAlreadyExists,
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ use crate::heartbeat::mailbox::{IncomingMessage, MailboxRef};
|
||||
|
||||
pub mod invalidate_table_cache;
|
||||
pub mod parse_mailbox_message;
|
||||
pub mod suspend;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
|
||||
69
src/common/meta/src/heartbeat/handler/suspend.rs
Normal file
69
src/common/meta/src/heartbeat/handler/suspend.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_telemetry::{info, warn};
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::heartbeat::handler::{
|
||||
HandleControl, HeartbeatResponseHandler, HeartbeatResponseHandlerContext,
|
||||
};
|
||||
use crate::instruction::Instruction;
|
||||
|
||||
/// A heartbeat response handler that handles special "suspend" error.
|
||||
/// It will simply set or clear (if previously set) the inner suspend atomic state.
|
||||
pub struct SuspendHandler {
|
||||
suspend: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl SuspendHandler {
|
||||
pub fn new(suspend: Arc<AtomicBool>) -> Self {
|
||||
Self { suspend }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HeartbeatResponseHandler for SuspendHandler {
|
||||
fn is_acceptable(&self, context: &HeartbeatResponseHandlerContext) -> bool {
|
||||
matches!(
|
||||
context.incoming_message,
|
||||
Some((_, Instruction::Suspend)) | None
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, context: &mut HeartbeatResponseHandlerContext) -> Result<HandleControl> {
|
||||
let flip_state = |expect: bool| {
|
||||
self.suspend
|
||||
.compare_exchange(expect, !expect, Ordering::Relaxed, Ordering::Relaxed)
|
||||
.is_ok()
|
||||
};
|
||||
|
||||
if let Some((_, Instruction::Suspend)) = context.incoming_message.take() {
|
||||
if flip_state(false) {
|
||||
warn!("Suspend instruction received from meta, entering suspension state");
|
||||
}
|
||||
} else {
|
||||
// Suspended components are made always tried to get rid of this state, we don't want
|
||||
// an "un-suspend" instruction to resume them running. That can be error-prone.
|
||||
// So if the "suspend" instruction is not found in the heartbeat, just unset the state.
|
||||
if flip_state(true) {
|
||||
info!("clear suspend state");
|
||||
}
|
||||
}
|
||||
Ok(HandleControl::Continue)
|
||||
}
|
||||
}
|
||||
@@ -15,8 +15,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
|
||||
use crate::error::{self, Result};
|
||||
use crate::instruction::{Instruction, InstructionReply};
|
||||
|
||||
pub type IncomingMessage = (MessageMeta, Instruction);
|
||||
@@ -51,13 +51,8 @@ impl HeartbeatMailbox {
|
||||
Self { sender }
|
||||
}
|
||||
|
||||
pub async fn send(&self, message: OutgoingMessage) -> Result<()> {
|
||||
self.sender.send(message).await.map_err(|e| {
|
||||
error::SendMessageSnafu {
|
||||
err_msg: e.to_string(),
|
||||
}
|
||||
.build()
|
||||
})
|
||||
pub async fn send(&self, message: OutgoingMessage) -> Result<(), SendError<OutgoingMessage>> {
|
||||
self.sender.send(message).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -339,6 +339,16 @@ pub struct FlushRegions {
|
||||
pub error_strategy: FlushErrorStrategy,
|
||||
}
|
||||
|
||||
impl Display for FlushRegions {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"FlushRegions(region_ids={:?}, strategy={:?}, error_strategy={:?})",
|
||||
self.region_ids, self.strategy, self.error_strategy
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl FlushRegions {
|
||||
/// Create synchronous single-region flush
|
||||
pub fn sync_single(region_id: RegionId) -> Self {
|
||||
@@ -504,6 +514,22 @@ impl Display for GcRegionsReply {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct EnterStagingRegion {
|
||||
pub region_id: RegionId,
|
||||
pub partition_expr: String,
|
||||
}
|
||||
|
||||
impl Display for EnterStagingRegion {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"EnterStagingRegion(region_id={}, partition_expr={})",
|
||||
self.region_id, self.partition_expr
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Display, PartialEq)]
|
||||
pub enum Instruction {
|
||||
/// Opens regions.
|
||||
@@ -529,6 +555,10 @@ pub enum Instruction {
|
||||
GetFileRefs(GetFileRefs),
|
||||
/// Triggers garbage collection for a region.
|
||||
GcRegions(GcRegions),
|
||||
/// Temporary suspend serving reads or writes
|
||||
Suspend,
|
||||
/// Makes regions enter staging state.
|
||||
EnterStagingRegions(Vec<EnterStagingRegion>),
|
||||
}
|
||||
|
||||
impl Instruction {
|
||||
@@ -585,6 +615,13 @@ impl Instruction {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_enter_staging_regions(self) -> Option<Vec<EnterStagingRegion>> {
|
||||
match self {
|
||||
Self::EnterStagingRegions(enter_staging) => Some(enter_staging),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The reply of [UpgradeRegion].
|
||||
@@ -678,6 +715,28 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
pub struct EnterStagingRegionReply {
|
||||
pub region_id: RegionId,
|
||||
/// Returns true if the region is under the new region rule.
|
||||
pub ready: bool,
|
||||
/// Indicates whether the region exists.
|
||||
pub exists: bool,
|
||||
/// Return error if any during the operation.
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
pub struct EnterStagingRegionsReply {
|
||||
pub replies: Vec<EnterStagingRegionReply>,
|
||||
}
|
||||
|
||||
impl EnterStagingRegionsReply {
|
||||
pub fn new(replies: Vec<EnterStagingRegionReply>) -> Self {
|
||||
Self { replies }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum InstructionReply {
|
||||
@@ -698,6 +757,7 @@ pub enum InstructionReply {
|
||||
FlushRegions(FlushRegionReply),
|
||||
GetFileRefs(GetFileRefsReply),
|
||||
GcRegions(GcRegionsReply),
|
||||
EnterStagingRegions(EnterStagingRegionsReply),
|
||||
}
|
||||
|
||||
impl Display for InstructionReply {
|
||||
@@ -714,6 +774,13 @@ impl Display for InstructionReply {
|
||||
Self::FlushRegions(reply) => write!(f, "InstructionReply::FlushRegions({})", reply),
|
||||
Self::GetFileRefs(reply) => write!(f, "InstructionReply::GetFileRefs({})", reply),
|
||||
Self::GcRegions(reply) => write!(f, "InstructionReply::GcRegion({})", reply),
|
||||
Self::EnterStagingRegions(reply) => {
|
||||
write!(
|
||||
f,
|
||||
"InstructionReply::EnterStagingRegions({:?})",
|
||||
reply.replies
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -754,13 +821,20 @@ impl InstructionReply {
|
||||
_ => panic!("Expected FlushRegions reply"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expect_enter_staging_regions_reply(self) -> Vec<EnterStagingRegionReply> {
|
||||
match self {
|
||||
Self::EnterStagingRegions(reply) => reply.replies,
|
||||
_ => panic!("Expected EnterStagingRegion reply"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashSet;
|
||||
|
||||
use store_api::storage::FileId;
|
||||
use store_api::storage::{FileId, FileRef};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1135,12 +1209,14 @@ mod tests {
|
||||
let mut manifest = FileRefsManifest::default();
|
||||
let r0 = RegionId::new(1024, 1);
|
||||
let r1 = RegionId::new(1024, 2);
|
||||
manifest
|
||||
.file_refs
|
||||
.insert(r0, HashSet::from([FileId::random()]));
|
||||
manifest
|
||||
.file_refs
|
||||
.insert(r1, HashSet::from([FileId::random()]));
|
||||
manifest.file_refs.insert(
|
||||
r0,
|
||||
HashSet::from([FileRef::new(r0, FileId::random(), None)]),
|
||||
);
|
||||
manifest.file_refs.insert(
|
||||
r1,
|
||||
HashSet::from([FileRef::new(r1, FileId::random(), None)]),
|
||||
);
|
||||
manifest.manifest_version.insert(r0, 10);
|
||||
manifest.manifest_version.insert(r1, 20);
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ impl TableInfoValue {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn update(&self, new_table_info: RawTableInfo) -> Self {
|
||||
pub fn update(&self, new_table_info: RawTableInfo) -> Self {
|
||||
Self {
|
||||
table_info: new_table_info,
|
||||
version: self.version + 1,
|
||||
|
||||
@@ -848,7 +848,7 @@ impl PgStore {
|
||||
.context(CreatePostgresPoolSnafu)?,
|
||||
};
|
||||
|
||||
Self::with_pg_pool(pool, None, table_name, max_txn_ops).await
|
||||
Self::with_pg_pool(pool, None, table_name, max_txn_ops, false).await
|
||||
}
|
||||
|
||||
/// Create [PgStore] impl of [KvBackendRef] from url (backward compatibility).
|
||||
@@ -862,6 +862,7 @@ impl PgStore {
|
||||
schema_name: Option<&str>,
|
||||
table_name: &str,
|
||||
max_txn_ops: usize,
|
||||
auto_create_schema: bool,
|
||||
) -> Result<KvBackendRef> {
|
||||
// Ensure the postgres metadata backend is ready to use.
|
||||
let client = match pool.get().await {
|
||||
@@ -873,9 +874,23 @@ impl PgStore {
|
||||
.fail();
|
||||
}
|
||||
};
|
||||
|
||||
// Automatically create schema if enabled and schema_name is provided.
|
||||
if auto_create_schema
|
||||
&& let Some(schema) = schema_name
|
||||
&& !schema.is_empty()
|
||||
{
|
||||
let create_schema_sql = format!("CREATE SCHEMA IF NOT EXISTS \"{}\"", schema);
|
||||
client
|
||||
.execute(&create_schema_sql, &[])
|
||||
.await
|
||||
.with_context(|_| PostgresExecutionSnafu {
|
||||
sql: create_schema_sql.clone(),
|
||||
})?;
|
||||
}
|
||||
|
||||
let template_factory = PgSqlTemplateFactory::new(schema_name, table_name);
|
||||
let sql_template_set = template_factory.build();
|
||||
// Do not attempt to create schema implicitly.
|
||||
client
|
||||
.execute(&sql_template_set.create_table_statement, &[])
|
||||
.await
|
||||
@@ -959,7 +974,7 @@ mod tests {
|
||||
let Some(pool) = build_pg15_pool().await else {
|
||||
return;
|
||||
};
|
||||
let res = PgStore::with_pg_pool(pool, None, "pg15_public_should_fail", 128).await;
|
||||
let res = PgStore::with_pg_pool(pool, None, "pg15_public_should_fail", 128, false).await;
|
||||
assert!(
|
||||
res.is_err(),
|
||||
"creating table in public should fail for test_user"
|
||||
@@ -1214,4 +1229,249 @@ mod tests {
|
||||
let t = PgSqlTemplateFactory::format_table_ident(Some(""), "test_table");
|
||||
assert_eq!(t, "\"test_table\"");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auto_create_schema_enabled() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_postgres_integration_test!();
|
||||
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
|
||||
let mut cfg = Config::new();
|
||||
cfg.url = Some(endpoints);
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.context(CreatePostgresPoolSnafu)
|
||||
.unwrap();
|
||||
|
||||
let schema_name = "test_auto_create_enabled";
|
||||
let table_name = "test_table";
|
||||
|
||||
// Drop the schema if it exists to start clean
|
||||
let client = pool.get().await.unwrap();
|
||||
let _ = client
|
||||
.execute(
|
||||
&format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
|
||||
&[],
|
||||
)
|
||||
.await;
|
||||
|
||||
// Create store with auto_create_schema enabled
|
||||
let _ = PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify schema was created
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1",
|
||||
&[&schema_name],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let created_schema: String = row.get(0);
|
||||
assert_eq!(created_schema, schema_name);
|
||||
|
||||
// Verify table was created in the schema
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
|
||||
&[&schema_name, &table_name],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let created_table_schema: String = row.get(0);
|
||||
let created_table_name: String = row.get(1);
|
||||
assert_eq!(created_table_schema, schema_name);
|
||||
assert_eq!(created_table_name, table_name);
|
||||
|
||||
// Cleanup
|
||||
let _ = client
|
||||
.execute(
|
||||
&format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
|
||||
&[],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auto_create_schema_disabled() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_postgres_integration_test!();
|
||||
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
|
||||
let mut cfg = Config::new();
|
||||
cfg.url = Some(endpoints);
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.context(CreatePostgresPoolSnafu)
|
||||
.unwrap();
|
||||
|
||||
let schema_name = "test_auto_create_disabled";
|
||||
let table_name = "test_table";
|
||||
|
||||
// Drop the schema if it exists to start clean
|
||||
let client = pool.get().await.unwrap();
|
||||
let _ = client
|
||||
.execute(
|
||||
&format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
|
||||
&[],
|
||||
)
|
||||
.await;
|
||||
|
||||
// Try to create store with auto_create_schema disabled (should fail)
|
||||
let result =
|
||||
PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, false).await;
|
||||
|
||||
// Verify it failed because schema doesn't exist
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Expected error when schema doesn't exist and auto_create_schema is disabled"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auto_create_schema_already_exists() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_postgres_integration_test!();
|
||||
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
|
||||
let mut cfg = Config::new();
|
||||
cfg.url = Some(endpoints);
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.context(CreatePostgresPoolSnafu)
|
||||
.unwrap();
|
||||
|
||||
let schema_name = "test_auto_create_existing";
|
||||
let table_name = "test_table";
|
||||
|
||||
// Manually create the schema first
|
||||
let client = pool.get().await.unwrap();
|
||||
let _ = client
|
||||
.execute(
|
||||
&format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
|
||||
&[],
|
||||
)
|
||||
.await;
|
||||
client
|
||||
.execute(&format!("CREATE SCHEMA \"{}\"", schema_name), &[])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Create store with auto_create_schema enabled (should succeed idempotently)
|
||||
let _ = PgStore::with_pg_pool(pool.clone(), Some(schema_name), table_name, 128, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify schema still exists
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = $1",
|
||||
&[&schema_name],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let created_schema: String = row.get(0);
|
||||
assert_eq!(created_schema, schema_name);
|
||||
|
||||
// Verify table was created in the schema
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2",
|
||||
&[&schema_name, &table_name],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let created_table_schema: String = row.get(0);
|
||||
let created_table_name: String = row.get(1);
|
||||
assert_eq!(created_table_schema, schema_name);
|
||||
assert_eq!(created_table_name, table_name);
|
||||
|
||||
// Cleanup
|
||||
let _ = client
|
||||
.execute(
|
||||
&format!("DROP SCHEMA IF EXISTS \"{}\" CASCADE", schema_name),
|
||||
&[],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auto_create_schema_no_schema_name() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_postgres_integration_test!();
|
||||
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
|
||||
let mut cfg = Config::new();
|
||||
cfg.url = Some(endpoints);
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.context(CreatePostgresPoolSnafu)
|
||||
.unwrap();
|
||||
|
||||
let table_name = "test_table_no_schema";
|
||||
|
||||
// Create store with auto_create_schema enabled but no schema name (should succeed)
|
||||
// This should create the table in the default schema (public)
|
||||
let _ = PgStore::with_pg_pool(pool.clone(), None, table_name, 128, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify table was created in public schema
|
||||
let client = pool.get().await.unwrap();
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT table_schema, table_name FROM information_schema.tables WHERE table_name = $1",
|
||||
&[&table_name],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let created_table_schema: String = row.get(0);
|
||||
let created_table_name: String = row.get(1);
|
||||
assert_eq!(created_table_name, table_name);
|
||||
// Verify it's in public schema (or whichever is the default)
|
||||
assert!(created_table_schema == "public" || !created_table_schema.is_empty());
|
||||
|
||||
// Cleanup
|
||||
let _ = client
|
||||
.execute(&format!("DROP TABLE IF EXISTS \"{}\"", table_name), &[])
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auto_create_schema_with_empty_schema_name() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
maybe_skip_postgres_integration_test!();
|
||||
let endpoints = std::env::var("GT_POSTGRES_ENDPOINTS").unwrap();
|
||||
let mut cfg = Config::new();
|
||||
cfg.url = Some(endpoints);
|
||||
let pool = cfg
|
||||
.create_pool(Some(Runtime::Tokio1), NoTls)
|
||||
.context(CreatePostgresPoolSnafu)
|
||||
.unwrap();
|
||||
|
||||
let table_name = "test_table_empty_schema";
|
||||
|
||||
// Create store with auto_create_schema enabled but empty schema name (should succeed)
|
||||
// This should create the table in the default schema (public)
|
||||
let _ = PgStore::with_pg_pool(pool.clone(), Some(""), table_name, 128, true)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify table was created in public schema
|
||||
let client = pool.get().await.unwrap();
|
||||
let row = client
|
||||
.query_one(
|
||||
"SELECT table_schema, table_name FROM information_schema.tables WHERE table_name = $1",
|
||||
&[&table_name],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let created_table_schema: String = row.get(0);
|
||||
let created_table_name: String = row.get(1);
|
||||
assert_eq!(created_table_name, table_name);
|
||||
// Verify it's in public schema (or whichever is the default)
|
||||
assert!(created_table_schema == "public" || !created_table_schema.is_empty());
|
||||
|
||||
// Cleanup
|
||||
let _ = client
|
||||
.execute(&format!("DROP TABLE IF EXISTS \"{}\"", table_name), &[])
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,19 +23,20 @@ use api::v1::alter_database_expr::Kind as PbAlterDatabaseKind;
|
||||
use api::v1::meta::ddl_task_request::Task;
|
||||
use api::v1::meta::{
|
||||
AlterDatabaseTask as PbAlterDatabaseTask, AlterTableTask as PbAlterTableTask,
|
||||
AlterTableTasks as PbAlterTableTasks, CreateDatabaseTask as PbCreateDatabaseTask,
|
||||
CreateFlowTask as PbCreateFlowTask, CreateTableTask as PbCreateTableTask,
|
||||
CreateTableTasks as PbCreateTableTasks, CreateViewTask as PbCreateViewTask,
|
||||
DdlTaskRequest as PbDdlTaskRequest, DdlTaskResponse as PbDdlTaskResponse,
|
||||
DropDatabaseTask as PbDropDatabaseTask, DropFlowTask as PbDropFlowTask,
|
||||
DropTableTask as PbDropTableTask, DropTableTasks as PbDropTableTasks,
|
||||
DropViewTask as PbDropViewTask, Partition, ProcedureId,
|
||||
AlterTableTasks as PbAlterTableTasks, CommentOnTask as PbCommentOnTask,
|
||||
CreateDatabaseTask as PbCreateDatabaseTask, CreateFlowTask as PbCreateFlowTask,
|
||||
CreateTableTask as PbCreateTableTask, CreateTableTasks as PbCreateTableTasks,
|
||||
CreateViewTask as PbCreateViewTask, DdlTaskRequest as PbDdlTaskRequest,
|
||||
DdlTaskResponse as PbDdlTaskResponse, DropDatabaseTask as PbDropDatabaseTask,
|
||||
DropFlowTask as PbDropFlowTask, DropTableTask as PbDropTableTask,
|
||||
DropTableTasks as PbDropTableTasks, DropViewTask as PbDropViewTask, Partition, ProcedureId,
|
||||
TruncateTableTask as PbTruncateTableTask,
|
||||
};
|
||||
use api::v1::{
|
||||
AlterDatabaseExpr, AlterTableExpr, CreateDatabaseExpr, CreateFlowExpr, CreateTableExpr,
|
||||
CreateViewExpr, DropDatabaseExpr, DropFlowExpr, DropTableExpr, DropViewExpr, EvalInterval,
|
||||
ExpireAfter, Option as PbOption, QueryContext as PbQueryContext, TruncateTableExpr,
|
||||
AlterDatabaseExpr, AlterTableExpr, CommentObjectType as PbCommentObjectType, CommentOnExpr,
|
||||
CreateDatabaseExpr, CreateFlowExpr, CreateTableExpr, CreateViewExpr, DropDatabaseExpr,
|
||||
DropFlowExpr, DropTableExpr, DropViewExpr, EvalInterval, ExpireAfter, Option as PbOption,
|
||||
QueryContext as PbQueryContext, TruncateTableExpr,
|
||||
};
|
||||
use base64::Engine as _;
|
||||
use base64::engine::general_purpose;
|
||||
@@ -78,6 +79,7 @@ pub enum DdlTask {
|
||||
DropView(DropViewTask),
|
||||
#[cfg(feature = "enterprise")]
|
||||
CreateTrigger(trigger::CreateTriggerTask),
|
||||
CommentOn(CommentOnTask),
|
||||
}
|
||||
|
||||
impl DdlTask {
|
||||
@@ -200,6 +202,11 @@ impl DdlTask {
|
||||
view_info,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a [`DdlTask`] to comment on a table, column, or flow.
|
||||
pub fn new_comment_on(task: CommentOnTask) -> Self {
|
||||
DdlTask::CommentOn(task)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<Task> for DdlTask {
|
||||
@@ -278,6 +285,7 @@ impl TryFrom<Task> for DdlTask {
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
Task::CommentOnTask(comment_on) => Ok(DdlTask::CommentOn(comment_on.try_into()?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -332,6 +340,7 @@ impl TryFrom<SubmitDdlTaskRequest> for PbDdlTaskRequest {
|
||||
DdlTask::CreateTrigger(task) => Task::CreateTriggerTask(task.try_into()?),
|
||||
#[cfg(feature = "enterprise")]
|
||||
DdlTask::DropTrigger(task) => Task::DropTriggerTask(task.into()),
|
||||
DdlTask::CommentOn(task) => Task::CommentOnTask(task.into()),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
@@ -1277,6 +1286,119 @@ impl From<DropFlowTask> for PbDropFlowTask {
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the ID of the object being commented on (Table or Flow).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum CommentObjectId {
|
||||
Table(TableId),
|
||||
Flow(FlowId),
|
||||
}
|
||||
|
||||
/// Comment on table, column, or flow
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct CommentOnTask {
|
||||
pub catalog_name: String,
|
||||
pub schema_name: String,
|
||||
pub object_type: CommentObjectType,
|
||||
pub object_name: String,
|
||||
/// Column name (only for Column comments)
|
||||
pub column_name: Option<String>,
|
||||
/// Object ID (Table or Flow) for validation and cache invalidation
|
||||
pub object_id: Option<CommentObjectId>,
|
||||
pub comment: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum CommentObjectType {
|
||||
Table,
|
||||
Column,
|
||||
Flow,
|
||||
}
|
||||
|
||||
impl CommentOnTask {
|
||||
pub fn table_ref(&self) -> TableReference<'_> {
|
||||
TableReference {
|
||||
catalog: &self.catalog_name,
|
||||
schema: &self.schema_name,
|
||||
table: &self.object_name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Proto conversions for CommentObjectType
|
||||
impl From<CommentObjectType> for PbCommentObjectType {
|
||||
fn from(object_type: CommentObjectType) -> Self {
|
||||
match object_type {
|
||||
CommentObjectType::Table => PbCommentObjectType::Table,
|
||||
CommentObjectType::Column => PbCommentObjectType::Column,
|
||||
CommentObjectType::Flow => PbCommentObjectType::Flow,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<i32> for CommentObjectType {
|
||||
type Error = error::Error;
|
||||
|
||||
fn try_from(value: i32) -> Result<Self> {
|
||||
match value {
|
||||
0 => Ok(CommentObjectType::Table),
|
||||
1 => Ok(CommentObjectType::Column),
|
||||
2 => Ok(CommentObjectType::Flow),
|
||||
_ => error::InvalidProtoMsgSnafu {
|
||||
err_msg: format!(
|
||||
"Invalid CommentObjectType value: {}. Valid values are: 0 (Table), 1 (Column), 2 (Flow)",
|
||||
value
|
||||
),
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Proto conversions for CommentOnTask
|
||||
impl TryFrom<PbCommentOnTask> for CommentOnTask {
|
||||
type Error = error::Error;
|
||||
|
||||
fn try_from(pb: PbCommentOnTask) -> Result<Self> {
|
||||
let comment_on = pb.comment_on.context(error::InvalidProtoMsgSnafu {
|
||||
err_msg: "expected comment_on",
|
||||
})?;
|
||||
|
||||
Ok(CommentOnTask {
|
||||
catalog_name: comment_on.catalog_name,
|
||||
schema_name: comment_on.schema_name,
|
||||
object_type: comment_on.object_type.try_into()?,
|
||||
object_name: comment_on.object_name,
|
||||
column_name: if comment_on.column_name.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(comment_on.column_name)
|
||||
},
|
||||
comment: if comment_on.comment.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(comment_on.comment)
|
||||
},
|
||||
object_id: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CommentOnTask> for PbCommentOnTask {
|
||||
fn from(task: CommentOnTask) -> Self {
|
||||
let pb_object_type: PbCommentObjectType = task.object_type.into();
|
||||
PbCommentOnTask {
|
||||
comment_on: Some(CommentOnExpr {
|
||||
catalog_name: task.catalog_name,
|
||||
schema_name: task.schema_name,
|
||||
object_type: pb_object_type as i32,
|
||||
object_name: task.object_name,
|
||||
column_name: task.column_name.unwrap_or_default(),
|
||||
comment: task.comment.unwrap_or_default(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct QueryContext {
|
||||
pub(crate) current_catalog: String,
|
||||
|
||||
@@ -205,11 +205,14 @@ impl KafkaTopicCreator {
|
||||
self.partition_client(topic).await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a kafka [Client](rskafka::client::Client).
|
||||
pub async fn build_kafka_client(connection: &KafkaConnectionConfig) -> Result<Client> {
|
||||
// Builds an kafka controller client for creating topics.
|
||||
let mut builder = ClientBuilder::new(connection.broker_endpoints.clone())
|
||||
.backoff_config(DEFAULT_BACKOFF_CONFIG);
|
||||
.backoff_config(DEFAULT_BACKOFF_CONFIG)
|
||||
.connect_timeout(Some(connection.connect_timeout))
|
||||
.timeout(Some(connection.timeout));
|
||||
if let Some(sasl) = &connection.sasl {
|
||||
builder = builder.sasl_config(sasl.config.clone().into_sasl_config());
|
||||
};
|
||||
|
||||
@@ -246,14 +246,6 @@ pub enum Error {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Loader for {type_name} is not implemented: {reason}"))]
|
||||
ProcedureLoaderNotImplemented {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
type_name: String,
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -274,8 +266,7 @@ impl ErrorExt for Error {
|
||||
Error::ToJson { .. }
|
||||
| Error::DeleteState { .. }
|
||||
| Error::FromJson { .. }
|
||||
| Error::WaitWatcher { .. }
|
||||
| Error::ProcedureLoaderNotImplemented { .. } => StatusCode::Internal,
|
||||
| Error::WaitWatcher { .. } => StatusCode::Internal,
|
||||
|
||||
Error::RetryTimesExceeded { .. }
|
||||
| Error::RollbackTimesExceeded { .. }
|
||||
|
||||
@@ -331,8 +331,29 @@ impl Runner {
|
||||
}
|
||||
|
||||
match status {
|
||||
Status::Executing { .. } => {}
|
||||
Status::Executing { .. } => {
|
||||
let prev_state = self.meta.state();
|
||||
if !matches!(prev_state, ProcedureState::Running) {
|
||||
info!(
|
||||
"Set Procedure {}-{} state to running, prev_state: {:?}",
|
||||
self.procedure.type_name(),
|
||||
self.meta.id,
|
||||
prev_state
|
||||
);
|
||||
self.meta.set_state(ProcedureState::Running);
|
||||
}
|
||||
}
|
||||
Status::Suspended { subprocedures, .. } => {
|
||||
let prev_state = self.meta.state();
|
||||
if !matches!(prev_state, ProcedureState::Running) {
|
||||
info!(
|
||||
"Set Procedure {}-{} state to running, prev_state: {:?}",
|
||||
self.procedure.type_name(),
|
||||
self.meta.id,
|
||||
prev_state
|
||||
);
|
||||
self.meta.set_state(ProcedureState::Running);
|
||||
}
|
||||
self.on_suspended(subprocedures).await;
|
||||
}
|
||||
Status::Done { output } => {
|
||||
@@ -393,8 +414,12 @@ impl Runner {
|
||||
return;
|
||||
}
|
||||
|
||||
self.meta
|
||||
.set_state(ProcedureState::prepare_rollback(Arc::new(e)));
|
||||
if self.procedure.rollback_supported() {
|
||||
self.meta
|
||||
.set_state(ProcedureState::prepare_rollback(Arc::new(e)));
|
||||
} else {
|
||||
self.meta.set_state(ProcedureState::failed(Arc::new(e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1080,20 +1105,10 @@ mod tests {
|
||||
let mut runner = new_runner(meta.clone(), Box::new(fail), procedure_store.clone());
|
||||
runner.manager_ctx.start();
|
||||
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_prepare_rollback(), "{state:?}");
|
||||
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_failed(), "{state:?}");
|
||||
check_files(
|
||||
&object_store,
|
||||
&procedure_store,
|
||||
ctx.procedure_id,
|
||||
&["0000000000.rollback"],
|
||||
)
|
||||
.await;
|
||||
check_files(&object_store, &procedure_store, ctx.procedure_id, &[]).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1146,6 +1161,8 @@ mod tests {
|
||||
async move {
|
||||
if times == 1 {
|
||||
Err(Error::retry_later(MockError::new(StatusCode::Unexpected)))
|
||||
} else if times == 2 {
|
||||
Ok(Status::executing(false))
|
||||
} else {
|
||||
Ok(Status::done())
|
||||
}
|
||||
@@ -1172,6 +1189,10 @@ mod tests {
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_retrying(), "{state:?}");
|
||||
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_running(), "{state:?}");
|
||||
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_done(), "{state:?}");
|
||||
@@ -1185,6 +1206,86 @@ mod tests {
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_execute_on_retry_later_error_with_child() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let mut times = 0;
|
||||
let child_id = ProcedureId::random();
|
||||
|
||||
let exec_fn = move |_| {
|
||||
times += 1;
|
||||
async move {
|
||||
debug!("times: {}", times);
|
||||
if times == 1 {
|
||||
Err(Error::retry_later(MockError::new(StatusCode::Unexpected)))
|
||||
} else if times == 2 {
|
||||
let exec_fn = |_| {
|
||||
async { Err(Error::external(MockError::new(StatusCode::Unexpected))) }
|
||||
.boxed()
|
||||
};
|
||||
let fail = ProcedureAdapter {
|
||||
data: "fail".to_string(),
|
||||
lock_key: LockKey::single_exclusive("catalog.schema.table.region-0"),
|
||||
poison_keys: PoisonKeys::default(),
|
||||
exec_fn,
|
||||
rollback_fn: None,
|
||||
};
|
||||
|
||||
Ok(Status::Suspended {
|
||||
subprocedures: vec![ProcedureWithId {
|
||||
id: child_id,
|
||||
procedure: Box::new(fail),
|
||||
}],
|
||||
persist: true,
|
||||
})
|
||||
} else {
|
||||
Ok(Status::done())
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
};
|
||||
|
||||
let retry_later = ProcedureAdapter {
|
||||
data: "retry_later".to_string(),
|
||||
lock_key: LockKey::single_exclusive("catalog.schema.table"),
|
||||
poison_keys: PoisonKeys::default(),
|
||||
exec_fn,
|
||||
rollback_fn: None,
|
||||
};
|
||||
|
||||
let dir = create_temp_dir("retry_later");
|
||||
let meta = retry_later.new_meta(ROOT_ID);
|
||||
let ctx = context_without_provider(meta.id);
|
||||
let object_store = test_util::new_object_store(&dir);
|
||||
let procedure_store = Arc::new(ProcedureStore::from_object_store(object_store.clone()));
|
||||
let mut runner = new_runner(meta.clone(), Box::new(retry_later), procedure_store.clone());
|
||||
runner.manager_ctx.start();
|
||||
debug!("execute_once 1");
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_retrying(), "{state:?}");
|
||||
|
||||
let moved_meta = meta.clone();
|
||||
tokio::spawn(async move {
|
||||
moved_meta.child_notify.notify_one();
|
||||
});
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_running(), "{state:?}");
|
||||
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_done(), "{state:?}");
|
||||
assert!(meta.state().is_done());
|
||||
check_files(
|
||||
&object_store,
|
||||
&procedure_store,
|
||||
ctx.procedure_id,
|
||||
&["0000000000.step", "0000000001.commit"],
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_exceed_max_retry_later() {
|
||||
let exec_fn =
|
||||
@@ -1304,7 +1405,7 @@ mod tests {
|
||||
async fn test_child_error() {
|
||||
let mut times = 0;
|
||||
let child_id = ProcedureId::random();
|
||||
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let exec_fn = move |ctx: Context| {
|
||||
times += 1;
|
||||
async move {
|
||||
@@ -1529,7 +1630,7 @@ mod tests {
|
||||
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_prepare_rollback(), "{state:?}");
|
||||
assert!(state.is_failed(), "{state:?}");
|
||||
|
||||
let procedure_id = runner
|
||||
.manager_ctx
|
||||
@@ -1596,11 +1697,6 @@ mod tests {
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_running(), "{state:?}");
|
||||
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_prepare_rollback(), "{state:?}");
|
||||
assert!(meta.state().is_prepare_rollback());
|
||||
|
||||
runner.execute_once(&ctx).await;
|
||||
let state = runner.meta.state();
|
||||
assert!(state.is_failed(), "{state:?}");
|
||||
|
||||
@@ -46,6 +46,22 @@ pub enum OutputData {
|
||||
Stream(SendableRecordBatchStream),
|
||||
}
|
||||
|
||||
impl OutputData {
|
||||
/// Consume the data to pretty printed string.
|
||||
pub async fn pretty_print(self) -> String {
|
||||
match self {
|
||||
OutputData::AffectedRows(x) => {
|
||||
format!("Affected Rows: {x}")
|
||||
}
|
||||
OutputData::RecordBatches(x) => x.pretty_print().unwrap_or_else(|e| e.to_string()),
|
||||
OutputData::Stream(x) => common_recordbatch::util::collect_batches(x)
|
||||
.await
|
||||
.and_then(|x| x.pretty_print())
|
||||
.unwrap_or_else(|e| e.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// OutputMeta stores meta information produced/generated during the execution
|
||||
#[derive(Debug, Default)]
|
||||
pub struct OutputMeta {
|
||||
|
||||
@@ -5,10 +5,12 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
arrow-schema.workspace = true
|
||||
common-base.workspace = true
|
||||
common-decimal.workspace = true
|
||||
common-error.workspace = true
|
||||
common-macro.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
common-time.workspace = true
|
||||
datafusion-sql.workspace = true
|
||||
datatypes.workspace = true
|
||||
|
||||
@@ -14,11 +14,12 @@
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use arrow_schema::extension::ExtensionType;
|
||||
use common_time::Timestamp;
|
||||
use common_time::timezone::Timezone;
|
||||
use datatypes::json::JsonStructureSettings;
|
||||
use datatypes::extension::json::JsonExtensionType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::ColumnDefaultConstraint;
|
||||
use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema};
|
||||
use datatypes::types::{JsonFormat, parse_string_to_jsonb, parse_string_to_vector_type_value};
|
||||
use datatypes::value::{OrderedF32, OrderedF64, Value};
|
||||
use snafu::{OptionExt, ResultExt, ensure};
|
||||
@@ -124,13 +125,14 @@ pub(crate) fn sql_number_to_value(data_type: &ConcreteDataType, n: &str) -> Resu
|
||||
/// If `auto_string_to_numeric` is true, tries to cast the string value to numeric values,
|
||||
/// and returns error if the cast fails.
|
||||
pub fn sql_value_to_value(
|
||||
column_name: &str,
|
||||
data_type: &ConcreteDataType,
|
||||
column_schema: &ColumnSchema,
|
||||
sql_val: &SqlValue,
|
||||
timezone: Option<&Timezone>,
|
||||
unary_op: Option<UnaryOperator>,
|
||||
auto_string_to_numeric: bool,
|
||||
) -> Result<Value> {
|
||||
let column_name = &column_schema.name;
|
||||
let data_type = &column_schema.data_type;
|
||||
let mut value = match sql_val {
|
||||
SqlValue::Number(n, _) => sql_number_to_value(data_type, n)?,
|
||||
SqlValue::Null => Value::Null,
|
||||
@@ -146,13 +148,9 @@ pub fn sql_value_to_value(
|
||||
|
||||
(*b).into()
|
||||
}
|
||||
SqlValue::DoubleQuotedString(s) | SqlValue::SingleQuotedString(s) => parse_string_to_value(
|
||||
column_name,
|
||||
s.clone(),
|
||||
data_type,
|
||||
timezone,
|
||||
auto_string_to_numeric,
|
||||
)?,
|
||||
SqlValue::DoubleQuotedString(s) | SqlValue::SingleQuotedString(s) => {
|
||||
parse_string_to_value(column_schema, s.clone(), timezone, auto_string_to_numeric)?
|
||||
}
|
||||
SqlValue::HexStringLiteral(s) => {
|
||||
// Should not directly write binary into json column
|
||||
ensure!(
|
||||
@@ -244,12 +242,12 @@ pub fn sql_value_to_value(
|
||||
}
|
||||
|
||||
pub(crate) fn parse_string_to_value(
|
||||
column_name: &str,
|
||||
column_schema: &ColumnSchema,
|
||||
s: String,
|
||||
data_type: &ConcreteDataType,
|
||||
timezone: Option<&Timezone>,
|
||||
auto_string_to_numeric: bool,
|
||||
) -> Result<Value> {
|
||||
let data_type = &column_schema.data_type;
|
||||
if auto_string_to_numeric && let Some(value) = auto_cast_to_numeric(&s, data_type)? {
|
||||
return Ok(value);
|
||||
}
|
||||
@@ -257,7 +255,7 @@ pub(crate) fn parse_string_to_value(
|
||||
ensure!(
|
||||
data_type.is_stringifiable(),
|
||||
ColumnTypeMismatchSnafu {
|
||||
column_name,
|
||||
column_name: column_schema.name.clone(),
|
||||
expect: data_type.clone(),
|
||||
actual: ConcreteDataType::string_datatype(),
|
||||
}
|
||||
@@ -303,23 +301,21 @@ pub(crate) fn parse_string_to_value(
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Binary(_) => Ok(Value::Binary(s.as_bytes().into())),
|
||||
ConcreteDataType::Json(j) => {
|
||||
match &j.format {
|
||||
JsonFormat::Jsonb => {
|
||||
let v = parse_string_to_jsonb(&s).context(DatatypeSnafu)?;
|
||||
Ok(Value::Binary(v.into()))
|
||||
}
|
||||
JsonFormat::Native(_inner) => {
|
||||
// Always use the structured version at this level.
|
||||
let serde_json_value =
|
||||
serde_json::from_str(&s).context(DeserializeSnafu { json: s })?;
|
||||
let json_structure_settings = JsonStructureSettings::Structured(None);
|
||||
json_structure_settings
|
||||
.encode(serde_json_value)
|
||||
.context(DatatypeSnafu)
|
||||
}
|
||||
ConcreteDataType::Json(j) => match &j.format {
|
||||
JsonFormat::Jsonb => {
|
||||
let v = parse_string_to_jsonb(&s).context(DatatypeSnafu)?;
|
||||
Ok(Value::Binary(v.into()))
|
||||
}
|
||||
}
|
||||
JsonFormat::Native(_) => {
|
||||
let extension_type: Option<JsonExtensionType> =
|
||||
column_schema.extension_type().context(DatatypeSnafu)?;
|
||||
let json_structure_settings = extension_type
|
||||
.and_then(|x| x.metadata().json_structure_settings.clone())
|
||||
.unwrap_or_default();
|
||||
let v = serde_json::from_str(&s).context(DeserializeSnafu { json: s })?;
|
||||
json_structure_settings.encode(v).context(DatatypeSnafu)
|
||||
}
|
||||
},
|
||||
ConcreteDataType::Vector(d) => {
|
||||
let v = parse_string_to_vector_type_value(&s, Some(d.dim)).context(DatatypeSnafu)?;
|
||||
Ok(Value::Binary(v.into()))
|
||||
@@ -417,305 +413,265 @@ mod test {
|
||||
|
||||
use super::*;
|
||||
|
||||
macro_rules! call_parse_string_to_value {
|
||||
($column_name: expr, $input: expr, $data_type: expr) => {
|
||||
call_parse_string_to_value!($column_name, $input, $data_type, None)
|
||||
};
|
||||
($column_name: expr, $input: expr, $data_type: expr, timezone = $timezone: expr) => {
|
||||
call_parse_string_to_value!($column_name, $input, $data_type, Some($timezone))
|
||||
};
|
||||
($column_name: expr, $input: expr, $data_type: expr, $timezone: expr) => {{
|
||||
let column_schema = ColumnSchema::new($column_name, $data_type, true);
|
||||
parse_string_to_value(&column_schema, $input, $timezone, true)
|
||||
}};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_string_to_value_auto_numeric() {
|
||||
fn test_string_to_value_auto_numeric() -> Result<()> {
|
||||
// Test string to boolean with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"true".to_string(),
|
||||
&ConcreteDataType::boolean_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::boolean_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::Boolean(true), result);
|
||||
|
||||
// Test invalid string to boolean with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_a_boolean".to_string(),
|
||||
&ConcreteDataType::boolean_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::boolean_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to int8
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"42".to_string(),
|
||||
&ConcreteDataType::int8_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::int8_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::Int8(42), result);
|
||||
|
||||
// Test invalid string to int8 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_an_int8".to_string(),
|
||||
&ConcreteDataType::int8_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::int8_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to int16
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"1000".to_string(),
|
||||
&ConcreteDataType::int16_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::int16_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::Int16(1000), result);
|
||||
|
||||
// Test invalid string to int16 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_an_int16".to_string(),
|
||||
&ConcreteDataType::int16_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::int16_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to int32
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"100000".to_string(),
|
||||
&ConcreteDataType::int32_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::int32_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::Int32(100000), result);
|
||||
|
||||
// Test invalid string to int32 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_an_int32".to_string(),
|
||||
&ConcreteDataType::int32_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::int32_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to int64
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"1000000".to_string(),
|
||||
&ConcreteDataType::int64_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::int64_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::Int64(1000000), result);
|
||||
|
||||
// Test invalid string to int64 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_an_int64".to_string(),
|
||||
&ConcreteDataType::int64_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::int64_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to uint8
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"200".to_string(),
|
||||
&ConcreteDataType::uint8_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::uint8_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::UInt8(200), result);
|
||||
|
||||
// Test invalid string to uint8 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_a_uint8".to_string(),
|
||||
&ConcreteDataType::uint8_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::uint8_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to uint16
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"60000".to_string(),
|
||||
&ConcreteDataType::uint16_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::uint16_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::UInt16(60000), result);
|
||||
|
||||
// Test invalid string to uint16 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_a_uint16".to_string(),
|
||||
&ConcreteDataType::uint16_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::uint16_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to uint32
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"4000000000".to_string(),
|
||||
&ConcreteDataType::uint32_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::uint32_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::UInt32(4000000000), result);
|
||||
|
||||
// Test invalid string to uint32 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_a_uint32".to_string(),
|
||||
&ConcreteDataType::uint32_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::uint32_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to uint64
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"18446744073709551615".to_string(),
|
||||
&ConcreteDataType::uint64_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::uint64_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::UInt64(18446744073709551615), result);
|
||||
|
||||
// Test invalid string to uint64 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_a_uint64".to_string(),
|
||||
&ConcreteDataType::uint64_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::uint64_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to float32
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"3.5".to_string(),
|
||||
&ConcreteDataType::float32_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::float32_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::Float32(OrderedF32::from(3.5)), result);
|
||||
|
||||
// Test invalid string to float32 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_a_float32".to_string(),
|
||||
&ConcreteDataType::float32_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::float32_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Test string to float64
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"3.5".to_string(),
|
||||
&ConcreteDataType::float64_datatype(),
|
||||
None,
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
ConcreteDataType::float64_datatype()
|
||||
)?;
|
||||
assert_eq!(Value::Float64(OrderedF64::from(3.5)), result);
|
||||
|
||||
// Test invalid string to float64 with auto cast
|
||||
let result = parse_string_to_value(
|
||||
let result = call_parse_string_to_value!(
|
||||
"col",
|
||||
"not_a_float64".to_string(),
|
||||
&ConcreteDataType::float64_datatype(),
|
||||
None,
|
||||
true,
|
||||
ConcreteDataType::float64_datatype()
|
||||
);
|
||||
assert!(result.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sql_value_to_value() {
|
||||
let sql_val = SqlValue::Null;
|
||||
assert_eq!(
|
||||
Value::Null,
|
||||
sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::float64_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
macro_rules! call_sql_value_to_value {
|
||||
($column_name: expr, $data_type: expr, $sql_value: expr) => {
|
||||
call_sql_value_to_value!($column_name, $data_type, $sql_value, None, None, false)
|
||||
};
|
||||
($column_name: expr, $data_type: expr, $sql_value: expr, timezone = $timezone: expr) => {
|
||||
call_sql_value_to_value!(
|
||||
$column_name,
|
||||
$data_type,
|
||||
$sql_value,
|
||||
Some($timezone),
|
||||
None,
|
||||
false
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
($column_name: expr, $data_type: expr, $sql_value: expr, unary_op = $unary_op: expr) => {
|
||||
call_sql_value_to_value!(
|
||||
$column_name,
|
||||
$data_type,
|
||||
$sql_value,
|
||||
None,
|
||||
Some($unary_op),
|
||||
false
|
||||
)
|
||||
};
|
||||
($column_name: expr, $data_type: expr, $sql_value: expr, auto_string_to_numeric) => {
|
||||
call_sql_value_to_value!($column_name, $data_type, $sql_value, None, None, true)
|
||||
};
|
||||
($column_name: expr, $data_type: expr, $sql_value: expr, $timezone: expr, $unary_op: expr, $auto_string_to_numeric: expr) => {{
|
||||
let column_schema = ColumnSchema::new($column_name, $data_type, true);
|
||||
sql_value_to_value(
|
||||
&column_schema,
|
||||
$sql_value,
|
||||
$timezone,
|
||||
$unary_op,
|
||||
$auto_string_to_numeric,
|
||||
)
|
||||
}};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sql_value_to_value() -> Result<()> {
|
||||
let sql_val = SqlValue::Null;
|
||||
assert_eq!(
|
||||
Value::Null,
|
||||
call_sql_value_to_value!("a", ConcreteDataType::float64_datatype(), &sql_val)?
|
||||
);
|
||||
|
||||
let sql_val = SqlValue::Boolean(true);
|
||||
assert_eq!(
|
||||
Value::Boolean(true),
|
||||
sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::boolean_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false
|
||||
)
|
||||
.unwrap()
|
||||
call_sql_value_to_value!("a", ConcreteDataType::boolean_datatype(), &sql_val)?
|
||||
);
|
||||
|
||||
let sql_val = SqlValue::Number("3.0".to_string(), false);
|
||||
assert_eq!(
|
||||
Value::Float64(OrderedFloat(3.0)),
|
||||
sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::float64_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false
|
||||
)
|
||||
.unwrap()
|
||||
call_sql_value_to_value!("a", ConcreteDataType::float64_datatype(), &sql_val)?
|
||||
);
|
||||
|
||||
let sql_val = SqlValue::Number("3.0".to_string(), false);
|
||||
let v = sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::boolean_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
let v = call_sql_value_to_value!("a", ConcreteDataType::boolean_datatype(), &sql_val);
|
||||
assert!(v.is_err());
|
||||
assert!(format!("{v:?}").contains("Failed to parse number '3.0' to boolean column type"));
|
||||
|
||||
let sql_val = SqlValue::Boolean(true);
|
||||
let v = sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::float64_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
let v = call_sql_value_to_value!("a", ConcreteDataType::float64_datatype(), &sql_val);
|
||||
assert!(v.is_err());
|
||||
assert!(
|
||||
format!("{v:?}").contains(
|
||||
@@ -725,41 +681,18 @@ mod test {
|
||||
);
|
||||
|
||||
let sql_val = SqlValue::HexStringLiteral("48656c6c6f20776f726c6421".to_string());
|
||||
let v = sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::binary_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
let v = call_sql_value_to_value!("a", ConcreteDataType::binary_datatype(), &sql_val)?;
|
||||
assert_eq!(Value::Binary(Bytes::from(b"Hello world!".as_slice())), v);
|
||||
|
||||
let sql_val = SqlValue::DoubleQuotedString("MorningMyFriends".to_string());
|
||||
let v = sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::binary_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
let v = call_sql_value_to_value!("a", ConcreteDataType::binary_datatype(), &sql_val)?;
|
||||
assert_eq!(
|
||||
Value::Binary(Bytes::from(b"MorningMyFriends".as_slice())),
|
||||
v
|
||||
);
|
||||
|
||||
let sql_val = SqlValue::HexStringLiteral("9AF".to_string());
|
||||
let v = sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::binary_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
let v = call_sql_value_to_value!("a", ConcreteDataType::binary_datatype(), &sql_val);
|
||||
assert!(v.is_err());
|
||||
assert!(
|
||||
format!("{v:?}").contains("odd number of digits"),
|
||||
@@ -767,38 +700,16 @@ mod test {
|
||||
);
|
||||
|
||||
let sql_val = SqlValue::HexStringLiteral("AG".to_string());
|
||||
let v = sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::binary_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
let v = call_sql_value_to_value!("a", ConcreteDataType::binary_datatype(), &sql_val);
|
||||
assert!(v.is_err());
|
||||
assert!(format!("{v:?}").contains("invalid character"), "v is {v:?}",);
|
||||
|
||||
let sql_val = SqlValue::DoubleQuotedString("MorningMyFriends".to_string());
|
||||
let v = sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::json_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
let v = call_sql_value_to_value!("a", ConcreteDataType::json_datatype(), &sql_val);
|
||||
assert!(v.is_err());
|
||||
|
||||
let sql_val = SqlValue::DoubleQuotedString(r#"{"a":"b"}"#.to_string());
|
||||
let v = sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::json_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
let v = call_sql_value_to_value!("a", ConcreteDataType::json_datatype(), &sql_val)?;
|
||||
assert_eq!(
|
||||
Value::Binary(Bytes::from(
|
||||
jsonb::parse_value(r#"{"a":"b"}"#.as_bytes())
|
||||
@@ -808,16 +719,15 @@ mod test {
|
||||
)),
|
||||
v
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_json_to_jsonb() {
|
||||
match parse_string_to_value(
|
||||
match call_parse_string_to_value!(
|
||||
"json_col",
|
||||
r#"{"a": "b"}"#.to_string(),
|
||||
&ConcreteDataType::json_datatype(),
|
||||
None,
|
||||
false,
|
||||
ConcreteDataType::json_datatype()
|
||||
) {
|
||||
Ok(Value::Binary(b)) => {
|
||||
assert_eq!(
|
||||
@@ -833,12 +743,10 @@ mod test {
|
||||
}
|
||||
|
||||
assert!(
|
||||
parse_string_to_value(
|
||||
call_parse_string_to_value!(
|
||||
"json_col",
|
||||
r#"Nicola Kovac is the best rifler in the world"#.to_string(),
|
||||
&ConcreteDataType::json_datatype(),
|
||||
None,
|
||||
false,
|
||||
ConcreteDataType::json_datatype()
|
||||
)
|
||||
.is_err()
|
||||
)
|
||||
@@ -878,13 +786,10 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_parse_date_literal() {
|
||||
let value = sql_value_to_value(
|
||||
let value = call_sql_value_to_value!(
|
||||
"date",
|
||||
&ConcreteDataType::date_datatype(),
|
||||
&SqlValue::DoubleQuotedString("2022-02-22".to_string()),
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
ConcreteDataType::date_datatype(),
|
||||
&SqlValue::DoubleQuotedString("2022-02-22".to_string())
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(ConcreteDataType::date_datatype(), value.data_type());
|
||||
@@ -895,13 +800,11 @@ mod test {
|
||||
}
|
||||
|
||||
// with timezone
|
||||
let value = sql_value_to_value(
|
||||
let value = call_sql_value_to_value!(
|
||||
"date",
|
||||
&ConcreteDataType::date_datatype(),
|
||||
ConcreteDataType::date_datatype(),
|
||||
&SqlValue::DoubleQuotedString("2022-02-22".to_string()),
|
||||
Some(&Timezone::from_tz_string("+07:00").unwrap()),
|
||||
None,
|
||||
false,
|
||||
timezone = &Timezone::from_tz_string("+07:00").unwrap()
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(ConcreteDataType::date_datatype(), value.data_type());
|
||||
@@ -913,16 +816,12 @@ mod test {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_timestamp_literal() {
|
||||
match parse_string_to_value(
|
||||
fn test_parse_timestamp_literal() -> Result<()> {
|
||||
match call_parse_string_to_value!(
|
||||
"timestamp_col",
|
||||
"2022-02-22T00:01:01+08:00".to_string(),
|
||||
&ConcreteDataType::timestamp_millisecond_datatype(),
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.unwrap()
|
||||
{
|
||||
ConcreteDataType::timestamp_millisecond_datatype()
|
||||
)? {
|
||||
Value::Timestamp(ts) => {
|
||||
assert_eq!(1645459261000, ts.value());
|
||||
assert_eq!(TimeUnit::Millisecond, ts.unit());
|
||||
@@ -932,15 +831,11 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
match parse_string_to_value(
|
||||
match call_parse_string_to_value!(
|
||||
"timestamp_col",
|
||||
"2022-02-22T00:01:01+08:00".to_string(),
|
||||
&ConcreteDataType::timestamp_datatype(TimeUnit::Second),
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.unwrap()
|
||||
{
|
||||
ConcreteDataType::timestamp_datatype(TimeUnit::Second)
|
||||
)? {
|
||||
Value::Timestamp(ts) => {
|
||||
assert_eq!(1645459261, ts.value());
|
||||
assert_eq!(TimeUnit::Second, ts.unit());
|
||||
@@ -950,15 +845,11 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
match parse_string_to_value(
|
||||
match call_parse_string_to_value!(
|
||||
"timestamp_col",
|
||||
"2022-02-22T00:01:01+08:00".to_string(),
|
||||
&ConcreteDataType::timestamp_datatype(TimeUnit::Microsecond),
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.unwrap()
|
||||
{
|
||||
ConcreteDataType::timestamp_datatype(TimeUnit::Microsecond)
|
||||
)? {
|
||||
Value::Timestamp(ts) => {
|
||||
assert_eq!(1645459261000000, ts.value());
|
||||
assert_eq!(TimeUnit::Microsecond, ts.unit());
|
||||
@@ -968,15 +859,11 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
match parse_string_to_value(
|
||||
match call_parse_string_to_value!(
|
||||
"timestamp_col",
|
||||
"2022-02-22T00:01:01+08:00".to_string(),
|
||||
&ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond),
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.unwrap()
|
||||
{
|
||||
ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond)
|
||||
)? {
|
||||
Value::Timestamp(ts) => {
|
||||
assert_eq!(1645459261000000000, ts.value());
|
||||
assert_eq!(TimeUnit::Nanosecond, ts.unit());
|
||||
@@ -987,26 +874,21 @@ mod test {
|
||||
}
|
||||
|
||||
assert!(
|
||||
parse_string_to_value(
|
||||
call_parse_string_to_value!(
|
||||
"timestamp_col",
|
||||
"2022-02-22T00:01:01+08".to_string(),
|
||||
&ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond),
|
||||
None,
|
||||
false,
|
||||
ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond)
|
||||
)
|
||||
.is_err()
|
||||
);
|
||||
|
||||
// with timezone
|
||||
match parse_string_to_value(
|
||||
match call_parse_string_to_value!(
|
||||
"timestamp_col",
|
||||
"2022-02-22T00:01:01".to_string(),
|
||||
&ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond),
|
||||
Some(&Timezone::from_tz_string("Asia/Shanghai").unwrap()),
|
||||
false,
|
||||
)
|
||||
.unwrap()
|
||||
{
|
||||
ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond),
|
||||
timezone = &Timezone::from_tz_string("Asia/Shanghai").unwrap()
|
||||
)? {
|
||||
Value::Timestamp(ts) => {
|
||||
assert_eq!(1645459261000000000, ts.value());
|
||||
assert_eq!("2022-02-21 16:01:01+0000", ts.to_iso8601_string());
|
||||
@@ -1016,51 +898,42 @@ mod test {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_placeholder_value() {
|
||||
assert!(
|
||||
sql_value_to_value(
|
||||
call_sql_value_to_value!(
|
||||
"test",
|
||||
&ConcreteDataType::string_datatype(),
|
||||
ConcreteDataType::string_datatype(),
|
||||
&SqlValue::Placeholder("default".into())
|
||||
)
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
call_sql_value_to_value!(
|
||||
"test",
|
||||
ConcreteDataType::string_datatype(),
|
||||
&SqlValue::Placeholder("default".into()),
|
||||
None,
|
||||
None,
|
||||
false
|
||||
unary_op = UnaryOperator::Minus
|
||||
)
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
sql_value_to_value(
|
||||
call_sql_value_to_value!(
|
||||
"test",
|
||||
&ConcreteDataType::string_datatype(),
|
||||
&SqlValue::Placeholder("default".into()),
|
||||
None,
|
||||
Some(UnaryOperator::Minus),
|
||||
false
|
||||
)
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
sql_value_to_value(
|
||||
"test",
|
||||
&ConcreteDataType::uint16_datatype(),
|
||||
ConcreteDataType::uint16_datatype(),
|
||||
&SqlValue::Number("3".into(), false),
|
||||
None,
|
||||
Some(UnaryOperator::Minus),
|
||||
false
|
||||
unary_op = UnaryOperator::Minus
|
||||
)
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
sql_value_to_value(
|
||||
call_sql_value_to_value!(
|
||||
"test",
|
||||
&ConcreteDataType::uint16_datatype(),
|
||||
&SqlValue::Number("3".into(), false),
|
||||
None,
|
||||
None,
|
||||
false
|
||||
ConcreteDataType::uint16_datatype(),
|
||||
&SqlValue::Number("3".into(), false)
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -1070,77 +943,60 @@ mod test {
|
||||
fn test_auto_string_to_numeric() {
|
||||
// Test with auto_string_to_numeric=true
|
||||
let sql_val = SqlValue::SingleQuotedString("123".to_string());
|
||||
let v = sql_value_to_value(
|
||||
let v = call_sql_value_to_value!(
|
||||
"a",
|
||||
&ConcreteDataType::int32_datatype(),
|
||||
ConcreteDataType::int32_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
auto_string_to_numeric
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(Value::Int32(123), v);
|
||||
|
||||
// Test with a float string
|
||||
let sql_val = SqlValue::SingleQuotedString("3.5".to_string());
|
||||
let v = sql_value_to_value(
|
||||
let v = call_sql_value_to_value!(
|
||||
"a",
|
||||
&ConcreteDataType::float64_datatype(),
|
||||
ConcreteDataType::float64_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
auto_string_to_numeric
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(Value::Float64(OrderedFloat(3.5)), v);
|
||||
|
||||
// Test with auto_string_to_numeric=false
|
||||
let sql_val = SqlValue::SingleQuotedString("123".to_string());
|
||||
let v = sql_value_to_value(
|
||||
"a",
|
||||
&ConcreteDataType::int32_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
);
|
||||
let v = call_sql_value_to_value!("a", ConcreteDataType::int32_datatype(), &sql_val);
|
||||
assert!(v.is_err());
|
||||
|
||||
// Test with an invalid numeric string but auto_string_to_numeric=true
|
||||
// Should return an error now with the new auto_cast_to_numeric behavior
|
||||
let sql_val = SqlValue::SingleQuotedString("not_a_number".to_string());
|
||||
let v = sql_value_to_value(
|
||||
let v = call_sql_value_to_value!(
|
||||
"a",
|
||||
&ConcreteDataType::int32_datatype(),
|
||||
ConcreteDataType::int32_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
auto_string_to_numeric
|
||||
);
|
||||
assert!(v.is_err());
|
||||
|
||||
// Test with boolean type
|
||||
let sql_val = SqlValue::SingleQuotedString("true".to_string());
|
||||
let v = sql_value_to_value(
|
||||
let v = call_sql_value_to_value!(
|
||||
"a",
|
||||
&ConcreteDataType::boolean_datatype(),
|
||||
ConcreteDataType::boolean_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
auto_string_to_numeric
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(Value::Boolean(true), v);
|
||||
|
||||
// Non-numeric types should still be handled normally
|
||||
let sql_val = SqlValue::SingleQuotedString("hello".to_string());
|
||||
let v = sql_value_to_value(
|
||||
let v = call_sql_value_to_value!(
|
||||
"a",
|
||||
&ConcreteDataType::string_datatype(),
|
||||
ConcreteDataType::string_datatype(),
|
||||
&sql_val,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
auto_string_to_numeric
|
||||
);
|
||||
assert!(v.is_ok());
|
||||
}
|
||||
|
||||
@@ -14,8 +14,8 @@
|
||||
|
||||
use common_time::timezone::Timezone;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::ColumnDefaultConstraint;
|
||||
use datatypes::schema::constraint::{CURRENT_TIMESTAMP, CURRENT_TIMESTAMP_FN};
|
||||
use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema};
|
||||
use snafu::ensure;
|
||||
use sqlparser::ast::ValueWithSpan;
|
||||
pub use sqlparser::ast::{
|
||||
@@ -47,9 +47,12 @@ pub fn parse_column_default_constraint(
|
||||
);
|
||||
|
||||
let default_constraint = match &opt.option {
|
||||
ColumnOption::Default(Expr::Value(v)) => ColumnDefaultConstraint::Value(
|
||||
sql_value_to_value(column_name, data_type, &v.value, timezone, None, false)?,
|
||||
),
|
||||
ColumnOption::Default(Expr::Value(v)) => {
|
||||
let schema = ColumnSchema::new(column_name, data_type.clone(), true);
|
||||
ColumnDefaultConstraint::Value(sql_value_to_value(
|
||||
&schema, &v.value, timezone, None, false,
|
||||
)?)
|
||||
}
|
||||
ColumnOption::Default(Expr::Function(func)) => {
|
||||
let mut func = format!("{func}").to_lowercase();
|
||||
// normalize CURRENT_TIMESTAMP to CURRENT_TIMESTAMP()
|
||||
@@ -80,8 +83,7 @@ pub fn parse_column_default_constraint(
|
||||
|
||||
if let Expr::Value(v) = &**expr {
|
||||
let value = sql_value_to_value(
|
||||
column_name,
|
||||
data_type,
|
||||
&ColumnSchema::new(column_name, data_type.clone(), true),
|
||||
&v.value,
|
||||
timezone,
|
||||
Some(*op),
|
||||
|
||||
@@ -58,10 +58,14 @@ pub fn get_total_memory_bytes() -> i64 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the total CPU cores. The result will be rounded to the nearest integer.
|
||||
/// For example, if the total CPU is 1.5 cores(1500 millicores), the result will be 2.
|
||||
/// Get the total CPU cores. The result will be rounded up to the next integer (ceiling).
|
||||
/// For example, if the total CPU is 1.1 cores (1100 millicores) or 1.5 cores (1500 millicores), the result will be 2.
|
||||
pub fn get_total_cpu_cores() -> usize {
|
||||
((get_total_cpu_millicores() as f64) / 1000.0).round() as usize
|
||||
cpu_cores(get_total_cpu_millicores())
|
||||
}
|
||||
|
||||
fn cpu_cores(cpu_millicores: i64) -> usize {
|
||||
((cpu_millicores as f64) / 1_000.0).ceil() as usize
|
||||
}
|
||||
|
||||
/// Get the total memory in readable size.
|
||||
@@ -178,6 +182,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_get_total_cpu_cores() {
|
||||
assert!(get_total_cpu_cores() > 0);
|
||||
assert_eq!(cpu_cores(1), 1);
|
||||
assert_eq!(cpu_cores(100), 1);
|
||||
assert_eq!(cpu_cores(500), 1);
|
||||
assert_eq!(cpu_cores(1000), 1);
|
||||
assert_eq!(cpu_cores(1100), 2);
|
||||
assert_eq!(cpu_cores(1900), 2);
|
||||
assert_eq!(cpu_cores(10_000), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -71,6 +71,7 @@ pub fn convert_metric_to_write_request(
|
||||
timestamp,
|
||||
}],
|
||||
exemplars: vec![],
|
||||
histograms: vec![],
|
||||
}),
|
||||
MetricType::GAUGE => timeseries.push(TimeSeries {
|
||||
labels: convert_label(m.get_label(), mf_name, None),
|
||||
@@ -79,6 +80,7 @@ pub fn convert_metric_to_write_request(
|
||||
timestamp,
|
||||
}],
|
||||
exemplars: vec![],
|
||||
histograms: vec![],
|
||||
}),
|
||||
MetricType::HISTOGRAM => {
|
||||
let h = m.get_histogram();
|
||||
@@ -97,6 +99,7 @@ pub fn convert_metric_to_write_request(
|
||||
timestamp,
|
||||
}],
|
||||
exemplars: vec![],
|
||||
histograms: vec![],
|
||||
});
|
||||
if upper_bound.is_sign_positive() && upper_bound.is_infinite() {
|
||||
inf_seen = true;
|
||||
@@ -114,6 +117,7 @@ pub fn convert_metric_to_write_request(
|
||||
timestamp,
|
||||
}],
|
||||
exemplars: vec![],
|
||||
histograms: vec![],
|
||||
});
|
||||
}
|
||||
timeseries.push(TimeSeries {
|
||||
@@ -127,6 +131,7 @@ pub fn convert_metric_to_write_request(
|
||||
timestamp,
|
||||
}],
|
||||
exemplars: vec![],
|
||||
histograms: vec![],
|
||||
});
|
||||
timeseries.push(TimeSeries {
|
||||
labels: convert_label(
|
||||
@@ -139,6 +144,7 @@ pub fn convert_metric_to_write_request(
|
||||
timestamp,
|
||||
}],
|
||||
exemplars: vec![],
|
||||
histograms: vec![],
|
||||
});
|
||||
}
|
||||
MetricType::SUMMARY => {
|
||||
@@ -155,6 +161,7 @@ pub fn convert_metric_to_write_request(
|
||||
timestamp,
|
||||
}],
|
||||
exemplars: vec![],
|
||||
histograms: vec![],
|
||||
});
|
||||
}
|
||||
timeseries.push(TimeSeries {
|
||||
@@ -168,6 +175,7 @@ pub fn convert_metric_to_write_request(
|
||||
timestamp,
|
||||
}],
|
||||
exemplars: vec![],
|
||||
histograms: vec![],
|
||||
});
|
||||
timeseries.push(TimeSeries {
|
||||
labels: convert_label(
|
||||
@@ -180,6 +188,7 @@ pub fn convert_metric_to_write_request(
|
||||
timestamp,
|
||||
}],
|
||||
exemplars: vec![],
|
||||
histograms: vec![],
|
||||
});
|
||||
}
|
||||
MetricType::UNTYPED => {
|
||||
@@ -274,7 +283,7 @@ mod test {
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", write_quest.timeseries),
|
||||
r#"[TimeSeries { labels: [Label { name: "__name__", value: "test_counter" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }]"#
|
||||
r#"[TimeSeries { labels: [Label { name: "__name__", value: "test_counter" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }]"#
|
||||
);
|
||||
|
||||
let gauge_opts = Opts::new("test_gauge", "test help")
|
||||
@@ -288,7 +297,7 @@ mod test {
|
||||
let write_quest = convert_metric_to_write_request(mf, None, 0);
|
||||
assert_eq!(
|
||||
format!("{:?}", write_quest.timeseries),
|
||||
r#"[TimeSeries { labels: [Label { name: "__name__", value: "test_gauge" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 42.0, timestamp: 0 }], exemplars: [] }]"#
|
||||
r#"[TimeSeries { labels: [Label { name: "__name__", value: "test_gauge" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 42.0, timestamp: 0 }], exemplars: [], histograms: [] }]"#
|
||||
);
|
||||
}
|
||||
|
||||
@@ -305,20 +314,20 @@ mod test {
|
||||
.iter()
|
||||
.map(|x| format!("{:?}", x))
|
||||
.collect();
|
||||
let ans = r#"TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.005" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.01" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.025" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.05" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.1" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.25" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.5" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "1" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "2.5" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "5" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "10" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "+Inf" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_sum" }, Label { name: "a", value: "1" }], samples: [Sample { value: 0.25, timestamp: 0 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_count" }, Label { name: "a", value: "1" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }"#;
|
||||
let ans = r#"TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.005" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.01" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.025" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.05" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.1" }], samples: [Sample { value: 0.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.25" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "0.5" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "1" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "2.5" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "5" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "10" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_bucket" }, Label { name: "a", value: "1" }, Label { name: "le", value: "+Inf" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_sum" }, Label { name: "a", value: "1" }], samples: [Sample { value: 0.25, timestamp: 0 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_count" }, Label { name: "a", value: "1" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }"#;
|
||||
assert_eq!(write_quest_str.join("\n"), ans);
|
||||
}
|
||||
|
||||
@@ -355,10 +364,10 @@ TimeSeries { labels: [Label { name: "__name__", value: "test_histogram_count" },
|
||||
.iter()
|
||||
.map(|x| format!("{:?}", x))
|
||||
.collect();
|
||||
let ans = r#"TimeSeries { labels: [Label { name: "__name__", value: "test_summary" }, Label { name: "quantile", value: "50" }], samples: [Sample { value: 3.0, timestamp: 20 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_summary" }, Label { name: "quantile", value: "100" }], samples: [Sample { value: 5.0, timestamp: 20 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_summary_sum" }], samples: [Sample { value: 15.0, timestamp: 20 }], exemplars: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_summary_count" }], samples: [Sample { value: 5.0, timestamp: 20 }], exemplars: [] }"#;
|
||||
let ans = r#"TimeSeries { labels: [Label { name: "__name__", value: "test_summary" }, Label { name: "quantile", value: "50" }], samples: [Sample { value: 3.0, timestamp: 20 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_summary" }, Label { name: "quantile", value: "100" }], samples: [Sample { value: 5.0, timestamp: 20 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_summary_sum" }], samples: [Sample { value: 15.0, timestamp: 20 }], exemplars: [], histograms: [] }
|
||||
TimeSeries { labels: [Label { name: "__name__", value: "test_summary_count" }], samples: [Sample { value: 5.0, timestamp: 20 }], exemplars: [], histograms: [] }"#;
|
||||
assert_eq!(write_quest_str.join("\n"), ans);
|
||||
}
|
||||
|
||||
@@ -385,11 +394,11 @@ TimeSeries { labels: [Label { name: "__name__", value: "test_summary_count" }],
|
||||
let write_quest2 = convert_metric_to_write_request(mf, Some(&filter), 0);
|
||||
assert_eq!(
|
||||
format!("{:?}", write_quest1.timeseries),
|
||||
r#"[TimeSeries { labels: [Label { name: "__name__", value: "filter_counter" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [] }, TimeSeries { labels: [Label { name: "__name__", value: "test_counter" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 2.0, timestamp: 0 }], exemplars: [] }]"#
|
||||
r#"[TimeSeries { labels: [Label { name: "__name__", value: "filter_counter" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 1.0, timestamp: 0 }], exemplars: [], histograms: [] }, TimeSeries { labels: [Label { name: "__name__", value: "test_counter" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 2.0, timestamp: 0 }], exemplars: [], histograms: [] }]"#
|
||||
);
|
||||
assert_eq!(
|
||||
format!("{:?}", write_quest2.timeseries),
|
||||
r#"[TimeSeries { labels: [Label { name: "__name__", value: "test_counter" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 2.0, timestamp: 0 }], exemplars: [] }]"#
|
||||
r#"[TimeSeries { labels: [Label { name: "__name__", value: "test_counter" }, Label { name: "a", value: "1" }, Label { name: "b", value: "2" }], samples: [Sample { value: 2.0, timestamp: 0 }], exemplars: [], histograms: [] }]"#
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,6 +206,8 @@ mod tests {
|
||||
client_cert_path: None,
|
||||
client_key_path: None,
|
||||
}),
|
||||
connect_timeout: Duration::from_secs(3),
|
||||
timeout: Duration::from_secs(3),
|
||||
},
|
||||
kafka_topic: KafkaTopicConfig {
|
||||
num_topics: 32,
|
||||
@@ -239,6 +241,8 @@ mod tests {
|
||||
client_cert_path: None,
|
||||
client_key_path: None,
|
||||
}),
|
||||
connect_timeout: Duration::from_secs(3),
|
||||
timeout: Duration::from_secs(3),
|
||||
},
|
||||
max_batch_bytes: ReadableSize::mb(1),
|
||||
consumer_wait_timeout: Duration::from_millis(100),
|
||||
|
||||
@@ -164,6 +164,12 @@ pub struct KafkaConnectionConfig {
|
||||
pub sasl: Option<KafkaClientSasl>,
|
||||
/// Client TLS config
|
||||
pub tls: Option<KafkaClientTls>,
|
||||
/// The connect timeout for kafka client.
|
||||
#[serde(with = "humantime_serde")]
|
||||
pub connect_timeout: Duration,
|
||||
/// The timeout for kafka client.
|
||||
#[serde(with = "humantime_serde")]
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
impl Default for KafkaConnectionConfig {
|
||||
@@ -172,6 +178,8 @@ impl Default for KafkaConnectionConfig {
|
||||
broker_endpoints: vec![BROKER_ENDPOINT.to_string()],
|
||||
sasl: None,
|
||||
tls: None,
|
||||
connect_timeout: Duration::from_secs(3),
|
||||
timeout: Duration::from_secs(3),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ use common_base::Plugins;
|
||||
use common_error::ext::BoxedError;
|
||||
use common_greptimedb_telemetry::GreptimeDBTelemetryTask;
|
||||
use common_meta::cache::{LayeredCacheRegistry, SchemaCacheRef, TableSchemaCacheRef};
|
||||
use common_meta::cache_invalidator::CacheInvalidatorRef;
|
||||
use common_meta::datanode::TopicStatsReporter;
|
||||
use common_meta::key::runtime_switch::RuntimeSwitchManager;
|
||||
use common_meta::key::{SchemaMetadataManager, SchemaMetadataManagerRef};
|
||||
@@ -281,21 +282,11 @@ impl DatanodeBuilder {
|
||||
open_all_regions.await?;
|
||||
}
|
||||
|
||||
let mut resource_stat = ResourceStatImpl::default();
|
||||
resource_stat.start_collect_cpu_usage();
|
||||
|
||||
let heartbeat_task = if let Some(meta_client) = meta_client {
|
||||
Some(
|
||||
HeartbeatTask::try_new(
|
||||
&self.opts,
|
||||
region_server.clone(),
|
||||
meta_client,
|
||||
cache_registry,
|
||||
self.plugins.clone(),
|
||||
Arc::new(resource_stat),
|
||||
)
|
||||
.await?,
|
||||
)
|
||||
let task = self
|
||||
.create_heartbeat_task(®ion_server, meta_client, cache_registry)
|
||||
.await?;
|
||||
Some(task)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -324,6 +315,29 @@ impl DatanodeBuilder {
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_heartbeat_task(
|
||||
&self,
|
||||
region_server: &RegionServer,
|
||||
meta_client: MetaClientRef,
|
||||
cache_invalidator: CacheInvalidatorRef,
|
||||
) -> Result<HeartbeatTask> {
|
||||
let stat = {
|
||||
let mut stat = ResourceStatImpl::default();
|
||||
stat.start_collect_cpu_usage();
|
||||
Arc::new(stat)
|
||||
};
|
||||
|
||||
HeartbeatTask::try_new(
|
||||
&self.opts,
|
||||
region_server.clone(),
|
||||
meta_client,
|
||||
cache_invalidator,
|
||||
self.plugins.clone(),
|
||||
stat,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Builds [ObjectStoreManager] from [StorageConfig].
|
||||
pub async fn build_object_store_manager(cfg: &StorageConfig) -> Result<ObjectStoreManagerRef> {
|
||||
let object_store = store::new_object_store(cfg.store.clone(), &cfg.data_home).await?;
|
||||
|
||||
@@ -410,14 +410,6 @@ pub enum Error {
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to build cache store"))]
|
||||
BuildCacheStore {
|
||||
#[snafu(source)]
|
||||
error: object_store::Error,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Not yet implemented: {what}"))]
|
||||
NotYetImplemented { what: String },
|
||||
}
|
||||
@@ -493,7 +485,6 @@ impl ErrorExt for Error {
|
||||
SerializeJson { .. } => StatusCode::Internal,
|
||||
|
||||
ObjectStore { source, .. } => source.status_code(),
|
||||
BuildCacheStore { .. } => StatusCode::StorageUnavailable,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ use common_meta::datanode::REGION_STATISTIC_KEY;
|
||||
use common_meta::distributed_time_constants::META_KEEP_ALIVE_INTERVAL_SECS;
|
||||
use common_meta::heartbeat::handler::invalidate_table_cache::InvalidateCacheHandler;
|
||||
use common_meta::heartbeat::handler::parse_mailbox_message::ParseMailboxMessageHandler;
|
||||
use common_meta::heartbeat::handler::suspend::SuspendHandler;
|
||||
use common_meta::heartbeat::handler::{
|
||||
HandlerGroupExecutor, HeartbeatResponseHandlerContext, HeartbeatResponseHandlerExecutorRef,
|
||||
};
|
||||
@@ -91,6 +92,7 @@ impl HeartbeatTask {
|
||||
let resp_handler_executor = Arc::new(HandlerGroupExecutor::new(vec![
|
||||
region_alive_keeper.clone(),
|
||||
Arc::new(ParseMailboxMessageHandler),
|
||||
Arc::new(SuspendHandler::new(region_server.suspend_state())),
|
||||
Arc::new(
|
||||
RegionHeartbeatResponseHandler::new(region_server.clone())
|
||||
.with_open_region_parallelism(opts.init_regions_parallelism),
|
||||
|
||||
@@ -24,6 +24,7 @@ use store_api::storage::GcReport;
|
||||
|
||||
mod close_region;
|
||||
mod downgrade_region;
|
||||
mod enter_staging;
|
||||
mod file_ref;
|
||||
mod flush_region;
|
||||
mod gc_worker;
|
||||
@@ -32,6 +33,7 @@ mod upgrade_region;
|
||||
|
||||
use crate::heartbeat::handler::close_region::CloseRegionsHandler;
|
||||
use crate::heartbeat::handler::downgrade_region::DowngradeRegionsHandler;
|
||||
use crate::heartbeat::handler::enter_staging::EnterStagingRegionsHandler;
|
||||
use crate::heartbeat::handler::file_ref::GetFileRefsHandler;
|
||||
use crate::heartbeat::handler::flush_region::FlushRegionsHandler;
|
||||
use crate::heartbeat::handler::gc_worker::GcRegionsHandler;
|
||||
@@ -99,26 +101,33 @@ impl RegionHeartbeatResponseHandler {
|
||||
self
|
||||
}
|
||||
|
||||
fn build_handler(&self, instruction: &Instruction) -> MetaResult<Box<InstructionHandlers>> {
|
||||
fn build_handler(
|
||||
&self,
|
||||
instruction: &Instruction,
|
||||
) -> MetaResult<Option<Box<InstructionHandlers>>> {
|
||||
match instruction {
|
||||
Instruction::CloseRegions(_) => Ok(Box::new(CloseRegionsHandler.into())),
|
||||
Instruction::OpenRegions(_) => Ok(Box::new(
|
||||
Instruction::CloseRegions(_) => Ok(Some(Box::new(CloseRegionsHandler.into()))),
|
||||
Instruction::OpenRegions(_) => Ok(Some(Box::new(
|
||||
OpenRegionsHandler {
|
||||
open_region_parallelism: self.open_region_parallelism,
|
||||
}
|
||||
.into(),
|
||||
)),
|
||||
Instruction::FlushRegions(_) => Ok(Box::new(FlushRegionsHandler.into())),
|
||||
Instruction::DowngradeRegions(_) => Ok(Box::new(DowngradeRegionsHandler.into())),
|
||||
Instruction::UpgradeRegions(_) => Ok(Box::new(
|
||||
))),
|
||||
Instruction::FlushRegions(_) => Ok(Some(Box::new(FlushRegionsHandler.into()))),
|
||||
Instruction::DowngradeRegions(_) => Ok(Some(Box::new(DowngradeRegionsHandler.into()))),
|
||||
Instruction::UpgradeRegions(_) => Ok(Some(Box::new(
|
||||
UpgradeRegionsHandler {
|
||||
upgrade_region_parallelism: self.open_region_parallelism,
|
||||
}
|
||||
.into(),
|
||||
)),
|
||||
Instruction::GetFileRefs(_) => Ok(Box::new(GetFileRefsHandler.into())),
|
||||
Instruction::GcRegions(_) => Ok(Box::new(GcRegionsHandler.into())),
|
||||
))),
|
||||
Instruction::GetFileRefs(_) => Ok(Some(Box::new(GetFileRefsHandler.into()))),
|
||||
Instruction::GcRegions(_) => Ok(Some(Box::new(GcRegionsHandler.into()))),
|
||||
Instruction::InvalidateCaches(_) => InvalidHeartbeatResponseSnafu.fail(),
|
||||
Instruction::Suspend => Ok(None),
|
||||
Instruction::EnterStagingRegions(_) => {
|
||||
Ok(Some(Box::new(EnterStagingRegionsHandler.into())))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -132,6 +141,7 @@ pub enum InstructionHandlers {
|
||||
UpgradeRegions(UpgradeRegionsHandler),
|
||||
GetFileRefs(GetFileRefsHandler),
|
||||
GcRegions(GcRegionsHandler),
|
||||
EnterStagingRegions(EnterStagingRegionsHandler),
|
||||
}
|
||||
|
||||
macro_rules! impl_from_handler {
|
||||
@@ -153,7 +163,8 @@ impl_from_handler!(
|
||||
DowngradeRegionsHandler => DowngradeRegions,
|
||||
UpgradeRegionsHandler => UpgradeRegions,
|
||||
GetFileRefsHandler => GetFileRefs,
|
||||
GcRegionsHandler => GcRegions
|
||||
GcRegionsHandler => GcRegions,
|
||||
EnterStagingRegionsHandler => EnterStagingRegions
|
||||
);
|
||||
|
||||
macro_rules! dispatch_instr {
|
||||
@@ -198,6 +209,7 @@ dispatch_instr!(
|
||||
UpgradeRegions => UpgradeRegions,
|
||||
GetFileRefs => GetFileRefs,
|
||||
GcRegions => GcRegions,
|
||||
EnterStagingRegions => EnterStagingRegions
|
||||
);
|
||||
|
||||
#[async_trait]
|
||||
@@ -216,30 +228,24 @@ impl HeartbeatResponseHandler for RegionHeartbeatResponseHandler {
|
||||
.context(InvalidHeartbeatResponseSnafu)?;
|
||||
|
||||
let mailbox = ctx.mailbox.clone();
|
||||
let region_server = self.region_server.clone();
|
||||
let downgrade_tasks = self.downgrade_tasks.clone();
|
||||
let flush_tasks = self.flush_tasks.clone();
|
||||
let gc_tasks = self.gc_tasks.clone();
|
||||
let handler = self.build_handler(&instruction)?;
|
||||
let _handle = common_runtime::spawn_global(async move {
|
||||
let reply = handler
|
||||
.handle(
|
||||
&HandlerContext {
|
||||
region_server,
|
||||
downgrade_tasks,
|
||||
flush_tasks,
|
||||
gc_tasks,
|
||||
},
|
||||
instruction,
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(reply) = reply
|
||||
&& let Err(e) = mailbox.send((meta, reply)).await
|
||||
{
|
||||
error!(e; "Failed to send reply to mailbox");
|
||||
}
|
||||
});
|
||||
if let Some(handler) = self.build_handler(&instruction)? {
|
||||
let context = HandlerContext {
|
||||
region_server: self.region_server.clone(),
|
||||
downgrade_tasks: self.downgrade_tasks.clone(),
|
||||
flush_tasks: self.flush_tasks.clone(),
|
||||
gc_tasks: self.gc_tasks.clone(),
|
||||
};
|
||||
let _handle = common_runtime::spawn_global(async move {
|
||||
let reply = handler.handle(&context, instruction).await;
|
||||
if let Some(reply) = reply
|
||||
&& let Err(e) = mailbox.send((meta, reply)).await
|
||||
{
|
||||
let error = e.to_string();
|
||||
let (meta, reply) = e.0;
|
||||
error!("Failed to send reply {reply} to {meta:?}: {error}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(HandleControl::Continue)
|
||||
}
|
||||
@@ -256,7 +262,9 @@ mod tests {
|
||||
use common_meta::heartbeat::mailbox::{
|
||||
HeartbeatMailbox, IncomingMessage, MailboxRef, MessageMeta,
|
||||
};
|
||||
use common_meta::instruction::{DowngradeRegion, OpenRegion, UpgradeRegion};
|
||||
use common_meta::instruction::{
|
||||
DowngradeRegion, EnterStagingRegion, OpenRegion, UpgradeRegion,
|
||||
};
|
||||
use mito2::config::MitoConfig;
|
||||
use mito2::engine::MITO_ENGINE_NAME;
|
||||
use mito2::test_util::{CreateRequestBuilder, TestEnv};
|
||||
@@ -337,6 +345,16 @@ mod tests {
|
||||
region_id,
|
||||
..Default::default()
|
||||
}]);
|
||||
assert!(
|
||||
heartbeat_handler
|
||||
.is_acceptable(&heartbeat_env.create_handler_ctx((meta.clone(), instruction)))
|
||||
);
|
||||
|
||||
// Enter staging region
|
||||
let instruction = Instruction::EnterStagingRegions(vec![EnterStagingRegion {
|
||||
region_id,
|
||||
partition_expr: "".to_string(),
|
||||
}]);
|
||||
assert!(
|
||||
heartbeat_handler.is_acceptable(&heartbeat_env.create_handler_ctx((meta, instruction)))
|
||||
);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user