Merge branch 'main' into feat/expr-split

This commit is contained in:
Weny Xu
2026-03-26 11:02:36 +08:00
committed by GitHub
203 changed files with 15793 additions and 4445 deletions

View File

@@ -37,17 +37,14 @@ inputs:
description: Whether to push the latest tag of the image
required: false
default: 'true'
aws-cn-s3-bucket:
description: S3 bucket to store released artifacts in CN region
proxy-url:
description: The url of the S3 proxy server
required: true
aws-cn-access-key-id:
description: AWS access key id in CN region
proxy-username:
description: The username of the S3 proxy
required: true
aws-cn-secret-access-key:
description: AWS secret access key in CN region
required: true
aws-cn-region:
description: AWS region in CN
proxy-password:
description: The password of the S3 proxy
required: true
upload-to-s3:
description: Upload to S3
@@ -77,21 +74,13 @@ runs:
with:
path: ${{ inputs.artifacts-dir }}
- name: Install s5cmd
shell: bash
run: |
wget https://github.com/peak/s5cmd/releases/download/v2.3.0/s5cmd_2.3.0_Linux-64bit.tar.gz
tar -xzf s5cmd_2.3.0_Linux-64bit.tar.gz
sudo mv s5cmd /usr/local/bin/
sudo chmod +x /usr/local/bin/s5cmd
- name: Release artifacts to cn region
uses: nick-invision/retry@v2
if: ${{ inputs.upload-to-s3 == 'true' }}
env:
AWS_ACCESS_KEY_ID: ${{ inputs.aws-cn-access-key-id }}
AWS_SECRET_ACCESS_KEY: ${{ inputs.aws-cn-secret-access-key }}
AWS_REGION: ${{ inputs.aws-cn-region }}
PROXY_URL: ${{ inputs.proxy-url }}
PROXY_USERNAME: ${{ inputs.proxy-username }}
PROXY_PASSWORD: ${{ inputs.proxy-password }}
UPDATE_VERSION_INFO: ${{ inputs.update-version-info }}
with:
max_attempts: ${{ inputs.upload-max-retry-times }}
@@ -99,8 +88,7 @@ runs:
command: |
./.github/scripts/upload-artifacts-to-s3.sh \
${{ inputs.artifacts-dir }} \
${{ inputs.version }} \
${{ inputs.aws-cn-s3-bucket }}
${{ inputs.version }}
- name: Push greptimedb image from Dockerhub to ACR
shell: bash

View File

@@ -5,16 +5,15 @@ set -o pipefail
ARTIFACTS_DIR=$1
VERSION=$2
AWS_S3_BUCKET=$3
RELEASE_DIRS="releases/greptimedb"
GREPTIMEDB_REPO="GreptimeTeam/greptimedb"
# Check if necessary variables are set.
function check_vars() {
for var in AWS_S3_BUCKET VERSION ARTIFACTS_DIR; do
for var in VERSION ARTIFACTS_DIR; do
if [ -z "${!var}" ]; then
echo "$var is not set or empty."
echo "Usage: $0 <artifacts-dir> <version> <aws-s3-bucket>"
echo "Usage: $0 <artifacts-dir> <version>"
exit 1
fi
done
@@ -33,8 +32,13 @@ function upload_artifacts() {
# ├── greptime-darwin-amd64-v0.2.0.sha256sum
# └── greptime-darwin-amd64-v0.2.0.tar.gz
find "$ARTIFACTS_DIR" -type f \( -name "*.tar.gz" -o -name "*.sha256sum" \) | while IFS= read -r file; do
s5cmd cp \
"$file" "s3://$AWS_S3_BUCKET/$RELEASE_DIRS/$VERSION/$(basename "$file")"
filename=$(basename "$file")
TARGET_URL="$PROXY_URL/$RELEASE_DIRS/$VERSION"
curl -X PUT \
-u "$PROXY_USERNAME:$PROXY_PASSWORD" \
-F "file=@$file" \
"$TARGET_URL"
done
}
@@ -45,16 +49,24 @@ function update_version_info() {
if [[ "$VERSION" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "Updating latest-version.txt"
echo "$VERSION" > latest-version.txt
s5cmd cp \
latest-version.txt "s3://$AWS_S3_BUCKET/$RELEASE_DIRS/latest-version.txt"
TARGET_URL="$PROXY_URL/$RELEASE_DIRS"
curl -X PUT \
-u "$PROXY_USERNAME:$PROXY_PASSWORD" \
-F "file=@latest-version.txt" \
"$TARGET_URL"
fi
# If it's the nightly release, update latest-nightly-version.txt.
if [[ "$VERSION" == *"nightly"* ]]; then
echo "Updating latest-nightly-version.txt"
echo "$VERSION" > latest-nightly-version.txt
s5cmd cp \
latest-nightly-version.txt "s3://$AWS_S3_BUCKET/$RELEASE_DIRS/latest-nightly-version.txt"
TARGET_URL="$PROXY_URL/$RELEASE_DIRS"
curl -X PUT \
-u "$PROXY_USERNAME:$PROXY_PASSWORD" \
-F "file=@latest-nightly-version.txt" \
"$TARGET_URL"
fi
fi
}
@@ -93,10 +105,10 @@ function main() {
}
# Usage example:
# AWS_ACCESS_KEY_ID=<your_access_key_id> \
# AWS_SECRET_ACCESS_KEY=<your_secret_access_key> \
# AWS_DEFAULT_REGION=<your_region> \
# PROXY_URL=<proxy_url> \
# PROXY_USERNAME=<proxy_username> \
# PROXY_PASSWORD=<proxy_password> \
# UPDATE_VERSION_INFO=true \
# DOWNLOAD_ARTIFACTS_FROM_GITHUB=false \
# ./upload-artifacts-to-s3.sh <artifacts-dir> <version> <aws-s3-bucket>
# ./upload-artifacts-to-s3.sh <artifacts-dir> <version>
main

View File

@@ -285,10 +285,9 @@ jobs:
dst-image-registry: ${{ vars.ACR_IMAGE_REGISTRY }}
dst-image-namespace: ${{ vars.IMAGE_NAMESPACE }}
version: ${{ needs.allocate-runners.outputs.version }}
aws-cn-s3-bucket: ${{ vars.AWS_RELEASE_BUCKET }}
aws-cn-access-key-id: ${{ secrets.AWS_CN_ACCESS_KEY_ID }}
aws-cn-secret-access-key: ${{ secrets.AWS_CN_SECRET_ACCESS_KEY }}
aws-cn-region: ${{ vars.AWS_RELEASE_BUCKET_REGION }}
proxy-url: ${{ secrets.PROXY_URL }}
proxy-username: ${{ secrets.PROXY_USERNAME }}
proxy-password: ${{ secrets.PROXY_PASSWORD }}
upload-to-s3: ${{ inputs.upload_artifacts_to_s3 }}
dev-mode: true # Only build the standard images(exclude centos images).
push-latest-tag: false # Don't push the latest tag to registry.

View File

@@ -319,7 +319,13 @@ jobs:
include:
- target: "fuzz_repartition_table"
mode:
name: "Local WAL Repartition GC"
name: "Local WAL mito table repartition"
minio: true
kafka: false
values: "with-minio-repartition-gc.yaml"
- target: "fuzz_repartition_metric_table"
mode:
name: "Local WAL metric table repartition"
minio: true
kafka: false
values: "with-minio-repartition-gc.yaml"
@@ -455,6 +461,14 @@ jobs:
path: /tmp/fuzz-monitor-dumps
if-no-files-found: warn
retention-days: 3
- name: Upload CSV dumps
if: failure()
uses: actions/upload-artifact@v4
with:
name: fuzz-tests-csv-dumps-${{ matrix.mode.name }}-${{ matrix.target }}
path: /tmp/greptime-fuzz-dumps
if-no-files-found: warn
retention-days: 3
- name: Delete cluster
if: success()
shell: bash

View File

@@ -236,10 +236,9 @@ jobs:
dst-image-registry: ${{ vars.ACR_IMAGE_REGISTRY }}
dst-image-namespace: ${{ vars.IMAGE_NAMESPACE }}
version: ${{ needs.allocate-runners.outputs.version }}
aws-cn-s3-bucket: ${{ vars.AWS_RELEASE_BUCKET }}
aws-cn-access-key-id: ${{ secrets.AWS_CN_ACCESS_KEY_ID }}
aws-cn-secret-access-key: ${{ secrets.AWS_CN_SECRET_ACCESS_KEY }}
aws-cn-region: ${{ vars.AWS_RELEASE_BUCKET_REGION }}
proxy-url: ${{ secrets.PROXY_URL }}
proxy-username: ${{ secrets.PROXY_USERNAME }}
proxy-password: ${{ secrets.PROXY_PASSWORD }}
upload-to-s3: false
dev-mode: false
update-version-info: false # Don't update version info in S3.

View File

@@ -358,10 +358,9 @@ jobs:
dst-image-registry: ${{ vars.ACR_IMAGE_REGISTRY }}
dst-image-namespace: ${{ vars.IMAGE_NAMESPACE }}
version: ${{ needs.allocate-runners.outputs.version }}
aws-cn-s3-bucket: ${{ vars.AWS_RELEASE_BUCKET }}
aws-cn-access-key-id: ${{ secrets.AWS_CN_ACCESS_KEY_ID }}
aws-cn-secret-access-key: ${{ secrets.AWS_CN_SECRET_ACCESS_KEY }}
aws-cn-region: ${{ vars.AWS_RELEASE_BUCKET_REGION }}
proxy-url: ${{ secrets.PROXY_URL }}
proxy-username: ${{ secrets.PROXY_USERNAME }}
proxy-password: ${{ secrets.PROXY_PASSWORD }}
dev-mode: false
upload-to-s3: true
update-version-info: true

3
.gitignore vendored
View File

@@ -70,3 +70,6 @@ CLAUDE.md
# AGENTS.md
AGENTS.md
# local design docs
docs/specs/

21
Cargo.lock generated
View File

@@ -1946,6 +1946,7 @@ dependencies = [
"tokio",
"tracing-appender",
"url",
"uuid",
]
[[package]]
@@ -2488,7 +2489,6 @@ version = "1.0.0-rc.2"
dependencies = [
"common-error",
"common-macro",
"common-telemetry",
"humantime",
"serde",
"snafu 0.8.6",
@@ -7301,7 +7301,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
dependencies = [
"cfg-if",
"windows-targets 0.52.6",
"windows-targets 0.48.5",
]
[[package]]
@@ -7887,6 +7887,7 @@ dependencies = [
"common-base",
"common-error",
"common-function",
"common-grpc",
"common-macro",
"common-meta",
"common-query",
@@ -9619,9 +9620,9 @@ dependencies = [
[[package]]
name = "pgwire"
version = "0.38.0"
version = "0.38.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89d5e5a60d3f6e40c91f6a2a7f8d09665e636272bd5611977253559b6651aabb"
checksum = "f2a798d130b8975a566c2cf6d8955746e1f09a9ee2c3ff2e6020a2c6528c5bd1"
dependencies = [
"async-trait",
"base64 0.22.1",
@@ -10771,9 +10772,9 @@ dependencies = [
[[package]]
name = "quinn-proto"
version = "0.11.12"
version = "0.11.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e"
checksum = "434b42fec591c96ef50e21e886936e66d3cc3f737104fdb9b737c40ffb94c098"
dependencies = [
"bytes",
"getrandom 0.3.3",
@@ -11634,9 +11635,9 @@ dependencies = [
[[package]]
name = "rustls-webpki"
version = "0.103.3"
version = "0.103.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435"
checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef"
dependencies = [
"ring",
"rustls-pki-types",
@@ -13403,9 +13404,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "tar"
version = "0.4.44"
version = "0.4.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a"
checksum = "22692a6476a21fa75fdfc11d452fda482af402c008cdbaf3476414e122040973"
dependencies = [
"filetime",
"libc",

View File

@@ -67,6 +67,7 @@ snapshot-20250101/
- Self-contained (all information needed for restore)
- Immutable (content never changes after creation)
- Verifiable (checksums at file, chunk, and snapshot levels)
- Schema-only snapshots contain only `manifest.json` and `schema/`; `data/` is absent, `chunks` is empty, and later data append is rejected (use `--force` to recreate)
### Chunk
@@ -116,6 +117,8 @@ greptime export create \
--schema-only \
--to s3://my-bucket/snapshots/prod-schema-only
Schema-only snapshots cannot be resumed with data; use `--force` to recreate.
# Export with specific format (default: parquet)
greptime export create \
--format csv \
@@ -173,7 +176,9 @@ The manifest is a JSON file containing snapshot metadata and chunk index:
- `snapshot_id`: Unique identifier (UUID)
- `catalog`, `schemas`: Catalog and schema list
- `time_range`: Overall time range covered
- `schema_only`: Whether the snapshot contains schema only
- `chunks[]`: Array of chunk metadata
- `format`: Data format for exported files
- `checksum`: Snapshot-level SHA256 checksum
**Chunk metadata structure**:
@@ -182,7 +187,7 @@ Each chunk entry in the manifest contains:
- `id`: Chunk identifier (sequential number)
- `time_range`: Start and end timestamps
- `status`: Export status (Pending, Completed, Failed)
- `status`: Export status (Pending, InProgress, Completed, Failed)
- `files`: List of data files in the chunk directory
- `checksum`: Chunk-level checksum for integrity verification
@@ -292,9 +297,9 @@ Checksums are verified during import before data is written to the database.
**Resume capability**:
- Manifest tracks chunk status (Pending, Completed, Failed)
- Manifest tracks chunk status (Pending, InProgress, Completed, Failed)
- Export/import automatically resumes when executed on existing snapshot
- Skips completed chunks, retries failed chunks, processes pending chunks
- Skips completed chunks, retries failed/in-progress chunks, processes pending chunks
- Works across process restarts
- Use `--force` (export only) to delete existing snapshot and start over

View File

@@ -0,0 +1,190 @@
---
Feature Name: Flow Batching Sequence-Based Incremental Query Plan (Lite)
Tracking Issue: TBD
Date: 2026-03-16
Author: @discord9
---
# Summary
This RFC proposes a correctness-first incremental query mode for Flow batching.
Flow queries can read only `seq > checkpoint` and advance checkpoints using per-region correctness watermarks.
When incremental reads are stale or correctness cannot be proven, Flow falls back to full recomputation.
# Motivation
Flow batching still needs to repeatedly compute old data in the same time window, so incremental query can improve Flow performance.
# Goals
1. Add opt-in incremental reads (`seq > given_seq`) for Flow.
2. Return per-region correctness watermarks for checkpoint advancement.
3. Keep existing query behavior unchanged unless explicitly enabled.
4. Define deterministic fallback for stale or unprovable incremental reads.
# Non-Goals
1. No business-schema changes (no synthetic watermark columns in result rows).
2. No global throughput optimization in v1 (correctness first).
3. No observational watermark output when correctness is unprovable.
# Proposal
## 1) Query options
Introduce three `QueryContext` extension keys:
- `flow.incremental_after_seqs`
- `flow.incremental_mode`
- `flow.return_region_seq`
These options are opt-in and only affect Flow incremental execution paths.
## 2) Scan mapping
When incremental mode is enabled:
- map `after_seq` to `memtable_min_sequence` (exclusive lower bound)
- keep existing snapshot upper-bound behavior (`memtable_max_sequence`)
Important limitation in v1:
- incremental filtering is correctness-proven only for memtable rows
- SST files do not preserve detailed row-level sequence metadata; they only expose coarser file-level sequence information
- therefore `seq > checkpoint` must not assume precise incremental pruning across memtable->SST flush boundaries
If required incremental parameters are missing or invalid, return argument error.
## 3) Stale protection
Add dedicated stale error:
- `IncrementalQueryStale { region_id, given_seq, min_readable_seq }`
Behavior:
- if `given_seq < min_readable_seq`, return stale error
- if `given_seq == min_readable_seq`, query is valid and reads `seq > given_seq`
- if `given_seq > min_readable_seq`, query is also valid and reads `seq > given_seq`
`IncrementalQueryStale` also covers the case where rows newer than the checkpoint have crossed a memtable->SST flush boundary and sequence-precise incremental exclusion can no longer be proven.
In other words, the flush-boundary case is not a separate fallback category in v1; it is one concrete way an incremental cursor becomes stale.
## 4) Watermark return
Extend query metrics with optional per-region watermark map:
- `region_latest_sequences: Vec<(region_id: u64, latest_sequence: u64)>`
Rules:
- only terminal metrics of successful query can advance checkpoints
- for multi-region query, watermark must be complete map or absent
- if correctness is unprovable, business rows may return but watermark is absent
## 5) Flow state machine
Checkpoint and watermark state are kept only in flownode memory in v1; they are not persisted as durable flow metadata.
Cold start or flownode restart therefore always re-enters through a full snapshot read.
Only after that full query succeeds with a complete correctness watermark may Flow switch back to incremental mode.
Flow starts in full mode, then transitions:
1. Full query succeeds with correctness watermark -> enter incremental mode
2. Incremental query succeeds with correctness watermark -> advance checkpoint
3. Incremental stale/failure -> fallback to full mode
4. Full query without correctness watermark -> remain in full mode
```mermaid
stateDiagram-v2
[*] --> FullSnapshot: Flow starts
state FullSnapshot {
[*] --> RunFull
RunFull --> RunFull: Full query succeeds but watermark is unprovable<br/>no region_latest_sequences returned
}
FullSnapshot --> Incremental: Full query succeeds and correctness watermark is returned<br/>(checkpoint updated)
state Incremental {
[*] --> RunInc
RunInc --> RunInc: Incremental succeeds<br/>(checkpoint advances)
}
Incremental --> FullSnapshot: IncrementalQueryStale<br/>(cursor too old, fallback required)
Incremental --> FullSnapshot: Incremental fails<br/> and fallback policy is triggered
FullSnapshot --> [*]: Flow stops
Incremental --> [*]: Flow stops
```
### Fallback Policy
Fallback to full mode is deterministic and is triggered by any of the following:
1. `IncrementalQueryStale` is returned.
2. Incremental query fails with execution errors.
3. Incremental query succeeds but watermark is absent or incomplete for participating regions.
Policy behavior:
1. Do not advance any checkpoint in the failed/incomplete round.
2. Switch to full mode for the affected flow/window in the next round.
3. Return to incremental mode only after a full query succeeds with a complete correctness watermark map.
### Persistence and recovery model
The v1 design is intentionally correctness-first and keeps the progress cursor lightweight:
1. Watermarks/checkpoints live only in flownode memory; v1 does not persist them separately.
2. On cold start, the flow re-establishes progress by running a successful full-query snapshot read, then resumes incremental mode only after that round returns a complete correctness watermark map.
3. Sequence-precise incremental correctness is currently limited to rows still visible in memtables.
4. Once relevant rows have been flushed into SST, the system cannot use `seq > checkpoint` alone to prove precise incremental exclusion, because SST lacks detailed row-level sequence metadata.
5. In that case the correct behavior is to fall back to full recomputation, not to continue a best-effort incremental scan.
# Distributed and Compatibility Requirements
1. Distributed path must preserve region-level snapshot/read-bound semantics end-to-end.
2. `snapshot_seqs` transport and `flow.*` options must both be carried correctly.
- `snapshot_seqs` means the per-region snapshot upper-bound map: `region_id -> sequence`.
3. New metrics fields must be backward-compatible (old clients ignore unknown fields).
# Rollout Plan
## Phase 1 (MVP, correctness first)
1. Add extension constants and parsing.
2. Add incremental scan mapping and stale detection.
3. Add watermark metrics field and terminal-watermark checkpoint update path.
4. Complete standalone and distributed passthrough.
## Phase 2 (performance and observability)
1. Improve batching key strategy with sequence/watermark context.
2. Optimize watermark serialization overhead.
3. Add metrics: incremental hit rate, fallback rate, fallback window size.
# Testing Plan
1. Unit tests for incremental bounds and stale detection.
2. Query-path tests for extension mapping and watermark semantics.
3. Flow integration tests for full->incremental->fallback transitions.
4. Distributed tests for end-to-end snapshot/watermark propagation.
5. Compatibility tests for old/new client-server combinations.
# Risks
1. Boundary semantic mismatch (`<` vs `<=`) may cause correctness bugs.
2. Incomplete distributed propagation can silently invalidate watermark safety.
3. Frequent fallback can reduce throughput before phase-2 optimizations.
4. Memtable->SST flushes may force more full recomputation than expected until finer-grained SST sequence tracking exists.
# Alternatives
1. Put watermark into business rows (rejected: schema pollution).
2. Add new dedicated Flight message type in v1 (deferred to reduce scope).
# Conclusion
This plan enables a practical, correctness-first incremental path for Flow batching.
It reuses existing sequence scan capability, adds strict stale handling, and advances checkpoints only from correctness-proven per-region watermarks.

View File

@@ -65,11 +65,13 @@ fn init_factory(
fn invalidator<'a>(
cache: &'a Cache<TableName, TableRef>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, MetaResult<()>> {
Box::pin(async move {
if let CacheIdent::TableName(table_name) = ident {
cache.invalidate(table_name).await
for ident in idents {
if let CacheIdent::TableName(table_name) = ident {
cache.invalidate(table_name).await
}
}
Ok(())
})

View File

@@ -267,7 +267,7 @@ impl InformationSchemaRegionPeersBuilder {
];
if !predicates.eval(&row) {
return;
continue;
}
self.table_catalogs.push(Some(table_catalog));

View File

@@ -151,7 +151,11 @@ impl DfTableSourceProvider {
let catalog_list = Arc::new(DummyCatalogList::new(self.catalog_manager.clone()));
let logical_plan = self
.plan_decoder
.decode(Bytes::from(view_info.view_info.clone()), catalog_list, true)
.decode(
Bytes::from(view_info.view_info.clone()),
catalog_list,
false,
)
.await
.context(DecodePlanSnafu {
name: &table.table_info().name,

View File

@@ -65,6 +65,8 @@ store-api.workspace = true
table.workspace = true
tokio.workspace = true
tracing-appender.workspace = true
url.workspace = true
uuid.workspace = true
[dev-dependencies]
common-meta = { workspace = true, features = ["testing"] }
@@ -72,4 +74,3 @@ common-test-util.workspace = true
common-version.workspace = true
serde.workspace = true
tempfile.workspace = true
url.workspace = true

View File

@@ -13,7 +13,12 @@
// limitations under the License.
mod export;
pub mod export_v2;
mod import;
pub mod import_v2;
pub(crate) mod path;
pub mod snapshot_storage;
pub(crate) mod sql;
mod storage_export;
use clap::Subcommand;
@@ -22,15 +27,24 @@ use common_error::ext::BoxedError;
use crate::Tool;
use crate::data::export::ExportCommand;
use crate::data::export_v2::ExportV2Command;
use crate::data::import::ImportCommand;
use crate::data::import_v2::ImportV2Command;
pub(crate) const COPY_PATH_PLACEHOLDER: &str = "<PATH/TO/FILES>";
/// Command for data operations including exporting data from and importing data into GreptimeDB.
#[derive(Subcommand)]
pub enum DataCommand {
/// Export data (V1 - legacy).
Export(ExportCommand),
/// Import data (V1 - legacy).
Import(ImportCommand),
/// Export V2 - JSON-based schema export with manifest support.
#[clap(subcommand)]
ExportV2(ExportV2Command),
/// Import V2 - Import from V2 snapshot.
ImportV2(ImportV2Command),
}
impl DataCommand {
@@ -38,6 +52,8 @@ impl DataCommand {
match self {
DataCommand::Export(cmd) => cmd.build().await,
DataCommand::Import(cmd) => cmd.build().await,
DataCommand::ExportV2(cmd) => cmd.build().await,
DataCommand::ImportV2(cmd) => cmd.build().await,
}
}
}

View File

@@ -107,13 +107,16 @@ pub struct ExportCommand {
#[clap(long, value_parser = humantime::parse_duration)]
timeout: Option<Duration>,
/// The proxy server address to connect, if set, will override the system proxy.
/// The proxy server address to connect.
///
/// The default behavior will use the system proxy if neither `proxy` nor `no_proxy` is set.
/// If set, it overrides the system proxy unless `--no-proxy` is specified.
/// If neither `--proxy` nor `--no-proxy` is set, system proxy (env) may be used.
#[clap(long)]
proxy: Option<String>,
/// Disable proxy server, if set, will not use any proxy.
/// Disable all proxy usage (ignores `--proxy` and system proxy).
///
/// When set and `--proxy` is not provided, this explicitly disables system proxy.
#[clap(long)]
no_proxy: bool,
@@ -173,6 +176,7 @@ impl ExportCommand {
// Treats `None` as `0s` to disable server-side default timeout.
self.timeout.unwrap_or_default(),
proxy,
self.no_proxy,
);
Ok(Box::new(Export {

View File

@@ -0,0 +1,49 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Export V2 module.
//!
//! This module provides the V2 implementation of database export functionality,
//! featuring:
//! - JSON-based schema export (version-agnostic)
//! - Manifest-based snapshot management
//! - Support for multiple storage backends (S3, OSS, GCS, Azure Blob, local FS)
//! - Resume capability for interrupted exports
//!
//! # Example
//!
//! ```bash
//! # Export schema only
//! greptime cli data export-v2 create \
//! --addr 127.0.0.1:4000 \
//! --to file:///tmp/snapshot \
//! --schema-only
//!
//! # Export with time range (M2)
//! greptime cli data export-v2 create \
//! --addr 127.0.0.1:4000 \
//! --to s3://bucket/snapshots/prod-20250101 \
//! --start-time 2025-01-01T00:00:00Z \
//! --end-time 2025-01-31T23:59:59Z
//! ```
mod command;
pub mod error;
pub mod extractor;
pub mod manifest;
pub mod schema;
pub use command::ExportV2Command;
#[cfg(test)]
mod tests;

View File

@@ -0,0 +1,496 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Export V2 CLI commands.
use std::collections::HashSet;
use std::time::Duration;
use async_trait::async_trait;
use clap::{Parser, Subcommand};
use common_error::ext::BoxedError;
use common_telemetry::info;
use serde_json::Value;
use snafu::{OptionExt, ResultExt};
use crate::Tool;
use crate::common::ObjectStoreConfig;
use crate::data::export_v2::error::{
CannotResumeSchemaOnlySnafu, DataExportNotImplementedSnafu, DatabaseSnafu, EmptyResultSnafu,
ManifestVersionMismatchSnafu, Result, UnexpectedValueTypeSnafu,
};
use crate::data::export_v2::extractor::SchemaExtractor;
use crate::data::export_v2::manifest::{DataFormat, MANIFEST_VERSION, Manifest};
use crate::data::path::ddl_path_for_schema;
use crate::data::snapshot_storage::{OpenDalStorage, SnapshotStorage, validate_uri};
use crate::data::sql::{escape_sql_identifier, escape_sql_literal};
use crate::database::{DatabaseClient, parse_proxy_opts};
/// Export V2 commands.
#[derive(Debug, Subcommand)]
pub enum ExportV2Command {
/// Create a new snapshot.
Create(ExportCreateCommand),
}
impl ExportV2Command {
pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
match self {
ExportV2Command::Create(cmd) => cmd.build().await,
}
}
}
/// Create a new snapshot.
#[derive(Debug, Parser)]
pub struct ExportCreateCommand {
/// Server address to connect (e.g., 127.0.0.1:4000).
#[clap(long)]
addr: String,
/// Target storage location (e.g., s3://bucket/path, file:///tmp/backup).
#[clap(long)]
to: String,
/// Catalog name.
#[clap(long, default_value = "greptime")]
catalog: String,
/// Schema list to export (default: all non-system schemas).
/// Can be specified multiple times or comma-separated.
#[clap(long, value_delimiter = ',')]
schemas: Vec<String>,
/// Export schema only, no data.
#[clap(long)]
schema_only: bool,
/// Time range start (ISO 8601 format, e.g., 2024-01-01T00:00:00Z).
#[clap(long)]
start_time: Option<String>,
/// Time range end (ISO 8601 format, e.g., 2024-12-31T23:59:59Z).
#[clap(long)]
end_time: Option<String>,
/// Data format: parquet, csv, json.
#[clap(long, value_enum, default_value = "parquet")]
format: DataFormat,
/// Delete existing snapshot and recreate.
#[clap(long)]
force: bool,
/// Concurrency level (for future use).
#[clap(long, default_value = "1")]
parallelism: usize,
/// Basic authentication (user:password).
#[clap(long)]
auth_basic: Option<String>,
/// Request timeout.
#[clap(long, value_parser = humantime::parse_duration)]
timeout: Option<Duration>,
/// Proxy server address.
///
/// If set, it overrides the system proxy unless `--no-proxy` is specified.
/// If neither `--proxy` nor `--no-proxy` is set, system proxy (env) may be used.
#[clap(long)]
proxy: Option<String>,
/// Disable all proxy usage (ignores `--proxy` and system proxy).
///
/// When set and `--proxy` is not provided, this explicitly disables system proxy.
#[clap(long)]
no_proxy: bool,
/// Object store configuration for remote storage backends.
#[clap(flatten)]
storage: ObjectStoreConfig,
}
impl ExportCreateCommand {
pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
// Validate URI format
validate_uri(&self.to).map_err(BoxedError::new)?;
if !self.schema_only {
return DataExportNotImplementedSnafu
.fail()
.map_err(BoxedError::new);
}
// Parse schemas (empty vec means all schemas)
let schemas = if self.schemas.is_empty() {
None
} else {
Some(self.schemas.clone())
};
// Build storage
let storage = OpenDalStorage::from_uri(&self.to, &self.storage).map_err(BoxedError::new)?;
// Build database client
let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?;
let database_client = DatabaseClient::new(
self.addr.clone(),
self.catalog.clone(),
self.auth_basic.clone(),
self.timeout.unwrap_or(Duration::from_secs(60)),
proxy,
self.no_proxy,
);
Ok(Box::new(ExportCreate {
catalog: self.catalog.clone(),
schemas,
schema_only: self.schema_only,
_format: self.format,
force: self.force,
_parallelism: self.parallelism,
storage: Box::new(storage),
database_client,
}))
}
}
/// Export tool implementation.
pub struct ExportCreate {
catalog: String,
schemas: Option<Vec<String>>,
schema_only: bool,
_format: DataFormat,
force: bool,
_parallelism: usize,
storage: Box<dyn SnapshotStorage>,
database_client: DatabaseClient,
}
#[async_trait]
impl Tool for ExportCreate {
async fn do_work(&self) -> std::result::Result<(), BoxedError> {
self.run().await.map_err(BoxedError::new)
}
}
impl ExportCreate {
async fn run(&self) -> Result<()> {
// 1. Check if snapshot exists
let exists = self.storage.exists().await?;
if exists {
if self.force {
info!("Deleting existing snapshot (--force)");
self.storage.delete_snapshot().await?;
} else {
// Resume mode - read existing manifest
let manifest = self.storage.read_manifest().await?;
// Check version compatibility
if manifest.version != MANIFEST_VERSION {
return ManifestVersionMismatchSnafu {
expected: MANIFEST_VERSION,
found: manifest.version,
}
.fail();
}
// Cannot resume schema-only with data export
if manifest.schema_only && !self.schema_only {
return CannotResumeSchemaOnlySnafu.fail();
}
info!(
"Resuming existing snapshot: {} (completed: {}/{} chunks)",
manifest.snapshot_id,
manifest.completed_count(),
manifest.chunks.len()
);
// For M1, we only handle schema-only exports
// M2 will add chunk resume logic
if manifest.is_complete() {
info!("Snapshot is already complete");
return Ok(());
}
// TODO: Resume data export in M2
info!("Data export resume not yet implemented (M2)");
return Ok(());
}
}
// 2. Get schema list
let extractor = SchemaExtractor::new(&self.database_client, &self.catalog);
let schema_snapshot = extractor.extract(self.schemas.as_deref()).await?;
let schema_names: Vec<String> = schema_snapshot
.schemas
.iter()
.map(|s| s.name.clone())
.collect();
info!("Exporting schemas: {:?}", schema_names);
// 3. Create manifest
let manifest = Manifest::new_schema_only(self.catalog.clone(), schema_names.clone());
// 4. Write schema files
self.storage.write_schema(&schema_snapshot).await?;
info!("Exported {} schemas", schema_snapshot.schemas.len());
// 5. Export DDL files for import recovery.
let ddl_by_schema = self.build_ddl_by_schema(&schema_names).await?;
for (schema, ddl) in ddl_by_schema {
let ddl_path = ddl_path_for_schema(&schema);
self.storage.write_text(&ddl_path, &ddl).await?;
info!("Exported DDL for schema {} to {}", schema, ddl_path);
}
// 6. Write manifest last.
//
// The manifest is the snapshot commit point: only write it after the schema
// index and all DDL files are durable, so a crash cannot leave a "valid"
// snapshot that is missing required schema artifacts.
self.storage.write_manifest(&manifest).await?;
info!("Snapshot created: {}", manifest.snapshot_id);
Ok(())
}
async fn build_ddl_by_schema(&self, schema_names: &[String]) -> Result<Vec<(String, String)>> {
let mut schemas = schema_names.to_vec();
schemas.sort();
let mut ddl_by_schema = Vec::with_capacity(schemas.len());
for schema in schemas {
let create_database = self.show_create("DATABASE", &schema, None).await?;
let (mut physical_tables, mut tables, mut views) =
self.get_schema_objects(&schema).await?;
physical_tables.sort();
let mut physical_ddls = Vec::with_capacity(physical_tables.len());
for table in physical_tables {
physical_ddls.push(self.show_create("TABLE", &schema, Some(&table)).await?);
}
tables.sort();
let mut table_ddls = Vec::with_capacity(tables.len());
for table in tables {
table_ddls.push(self.show_create("TABLE", &schema, Some(&table)).await?);
}
views.sort();
let mut view_ddls = Vec::with_capacity(views.len());
for view in views {
view_ddls.push(self.show_create("VIEW", &schema, Some(&view)).await?);
}
let ddl = build_schema_ddl(
&schema,
create_database,
physical_ddls,
table_ddls,
view_ddls,
);
ddl_by_schema.push((schema, ddl));
}
Ok(ddl_by_schema)
}
async fn get_schema_objects(
&self,
schema: &str,
) -> Result<(Vec<String>, Vec<String>, Vec<String>)> {
let physical_tables = self.get_metric_physical_tables(schema).await?;
let physical_set: HashSet<&str> = physical_tables.iter().map(String::as_str).collect();
let sql = format!(
"SELECT table_name, table_type FROM information_schema.tables \
WHERE table_catalog = '{}' AND table_schema = '{}' \
AND (table_type = 'BASE TABLE' OR table_type = 'VIEW')",
escape_sql_literal(&self.catalog),
escape_sql_literal(schema)
);
let records: Option<Vec<Vec<Value>>> = self
.database_client
.sql_in_public(&sql)
.await
.context(DatabaseSnafu)?;
let mut tables = Vec::new();
let mut views = Vec::new();
if let Some(rows) = records {
for row in rows {
let name = match row.first() {
Some(Value::String(name)) => name.clone(),
_ => return UnexpectedValueTypeSnafu.fail(),
};
let table_type = match row.get(1) {
Some(Value::String(table_type)) => table_type.as_str(),
_ => return UnexpectedValueTypeSnafu.fail(),
};
if !physical_set.contains(name.as_str()) {
if table_type == "VIEW" {
views.push(name);
} else {
tables.push(name);
}
}
}
}
Ok((physical_tables, tables, views))
}
async fn get_metric_physical_tables(&self, schema: &str) -> Result<Vec<String>> {
let sql = format!(
"SELECT DISTINCT table_name FROM information_schema.columns \
WHERE table_catalog = '{}' AND table_schema = '{}' AND column_name = '__tsid'",
escape_sql_literal(&self.catalog),
escape_sql_literal(schema)
);
let records: Option<Vec<Vec<Value>>> = self
.database_client
.sql_in_public(&sql)
.await
.context(DatabaseSnafu)?;
let mut tables = HashSet::new();
if let Some(rows) = records {
for row in rows {
let name = match row.first() {
Some(Value::String(name)) => name.clone(),
_ => return UnexpectedValueTypeSnafu.fail(),
};
tables.insert(name);
}
}
Ok(tables.into_iter().collect())
}
async fn show_create(
&self,
show_type: &str,
schema: &str,
table: Option<&str>,
) -> Result<String> {
let sql = match table {
Some(table) => format!(
r#"SHOW CREATE {} "{}"."{}"."{}""#,
show_type,
escape_sql_identifier(&self.catalog),
escape_sql_identifier(schema),
escape_sql_identifier(table)
),
None => format!(
r#"SHOW CREATE {} "{}"."{}""#,
show_type,
escape_sql_identifier(&self.catalog),
escape_sql_identifier(schema)
),
};
let records: Option<Vec<Vec<Value>>> = self
.database_client
.sql_in_public(&sql)
.await
.context(DatabaseSnafu)?;
let rows = records.context(EmptyResultSnafu)?;
let row = rows.first().context(EmptyResultSnafu)?;
let Some(Value::String(create)) = row.get(1) else {
return UnexpectedValueTypeSnafu.fail();
};
Ok(format!("{};\n", create))
}
}
fn build_schema_ddl(
schema: &str,
create_database: String,
physical_tables: Vec<String>,
tables: Vec<String>,
views: Vec<String>,
) -> String {
let mut ddl = String::new();
ddl.push_str(&format!("-- Schema: {}\n", schema));
ddl.push_str(&create_database);
for stmt in physical_tables {
ddl.push_str(&stmt);
}
for stmt in tables {
ddl.push_str(&stmt);
}
for stmt in views {
ddl.push_str(&stmt);
}
ddl.push('\n');
ddl
}
#[cfg(test)]
mod tests {
use clap::Parser;
use super::*;
use crate::data::path::ddl_path_for_schema;
#[test]
fn test_ddl_path_for_schema() {
assert_eq!(ddl_path_for_schema("public"), "schema/ddl/public.sql");
assert_eq!(
ddl_path_for_schema("../evil"),
"schema/ddl/%2E%2E%2Fevil.sql"
);
}
#[test]
fn test_build_schema_ddl_order() {
let ddl = build_schema_ddl(
"public",
"CREATE DATABASE public;\n".to_string(),
vec!["PHYSICAL;\n".to_string()],
vec!["TABLE;\n".to_string()],
vec!["VIEW;\n".to_string()],
);
let db_pos = ddl.find("CREATE DATABASE").unwrap();
let physical_pos = ddl.find("PHYSICAL;").unwrap();
let table_pos = ddl.find("TABLE;").unwrap();
let view_pos = ddl.find("VIEW;").unwrap();
assert!(db_pos < physical_pos);
assert!(physical_pos < table_pos);
assert!(table_pos < view_pos);
}
#[tokio::test]
async fn test_build_rejects_non_schema_only_export() {
let cmd = ExportCreateCommand::parse_from([
"export-v2-create",
"--addr",
"127.0.0.1:4000",
"--to",
"file:///tmp/export-v2-test",
]);
let result = cmd.build().await;
assert!(result.is_err());
let error = result.err().unwrap().to_string();
assert!(error.contains("Data export is not implemented yet"));
}
}

View File

@@ -0,0 +1,181 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::any::Any;
use common_error::ext::ErrorExt;
use common_error::status_code::StatusCode;
use common_macro::stack_trace_debug;
use snafu::{Location, Snafu};
#[derive(Snafu)]
#[snafu(visibility(pub))]
#[stack_trace_debug]
pub enum Error {
#[snafu(display("Invalid URI '{}': {}", uri, reason))]
InvalidUri {
uri: String,
reason: String,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Unsupported storage scheme: {}", scheme))]
UnsupportedScheme {
scheme: String,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Storage operation '{}' failed", operation))]
StorageOperation {
operation: String,
#[snafu(source)]
error: object_store::Error,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to parse manifest"))]
ManifestParse {
#[snafu(source)]
error: serde_json::Error,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to serialize manifest"))]
ManifestSerialize {
#[snafu(source)]
error: serde_json::Error,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to decode text file as UTF-8"))]
TextDecode {
#[snafu(source)]
error: std::string::FromUtf8Error,
#[snafu(implicit)]
location: Location,
},
#[snafu(display(
"Cannot resume schema-only snapshot with data export. Use --force to recreate."
))]
CannotResumeSchemaOnly {
#[snafu(implicit)]
location: Location,
},
#[snafu(display(
"Data export is not implemented yet. Use --schema-only to create a schema snapshot."
))]
DataExportNotImplemented {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Empty result from query"))]
EmptyResult {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Unexpected value type in query result"))]
UnexpectedValueType {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Database error"))]
Database {
#[snafu(source)]
error: crate::error::Error,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Snapshot not found at '{}'", uri))]
SnapshotNotFound {
uri: String,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Schema '{}' not found in catalog '{}'", schema, catalog))]
SchemaNotFound {
catalog: String,
schema: String,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to parse URL"))]
UrlParse {
#[snafu(source)]
error: url::ParseError,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Failed to build object store"))]
BuildObjectStore {
#[snafu(source)]
error: object_store::Error,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Manifest version mismatch: expected {}, found {}", expected, found))]
ManifestVersionMismatch {
expected: u32,
found: u32,
#[snafu(implicit)]
location: Location,
},
}
pub type Result<T> = std::result::Result<T, Error>;
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::InvalidUri { .. }
| Error::UnsupportedScheme { .. }
| Error::CannotResumeSchemaOnly { .. }
| Error::DataExportNotImplemented { .. }
| Error::ManifestVersionMismatch { .. } => StatusCode::InvalidArguments,
Error::StorageOperation { .. }
| Error::ManifestParse { .. }
| Error::ManifestSerialize { .. }
| Error::TextDecode { .. }
| Error::BuildObjectStore { .. } => StatusCode::StorageUnavailable,
Error::EmptyResult { .. }
| Error::UnexpectedValueType { .. }
| Error::UrlParse { .. } => StatusCode::Internal,
Error::Database { error, .. } => error.status_code(),
Error::SnapshotNotFound { .. } => StatusCode::InvalidArguments,
Error::SchemaNotFound { .. } => StatusCode::DatabaseNotFound,
}
}
fn as_any(&self) -> &dyn Any {
self
}
}

View File

@@ -0,0 +1,254 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Schema extraction from information_schema.
//!
//! For V2 DDL-only snapshots, extractor only persists the schema index.
use std::collections::{HashMap, HashSet};
use serde_json::Value;
use snafu::ResultExt;
use crate::data::export_v2::error::{
DatabaseSnafu, EmptyResultSnafu, Result, SchemaNotFoundSnafu, UnexpectedValueTypeSnafu,
};
use crate::data::export_v2::schema::{SchemaDefinition, SchemaSnapshot};
use crate::data::sql::escape_sql_literal;
use crate::database::DatabaseClient;
/// System schemas that should be excluded from export.
const SYSTEM_SCHEMAS: &[&str] = &["information_schema", "pg_catalog"];
/// Extracts schema definitions from information_schema.
pub struct SchemaExtractor<'a> {
client: &'a DatabaseClient,
catalog: &'a str,
}
impl<'a> SchemaExtractor<'a> {
/// Creates a new schema extractor.
pub fn new(client: &'a DatabaseClient, catalog: &'a str) -> Self {
Self { client, catalog }
}
/// Extracts the schema index for the given schemas.
///
/// If `schemas` is None, extracts all non-system schemas.
pub async fn extract(&self, schemas: Option<&[String]>) -> Result<SchemaSnapshot> {
let mut snapshot = SchemaSnapshot::new();
let schema_names = match schemas {
Some(names) => self.validate_schemas(names).await?,
None => self.get_all_schemas().await?,
};
for schema_name in &schema_names {
let schema_def = self.extract_schema_definition(schema_name).await?;
snapshot.add_schema(schema_def);
}
Ok(snapshot)
}
/// Gets all non-system schemas in the catalog.
async fn get_all_schemas(&self) -> Result<Vec<String>> {
let sql = format!(
"SELECT schema_name FROM information_schema.schemata \
WHERE catalog_name = '{}'",
escape_sql_literal(self.catalog)
);
let records = self.query(&sql).await?;
let mut schemas = Vec::new();
for row in records {
let name = extract_string(&row, 0)?;
if !SYSTEM_SCHEMAS.contains(&name.as_str()) {
schemas.push(name);
}
}
Ok(schemas)
}
/// Validates that all specified schemas exist.
async fn validate_schemas(&self, schemas: &[String]) -> Result<Vec<String>> {
let all_schemas = self.get_all_schemas().await?;
dedupe_canonicalized_schemas(schemas, &all_schemas, self.catalog)
}
/// Extracts schema (database) definition.
async fn extract_schema_definition(&self, schema: &str) -> Result<SchemaDefinition> {
let sql = format!(
"SELECT schema_name, options FROM information_schema.schemata \
WHERE catalog_name = '{}' AND schema_name = '{}'",
escape_sql_literal(self.catalog),
escape_sql_literal(schema)
);
let records = self.query(&sql).await?;
if records.is_empty() {
return SchemaNotFoundSnafu {
catalog: self.catalog,
schema,
}
.fail();
}
let name = extract_string(&records[0], 0)?;
let options = extract_optional_string(&records[0], 1)
.map(|opts| parse_options(&opts))
.unwrap_or_default();
Ok(SchemaDefinition {
catalog: self.catalog.to_string(),
name,
options,
})
}
/// Executes a SQL query and returns the results.
async fn query(&self, sql: &str) -> Result<Vec<Vec<Value>>> {
self.client
.sql_in_public(sql)
.await
.context(DatabaseSnafu)?
.ok_or_else(|| EmptyResultSnafu.build())
}
}
/// Extracts a string value from a row.
fn extract_string(row: &[Value], index: usize) -> Result<String> {
match row.get(index) {
Some(Value::String(s)) => Ok(s.clone()),
Some(Value::Null) => UnexpectedValueTypeSnafu.fail(),
_ => UnexpectedValueTypeSnafu.fail(),
}
}
/// Extracts an optional string value from a row.
fn extract_optional_string(row: &[Value], index: usize) -> Option<String> {
match row.get(index) {
Some(Value::String(s)) if !s.is_empty() => Some(s.clone()),
_ => None,
}
}
/// Parses options string into a HashMap.
fn parse_options(options_str: &str) -> HashMap<String, String> {
if let Ok(map) = serde_json::from_str::<HashMap<String, String>>(options_str) {
return map;
}
let mut options = HashMap::new();
for line in options_str.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if let Some((key, value)) = parse_quoted_option_line(trimmed) {
options.insert(key, value);
continue;
}
for part in trimmed.split_whitespace() {
if let Some((key, value)) = part.split_once('=') {
options.insert(key.to_string(), value.to_string());
}
}
}
options
}
fn parse_quoted_option_line(line: &str) -> Option<(String, String)> {
let key = line.strip_prefix('\'')?;
let (key, rest) = key.split_once("'='")?;
let value = rest.strip_suffix('\'')?;
Some((key.to_string(), value.to_string()))
}
fn dedupe_canonicalized_schemas(
requested: &[String],
available: &[String],
catalog: &str,
) -> Result<Vec<String>> {
let mut canonicalized = Vec::new();
let mut seen = HashSet::new();
for schema in requested {
let Some(canonical) = available.iter().find(|s| s.eq_ignore_ascii_case(schema)) else {
return SchemaNotFoundSnafu { catalog, schema }.fail();
};
if seen.insert(canonical.to_ascii_lowercase()) {
canonicalized.push(canonical.clone());
}
}
Ok(canonicalized)
}
#[cfg(test)]
mod tests {
use serde_json::Value;
use super::*;
#[test]
fn test_parse_options_json() {
let opts = r#"{"ttl": "30d", "custom": "value"}"#;
let parsed = parse_options(opts);
assert_eq!(parsed.get("ttl"), Some(&"30d".to_string()));
assert_eq!(parsed.get("custom"), Some(&"value".to_string()));
}
#[test]
fn test_parse_options_key_value() {
let opts = "ttl=30d custom=value";
let parsed = parse_options(opts);
assert_eq!(parsed.get("ttl"), Some(&"30d".to_string()));
assert_eq!(parsed.get("custom"), Some(&"value".to_string()));
}
#[test]
fn test_parse_options_schema_display_format() {
let opts = "'ttl'='30d'\n'custom'='value with spaces'\n";
let parsed = parse_options(opts);
assert_eq!(parsed.get("ttl"), Some(&"30d".to_string()));
assert_eq!(parsed.get("custom"), Some(&"value with spaces".to_string()));
}
#[test]
fn test_extract_string_rejects_null() {
let row = vec![Value::Null];
assert!(extract_string(&row, 0).is_err());
}
#[test]
fn test_dedupe_canonicalized_schemas() {
let available = vec!["public".to_string(), "test_db".to_string()];
let requested = vec![
"PUBLIC".to_string(),
"public".to_string(),
"Test_Db".to_string(),
];
let canonicalized = dedupe_canonicalized_schemas(&requested, &available, "greptime")
.expect("schemas should be canonicalized");
assert_eq!(canonicalized, vec!["public", "test_db"]);
}
}

View File

@@ -0,0 +1,381 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Manifest data structures for Export/Import V2.
use std::{fmt, str};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
/// Current manifest format version.
pub const MANIFEST_VERSION: u32 = 1;
/// Manifest file name within snapshot directory.
pub const MANIFEST_FILE: &str = "manifest.json";
/// Time range for data export (half-open interval: [start, end)).
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct TimeRange {
/// Start time (inclusive). None means earliest available data.
#[serde(skip_serializing_if = "Option::is_none")]
pub start: Option<DateTime<Utc>>,
/// End time (exclusive). None means current time.
#[serde(skip_serializing_if = "Option::is_none")]
pub end: Option<DateTime<Utc>>,
}
impl TimeRange {
/// Creates a new time range with specified bounds.
pub fn new(start: Option<DateTime<Utc>>, end: Option<DateTime<Utc>>) -> Self {
Self { start, end }
}
/// Creates an unbounded time range (all data).
pub fn unbounded() -> Self {
Self {
start: None,
end: None,
}
}
/// Returns true if this time range is unbounded.
pub fn is_unbounded(&self) -> bool {
self.start.is_none() && self.end.is_none()
}
}
impl Default for TimeRange {
fn default() -> Self {
Self::unbounded()
}
}
/// Status of a chunk during export/import.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum ChunkStatus {
/// Chunk is pending export.
#[default]
Pending,
/// Chunk export is in progress.
InProgress,
/// Chunk export completed successfully.
Completed,
/// Chunk export failed.
Failed,
}
/// Metadata for a single chunk of exported data.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkMeta {
/// Chunk identifier (sequential number starting from 1).
pub id: u32,
/// Time range covered by this chunk.
pub time_range: TimeRange,
/// Export status.
pub status: ChunkStatus,
/// List of data files in this chunk (relative paths from snapshot root).
#[serde(default)]
pub files: Vec<String>,
/// SHA256 checksum of all files in this chunk (aggregated).
#[serde(skip_serializing_if = "Option::is_none")]
pub checksum: Option<String>,
/// Error message if status is Failed.
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
impl ChunkMeta {
/// Creates a new pending chunk with the given id and time range.
pub fn new(id: u32, time_range: TimeRange) -> Self {
Self {
id,
time_range,
status: ChunkStatus::Pending,
files: vec![],
checksum: None,
error: None,
}
}
/// Marks this chunk as in progress.
pub fn mark_in_progress(&mut self) {
self.status = ChunkStatus::InProgress;
self.error = None;
}
/// Marks this chunk as completed with the given files and checksum.
pub fn mark_completed(&mut self, files: Vec<String>, checksum: Option<String>) {
self.status = ChunkStatus::Completed;
self.files = files;
self.checksum = checksum;
self.error = None;
}
/// Marks this chunk as failed with the given error message.
pub fn mark_failed(&mut self, error: String) {
self.status = ChunkStatus::Failed;
self.error = Some(error);
}
}
/// Supported data formats for export.
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default, clap::ValueEnum)]
#[serde(rename_all = "lowercase")]
#[value(rename_all = "lowercase")]
pub enum DataFormat {
/// Apache Parquet format (default, recommended for production).
#[default]
Parquet,
/// CSV format (human-readable).
Csv,
/// JSON format (structured text).
Json,
}
impl fmt::Display for DataFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DataFormat::Parquet => write!(f, "parquet"),
DataFormat::Csv => write!(f, "csv"),
DataFormat::Json => write!(f, "json"),
}
}
}
impl str::FromStr for DataFormat {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"parquet" => Ok(DataFormat::Parquet),
"csv" => Ok(DataFormat::Csv),
"json" => Ok(DataFormat::Json),
_ => Err(format!(
"invalid format '{}': expected one of parquet, csv, json",
s
)),
}
}
}
/// Snapshot manifest containing all metadata.
///
/// The manifest is stored as `manifest.json` in the snapshot root directory.
/// It contains:
/// - Snapshot identification (UUID, timestamps)
/// - Scope (catalog, schemas, time range)
/// - Export configuration (format, schema_only)
/// - Chunk metadata for resume support
/// - Integrity checksums
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Manifest {
/// Manifest format version for compatibility checking.
pub version: u32,
/// Unique snapshot identifier.
pub snapshot_id: Uuid,
/// Catalog name.
pub catalog: String,
/// List of schemas included in this snapshot.
pub schemas: Vec<String>,
/// Overall time range covered by this snapshot.
pub time_range: TimeRange,
/// Whether this is a schema-only snapshot (no data).
pub schema_only: bool,
/// Data format used for export.
pub format: DataFormat,
/// Chunk metadata (empty for schema-only snapshots).
#[serde(default)]
pub chunks: Vec<ChunkMeta>,
/// Snapshot-level SHA256 checksum (aggregated from all chunks).
#[serde(skip_serializing_if = "Option::is_none")]
pub checksum: Option<String>,
/// Creation timestamp.
pub created_at: DateTime<Utc>,
/// Last updated timestamp.
pub updated_at: DateTime<Utc>,
}
impl Manifest {
/// Creates a new manifest for schema-only export.
pub fn new_schema_only(catalog: String, schemas: Vec<String>) -> Self {
let now = Utc::now();
Self {
version: MANIFEST_VERSION,
snapshot_id: Uuid::new_v4(),
catalog,
schemas,
time_range: TimeRange::unbounded(),
schema_only: true,
format: DataFormat::Parquet,
chunks: vec![],
checksum: None,
created_at: now,
updated_at: now,
}
}
/// Creates a new manifest for full export with time range and format.
pub fn new_full(
catalog: String,
schemas: Vec<String>,
time_range: TimeRange,
format: DataFormat,
) -> Self {
let now = Utc::now();
Self {
version: MANIFEST_VERSION,
snapshot_id: Uuid::new_v4(),
catalog,
schemas,
time_range,
schema_only: false,
format,
chunks: vec![],
checksum: None,
created_at: now,
updated_at: now,
}
}
/// Returns true if all chunks are completed (or if schema-only).
pub fn is_complete(&self) -> bool {
self.schema_only
|| (!self.chunks.is_empty()
&& self
.chunks
.iter()
.all(|c| c.status == ChunkStatus::Completed))
}
/// Returns the number of pending chunks.
pub fn pending_count(&self) -> usize {
self.chunks
.iter()
.filter(|c| c.status == ChunkStatus::Pending)
.count()
}
/// Returns the number of in-progress chunks.
pub fn in_progress_count(&self) -> usize {
self.chunks
.iter()
.filter(|c| c.status == ChunkStatus::InProgress)
.count()
}
/// Returns the number of completed chunks.
pub fn completed_count(&self) -> usize {
self.chunks
.iter()
.filter(|c| c.status == ChunkStatus::Completed)
.count()
}
/// Returns the number of failed chunks.
pub fn failed_count(&self) -> usize {
self.chunks
.iter()
.filter(|c| c.status == ChunkStatus::Failed)
.count()
}
/// Updates the `updated_at` timestamp to now.
pub fn touch(&mut self) {
self.updated_at = Utc::now();
}
/// Adds a chunk to the manifest.
pub fn add_chunk(&mut self, chunk: ChunkMeta) {
self.chunks.push(chunk);
self.touch();
}
/// Updates a chunk by id.
pub fn update_chunk(&mut self, id: u32, updater: impl FnOnce(&mut ChunkMeta)) {
if let Some(chunk) = self.chunks.iter_mut().find(|c| c.id == id) {
updater(chunk);
self.touch();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_time_range_serialization() {
let range = TimeRange::unbounded();
let json = serde_json::to_string(&range).unwrap();
assert_eq!(json, "{}");
let range: TimeRange = serde_json::from_str("{}").unwrap();
assert!(range.is_unbounded());
}
#[test]
fn test_manifest_schema_only() {
let manifest =
Manifest::new_schema_only("greptime".to_string(), vec!["public".to_string()]);
assert_eq!(manifest.version, MANIFEST_VERSION);
assert!(manifest.schema_only);
assert!(manifest.chunks.is_empty());
assert!(manifest.is_complete());
}
#[test]
fn test_manifest_full() {
let manifest = Manifest::new_full(
"greptime".to_string(),
vec!["public".to_string()],
TimeRange::unbounded(),
DataFormat::Parquet,
);
assert!(!manifest.schema_only);
assert!(manifest.chunks.is_empty());
assert!(!manifest.is_complete());
}
#[test]
fn test_data_format_parsing() {
assert_eq!(
"parquet".parse::<DataFormat>().unwrap(),
DataFormat::Parquet
);
assert_eq!("CSV".parse::<DataFormat>().unwrap(), DataFormat::Csv);
assert_eq!("JSON".parse::<DataFormat>().unwrap(), DataFormat::Json);
assert!("invalid".parse::<DataFormat>().is_err());
}
#[test]
fn test_chunk_status_transitions() {
let mut chunk = ChunkMeta::new(1, TimeRange::unbounded());
assert_eq!(chunk.status, ChunkStatus::Pending);
chunk.mark_in_progress();
assert_eq!(chunk.status, ChunkStatus::InProgress);
chunk.mark_completed(
vec!["file1.parquet".to_string()],
Some("abc123".to_string()),
);
assert_eq!(chunk.status, ChunkStatus::Completed);
assert_eq!(chunk.files.len(), 1);
}
}

View File

@@ -0,0 +1,98 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Minimal schema index structures for Export/Import V2.
//!
//! The canonical schema representation is the per-schema DDL file under
//! `schema/ddl/`. `schemas.json` only records which schemas exist in a snapshot.
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
/// Schema directory name within snapshot.
pub const SCHEMA_DIR: &str = "schema";
/// DDL directory name within schema directory.
pub const DDL_DIR: &str = "ddl";
/// Schema definition file name.
pub const SCHEMAS_FILE: &str = "schemas.json";
/// Schema (database) definition.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SchemaDefinition {
/// Catalog name.
pub catalog: String,
/// Schema (database) name.
pub name: String,
/// Schema options (if any).
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub options: HashMap<String, String>,
}
/// Minimal schema index stored in a snapshot.
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct SchemaSnapshot {
/// Schema (database) definitions.
pub schemas: Vec<SchemaDefinition>,
}
impl SchemaSnapshot {
/// Creates an empty schema snapshot.
pub fn new() -> Self {
Self::default()
}
/// Adds a schema definition.
pub fn add_schema(&mut self, schema: SchemaDefinition) {
self.schemas.push(schema);
}
/// Filters the snapshot to only include specified schemas.
pub fn filter_schemas(&self, schemas: &[String]) -> Self {
Self {
schemas: self
.schemas
.iter()
.filter(|s| schemas.contains(&s.name))
.cloned()
.collect(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_schema_snapshot_filter() {
let mut snapshot = SchemaSnapshot::new();
snapshot.add_schema(SchemaDefinition {
catalog: "greptime".to_string(),
name: "public".to_string(),
options: HashMap::new(),
});
snapshot.add_schema(SchemaDefinition {
catalog: "greptime".to_string(),
name: "private".to_string(),
options: HashMap::new(),
});
let filtered = snapshot.filter_schemas(&["public".to_string()]);
assert_eq!(filtered.schemas.len(), 1);
assert_eq!(filtered.schemas[0].name, "public");
}
}

View File

@@ -0,0 +1,341 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::env;
use std::time::Duration;
use clap::Parser;
use common_error::ext::BoxedError;
use snafu::ResultExt;
use tempfile::tempdir;
use url::Url;
use super::command::ExportCreateCommand;
use crate::common::ObjectStoreConfig;
use crate::data::import_v2::ImportV2Command;
use crate::data::snapshot_storage::OpenDalStorage;
use crate::database::DatabaseClient;
use crate::error::{FileIoSnafu, InvalidArgumentsSnafu, OtherSnafu, Result};
#[tokio::test]
#[ignore]
async fn export_import_v2_schema_parity_e2e() -> Result<()> {
let addr = env::var("GREPTIME_ADDR").unwrap_or_else(|_| "127.0.0.1:4000".to_string());
let catalog = env::var("GREPTIME_CATALOG").unwrap_or_else(|_| "greptime".to_string());
let auth_basic = env::var("GREPTIME_AUTH_BASIC").ok();
let schema = "test_db_schema_parity";
let database_client = DatabaseClient::new(
addr.clone(),
catalog.clone(),
auth_basic.clone(),
Duration::from_secs(60),
None,
false,
);
database_client
.sql_in_public(&format!("DROP DATABASE IF EXISTS {schema}"))
.await?;
database_client
.sql_in_public(&format!("CREATE DATABASE {schema}"))
.await?;
database_client
.sql(
"CREATE TABLE metrics (\
ts TIMESTAMP TIME INDEX, \
host STRING PRIMARY KEY, \
cpu DOUBLE DEFAULT 0.0, \
region_name STRING \
) ENGINE = mito WITH (ttl='7d', 'compaction.type'='twcs')",
schema,
)
.await?;
database_client
.sql(
"CREATE TABLE logs (\
ts TIMESTAMP TIME INDEX, \
app STRING PRIMARY KEY, \
msg STRING NOT NULL COMMENT 'log message' \
) ENGINE = mito",
schema,
)
.await?;
database_client
.sql(
"CREATE TABLE metrics_physical (\
ts TIMESTAMP TIME INDEX, \
host STRING, \
region_name STRING, \
cpu DOUBLE DEFAULT 0.0, \
PRIMARY KEY (host, region_name) \
) ENGINE = metric WITH (physical_metric_table='true')",
schema,
)
.await?;
database_client
.sql(
"CREATE TABLE metrics_logical (\
ts TIMESTAMP TIME INDEX, \
host STRING, \
region_name STRING, \
cpu DOUBLE DEFAULT 0.0, \
PRIMARY KEY (host, region_name) \
) ENGINE = metric WITH (on_physical_table='metrics_physical')",
schema,
)
.await?;
database_client
.sql(
"CREATE VIEW metrics_view AS SELECT * FROM metrics WHERE cpu > 0.5",
schema,
)
.await?;
let src_dir = tempdir().context(FileIoSnafu)?;
let src_uri = Url::from_directory_path(src_dir.path())
.map_err(|_| {
InvalidArgumentsSnafu {
msg: "invalid temp dir path".to_string(),
}
.build()
})?
.to_string();
let mut export_args = vec![
"export-v2-create",
"--addr",
&addr,
"--to",
&src_uri,
"--catalog",
&catalog,
"--schemas",
schema,
"--schema-only",
];
if let Some(auth) = &auth_basic {
export_args.push("--auth-basic");
export_args.push(auth);
}
let export_cmd = ExportCreateCommand::parse_from(export_args);
export_cmd
.build()
.await
.context(OtherSnafu)?
.do_work()
.await
.context(OtherSnafu)?;
database_client
.sql_in_public(&format!("DROP DATABASE {schema}"))
.await?;
let mut import_args = vec![
"import-v2",
"--addr",
&addr,
"--from",
&src_uri,
"--catalog",
&catalog,
"--schemas",
schema,
];
if let Some(auth) = &auth_basic {
import_args.push("--auth-basic");
import_args.push(auth);
}
let import_cmd = ImportV2Command::parse_from(import_args);
import_cmd
.build()
.await
.context(OtherSnafu)?
.do_work()
.await
.context(OtherSnafu)?;
let dst_dir = tempdir().context(FileIoSnafu)?;
let dst_uri = Url::from_directory_path(dst_dir.path())
.map_err(|_| {
InvalidArgumentsSnafu {
msg: "invalid temp dir path".to_string(),
}
.build()
})?
.to_string();
let mut export_args = vec![
"export-v2-create",
"--addr",
&addr,
"--to",
&dst_uri,
"--catalog",
&catalog,
"--schemas",
schema,
"--schema-only",
];
if let Some(auth) = &auth_basic {
export_args.push("--auth-basic");
export_args.push(auth);
}
let export_cmd = ExportCreateCommand::parse_from(export_args);
export_cmd
.build()
.await
.context(OtherSnafu)?
.do_work()
.await
.context(OtherSnafu)?;
let storage_config = ObjectStoreConfig::default();
let src_storage = OpenDalStorage::from_uri(&src_uri, &storage_config)
.map_err(BoxedError::new)
.context(OtherSnafu)?;
let dst_storage = OpenDalStorage::from_uri(&dst_uri, &storage_config)
.map_err(BoxedError::new)
.context(OtherSnafu)?;
let src_schema_snapshot = src_storage
.read_schema()
.await
.map_err(BoxedError::new)
.context(OtherSnafu)?;
let dst_schema_snapshot = dst_storage
.read_schema()
.await
.map_err(BoxedError::new)
.context(OtherSnafu)?;
assert_eq!(src_schema_snapshot, dst_schema_snapshot);
database_client
.sql_in_public(&format!("DROP DATABASE IF EXISTS {schema}"))
.await?;
Ok(())
}
#[tokio::test]
#[ignore]
async fn import_v2_ddl_dry_run_e2e() -> Result<()> {
let addr = env::var("GREPTIME_ADDR").unwrap_or_else(|_| "127.0.0.1:4000".to_string());
let catalog = env::var("GREPTIME_CATALOG").unwrap_or_else(|_| "greptime".to_string());
let auth_basic = env::var("GREPTIME_AUTH_BASIC").ok();
let schema = "test_db_ddl_dry_run";
let database_client = DatabaseClient::new(
addr.clone(),
catalog.clone(),
auth_basic.clone(),
Duration::from_secs(60),
None,
false,
);
database_client
.sql_in_public(&format!("DROP DATABASE IF EXISTS {schema}"))
.await?;
database_client
.sql_in_public(&format!("CREATE DATABASE {schema}"))
.await?;
database_client
.sql(
"CREATE TABLE metrics (\
ts TIMESTAMP TIME INDEX, \
host STRING PRIMARY KEY, \
cpu DOUBLE DEFAULT 0.0, \
region_name STRING \
) ENGINE = mito WITH (ttl='7d', 'compaction.type'='twcs')",
schema,
)
.await?;
database_client
.sql(
"CREATE TABLE logs (\
ts TIMESTAMP TIME INDEX, \
app STRING PRIMARY KEY, \
msg STRING NOT NULL COMMENT 'log message' \
) ENGINE = mito",
schema,
)
.await?;
let src_dir = tempdir().context(FileIoSnafu)?;
let src_uri = Url::from_directory_path(src_dir.path())
.map_err(|_| {
InvalidArgumentsSnafu {
msg: "invalid temp dir path".to_string(),
}
.build()
})?
.to_string();
let mut export_args = vec![
"export-v2-create",
"--addr",
&addr,
"--to",
&src_uri,
"--catalog",
&catalog,
"--schemas",
schema,
"--schema-only",
];
if let Some(auth) = &auth_basic {
export_args.push("--auth-basic");
export_args.push(auth);
}
let export_cmd = ExportCreateCommand::parse_from(export_args);
export_cmd
.build()
.await
.context(OtherSnafu)?
.do_work()
.await
.context(OtherSnafu)?;
let mut import_args = vec![
"import-v2",
"--addr",
&addr,
"--from",
&src_uri,
"--catalog",
&catalog,
"--schemas",
schema,
"--dry-run",
];
if let Some(auth) = &auth_basic {
import_args.push("--auth-basic");
import_args.push(auth);
}
let import_cmd = ImportV2Command::parse_from(import_args);
import_cmd
.build()
.await
.context(OtherSnafu)?
.do_work()
.await
.context(OtherSnafu)?;
database_client
.sql_in_public(&format!("DROP DATABASE IF EXISTS {schema}"))
.await?;
Ok(())
}

View File

@@ -81,13 +81,16 @@ pub struct ImportCommand {
#[clap(long, value_parser = humantime::parse_duration)]
timeout: Option<Duration>,
/// The proxy server address to connect, if set, will override the system proxy.
/// The proxy server address to connect.
///
/// The default behavior will use the system proxy if neither `proxy` nor `no_proxy` is set.
/// If set, it overrides the system proxy unless `--no-proxy` is specified.
/// If neither `--proxy` nor `--no-proxy` is set, system proxy (env) may be used.
#[clap(long)]
proxy: Option<String>,
/// Disable proxy server, if set, will not use any proxy.
/// Disable all proxy usage (ignores `--proxy` and system proxy).
///
/// When set and `--proxy` is not provided, this explicitly disables system proxy.
#[clap(long, default_value = "false")]
no_proxy: bool,
}
@@ -104,6 +107,7 @@ impl ImportCommand {
// Treats `None` as `0s` to disable server-side default timeout.
self.timeout.unwrap_or_default(),
proxy,
self.no_proxy,
);
Ok(Box::new(Import {
@@ -314,6 +318,7 @@ mod tests {
None,
Duration::from_secs(0),
None,
false,
),
input_dir: input_dir.to_string(),
parallelism: 1,

View File

@@ -0,0 +1,41 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Import V2 module.
//!
//! This module provides the V2 implementation of database import functionality,
//! featuring:
//! - DDL-based schema import
//! - Dry-run mode for verification
//!
//! # Example
//!
//! ```bash
//! # Dry-run import (verify without executing)
//! greptime cli data import-v2 \
//! --addr 127.0.0.1:4000 \
//! --from file:///tmp/snapshot \
//! --dry-run
//!
//! # Actual import
//! greptime cli data import-v2 \
//! --addr 127.0.0.1:4000 \
//! --from s3://bucket/snapshots/prod-20250101
//! ```
mod command;
pub mod error;
pub mod executor;
pub use command::ImportV2Command;

View File

@@ -0,0 +1,542 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Import V2 CLI command.
use std::collections::HashSet;
use std::time::Duration;
use async_trait::async_trait;
use clap::Parser;
use common_error::ext::BoxedError;
use common_telemetry::info;
use snafu::ResultExt;
use crate::Tool;
use crate::common::ObjectStoreConfig;
use crate::data::export_v2::manifest::MANIFEST_VERSION;
use crate::data::import_v2::error::{
ManifestVersionMismatchSnafu, Result, SchemaNotInSnapshotSnafu, SnapshotStorageSnafu,
};
use crate::data::import_v2::executor::{DdlExecutor, DdlStatement};
use crate::data::path::ddl_path_for_schema;
use crate::data::snapshot_storage::{OpenDalStorage, SnapshotStorage, validate_uri};
use crate::database::{DatabaseClient, parse_proxy_opts};
/// Import from a snapshot.
#[derive(Debug, Parser)]
pub struct ImportV2Command {
/// Server address to connect (e.g., 127.0.0.1:4000).
#[clap(long)]
addr: String,
/// Source snapshot location (e.g., s3://bucket/path, file:///tmp/backup).
#[clap(long)]
from: String,
/// Target catalog name.
#[clap(long, default_value = "greptime")]
catalog: String,
/// Schema list to import (default: all in snapshot).
/// Can be specified multiple times or comma-separated.
#[clap(long, value_delimiter = ',')]
schemas: Vec<String>,
/// Verify without importing (dry-run).
#[clap(long)]
dry_run: bool,
/// Concurrency level (for future use).
#[clap(long, default_value = "1")]
parallelism: usize,
/// Basic authentication (user:password).
#[clap(long)]
auth_basic: Option<String>,
/// Request timeout.
#[clap(long, value_parser = humantime::parse_duration)]
timeout: Option<Duration>,
/// Proxy server address.
///
/// If set, it overrides the system proxy unless `--no-proxy` is specified.
/// If neither `--proxy` nor `--no-proxy` is set, system proxy (env) may be used.
#[clap(long)]
proxy: Option<String>,
/// Disable all proxy usage (ignores `--proxy` and system proxy).
///
/// When set and `--proxy` is not provided, this explicitly disables system proxy.
#[clap(long)]
no_proxy: bool,
/// Object store configuration for remote storage backends.
#[clap(flatten)]
storage: ObjectStoreConfig,
}
impl ImportV2Command {
pub async fn build(&self) -> std::result::Result<Box<dyn Tool>, BoxedError> {
// Validate URI format
validate_uri(&self.from)
.context(SnapshotStorageSnafu)
.map_err(BoxedError::new)?;
// Parse schemas (empty vec means all schemas)
let schemas = if self.schemas.is_empty() {
None
} else {
Some(self.schemas.clone())
};
// Build storage
let storage = OpenDalStorage::from_uri(&self.from, &self.storage)
.context(SnapshotStorageSnafu)
.map_err(BoxedError::new)?;
// Build database client
let proxy = parse_proxy_opts(self.proxy.clone(), self.no_proxy)?;
let database_client = DatabaseClient::new(
self.addr.clone(),
self.catalog.clone(),
self.auth_basic.clone(),
self.timeout.unwrap_or(Duration::from_secs(60)),
proxy,
self.no_proxy,
);
Ok(Box::new(Import {
schemas,
dry_run: self.dry_run,
_parallelism: self.parallelism,
storage: Box::new(storage),
database_client,
}))
}
}
/// Import tool implementation.
pub struct Import {
schemas: Option<Vec<String>>,
dry_run: bool,
_parallelism: usize,
storage: Box<dyn SnapshotStorage>,
database_client: DatabaseClient,
}
#[async_trait]
impl Tool for Import {
async fn do_work(&self) -> std::result::Result<(), BoxedError> {
self.run().await.map_err(BoxedError::new)
}
}
impl Import {
async fn run(&self) -> Result<()> {
// 1. Read manifest
let manifest = self
.storage
.read_manifest()
.await
.context(SnapshotStorageSnafu)?;
info!(
"Loading snapshot: {} (version: {}, schema_only: {})",
manifest.snapshot_id, manifest.version, manifest.schema_only
);
// Check version compatibility
if manifest.version != MANIFEST_VERSION {
return ManifestVersionMismatchSnafu {
expected: MANIFEST_VERSION,
found: manifest.version,
}
.fail();
}
info!("Snapshot contains {} schema(s)", manifest.schemas.len());
// 2. Determine schemas to import
let schemas_to_import = match &self.schemas {
Some(filter) => canonicalize_schema_filter(filter, &manifest.schemas)?,
None => manifest.schemas.clone(),
};
info!("Importing schemas: {:?}", schemas_to_import);
// 3. Read DDL statements
let ddl_statements = self.read_ddl_statements(&schemas_to_import).await?;
info!("Generated {} DDL statements", ddl_statements.len());
// 4. Dry-run mode: print DDL and exit
if self.dry_run {
info!("Dry-run mode - DDL statements to execute:");
println!();
for (i, stmt) in ddl_statements.iter().enumerate() {
println!("-- Statement {}", i + 1);
println!("{};", stmt.sql);
println!();
}
return Ok(());
}
// 5. Execute DDL
let executor = DdlExecutor::new(&self.database_client);
executor.execute_strict(&ddl_statements).await?;
info!(
"Import completed: {} DDL statements executed",
ddl_statements.len()
);
// 6. Data import would happen here for non-schema-only snapshots (M2/M3)
if !manifest.schema_only && !manifest.chunks.is_empty() {
info!(
"Data import not yet implemented (M3). {} chunks pending.",
manifest.chunks.len()
);
}
Ok(())
}
async fn read_ddl_statements(&self, schemas: &[String]) -> Result<Vec<DdlStatement>> {
let mut statements = Vec::new();
for schema in schemas {
let path = ddl_path_for_schema(schema);
let content = self
.storage
.read_text(&path)
.await
.context(SnapshotStorageSnafu)?;
statements.extend(
parse_ddl_statements(&content)
.into_iter()
.map(|sql| ddl_statement_for_schema(schema, sql)),
);
}
Ok(statements)
}
}
fn parse_ddl_statements(content: &str) -> Vec<String> {
let mut statements = Vec::new();
let mut current = String::new();
let mut chars = content.chars().peekable();
let mut in_single_quote = false;
let mut in_double_quote = false;
let mut in_line_comment = false;
let mut in_block_comment = false;
while let Some(ch) = chars.next() {
if in_line_comment {
if ch == '\n' {
in_line_comment = false;
current.push('\n');
}
continue;
}
if in_block_comment {
if ch == '*' && chars.peek() == Some(&'/') {
chars.next();
in_block_comment = false;
}
continue;
}
if in_single_quote {
current.push(ch);
if ch == '\'' {
if chars.peek() == Some(&'\'') {
current.push(chars.next().expect("peeked quote must exist"));
} else {
in_single_quote = false;
}
}
continue;
}
if in_double_quote {
current.push(ch);
if ch == '"' {
if chars.peek() == Some(&'"') {
current.push(chars.next().expect("peeked quote must exist"));
} else {
in_double_quote = false;
}
}
continue;
}
match ch {
'-' if chars.peek() == Some(&'-') => {
chars.next();
in_line_comment = true;
}
'/' if chars.peek() == Some(&'*') => {
chars.next();
in_block_comment = true;
}
'\'' => {
in_single_quote = true;
current.push(ch);
}
'"' => {
in_double_quote = true;
current.push(ch);
}
';' => {
let statement = current.trim();
if !statement.is_empty() {
statements.push(statement.to_string());
}
current.clear();
}
_ => current.push(ch),
}
}
let statement = current.trim();
if !statement.is_empty() {
statements.push(statement.to_string());
}
statements
}
fn ddl_statement_for_schema(schema: &str, sql: String) -> DdlStatement {
if is_schema_scoped_statement(&sql) {
DdlStatement::with_execution_schema(sql, schema.to_string())
} else {
DdlStatement::new(sql)
}
}
fn is_schema_scoped_statement(sql: &str) -> bool {
let trimmed = sql.trim_start();
if !starts_with_keyword(trimmed, "CREATE") {
return false;
}
let Some(rest) = trimmed.get("CREATE".len()..) else {
return false;
};
let mut rest = rest.trim_start();
if starts_with_keyword(rest, "OR") {
let Some(next) = rest.get("OR".len()..) else {
return false;
};
rest = next.trim_start();
if !starts_with_keyword(rest, "REPLACE") {
return false;
}
let Some(next) = rest.get("REPLACE".len()..) else {
return false;
};
rest = next.trim_start();
}
if starts_with_keyword(rest, "EXTERNAL") {
let Some(next) = rest.get("EXTERNAL".len()..) else {
return false;
};
rest = next.trim_start();
}
starts_with_keyword(rest, "TABLE") || starts_with_keyword(rest, "VIEW")
}
fn starts_with_keyword(input: &str, keyword: &str) -> bool {
input
.get(0..keyword.len())
.map(|s| s.eq_ignore_ascii_case(keyword))
.unwrap_or(false)
&& input
.as_bytes()
.get(keyword.len())
.map(|b| !b.is_ascii_alphanumeric() && *b != b'_')
.unwrap_or(true)
}
fn canonicalize_schema_filter(
filter: &[String],
manifest_schemas: &[String],
) -> Result<Vec<String>> {
let mut canonicalized = Vec::new();
let mut seen = HashSet::new();
for schema in filter {
let canonical = manifest_schemas
.iter()
.find(|candidate| candidate.eq_ignore_ascii_case(schema))
.cloned()
.ok_or_else(|| {
SchemaNotInSnapshotSnafu {
schema: schema.clone(),
}
.build()
})?;
if seen.insert(canonical.to_ascii_lowercase()) {
canonicalized.push(canonical);
}
}
Ok(canonicalized)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_ddl_statements() {
let content = r#"
-- Schema: public
CREATE DATABASE public;
CREATE TABLE t (ts TIMESTAMP TIME INDEX, host STRING, PRIMARY KEY (host)) ENGINE=mito;
-- comment
CREATE VIEW v AS SELECT * FROM t;
"#;
let statements = parse_ddl_statements(content);
assert_eq!(statements.len(), 3);
assert!(statements[0].starts_with("CREATE DATABASE public"));
assert!(statements[1].starts_with("CREATE TABLE t"));
assert!(statements[2].starts_with("CREATE VIEW v"));
}
#[test]
fn test_parse_ddl_statements_preserves_semicolons_in_string_literals() {
let content = r#"
CREATE TABLE t (
host STRING DEFAULT 'a;b'
);
CREATE VIEW v AS SELECT ';' AS marker;
"#;
let statements = parse_ddl_statements(content);
assert_eq!(statements.len(), 2);
assert!(statements[0].contains("'a;b'"));
assert!(statements[1].contains("';' AS marker"));
}
#[test]
fn test_parse_ddl_statements_handles_comments_without_splitting() {
let content = r#"
-- leading comment
CREATE TABLE t (ts TIMESTAMP TIME INDEX); /* block; comment */
CREATE VIEW v AS SELECT 1;
"#;
let statements = parse_ddl_statements(content);
assert_eq!(statements.len(), 2);
assert!(statements[0].starts_with("CREATE TABLE t"));
assert!(statements[1].starts_with("CREATE VIEW v"));
}
#[test]
fn test_canonicalize_schema_filter_uses_manifest_casing() {
let filter = vec!["TEST_DB".to_string(), "PUBLIC".to_string()];
let manifest_schemas = vec!["test_db".to_string(), "public".to_string()];
let canonicalized = canonicalize_schema_filter(&filter, &manifest_schemas).unwrap();
assert_eq!(canonicalized, vec!["test_db", "public"]);
}
#[test]
fn test_canonicalize_schema_filter_dedupes_case_insensitive_matches() {
let filter = vec![
"TEST_DB".to_string(),
"test_db".to_string(),
"PUBLIC".to_string(),
"public".to_string(),
];
let manifest_schemas = vec!["test_db".to_string(), "public".to_string()];
let canonicalized = canonicalize_schema_filter(&filter, &manifest_schemas).unwrap();
assert_eq!(canonicalized, vec!["test_db", "public"]);
}
#[test]
fn test_canonicalize_schema_filter_rejects_missing_schema() {
let filter = vec!["missing".to_string()];
let manifest_schemas = vec!["test_db".to_string()];
let error = canonicalize_schema_filter(&filter, &manifest_schemas)
.expect_err("missing schema should fail")
.to_string();
assert!(error.contains("missing"));
}
#[test]
fn test_ddl_statement_for_schema_create_table_uses_execution_schema() {
let stmt = ddl_statement_for_schema(
"test_db",
"CREATE TABLE metrics (ts TIMESTAMP TIME INDEX) ENGINE=mito".to_string(),
);
assert_eq!(stmt.execution_schema.as_deref(), Some("test_db"));
}
#[test]
fn test_ddl_statement_for_schema_create_view_uses_execution_schema() {
let stmt = ddl_statement_for_schema(
"test_db",
"CREATE VIEW metrics_view AS SELECT * FROM metrics".to_string(),
);
assert_eq!(stmt.execution_schema.as_deref(), Some("test_db"));
}
#[test]
fn test_ddl_statement_for_schema_create_or_replace_view_uses_execution_schema() {
let stmt = ddl_statement_for_schema(
"test_db",
"CREATE OR REPLACE VIEW metrics_view AS SELECT * FROM metrics".to_string(),
);
assert_eq!(stmt.execution_schema.as_deref(), Some("test_db"));
}
#[test]
fn test_ddl_statement_for_schema_create_external_table_uses_execution_schema() {
let stmt = ddl_statement_for_schema(
"test_db",
"CREATE EXTERNAL TABLE IF NOT EXISTS ext_metrics (ts TIMESTAMP TIME INDEX) ENGINE=file"
.to_string(),
);
assert_eq!(stmt.execution_schema.as_deref(), Some("test_db"));
}
#[test]
fn test_ddl_statement_for_schema_create_database_uses_public_context() {
let stmt = ddl_statement_for_schema("test_db", "CREATE DATABASE test_db".to_string());
assert_eq!(stmt.execution_schema, None);
}
#[test]
fn test_starts_with_keyword_requires_word_boundary() {
assert!(starts_with_keyword("CREATE TABLE t", "CREATE"));
assert!(!starts_with_keyword("CREATED TABLE t", "CREATE"));
assert!(!starts_with_keyword("TABLESPACE foo", "TABLE"));
}
}

View File

@@ -0,0 +1,82 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::any::Any;
use common_error::ext::ErrorExt;
use common_error::status_code::StatusCode;
use common_macro::stack_trace_debug;
use snafu::{Location, Snafu};
#[derive(Snafu)]
#[snafu(visibility(pub))]
#[stack_trace_debug]
pub enum Error {
#[snafu(display("Snapshot not found at '{}'", uri))]
SnapshotNotFound {
uri: String,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Manifest version mismatch: expected {}, found {}", expected, found))]
ManifestVersionMismatch {
expected: u32,
found: u32,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Schema '{}' not found in snapshot", schema))]
SchemaNotInSnapshot {
schema: String,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Snapshot storage error"))]
SnapshotStorage {
#[snafu(source)]
error: crate::data::export_v2::error::Error,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Database error"))]
Database {
#[snafu(source)]
error: crate::error::Error,
#[snafu(implicit)]
location: Location,
},
}
pub type Result<T> = std::result::Result<T, Error>;
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::SnapshotNotFound { .. } | Error::SchemaNotInSnapshot { .. } => {
StatusCode::InvalidArguments
}
Error::ManifestVersionMismatch { .. } => StatusCode::InvalidArguments,
Error::Database { error, .. } => error.status_code(),
Error::SnapshotStorage { error, .. } => error.status_code(),
}
}
fn as_any(&self) -> &dyn Any {
self
}
}

View File

@@ -0,0 +1,122 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! DDL execution for import.
use common_telemetry::info;
use snafu::ResultExt;
use crate::data::import_v2::error::{DatabaseSnafu, Result};
use crate::database::DatabaseClient;
/// A DDL statement with an explicit execution schema context.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DdlStatement {
pub sql: String,
pub execution_schema: Option<String>,
}
impl DdlStatement {
pub fn new(sql: String) -> Self {
Self {
sql,
execution_schema: None,
}
}
pub fn with_execution_schema(sql: String, schema: String) -> Self {
Self {
sql,
execution_schema: Some(schema),
}
}
}
/// Executes DDL statements against the database.
pub struct DdlExecutor<'a> {
client: &'a DatabaseClient,
}
impl<'a> DdlExecutor<'a> {
/// Creates a new DDL executor.
pub fn new(client: &'a DatabaseClient) -> Self {
Self { client }
}
/// Executes a list of DDL statements, stopping on first error.
pub async fn execute_strict(&self, statements: &[DdlStatement]) -> Result<()> {
let total = statements.len();
for (i, stmt) in statements.iter().enumerate() {
let preview = preview_sql(&stmt.sql);
info!("Executing DDL ({}/{}): {}", i + 1, total, preview);
if let Some(schema) = stmt.execution_schema.as_deref() {
self.client
.sql(&stmt.sql, schema)
.await
.context(DatabaseSnafu)?;
} else {
self.client
.sql_in_public(&stmt.sql)
.await
.context(DatabaseSnafu)?;
}
}
Ok(())
}
}
fn preview_sql(sql: &str) -> String {
let mut chars = sql.chars();
let preview: String = chars.by_ref().take(80).collect();
if chars.next().is_some() {
format!("{preview}...")
} else {
preview
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_statement_without_execution_schema_uses_public() {
let stmt = DdlStatement::new("CREATE DATABASE IF NOT EXISTS test_db".to_string());
assert_eq!(stmt.execution_schema, None);
}
#[test]
fn test_statement_with_execution_schema_preserves_context() {
let stmt = DdlStatement::with_execution_schema(
r#"CREATE TABLE IF NOT EXISTS "my""schema"."metrics" (ts TIMESTAMP TIME INDEX)"#
.to_string(),
r#"my"schema"#.to_string(),
);
assert_eq!(stmt.execution_schema.as_deref(), Some(r#"my"schema"#));
}
#[test]
fn test_preview_sql_truncates_at_char_boundary() {
let sql = format!(
"CREATE TABLE {} (ts TIMESTAMP TIME INDEX)",
"".repeat(100)
);
let preview = preview_sql(&sql);
assert!(preview.ends_with("..."));
assert!(preview.is_char_boundary(preview.len()));
}
}

76
src/cli/src/data/path.rs Normal file
View File

@@ -0,0 +1,76 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Shared path helpers for export/import data files.
use crate::data::export_v2::schema::{DDL_DIR, SCHEMA_DIR};
pub(crate) fn ddl_path_for_schema(schema: &str) -> String {
format!(
"{}/{}/{}.sql",
SCHEMA_DIR,
DDL_DIR,
encode_path_segment(schema)
)
}
pub(crate) fn encode_path_segment(value: &str) -> String {
let mut encoded = String::with_capacity(value.len());
for byte in value.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' => {
encoded.push(byte as char);
}
_ => {
encoded.push('%');
encoded.push(hex_char(byte >> 4));
encoded.push(hex_char(byte & 0x0F));
}
}
}
encoded
}
fn hex_char(nibble: u8) -> char {
match nibble {
0..=9 => (b'0' + nibble) as char,
10..=15 => (b'A' + (nibble - 10)) as char,
_ => unreachable!("nibble must be in 0..=15"),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_path_segment_preserves_safe_ascii() {
assert_eq!(encode_path_segment("test_db"), "test_db");
}
#[test]
fn test_encode_path_segment_escapes_path_traversal_chars() {
assert_eq!(encode_path_segment("../evil"), "%2E%2E%2Fevil");
assert_eq!(encode_path_segment(r"..\\evil"), "%2E%2E%5C%5Cevil");
}
#[test]
fn test_ddl_path_for_schema_encodes_schema_segment() {
assert_eq!(ddl_path_for_schema("public"), "schema/ddl/public.sql");
assert_eq!(
ddl_path_for_schema("../evil"),
"schema/ddl/%2E%2E%2Fevil.sql"
);
}
}

View File

@@ -0,0 +1,669 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Storage abstraction for Export/Import V2.
//!
//! This module provides a unified interface for reading and writing snapshot data
//! to various storage backends (S3, OSS, GCS, Azure Blob, local filesystem).
use async_trait::async_trait;
use object_store::services::{Azblob, Fs, Gcs, Oss, S3};
use object_store::util::{with_instrument_layers, with_retry_layers};
use object_store::{AzblobConnection, GcsConnection, ObjectStore, OssConnection, S3Connection};
use snafu::ResultExt;
use url::Url;
use crate::common::ObjectStoreConfig;
use crate::data::export_v2::error::{
BuildObjectStoreSnafu, InvalidUriSnafu, ManifestParseSnafu, ManifestSerializeSnafu, Result,
SnapshotNotFoundSnafu, StorageOperationSnafu, TextDecodeSnafu, UnsupportedSchemeSnafu,
UrlParseSnafu,
};
use crate::data::export_v2::manifest::{MANIFEST_FILE, Manifest};
#[cfg(test)]
use crate::data::export_v2::schema::SchemaDefinition;
use crate::data::export_v2::schema::{SCHEMA_DIR, SCHEMAS_FILE, SchemaSnapshot};
struct RemoteLocation {
bucket_or_container: String,
root: String,
}
/// URI schemes supported for snapshot storage.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageScheme {
/// Amazon S3.
S3,
/// Alibaba Cloud OSS.
Oss,
/// Google Cloud Storage.
Gcs,
/// Azure Blob Storage.
Azblob,
/// Local filesystem (file://).
File,
}
impl StorageScheme {
/// Parses storage scheme from URI.
pub fn from_uri(uri: &str) -> Result<Self> {
let url = Url::parse(uri).context(UrlParseSnafu)?;
match url.scheme() {
"s3" => Ok(Self::S3),
"oss" => Ok(Self::Oss),
"gs" | "gcs" => Ok(Self::Gcs),
"azblob" => Ok(Self::Azblob),
"file" => Ok(Self::File),
scheme => UnsupportedSchemeSnafu { scheme }.fail(),
}
}
}
/// Extracts bucket/container and root path from a URI.
fn extract_remote_location(uri: &str) -> Result<RemoteLocation> {
let url = Url::parse(uri).context(UrlParseSnafu)?;
let bucket_or_container = url.host_str().unwrap_or("").to_string();
if bucket_or_container.is_empty() {
return InvalidUriSnafu {
uri,
reason: "URI must include bucket/container in host",
}
.fail();
}
let root = url.path().trim_start_matches('/').to_string();
if root.is_empty() {
return InvalidUriSnafu {
uri,
reason: "snapshot URI must include a non-empty path after the bucket/container",
}
.fail();
}
Ok(RemoteLocation {
bucket_or_container,
root,
})
}
/// Validates that a URI has a proper scheme.
///
/// Rejects bare paths (e.g., `/tmp/backup`, `./backup`) because:
/// - Schema export (CLI) and data export (server) run in different processes
/// - Using bare paths would split the snapshot across machines
///
/// Supported URI schemes:
/// - `s3://bucket/path` - Amazon S3
/// - `oss://bucket/path` - Alibaba Cloud OSS
/// - `gs://bucket/path` - Google Cloud Storage
/// - `azblob://container/path` - Azure Blob Storage
/// - `file:///absolute/path` - Local filesystem
pub fn validate_uri(uri: &str) -> Result<StorageScheme> {
// Must have a scheme
if !uri.contains("://") {
return InvalidUriSnafu {
uri,
reason: "URI must have a scheme (e.g., s3://, file://). Bare paths are not supported.",
}
.fail();
}
StorageScheme::from_uri(uri)
}
fn schema_index_path() -> String {
format!("{}/{}", SCHEMA_DIR, SCHEMAS_FILE)
}
/// Extracts the absolute filesystem path from a file:// URI.
fn extract_file_path_from_uri(uri: &str) -> Result<String> {
let url = Url::parse(uri).context(UrlParseSnafu)?;
match url.host_str() {
Some(host) if !host.is_empty() && host != "localhost" => InvalidUriSnafu {
uri,
reason: "file:// URI must use an absolute path like file:///tmp/backup",
}
.fail(),
_ => url
.to_file_path()
.map(|path| path.to_string_lossy().into_owned())
.map_err(|_| {
InvalidUriSnafu {
uri,
reason: "file:// URI must use a valid absolute filesystem path",
}
.build()
}),
}
}
async fn ensure_snapshot_exists(storage: &OpenDalStorage) -> Result<()> {
if storage.exists().await? {
Ok(())
} else {
SnapshotNotFoundSnafu {
uri: storage.target_uri.as_str(),
}
.fail()
}
}
/// Snapshot storage abstraction.
///
/// Provides operations for reading and writing snapshot data to various storage backends.
#[async_trait]
pub trait SnapshotStorage: Send + Sync {
/// Checks if a snapshot exists at this location (manifest.json exists).
async fn exists(&self) -> Result<bool>;
/// Reads the manifest file.
async fn read_manifest(&self) -> Result<Manifest>;
/// Writes the manifest file.
async fn write_manifest(&self, manifest: &Manifest) -> Result<()>;
/// Writes the schema index to schema/schemas.json.
async fn write_schema(&self, schema: &SchemaSnapshot) -> Result<()>;
/// Writes a text file to a relative path under the snapshot root.
async fn write_text(&self, path: &str, content: &str) -> Result<()>;
/// Reads a text file from a relative path under the snapshot root.
async fn read_text(&self, path: &str) -> Result<String>;
/// Deletes the entire snapshot (for --force).
async fn delete_snapshot(&self) -> Result<()>;
}
/// OpenDAL-based implementation of SnapshotStorage.
pub struct OpenDalStorage {
object_store: ObjectStore,
target_uri: String,
}
impl OpenDalStorage {
fn new_operator_rooted(object_store: ObjectStore, target_uri: &str) -> Self {
Self {
object_store,
target_uri: target_uri.to_string(),
}
}
fn finish_local_store(object_store: ObjectStore) -> ObjectStore {
with_instrument_layers(object_store, false)
}
fn finish_remote_store(object_store: ObjectStore) -> ObjectStore {
with_instrument_layers(with_retry_layers(object_store), false)
}
fn ensure_backend_enabled(uri: &str, enabled: bool, reason: &'static str) -> Result<()> {
if enabled {
Ok(())
} else {
InvalidUriSnafu { uri, reason }.fail()
}
}
fn validate_remote_config<E: std::fmt::Display>(
uri: &str,
backend: &str,
result: std::result::Result<(), E>,
) -> Result<()> {
result.map_err(|error| {
InvalidUriSnafu {
uri,
reason: format!("invalid {} config: {}", backend, error),
}
.build()
})
}
/// Creates a new storage from a file:// URI.
pub fn from_file_uri(uri: &str) -> Result<Self> {
let path = extract_file_path_from_uri(uri)?;
let builder = Fs::default().root(&path);
let object_store = ObjectStore::new(builder)
.context(BuildObjectStoreSnafu)?
.finish();
Ok(Self::new_operator_rooted(
Self::finish_local_store(object_store),
uri,
))
}
fn from_file_uri_with_config(uri: &str, storage: &ObjectStoreConfig) -> Result<Self> {
if storage.enable_s3 || storage.enable_oss || storage.enable_gcs || storage.enable_azblob {
return InvalidUriSnafu {
uri,
reason: "file:// cannot be used with remote storage flags",
}
.fail();
}
Self::from_file_uri(uri)
}
fn from_s3_uri(uri: &str, storage: &ObjectStoreConfig) -> Result<Self> {
Self::ensure_backend_enabled(
uri,
storage.enable_s3,
"s3:// requires --s3 and related options",
)?;
let location = extract_remote_location(uri)?;
let mut config = storage.s3.clone();
config.s3_bucket = location.bucket_or_container;
config.s3_root = location.root;
Self::validate_remote_config(uri, "s3", config.validate())?;
let conn: S3Connection = config.into();
let object_store = ObjectStore::new(S3::from(&conn))
.context(BuildObjectStoreSnafu)?
.finish();
Ok(Self::new_operator_rooted(
Self::finish_remote_store(object_store),
uri,
))
}
fn from_oss_uri(uri: &str, storage: &ObjectStoreConfig) -> Result<Self> {
Self::ensure_backend_enabled(
uri,
storage.enable_oss,
"oss:// requires --oss and related options",
)?;
let location = extract_remote_location(uri)?;
let mut config = storage.oss.clone();
config.oss_bucket = location.bucket_or_container;
config.oss_root = location.root;
Self::validate_remote_config(uri, "oss", config.validate())?;
let conn: OssConnection = config.into();
let object_store = ObjectStore::new(Oss::from(&conn))
.context(BuildObjectStoreSnafu)?
.finish();
Ok(Self::new_operator_rooted(
Self::finish_remote_store(object_store),
uri,
))
}
fn from_gcs_uri(uri: &str, storage: &ObjectStoreConfig) -> Result<Self> {
Self::ensure_backend_enabled(
uri,
storage.enable_gcs,
"gs:// or gcs:// requires --gcs and related options",
)?;
let location = extract_remote_location(uri)?;
let mut config = storage.gcs.clone();
config.gcs_bucket = location.bucket_or_container;
config.gcs_root = location.root;
Self::validate_remote_config(uri, "gcs", config.validate())?;
let conn: GcsConnection = config.into();
let object_store = ObjectStore::new(Gcs::from(&conn))
.context(BuildObjectStoreSnafu)?
.finish();
Ok(Self::new_operator_rooted(
Self::finish_remote_store(object_store),
uri,
))
}
fn from_azblob_uri(uri: &str, storage: &ObjectStoreConfig) -> Result<Self> {
Self::ensure_backend_enabled(
uri,
storage.enable_azblob,
"azblob:// requires --azblob and related options",
)?;
let location = extract_remote_location(uri)?;
let mut config = storage.azblob.clone();
config.azblob_container = location.bucket_or_container;
config.azblob_root = location.root;
Self::validate_remote_config(uri, "azblob", config.validate())?;
let conn: AzblobConnection = config.into();
let object_store = ObjectStore::new(Azblob::from(&conn))
.context(BuildObjectStoreSnafu)?
.finish();
Ok(Self::new_operator_rooted(
Self::finish_remote_store(object_store),
uri,
))
}
/// Creates a new storage from a URI and object store config.
pub fn from_uri(uri: &str, storage: &ObjectStoreConfig) -> Result<Self> {
match StorageScheme::from_uri(uri)? {
StorageScheme::File => Self::from_file_uri_with_config(uri, storage),
StorageScheme::S3 => Self::from_s3_uri(uri, storage),
StorageScheme::Oss => Self::from_oss_uri(uri, storage),
StorageScheme::Gcs => Self::from_gcs_uri(uri, storage),
StorageScheme::Azblob => Self::from_azblob_uri(uri, storage),
}
}
/// Reads a file as bytes.
async fn read_file(&self, path: &str) -> Result<Vec<u8>> {
let data = self
.object_store
.read(path)
.await
.context(StorageOperationSnafu {
operation: format!("read {}", path),
})?;
Ok(data.to_vec())
}
/// Writes bytes to a file.
async fn write_file(&self, path: &str, data: Vec<u8>) -> Result<()> {
self.object_store
.write(path, data)
.await
.map(|_| ())
.context(StorageOperationSnafu {
operation: format!("write {}", path),
})
}
/// Checks if a file exists using stat.
async fn file_exists(&self, path: &str) -> Result<bool> {
match self.object_store.stat(path).await {
Ok(_) => Ok(true),
Err(e) if e.kind() == object_store::ErrorKind::NotFound => Ok(false),
Err(e) => Err(e).context(StorageOperationSnafu {
operation: format!("check exists {}", path),
}),
}
}
#[cfg(test)]
pub async fn read_schema(&self) -> Result<SchemaSnapshot> {
let schemas_path = schema_index_path();
let schemas: Vec<SchemaDefinition> = if self.file_exists(&schemas_path).await? {
let data = self.read_file(&schemas_path).await?;
serde_json::from_slice(&data).context(ManifestParseSnafu)?
} else {
vec![]
};
Ok(SchemaSnapshot { schemas })
}
}
#[async_trait]
impl SnapshotStorage for OpenDalStorage {
async fn exists(&self) -> Result<bool> {
self.file_exists(MANIFEST_FILE).await
}
async fn read_manifest(&self) -> Result<Manifest> {
ensure_snapshot_exists(self).await?;
let data = self.read_file(MANIFEST_FILE).await?;
serde_json::from_slice(&data).context(ManifestParseSnafu)
}
async fn write_manifest(&self, manifest: &Manifest) -> Result<()> {
let data = serde_json::to_vec_pretty(manifest).context(ManifestSerializeSnafu)?;
self.write_file(MANIFEST_FILE, data).await
}
async fn write_schema(&self, schema: &SchemaSnapshot) -> Result<()> {
let schemas_path = schema_index_path();
let schemas_data =
serde_json::to_vec_pretty(&schema.schemas).context(ManifestSerializeSnafu)?;
self.write_file(&schemas_path, schemas_data).await
}
async fn write_text(&self, path: &str, content: &str) -> Result<()> {
self.write_file(path, content.as_bytes().to_vec()).await
}
async fn read_text(&self, path: &str) -> Result<String> {
let data = self.read_file(path).await?;
String::from_utf8(data).context(TextDecodeSnafu)
}
async fn delete_snapshot(&self) -> Result<()> {
self.object_store
.remove_all("/")
.await
.context(StorageOperationSnafu {
operation: "delete snapshot",
})
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::path::Path;
use object_store::ObjectStore;
use object_store::services::Fs;
use tempfile::tempdir;
use url::Url;
use super::*;
use crate::data::export_v2::manifest::{DataFormat, TimeRange};
use crate::data::export_v2::schema::SchemaDefinition;
fn make_storage_with_rooted_fs(dir: &std::path::Path) -> OpenDalStorage {
let object_store = ObjectStore::new(Fs::default().root(dir.to_str().unwrap()))
.unwrap()
.finish();
OpenDalStorage::new_operator_rooted(
OpenDalStorage::finish_local_store(object_store),
Url::from_directory_path(dir).unwrap().as_ref(),
)
}
#[test]
fn test_validate_uri_valid() {
assert_eq!(validate_uri("s3://bucket/path").unwrap(), StorageScheme::S3);
assert_eq!(
validate_uri("oss://bucket/path").unwrap(),
StorageScheme::Oss
);
assert_eq!(
validate_uri("gs://bucket/path").unwrap(),
StorageScheme::Gcs
);
assert_eq!(
validate_uri("gcs://bucket/path").unwrap(),
StorageScheme::Gcs
);
assert_eq!(
validate_uri("azblob://container/path").unwrap(),
StorageScheme::Azblob
);
assert_eq!(
validate_uri("file:///tmp/backup").unwrap(),
StorageScheme::File
);
}
#[test]
fn test_validate_uri_invalid() {
// Bare paths should be rejected
assert!(validate_uri("/tmp/backup").is_err());
assert!(validate_uri("./backup").is_err());
assert!(validate_uri("backup").is_err());
// Unknown schemes
assert!(validate_uri("ftp://server/path").is_err());
}
#[test]
fn test_extract_remote_location_requires_non_empty_root() {
assert!(extract_remote_location("s3://bucket").is_err());
assert!(extract_remote_location("s3://bucket/").is_err());
assert!(extract_remote_location("oss://bucket").is_err());
assert!(extract_remote_location("gs://bucket").is_err());
assert!(extract_remote_location("azblob://container").is_err());
}
#[cfg(not(windows))]
#[test]
fn test_extract_path_from_uri_unix_examples() {
assert_eq!(
extract_file_path_from_uri("file:///tmp/backup").unwrap(),
"/tmp/backup"
);
assert_eq!(
extract_file_path_from_uri("file://localhost/tmp/backup").unwrap(),
"/tmp/backup"
);
}
#[test]
fn test_extract_file_path_from_uri_rejects_file_host() {
assert!(extract_file_path_from_uri("file://tmp/backup").is_err());
}
#[test]
fn test_extract_file_path_from_uri_round_trips_directory_url() {
let dir = tempdir().unwrap();
let uri = Url::from_directory_path(dir.path()).unwrap().to_string();
let path = extract_file_path_from_uri(&uri).unwrap();
assert_eq!(Path::new(&path), dir.path());
}
#[tokio::test]
async fn test_read_manifest_reports_requested_uri() {
let dir = tempdir().unwrap();
let uri = Url::from_directory_path(dir.path()).unwrap().to_string();
let storage = OpenDalStorage::from_file_uri(&uri).unwrap();
let error = storage.read_manifest().await.unwrap_err().to_string();
assert!(error.contains(uri.as_str()));
}
#[tokio::test]
async fn test_manifest_round_trip() {
let dir = tempdir().unwrap();
let storage = make_storage_with_rooted_fs(dir.path());
let manifest = Manifest::new_full(
"greptime".to_string(),
vec!["public".to_string()],
TimeRange::unbounded(),
DataFormat::Parquet,
);
storage.write_manifest(&manifest).await.unwrap();
let loaded = storage.read_manifest().await.unwrap();
assert_eq!(loaded.catalog, manifest.catalog);
assert_eq!(loaded.schemas, manifest.schemas);
assert_eq!(loaded.schema_only, manifest.schema_only);
assert_eq!(loaded.format, manifest.format);
assert_eq!(loaded.snapshot_id, manifest.snapshot_id);
}
#[tokio::test]
async fn test_schema_round_trip() {
let dir = tempdir().unwrap();
let storage = make_storage_with_rooted_fs(dir.path());
let mut snapshot = SchemaSnapshot::new();
snapshot.add_schema(SchemaDefinition {
catalog: "greptime".to_string(),
name: "test_db".to_string(),
options: HashMap::from([("ttl".to_string(), "7d".to_string())]),
});
storage.write_schema(&snapshot).await.unwrap();
let loaded = storage.read_schema().await.unwrap();
assert_eq!(loaded, snapshot);
}
#[tokio::test]
async fn test_text_round_trip() {
let dir = tempdir().unwrap();
let storage = make_storage_with_rooted_fs(dir.path());
let content = "CREATE TABLE metrics (ts TIMESTAMP TIME INDEX);";
storage
.write_text("schema/ddl/public.sql", content)
.await
.unwrap();
let loaded = storage.read_text("schema/ddl/public.sql").await.unwrap();
assert_eq!(loaded, content);
}
#[tokio::test]
async fn test_read_text_rejects_invalid_utf8() {
let dir = tempdir().unwrap();
let storage = make_storage_with_rooted_fs(dir.path());
storage
.write_file("schema/ddl/public.sql", vec![0xff, 0xfe, 0xfd])
.await
.unwrap();
let error = storage
.read_text("schema/ddl/public.sql")
.await
.unwrap_err();
assert!(error.to_string().contains("UTF-8"));
}
#[tokio::test]
async fn test_exists_follows_manifest_presence() {
let dir = tempdir().unwrap();
let storage = make_storage_with_rooted_fs(dir.path());
assert!(!storage.exists().await.unwrap());
storage
.write_manifest(&Manifest::new_schema_only(
"greptime".to_string(),
vec!["public".to_string()],
))
.await
.unwrap();
assert!(storage.exists().await.unwrap());
}
#[tokio::test]
async fn test_delete_snapshot_only_removes_rooted_contents() {
let parent = tempdir().unwrap();
let snapshot_root = parent.path().join("snapshot");
let sibling = parent.path().join("sibling");
std::fs::create_dir_all(&snapshot_root).unwrap();
std::fs::create_dir_all(&sibling).unwrap();
std::fs::write(snapshot_root.join("manifest.json"), b"{}").unwrap();
std::fs::write(sibling.join("keep.txt"), b"keep").unwrap();
let storage = make_storage_with_rooted_fs(&snapshot_root);
storage.delete_snapshot().await.unwrap();
assert!(!snapshot_root.join("manifest.json").exists());
assert!(sibling.join("keep.txt").exists());
}
}

40
src/cli/src/data/sql.rs Normal file
View File

@@ -0,0 +1,40 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Shared SQL escaping helpers for CLI-generated statements.
pub(crate) fn escape_sql_literal(value: &str) -> String {
value.replace('\'', "''")
}
pub(crate) fn escape_sql_identifier(value: &str) -> String {
value.replace('"', "\"\"")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_sql_literal_escapes_single_quotes() {
assert_eq!(escape_sql_literal("test_db"), "test_db");
assert_eq!(escape_sql_literal("te'st"), "te''st");
}
#[test]
fn test_escape_sql_identifier_escapes_double_quotes() {
assert_eq!(escape_sql_identifier("test_db"), "test_db");
assert_eq!(escape_sql_identifier(r#"te"st"#), r#"te""st"#);
}
}

View File

@@ -36,6 +36,7 @@ pub struct DatabaseClient {
auth_header: Option<String>,
timeout: Duration,
proxy: Option<reqwest::Proxy>,
no_proxy: bool,
}
pub fn parse_proxy_opts(
@@ -61,6 +62,7 @@ impl DatabaseClient {
auth_basic: Option<String>,
timeout: Duration,
proxy: Option<reqwest::Proxy>,
no_proxy: bool,
) -> Self {
let auth_header = if let Some(basic) = auth_basic {
let encoded = general_purpose::STANDARD.encode(basic);
@@ -69,7 +71,9 @@ impl DatabaseClient {
None
};
if let Some(ref proxy) = proxy {
if no_proxy {
common_telemetry::info!("Proxy disabled");
} else if let Some(ref proxy) = proxy {
common_telemetry::info!("Using proxy: {:?}", proxy);
} else {
common_telemetry::info!("Using system proxy(if any)");
@@ -81,6 +85,7 @@ impl DatabaseClient {
auth_header,
timeout,
proxy,
no_proxy,
}
}
@@ -95,12 +100,14 @@ impl DatabaseClient {
("db", format!("{}-{}", self.catalog, schema)),
("sql", sql.to_string()),
];
let client = self
.proxy
.clone()
.map(|proxy| reqwest::Client::builder().proxy(proxy).build())
.unwrap_or_else(|| Ok(reqwest::Client::new()))
.context(BuildClientSnafu)?;
let mut builder = reqwest::Client::builder();
if let Some(proxy) = self.proxy.clone() {
builder = builder.proxy(proxy);
}
if self.no_proxy {
builder = builder.no_proxy();
}
let client = builder.build().context(BuildClientSnafu)?;
let mut request = client
.post(&url)
.form(&params)

View File

@@ -29,7 +29,7 @@ pub use database::DatabaseClient;
use error::Result;
pub use crate::bench::BenchTableMetadataCommand;
pub use crate::data::DataCommand;
pub use crate::data::{DataCommand, export_v2, import_v2};
pub use crate::metadata::MetadataCommand;
#[async_trait]

View File

@@ -20,13 +20,14 @@ use clap::Parser;
use colored::Colorize;
use datanode::config::RegionEngineConfig;
use datanode::store;
use either::Either;
use futures::stream;
use mito2::access_layer::{
AccessLayer, AccessLayerRef, Metrics, OperationType, SstWriteRequest, WriteType,
};
use mito2::cache::{CacheManager, CacheManagerRef};
use mito2::config::{FulltextIndexConfig, MitoConfig, Mode};
use mito2::read::Source;
use mito2::read::FlatSource;
use mito2::sst::FormatType;
use mito2::sst::file::{FileHandle, FileMeta};
use mito2::sst::file_purger::{FilePurger, FilePurgerRef};
use mito2::sst::index::intermediate::IntermediateManager;
@@ -210,6 +211,7 @@ impl ObjbenchCommand {
object_store.clone(),
)
.expected_metadata(Some(region_meta.clone()))
.flat_format(true)
.build()
.await
.map_err(|e| {
@@ -231,6 +233,10 @@ impl ObjbenchCommand {
let reader_build_elapsed = reader_build_start.elapsed();
let total_rows = reader.parquet_metadata().file_metadata().num_rows();
println!("{} Reader built in {:?}", "".green(), reader_build_elapsed);
let reader_stream = Box::pin(stream::try_unfold(reader, |mut reader| async move {
let batch = reader.next_record_batch().await?;
Ok(batch.map(|batch| (batch, reader)))
}));
// Build write request
let fulltext_index_config = FulltextIndexConfig {
@@ -241,10 +247,11 @@ impl ObjbenchCommand {
let write_req = SstWriteRequest {
op_type: OperationType::Flush,
metadata: region_meta,
source: Either::Left(Source::Reader(Box::new(reader))),
source: FlatSource::Stream(reader_stream),
cache_manager,
storage: None,
max_sequence: None,
sst_write_format: FormatType::PrimaryKey,
index_options: Default::default(),
index_config: mito_engine_config.index.clone(),
inverted_index_config: MitoConfig::default().inverted_index,

View File

@@ -32,14 +32,15 @@ use common_meta::cache::LayeredCacheRegistryBuilder;
use common_meta::ddl::flow_meta::FlowMetadataAllocator;
use common_meta::ddl::table_meta::TableMetadataAllocator;
use common_meta::ddl::{DdlContext, NoopRegionFailureDetectorControl};
use common_meta::ddl_manager::{DdlManager, DdlManagerConfiguratorRef};
use common_meta::ddl_manager::{DdlManager, DdlManagerConfiguratorRef, DdlManagerRef};
use common_meta::key::flow::FlowMetadataManager;
use common_meta::key::{TableMetadataManager, TableMetadataManagerRef};
use common_meta::kv_backend::KvBackendRef;
use common_meta::procedure_executor::LocalProcedureExecutor;
use common_meta::node_manager::{FlownodeRef, NodeManagerRef};
use common_meta::procedure_executor::{LocalProcedureExecutor, ProcedureExecutorRef};
use common_meta::region_keeper::MemoryRegionKeeper;
use common_meta::region_registry::LeaderRegionRegistry;
use common_meta::sequence::SequenceBuilder;
use common_meta::sequence::{Sequence, SequenceBuilder};
use common_meta::wal_provider::{WalProviderRef, build_wal_provider};
use common_procedure::ProcedureManagerRef;
use common_query::prelude::set_default_prefix;
@@ -49,6 +50,7 @@ use common_time::timezone::set_default_timezone;
use common_version::{short_version, verbose_version};
use datanode::config::DatanodeOptions;
use datanode::datanode::{Datanode, DatanodeBuilder};
use datanode::region_server::RegionServer;
use flow::{
FlownodeBuilder, FlownodeInstance, FlownodeOptions, FrontendClient, FrontendInvoker,
GrpcQueryHandlerWithBoxedError,
@@ -58,6 +60,7 @@ use frontend::instance::StandaloneDatanodeManager;
use frontend::instance::builder::FrontendBuilder;
use frontend::server::Services;
use meta_srv::metasrv::{FLOW_ID_SEQ, TABLE_ID_SEQ};
use plugins::PluginOptions;
use plugins::frontend::context::{
CatalogManagerConfigureContext, StandaloneCatalogManagerConfigureContext,
};
@@ -130,6 +133,18 @@ impl Instance {
pub fn server_addr(&self, name: &str) -> Option<SocketAddr> {
self.frontend.server_handlers().addr(name)
}
/// Get the mutable Frontend component of this Standalone instance for externally modification
/// by others (might not be in this code base, so don't delete this function).
pub fn mut_frontend(&mut self) -> &mut Frontend {
&mut self.frontend
}
/// Get the Datanode component of this Standalone instance for externally usage
/// by others (might not be in this code base, so don't delete this function).
pub fn datanode(&self) -> &Datanode {
&self.datanode
}
}
#[async_trait]
@@ -342,9 +357,18 @@ impl StartCommand {
info!("Standalone start command: {:#?}", self);
info!("Standalone options: {opts:#?}");
let (mut instance, _) =
Self::build_with(opts.component, opts.plugins, InstanceCreator::default()).await?;
instance._guard.extend(guard);
Ok(instance)
}
pub async fn build_with(
mut opts: StandaloneOptions,
plugin_opts: Vec<PluginOptions>,
creator: InstanceCreator,
) -> Result<(Instance, InstanceCreatorResult)> {
let mut plugins = Plugins::new();
let plugin_opts = opts.plugins;
let mut opts = opts.component;
set_default_prefix(opts.default_column_prefix.as_deref())
.map_err(BoxedError::new)
.context(error::BuildCliSnafu)?;
@@ -462,17 +486,16 @@ impl StartCommand {
.await;
}
let node_manager = Arc::new(StandaloneDatanodeManager {
region_server: datanode.region_server(),
flow_server: flownode.flow_engine(),
});
let node_manager = creator
.node_manager_creator
.create(
&kv_backend,
datanode.region_server(),
flownode.flow_engine(),
)
.await?;
let table_id_allocator = Arc::new(
SequenceBuilder::new(TABLE_ID_SEQ, kv_backend.clone())
.initial(MIN_USER_TABLE_ID as u64)
.step(10)
.build(),
);
let table_id_allocator = creator.table_id_allocator_creator.create(&kv_backend);
let flow_id_sequence = Arc::new(
SequenceBuilder::new(FLOW_ID_SEQ, kv_backend.clone())
.initial(MIN_USER_FLOW_ID as u64)
@@ -489,7 +512,7 @@ impl StartCommand {
.context(error::BuildWalProviderSnafu)?;
let wal_provider = Arc::new(wal_provider);
let table_metadata_allocator = Arc::new(TableMetadataAllocator::new(
table_id_allocator,
table_id_allocator.clone(),
wal_provider.clone(),
));
let flow_metadata_allocator = Arc::new(FlowMetadataAllocator::with_noop_peer_allocator(
@@ -532,10 +555,10 @@ impl StartCommand {
ddl_manager
};
let procedure_executor = Arc::new(LocalProcedureExecutor::new(
Arc::new(ddl_manager),
procedure_manager.clone(),
));
let procedure_executor = creator
.procedure_executor_creator
.create(Arc::new(ddl_manager), procedure_manager.clone())
.await?;
let fe_instance = FrontendBuilder::new(
fe_opts.clone(),
@@ -568,7 +591,7 @@ impl StartCommand {
kv_backend.clone(),
layered_cache_registry.clone(),
procedure_executor,
node_manager,
node_manager.clone(),
)
.await
.context(StartFlownodeSnafu)?;
@@ -584,14 +607,20 @@ impl StartCommand {
heartbeat_task: None,
};
Ok(Instance {
let instance = Instance {
datanode,
frontend,
flownode,
procedure_manager,
wal_provider,
_guard: guard,
})
_guard: vec![],
};
let result = InstanceCreatorResult {
kv_backend,
node_manager,
table_id_allocator,
};
Ok((instance, result))
}
pub async fn create_table_metadata_manager(
@@ -608,6 +637,115 @@ impl StartCommand {
}
}
#[async_trait]
pub trait NodeManagerCreator {
async fn create(
&self,
kv_backend: &KvBackendRef,
region_server: RegionServer,
flow_server: FlownodeRef,
) -> Result<NodeManagerRef>;
}
pub struct DefaultNodeManagerCreator;
#[async_trait]
impl NodeManagerCreator for DefaultNodeManagerCreator {
async fn create(
&self,
_: &KvBackendRef,
region_server: RegionServer,
flow_server: FlownodeRef,
) -> Result<NodeManagerRef> {
Ok(Arc::new(StandaloneDatanodeManager {
region_server,
flow_server,
}))
}
}
pub trait TableIdAllocatorCreator {
fn create(&self, kv_backend: &KvBackendRef) -> Arc<Sequence>;
}
struct DefaultTableIdAllocatorCreator;
impl TableIdAllocatorCreator for DefaultTableIdAllocatorCreator {
fn create(&self, kv_backend: &KvBackendRef) -> Arc<Sequence> {
Arc::new(
SequenceBuilder::new(TABLE_ID_SEQ, kv_backend.clone())
.initial(MIN_USER_TABLE_ID as u64)
.step(10)
.build(),
)
}
}
#[async_trait]
pub trait ProcedureExecutorCreator {
async fn create(
&self,
ddl_manager: DdlManagerRef,
procedure_manager: ProcedureManagerRef,
) -> Result<ProcedureExecutorRef>;
}
pub struct DefaultProcedureExecutorCreator;
#[async_trait]
impl ProcedureExecutorCreator for DefaultProcedureExecutorCreator {
async fn create(
&self,
ddl_manager: DdlManagerRef,
procedure_manager: ProcedureManagerRef,
) -> Result<ProcedureExecutorRef> {
Ok(Arc::new(LocalProcedureExecutor::new(
ddl_manager,
procedure_manager,
)))
}
}
/// `InstanceCreator` is used for grouping various component creators for building the
/// Standalone instance, suitable for customizing how the instance can be built.
pub struct InstanceCreator {
node_manager_creator: Box<dyn NodeManagerCreator>,
table_id_allocator_creator: Box<dyn TableIdAllocatorCreator>,
procedure_executor_creator: Box<dyn ProcedureExecutorCreator>,
}
impl InstanceCreator {
pub fn new(
node_manager_creator: Box<dyn NodeManagerCreator>,
table_id_allocator_creator: Box<dyn TableIdAllocatorCreator>,
procedure_executor_creator: Box<dyn ProcedureExecutorCreator>,
) -> Self {
Self {
node_manager_creator,
table_id_allocator_creator,
procedure_executor_creator,
}
}
}
impl Default for InstanceCreator {
fn default() -> Self {
Self {
node_manager_creator: Box::new(DefaultNodeManagerCreator),
table_id_allocator_creator: Box::new(DefaultTableIdAllocatorCreator),
procedure_executor_creator: Box::new(DefaultProcedureExecutorCreator),
}
}
}
/// `InstanceCreatorResult` is expected to be used paired with [InstanceCreator].
/// It stores the created and other important components for further reusing.
pub struct InstanceCreatorResult {
pub kv_backend: KvBackendRef,
pub node_manager: NodeManagerRef,
pub table_id_allocator: Arc<Sequence>,
}
#[cfg(test)]
mod tests {
use std::default::Default;

View File

@@ -53,7 +53,7 @@ pub trait Configurable: Serialize + DeserializeOwned + Default + Sized {
env.try_parsing(true)
.separator(ENV_VAR_SEP)
.ignore_empty(true)
.ignore_empty(false)
};
// Workaround: Replacement for `Config::try_from(&default_opts)` due to
@@ -237,4 +237,31 @@ mod tests {
},
);
}
#[derive(Debug, Serialize, Deserialize, Default)]
struct SimpleConfig {
name: Option<String>,
prefix: Option<String>,
}
impl Configurable for SimpleConfig {}
#[test]
fn test_empty_env_var_is_not_ignored() {
let env_prefix = "SIMPLE_CFG_UT";
temp_env::with_vars(
[(
[env_prefix.to_string(), "PREFIX".to_string()].join(ENV_VAR_SEP),
Some(""),
)],
|| {
let opts = SimpleConfig::load_layered_options(None, env_prefix).unwrap();
// With ignore_empty(false), an empty env var should yield Some("")
// rather than None (which was the previous behavior with ignore_empty(true)).
assert_eq!(opts.prefix, Some("".to_string()));
// Unset env var should remain None.
assert_eq!(opts.name, None);
},
);
}
}

View File

@@ -25,7 +25,7 @@
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use arrow::array::StructArray;
use arrow::array::{ArrayRef, BooleanArray, StructArray};
use arrow_schema::{FieldRef, Fields};
use common_telemetry::debug;
use datafusion::functions_aggregate::all_default_aggregate_functions;
@@ -38,8 +38,8 @@ use datafusion_common::{Column, ScalarValue};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::{
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, Expr, ExprSchemable, LogicalPlan,
Signature,
Accumulator, Aggregate, AggregateUDF, AggregateUDFImpl, EmitTo, Expr, ExprSchemable,
GroupsAccumulator, LogicalPlan, Signature,
};
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
use datatypes::arrow::datatypes::{DataType, Field};
@@ -322,6 +322,14 @@ impl StateWrapper {
);
})
}
fn fix_inner_acc_args<'b>(
&self,
mut acc_args: datafusion_expr::function::AccumulatorArgs<'b>,
) -> datafusion_common::Result<datafusion_expr::function::AccumulatorArgs<'b>> {
acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
Ok(acc_args)
}
}
impl AggregateUDFImpl for StateWrapper {
@@ -331,15 +339,32 @@ impl AggregateUDFImpl for StateWrapper {
) -> datafusion_common::Result<Box<dyn Accumulator>> {
// fix and recover proper acc args for the original aggregate function.
let state_type = acc_args.return_type().clone();
let inner = {
let mut new_acc_args = acc_args.clone();
new_acc_args.return_field = self.deduce_aggr_return_type(&acc_args)?;
self.inner.accumulator(new_acc_args)?
};
let inner = self.inner.accumulator(self.fix_inner_acc_args(acc_args)?)?;
Ok(Box::new(StateAccum::new(inner, state_type)?))
}
fn groups_accumulator_supported(
&self,
acc_args: datafusion_expr::function::AccumulatorArgs,
) -> bool {
self.fix_inner_acc_args(acc_args)
.map(|args| self.inner.inner().groups_accumulator_supported(args))
.unwrap_or(false)
}
fn create_groups_accumulator(
&self,
acc_args: datafusion_expr::function::AccumulatorArgs,
) -> datafusion_common::Result<Box<dyn GroupsAccumulator>> {
let state_type = acc_args.return_type().clone();
let inner = self
.inner
.inner()
.create_groups_accumulator(self.fix_inner_acc_args(acc_args)?)?;
Ok(Box::new(StateGroupsAccum::new(inner, state_type)?))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
@@ -462,6 +487,118 @@ pub struct StateAccum {
state_fields: Fields,
}
pub struct StateGroupsAccum {
inner: Box<dyn GroupsAccumulator>,
state_fields: Fields,
}
impl StateGroupsAccum {
fn new(
inner: Box<dyn GroupsAccumulator>,
state_type: DataType,
) -> datafusion_common::Result<Self> {
let DataType::Struct(fields) = state_type else {
return Err(datafusion_common::DataFusionError::Internal(format!(
"Expected a struct type for state, got: {:?}",
state_type
)));
};
Ok(Self {
inner,
state_fields: fields,
})
}
fn wrap_state_arrays(&self, arrays: Vec<ArrayRef>) -> datafusion_common::Result<ArrayRef> {
let array_type = arrays
.iter()
.map(|array| array.data_type().clone())
.collect::<Vec<_>>();
let expected_type = self
.state_fields
.iter()
.map(|field| field.data_type().clone())
.collect::<Vec<_>>();
if array_type != expected_type {
debug!(
"State mismatch, expected: {}, got: {} for expected fields: {:?} and given array types: {:?}",
self.state_fields.len(),
arrays.len(),
self.state_fields,
array_type,
);
let guess_schema = arrays
.iter()
.enumerate()
.map(|(index, array)| {
Field::new(
format!("col_{index}[mismatch_state]").as_str(),
array.data_type().clone(),
true,
)
})
.collect::<Fields>();
let array = StructArray::try_new(guess_schema, arrays, None)?;
return Ok(Arc::new(array));
}
Ok(Arc::new(StructArray::try_new(
self.state_fields.clone(),
arrays,
None,
)?))
}
}
impl GroupsAccumulator for StateGroupsAccum {
fn update_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> datafusion_common::Result<()> {
self.inner
.update_batch(values, group_indices, opt_filter, total_num_groups)
}
fn merge_batch(
&mut self,
values: &[ArrayRef],
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> datafusion_common::Result<()> {
self.inner
.merge_batch(values, group_indices, opt_filter, total_num_groups)
}
fn evaluate(&mut self, emit_to: EmitTo) -> datafusion_common::Result<ArrayRef> {
let state = self.inner.state(emit_to)?;
self.wrap_state_arrays(state)
}
fn state(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
self.inner.state(emit_to)
}
fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> datafusion_common::Result<Vec<ArrayRef>> {
self.inner.convert_to_state(values, opt_filter)
}
fn supports_convert_to_state(&self) -> bool {
self.inner.supports_convert_to_state()
}
fn size(&self) -> usize {
self.inner.size()
}
}
impl StateAccum {
pub fn new(
inner: Box<dyn Accumulator>,

View File

@@ -40,10 +40,13 @@ use datafusion_common::arrow::array::AsArray;
use datafusion_common::arrow::datatypes::{Float64Type, UInt64Type};
use datafusion_common::{Column, TableReference};
use datafusion_expr::expr::{AggregateFunction, NullTreatment};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::{
Aggregate, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr, TableScan, lit,
Aggregate, AggregateUDFImpl, ColumnarValue, Expr, LogicalPlan, ScalarFunctionArgs, SortExpr,
TableScan, lit,
};
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::col;
use datafusion_physical_expr::{EquivalenceProperties, Partitioning};
use datatypes::arrow_array::StringArray;
use futures::{Stream, StreamExt as _};
@@ -256,6 +259,38 @@ fn dummy_table_scan_with_ts() -> LogicalPlan {
)
}
fn create_avg_state_groups_accumulator() -> Box<dyn GroupsAccumulator> {
let state_wrapper = StateWrapper::new((*avg_udaf()).clone()).unwrap();
let schema = Arc::new(arrow_schema::Schema::new(vec![Field::new(
"number",
DataType::Float64,
true,
)]));
let expr = col("number", &schema).unwrap();
let expr_field = expr.return_field(&schema).unwrap();
let return_field = Arc::new(Field::new(
"__avg_state(number)",
state_wrapper.return_type(&[DataType::Float64]).unwrap(),
true,
));
let exprs = [expr];
let expr_fields = [expr_field];
let acc_args = AccumulatorArgs {
return_field,
schema: &schema,
ignore_nulls: false,
order_bys: &[],
is_reversed: false,
name: "__avg_state(number)",
is_distinct: false,
exprs: &exprs,
expr_fields: &expr_fields,
};
assert!(state_wrapper.groups_accumulator_supported(acc_args.clone()));
state_wrapper.create_groups_accumulator(acc_args).unwrap()
}
#[tokio::test]
async fn test_sum_udaf() {
let ctx = SessionContext::new();
@@ -796,6 +831,95 @@ async fn test_last_value_order_by_udaf() {
assert_eq!(merge_eval_res, ScalarValue::Int64(Some(4)));
}
#[test]
fn test_avg_state_groups_accumulator_evaluate() {
let mut state_accum = create_avg_state_groups_accumulator();
let values = vec![Arc::new(Float64Array::from(vec![
Some(1.0),
Some(2.0),
None,
Some(3.0),
Some(4.0),
Some(5.0),
])) as ArrayRef];
let group_indices = vec![0, 1, 0, 0, 1, 2];
state_accum
.update_batch(&values, &group_indices, None, 3)
.unwrap();
let result = state_accum.evaluate(EmitTo::All).unwrap();
let result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(
result
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap(),
&UInt64Array::from(vec![2, 2, 1])
);
assert_eq!(
result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap(),
&Float64Array::from(vec![4.0, 6.0, 5.0])
);
}
#[test]
fn test_avg_state_groups_accumulator_state_merge_evaluate() {
let mut source_accum = create_avg_state_groups_accumulator();
let source_values = vec![Arc::new(Float64Array::from(vec![
Some(1.0),
Some(2.0),
None,
Some(3.0),
Some(4.0),
Some(5.0),
])) as ArrayRef];
let source_group_indices = vec![0, 1, 0, 0, 1, 2];
source_accum
.update_batch(&source_values, &source_group_indices, None, 3)
.unwrap();
let source_state = source_accum.state(EmitTo::All).unwrap();
let mut merged_accum = create_avg_state_groups_accumulator();
let merged_values =
vec![Arc::new(Float64Array::from(vec![Some(10.0), Some(20.0), Some(30.0)])) as ArrayRef];
let merged_group_indices = vec![0, 1, 2];
merged_accum
.update_batch(&merged_values, &merged_group_indices, None, 3)
.unwrap();
merged_accum
.merge_batch(&source_state, &[1, 2, 0], None, 3)
.unwrap();
let result = merged_accum.evaluate(EmitTo::All).unwrap();
let result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(
result
.column(0)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap(),
&UInt64Array::from(vec![2, 3, 3])
);
assert_eq!(
result
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap(),
&Float64Array::from(vec![15.0, 24.0, 36.0])
);
}
/// For testing whether the UDAF state fields are correctly implemented.
/// esp. for our own custom UDAF's state fields.
/// By compare eval results before and after split to state/merge functions.

View File

@@ -19,6 +19,7 @@ use datafusion_common::DataFusionError;
use datafusion_common::arrow::array::{Array, AsArray, StringViewBuilder};
use datafusion_common::arrow::datatypes::DataType;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, Signature, Volatility};
use datatypes::types::jsonb_to_string;
use crate::function::{Function, extract_args};
@@ -74,7 +75,7 @@ impl Function for JsonToStringFunction {
for i in 0..size {
let json = jsons.is_valid(i).then(|| jsons.value(i));
let result = json
.map(|json| jsonb::from_slice(json).map(|x| x.to_string()))
.map(jsonb_to_string)
.transpose()
.map_err(|e| DataFusionError::Execution(format!("invalid json binary: {e}")))?;

View File

@@ -10,7 +10,6 @@ workspace = true
[dependencies]
common-error = { workspace = true }
common-macro = { workspace = true }
common-telemetry = { workspace = true }
humantime = { workspace = true }
serde = { workspace = true }
snafu = { workspace = true }

View File

@@ -14,14 +14,13 @@
use std::{fmt, mem};
use common_telemetry::debug;
use snafu::ensure;
use tokio::sync::{OwnedSemaphorePermit, TryAcquireError};
use crate::error::{
MemoryAcquireTimeoutSnafu, MemoryLimitExceededSnafu, MemorySemaphoreClosedSnafu, Result,
};
use crate::manager::{MemoryMetrics, MemoryQuota};
use crate::manager::{MemoryMetrics, MemoryQuota, UnlimitedMemoryQuota};
use crate::policy::OnExhaustedPolicy;
/// Guard representing a slice of reserved memory.
@@ -30,31 +29,57 @@ pub struct MemoryGuard<M: MemoryMetrics> {
}
pub(crate) enum GuardState<M: MemoryMetrics> {
Unlimited,
Released,
Unlimited {
quota: UnlimitedMemoryQuota<M>,
granted_bytes: u64,
},
Limited {
permit: OwnedSemaphorePermit,
quota: MemoryQuota<M>,
permit: OwnedSemaphorePermit,
},
}
impl<M: MemoryMetrics> GuardState<M> {
fn release(self) {
match self {
GuardState::Released => {}
GuardState::Unlimited {
quota,
granted_bytes,
} => {
quota.sub_in_use(granted_bytes);
}
GuardState::Limited { quota, permit } => {
quota.release_permit(permit);
}
}
}
}
impl<M: MemoryMetrics> MemoryGuard<M> {
pub(crate) fn unlimited() -> Self {
pub(crate) fn unlimited(quota: UnlimitedMemoryQuota<M>, bytes: u64) -> Self {
quota.add_in_use(bytes);
Self {
state: GuardState::Unlimited,
state: GuardState::Unlimited {
quota,
granted_bytes: bytes,
},
}
}
pub(crate) fn limited(permit: OwnedSemaphorePermit, quota: MemoryQuota<M>) -> Self {
pub(crate) fn limited(quota: MemoryQuota<M>, permit: OwnedSemaphorePermit) -> Self {
Self {
state: GuardState::Limited { permit, quota },
state: GuardState::Limited { quota, permit },
}
}
/// Returns granted quota in bytes.
pub fn granted_bytes(&self) -> u64 {
match &self.state {
GuardState::Unlimited => 0,
GuardState::Limited { permit, quota } => {
GuardState::Released => 0,
GuardState::Unlimited { granted_bytes, .. } => *granted_bytes,
GuardState::Limited { quota, permit } => {
quota.permits_to_bytes(permit.num_permits() as u32)
}
}
@@ -68,13 +93,24 @@ impl<M: MemoryMetrics> MemoryGuard<M> {
/// - Returns error if requested bytes would exceed the manager's total limit
/// - Returns error if the semaphore is unexpectedly closed
pub async fn acquire_additional(&mut self, bytes: u64) -> Result<()> {
match &mut self.state {
GuardState::Unlimited => Ok(()),
GuardState::Limited { permit, quota } => {
if bytes == 0 {
return Ok(());
}
if bytes == 0 {
return Ok(());
}
match &mut self.state {
GuardState::Released => {
debug_assert!(false, "released memory guard state should not be reused");
Ok(())
}
GuardState::Unlimited {
quota,
granted_bytes,
} => {
quota.add_in_use(bytes);
*granted_bytes = granted_bytes.saturating_add(bytes);
Ok(())
}
GuardState::Limited { quota, permit } => {
let additional_permits = quota.bytes_to_permits(bytes);
let current_permits = permit.num_permits() as u32;
@@ -95,7 +131,6 @@ impl<M: MemoryMetrics> MemoryGuard<M> {
permit.merge(additional_permit);
quota.update_in_use_metric();
debug!("Acquired additional {} bytes", bytes);
Ok(())
}
}
@@ -106,13 +141,24 @@ impl<M: MemoryMetrics> MemoryGuard<M> {
/// On success, merges the new memory into this guard and returns true.
/// On failure, returns false and leaves this guard unchanged.
pub fn try_acquire_additional(&mut self, bytes: u64) -> bool {
match &mut self.state {
GuardState::Unlimited => true,
GuardState::Limited { permit, quota } => {
if bytes == 0 {
return true;
}
if bytes == 0 {
return true;
}
match &mut self.state {
GuardState::Released => {
debug_assert!(false, "released memory guard state should not be reused");
false
}
GuardState::Unlimited {
quota,
granted_bytes,
} => {
quota.add_in_use(bytes);
*granted_bytes = granted_bytes.saturating_add(bytes);
true
}
GuardState::Limited { quota, permit } => {
let additional_permits = quota.bytes_to_permits(bytes);
match quota
@@ -123,7 +169,6 @@ impl<M: MemoryMetrics> MemoryGuard<M> {
Ok(additional_permit) => {
permit.merge(additional_permit);
quota.update_in_use_metric();
debug!("Acquired additional {} bytes", bytes);
true
}
Err(TryAcquireError::NoPermits) | Err(TryAcquireError::Closed) => {
@@ -168,7 +213,8 @@ impl<M: MemoryMetrics> MemoryGuard<M> {
MemoryLimitExceededSnafu {
requested_bytes: bytes,
limit_bytes: match &self.state {
GuardState::Unlimited => 0, // unreachable: unlimited mode always succeeds
GuardState::Released => 0,
GuardState::Unlimited { .. } => 0, // unreachable: unlimited mode always succeeds
GuardState::Limited { quota, .. } => {
quota.permits_to_bytes(quota.limit_permits)
}
@@ -184,22 +230,30 @@ impl<M: MemoryMetrics> MemoryGuard<M> {
///
/// Returns true if the release succeeds or is a no-op; false if the request exceeds granted.
pub fn release_partial(&mut self, bytes: u64) -> bool {
if bytes == 0 {
return true;
}
match &mut self.state {
GuardState::Unlimited => true,
GuardState::Limited { permit, quota } => {
if bytes == 0 {
return true;
GuardState::Released => true,
GuardState::Unlimited {
quota,
granted_bytes,
} => {
if bytes > *granted_bytes {
return false;
}
quota.sub_in_use(bytes);
*granted_bytes = granted_bytes.saturating_sub(bytes);
true
}
GuardState::Limited { quota, permit } => {
let release_permits = quota.bytes_to_permits(bytes);
match permit.split(release_permits as usize) {
Some(released_permit) => {
let released_bytes =
quota.permits_to_bytes(released_permit.num_permits() as u32);
drop(released_permit);
quota.update_in_use_metric();
debug!("Released {} bytes from memory guard", released_bytes);
quota.release_permit(released_permit);
true
}
None => false,
@@ -211,14 +265,7 @@ impl<M: MemoryMetrics> MemoryGuard<M> {
impl<M: MemoryMetrics> Drop for MemoryGuard<M> {
fn drop(&mut self) {
if let GuardState::Limited { permit, quota } =
mem::replace(&mut self.state, GuardState::Unlimited)
{
let bytes = quota.permits_to_bytes(permit.num_permits() as u32);
drop(permit);
quota.update_in_use_metric();
debug!("Released memory: {} bytes", bytes);
}
mem::replace(&mut self.state, GuardState::Released).release();
}
}

View File

@@ -13,9 +13,10 @@
// limitations under the License.
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use snafu::ensure;
use tokio::sync::{Semaphore, TryAcquireError};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
use crate::error::{
MemoryAcquireTimeoutSnafu, MemoryLimitExceededSnafu, MemorySemaphoreClosedSnafu, Result,
@@ -34,7 +35,7 @@ pub trait MemoryMetrics: Clone + Send + Sync + 'static {
/// Generic memory manager for quota-controlled operations.
#[derive(Clone)]
pub struct MemoryManager<M: MemoryMetrics> {
quota: Option<MemoryQuota<M>>,
quota: MemoryQuotaState<M>,
}
impl<M: MemoryMetrics + Default> Default for MemoryManager<M> {
@@ -51,6 +52,18 @@ pub(crate) struct MemoryQuota<M: MemoryMetrics> {
pub(crate) metrics: M,
}
#[derive(Clone)]
pub(crate) struct UnlimitedMemoryQuota<M: MemoryMetrics> {
pub(crate) current_bytes: Arc<AtomicU64>,
pub(crate) metrics: M,
}
#[derive(Clone)]
pub(crate) enum MemoryQuotaState<M: MemoryMetrics> {
Unlimited(UnlimitedMemoryQuota<M>),
Limited(MemoryQuota<M>),
}
impl<M: MemoryMetrics> MemoryManager<M> {
/// Creates a new memory manager with the given limit in bytes.
/// `limit_bytes = 0` disables the limit.
@@ -62,7 +75,12 @@ impl<M: MemoryMetrics> MemoryManager<M> {
pub fn with_granularity(limit_bytes: u64, granularity: PermitGranularity, metrics: M) -> Self {
if limit_bytes == 0 {
metrics.set_limit(0);
return Self { quota: None };
return Self {
quota: MemoryQuotaState::Unlimited(UnlimitedMemoryQuota {
current_bytes: Arc::new(AtomicU64::new(0)),
metrics,
}),
};
}
let limit_permits = granularity.bytes_to_permits(limit_bytes);
@@ -70,7 +88,7 @@ impl<M: MemoryMetrics> MemoryManager<M> {
metrics.set_limit(limit_aligned_bytes as i64);
Self {
quota: Some(MemoryQuota {
quota: MemoryQuotaState::Limited(MemoryQuota {
semaphore: Arc::new(Semaphore::new(limit_permits as usize)),
limit_permits,
granularity,
@@ -81,26 +99,30 @@ impl<M: MemoryMetrics> MemoryManager<M> {
/// Returns the configured limit in bytes (0 if unlimited).
pub fn limit_bytes(&self) -> u64 {
self.quota
.as_ref()
.map(|quota| quota.permits_to_bytes(quota.limit_permits))
.unwrap_or(0)
match &self.quota {
MemoryQuotaState::Unlimited(_) => 0,
MemoryQuotaState::Limited(quota) => quota.permits_to_bytes(quota.limit_permits),
}
}
/// Returns currently used bytes.
pub fn used_bytes(&self) -> u64 {
self.quota
.as_ref()
.map(|quota| quota.permits_to_bytes(quota.used_permits()))
.unwrap_or(0)
match &self.quota {
MemoryQuotaState::Unlimited(quota) => quota.current_bytes.load(Ordering::Acquire),
MemoryQuotaState::Limited(quota) => quota.permits_to_bytes(quota.used_permits()),
}
}
/// Returns available bytes.
///
/// Unlimited managers report `u64::MAX`.
pub fn available_bytes(&self) -> u64 {
self.quota
.as_ref()
.map(|quota| quota.permits_to_bytes(quota.available_permits_clamped()))
.unwrap_or(0)
match &self.quota {
MemoryQuotaState::Unlimited(_) => u64::MAX,
MemoryQuotaState::Limited(quota) => {
quota.permits_to_bytes(quota.available_permits_clamped())
}
}
}
/// Acquires memory, waiting if necessary until enough is available.
@@ -110,8 +132,8 @@ impl<M: MemoryMetrics> MemoryManager<M> {
/// - Returns error if the semaphore is unexpectedly closed
pub async fn acquire(&self, bytes: u64) -> Result<MemoryGuard<M>> {
match &self.quota {
None => Ok(MemoryGuard::unlimited()),
Some(quota) => {
MemoryQuotaState::Unlimited(quota) => Ok(MemoryGuard::unlimited(quota.clone(), bytes)),
MemoryQuotaState::Limited(quota) => {
let permits = quota.bytes_to_permits(bytes);
ensure!(
@@ -129,7 +151,7 @@ impl<M: MemoryMetrics> MemoryManager<M> {
.await
.map_err(|_| MemorySemaphoreClosedSnafu.build())?;
quota.update_in_use_metric();
Ok(MemoryGuard::limited(permit, quota.clone()))
Ok(MemoryGuard::limited(quota.clone(), permit))
}
}
}
@@ -137,14 +159,16 @@ impl<M: MemoryMetrics> MemoryManager<M> {
/// Tries to acquire memory. Returns Some(guard) on success, None if insufficient.
pub fn try_acquire(&self, bytes: u64) -> Option<MemoryGuard<M>> {
match &self.quota {
None => Some(MemoryGuard::unlimited()),
Some(quota) => {
MemoryQuotaState::Unlimited(quota) => {
Some(MemoryGuard::unlimited(quota.clone(), bytes))
}
MemoryQuotaState::Limited(quota) => {
let permits = quota.bytes_to_permits(bytes);
match quota.semaphore.clone().try_acquire_many_owned(permits) {
Ok(permit) => {
quota.update_in_use_metric();
Some(MemoryGuard::limited(permit, quota.clone()))
Some(MemoryGuard::limited(quota.clone(), permit))
}
Err(TryAcquireError::NoPermits) | Err(TryAcquireError::Closed) => {
quota.metrics.inc_rejected("try_acquire");
@@ -219,4 +243,49 @@ impl<M: MemoryMetrics> MemoryQuota<M> {
let bytes = self.permits_to_bytes(self.used_permits());
self.metrics.set_in_use(bytes as i64);
}
pub(crate) fn release_permit(&self, permit: OwnedSemaphorePermit) {
drop(permit);
self.update_in_use_metric();
}
}
impl<M: MemoryMetrics> UnlimitedMemoryQuota<M> {
pub(crate) fn add_in_use(&self, bytes: u64) {
if bytes == 0 {
return;
}
let previous = self
.current_bytes
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
Some(current.saturating_add(bytes))
})
.unwrap();
let new_total = previous.saturating_add(bytes);
debug_assert!(
new_total >= previous,
"unlimited memory usage counter overflowed"
);
self.metrics.set_in_use(new_total as i64);
}
pub(crate) fn sub_in_use(&self, bytes: u64) {
if bytes == 0 {
return;
}
let previous = self
.current_bytes
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
Some(current.saturating_sub(bytes))
})
.unwrap();
debug_assert!(
previous >= bytes,
"unlimited memory usage counter underflowed: current={previous}, release={bytes}"
);
let new_total = previous.saturating_sub(bytes);
self.metrics.set_in_use(new_total as i64);
}
}

View File

@@ -24,7 +24,9 @@ fn test_try_acquire_unlimited() {
let manager = MemoryManager::new(0, NoOpMetrics);
let guard = manager.try_acquire(10 * PERMIT_GRANULARITY_BYTES).unwrap();
assert_eq!(manager.limit_bytes(), 0);
assert_eq!(guard.granted_bytes(), 0);
assert_eq!(manager.available_bytes(), u64::MAX);
assert_eq!(guard.granted_bytes(), 10 * PERMIT_GRANULARITY_BYTES);
assert_eq!(manager.used_bytes(), 10 * PERMIT_GRANULARITY_BYTES);
}
#[test]
@@ -136,7 +138,10 @@ fn test_request_additional_unlimited() {
// Should always succeed with unlimited manager
assert!(guard.try_acquire_additional(100 * PERMIT_GRANULARITY_BYTES));
assert_eq!(guard.granted_bytes(), 0);
assert_eq!(guard.granted_bytes(), 105 * PERMIT_GRANULARITY_BYTES);
assert_eq!(manager.used_bytes(), 105 * PERMIT_GRANULARITY_BYTES);
drop(guard);
assert_eq!(manager.used_bytes(), 0);
}
@@ -187,9 +192,10 @@ fn test_early_release_partial_unlimited() {
let manager = MemoryManager::new(0, NoOpMetrics);
let mut guard = manager.try_acquire(100 * PERMIT_GRANULARITY_BYTES).unwrap();
// Unlimited guard - release should succeed (no-op)
// Unlimited guard should track and release exact bytes.
assert!(guard.release_partial(50 * PERMIT_GRANULARITY_BYTES));
assert_eq!(guard.granted_bytes(), 0);
assert_eq!(guard.granted_bytes(), 50 * PERMIT_GRANULARITY_BYTES);
assert_eq!(manager.used_bytes(), 50 * PERMIT_GRANULARITY_BYTES);
}
#[test]
@@ -406,6 +412,6 @@ async fn test_acquire_additional_unlimited() {
.acquire_additional(1000 * PERMIT_GRANULARITY_BYTES)
.await
.unwrap();
assert_eq!(guard.granted_bytes(), 0);
assert_eq!(manager.used_bytes(), 0);
assert_eq!(guard.granted_bytes(), 1000 * PERMIT_GRANULARITY_BYTES);
assert_eq!(manager.used_bytes(), 1000 * PERMIT_GRANULARITY_BYTES);
}

View File

@@ -8,7 +8,6 @@ license.workspace = true
testing = []
pg_kvbackend = [
"dep:tokio-postgres",
"dep:backon",
"dep:deadpool-postgres",
"dep:deadpool",
"dep:tokio-postgres-rustls",
@@ -16,7 +15,7 @@ pg_kvbackend = [
"dep:rustls-native-certs",
"dep:rustls",
]
mysql_kvbackend = ["dep:sqlx", "dep:backon"]
mysql_kvbackend = ["dep:sqlx"]
enterprise = ["prost-types"]
[lints]
@@ -28,7 +27,7 @@ api.workspace = true
async-recursion = "1.0"
async-stream.workspace = true
async-trait.workspace = true
backon = { workspace = true, optional = true }
backon.workspace = true
base64.workspace = true
bytes.workspace = true
chrono.workspace = true

View File

@@ -15,10 +15,14 @@
use std::borrow::Borrow;
use std::hash::Hash;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use futures::future::{BoxFuture, join_all};
use backon::{BackoffBuilder, ExponentialBuilder};
use futures::future::BoxFuture;
use moka::future::Cache;
use snafu::{OptionExt, ResultExt};
use tokio::time::sleep;
use crate::cache_invalidator::{CacheInvalidator, Context};
use crate::error::{self, Error, Result};
@@ -29,12 +33,29 @@ use crate::metrics;
pub type TokenFilter<CacheToken> = Box<dyn Fn(&CacheToken) -> bool + Send + Sync>;
/// Invalidates cached values by [CacheToken]s.
pub type Invalidator<K, V, CacheToken> =
Box<dyn for<'a> Fn(&'a Cache<K, V>, &'a CacheToken) -> BoxFuture<'a, Result<()>> + Send + Sync>;
pub type Invalidator<K, V, CacheToken> = Box<
dyn for<'a> Fn(&'a Cache<K, V>, &'a [&CacheToken]) -> BoxFuture<'a, Result<()>> + Send + Sync,
>;
/// Initializes value (i.e., fetches from remote).
pub type Initializer<K, V> = Arc<dyn Fn(&'_ K) -> BoxFuture<'_, Result<Option<V>>> + Send + Sync>;
#[derive(Debug, Clone, Copy)]
/// Initialization strategy for cache-miss loading.
///
/// This strategy is selected when building [CacheContainer] and remains immutable
/// for the lifetime of the container instance.
pub enum InitStrategy {
/// Fast path: load once without version conflict retry.
///
/// Under concurrent invalidation, callers may observe stale/dirty value.
Unchecked,
/// Strict path: retry load when version changes during initialization.
///
/// This avoids returning dirty value under invalidate/load races.
VersionChecked,
}
/// [CacheContainer] provides ability to:
/// - Cache value loaded by [Initializer].
/// - Invalidate caches by [Invalidator].
@@ -44,6 +65,16 @@ pub struct CacheContainer<K, V, CacheToken> {
invalidator: Invalidator<K, V, CacheToken>,
initializer: Initializer<K, V>,
token_filter: fn(&CacheToken) -> bool,
version: Arc<AtomicUsize>,
init_strategy: InitStrategy,
}
fn latest_get_backoff() -> impl Iterator<Item = Duration> {
ExponentialBuilder::default()
.with_min_delay(Duration::from_millis(10))
.with_max_delay(Duration::from_millis(100))
.with_max_times(3)
.build()
}
impl<K, V, CacheToken> CacheContainer<K, V, CacheToken>
@@ -52,13 +83,37 @@ where
V: Send + Sync,
CacheToken: Send + Sync,
{
/// Constructs an [CacheContainer].
/// Constructs an [CacheContainer] with [InitStrategy::Unchecked].
///
/// This keeps the historical behavior and can return stale/dirty value under
/// concurrent invalidation.
pub fn new(
name: String,
cache: Cache<K, V>,
invalidator: Invalidator<K, V, CacheToken>,
initializer: Initializer<K, V>,
token_filter: fn(&CacheToken) -> bool,
) -> Self {
Self::with_strategy(
name,
cache,
invalidator,
initializer,
token_filter,
InitStrategy::Unchecked,
)
}
/// Constructs an [CacheContainer] with explicit [InitStrategy].
///
/// The strategy is fixed at construction time and cannot be changed later.
pub fn with_strategy(
name: String,
cache: Cache<K, V>,
invalidator: Invalidator<K, V, CacheToken>,
initializer: Initializer<K, V>,
token_filter: fn(&CacheToken) -> bool,
init_strategy: InitStrategy,
) -> Self {
Self {
name,
@@ -66,6 +121,8 @@ where
invalidator,
initializer,
token_filter,
version: Arc::new(AtomicUsize::new(0)),
init_strategy,
}
}
@@ -75,6 +132,67 @@ where
}
}
impl<K, V, CacheToken> CacheContainer<K, V, CacheToken> {
fn inc_version(&self) {
self.version.fetch_add(1, Ordering::Relaxed);
}
}
async fn init<'a, K, V>(init: Initializer<K, V>, key: K, cache_name: &'a str) -> Result<V>
where
K: Send + Sync + 'a,
V: Send + 'a,
{
metrics::CACHE_CONTAINER_CACHE_MISS
.with_label_values(&[cache_name])
.inc();
let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
.with_label_values(&[cache_name])
.start_timer();
init(&key)
.await
.transpose()
.context(error::ValueNotExistSnafu)?
}
async fn init_with_retry<'a, K, V>(
init: Initializer<K, V>,
key: K,
mut backoff: impl Iterator<Item = Duration> + 'a,
version: Arc<AtomicUsize>,
cache_name: &'a str,
) -> Result<V>
where
K: Send + Sync + 'a,
V: Send + 'a,
{
let mut attempts = 1usize;
loop {
let pre_version = version.load(Ordering::Relaxed);
metrics::CACHE_CONTAINER_CACHE_MISS
.with_label_values(&[cache_name])
.inc();
let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
.with_label_values(&[cache_name])
.start_timer();
let value = init(&key)
.await
.transpose()
.context(error::ValueNotExistSnafu)??;
if pre_version == version.load(Ordering::Relaxed) {
return Ok(value);
}
if let Some(duration) = backoff.next() {
sleep(duration).await;
attempts += 1;
} else {
return error::GetLatestCacheRetryExceededSnafu { attempts }.fail();
}
}
}
#[async_trait::async_trait]
impl<K, V> CacheInvalidator for CacheContainer<K, V, CacheIdent>
where
@@ -82,14 +200,15 @@ where
V: Send + Sync,
{
async fn invalidate(&self, _ctx: &Context, caches: &[CacheIdent]) -> Result<()> {
let tasks = caches
let idents = caches
.iter()
.filter(|token| (self.token_filter)(token))
.map(|token| (self.invalidator)(&self.cache, token));
join_all(tasks)
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;
.collect::<Vec<_>>();
if !idents.is_empty() {
self.inc_version();
(self.invalidator)(&self.cache, &idents).await?;
}
Ok(())
}
}
@@ -99,27 +218,39 @@ where
K: Copy + Hash + Eq + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
{
/// Returns a _clone_ of the value corresponding to the key.
/// Returns a value from cache for copyable keys.
///
/// With [InitStrategy::Unchecked], this method prioritizes latency and may
/// return stale/dirty value. With [InitStrategy::VersionChecked], this method
/// retries initialization on version change and avoids dirty returns.
pub async fn get(&self, key: K) -> Result<Option<V>> {
metrics::CACHE_CONTAINER_CACHE_GET
.with_label_values(&[&self.name])
.inc();
let moved_init = self.initializer.clone();
let moved_key = key;
let init = async move {
metrics::CACHE_CONTAINER_CACHE_MISS
.with_label_values(&[&self.name])
.inc();
let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
.with_label_values(&[&self.name])
.start_timer();
moved_init(&moved_key)
.await
.transpose()
.context(error::ValueNotExistSnafu)?
let result = match self.init_strategy {
InitStrategy::Unchecked => {
self.cache
.try_get_with(key, init(self.initializer.clone(), key, &self.name))
.await
}
InitStrategy::VersionChecked => {
self.cache
.try_get_with(
key,
init_with_retry(
self.initializer.clone(),
key,
latest_get_backoff(),
self.version.clone(),
&self.name,
),
)
.await
}
};
match self.cache.try_get_with(key, init).await {
match result {
Ok(value) => Ok(Some(value)),
Err(err) => match err.as_ref() {
Error::ValueNotExist { .. } => Ok(None),
@@ -136,14 +267,15 @@ where
{
/// Invalidates cache by [CacheToken].
pub async fn invalidate(&self, caches: &[CacheToken]) -> Result<()> {
let tasks = caches
let idents = caches
.iter()
.filter(|token| (self.token_filter)(token))
.map(|token| (self.invalidator)(&self.cache, token));
join_all(tasks)
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;
.collect::<Vec<_>>();
if !idents.is_empty() {
self.inc_version();
(self.invalidator)(&self.cache, &idents).await?;
}
Ok(())
}
@@ -156,7 +288,11 @@ where
self.cache.contains_key(key)
}
/// Returns a _clone_ of the value corresponding to the key.
/// Returns a value from cache by key reference.
///
/// With [InitStrategy::Unchecked], this method prioritizes latency and may
/// return stale/dirty value. With [InitStrategy::VersionChecked], this method
/// retries initialization on version change and avoids dirty returns.
pub async fn get_by_ref<Q>(&self, key: &Q) -> Result<Option<V>>
where
K: Borrow<Q>,
@@ -165,24 +301,32 @@ where
metrics::CACHE_CONTAINER_CACHE_GET
.with_label_values(&[&self.name])
.inc();
let moved_init = self.initializer.clone();
let moved_key = key.to_owned();
let init = async move {
metrics::CACHE_CONTAINER_CACHE_MISS
.with_label_values(&[&self.name])
.inc();
let _timer = metrics::CACHE_CONTAINER_LOAD_CACHE
.with_label_values(&[&self.name])
.start_timer();
moved_init(&moved_key)
.await
.transpose()
.context(error::ValueNotExistSnafu)?
let result = match self.init_strategy {
InitStrategy::Unchecked => {
self.cache
.try_get_with_by_ref(
key,
init(self.initializer.clone(), key.to_owned(), &self.name),
)
.await
}
InitStrategy::VersionChecked => {
self.cache
.try_get_with_by_ref(
key,
init_with_retry(
self.initializer.clone(),
key.to_owned(),
latest_get_backoff(),
self.version.clone(),
&self.name,
),
)
.await
}
};
match self.cache.try_get_with_by_ref(key, init).await {
match result {
Ok(value) => Ok(Some(value)),
Err(err) => match err.as_ref() {
Error::ValueNotExist { .. } => Ok(None),
@@ -296,9 +440,11 @@ mod tests {
moved_counter.fetch_add(1, Ordering::Relaxed);
Box::pin(async { Ok(Some("hi".to_string())) })
});
let invalidator: Invalidator<String, String, String> = Box::new(|cache, key| {
let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
Box::pin(async move {
cache.invalidate(key).await;
for key in keys {
cache.invalidate(*key).await;
}
Ok(())
})
});
@@ -323,4 +469,46 @@ mod tests {
assert_eq!(value, "hi");
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_get_by_ref_returns_fresh_value_after_invalidate() {
let cache: Cache<String, String> = CacheBuilder::new(128).build();
let counter = Arc::new(AtomicI32::new(0));
let moved_counter = counter.clone();
let init: Initializer<String, String> = Arc::new(move |_| {
let counter = moved_counter.clone();
Box::pin(async move {
let n = counter.fetch_add(1, Ordering::Relaxed) + 1;
sleep(Duration::from_millis(100)).await;
Ok(Some(format!("v{n}")))
})
});
let invalidator: Invalidator<String, String, String> = Box::new(|cache, keys| {
Box::pin(async move {
for key in keys {
cache.invalidate(*key).await;
}
Ok(())
})
});
let adv_cache = Arc::new(CacheContainer::with_strategy(
"test".to_string(),
cache,
invalidator,
init,
always_true_filter,
InitStrategy::VersionChecked,
));
let moved_cache = adv_cache.clone();
let get_task = tokio::spawn(async move { moved_cache.get_by_ref("foo").await });
sleep(Duration::from_millis(50)).await;
adv_cache.invalidate(&["foo".to_string()]).await.unwrap();
let value = get_task.await.unwrap().unwrap().unwrap();
assert_eq!(value, "v2");
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
}

View File

@@ -170,20 +170,22 @@ async fn handle_drop_flow(
fn invalidator<'a>(
cache: &'a Cache<TableId, FlownodeFlowSet>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
match ident {
CacheIdent::CreateFlow(create_flow) => handle_create_flow(cache, create_flow).await,
CacheIdent::DropFlow(drop_flow) => handle_drop_flow(cache, drop_flow).await,
CacheIdent::FlowNodeAddressChange(node_id) => {
info!(
"Invalidate flow node cache for node_id in table_flownode: {}",
node_id
);
cache.invalidate_all();
for ident in idents {
match ident {
CacheIdent::CreateFlow(create_flow) => handle_create_flow(cache, create_flow).await,
CacheIdent::DropFlow(drop_flow) => handle_drop_flow(cache, drop_flow).await,
CacheIdent::FlowNodeAddressChange(node_id) => {
info!(
"Invalidate flow node cache for node_id in table_flownode: {}",
node_id
);
cache.invalidate_all();
}
_ => {}
}
_ => {}
}
Ok(())
})

View File

@@ -58,11 +58,13 @@ fn init_factory(schema_manager: SchemaManager) -> Initializer<SchemaName, Arc<Sc
fn invalidator<'a>(
cache: &'a Cache<SchemaName, Arc<SchemaNameValue>>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, crate::error::Result<()>> {
Box::pin(async move {
if let CacheIdent::SchemaName(schema_name) = ident {
cache.invalidate(schema_name).await
for ident in idents {
if let CacheIdent::SchemaName(schema_name) = ident {
cache.invalidate(schema_name).await
}
}
Ok(())
})

View File

@@ -61,11 +61,13 @@ fn init_factory(table_info_manager: TableInfoManagerRef) -> Initializer<TableId,
fn invalidator<'a>(
cache: &'a Cache<TableId, Arc<TableInfo>>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
for ident in idents {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
}
}
Ok(())
})

View File

@@ -71,11 +71,13 @@ fn init_factory(table_name_manager: TableNameManagerRef) -> Initializer<TableNam
fn invalidator<'a>(
cache: &'a Cache<TableName, TableId>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
if let CacheIdent::TableName(table_name) = ident {
cache.invalidate(table_name).await
for ident in idents {
if let CacheIdent::TableName(table_name) = ident {
cache.invalidate(table_name).await
}
}
Ok(())
})

View File

@@ -19,6 +19,7 @@ use moka::future::Cache;
use snafu::OptionExt;
use store_api::storage::TableId;
use crate::cache::container::InitStrategy;
use crate::cache::{CacheContainer, Initializer};
use crate::error;
use crate::error::Result;
@@ -65,7 +66,14 @@ pub fn new_table_route_cache(
let table_info_manager = Arc::new(TableRouteManager::new(kv_backend));
let init = init_factory(table_info_manager);
CacheContainer::new(name, cache, Box::new(invalidator), init, filter)
CacheContainer::with_strategy(
name,
cache,
Box::new(invalidator),
init,
filter,
InitStrategy::VersionChecked,
)
}
fn init_factory(
@@ -92,11 +100,13 @@ fn init_factory(
fn invalidator<'a>(
cache: &'a Cache<TableId, Arc<TableRoute>>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
for ident in idents {
if let CacheIdent::TableId(table_id) = ident {
cache.invalidate(table_id).await
}
}
Ok(())
})

View File

@@ -65,7 +65,7 @@ fn init_factory(table_info_manager: TableInfoManager) -> Initializer<TableId, Ar
/// Never invalidates table id schema cache.
fn invalidator<'a>(
_cache: &'a Cache<TableId, Arc<SchemaName>>,
_ident: &'a CacheIdent,
_idents: &'a [&CacheIdent],
) -> BoxFuture<'a, error::Result<()>> {
Box::pin(std::future::ready(Ok(())))
}

View File

@@ -60,11 +60,13 @@ fn init_factory(view_info_manager: ViewInfoManagerRef) -> Initializer<TableId, A
fn invalidator<'a>(
cache: &'a Cache<TableId, Arc<ViewInfoValue>>,
ident: &'a CacheIdent,
idents: &'a [&CacheIdent],
) -> BoxFuture<'a, Result<()>> {
Box::pin(async move {
if let CacheIdent::TableId(view_id) = ident {
cache.invalidate(view_id).await
for ident in idents {
if let CacheIdent::TableId(view_id) = ident {
cache.invalidate(view_id).await
}
}
Ok(())
})

View File

@@ -21,15 +21,85 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use common_telemetry::{error, info, warn};
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::broadcast::{self, Receiver, Sender};
use crate::error::Result;
use crate::metasrv::MetasrvNodeInfo;
pub(crate) const CANDIDATE_LEASE_SECS: u64 = 600;
pub const CANDIDATE_LEASE_SECS: u64 = 600;
const KEEP_ALIVE_INTERVAL_SECS: u64 = CANDIDATE_LEASE_SECS / 2;
/// The value of the leader. It is used to store the leader's address.
pub struct LeaderValue(pub String);
impl<T: AsRef<[u8]>> From<T> for LeaderValue {
fn from(value: T) -> Self {
let string = String::from_utf8_lossy(value.as_ref());
Self(string.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetasrvNodeInfo {
// The metasrv's address
pub addr: String,
// The node build version
pub version: String,
// The node build git commit hash
pub git_commit: String,
// The node start timestamp in milliseconds
pub start_time_ms: u64,
// The node total cpu millicores
#[serde(default)]
pub total_cpu_millicores: i64,
// The node total memory bytes
#[serde(default)]
pub total_memory_bytes: i64,
/// The node build cpu usage millicores
#[serde(default)]
pub cpu_usage_millicores: i64,
/// The node build memory usage bytes
#[serde(default)]
pub memory_usage_bytes: i64,
// The node hostname
#[serde(default)]
pub hostname: String,
}
// TODO(zyy17): Allow deprecated fields for backward compatibility. Remove this when the deprecated top-level fields are removed from the proto.
#[allow(deprecated)]
impl From<MetasrvNodeInfo> for api::v1::meta::MetasrvNodeInfo {
fn from(node_info: MetasrvNodeInfo) -> Self {
Self {
peer: Some(api::v1::meta::Peer {
addr: node_info.addr,
..Default::default()
}),
// TODO(zyy17): The following top-level fields are deprecated. They are kept for backward compatibility and will be removed in a future version.
// New code should use the fields in `info.NodeInfo` instead.
version: node_info.version.clone(),
git_commit: node_info.git_commit.clone(),
start_time_ms: node_info.start_time_ms,
cpus: node_info.total_cpu_millicores as u32,
memory_bytes: node_info.total_memory_bytes as u64,
// The canonical location for node information.
info: Some(api::v1::meta::NodeInfo {
version: node_info.version,
git_commit: node_info.git_commit,
start_time_ms: node_info.start_time_ms,
total_cpu_millicores: node_info.total_cpu_millicores,
total_memory_bytes: node_info.total_memory_bytes,
cpu_usage_millicores: node_info.cpu_usage_millicores,
memory_usage_bytes: node_info.memory_usage_bytes,
cpus: node_info.total_cpu_millicores as u32,
memory_bytes: node_info.total_memory_bytes as u64,
hostname: node_info.hostname,
}),
}
}
}
/// Messages sent when the leader changes.
#[derive(Debug, Clone)]
pub enum LeaderChangeMessage {
@@ -168,3 +238,5 @@ pub trait Election: Send + Sync {
fn subscribe_leader_change(&self) -> Receiver<LeaderChangeMessage>;
}
pub type ElectionRef = Arc<dyn Election<Leader = LeaderValue>>;

View File

@@ -16,8 +16,6 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use common_meta::distributed_time_constants::{META_KEEP_ALIVE_INTERVAL_SECS, META_LEASE_SECS};
use common_meta::key::{CANDIDATES_ROOT, ELECTION_KEY};
use common_telemetry::{error, info, warn};
use etcd_client::{
Client, GetOptions, LeaderKey as EtcdLeaderKey, LeaseKeepAliveStream, LeaseKeeper, PutOptions,
@@ -27,13 +25,15 @@ use tokio::sync::broadcast;
use tokio::sync::broadcast::Receiver;
use tokio::time::{MissedTickBehavior, timeout};
use crate::distributed_time_constants::{META_KEEP_ALIVE_INTERVAL_SECS, META_LEASE_SECS};
use crate::election::{
CANDIDATE_LEASE_SECS, Election, KEEP_ALIVE_INTERVAL_SECS, LeaderChangeMessage, LeaderKey,
listen_leader_change, send_leader_change_and_set_flags,
CANDIDATE_LEASE_SECS, Election, ElectionRef, KEEP_ALIVE_INTERVAL_SECS, LeaderChangeMessage,
LeaderKey, LeaderValue, MetasrvNodeInfo, listen_leader_change,
send_leader_change_and_set_flags,
};
use crate::error;
use crate::error::Result;
use crate::metasrv::{ElectionRef, LeaderValue, MetasrvNodeInfo};
use crate::key::{CANDIDATES_ROOT, ELECTION_KEY};
impl LeaderKey for EtcdLeaderKey {
fn name(&self) -> &[u8] {
@@ -253,7 +253,7 @@ impl Election for EtcdElection {
.leader(self.election_key())
.await
.context(error::EtcdFailedSnafu)?;
let leader_value = res.kv().context(error::NoLeaderSnafu)?.value();
let leader_value = res.kv().context(error::ElectionNoLeaderSnafu)?.value();
Ok(leader_value.into())
}
}
@@ -279,7 +279,7 @@ impl EtcdElection {
ensure!(
res.ttl() > 0,
error::UnexpectedSnafu {
violated: "Failed to refresh the lease",
err_msg: "Failed to refresh the lease".to_string(),
}
);

View File

@@ -36,7 +36,7 @@ fn parse_value_and_expire_time(value: &str) -> Result<(String, Timestamp)> {
.split(LEASE_SEP)
.collect_tuple()
.with_context(|| UnexpectedSnafu {
violated: format!(
err_msg: format!(
"Invalid value {}, expect node info || {} || expire time",
value, LEASE_SEP
),
@@ -45,7 +45,7 @@ fn parse_value_and_expire_time(value: &str) -> Result<(String, Timestamp)> {
let expire_time = match Timestamp::from_str(expire_time, None) {
Ok(ts) => ts,
Err(_) => UnexpectedSnafu {
violated: format!("Invalid timestamp: {}", expire_time),
err_msg: format!("Invalid timestamp: {}", expire_time),
}
.fail()?,
};

View File

@@ -16,7 +16,6 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use common_meta::key::{CANDIDATES_ROOT, ELECTION_KEY};
use common_telemetry::{error, info, warn};
use common_time::Timestamp;
use snafu::{OptionExt, ResultExt, ensure};
@@ -29,14 +28,15 @@ use tokio::time::MissedTickBehavior;
use crate::election::rds::{LEASE_SEP, Lease, RdsLeaderKey, parse_value_and_expire_time};
use crate::election::{
Election, LeaderChangeMessage, listen_leader_change, send_leader_change_and_set_flags,
Election, ElectionRef, LeaderChangeMessage, LeaderValue, MetasrvNodeInfo, listen_leader_change,
send_leader_change_and_set_flags,
};
use crate::error::{
AcquireMySqlClientSnafu, DecodeSqlValueSnafu, DeserializeFromJsonSnafu,
LeaderLeaseChangedSnafu, LeaderLeaseExpiredSnafu, MySqlExecutionSnafu, NoLeaderSnafu, Result,
SerializeToJsonSnafu, SqlExecutionTimeoutSnafu, UnexpectedSnafu,
ElectionLeaderLeaseChangedSnafu, ElectionLeaderLeaseExpiredSnafu, ElectionNoLeaderSnafu,
MySqlExecutionSnafu, Result, SerializeToJsonSnafu, SqlExecutionTimeoutSnafu, UnexpectedSnafu,
};
use crate::metasrv::{ElectionRef, LeaderValue, MetasrvNodeInfo};
use crate::key::{CANDIDATES_ROOT, ELECTION_KEY};
struct ElectionSqlFactory<'a> {
table_name: &'a str,
@@ -592,7 +592,7 @@ impl Election for MySqlElection {
ensure!(
lease.expire_time > lease.current,
UnexpectedSnafu {
violated: format!(
err_msg: format!(
"Candidate lease expired at {:?} (current time: {:?}), key: {:?}",
lease.expire_time,
lease.current,
@@ -667,10 +667,10 @@ impl Election for MySqlElection {
let client = self.client.lock().await;
let mut executor = Executor::Default(client);
if let Some(lease) = self.get_value_with_lease(&key, &mut executor).await? {
ensure!(lease.expire_time > lease.current, NoLeaderSnafu);
ensure!(lease.expire_time > lease.current, ElectionNoLeaderSnafu);
Ok(lease.leader_value.as_bytes().into())
} else {
NoLeaderSnafu.fail()
ElectionNoLeaderSnafu.fail()
}
}
}
@@ -705,7 +705,7 @@ impl MySqlElection {
let current_time = match Timestamp::from_str(&current_time_str, None) {
Ok(ts) => ts,
Err(_) => UnexpectedSnafu {
violated: format!("Invalid timestamp: {}", current_time_str),
err_msg: format!("Invalid timestamp: {}", current_time_str),
}
.fail()?,
};
@@ -740,7 +740,7 @@ impl MySqlElection {
current = match Timestamp::from_str(current_time_str, None) {
Ok(ts) => ts,
Err(_) => UnexpectedSnafu {
violated: format!("Invalid timestamp: {}", current_time_str),
err_msg: format!("Invalid timestamp: {}", current_time_str),
}
.fail()?,
};
@@ -777,7 +777,7 @@ impl MySqlElection {
ensure!(
res == 1,
UnexpectedSnafu {
violated: format!("Failed to update key: {}", String::from_utf8_lossy(key)),
err_msg: format!("Failed to update key: {}", String::from_utf8_lossy(key)),
}
);
@@ -920,9 +920,12 @@ impl MySqlElection {
/// will be released.
/// - **Case 2**: If all checks pass, the function returns without performing any actions.
fn lease_check(&self, lease: &Option<Lease>) -> Result<Lease> {
let lease = lease.as_ref().context(NoLeaderSnafu)?;
let lease = lease.as_ref().context(ElectionNoLeaderSnafu)?;
// Case 1: Lease expired
ensure!(lease.expire_time > lease.current, LeaderLeaseExpiredSnafu);
ensure!(
lease.expire_time > lease.current,
ElectionLeaderLeaseExpiredSnafu
);
// Case 2: Everything is fine
Ok(lease.clone())
}
@@ -960,7 +963,7 @@ impl MySqlElection {
let remote_lease = self.get_value_with_lease(&key, &mut executor).await?;
ensure!(
expected_lease.map(|lease| lease.origin) == remote_lease.map(|lease| lease.origin),
LeaderLeaseChangedSnafu
ElectionLeaderLeaseChangedSnafu
);
self.delete_value(&key, &mut executor).await?;
self.put_value_with_lease(
@@ -987,12 +990,11 @@ mod tests {
use std::assert_matches::assert_matches;
use std::env;
use common_meta::maybe_skip_mysql_integration_test;
use common_telemetry::init_default_ut_logging;
use sqlx::MySqlPool;
use super::*;
use crate::error;
use crate::utils::mysql::create_mysql_pool;
use crate::{error, maybe_skip_mysql_integration_test};
async fn create_mysql_client(
table_name: Option<&str>,
@@ -1003,11 +1005,11 @@ mod tests {
let endpoint = env::var("GT_MYSQL_ENDPOINTS").unwrap_or_default();
if endpoint.is_empty() {
return UnexpectedSnafu {
violated: "MySQL endpoint is empty".to_string(),
err_msg: "MySQL endpoint is empty".to_string(),
}
.fail();
}
let pool = create_mysql_pool(&[endpoint], None).await.unwrap();
let pool = MySqlPool::connect(&endpoint).await.unwrap();
let mut client = ElectionMysqlClient::new(
pool,
execution_timeout,
@@ -1302,7 +1304,7 @@ mod tests {
let err = elected(&leader_mysql_election, table_name, Some(incorrect_lease))
.await
.unwrap_err();
assert_matches!(err, error::Error::LeaderLeaseChanged { .. });
assert_matches!(err, error::Error::ElectionLeaderLeaseChanged { .. });
let lease = get_lease(&leader_mysql_election).await;
assert!(lease.is_none());
drop_table(&leader_mysql_election.client, table_name).await;

View File

@@ -16,7 +16,6 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use common_meta::key::{CANDIDATES_ROOT, ELECTION_KEY};
use common_telemetry::{error, info, warn};
use common_time::Timestamp;
use deadpool_postgres::{Manager, Pool};
@@ -28,13 +27,15 @@ use tokio_postgres::types::ToSql;
use crate::election::rds::{LEASE_SEP, Lease, RdsLeaderKey, parse_value_and_expire_time};
use crate::election::{
Election, LeaderChangeMessage, listen_leader_change, send_leader_change_and_set_flags,
Election, ElectionRef, LeaderChangeMessage, LeaderValue, MetasrvNodeInfo, listen_leader_change,
send_leader_change_and_set_flags,
};
use crate::error::{
DeserializeFromJsonSnafu, GetPostgresClientSnafu, NoLeaderSnafu, PostgresExecutionSnafu,
Result, SerializeToJsonSnafu, SqlExecutionTimeoutSnafu, UnexpectedSnafu,
DeserializeFromJsonSnafu, ElectionNoLeaderSnafu, GetPostgresClientSnafu,
PostgresExecutionSnafu, Result, SerializeToJsonSnafu, SqlExecutionTimeoutSnafu,
UnexpectedSnafu,
};
use crate::metasrv::{ElectionRef, LeaderValue, MetasrvNodeInfo};
use crate::key::{CANDIDATES_ROOT, ELECTION_KEY};
struct ElectionSqlFactory<'a> {
lock_id: u64,
@@ -404,13 +405,13 @@ impl Election for PgElection {
.get_value_with_lease(&key)
.await?
.context(UnexpectedSnafu {
violated: format!("Failed to get lease for key: {:?}", key),
err_msg: format!("Failed to get lease for key: {:?}", key),
})?;
ensure!(
lease.expire_time > lease.current,
UnexpectedSnafu {
violated: format!(
err_msg: format!(
"Candidate lease expired at {:?} (current time {:?}), key: {:?}",
lease.expire_time, lease.current, key
),
@@ -464,11 +465,11 @@ impl Election for PgElection {
.query(&self.sql_set.campaign, &[])
.await?;
let row = res.first().context(UnexpectedSnafu {
violated: "Failed to get the result of acquiring advisory lock",
err_msg: "Failed to get the result of acquiring advisory lock".to_string(),
})?;
let is_leader = row.try_get(0).map_err(|_| {
UnexpectedSnafu {
violated: "Failed to get the result of get lock",
err_msg: "Failed to get the result of get lock".to_string(),
}
.build()
})?;
@@ -500,10 +501,10 @@ impl Election for PgElection {
} else {
let key = self.election_key();
if let Some(lease) = self.get_value_with_lease(&key).await? {
ensure!(lease.expire_time > lease.current, NoLeaderSnafu);
ensure!(lease.expire_time > lease.current, ElectionNoLeaderSnafu);
Ok(lease.leader_value.as_bytes().into())
} else {
NoLeaderSnafu.fail()
ElectionNoLeaderSnafu.fail()
}
}
}
@@ -537,7 +538,7 @@ impl PgElection {
let current_time = match Timestamp::from_str(current_time_str, None) {
Ok(ts) => ts,
Err(_) => UnexpectedSnafu {
violated: format!("Invalid timestamp: {}", current_time_str),
err_msg: format!("Invalid timestamp: {}", current_time_str),
}
.fail()?,
};
@@ -576,7 +577,7 @@ impl PgElection {
current = match Timestamp::from_str(current_time_str, None) {
Ok(ts) => ts,
Err(_) => UnexpectedSnafu {
violated: format!("Invalid timestamp: {}", current_time_str),
err_msg: format!("Invalid timestamp: {}", current_time_str),
}
.fail()?,
};
@@ -613,7 +614,7 @@ impl PgElection {
ensure!(
res == 1,
UnexpectedSnafu {
violated: format!("Failed to update key: {}", String::from_utf8_lossy(key)),
err_msg: format!("Failed to update key: {}", String::from_utf8_lossy(key)),
}
);
@@ -742,9 +743,9 @@ impl PgElection {
let lease = self
.get_value_with_lease(&key)
.await?
.context(NoLeaderSnafu)?;
.context(ElectionNoLeaderSnafu)?;
// Case 2
ensure!(lease.expire_time > lease.current, NoLeaderSnafu);
ensure!(lease.expire_time > lease.current, ElectionNoLeaderSnafu);
// Case 3
Ok(())
}
@@ -831,11 +832,11 @@ mod tests {
use std::assert_matches::assert_matches;
use std::env;
use common_meta::maybe_skip_postgres_integration_test;
use deadpool_postgres::{Config, Runtime};
use tokio_postgres::NoTls;
use super::*;
use crate::error;
use crate::utils::postgres::create_postgres_pool;
use crate::{error, maybe_skip_postgres_integration_test};
async fn create_postgres_client(
table_name: Option<&str>,
@@ -846,11 +847,13 @@ mod tests {
let endpoint = env::var("GT_POSTGRES_ENDPOINTS").unwrap_or_default();
if endpoint.is_empty() {
return UnexpectedSnafu {
violated: "Postgres endpoint is empty".to_string(),
err_msg: "Postgres endpoint is empty".to_string(),
}
.fail();
}
let pool = create_postgres_pool(&[endpoint], None, None).await.unwrap();
let mut cfg = Config::new();
cfg.url = Some(endpoint);
let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls).unwrap();
let mut pg_client = ElectionPgClient::new(
pool,
execution_timeout,

View File

@@ -338,6 +338,24 @@ pub enum Error {
location: Location,
},
#[snafu(display("Metasrv election has no leader at this moment"))]
ElectionNoLeader {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Metasrv election leader lease expired"))]
ElectionLeaderLeaseExpired {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Metasrv election leader lease changed during election"))]
ElectionLeaderLeaseChanged {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Table already exists, table: {}", table_name))]
TableAlreadyExists {
table_name: String,
@@ -714,6 +732,16 @@ pub enum Error {
#[snafu(display("Failed to get cache"))]
GetCache { source: Arc<Error> },
#[snafu(display(
"Failed to get latest cache value after {} attempts due to concurrent invalidation",
attempts
))]
GetLatestCacheRetryExceeded {
attempts: usize,
#[snafu(implicit)]
location: Location,
},
#[cfg(feature = "pg_kvbackend")]
#[snafu(display("Failed to execute via Postgres, sql: {}", sql))]
PostgresExecution {
@@ -741,6 +769,15 @@ pub enum Error {
location: Location,
},
#[cfg(feature = "pg_kvbackend")]
#[snafu(display("Failed to get Postgres client"))]
GetPostgresClient {
#[snafu(source)]
error: deadpool::managed::PoolError<tokio_postgres::Error>,
#[snafu(implicit)]
location: Location,
},
#[cfg(feature = "pg_kvbackend")]
#[snafu(display("Failed to {} Postgres transaction", operation))]
PostgresTransaction {
@@ -795,6 +832,24 @@ pub enum Error {
location: Location,
},
#[cfg(feature = "mysql_kvbackend")]
#[snafu(display("Failed to decode sql value"))]
DecodeSqlValue {
#[snafu(source)]
error: sqlx::error::Error,
#[snafu(implicit)]
location: Location,
},
#[cfg(feature = "mysql_kvbackend")]
#[snafu(display("Failed to acquire mysql client from pool"))]
AcquireMySqlClient {
#[snafu(source)]
error: sqlx::Error,
#[snafu(implicit)]
location: Location,
},
#[cfg(feature = "mysql_kvbackend")]
#[snafu(display("Failed to {} MySql transaction", operation))]
MySqlTransaction {
@@ -812,6 +867,15 @@ pub enum Error {
location: Location,
},
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
#[snafu(display("Sql execution timeout, sql: {}, duration: {:?}", sql, duration))]
SqlExecutionTimeout {
sql: String,
duration: std::time::Duration,
#[snafu(implicit)]
location: Location,
},
#[snafu(display(
"Datanode table info not found, table id: {}, datanode id: {}",
table_id,
@@ -1063,8 +1127,12 @@ impl ErrorExt for Error {
| ConnectEtcd { .. }
| MoveValues { .. }
| GetCache { .. }
| GetLatestCacheRetryExceeded { .. }
| SerializeToJson { .. }
| DeserializeFromJson { .. } => StatusCode::Internal,
| DeserializeFromJson { .. }
| ElectionNoLeader { .. }
| ElectionLeaderLeaseExpired { .. }
| ElectionLeaderLeaseChanged { .. } => StatusCode::Internal,
NoLeader { .. } => StatusCode::TableUnavailable,
ValueNotExist { .. }
@@ -1187,15 +1255,18 @@ impl ErrorExt for Error {
PostgresExecution { .. }
| CreatePostgresPool { .. }
| GetPostgresConnection { .. }
| GetPostgresClient { .. }
| PostgresTransaction { .. }
| PostgresTlsConfig { .. }
| InvalidTlsConfig { .. } => StatusCode::Internal,
#[cfg(feature = "mysql_kvbackend")]
MySqlExecution { .. } | CreateMySqlPool { .. } | MySqlTransaction { .. } => {
StatusCode::Internal
}
MySqlExecution { .. }
| CreateMySqlPool { .. }
| DecodeSqlValue { .. }
| AcquireMySqlClient { .. }
| MySqlTransaction { .. } => StatusCode::Internal,
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
RdsTransactionRetryFailed { .. } => StatusCode::Internal,
RdsTransactionRetryFailed { .. } | SqlExecutionTimeout { .. } => StatusCode::Internal,
DatanodeTableInfoNotFound { .. } => StatusCode::Internal,
}
}
@@ -1243,7 +1314,10 @@ impl Error {
/// Determine whether it is a retry later type through [StatusCode]
pub fn is_retry_later(&self) -> bool {
matches!(self, Error::RetryLater { .. })
matches!(
self,
Error::RetryLater { .. } | Error::GetLatestCacheRetryExceeded { .. }
)
}
/// Determine whether it needs to clean poisons.

View File

@@ -22,6 +22,7 @@ pub mod datanode;
pub mod ddl;
pub mod ddl_manager;
pub mod distributed_time_constants;
pub mod election;
pub mod error;
pub mod flow_name;
pub mod heartbeat;

View File

@@ -17,6 +17,8 @@ use std::sync::Arc;
use std::time::Duration;
use backon::{BackoffBuilder, ExponentialBuilder};
use common_error::ext::PlainError;
use common_error::status_code::StatusCode;
use common_event_recorder::EventRecorderRef;
use common_telemetry::tracing_context::{FutureExt, TracingContext};
use common_telemetry::{debug, error, info, tracing};
@@ -90,6 +92,45 @@ impl Drop for ProcedureGuard {
}
}
/// Returns a list of conflicting lock keys between a parent and a child procedure.
/// Evaluates the Read/Write lock compatibility matrix:
/// - Share + Share => Compatible
/// - Exclusive + Any => Conflict
/// - Any + Exclusive => Conflict
fn find_lock_conflicts<'a>(
parent_keys: impl Iterator<Item = &'a StringKey>,
child_keys: impl Iterator<Item = &'a StringKey>,
) -> Vec<String> {
use std::collections::HashMap;
// Map from key string slice (&str) to a boolean indicating if the parent holds it EXCLUSIVELY.
let mut parent_map = HashMap::new();
for key in parent_keys {
match key {
StringKey::Exclusive(k) => {
parent_map.insert(k.as_str(), true);
}
StringKey::Share(k) => {
parent_map.entry(k.as_str()).or_insert(false);
}
}
}
child_keys
.filter_map(|child_key| match child_key {
StringKey::Exclusive(k) | StringKey::Share(k)
if parent_map.get(k.as_str()) == Some(&true) =>
{
Some(k.clone())
}
StringKey::Exclusive(k) if parent_map.get(k.as_str()) == Some(&false) => {
Some(k.clone())
}
_ => None,
})
.collect()
}
pub(crate) struct Runner {
pub(crate) meta: ProcedureMetaRef,
pub(crate) procedure: BoxedProcedure,
@@ -512,6 +553,41 @@ impl Runner {
async fn on_suspended(&mut self, subprocedures: Vec<ProcedureWithId>) {
let has_child = !subprocedures.is_empty();
// Pre-check: detect potential deadlocks BEFORE submitting any subprocedure.
// If a child shares conflicting lock keys with the parent, submitting it would
// cause a Hold-and-Wait deadlock — the child blocks on lock acquisition while
// the parent holds the lock and waits for the child to finish.
for sub in &subprocedures {
let conflicting = find_lock_conflicts(
self.meta.lock_key.keys_to_lock(),
sub.procedure.lock_key().keys_to_lock(),
);
if !conflicting.is_empty() {
let err_msg = format!(
"Deadlock prevented: subprocedure {}-{} shares conflicting lock key(s) {:?} \
with parent {}-{}. Parent holds these locks and would wait for child \
completion, but child cannot acquire them.",
sub.procedure.type_name(),
sub.id,
conflicting,
self.procedure.type_name(),
self.meta.id,
);
error!("{}", err_msg);
let err = Arc::new(Error::external(PlainError::new(
err_msg,
StatusCode::Internal,
)));
if self.procedure.rollback_supported() {
self.meta.set_state(ProcedureState::prepare_rollback(err));
} else {
self.meta.set_state(ProcedureState::failed(err));
}
return;
}
}
for subprocedure in subprocedures {
info!(
"Procedure {}-{} submit subprocedure {}-{}",
@@ -1939,4 +2015,169 @@ mod tests {
join_all(tasks).await;
assert_eq!(shared_atomic_value.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn test_on_suspend_deadlock_detected_no_rollback() {
// Parent holds Exclusive("catalog.schema.table"), child also requests Exclusive("catalog.schema.table").
// Since parent does NOT support rollback, state should become Failed.
let child_id = ProcedureId::random();
let exec_fn = move |_| {
async move {
let child_exec_fn = |_| async { Ok(Status::done()) }.boxed();
let child = ProcedureAdapter {
data: "child".to_string(),
lock_key: LockKey::single_exclusive("catalog.schema.table"),
poison_keys: PoisonKeys::default(),
exec_fn: child_exec_fn,
rollback_fn: None,
};
Ok(Status::Suspended {
subprocedures: vec![ProcedureWithId {
id: child_id,
procedure: Box::new(child),
}],
persist: false,
})
}
.boxed()
};
let parent = ProcedureAdapter {
data: "parent".to_string(),
lock_key: LockKey::single_exclusive("catalog.schema.table"),
poison_keys: PoisonKeys::default(),
exec_fn,
rollback_fn: None, // No rollback support
};
let dir = create_temp_dir("deadlock_no_rollback");
let meta = parent.new_meta(ROOT_ID);
let ctx = context_without_provider(meta.id);
let object_store = test_util::new_object_store(&dir);
let procedure_store = Arc::new(ProcedureStore::from_object_store(object_store.clone()));
let mut runner = new_runner(meta.clone(), Box::new(parent), procedure_store);
runner.manager_ctx.start();
runner.execute_once(&ctx).await;
let state = runner.meta.state();
assert!(state.is_failed(), "Expected Failed, got {state:?}");
// Verify the error exists
assert!(
state.error().is_some(),
"Failed state should contain an error"
);
// Child should NOT have been submitted
assert!(
!runner.manager_ctx.contains_procedure(child_id),
"Child procedure should not be submitted when deadlock is detected"
);
}
#[tokio::test]
async fn test_on_suspend_deadlock_detected_with_rollback() {
// Parent holds Exclusive("catalog.schema.table"), child also requests Exclusive("catalog.schema.table").
// Since parent DOES support rollback, state should become PrepareRollback.
let child_id = ProcedureId::random();
let exec_fn = move |_| {
async move {
let child_exec_fn = |_| async { Ok(Status::done()) }.boxed();
let child = ProcedureAdapter {
data: "child".to_string(),
lock_key: LockKey::single_exclusive("catalog.schema.table"),
poison_keys: PoisonKeys::default(),
exec_fn: child_exec_fn,
rollback_fn: None,
};
Ok(Status::Suspended {
subprocedures: vec![ProcedureWithId {
id: child_id,
procedure: Box::new(child),
}],
persist: false,
})
}
.boxed()
};
let rollback_fn = move |_| async move { Ok(()) }.boxed();
let parent = ProcedureAdapter {
data: "parent".to_string(),
lock_key: LockKey::single_exclusive("catalog.schema.table"),
poison_keys: PoisonKeys::default(),
exec_fn,
rollback_fn: Some(Box::new(rollback_fn)), // Supports rollback
};
let dir = create_temp_dir("deadlock_with_rollback");
let meta = parent.new_meta(ROOT_ID);
let ctx = context_without_provider(meta.id);
let object_store = test_util::new_object_store(&dir);
let procedure_store = Arc::new(ProcedureStore::from_object_store(object_store.clone()));
let mut runner = new_runner(meta.clone(), Box::new(parent), procedure_store);
runner.manager_ctx.start();
runner.execute_once(&ctx).await;
let state = runner.meta.state();
assert!(
state.is_prepare_rollback(),
"Expected PrepareRollback, got {state:?}"
);
// Verify the error exists in PrepareRollback variant
match &state {
ProcedureState::PrepareRollback { error } => {
assert!(!error.to_string().is_empty(), "Error should not be empty");
}
_ => panic!("Expected PrepareRollback, got {state:?}"),
}
// Child should NOT have been submitted
assert!(
!runner.manager_ctx.contains_procedure(child_id),
"Child procedure should not be submitted when deadlock is detected"
);
}
#[test]
fn test_find_lock_conflicts() {
use crate::procedure::StringKey;
// 1. Share + Share = No conflict (Compatible)
let parent = [StringKey::Share("A".to_string())];
let child = [StringKey::Share("A".to_string())];
assert!(super::find_lock_conflicts(parent.iter(), child.iter()).is_empty());
// 2. Share + Exclusive = Conflict
let parent = [StringKey::Share("A".to_string())];
let child = [StringKey::Exclusive("A".to_string())];
assert_eq!(
super::find_lock_conflicts(parent.iter(), child.iter()),
vec!["A".to_string()]
);
// 3. Exclusive + Share = Conflict
let parent = [StringKey::Exclusive("A".to_string())];
let child = [StringKey::Share("A".to_string())];
assert_eq!(
super::find_lock_conflicts(parent.iter(), child.iter()),
vec!["A".to_string()]
);
// 4. Exclusive + Exclusive = Conflict
let parent = [StringKey::Exclusive("A".to_string())];
let child = [StringKey::Exclusive("A".to_string())];
assert_eq!(
super::find_lock_conflicts(parent.iter(), child.iter()),
vec!["A".to_string()]
);
// 5. Multiple keys, partial overlap
let parent = [
StringKey::Share("A".to_string()),
StringKey::Exclusive("B".to_string()),
];
let child = [
StringKey::Exclusive("A".to_string()), // Conflict with Share("A")
StringKey::Share("B".to_string()), // Conflict with Exclusive("B")
StringKey::Exclusive("C".to_string()), // No conflict, parent doesn't hold C
];
let mut conflicts = super::find_lock_conflicts(parent.iter(), child.iter());
conflicts.sort();
assert_eq!(conflicts, vec!["A".to_string(), "B".to_string()]);
}
}

View File

@@ -27,7 +27,16 @@ static GREPTIME_TIMESTAMP_CELL: OnceCell<String> = OnceCell::new();
static GREPTIME_VALUE_CELL: OnceCell<String> = OnceCell::new();
pub fn set_default_prefix(prefix: Option<&str>) -> Result<()> {
match prefix {
// Strip surrounding double quotes as a defensive measure against upstream
// sources (scripts, CI, template engines, incorrect shell escaping) that may
// pass literal `""` as the value instead of an empty string.
let stripped = prefix.map(|s| {
s.strip_prefix('"')
.and_then(|s| s.strip_suffix('"'))
.unwrap_or(s)
});
match stripped {
None => {
// use default greptime prefix
GREPTIME_TIMESTAMP_CELL.get_or_init(|| GREPTIME_TIMESTAMP.to_string());
@@ -70,3 +79,45 @@ const GREPTIME_VALUE: &str = "greptime_value";
pub const GREPTIME_COUNT: &str = "greptime_count";
/// Default physical table name
pub const GREPTIME_PHYSICAL_TABLE: &str = "greptime_physical_table";
#[cfg(test)]
mod tests {
use super::*;
// Each test runs in a separate process via `cargo nextest`, so OnceCell
// state does not leak between tests.
#[test]
fn test_set_default_prefix_none() {
set_default_prefix(None).unwrap();
assert_eq!(greptime_timestamp(), "greptime_timestamp");
assert_eq!(greptime_value(), "greptime_value");
}
#[test]
fn test_set_default_prefix_empty_string() {
set_default_prefix(Some("")).unwrap();
assert_eq!(greptime_timestamp(), "timestamp");
assert_eq!(greptime_value(), "value");
}
#[test]
fn test_set_default_prefix_quoted_empty() {
// Handles upstream sources that pass literal `""` instead of an empty string
set_default_prefix(Some("\"\"")).unwrap();
assert_eq!(greptime_timestamp(), "timestamp");
assert_eq!(greptime_value(), "value");
}
#[test]
fn test_set_default_prefix_custom() {
set_default_prefix(Some("mydb")).unwrap();
assert_eq!(greptime_timestamp(), "mydb_timestamp");
assert_eq!(greptime_value(), "mydb_value");
}
#[test]
fn test_set_default_prefix_invalid() {
assert!(set_default_prefix(Some("invalid prefix!")).is_err());
}
}

View File

@@ -16,8 +16,8 @@ mod column_schema;
pub mod constraint;
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::{fmt, mem};
use arrow::datatypes::{Field, Schema as ArrowSchema};
use datafusion_common::DFSchemaRef;
@@ -177,6 +177,26 @@ impl Schema {
&self.arrow_schema.metadata
}
/// Returns the estimated memory footprint of this schema.
pub fn estimated_size(&self) -> usize {
mem::size_of_val(self)
+ mem::size_of::<ColumnSchema>() * self.column_schemas.capacity()
+ self
.column_schemas
.iter()
.map(|column_schema| {
column_schema.estimated_size() - mem::size_of::<ColumnSchema>()
})
.sum::<usize>()
+ mem::size_of::<(String, usize)>() * self.name_to_index.capacity()
+ self
.name_to_index
.keys()
.map(|name| name.capacity())
.sum::<usize>()
+ arrow_schema_size(self.arrow_schema.as_ref())
}
/// Generate a new projected schema
///
/// # Panic
@@ -213,6 +233,17 @@ impl Schema {
}
}
fn arrow_schema_size(schema: &ArrowSchema) -> usize {
mem::size_of_val(schema)
+ schema.fields.size()
+ mem::size_of::<(String, String)>() * schema.metadata.capacity()
+ schema
.metadata
.iter()
.map(|(key, value)| key.capacity() + value.capacity())
.sum::<usize>()
}
#[derive(Default)]
pub struct SchemaBuilder {
column_schemas: Vec<ColumnSchema>,

View File

@@ -13,8 +13,8 @@
// limitations under the License.
use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
use std::{fmt, mem};
use arrow::datatypes::Field;
use arrow_schema::extension::{
@@ -178,6 +178,19 @@ impl ColumnSchema {
self
}
/// Returns the estimated memory footprint of this schema.
pub fn estimated_size(&self) -> usize {
mem::size_of_val(self) - mem::size_of_val(&self.data_type)
+ self.data_type.as_arrow_type().size()
+ self.name.capacity()
+ self
.default_constraint
.as_ref()
.map(column_default_constraint_size)
.unwrap_or_default()
+ metadata_size(&self.metadata)
}
/// Set the inverted index for the column.
/// Similar to [with_inverted_index] but don't take the ownership.
///
@@ -493,6 +506,21 @@ impl ColumnSchema {
}
}
fn metadata_size(metadata: &Metadata) -> usize {
mem::size_of::<(String, String)>() * metadata.capacity()
+ metadata
.iter()
.map(|(key, value)| key.capacity() + value.capacity())
.sum::<usize>()
}
fn column_default_constraint_size(default_constraint: &ColumnDefaultConstraint) -> usize {
match default_constraint {
ColumnDefaultConstraint::Function(expr) => expr.capacity(),
ColumnDefaultConstraint::Value(value) => value.as_value_ref().data_size(),
}
}
/// Column extended type set in column schema's metadata.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ColumnExtType {

View File

@@ -396,7 +396,7 @@ pub fn jsonb_to_string(val: &[u8]) -> Result<String> {
match jsonb::from_slice(val) {
Ok(jsonb_value) => {
let serialized = jsonb_value.to_string();
Ok(serialized)
fix_unicode_point(&serialized)
}
Err(e) => InvalidJsonbSnafu { error: e }.fail(),
}
@@ -405,18 +405,12 @@ pub fn jsonb_to_string(val: &[u8]) -> Result<String> {
/// Converts a json type value to serde_json::Value
pub fn jsonb_to_serde_json(val: &[u8]) -> Result<serde_json::Value> {
let json_string = jsonb_to_string(val)?;
jsonb_string_to_serde_value(&json_string)
serde_json::Value::from_str(&json_string).context(DeserializeSnafu { json: json_string })
}
/// Attempts to deserialize a JSON text into `serde_json::Value`, with a best-effort
/// fallback for Rust-style Unicode escape sequences.
/// Normalizes a JSON string by converting Rust-style Unicode escape sequences to JSON-compatible format.
///
/// This function is intended to be used on JSON strings produced from the internal
/// JSONB representation (e.g. via [`jsonb_to_string`]). It first calls
/// `serde_json::Value::from_str` directly. If that succeeds, the parsed value is
/// returned as-is.
///
/// If the initial parse fails, the input is scanned for Rust-style Unicode code
/// The input is scanned for Rust-style Unicode code
/// point escapes of the form `\\u{H...}` (a backslash, `u`, an opening brace,
/// followed by 16 hexadecimal digits, and a closing brace). Each such escape is
/// converted into JSON-compatible UTF16 escape sequences:
@@ -427,59 +421,44 @@ pub fn jsonb_to_serde_json(val: &[u8]) -> Result<serde_json::Value> {
/// the code point is encoded as a UTF16 surrogate pair and emitted as two consecutive
/// `\\uXXXX` sequences (as JSON format required).
///
/// After this normalization, the function retries parsing the resulting string as
/// JSON and returns the deserialized value or a `DeserializeSnafu` error if it
/// still cannot be parsed.
fn jsonb_string_to_serde_value(json: &str) -> Result<serde_json::Value> {
match serde_json::Value::from_str(json) {
Ok(v) => Ok(v),
Err(e) => {
// If above deserialization is failed, the JSON string might contain some Rust chars
// that are somehow incorrectly represented as Unicode code point literal. For example,
// "\u{fe0f}". We have to convert them to JSON compatible format, like "\uFE0F", then
// try to deserialize the JSON string again.
if !e.is_syntax() || !e.to_string().contains("invalid escape") {
return Err(e).context(DeserializeSnafu { json });
}
/// After this normalization, the function returns the normalized string
fn fix_unicode_point(json: &str) -> Result<String> {
static UNICODE_CODE_POINT_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
// Match literal "\u{...}" sequences, capturing 16 (code point range) hex digits
// inside braces.
Regex::new(r"\\u\{([0-9a-fA-F]{1,6})}").unwrap_or_else(|e| panic!("{}", e))
});
static UNICODE_CODE_POINT_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
// Match literal "\u{...}" sequences, capturing 16 (code point range) hex digits
// inside braces.
Regex::new(r"\\u\{([0-9a-fA-F]{1,6})}").unwrap_or_else(|e| panic!("{}", e))
});
let v = UNICODE_CODE_POINT_PATTERN.replace_all(json, |caps: &Captures| {
// Extract the hex payload (without braces) and parse to a code point.
let hex = &caps[1];
let Ok(code) = u32::from_str_radix(hex, 16) else {
// On parse failure, leave the original escape sequence unchanged.
return caps[0].to_string();
};
let v = UNICODE_CODE_POINT_PATTERN.replace_all(json, |caps: &Captures| {
// Extract the hex payload (without braces) and parse to a code point.
let hex = &caps[1];
let Ok(code) = u32::from_str_radix(hex, 16) else {
// On parse failure, leave the original escape sequence unchanged.
return caps[0].to_string();
};
if code <= 0xFFFF {
// Basic Multilingual Plane: JSON can represent this directly as \uXXXX.
format!("\\u{:04X}", code)
} else if code > 0x10FFFF {
// Beyond max Unicode code point
caps[0].to_string()
} else {
// Supplementary planes: JSON needs UTF-16 surrogate pairs.
// Convert the code point to a 20-bit value.
let code = code - 0x10000;
if code <= 0xFFFF {
// Basic Multilingual Plane: JSON can represent this directly as \uXXXX.
format!("\\u{:04X}", code)
} else if code > 0x10FFFF {
// Beyond max Unicode code point
caps[0].to_string()
} else {
// Supplementary planes: JSON needs UTF-16 surrogate pairs.
// Convert the code point to a 20-bit value.
let code = code - 0x10000;
// High surrogate: top 10 bits, offset by 0xD800.
let high = 0xD800 + ((code >> 10) & 0x3FF);
// High surrogate: top 10 bits, offset by 0xD800.
let high = 0xD800 + ((code >> 10) & 0x3FF);
// Low surrogate: bottom 10 bits, offset by 0xDC00.
let low = 0xDC00 + (code & 0x3FF);
// Low surrogate: bottom 10 bits, offset by 0xDC00.
let low = 0xDC00 + (code & 0x3FF);
// Emit two \uXXXX escapes in sequence.
format!("\\u{:04X}\\u{:04X}", high, low)
}
});
serde_json::Value::from_str(&v).context(DeserializeSnafu { json })
// Emit two \uXXXX escapes in sequence.
format!("\\u{:04X}\\u{:04X}", high, low)
}
}
});
Ok(v.to_string())
}
/// Parses a string to a json type value
@@ -495,45 +474,54 @@ mod tests {
use crate::json::JsonStructureSettings;
#[test]
fn test_jsonb_string_to_serde_value() -> Result<()> {
fn test_fix_unicode_point() -> Result<()> {
let valid_cases = vec![
(r#"{"data": "simple ascii"}"#, r#"{"data":"simple ascii"}"#),
(r#"{"data": "simple ascii"}"#, r#"{"data": "simple ascii"}"#),
(
r#"{"data": "Greek sigma: \u{03a3}"}"#,
r#"{"data":"Greek sigma: Σ"}"#,
r#"{"data":"Greek sigma: \u{03a3}"}"#,
r#"{"data":"Greek sigma: \u03A3"}"#,
),
(
r#"{"data": "Joker card: \u{1f0df}"}"#,
r#"{"data":"Joker card: 🃟"}"#,
r#"{"data":"Joker card: \u{1f0df}"}"#,
r#"{"data":"Joker card: \uD83C\uDCDF"}"#,
),
(
r#"{"data": "BMP boundary: \u{ffff}"}"#,
r#"{"data":"BMP boundary: ￿"}"#,
r#"{"data":"BMP boundary: \u{ffff}"}"#,
r#"{"data":"BMP boundary: \uFFFF"}"#,
),
(
r#"{"data": "Supplementary min: \u{10000}"}"#,
r#"{"data":"Supplementary min: 𐀀"}"#,
r#"{"data":"Supplementary min: \u{10000}"}"#,
r#"{"data":"Supplementary min: \uD800\uDC00"}"#,
),
(
r#"{"data": "Supplementary max: \u{10ffff}"}"#,
r#"{"data":"Supplementary max: 􏿿"}"#,
r#"{"data":"Supplementary max: \u{10ffff}"}"#,
r#"{"data":"Supplementary max: \uDBFF\uDFFF"}"#,
),
];
for (input, expect) in valid_cases {
let v = jsonb_string_to_serde_value(input)?;
assert_eq!(v.to_string(), expect);
let v = fix_unicode_point(input)?;
assert_eq!(v, expect);
}
let invalid_cases = vec![
r#"{"data": "Invalid hex: \u{gggg}"}"#,
r#"{"data": "Beyond max Unicode code point: \u{110000}"}"#,
r#"{"data": "Out of range: \u{1100000}"}"#, // 7 digit
r#"{"data": "Empty braces: \u{}"}"#,
let invalid_escape_cases = vec![
(
r#"{"data": "Invalid hex: \u{gggg}"}"#,
r#"{"data": "Invalid hex: \u{gggg}"}"#,
),
(
r#"{"data": "Empty braces: \u{}"}"#,
r#"{"data": "Empty braces: \u{}"}"#,
),
(
r#"{"data": "Out of range: \u{1100000}"}"#,
r#"{"data": "Out of range: \u{1100000}"}"#,
),
];
for input in invalid_cases {
let result = jsonb_string_to_serde_value(input);
assert!(result.is_err());
for (input, expect) in invalid_escape_cases {
let v = fix_unicode_point(input)?;
assert_eq!(v, expect);
}
Ok(())
}

View File

@@ -16,30 +16,19 @@
#![warn(unused)]
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::sync::Arc;
use common_error::ext::BoxedError;
use common_telemetry::debug;
use datafusion::config::ConfigOptions;
use datafusion::error::DataFusionError;
use datafusion::functions_aggregate::count::count_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
use datafusion::optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
use datafusion::optimizer::optimize_projections::OptimizeProjections;
use datafusion::optimizer::simplify_expressions::SimplifyExpressions;
use datafusion::optimizer::utils::NamePreserver;
use datafusion::optimizer::{Analyzer, AnalyzerRule, Optimizer, OptimizerContext};
use datafusion_common::tree_node::{
Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor,
};
use datafusion_common::{Column, DFSchema, ScalarValue};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{
BinaryExpr, ColumnarValue, Expr, Literal, Operator, Projection, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor};
use query::QueryEngine;
use query::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
use query::parser::QueryLanguageParser;
@@ -52,7 +41,6 @@ use substrait::DFLogicalSubstraitConvertor;
use crate::adapter::FlownodeContext;
use crate::error::{DatafusionSnafu, Error, ExternalSnafu, UnexpectedSnafu};
use crate::expr::{TUMBLE_END, TUMBLE_START};
use crate::plan::TypedPlan;
// TODO(discord9): use `Analyzer` to manage rules if more `AnalyzerRule` is needed
@@ -63,8 +51,6 @@ pub async fn apply_df_optimizer(
let cfg = query_ctx.create_config_options();
let analyzer = Analyzer::with_rules(vec![
Arc::new(CountWildcardToTimeIndexRule),
Arc::new(AvgExpandRule),
Arc::new(TumbleExpandRule),
Arc::new(CheckGroupByRule::new()),
Arc::new(TypeCoercion::new()),
]);
@@ -127,390 +113,6 @@ pub async fn sql_to_flow_plan(
Ok(flow_plan)
}
#[derive(Debug)]
struct AvgExpandRule;
impl AnalyzerRule for AvgExpandRule {
fn analyze(
&self,
plan: datafusion_expr::LogicalPlan,
_config: &ConfigOptions,
) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
let transformed = plan
.transform_up_with_subqueries(expand_avg_analyzer)?
.data
.transform_down_with_subqueries(put_aggr_to_proj_analyzer)?
.data;
Ok(transformed)
}
fn name(&self) -> &str {
"avg_expand"
}
}
/// lift aggr's composite aggr_expr to outer proj, and leave aggr only with simple direct aggr expr
/// i.e.
/// ```ignore
/// proj: avg(x)
/// -- aggr: [sum(x)/count(x) as avg(x)]
/// ```
/// becomes:
/// ```ignore
/// proj: sum(x)/count(x) as avg(x)
/// -- aggr: [sum(x), count(x)]
/// ```
fn put_aggr_to_proj_analyzer(
plan: datafusion_expr::LogicalPlan,
) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
if let datafusion_expr::LogicalPlan::Projection(proj) = &plan
&& let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref()
{
let mut replace_old_proj_exprs = HashMap::new();
let mut expanded_aggr_exprs = vec![];
for aggr_expr in &aggr.aggr_expr {
let mut is_composite = false;
if let Expr::AggregateFunction(_) = &aggr_expr {
expanded_aggr_exprs.push(aggr_expr.clone());
} else {
let old_name = aggr_expr.name_for_alias()?;
let new_proj_expr = aggr_expr
.clone()
.transform(|ch| {
if let Expr::AggregateFunction(_) = &ch {
is_composite = true;
expanded_aggr_exprs.push(ch.clone());
Ok(Transformed::yes(Expr::Column(Column::from_qualified_name(
ch.name_for_alias()?,
))))
} else {
Ok(Transformed::no(ch))
}
})?
.data;
replace_old_proj_exprs.insert(old_name, new_proj_expr);
}
}
if expanded_aggr_exprs.len() > aggr.aggr_expr.len() {
let mut aggr = aggr.clone();
aggr.aggr_expr = expanded_aggr_exprs;
let mut aggr_plan = datafusion_expr::LogicalPlan::Aggregate(aggr);
// important to recompute schema after changing aggr_expr
aggr_plan = aggr_plan.recompute_schema()?;
// reconstruct proj with new proj_exprs
let mut new_proj_exprs = proj.expr.clone();
for proj_expr in new_proj_exprs.iter_mut() {
if let Some(new_proj_expr) =
replace_old_proj_exprs.get(&proj_expr.name_for_alias()?)
{
*proj_expr = new_proj_expr.clone();
}
*proj_expr = proj_expr
.clone()
.transform(|expr| {
if let Some(new_expr) = replace_old_proj_exprs.get(&expr.name_for_alias()?)
{
Ok(Transformed::yes(new_expr.clone()))
} else {
Ok(Transformed::no(expr))
}
})?
.data;
}
let proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
new_proj_exprs,
Arc::new(aggr_plan),
)?);
return Ok(Transformed::yes(proj));
}
}
Ok(Transformed::no(plan))
}
/// expand `avg(<expr>)` function into `cast(sum((<expr>) AS f64)/count((<expr>)`
fn expand_avg_analyzer(
plan: datafusion_expr::LogicalPlan,
) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
let mut schema = merge_schema(&plan.inputs());
if let datafusion_expr::LogicalPlan::TableScan(ts) = &plan {
let source_schema =
DFSchema::try_from_qualified_schema(ts.table_name.clone(), &ts.source.schema())?;
schema.merge(&source_schema);
}
let mut expr_rewrite = ExpandAvgRewriter::new(&schema);
let name_preserver = NamePreserver::new(&plan);
// apply coercion rewrite all expressions in the plan individually
plan.map_expressions(|expr| {
let original_name = name_preserver.save(&expr);
Ok(expr
.rewrite(&mut expr_rewrite)?
.update_data(|expr| original_name.restore(expr)))
})?
.map_data(|plan| plan.recompute_schema())
}
/// rewrite `avg(<expr>)` function into `CASE WHEN count(<expr>) !=0 THEN cast(sum((<expr>) AS avg_return_type)/count((<expr>) ELSE 0`
///
/// TODO(discord9): support avg return type decimal128
///
/// see impl details at https://github.com/apache/datafusion/blob/4ad4f90d86c57226a4e0fb1f79dfaaf0d404c273/datafusion/expr/src/type_coercion/aggregates.rs#L457-L462
pub(crate) struct ExpandAvgRewriter<'a> {
/// schema of the plan
#[allow(unused)]
pub(crate) schema: &'a DFSchema,
}
impl<'a> ExpandAvgRewriter<'a> {
fn new(schema: &'a DFSchema) -> Self {
Self { schema }
}
}
impl TreeNodeRewriter for ExpandAvgRewriter<'_> {
type Node = Expr;
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>, DataFusionError> {
if let Expr::AggregateFunction(aggr_func) = &expr
&& aggr_func.func.name() == "avg"
{
let sum_expr = {
let mut tmp = aggr_func.clone();
tmp.func = sum_udaf();
Expr::AggregateFunction(tmp)
};
let sum_cast = {
let mut tmp = sum_expr.clone();
tmp = Expr::Cast(datafusion_expr::Cast {
expr: Box::new(tmp),
data_type: arrow_schema::DataType::Float64,
});
tmp
};
let count_expr = {
let mut tmp = aggr_func.clone();
tmp.func = count_udaf();
Expr::AggregateFunction(tmp)
};
let count_expr_ref =
Expr::Column(Column::from_qualified_name(count_expr.name_for_alias()?));
let div = BinaryExpr::new(Box::new(sum_cast), Operator::Divide, Box::new(count_expr));
let div_expr = Box::new(Expr::BinaryExpr(div));
let zero = Box::new(0.lit());
let not_zero = BinaryExpr::new(Box::new(count_expr_ref), Operator::NotEq, zero.clone());
let not_zero = Box::new(Expr::BinaryExpr(not_zero));
let null = Box::new(Expr::Literal(ScalarValue::Null, None));
let case_when =
datafusion_expr::Case::new(None, vec![(not_zero, div_expr)], Some(null));
let case_when_expr = Expr::Case(case_when);
return Ok(Transformed::yes(case_when_expr));
}
Ok(Transformed::no(expr))
}
}
/// expand tumble in aggr expr to tumble_start and tumble_end with column name like `window_start`
#[derive(Debug)]
struct TumbleExpandRule;
impl AnalyzerRule for TumbleExpandRule {
fn analyze(
&self,
plan: datafusion_expr::LogicalPlan,
_config: &ConfigOptions,
) -> datafusion_common::Result<datafusion_expr::LogicalPlan> {
let transformed = plan
.transform_up_with_subqueries(expand_tumble_analyzer)?
.data;
Ok(transformed)
}
fn name(&self) -> &str {
"tumble_expand"
}
}
/// expand `tumble` in aggr expr to `tumble_start` and `tumble_end`, also expand related alias and column ref
///
/// will add `tumble_start` and `tumble_end` to outer projection if not exist before
fn expand_tumble_analyzer(
plan: datafusion_expr::LogicalPlan,
) -> Result<Transformed<datafusion_expr::LogicalPlan>, DataFusionError> {
if let datafusion_expr::LogicalPlan::Projection(proj) = &plan
&& let datafusion_expr::LogicalPlan::Aggregate(aggr) = proj.input.as_ref()
{
let mut new_group_expr = vec![];
let mut alias_to_expand = HashMap::new();
let mut encountered_tumble = false;
for expr in aggr.group_expr.iter() {
match expr {
datafusion_expr::Expr::ScalarFunction(func) if func.name() == "tumble" => {
encountered_tumble = true;
let tumble_start = TumbleExpand::new(TUMBLE_START);
let tumble_start = datafusion_expr::expr::ScalarFunction::new_udf(
Arc::new(tumble_start.into()),
func.args.clone(),
);
let tumble_start = datafusion_expr::Expr::ScalarFunction(tumble_start);
let start_col_name = tumble_start.name_for_alias()?;
new_group_expr.push(tumble_start);
let tumble_end = TumbleExpand::new(TUMBLE_END);
let tumble_end = datafusion_expr::expr::ScalarFunction::new_udf(
Arc::new(tumble_end.into()),
func.args.clone(),
);
let tumble_end = datafusion_expr::Expr::ScalarFunction(tumble_end);
let end_col_name = tumble_end.name_for_alias()?;
new_group_expr.push(tumble_end);
alias_to_expand.insert(expr.name_for_alias()?, (start_col_name, end_col_name));
}
_ => new_group_expr.push(expr.clone()),
}
}
if !encountered_tumble {
return Ok(Transformed::no(plan));
}
let mut new_aggr = aggr.clone();
new_aggr.group_expr = new_group_expr;
let new_aggr = datafusion_expr::LogicalPlan::Aggregate(new_aggr).recompute_schema()?;
// replace alias in projection if needed, and add new column ref if necessary
let mut new_proj_expr = vec![];
let mut have_expanded = false;
for proj_expr in proj.expr.iter() {
if let Some((start_col_name, end_col_name)) =
alias_to_expand.get(&proj_expr.name_for_alias()?)
{
let start_col = Column::from_qualified_name(start_col_name);
let end_col = Column::from_qualified_name(end_col_name);
new_proj_expr.push(datafusion_expr::Expr::Column(start_col));
new_proj_expr.push(datafusion_expr::Expr::Column(end_col));
have_expanded = true;
} else {
new_proj_expr.push(proj_expr.clone());
}
}
// append to end of projection if not exist
if !have_expanded {
for (start_col_name, end_col_name) in alias_to_expand.values() {
let start_col = Column::from_qualified_name(start_col_name);
let end_col = Column::from_qualified_name(end_col_name);
new_proj_expr.push(datafusion_expr::Expr::Column(start_col).alias("window_start"));
new_proj_expr.push(datafusion_expr::Expr::Column(end_col).alias("window_end"));
}
}
let new_proj = datafusion_expr::LogicalPlan::Projection(Projection::try_new(
new_proj_expr,
Arc::new(new_aggr),
)?);
return Ok(Transformed::yes(new_proj));
}
Ok(Transformed::no(plan))
}
/// This is a placeholder for tumble_start and tumble_end function, so that datafusion can
/// recognize them as scalar function
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TumbleExpand {
signature: Signature,
name: String,
}
impl TumbleExpand {
pub fn new(name: &str) -> Self {
Self {
signature: Signature::new(TypeSignature::UserDefined, Volatility::Immutable),
name: name.to_string(),
}
}
}
impl ScalarUDFImpl for TumbleExpand {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn name(&self) -> &str {
&self.name
}
/// elide the signature for now
fn signature(&self) -> &Signature {
&self.signature
}
fn coerce_types(
&self,
arg_types: &[arrow_schema::DataType],
) -> datafusion_common::Result<Vec<arrow_schema::DataType>> {
match (arg_types.first(), arg_types.get(1), arg_types.get(2)) {
(Some(ts), Some(window), opt) => {
use arrow_schema::DataType::*;
if !matches!(ts, Date32 | Timestamp(_, _)) {
return Err(DataFusionError::Plan(
format!("Expect timestamp column as first arg for tumble_start, found {:?}", ts)
));
}
if !matches!(window, Utf8 | Interval(_)) {
return Err(DataFusionError::Plan(
format!("Expect second arg for window size's type being interval for tumble_start, found {:?}", window),
));
}
if let Some(start_time) = opt
&& !matches!(start_time, Utf8 | Date32 | Timestamp(_, _)){
return Err(DataFusionError::Plan(
format!("Expect start_time to either be date, timestamp or string, found {:?}", start_time)
));
}
Ok(arg_types.to_vec())
}
_ => Err(DataFusionError::Plan(
"Expect tumble function have at least two arg(timestamp column and window size) and a third optional arg for starting time".to_string(),
)),
}
}
fn return_type(
&self,
arg_types: &[arrow_schema::DataType],
) -> Result<arrow_schema::DataType, DataFusionError> {
arg_types.first().cloned().ok_or_else(|| {
DataFusionError::Plan(
"Expect tumble function have at least two arg(timestamp column and window size)"
.to_string(),
)
})
}
fn invoke_with_args(
&self,
_args: ScalarFunctionArgs,
) -> datafusion_common::Result<ColumnarValue> {
Err(DataFusionError::Plan(
"This function should not be executed by datafusion".to_string(),
))
}
}
/// This rule check all group by exprs, and make sure they are also in select clause in a aggr query
#[derive(Debug)]
struct CheckGroupByRule {}

View File

@@ -382,10 +382,9 @@ impl TypedPlan {
#[cfg(test)]
mod test {
use std::time::Duration;
use bytes::BytesMut;
use common_time::{IntervalMonthDayNano, Timestamp};
use common_time::IntervalMonthDayNano;
use datatypes::data_type::ConcreteDataType as CDT;
use datatypes::prelude::ConcreteDataType;
use datatypes::value::Value;
@@ -397,898 +396,6 @@ mod test {
use crate::repr::{ColumnType, RelationType};
use crate::transform::test::{create_test_ctx, create_test_query_engine, sql_to_substrait};
#[tokio::test]
async fn test_df_func_basic() {
let engine = create_test_query_engine();
let sql = "SELECT sum(abs(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
.await
.unwrap();
let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt64,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected =
TypedPlan {
schema: RelationType::new(vec![
ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
])
.with_key(vec![2])
.with_time_index(Some(1))
.into_named(vec![
Some("sum(abs(numbers_with_ts.number))".to_string()),
Some("window_start".to_string()),
Some("window_end".to_string()),
]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
])
.into_named(vec![
Some("number".to_string()),
Some("ts".to_string()),
]),
)
.mfp(MapFilterProject::new(2).into_safe())
.unwrap(),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowFloor {
window_size: Duration::from_nanos(1_000_000_000),
start_time: Some(Timestamp::new_millisecond(
1625097600000,
)),
},
),
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowCeiling {
window_size: Duration::from_nanos(1_000_000_000),
start_time: Some(Timestamp::new_millisecond(
1625097600000,
)),
},
),
])
.unwrap()
.project(vec![2, 3])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.map(vec![ScalarExpr::CallDf {
df_scalar_fn: DfScalarFunction::try_from_raw_fn(
RawDfScalarFn {
f: BytesMut::from(
b"\x08\x02\"\x08\x1a\x06\x12\x04\n\x02\x12\0"
.as_ref(),
),
input_schema: RelationType::new(vec![ColumnType::new(
ConcreteDataType::uint32_datatype(),
false,
)])
.into_unnamed(),
extensions: FunctionExtensions::from_iter(
[
(0, "tumble_start".to_string()),
(1, "tumble_end".to_string()),
(2, "abs".to_string()),
(3, "sum".to_string()),
]
.into_iter(),
),
},
)
.await
.unwrap(),
exprs: vec![ScalarExpr::Column(0)],
}
.cast(CDT::uint64_datatype())])
.unwrap()
.project(vec![2])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
])
.with_key(vec![1])
.with_time_index(Some(0))
.into_unnamed(),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
ScalarExpr::Column(2),
ScalarExpr::Column(0),
ScalarExpr::Column(1),
])
.unwrap()
.project(vec![3, 4, 5])
.unwrap(),
},
};
assert_eq!(flow_plan, expected);
}
#[tokio::test]
async fn test_df_func_expr_tree() {
let engine = create_test_query_engine();
let sql = "SELECT abs(sum(number)) FROM numbers_with_ts GROUP BY tumble(ts, '1 second', '2021-07-01 00:00:00');";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
.await
.unwrap();
let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt64,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
schema: RelationType::new(vec![
ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
])
.with_key(vec![2])
.with_time_index(Some(1))
.into_named(vec![
Some("abs(sum(numbers_with_ts.number))".to_string()),
Some("window_start".to_string()),
Some("window_end".to_string()),
]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
])
.into_named(vec![
Some("number".to_string()),
Some("ts".to_string()),
]),
)
.mfp(MapFilterProject::new(2).into_safe())
.unwrap(),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowFloor {
window_size: Duration::from_nanos(1_000_000_000),
start_time: Some(Timestamp::new_millisecond(
1625097600000,
)),
},
),
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowCeiling {
window_size: Duration::from_nanos(1_000_000_000),
start_time: Some(Timestamp::new_millisecond(
1625097600000,
)),
},
),
])
.unwrap()
.project(vec![2, 3])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
.unwrap()
.project(vec![2])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
])
.with_key(vec![1])
.with_time_index(Some(0))
.into_named(vec![None, None, None]),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
ScalarExpr::CallDf {
df_scalar_fn: DfScalarFunction::try_from_raw_fn(RawDfScalarFn {
f: BytesMut::from(b"\"\x08\x1a\x06\x12\x04\n\x02\x12\0".as_ref()),
input_schema: RelationType::new(vec![ColumnType::new(
ConcreteDataType::uint64_datatype(),
true,
)])
.into_unnamed(),
extensions: FunctionExtensions::from_iter(
[
(0, "abs".to_string()),
(1, "tumble_start".to_string()),
(2, "tumble_end".to_string()),
(3, "sum".to_string()),
]
.into_iter(),
),
})
.await
.unwrap(),
exprs: vec![ScalarExpr::Column(2)],
},
ScalarExpr::Column(0),
ScalarExpr::Column(1),
])
.unwrap()
.project(vec![3, 4, 5])
.unwrap(),
},
};
assert_eq!(flow_plan, expected);
}
/// TODO(discord9): add more illegal sql tests
#[tokio::test]
async fn test_tumble_composite() {
let engine = create_test_query_engine();
let sql =
"SELECT number, avg(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour'), number";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
.await
.unwrap();
let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt64,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::Count,
expr: ScalarExpr::Column(1),
distinct: false,
},
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(4).call_binary(
ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(
ScalarExpr::Column(3)
.cast(CDT::float64_datatype())
.call_binary(
ScalarExpr::Column(4).cast(CDT::float64_datatype()),
BinaryFunc::DivFloat64,
),
),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
};
let expected = TypedPlan {
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
])
.into_named(vec![
Some("number".to_string()),
Some("ts".to_string()),
]),
)
.mfp(MapFilterProject::new(2).into_safe())
.unwrap(),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowFloor {
window_size: Duration::from_nanos(3_600_000_000_000),
start_time: None,
},
),
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowCeiling {
window_size: Duration::from_nanos(3_600_000_000_000),
start_time: None,
},
),
ScalarExpr::Column(0),
])
.unwrap()
.project(vec![2, 3, 4])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
ScalarExpr::Column(0),
])
.unwrap()
.project(vec![2, 3])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
// keys
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start(time index)
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end(pk)
ColumnType::new(CDT::uint32_datatype(), false), // number(pk)
// values
ColumnType::new(CDT::uint64_datatype(), true), // avg.sum(number)
ColumnType::new(CDT::int64_datatype(), true), // avg.count(number)
])
.with_key(vec![1, 2])
.with_time_index(Some(0))
.into_named(vec![
None,
None,
Some("number".to_string()),
None,
None,
]),
),
),
mfp: MapFilterProject::new(5)
.map(vec![
ScalarExpr::Column(2), // number(pk)
avg_expr,
ScalarExpr::Column(0), // window start
ScalarExpr::Column(1), // window end
])
.unwrap()
.project(vec![5, 6, 7, 8])
.unwrap(),
},
schema: RelationType::new(vec![
ColumnType::new(CDT::uint32_datatype(), false), // number
ColumnType::new(CDT::float64_datatype(), true), // avg(number)
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
])
.with_key(vec![0, 3])
.with_time_index(Some(2))
.into_named(vec![
Some("number".to_string()),
Some("avg(numbers_with_ts.number)".to_string()),
Some("window_start".to_string()),
Some("window_end".to_string()),
]),
};
assert_eq!(flow_plan, expected);
}
#[tokio::test]
async fn test_tumble_parse_optional() {
let engine = create_test_query_engine();
let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour')";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
.await
.unwrap();
let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt64,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
schema: RelationType::new(vec![
ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
])
.with_key(vec![2])
.with_time_index(Some(1))
.into_named(vec![
Some("sum(numbers_with_ts.number)".to_string()),
Some("window_start".to_string()),
Some("window_end".to_string()),
]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
])
.into_named(vec![
Some("number".to_string()),
Some("ts".to_string()),
]),
)
.mfp(MapFilterProject::new(2).into_safe())
.unwrap(),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowFloor {
window_size: Duration::from_nanos(3_600_000_000_000),
start_time: None,
},
),
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowCeiling {
window_size: Duration::from_nanos(3_600_000_000_000),
start_time: None,
},
),
])
.unwrap()
.project(vec![2, 3])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
.unwrap()
.project(vec![2])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
])
.with_key(vec![1])
.with_time_index(Some(0))
.into_named(vec![None, None, None]),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
ScalarExpr::Column(2),
ScalarExpr::Column(0),
ScalarExpr::Column(1),
])
.unwrap()
.project(vec![3, 4, 5])
.unwrap(),
},
};
assert_eq!(flow_plan, expected);
}
#[tokio::test]
async fn test_tumble_parse() {
let engine = create_test_query_engine();
let sql = "SELECT sum(number) FROM numbers_with_ts GROUP BY tumble(ts, '1 hour', '2021-07-01 00:00:00')";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
.await
.unwrap();
let aggr_expr = AggregateExpr {
func: AggregateFunc::SumUInt64,
expr: ScalarExpr::Column(0),
distinct: false,
};
let expected = TypedPlan {
schema: RelationType::new(vec![
ColumnType::new(CDT::uint64_datatype(), true), // sum(number)
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
])
.with_key(vec![2])
.with_time_index(Some(1))
.into_named(vec![
Some("sum(numbers_with_ts.number)".to_string()),
Some("window_start".to_string()),
Some("window_end".to_string()),
]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(1)),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false),
ColumnType::new(
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
])
.into_named(vec![
Some("number".to_string()),
Some("ts".to_string()),
]),
)
.mfp(MapFilterProject::new(2).into_safe())
.unwrap(),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(2)
.map(vec![
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowFloor {
window_size: Duration::from_nanos(3_600_000_000_000),
start_time: Some(Timestamp::new_millisecond(
1625097600000,
)),
},
),
ScalarExpr::Column(1).call_unary(
UnaryFunc::TumbleWindowCeiling {
window_size: Duration::from_nanos(3_600_000_000_000),
start_time: Some(Timestamp::new_millisecond(
1625097600000,
)),
},
),
])
.unwrap()
.project(vec![2, 3])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(2)
.map(vec![ScalarExpr::Column(0).cast(CDT::uint64_datatype())])
.unwrap()
.project(vec![2])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: vec![aggr_expr.clone()],
simple_aggrs: vec![AggrWithIndex::new(aggr_expr.clone(), 0, 0)],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window start
ColumnType::new(CDT::timestamp_millisecond_datatype(), true), // window end
ColumnType::new(CDT::uint64_datatype(), true), //sum(number)
])
.with_key(vec![1])
.with_time_index(Some(0))
.into_unnamed(),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
ScalarExpr::Column(2),
ScalarExpr::Column(0),
ScalarExpr::Column(1),
])
.unwrap()
.project(vec![3, 4, 5])
.unwrap(),
},
};
assert_eq!(flow_plan, expected);
}
#[tokio::test]
async fn test_avg_group_by() {
let engine = create_test_query_engine();
let sql = "SELECT avg(number), number FROM numbers GROUP BY number";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan).await;
let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt64,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::Count,
expr: ScalarExpr::Column(1),
distinct: false,
},
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(2).call_binary(
ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(
ScalarExpr::Column(1)
.cast(CDT::float64_datatype())
.call_binary(
ScalarExpr::Column(2).cast(CDT::float64_datatype()),
BinaryFunc::DivFloat64,
),
),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
};
let expected = TypedPlan {
schema: RelationType::new(vec![
ColumnType::new(CDT::float64_datatype(), true), // avg(number: u32) -> f64
ColumnType::new(CDT::uint32_datatype(), false), // number
])
.with_key(vec![1])
.into_named(vec![
Some("avg(numbers.number)".to_string()),
Some("number".to_string()),
]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}
.with_types(
RelationType::new(vec![ColumnType::new(
ConcreteDataType::uint32_datatype(),
false,
)])
.into_named(vec![Some("number".to_string())]),
)
.mfp(
MapFilterProject::new(1)
.project(vec![0])
.unwrap()
.into_safe(),
)
.unwrap(),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.map(vec![ScalarExpr::Column(0)])
.unwrap()
.project(vec![1])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.map(vec![
ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
ScalarExpr::Column(0),
])
.unwrap()
.project(vec![1, 2])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint32_datatype(), false), // key: number
ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum
ColumnType::new(ConcreteDataType::int64_datatype(), true), // count
])
.with_key(vec![0])
.into_named(vec![
Some("number".to_string()),
None,
None,
]),
),
),
mfp: MapFilterProject::new(3)
.map(vec![
avg_expr, // col 3
ScalarExpr::Column(0),
// TODO(discord9): optimize mfp so to remove indirect ref
])
.unwrap()
.project(vec![3, 4])
.unwrap(),
},
};
assert_eq!(flow_plan.unwrap(), expected);
}
#[tokio::test]
async fn test_avg() {
let engine = create_test_query_engine();
let sql = "SELECT avg(number) FROM numbers";
let plan = sql_to_substrait(engine.clone(), sql).await;
let mut ctx = create_test_ctx();
let flow_plan = TypedPlan::from_substrait_plan(&mut ctx, &plan)
.await
.unwrap();
let aggr_exprs = vec![
AggregateExpr {
func: AggregateFunc::SumUInt64,
expr: ScalarExpr::Column(0),
distinct: false,
},
AggregateExpr {
func: AggregateFunc::Count,
expr: ScalarExpr::Column(1),
distinct: false,
},
];
let avg_expr = ScalarExpr::If {
cond: Box::new(ScalarExpr::Column(1).call_binary(
ScalarExpr::Literal(Value::from(0i64), CDT::int64_datatype()),
BinaryFunc::NotEq,
)),
then: Box::new(
ScalarExpr::Column(0)
.cast(CDT::float64_datatype())
.call_binary(
ScalarExpr::Column(1).cast(CDT::float64_datatype()),
BinaryFunc::DivFloat64,
),
),
els: Box::new(ScalarExpr::Literal(Value::Null, CDT::float64_datatype())),
};
let input = Box::new(
Plan::Get {
id: crate::expr::Id::Global(GlobalId::User(0)),
}
.with_types(
RelationType::new(vec![ColumnType::new(
ConcreteDataType::uint32_datatype(),
false,
)])
.into_named(vec![Some("number".to_string())]),
),
);
let expected = TypedPlan {
schema: RelationType::new(vec![ColumnType::new(CDT::float64_datatype(), true)])
.into_named(vec![Some("avg(numbers.number)".to_string())]),
plan: Plan::Mfp {
input: Box::new(
Plan::Reduce {
input: Box::new(
Plan::Mfp {
input: input.clone(),
mfp: MapFilterProject::new(1).project(vec![0]).unwrap(),
}
.with_types(
RelationType::new(vec![ColumnType::new(
CDT::uint32_datatype(),
false,
)])
.into_named(vec![Some("number".to_string())]),
),
),
key_val_plan: KeyValPlan {
key_plan: MapFilterProject::new(1)
.project(vec![])
.unwrap()
.into_safe(),
val_plan: MapFilterProject::new(1)
.map(vec![
ScalarExpr::Column(0).cast(CDT::uint64_datatype()),
ScalarExpr::Column(0),
])
.unwrap()
.project(vec![1, 2])
.unwrap()
.into_safe(),
},
reduce_plan: ReducePlan::Accumulable(AccumulablePlan {
full_aggrs: aggr_exprs.clone(),
simple_aggrs: vec![
AggrWithIndex::new(aggr_exprs[0].clone(), 0, 0),
AggrWithIndex::new(aggr_exprs[1].clone(), 1, 1),
],
distinct_aggrs: vec![],
}),
}
.with_types(
RelationType::new(vec![
ColumnType::new(ConcreteDataType::uint64_datatype(), true), // sum
ColumnType::new(ConcreteDataType::int64_datatype(), true), // count
])
.into_named(vec![None, None]),
),
),
mfp: MapFilterProject::new(2)
.map(vec![
avg_expr,
// TODO(discord9): optimize mfp so to remove indirect ref
])
.unwrap()
.project(vec![2])
.unwrap(),
},
};
assert_eq!(flow_plan, expected);
}
#[tokio::test]
async fn test_sum() {
let engine = create_test_query_engine();

View File

@@ -13,6 +13,7 @@
// limitations under the License.
pub mod builder;
mod dashboard;
mod grpc;
mod influxdb;
mod jaeger;

View File

@@ -0,0 +1,405 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use api::v1::value::ValueData;
use api::v1::{
ColumnDataType, ColumnDef, ColumnSchema as PbColumnSchema, Row, RowInsertRequest,
RowInsertRequests, Rows, SemanticType,
};
use async_trait::async_trait;
use common_catalog::consts::{DEFAULT_PRIVATE_SCHEMA_NAME, default_engine};
use common_error::ext::BoxedError;
use common_query::OutputData;
use common_recordbatch::util as record_util;
use common_telemetry::info;
use common_time::FOREVER;
use datafusion::datasource::DefaultTableSource;
use datafusion::logical_expr::col;
use datafusion::sql::TableReference;
use datafusion_expr::{DmlStatement, LogicalPlan, lit};
use datatypes::arrow::array::{Array, AsArray};
use servers::error::{
CatalogSnafu, CollectRecordbatchSnafu, DataFusionSnafu, ExecuteQuerySnafu, NotSupportedSnafu,
TableNotFoundSnafu,
};
use servers::query_handler::DashboardDefinition;
use session::context::{QueryContextBuilder, QueryContextRef};
use snafu::{OptionExt, ResultExt};
use table::TableRef;
use table::metadata::TableInfo;
use table::requests::TTL_KEY;
use table::table::adapter::DfTableProviderAdapter;
use crate::instance::Instance;
pub const DASHBOARD_TABLE_NAME: &str = "dashboard";
pub const DASHBOARD_TABLE_NAME_COLUMN_NAME: &str = "name";
pub const DASHBOARD_TABLE_DEFINITION_COLUMN_NAME: &str = "definition";
pub const DASHBOARD_TABLE_CREATED_AT_COLUMN_NAME: &str = "created_at";
impl Instance {
/// Build a schema for dashboard table.
/// Returns the (time index, primary keys, column) definitions.
fn build_dashboard_schema() -> (String, Vec<String>, Vec<ColumnDef>) {
(
DASHBOARD_TABLE_CREATED_AT_COLUMN_NAME.to_string(),
vec![DASHBOARD_TABLE_NAME_COLUMN_NAME.to_string()],
vec![
ColumnDef {
name: DASHBOARD_TABLE_NAME_COLUMN_NAME.to_string(),
data_type: ColumnDataType::String as i32,
is_nullable: false,
default_constraint: vec![],
semantic_type: SemanticType::Tag as i32,
comment: String::new(),
datatype_extension: None,
options: None,
},
ColumnDef {
name: DASHBOARD_TABLE_DEFINITION_COLUMN_NAME.to_string(),
data_type: ColumnDataType::String as i32,
is_nullable: false,
default_constraint: vec![],
semantic_type: SemanticType::Field as i32,
comment: String::new(),
datatype_extension: None,
options: None,
},
ColumnDef {
name: DASHBOARD_TABLE_CREATED_AT_COLUMN_NAME.to_string(),
data_type: ColumnDataType::TimestampNanosecond as i32,
is_nullable: false,
default_constraint: vec![],
semantic_type: SemanticType::Timestamp as i32,
comment: String::new(),
datatype_extension: None,
options: None,
},
],
)
}
/// Build a column schemas for inserting a row into the dashboard table.
fn build_dashboard_insert_column_schemas() -> Vec<PbColumnSchema> {
vec![
PbColumnSchema {
column_name: DASHBOARD_TABLE_NAME_COLUMN_NAME.to_string(),
datatype: ColumnDataType::String.into(),
semantic_type: SemanticType::Tag.into(),
..Default::default()
},
PbColumnSchema {
column_name: DASHBOARD_TABLE_DEFINITION_COLUMN_NAME.to_string(),
datatype: ColumnDataType::String.into(),
semantic_type: SemanticType::Field.into(),
..Default::default()
},
PbColumnSchema {
column_name: DASHBOARD_TABLE_CREATED_AT_COLUMN_NAME.to_string(),
datatype: ColumnDataType::TimestampNanosecond.into(),
semantic_type: SemanticType::Timestamp.into(),
..Default::default()
},
]
}
fn dashboard_query_ctx(table_info: &TableInfo) -> QueryContextRef {
QueryContextBuilder::default()
.current_catalog(table_info.catalog_name.clone())
.current_schema(table_info.schema_name.clone())
.build()
.into()
}
async fn create_dashboard_table_if_not_exists(
&self,
ctx: QueryContextRef,
) -> servers::error::Result<TableRef> {
let catalog = ctx.current_catalog();
if let Some(table) = self
.catalog_manager
.table(
catalog,
DEFAULT_PRIVATE_SCHEMA_NAME,
DASHBOARD_TABLE_NAME,
Some(&ctx),
)
.await
.context(CatalogSnafu)?
{
return Ok(table);
}
let (time_index, primary_keys, column_defs) = Self::build_dashboard_schema();
let mut table_options = HashMap::new();
table_options.insert(TTL_KEY.to_string(), FOREVER.to_string());
let mut create_table_expr = api::v1::CreateTableExpr {
catalog_name: catalog.to_string(),
schema_name: DEFAULT_PRIVATE_SCHEMA_NAME.to_string(),
table_name: DASHBOARD_TABLE_NAME.to_string(),
desc: "GreptimeDB dashboard table".to_string(),
column_defs,
time_index,
primary_keys,
create_if_not_exists: true,
table_options,
table_id: None,
engine: default_engine().to_string(),
};
self.statement_executor
.create_table_inner(&mut create_table_expr, None, ctx.clone())
.await
.map_err(BoxedError::new)
.context(ExecuteQuerySnafu)?;
let table = self
.catalog_manager
.table(
catalog,
DEFAULT_PRIVATE_SCHEMA_NAME,
DASHBOARD_TABLE_NAME,
Some(&ctx),
)
.await
.context(CatalogSnafu)?
.context(TableNotFoundSnafu {
catalog: catalog.to_string(),
schema: DEFAULT_PRIVATE_SCHEMA_NAME.to_string(),
table: DASHBOARD_TABLE_NAME.to_string(),
})?;
Ok(table)
}
/// Insert a dashboard into the dashboard table.
async fn insert_dashboard(
&self,
name: &str,
definition: &str,
query_ctx: QueryContextRef,
) -> servers::error::Result<()> {
let table = self
.create_dashboard_table_if_not_exists(query_ctx.clone())
.await?;
let table_info = table.table_info();
let insert = RowInsertRequest {
table_name: DASHBOARD_TABLE_NAME.to_string(),
rows: Some(Rows {
schema: Self::build_dashboard_insert_column_schemas(),
rows: vec![Row {
values: vec![
ValueData::StringValue(name.to_string()).into(),
ValueData::StringValue(definition.to_string()).into(),
ValueData::TimestampNanosecondValue(0).into(),
],
}],
}),
};
let requests = RowInsertRequests {
inserts: vec![insert],
};
let output = self
.inserter
.handle_row_inserts(
requests,
Self::dashboard_query_ctx(&table_info),
&self.statement_executor,
false,
false,
)
.await
.map_err(BoxedError::new)
.context(ExecuteQuerySnafu)?;
info!(
"Insert dashboard success, name: {}, table: {}, output: {:?}",
name,
table_info.full_table_name(),
output
);
Ok(())
}
/// List all dashboards.
async fn list_dashboards(
&self,
query_ctx: QueryContextRef,
) -> servers::error::Result<Vec<DashboardDefinition>> {
let table = if let Some(table) = self
.catalog_manager
.table(
query_ctx.current_catalog(),
DEFAULT_PRIVATE_SCHEMA_NAME,
DASHBOARD_TABLE_NAME,
Some(&query_ctx),
)
.await
.context(CatalogSnafu)?
{
table
} else {
return Ok(vec![]);
};
let table_info = table.table_info();
let dataframe = self
.query_engine
.read_table(table.clone())
.map_err(BoxedError::new)
.context(ExecuteQuerySnafu)?;
let dataframe = dataframe
.select_columns(&[
DASHBOARD_TABLE_NAME_COLUMN_NAME,
DASHBOARD_TABLE_DEFINITION_COLUMN_NAME,
])
.context(DataFusionSnafu)?;
let plan = dataframe.into_parts().1;
let output = self
.query_engine
.execute(plan, Self::dashboard_query_ctx(&table_info))
.await
.map_err(BoxedError::new)
.context(ExecuteQuerySnafu)?;
let stream = match output.data {
OutputData::Stream(stream) => stream,
OutputData::RecordBatches(record_batches) => record_batches.as_stream(),
_ => unreachable!(),
};
let records = record_util::collect(stream)
.await
.context(CollectRecordbatchSnafu)?;
let mut dashboards = Vec::new();
for r in &records {
let name_column = r.column(0);
let definition_column = r.column(1);
let name = name_column
.as_string_opt::<i32>()
.context(NotSupportedSnafu {
feat: "Invalid data type for greptime_private.dashboard.name",
})?;
let definition =
definition_column
.as_string_opt::<i32>()
.context(NotSupportedSnafu {
feat: "Invalid data type for greptime_private.dashboard.definition",
})?;
for i in 0..name.len() {
dashboards.push(DashboardDefinition {
name: name.value(i).to_string(),
definition: definition.value(i).to_string(),
});
}
}
Ok(dashboards)
}
/// Delete a dashboard by name.
async fn delete_dashboard(
&self,
name: &str,
query_ctx: QueryContextRef,
) -> servers::error::Result<()> {
let table = self
.create_dashboard_table_if_not_exists(query_ctx.clone())
.await?;
let table_info = table.table_info();
let dataframe = self
.query_engine
.read_table(table.clone())
.map_err(BoxedError::new)
.context(ExecuteQuerySnafu)?;
let name_condition = col(DASHBOARD_TABLE_NAME_COLUMN_NAME).eq(lit(name));
let dataframe = dataframe.filter(name_condition).context(DataFusionSnafu)?;
let table_name = TableReference::full(
table_info.catalog_name.clone(),
table_info.schema_name.clone(),
table_info.name.clone(),
);
let table_provider = Arc::new(DfTableProviderAdapter::new(table.clone()));
let table_source = Arc::new(DefaultTableSource::new(table_provider));
let stmt = DmlStatement::new(
table_name,
table_source,
datafusion_expr::WriteOp::Delete,
Arc::new(dataframe.into_parts().1),
);
let plan = LogicalPlan::Dml(stmt);
let output = self
.query_engine
.execute(plan, Self::dashboard_query_ctx(&table_info))
.await
.map_err(BoxedError::new)
.context(ExecuteQuerySnafu)?;
info!(
"Delete dashboard success, name: {}, table: {}, output: {:?}",
name,
table_info.full_table_name(),
output
);
Ok(())
}
}
#[async_trait]
impl servers::query_handler::DashboardHandler for Instance {
async fn save(
&self,
name: &str,
definition: &str,
ctx: QueryContextRef,
) -> servers::error::Result<()> {
self.insert_dashboard(name, definition, ctx).await
}
async fn list(&self, ctx: QueryContextRef) -> servers::error::Result<Vec<DashboardDefinition>> {
self.list_dashboards(ctx).await
}
async fn delete(&self, name: &str, ctx: QueryContextRef) -> servers::error::Result<()> {
self.delete_dashboard(name, ctx).await
}
}

View File

@@ -27,7 +27,6 @@ use api::v1::{
use async_stream::try_stream;
use async_trait::async_trait;
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
use common_base::AffectedRows;
use common_error::ext::BoxedError;
use common_grpc::flight::do_put::DoPutResponse;
use common_query::Output;
@@ -260,62 +259,6 @@ impl GrpcQueryHandler for Instance {
.context(server_error::ExecuteGrpcQuerySnafu)
}
async fn put_record_batch(
&self,
request: servers::grpc::flight::PutRecordBatchRequest,
table_ref: &mut Option<TableRef>,
ctx: QueryContextRef,
) -> server_error::Result<AffectedRows> {
let result: Result<AffectedRows> = async {
let table = if let Some(table) = table_ref {
table.clone()
} else {
let table = self
.catalog_manager()
.table(
&request.table_name.catalog_name,
&request.table_name.schema_name,
&request.table_name.table_name,
None,
)
.await
.context(CatalogSnafu)?
.with_context(|| TableNotFoundSnafu {
table_name: request.table_name.to_string(),
})?;
*table_ref = Some(table.clone());
table
};
let interceptor_ref = self.plugins.get::<GrpcQueryInterceptorRef<Error>>();
let interceptor = interceptor_ref.as_ref();
interceptor.pre_bulk_insert(table.clone(), ctx.clone())?;
self.plugins
.get::<PermissionCheckerRef>()
.as_ref()
.check_permission(ctx.current_user(), PermissionReq::BulkInsert)
.context(PermissionSnafu)?;
// do we check limit for bulk insert?
self.inserter
.handle_bulk_insert(
table,
request.flight_data,
request.record_batch,
request.schema_bytes,
)
.await
.context(TableOperationSnafu)
}
.await;
result
.map_err(BoxedError::new)
.context(server_error::ExecuteGrpcRequestSnafu)
}
fn handle_put_record_batch_stream(
&self,
stream: servers::grpc::flight::PutRecordBatchRequestStream,

View File

@@ -143,6 +143,8 @@ where
builder = builder.with_jaeger_handler(self.instance.clone());
}
builder = builder.with_dashboard_handler(self.instance.clone());
if let Some(configurator) = self.plugins.get::<RouterConfigurator>() {
info!("Adding extra router from plugins");
builder = builder.with_extra_router(configurator.router());

View File

@@ -24,6 +24,8 @@ use common_base::Plugins;
use common_config::Configurable;
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
use common_meta::distributed_time_constants::META_LEASE_SECS;
use common_meta::election::CANDIDATE_LEASE_SECS;
use common_meta::election::etcd::EtcdElection;
use common_meta::kv_backend::chroot::ChrootKvBackend;
use common_meta::kv_backend::etcd::EtcdStore;
use common_meta::kv_backend::memory::MemoryKvBackend;
@@ -42,9 +44,6 @@ use tonic::codec::CompressionEncoding;
use tonic::transport::server::{Router, TcpIncoming};
use crate::cluster::{MetaPeerClientBuilder, MetaPeerClientRef};
#[cfg(any(feature = "pg_kvbackend", feature = "mysql_kvbackend"))]
use crate::election::CANDIDATE_LEASE_SECS;
use crate::election::etcd::EtcdElection;
use crate::error::OtherSnafu;
use crate::metasrv::builder::MetasrvBuilder;
use crate::metasrv::{
@@ -281,7 +280,8 @@ pub async fn metasrv_builder(
etcd_client,
opts.store_key_prefix.clone(),
)
.await?;
.await
.context(error::KvBackendSnafu)?;
(kv_backend, Some(election))
}
@@ -290,10 +290,10 @@ pub async fn metasrv_builder(
use std::time::Duration;
use common_meta::distributed_time_constants::POSTGRES_KEEP_ALIVE_SECS;
use common_meta::election::rds::postgres::{ElectionPgClient, PgElection};
use common_meta::kv_backend::rds::PgStore;
use deadpool_postgres::{Config, ManagerConfig, RecyclingMethod};
use crate::election::rds::postgres::{ElectionPgClient, PgElection};
use crate::utils::postgres::create_postgres_pool;
let candidate_lease_ttl = Duration::from_secs(CANDIDATE_LEASE_SECS);
@@ -321,7 +321,8 @@ pub async fn metasrv_builder(
execution_timeout,
idle_session_timeout,
statement_timeout,
)?;
)
.context(error::KvBackendSnafu)?;
let election = PgElection::with_pg_client(
opts.grpc.server_addr.clone(),
election_client,
@@ -332,7 +333,8 @@ pub async fn metasrv_builder(
&opts.meta_table_name,
opts.meta_election_lock_id,
)
.await?;
.await
.context(error::KvBackendSnafu)?;
let pool = create_postgres_pool(&opts.store_addrs, Some(cfg), opts.backend_tls.clone())
.await?;
@@ -352,9 +354,9 @@ pub async fn metasrv_builder(
(None, BackendImpl::MysqlStore) => {
use std::time::Duration;
use common_meta::election::rds::mysql::{ElectionMysqlClient, MySqlElection};
use common_meta::kv_backend::rds::MySqlStore;
use crate::election::rds::mysql::{ElectionMysqlClient, MySqlElection};
use crate::utils::mysql::create_mysql_pool;
let pool = create_mysql_pool(&opts.store_addrs, opts.backend_tls.as_ref()).await?;
@@ -389,7 +391,8 @@ pub async fn metasrv_builder(
meta_lease_ttl,
&election_table_name,
)
.await?;
.await
.context(error::KvBackendSnafu)?;
(kv_backend, Some(election))
}
};

View File

@@ -247,7 +247,7 @@ impl MetaPeerClient {
// Safety: when self.is_leader() == false, election must not empty.
let election = self.election.as_ref().unwrap();
let leader_addr = election.leader().await?.0;
let leader_addr = election.leader().await.context(error::KvBackendSnafu)?.0;
let channel = self
.channel_manager
@@ -279,7 +279,7 @@ impl MetaPeerClient {
// Safety: when self.is_leader() == false, election must not empty.
let election = self.election.as_ref().unwrap();
let leader_addr = election.leader().await?.0;
let leader_addr = election.leader().await.context(error::KvBackendSnafu)?.0;
let channel = self
.channel_manager

View File

@@ -21,7 +21,6 @@ pub mod bootstrap;
pub mod cache_invalidator;
pub mod cluster;
pub mod discovery;
pub mod election;
pub mod error;
pub mod events;
mod failure_detector;

View File

@@ -32,6 +32,8 @@ use common_meta::ddl_manager::DdlManagerRef;
use common_meta::distributed_time_constants::{
self, BASE_HEARTBEAT_INTERVAL, default_distributed_time_constants, frontend_heartbeat_interval,
};
use common_meta::election::LeaderChangeMessage;
pub use common_meta::election::{ElectionRef, MetasrvNodeInfo};
use common_meta::key::TableMetadataManagerRef;
use common_meta::key::runtime_switch::RuntimeSwitchManagerRef;
use common_meta::kv_backend::{KvBackendRef, ResettableKvBackend, ResettableKvBackendRef};
@@ -64,7 +66,6 @@ use tokio::sync::broadcast::error::RecvError;
use crate::cluster::MetaPeerClientRef;
use crate::discovery;
use crate::election::{Election, LeaderChangeMessage};
use crate::error::{
self, InitMetadataSnafu, KvBackendSnafu, Result, StartProcedureManagerSnafu,
StartTelemetryTaskSnafu, StopProcedureManagerSnafu,
@@ -459,76 +460,6 @@ impl Context {
}
}
/// The value of the leader. It is used to store the leader's address.
pub struct LeaderValue(pub String);
impl<T: AsRef<[u8]>> From<T> for LeaderValue {
fn from(value: T) -> Self {
let string = String::from_utf8_lossy(value.as_ref());
Self(string.to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetasrvNodeInfo {
// The metasrv's address
pub addr: String,
// The node build version
pub version: String,
// The node build git commit hash
pub git_commit: String,
// The node start timestamp in milliseconds
pub start_time_ms: u64,
// The node total cpu millicores
#[serde(default)]
pub total_cpu_millicores: i64,
// The node total memory bytes
#[serde(default)]
pub total_memory_bytes: i64,
/// The node build cpu usage millicores
#[serde(default)]
pub cpu_usage_millicores: i64,
/// The node build memory usage bytes
#[serde(default)]
pub memory_usage_bytes: i64,
// The node hostname
#[serde(default)]
pub hostname: String,
}
// TODO(zyy17): Allow deprecated fields for backward compatibility. Remove this when the deprecated top-level fields are removed from the proto.
#[allow(deprecated)]
impl From<MetasrvNodeInfo> for api::v1::meta::MetasrvNodeInfo {
fn from(node_info: MetasrvNodeInfo) -> Self {
Self {
peer: Some(api::v1::meta::Peer {
addr: node_info.addr,
..Default::default()
}),
// TODO(zyy17): The following top-level fields are deprecated. They are kept for backward compatibility and will be removed in a future version.
// New code should use the fields in `info.NodeInfo` instead.
version: node_info.version.clone(),
git_commit: node_info.git_commit.clone(),
start_time_ms: node_info.start_time_ms,
cpus: node_info.total_cpu_millicores as u32,
memory_bytes: node_info.total_memory_bytes as u64,
// The canonical location for node information.
info: Some(api::v1::meta::NodeInfo {
version: node_info.version,
git_commit: node_info.git_commit,
start_time_ms: node_info.start_time_ms,
total_cpu_millicores: node_info.total_cpu_millicores,
total_memory_bytes: node_info.total_memory_bytes,
cpu_usage_millicores: node_info.cpu_usage_millicores,
memory_usage_bytes: node_info.memory_usage_bytes,
cpus: node_info.total_cpu_millicores as u32,
memory_bytes: node_info.total_memory_bytes as u64,
hostname: node_info.hostname,
}),
}
}
}
#[derive(Clone, Copy)]
pub enum SelectTarget {
Datanode,
@@ -552,7 +483,6 @@ pub struct SelectorContext {
pub type SelectorRef = Arc<dyn Selector<Context = SelectorContext, Output = Vec<Peer>>>;
pub type RegionStatAwareSelectorRef =
Arc<dyn RegionStatAwareSelector<Context = SelectorContext, Output = Vec<(RegionId, Peer)>>>;
pub type ElectionRef = Arc<dyn Election<Leader = LeaderValue>>;
pub struct MetaStateHandler {
subscribe_manager: Option<SubscriptionManagerRef>,

View File

@@ -32,7 +32,7 @@ pub struct LeaderHandler {
impl LeaderHandler {
async fn get_leader(&self) -> Result<Option<String>> {
if let Some(election) = &self.election {
let leader_addr = election.leader().await?.0;
let leader_addr = election.leader().await.context(error::KvBackendSnafu)?.0;
return Ok(Some(leader_addr));
}
Ok(None)

View File

@@ -63,7 +63,10 @@ impl cluster_server::Cluster for Metasrv {
let leader_addr = &self.options().grpc.server_addr;
let (leader, followers) = match self.election() {
Some(election) => {
let nodes = election.all_candidates().await?;
let nodes = election
.all_candidates()
.await
.context(error::KvBackendSnafu)?;
let followers = nodes
.into_iter()
.filter(|node_info| &node_info.addr != leader_addr)

View File

@@ -23,7 +23,7 @@ use api::v1::meta::{
use common_telemetry::{debug, error, info, warn};
use futures::StreamExt;
use once_cell::sync::OnceCell;
use snafu::OptionExt;
use snafu::{OptionExt, ResultExt};
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio_stream::wrappers::ReceiverStream;
@@ -148,7 +148,7 @@ async fn handle_ask_leader(_req: AskLeaderRequest, ctx: Context) -> Result<AskLe
if election.is_leader() {
ctx.server_addr
} else {
election.leader().await?.0
election.leader().await.context(error::KvBackendSnafu)?.0
}
}
None => ctx.server_addr,

View File

@@ -17,6 +17,7 @@ bytes.workspace = true
fxhash = "0.2"
common-base.workspace = true
common-error.workspace = true
common-grpc.workspace = true
common-macro.workspace = true
common-query.workspace = true
common-recordbatch.workspace = true

View File

@@ -0,0 +1,426 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::hash::Hasher;
use std::sync::Arc;
use datatypes::arrow::array::{Array, BinaryBuilder, StringArray, UInt64Array};
use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
use datatypes::arrow::record_batch::RecordBatch;
use datatypes::value::ValueRef;
use fxhash::FxHasher;
use mito_codec::row_converter::SparsePrimaryKeyCodec;
use snafu::ResultExt;
use store_api::storage::ColumnId;
use store_api::storage::consts::{PRIMARY_KEY_COLUMN_NAME, ReservedColumnId};
use crate::error::{EncodePrimaryKeySnafu, Result, UnexpectedRequestSnafu};
/// Info about a tag column for TSID computation and sparse primary key encoding.
#[allow(dead_code)]
pub(crate) struct TagColumnInfo {
/// Column name (used for label-name hash).
pub name: String,
/// Column index in the RecordBatch.
pub index: usize,
/// Column ID in the physical region.
pub column_id: ColumnId,
}
/// Computes `__tsid` values for each row.
#[allow(dead_code)]
pub(crate) fn compute_tsid_array(
batch: &RecordBatch,
sorted_tag_columns: &[TagColumnInfo],
tag_arrays: &[&StringArray],
) -> UInt64Array {
let num_rows = batch.num_rows();
let label_name_hash = {
let mut hasher = FxHasher::default();
for tag_col in sorted_tag_columns {
hasher.write(tag_col.name.as_bytes());
hasher.write_u8(0xff);
}
hasher.finish()
};
let mut tsid_values = Vec::with_capacity(num_rows);
for row in 0..num_rows {
let has_null = tag_arrays.iter().any(|arr| arr.is_null(row));
let tsid = if !has_null {
let mut hasher = FxHasher::default();
hasher.write_u64(label_name_hash);
for arr in tag_arrays {
hasher.write(arr.value(row).as_bytes());
hasher.write_u8(0xff);
}
hasher.finish()
} else {
let mut name_hasher = FxHasher::default();
for (tc, arr) in sorted_tag_columns.iter().zip(tag_arrays.iter()) {
if !arr.is_null(row) {
name_hasher.write(tc.name.as_bytes());
name_hasher.write_u8(0xff);
}
}
let row_label_hash = name_hasher.finish();
let mut val_hasher = FxHasher::default();
val_hasher.write_u64(row_label_hash);
for arr in tag_arrays {
if !arr.is_null(row) {
val_hasher.write(arr.value(row).as_bytes());
val_hasher.write_u8(0xff);
}
}
val_hasher.finish()
};
tsid_values.push(tsid);
}
UInt64Array::from(tsid_values)
}
fn build_tag_arrays<'a>(
batch: &'a RecordBatch,
sorted_tag_columns: &[TagColumnInfo],
) -> Vec<&'a StringArray> {
sorted_tag_columns
.iter()
.map(|tc| {
batch
.column(tc.index)
.as_any()
.downcast_ref::<StringArray>()
.expect("tag column must be utf8")
})
.collect()
}
/// Modifies a RecordBatch for sparse primary key encoding.
#[allow(dead_code)]
pub(crate) fn modify_batch_sparse(
batch: RecordBatch,
table_id: u32,
sorted_tag_columns: &[TagColumnInfo],
non_tag_column_indices: &[usize],
) -> Result<RecordBatch> {
let num_rows = batch.num_rows();
let codec = SparsePrimaryKeyCodec::schemaless();
let tag_arrays: Vec<&StringArray> = build_tag_arrays(&batch, sorted_tag_columns);
let tsid_array = compute_tsid_array(&batch, sorted_tag_columns, &tag_arrays);
let mut pk_builder = BinaryBuilder::with_capacity(num_rows, 0);
let mut buffer = Vec::new();
for row in 0..num_rows {
buffer.clear();
let internal = [
(ReservedColumnId::table_id(), ValueRef::UInt32(table_id)),
(
ReservedColumnId::tsid(),
ValueRef::UInt64(tsid_array.value(row)),
),
];
codec
.encode_to_vec(internal.into_iter(), &mut buffer)
.context(EncodePrimaryKeySnafu)?;
let tags = sorted_tag_columns
.iter()
.zip(tag_arrays.iter())
.filter(|(_, arr)| !arr.is_null(row))
.map(|(tc, arr)| (tc.column_id, ValueRef::String(arr.value(row))));
codec
.encode_to_vec(tags, &mut buffer)
.context(EncodePrimaryKeySnafu)?;
pk_builder.append_value(&buffer);
}
let pk_array = pk_builder.finish();
let mut fields = vec![Arc::new(Field::new(
PRIMARY_KEY_COLUMN_NAME,
DataType::Binary,
false,
))];
let mut columns: Vec<Arc<dyn Array>> = vec![Arc::new(pk_array)];
for &idx in non_tag_column_indices {
fields.push(batch.schema().fields()[idx].clone());
columns.push(batch.column(idx).clone());
}
let new_schema = Arc::new(ArrowSchema::new(fields));
RecordBatch::try_new(new_schema, columns).map_err(|e| {
UnexpectedRequestSnafu {
reason: format!("Failed to build modified sparse RecordBatch: {e}"),
}
.build()
})
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use api::v1::value::ValueData;
use api::v1::{ColumnDataType, ColumnSchema, Row, Rows, SemanticType, Value};
use datatypes::arrow::array::{BinaryArray, Int64Array, StringArray};
use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema};
use datatypes::arrow::record_batch::RecordBatch;
use store_api::codec::PrimaryKeyEncoding;
use store_api::storage::consts::PRIMARY_KEY_COLUMN_NAME;
use super::*;
use crate::row_modifier::{RowModifier, RowsIter, TableIdInput};
fn build_sparse_test_batch() -> RecordBatch {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("greptime_timestamp", DataType::Int64, false),
Field::new("greptime_value", DataType::Float64, true),
Field::new("namespace", DataType::Utf8, true),
Field::new("host", DataType::Utf8, true),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(Int64Array::from(vec![1000])),
Arc::new(datatypes::arrow::array::Float64Array::from(vec![42.0])),
Arc::new(StringArray::from(vec!["greptimedb"])),
Arc::new(StringArray::from(vec!["127.0.0.1"])),
],
)
.unwrap()
}
fn sparse_tag_columns() -> Vec<TagColumnInfo> {
vec![
TagColumnInfo {
name: "host".to_string(),
index: 3,
column_id: 3,
},
TagColumnInfo {
name: "namespace".to_string(),
index: 2,
column_id: 2,
},
]
}
#[test]
fn test_compute_tsid_basic() {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("namespace", DataType::Utf8, true),
Field::new("host", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec!["greptimedb"])),
Arc::new(StringArray::from(vec!["127.0.0.1"])),
],
)
.unwrap();
let tag_columns: Vec<TagColumnInfo> = vec![
TagColumnInfo {
name: "host".to_string(),
index: 1,
column_id: 2,
},
TagColumnInfo {
name: "namespace".to_string(),
index: 0,
column_id: 1,
},
];
let tag_arrays = build_tag_arrays(&batch, &tag_columns);
let tsid_array = compute_tsid_array(&batch, &tag_columns, &tag_arrays);
assert_eq!(tsid_array.value(0), 2721566936019240841);
}
#[test]
fn test_compute_tsid_with_nulls() {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new("a", DataType::Utf8, true),
Field::new("b", DataType::Utf8, true),
]));
let batch_no_null = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["A"])),
Arc::new(StringArray::from(vec!["B"])),
],
)
.unwrap();
let tag_cols_2: Vec<TagColumnInfo> = vec![
TagColumnInfo {
name: "a".to_string(),
index: 0,
column_id: 1,
},
TagColumnInfo {
name: "b".to_string(),
index: 1,
column_id: 2,
},
];
let tag_arrays_2 = build_tag_arrays(&batch_no_null, &tag_cols_2);
let tsid_no_null = compute_tsid_array(&batch_no_null, &tag_cols_2, &tag_arrays_2);
let schema3 = Arc::new(ArrowSchema::new(vec![
Field::new("a", DataType::Utf8, true),
Field::new("b", DataType::Utf8, true),
Field::new("c", DataType::Utf8, true),
]));
let batch_with_null = RecordBatch::try_new(
schema3,
vec![
Arc::new(StringArray::from(vec!["A"])),
Arc::new(StringArray::from(vec!["B"])),
Arc::new(StringArray::from(vec![None as Option<&str>])),
],
)
.unwrap();
let tag_cols_3: Vec<TagColumnInfo> = vec![
TagColumnInfo {
name: "a".to_string(),
index: 0,
column_id: 1,
},
TagColumnInfo {
name: "b".to_string(),
index: 1,
column_id: 2,
},
TagColumnInfo {
name: "c".to_string(),
index: 2,
column_id: 3,
},
];
let tag_arrays_3 = build_tag_arrays(&batch_with_null, &tag_cols_3);
let tsid_with_null = compute_tsid_array(&batch_with_null, &tag_cols_3, &tag_arrays_3);
assert_eq!(tsid_no_null.value(0), tsid_with_null.value(0));
}
#[test]
fn test_modify_batch_sparse() {
let batch = build_sparse_test_batch();
let tag_columns = sparse_tag_columns();
let non_tag_indices = vec![0, 1];
let table_id: u32 = 1025;
let modified =
modify_batch_sparse(batch, table_id, &tag_columns, &non_tag_indices).unwrap();
assert_eq!(modified.num_columns(), 3);
assert_eq!(modified.schema().field(0).name(), PRIMARY_KEY_COLUMN_NAME);
assert_eq!(modified.schema().field(1).name(), "greptime_timestamp");
assert_eq!(modified.schema().field(2).name(), "greptime_value");
}
#[test]
fn test_modify_batch_sparse_matches_row_modifier() {
let batch = build_sparse_test_batch();
let tag_columns = sparse_tag_columns();
let non_tag_indices = vec![0, 1];
let table_id: u32 = 1025;
let modified =
modify_batch_sparse(batch, table_id, &tag_columns, &non_tag_indices).unwrap();
let name_to_column_id: HashMap<String, ColumnId> = [
("greptime_timestamp".to_string(), 0),
("greptime_value".to_string(), 1),
("namespace".to_string(), 2),
("host".to_string(), 3),
]
.into_iter()
.collect();
let rows = Rows {
schema: vec![
ColumnSchema {
column_name: "greptime_timestamp".to_string(),
datatype: ColumnDataType::TimestampMillisecond as i32,
semantic_type: SemanticType::Timestamp as i32,
..Default::default()
},
ColumnSchema {
column_name: "greptime_value".to_string(),
datatype: ColumnDataType::Float64 as i32,
semantic_type: SemanticType::Field as i32,
..Default::default()
},
ColumnSchema {
column_name: "namespace".to_string(),
datatype: ColumnDataType::String as i32,
semantic_type: SemanticType::Tag as i32,
..Default::default()
},
ColumnSchema {
column_name: "host".to_string(),
datatype: ColumnDataType::String as i32,
semantic_type: SemanticType::Tag as i32,
..Default::default()
},
],
rows: vec![Row {
values: vec![
Value {
value_data: Some(ValueData::TimestampMillisecondValue(1000)),
},
Value {
value_data: Some(ValueData::F64Value(42.0)),
},
Value {
value_data: Some(ValueData::StringValue("greptimedb".to_string())),
},
Value {
value_data: Some(ValueData::StringValue("127.0.0.1".to_string())),
},
],
}],
};
let row_iter = RowsIter::new(rows, &name_to_column_id);
let rows = RowModifier::default()
.modify_rows(
row_iter,
TableIdInput::Single(table_id),
PrimaryKeyEncoding::Sparse,
)
.unwrap();
let ValueData::BinaryValue(expected_pk) =
rows.rows[0].values[0].value_data.clone().unwrap()
else {
panic!("expected binary primary key");
};
let actual_array = modified
.column(0)
.as_any()
.downcast_ref::<BinaryArray>()
.unwrap();
assert_eq!(actual_array.value(0), expected_pk.as_slice());
}
}

View File

@@ -13,6 +13,7 @@
// limitations under the License.
mod alter;
mod bulk_insert;
mod catchup;
mod close;
mod create;
@@ -288,9 +289,8 @@ impl RegionEngine for MetricEngine {
debug_assert_eq!(region_id, resp_region_id);
return response;
}
RegionRequest::BulkInserts(_) => {
// todo(hl): find a way to support bulk inserts in metric engine.
UnsupportedRegionRequestSnafu { request }.fail()
RegionRequest::BulkInserts(bulk) => {
self.inner.bulk_insert_region(region_id, bulk).await
}
};

View File

@@ -0,0 +1,783 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashSet;
use api::v1::{ArrowIpc, ColumnDataType, SemanticType};
use bytes::Bytes;
use common_error::ext::ErrorExt;
use common_error::status_code::StatusCode;
use common_grpc::flight::{FlightEncoder, FlightMessage};
use common_query::prelude::{greptime_timestamp, greptime_value};
use datatypes::arrow::array::{Array, Float64Array, StringArray, TimestampMillisecondArray};
use datatypes::arrow::record_batch::RecordBatch;
use snafu::{OptionExt, ensure};
use store_api::codec::PrimaryKeyEncoding;
use store_api::metadata::RegionMetadataRef;
use store_api::region_request::{
AffectedRows, RegionBulkInsertsRequest, RegionPutRequest, RegionRequest,
};
use store_api::storage::RegionId;
use crate::batch_modifier::{TagColumnInfo, modify_batch_sparse};
use crate::engine::MetricEngineInner;
use crate::error;
use crate::error::Result;
impl MetricEngineInner {
/// Bulk-inserts logical rows into a metric region.
///
/// This method accepts a `RegionBulkInsertsRequest` whose payload is a logical
/// `RecordBatch` (timestamp, value and tag columns) for the given logical `region_id`.
///
/// The transformed batch is encoded to Arrow IPC and forwarded as a `BulkInserts`
/// request to the data region, along with the original `partition_expr_version`.
/// If the data region reports `StatusCode::Unsupported` for bulk inserts, the request
/// is transparently retried as a `Put` by converting the original logical batch into
/// `api::v1::Rows`, so callers observe the same semantics as `put_region`.
///
/// Returns the number of affected rows, or `0` if the input batch is empty.
pub async fn bulk_insert_region(
&self,
region_id: RegionId,
request: RegionBulkInsertsRequest,
) -> Result<AffectedRows> {
ensure!(
!self.is_physical_region(region_id),
error::UnsupportedRegionRequestSnafu {
request: RegionRequest::BulkInserts(request),
}
);
let (physical_region_id, data_region_id, primary_key_encoding) =
self.find_data_region_meta(region_id)?;
if primary_key_encoding != PrimaryKeyEncoding::Sparse {
return error::UnsupportedRegionRequestSnafu {
request: RegionRequest::BulkInserts(request),
}
.fail();
}
let batch = request.payload;
if batch.num_rows() == 0 {
return Ok(0);
}
let logical_metadata = self
.logical_region_metadata(physical_region_id, region_id)
.await?;
let (tag_columns, non_tag_indices) = self.resolve_tag_columns_from_metadata(
region_id,
data_region_id,
&batch,
&logical_metadata,
)?;
let modified_batch = modify_batch_sparse(
batch.clone(),
region_id.table_id(),
&tag_columns,
&non_tag_indices,
)?;
let (schema, data_header, payload) = record_batch_to_ipc(&modified_batch)?;
let partition_expr_version = request.partition_expr_version;
let request = RegionBulkInsertsRequest {
region_id: data_region_id,
payload: modified_batch,
raw_data: ArrowIpc {
schema,
data_header,
payload,
},
partition_expr_version,
};
match self
.data_region
.write_data(data_region_id, RegionRequest::BulkInserts(request))
.await
{
Ok(affected_rows) => Ok(affected_rows),
Err(err) if err.status_code() == StatusCode::Unsupported => {
// todo(hl): fallback path for PartitionTreeMemtable, remove this once we remove it
let rows = record_batch_to_rows(&batch, region_id)?;
self.put_region(
region_id,
RegionPutRequest {
rows,
hint: None,
partition_expr_version,
},
)
.await
}
Err(err) => Err(err),
}
}
fn resolve_tag_columns_from_metadata(
&self,
logical_region_id: RegionId,
data_region_id: RegionId,
batch: &RecordBatch,
logical_metadata: &RegionMetadataRef,
) -> Result<(Vec<TagColumnInfo>, Vec<usize>)> {
let tag_names: HashSet<&str> = logical_metadata
.column_metadatas
.iter()
.filter_map(|column| {
if column.semantic_type == SemanticType::Tag {
Some(column.column_schema.name.as_str())
} else {
None
}
})
.collect();
let mut tag_columns = Vec::new();
let mut non_tag_indices = Vec::new();
{
let state = self.state.read().unwrap();
let physical_columns = state
.physical_region_states()
.get(&data_region_id)
.context(error::PhysicalRegionNotFoundSnafu {
region_id: data_region_id,
})?
.physical_columns();
for (index, field) in batch.schema().fields().iter().enumerate() {
let name = field.name();
let column_id =
*physical_columns
.get(name)
.with_context(|| error::ColumnNotFoundSnafu {
name: name.clone(),
region_id: logical_region_id,
})?;
if tag_names.contains(name.as_str()) {
tag_columns.push(TagColumnInfo {
name: name.clone(),
index,
column_id,
});
} else {
non_tag_indices.push(index);
}
}
}
tag_columns.sort_by(|a, b| a.name.cmp(&b.name));
Ok((tag_columns, non_tag_indices))
}
}
fn record_batch_to_rows(batch: &RecordBatch, logical_region_id: RegionId) -> Result<api::v1::Rows> {
let schema_ref = batch.schema();
let fields = schema_ref.fields();
let mut ts_idx = None;
let mut val_idx = None;
let mut tag_indices = Vec::new();
for (idx, field) in fields.iter().enumerate() {
if field.name() == greptime_timestamp() {
ts_idx = Some(idx);
if !matches!(
field.data_type(),
datatypes::arrow::datatypes::DataType::Timestamp(
datatypes::arrow::datatypes::TimeUnit::Millisecond,
_
)
) {
return error::UnexpectedRequestSnafu {
reason: format!(
"Timestamp column '{}' in region {:?} has incompatible type: {:?}",
field.name(),
logical_region_id,
field.data_type()
),
}
.fail();
}
} else if field.name() == greptime_value() {
val_idx = Some(idx);
if !matches!(
field.data_type(),
datatypes::arrow::datatypes::DataType::Float64
) {
return error::UnexpectedRequestSnafu {
reason: format!(
"Value column '{}' in region {:?} has incompatible type: {:?}",
field.name(),
logical_region_id,
field.data_type()
),
}
.fail();
}
} else {
if !matches!(
field.data_type(),
datatypes::arrow::datatypes::DataType::Utf8
) {
return error::UnexpectedRequestSnafu {
reason: format!(
"Tag column '{}' in region {:?} must be Utf8, found: {:?}",
field.name(),
logical_region_id,
field.data_type()
),
}
.fail();
}
tag_indices.push(idx);
}
}
let ts_idx = ts_idx.with_context(|| error::UnexpectedRequestSnafu {
reason: format!(
"Timestamp column '{}' not found in RecordBatch for region {:?}",
greptime_timestamp(),
logical_region_id
),
})?;
let val_idx = val_idx.with_context(|| error::UnexpectedRequestSnafu {
reason: format!(
"Value column '{}' not found in RecordBatch for region {:?}",
greptime_value(),
logical_region_id
),
})?;
let mut schema = Vec::with_capacity(2 + tag_indices.len());
schema.push(api::v1::ColumnSchema {
column_name: greptime_timestamp().to_string(),
datatype: ColumnDataType::TimestampMillisecond as i32,
semantic_type: SemanticType::Timestamp as i32,
datatype_extension: None,
options: None,
});
schema.push(api::v1::ColumnSchema {
column_name: greptime_value().to_string(),
datatype: ColumnDataType::Float64 as i32,
semantic_type: SemanticType::Field as i32,
datatype_extension: None,
options: None,
});
for &idx in &tag_indices {
let field = &fields[idx];
schema.push(api::v1::ColumnSchema {
column_name: field.name().clone(),
datatype: ColumnDataType::String as i32,
semantic_type: SemanticType::Tag as i32,
datatype_extension: None,
options: None,
});
}
let ts_array = batch
.column(ts_idx)
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.expect("validated as TimestampMillisecond");
let val_array = batch
.column(val_idx)
.as_any()
.downcast_ref::<Float64Array>()
.expect("validated as Float64");
let tag_arrays: Vec<&StringArray> = tag_indices
.iter()
.map(|&idx| {
batch
.column(idx)
.as_any()
.downcast_ref::<StringArray>()
.expect("validated as Utf8")
})
.collect();
let num_rows = batch.num_rows();
let mut rows = Vec::with_capacity(num_rows);
for row_idx in 0..num_rows {
let mut values = Vec::with_capacity(2 + tag_arrays.len());
if ts_array.is_null(row_idx) {
values.push(api::v1::Value { value_data: None });
} else {
values.push(api::v1::Value {
value_data: Some(api::v1::value::ValueData::TimestampMillisecondValue(
ts_array.value(row_idx),
)),
});
}
if val_array.is_null(row_idx) {
values.push(api::v1::Value { value_data: None });
} else {
values.push(api::v1::Value {
value_data: Some(api::v1::value::ValueData::F64Value(
val_array.value(row_idx),
)),
});
}
for arr in &tag_arrays {
if arr.is_null(row_idx) {
values.push(api::v1::Value { value_data: None });
} else {
values.push(api::v1::Value {
value_data: Some(api::v1::value::ValueData::StringValue(
arr.value(row_idx).to_string(),
)),
});
}
}
rows.push(api::v1::Row { values });
}
Ok(api::v1::Rows { schema, rows })
}
fn record_batch_to_ipc(record_batch: &RecordBatch) -> Result<(Bytes, Bytes, Bytes)> {
let mut encoder = FlightEncoder::default();
let schema = encoder.encode_schema(record_batch.schema().as_ref());
let mut iter = encoder
.encode(FlightMessage::RecordBatch(record_batch.clone()))
.into_iter();
let Some(flight_data) = iter.next() else {
return error::UnexpectedRequestSnafu {
reason: "Failed to encode empty flight data",
}
.fail();
};
ensure!(
iter.next().is_none(),
error::UnexpectedRequestSnafu {
reason: "Bulk insert RecordBatch with dictionary arrays is unsupported".to_string(),
}
);
Ok((
schema.data_header,
flight_data.data_header,
flight_data.data_body,
))
}
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use std::sync::Arc;
use api::v1::ArrowIpc;
use common_error::ext::ErrorExt;
use common_query::prelude::{greptime_timestamp, greptime_value};
use common_recordbatch::RecordBatches;
use datatypes::arrow::array::{Float64Array, StringArray, TimestampMillisecondArray};
use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema, TimeUnit};
use datatypes::arrow::record_batch::RecordBatch;
use store_api::metric_engine_consts::MEMTABLE_PARTITION_TREE_PRIMARY_KEY_ENCODING;
use store_api::path_utils::table_dir;
use store_api::region_engine::RegionEngine;
use store_api::region_request::{RegionBulkInsertsRequest, RegionPutRequest, RegionRequest};
use store_api::storage::{RegionId, ScanRequest};
use super::record_batch_to_ipc;
use crate::error::Error;
use crate::test_util::{self, TestEnv};
fn build_logical_batch(start: usize, rows: usize) -> RecordBatch {
let schema = Arc::new(ArrowSchema::new(vec![
Field::new(
greptime_timestamp(),
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(greptime_value(), DataType::Float64, true),
Field::new("job", DataType::Utf8, true),
]));
let mut ts = Vec::with_capacity(rows);
let mut values = Vec::with_capacity(rows);
let mut tags = Vec::with_capacity(rows);
for i in start..start + rows {
ts.push(i as i64);
values.push(i as f64);
tags.push("tag_0".to_string());
}
RecordBatch::try_new(
schema,
vec![
Arc::new(TimestampMillisecondArray::from(ts)),
Arc::new(Float64Array::from(values)),
Arc::new(StringArray::from(tags)),
],
)
.unwrap()
}
fn build_bulk_request(logical_region_id: RegionId, batch: RecordBatch) -> RegionRequest {
let (schema, data_header, payload) = record_batch_to_ipc(&batch).unwrap();
RegionRequest::BulkInserts(RegionBulkInsertsRequest {
region_id: logical_region_id,
payload: batch,
raw_data: ArrowIpc {
schema,
data_header,
payload,
},
partition_expr_version: None,
})
}
async fn init_dense_metric_region(env: &TestEnv) -> RegionId {
let physical_region_id = env.default_physical_region_id();
env.create_physical_region(
physical_region_id,
&TestEnv::default_table_dir(),
vec![(
MEMTABLE_PARTITION_TREE_PRIMARY_KEY_ENCODING.to_string(),
"dense".to_string(),
)],
)
.await;
let logical_region_id = env.default_logical_region_id();
let request = test_util::create_logical_region_request(
&["job"],
physical_region_id,
&table_dir("test", logical_region_id.table_id()),
);
env.metric()
.handle_request(logical_region_id, RegionRequest::Create(request))
.await
.unwrap();
logical_region_id
}
#[tokio::test]
async fn test_bulk_insert_empty_batch_returns_zero() {
let env = TestEnv::new().await;
env.init_metric_region().await;
let logical_region_id = env.default_logical_region_id();
let batch = build_logical_batch(0, 0);
let request = RegionRequest::BulkInserts(RegionBulkInsertsRequest {
region_id: logical_region_id,
payload: batch,
raw_data: ArrowIpc::default(),
partition_expr_version: None,
});
let response = env
.metric()
.handle_request(logical_region_id, request)
.await
.unwrap();
assert_eq!(response.affected_rows, 0);
}
#[tokio::test]
async fn test_bulk_insert_physical_region_rejected() {
let env = TestEnv::new().await;
env.init_metric_region().await;
let physical_region_id = env.default_physical_region_id();
let batch = build_logical_batch(0, 2);
let request = build_bulk_request(physical_region_id, batch);
let err = env
.metric()
.handle_request(physical_region_id, request)
.await
.unwrap_err();
let Some(err) = err.as_any().downcast_ref::<Error>() else {
panic!("unexpected error type");
};
assert_matches!(err, Error::UnsupportedRegionRequest { .. });
}
#[tokio::test]
async fn test_bulk_insert_unknown_column_errors() {
let env = TestEnv::new().await;
env.init_metric_region().await;
let logical_region_id = env.default_logical_region_id();
let schema = Arc::new(ArrowSchema::new(vec![
Field::new(
greptime_timestamp(),
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(greptime_value(), DataType::Float64, true),
Field::new("nonexistent_column", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(TimestampMillisecondArray::from(vec![0i64])),
Arc::new(Float64Array::from(vec![1.0])),
Arc::new(StringArray::from(vec!["val"])),
],
)
.unwrap();
let request = build_bulk_request(logical_region_id, batch);
let err = env
.metric()
.handle_request(logical_region_id, request)
.await
.unwrap_err();
let Some(err) = err.as_any().downcast_ref::<Error>() else {
panic!("unexpected error type");
};
assert_matches!(err, Error::ColumnNotFound { .. });
}
#[tokio::test]
async fn test_bulk_insert_multiple_tag_columns() {
let env = TestEnv::new().await;
let physical_region_id = env.default_physical_region_id();
env.create_physical_region(physical_region_id, &TestEnv::default_table_dir(), vec![])
.await;
let logical_region_id = env.default_logical_region_id();
let request = test_util::create_logical_region_request(
&["host", "region"],
physical_region_id,
&table_dir("test", logical_region_id.table_id()),
);
env.metric()
.handle_request(logical_region_id, RegionRequest::Create(request))
.await
.unwrap();
let schema = Arc::new(ArrowSchema::new(vec![
Field::new(
greptime_timestamp(),
DataType::Timestamp(TimeUnit::Millisecond, None),
false,
),
Field::new(greptime_value(), DataType::Float64, true),
Field::new("host", DataType::Utf8, true),
Field::new("region", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(TimestampMillisecondArray::from(vec![0i64, 1, 2])),
Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
Arc::new(StringArray::from(vec!["h1", "h2", "h1"])),
Arc::new(StringArray::from(vec!["us-east", "us-west", "eu-west"])),
],
)
.unwrap();
let request = build_bulk_request(logical_region_id, batch);
let response = env
.metric()
.handle_request(logical_region_id, request)
.await
.unwrap();
assert_eq!(response.affected_rows, 3);
let stream = env
.metric()
.scan_to_stream(logical_region_id, ScanRequest::default())
.await
.unwrap();
let batches = RecordBatches::try_collect(stream).await.unwrap();
assert_eq!(batches.iter().map(|b| b.num_rows()).sum::<usize>(), 3);
}
#[tokio::test]
async fn test_bulk_insert_accumulates_rows() {
let env = TestEnv::new().await;
env.init_metric_region().await;
let logical_region_id = env.default_logical_region_id();
let request = build_bulk_request(logical_region_id, build_logical_batch(0, 3));
let response = env
.metric()
.handle_request(logical_region_id, request)
.await
.unwrap();
assert_eq!(response.affected_rows, 3);
let request = build_bulk_request(logical_region_id, build_logical_batch(3, 5));
let response = env
.metric()
.handle_request(logical_region_id, request)
.await
.unwrap();
assert_eq!(response.affected_rows, 5);
let stream = env
.metric()
.scan_to_stream(logical_region_id, ScanRequest::default())
.await
.unwrap();
let batches = RecordBatches::try_collect(stream).await.unwrap();
assert_eq!(batches.iter().map(|b| b.num_rows()).sum::<usize>(), 8);
}
#[tokio::test]
async fn test_bulk_insert_sparse_encoding() {
let env = TestEnv::new().await;
env.init_metric_region().await;
let logical_region_id = env.default_logical_region_id();
let request = build_bulk_request(logical_region_id, build_logical_batch(0, 4));
let response = env
.metric()
.handle_request(logical_region_id, request)
.await
.unwrap();
assert_eq!(response.affected_rows, 4);
let stream = env
.metric()
.scan_to_stream(logical_region_id, ScanRequest::default())
.await
.unwrap();
let batches = RecordBatches::try_collect(stream).await.unwrap();
assert_eq!(batches.iter().map(|b| b.num_rows()).sum::<usize>(), 4);
}
#[tokio::test]
async fn test_bulk_insert_dense_encoding_rejected() {
let env = TestEnv::new().await;
let logical_region_id = init_dense_metric_region(&env).await;
let request = build_bulk_request(logical_region_id, build_logical_batch(0, 2));
let err = env
.metric()
.handle_request(logical_region_id, request)
.await
.unwrap_err();
let Some(err) = err.as_any().downcast_ref::<Error>() else {
panic!("unexpected error type");
};
assert_matches!(err, Error::UnsupportedRegionRequest { .. });
}
#[tokio::test]
async fn test_bulk_insert_matches_put() {
let env_put = TestEnv::new().await;
env_put.init_metric_region().await;
let logical_region_id = env_put.default_logical_region_id();
let schema = test_util::row_schema_with_tags(&["job"]);
let rows = test_util::build_rows(1, 5);
env_put
.metric()
.handle_request(
logical_region_id,
RegionRequest::Put(RegionPutRequest {
rows: api::v1::Rows { schema, rows },
hint: None,
partition_expr_version: None,
}),
)
.await
.unwrap();
let put_stream = env_put
.metric()
.scan_to_stream(logical_region_id, ScanRequest::default())
.await
.unwrap();
let put_batches = RecordBatches::try_collect(put_stream).await.unwrap();
let put_output = put_batches.pretty_print().unwrap();
let env_bulk = TestEnv::new().await;
env_bulk.init_metric_region().await;
let request = build_bulk_request(logical_region_id, build_logical_batch(0, 5));
env_bulk
.metric()
.handle_request(logical_region_id, request)
.await
.unwrap();
let bulk_stream = env_bulk
.metric()
.scan_to_stream(logical_region_id, ScanRequest::default())
.await
.unwrap();
let bulk_batches = RecordBatches::try_collect(bulk_stream).await.unwrap();
let bulk_output = bulk_batches.pretty_print().unwrap();
assert_eq!(put_output, bulk_output);
}
#[test]
fn test_record_batch_to_rows_with_null_values() {
use datatypes::arrow::array::{Float64Array, StringArray, TimestampMillisecondArray};
use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema, TimeUnit};
use datatypes::arrow::record_batch::RecordBatch;
use store_api::storage::RegionId;
use crate::engine::bulk_insert::record_batch_to_rows;
let schema = Arc::new(ArrowSchema::new(vec![
Field::new(
greptime_timestamp(),
DataType::Timestamp(TimeUnit::Millisecond, None),
true,
),
Field::new(greptime_value(), DataType::Float64, true),
Field::new("job", DataType::Utf8, true),
Field::new("host", DataType::Utf8, true),
]));
let ts_array = TimestampMillisecondArray::from(vec![Some(1000), None, Some(3000)]);
let val_array = Float64Array::from(vec![Some(1.0), Some(2.0), None]);
let job_array = StringArray::from(vec![Some("job1"), None, Some("job3")]);
let host_array = StringArray::from(vec![None, Some("host2"), Some("host3")]);
let batch = RecordBatch::try_new(
schema,
vec![
Arc::new(ts_array),
Arc::new(val_array),
Arc::new(job_array),
Arc::new(host_array),
],
)
.unwrap();
let region_id = RegionId::new(1, 1);
let rows = record_batch_to_rows(&batch, region_id).unwrap();
assert_eq!(rows.rows.len(), 3);
assert_eq!(rows.schema.len(), 4);
// Row 0: all non-null except host
assert!(rows.rows[0].values[0].value_data.is_some());
assert!(rows.rows[0].values[1].value_data.is_some());
assert!(rows.rows[0].values[2].value_data.is_some());
assert!(rows.rows[0].values[3].value_data.is_none());
// Row 1: null timestamp, null job
assert!(rows.rows[1].values[0].value_data.is_none());
assert!(rows.rows[1].values[1].value_data.is_some());
assert!(rows.rows[1].values[2].value_data.is_none());
assert!(rows.rows[1].values[3].value_data.is_some());
// Row 2: null value
assert!(rows.rows[2].values[0].value_data.is_some());
assert!(rows.rows[2].values[1].value_data.is_none());
assert!(rows.rows[2].values[2].value_data.is_some());
assert!(rows.rows[2].values[3].value_data.is_some());
}
}

View File

@@ -460,7 +460,7 @@ impl MetricEngineInner {
.await
}
fn find_data_region_meta(
pub(crate) fn find_data_region_meta(
&self,
logical_region_id: RegionId,
) -> Result<(RegionId, RegionId, PrimaryKeyEncoding)> {

View File

@@ -52,6 +52,7 @@
#![feature(assert_matches)]
mod batch_modifier;
pub mod config;
mod data_region;
pub mod engine;

View File

@@ -108,6 +108,11 @@ name = "memtable_bench"
harness = false
required-features = ["test"]
[[bench]]
name = "bench_cache_stream"
harness = false
required-features = ["test"]
[[bench]]
name = "bench_filter_time_partition"
harness = false

View File

@@ -0,0 +1,126 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Benchmarks for `cache_flat_range_stream` overhead.
//!
//! Compares consuming batches from a plain stream vs through the caching wrapper
//! that clones batches for the range cache.
//!
//! Run with:
//! ```sh
//! cargo bench -p mito2 --features test --bench bench_cache_stream
//! ```
use std::collections::VecDeque;
use std::sync::Arc;
use criterion::{Criterion, criterion_group, criterion_main};
use futures::TryStreamExt;
use mito_codec::row_converter::DensePrimaryKeyCodec;
use mito2::memtable::bulk::context::BulkIterContext;
use mito2::memtable::bulk::part::{BulkPartConverter, BulkPartEncoder};
use mito2::memtable::bulk::part_reader::EncodedBulkPartIter;
use mito2::read::range_cache::bench_cache_flat_range_stream;
use mito2::sst::parquet::DEFAULT_ROW_GROUP_SIZE;
use mito2::sst::{FlatSchemaOptions, to_flat_sst_arrow_schema};
use mito2::test_util::bench_util::{CpuDataGenerator, cpu_metadata};
fn cache_flat_range_stream_bench(c: &mut Criterion) {
let metadata = Arc::new(cpu_metadata());
let region_id = metadata.region_id;
let start_sec = 1710043200;
// 2000 hosts × 51 steps = 102,000 rows ≈ DEFAULT_ROW_GROUP_SIZE
let num_hosts = 2000;
let end_sec = start_sec + 510;
let generator = CpuDataGenerator::new(metadata.clone(), num_hosts, start_sec, end_sec);
// Build a BulkPart from all the generated data
let schema = to_flat_sst_arrow_schema(&metadata, &FlatSchemaOptions::default());
let codec = Arc::new(DensePrimaryKeyCodec::new(&metadata));
let mut converter = BulkPartConverter::new(
&metadata,
schema,
DEFAULT_ROW_GROUP_SIZE,
codec,
true, // store_pk_columns
);
for kvs in generator.iter() {
converter.append_key_values(&kvs).unwrap();
}
let bulk_part = converter.convert().unwrap();
// Encode to parquet
let encoder = BulkPartEncoder::new(metadata.clone(), DEFAULT_ROW_GROUP_SIZE).unwrap();
let encoded_part = encoder.encode_part(&bulk_part).unwrap().unwrap();
// Decode all record batches
let num_row_groups = encoded_part.metadata().parquet_metadata.num_row_groups();
let context = Arc::new(
BulkIterContext::new(
metadata.clone(),
None, // No projection
None, // No predicate
false,
)
.unwrap(),
);
let row_groups: VecDeque<usize> = (0..num_row_groups).collect();
let rt = tokio::runtime::Runtime::new().unwrap();
let mut group = c.benchmark_group("cache_flat_range_stream");
group.sample_size(10);
group.bench_function("baseline_iter_stream", |b| {
b.iter(|| {
rt.block_on(async {
let iter = EncodedBulkPartIter::try_new(
&encoded_part,
context.clone(),
row_groups.clone(),
None,
None,
)
.unwrap();
let stream: mito2::read::BoxedRecordBatchStream =
Box::pin(futures::stream::iter(iter));
let mut stream = stream;
while let Some(_batch) = stream.try_next().await.unwrap() {}
});
});
});
group.bench_function("cache_flat_range_stream", |b| {
b.iter(|| {
rt.block_on(async {
let iter = EncodedBulkPartIter::try_new(
&encoded_part,
context.clone(),
row_groups.clone(),
None,
None,
)
.unwrap();
let stream: mito2::read::BoxedRecordBatchStream =
Box::pin(futures::stream::iter(iter));
let mut stream = bench_cache_flat_range_stream(stream, 64 * 1024 * 1024, region_id);
while let Some(_batch) = stream.try_next().await.unwrap() {}
});
});
});
}
criterion_group!(benches, cache_flat_range_stream_bench);
criterion_main!(benches);

View File

@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! Benchmarks for memtable operations: writes, full scans, filtered scans,
//! bulk part conversion, record batch iteration with filters, and flat merge.
//!
//! Run with:
//! ```sh
//! cargo bench -p mito2 --features test --bench memtable_bench
//! ```
use std::sync::Arc;
use api::v1::value::ValueData;
use api::v1::{Row, Rows, SemanticType};
use criterion::{Criterion, criterion_group, criterion_main};
use datafusion_common::Column;
use datafusion_expr::{Expr, lit};
use datatypes::data_type::ConcreteDataType;
use datatypes::schema::ColumnSchema;
use mito_codec::row_converter::DensePrimaryKeyCodec;
use mito2::memtable::bulk::context::BulkIterContext;
use mito2::memtable::bulk::part::BulkPartConverter;
@@ -28,20 +30,13 @@ use mito2::memtable::bulk::part_reader::BulkPartBatchIter;
use mito2::memtable::bulk::{BulkMemtable, BulkMemtableConfig};
use mito2::memtable::partition_tree::{PartitionTreeConfig, PartitionTreeMemtable};
use mito2::memtable::time_series::TimeSeriesMemtable;
use mito2::memtable::{KeyValues, Memtable, RangesOptions};
use mito2::memtable::{IterBuilder, Memtable, RangesOptions};
use mito2::read::flat_merge::FlatMergeIterator;
use mito2::read::scan_region::PredicateGroup;
use mito2::region::options::MergeMode;
use mito2::sst::{FlatSchemaOptions, to_flat_sst_arrow_schema};
use mito2::test_util::memtable_util::{self, region_metadata_to_row_schema};
use rand::Rng;
use rand::rngs::ThreadRng;
use rand::seq::IndexedRandom;
use store_api::metadata::{
ColumnMetadata, RegionMetadata, RegionMetadataBuilder, RegionMetadataRef,
};
use store_api::storage::RegionId;
use table::predicate::Predicate;
use mito2::test_util::bench_util::{CpuDataGenerator, cpu_metadata};
use mito2::test_util::memtable_util;
/// Writes rows.
fn write_rows(c: &mut Criterion) {
@@ -105,7 +100,11 @@ fn full_scan(c: &mut Criterion) {
}
b.iter(|| {
let iter = memtable.iter(None, None, None).unwrap();
let iter = memtable
.ranges(None, RangesOptions::default())
.unwrap()
.build(None)
.unwrap();
for batch in iter {
let _batch = batch.unwrap();
}
@@ -145,7 +144,17 @@ fn filter_1_host(c: &mut Criterion) {
let predicate = generator.random_host_filter();
b.iter(|| {
let iter = memtable.iter(None, Some(predicate.clone()), None).unwrap();
let iter = memtable
.ranges(
None,
RangesOptions {
predicate: PredicateGroup::new(&metadata, predicate.exprs()).unwrap(),
..Default::default()
},
)
.unwrap()
.build(None)
.unwrap();
for batch in iter {
let _batch = batch.unwrap();
}
@@ -202,224 +211,6 @@ fn filter_1_host(c: &mut Criterion) {
});
}
struct Host {
hostname: String,
region: String,
datacenter: String,
rack: String,
os: String,
arch: String,
team: String,
service: String,
service_version: String,
service_environment: String,
}
impl Host {
fn random_with_id(id: usize) -> Host {
let mut rng = rand::rng();
let region = format!("ap-southeast-{}", rng.random_range(0..10));
let datacenter = format!(
"{}{}",
region,
['a', 'b', 'c', 'd', 'e'].choose(&mut rng).unwrap()
);
Host {
hostname: format!("host_{id}"),
region,
datacenter,
rack: rng.random_range(0..100).to_string(),
os: "Ubuntu16.04LTS".to_string(),
arch: "x86".to_string(),
team: "CHI".to_string(),
service: rng.random_range(0..100).to_string(),
service_version: rng.random_range(0..10).to_string(),
service_environment: "test".to_string(),
}
}
fn fill_values(&self, values: &mut Vec<api::v1::Value>) {
let tags = [
api::v1::Value {
value_data: Some(ValueData::StringValue(self.hostname.clone())),
},
api::v1::Value {
value_data: Some(ValueData::StringValue(self.region.clone())),
},
api::v1::Value {
value_data: Some(ValueData::StringValue(self.datacenter.clone())),
},
api::v1::Value {
value_data: Some(ValueData::StringValue(self.rack.clone())),
},
api::v1::Value {
value_data: Some(ValueData::StringValue(self.os.clone())),
},
api::v1::Value {
value_data: Some(ValueData::StringValue(self.arch.clone())),
},
api::v1::Value {
value_data: Some(ValueData::StringValue(self.team.clone())),
},
api::v1::Value {
value_data: Some(ValueData::StringValue(self.service.clone())),
},
api::v1::Value {
value_data: Some(ValueData::StringValue(self.service_version.clone())),
},
api::v1::Value {
value_data: Some(ValueData::StringValue(self.service_environment.clone())),
},
];
for tag in tags {
values.push(tag);
}
}
}
struct CpuDataGenerator {
metadata: RegionMetadataRef,
column_schemas: Vec<api::v1::ColumnSchema>,
hosts: Vec<Host>,
start_sec: i64,
end_sec: i64,
}
impl CpuDataGenerator {
fn new(metadata: RegionMetadataRef, num_hosts: usize, start_sec: i64, end_sec: i64) -> Self {
let column_schemas = region_metadata_to_row_schema(&metadata);
Self {
metadata,
column_schemas,
hosts: Self::generate_hosts(num_hosts),
start_sec,
end_sec,
}
}
fn iter(&self) -> impl Iterator<Item = KeyValues> + '_ {
// point per 10s.
(self.start_sec..self.end_sec)
.step_by(10)
.enumerate()
.map(|(seq, ts)| self.build_key_values(seq, ts))
}
fn build_key_values(&self, seq: usize, current_sec: i64) -> KeyValues {
let rows = self
.hosts
.iter()
.map(|host| {
let mut rng = rand::rng();
let mut values = Vec::with_capacity(21);
values.push(api::v1::Value {
value_data: Some(ValueData::TimestampMillisecondValue(current_sec * 1000)),
});
host.fill_values(&mut values);
for _ in 0..10 {
values.push(api::v1::Value {
value_data: Some(ValueData::F64Value(Self::random_f64(&mut rng))),
});
}
Row { values }
})
.collect();
let mutation = api::v1::Mutation {
op_type: api::v1::OpType::Put as i32,
sequence: seq as u64,
rows: Some(Rows {
schema: self.column_schemas.clone(),
rows,
}),
write_hint: None,
};
KeyValues::new(&self.metadata, mutation).unwrap()
}
fn random_host_filter(&self) -> Predicate {
let host = self.random_hostname();
let expr = Expr::Column(Column::from_name("hostname")).eq(lit(host));
Predicate::new(vec![expr])
}
fn random_host_filter_exprs(&self) -> Vec<Expr> {
let host = self.random_hostname();
vec![Expr::Column(Column::from_name("hostname")).eq(lit(host))]
}
fn random_hostname(&self) -> String {
let mut rng = rand::rng();
self.hosts.choose(&mut rng).unwrap().hostname.clone()
}
fn random_f64(rng: &mut ThreadRng) -> f64 {
let base: u32 = rng.random_range(30..95);
base as f64
}
fn generate_hosts(num_hosts: usize) -> Vec<Host> {
(0..num_hosts).map(Host::random_with_id).collect()
}
}
/// Creates a metadata for TSBS cpu-like table.
fn cpu_metadata() -> RegionMetadata {
let mut builder = RegionMetadataBuilder::new(RegionId::new(1, 1));
builder.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
semantic_type: SemanticType::Timestamp,
column_id: 0,
});
let mut column_id = 1;
let tags = [
"hostname",
"region",
"datacenter",
"rack",
"os",
"arch",
"team",
"service",
"service_version",
"service_environment",
];
for tag in tags {
builder.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(tag, ConcreteDataType::string_datatype(), true),
semantic_type: SemanticType::Tag,
column_id,
});
column_id += 1;
}
let fields = [
"usage_user",
"usage_system",
"usage_idle",
"usage_nice",
"usage_iowait",
"usage_irq",
"usage_softirq",
"usage_steal",
"usage_guest",
"usage_guest_nice",
];
for field in fields {
builder.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(field, ConcreteDataType::float64_datatype(), true),
semantic_type: SemanticType::Field,
column_id,
});
column_id += 1;
}
builder.primary_key(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
builder.build().unwrap()
}
fn bulk_part_converter(c: &mut Criterion) {
let metadata = Arc::new(cpu_metadata());
let start_sec = 1710043200;

View File

@@ -21,7 +21,7 @@ use criterion::{Criterion, criterion_group, criterion_main};
use datatypes::data_type::ConcreteDataType;
use datatypes::schema::ColumnSchema;
use mito2::memtable::simple_bulk_memtable::SimpleBulkMemtable;
use mito2::memtable::{KeyValues, Memtable, MemtableRanges, RangesOptions};
use mito2::memtable::{IterBuilder, KeyValues, Memtable, MemtableRanges, RangesOptions};
use mito2::read;
use mito2::read::Source;
use mito2::read::dedup::DedupReader;
@@ -156,7 +156,11 @@ async fn flush(mem: &SimpleBulkMemtable) {
}
async fn flush_original(mem: &SimpleBulkMemtable) {
let iter = mem.iter(None, None, None).unwrap();
let iter = mem
.ranges(None, RangesOptions::default())
.unwrap()
.build(None)
.unwrap();
for b in iter {
black_box(b.unwrap());
}

View File

@@ -17,7 +17,6 @@ use std::time::{Duration, Instant};
use async_stream::try_stream;
use common_time::Timestamp;
use either::Either;
use futures::{Stream, TryStreamExt};
use object_store::services::Fs;
use object_store::util::{join_dir, with_instrument_layers};
@@ -37,7 +36,7 @@ use crate::error::{
CleanDirSnafu, DeleteIndexSnafu, DeleteIndexesSnafu, DeleteSstsSnafu, OpenDalSnafu, Result,
};
use crate::metrics::{COMPACTION_STAGE_ELAPSED, FLUSH_ELAPSED};
use crate::read::{FlatSource, Source};
use crate::read::FlatSource;
use crate::region::options::IndexOptions;
use crate::sst::file::{FileHandle, RegionFileId, RegionIndexId};
use crate::sst::index::IndexerBuilderImpl;
@@ -47,7 +46,7 @@ use crate::sst::location::{self, region_dir_from_table_dir};
use crate::sst::parquet::reader::ParquetReaderBuilder;
use crate::sst::parquet::writer::ParquetWriter;
use crate::sst::parquet::{SstInfo, WriteOptions};
use crate::sst::{DEFAULT_WRITE_BUFFER_SIZE, DEFAULT_WRITE_CONCURRENCY};
use crate::sst::{DEFAULT_WRITE_BUFFER_SIZE, DEFAULT_WRITE_CONCURRENCY, FormatType};
pub type AccessLayerRef = Arc<AccessLayer>;
/// SST write results.
@@ -339,6 +338,7 @@ impl AccessLayer {
metrics: &mut Metrics,
) -> Result<SstInfoArray> {
let region_id = request.metadata.region_id;
let region_metadata = request.metadata.clone();
let cache_manager = request.cache_manager.clone();
let sst_info = if let Some(write_cache) = cache_manager.write_cache() {
@@ -391,15 +391,19 @@ impl AccessLayer {
)
.await
.with_file_cleaner(cleaner);
match request.source {
Either::Left(source) => {
match request.sst_write_format {
FormatType::PrimaryKey => {
writer
.write_all(source, request.max_sequence, write_opts)
.write_all_flat_as_primary_key(
request.source,
request.max_sequence,
write_opts,
)
.await?
}
Either::Right(flat_source) => {
FormatType::Flat => {
writer
.write_all_flat(flat_source, request.max_sequence, write_opts)
.write_all_flat(request.source, request.max_sequence, write_opts)
.await?
}
}
@@ -412,6 +416,7 @@ impl AccessLayer {
cache_manager.put_parquet_meta_data(
RegionFileId::new(region_id, sst.file_id),
parquet_metadata.clone(),
Some(region_metadata.clone()),
)
}
}
@@ -520,11 +525,12 @@ pub enum OperationType {
pub struct SstWriteRequest {
pub op_type: OperationType,
pub metadata: RegionMetadataRef,
pub source: Either<Source, FlatSource>,
pub source: FlatSource,
pub cache_manager: CacheManagerRef,
#[allow(dead_code)]
pub storage: Option<String>,
pub max_sequence: Option<SequenceNumber>,
pub sst_write_format: FormatType,
/// Configs for index
pub index_options: IndexOptions,

View File

@@ -28,6 +28,7 @@ use std::ops::Range;
use std::sync::Arc;
use bytes::Bytes;
use common_telemetry::warn;
use datatypes::arrow::record_batch::RecordBatch;
use datatypes::value::Value;
use datatypes::vectors::VectorRef;
@@ -36,8 +37,10 @@ use index::result_cache::IndexResultCache;
use moka::notification::RemovalCause;
use moka::sync::Cache;
use object_store::ObjectStore;
use parquet::file::metadata::{PageIndexPolicy, ParquetMetaData};
use parquet::file::metadata::{FileMetaData, PageIndexPolicy, ParquetMetaData};
use puffin::puffin_manager::cache::{PuffinMetadataCache, PuffinMetadataCacheRef};
use snafu::{OptionExt, ResultExt};
use store_api::metadata::RegionMetadataRef;
use store_api::storage::{ConcreteDataType, FileId, RegionId, TimeSeriesRowSelector};
use crate::cache::cache_size::parquet_meta_size;
@@ -46,10 +49,13 @@ use crate::cache::index::inverted_index::{InvertedIndexCache, InvertedIndexCache
#[cfg(feature = "vector_index")]
use crate::cache::index::vector_index::{VectorIndexCache, VectorIndexCacheRef};
use crate::cache::write_cache::WriteCacheRef;
use crate::error::{InvalidMetadataSnafu, InvalidParquetSnafu, Result};
use crate::memtable::record_batch_estimated_size;
use crate::metrics::{CACHE_BYTES, CACHE_EVICTION, CACHE_HIT, CACHE_MISS};
use crate::read::Batch;
use crate::read::range_cache::{RangeScanCacheKey, RangeScanCacheValue};
use crate::sst::file::{RegionFileId, RegionIndexId};
use crate::sst::parquet::PARQUET_METADATA_KEY;
use crate::sst::parquet::reader::MetadataCacheMetrics;
/// Metrics type key for sst meta.
@@ -64,6 +70,108 @@ const FILE_TYPE: &str = "file";
const INDEX_TYPE: &str = "index";
/// Metrics type key for selector result cache.
const SELECTOR_RESULT_TYPE: &str = "selector_result";
/// Metrics type key for range scan result cache.
const RANGE_RESULT_TYPE: &str = "range_result";
/// Cached SST metadata combines the parquet footer with the decoded region metadata.
///
/// The cached parquet footer strips the `greptime:metadata` JSON payload and stores the decoded
/// [RegionMetadata] separately so readers can skip repeated deserialization work.
#[derive(Debug)]
pub(crate) struct CachedSstMeta {
parquet_metadata: Arc<ParquetMetaData>,
region_metadata: RegionMetadataRef,
region_metadata_weight: usize,
}
impl CachedSstMeta {
pub(crate) fn try_new(file_path: &str, parquet_metadata: ParquetMetaData) -> Result<Self> {
Self::try_new_with_region_metadata(file_path, parquet_metadata, None)
}
pub(crate) fn try_new_with_region_metadata(
file_path: &str,
parquet_metadata: ParquetMetaData,
region_metadata: Option<RegionMetadataRef>,
) -> Result<Self> {
let file_metadata = parquet_metadata.file_metadata();
let key_values = file_metadata
.key_value_metadata()
.context(InvalidParquetSnafu {
file: file_path,
reason: "missing key value meta",
})?;
let meta_value = key_values
.iter()
.find(|kv| kv.key == PARQUET_METADATA_KEY)
.with_context(|| InvalidParquetSnafu {
file: file_path,
reason: format!("key {} not found", PARQUET_METADATA_KEY),
})?;
let json = meta_value
.value
.as_ref()
.with_context(|| InvalidParquetSnafu {
file: file_path,
reason: format!("No value for key {}", PARQUET_METADATA_KEY),
})?;
let region_metadata = match region_metadata {
Some(region_metadata) => region_metadata,
None => Arc::new(
store_api::metadata::RegionMetadata::from_json(json)
.context(InvalidMetadataSnafu)?,
),
};
// Keep the previous JSON-byte floor and charge the decoded structures as well.
let region_metadata_weight = region_metadata.estimated_size().max(json.len());
let parquet_metadata = Arc::new(strip_region_metadata_from_parquet(parquet_metadata));
Ok(Self {
parquet_metadata,
region_metadata,
region_metadata_weight,
})
}
pub(crate) fn parquet_metadata(&self) -> Arc<ParquetMetaData> {
self.parquet_metadata.clone()
}
pub(crate) fn region_metadata(&self) -> RegionMetadataRef {
self.region_metadata.clone()
}
}
fn strip_region_metadata_from_parquet(parquet_metadata: ParquetMetaData) -> ParquetMetaData {
let file_metadata = parquet_metadata.file_metadata();
let filtered_key_values = file_metadata.key_value_metadata().and_then(|key_values| {
let filtered = key_values
.iter()
.filter(|kv| kv.key != PARQUET_METADATA_KEY)
.cloned()
.collect::<Vec<_>>();
(!filtered.is_empty()).then_some(filtered)
});
let stripped_file_metadata = FileMetaData::new(
file_metadata.version(),
file_metadata.num_rows(),
file_metadata.created_by().map(ToString::to_string),
filtered_key_values,
file_metadata.schema_descr_ptr(),
file_metadata.column_orders().cloned(),
);
let mut builder = parquet_metadata.into_builder();
let row_groups = builder.take_row_groups();
let column_index = builder.take_column_index();
let offset_index = builder.take_offset_index();
parquet::file::metadata::ParquetMetaDataBuilder::new(stripped_file_metadata)
.set_row_groups(row_groups)
.set_column_index(column_index)
.set_offset_index(offset_index)
.build()
}
/// Cache strategies that may only enable a subset of caches.
#[derive(Clone)]
@@ -81,18 +189,17 @@ pub enum CacheStrategy {
}
impl CacheStrategy {
/// Gets parquet metadata with cache metrics tracking.
/// Returns the metadata and updates the provided metrics.
pub(crate) async fn get_parquet_meta_data(
/// Gets fused SST metadata with cache metrics tracking.
pub(crate) async fn get_sst_meta_data(
&self,
file_id: RegionFileId,
metrics: &mut MetadataCacheMetrics,
page_index_policy: PageIndexPolicy,
) -> Option<Arc<ParquetMetaData>> {
) -> Option<Arc<CachedSstMeta>> {
match self {
CacheStrategy::EnableAll(cache_manager) | CacheStrategy::Compaction(cache_manager) => {
cache_manager
.get_parquet_meta_data(file_id, metrics, page_index_policy)
.get_sst_meta_data(file_id, metrics, page_index_policy)
.await
}
CacheStrategy::Disabled => {
@@ -102,30 +209,48 @@ impl CacheStrategy {
}
}
/// Calls [CacheManager::get_parquet_meta_data_from_mem_cache()].
pub fn get_parquet_meta_data_from_mem_cache(
/// Calls [CacheManager::get_sst_meta_data_from_mem_cache()].
pub(crate) fn get_sst_meta_data_from_mem_cache(
&self,
file_id: RegionFileId,
) -> Option<Arc<ParquetMetaData>> {
) -> Option<Arc<CachedSstMeta>> {
match self {
CacheStrategy::EnableAll(cache_manager) => {
cache_manager.get_parquet_meta_data_from_mem_cache(file_id)
}
CacheStrategy::Compaction(cache_manager) => {
cache_manager.get_parquet_meta_data_from_mem_cache(file_id)
CacheStrategy::EnableAll(cache_manager) | CacheStrategy::Compaction(cache_manager) => {
cache_manager.get_sst_meta_data_from_mem_cache(file_id)
}
CacheStrategy::Disabled => None,
}
}
/// Calls [CacheManager::put_parquet_meta_data()].
pub fn put_parquet_meta_data(&self, file_id: RegionFileId, metadata: Arc<ParquetMetaData>) {
/// Calls [CacheManager::get_parquet_meta_data_from_mem_cache()].
pub fn get_parquet_meta_data_from_mem_cache(
&self,
file_id: RegionFileId,
) -> Option<Arc<ParquetMetaData>> {
self.get_sst_meta_data_from_mem_cache(file_id)
.map(|metadata| metadata.parquet_metadata())
}
/// Calls [CacheManager::put_sst_meta_data()].
pub(crate) fn put_sst_meta_data(&self, file_id: RegionFileId, metadata: Arc<CachedSstMeta>) {
match self {
CacheStrategy::EnableAll(cache_manager) => {
cache_manager.put_parquet_meta_data(file_id, metadata);
CacheStrategy::EnableAll(cache_manager) | CacheStrategy::Compaction(cache_manager) => {
cache_manager.put_sst_meta_data(file_id, metadata);
}
CacheStrategy::Compaction(cache_manager) => {
cache_manager.put_parquet_meta_data(file_id, metadata);
CacheStrategy::Disabled => {}
}
}
/// Calls [CacheManager::put_parquet_meta_data()].
pub fn put_parquet_meta_data(
&self,
file_id: RegionFileId,
metadata: Arc<ParquetMetaData>,
region_metadata: Option<RegionMetadataRef>,
) {
match self {
CacheStrategy::EnableAll(cache_manager) | CacheStrategy::Compaction(cache_manager) => {
cache_manager.put_parquet_meta_data(file_id, metadata, region_metadata);
}
CacheStrategy::Disabled => {}
}
@@ -223,6 +348,31 @@ impl CacheStrategy {
}
}
/// Calls [CacheManager::get_range_result()].
/// It returns None if the strategy is [CacheStrategy::Compaction] or [CacheStrategy::Disabled].
#[allow(dead_code)]
pub(crate) fn get_range_result(
&self,
key: &RangeScanCacheKey,
) -> Option<Arc<RangeScanCacheValue>> {
match self {
CacheStrategy::EnableAll(cache_manager) => cache_manager.get_range_result(key),
CacheStrategy::Compaction(_) | CacheStrategy::Disabled => None,
}
}
/// Calls [CacheManager::put_range_result()].
/// It does nothing if the strategy isn't [CacheStrategy::EnableAll].
pub(crate) fn put_range_result(
&self,
key: RangeScanCacheKey,
result: Arc<RangeScanCacheValue>,
) {
if let CacheStrategy::EnableAll(cache_manager) = self {
cache_manager.put_range_result(key, result);
}
}
/// Calls [CacheManager::write_cache()].
/// It returns None if the strategy is [CacheStrategy::Disabled].
pub fn write_cache(&self) -> Option<&WriteCacheRef> {
@@ -324,6 +474,8 @@ pub struct CacheManager {
puffin_metadata_cache: Option<PuffinMetadataCacheRef>,
/// Cache for time series selectors.
selector_result_cache: Option<SelectorResultCache>,
/// Cache for range scan outputs in flat format.
range_result_cache: Option<RangeResultCache>,
/// Cache for index result.
index_result_cache: Option<IndexResultCache>,
}
@@ -336,6 +488,35 @@ impl CacheManager {
CacheManagerBuilder::default()
}
/// Gets fused SST metadata with metrics tracking.
/// Tries in-memory cache first, then file cache, updating metrics accordingly.
pub(crate) async fn get_sst_meta_data(
&self,
file_id: RegionFileId,
metrics: &mut MetadataCacheMetrics,
page_index_policy: PageIndexPolicy,
) -> Option<Arc<CachedSstMeta>> {
if let Some(metadata) = self.get_sst_meta_data_from_mem_cache(file_id) {
metrics.mem_cache_hit += 1;
return Some(metadata);
}
let key = IndexKey::new(file_id.region_id(), file_id.file_id(), FileType::Parquet);
if let Some(write_cache) = &self.write_cache
&& let Some(metadata) = write_cache
.file_cache()
.get_sst_meta_data(key, metrics, page_index_policy)
.await
{
metrics.file_cache_hit += 1;
self.put_sst_meta_data(file_id, metadata.clone());
return Some(metadata);
}
metrics.cache_miss += 1;
None
}
/// Gets cached [ParquetMetaData] with metrics tracking.
/// Tries in-memory cache first, then file cache, updating metrics accordingly.
pub(crate) async fn get_parquet_meta_data(
@@ -344,29 +525,21 @@ impl CacheManager {
metrics: &mut MetadataCacheMetrics,
page_index_policy: PageIndexPolicy,
) -> Option<Arc<ParquetMetaData>> {
// Try to get metadata from sst meta cache
if let Some(metadata) = self.get_parquet_meta_data_from_mem_cache(file_id) {
metrics.mem_cache_hit += 1;
return Some(metadata);
}
self.get_sst_meta_data(file_id, metrics, page_index_policy)
.await
.map(|metadata| metadata.parquet_metadata())
}
// Try to get metadata from write cache
let key = IndexKey::new(file_id.region_id(), file_id.file_id(), FileType::Parquet);
if let Some(write_cache) = &self.write_cache
&& let Some(metadata) = write_cache
.file_cache()
.get_parquet_meta_data(key, metrics, page_index_policy)
.await
{
metrics.file_cache_hit += 1;
let metadata = Arc::new(metadata);
// Put metadata into sst meta cache
self.put_parquet_meta_data(file_id, metadata.clone());
return Some(metadata);
};
metrics.cache_miss += 1;
None
/// Gets cached fused SST metadata from in-memory cache.
/// This method does not perform I/O.
pub(crate) fn get_sst_meta_data_from_mem_cache(
&self,
file_id: RegionFileId,
) -> Option<Arc<CachedSstMeta>> {
self.sst_meta_cache.as_ref().and_then(|sst_meta_cache| {
let value = sst_meta_cache.get(&SstMetaKey(file_id.region_id(), file_id.file_id()));
update_hit_miss(value, SST_META_TYPE)
})
}
/// Gets cached [ParquetMetaData] from in-memory cache.
@@ -375,15 +548,12 @@ impl CacheManager {
&self,
file_id: RegionFileId,
) -> Option<Arc<ParquetMetaData>> {
// Try to get metadata from sst meta cache
self.sst_meta_cache.as_ref().and_then(|sst_meta_cache| {
let value = sst_meta_cache.get(&SstMetaKey(file_id.region_id(), file_id.file_id()));
update_hit_miss(value, SST_META_TYPE)
})
self.get_sst_meta_data_from_mem_cache(file_id)
.map(|metadata| metadata.parquet_metadata())
}
/// Puts [ParquetMetaData] into the cache.
pub fn put_parquet_meta_data(&self, file_id: RegionFileId, metadata: Arc<ParquetMetaData>) {
/// Puts fused SST metadata into the cache.
pub(crate) fn put_sst_meta_data(&self, file_id: RegionFileId, metadata: Arc<CachedSstMeta>) {
if let Some(cache) = &self.sst_meta_cache {
let key = SstMetaKey(file_id.region_id(), file_id.file_id());
CACHE_BYTES
@@ -393,6 +563,34 @@ impl CacheManager {
}
}
/// Puts [ParquetMetaData] into the cache.
pub fn put_parquet_meta_data(
&self,
file_id: RegionFileId,
metadata: Arc<ParquetMetaData>,
region_metadata: Option<RegionMetadataRef>,
) {
if self.sst_meta_cache.is_some() {
let file_path = format!(
"region_id={}, file_id={}",
file_id.region_id(),
file_id.file_id()
);
match CachedSstMeta::try_new_with_region_metadata(
&file_path,
Arc::unwrap_or_clone(metadata),
region_metadata,
) {
Ok(metadata) => self.put_sst_meta_data(file_id, Arc::new(metadata)),
Err(err) => warn!(
err; "Failed to decode region metadata while caching parquet metadata, region_id: {}, file_id: {}",
file_id.region_id(),
file_id.file_id()
),
}
}
}
/// Removes [ParquetMetaData] from the cache.
pub fn remove_parquet_meta_data(&self, file_id: RegionFileId) {
if let Some(cache) = &self.sst_meta_cache {
@@ -512,6 +710,31 @@ impl CacheManager {
}
}
/// Gets cached result for range scan.
#[allow(dead_code)]
pub(crate) fn get_range_result(
&self,
key: &RangeScanCacheKey,
) -> Option<Arc<RangeScanCacheValue>> {
self.range_result_cache
.as_ref()
.and_then(|cache| update_hit_miss(cache.get(key), RANGE_RESULT_TYPE))
}
/// Puts range scan result into cache.
pub(crate) fn put_range_result(
&self,
key: RangeScanCacheKey,
result: Arc<RangeScanCacheValue>,
) {
if let Some(cache) = &self.range_result_cache {
CACHE_BYTES
.with_label_values(&[RANGE_RESULT_TYPE])
.add(range_result_cache_weight(&key, &result).into());
cache.insert(key, result);
}
}
/// Gets the write cache.
pub(crate) fn write_cache(&self) -> Option<&WriteCacheRef> {
self.write_cache.as_ref()
@@ -562,6 +785,7 @@ pub struct CacheManagerBuilder {
puffin_metadata_size: u64,
write_cache: Option<WriteCacheRef>,
selector_result_cache_size: u64,
range_result_cache_size: u64,
}
impl CacheManagerBuilder {
@@ -625,6 +849,12 @@ impl CacheManagerBuilder {
self
}
/// Sets range result cache size.
pub fn range_result_cache_size(mut self, bytes: u64) -> Self {
self.range_result_cache_size = bytes;
self
}
/// Builds the [CacheManager].
pub fn build(self) -> CacheManager {
fn to_str(cause: RemovalCause) -> &'static str {
@@ -712,6 +942,21 @@ impl CacheManagerBuilder {
})
.build()
});
let range_result_cache = (self.range_result_cache_size != 0).then(|| {
Cache::builder()
.max_capacity(self.range_result_cache_size)
.weigher(range_result_cache_weight)
.eviction_listener(move |k, v, cause| {
let size = range_result_cache_weight(&k, &v);
CACHE_BYTES
.with_label_values(&[RANGE_RESULT_TYPE])
.sub(size.into());
CACHE_EVICTION
.with_label_values(&[RANGE_RESULT_TYPE, to_str(cause)])
.inc();
})
.build()
});
CacheManager {
sst_meta_cache,
vector_cache,
@@ -723,14 +968,15 @@ impl CacheManagerBuilder {
vector_index_cache,
puffin_metadata_cache: Some(Arc::new(puffin_metadata_cache)),
selector_result_cache,
range_result_cache,
index_result_cache,
}
}
}
fn meta_cache_weight(k: &SstMetaKey, v: &Arc<ParquetMetaData>) -> u32 {
fn meta_cache_weight(k: &SstMetaKey, v: &Arc<CachedSstMeta>) -> u32 {
// We ignore the size of `Arc`.
(k.estimated_size() + parquet_meta_size(v)) as u32
(k.estimated_size() + parquet_meta_size(&v.parquet_metadata) + v.region_metadata_weight) as u32
}
fn vector_cache_weight(_k: &(ConcreteDataType, Value), v: &VectorRef) -> u32 {
@@ -746,6 +992,10 @@ fn selector_result_cache_weight(k: &SelectorResultKey, v: &Arc<SelectorResultVal
(mem::size_of_val(k) + v.estimated_size()) as u32
}
fn range_result_cache_weight(k: &RangeScanCacheKey, v: &Arc<RangeScanCacheValue>) -> u32 {
(k.estimated_size() + v.estimated_size()) as u32
}
/// Updates cache hit/miss metrics.
fn update_hit_miss<T>(value: Option<T>, cache_type: &str) -> Option<T> {
if value.is_some() {
@@ -892,8 +1142,8 @@ impl SelectorResultValue {
}
}
/// Maps (region id, file id) to [ParquetMetaData].
type SstMetaCache = Cache<SstMetaKey, Arc<ParquetMetaData>>;
/// Maps (region id, file id) to fused SST metadata.
type SstMetaCache = Cache<SstMetaKey, Arc<CachedSstMeta>>;
/// Maps [Value] to a vector that holds this value repeatedly.
///
/// e.g. `"hello" => ["hello", "hello", "hello"]`
@@ -902,20 +1152,30 @@ type VectorCache = Cache<(ConcreteDataType, Value), VectorRef>;
type PageCache = Cache<PageKey, Arc<PageValue>>;
/// Maps (file id, row group id, time series row selector) to [SelectorResultValue].
type SelectorResultCache = Cache<SelectorResultKey, Arc<SelectorResultValue>>;
/// Maps partition-range scan key to cached flat batches.
type RangeResultCache = Cache<RangeScanCacheKey, Arc<RangeScanCacheValue>>;
#[cfg(test)]
mod tests {
use std::sync::Arc;
use api::v1::SemanticType;
use api::v1::index::{BloomFilterMeta, InvertedIndexMetas};
use datatypes::schema::ColumnSchema;
use datatypes::vectors::Int64Vector;
use puffin::file_metadata::FileMetadata;
use store_api::metadata::{ColumnMetadata, RegionMetadata, RegionMetadataBuilder};
use store_api::storage::ColumnId;
use super::*;
use crate::cache::index::bloom_filter_index::Tag;
use crate::cache::index::result_cache::PredicateKey;
use crate::cache::test_util::parquet_meta;
use crate::cache::test_util::{
parquet_meta, sst_parquet_meta, sst_parquet_meta_with_region_metadata,
};
use crate::read::range_cache::{
RangeScanCacheKey, RangeScanCacheValue, ScanRequestFingerprintBuilder,
};
use crate::sst::parquet::row_selection::RowGroupSelection;
#[tokio::test]
@@ -929,7 +1189,7 @@ mod tests {
let file_id = RegionFileId::new(region_id, FileId::random());
let metadata = parquet_meta();
let mut metrics = MetadataCacheMetrics::default();
cache.put_parquet_meta_data(file_id, metadata);
cache.put_parquet_meta_data(file_id, metadata, None);
assert!(
cache
.get_parquet_meta_data(file_id, &mut metrics, Default::default())
@@ -966,13 +1226,23 @@ mod tests {
.await
.is_none()
);
let metadata = parquet_meta();
cache.put_parquet_meta_data(file_id, metadata);
let (metadata, region_metadata) = sst_parquet_meta();
cache.put_parquet_meta_data(file_id, metadata, None);
let cached = cache
.get_sst_meta_data(file_id, &mut metrics, Default::default())
.await
.unwrap();
assert_eq!(region_metadata, cached.region_metadata());
assert!(
cache
.get_parquet_meta_data(file_id, &mut metrics, Default::default())
.await
.is_some()
cached
.parquet_metadata()
.file_metadata()
.key_value_metadata()
.is_none_or(|key_values| {
key_values
.iter()
.all(|key_value| key_value.key != PARQUET_METADATA_KEY)
})
);
cache.remove_parquet_meta_data(file_id);
assert!(
@@ -983,6 +1253,42 @@ mod tests {
);
}
#[tokio::test]
async fn test_parquet_meta_cache_with_provided_region_metadata() {
let cache = CacheManager::builder().sst_meta_cache_size(2000).build();
let mut metrics = MetadataCacheMetrics::default();
let region_id = RegionId::new(1, 1);
let file_id = RegionFileId::new(region_id, FileId::random());
let (metadata, region_metadata) = sst_parquet_meta();
cache.put_parquet_meta_data(file_id, metadata, Some(region_metadata.clone()));
let cached = cache
.get_sst_meta_data(file_id, &mut metrics, Default::default())
.await
.unwrap();
assert!(Arc::ptr_eq(&region_metadata, &cached.region_metadata()));
}
#[test]
fn test_meta_cache_weight_accounts_for_decoded_region_metadata() {
let region_metadata = Arc::new(wide_region_metadata(128));
let json_len = region_metadata.to_json().unwrap().len();
let metadata = sst_parquet_meta_with_region_metadata(region_metadata.clone());
let cached = Arc::new(
CachedSstMeta::try_new("test.parquet", Arc::unwrap_or_clone(metadata)).unwrap(),
);
let key = SstMetaKey(region_metadata.region_id, FileId::random());
assert!(cached.region_metadata_weight > json_len);
assert_eq!(
meta_cache_weight(&key, &cached) as usize,
key.estimated_size()
+ parquet_meta_size(&cached.parquet_metadata)
+ cached.region_metadata_weight
);
}
#[test]
fn test_repeated_vector_cache() {
let cache = CacheManager::builder().vector_cache_size(4096).build();
@@ -1028,6 +1334,50 @@ mod tests {
assert!(cache.get_selector_result(&key).is_some());
}
#[test]
fn test_range_result_cache() {
let cache = Arc::new(
CacheManager::builder()
.range_result_cache_size(1024 * 1024)
.build(),
);
let key = RangeScanCacheKey {
region_id: RegionId::new(1, 1),
row_groups: vec![(FileId::random(), 0)],
scan: ScanRequestFingerprintBuilder {
read_column_ids: vec![],
read_column_types: vec![],
filters: vec!["tag_0 = 1".to_string()],
time_filters: vec![],
series_row_selector: None,
append_mode: false,
filter_deleted: true,
merge_mode: crate::region::options::MergeMode::LastRow,
partition_expr_version: 0,
}
.build(),
};
let value = Arc::new(RangeScanCacheValue::new(Vec::new(), 0));
assert!(cache.get_range_result(&key).is_none());
cache.put_range_result(key.clone(), value.clone());
assert!(cache.get_range_result(&key).is_some());
let enable_all = CacheStrategy::EnableAll(cache.clone());
assert!(enable_all.get_range_result(&key).is_some());
let compaction = CacheStrategy::Compaction(cache.clone());
assert!(compaction.get_range_result(&key).is_none());
compaction.put_range_result(key.clone(), value.clone());
assert!(cache.get_range_result(&key).is_some());
let disabled = CacheStrategy::Disabled;
assert!(disabled.get_range_result(&key).is_none());
disabled.put_range_result(key.clone(), value);
assert!(cache.get_range_result(&key).is_some());
}
#[tokio::test]
async fn test_evict_puffin_cache_clears_all_entries() {
use std::collections::{BTreeMap, HashMap};
@@ -1122,4 +1472,45 @@ mod tests {
assert!(result_cache.get(&predicate, index_id.file_id()).is_none());
assert!(puffin_metadata_cache.get_metadata(&file_id_str).is_none());
}
fn wide_region_metadata(column_count: u32) -> RegionMetadata {
let region_id = RegionId::new(1024, 7);
let mut builder = RegionMetadataBuilder::new(region_id);
let mut primary_key = Vec::new();
for column_id in 0..column_count {
let semantic_type = if column_id < 32 {
primary_key.push(column_id);
SemanticType::Tag
} else {
SemanticType::Field
};
let mut column_schema = ColumnSchema::new(
format!("wide_column_{column_id}"),
ConcreteDataType::string_datatype(),
true,
);
column_schema
.mut_metadata()
.insert(format!("cache_key_{column_id}"), "cache_value".repeat(4));
builder.push_column_metadata(ColumnMetadata {
column_schema,
semantic_type,
column_id,
});
}
builder.push_column_metadata(ColumnMetadata {
column_schema: ColumnSchema::new(
"ts",
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
semantic_type: SemanticType::Timestamp,
column_id: column_count,
});
builder.primary_key(primary_key);
builder.build().unwrap()
}
}

View File

@@ -34,7 +34,7 @@ use store_api::storage::{FileId, RegionId};
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
use crate::access_layer::TempFileCleaner;
use crate::cache::{FILE_TYPE, INDEX_TYPE};
use crate::cache::{CachedSstMeta, FILE_TYPE, INDEX_TYPE};
use crate::error::{self, OpenDalSnafu, Result};
use crate::metrics::{
CACHE_BYTES, CACHE_HIT, CACHE_MISS, WRITE_CACHE_DOWNLOAD_BYTES_TOTAL,
@@ -612,6 +612,34 @@ impl FileCache {
}
}
/// Get fused SST metadata from the file cache.
/// If the file is not in the cache, or metadata loading/decoding fails, return None.
pub(crate) async fn get_sst_meta_data(
&self,
key: IndexKey,
cache_metrics: &mut MetadataCacheMetrics,
page_index_policy: PageIndexPolicy,
) -> Option<Arc<CachedSstMeta>> {
let file_path = self.inner.cache_file_path(key);
self.get_parquet_meta_data(key, cache_metrics, page_index_policy)
.await
.and_then(
|metadata| match CachedSstMeta::try_new(&file_path, metadata) {
Ok(metadata) => Some(Arc::new(metadata)),
Err(err) => {
CACHE_MISS
.with_label_values(&[key.file_type.metric_label()])
.inc();
warn!(
err; "Failed to decode cached parquet metadata for key {:?}",
key
);
None
}
},
)
}
async fn get_reader(&self, file_path: &str) -> object_store::Result<Option<Reader>> {
if self.inner.local_store.exists(file_path).await? {
Ok(Some(self.inner.local_store.reader(file_path).await?))

View File

@@ -23,8 +23,13 @@ use object_store::ObjectStore;
use object_store::services::Fs;
use parquet::arrow::ArrowWriter;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use parquet::file::metadata::ParquetMetaData;
use parquet::file::metadata::{KeyValue, ParquetMetaData};
use parquet::file::properties::WriterProperties;
use parquet::file::statistics::Statistics;
use store_api::metadata::RegionMetadataRef;
use crate::sst::parquet::PARQUET_METADATA_KEY;
use crate::test_util::sst_util::sst_region_metadata;
/// Returns a parquet meta data.
pub(crate) fn parquet_meta() -> Arc<ParquetMetaData> {
@@ -33,13 +38,43 @@ pub(crate) fn parquet_meta() -> Arc<ParquetMetaData> {
builder.metadata().clone()
}
/// Returns parquet metadata for an SST parquet file and its decoded region metadata.
pub(crate) fn sst_parquet_meta() -> (Arc<ParquetMetaData>, RegionMetadataRef) {
let region_metadata = Arc::new(sst_region_metadata());
let file_data = parquet_file_data_with_region_metadata(&region_metadata);
let builder = ParquetRecordBatchReaderBuilder::try_new(Bytes::from(file_data)).unwrap();
(builder.metadata().clone(), region_metadata)
}
/// Returns parquet metadata for an SST parquet file with custom region metadata.
pub(crate) fn sst_parquet_meta_with_region_metadata(
region_metadata: RegionMetadataRef,
) -> Arc<ParquetMetaData> {
let file_data = parquet_file_data_with_region_metadata(&region_metadata);
let builder = ParquetRecordBatchReaderBuilder::try_new(Bytes::from(file_data)).unwrap();
builder.metadata().clone()
}
/// Write a test parquet file to a buffer
fn parquet_file_data() -> Vec<u8> {
parquet_file_data_inner(None)
}
fn parquet_file_data_with_region_metadata(region_metadata: &RegionMetadataRef) -> Vec<u8> {
let json = region_metadata.to_json().unwrap();
let key_value = KeyValue::new(PARQUET_METADATA_KEY.to_string(), json);
parquet_file_data_inner(Some(vec![key_value]))
}
fn parquet_file_data_inner(key_value_metadata: Option<Vec<KeyValue>>) -> Vec<u8> {
let col = Arc::new(Int64Array::from_iter_values([1, 2, 3])) as ArrayRef;
let to_write = RecordBatch::try_from_iter([("col", col)]).unwrap();
let mut buffer = Vec::new();
let mut writer = ArrowWriter::try_new(&mut buffer, to_write.schema(), None).unwrap();
let props = WriterProperties::builder()
.set_key_value_metadata(key_value_metadata)
.build();
let mut writer = ArrowWriter::try_new(&mut buffer, to_write.schema(), Some(props)).unwrap();
writer.write(&to_write).unwrap();
writer.close().unwrap();

View File

@@ -244,15 +244,19 @@ impl WriteCache {
.await
.with_file_cleaner(cleaner);
let sst_info = match write_request.source {
either::Left(source) => {
let sst_info = match write_request.sst_write_format {
crate::sst::FormatType::PrimaryKey => {
writer
.write_all(source, write_request.max_sequence, write_opts)
.write_all_flat_as_primary_key(
write_request.source,
write_request.max_sequence,
write_opts,
)
.await?
}
either::Right(flat_source) => {
crate::sst::FormatType::Flat => {
writer
.write_all_flat(flat_source, write_request.max_sequence, write_opts)
.write_all_flat(write_request.source, write_request.max_sequence, write_opts)
.await?
}
};
@@ -509,12 +513,13 @@ mod tests {
use crate::cache::test_util::{assert_parquet_metadata_equal, new_fs_store};
use crate::cache::{CacheManager, CacheStrategy};
use crate::error::InvalidBatchSnafu;
use crate::read::Source;
use crate::read::FlatSource;
use crate::region::options::IndexOptions;
use crate::sst::parquet::reader::ParquetReaderBuilder;
use crate::test_util::TestEnv;
use crate::test_util::sst_util::{
new_batch_by_range, new_source, sst_file_handle_with_file_id, sst_region_metadata,
new_flat_source_from_record_batches, new_record_batch_by_range,
sst_file_handle_with_file_id, sst_region_metadata,
};
#[tokio::test]
@@ -532,21 +537,22 @@ mod tests {
.create_write_cache(local_store.clone(), ReadableSize::mb(10))
.await;
// Create Source
// Create source.
let metadata = Arc::new(sst_region_metadata());
let region_id = metadata.region_id;
let source = new_source(&[
new_batch_by_range(&["a", "d"], 0, 60),
new_batch_by_range(&["b", "f"], 0, 40),
new_batch_by_range(&["b", "h"], 100, 200),
let source = new_flat_source_from_record_batches(vec![
new_record_batch_by_range(&["a", "d"], 0, 60),
new_record_batch_by_range(&["b", "f"], 0, 40),
new_record_batch_by_range(&["b", "h"], 100, 200),
]);
let write_request = SstWriteRequest {
op_type: OperationType::Flush,
metadata,
source: either::Left(source),
source,
storage: None,
max_sequence: None,
sst_write_format: Default::default(),
cache_manager: Default::default(),
index_options: IndexOptions::default(),
index_config: Default::default(),
@@ -636,19 +642,20 @@ mod tests {
// Create source
let metadata = Arc::new(sst_region_metadata());
let source = new_source(&[
new_batch_by_range(&["a", "d"], 0, 60),
new_batch_by_range(&["b", "f"], 0, 40),
new_batch_by_range(&["b", "h"], 100, 200),
let source = new_flat_source_from_record_batches(vec![
new_record_batch_by_range(&["a", "d"], 0, 60),
new_record_batch_by_range(&["b", "f"], 0, 40),
new_record_batch_by_range(&["b", "h"], 100, 200),
]);
// Write to local cache and upload sst to mock remote store
let write_request = SstWriteRequest {
op_type: OperationType::Flush,
metadata,
source: either::Left(source),
source,
storage: None,
max_sequence: None,
sst_write_format: Default::default(),
cache_manager: cache_manager.clone(),
index_options: IndexOptions::default(),
index_config: Default::default(),
@@ -686,9 +693,15 @@ mod tests {
.cache(CacheStrategy::EnableAll(cache_manager.clone()))
.page_index_policy(PageIndexPolicy::Optional);
let reader = builder.build().await.unwrap().unwrap();
let cached_write_parquet_metadata = crate::cache::CachedSstMeta::try_new(
"test.sst",
Arc::unwrap_or_clone(write_parquet_metadata),
)
.unwrap()
.parquet_metadata();
// Check parquet metadata
assert_parquet_metadata_equal(write_parquet_metadata, reader.parquet_metadata());
assert_parquet_metadata_equal(cached_write_parquet_metadata, reader.parquet_metadata());
}
#[tokio::test]
@@ -715,9 +728,9 @@ mod tests {
let metadata = Arc::new(sst_region_metadata());
// Creates a source that can return an error to abort the writer.
let source = Source::Iter(Box::new(
let source = FlatSource::Iter(Box::new(
[
Ok(new_batch_by_range(&["a", "d"], 0, 60)),
Ok(new_record_batch_by_range(&["a", "d"], 0, 60)),
InvalidBatchSnafu {
reason: "Abort the writer",
}
@@ -730,9 +743,10 @@ mod tests {
let write_request = SstWriteRequest {
op_type: OperationType::Flush,
metadata,
source: either::Left(source),
source,
storage: None,
max_sequence: None,
sst_write_format: Default::default(),
cache_manager: cache_manager.clone(),
index_options: IndexOptions::default(),
index_config: Default::default(),

View File

@@ -58,10 +58,10 @@ use crate::error::{
TimeRangePredicateOverflowSnafu, TimeoutSnafu,
};
use crate::metrics::{COMPACTION_STAGE_ELAPSED, INFLIGHT_COMPACTION_COUNT};
use crate::read::BoxedRecordBatchStream;
use crate::read::projection::ProjectionMapper;
use crate::read::scan_region::{PredicateGroup, ScanInput};
use crate::read::seq_scan::SeqScan;
use crate::read::{BoxedBatchReader, BoxedRecordBatchStream};
use crate::region::options::{MergeMode, RegionOptions};
use crate::region::version::VersionControlRef;
use crate::region::{ManifestContextRef, RegionLeaderState, RegionRoleState};
@@ -828,7 +828,7 @@ pub struct SerializedCompactionOutput {
output_time_range: Option<TimestampRange>,
}
/// Builders to create [BoxedBatchReader] for compaction.
/// Builders to create [BoxedRecordBatchStream] for compaction.
struct CompactionSstReaderBuilder<'a> {
metadata: RegionMetadataRef,
sst_layer: AccessLayerRef,
@@ -841,24 +841,17 @@ struct CompactionSstReaderBuilder<'a> {
}
impl CompactionSstReaderBuilder<'_> {
/// Builds [BoxedBatchReader] that reads all SST files and yields batches in primary key order.
async fn build_sst_reader(self) -> Result<BoxedBatchReader> {
let scan_input = self.build_scan_input(false)?.with_compaction(true);
SeqScan::new(scan_input).build_reader_for_compaction().await
}
/// Builds [BoxedRecordBatchStream] that reads all SST files and yields batches in flat format for compaction.
async fn build_flat_sst_reader(self) -> Result<BoxedRecordBatchStream> {
let scan_input = self.build_scan_input(true)?.with_compaction(true);
let scan_input = self.build_scan_input()?.with_compaction(true);
SeqScan::new(scan_input)
.build_flat_reader_for_compaction()
.await
}
fn build_scan_input(self, flat_format: bool) -> Result<ScanInput> {
let mapper = ProjectionMapper::all(&self.metadata, flat_format)?;
fn build_scan_input(self) -> Result<ScanInput> {
let mapper = ProjectionMapper::all(&self.metadata, true)?;
let mut scan_input = ScanInput::new(self.sst_layer, mapper)
.with_files(self.inputs.to_vec())
.with_append_mode(self.append_mode)
@@ -868,7 +861,7 @@ impl CompactionSstReaderBuilder<'_> {
// We ignore file not found error during compaction.
.with_ignore_file_not_found(true)
.with_merge_mode(self.merge_mode)
.with_flat_format(flat_format);
.with_flat_format(true);
// This serves as a workaround of https://github.com/GreptimeTeam/greptimedb/issues/3944
// by converting time ranges into predicate.

View File

@@ -43,7 +43,7 @@ use crate::error::{
use crate::manifest::action::{RegionEdit, RegionMetaAction, RegionMetaActionList};
use crate::manifest::manager::{RegionManifestManager, RegionManifestOptions};
use crate::metrics;
use crate::read::{FlatSource, Source};
use crate::read::FlatSource;
use crate::region::options::RegionOptions;
use crate::region::version::VersionRef;
use crate::region::{ManifestContext, RegionLeaderState, RegionRoleState};
@@ -356,13 +356,8 @@ impl DefaultCompactor {
time_range: output.output_time_range,
merge_mode,
};
let source = if flat_format {
let reader = builder.build_flat_sst_reader().await?;
Either::Right(FlatSource::Stream(reader))
} else {
let reader = builder.build_sst_reader().await?;
Either::Left(Source::Reader(reader))
};
let reader = builder.build_flat_sst_reader().await?;
let source = FlatSource::Stream(reader);
let mut metrics = Metrics::new(WriteType::Compaction);
let region_metadata = compaction_region.region_metadata.clone();
let sst_infos = compaction_region
@@ -375,6 +370,11 @@ impl DefaultCompactor {
cache_manager: compaction_region.cache_manager.clone(),
storage,
max_sequence: max_sequence.map(NonZero::get),
sst_write_format: if flat_format {
FormatType::Flat
} else {
FormatType::PrimaryKey
},
index_options,
index_config,
inverted_index_config,

View File

@@ -116,6 +116,8 @@ pub struct MitoConfig {
pub page_cache_size: ReadableSize,
/// Cache size for time series selector (e.g. `last_value()`). Setting it to 0 to disable the cache.
pub selector_result_cache_size: ReadableSize,
/// Cache size for flat range scan results. Setting it to 0 to disable the cache.
pub range_result_cache_size: ReadableSize,
/// Whether to enable the write cache.
pub enable_write_cache: bool,
/// File system path for write cache dir's root, defaults to `{data_home}`.
@@ -200,6 +202,7 @@ impl Default for MitoConfig {
vector_cache_size: ReadableSize::mb(512),
page_cache_size: ReadableSize::mb(512),
selector_result_cache_size: ReadableSize::mb(512),
range_result_cache_size: ReadableSize::mb(512),
enable_write_cache: false,
write_cache_path: String::new(),
write_cache_size: ReadableSize::gb(5),
@@ -336,6 +339,7 @@ impl MitoConfig {
self.vector_cache_size = mem_cache_size;
self.page_cache_size = page_cache_size;
self.selector_result_cache_size = mem_cache_size;
self.range_result_cache_size = mem_cache_size;
self.index.adjust_buffer_and_cache_size(sys_memory);
}

View File

@@ -24,7 +24,7 @@ use crate::test_util::{
CreateRequestBuilder, TestEnv, build_rows_for_key, flush_region, put_rows, rows_schema,
};
async fn test_last_row(append_mode: bool) {
async fn test_last_row(append_mode: bool, flat_format: bool) {
let mut env = TestEnv::new().await;
let engine = env.create_engine(MitoConfig::default()).await;
let region_id = RegionId::new(1, 1);
@@ -39,9 +39,12 @@ async fn test_last_row(append_mode: bool) {
env.get_kv_backend(),
)
.await;
let request = CreateRequestBuilder::new()
.insert_option("append_mode", &append_mode.to_string())
.build();
let mut request_builder =
CreateRequestBuilder::new().insert_option("append_mode", &append_mode.to_string());
if flat_format {
request_builder = request_builder.insert_option("sst_format", "flat");
}
let request = request_builder.build();
let column_schemas = rows_schema(&request);
engine
.handle_request(region_id, RegionRequest::Create(request))
@@ -106,10 +109,20 @@ async fn test_last_row(append_mode: bool) {
#[tokio::test]
async fn test_last_row_append_mode_disabled() {
test_last_row(false).await;
test_last_row(false, false).await;
}
#[tokio::test]
async fn test_last_row_append_mode_enabled() {
test_last_row(true).await;
test_last_row(true, false).await;
}
#[tokio::test]
async fn test_last_row_flat_format_append_mode_disabled() {
test_last_row(false, true).await;
}
#[tokio::test]
async fn test_last_row_flat_format_append_mode_enabled() {
test_last_row(true, true).await;
}

View File

@@ -15,7 +15,9 @@
use api::v1::Rows;
use common_wal::options::{WAL_OPTIONS_KEY, WalOptions};
use store_api::region_engine::{RegionEngine, RegionRole};
use store_api::region_request::{RegionCloseRequest, RegionRequest};
use store_api::region_request::{
RegionCloseRequest, RegionOpenRequest, RegionRequest, RegionTruncateRequest,
};
use store_api::storage::{RegionId, ScanRequest};
use crate::config::MitoConfig;
@@ -168,3 +170,76 @@ async fn test_close_follower_region_skip_wal() {
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(0, total_rows);
}
#[tokio::test]
async fn test_close_region_after_truncate_skip_wal() {
common_telemetry::init_default_ut_logging();
let mut env = TestEnv::with_prefix("close-truncate-skip-wal").await;
let engine = env.create_engine(MitoConfig::default()).await;
let region_id = RegionId::new(1, 1);
let mut request = CreateRequestBuilder::new().build();
let wal_options = WalOptions::Noop;
request.options.insert(
WAL_OPTIONS_KEY.to_string(),
serde_json::to_string(&wal_options).unwrap(),
);
engine
.handle_request(region_id, RegionRequest::Create(request.clone()))
.await
.unwrap();
engine
.handle_request(
region_id,
RegionRequest::Truncate(RegionTruncateRequest::All),
)
.await
.unwrap();
let region = engine.get_region(region_id).unwrap();
let version_data = region.version_control.current();
assert_eq!(
version_data.version.truncated_entry_id,
Some(version_data.last_entry_id)
);
let rows = Rows {
schema: rows_schema(&request),
rows: build_rows(0, 3),
};
put_rows(&engine, region_id, rows).await;
let region = engine.get_region(region_id).unwrap();
assert!(!region.version().memtables.is_empty());
engine
.handle_request(region_id, RegionRequest::Close(RegionCloseRequest {}))
.await
.unwrap();
engine
.handle_request(
region_id,
RegionRequest::Open(RegionOpenRequest {
engine: String::new(),
table_dir: request.table_dir,
path_type: store_api::region_request::PathType::Bare,
options: request.options,
skip_wal_replay: false,
checkpoint: None,
}),
)
.await
.unwrap();
let stream = engine
.scan_to_stream(region_id, ScanRequest::default())
.await
.unwrap();
let batches = common_recordbatch::RecordBatches::try_collect(stream)
.await
.unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(3, total_rows);
}

View File

@@ -616,15 +616,6 @@ pub enum Error {
location: Location,
},
#[snafu(display("Failed to read arrow record batch from parquet file {}", path))]
ArrowReader {
path: String,
#[snafu(source)]
error: ArrowError,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Column not found, column: {column}"))]
ColumnNotFound {
column: String,
@@ -1349,7 +1340,6 @@ impl ErrorExt for Error {
RegionState { .. } | UpdateManifest { .. } => StatusCode::RegionNotReady,
JsonOptions { .. } => StatusCode::InvalidArguments,
EmptyRegionDir { .. } | EmptyManifestDir { .. } => StatusCode::RegionNotFound,
ArrowReader { .. } => StatusCode::StorageUnavailable,
ConvertValue { source, .. } => source.status_code(),
ApplyBloomFilterIndex { source, .. } => source.status_code(),
InvalidPartitionExpr { source, .. } => source.status_code(),

View File

@@ -22,7 +22,6 @@ use std::time::Instant;
use common_telemetry::{debug, error, info};
use datatypes::arrow::datatypes::SchemaRef;
use either::Either;
use partition::expr::PartitionExpr;
use smallvec::{SmallVec, smallvec};
use snafu::ResultExt;
@@ -41,18 +40,14 @@ use crate::error::{
};
use crate::manifest::action::{RegionEdit, RegionMetaAction, RegionMetaActionList};
use crate::memtable::bulk::ENCODE_ROW_THRESHOLD;
use crate::memtable::{
BoxedRecordBatchIterator, EncodedRange, IterBuilder, MemtableRanges, RangesOptions,
};
use crate::memtable::{BoxedRecordBatchIterator, EncodedRange, MemtableRanges, RangesOptions};
use crate::metrics::{
FLUSH_BYTES_TOTAL, FLUSH_ELAPSED, FLUSH_FAILURE_TOTAL, FLUSH_FILE_TOTAL, FLUSH_REQUESTS_TOTAL,
INFLIGHT_FLUSH_COUNT,
};
use crate::read::dedup::{DedupReader, LastNonNull, LastRow};
use crate::read::FlatSource;
use crate::read::flat_dedup::{FlatDedupIterator, FlatLastNonNull, FlatLastRow};
use crate::read::flat_merge::FlatMergeIterator;
use crate::read::merge::MergeReaderBuilder;
use crate::read::{FlatSource, Source};
use crate::region::options::{IndexOptions, MergeMode, RegionOptions};
use crate::region::version::{VersionControlData, VersionControlRef, VersionRef};
use crate::region::{ManifestContextRef, RegionLeaderState, RegionRoleState, parse_partition_expr};
@@ -62,8 +57,10 @@ use crate::request::{
};
use crate::schedule::scheduler::{Job, SchedulerRef};
use crate::sst::file::FileMeta;
use crate::sst::parquet::{DEFAULT_READ_BATCH_SIZE, DEFAULT_ROW_GROUP_SIZE, SstInfo, WriteOptions};
use crate::sst::{FlatSchemaOptions, to_flat_sst_arrow_schema};
use crate::sst::parquet::{
DEFAULT_READ_BATCH_SIZE, DEFAULT_ROW_GROUP_SIZE, SstInfo, WriteOptions, flat_format,
};
use crate::sst::{FlatSchemaOptions, FormatType, to_flat_sst_arrow_schema};
use crate::worker::WorkerListener;
/// Global write buffer (memtable) manager.
@@ -480,78 +477,29 @@ impl RegionFlushTask {
// the counter may have more series than the actual series count.
series_count += memtable_series_count;
if mem_ranges.is_record_batch() {
let flush_start = Instant::now();
let FlushFlatMemResult {
num_encoded,
num_sources,
results,
} = self
.flush_flat_mem_ranges(version, &write_opts, mem_ranges)
.await?;
encoded_part_count += num_encoded;
for (source_idx, result) in results.into_iter().enumerate() {
let (max_sequence, ssts_written, metrics) = result?;
if ssts_written.is_empty() {
// No data written.
continue;
}
common_telemetry::debug!(
"Region {} flush one memtable {} {}/{}, metrics: {:?}",
self.region_id,
memtable_id,
source_idx,
num_sources,
metrics
);
flush_metrics = flush_metrics.merge(metrics);
file_metas.extend(ssts_written.into_iter().map(|sst_info| {
flushed_bytes += sst_info.file_size;
Self::new_file_meta(
self.region_id,
max_sequence,
sst_info,
partition_expr.clone(),
)
}));
}
common_telemetry::debug!(
"Region {} flush {} memtables for {}, num_mem_ranges: {}, num_encoded: {}, num_rows: {}, flush_cost: {:?}, compact_cost: {:?}",
self.region_id,
num_sources,
memtable_id,
num_mem_ranges,
num_encoded,
num_mem_rows,
flush_start.elapsed(),
compact_cost,
);
} else {
let max_sequence = mem_ranges.max_sequence();
let source = memtable_source(mem_ranges, &version.options).await?;
// Flush to level 0.
let source = Either::Left(source);
let write_request = self.new_write_request(version, max_sequence, source);
let mut metrics = Metrics::new(WriteType::Flush);
let ssts_written = self
.access_layer
.write_sst(write_request, &write_opts, &mut metrics)
.await?;
FLUSH_FILE_TOTAL.inc_by(ssts_written.len() as u64);
let flush_start = Instant::now();
let FlushFlatMemResult {
num_encoded,
num_sources,
results,
} = self
.flush_flat_mem_ranges(version, &write_opts, mem_ranges)
.await?;
encoded_part_count += num_encoded;
for (source_idx, result) in results.into_iter().enumerate() {
let (max_sequence, ssts_written, metrics) = result?;
if ssts_written.is_empty() {
// No data written.
continue;
}
debug!(
"Region {} flush one memtable, num_mem_ranges: {}, num_rows: {}, metrics: {:?}",
self.region_id, num_mem_ranges, num_mem_rows, metrics
common_telemetry::debug!(
"Region {} flush one memtable {} {}/{}, metrics: {:?}",
self.region_id,
memtable_id,
source_idx,
num_sources,
metrics
);
flush_metrics = flush_metrics.merge(metrics);
@@ -565,7 +513,19 @@ impl RegionFlushTask {
partition_expr.clone(),
)
}));
};
}
common_telemetry::debug!(
"Region {} flush {} memtables for {}, num_mem_ranges: {}, num_encoded: {}, num_rows: {}, flush_cost: {:?}, compact_cost: {:?}",
self.region_id,
num_sources,
memtable_id,
num_mem_ranges,
num_encoded,
num_mem_rows,
flush_start.elapsed(),
compact_cost,
);
}
Ok(DoFlushMemtablesResult {
@@ -587,16 +547,17 @@ impl RegionFlushTask {
&version.metadata,
&FlatSchemaOptions::from_encoding(version.metadata.primary_key_encoding),
);
let field_column_start =
flat_format::field_column_start(&version.metadata, batch_schema.fields().len());
let flat_sources = memtable_flat_sources(
batch_schema,
mem_ranges,
&version.options,
version.metadata.primary_key.len(),
field_column_start,
)?;
let mut tasks = Vec::with_capacity(flat_sources.encoded.len() + flat_sources.sources.len());
let num_encoded = flat_sources.encoded.len();
for (source, max_sequence) in flat_sources.sources {
let source = Either::Right(source);
let write_request = self.new_write_request(version, max_sequence, source);
let access_layer = self.access_layer.clone();
let write_opts = write_opts.clone();
@@ -667,8 +628,13 @@ impl RegionFlushTask {
&self,
version: &VersionRef,
max_sequence: u64,
source: Either<Source, FlatSource>,
source: FlatSource,
) -> SstWriteRequest {
let flat_format = version
.options
.sst_format
.map(|f| f == FormatType::Flat)
.unwrap_or(self.engine_config.default_experimental_flat_format);
SstWriteRequest {
op_type: OperationType::Flush,
metadata: version.metadata.clone(),
@@ -676,6 +642,11 @@ impl RegionFlushTask {
cache_manager: self.cache_manager.clone(),
storage: version.options.storage.clone(),
max_sequence: Some(max_sequence),
sst_write_format: if flat_format {
FormatType::Flat
} else {
FormatType::PrimaryKey
},
index_options: self.index_options.clone(),
index_config: self.engine_config.index.clone(),
inverted_index_config: self.engine_config.inverted_index.clone(),
@@ -722,41 +693,6 @@ struct DoFlushMemtablesResult {
flush_metrics: Metrics,
}
/// Returns a [Source] for the given memtable.
async fn memtable_source(mem_ranges: MemtableRanges, options: &RegionOptions) -> Result<Source> {
let source = if mem_ranges.ranges.len() == 1 {
let only_range = mem_ranges.ranges.into_values().next().unwrap();
let iter = only_range.build_iter()?;
Source::Iter(iter)
} else {
// todo(hl): a workaround since sync version of MergeReader is wip.
let sources = mem_ranges
.ranges
.into_values()
.map(|r| r.build_iter().map(Source::Iter))
.collect::<Result<Vec<_>>>()?;
let merge_reader = MergeReaderBuilder::from_sources(sources).build().await?;
let maybe_dedup = if options.append_mode {
// no dedup in append mode
Box::new(merge_reader) as _
} else {
// dedup according to merge mode
match options.merge_mode.unwrap_or(MergeMode::LastRow) {
MergeMode::LastRow => {
Box::new(DedupReader::new(merge_reader, LastRow::new(false), None)) as _
}
MergeMode::LastNonNull => Box::new(DedupReader::new(
merge_reader,
LastNonNull::new(false),
None,
)) as _,
}
};
Source::Reader(maybe_dedup)
};
Ok(source)
}
struct FlatSources {
sources: SmallVec<[(FlatSource, SequenceNumber); 4]>,
encoded: SmallVec<[(EncodedRange, SequenceNumber); 4]>,

View File

@@ -28,6 +28,7 @@ use mito_codec::key_values::KeyValue;
pub use mito_codec::key_values::KeyValues;
use mito_codec::row_converter::{PrimaryKeyCodec, build_primary_key_codec};
use serde::{Deserialize, Serialize};
use snafu::ensure;
use store_api::metadata::RegionMetadataRef;
use store_api::storage::{ColumnId, SequenceNumber, SequenceRange};
@@ -231,10 +232,17 @@ impl MemtableRanges {
impl IterBuilder for MemtableRanges {
fn build(&self, _metrics: Option<MemScanMetrics>) -> Result<BoxedBatchIterator> {
UnsupportedOperationSnafu {
err_msg: "MemtableRanges does not support build iterator",
}
.fail()
ensure!(
self.ranges.len() == 1,
UnsupportedOperationSnafu {
err_msg: format!(
"Building an iterator from MemtableRanges expects 1 range, but got {}",
self.ranges.len()
),
}
);
self.ranges.values().next().unwrap().build_iter()
}
fn is_record_batch(&self) -> bool {
@@ -256,20 +264,6 @@ pub trait Memtable: Send + Sync + fmt::Debug {
/// Writes an encoded batch of into memtable.
fn write_bulk(&self, part: crate::memtable::bulk::part::BulkPart) -> Result<()>;
/// Scans the memtable.
/// `projection` selects columns to read, `None` means reading all columns.
/// `filters` are the predicates to be pushed down to memtable.
///
/// # Note
/// This method should only be used for tests.
#[cfg(any(test, feature = "test"))]
fn iter(
&self,
projection: Option<&[ColumnId]>,
predicate: Option<table::predicate::Predicate>,
sequence: Option<SequenceRange>,
) -> Result<BoxedBatchIterator>;
/// Returns the ranges in the memtable.
///
/// The returned map contains the range id and the range after applying the predicate.
@@ -543,11 +537,15 @@ pub trait IterBuilder: Send + Sync {
}
/// Returns the record batch iterator to read the range.
/// ## Note
/// Implementations should ensure the iterator yields data within given time range.
fn build_record_batch(
&self,
time_range: Option<(Timestamp, Timestamp)>,
metrics: Option<MemScanMetrics>,
) -> Result<BoxedRecordBatchIterator> {
let _metrics = metrics;
let _ = time_range;
UnsupportedOperationSnafu {
err_msg: "Record batch iterator is not supported by this memtable",
}
@@ -706,7 +704,7 @@ impl MemtableRange {
metrics: Option<MemScanMetrics>,
) -> Result<BoxedRecordBatchIterator> {
if self.context.builder.is_record_batch() {
return self.context.builder.build_record_batch(metrics);
return self.context.builder.build_record_batch(time_range, metrics);
}
if let Some(context) = self.context.batch_to_record_batch.as_ref() {

Some files were not shown because too many files have changed in this diff Show More