diff --git a/.env.example b/.env.example new file mode 100644 index 0000000000..9117d0d4c6 --- /dev/null +++ b/.env.example @@ -0,0 +1,4 @@ +# Settings for s3 test +GT_S3_BUCKET=S3 bucket +GT_S3_ACCESS_KEY_ID=S3 access key id +GT_S3_ACCESS_KEY=S3 secret access key diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 53629f24d0..1c62e4ad41 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -13,7 +13,7 @@ Please explain IN DETAIL what the changes are in this PR and why they are needed ## Checklist -- [] I have written the necessary rustdoc comments. -- [] I have added the necessary unit tests and integration tests. +- [ ] I have written the necessary rustdoc comments. +- [ ] I have added the necessary unit tests and integration tests. ## Refer to a related PR or issue link (optional) diff --git a/.github/workflows/develop.yml b/.github/workflows/develop.yml index 33f0dadc98..2cba1fa5d2 100644 --- a/.github/workflows/develop.yml +++ b/.github/workflows/develop.yml @@ -26,6 +26,13 @@ env: RUST_TOOLCHAIN: nightly-2022-07-14 jobs: + typos: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: crate-ci/typos@v1.0.4 + check: name: Check if: github.event.pull_request.draft == false @@ -42,6 +49,23 @@ jobs: - name: Run cargo check run: cargo check --workspace --all-targets + toml: + name: Toml Check + if: github.event.pull_request.draft == false + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ env.RUST_TOOLCHAIN }} + - name: Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Install taplo + run: cargo install taplo-cli --version ^0.8 --locked + - name: Run taplo + run: taplo format --check --option "indent_string= " + # Use coverage to run test. # test: # name: Test Suite diff --git a/.github/workflows/doc-issue.yml b/.github/workflows/doc-issue.yml new file mode 100644 index 0000000000..a06102a479 --- /dev/null +++ b/.github/workflows/doc-issue.yml @@ -0,0 +1,25 @@ +name: Create Issue in docs repo on doc related changes + +on: + issues: + types: + - labeled + pull_request_target: + types: + - labeled + +jobs: + doc_issue: + if: github.event.label.name == 'doc update required' + runs-on: ubuntu-latest + steps: + - name: create an issue in doc repo + uses: dacbd/create-issue-action@main + with: + owner: GreptimeTeam + repo: docs + token: ${{ secrets.DOCS_REPO_TOKEN }} + title: Update docs for ${{ github.event.issue.title || github.event.pull_request.title }} + body: | + A document change request is generated from + ${{ github.event.issue.html_url || github.event.pull_request.html_url }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1ff1c8026d..5a98be2201 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,9 +3,8 @@ on: tags: - "v*.*.*" schedule: - # At 00:00 Everyday - # https://crontab.guru/every-day-at-midnight - - cron: '0 0 * * *' + # At 00:00 on Monday. + - cron: '0 0 * * 1' workflow_dispatch: name: Release @@ -14,7 +13,10 @@ env: RUST_TOOLCHAIN: nightly-2022-07-14 # FIXME(zyy17): Would be better to use `gh release list -L 1 | cut -f 3` to get the latest release version tag, but for a long time, we will stay at 'v0.1.0-alpha-*'. - NIGHTLY_BUILD_VERSION_PREFIX: v0.1.0-alpha + SCHEDULED_BUILD_VERSION_PREFIX: v0.1.0-alpha + + # In the future, we can change SCHEDULED_PERIOD to nightly. + SCHEDULED_PERIOD: weekly jobs: build: @@ -113,25 +115,25 @@ jobs: - name: Download artifacts uses: actions/download-artifact@v3 - - name: Configure nightly build version # the version would be ${NIGHTLY_BUILD_VERSION_PREFIX}-YYYYMMDD-nightly, like v0.1.0-alpha-20221119-nightly. + - name: Configure scheduled build version # the version would be ${SCHEDULED_BUILD_VERSION_PREFIX}-YYYYMMDD-${SCHEDULED_PERIOD}, like v0.1.0-alpha-20221119-weekly. shell: bash if: github.event_name == 'schedule' run: | buildTime=`date "+%Y%m%d"` - NIGHTLY_VERSION=${{ env.NIGHTLY_BUILD_VERSION_PREFIX }}-$buildTime-nightly - echo "NIGHTLY_VERSION=${NIGHTLY_VERSION}" >> $GITHUB_ENV + SCHEDULED_BUILD_VERSION=${{ env.SCHEDULED_BUILD_VERSION_PREFIX }}-$buildTime-${{ env.SCHEDULED_PERIOD }} + echo "SCHEDULED_BUILD_VERSION=${SCHEDULED_BUILD_VERSION}" >> $GITHUB_ENV - - name: Create nightly git tag + - name: Create scheduled build git tag if: github.event_name == 'schedule' run: | - git tag ${{ env.NIGHTLY_VERSION }} + git tag ${{ env.SCHEDULED_BUILD_VERSION }} - - name: Publish nightly release # configure the different release title and tags. + - name: Publish scheduled release # configure the different release title and tags. uses: softprops/action-gh-release@v1 if: github.event_name == 'schedule' with: - name: "Release ${{ env.NIGHTLY_VERSION }}" - tag_name: ${{ env.NIGHTLY_VERSION }} + name: "Release ${{ env.SCHEDULED_BUILD_VERSION }}" + tag_name: ${{ env.SCHEDULED_BUILD_VERSION }} generate_release_notes: true files: | **/greptime-* @@ -189,13 +191,13 @@ jobs: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Configure nightly build image tag # the tag would be ${NIGHTLY_BUILD_VERSION_PREFIX}-YYYYMMDD-nightly + - name: Configure scheduled build image tag # the tag would be ${SCHEDULED_BUILD_VERSION_PREFIX}-YYYYMMDD-${SCHEDULED_PERIOD} shell: bash if: github.event_name == 'schedule' run: | buildTime=`date "+%Y%m%d"` - NIGHTLY_VERSION=${{ env.NIGHTLY_BUILD_VERSION_PREFIX }}-$buildTime-nightly - echo "IMAGE_TAG=${NIGHTLY_VERSION:1}" >> $GITHUB_ENV + SCHEDULED_BUILD_VERSION=${{ env.SCHEDULED_BUILD_VERSION_PREFIX }}-$buildTime-${{ env.SCHEDULED_PERIOD }} + echo "IMAGE_TAG=${SCHEDULED_BUILD_VERSION:1}" >> $GITHUB_ENV - name: Configure tag # If the release tag is v0.1.0, then the image version tag will be 0.1.0. shell: bash diff --git a/.gitignore b/.gitignore index 84cdd03cf0..1cb44bbdf1 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ debug/ # JetBrains IDE config directory .idea/ +*.iml # VSCode IDE config directory .vscode/ @@ -31,3 +32,6 @@ logs/ # Benchmark dataset benchmarks/data + +# dotenv +.env diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 83a064a734..2dbc9c6e61 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: rev: e6a795bc6b2c0958f9ef52af4863bbd7cc17238f hooks: - id: cargo-sort - args: ["--workspace", "--print"] + args: ["--workspace"] - repo: https://github.com/doublify/pre-commit-rust rev: v1.0 diff --git a/Cargo.lock b/Cargo.lock index 3a2a17a52f..9b442cd818 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -35,7 +35,20 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" dependencies = [ - "getrandom", + "getrandom 0.2.7", + "once_cell", + "version_check", +] + +[[package]] +name = "ahash" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf6ccdb167abbf410dcb915cabd428929d7f6a04980b54a11f26a39f1c7f7107" +dependencies = [ + "cfg-if", + "const-random", + "getrandom 0.2.7", "once_cell", "version_check", ] @@ -64,11 +77,11 @@ dependencies = [ [[package]] name = "aide" -version = "0.6.2" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47c350c121222a7d8cc7d2efad0856ddd3a903eb62f0f5e982efb27d811c94c" +checksum = "befdff0b4683a0824fc8719ce639a252d9d62cd89c8d0004c39e2417128c1eb8" dependencies = [ - "axum 0.6.0-rc.2", + "axum 0.6.1", "bytes", "cfg-if", "http", @@ -174,6 +187,12 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4c527152e37cf757a3f78aae5a06fbeefdb07ccc535c980a3208ee3060dd544" +[[package]] +name = "arrayvec" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b62fc65de8e4e7f52534fb52b0f3ed04746ae267519eef2a83941e8085068b" + [[package]] name = "arrayvec" version = "0.7.2" @@ -436,12 +455,12 @@ dependencies = [ [[package]] name = "axum" -version = "0.6.0-rc.2" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2628a243073c55aef15a1c1fe45c87f21b84f9e89ca9e7b262a180d3d03543d" +checksum = "08b108ad2665fa3f6e6a517c3d80ec3e77d224c47d605167aefaa5d7ef97fa48" dependencies = [ "async-trait", - "axum-core 0.3.0-rc.2", + "axum-core 0.3.0", "bitflags", "bytes", "futures-util", @@ -449,13 +468,15 @@ dependencies = [ "http-body", "hyper", "itoa 1.0.3", - "matchit 0.6.0", + "matchit 0.7.0", "memchr", "mime", "percent-encoding", "pin-project-lite", + "rustversion", "serde", "serde_json", + "serde_path_to_error", "serde_urlencoded", "sync_wrapper", "tokio", @@ -483,9 +504,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.3.0-rc.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "473bd0762170028bb6b5068be9e97de2a9f0af3bf2084498d840498f47194d3d" +checksum = "79b8558f5a0581152dc94dcd289132a1d377494bdeafcd41869b3258e3e2ad92" dependencies = [ "async-trait", "bytes", @@ -493,15 +514,16 @@ dependencies = [ "http", "http-body", "mime", + "rustversion", "tower-layer", "tower-service", ] [[package]] name = "axum-macros" -version = "0.3.0-rc.1" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "247a599903eb2e02abbaf2facc6396140df7af6dcc84e64ce3b71d117702fa22" +checksum = "e4df0fc33ada14a338b799002f7e8657711422b25d4e16afb032708d6b185621" dependencies = [ "heck 0.4.0", "proc-macro2", @@ -533,7 +555,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" dependencies = [ "futures-core", - "getrandom", + "getrandom 0.2.7", "instant", "pin-project-lite", "rand 0.8.5", @@ -577,7 +599,7 @@ checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" name = "benchmarks" version = "0.1.0" dependencies = [ - "arrow", + "arrow 10.0.0", "clap 4.0.18", "client", "indicatif", @@ -671,6 +693,17 @@ dependencies = [ "digest", ] +[[package]] +name = "blake2b_simd" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afa748e348ad3be8263be728124b24a24f268266f6f5d58af9d75f6a40b5c587" +dependencies = [ + "arrayref", + "arrayvec 0.5.2", + "constant_time_eq", +] + [[package]] name = "blake3" version = "1.3.1" @@ -678,7 +711,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a08e53fc5a564bb15bfe6fae56bd71522205f1f91893f9c0116edad6496c183f" dependencies = [ "arrayref", - "arrayvec", + "arrayvec 0.7.2", "cc", "cfg-if", "constant_time_eq", @@ -727,6 +760,17 @@ dependencies = [ "serde", ] +[[package]] +name = "build-data" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a94f9f7aab679acac7ce29ba5581c00d3971a861c3b501c5bb74c3ba0026d90" +dependencies = [ + "chrono", + "safe-lock", + "safe-regex", +] + [[package]] name = "bumpalo" version = "3.11.0" @@ -855,7 +899,6 @@ dependencies = [ "meta-client", "mito", "object-store", - "opendal", "regex", "serde", "serde_json", @@ -1072,7 +1115,7 @@ dependencies = [ "common-base", "common-error", "common-grpc", - "common-insert", + "common-grpc-expr", "common-query", "common-recordbatch", "common-time", @@ -1116,6 +1159,7 @@ dependencies = [ name = "cmd" version = "0.1.0" dependencies = [ + "build-data", "clap 3.2.22", "common-error", "common-telemetry", @@ -1143,6 +1187,18 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "comfy-table" +version = "6.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1090f39f45786ec6dc6286f8ea9c75d0a7ef0a0d3cda674cef0c3af7b307fbc2" +dependencies = [ + "crossterm", + "strum 0.24.1", + "strum_macros 0.24.3", + "unicode-width", +] + [[package]] name = "common-base" version = "0.1.0" @@ -1191,7 +1247,7 @@ dependencies = [ "common-function-macro", "common-query", "common-time", - "datafusion-common", + "datafusion-common 7.0.0", "datatypes", "libc", "num", @@ -1240,13 +1296,15 @@ dependencies = [ ] [[package]] -name = "common-insert" +name = "common-grpc-expr" version = "0.1.0" dependencies = [ "api", "async-trait", "common-base", + "common-catalog", "common-error", + "common-grpc", "common-query", "common-telemetry", "common-time", @@ -1265,7 +1323,7 @@ dependencies = [ "common-recordbatch", "common-time", "datafusion", - "datafusion-common", + "datafusion-common 7.0.0", "datafusion-expr", "datatypes", "snafu", @@ -1279,7 +1337,7 @@ version = "0.1.0" dependencies = [ "common-error", "datafusion", - "datafusion-common", + "datafusion-common 7.0.0", "datatypes", "futures", "paste", @@ -1612,6 +1670,31 @@ dependencies = [ "once_cell", ] +[[package]] +name = "crossterm" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67" +dependencies = [ + "bitflags", + "crossterm_winapi", + "libc", + "mio", + "parking_lot", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ae1b35a484aa10e07fe0638d02301c5ad24de82d310ccbd2f3693da5f09bf1c" +dependencies = [ + "winapi", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -1799,7 +1882,7 @@ dependencies = [ "blake2", "blake3", "chrono", - "datafusion-common", + "datafusion-common 7.0.0", "datafusion-expr", "datafusion-row", "half 2.1.0", @@ -1847,23 +1930,24 @@ version = "0.1.0" dependencies = [ "api", "async-trait", - "axum 0.6.0-rc.2", + "axum 0.6.1", "axum-macros", "axum-test-helper", + "backon", "catalog", "client", "common-base", "common-catalog", "common-error", "common-grpc", - "common-insert", + "common-grpc-expr", "common-query", "common-recordbatch", "common-runtime", "common-telemetry", "common-time", "datafusion", - "datafusion-common", + "datafusion-common 7.0.0", "datatypes", "frontend", "futures", @@ -1879,6 +1963,7 @@ dependencies = [ "serde", "serde_json", "servers", + "session", "snafu", "sql", "storage", @@ -1901,7 +1986,26 @@ dependencies = [ "common-base", "common-error", "common-time", - "datafusion-common", + "datafusion-common 7.0.0", + "enum_dispatch", + "num", + "num-traits", + "ordered-float 3.1.0", + "paste", + "serde", + "serde_json", + "snafu", +] + +[[package]] +name = "datatypes2" +version = "0.1.0" +dependencies = [ + "arrow 26.0.0", + "common-base", + "common-error", + "common-time", + "datafusion-common 14.0.0", "enum_dispatch", "num", "num-traits", @@ -1965,6 +2069,17 @@ dependencies = [ "subtle", ] +[[package]] +name = "dirs" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fd78930633bd1c6e35c4b42b1df7b0cbc6bc191146e512bb3bedf243fcc3901" +dependencies = [ + "libc", + "redox_users 0.3.5", + "winapi", +] + [[package]] name = "dirs" version = "4.0.0" @@ -1991,7 +2106,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" dependencies = [ "libc", - "redox_users", + "redox_users 0.4.3", "winapi", ] @@ -2002,7 +2117,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" dependencies = [ "libc", - "redox_users", + "redox_users 0.4.3", "winapi", ] @@ -2018,6 +2133,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +[[package]] +name = "dotenv" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f" + [[package]] name = "dyn-clone" version = "1.0.9" @@ -2179,6 +2300,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "flatbuffers" +version = "22.9.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce016b9901aef3579617931fbb2df8fc9a9f7cb95a16eb8acc8148209bb9e70" +dependencies = [ + "bitflags", + "thiserror", +] + [[package]] name = "flate2" version = "1.0.24" @@ -2228,14 +2359,14 @@ dependencies = [ "common-catalog", "common-error", "common-grpc", - "common-insert", + "common-grpc-expr", "common-query", "common-recordbatch", "common-runtime", "common-telemetry", "common-time", "datafusion", - "datafusion-common", + "datafusion-common 7.0.0", "datafusion-expr", "datanode", "datatypes", @@ -2248,9 +2379,11 @@ dependencies = [ "openmetrics-parser", "prost 0.11.0", "query", + "rustls", "serde", "serde_json", "servers", + "session", "snafu", "sql", "sqlparser 0.15.0", @@ -2462,6 +2595,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + [[package]] name = "getrandom" version = "0.2.7" @@ -2647,6 +2791,16 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "humantime-serde" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57a3db5ea5923d99402c94e9feb261dc5ee9b4efa158b0315f788cf549cc200c" +dependencies = [ + "humantime", + "serde", +] + [[package]] name = "hyper" version = "0.14.20" @@ -3040,9 +3194,9 @@ dependencies = [ [[package]] name = "lru" -version = "0.7.8" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999beba7b6e8345721bd280141ed958096a2e4abdf74f67ff4ce49b4b54e47a" +checksum = "b6e8aaa3f231bb4bd57b84b2d5dc3ae7f350265df8aa96492e0bc394a1571909" dependencies = [ "hashbrown", ] @@ -3114,9 +3268,9 @@ checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb" [[package]] name = "matchit" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dfc802da7b1cf80aefffa0c7b2f77247c8b32206cc83c270b61264f5b360a80" +checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" [[package]] name = "matrixmultiply" @@ -3319,7 +3473,7 @@ dependencies = [ "common-telemetry", "common-time", "datafusion", - "datafusion-common", + "datafusion-common 7.0.0", "datatypes", "futures", "log-store", @@ -3388,8 +3542,9 @@ dependencies = [ [[package]] name = "mysql_async" -version = "0.30.0" -source = "git+https://github.com/Morranto/mysql_async.git?rev=127b538#127b538e20b880e2855204e480fa784d8d08e150" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8fbd756177cfa8248baa7c5f555b9446349822bb94810c22336ec7597a72652" dependencies = [ "bytes", "crossbeam", @@ -3406,7 +3561,7 @@ dependencies = [ "percent-encoding", "pin-project", "rustls", - "rustls-pemfile 0.2.1", + "rustls-pemfile", "serde", "serde_json", "socket2", @@ -3663,6 +3818,7 @@ dependencies = [ "opendal", "tempdir", "tokio", + "uuid", ] [[package]] @@ -3699,9 +3855,9 @@ checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" [[package]] name = "opendal" -version = "0.20.1" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63b17b778cf11d10fbaaae4a5a0f82d5c6f527f96a9e4843f4e2dd6cd0dbe580" +checksum = "8c9be1e30ca12b989107a5ee5bb75468a7f538059e43255ccd4743089b42aeeb" dependencies = [ "anyhow", "async-compat", @@ -3715,6 +3871,7 @@ dependencies = [ "http", "log", "md-5", + "metrics", "once_cell", "parking_lot", "percent-encoding", @@ -3724,9 +3881,9 @@ dependencies = [ "reqwest", "serde", "serde_json", - "thiserror", "time 0.3.14", "tokio", + "tracing", "ureq", ] @@ -3743,16 +3900,18 @@ dependencies = [ [[package]] name = "opensrv-mysql" -version = "0.2.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4c24c12fd688cb5aa5b1a54c6ccb2e30fb9b5132debb0e89fcb432b3f73db8f" +checksum = "ac5d68ae914b1317d874ce049e52d386b1209d8835d4e6e094f2e90bfb49eccc" dependencies = [ "async-trait", "byteorder", "chrono", "mysql_common", "nom", + "pin-project-lite", "tokio", + "tokio-rustls", ] [[package]] @@ -3873,7 +4032,7 @@ dependencies = [ "cfg-if", "libc", "petgraph", - "redox_syscall", + "redox_syscall 0.2.16", "smallvec", "thread-id", "windows-sys", @@ -4246,6 +4405,17 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +[[package]] +name = "prettydiff" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b6176190f1637d46034820b82fbe758727ccb40da9c9fc2255d695eb05ea29c" +dependencies = [ + "ansi_term", + "prettytable-rs", + "structopt", +] + [[package]] name = "prettyplease" version = "0.1.19" @@ -4256,6 +4426,20 @@ dependencies = [ "syn", ] +[[package]] +name = "prettytable-rs" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fd04b170004fa2daccf418a7f8253aaf033c27760b5f225889024cf66d7ac2e" +dependencies = [ + "atty", + "csv", + "encode_unicode", + "lazy_static", + "term", + "unicode-width", +] + [[package]] name = "proc-macro-crate" version = "1.2.1" @@ -4456,7 +4640,7 @@ dependencies = [ "common-telemetry", "common-time", "datafusion", - "datafusion-common", + "datafusion-common 7.0.0", "datafusion-physical-expr", "datatypes", "format_num", @@ -4470,6 +4654,7 @@ dependencies = [ "rand 0.8.5", "serde", "serde_json", + "session", "snafu", "sql", "statrs", @@ -4570,7 +4755,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.7", ] [[package]] @@ -4631,6 +4816,12 @@ dependencies = [ "rand_core 0.3.1", ] +[[package]] +name = "redox_syscall" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" + [[package]] name = "redox_syscall" version = "0.2.16" @@ -4640,14 +4831,25 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_users" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de0737333e7a9502c789a36d7c7fa6092a49895d4faa31ca5df163857ded2e9d" +dependencies = [ + "getrandom 0.1.16", + "redox_syscall 0.1.57", + "rust-argon2", +] + [[package]] name = "redox_users" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ - "getrandom", - "redox_syscall", + "getrandom 0.2.7", + "redox_syscall 0.2.16", "thiserror", ] @@ -4688,15 +4890,15 @@ dependencies = [ [[package]] name = "reqsign" -version = "0.6.4" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22524be78041476bf8673f2720fa1000f34432b384d9ad5846b024569a4b150" +checksum = "d34ea360414ee77ddab3a8360a0c241fc77ab5e27892dcde1d2cfcc29d4e0f55" dependencies = [ "anyhow", "backon", "base64", "bytes", - "dirs", + "dirs 4.0.0", "form_urlencoded", "hex", "hmac", @@ -4741,7 +4943,7 @@ dependencies = [ "pin-project-lite", "rustls", "rustls-native-certs", - "rustls-pemfile 1.0.1", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", @@ -4804,6 +5006,18 @@ dependencies = [ "serde", ] +[[package]] +name = "rust-argon2" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b18820d944b33caa75a71378964ac46f58517c92b6ae5f762636247c09e78fb" +dependencies = [ + "base64", + "blake2b_simd", + "constant_time_eq", + "crossbeam-utils", +] + [[package]] name = "rust-ini" version = "0.18.0" @@ -4820,7 +5034,7 @@ version = "1.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee9164faf726e4f3ece4978b25ca877ddc6802fa77f38cdccb32c7f805ecd70c" dependencies = [ - "arrayvec", + "arrayvec 0.7.2", "num-traits", "serde", ] @@ -4862,9 +5076,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.20.6" +version = "0.20.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033" +checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c" dependencies = [ "log", "ring", @@ -4879,20 +5093,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50" dependencies = [ "openssl-probe", - "rustls-pemfile 1.0.1", + "rustls-pemfile", "schannel", "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5eebeaeb360c87bfb72e84abdb3447159c0eaececf1bef2aecd65a8be949d1c9" -dependencies = [ - "base64", -] - [[package]] name = "rustls-pemfile" version = "1.0.1" @@ -5141,6 +5346,59 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" +[[package]] +name = "safe-lock" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "077d73db7973cccf63eb4aff1e5a34dc2459baa867512088269ea5f2f4253c90" + +[[package]] +name = "safe-proc-macro2" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "814c536dcd27acf03296c618dab7ad62d28e70abd7ba41d3f34a2ce707a2c666" +dependencies = [ + "unicode-xid", +] + +[[package]] +name = "safe-quote" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77e530f7831f3feafcd5f1aae406ac205dd998436b4007c8e80f03eca78a88f7" +dependencies = [ + "safe-proc-macro2", +] + +[[package]] +name = "safe-regex" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15289bf322e0673d52756a18194167f2378ec1a15fe884af6e2d2cb934822b0" +dependencies = [ + "safe-regex-macro", +] + +[[package]] +name = "safe-regex-compiler" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fba76fae590a2aa665279deb1f57b5098cbace01a0c5e60e262fcf55f7c51542" +dependencies = [ + "safe-proc-macro2", + "safe-quote", +] + +[[package]] +name = "safe-regex-macro" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96c2e96b5c03f158d1b16ba79af515137795f4ad4e8de3f790518aae91f1d127" +dependencies = [ + "safe-proc-macro2", + "safe-regex-compiler", +] + [[package]] name = "same-file" version = "1.0.6" @@ -5221,7 +5479,7 @@ dependencies = [ "common-time", "console", "datafusion", - "datafusion-common", + "datafusion-common 7.0.0", "datafusion-expr", "datafusion-physical-expr", "datatypes", @@ -5239,6 +5497,7 @@ dependencies = [ "rustpython-parser", "rustpython-vm", "serde", + "session", "snafu", "sql", "storage", @@ -5348,6 +5607,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "184c643044780f7ceb59104cef98a5a6f12cb2288a7bc701ab93a362b49fd47d" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -5367,7 +5635,7 @@ dependencies = [ "aide", "api", "async-trait", - "axum 0.6.0-rc.2", + "axum 0.6.1", "axum-macros", "axum-test-helper", "bytes", @@ -5384,6 +5652,7 @@ dependencies = [ "datatypes", "futures", "hex", + "humantime-serde", "hyper", "influxdb_line_protocol", "metrics", @@ -5397,15 +5666,20 @@ dependencies = [ "query", "rand 0.8.5", "regex", + "rustls", + "rustls-pemfile", "schemars", "script", "serde", "serde_json", + "session", "snafu", "snap", "table", "tokio", "tokio-postgres", + "tokio-postgres-rustls", + "tokio-rustls", "tokio-stream", "tokio-test", "tonic", @@ -5414,6 +5688,14 @@ dependencies = [ "tower-http", ] +[[package]] +name = "session" +version = "0.1.0" +dependencies = [ + "arc-swap", + "common-telemetry", +] + [[package]] name = "sha-1" version = "0.10.0" @@ -5462,6 +5744,27 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" +[[package]] +name = "signal-hook" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a253b5e89e2698464fc26b545c9edceb338e18a89effeeecfea192c3025be29d" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" +dependencies = [ + "libc", + "mio", + "signal-hook", +] + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -5764,6 +6067,30 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "structopt" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c6b5c64445ba8094a6ab0c3cd2ad323e07171012d9c98b0b15651daf1787a10" +dependencies = [ + "clap 2.34.0", + "lazy_static", + "structopt-derive", +] + +[[package]] +name = "structopt-derive" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcb5ae327f9cc13b68763b5749770cb9e048a99bd9dfdfa58d0cf05d5f64afe0" +dependencies = [ + "heck 0.3.3", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "strum" version = "0.24.1" @@ -5801,7 +6128,9 @@ dependencies = [ "catalog", "common-catalog", "common-error", + "common-telemetry", "datafusion", + "datafusion-expr", "datatypes", "futures", "prost 0.9.0", @@ -5866,7 +6195,7 @@ dependencies = [ "common-recordbatch", "common-telemetry", "datafusion", - "datafusion-common", + "datafusion-common 7.0.0", "datafusion-expr", "datatypes", "derive_builder", @@ -5912,11 +6241,22 @@ dependencies = [ "cfg-if", "fastrand", "libc", - "redox_syscall", + "redox_syscall 0.2.16", "remove_dir_all", "winapi", ] +[[package]] +name = "term" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd106a334b7657c10b7c540a0106114feadeb4dc314513e97df481d5d966f42" +dependencies = [ + "byteorder", + "dirs 1.0.5", + "winapi", +] + [[package]] name = "termcolor" version = "1.1.3" @@ -5936,6 +6276,38 @@ dependencies = [ "winapi", ] +[[package]] +name = "tests-integration" +version = "0.1.0" +dependencies = [ + "api", + "axum 0.6.1", + "axum-test-helper", + "catalog", + "client", + "common-catalog", + "common-runtime", + "common-telemetry", + "datanode", + "datatypes", + "dotenv", + "frontend", + "mito", + "object-store", + "once_cell", + "paste", + "rand 0.8.5", + "serde", + "serde_json", + "servers", + "snafu", + "sql", + "table", + "tempdir", + "tokio", + "uuid", +] + [[package]] name = "textwrap" version = "0.11.0" @@ -5978,7 +6350,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5fdfe0627923f7411a43ec9ec9c39c3a9b4151be313e0922042581fb6c9b717f" dependencies = [ "libc", - "redox_syscall", + "redox_syscall 0.2.16", "winapi", ] @@ -6160,6 +6532,20 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-postgres-rustls" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "606f2b73660439474394432239c82249c0d45eb5f23d91f401be1e33590444a7" +dependencies = [ + "futures", + "ring", + "rustls", + "tokio", + "tokio-postgres", + "tokio-rustls", +] + [[package]] name = "tokio-rustls" version = "0.23.4" @@ -6332,9 +6718,9 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62" +checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" [[package]] name = "tower-service" @@ -6635,6 +7021,12 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + [[package]] name = "unicode_names2" version = "0.5.1" @@ -6687,7 +7079,8 @@ version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "422ee0de9031b5b948b97a8fc04e3aa35230001a722ddd27943e0be31564ce4c" dependencies = [ - "getrandom", + "getrandom 0.2.7", + "serde", ] [[package]] @@ -6753,6 +7146,12 @@ dependencies = [ "try-lock", ] +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + [[package]] name = "wasi" version = "0.10.2+wasi-snapshot-preview1" diff --git a/Cargo.toml b/Cargo.toml index be37ca3790..77d94f0f37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,28 +11,32 @@ members = [ "src/common/function", "src/common/function-macro", "src/common/grpc", + "src/common/grpc-expr", "src/common/query", "src/common/recordbatch", "src/common/runtime", "src/common/substrait", - "src/common/insert", "src/common/telemetry", "src/common/time", "src/datanode", "src/datatypes", + "src/datatypes2", "src/frontend", "src/log-store", "src/meta-client", "src/meta-srv", + "src/mito", "src/object-store", "src/query", "src/script", "src/servers", + "src/session", "src/sql", "src/storage", "src/store-api", "src/table", - "src/mito", + "tests-integration", + "tests/runner", ] [profile.release] diff --git a/README.md b/README.md index 6f4f3398bc..c54ba35202 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,6 @@ To compile GreptimeDB from source, you'll need: find an installation instructions [here](https://grpc.io/docs/protoc-installation/). **Note that `protoc` version needs to be >= 3.15** because we have used the `optional` keyword. You can check it with `protoc --version`. - #### Build with Docker @@ -161,6 +160,8 @@ break things. Benchmark on development branch may not represent its potential performance. We release pre-built binaries constantly for functional evaluation. Do not use it in production at the moment. +For future plans, check out [GreptimeDB roadmap](https://github.com/GreptimeTeam/greptimedb/issues/669). + ## Community Our core team is thrilled too see you participate in any ways you like. When you are stuck, try to diff --git a/benchmarks/src/bin/nyc-taxi.rs b/benchmarks/src/bin/nyc-taxi.rs index fe434591f3..0ca1f33182 100644 --- a/benchmarks/src/bin/nyc-taxi.rs +++ b/benchmarks/src/bin/nyc-taxi.rs @@ -28,9 +28,8 @@ use arrow::datatypes::{DataType, Float64Type, Int64Type}; use arrow::record_batch::RecordBatch; use clap::Parser; use client::admin::Admin; -use client::api::v1::codec::InsertBatch; use client::api::v1::column::Values; -use client::api::v1::{insert_expr, Column, ColumnDataType, ColumnDef, CreateExpr, InsertExpr}; +use client::api::v1::{Column, ColumnDataType, ColumnDef, CreateExpr, InsertExpr}; use client::{Client, Database, Select}; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use parquet::arrow::{ArrowReader, ParquetFileArrowReader}; @@ -100,16 +99,13 @@ async fn write_data( for record_batch in record_batch_reader { let record_batch = record_batch.unwrap(); - let row_count = record_batch.num_rows(); - let insert_batch = convert_record_batch(record_batch).into(); + let (columns, row_count) = convert_record_batch(record_batch); let insert_expr = InsertExpr { schema_name: "public".to_string(), table_name: TABLE_NAME.to_string(), - expr: Some(insert_expr::Expr::Values(insert_expr::Values { - values: vec![insert_batch], - })), - options: HashMap::default(), region_number: 0, + columns, + row_count, }; let now = Instant::now(); db.insert(insert_expr).await.unwrap(); @@ -125,7 +121,7 @@ async fn write_data( total_rpc_elapsed_ms } -fn convert_record_batch(record_batch: RecordBatch) -> InsertBatch { +fn convert_record_batch(record_batch: RecordBatch) -> (Vec, u32) { let schema = record_batch.schema(); let fields = schema.fields(); let row_count = record_batch.num_rows(); @@ -143,10 +139,7 @@ fn convert_record_batch(record_batch: RecordBatch) -> InsertBatch { columns.push(column); } - InsertBatch { - columns, - row_count: row_count as _, - } + (columns, row_count as _) } fn build_values(column: &ArrayRef) -> Values { diff --git a/codecov.yml b/codecov.yml index cdd5d34113..38422f8218 100644 --- a/codecov.yml +++ b/codecov.yml @@ -7,3 +7,4 @@ coverage: patch: off ignore: - "**/error*.rs" # ignore all error.rs files + - "tests/runner/*.rs" # ignore integration test runner diff --git a/config/datanode.example.toml b/config/datanode.example.toml index 795ad45661..a8ed29da4b 100644 --- a/config/datanode.example.toml +++ b/config/datanode.example.toml @@ -5,6 +5,7 @@ wal_dir = '/tmp/greptimedb/wal' rpc_runtime_size = 8 mysql_addr = '127.0.0.1:4406' mysql_runtime_size = 4 +enable_memory_catalog = false [storage] type = 'File' diff --git a/config/frontend.example.toml b/config/frontend.example.toml index b23335f51e..a26112ba22 100644 --- a/config/frontend.example.toml +++ b/config/frontend.example.toml @@ -1,6 +1,9 @@ mode = 'distributed' datanode_rpc_addr = '127.0.0.1:3001' -http_addr = '127.0.0.1:4000' + +[http_options] +addr = '127.0.0.1:4000' +timeout = "30s" [meta_client_opts] metasrv_addrs = ['127.0.0.1:3002'] diff --git a/config/standalone.example.toml b/config/standalone.example.toml index dc8b0519ab..54587a6e4d 100644 --- a/config/standalone.example.toml +++ b/config/standalone.example.toml @@ -1,7 +1,11 @@ node_id = 0 mode = 'standalone' -http_addr = '127.0.0.1:4000' wal_dir = '/tmp/greptimedb/wal/' +enable_memory_catalog = false + +[http_options] +addr = '127.0.0.1:4000' +timeout = "30s" [storage] type = 'File' diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 035ba7d7c5..8e1460e796 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -7,8 +7,8 @@ license = "Apache-2.0" [dependencies] common-base = { path = "../common/base" } -common-time = { path = "../common/time" } common-error = { path = "../common/error" } +common-time = { path = "../common/time" } datatypes = { path = "../datatypes" } prost = "0.11" snafu = { version = "0.7", features = ["backtraces"] } diff --git a/src/api/build.rs b/src/api/build.rs index ec88ccf408..f3ff5f6600 100644 --- a/src/api/build.rs +++ b/src/api/build.rs @@ -20,7 +20,6 @@ fn main() { .file_descriptor_set_path(default_out_dir.join("greptime_fd.bin")) .compile( &[ - "greptime/v1/insert.proto", "greptime/v1/select.proto", "greptime/v1/physical_plan.proto", "greptime/v1/greptime.proto", diff --git a/src/api/greptime/v1/admin.proto b/src/api/greptime/v1/admin.proto index 9c2c95ecdf..3f253cde0f 100644 --- a/src/api/greptime/v1/admin.proto +++ b/src/api/greptime/v1/admin.proto @@ -20,6 +20,7 @@ message AdminExpr { CreateExpr create = 2; AlterExpr alter = 3; CreateDatabaseExpr create_database = 4; + DropTableExpr drop_table = 5; } } @@ -55,6 +56,12 @@ message AlterExpr { } } +message DropTableExpr { + string catalog_name = 1; + string schema_name = 2; + string table_name = 3; +} + message AddColumns { repeated AddColumn add_columns = 1; } diff --git a/src/api/greptime/v1/database.proto b/src/api/greptime/v1/database.proto index e4b651f322..1cd6a6ee3e 100644 --- a/src/api/greptime/v1/database.proto +++ b/src/api/greptime/v1/database.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package greptime.v1; +import "greptime/v1/column.proto"; import "greptime/v1/common.proto"; message DatabaseRequest { @@ -41,26 +42,16 @@ message InsertExpr { string schema_name = 1; string table_name = 2; - message Values { - repeated bytes values = 1; - } + // Data is represented here. + repeated Column columns = 3; - oneof expr { - Values values = 3; + // The row_count of all columns, which include null and non-null values. + // + // Note: the row_count of all columns in a InsertExpr must be same. + uint32 row_count = 4; - // TODO(LFC): Remove field "sql" in InsertExpr. - // When Frontend instance received an insertion SQL (`insert into ...`), it's anticipated to parse the SQL and - // assemble the values to insert to feed Datanode. In other words, inserting data through Datanode instance's GRPC - // interface shouldn't use SQL directly. - // Then why the "sql" field exists here? It's because the Frontend needs table schema to create the values to insert, - // which is currently not able to find anywhere. (Maybe the table schema is suppose to be fetched from Meta?) - // The "sql" field is meant to be removed in the future. - string sql = 4; - } - - /// The region number of current insert request. + // The region number of current insert request. uint32 region_number = 5; - map options = 6; } // TODO(jiachun) diff --git a/src/api/greptime/v1/insert.proto b/src/api/greptime/v1/insert.proto deleted file mode 100644 index 0e173723a6..0000000000 --- a/src/api/greptime/v1/insert.proto +++ /dev/null @@ -1,14 +0,0 @@ -syntax = "proto3"; - -package greptime.v1.codec; - -import "greptime/v1/column.proto"; - -message InsertBatch { - repeated Column columns = 1; - uint32 row_count = 2; -} - -message RegionNumber { - uint32 id = 1; -} diff --git a/src/api/src/lib.rs b/src/api/src/lib.rs index 73aa6c4363..d6c415d8cf 100644 --- a/src/api/src/lib.rs +++ b/src/api/src/lib.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod column_def; pub mod error; pub mod helper; pub mod prometheus; diff --git a/src/api/src/serde.rs b/src/api/src/serde.rs index 9daf64e8e6..1523bfbcfe 100644 --- a/src/api/src/serde.rs +++ b/src/api/src/serde.rs @@ -15,7 +15,7 @@ pub use prost::DecodeError; use prost::Message; -use crate::v1::codec::{InsertBatch, PhysicalPlanNode, RegionNumber, SelectResult}; +use crate::v1::codec::{PhysicalPlanNode, SelectResult}; use crate::v1::meta::TableRouteValue; macro_rules! impl_convert_with_bytes { @@ -36,10 +36,8 @@ macro_rules! impl_convert_with_bytes { }; } -impl_convert_with_bytes!(InsertBatch); impl_convert_with_bytes!(SelectResult); impl_convert_with_bytes!(PhysicalPlanNode); -impl_convert_with_bytes!(RegionNumber); impl_convert_with_bytes!(TableRouteValue); #[cfg(test)] @@ -51,52 +49,6 @@ mod tests { const SEMANTIC_TAG: i32 = 0; - #[test] - fn test_convert_insert_batch() { - let insert_batch = mock_insert_batch(); - - let bytes: Vec = insert_batch.into(); - let insert: InsertBatch = bytes.deref().try_into().unwrap(); - - assert_eq!(8, insert.row_count); - assert_eq!(1, insert.columns.len()); - - let column = &insert.columns[0]; - assert_eq!("foo", column.column_name); - assert_eq!(SEMANTIC_TAG, column.semantic_type); - assert_eq!(vec![1], column.null_mask); - assert_eq!( - vec![2, 3, 4, 5, 6, 7, 8], - column.values.as_ref().unwrap().i32_values - ); - } - - #[should_panic] - #[test] - fn test_convert_insert_batch_wrong() { - let insert_batch = mock_insert_batch(); - - let mut bytes: Vec = insert_batch.into(); - - // modify some bytes - bytes[0] = 0b1; - bytes[1] = 0b1; - - let insert: InsertBatch = bytes.deref().try_into().unwrap(); - - assert_eq!(8, insert.row_count); - assert_eq!(1, insert.columns.len()); - - let column = &insert.columns[0]; - assert_eq!("foo", column.column_name); - assert_eq!(SEMANTIC_TAG, column.semantic_type); - assert_eq!(vec![1], column.null_mask); - assert_eq!( - vec![2, 3, 4, 5, 6, 7, 8], - column.values.as_ref().unwrap().i32_values - ); - } - #[test] fn test_convert_select_result() { let select_result = mock_select_result(); @@ -143,35 +95,6 @@ mod tests { ); } - #[test] - fn test_convert_region_id() { - let region_id = RegionNumber { id: 12 }; - - let bytes: Vec = region_id.into(); - let region_id: RegionNumber = bytes.deref().try_into().unwrap(); - - assert_eq!(12, region_id.id); - } - - fn mock_insert_batch() -> InsertBatch { - let values = column::Values { - i32_values: vec![2, 3, 4, 5, 6, 7, 8], - ..Default::default() - }; - let null_mask = vec![1]; - let column = Column { - column_name: "foo".to_string(), - semantic_type: SEMANTIC_TAG, - values: Some(values), - null_mask, - ..Default::default() - }; - InsertBatch { - columns: vec![column], - row_count: 8, - } - } - fn mock_select_result() -> SelectResult { let values = column::Values { i32_values: vec![2, 3, 4, 5, 6, 7, 8], diff --git a/src/api/src/v1.rs b/src/api/src/v1.rs index 4438ce7870..380e810f09 100644 --- a/src/api/src/v1.rs +++ b/src/api/src/v1.rs @@ -21,4 +21,5 @@ pub mod codec { tonic::include_proto!("greptime.v1.codec"); } +mod column_def; pub mod meta; diff --git a/src/api/src/column_def.rs b/src/api/src/v1/column_def.rs similarity index 100% rename from src/api/src/column_def.rs rename to src/api/src/v1/column_def.rs diff --git a/src/catalog/Cargo.toml b/src/catalog/Cargo.toml index 3ea95d2b21..90adcf8e8a 100644 --- a/src/catalog/Cargo.toml +++ b/src/catalog/Cargo.toml @@ -25,7 +25,6 @@ futures = "0.3" futures-util = "0.3" lazy_static = "1.4" meta-client = { path = "../meta-client" } -opendal = "0.20" regex = "1.6" serde = "1.0" serde_json = "1.0" @@ -37,9 +36,8 @@ tokio = { version = "1.18", features = ["full"] } [dev-dependencies] chrono = "0.4" log-store = { path = "../log-store" } +mito = { path = "../mito", features = ["test"] } object-store = { path = "../object-store" } -opendal = "0.20" storage = { path = "../storage" } -mito = { path = "../mito" } tempdir = "0.3" tokio = { version = "1.0", features = ["full"] } diff --git a/src/catalog/src/error.rs b/src/catalog/src/error.rs index 24ab530f4e..05e6944cd5 100644 --- a/src/catalog/src/error.rs +++ b/src/catalog/src/error.rs @@ -94,7 +94,7 @@ pub enum Error { backtrace: Backtrace, }, - #[snafu(display("Table {} already exists", table))] + #[snafu(display("Table `{}` already exists", table))] TableExists { table: String, backtrace: Backtrace }, #[snafu(display("Schema {} already exists", schema))] @@ -109,6 +109,12 @@ pub enum Error { source: BoxedError, }, + #[snafu(display("Operation {} not implemented yet", operation))] + Unimplemented { + operation: String, + backtrace: Backtrace, + }, + #[snafu(display("Failed to open table, table info: {}, source: {}", table_info, source))] OpenTable { table_info: String, @@ -216,11 +222,12 @@ impl ErrorExt for Error { | Error::ValueDeserialize { .. } | Error::Io { .. } => StatusCode::StorageUnavailable, + Error::RegisterTable { .. } => StatusCode::Internal, + Error::ReadSystemCatalog { source, .. } => source.status_code(), Error::SystemCatalogTypeMismatch { source, .. } => source.status_code(), Error::InvalidCatalogValue { source, .. } => source.status_code(), - Error::RegisterTable { .. } => StatusCode::Internal, Error::TableExists { .. } => StatusCode::TableAlreadyExists, Error::SchemaExists { .. } => StatusCode::InvalidArguments, @@ -235,6 +242,8 @@ impl ErrorExt for Error { Error::InvalidTableSchema { source, .. } => source.status_code(), Error::InvalidTableInfoInCatalog { .. } => StatusCode::Unexpected, Error::Internal { source, .. } => source.status_code(), + + Error::Unimplemented { .. } => StatusCode::Unsupported, } } diff --git a/src/catalog/src/lib.rs b/src/catalog/src/lib.rs index 941d2c2580..fc7bb42b03 100644 --- a/src/catalog/src/lib.rs +++ b/src/catalog/src/lib.rs @@ -15,6 +15,7 @@ #![feature(assert_matches)] use std::any::Any; +use std::fmt::{Debug, Formatter}; use std::sync::Arc; use common_telemetry::info; @@ -83,12 +84,17 @@ pub trait CatalogManager: CatalogList { /// Starts a catalog manager. async fn start(&self) -> Result<()>; - /// Registers a table given given catalog/schema to catalog manager, - /// returns table registered. - async fn register_table(&self, request: RegisterTableRequest) -> Result; + /// Registers a table within given catalog/schema to catalog manager, + /// returns whether the table registered. + async fn register_table(&self, request: RegisterTableRequest) -> Result; - /// Register a schema with catalog name and schema name. - async fn register_schema(&self, request: RegisterSchemaRequest) -> Result; + /// Deregisters a table within given catalog/schema to catalog manager, + /// returns whether the table deregistered. + async fn deregister_table(&self, request: DeregisterTableRequest) -> Result; + + /// Register a schema with catalog name and schema name. Retuens whether the + /// schema registered. + async fn register_schema(&self, request: RegisterSchemaRequest) -> Result; /// Register a system table, should be called before starting the manager. async fn register_system_table(&self, request: RegisterSystemTableRequest) @@ -123,6 +129,25 @@ pub struct RegisterTableRequest { pub table: TableRef, } +impl Debug for RegisterTableRequest { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RegisterTableRequest") + .field("catalog", &self.catalog) + .field("schema", &self.schema) + .field("table_name", &self.table_name) + .field("table_id", &self.table_id) + .field("table", &self.table.table_info()) + .finish() + } +} + +#[derive(Clone)] +pub struct DeregisterTableRequest { + pub catalog: String, + pub schema: String, + pub table_name: String, +} + #[derive(Debug, Clone)] pub struct RegisterSchemaRequest { pub catalog: String, diff --git a/src/catalog/src/local/manager.rs b/src/catalog/src/local/manager.rs index ed6783c68f..d09411cbaa 100644 --- a/src/catalog/src/local/manager.rs +++ b/src/catalog/src/local/manager.rs @@ -21,7 +21,7 @@ use common_catalog::consts::{ SYSTEM_CATALOG_NAME, SYSTEM_CATALOG_TABLE_NAME, }; use common_recordbatch::{RecordBatch, SendableRecordBatchStream}; -use common_telemetry::info; +use common_telemetry::{error, info}; use datatypes::prelude::ScalarVector; use datatypes::vectors::{BinaryVector, UInt8Vector}; use futures_util::lock::Mutex; @@ -36,7 +36,7 @@ use table::TableRef; use crate::error::{ CatalogNotFoundSnafu, IllegalManagerStateSnafu, OpenTableSnafu, ReadSystemCatalogSnafu, Result, SchemaExistsSnafu, SchemaNotFoundSnafu, SystemCatalogSnafu, SystemCatalogTypeMismatchSnafu, - TableExistsSnafu, TableNotFoundSnafu, + TableExistsSnafu, TableNotFoundSnafu, UnimplementedSnafu, }; use crate::local::memory::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; use crate::system::{ @@ -46,8 +46,8 @@ use crate::system::{ use crate::tables::SystemCatalog; use crate::{ format_full_table_name, handle_system_table_request, CatalogList, CatalogManager, - CatalogProvider, CatalogProviderRef, RegisterSchemaRequest, RegisterSystemTableRequest, - RegisterTableRequest, SchemaProvider, SchemaProviderRef, + CatalogProvider, CatalogProviderRef, DeregisterTableRequest, RegisterSchemaRequest, + RegisterSystemTableRequest, RegisterTableRequest, SchemaProvider, SchemaProviderRef, }; /// A `CatalogManager` consists of a system catalog and a bunch of user catalogs. @@ -57,6 +57,7 @@ pub struct LocalCatalogManager { engine: TableEngineRef, next_table_id: AtomicU32, init_lock: Mutex, + register_lock: Mutex<()>, system_table_requests: Mutex>, } @@ -76,6 +77,7 @@ impl LocalCatalogManager { engine, next_table_id: AtomicU32::new(MIN_USER_TABLE_ID), init_lock: Mutex::new(false), + register_lock: Mutex::new(()), system_table_requests: Mutex::new(Vec::default()), }) } @@ -309,7 +311,7 @@ impl CatalogManager for LocalCatalogManager { self.init().await } - async fn register_table(&self, request: RegisterTableRequest) -> Result { + async fn register_table(&self, request: RegisterTableRequest) -> Result { let started = self.init_lock.lock().await; ensure!( @@ -332,27 +334,50 @@ impl CatalogManager for LocalCatalogManager { schema_info: format!("{}.{}", catalog_name, schema_name), })?; - if schema.table_exist(&request.table_name)? { - return TableExistsSnafu { - table: format_full_table_name(catalog_name, schema_name, &request.table_name), + { + let _lock = self.register_lock.lock().await; + if let Some(existing) = schema.table(&request.table_name)? { + if existing.table_info().ident.table_id != request.table_id { + error!( + "Unexpected table register request: {:?}, existing: {:?}", + request, + existing.table_info() + ); + return TableExistsSnafu { + table: format_full_table_name( + catalog_name, + schema_name, + &request.table_name, + ), + } + .fail(); + } + // Try to register table with same table id, just ignore. + Ok(false) + } else { + // table does not exist + self.system + .register_table( + catalog_name.clone(), + schema_name.clone(), + request.table_name.clone(), + request.table_id, + ) + .await?; + schema.register_table(request.table_name, request.table)?; + Ok(true) } - .fail(); } - - self.system - .register_table( - catalog_name.clone(), - schema_name.clone(), - request.table_name.clone(), - request.table_id, - ) - .await?; - - schema.register_table(request.table_name, request.table)?; - Ok(1) } - async fn register_schema(&self, request: RegisterSchemaRequest) -> Result { + async fn deregister_table(&self, _request: DeregisterTableRequest) -> Result { + UnimplementedSnafu { + operation: "deregister table", + } + .fail() + } + + async fn register_schema(&self, request: RegisterSchemaRequest) -> Result { let started = self.init_lock.lock().await; ensure!( *started, @@ -367,17 +392,21 @@ impl CatalogManager for LocalCatalogManager { .catalogs .catalog(catalog_name)? .context(CatalogNotFoundSnafu { catalog_name })?; - if catalog.schema(schema_name)?.is_some() { - return SchemaExistsSnafu { - schema: schema_name, - } - .fail(); + + { + let _lock = self.register_lock.lock().await; + ensure!( + catalog.schema(schema_name)?.is_none(), + SchemaExistsSnafu { + schema: schema_name, + } + ); + self.system + .register_schema(request.catalog, schema_name.clone()) + .await?; + catalog.register_schema(request.schema, Arc::new(MemorySchemaProvider::new()))?; + Ok(true) } - self.system - .register_schema(request.catalog, schema_name.clone()) - .await?; - catalog.register_schema(request.schema, Arc::new(MemorySchemaProvider::new()))?; - Ok(1) } async fn register_system_table(&self, request: RegisterSystemTableRequest) -> Result<()> { diff --git a/src/catalog/src/local/memory.rs b/src/catalog/src/local/memory.rs index a32b29e204..fb41058ad0 100644 --- a/src/catalog/src/local/memory.rs +++ b/src/catalog/src/local/memory.rs @@ -19,6 +19,7 @@ use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, RwLock}; use common_catalog::consts::MIN_USER_TABLE_ID; +use common_telemetry::error; use snafu::OptionExt; use table::metadata::TableId; use table::table::TableIdProvider; @@ -27,8 +28,8 @@ use table::TableRef; use crate::error::{CatalogNotFoundSnafu, Result, SchemaNotFoundSnafu, TableExistsSnafu}; use crate::schema::SchemaProvider; use crate::{ - CatalogList, CatalogManager, CatalogProvider, CatalogProviderRef, RegisterSchemaRequest, - RegisterSystemTableRequest, RegisterTableRequest, SchemaProviderRef, + CatalogList, CatalogManager, CatalogProvider, CatalogProviderRef, DeregisterTableRequest, + RegisterSchemaRequest, RegisterSystemTableRequest, RegisterTableRequest, SchemaProviderRef, }; /// Simple in-memory list of catalogs @@ -69,7 +70,7 @@ impl CatalogManager for MemoryCatalogManager { Ok(()) } - async fn register_table(&self, request: RegisterTableRequest) -> Result { + async fn register_table(&self, request: RegisterTableRequest) -> Result { let catalogs = self.catalogs.write().unwrap(); let catalog = catalogs .get(&request.catalog) @@ -84,10 +85,28 @@ impl CatalogManager for MemoryCatalogManager { })?; schema .register_table(request.table_name, request.table) - .map(|v| if v.is_some() { 0 } else { 1 }) + .map(|v| v.is_none()) } - async fn register_schema(&self, request: RegisterSchemaRequest) -> Result { + async fn deregister_table(&self, request: DeregisterTableRequest) -> Result { + let catalogs = self.catalogs.write().unwrap(); + let catalog = catalogs + .get(&request.catalog) + .context(CatalogNotFoundSnafu { + catalog_name: &request.catalog, + })? + .clone(); + let schema = catalog + .schema(&request.schema)? + .with_context(|| SchemaNotFoundSnafu { + schema_info: format!("{}.{}", &request.catalog, &request.schema), + })?; + schema + .deregister_table(&request.table_name) + .map(|v| v.is_some()) + } + + async fn register_schema(&self, request: RegisterSchemaRequest) -> Result { let catalogs = self.catalogs.write().unwrap(); let catalog = catalogs .get(&request.catalog) @@ -95,11 +114,12 @@ impl CatalogManager for MemoryCatalogManager { catalog_name: &request.catalog, })?; catalog.register_schema(request.schema, Arc::new(MemorySchemaProvider::new()))?; - Ok(1) + Ok(true) } async fn register_system_table(&self, _request: RegisterSystemTableRequest) -> Result<()> { - unimplemented!() + // TODO(ruihang): support register system table request + Ok(()) } fn schema(&self, catalog: &str, schema: &str) -> Result> { @@ -251,11 +271,21 @@ impl SchemaProvider for MemorySchemaProvider { } fn register_table(&self, name: String, table: TableRef) -> Result> { - if self.table_exist(name.as_str())? { - return TableExistsSnafu { table: name }.fail()?; - } let mut tables = self.tables.write().unwrap(); - Ok(tables.insert(name, table)) + if let Some(existing) = tables.get(name.as_str()) { + // if table with the same name but different table id exists, then it's a fatal bug + if existing.table_info().ident.table_id != table.table_info().ident.table_id { + error!( + "Unexpected table register: {:?}, existing: {:?}", + table.table_info(), + existing.table_info() + ); + return TableExistsSnafu { table: name }.fail()?; + } + Ok(Some(existing.clone())) + } else { + Ok(tables.insert(name, table)) + } } fn deregister_table(&self, name: &str) -> Result> { @@ -315,7 +345,7 @@ mod tests { .unwrap() .is_none()); assert!(provider.table_exist(table_name).unwrap()); - let other_table = NumbersTable::default(); + let other_table = NumbersTable::new(12); let result = provider.register_table(table_name.to_string(), Arc::new(other_table)); let err = result.err().unwrap(); assert!(err.backtrace_opt().is_some()); @@ -340,4 +370,34 @@ mod tests { .downcast_ref::() .unwrap(); } + + #[tokio::test] + pub async fn test_catalog_deregister_table() { + let catalog = MemoryCatalogManager::default(); + let schema = catalog + .schema(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME) + .unwrap() + .unwrap(); + + let register_table_req = RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name: "numbers".to_string(), + table_id: 2333, + table: Arc::new(NumbersTable::default()), + }; + catalog.register_table(register_table_req).await.unwrap(); + assert!(schema.table_exist("numbers").unwrap()); + + let deregister_table_req = DeregisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name: "numbers".to_string(), + }; + catalog + .deregister_table(deregister_table_req) + .await + .unwrap(); + assert!(!schema.table_exist("numbers").unwrap()); + } } diff --git a/src/catalog/src/remote/manager.rs b/src/catalog/src/remote/manager.rs index 5369f6ce0d..ba7c09f6c0 100644 --- a/src/catalog/src/remote/manager.rs +++ b/src/catalog/src/remote/manager.rs @@ -37,13 +37,13 @@ use tokio::sync::Mutex; use crate::error::{ CatalogNotFoundSnafu, CreateTableSnafu, InvalidCatalogValueSnafu, InvalidTableSchemaSnafu, - OpenTableSnafu, Result, SchemaNotFoundSnafu, TableExistsSnafu, + OpenTableSnafu, Result, SchemaNotFoundSnafu, TableExistsSnafu, UnimplementedSnafu, }; use crate::remote::{Kv, KvBackendRef}; use crate::{ handle_system_table_request, CatalogList, CatalogManager, CatalogProvider, CatalogProviderRef, - RegisterSchemaRequest, RegisterSystemTableRequest, RegisterTableRequest, SchemaProvider, - SchemaProviderRef, + DeregisterTableRequest, RegisterSchemaRequest, RegisterSystemTableRequest, + RegisterTableRequest, SchemaProvider, SchemaProviderRef, }; /// Catalog manager based on metasrv. @@ -154,8 +154,8 @@ impl RemoteCatalogManager { } let table_key = TableGlobalKey::parse(&String::from_utf8_lossy(&k)) .context(InvalidCatalogValueSnafu)?; - let table_value = TableGlobalValue::parse(&String::from_utf8_lossy(&v)) - .context(InvalidCatalogValueSnafu)?; + let table_value = + TableGlobalValue::from_bytes(&v).context(InvalidCatalogValueSnafu)?; info!( "Found catalog table entry, key: {}, value: {:?}", @@ -411,7 +411,7 @@ impl CatalogManager for RemoteCatalogManager { Ok(()) } - async fn register_table(&self, request: RegisterTableRequest) -> Result { + async fn register_table(&self, request: RegisterTableRequest) -> Result { let catalog_name = request.catalog; let schema_name = request.schema; let catalog_provider = self.catalog(&catalog_name)?.context(CatalogNotFoundSnafu { @@ -430,10 +430,17 @@ impl CatalogManager for RemoteCatalogManager { .fail(); } schema_provider.register_table(request.table_name, request.table)?; - Ok(1) + Ok(true) } - async fn register_schema(&self, request: RegisterSchemaRequest) -> Result { + async fn deregister_table(&self, _request: DeregisterTableRequest) -> Result { + UnimplementedSnafu { + operation: "deregister table", + } + .fail() + } + + async fn register_schema(&self, request: RegisterSchemaRequest) -> Result { let catalog_name = request.catalog; let schema_name = request.schema; let catalog_provider = self.catalog(&catalog_name)?.context(CatalogNotFoundSnafu { @@ -441,7 +448,7 @@ impl CatalogManager for RemoteCatalogManager { })?; let schema_provider = self.new_schema_provider(&catalog_name, &schema_name); catalog_provider.register_schema(schema_name, schema_provider)?; - Ok(1) + Ok(true) } async fn register_system_table(&self, request: RegisterSystemTableRequest) -> Result<()> { diff --git a/src/catalog/src/system.rs b/src/catalog/src/system.rs index 564acc7ba5..b6555b9353 100644 --- a/src/catalog/src/system.rs +++ b/src/catalog/src/system.rs @@ -43,7 +43,6 @@ use crate::error::{ pub const ENTRY_TYPE_INDEX: usize = 0; pub const KEY_INDEX: usize = 1; -pub const TIMESTAMP_INDEX: usize = 2; pub const VALUE_INDEX: usize = 3; pub struct SystemCatalogTable { @@ -111,7 +110,7 @@ impl SystemCatalogTable { desc: Some("System catalog table".to_string()), schema: schema.clone(), region_numbers: vec![0], - primary_key_indices: vec![ENTRY_TYPE_INDEX, KEY_INDEX, TIMESTAMP_INDEX], + primary_key_indices: vec![ENTRY_TYPE_INDEX, KEY_INDEX], create_if_not_exists: true, table_options: HashMap::new(), }; @@ -456,7 +455,7 @@ mod tests { pub async fn prepare_table_engine() -> (TempDir, TableEngineRef) { let dir = TempDir::new("system-table-test").unwrap(); let store_dir = dir.path().to_string_lossy(); - let accessor = opendal::services::fs::Builder::default() + let accessor = object_store::backend::fs::Builder::default() .root(&store_dir) .build() .unwrap(); diff --git a/src/catalog/tests/local_catalog_tests.rs b/src/catalog/tests/local_catalog_tests.rs new file mode 100644 index 0000000000..2e57754077 --- /dev/null +++ b/src/catalog/tests/local_catalog_tests.rs @@ -0,0 +1,132 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use catalog::local::LocalCatalogManager; + use catalog::{CatalogManager, RegisterTableRequest}; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; + use common_telemetry::{error, info}; + use mito::config::EngineConfig; + use table::table::numbers::NumbersTable; + use table::TableRef; + use tokio::sync::Mutex; + + async fn create_local_catalog_manager() -> Result { + let (_dir, object_store) = + mito::table::test_util::new_test_object_store("setup_mock_engine_and_table").await; + let mock_engine = Arc::new(mito::table::test_util::MockMitoEngine::new( + EngineConfig::default(), + mito::table::test_util::MockEngine::default(), + object_store, + )); + let catalog_manager = LocalCatalogManager::try_new(mock_engine).await.unwrap(); + catalog_manager.start().await?; + Ok(catalog_manager) + } + + #[tokio::test] + async fn test_duplicate_register() { + let catalog_manager = create_local_catalog_manager().await.unwrap(); + let request = RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name: "test_table".to_string(), + table_id: 42, + table: Arc::new(NumbersTable::new(42)), + }; + assert!(catalog_manager + .register_table(request.clone()) + .await + .unwrap()); + + // register table with same table id will succeed with 0 as return val. + assert!(!catalog_manager.register_table(request).await.unwrap()); + + let err = catalog_manager + .register_table(RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name: "test_table".to_string(), + table_id: 43, + table: Arc::new(NumbersTable::new(43)), + }) + .await + .unwrap_err(); + assert!( + err.to_string() + .contains("Table `greptime.public.test_table` already exists"), + "Actual error message: {}", + err + ); + } + + #[test] + fn test_concurrent_register() { + common_telemetry::init_default_ut_logging(); + let rt = Arc::new(tokio::runtime::Builder::new_multi_thread().build().unwrap()); + let catalog_manager = + Arc::new(rt.block_on(async { create_local_catalog_manager().await.unwrap() })); + + let succeed: Arc>> = Arc::new(Mutex::new(None)); + + let mut handles = Vec::with_capacity(8); + for i in 0..8 { + let catalog = catalog_manager.clone(); + let succeed = succeed.clone(); + let handle = rt.spawn(async move { + let table_id = 42 + i; + let table = Arc::new(NumbersTable::new(table_id)); + let req = RegisterTableRequest { + catalog: DEFAULT_CATALOG_NAME.to_string(), + schema: DEFAULT_SCHEMA_NAME.to_string(), + table_name: "test_table".to_string(), + table_id, + table: table.clone(), + }; + match catalog.register_table(req).await { + Ok(res) => { + if res { + let mut succeed = succeed.lock().await; + info!("Successfully registered table: {}", table_id); + *succeed = Some(table); + } + } + Err(_) => { + error!("Failed to register table {}", table_id); + } + } + }); + handles.push(handle); + } + + rt.block_on(async move { + for handle in handles { + handle.await.unwrap(); + } + let guard = succeed.lock().await; + let table = guard.as_ref().unwrap(); + let table_registered = catalog_manager + .table(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, "test_table") + .unwrap() + .unwrap(); + assert_eq!( + table_registered.table_info().ident.table_id, + table.table_info().ident.table_id + ); + }); + } +} diff --git a/src/catalog/tests/mock.rs b/src/catalog/tests/mock.rs index 067abc7ca0..01aec6e2f8 100644 --- a/src/catalog/tests/mock.rs +++ b/src/catalog/tests/mock.rs @@ -217,7 +217,7 @@ impl TableEngine for MockTableEngine { &self, _ctx: &EngineContext, _request: DropTableRequest, - ) -> table::Result<()> { + ) -> table::Result { unimplemented!() } } diff --git a/src/catalog/tests/remote_catalog_tests.rs b/src/catalog/tests/remote_catalog_tests.rs index b43fd09889..e5d8811e71 100644 --- a/src/catalog/tests/remote_catalog_tests.rs +++ b/src/catalog/tests/remote_catalog_tests.rs @@ -202,7 +202,7 @@ mod tests { table_id, table, }; - assert_eq!(1, catalog_manager.register_table(reg_req).await.unwrap()); + assert!(catalog_manager.register_table(reg_req).await.unwrap()); assert_eq!( HashSet::from([table_name, "numbers".to_string()]), default_schema @@ -287,7 +287,7 @@ mod tests { .register_schema(schema_name.clone(), schema.clone()) .expect("Register schema should not fail"); assert!(prev.is_none()); - assert_eq!(1, catalog_manager.register_table(reg_req).await.unwrap()); + assert!(catalog_manager.register_table(reg_req).await.unwrap()); assert_eq!( HashSet::from([schema_name.clone()]), diff --git a/src/client/Cargo.toml b/src/client/Cargo.toml index 8c52397ea7..5c19f89970 100644 --- a/src/client/Cargo.toml +++ b/src/client/Cargo.toml @@ -11,9 +11,9 @@ async-stream = "0.3" common-base = { path = "../common/base" } common-error = { path = "../common/error" } common-grpc = { path = "../common/grpc" } +common-grpc-expr = { path = "../common/grpc-expr" } common-query = { path = "../common/query" } common-recordbatch = { path = "../common/recordbatch" } -common-insert = { path = "../common/insert" } common-time = { path = "../common/time" } datafusion = "14.0.0" datatypes = { path = "../datatypes" } diff --git a/src/client/examples/insert.rs b/src/client/examples/insert.rs index e85d45200c..66f38eded3 100644 --- a/src/client/examples/insert.rs +++ b/src/client/examples/insert.rs @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; - -use api::v1::codec::InsertBatch; use api::v1::*; use client::{Client, Database}; + fn main() { tracing::subscriber::set_global_default(tracing_subscriber::FmtSubscriber::builder().finish()) .unwrap(); @@ -29,19 +27,19 @@ async fn run() { let client = Client::with_urls(vec!["127.0.0.1:3001"]); let db = Database::new("greptime", client); + let (columns, row_count) = insert_data(); + let expr = InsertExpr { schema_name: "public".to_string(), table_name: "demo".to_string(), - expr: Some(insert_expr::Expr::Values(insert_expr::Values { - values: insert_batches(), - })), - options: HashMap::default(), region_number: 0, + columns, + row_count, }; db.insert(expr).await.unwrap(); } -fn insert_batches() -> Vec> { +fn insert_data() -> (Vec, u32) { const SEMANTIC_TAG: i32 = 0; const SEMANTIC_FIELD: i32 = 1; const SEMANTIC_TS: i32 = 2; @@ -101,9 +99,8 @@ fn insert_batches() -> Vec> { ..Default::default() }; - let insert_batch = InsertBatch { - columns: vec![host_column, cpu_column, mem_column, ts_column], + ( + vec![host_column, cpu_column, mem_column, ts_column], row_count, - }; - vec![insert_batch.into()] + ) } diff --git a/src/client/src/admin.rs b/src/client/src/admin.rs index d872dd41d2..f70aea0356 100644 --- a/src/client/src/admin.rs +++ b/src/client/src/admin.rs @@ -58,7 +58,19 @@ impl Admin { header: Some(header), expr: Some(admin_expr::Expr::Alter(expr)), }; - Ok(self.do_requests(vec![expr]).await?.remove(0)) + self.do_request(expr).await + } + + pub async fn drop_table(&self, expr: DropTableExpr) -> Result { + let header = ExprHeader { + version: PROTOCOL_VERSION, + }; + let expr = AdminExpr { + header: Some(header), + expr: Some(admin_expr::Expr::DropTable(expr)), + }; + + self.do_request(expr).await } /// Invariants: the lengths of input vec (`Vec`) and output vec (`Vec`) are equal. diff --git a/src/client/src/database.rs b/src/client/src/database.rs index 9cea7d5d85..3228a74cf8 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -23,7 +23,7 @@ use api::v1::{ }; use common_error::status_code::StatusCode; use common_grpc::{AsExecutionPlan, DefaultAsPlanImpl}; -use common_insert::column_to_vector; +use common_grpc_expr::column_to_vector; use common_query::Output; use common_recordbatch::{RecordBatch, RecordBatches}; use datafusion::physical_plan::ExecutionPlan; diff --git a/src/client/src/error.rs b/src/client/src/error.rs index add1a0989e..953fcb44f9 100644 --- a/src/client/src/error.rs +++ b/src/client/src/error.rs @@ -103,7 +103,7 @@ pub enum Error { #[snafu(display("Failed to convert column to vector, source: {}", source))] ColumnToVector { #[snafu(backtrace)] - source: common_insert::error::Error, + source: common_grpc_expr::error::Error, }, } diff --git a/src/cmd/Cargo.toml b/src/cmd/Cargo.toml index c446180738..8168b98788 100644 --- a/src/cmd/Cargo.toml +++ b/src/cmd/Cargo.toml @@ -15,13 +15,13 @@ common-error = { path = "../common/error" } common-telemetry = { path = "../common/telemetry", features = [ "deadlock_detection", ] } -meta-client = { path = "../meta-client" } datanode = { path = "../datanode" } frontend = { path = "../frontend" } futures = "0.3" +meta-client = { path = "../meta-client" } meta-srv = { path = "../meta-srv" } serde = "1.0" -servers = {path = "../servers"} +servers = { path = "../servers" } snafu = { version = "0.7", features = ["backtraces"] } tokio = { version = "1.18", features = ["full"] } toml = "0.5" @@ -29,3 +29,6 @@ toml = "0.5" [dev-dependencies] serde = "1.0" tempdir = "0.3" + +[build-dependencies] +build-data = "0.1.3" diff --git a/src/cmd/build.rs b/src/cmd/build.rs new file mode 100644 index 0000000000..15d858e847 --- /dev/null +++ b/src/cmd/build.rs @@ -0,0 +1,19 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +fn main() { + build_data::set_GIT_BRANCH(); + build_data::set_GIT_COMMIT(); + build_data::set_GIT_DIRTY(); +} diff --git a/src/cmd/src/bin/greptime.rs b/src/cmd/src/bin/greptime.rs index 4beb4b805d..578bee7e3b 100644 --- a/src/cmd/src/bin/greptime.rs +++ b/src/cmd/src/bin/greptime.rs @@ -20,7 +20,7 @@ use cmd::{datanode, frontend, metasrv, standalone}; use common_telemetry::logging::{error, info}; #[derive(Parser)] -#[clap(name = "greptimedb")] +#[clap(name = "greptimedb", version = print_version())] struct Command { #[clap(long, default_value = "/tmp/greptimedb/logs")] log_dir: String, @@ -70,6 +70,17 @@ impl fmt::Display for SubCommand { } } +fn print_version() -> &'static str { + concat!( + "\nbranch: ", + env!("GIT_BRANCH"), + "\ncommit: ", + env!("GIT_COMMIT"), + "\ndirty: ", + env!("GIT_DIRTY") + ) +} + #[tokio::main] async fn main() -> Result<()> { let cmd = Command::parse(); diff --git a/src/cmd/src/datanode.rs b/src/cmd/src/datanode.rs index d386bfa64e..4afc1ea619 100644 --- a/src/cmd/src/datanode.rs +++ b/src/cmd/src/datanode.rs @@ -170,6 +170,7 @@ mod tests { ObjectStoreConfig::File { data_dir } => { assert_eq!("/tmp/greptimedb/data/".to_string(), data_dir) } + ObjectStoreConfig::S3 { .. } => unreachable!(), }; } diff --git a/src/cmd/src/error.rs b/src/cmd/src/error.rs index cedc7ba19c..7856c66e16 100644 --- a/src/cmd/src/error.rs +++ b/src/cmd/src/error.rs @@ -97,10 +97,7 @@ mod tests { #[test] fn test_start_node_error() { fn throw_datanode_error() -> StdResult { - datanode::error::MissingFieldSnafu { - field: "test_field", - } - .fail() + datanode::error::MissingNodeIdSnafu {}.fail() } let e = throw_datanode_error() diff --git a/src/cmd/src/frontend.rs b/src/cmd/src/frontend.rs index b6b7b8bfad..100f411d30 100644 --- a/src/cmd/src/frontend.rs +++ b/src/cmd/src/frontend.rs @@ -21,6 +21,7 @@ use frontend::mysql::MysqlOptions; use frontend::opentsdb::OpentsdbOptions; use frontend::postgres::PostgresOptions; use meta_client::MetaClientOpts; +use servers::http::HttpOptions; use servers::Mode; use snafu::ResultExt; @@ -96,7 +97,10 @@ impl TryFrom for FrontendOptions { }; if let Some(addr) = cmd.http_addr { - opts.http_addr = Some(addr); + opts.http_options = Some(HttpOptions { + addr, + ..Default::default() + }); } if let Some(addr) = cmd.grpc_addr { opts.grpc_options = Some(GrpcOptions { @@ -141,6 +145,8 @@ impl TryFrom for FrontendOptions { #[cfg(test)] mod tests { + use std::time::Duration; + use super::*; #[test] @@ -157,7 +163,7 @@ mod tests { }; let opts: FrontendOptions = command.try_into().unwrap(); - assert_eq!(opts.http_addr, Some("127.0.0.1:1234".to_string())); + assert_eq!(opts.http_options.as_ref().unwrap().addr, "127.0.0.1:1234"); assert_eq!(opts.mysql_options.as_ref().unwrap().addr, "127.0.0.1:5678"); assert_eq!( opts.postgres_options.as_ref().unwrap().addr, @@ -188,4 +194,33 @@ mod tests { assert!(!opts.influxdb_options.unwrap().enable); } + + #[test] + fn test_read_from_config_file() { + let command = StartCommand { + http_addr: None, + grpc_addr: None, + mysql_addr: None, + postgres_addr: None, + opentsdb_addr: None, + influxdb_enable: None, + config_file: Some(format!( + "{}/../../config/frontend.example.toml", + std::env::current_dir().unwrap().as_path().to_str().unwrap() + )), + metasrv_addr: None, + }; + + let fe_opts = FrontendOptions::try_from(command).unwrap(); + assert_eq!(Mode::Distributed, fe_opts.mode); + assert_eq!("127.0.0.1:3001".to_string(), fe_opts.datanode_rpc_addr); + assert_eq!( + "127.0.0.1:4000".to_string(), + fe_opts.http_options.as_ref().unwrap().addr + ); + assert_eq!( + Duration::from_secs(30), + fe_opts.http_options.as_ref().unwrap().timeout + ); + } } diff --git a/src/cmd/src/standalone.rs b/src/cmd/src/standalone.rs index 39b0b14fa6..b3a86e3fb3 100644 --- a/src/cmd/src/standalone.rs +++ b/src/cmd/src/standalone.rs @@ -25,6 +25,7 @@ use frontend::opentsdb::OpentsdbOptions; use frontend::postgres::PostgresOptions; use frontend::prometheus::PrometheusOptions; use serde::{Deserialize, Serialize}; +use servers::http::HttpOptions; use servers::Mode; use snafu::ResultExt; use tokio::try_join; @@ -61,7 +62,7 @@ impl SubCommand { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct StandaloneOptions { - pub http_addr: Option, + pub http_options: Option, pub grpc_options: Option, pub mysql_options: Option, pub postgres_options: Option, @@ -71,12 +72,13 @@ pub struct StandaloneOptions { pub mode: Mode, pub wal_dir: String, pub storage: ObjectStoreConfig, + pub enable_memory_catalog: bool, } impl Default for StandaloneOptions { fn default() -> Self { Self { - http_addr: Some("127.0.0.1:4000".to_string()), + http_options: Some(HttpOptions::default()), grpc_options: Some(GrpcOptions::default()), mysql_options: Some(MysqlOptions::default()), postgres_options: Some(PostgresOptions::default()), @@ -86,6 +88,7 @@ impl Default for StandaloneOptions { mode: Mode::Standalone, wal_dir: "/tmp/greptimedb/wal".to_string(), storage: ObjectStoreConfig::default(), + enable_memory_catalog: false, } } } @@ -93,7 +96,7 @@ impl Default for StandaloneOptions { impl StandaloneOptions { fn frontend_options(self) -> FrontendOptions { FrontendOptions { - http_addr: self.http_addr, + http_options: self.http_options, grpc_options: self.grpc_options, mysql_options: self.mysql_options, postgres_options: self.postgres_options, @@ -110,6 +113,7 @@ impl StandaloneOptions { DatanodeOptions { wal_dir: self.wal_dir, storage: self.storage, + enable_memory_catalog: self.enable_memory_catalog, ..Default::default() } } @@ -131,18 +135,22 @@ struct StartCommand { influxdb_enable: bool, #[clap(short, long)] config_file: Option, + #[clap(short = 'm', long = "memory-catalog")] + enable_memory_catalog: bool, } impl StartCommand { async fn run(self) -> Result<()> { + let enable_memory_catalog = self.enable_memory_catalog; let config_file = self.config_file.clone(); let fe_opts = FrontendOptions::try_from(self)?; let dn_opts: DatanodeOptions = { - let opts: StandaloneOptions = if let Some(path) = config_file { + let mut opts: StandaloneOptions = if let Some(path) = config_file { toml_loader::from_file!(&path)? } else { StandaloneOptions::default() }; + opts.enable_memory_catalog = enable_memory_catalog; opts.datanode_options() }; @@ -156,8 +164,15 @@ impl StartCommand { .context(StartDatanodeSnafu)?; let mut frontend = build_frontend(fe_opts, &dn_opts, datanode.get_instance()).await?; + // Start datanode instance before starting services, to avoid requests come in before internal components are started. + datanode + .start_instance() + .await + .context(StartDatanodeSnafu)?; + info!("Datanode instance started"); + try_join!( - async { datanode.start().await.context(StartDatanodeSnafu) }, + async { datanode.start_services().await.context(StartDatanodeSnafu) }, async { frontend.start().await.context(StartFrontendSnafu) } )?; @@ -199,7 +214,10 @@ impl TryFrom for FrontendOptions { opts.mode = Mode::Standalone; if let Some(addr) = cmd.http_addr { - opts.http_addr = Some(addr); + opts.http_options = Some(HttpOptions { + addr, + ..Default::default() + }); } if let Some(addr) = cmd.rpc_addr { // frontend grpc addr conflict with datanode default grpc addr @@ -249,6 +267,8 @@ impl TryFrom for FrontendOptions { #[cfg(test)] mod tests { + use std::time::Duration; + use super::*; #[test] @@ -264,12 +284,20 @@ mod tests { std::env::current_dir().unwrap().as_path().to_str().unwrap() )), influxdb_enable: false, + enable_memory_catalog: false, }; let fe_opts = FrontendOptions::try_from(cmd).unwrap(); assert_eq!(Mode::Standalone, fe_opts.mode); assert_eq!("127.0.0.1:3001".to_string(), fe_opts.datanode_rpc_addr); - assert_eq!(Some("127.0.0.1:4000".to_string()), fe_opts.http_addr); + assert_eq!( + "127.0.0.1:4000".to_string(), + fe_opts.http_options.as_ref().unwrap().addr + ); + assert_eq!( + Duration::from_secs(30), + fe_opts.http_options.as_ref().unwrap().timeout + ); assert_eq!( "127.0.0.1:4001".to_string(), fe_opts.grpc_options.unwrap().addr diff --git a/src/common/catalog/src/helper.rs b/src/common/catalog/src/helper.rs index ccfe362969..dcfa08e8a7 100644 --- a/src/common/catalog/src/helper.rs +++ b/src/common/catalog/src/helper.rs @@ -28,31 +28,42 @@ use crate::error::{ DeserializeCatalogEntryValueSnafu, Error, InvalidCatalogSnafu, SerializeCatalogEntryValueSnafu, }; +const ALPHANUMERICS_NAME_PATTERN: &str = "[a-zA-Z_][a-zA-Z0-9_]*"; + lazy_static! { - static ref CATALOG_KEY_PATTERN: Regex = - Regex::new(&format!("^{}-([a-zA-Z_]+)$", CATALOG_KEY_PREFIX)).unwrap(); + static ref CATALOG_KEY_PATTERN: Regex = Regex::new(&format!( + "^{}-({})$", + CATALOG_KEY_PREFIX, ALPHANUMERICS_NAME_PATTERN + )) + .unwrap(); } lazy_static! { static ref SCHEMA_KEY_PATTERN: Regex = Regex::new(&format!( - "^{}-([a-zA-Z_]+)-([a-zA-Z_]+)$", - SCHEMA_KEY_PREFIX + "^{}-({})-({})$", + SCHEMA_KEY_PREFIX, ALPHANUMERICS_NAME_PATTERN, ALPHANUMERICS_NAME_PATTERN )) .unwrap(); } lazy_static! { static ref TABLE_GLOBAL_KEY_PATTERN: Regex = Regex::new(&format!( - "^{}-([a-zA-Z_]+)-([a-zA-Z_]+)-([a-zA-Z0-9_]+)$", - TABLE_GLOBAL_KEY_PREFIX + "^{}-({})-({})-({})$", + TABLE_GLOBAL_KEY_PREFIX, + ALPHANUMERICS_NAME_PATTERN, + ALPHANUMERICS_NAME_PATTERN, + ALPHANUMERICS_NAME_PATTERN )) .unwrap(); } lazy_static! { static ref TABLE_REGIONAL_KEY_PATTERN: Regex = Regex::new(&format!( - "^{}-([a-zA-Z_]+)-([a-zA-Z_]+)-([a-zA-Z0-9_]+)-([0-9]+)$", - TABLE_REGIONAL_KEY_PREFIX + "^{}-({})-({})-({})-([0-9]+)$", + TABLE_REGIONAL_KEY_PREFIX, + ALPHANUMERICS_NAME_PATTERN, + ALPHANUMERICS_NAME_PATTERN, + ALPHANUMERICS_NAME_PATTERN )) .unwrap(); } @@ -261,6 +272,10 @@ macro_rules! define_catalog_value { .context(DeserializeCatalogEntryValueSnafu { raw: s.as_ref() }) } + pub fn from_bytes(bytes: impl AsRef<[u8]>) -> Result { + Self::parse(&String::from_utf8_lossy(bytes.as_ref())) + } + pub fn as_bytes(&self) -> Result, Error> { Ok(serde_json::to_string(self) .context(SerializeCatalogEntryValueSnafu)? diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index 7ea7a0088f..ce49cb5e5b 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -9,8 +9,8 @@ arc-swap = "1.0" chrono-tz = "0.6" common-error = { path = "../error" } common-function-macro = { path = "../function-macro" } -common-time = { path = "../time" } common-query = { path = "../query" } +common-time = { path = "../time" } datafusion-common = "14.0.0" datatypes = { path = "../../datatypes" } libc = "0.2" diff --git a/src/common/insert/Cargo.toml b/src/common/grpc-expr/Cargo.toml similarity index 81% rename from src/common/insert/Cargo.toml rename to src/common/grpc-expr/Cargo.toml index 8dca21eb9a..9d8580b3d7 100644 --- a/src/common/insert/Cargo.toml +++ b/src/common/grpc-expr/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "common-insert" +name = "common-grpc-expr" version = "0.1.0" edition = "2021" license = "Apache-2.0" @@ -8,10 +8,12 @@ license = "Apache-2.0" api = { path = "../../api" } async-trait = "0.1" common-base = { path = "../base" } +common-catalog = { path = "../catalog" } common-error = { path = "../error" } +common-grpc = { path = "../grpc" } +common-query = { path = "../query" } common-telemetry = { path = "../telemetry" } common-time = { path = "../time" } -common-query = { path = "../query" } datatypes = { path = "../../datatypes" } snafu = { version = "0.7", features = ["backtraces"] } table = { path = "../../table" } diff --git a/src/common/grpc-expr/src/alter.rs b/src/common/grpc-expr/src/alter.rs new file mode 100644 index 0000000000..cdef37cbcb --- /dev/null +++ b/src/common/grpc-expr/src/alter.rs @@ -0,0 +1,234 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use api::v1::alter_expr::Kind; +use api::v1::{AlterExpr, CreateExpr, DropColumns}; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use datatypes::schema::{ColumnSchema, SchemaBuilder, SchemaRef}; +use snafu::{ensure, OptionExt, ResultExt}; +use table::metadata::TableId; +use table::requests::{AddColumnRequest, AlterKind, AlterTableRequest, CreateTableRequest}; + +use crate::error::{ + ColumnNotFoundSnafu, CreateSchemaSnafu, InvalidColumnDefSnafu, MissingFieldSnafu, + MissingTimestampColumnSnafu, Result, +}; + +/// Convert an [`AlterExpr`] to an optional [`AlterTableRequest`] +pub fn alter_expr_to_request(expr: AlterExpr) -> Result> { + match expr.kind { + Some(Kind::AddColumns(add_columns)) => { + let add_column_requests = add_columns + .add_columns + .into_iter() + .map(|ac| { + let column_def = ac.column_def.context(MissingFieldSnafu { + field: "column_def", + })?; + + let schema = + column_def + .try_as_column_schema() + .context(InvalidColumnDefSnafu { + column: &column_def.name, + })?; + Ok(AddColumnRequest { + column_schema: schema, + is_key: ac.is_key, + }) + }) + .collect::>>()?; + + let alter_kind = AlterKind::AddColumns { + columns: add_column_requests, + }; + + let request = AlterTableRequest { + catalog_name: expr.catalog_name, + schema_name: expr.schema_name, + table_name: expr.table_name, + alter_kind, + }; + Ok(Some(request)) + } + Some(Kind::DropColumns(DropColumns { drop_columns })) => { + let alter_kind = AlterKind::DropColumns { + names: drop_columns.into_iter().map(|c| c.name).collect(), + }; + + let request = AlterTableRequest { + catalog_name: expr.catalog_name, + schema_name: expr.schema_name, + table_name: expr.table_name, + alter_kind, + }; + Ok(Some(request)) + } + None => Ok(None), + } +} + +pub fn create_table_schema(expr: &CreateExpr) -> Result { + let column_schemas = expr + .column_defs + .iter() + .map(|x| { + x.try_as_column_schema() + .context(InvalidColumnDefSnafu { column: &x.name }) + }) + .collect::>>()?; + + ensure!( + column_schemas + .iter() + .any(|column| column.name == expr.time_index), + MissingTimestampColumnSnafu { + msg: format!("CreateExpr: {:?}", expr) + } + ); + + let column_schemas = column_schemas + .into_iter() + .map(|column_schema| { + if column_schema.name == expr.time_index { + column_schema.with_time_index(true) + } else { + column_schema + } + }) + .collect::>(); + + Ok(Arc::new( + SchemaBuilder::try_from(column_schemas) + .context(CreateSchemaSnafu)? + .build() + .context(CreateSchemaSnafu)?, + )) +} + +pub fn create_expr_to_request(table_id: TableId, expr: CreateExpr) -> Result { + let schema = create_table_schema(&expr)?; + let primary_key_indices = expr + .primary_keys + .iter() + .map(|key| { + schema + .column_index_by_name(key) + .context(ColumnNotFoundSnafu { + column_name: key, + table_name: &expr.table_name, + }) + }) + .collect::>>()?; + + let catalog_name = expr + .catalog_name + .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()); + let schema_name = expr + .schema_name + .unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()); + + let region_ids = if expr.region_ids.is_empty() { + vec![0] + } else { + expr.region_ids + }; + + Ok(CreateTableRequest { + id: table_id, + catalog_name, + schema_name, + table_name: expr.table_name, + desc: expr.desc, + schema, + region_numbers: region_ids, + primary_key_indices, + create_if_not_exists: expr.create_if_not_exists, + table_options: expr.table_options, + }) +} + +#[cfg(test)] +mod tests { + use api::v1::{AddColumn, AddColumns, ColumnDataType, ColumnDef, DropColumn}; + use datatypes::prelude::ConcreteDataType; + + use super::*; + + #[test] + fn test_alter_expr_to_request() { + let expr = AlterExpr { + catalog_name: None, + schema_name: None, + table_name: "monitor".to_string(), + + kind: Some(Kind::AddColumns(AddColumns { + add_columns: vec![AddColumn { + column_def: Some(ColumnDef { + name: "mem_usage".to_string(), + datatype: ColumnDataType::Float64 as i32, + is_nullable: false, + default_constraint: None, + }), + is_key: false, + }], + })), + }; + + let alter_request = alter_expr_to_request(expr).unwrap().unwrap(); + assert_eq!(None, alter_request.catalog_name); + assert_eq!(None, alter_request.schema_name); + assert_eq!("monitor".to_string(), alter_request.table_name); + let add_column = match alter_request.alter_kind { + AlterKind::AddColumns { mut columns } => columns.pop().unwrap(), + _ => unreachable!(), + }; + + assert!(!add_column.is_key); + assert_eq!("mem_usage", add_column.column_schema.name); + assert_eq!( + ConcreteDataType::float64_datatype(), + add_column.column_schema.data_type + ); + } + + #[test] + fn test_drop_column_expr() { + let expr = AlterExpr { + catalog_name: Some("test_catalog".to_string()), + schema_name: Some("test_schema".to_string()), + table_name: "monitor".to_string(), + + kind: Some(Kind::DropColumns(DropColumns { + drop_columns: vec![DropColumn { + name: "mem_usage".to_string(), + }], + })), + }; + + let alter_request = alter_expr_to_request(expr).unwrap().unwrap(); + assert_eq!(Some("test_catalog".to_string()), alter_request.catalog_name); + assert_eq!(Some("test_schema".to_string()), alter_request.schema_name); + assert_eq!("monitor".to_string(), alter_request.table_name); + + let mut drop_names = match alter_request.alter_kind { + AlterKind::DropColumns { names } => names, + _ => unreachable!(), + }; + assert_eq!(1, drop_names.len()); + assert_eq!("mem_usage".to_string(), drop_names.pop().unwrap()); + } +} diff --git a/src/common/insert/src/error.rs b/src/common/grpc-expr/src/error.rs similarity index 74% rename from src/common/insert/src/error.rs rename to src/common/grpc-expr/src/error.rs index dbd455b2ec..dc0df10c46 100644 --- a/src/common/insert/src/error.rs +++ b/src/common/grpc-expr/src/error.rs @@ -22,7 +22,7 @@ use snafu::{Backtrace, ErrorCompat}; #[derive(Debug, Snafu)] #[snafu(visibility(pub))] pub enum Error { - #[snafu(display("Column {} not found in table {}", column_name, table_name))] + #[snafu(display("Column `{}` not found in table `{}`", column_name, table_name))] ColumnNotFound { column_name: String, table_name: String, @@ -57,8 +57,8 @@ pub enum Error { backtrace: Backtrace, }, - #[snafu(display("Missing timestamp column in request"))] - MissingTimestampColumn { backtrace: Backtrace }, + #[snafu(display("Missing timestamp column, msg: {}", msg))] + MissingTimestampColumn { msg: String, backtrace: Backtrace }, #[snafu(display("Invalid column proto: {}", err_msg))] InvalidColumnProto { @@ -70,6 +70,26 @@ pub enum Error { #[snafu(backtrace)] source: datatypes::error::Error, }, + + #[snafu(display("Missing required field in protobuf, field: {}", field))] + MissingField { field: String, backtrace: Backtrace }, + + #[snafu(display("Invalid column default constraint, source: {}", source))] + ColumnDefaultConstraint { + #[snafu(backtrace)] + source: datatypes::error::Error, + }, + + #[snafu(display( + "Invalid column proto definition, column: {}, source: {}", + column, + source + ))] + InvalidColumnDef { + column: String, + #[snafu(backtrace)] + source: api::error::Error, + }, } pub type Result = std::result::Result; @@ -87,6 +107,9 @@ impl ErrorExt for Error { | Error::MissingTimestampColumn { .. } => StatusCode::InvalidArguments, Error::InvalidColumnProto { .. } => StatusCode::InvalidArguments, Error::CreateVector { .. } => StatusCode::InvalidArguments, + Error::MissingField { .. } => StatusCode::InvalidArguments, + Error::ColumnDefaultConstraint { source, .. } => source.status_code(), + Error::InvalidColumnDef { source, .. } => source.status_code(), } } fn backtrace_opt(&self) -> Option<&Backtrace> { diff --git a/src/common/insert/src/insert.rs b/src/common/grpc-expr/src/insert.rs similarity index 82% rename from src/common/insert/src/insert.rs rename to src/common/grpc-expr/src/insert.rs index 4f597d3d37..d7687d0789 100644 --- a/src/common/insert/src/insert.rs +++ b/src/common/grpc-expr/src/insert.rs @@ -14,11 +14,9 @@ use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; -use std::ops::Deref; use std::sync::Arc; use api::helper::ColumnDataTypeWrapper; -use api::v1::codec::InsertBatch; use api::v1::column::{SemanticType, Values}; use api::v1::{AddColumn, AddColumns, Column, ColumnDataType, ColumnDef, CreateExpr}; use common_base::BitVec; @@ -35,9 +33,8 @@ use table::requests::{AddColumnRequest, AlterKind, AlterTableRequest, InsertRequ use table::Table; use crate::error::{ - ColumnDataTypeSnafu, ColumnNotFoundSnafu, CreateVectorSnafu, DecodeInsertSnafu, - DuplicatedTimestampColumnSnafu, IllegalInsertDataSnafu, InvalidColumnProtoSnafu, - MissingTimestampColumnSnafu, Result, + ColumnDataTypeSnafu, ColumnNotFoundSnafu, CreateVectorSnafu, DuplicatedTimestampColumnSnafu, + IllegalInsertDataSnafu, InvalidColumnProtoSnafu, MissingTimestampColumnSnafu, Result, }; const TAG_SEMANTIC_TYPE: i32 = SemanticType::Tag as i32; const TIMESTAMP_SEMANTIC_TYPE: i32 = SemanticType::Timestamp as i32; @@ -52,35 +49,25 @@ fn build_column_def(column_name: &str, datatype: i32, nullable: bool) -> ColumnD } } -pub fn find_new_columns( - schema: &SchemaRef, - insert_batches: &[InsertBatch], -) -> Result> { +pub fn find_new_columns(schema: &SchemaRef, columns: &[Column]) -> Result> { let mut columns_to_add = Vec::default(); let mut new_columns: HashSet = HashSet::default(); - for InsertBatch { columns, row_count } in insert_batches { - if *row_count == 0 || columns.is_empty() { - continue; - } - - for Column { - column_name, - semantic_type, - datatype, - .. - } in columns + for Column { + column_name, + semantic_type, + datatype, + .. + } in columns + { + if schema.column_schema_by_name(column_name).is_none() && !new_columns.contains(column_name) { - if schema.column_schema_by_name(column_name).is_none() - && !new_columns.contains(column_name) - { - let column_def = Some(build_column_def(column_name, *datatype, true)); - columns_to_add.push(AddColumn { - column_def, - is_key: *semantic_type == TAG_SEMANTIC_TYPE, - }); - new_columns.insert(column_name.to_string()); - } + let column_def = Some(build_column_def(column_name, *datatype, true)); + columns_to_add.push(AddColumn { + column_def, + is_key: *semantic_type == TAG_SEMANTIC_TYPE, + }); + new_columns.insert(column_name.to_string()); } } @@ -201,89 +188,84 @@ pub fn build_create_expr_from_insertion( schema_name: &str, table_id: Option, table_name: &str, - insert_batches: &[InsertBatch], + columns: &[Column], ) -> Result { let mut new_columns: HashSet = HashSet::default(); let mut column_defs = Vec::default(); let mut primary_key_indices = Vec::default(); let mut timestamp_index = usize::MAX; - for InsertBatch { columns, row_count } in insert_batches { - if *row_count == 0 || columns.is_empty() { - continue; - } - - for Column { - column_name, - semantic_type, - datatype, - .. - } in columns - { - if !new_columns.contains(column_name) { - let mut is_nullable = true; - match *semantic_type { - TAG_SEMANTIC_TYPE => primary_key_indices.push(column_defs.len()), - TIMESTAMP_SEMANTIC_TYPE => { - ensure!( - timestamp_index == usize::MAX, - DuplicatedTimestampColumnSnafu { - exists: &columns[timestamp_index].column_name, - duplicated: column_name, - } - ); - timestamp_index = column_defs.len(); - // Timestamp column must not be null. - is_nullable = false; - } - _ => {} + for Column { + column_name, + semantic_type, + datatype, + .. + } in columns + { + if !new_columns.contains(column_name) { + let mut is_nullable = true; + match *semantic_type { + TAG_SEMANTIC_TYPE => primary_key_indices.push(column_defs.len()), + TIMESTAMP_SEMANTIC_TYPE => { + ensure!( + timestamp_index == usize::MAX, + DuplicatedTimestampColumnSnafu { + exists: &columns[timestamp_index].column_name, + duplicated: column_name, + } + ); + timestamp_index = column_defs.len(); + // Timestamp column must not be null. + is_nullable = false; } - - let column_def = build_column_def(column_name, *datatype, is_nullable); - column_defs.push(column_def); - new_columns.insert(column_name.to_string()); + _ => {} } + + let column_def = build_column_def(column_name, *datatype, is_nullable); + column_defs.push(column_def); + new_columns.insert(column_name.to_string()); } - - ensure!(timestamp_index != usize::MAX, MissingTimestampColumnSnafu); - let timestamp_field_name = columns[timestamp_index].column_name.clone(); - - let primary_keys = primary_key_indices - .iter() - .map(|idx| columns[*idx].column_name.clone()) - .collect::>(); - - let expr = CreateExpr { - catalog_name: Some(catalog_name.to_string()), - schema_name: Some(schema_name.to_string()), - table_name: table_name.to_string(), - desc: Some("Created on insertion".to_string()), - column_defs, - time_index: timestamp_field_name, - primary_keys, - create_if_not_exists: true, - table_options: Default::default(), - table_id, - region_ids: vec![0], // TODO:(hl): region id should be allocated by frontend - }; - - return Ok(expr); } - IllegalInsertDataSnafu.fail() + ensure!( + timestamp_index != usize::MAX, + MissingTimestampColumnSnafu { msg: table_name } + ); + let timestamp_field_name = columns[timestamp_index].column_name.clone(); + + let primary_keys = primary_key_indices + .iter() + .map(|idx| columns[*idx].column_name.clone()) + .collect::>(); + + let expr = CreateExpr { + catalog_name: Some(catalog_name.to_string()), + schema_name: Some(schema_name.to_string()), + table_name: table_name.to_string(), + desc: Some("Created on insertion".to_string()), + column_defs, + time_index: timestamp_field_name, + primary_keys, + create_if_not_exists: true, + table_options: Default::default(), + table_id, + region_ids: vec![0], // TODO:(hl): region id should be allocated by frontend + }; + + Ok(expr) } pub fn insertion_expr_to_request( catalog_name: &str, schema_name: &str, table_name: &str, - insert_batches: Vec, + insert_batches: Vec<(Vec, u32)>, table: Arc, ) -> Result { let schema = table.schema(); let mut columns_builders = HashMap::with_capacity(schema.column_schemas().len()); - for InsertBatch { columns, row_count } in insert_batches { + for (columns, row_count) in insert_batches { for Column { column_name, values, @@ -329,14 +311,6 @@ pub fn insertion_expr_to_request( }) } -#[inline] -pub fn insert_batches(bytes_vec: &[Vec]) -> Result> { - bytes_vec - .iter() - .map(|bytes| bytes.deref().try_into().context(DecodeInsertSnafu)) - .collect() -} - fn add_values_to_builder( builder: &mut VectorBuilder, values: Values, @@ -463,9 +437,8 @@ mod tests { use std::sync::Arc; use api::helper::ColumnDataTypeWrapper; - use api::v1::codec::InsertBatch; use api::v1::column::{self, SemanticType, Values}; - use api::v1::{insert_expr, Column, ColumnDataType}; + use api::v1::{Column, ColumnDataType}; use common_base::BitVec; use common_query::physical_plan::PhysicalPlanRef; use common_query::prelude::Expr; @@ -479,11 +452,12 @@ mod tests { use table::Table; use super::{ - build_create_expr_from_insertion, convert_values, find_new_columns, insert_batches, - insertion_expr_to_request, is_null, TAG_SEMANTIC_TYPE, TIMESTAMP_SEMANTIC_TYPE, + build_create_expr_from_insertion, convert_values, insertion_expr_to_request, is_null, + TAG_SEMANTIC_TYPE, TIMESTAMP_SEMANTIC_TYPE, }; use crate::error; use crate::error::ColumnDataTypeSnafu; + use crate::insert::find_new_columns; #[inline] fn build_column_schema( @@ -508,11 +482,10 @@ mod tests { assert!(build_create_expr_from_insertion("", "", table_id, table_name, &[]).is_err()); - let mock_batch_bytes = mock_insert_batches(); - let insert_batches = insert_batches(&mock_batch_bytes).unwrap(); + let insert_batch = mock_insert_batch(); let create_expr = - build_create_expr_from_insertion("", "", table_id, table_name, &insert_batches) + build_create_expr_from_insertion("", "", table_id, table_name, &insert_batch.0) .unwrap(); assert_eq!(table_id, create_expr.table_id); @@ -598,9 +571,9 @@ mod tests { assert!(find_new_columns(&schema, &[]).unwrap().is_none()); - let mock_insert_bytes = mock_insert_batches(); - let insert_batches = insert_batches(&mock_insert_bytes).unwrap(); - let add_columns = find_new_columns(&schema, &insert_batches).unwrap().unwrap(); + let insert_batch = mock_insert_batch(); + + let add_columns = find_new_columns(&schema, &insert_batch.0).unwrap().unwrap(); assert_eq!(2, add_columns.add_columns.len()); let host_column = &add_columns.add_columns[0]; @@ -630,10 +603,7 @@ mod tests { fn test_insertion_expr_to_request() { let table: Arc = Arc::new(DemoTable {}); - let values = insert_expr::Values { - values: mock_insert_batches(), - }; - let insert_batches = insert_batches(&values.values).unwrap(); + let insert_batches = vec![mock_insert_batch()]; let insert_req = insertion_expr_to_request("greptime", "public", "demo", insert_batches, table).unwrap(); @@ -731,7 +701,7 @@ mod tests { } } - fn mock_insert_batches() -> Vec> { + fn mock_insert_batch() -> (Vec, u32) { let row_count = 2; let host_vals = column::Values { @@ -782,10 +752,9 @@ mod tests { datatype: ColumnDataType::Timestamp as i32, }; - let insert_batch = InsertBatch { - columns: vec![host_column, cpu_column, mem_column, ts_column], + ( + vec![host_column, cpu_column, mem_column, ts_column], row_count, - }; - vec![insert_batch.into()] + ) } } diff --git a/src/common/insert/src/lib.rs b/src/common/grpc-expr/src/lib.rs similarity index 80% rename from src/common/insert/src/lib.rs rename to src/common/grpc-expr/src/lib.rs index 3bac3e0969..71786d670f 100644 --- a/src/common/insert/src/lib.rs +++ b/src/common/grpc-expr/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(assert_matches)] // Copyright 2022 Greptime Team // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,9 +13,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod alter; pub mod error; mod insert; + +pub use alter::{alter_expr_to_request, create_expr_to_request, create_table_schema}; pub use insert::{ build_alter_table_request, build_create_expr_from_insertion, column_to_vector, - find_new_columns, insert_batches, insertion_expr_to_request, + find_new_columns, insertion_expr_to_request, }; diff --git a/src/common/grpc/Cargo.toml b/src/common/grpc/Cargo.toml index 1296513c8a..7665f3b721 100644 --- a/src/common/grpc/Cargo.toml +++ b/src/common/grpc/Cargo.toml @@ -12,7 +12,6 @@ common-error = { path = "../error" } common-query = { path = "../query" } common-recordbatch = { path = "../recordbatch" } common-runtime = { path = "../runtime" } -datatypes = { path = "../../datatypes" } dashmap = "5.4" datafusion = "14.0.0" snafu = { version = "0.7", features = ["backtraces"] } diff --git a/src/common/grpc/src/writer.rs b/src/common/grpc/src/writer.rs index 42e27fc887..2cd28f45af 100644 --- a/src/common/grpc/src/writer.rs +++ b/src/common/grpc/src/writer.rs @@ -14,7 +14,6 @@ use std::collections::HashMap; -use api::v1::codec::InsertBatch; use api::v1::column::{SemanticType, Values}; use api::v1::{Column, ColumnDataType}; use common_base::BitVec; @@ -24,12 +23,14 @@ use crate::error::{Result, TypeMismatchSnafu}; type ColumnName = String; +type RowCount = u32; + // TODO(fys): will remove in the future. #[derive(Default)] pub struct LinesWriter { column_name_index: HashMap, null_masks: Vec, - batch: InsertBatch, + batch: (Vec, RowCount), lines: usize, } @@ -171,20 +172,20 @@ impl LinesWriter { pub fn commit(&mut self) { let batch = &mut self.batch; - batch.row_count += 1; + batch.1 += 1; - for i in 0..batch.columns.len() { + for i in 0..batch.0.len() { let null_mask = &mut self.null_masks[i]; - if batch.row_count as usize > null_mask.len() { + if batch.1 as usize > null_mask.len() { null_mask.push(true); } } } - pub fn finish(mut self) -> InsertBatch { + pub fn finish(mut self) -> (Vec, RowCount) { let null_masks = self.null_masks; for (i, null_mask) in null_masks.into_iter().enumerate() { - let columns = &mut self.batch.columns; + let columns = &mut self.batch.0; columns[i].null_mask = null_mask.into_vec(); } self.batch @@ -204,9 +205,9 @@ impl LinesWriter { let batch = &mut self.batch; let to_insert = self.lines; let mut null_mask = BitVec::with_capacity(to_insert); - null_mask.extend(BitVec::repeat(true, batch.row_count as usize)); + null_mask.extend(BitVec::repeat(true, batch.1 as usize)); self.null_masks.push(null_mask); - batch.columns.push(Column { + batch.0.push(Column { column_name: column_name.to_string(), semantic_type: semantic_type.into(), values: Some(Values::with_capacity(datatype, to_insert)), @@ -217,7 +218,7 @@ impl LinesWriter { new_idx } }; - (column_idx, &mut self.batch.columns[column_idx]) + (column_idx, &mut self.batch.0[column_idx]) } } @@ -282,9 +283,9 @@ mod tests { writer.commit(); let insert_batch = writer.finish(); - assert_eq!(3, insert_batch.row_count); + assert_eq!(3, insert_batch.1); - let columns = insert_batch.columns; + let columns = insert_batch.0; assert_eq!(9, columns.len()); let column = &columns[0]; diff --git a/src/common/recordbatch/src/recordbatch.rs b/src/common/recordbatch/src/recordbatch.rs index b768a2f0bc..5fc886f8b9 100644 --- a/src/common/recordbatch/src/recordbatch.rs +++ b/src/common/recordbatch/src/recordbatch.rs @@ -23,6 +23,7 @@ use snafu::ResultExt; use crate::error::{self, Result}; +// TODO(yingwen): We should hold vectors in the RecordBatch. #[derive(Clone, Debug, PartialEq)] pub struct RecordBatch { pub schema: SchemaRef, @@ -103,6 +104,7 @@ impl<'a> Iterator for RecordBatchRowIterator<'a> { } else { let mut row = Vec::with_capacity(self.columns); + // TODO(yingwen): Get from the vector if RecordBatch also holds vectors. for col in 0..self.columns { let column_array = self.record_batch.df_recordbatch.column(col); match arrow_array_get(column_array.as_ref(), self.row_cursor) diff --git a/src/common/substrait/Cargo.toml b/src/common/substrait/Cargo.toml index 41a1f74ae3..815a986d1e 100644 --- a/src/common/substrait/Cargo.toml +++ b/src/common/substrait/Cargo.toml @@ -9,7 +9,9 @@ bytes = "1.1" catalog = { path = "../../catalog" } common-catalog = { path = "../catalog" } common-error = { path = "../error" } +common-telemetry = { path = "../telemetry" } datafusion = "14.0.0" +datafusion-expr = "14.0.0" datatypes = { path = "../../datatypes" } futures = "0.3" prost = "0.9" diff --git a/src/common/substrait/src/context.rs b/src/common/substrait/src/context.rs new file mode 100644 index 0000000000..893546ea48 --- /dev/null +++ b/src/common/substrait/src/context.rs @@ -0,0 +1,66 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; + +use substrait_proto::protobuf::extensions::simple_extension_declaration::{ + ExtensionFunction, MappingType, +}; +use substrait_proto::protobuf::extensions::SimpleExtensionDeclaration; + +#[derive(Default)] +pub struct ConvertorContext { + scalar_fn_names: HashMap, + scalar_fn_map: HashMap, +} + +impl ConvertorContext { + pub fn register_scalar_fn>(&mut self, name: S) -> u32 { + if let Some(anchor) = self.scalar_fn_names.get(name.as_ref()) { + return *anchor; + } + + let next_anchor = self.scalar_fn_map.len() as _; + self.scalar_fn_map + .insert(next_anchor, name.as_ref().to_string()); + self.scalar_fn_names + .insert(name.as_ref().to_string(), next_anchor); + next_anchor + } + + pub fn register_scalar_with_anchor>(&mut self, name: S, anchor: u32) { + self.scalar_fn_map.insert(anchor, name.as_ref().to_string()); + self.scalar_fn_names + .insert(name.as_ref().to_string(), anchor); + } + + pub fn find_scalar_fn(&self, anchor: u32) -> Option<&str> { + self.scalar_fn_map.get(&anchor).map(|s| s.as_str()) + } + + pub fn generate_function_extension(&self) -> Vec { + let mut result = Vec::with_capacity(self.scalar_fn_map.len()); + for (anchor, name) in &self.scalar_fn_map { + let declaration = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction { + extension_uri_reference: 0, + function_anchor: *anchor, + name: name.clone(), + })), + }; + result.push(declaration); + } + result + } +} diff --git a/src/common/substrait/src/df_expr.rs b/src/common/substrait/src/df_expr.rs new file mode 100644 index 0000000000..8267fa9cc1 --- /dev/null +++ b/src/common/substrait/src/df_expr.rs @@ -0,0 +1,742 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; +use std::str::FromStr; + +use datafusion::logical_plan::{Column, Expr}; +use datafusion_expr::{expr_fn, BuiltinScalarFunction, Operator}; +use datatypes::schema::Schema; +use snafu::{ensure, OptionExt}; +use substrait_proto::protobuf::expression::field_reference::ReferenceType as FieldReferenceType; +use substrait_proto::protobuf::expression::reference_segment::{ + ReferenceType as SegReferenceType, StructField, +}; +use substrait_proto::protobuf::expression::{ + FieldReference, ReferenceSegment, RexType, ScalarFunction, +}; +use substrait_proto::protobuf::function_argument::ArgType; +use substrait_proto::protobuf::Expression; + +use crate::context::ConvertorContext; +use crate::error::{ + EmptyExprSnafu, InvalidParametersSnafu, MissingFieldSnafu, Result, UnsupportedExprSnafu, +}; + +/// Convert substrait's `Expression` to DataFusion's `Expr`. +pub fn to_df_expr(ctx: &ConvertorContext, expression: Expression, schema: &Schema) -> Result { + let expr_rex_type = expression.rex_type.context(EmptyExprSnafu)?; + match expr_rex_type { + RexType::Literal(_) => UnsupportedExprSnafu { + name: "substrait Literal expression", + } + .fail()?, + RexType::Selection(selection) => convert_selection_rex(*selection, schema), + RexType::ScalarFunction(scalar_fn) => convert_scalar_function(ctx, scalar_fn, schema), + RexType::WindowFunction(_) + | RexType::IfThen(_) + | RexType::SwitchExpression(_) + | RexType::SingularOrList(_) + | RexType::MultiOrList(_) + | RexType::Cast(_) + | RexType::Subquery(_) + | RexType::Enum(_) => UnsupportedExprSnafu { + name: format!("substrait expression {:?}", expr_rex_type), + } + .fail()?, + } +} + +/// Convert Substrait's `FieldReference` - `DirectReference` - `StructField` to Datafusion's +/// `Column` expr. +pub fn convert_selection_rex(selection: FieldReference, schema: &Schema) -> Result { + if let Some(FieldReferenceType::DirectReference(direct_ref)) = selection.reference_type + && let Some(SegReferenceType::StructField(field)) = direct_ref.reference_type { + let column_name = schema.column_name_by_index(field.field as _).to_string(); + Ok(Expr::Column(Column { + relation: None, + name: column_name, + })) + } else { + InvalidParametersSnafu { + reason: "Only support direct struct reference in Selection Rex", + } + .fail() + } +} + +pub fn convert_scalar_function( + ctx: &ConvertorContext, + scalar_fn: ScalarFunction, + schema: &Schema, +) -> Result { + // convert argument + let mut inputs = VecDeque::with_capacity(scalar_fn.arguments.len()); + for arg in scalar_fn.arguments { + if let Some(ArgType::Value(sub_expr)) = arg.arg_type { + inputs.push_back(to_df_expr(ctx, sub_expr, schema)?); + } else { + InvalidParametersSnafu { + reason: "Only value expression arg is supported to be function argument", + } + .fail()?; + } + } + + // convert this scalar function + // map function name + let anchor = scalar_fn.function_reference; + let fn_name = ctx + .find_scalar_fn(anchor) + .with_context(|| InvalidParametersSnafu { + reason: format!("Unregistered scalar function reference: {}", anchor), + })?; + + // convenient util + let ensure_arg_len = |expected: usize| -> Result<()> { + ensure!( + inputs.len() == expected, + InvalidParametersSnafu { + reason: format!( + "Invalid number of scalar function {}, expected {} but found {}", + fn_name, + expected, + inputs.len() + ) + } + ); + Ok(()) + }; + + // construct DataFusion expr + let expr = match fn_name { + // begin binary exprs, with the same order of DF `Operator`'s definition. + "eq" | "equal" => { + ensure_arg_len(2)?; + inputs.pop_front().unwrap().eq(inputs.pop_front().unwrap()) + } + "not_eq" | "not_equal" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .not_eq(inputs.pop_front().unwrap()) + } + "lt" => { + ensure_arg_len(2)?; + inputs.pop_front().unwrap().lt(inputs.pop_front().unwrap()) + } + "lt_eq" | "lte" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .lt_eq(inputs.pop_front().unwrap()) + } + "gt" => { + ensure_arg_len(2)?; + inputs.pop_front().unwrap().gt(inputs.pop_front().unwrap()) + } + "gt_eq" | "gte" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .gt_eq(inputs.pop_front().unwrap()) + } + "plus" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Plus, + inputs.pop_front().unwrap(), + ) + } + "minus" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Minus, + inputs.pop_front().unwrap(), + ) + } + "multiply" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Multiply, + inputs.pop_front().unwrap(), + ) + } + "divide" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Divide, + inputs.pop_front().unwrap(), + ) + } + "modulo" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::Modulo, + inputs.pop_front().unwrap(), + ) + } + "and" => { + ensure_arg_len(2)?; + expr_fn::and(inputs.pop_front().unwrap(), inputs.pop_front().unwrap()) + } + "or" => { + ensure_arg_len(2)?; + expr_fn::or(inputs.pop_front().unwrap(), inputs.pop_front().unwrap()) + } + "like" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .like(inputs.pop_front().unwrap()) + } + "not_like" => { + ensure_arg_len(2)?; + inputs + .pop_front() + .unwrap() + .not_like(inputs.pop_front().unwrap()) + } + "is_distinct_from" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::IsDistinctFrom, + inputs.pop_front().unwrap(), + ) + } + "is_not_distinct_from" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::IsNotDistinctFrom, + inputs.pop_front().unwrap(), + ) + } + "regex_match" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::RegexMatch, + inputs.pop_front().unwrap(), + ) + } + "regex_i_match" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::RegexIMatch, + inputs.pop_front().unwrap(), + ) + } + "regex_not_match" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::RegexNotMatch, + inputs.pop_front().unwrap(), + ) + } + "regex_not_i_match" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::RegexNotIMatch, + inputs.pop_front().unwrap(), + ) + } + "bitwise_and" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::BitwiseAnd, + inputs.pop_front().unwrap(), + ) + } + "bitwise_or" => { + ensure_arg_len(2)?; + expr_fn::binary_expr( + inputs.pop_front().unwrap(), + Operator::BitwiseOr, + inputs.pop_front().unwrap(), + ) + } + // end binary exprs + // start other direct expr, with the same order of DF `Expr`'s definition. + "not" => { + ensure_arg_len(1)?; + inputs.pop_front().unwrap().not() + } + "is_not_null" => { + ensure_arg_len(1)?; + inputs.pop_front().unwrap().is_not_null() + } + "is_null" => { + ensure_arg_len(1)?; + inputs.pop_front().unwrap().is_null() + } + "negative" => { + ensure_arg_len(1)?; + Expr::Negative(Box::new(inputs.pop_front().unwrap())) + } + // skip GetIndexedField, unimplemented. + "between" => { + ensure_arg_len(3)?; + Expr::Between { + expr: Box::new(inputs.pop_front().unwrap()), + negated: false, + low: Box::new(inputs.pop_front().unwrap()), + high: Box::new(inputs.pop_front().unwrap()), + } + } + "not_between" => { + ensure_arg_len(3)?; + Expr::Between { + expr: Box::new(inputs.pop_front().unwrap()), + negated: true, + low: Box::new(inputs.pop_front().unwrap()), + high: Box::new(inputs.pop_front().unwrap()), + } + } + // skip Case, is covered in substrait::SwitchExpression. + // skip Cast and TryCast, is covered in substrait::Cast. + "sort" | "sort_des" => { + ensure_arg_len(1)?; + Expr::Sort { + expr: Box::new(inputs.pop_front().unwrap()), + asc: false, + nulls_first: false, + } + } + "sort_asc" => { + ensure_arg_len(1)?; + Expr::Sort { + expr: Box::new(inputs.pop_front().unwrap()), + asc: true, + nulls_first: false, + } + } + // those are datafusion built-in "scalar functions". + "abs" + | "acos" + | "asin" + | "atan" + | "atan2" + | "ceil" + | "cos" + | "exp" + | "floor" + | "ln" + | "log" + | "log10" + | "log2" + | "power" + | "pow" + | "round" + | "signum" + | "sin" + | "sqrt" + | "tan" + | "trunc" + | "coalesce" + | "make_array" + | "ascii" + | "bit_length" + | "btrim" + | "char_length" + | "character_length" + | "concat" + | "concat_ws" + | "chr" + | "current_date" + | "current_time" + | "date_part" + | "datepart" + | "date_trunc" + | "datetrunc" + | "date_bin" + | "initcap" + | "left" + | "length" + | "lower" + | "lpad" + | "ltrim" + | "md5" + | "nullif" + | "octet_length" + | "random" + | "regexp_replace" + | "repeat" + | "replace" + | "reverse" + | "right" + | "rpad" + | "rtrim" + | "sha224" + | "sha256" + | "sha384" + | "sha512" + | "digest" + | "split_part" + | "starts_with" + | "strpos" + | "substr" + | "to_hex" + | "to_timestamp" + | "to_timestamp_millis" + | "to_timestamp_micros" + | "to_timestamp_seconds" + | "now" + | "translate" + | "trim" + | "upper" + | "uuid" + | "regexp_match" + | "struct" + | "from_unixtime" + | "arrow_typeof" => Expr::ScalarFunction { + fun: BuiltinScalarFunction::from_str(fn_name).unwrap(), + args: inputs.into(), + }, + // skip ScalarUDF, unimplemented. + // skip AggregateFunction, is covered in substrait::AggregateRel + // skip WindowFunction, is covered in substrait WindowFunction + // skip AggregateUDF, unimplemented. + // skip InList, unimplemented + // skip Wildcard, unimplemented. + // end other direct expr + _ => UnsupportedExprSnafu { + name: format!("scalar function {}", fn_name), + } + .fail()?, + }; + + Ok(expr) +} + +/// Convert DataFusion's `Expr` to substrait's `Expression` +pub fn expression_from_df_expr( + ctx: &mut ConvertorContext, + expr: &Expr, + schema: &Schema, +) -> Result { + let expression = match expr { + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::Alias(..) => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + Expr::Column(column) => { + let field_reference = convert_column(column, schema)?; + Expression { + rex_type: Some(RexType::Selection(Box::new(field_reference))), + } + } + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::ScalarVariable(..) | Expr::Literal(..) => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + Expr::BinaryExpr { left, op, right } => { + let left = expression_from_df_expr(ctx, left, schema)?; + let right = expression_from_df_expr(ctx, right, schema)?; + let arguments = utils::expression_to_argument(vec![left, right]); + let op_name = utils::name_df_operator(op); + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::Not(e) => { + let arg = expression_from_df_expr(ctx, e, schema)?; + let arguments = utils::expression_to_argument(vec![arg]); + let op_name = "not"; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::IsNotNull(e) => { + let arg = expression_from_df_expr(ctx, e, schema)?; + let arguments = utils::expression_to_argument(vec![arg]); + let op_name = "is_not_null"; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::IsNull(e) => { + let arg = expression_from_df_expr(ctx, e, schema)?; + let arguments = utils::expression_to_argument(vec![arg]); + let op_name = "is_null"; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::Negative(e) => { + let arg = expression_from_df_expr(ctx, e, schema)?; + let arguments = utils::expression_to_argument(vec![arg]); + let op_name = "negative"; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::GetIndexedField { .. } => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + Expr::Between { + expr, + negated, + low, + high, + } => { + let expr = expression_from_df_expr(ctx, expr, schema)?; + let low = expression_from_df_expr(ctx, low, schema)?; + let high = expression_from_df_expr(ctx, high, schema)?; + let arguments = utils::expression_to_argument(vec![expr, low, high]); + let op_name = if *negated { "not_between" } else { "between" }; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + Expr::Sort { + expr, + asc, + nulls_first: _, + } => { + let expr = expression_from_df_expr(ctx, expr, schema)?; + let arguments = utils::expression_to_argument(vec![expr]); + let op_name = if *asc { "sort_asc" } else { "sort_des" }; + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + Expr::ScalarFunction { fun, args } => { + let arguments = utils::expression_to_argument( + args.iter() + .map(|e| expression_from_df_expr(ctx, e, schema)) + .collect::>>()?, + ); + let op_name = utils::name_builtin_scalar_function(fun); + let function_reference = ctx.register_scalar_fn(op_name); + utils::build_scalar_function_expression(function_reference, arguments) + } + // Don't merge them with other unsupported expr arms to preserve the ordering. + Expr::ScalarUDF { .. } + | Expr::AggregateFunction { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateUDF { .. } + | Expr::InList { .. } + | Expr::Wildcard => UnsupportedExprSnafu { + name: expr.to_string(), + } + .fail()?, + }; + + Ok(expression) +} + +/// Convert DataFusion's `Column` expr into substrait's `FieldReference` - +/// `DirectReference` - `StructField`. +pub fn convert_column(column: &Column, schema: &Schema) -> Result { + let column_name = &column.name; + let field_index = + schema + .column_index_by_name(column_name) + .with_context(|| MissingFieldSnafu { + field: format!("{:?}", column), + plan: format!("schema: {:?}", schema), + })?; + + Ok(FieldReference { + reference_type: Some(FieldReferenceType::DirectReference(ReferenceSegment { + reference_type: Some(SegReferenceType::StructField(Box::new(StructField { + field: field_index as _, + child: None, + }))), + })), + root_type: None, + }) +} + +/// Some utils special for this `DataFusion::Expr` and `Substrait::Expression` conversion. +mod utils { + use datafusion_expr::{BuiltinScalarFunction, Operator}; + use substrait_proto::protobuf::expression::{RexType, ScalarFunction}; + use substrait_proto::protobuf::function_argument::ArgType; + use substrait_proto::protobuf::{Expression, FunctionArgument}; + + pub(crate) fn name_df_operator(op: &Operator) -> &str { + match op { + Operator::Eq => "equal", + Operator::NotEq => "not_equal", + Operator::Lt => "lt", + Operator::LtEq => "lte", + Operator::Gt => "gt", + Operator::GtEq => "gte", + Operator::Plus => "plus", + Operator::Minus => "minus", + Operator::Multiply => "multiply", + Operator::Divide => "divide", + Operator::Modulo => "modulo", + Operator::And => "and", + Operator::Or => "or", + Operator::Like => "like", + Operator::NotLike => "not_like", + Operator::IsDistinctFrom => "is_distinct_from", + Operator::IsNotDistinctFrom => "is_not_distinct_from", + Operator::RegexMatch => "regex_match", + Operator::RegexIMatch => "regex_i_match", + Operator::RegexNotMatch => "regex_not_match", + Operator::RegexNotIMatch => "regex_not_i_match", + Operator::BitwiseAnd => "bitwise_and", + Operator::BitwiseOr => "bitwise_or", + } + } + + /// Convert list of [Expression] to [FunctionArgument] vector. + pub(crate) fn expression_to_argument>( + expressions: I, + ) -> Vec { + expressions + .into_iter() + .map(|expr| FunctionArgument { + arg_type: Some(ArgType::Value(expr)), + }) + .collect() + } + + /// Convenient builder for [Expression] + pub(crate) fn build_scalar_function_expression( + function_reference: u32, + arguments: Vec, + ) -> Expression { + Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference, + arguments, + output_type: None, + ..Default::default() + })), + } + } + + pub(crate) fn name_builtin_scalar_function(fun: &BuiltinScalarFunction) -> &str { + match fun { + BuiltinScalarFunction::Abs => "abs", + BuiltinScalarFunction::Acos => "acos", + BuiltinScalarFunction::Asin => "asin", + BuiltinScalarFunction::Atan => "atan", + BuiltinScalarFunction::Ceil => "ceil", + BuiltinScalarFunction::Cos => "cos", + BuiltinScalarFunction::Digest => "digest", + BuiltinScalarFunction::Exp => "exp", + BuiltinScalarFunction::Floor => "floor", + BuiltinScalarFunction::Ln => "ln", + BuiltinScalarFunction::Log => "log", + BuiltinScalarFunction::Log10 => "log10", + BuiltinScalarFunction::Log2 => "log2", + BuiltinScalarFunction::Round => "round", + BuiltinScalarFunction::Signum => "signum", + BuiltinScalarFunction::Sin => "sin", + BuiltinScalarFunction::Sqrt => "sqrt", + BuiltinScalarFunction::Tan => "tan", + BuiltinScalarFunction::Trunc => "trunc", + BuiltinScalarFunction::Array => "make_array", + BuiltinScalarFunction::Ascii => "ascii", + BuiltinScalarFunction::BitLength => "bit_length", + BuiltinScalarFunction::Btrim => "btrim", + BuiltinScalarFunction::CharacterLength => "character_length", + BuiltinScalarFunction::Chr => "chr", + BuiltinScalarFunction::Concat => "concat", + BuiltinScalarFunction::ConcatWithSeparator => "concat_ws", + BuiltinScalarFunction::DatePart => "date_part", + BuiltinScalarFunction::DateTrunc => "date_trunc", + BuiltinScalarFunction::InitCap => "initcap", + BuiltinScalarFunction::Left => "left", + BuiltinScalarFunction::Lpad => "lpad", + BuiltinScalarFunction::Lower => "lower", + BuiltinScalarFunction::Ltrim => "ltrim", + BuiltinScalarFunction::MD5 => "md5", + BuiltinScalarFunction::NullIf => "nullif", + BuiltinScalarFunction::OctetLength => "octet_length", + BuiltinScalarFunction::Random => "random", + BuiltinScalarFunction::RegexpReplace => "regexp_replace", + BuiltinScalarFunction::Repeat => "repeat", + BuiltinScalarFunction::Replace => "replace", + BuiltinScalarFunction::Reverse => "reverse", + BuiltinScalarFunction::Right => "right", + BuiltinScalarFunction::Rpad => "rpad", + BuiltinScalarFunction::Rtrim => "rtrim", + BuiltinScalarFunction::SHA224 => "sha224", + BuiltinScalarFunction::SHA256 => "sha256", + BuiltinScalarFunction::SHA384 => "sha384", + BuiltinScalarFunction::SHA512 => "sha512", + BuiltinScalarFunction::SplitPart => "split_part", + BuiltinScalarFunction::StartsWith => "starts_with", + BuiltinScalarFunction::Strpos => "strpos", + BuiltinScalarFunction::Substr => "substr", + BuiltinScalarFunction::ToHex => "to_hex", + BuiltinScalarFunction::ToTimestamp => "to_timestamp", + BuiltinScalarFunction::ToTimestampMillis => "to_timestamp_millis", + BuiltinScalarFunction::ToTimestampMicros => "to_timestamp_macros", + BuiltinScalarFunction::ToTimestampSeconds => "to_timestamp_seconds", + BuiltinScalarFunction::Now => "now", + BuiltinScalarFunction::Translate => "translate", + BuiltinScalarFunction::Trim => "trim", + BuiltinScalarFunction::Upper => "upper", + BuiltinScalarFunction::RegexpMatch => "regexp_match", + } + } +} + +#[cfg(test)] +mod test { + use datatypes::schema::ColumnSchema; + + use super::*; + + #[test] + fn expr_round_trip() { + let expr = expr_fn::and( + expr_fn::col("column_a").lt_eq(expr_fn::col("column_b")), + expr_fn::col("column_a").gt(expr_fn::col("column_b")), + ); + + let schema = Schema::new(vec![ + ColumnSchema::new( + "column_a", + datatypes::data_type::ConcreteDataType::int64_datatype(), + true, + ), + ColumnSchema::new( + "column_b", + datatypes::data_type::ConcreteDataType::float64_datatype(), + true, + ), + ]); + + let mut ctx = ConvertorContext::default(); + let substrait_expr = expression_from_df_expr(&mut ctx, &expr, &schema).unwrap(); + let converted_expr = to_df_expr(&ctx, substrait_expr, &schema).unwrap(); + + assert_eq!(expr, converted_expr); + } +} diff --git a/src/common/substrait/src/df_logical.rs b/src/common/substrait/src/df_logical.rs index 6f0573144c..8d53ef1b08 100644 --- a/src/common/substrait/src/df_logical.rs +++ b/src/common/substrait/src/df_logical.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use bytes::{Buf, Bytes, BytesMut}; use catalog::CatalogManagerRef; use common_error::prelude::BoxedError; +use common_telemetry::debug; use datafusion::datasource::TableProvider; use datafusion::logical_plan::{LogicalPlan, TableScan, ToDFSchema}; use datafusion::physical_plan::project_schema; @@ -24,12 +25,15 @@ use prost::Message; use snafu::{ensure, OptionExt, ResultExt}; use substrait_proto::protobuf::expression::mask_expression::{StructItem, StructSelect}; use substrait_proto::protobuf::expression::MaskExpression; +use substrait_proto::protobuf::extensions::simple_extension_declaration::MappingType; use substrait_proto::protobuf::plan_rel::RelType as PlanRelType; use substrait_proto::protobuf::read_rel::{NamedTable, ReadType}; use substrait_proto::protobuf::rel::RelType; -use substrait_proto::protobuf::{PlanRel, ReadRel, Rel}; +use substrait_proto::protobuf::{Plan, PlanRel, ReadRel, Rel}; use table::table::adapter::DfTableProviderAdapter; +use crate::context::ConvertorContext; +use crate::df_expr::{expression_from_df_expr, to_df_expr}; use crate::error::{ DFInternalSnafu, DecodeRelSnafu, EmptyPlanSnafu, EncodeRelSnafu, Error, InternalSnafu, InvalidParametersSnafu, MissingFieldSnafu, SchemaNotMatchSnafu, TableNotFoundSnafu, @@ -48,25 +52,15 @@ impl SubstraitPlan for DFLogicalSubstraitConvertor { type Plan = LogicalPlan; fn decode(&self, message: B) -> Result { - let plan_rel = PlanRel::decode(message).context(DecodeRelSnafu)?; - let rel = match plan_rel.rel_type.context(EmptyPlanSnafu)? { - PlanRelType::Rel(rel) => rel, - PlanRelType::Root(_) => UnsupportedPlanSnafu { - name: "Root Relation", - } - .fail()?, - }; - self.convert_rel(rel) + let plan = Plan::decode(message).context(DecodeRelSnafu)?; + self.convert_plan(plan) } fn encode(&self, plan: Self::Plan) -> Result { - let rel = self.convert_plan(plan)?; - let plan_rel = PlanRel { - rel_type: Some(PlanRelType::Rel(rel)), - }; + let plan = self.convert_df_plan(plan)?; let mut buf = BytesMut::new(); - plan_rel.encode(&mut buf).context(EncodeRelSnafu)?; + plan.encode(&mut buf).context(EncodeRelSnafu)?; Ok(buf.freeze()) } @@ -79,10 +73,37 @@ impl DFLogicalSubstraitConvertor { } impl DFLogicalSubstraitConvertor { - pub fn convert_rel(&self, rel: Rel) -> Result { + pub fn convert_plan(&self, mut plan: Plan) -> Result { + // prepare convertor context + let mut ctx = ConvertorContext::default(); + for simple_ext in plan.extensions { + if let Some(MappingType::ExtensionFunction(function_extension)) = + simple_ext.mapping_type + { + ctx.register_scalar_with_anchor( + function_extension.name, + function_extension.function_anchor, + ); + } else { + debug!("Encounter unsupported substrait extension {:?}", simple_ext); + } + } + + // extract rel + let rel = if let Some(PlanRel { rel_type }) = plan.relations.pop() + && let Some(PlanRelType::Rel(rel)) = rel_type { + rel + } else { + UnsupportedPlanSnafu { + name: "Emply or non-Rel relation", + } + .fail()? + }; let rel_type = rel.rel_type.context(EmptyPlanSnafu)?; + + // build logical plan let logical_plan = match rel_type { - RelType::Read(read_rel) => self.convert_read_rel(read_rel), + RelType::Read(read_rel) => self.convert_read_rel(&mut ctx, read_rel), RelType::Filter(_filter_rel) => UnsupportedPlanSnafu { name: "Filter Relation", } @@ -132,9 +153,12 @@ impl DFLogicalSubstraitConvertor { Ok(logical_plan) } - fn convert_read_rel(&self, read_rel: Box) -> Result { + fn convert_read_rel( + &self, + ctx: &mut ConvertorContext, + read_rel: Box, + ) -> Result { // Extract the catalog, schema and table name from NamedTable. Assume the first three are those names. - let read_type = read_rel.read_type.context(MissingFieldSnafu { field: "read_type", plan: "Read", @@ -190,6 +214,13 @@ impl DFLogicalSubstraitConvertor { } ); + // Convert filter + let filters = if let Some(filter) = read_rel.filter { + vec![to_df_expr(ctx, *filter, &retrieved_schema)?] + } else { + vec![] + }; + // Calculate the projected schema let projected_schema = project_schema(&stored_schema, projection.as_ref()) .context(DFInternalSnafu)? @@ -202,7 +233,7 @@ impl DFLogicalSubstraitConvertor { source: adapter, projection, projected_schema, - filters: vec![], + filters, limit: None, })) } @@ -219,8 +250,12 @@ impl DFLogicalSubstraitConvertor { } impl DFLogicalSubstraitConvertor { - pub fn convert_plan(&self, plan: LogicalPlan) -> Result { - match plan { + pub fn convert_df_plan(&self, plan: LogicalPlan) -> Result { + let mut ctx = ConvertorContext::default(); + + // TODO(ruihang): extract this translation logic into a separated function + // convert PlanRel + let rel = match plan { LogicalPlan::Projection(_) => UnsupportedPlanSnafu { name: "DataFusion Logical Projection", } @@ -258,10 +293,10 @@ impl DFLogicalSubstraitConvertor { } .fail()?, LogicalPlan::TableScan(table_scan) => { - let read_rel = self.convert_table_scan_plan(table_scan)?; - Ok(Rel { + let read_rel = self.convert_table_scan_plan(&mut ctx, table_scan)?; + Rel { rel_type: Some(RelType::Read(Box::new(read_rel))), - }) + } } LogicalPlan::EmptyRelation(_) => UnsupportedPlanSnafu { name: "DataFusion Logical EmptyRelation", @@ -284,10 +319,30 @@ impl DFLogicalSubstraitConvertor { ), } .fail()?, - } + }; + + // convert extension + let extensions = ctx.generate_function_extension(); + + // assemble PlanRel + let plan_rel = PlanRel { + rel_type: Some(PlanRelType::Rel(rel)), + }; + + Ok(Plan { + extension_uris: vec![], + extensions, + relations: vec![plan_rel], + advanced_extensions: None, + expected_type_urls: vec![], + }) } - pub fn convert_table_scan_plan(&self, table_scan: TableScan) -> Result { + pub fn convert_table_scan_plan( + &self, + ctx: &mut ConvertorContext, + table_scan: TableScan, + ) -> Result { let provider = table_scan .source .as_any() @@ -313,10 +368,25 @@ impl DFLogicalSubstraitConvertor { // assemble base (unprojected) schema using Table's schema. let base_schema = from_schema(&provider.table().schema())?; + // make conjunction over a list of filters and convert the result to substrait + let filter = if let Some(conjunction) = table_scan + .filters + .into_iter() + .reduce(|accum, expr| accum.and(expr)) + { + Some(Box::new(expression_from_df_expr( + ctx, + &conjunction, + &provider.table().schema(), + )?)) + } else { + None + }; + let read_rel = ReadRel { common: None, base_schema: Some(base_schema), - filter: None, + filter, projection, advanced_extension: None, read_type: Some(read_type), diff --git a/src/common/substrait/src/error.rs b/src/common/substrait/src/error.rs index 74e2112a91..c33b3679fb 100644 --- a/src/common/substrait/src/error.rs +++ b/src/common/substrait/src/error.rs @@ -23,10 +23,10 @@ use snafu::{Backtrace, ErrorCompat, Snafu}; #[derive(Debug, Snafu)] #[snafu(visibility(pub))] pub enum Error { - #[snafu(display("Unsupported physical expr: {}", name))] + #[snafu(display("Unsupported physical plan: {}", name))] UnsupportedPlan { name: String, backtrace: Backtrace }, - #[snafu(display("Unsupported physical plan: {}", name))] + #[snafu(display("Unsupported expr: {}", name))] UnsupportedExpr { name: String, backtrace: Backtrace }, #[snafu(display("Unsupported concrete type: {:?}", ty))] diff --git a/src/common/substrait/src/lib.rs b/src/common/substrait/src/lib.rs index 137796b527..c318799a3b 100644 --- a/src/common/substrait/src/lib.rs +++ b/src/common/substrait/src/lib.rs @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![feature(let_chains)] + +mod context; +mod df_expr; mod df_logical; pub mod error; mod schema; diff --git a/src/common/time/src/timestamp.rs b/src/common/time/src/timestamp.rs index fd0f148d96..5ff20f702b 100644 --- a/src/common/time/src/timestamp.rs +++ b/src/common/time/src/timestamp.rs @@ -147,6 +147,18 @@ impl From for Timestamp { } } +impl From for i64 { + fn from(t: Timestamp) -> Self { + t.value + } +} + +impl From for serde_json::Value { + fn from(d: Timestamp) -> Self { + serde_json::Value::String(d.to_iso8601_string()) + } +} + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum TimeUnit { Second, @@ -197,6 +209,7 @@ impl Hash for Timestamp { #[cfg(test)] mod tests { use chrono::Offset; + use serde_json::Value; use super::*; @@ -318,4 +331,39 @@ mod tests { let ts = Timestamp::from_millis(ts_millis); assert_eq!("1969-12-31 23:59:58.999+0000", ts.to_iso8601_string()); } + + #[test] + fn test_serialize_to_json_value() { + assert_eq!( + "1970-01-01 00:00:01+0000", + match serde_json::Value::from(Timestamp::new(1, TimeUnit::Second)) { + Value::String(s) => s, + _ => unreachable!(), + } + ); + + assert_eq!( + "1970-01-01 00:00:00.001+0000", + match serde_json::Value::from(Timestamp::new(1, TimeUnit::Millisecond)) { + Value::String(s) => s, + _ => unreachable!(), + } + ); + + assert_eq!( + "1970-01-01 00:00:00.000001+0000", + match serde_json::Value::from(Timestamp::new(1, TimeUnit::Microsecond)) { + Value::String(s) => s, + _ => unreachable!(), + } + ); + + assert_eq!( + "1970-01-01 00:00:00.000000001+0000", + match serde_json::Value::from(Timestamp::new(1, TimeUnit::Nanosecond)) { + Value::String(s) => s, + _ => unreachable!(), + } + ); + } } diff --git a/src/datanode/Cargo.toml b/src/datanode/Cargo.toml index a879c62945..56a9ce25f1 100644 --- a/src/datanode/Cargo.toml +++ b/src/datanode/Cargo.toml @@ -11,13 +11,15 @@ python = ["dep:script"] [dependencies] api = { path = "../api" } async-trait = "0.1" -axum = "0.6.0-rc.2" -axum-macros = "0.3.0-rc.1" +axum = "0.6" +axum-macros = "0.3" +backon = "0.2" catalog = { path = "../catalog" } common-base = { path = "../common/base" } common-catalog = { path = "../common/catalog" } common-error = { path = "../common/error" } common-grpc = { path = "../common/grpc" } +common-grpc-expr = { path = "../common/grpc-expr" } common-query = { path = "../common/query" } common-recordbatch = { path = "../common/recordbatch" } common-runtime = { path = "../common/runtime" } @@ -26,36 +28,39 @@ common-time = { path = "../common/time" } common-insert = { path = "../common/insert" } datafusion = "14.0.0" datatypes = { path = "../datatypes" } +frontend = { path = "../frontend" } futures = "0.3" hyper = { version = "0.14", features = ["full"] } log-store = { path = "../log-store" } meta-client = { path = "../meta-client" } meta-srv = { path = "../meta-srv", features = ["mock"] } metrics = "0.20" +mito = { path = "../mito", features = ["test"] } object-store = { path = "../object-store" } query = { path = "../query" } script = { path = "../script", features = ["python"], optional = true } serde = "1.0" serde_json = "1.0" servers = { path = "../servers" } +session = { path = "../session" } snafu = { version = "0.7", features = ["backtraces"] } sql = { path = "../sql" } storage = { path = "../storage" } store-api = { path = "../store-api" } substrait = { path = "../common/substrait" } table = { path = "../table" } -mito = { path = "../mito", features = ["test"] } tokio = { version = "1.18", features = ["full"] } tokio-stream = { version = "0.1", features = ["net"] } tonic = "0.8" tower = { version = "0.4", features = ["full"] } tower-http = { version = "0.3", features = ["full"] } -frontend = { path = "../frontend" } [dev-dependencies] -axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } +tempdir = "0.3"axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } client = { path = "../client" } common-query = { path = "../common/query" } -datafusion = "14.0.0" datafusion-common = "14.0.0" -tempdir = "0.3" +datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [ + "simd", +] } +datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" } diff --git a/src/datanode/src/datanode.rs b/src/datanode/src/datanode.rs index 89cb34eda2..ccc8b0d3c6 100644 --- a/src/datanode/src/datanode.rs +++ b/src/datanode/src/datanode.rs @@ -26,7 +26,15 @@ use crate::server::Services; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum ObjectStoreConfig { - File { data_dir: String }, + File { + data_dir: String, + }, + S3 { + bucket: String, + root: String, + access_key_id: String, + secret_access_key: String, + }, } impl Default for ObjectStoreConfig { @@ -47,6 +55,7 @@ pub struct DatanodeOptions { pub meta_client_opts: Option, pub wal_dir: String, pub storage: ObjectStoreConfig, + pub enable_memory_catalog: bool, pub mode: Mode, } @@ -61,6 +70,7 @@ impl Default for DatanodeOptions { meta_client_opts: None, wal_dir: "/tmp/greptimedb/wal".to_string(), storage: ObjectStoreConfig::default(), + enable_memory_catalog: false, mode: Mode::Standalone, } } @@ -86,9 +96,18 @@ impl Datanode { pub async fn start(&mut self) -> Result<()> { info!("Starting datanode instance..."); - self.instance.start().await?; - self.services.start(&self.opts).await?; - Ok(()) + self.start_instance().await?; + self.start_services().await + } + + /// Start only the internal component of datanode. + pub async fn start_instance(&mut self) -> Result<()> { + self.instance.start().await + } + + /// Start services of datanode. This method call will block until services are shutdown. + pub async fn start_services(&mut self) -> Result<()> { + self.services.start(&self.opts).await } pub fn get_instance(&self) -> InstanceRef { diff --git a/src/datanode/src/error.rs b/src/datanode/src/error.rs index a6ecd963a4..fa5fb8c4b4 100644 --- a/src/datanode/src/error.rs +++ b/src/datanode/src/error.rs @@ -18,6 +18,8 @@ use common_error::prelude::*; use storage::error::Error as StorageError; use table::error::Error as TableError; +use crate::datanode::ObjectStoreConfig; + /// Business error of datanode. #[derive(Debug, Snafu)] #[snafu(visibility(pub))] @@ -73,6 +75,13 @@ pub enum Error { source: TableError, }, + #[snafu(display("Failed to drop table {}, source: {}", table_name, source))] + DropTable { + table_name: String, + #[snafu(backtrace)] + source: BoxedError, + }, + #[snafu(display("Table not found: {}", table_name))] TableNotFound { table_name: String }, @@ -82,9 +91,6 @@ pub enum Error { table_name: String, }, - #[snafu(display("Missing required field in protobuf, field: {}", field))] - MissingField { field: String, backtrace: Backtrace }, - #[snafu(display("Missing timestamp column in request"))] MissingTimestampColumn { backtrace: Backtrace }, @@ -138,10 +144,10 @@ pub enum Error { #[snafu(display("Failed to storage engine, source: {}", source))] OpenStorageEngine { source: StorageError }, - #[snafu(display("Failed to init backend, dir: {}, source: {}", dir, source))] + #[snafu(display("Failed to init backend, config: {:#?}, source: {}", config, source))] InitBackend { - dir: String, - source: std::io::Error, + config: ObjectStoreConfig, + source: object_store::Error, backtrace: Backtrace, }, @@ -202,21 +208,16 @@ pub enum Error { source: common_grpc::Error, }, - #[snafu(display("Column datatype error, source: {}", source))] - ColumnDataType { + #[snafu(display("Failed to convert alter expr to request: {}", source))] + AlterExprToRequest { #[snafu(backtrace)] - source: api::error::Error, + source: common_grpc_expr::error::Error, }, - #[snafu(display( - "Invalid column proto definition, column: {}, source: {}", - column, - source - ))] - InvalidColumnDef { - column: String, + #[snafu(display("Failed to convert create expr to request: {}", source))] + CreateExprToRequest { #[snafu(backtrace)] - source: api::error::Error, + source: common_grpc_expr::error::Error, }, #[snafu(display("Failed to parse SQL, source: {}", source))] @@ -263,7 +264,7 @@ pub enum Error { #[snafu(display("Failed to insert data, source: {}", source))] InsertData { #[snafu(backtrace)] - source: common_insert::error::Error, + source: common_grpc_expr::error::Error, }, #[snafu(display("Insert batch is empty"))] @@ -306,6 +307,7 @@ impl ErrorExt for Error { Error::CreateTable { source, .. } | Error::GetTable { source, .. } | Error::AlterTable { source, .. } => source.status_code(), + Error::DropTable { source, .. } => source.status_code(), Error::Insert { source, .. } => source.status_code(), @@ -316,6 +318,8 @@ impl ErrorExt for Error { source.status_code() } + Error::AlterExprToRequest { source, .. } + | Error::CreateExprToRequest { source, .. } => source.status_code(), Error::CreateSchema { source, .. } | Error::ConvertSchema { source, .. } | Error::VectorComputation { source } => source.status_code(), @@ -324,7 +328,6 @@ impl ErrorExt for Error { | Error::InvalidSql { .. } | Error::KeyColumnNotFound { .. } | Error::InvalidPrimaryKey { .. } - | Error::MissingField { .. } | Error::MissingTimestampColumn { .. } | Error::CatalogNotFound { .. } | Error::SchemaNotFound { .. } @@ -343,10 +346,6 @@ impl ErrorExt for Error { | Error::UnsupportedExpr { .. } | Error::Catalog { .. } => StatusCode::Internal, - Error::ColumnDataType { source } | Error::InvalidColumnDef { source, .. } => { - source.status_code() - } - Error::InitBackend { .. } => StatusCode::StorageUnavailable, Error::OpenLogStore { source } => source.status_code(), Error::StartScriptManager { source } => source.status_code(), diff --git a/src/datanode/src/instance.rs b/src/datanode/src/instance.rs index b5c0b028e2..27cd13e12e 100644 --- a/src/datanode/src/instance.rs +++ b/src/datanode/src/instance.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use std::time::Duration; use std::{fs, path}; +use backon::ExponentialBackoff; use catalog::remote::MetaKvBackend; use catalog::CatalogManagerRef; use common_grpc::channel_manager::{ChannelConfig, ChannelManager}; @@ -26,8 +27,9 @@ use meta_client::client::{MetaClient, MetaClientBuilder}; use meta_client::MetaClientOpts; use mito::config::EngineConfig as TableEngineConfig; use mito::engine::MitoEngine; -use object_store::layers::LoggingLayer; -use object_store::services::fs::Builder; +use object_store::layers::{LoggingLayer, MetricsLayer, RetryLayer, TracingLayer}; +use object_store::services::fs::Builder as FsBuilder; +use object_store::services::s3::Builder as S3Builder; use object_store::{util, ObjectStore}; use query::query_engine::{QueryEngineFactory, QueryEngineRef}; use servers::Mode; @@ -99,17 +101,29 @@ impl Instance { // create remote catalog manager let (catalog_manager, factory, table_id_provider) = match opts.mode { Mode::Standalone => { - let catalog = Arc::new( - catalog::local::LocalCatalogManager::try_new(table_engine.clone()) - .await - .context(CatalogSnafu)?, - ); - let factory = QueryEngineFactory::new(catalog.clone()); - ( - catalog.clone() as CatalogManagerRef, - factory, - Some(catalog as TableIdProviderRef), - ) + if opts.enable_memory_catalog { + let catalog = Arc::new(catalog::local::MemoryCatalogManager::default()); + let factory = QueryEngineFactory::new(catalog.clone()); + + ( + catalog.clone() as CatalogManagerRef, + factory, + Some(catalog as TableIdProviderRef), + ) + } else { + let catalog = Arc::new( + catalog::local::LocalCatalogManager::try_new(table_engine.clone()) + .await + .context(CatalogSnafu)?, + ); + let factory = QueryEngineFactory::new(catalog.clone()); + + ( + catalog.clone() as CatalogManagerRef, + factory, + Some(catalog as TableIdProviderRef), + ) + } } Mode::Distributed => { @@ -139,7 +153,11 @@ impl Instance { }; Ok(Self { query_engine: query_engine.clone(), - sql_handler: SqlHandler::new(table_engine, catalog_manager.clone()), + sql_handler: SqlHandler::new( + table_engine, + catalog_manager.clone(), + query_engine.clone(), + ), catalog_manager, physical_planner: PhysicalPlanner::new(query_engine), script_executor, @@ -170,24 +188,64 @@ impl Instance { } pub(crate) async fn new_object_store(store_config: &ObjectStoreConfig) -> Result { - // TODO(dennis): supports other backend - let data_dir = util::normalize_dir(match store_config { - ObjectStoreConfig::File { data_dir } => data_dir, - }); + let object_store = match store_config { + ObjectStoreConfig::File { data_dir } => new_fs_object_store(data_dir).await, + ObjectStoreConfig::S3 { .. } => new_s3_object_store(store_config).await, + }; + object_store.map(|object_store| { + object_store + .layer(RetryLayer::new(ExponentialBackoff::default().with_jitter())) + .layer(MetricsLayer) + .layer(LoggingLayer) + .layer(TracingLayer) + }) +} + +pub(crate) async fn new_s3_object_store(store_config: &ObjectStoreConfig) -> Result { + let (root, secret_key, key_id, bucket) = match store_config { + ObjectStoreConfig::S3 { + bucket, + root, + access_key_id, + secret_access_key, + } => (root, secret_access_key, access_key_id, bucket), + _ => unreachable!(), + }; + + let root = util::normalize_dir(root); + info!("The s3 storage bucket is: {}, root is: {}", bucket, &root); + + let accessor = S3Builder::default() + .root(&root) + .bucket(bucket) + .access_key_id(key_id) + .secret_access_key(secret_key) + .build() + .with_context(|_| error::InitBackendSnafu { + config: store_config.clone(), + })?; + + Ok(ObjectStore::new(accessor)) +} + +pub(crate) async fn new_fs_object_store(data_dir: &str) -> Result { + let data_dir = util::normalize_dir(data_dir); fs::create_dir_all(path::Path::new(&data_dir)) .context(error::CreateDirSnafu { dir: &data_dir })?; + info!("The file storage directory is: {}", &data_dir); - info!("The storage directory is: {}", &data_dir); + let atomic_write_dir = format!("{}/.tmp/", data_dir); - let accessor = Builder::default() + let accessor = FsBuilder::default() .root(&data_dir) + .atomic_write_dir(&atomic_write_dir) .build() - .context(error::InitBackendSnafu { dir: &data_dir })?; + .context(error::InitBackendSnafu { + config: ObjectStoreConfig::File { data_dir }, + })?; - let object_store = ObjectStore::new(accessor).layer(LoggingLayer); // Add logging - - Ok(object_store) + Ok(ObjectStore::new(accessor)) } /// Create metasrv client instance and spawn heartbeat loop. diff --git a/src/datanode/src/instance/grpc.rs b/src/datanode/src/instance/grpc.rs index c863743d29..ddc03a6436 100644 --- a/src/datanode/src/instance/grpc.rs +++ b/src/datanode/src/instance/grpc.rs @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use api::result::{build_err_result, AdminResultBuilder, ObjectResultBuilder}; use api::v1::{ - admin_expr, insert_expr, object_expr, select_expr, AdminExpr, AdminResult, CreateDatabaseExpr, + admin_expr, object_expr, select_expr, AdminExpr, AdminResult, Column, CreateDatabaseExpr, ObjectExpr, ObjectResult, SelectExpr, }; use async_trait::async_trait; @@ -22,10 +24,11 @@ use common_catalog::consts::DEFAULT_CATALOG_NAME; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_grpc::select::to_object_result; -use common_insert::insertion_expr_to_request; +use common_grpc_expr::insertion_expr_to_request; use common_query::Output; use query::plan::LogicalPlan; use servers::query_handler::{GrpcAdminHandler, GrpcQueryHandler}; +use session::context::QueryContext; use snafu::prelude::*; use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan}; use table::requests::CreateDatabaseRequest; @@ -44,7 +47,7 @@ impl Instance { catalog_name: &str, schema_name: &str, table_name: &str, - values: insert_expr::Values, + insert_batches: Vec<(Vec, u32)>, ) -> Result { let schema_provider = self .catalog_manager @@ -55,11 +58,7 @@ impl Instance { .context(CatalogSnafu)? .context(SchemaNotFoundSnafu { name: schema_name })?; - let insert_batches = - common_insert::insert_batches(&values.values).context(InsertDataSnafu)?; - ensure!(!insert_batches.is_empty(), EmptyInsertBatchSnafu); - let table = schema_provider .table(table_name) .context(CatalogSnafu)? @@ -87,10 +86,10 @@ impl Instance { catalog_name: &str, schema_name: &str, table_name: &str, - values: insert_expr::Values, + insert_batches: Vec<(Vec, u32)>, ) -> ObjectResult { match self - .execute_grpc_insert(catalog_name, schema_name, table_name, values) + .execute_grpc_insert(catalog_name, schema_name, table_name, insert_batches) .await { Ok(Output::AffectedRows(rows)) => ObjectResultBuilder::new() @@ -114,7 +113,9 @@ impl Instance { async fn do_handle_select(&self, select_expr: SelectExpr) -> Result { let expr = select_expr.expr; match expr { - Some(select_expr::Expr::Sql(sql)) => self.execute_sql(&sql).await, + Some(select_expr::Expr::Sql(sql)) => { + self.execute_sql(&sql, Arc::new(QueryContext::new())).await + } Some(select_expr::Expr::LogicalPlan(plan)) => self.execute_logical(plan).await, Some(select_expr::Expr::PhysicalPlan(api::v1::PhysicalPlan { original_ql, plan })) => { self.physical_planner @@ -170,25 +171,13 @@ impl GrpcQueryHandler for Instance { let catalog_name = DEFAULT_CATALOG_NAME; let schema_name = &insert_expr.schema_name; let table_name = &insert_expr.table_name; - let expr = insert_expr - .expr - .context(servers::error::InvalidQuerySnafu { - reason: "missing `expr` in `InsertExpr`", - })?; // TODO(fys): _region_number is for later use. let _region_number: u32 = insert_expr.region_number; - match expr { - insert_expr::Expr::Values(values) => { - self.handle_insert(catalog_name, schema_name, table_name, values) - .await - } - insert_expr::Expr::Sql(sql) => { - let output = self.execute_sql(&sql).await; - to_object_result(output).await - } - } + let insert_batches = vec![(insert_expr.columns, insert_expr.row_count)]; + self.handle_insert(catalog_name, schema_name, table_name, insert_batches) + .await } Some(object_expr::Expr::Select(select_expr)) => self.handle_select(select_expr).await, other => { @@ -211,6 +200,9 @@ impl GrpcAdminHandler for Instance { Some(admin_expr::Expr::CreateDatabase(create_database_expr)) => { self.execute_create_database(create_database_expr).await } + Some(admin_expr::Expr::DropTable(drop_table_expr)) => { + self.handle_drop_table(drop_table_expr).await + } other => { return servers::error::NotSupportedSnafu { feat: format!("{:?}", other), diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index ba1793e8f8..80149dda5c 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -13,25 +13,27 @@ // limitations under the License. use async_trait::async_trait; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::prelude::BoxedError; use common_query::Output; +use common_recordbatch::RecordBatches; use common_telemetry::logging::{error, info}; use common_telemetry::timer; use servers::query_handler::SqlQueryHandler; +use session::context::QueryContextRef; use snafu::prelude::*; +use sql::ast::ObjectName; use sql::statements::statement::Statement; +use table::engine::TableReference; use table::requests::CreateDatabaseRequest; -use crate::error::{ - BumpTableIdSnafu, CatalogNotFoundSnafu, CatalogSnafu, ExecuteSqlSnafu, ParseSqlSnafu, Result, - SchemaNotFoundSnafu, TableIdProviderNotFoundSnafu, -}; +use crate::error::{self, BumpTableIdSnafu, ExecuteSqlSnafu, Result, TableIdProviderNotFoundSnafu}; use crate::instance::Instance; use crate::metric; use crate::sql::SqlRequest; impl Instance { - pub async fn execute_sql(&self, sql: &str) -> Result { + pub async fn execute_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result { let stmt = self .query_engine .sql_to_statement(sql) @@ -41,7 +43,7 @@ impl Instance { Statement::Query(_) => { let logical_plan = self .query_engine - .statement_to_plan(stmt) + .statement_to_plan(stmt, query_ctx) .context(ExecuteSqlSnafu)?; self.query_engine @@ -50,20 +52,15 @@ impl Instance { .context(ExecuteSqlSnafu) } Statement::Insert(i) => { - let (catalog_name, schema_name, _table_name) = - i.full_table_name().context(ParseSqlSnafu)?; - - let schema_provider = self - .catalog_manager - .catalog(&catalog_name) - .context(CatalogSnafu)? - .context(CatalogNotFoundSnafu { name: catalog_name })? - .schema(&schema_name) - .context(CatalogSnafu)? - .context(SchemaNotFoundSnafu { name: schema_name })?; - - let request = self.sql_handler.insert_to_request(schema_provider, *i)?; - self.sql_handler.execute(request).await + let (catalog, schema, table) = + table_idents_to_full_name(i.table_name(), query_ctx.clone())?; + let table_ref = TableReference::full(&catalog, &schema, &table); + let request = self.sql_handler.insert_to_request( + self.catalog_manager.clone(), + *i, + table_ref, + )?; + self.sql_handler.execute(request, query_ctx).await } Statement::CreateDatabase(c) => { @@ -74,7 +71,7 @@ impl Instance { info!("Creating a new database: {}", request.db_name); self.sql_handler - .execute(SqlRequest::CreateDatabase(request)) + .execute(SqlRequest::CreateDatabase(request), query_ctx) .await } @@ -89,49 +86,116 @@ impl Instance { let _engine_name = c.engine.clone(); // TODO(hl): Select table engine by engine_name - let request = self.sql_handler.create_to_request(table_id, c)?; - let catalog_name = &request.catalog_name; - let schema_name = &request.schema_name; - let table_name = &request.table_name; + let name = c.name.clone(); + let (catalog, schema, table) = table_idents_to_full_name(&name, query_ctx.clone())?; + let table_ref = TableReference::full(&catalog, &schema, &table); + let request = self.sql_handler.create_to_request(table_id, c, table_ref)?; let table_id = request.id; info!( "Creating table, catalog: {:?}, schema: {:?}, table name: {:?}, table id: {}", - catalog_name, schema_name, table_name, table_id + catalog, schema, table, table_id ); self.sql_handler - .execute(SqlRequest::CreateTable(request)) + .execute(SqlRequest::CreateTable(request), query_ctx) .await } Statement::Alter(alter_table) => { - let req = self.sql_handler.alter_to_request(alter_table)?; - self.sql_handler.execute(SqlRequest::Alter(req)).await + let name = alter_table.table_name().clone(); + let (catalog, schema, table) = table_idents_to_full_name(&name, query_ctx.clone())?; + let table_ref = TableReference::full(&catalog, &schema, &table); + let req = self.sql_handler.alter_to_request(alter_table, table_ref)?; + self.sql_handler + .execute(SqlRequest::Alter(req), query_ctx) + .await + } + Statement::DropTable(drop_table) => { + let req = self.sql_handler.drop_table_to_request(drop_table); + self.sql_handler + .execute(SqlRequest::DropTable(req), query_ctx) + .await } Statement::ShowDatabases(stmt) => { self.sql_handler - .execute(SqlRequest::ShowDatabases(stmt)) + .execute(SqlRequest::ShowDatabases(stmt), query_ctx) .await } Statement::ShowTables(stmt) => { - self.sql_handler.execute(SqlRequest::ShowTables(stmt)).await + self.sql_handler + .execute(SqlRequest::ShowTables(stmt), query_ctx) + .await + } + Statement::Explain(stmt) => { + self.sql_handler + .execute(SqlRequest::Explain(Box::new(stmt)), query_ctx) + .await } Statement::DescribeTable(stmt) => { self.sql_handler - .execute(SqlRequest::DescribeTable(stmt)) + .execute(SqlRequest::DescribeTable(stmt), query_ctx) .await } Statement::ShowCreateTable(_stmt) => { unimplemented!("SHOW CREATE TABLE is unimplemented yet"); } + Statement::Use(db) => { + ensure!( + self.catalog_manager + .schema(DEFAULT_CATALOG_NAME, &db) + .context(error::CatalogSnafu)? + .is_some(), + error::SchemaNotFoundSnafu { name: &db } + ); + + query_ctx.set_current_schema(&db); + + Ok(Output::RecordBatches(RecordBatches::empty())) + } } } } +// TODO(LFC): Refactor consideration: move this function to some helper mod, +// could be done together or after `TableReference`'s refactoring, when issue #559 is resolved. +/// Converts maybe fully-qualified table name (`..`) to tuple. +fn table_idents_to_full_name( + obj_name: &ObjectName, + query_ctx: QueryContextRef, +) -> Result<(String, String, String)> { + match &obj_name.0[..] { + [table] => Ok(( + DEFAULT_CATALOG_NAME.to_string(), + query_ctx.current_schema().unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()), + table.value.clone(), + )), + [schema, table] => Ok(( + DEFAULT_CATALOG_NAME.to_string(), + schema.value.clone(), + table.value.clone(), + )), + [catalog, schema, table] => Ok(( + catalog.value.clone(), + schema.value.clone(), + table.value.clone(), + )), + _ => error::InvalidSqlSnafu { + msg: format!( + "expect table name to be ..
, .
or
, actual: {}", + obj_name + ), + }.fail(), + } +} + #[async_trait] impl SqlQueryHandler for Instance { - async fn do_query(&self, query: &str) -> servers::error::Result { + async fn do_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> servers::error::Result { let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); - self.execute_sql(query) + self.execute_sql(query, query_ctx) .await .map_err(|e| { error!(e; "Instance failed to execute sql"); @@ -140,3 +204,78 @@ impl SqlQueryHandler for Instance { .context(servers::error::ExecuteQuerySnafu { query }) } } + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use session::context::QueryContext; + + use super::*; + + #[test] + fn test_table_idents_to_full_name() { + let my_catalog = "my_catalog"; + let my_schema = "my_schema"; + let my_table = "my_table"; + + let full = ObjectName(vec![my_catalog.into(), my_schema.into(), my_table.into()]); + let partial = ObjectName(vec![my_schema.into(), my_table.into()]); + let bare = ObjectName(vec![my_table.into()]); + + let using_schema = "foo"; + let query_ctx = Arc::new(QueryContext::with_current_schema(using_schema.to_string())); + let empty_ctx = Arc::new(QueryContext::new()); + + assert_eq!( + table_idents_to_full_name(&full, query_ctx.clone()).unwrap(), + ( + my_catalog.to_string(), + my_schema.to_string(), + my_table.to_string() + ) + ); + assert_eq!( + table_idents_to_full_name(&full, empty_ctx.clone()).unwrap(), + ( + my_catalog.to_string(), + my_schema.to_string(), + my_table.to_string() + ) + ); + + assert_eq!( + table_idents_to_full_name(&partial, query_ctx.clone()).unwrap(), + ( + DEFAULT_CATALOG_NAME.to_string(), + my_schema.to_string(), + my_table.to_string() + ) + ); + assert_eq!( + table_idents_to_full_name(&partial, empty_ctx.clone()).unwrap(), + ( + DEFAULT_CATALOG_NAME.to_string(), + my_schema.to_string(), + my_table.to_string() + ) + ); + + assert_eq!( + table_idents_to_full_name(&bare, query_ctx).unwrap(), + ( + DEFAULT_CATALOG_NAME.to_string(), + using_schema.to_string(), + my_table.to_string() + ) + ); + assert_eq!( + table_idents_to_full_name(&bare, empty_ctx).unwrap(), + ( + DEFAULT_CATALOG_NAME.to_string(), + DEFAULT_SCHEMA_NAME.to_string(), + my_table.to_string() + ) + ); + } +} diff --git a/src/datanode/src/lib.rs b/src/datanode/src/lib.rs index 8540c851f6..3e1aa92a76 100644 --- a/src/datanode/src/lib.rs +++ b/src/datanode/src/lib.rs @@ -22,6 +22,6 @@ mod metric; mod mock; mod script; pub mod server; -mod sql; +pub mod sql; #[cfg(test)] mod tests; diff --git a/src/datanode/src/mock.rs b/src/datanode/src/mock.rs index 240b68e04d..73b758cc13 100644 --- a/src/datanode/src/mock.rs +++ b/src/datanode/src/mock.rs @@ -58,7 +58,11 @@ impl Instance { let factory = QueryEngineFactory::new(catalog_manager.clone()); let query_engine = factory.query_engine(); - let sql_handler = SqlHandler::new(mock_engine.clone(), catalog_manager.clone()); + let sql_handler = SqlHandler::new( + mock_engine.clone(), + catalog_manager.clone(), + query_engine.clone(), + ); let physical_planner = PhysicalPlanner::new(query_engine.clone()); let script_executor = ScriptExecutor::new(catalog_manager.clone(), query_engine.clone()) .await @@ -123,7 +127,11 @@ impl Instance { ); Ok(Self { query_engine: query_engine.clone(), - sql_handler: SqlHandler::new(table_engine, catalog_manager.clone()), + sql_handler: SqlHandler::new( + table_engine, + catalog_manager.clone(), + query_engine.clone(), + ), catalog_manager, physical_planner: PhysicalPlanner::new(query_engine), script_executor, diff --git a/src/datanode/src/server.rs b/src/datanode/src/server.rs index f2cef8ca74..e77e8c20a1 100644 --- a/src/datanode/src/server.rs +++ b/src/datanode/src/server.rs @@ -62,6 +62,7 @@ impl Services { Some(MysqlServer::create_server( instance.clone(), mysql_io_runtime, + Default::default(), )) } }; diff --git a/src/datanode/src/server/grpc/ddl.rs b/src/datanode/src/server/grpc/ddl.rs index 7a3980c6f6..26108eb020 100644 --- a/src/datanode/src/server/grpc/ddl.rs +++ b/src/datanode/src/server/grpc/ddl.rs @@ -15,19 +15,17 @@ use std::sync::Arc; use api::result::AdminResultBuilder; -use api::v1::alter_expr::Kind; -use api::v1::{AdminResult, AlterExpr, CreateExpr, DropColumns}; -use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use api::v1::{AdminResult, AlterExpr, CreateExpr, DropTableExpr}; use common_error::prelude::{ErrorExt, StatusCode}; +use common_grpc_expr::{alter_expr_to_request, create_expr_to_request}; use common_query::Output; use common_telemetry::{error, info}; -use datatypes::schema::{ColumnSchema, SchemaBuilder, SchemaRef}; use futures::TryFutureExt; +use session::context::QueryContext; use snafu::prelude::*; -use table::metadata::TableId; -use table::requests::{AddColumnRequest, AlterKind, AlterTableRequest, CreateTableRequest}; +use table::requests::DropTableRequest; -use crate::error::{self, BumpTableIdSnafu, MissingFieldSnafu, Result}; +use crate::error::{AlterExprToRequestSnafu, BumpTableIdSnafu, CreateExprToRequestSnafu}; use crate::instance::Instance; use crate::sql::SqlRequest; @@ -75,9 +73,14 @@ impl Instance { } }; - let request = create_expr_to_request(table_id, expr).await; + let request = create_expr_to_request(table_id, expr).context(CreateExprToRequestSnafu); let result = futures::future::ready(request) - .and_then(|request| self.sql_handler().execute(SqlRequest::CreateTable(request))) + .and_then(|request| { + self.sql_handler().execute( + SqlRequest::CreateTable(request), + Arc::new(QueryContext::new()), + ) + }) .await; match result { Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default() @@ -94,18 +97,24 @@ impl Instance { } pub(crate) async fn handle_alter(&self, expr: AlterExpr) -> AdminResult { - let request = match alter_expr_to_request(expr).transpose() { - Some(req) => req, + let request = match alter_expr_to_request(expr) + .context(AlterExprToRequestSnafu) + .transpose() + { None => { return AdminResultBuilder::default() .status_code(StatusCode::Success as u32) .mutate_result(0, 0) .build() } + Some(req) => req, }; let result = futures::future::ready(request) - .and_then(|request| self.sql_handler().execute(SqlRequest::Alter(request))) + .and_then(|request| { + self.sql_handler() + .execute(SqlRequest::Alter(request), Arc::new(QueryContext::new())) + }) .await; match result { Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default() @@ -119,156 +128,50 @@ impl Instance { .build(), } } -} -async fn create_expr_to_request(table_id: TableId, expr: CreateExpr) -> Result { - let schema = create_table_schema(&expr)?; - let primary_key_indices = expr - .primary_keys - .iter() - .map(|key| { - schema - .column_index_by_name(key) - .context(error::KeyColumnNotFoundSnafu { name: key }) - }) - .collect::>>()?; - - let catalog_name = expr - .catalog_name - .unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string()); - let schema_name = expr - .schema_name - .unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()); - - let region_ids = if expr.region_ids.is_empty() { - vec![0] - } else { - expr.region_ids - }; - - Ok(CreateTableRequest { - id: table_id, - catalog_name, - schema_name, - table_name: expr.table_name, - desc: expr.desc, - schema, - region_numbers: region_ids, - primary_key_indices, - create_if_not_exists: expr.create_if_not_exists, - table_options: expr.table_options, - }) -} - -fn alter_expr_to_request(expr: AlterExpr) -> Result> { - match expr.kind { - Some(Kind::AddColumns(add_columns)) => { - let mut add_column_requests = vec![]; - for add_column_expr in add_columns.add_columns { - let column_def = add_column_expr.column_def.context(MissingFieldSnafu { - field: "column_def", - })?; - - let schema = - column_def - .try_as_column_schema() - .context(error::InvalidColumnDefSnafu { - column: &column_def.name, - })?; - add_column_requests.push(AddColumnRequest { - column_schema: schema, - is_key: add_column_expr.is_key, - }) - } - - let alter_kind = AlterKind::AddColumns { - columns: add_column_requests, - }; - - let request = AlterTableRequest { - catalog_name: expr.catalog_name, - schema_name: expr.schema_name, - table_name: expr.table_name, - alter_kind, - }; - Ok(Some(request)) + pub(crate) async fn handle_drop_table(&self, expr: DropTableExpr) -> AdminResult { + let req = DropTableRequest { + catalog_name: expr.catalog_name, + schema_name: expr.schema_name, + table_name: expr.table_name, + }; + let result = self + .sql_handler() + .execute(SqlRequest::DropTable(req), Arc::new(QueryContext::new())) + .await; + match result { + Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default() + .status_code(StatusCode::Success as u32) + .mutate_result(rows as _, 0) + .build(), + Ok(Output::Stream(_)) | Ok(Output::RecordBatches(_)) => unreachable!(), + Err(err) => AdminResultBuilder::default() + .status_code(err.status_code() as u32) + .err_msg(err.to_string()) + .build(), } - Some(Kind::DropColumns(DropColumns { drop_columns })) => { - let alter_kind = AlterKind::DropColumns { - names: drop_columns.into_iter().map(|c| c.name).collect(), - }; - - let request = AlterTableRequest { - catalog_name: expr.catalog_name, - schema_name: expr.schema_name, - table_name: expr.table_name, - alter_kind, - }; - Ok(Some(request)) - } - None => Ok(None), } } -fn create_table_schema(expr: &CreateExpr) -> Result { - let column_schemas = expr - .column_defs - .iter() - .map(|x| { - x.try_as_column_schema() - .context(error::InvalidColumnDefSnafu { column: &x.name }) - }) - .collect::>>()?; - - ensure!( - column_schemas - .iter() - .any(|column| column.name == expr.time_index), - error::KeyColumnNotFoundSnafu { - name: &expr.time_index, - } - ); - - let column_schemas = column_schemas - .into_iter() - .map(|column_schema| { - if column_schema.name == expr.time_index { - column_schema.with_time_index(true) - } else { - column_schema - } - }) - .collect::>(); - - Ok(Arc::new( - SchemaBuilder::try_from(column_schemas) - .context(error::CreateSchemaSnafu)? - .build() - .context(error::CreateSchemaSnafu)?, - )) -} - #[cfg(test)] mod tests { - use api::v1::ColumnDef; + use std::sync::Arc; + + use api::v1::{ColumnDataType, ColumnDef}; use common_catalog::consts::MIN_USER_TABLE_ID; + use common_grpc_expr::create_table_schema; use datatypes::prelude::ConcreteDataType; - use datatypes::schema::ColumnDefaultConstraint; + use datatypes::schema::{ColumnDefaultConstraint, ColumnSchema, SchemaBuilder, SchemaRef}; use datatypes::value::Value; use super::*; - use crate::tests::test_util; #[tokio::test(flavor = "multi_thread")] async fn test_create_expr_to_request() { common_telemetry::init_default_ut_logging(); - let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("create_expr_to_request"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); - let expr = testing_create_expr(); - let request = create_expr_to_request(1024, expr).await.unwrap(); - assert_eq!(request.id, common_catalog::consts::MIN_USER_TABLE_ID); + let request = create_expr_to_request(1024, expr).unwrap(); + assert_eq!(request.id, MIN_USER_TABLE_ID); assert_eq!(request.catalog_name, "greptime".to_string()); assert_eq!(request.schema_name, "public".to_string()); assert_eq!(request.table_name, "my-metrics"); @@ -279,12 +182,13 @@ mod tests { let mut expr = testing_create_expr(); expr.primary_keys = vec!["host".to_string(), "not-exist-column".to_string()]; - let result = create_expr_to_request(1025, expr).await; - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Specified timestamp key or primary key column not found: not-exist-column")); + let result = create_expr_to_request(1025, expr); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Column `not-exist-column` not found in table `my-metrics`"), + "{}", + err_msg + ); } #[test] @@ -295,14 +199,16 @@ mod tests { expr.time_index = "not-exist-column".to_string(); let result = create_table_schema(&expr); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Specified timestamp key or primary key column not found: not-exist-column")); + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Missing timestamp column"), + "actual: {}", + err_msg + ); } #[test] + fn test_create_column_schema() { let column_def = ColumnDef { name: "a".to_string(), @@ -318,7 +224,7 @@ mod tests { let column_def = ColumnDef { name: "a".to_string(), - datatype: 12, // string + datatype: ColumnDataType::String as i32, is_nullable: true, default_constraint: None, }; @@ -330,7 +236,7 @@ mod tests { let default_constraint = ColumnDefaultConstraint::Value(Value::from("default value")); let column_def = ColumnDef { name: "a".to_string(), - datatype: 12, // string + datatype: ColumnDataType::String as i32, is_nullable: true, default_constraint: Some(default_constraint.clone().try_into().unwrap()), }; @@ -348,25 +254,25 @@ mod tests { let column_defs = vec![ ColumnDef { name: "host".to_string(), - datatype: 12, // string + datatype: ColumnDataType::String as i32, is_nullable: false, default_constraint: None, }, ColumnDef { name: "ts".to_string(), - datatype: 15, // timestamp + datatype: ColumnDataType::Timestamp as i32, is_nullable: false, default_constraint: None, }, ColumnDef { name: "cpu".to_string(), - datatype: 9, // float32 + datatype: ColumnDataType::Float32 as i32, is_nullable: true, default_constraint: None, }, ColumnDef { name: "memory".to_string(), - datatype: 10, // float64 + datatype: ColumnDataType::Float64 as i32, is_nullable: true, default_constraint: None, }, diff --git a/src/datanode/src/sql.rs b/src/datanode/src/sql.rs index 8f989badef..0a3b4a999e 100644 --- a/src/datanode/src/sql.rs +++ b/src/datanode/src/sql.rs @@ -12,22 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! sql handler - use catalog::CatalogManagerRef; use common_query::Output; -use query::sql::{describe_table, show_databases, show_tables}; +use common_telemetry::error; +use query::query_engine::QueryEngineRef; +use query::sql::{describe_table, explain, show_databases, show_tables}; +use session::context::QueryContextRef; use snafu::{OptionExt, ResultExt}; use sql::statements::describe::DescribeTable; +use sql::statements::explain::Explain; use sql::statements::show::{ShowDatabases, ShowTables}; use table::engine::{EngineContext, TableEngineRef, TableReference}; use table::requests::*; use table::TableRef; -use crate::error::{self, GetTableSnafu, Result, TableNotFoundSnafu}; +use crate::error::{ExecuteSqlSnafu, GetTableSnafu, Result, TableNotFoundSnafu}; mod alter; mod create; +mod drop_table; mod insert; #[derive(Debug)] @@ -36,41 +39,61 @@ pub enum SqlRequest { CreateTable(CreateTableRequest), CreateDatabase(CreateDatabaseRequest), Alter(AlterTableRequest), + DropTable(DropTableRequest), ShowDatabases(ShowDatabases), ShowTables(ShowTables), DescribeTable(DescribeTable), + Explain(Box), } // Handler to execute SQL except query pub struct SqlHandler { table_engine: TableEngineRef, catalog_manager: CatalogManagerRef, + query_engine: QueryEngineRef, } impl SqlHandler { - pub fn new(table_engine: TableEngineRef, catalog_manager: CatalogManagerRef) -> Self { + pub fn new( + table_engine: TableEngineRef, + catalog_manager: CatalogManagerRef, + query_engine: QueryEngineRef, + ) -> Self { Self { table_engine, catalog_manager, + query_engine, } } - pub async fn execute(&self, request: SqlRequest) -> Result { - match request { + // TODO(LFC): Refactor consideration: a context awareness "Planner". + // Now we have some query related state (like current using database in session context), maybe + // we could create a new struct called `Planner` that stores context and handle these queries + // there, instead of executing here in a "static" fashion. + pub async fn execute(&self, request: SqlRequest, query_ctx: QueryContextRef) -> Result { + let result = match request { SqlRequest::Insert(req) => self.insert(req).await, SqlRequest::CreateTable(req) => self.create_table(req).await, SqlRequest::CreateDatabase(req) => self.create_database(req).await, SqlRequest::Alter(req) => self.alter(req).await, + SqlRequest::DropTable(req) => self.drop_table(req).await, SqlRequest::ShowDatabases(stmt) => { - show_databases(stmt, self.catalog_manager.clone()).context(error::ExecuteSqlSnafu) + show_databases(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu) } SqlRequest::ShowTables(stmt) => { - show_tables(stmt, self.catalog_manager.clone()).context(error::ExecuteSqlSnafu) + show_tables(stmt, self.catalog_manager.clone(), query_ctx).context(ExecuteSqlSnafu) } SqlRequest::DescribeTable(stmt) => { - describe_table(stmt, self.catalog_manager.clone()).context(error::ExecuteSqlSnafu) + describe_table(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu) } + SqlRequest::Explain(stmt) => explain(stmt, self.query_engine.clone(), query_ctx) + .await + .context(ExecuteSqlSnafu), + }; + if let Err(e) = &result { + error!("Datanode execution error: {:?}", e); } + result } pub(crate) fn get_table<'a>(&self, table_ref: &'a TableReference) -> Result { @@ -94,7 +117,8 @@ mod tests { use std::any::Any; use std::sync::Arc; - use catalog::SchemaProvider; + use catalog::{CatalogList, SchemaProvider}; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::logical_plan::Expr; use common_query::physical_plan::PhysicalPlanRef; use common_time::timestamp::Timestamp; @@ -214,9 +238,17 @@ mod tests { .await .unwrap(), ); + let catalog_provider = catalog_list.catalog(DEFAULT_CATALOG_NAME).unwrap().unwrap(); + catalog_provider + .register_schema( + DEFAULT_SCHEMA_NAME.to_string(), + Arc::new(MockSchemaProvider {}), + ) + .unwrap(); + let factory = QueryEngineFactory::new(catalog_list.clone()); let query_engine = factory.query_engine(); - let sql_handler = SqlHandler::new(table_engine, catalog_list); + let sql_handler = SqlHandler::new(table_engine, catalog_list.clone(), query_engine.clone()); let stmt = match query_engine.sql_to_statement(sql).unwrap() { Statement::Insert(i) => i, @@ -224,9 +256,8 @@ mod tests { unreachable!() } }; - let schema_provider = Arc::new(MockSchemaProvider {}); let request = sql_handler - .insert_to_request(schema_provider, *stmt) + .insert_to_request(catalog_list.clone(), *stmt, TableReference::bare("demo")) .unwrap(); match request { diff --git a/src/datanode/src/sql/alter.rs b/src/datanode/src/sql/alter.rs index 077ebd0a9c..77fada09fd 100644 --- a/src/datanode/src/sql/alter.rs +++ b/src/datanode/src/sql/alter.rs @@ -16,7 +16,7 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use snafu::prelude::*; use sql::statements::alter::{AlterTable, AlterTableOperation}; -use sql::statements::{column_def_to_schema, table_idents_to_full_name}; +use sql::statements::column_def_to_schema; use table::engine::{EngineContext, TableReference}; use table::requests::{AddColumnRequest, AlterKind, AlterTableRequest}; @@ -53,10 +53,11 @@ impl SqlHandler { Ok(Output::AffectedRows(0)) } - pub(crate) fn alter_to_request(&self, alter_table: AlterTable) -> Result { - let (catalog_name, schema_name, table_name) = - table_idents_to_full_name(alter_table.table_name()).context(error::ParseSqlSnafu)?; - + pub(crate) fn alter_to_request( + &self, + alter_table: AlterTable, + table_ref: TableReference, + ) -> Result { let alter_kind = match alter_table.alter_operation() { AlterTableOperation::AddConstraint(table_constraint) => { return error::InvalidSqlSnafu { @@ -77,9 +78,9 @@ impl SqlHandler { }, }; Ok(AlterTableRequest { - catalog_name: Some(catalog_name), - schema_name: Some(schema_name), - table_name, + catalog_name: Some(table_ref.catalog.to_string()), + schema_name: Some(table_ref.schema.to_string()), + table_name: table_ref.table.to_string(), alter_kind, }) } @@ -112,7 +113,9 @@ mod tests { async fn test_alter_to_request_with_adding_column() { let handler = create_mock_sql_handler().await; let alter_table = parse_sql("ALTER TABLE my_metric_1 ADD tagk_i STRING Null;"); - let req = handler.alter_to_request(alter_table).unwrap(); + let req = handler + .alter_to_request(alter_table, TableReference::bare("my_metric_1")) + .unwrap(); assert_eq!(req.catalog_name, Some("greptime".to_string())); assert_eq!(req.schema_name, Some("public".to_string())); assert_eq!(req.table_name, "my_metric_1"); diff --git a/src/datanode/src/sql/create.rs b/src/datanode/src/sql/create.rs index d02de70211..8b75bdef3f 100644 --- a/src/datanode/src/sql/create.rs +++ b/src/datanode/src/sql/create.rs @@ -23,10 +23,10 @@ use common_telemetry::tracing::log::error; use datatypes::schema::SchemaBuilder; use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::TableConstraint; +use sql::statements::column_def_to_schema; use sql::statements::create::CreateTable; -use sql::statements::{column_def_to_schema, table_idents_to_full_name}; use store_api::storage::consts::TIME_INDEX_NAME; -use table::engine::EngineContext; +use table::engine::{EngineContext, TableReference}; use table::metadata::TableId; use table::requests::*; @@ -84,7 +84,6 @@ impl SqlHandler { // determine catalog and schema from the very beginning let table_name = req.table_name.clone(); - let table_id = req.id; let table = self .table_engine .create_table(&ctx, req) @@ -97,7 +96,7 @@ impl SqlHandler { catalog: table.table_info().catalog_name.clone(), schema: table.table_info().schema_name.clone(), table_name: table_name.clone(), - table_id, + table_id: table.table_info().ident.table_id, table, }; @@ -115,13 +114,11 @@ impl SqlHandler { &self, table_id: TableId, stmt: CreateTable, + table_ref: TableReference, ) -> Result { let mut ts_index = usize::MAX; let mut primary_keys = vec![]; - let (catalog_name, schema_name, table_name) = - table_idents_to_full_name(&stmt.name).context(error::ParseSqlSnafu)?; - let col_map = stmt .columns .iter() @@ -172,7 +169,7 @@ impl SqlHandler { return ConstraintNotSupportedSnafu { constraint: format!("{:?}", c), } - .fail() + .fail(); } } } @@ -186,14 +183,6 @@ impl SqlHandler { ensure!(ts_index != usize::MAX, error::MissingTimestampColumnSnafu); - if primary_keys.is_empty() { - info!( - "Creating table: {:?}.{:?}.{} but primary key not set, use time index column: {}", - catalog_name, schema_name, table_name, ts_index - ); - primary_keys.push(ts_index); - } - let columns_schemas: Vec<_> = stmt .columns .iter() @@ -212,9 +201,9 @@ impl SqlHandler { let request = CreateTableRequest { id: table_id, - catalog_name, - schema_name, - table_name, + catalog_name: table_ref.catalog.to_string(), + schema_name: table_ref.schema.to_string(), + table_name: table_ref.table.to_string(), desc: None, schema, region_numbers: vec![0], @@ -262,7 +251,9 @@ mod tests { TIME INDEX (ts), PRIMARY KEY(host)) engine=mito with(regions=1);"#, ); - let c = handler.create_to_request(42, parsed_stmt).unwrap(); + let c = handler + .create_to_request(42, parsed_stmt, TableReference::bare("demo_table")) + .unwrap(); assert_eq!("demo_table", c.table_name); assert_eq!(42, c.id); assert!(!c.create_if_not_exists); @@ -283,11 +274,12 @@ mod tests { memory double, PRIMARY KEY(host)) engine=mito with(regions=1);"#, ); - let error = handler.create_to_request(42, parsed_stmt).unwrap_err(); + let error = handler + .create_to_request(42, parsed_stmt, TableReference::bare("demo_table")) + .unwrap_err(); assert_matches!(error, Error::MissingTimestampColumn { .. }); } - /// If primary key is not specified, time index should be used as primary key. #[tokio::test] pub async fn test_primary_key_not_specified() { let handler = create_mock_sql_handler().await; @@ -300,12 +292,11 @@ mod tests { memory double, TIME INDEX (ts)) engine=mito with(regions=1);"#, ); - let c = handler.create_to_request(42, parsed_stmt).unwrap(); - assert_eq!(1, c.primary_key_indices.len()); - assert_eq!( - c.schema.timestamp_index().unwrap(), - c.primary_key_indices[0] - ); + let c = handler + .create_to_request(42, parsed_stmt, TableReference::bare("demo_table")) + .unwrap(); + assert!(c.primary_key_indices.is_empty()); + assert_eq!(c.schema.timestamp_index(), Some(1)); } /// Constraints specified, not column cannot be found. @@ -319,7 +310,9 @@ mod tests { TIME INDEX (ts)) engine=mito with(regions=1);"#, ); - let error = handler.create_to_request(42, parsed_stmt).unwrap_err(); + let error = handler + .create_to_request(42, parsed_stmt, TableReference::bare("demo_table")) + .unwrap_err(); assert_matches!(error, Error::KeyColumnNotFound { .. }); } @@ -339,7 +332,9 @@ mod tests { let handler = create_mock_sql_handler().await; - let error = handler.create_to_request(42, create_table).unwrap_err(); + let error = handler + .create_to_request(42, create_table, TableReference::full("c", "s", "demo")) + .unwrap_err(); assert_matches!(error, Error::InvalidPrimaryKey { .. }); } @@ -359,7 +354,9 @@ mod tests { let handler = create_mock_sql_handler().await; - let request = handler.create_to_request(42, create_table).unwrap(); + let request = handler + .create_to_request(42, create_table, TableReference::full("c", "s", "demo")) + .unwrap(); assert_eq!(42, request.id); assert_eq!("c".to_string(), request.catalog_name); diff --git a/src/datanode/src/sql/drop_table.rs b/src/datanode/src/sql/drop_table.rs new file mode 100644 index 0000000000..4a56b669c9 --- /dev/null +++ b/src/datanode/src/sql/drop_table.rs @@ -0,0 +1,71 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use catalog::DeregisterTableRequest; +use common_error::prelude::BoxedError; +use common_query::Output; +use common_telemetry::info; +use snafu::ResultExt; +use sql::statements::drop::DropTable; +use table::engine::{EngineContext, TableReference}; +use table::requests::DropTableRequest; + +use crate::error::{self, Result}; +use crate::sql::SqlHandler; + +impl SqlHandler { + pub async fn drop_table(&self, req: DropTableRequest) -> Result { + let deregister_table_req = DeregisterTableRequest { + catalog: req.catalog_name.clone(), + schema: req.schema_name.clone(), + table_name: req.table_name.clone(), + }; + + let table_reference = TableReference { + catalog: &req.catalog_name, + schema: &req.schema_name, + table: &req.table_name, + }; + let table_full_name = table_reference.to_string(); + + self.catalog_manager + .deregister_table(deregister_table_req) + .await + .map_err(BoxedError::new) + .context(error::DropTableSnafu { + table_name: table_full_name.clone(), + })?; + + let ctx = EngineContext {}; + self.table_engine() + .drop_table(&ctx, req) + .await + .map_err(BoxedError::new) + .context(error::DropTableSnafu { + table_name: table_full_name.clone(), + })?; + + info!("Successfully dropped table: {}", table_full_name); + + Ok(Output::AffectedRows(1)) + } + + pub fn drop_table_to_request(&self, drop_table: DropTable) -> DropTableRequest { + DropTableRequest { + catalog_name: drop_table.catalog_name, + schema_name: drop_table.schema_name, + table_name: drop_table.table_name, + } + } +} diff --git a/src/datanode/src/sql/insert.rs b/src/datanode/src/sql/insert.rs index 00aa59a026..8c2dae5c4a 100644 --- a/src/datanode/src/sql/insert.rs +++ b/src/datanode/src/sql/insert.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use catalog::SchemaProviderRef; +use catalog::CatalogManagerRef; use common_query::Output; use datatypes::prelude::{ConcreteDataType, VectorBuilder}; use snafu::{ensure, OptionExt, ResultExt}; @@ -23,7 +23,7 @@ use table::engine::TableReference; use table::requests::*; use crate::error::{ - CatalogSnafu, ColumnNotFoundSnafu, ColumnValuesNumberMismatchSnafu, InsertSnafu, ParseSqlSnafu, + CatalogSnafu, ColumnNotFoundSnafu, ColumnValuesNumberMismatchSnafu, InsertSnafu, ParseSqlValueSnafu, Result, TableNotFoundSnafu, }; use crate::sql::{SqlHandler, SqlRequest}; @@ -49,19 +49,18 @@ impl SqlHandler { pub(crate) fn insert_to_request( &self, - schema_provider: SchemaProviderRef, + catalog_manager: CatalogManagerRef, stmt: Insert, + table_ref: TableReference, ) -> Result { let columns = stmt.columns(); let values = stmt.values().context(ParseSqlValueSnafu)?; - let (catalog_name, schema_name, table_name) = - stmt.full_table_name().context(ParseSqlSnafu)?; - let table = schema_provider - .table(&table_name) + let table = catalog_manager + .table(table_ref.catalog, table_ref.schema, table_ref.table) .context(CatalogSnafu)? .context(TableNotFoundSnafu { - table_name: &table_name, + table_name: table_ref.table, })?; let schema = table.schema(); let columns_num = if columns.is_empty() { @@ -88,7 +87,7 @@ impl SqlHandler { let column_schema = schema.column_schema_by_name(column_name).with_context(|| { ColumnNotFoundSnafu { - table_name: &table_name, + table_name: table_ref.table, column_name: column_name.to_string(), } })?; @@ -119,9 +118,9 @@ impl SqlHandler { } Ok(SqlRequest::Insert(InsertRequest { - catalog_name, - schema_name, - table_name, + catalog_name: table_ref.catalog.to_string(), + schema_name: table_ref.schema.to_string(), + table_name: table_ref.table.to_string(), columns_values: columns_builders .into_iter() .map(|(c, _, mut b)| (c.to_owned(), b.finish())) diff --git a/src/datanode/src/tests.rs b/src/datanode/src/tests.rs index 5cb02b3453..8c460a53fd 100644 --- a/src/datanode/src/tests.rs +++ b/src/datanode/src/tests.rs @@ -12,7 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod grpc_test; -mod http_test; mod instance_test; pub(crate) mod test_util; diff --git a/src/datanode/src/tests/instance_test.rs b/src/datanode/src/tests/instance_test.rs index 6914058ffb..b93759b3c7 100644 --- a/src/datanode/src/tests/instance_test.rs +++ b/src/datanode/src/tests/instance_test.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + +use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_query::Output; use common_recordbatch::util; use datafusion::arrow_print; @@ -19,6 +22,7 @@ use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::arrow::array::{Int64Array, UInt64Array, Utf8Array}; use datatypes::arrow_array::StringArray; use datatypes::prelude::ConcreteDataType; +use session::context::QueryContext; use crate::instance::Instance; use crate::tests::test_util; @@ -32,39 +36,33 @@ async fn test_create_database_and_insert_query() { let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); instance.start().await.unwrap(); - let output = instance.execute_sql("create database test").await.unwrap(); + let output = execute_sql(&instance, "create database test").await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql( - r#"create table greptime.test.demo( + let output = execute_sql( + &instance, + r#"create table greptime.test.demo( host STRING, cpu DOUBLE, memory DOUBLE, ts bigint, TIME INDEX(ts) )"#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql( - r#"insert into test.demo(host, cpu, memory, ts) values + let output = execute_sql( + &instance, + r#"insert into test.demo(host, cpu, memory, ts) values ('host1', 66.6, 1024, 1655276557000), ('host2', 88.8, 333.3, 1655276558000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(2))); - let query_output = instance - .execute_sql("select ts from test.demo order by ts") - .await - .unwrap(); - + let query_output = execute_sql(&instance, "select ts from test.demo order by ts").await; match query_output { Output::Stream(s) => { let batches = util::collect(s).await.unwrap(); @@ -88,54 +86,50 @@ async fn test_issue477_same_table_name_in_different_databases() { instance.start().await.unwrap(); // Create database a and b - let output = instance.execute_sql("create database a").await.unwrap(); + let output = execute_sql(&instance, "create database a").await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance.execute_sql("create database b").await.unwrap(); + let output = execute_sql(&instance, "create database b").await; assert!(matches!(output, Output::AffectedRows(1))); // Create table a.demo and b.demo - let output = instance - .execute_sql( - r#"create table a.demo( + let output = execute_sql( + &instance, + r#"create table a.demo( host STRING, ts bigint, TIME INDEX(ts) )"#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql( - r#"create table b.demo( + let output = execute_sql( + &instance, + r#"create table b.demo( host STRING, ts bigint, TIME INDEX(ts) )"#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); // Insert different data into a.demo and b.demo - let output = instance - .execute_sql( - r#"insert into a.demo(host, ts) values + let output = execute_sql( + &instance, + r#"insert into a.demo(host, ts) values ('host1', 1655276557000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql( - r#"insert into b.demo(host, ts) values + let output = execute_sql( + &instance, + r#"insert into b.demo(host, ts) values ('host2',1655276558000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); // Query data and assert @@ -157,7 +151,7 @@ async fn test_issue477_same_table_name_in_different_databases() { } async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str) { - let query_output = instance.execute_sql(sql).await.unwrap(); + let query_output = execute_sql(instance, sql).await; match query_output { Output::Stream(s) => { let batches = util::collect(s).await.unwrap(); @@ -200,15 +194,14 @@ async fn setup_test_instance() -> Instance { #[tokio::test(flavor = "multi_thread")] async fn test_execute_insert() { let instance = setup_test_instance().await; - let output = instance - .execute_sql( - r#"insert into demo(host, cpu, memory, ts) values + let output = execute_sql( + &instance, + r#"insert into demo(host, cpu, memory, ts) values ('host1', 66.6, 1024, 1655276557000), ('host2', 88.8, 333.3, 1655276558000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(2))); } @@ -228,22 +221,17 @@ async fn test_execute_insert_query_with_i64_timestamp() { .await .unwrap(); - let output = instance - .execute_sql( - r#"insert into demo(host, cpu, memory, ts) values + let output = execute_sql( + &instance, + r#"insert into demo(host, cpu, memory, ts) values ('host1', 66.6, 1024, 1655276557000), ('host2', 88.8, 333.3, 1655276558000) "#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(2))); - let query_output = instance - .execute_sql("select ts from demo order by ts") - .await - .unwrap(); - + let query_output = execute_sql(&instance, "select ts from demo order by ts").await; match query_output { Output::Stream(s) => { let batches = util::collect(s).await.unwrap(); @@ -257,11 +245,7 @@ async fn test_execute_insert_query_with_i64_timestamp() { _ => unreachable!(), } - let query_output = instance - .execute_sql("select ts as time from demo order by ts") - .await - .unwrap(); - + let query_output = execute_sql(&instance, "select ts as time from demo order by ts").await; match query_output { Output::Stream(s) => { let batches = util::collect(s).await.unwrap(); @@ -282,10 +266,7 @@ async fn test_execute_query() { let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); instance.start().await.unwrap(); - let output = instance - .execute_sql("select sum(number) from numbers limit 20") - .await - .unwrap(); + let output = execute_sql(&instance, "select sum(number) from numbers limit 20").await; match output { Output::Stream(recordbatch) => { let numbers = util::collect(recordbatch).await.unwrap(); @@ -309,7 +290,7 @@ async fn test_execute_show_databases_tables() { let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); instance.start().await.unwrap(); - let output = instance.execute_sql("show databases").await.unwrap(); + let output = execute_sql(&instance, "show databases").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -325,10 +306,7 @@ async fn test_execute_show_databases_tables() { _ => unreachable!(), } - let output = instance - .execute_sql("show databases like '%bl%'") - .await - .unwrap(); + let output = execute_sql(&instance, "show databases like '%bl%'").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -344,7 +322,7 @@ async fn test_execute_show_databases_tables() { _ => unreachable!(), } - let output = instance.execute_sql("show tables").await.unwrap(); + let output = execute_sql(&instance, "show tables").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -364,7 +342,7 @@ async fn test_execute_show_databases_tables() { .await .unwrap(); - let output = instance.execute_sql("show tables").await.unwrap(); + let output = execute_sql(&instance, "show tables").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -376,10 +354,7 @@ async fn test_execute_show_databases_tables() { } // show tables like [string] - let output = instance - .execute_sql("show tables like 'de%'") - .await - .unwrap(); + let output = execute_sql(&instance, "show tables like 'de%'").await; match output { Output::RecordBatches(databases) => { let databases = databases.take(); @@ -404,9 +379,9 @@ pub async fn test_execute_create() { let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); instance.start().await.unwrap(); - let output = instance - .execute_sql( - r#"create table test_table( + let output = execute_sql( + &instance, + r#"create table test_table( host string, ts timestamp, cpu double default 0, @@ -414,56 +389,24 @@ pub async fn test_execute_create() { TIME INDEX (ts), PRIMARY KEY(host) ) engine=mito with(regions=1);"#, - ) - .await - .unwrap(); + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); } -#[tokio::test(flavor = "multi_thread")] -pub async fn test_create_table_illegal_timestamp_type() { - common_telemetry::init_default_ut_logging(); - - let (opts, _guard) = - test_util::create_tmp_dir_and_datanode_opts("create_table_illegal_timestamp_type"); - let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); - instance.start().await.unwrap(); - - let output = instance - .execute_sql( - r#"create table test_table( - host string, - ts bigint, - cpu double default 0, - memory double, - TIME INDEX (ts), - PRIMARY KEY(host) - ) engine=mito with(regions=1);"#, - ) - .await - .unwrap(); - match output { - Output::AffectedRows(rows) => { - assert_eq!(1, rows); - } - _ => unreachable!(), - } -} - async fn check_output_stream(output: Output, expected: Vec<&str>) { - match output { - Output::Stream(stream) => { - let recordbatches = util::collect(stream).await.unwrap(); - let recordbatch = recordbatches - .into_iter() - .map(|r| r.df_recordbatch) - .collect::>(); - let pretty_print = arrow_print::write(&recordbatch); - let pretty_print = pretty_print.lines().collect::>(); - assert_eq!(pretty_print, expected); - } + let recordbatches = match output { + Output::Stream(stream) => util::collect(stream).await.unwrap(), + Output::RecordBatches(recordbatches) => recordbatches.take(), _ => unreachable!(), - } + }; + let recordbatches = recordbatches + .into_iter() + .map(|r| r.df_recordbatch) + .collect::>(); + let pretty_print = arrow_print::write(&recordbatches); + let pretty_print = pretty_print.lines().collect::>(); + assert_eq!(pretty_print, expected); } #[tokio::test] @@ -479,35 +422,30 @@ async fn test_alter_table() { .await .unwrap(); // make sure table insertion is ok before altering table - instance - .execute_sql("insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000)") - .await - .unwrap(); + execute_sql( + &instance, + "insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000)", + ) + .await; // Add column - let output = instance - .execute_sql("alter table demo add my_tag string null") - .await - .unwrap(); + let output = execute_sql(&instance, "alter table demo add my_tag string null").await; assert!(matches!(output, Output::AffectedRows(0))); - let output = instance - .execute_sql( - "insert into demo(host, cpu, memory, ts, my_tag) values ('host2', 2.2, 200, 2000, 'hello')", - ) - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into demo(host, cpu, memory, ts, my_tag) values ('host2', 2.2, 200, 2000, 'hello')", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql("insert into demo(host, cpu, memory, ts) values ('host3', 3.3, 300, 3000)") - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into demo(host, cpu, memory, ts) values ('host3', 3.3, 300, 3000)", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql("select * from demo order by ts") - .await - .unwrap(); + let output = execute_sql(&instance, "select * from demo order by ts").await; let expected = vec![ "+-------+-----+--------+---------------------+--------+", "| host | cpu | memory | ts | my_tag |", @@ -520,16 +458,10 @@ async fn test_alter_table() { check_output_stream(output, expected).await; // Drop a column - let output = instance - .execute_sql("alter table demo drop column memory") - .await - .unwrap(); + let output = execute_sql(&instance, "alter table demo drop column memory").await; assert!(matches!(output, Output::AffectedRows(0))); - let output = instance - .execute_sql("select * from demo order by ts") - .await - .unwrap(); + let output = execute_sql(&instance, "select * from demo order by ts").await; let expected = vec![ "+-------+-----+---------------------+--------+", "| host | cpu | ts | my_tag |", @@ -542,16 +474,14 @@ async fn test_alter_table() { check_output_stream(output, expected).await; // insert a new row - let output = instance - .execute_sql("insert into demo(host, cpu, ts, my_tag) values ('host4', 400, 4000, 'world')") - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into demo(host, cpu, ts, my_tag) values ('host4', 400, 4000, 'world')", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql("select * from demo order by ts") - .await - .unwrap(); + let output = execute_sql(&instance, "select * from demo order by ts").await; let expected = vec![ "+-------+-----+---------------------+--------+", "| host | cpu | ts | my_tag |", @@ -580,27 +510,26 @@ async fn test_insert_with_default_value_for_type(type_name: &str) { ) engine=mito with(regions=1);"#, type_name ); - let output = instance.execute_sql(&create_sql).await.unwrap(); + let output = execute_sql(&instance, &create_sql).await; assert!(matches!(output, Output::AffectedRows(1))); // Insert with ts. - instance - .execute_sql("insert into test_table(host, cpu, ts) values ('host1', 1.1, 1000)") - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into test_table(host, cpu, ts) values ('host1', 1.1, 1000)", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); // Insert without ts, so it should be filled by default value. - let output = instance - .execute_sql("insert into test_table(host, cpu) values ('host2', 2.2)") - .await - .unwrap(); + let output = execute_sql( + &instance, + "insert into test_table(host, cpu) values ('host2', 2.2)", + ) + .await; assert!(matches!(output, Output::AffectedRows(1))); - let output = instance - .execute_sql("select host, cpu from test_table") - .await - .unwrap(); + let output = execute_sql(&instance, "select host, cpu from test_table").await; let expected = vec![ "+-------+-----+", "| host | cpu |", @@ -619,3 +548,70 @@ async fn test_insert_with_default_value() { test_insert_with_default_value_for_type("timestamp").await; test_insert_with_default_value_for_type("bigint").await; } + +#[tokio::test(flavor = "multi_thread")] +async fn test_use_database() { + let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("use_database"); + let instance = Instance::with_mock_meta_client(&opts).await.unwrap(); + instance.start().await.unwrap(); + + let output = execute_sql(&instance, "create database db1").await; + assert!(matches!(output, Output::AffectedRows(1))); + + let output = execute_sql_in_db( + &instance, + "create table tb1(col_i32 int, ts bigint, TIME INDEX(ts))", + "db1", + ) + .await; + assert!(matches!(output, Output::AffectedRows(1))); + + let output = execute_sql_in_db(&instance, "show tables", "db1").await; + let expected = vec![ + "+--------+", + "| Tables |", + "+--------+", + "| tb1 |", + "+--------+", + ]; + check_output_stream(output, expected).await; + + let output = execute_sql_in_db( + &instance, + r#"insert into tb1(col_i32, ts) values (1, 1655276557000)"#, + "db1", + ) + .await; + assert!(matches!(output, Output::AffectedRows(1))); + + let output = execute_sql_in_db(&instance, "select col_i32 from tb1", "db1").await; + let expected = vec![ + "+---------+", + "| col_i32 |", + "+---------+", + "| 1 |", + "+---------+", + ]; + check_output_stream(output, expected).await; + + // Making a particular database the default by means of the USE statement does not preclude + // accessing tables in other databases. + let output = execute_sql(&instance, "select number from public.numbers limit 1").await; + let expected = vec![ + "+--------+", + "| number |", + "+--------+", + "| 0 |", + "+--------+", + ]; + check_output_stream(output, expected).await; +} + +async fn execute_sql(instance: &Instance, sql: &str) -> Output { + execute_sql_in_db(instance, sql, DEFAULT_SCHEMA_NAME).await +} + +async fn execute_sql_in_db(instance: &Instance, sql: &str, db: &str) -> Output { + let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string())); + instance.execute_sql(sql, query_ctx).await.unwrap() +} diff --git a/src/datanode/src/tests/test_util.rs b/src/datanode/src/tests/test_util.rs index 4c9a390d58..a7cf8e1fe5 100644 --- a/src/datanode/src/tests/test_util.rs +++ b/src/datanode/src/tests/test_util.rs @@ -21,6 +21,7 @@ use datatypes::data_type::ConcreteDataType; use datatypes::schema::{ColumnSchema, SchemaBuilder}; use mito::config::EngineConfig; use mito::table::test_util::{new_test_object_store, MockEngine, MockMitoEngine}; +use query::QueryEngineFactory; use servers::Mode; use snafu::ResultExt; use table::engine::{EngineContext, TableEngineRef}; @@ -88,7 +89,7 @@ pub async fn create_test_table( .expect("ts is expected to be timestamp column"), ), create_if_not_exists: true, - primary_key_indices: vec![3, 0], // "host" and "ts" are primary keys + primary_key_indices: vec![0], // "host" is in primary keys table_options: HashMap::new(), region_numbers: vec![0], }, @@ -121,5 +122,9 @@ pub async fn create_mock_sql_handler() -> SqlHandler { .await .unwrap(), ); - SqlHandler::new(mock_engine, catalog_manager) + + let catalog_list = catalog::local::new_memory_catalog_list().unwrap(); + let factory = QueryEngineFactory::new(catalog_list); + + SqlHandler::new(mock_engine, catalog_manager, factory.query_engine()) } diff --git a/src/datatypes2/Cargo.toml b/src/datatypes2/Cargo.toml new file mode 100644 index 0000000000..34941606d4 --- /dev/null +++ b/src/datatypes2/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "datatypes2" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" + +[features] +default = [] +test = [] + +[dependencies] +common-base = { path = "../common/base" } +common-error = { path = "../common/error" } +common-time = { path = "../common/time" } +datafusion-common = "14.0" +enum_dispatch = "0.3" +num = "0.4" +num-traits = "0.2" +ordered-float = { version = "3.0", features = ["serde"] } +paste = "1.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +snafu = { version = "0.7", features = ["backtraces"] } +arrow = "26.0" diff --git a/src/datatypes2/src/arrow_array.rs b/src/datatypes2/src/arrow_array.rs new file mode 100644 index 0000000000..7405c8a665 --- /dev/null +++ b/src/datatypes2/src/arrow_array.rs @@ -0,0 +1,242 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::array::{ + Array, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; +use arrow::datatypes::DataType; +use common_time::timestamp::TimeUnit; +use common_time::Timestamp; +use snafu::OptionExt; + +use crate::data_type::ConcreteDataType; +use crate::error::{ConversionSnafu, Result}; +use crate::value::{ListValue, Value}; + +pub type BinaryArray = arrow::array::LargeBinaryArray; +pub type MutableBinaryArray = arrow::array::LargeBinaryBuilder; +pub type StringArray = arrow::array::StringArray; +pub type MutableStringArray = arrow::array::StringBuilder; + +macro_rules! cast_array { + ($arr: ident, $CastType: ty) => { + $arr.as_any() + .downcast_ref::<$CastType>() + .with_context(|| ConversionSnafu { + from: format!("{:?}", $arr.data_type()), + })? + }; +} + +// TODO(yingwen): Remove this function. +pub fn arrow_array_get(array: &dyn Array, idx: usize) -> Result { + if array.is_null(idx) { + return Ok(Value::Null); + } + + let result = match array.data_type() { + DataType::Null => Value::Null, + DataType::Boolean => Value::Boolean(cast_array!(array, BooleanArray).value(idx)), + DataType::Binary => Value::Binary(cast_array!(array, BinaryArray).value(idx).into()), + DataType::Int8 => Value::Int8(cast_array!(array, Int8Array).value(idx)), + DataType::Int16 => Value::Int16(cast_array!(array, Int16Array).value(idx)), + DataType::Int32 => Value::Int32(cast_array!(array, Int32Array).value(idx)), + DataType::Int64 => Value::Int64(cast_array!(array, Int64Array).value(idx)), + DataType::UInt8 => Value::UInt8(cast_array!(array, UInt8Array).value(idx)), + DataType::UInt16 => Value::UInt16(cast_array!(array, UInt16Array).value(idx)), + DataType::UInt32 => Value::UInt32(cast_array!(array, UInt32Array).value(idx)), + DataType::UInt64 => Value::UInt64(cast_array!(array, UInt64Array).value(idx)), + DataType::Float32 => Value::Float32(cast_array!(array, Float32Array).value(idx).into()), + DataType::Float64 => Value::Float64(cast_array!(array, Float64Array).value(idx).into()), + DataType::Utf8 => Value::String(cast_array!(array, StringArray).value(idx).into()), + DataType::Date32 => Value::Date(cast_array!(array, Date32Array).value(idx).into()), + DataType::Date64 => Value::DateTime(cast_array!(array, Date64Array).value(idx).into()), + DataType::Timestamp(t, _) => match t { + arrow::datatypes::TimeUnit::Second => Value::Timestamp(Timestamp::new( + cast_array!(array, arrow::array::TimestampSecondArray).value(idx), + TimeUnit::Second, + )), + arrow::datatypes::TimeUnit::Millisecond => Value::Timestamp(Timestamp::new( + cast_array!(array, arrow::array::TimestampMillisecondArray).value(idx), + TimeUnit::Millisecond, + )), + arrow::datatypes::TimeUnit::Microsecond => Value::Timestamp(Timestamp::new( + cast_array!(array, arrow::array::TimestampMicrosecondArray).value(idx), + TimeUnit::Microsecond, + )), + arrow::datatypes::TimeUnit::Nanosecond => Value::Timestamp(Timestamp::new( + cast_array!(array, arrow::array::TimestampNanosecondArray).value(idx), + TimeUnit::Nanosecond, + )), + }, + DataType::List(_) => { + let array = cast_array!(array, ListArray).value(idx); + let item_type = ConcreteDataType::try_from(array.data_type())?; + let values = (0..array.len()) + .map(|i| arrow_array_get(&*array, i)) + .collect::>>()?; + Value::List(ListValue::new(Some(Box::new(values)), item_type)) + } + _ => unimplemented!("Arrow array datatype: {:?}", array.data_type()), + }; + + Ok(result) +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{ + BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, + LargeBinaryArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, + }; + use arrow::datatypes::Int32Type; + use common_time::timestamp::{TimeUnit, Timestamp}; + use paste::paste; + + use super::*; + use crate::data_type::ConcreteDataType; + use crate::types::TimestampType; + + macro_rules! test_arrow_array_get_for_timestamps { + ( $($unit: ident), *) => { + $( + paste! { + let mut builder = arrow::array::[]::builder(3); + builder.append_value(1); + builder.append_value(0); + builder.append_value(-1); + let ts_array = Arc::new(builder.finish()) as Arc; + let v = arrow_array_get(&ts_array, 1).unwrap(); + assert_eq!( + ConcreteDataType::Timestamp(TimestampType::$unit( + $crate::types::[]::default(), + )), + v.data_type() + ); + } + )* + }; + } + + #[test] + fn test_timestamp_array() { + test_arrow_array_get_for_timestamps![Second, Millisecond, Microsecond, Nanosecond]; + } + + #[test] + fn test_arrow_array_access() { + let array1 = BooleanArray::from(vec![true, true, false, false]); + assert_eq!(Value::Boolean(true), arrow_array_get(&array1, 1).unwrap()); + let array1 = Int8Array::from(vec![1, 2, 3, 4]); + assert_eq!(Value::Int8(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = UInt8Array::from(vec![1, 2, 3, 4]); + assert_eq!(Value::UInt8(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = Int16Array::from(vec![1, 2, 3, 4]); + assert_eq!(Value::Int16(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = UInt16Array::from(vec![1, 2, 3, 4]); + assert_eq!(Value::UInt16(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = Int32Array::from(vec![1, 2, 3, 4]); + assert_eq!(Value::Int32(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = UInt32Array::from(vec![1, 2, 3, 4]); + assert_eq!(Value::UInt32(2), arrow_array_get(&array1, 1).unwrap()); + let array = Int64Array::from(vec![1, 2, 3, 4]); + assert_eq!(Value::Int64(2), arrow_array_get(&array, 1).unwrap()); + let array1 = UInt64Array::from(vec![1, 2, 3, 4]); + assert_eq!(Value::UInt64(2), arrow_array_get(&array1, 1).unwrap()); + let array1 = Float32Array::from(vec![1f32, 2f32, 3f32, 4f32]); + assert_eq!( + Value::Float32(2f32.into()), + arrow_array_get(&array1, 1).unwrap() + ); + let array1 = Float64Array::from(vec![1f64, 2f64, 3f64, 4f64]); + assert_eq!( + Value::Float64(2f64.into()), + arrow_array_get(&array1, 1).unwrap() + ); + + let array2 = StringArray::from(vec![Some("hello"), None, Some("world")]); + assert_eq!( + Value::String("hello".into()), + arrow_array_get(&array2, 0).unwrap() + ); + assert_eq!(Value::Null, arrow_array_get(&array2, 1).unwrap()); + + let array3 = LargeBinaryArray::from(vec![ + Some("hello".as_bytes()), + None, + Some("world".as_bytes()), + ]); + assert_eq!(Value::Null, arrow_array_get(&array3, 1).unwrap()); + + let array = TimestampSecondArray::from(vec![1, 2, 3]); + let value = arrow_array_get(&array, 1).unwrap(); + assert_eq!(value, Value::Timestamp(Timestamp::new(2, TimeUnit::Second))); + let array = TimestampMillisecondArray::from(vec![1, 2, 3]); + let value = arrow_array_get(&array, 1).unwrap(); + assert_eq!( + value, + Value::Timestamp(Timestamp::new(2, TimeUnit::Millisecond)) + ); + let array = TimestampMicrosecondArray::from(vec![1, 2, 3]); + let value = arrow_array_get(&array, 1).unwrap(); + assert_eq!( + value, + Value::Timestamp(Timestamp::new(2, TimeUnit::Microsecond)) + ); + let array = TimestampNanosecondArray::from(vec![1, 2, 3]); + let value = arrow_array_get(&array, 1).unwrap(); + assert_eq!( + value, + Value::Timestamp(Timestamp::new(2, TimeUnit::Nanosecond)) + ); + + // test list array + let data = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + let arrow_array = ListArray::from_iter_primitive::(data); + + let v0 = arrow_array_get(&arrow_array, 0).unwrap(); + match v0 { + Value::List(list) => { + assert!(matches!(list.datatype(), ConcreteDataType::Int32(_))); + let items = list.items().as_ref().unwrap(); + assert_eq!( + **items, + vec![Value::Int32(1), Value::Int32(2), Value::Int32(3)] + ); + } + _ => unreachable!(), + } + + assert_eq!(Value::Null, arrow_array_get(&arrow_array, 1).unwrap()); + let v2 = arrow_array_get(&arrow_array, 2).unwrap(); + match v2 { + Value::List(list) => { + assert!(matches!(list.datatype(), ConcreteDataType::Int32(_))); + let items = list.items().as_ref().unwrap(); + assert_eq!(**items, vec![Value::Int32(4), Value::Null, Value::Int32(6)]); + } + _ => unreachable!(), + } + } +} diff --git a/src/datatypes2/src/data_type.rs b/src/datatypes2/src/data_type.rs new file mode 100644 index 0000000000..0d06d566b6 --- /dev/null +++ b/src/datatypes2/src/data_type.rs @@ -0,0 +1,486 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType as ArrowDataType, TimeUnit as ArrowTimeUnit}; +use common_time::timestamp::TimeUnit; +use paste::paste; +use serde::{Deserialize, Serialize}; + +use crate::error::{self, Error, Result}; +use crate::type_id::LogicalTypeId; +use crate::types::{ + BinaryType, BooleanType, DateTimeType, DateType, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, ListType, NullType, StringType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, TimestampType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use crate::value::Value; +use crate::vectors::MutableVector; + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[enum_dispatch::enum_dispatch(DataType)] +pub enum ConcreteDataType { + Null(NullType), + Boolean(BooleanType), + + // Numeric types: + Int8(Int8Type), + Int16(Int16Type), + Int32(Int32Type), + Int64(Int64Type), + UInt8(UInt8Type), + UInt16(UInt16Type), + UInt32(UInt32Type), + UInt64(UInt64Type), + Float32(Float32Type), + Float64(Float64Type), + + // String types: + Binary(BinaryType), + String(StringType), + + // Date types: + Date(DateType), + DateTime(DateTimeType), + Timestamp(TimestampType), + + // Compound types: + List(ListType), +} + +// TODO(yingwen): Refactor these `is_xxx()` methods, such as adding a `properties()` method +// returning all these properties to the `DataType` trait +impl ConcreteDataType { + pub fn is_float(&self) -> bool { + matches!( + self, + ConcreteDataType::Float64(_) | ConcreteDataType::Float32(_) + ) + } + + pub fn is_boolean(&self) -> bool { + matches!(self, ConcreteDataType::Boolean(_)) + } + + pub fn is_stringifiable(&self) -> bool { + matches!( + self, + ConcreteDataType::String(_) + | ConcreteDataType::Date(_) + | ConcreteDataType::DateTime(_) + | ConcreteDataType::Timestamp(_) + ) + } + + pub fn is_signed(&self) -> bool { + matches!( + self, + ConcreteDataType::Int8(_) + | ConcreteDataType::Int16(_) + | ConcreteDataType::Int32(_) + | ConcreteDataType::Int64(_) + | ConcreteDataType::Date(_) + | ConcreteDataType::DateTime(_) + | ConcreteDataType::Timestamp(_) + ) + } + + pub fn is_unsigned(&self) -> bool { + matches!( + self, + ConcreteDataType::UInt8(_) + | ConcreteDataType::UInt16(_) + | ConcreteDataType::UInt32(_) + | ConcreteDataType::UInt64(_) + ) + } + + pub fn numerics() -> Vec { + vec![ + ConcreteDataType::int8_datatype(), + ConcreteDataType::int16_datatype(), + ConcreteDataType::int32_datatype(), + ConcreteDataType::int64_datatype(), + ConcreteDataType::uint8_datatype(), + ConcreteDataType::uint16_datatype(), + ConcreteDataType::uint32_datatype(), + ConcreteDataType::uint64_datatype(), + ConcreteDataType::float32_datatype(), + ConcreteDataType::float64_datatype(), + ] + } + + /// Convert arrow data type to [ConcreteDataType]. + /// + /// # Panics + /// Panic if given arrow data type is not supported. + pub fn from_arrow_type(dt: &ArrowDataType) -> Self { + ConcreteDataType::try_from(dt).expect("Unimplemented type") + } + + pub fn is_null(&self) -> bool { + matches!(self, ConcreteDataType::Null(NullType)) + } +} + +impl TryFrom<&ArrowDataType> for ConcreteDataType { + type Error = Error; + + fn try_from(dt: &ArrowDataType) -> Result { + let concrete_type = match dt { + ArrowDataType::Null => Self::null_datatype(), + ArrowDataType::Boolean => Self::boolean_datatype(), + ArrowDataType::UInt8 => Self::uint8_datatype(), + ArrowDataType::UInt16 => Self::uint16_datatype(), + ArrowDataType::UInt32 => Self::uint32_datatype(), + ArrowDataType::UInt64 => Self::uint64_datatype(), + ArrowDataType::Int8 => Self::int8_datatype(), + ArrowDataType::Int16 => Self::int16_datatype(), + ArrowDataType::Int32 => Self::int32_datatype(), + ArrowDataType::Int64 => Self::int64_datatype(), + ArrowDataType::Float32 => Self::float32_datatype(), + ArrowDataType::Float64 => Self::float64_datatype(), + ArrowDataType::Date32 => Self::date_datatype(), + ArrowDataType::Date64 => Self::datetime_datatype(), + ArrowDataType::Timestamp(u, _) => ConcreteDataType::from_arrow_time_unit(u), + ArrowDataType::Binary | ArrowDataType::LargeBinary => Self::binary_datatype(), + ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 => Self::string_datatype(), + ArrowDataType::List(field) => Self::List(ListType::new( + ConcreteDataType::from_arrow_type(field.data_type()), + )), + _ => { + return error::UnsupportedArrowTypeSnafu { + arrow_type: dt.clone(), + } + .fail() + } + }; + + Ok(concrete_type) + } +} + +macro_rules! impl_new_concrete_type_functions { + ($($Type: ident), +) => { + paste! { + impl ConcreteDataType { + $( + pub fn [<$Type:lower _datatype>]() -> ConcreteDataType { + ConcreteDataType::$Type([<$Type Type>]::default()) + } + )+ + } + } + } +} + +impl_new_concrete_type_functions!( + Null, Boolean, UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64, + Binary, Date, DateTime, String +); + +impl ConcreteDataType { + pub fn timestamp_second_datatype() -> Self { + ConcreteDataType::Timestamp(TimestampType::Second(TimestampSecondType::default())) + } + + pub fn timestamp_millisecond_datatype() -> Self { + ConcreteDataType::Timestamp(TimestampType::Millisecond( + TimestampMillisecondType::default(), + )) + } + + pub fn timestamp_microsecond_datatype() -> Self { + ConcreteDataType::Timestamp(TimestampType::Microsecond( + TimestampMicrosecondType::default(), + )) + } + + pub fn timestamp_nanosecond_datatype() -> Self { + ConcreteDataType::Timestamp(TimestampType::Nanosecond(TimestampNanosecondType::default())) + } + + pub fn timestamp_datatype(unit: TimeUnit) -> Self { + match unit { + TimeUnit::Second => Self::timestamp_second_datatype(), + TimeUnit::Millisecond => Self::timestamp_millisecond_datatype(), + TimeUnit::Microsecond => Self::timestamp_microsecond_datatype(), + TimeUnit::Nanosecond => Self::timestamp_nanosecond_datatype(), + } + } + + /// Converts from arrow timestamp unit to + pub fn from_arrow_time_unit(t: &ArrowTimeUnit) -> Self { + match t { + ArrowTimeUnit::Second => Self::timestamp_second_datatype(), + ArrowTimeUnit::Millisecond => Self::timestamp_millisecond_datatype(), + ArrowTimeUnit::Microsecond => Self::timestamp_microsecond_datatype(), + ArrowTimeUnit::Nanosecond => Self::timestamp_nanosecond_datatype(), + } + } + + pub fn list_datatype(item_type: ConcreteDataType) -> ConcreteDataType { + ConcreteDataType::List(ListType::new(item_type)) + } +} + +/// Data type abstraction. +#[enum_dispatch::enum_dispatch] +pub trait DataType: std::fmt::Debug + Send + Sync { + /// Name of this data type. + fn name(&self) -> &str; + + /// Returns id of the Logical data type. + fn logical_type_id(&self) -> LogicalTypeId; + + /// Returns the default value of this type. + fn default_value(&self) -> Value; + + /// Convert this type as [arrow::datatypes::DataType]. + fn as_arrow_type(&self) -> ArrowDataType; + + /// Creates a mutable vector with given `capacity` of this type. + fn create_mutable_vector(&self, capacity: usize) -> Box; + + /// Returns true if the data type is compatible with timestamp type so we can + /// use it as a timestamp. + fn is_timestamp_compatible(&self) -> bool; +} + +pub type DataTypeRef = Arc; + +#[cfg(test)] +mod tests { + use arrow::datatypes::Field; + + use super::*; + + #[test] + fn test_concrete_type_as_datatype_trait() { + let concrete_type = ConcreteDataType::boolean_datatype(); + + assert_eq!("Boolean", concrete_type.name()); + assert_eq!(Value::Boolean(false), concrete_type.default_value()); + assert_eq!(LogicalTypeId::Boolean, concrete_type.logical_type_id()); + assert_eq!(ArrowDataType::Boolean, concrete_type.as_arrow_type()); + } + + #[test] + fn test_from_arrow_type() { + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Null), + ConcreteDataType::Null(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Boolean), + ConcreteDataType::Boolean(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Binary), + ConcreteDataType::Binary(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::LargeBinary), + ConcreteDataType::Binary(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Int8), + ConcreteDataType::Int8(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Int16), + ConcreteDataType::Int16(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Int32), + ConcreteDataType::Int32(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Int64), + ConcreteDataType::Int64(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::UInt8), + ConcreteDataType::UInt8(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::UInt16), + ConcreteDataType::UInt16(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::UInt32), + ConcreteDataType::UInt32(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::UInt64), + ConcreteDataType::UInt64(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Float32), + ConcreteDataType::Float32(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Float64), + ConcreteDataType::Float64(_) + )); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Utf8), + ConcreteDataType::String(_) + )); + assert_eq!( + ConcreteDataType::from_arrow_type(&ArrowDataType::List(Box::new(Field::new( + "item", + ArrowDataType::Int32, + true, + )))), + ConcreteDataType::List(ListType::new(ConcreteDataType::int32_datatype())) + ); + assert!(matches!( + ConcreteDataType::from_arrow_type(&ArrowDataType::Date32), + ConcreteDataType::Date(_) + )); + } + + #[test] + fn test_from_arrow_timestamp() { + assert_eq!( + ConcreteDataType::timestamp_millisecond_datatype(), + ConcreteDataType::from_arrow_time_unit(&ArrowTimeUnit::Millisecond) + ); + assert_eq!( + ConcreteDataType::timestamp_microsecond_datatype(), + ConcreteDataType::from_arrow_time_unit(&ArrowTimeUnit::Microsecond) + ); + assert_eq!( + ConcreteDataType::timestamp_nanosecond_datatype(), + ConcreteDataType::from_arrow_time_unit(&ArrowTimeUnit::Nanosecond) + ); + assert_eq!( + ConcreteDataType::timestamp_second_datatype(), + ConcreteDataType::from_arrow_time_unit(&ArrowTimeUnit::Second) + ); + } + + #[test] + fn test_is_timestamp_compatible() { + assert!(ConcreteDataType::timestamp_datatype(TimeUnit::Second).is_timestamp_compatible()); + assert!( + ConcreteDataType::timestamp_datatype(TimeUnit::Millisecond).is_timestamp_compatible() + ); + assert!( + ConcreteDataType::timestamp_datatype(TimeUnit::Microsecond).is_timestamp_compatible() + ); + assert!( + ConcreteDataType::timestamp_datatype(TimeUnit::Nanosecond).is_timestamp_compatible() + ); + assert!(ConcreteDataType::timestamp_second_datatype().is_timestamp_compatible()); + assert!(ConcreteDataType::timestamp_millisecond_datatype().is_timestamp_compatible()); + assert!(ConcreteDataType::timestamp_microsecond_datatype().is_timestamp_compatible()); + assert!(ConcreteDataType::timestamp_nanosecond_datatype().is_timestamp_compatible()); + assert!(ConcreteDataType::int64_datatype().is_timestamp_compatible()); + assert!(!ConcreteDataType::null_datatype().is_timestamp_compatible()); + assert!(!ConcreteDataType::binary_datatype().is_timestamp_compatible()); + assert!(!ConcreteDataType::boolean_datatype().is_timestamp_compatible()); + assert!(!ConcreteDataType::date_datatype().is_timestamp_compatible()); + assert!(!ConcreteDataType::datetime_datatype().is_timestamp_compatible()); + assert!(!ConcreteDataType::string_datatype().is_timestamp_compatible()); + assert!(!ConcreteDataType::int32_datatype().is_timestamp_compatible()); + assert!(!ConcreteDataType::uint64_datatype().is_timestamp_compatible()); + } + + #[test] + fn test_is_null() { + assert!(ConcreteDataType::null_datatype().is_null()); + assert!(!ConcreteDataType::int32_datatype().is_null()); + } + + #[test] + fn test_is_float() { + assert!(!ConcreteDataType::int32_datatype().is_float()); + assert!(ConcreteDataType::float32_datatype().is_float()); + assert!(ConcreteDataType::float64_datatype().is_float()); + } + + #[test] + fn test_is_boolean() { + assert!(!ConcreteDataType::int32_datatype().is_boolean()); + assert!(!ConcreteDataType::float32_datatype().is_boolean()); + assert!(ConcreteDataType::boolean_datatype().is_boolean()); + } + + #[test] + fn test_is_stringifiable() { + assert!(!ConcreteDataType::int32_datatype().is_stringifiable()); + assert!(!ConcreteDataType::float32_datatype().is_stringifiable()); + assert!(ConcreteDataType::string_datatype().is_stringifiable()); + assert!(ConcreteDataType::date_datatype().is_stringifiable()); + assert!(ConcreteDataType::datetime_datatype().is_stringifiable()); + assert!(ConcreteDataType::timestamp_second_datatype().is_stringifiable()); + assert!(ConcreteDataType::timestamp_millisecond_datatype().is_stringifiable()); + assert!(ConcreteDataType::timestamp_microsecond_datatype().is_stringifiable()); + assert!(ConcreteDataType::timestamp_nanosecond_datatype().is_stringifiable()); + } + + #[test] + fn test_is_signed() { + assert!(ConcreteDataType::int8_datatype().is_signed()); + assert!(ConcreteDataType::int16_datatype().is_signed()); + assert!(ConcreteDataType::int32_datatype().is_signed()); + assert!(ConcreteDataType::int64_datatype().is_signed()); + assert!(ConcreteDataType::date_datatype().is_signed()); + assert!(ConcreteDataType::datetime_datatype().is_signed()); + assert!(ConcreteDataType::timestamp_second_datatype().is_signed()); + assert!(ConcreteDataType::timestamp_millisecond_datatype().is_signed()); + assert!(ConcreteDataType::timestamp_microsecond_datatype().is_signed()); + assert!(ConcreteDataType::timestamp_nanosecond_datatype().is_signed()); + + assert!(!ConcreteDataType::uint8_datatype().is_signed()); + assert!(!ConcreteDataType::uint16_datatype().is_signed()); + assert!(!ConcreteDataType::uint32_datatype().is_signed()); + assert!(!ConcreteDataType::uint64_datatype().is_signed()); + + assert!(!ConcreteDataType::float32_datatype().is_signed()); + assert!(!ConcreteDataType::float64_datatype().is_signed()); + } + + #[test] + fn test_is_unsigned() { + assert!(!ConcreteDataType::int8_datatype().is_unsigned()); + assert!(!ConcreteDataType::int16_datatype().is_unsigned()); + assert!(!ConcreteDataType::int32_datatype().is_unsigned()); + assert!(!ConcreteDataType::int64_datatype().is_unsigned()); + assert!(!ConcreteDataType::date_datatype().is_unsigned()); + assert!(!ConcreteDataType::datetime_datatype().is_unsigned()); + assert!(!ConcreteDataType::timestamp_second_datatype().is_unsigned()); + assert!(!ConcreteDataType::timestamp_millisecond_datatype().is_unsigned()); + assert!(!ConcreteDataType::timestamp_microsecond_datatype().is_unsigned()); + assert!(!ConcreteDataType::timestamp_nanosecond_datatype().is_unsigned()); + + assert!(ConcreteDataType::uint8_datatype().is_unsigned()); + assert!(ConcreteDataType::uint16_datatype().is_unsigned()); + assert!(ConcreteDataType::uint32_datatype().is_unsigned()); + assert!(ConcreteDataType::uint64_datatype().is_unsigned()); + + assert!(!ConcreteDataType::float32_datatype().is_unsigned()); + assert!(!ConcreteDataType::float64_datatype().is_unsigned()); + } + + #[test] + fn test_numerics() { + let nums = ConcreteDataType::numerics(); + assert_eq!(10, nums.len()); + } +} diff --git a/src/datatypes2/src/error.rs b/src/datatypes2/src/error.rs new file mode 100644 index 0000000000..50b49cf2b4 --- /dev/null +++ b/src/datatypes2/src/error.rs @@ -0,0 +1,144 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; + +use common_error::prelude::{ErrorCompat, ErrorExt, Snafu, StatusCode}; +use snafu::Backtrace; + +#[derive(Debug, Snafu)] +#[snafu(visibility(pub))] +pub enum Error { + #[snafu(display("Failed to serialize data, source: {}", source))] + Serialize { + source: serde_json::Error, + backtrace: Backtrace, + }, + + #[snafu(display("Failed to deserialize data, source: {}, json: {}", source, json))] + Deserialize { + source: serde_json::Error, + backtrace: Backtrace, + json: String, + }, + + #[snafu(display("Failed to convert datafusion type: {}", from))] + Conversion { from: String, backtrace: Backtrace }, + + #[snafu(display("Bad array access, Index out of bounds: {}, size: {}", index, size))] + BadArrayAccess { + index: usize, + size: usize, + backtrace: Backtrace, + }, + + #[snafu(display("Unknown vector, {}", msg))] + UnknownVector { msg: String, backtrace: Backtrace }, + + #[snafu(display("Unsupported arrow data type, type: {:?}", arrow_type))] + UnsupportedArrowType { + arrow_type: arrow::datatypes::DataType, + backtrace: Backtrace, + }, + + #[snafu(display("Timestamp column {} not found", name,))] + TimestampNotFound { name: String, backtrace: Backtrace }, + + #[snafu(display( + "Failed to parse version in schema meta, value: {}, source: {}", + value, + source + ))] + ParseSchemaVersion { + value: String, + source: std::num::ParseIntError, + backtrace: Backtrace, + }, + + #[snafu(display("Invalid timestamp index: {}", index))] + InvalidTimestampIndex { index: usize, backtrace: Backtrace }, + + #[snafu(display("Duplicate timestamp index, exists: {}, new: {}", exists, new))] + DuplicateTimestampIndex { + exists: usize, + new: usize, + backtrace: Backtrace, + }, + + #[snafu(display("{}", msg))] + CastType { msg: String, backtrace: Backtrace }, + + #[snafu(display("Arrow failed to compute, source: {}", source))] + ArrowCompute { + source: arrow::error::ArrowError, + backtrace: Backtrace, + }, + + #[snafu(display("Unsupported column default constraint expression: {}", expr))] + UnsupportedDefaultExpr { expr: String, backtrace: Backtrace }, + + #[snafu(display("Default value should not be null for non null column"))] + NullDefault { backtrace: Backtrace }, + + #[snafu(display("Incompatible default value type, reason: {}", reason))] + DefaultValueType { + reason: String, + backtrace: Backtrace, + }, + + #[snafu(display("Duplicated metadata for {}", key))] + DuplicateMeta { key: String, backtrace: Backtrace }, +} + +impl ErrorExt for Error { + fn status_code(&self) -> StatusCode { + // Inner encoding and decoding error should not be exposed to users. + StatusCode::Internal + } + + fn backtrace_opt(&self) -> Option<&Backtrace> { + ErrorCompat::backtrace(self) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use snafu::ResultExt; + + use super::*; + + #[test] + pub fn test_error() { + let mut map = HashMap::new(); + map.insert(true, 1); + map.insert(false, 2); + + let result = serde_json::to_string(&map).context(SerializeSnafu); + assert!(result.is_err(), "serialize result is: {:?}", result); + let err = serde_json::to_string(&map) + .context(SerializeSnafu) + .err() + .unwrap(); + assert!(err.backtrace_opt().is_some()); + assert_eq!(StatusCode::Internal, err.status_code()); + } +} diff --git a/src/datatypes2/src/lib.rs b/src/datatypes2/src/lib.rs new file mode 100644 index 0000000000..256d347eac --- /dev/null +++ b/src/datatypes2/src/lib.rs @@ -0,0 +1,33 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(generic_associated_types)] +#![feature(assert_matches)] + +pub mod arrow_array; +pub mod data_type; +pub mod error; +pub mod macros; +pub mod prelude; +mod scalars; +pub mod schema; +pub mod serialize; +mod timestamp; +pub mod type_id; +pub mod types; +pub mod value; +pub mod vectors; + +pub use arrow; +pub use error::{Error, Result}; diff --git a/src/datatypes2/src/macros.rs b/src/datatypes2/src/macros.rs new file mode 100644 index 0000000000..37c0a42e3f --- /dev/null +++ b/src/datatypes2/src/macros.rs @@ -0,0 +1,68 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Some helper macros for datatypes, copied from databend. + +/// Apply the macro rules to all primitive types. +#[macro_export] +macro_rules! for_all_primitive_types { + ($macro:tt $(, $x:tt)*) => { + $macro! { + [$($x),*], + { i8 }, + { i16 }, + { i32 }, + { i64 }, + { u8 }, + { u16 }, + { u32 }, + { u64 }, + { f32 }, + { f64 } + } + }; +} + +/// Match the logical type and apply `$body` to all primitive types and +/// `nbody` to other types. +#[macro_export] +macro_rules! with_match_primitive_type_id { + ($key_type:expr, | $_:tt $T:ident | $body:tt, $nbody:tt) => {{ + macro_rules! __with_ty__ { + ( $_ $T:ident ) => { + $body + }; + } + + use $crate::type_id::LogicalTypeId; + use $crate::types::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, + }; + match $key_type { + LogicalTypeId::Int8 => __with_ty__! { Int8Type }, + LogicalTypeId::Int16 => __with_ty__! { Int16Type }, + LogicalTypeId::Int32 => __with_ty__! { Int32Type }, + LogicalTypeId::Int64 => __with_ty__! { Int64Type }, + LogicalTypeId::UInt8 => __with_ty__! { UInt8Type }, + LogicalTypeId::UInt16 => __with_ty__! { UInt16Type }, + LogicalTypeId::UInt32 => __with_ty__! { UInt32Type }, + LogicalTypeId::UInt64 => __with_ty__! { UInt64Type }, + LogicalTypeId::Float32 => __with_ty__! { Float32Type }, + LogicalTypeId::Float64 => __with_ty__! { Float64Type }, + + _ => $nbody, + } + }}; +} diff --git a/src/datatypes2/src/prelude.rs b/src/datatypes2/src/prelude.rs new file mode 100644 index 0000000000..f6bd298316 --- /dev/null +++ b/src/datatypes2/src/prelude.rs @@ -0,0 +1,20 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub use crate::data_type::{ConcreteDataType, DataType, DataTypeRef}; +pub use crate::macros::*; +pub use crate::scalars::{Scalar, ScalarRef, ScalarVector, ScalarVectorBuilder}; +pub use crate::type_id::LogicalTypeId; +pub use crate::value::{Value, ValueRef}; +pub use crate::vectors::{MutableVector, Validity, Vector, VectorRef}; diff --git a/src/datatypes2/src/scalars.rs b/src/datatypes2/src/scalars.rs new file mode 100644 index 0000000000..327ebaa629 --- /dev/null +++ b/src/datatypes2/src/scalars.rs @@ -0,0 +1,443 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; + +use common_time::{Date, DateTime}; + +use crate::types::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, + UInt64Type, UInt8Type, +}; +use crate::value::{ListValue, ListValueRef, Value}; +use crate::vectors::{ + BinaryVector, BooleanVector, DateTimeVector, DateVector, ListVector, MutableVector, + PrimitiveVector, StringVector, Vector, +}; + +fn get_iter_capacity>(iter: &I) -> usize { + match iter.size_hint() { + (_lower, Some(upper)) => upper, + (0, None) => 1024, + (lower, None) => lower, + } +} + +/// Owned scalar value +/// primitive types, bool, Vec ... +pub trait Scalar: 'static + Sized + Default + Any +where + for<'a> Self::VectorType: ScalarVector = Self::RefType<'a>>, +{ + type VectorType: ScalarVector; + type RefType<'a>: ScalarRef<'a, ScalarType = Self> + where + Self: 'a; + /// Get a reference of the current value. + fn as_scalar_ref(&self) -> Self::RefType<'_>; + + /// Upcast GAT type's lifetime. + fn upcast_gat<'short, 'long: 'short>(long: Self::RefType<'long>) -> Self::RefType<'short>; +} + +pub trait ScalarRef<'a>: std::fmt::Debug + Clone + Copy + Send + 'a { + /// The corresponding [`Scalar`] type. + type ScalarType: Scalar = Self>; + + /// Convert the reference into an owned value. + fn to_owned_scalar(&self) -> Self::ScalarType; +} + +/// A sub trait of Vector to add scalar operation support. +// This implementation refers to Datebend's [ScalarColumn](https://github.com/datafuselabs/databend/blob/main/common/datavalues/src/scalars/type_.rs) +// and skyzh's [type-exercise-in-rust](https://github.com/skyzh/type-exercise-in-rust). +pub trait ScalarVector: Vector + Send + Sync + Sized + 'static +where + for<'a> Self::OwnedItem: Scalar = Self::RefItem<'a>>, +{ + type OwnedItem: Scalar; + /// The reference item of this vector. + type RefItem<'a>: ScalarRef<'a, ScalarType = Self::OwnedItem> + where + Self: 'a; + + /// Iterator type of this vector. + type Iter<'a>: Iterator>> + where + Self: 'a; + + /// Builder type to build this vector. + type Builder: ScalarVectorBuilder; + + /// Returns the reference to an element at given position. + /// + /// Note: `get()` has bad performance, avoid call this function inside loop. + /// + /// # Panics + /// Panics if `idx >= self.len()`. + fn get_data(&self, idx: usize) -> Option>; + + /// Returns iterator of current vector. + fn iter_data(&self) -> Self::Iter<'_>; + + fn from_slice(data: &[Self::RefItem<'_>]) -> Self { + let mut builder = Self::Builder::with_capacity(data.len()); + for item in data { + builder.push(Some(*item)); + } + builder.finish() + } + + fn from_iterator<'a>(it: impl Iterator>) -> Self { + let mut builder = Self::Builder::with_capacity(get_iter_capacity(&it)); + for item in it { + builder.push(Some(item)); + } + builder.finish() + } + + fn from_owned_iterator(it: impl Iterator>) -> Self { + let mut builder = Self::Builder::with_capacity(get_iter_capacity(&it)); + for item in it { + match item { + Some(item) => builder.push(Some(item.as_scalar_ref())), + None => builder.push(None), + } + } + builder.finish() + } + + fn from_vec>(values: Vec) -> Self { + let it = values.into_iter(); + let mut builder = Self::Builder::with_capacity(get_iter_capacity(&it)); + for item in it { + builder.push(Some(item.into().as_scalar_ref())); + } + builder.finish() + } +} + +/// A trait over all vector builders. +pub trait ScalarVectorBuilder: MutableVector { + type VectorType: ScalarVector; + + /// Create a new builder with initial `capacity`. + fn with_capacity(capacity: usize) -> Self; + + /// Push a value into the builder. + fn push(&mut self, value: Option<::RefItem<'_>>); + + /// Finish build and return a new vector. + fn finish(&mut self) -> Self::VectorType; +} + +macro_rules! impl_scalar_for_native { + ($Native: ident, $DataType: ident) => { + impl Scalar for $Native { + type VectorType = PrimitiveVector<$DataType>; + type RefType<'a> = $Native; + + #[inline] + fn as_scalar_ref(&self) -> $Native { + *self + } + + #[allow(clippy::needless_lifetimes)] + #[inline] + fn upcast_gat<'short, 'long: 'short>(long: $Native) -> $Native { + long + } + } + + /// Implement [`ScalarRef`] for primitive types. Note that primitive types are both [`Scalar`] and [`ScalarRef`]. + impl<'a> ScalarRef<'a> for $Native { + type ScalarType = $Native; + + #[inline] + fn to_owned_scalar(&self) -> $Native { + *self + } + } + }; +} + +impl_scalar_for_native!(u8, UInt8Type); +impl_scalar_for_native!(u16, UInt16Type); +impl_scalar_for_native!(u32, UInt32Type); +impl_scalar_for_native!(u64, UInt64Type); +impl_scalar_for_native!(i8, Int8Type); +impl_scalar_for_native!(i16, Int16Type); +impl_scalar_for_native!(i32, Int32Type); +impl_scalar_for_native!(i64, Int64Type); +impl_scalar_for_native!(f32, Float32Type); +impl_scalar_for_native!(f64, Float64Type); + +impl Scalar for bool { + type VectorType = BooleanVector; + type RefType<'a> = bool; + + #[inline] + fn as_scalar_ref(&self) -> bool { + *self + } + + #[allow(clippy::needless_lifetimes)] + #[inline] + fn upcast_gat<'short, 'long: 'short>(long: bool) -> bool { + long + } +} + +impl<'a> ScalarRef<'a> for bool { + type ScalarType = bool; + + #[inline] + fn to_owned_scalar(&self) -> bool { + *self + } +} + +impl Scalar for String { + type VectorType = StringVector; + type RefType<'a> = &'a str; + + #[inline] + fn as_scalar_ref(&self) -> &str { + self + } + + #[inline] + fn upcast_gat<'short, 'long: 'short>(long: &'long str) -> &'short str { + long + } +} + +impl<'a> ScalarRef<'a> for &'a str { + type ScalarType = String; + + #[inline] + fn to_owned_scalar(&self) -> String { + self.to_string() + } +} + +impl Scalar for Vec { + type VectorType = BinaryVector; + type RefType<'a> = &'a [u8]; + + #[inline] + fn as_scalar_ref(&self) -> &[u8] { + self + } + + #[inline] + fn upcast_gat<'short, 'long: 'short>(long: &'long [u8]) -> &'short [u8] { + long + } +} + +impl<'a> ScalarRef<'a> for &'a [u8] { + type ScalarType = Vec; + + #[inline] + fn to_owned_scalar(&self) -> Vec { + self.to_vec() + } +} + +impl Scalar for Date { + type VectorType = DateVector; + type RefType<'a> = Date; + + fn as_scalar_ref(&self) -> Self::RefType<'_> { + *self + } + + fn upcast_gat<'short, 'long: 'short>(long: Self::RefType<'long>) -> Self::RefType<'short> { + long + } +} + +impl<'a> ScalarRef<'a> for Date { + type ScalarType = Date; + + fn to_owned_scalar(&self) -> Self::ScalarType { + *self + } +} + +impl Scalar for DateTime { + type VectorType = DateTimeVector; + type RefType<'a> = DateTime; + + fn as_scalar_ref(&self) -> Self::RefType<'_> { + *self + } + + fn upcast_gat<'short, 'long: 'short>(long: Self::RefType<'long>) -> Self::RefType<'short> { + long + } +} + +impl<'a> ScalarRef<'a> for DateTime { + type ScalarType = DateTime; + + fn to_owned_scalar(&self) -> Self::ScalarType { + *self + } +} + +// Timestamp types implement Scalar and ScalarRef in `src/timestamp.rs`. + +impl Scalar for ListValue { + type VectorType = ListVector; + type RefType<'a> = ListValueRef<'a>; + + fn as_scalar_ref(&self) -> Self::RefType<'_> { + ListValueRef::Ref { val: self } + } + + fn upcast_gat<'short, 'long: 'short>(long: Self::RefType<'long>) -> Self::RefType<'short> { + long + } +} + +impl<'a> ScalarRef<'a> for ListValueRef<'a> { + type ScalarType = ListValue; + + fn to_owned_scalar(&self) -> Self::ScalarType { + match self { + ListValueRef::Indexed { vector, idx } => match vector.get(*idx) { + // Normally should not get `Value::Null` if the `ListValueRef` comes + // from the iterator of the ListVector, but we avoid panic and just + // returns a default list value in such case since `ListValueRef` may + // be constructed manually. + Value::Null => ListValue::default(), + Value::List(v) => v, + _ => unreachable!(), + }, + ListValueRef::Ref { val } => (*val).clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::data_type::ConcreteDataType; + use crate::timestamp::TimestampSecond; + use crate::vectors::{BinaryVector, Int32Vector, ListVectorBuilder, TimestampSecondVector}; + + fn build_vector_from_slice(items: &[Option>]) -> T { + let mut builder = T::Builder::with_capacity(items.len()); + for item in items { + builder.push(*item); + } + builder.finish() + } + + fn assert_vector_eq<'a, T: ScalarVector>(expect: &[Option>], vector: &'a T) + where + T::RefItem<'a>: PartialEq + std::fmt::Debug, + { + for (a, b) in expect.iter().zip(vector.iter_data()) { + assert_eq!(*a, b); + } + } + + #[test] + fn test_build_i32_vector() { + let expect = vec![Some(1), Some(2), Some(3), None, Some(5)]; + let vector: Int32Vector = build_vector_from_slice(&expect); + assert_vector_eq(&expect, &vector); + } + + #[test] + fn test_build_binary_vector() { + let expect: Vec> = vec![ + Some(b"a"), + Some(b"b"), + Some(b"c"), + None, + Some(b"e"), + Some(b""), + ]; + let vector: BinaryVector = build_vector_from_slice(&expect); + assert_vector_eq(&expect, &vector); + } + + #[test] + fn test_build_date_vector() { + let expect: Vec> = vec![ + Some(Date::new(0)), + Some(Date::new(-1)), + None, + Some(Date::new(1)), + ]; + let vector: DateVector = build_vector_from_slice(&expect); + assert_vector_eq(&expect, &vector); + } + + #[test] + fn test_date_scalar() { + let date = Date::new(1); + assert_eq!(date, date.as_scalar_ref()); + assert_eq!(date, date.to_owned_scalar()); + } + + #[test] + fn test_datetime_scalar() { + let dt = DateTime::new(123); + assert_eq!(dt, dt.as_scalar_ref()); + assert_eq!(dt, dt.to_owned_scalar()); + } + + #[test] + fn test_list_value_scalar() { + let list_value = ListValue::new( + Some(Box::new(vec![Value::Int32(123)])), + ConcreteDataType::int32_datatype(), + ); + let list_ref = ListValueRef::Ref { val: &list_value }; + assert_eq!(list_ref, list_value.as_scalar_ref()); + assert_eq!(list_value, list_ref.to_owned_scalar()); + + let mut builder = + ListVectorBuilder::with_type_capacity(ConcreteDataType::int32_datatype(), 1); + builder.push(None); + builder.push(Some(list_value.as_scalar_ref())); + let vector = builder.finish(); + + let ref_on_vec = ListValueRef::Indexed { + vector: &vector, + idx: 0, + }; + assert_eq!(ListValue::default(), ref_on_vec.to_owned_scalar()); + let ref_on_vec = ListValueRef::Indexed { + vector: &vector, + idx: 1, + }; + assert_eq!(list_value, ref_on_vec.to_owned_scalar()); + } + + #[test] + fn test_build_timestamp_vector() { + let expect: Vec> = vec![Some(10.into()), None, Some(42.into())]; + let vector: TimestampSecondVector = build_vector_from_slice(&expect); + assert_vector_eq(&expect, &vector); + let val = vector.get_data(0).unwrap(); + assert_eq!(val, val.as_scalar_ref()); + assert_eq!(TimestampSecond::from(10), val.to_owned_scalar()); + } +} diff --git a/src/datatypes2/src/schema.rs b/src/datatypes2/src/schema.rs new file mode 100644 index 0000000000..328fe0de24 --- /dev/null +++ b/src/datatypes2/src/schema.rs @@ -0,0 +1,430 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod column_schema; +mod constraint; +mod raw; + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow::datatypes::{Field, Schema as ArrowSchema}; +use snafu::{ensure, ResultExt}; + +use crate::data_type::DataType; +use crate::error::{self, Error, Result}; +pub use crate::schema::column_schema::{ColumnSchema, Metadata}; +pub use crate::schema::constraint::ColumnDefaultConstraint; +pub use crate::schema::raw::RawSchema; + +/// Key used to store version number of the schema in metadata. +const VERSION_KEY: &str = "greptime:version"; + +/// A common schema, should be immutable. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Schema { + column_schemas: Vec, + name_to_index: HashMap, + arrow_schema: Arc, + /// Index of the timestamp key column. + /// + /// Timestamp key column is the column holds the timestamp and forms part of + /// the primary key. None means there is no timestamp key column. + timestamp_index: Option, + /// Version of the schema. + /// + /// Initial value is zero. The version should bump after altering schema. + version: u32, +} + +impl Schema { + /// Initial version of the schema. + pub const INITIAL_VERSION: u32 = 0; + + /// Create a schema from a vector of [ColumnSchema]. + /// + /// # Panics + /// Panics when ColumnSchema's `default_constraint` can't be serialized into json. + pub fn new(column_schemas: Vec) -> Schema { + // Builder won't fail in this case + SchemaBuilder::try_from(column_schemas) + .unwrap() + .build() + .unwrap() + } + + /// Try to Create a schema from a vector of [ColumnSchema]. + pub fn try_new(column_schemas: Vec) -> Result { + SchemaBuilder::try_from(column_schemas)?.build() + } + + #[inline] + pub fn arrow_schema(&self) -> &Arc { + &self.arrow_schema + } + + #[inline] + pub fn column_schemas(&self) -> &[ColumnSchema] { + &self.column_schemas + } + + pub fn column_schema_by_name(&self, name: &str) -> Option<&ColumnSchema> { + self.name_to_index + .get(name) + .map(|index| &self.column_schemas[*index]) + } + + /// Retrieve the column's name by index + /// # Panics + /// This method **may** panic if the index is out of range of column schemas. + #[inline] + pub fn column_name_by_index(&self, idx: usize) -> &str { + &self.column_schemas[idx].name + } + + #[inline] + pub fn column_index_by_name(&self, name: &str) -> Option { + self.name_to_index.get(name).copied() + } + + #[inline] + pub fn contains_column(&self, name: &str) -> bool { + self.name_to_index.contains_key(name) + } + + #[inline] + pub fn num_columns(&self) -> usize { + self.column_schemas.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.column_schemas.is_empty() + } + + /// Returns index of the timestamp key column. + #[inline] + pub fn timestamp_index(&self) -> Option { + self.timestamp_index + } + + #[inline] + pub fn timestamp_column(&self) -> Option<&ColumnSchema> { + self.timestamp_index.map(|idx| &self.column_schemas[idx]) + } + + #[inline] + pub fn version(&self) -> u32 { + self.version + } + + #[inline] + pub fn metadata(&self) -> &HashMap { + &self.arrow_schema.metadata + } +} + +#[derive(Default)] +pub struct SchemaBuilder { + column_schemas: Vec, + name_to_index: HashMap, + fields: Vec, + timestamp_index: Option, + version: u32, + metadata: HashMap, +} + +impl TryFrom> for SchemaBuilder { + type Error = Error; + + fn try_from(column_schemas: Vec) -> Result { + SchemaBuilder::try_from_columns(column_schemas) + } +} + +impl SchemaBuilder { + pub fn try_from_columns(column_schemas: Vec) -> Result { + let FieldsAndIndices { + fields, + name_to_index, + timestamp_index, + } = collect_fields(&column_schemas)?; + + Ok(Self { + column_schemas, + name_to_index, + fields, + timestamp_index, + ..Default::default() + }) + } + + pub fn version(mut self, version: u32) -> Self { + self.version = version; + self + } + + /// Add key value pair to metadata. + /// + /// Old metadata with same key would be overwritten. + pub fn add_metadata(mut self, key: impl Into, value: impl Into) -> Self { + self.metadata.insert(key.into(), value.into()); + self + } + + pub fn build(mut self) -> Result { + if let Some(timestamp_index) = self.timestamp_index { + validate_timestamp_index(&self.column_schemas, timestamp_index)?; + } + + self.metadata + .insert(VERSION_KEY.to_string(), self.version.to_string()); + + let arrow_schema = ArrowSchema::new(self.fields).with_metadata(self.metadata); + + Ok(Schema { + column_schemas: self.column_schemas, + name_to_index: self.name_to_index, + arrow_schema: Arc::new(arrow_schema), + timestamp_index: self.timestamp_index, + version: self.version, + }) + } +} + +struct FieldsAndIndices { + fields: Vec, + name_to_index: HashMap, + timestamp_index: Option, +} + +fn collect_fields(column_schemas: &[ColumnSchema]) -> Result { + let mut fields = Vec::with_capacity(column_schemas.len()); + let mut name_to_index = HashMap::with_capacity(column_schemas.len()); + let mut timestamp_index = None; + for (index, column_schema) in column_schemas.iter().enumerate() { + if column_schema.is_time_index() { + ensure!( + timestamp_index.is_none(), + error::DuplicateTimestampIndexSnafu { + exists: timestamp_index.unwrap(), + new: index, + } + ); + timestamp_index = Some(index); + } + let field = Field::try_from(column_schema)?; + fields.push(field); + name_to_index.insert(column_schema.name.clone(), index); + } + + Ok(FieldsAndIndices { + fields, + name_to_index, + timestamp_index, + }) +} + +fn validate_timestamp_index(column_schemas: &[ColumnSchema], timestamp_index: usize) -> Result<()> { + ensure!( + timestamp_index < column_schemas.len(), + error::InvalidTimestampIndexSnafu { + index: timestamp_index, + } + ); + + let column_schema = &column_schemas[timestamp_index]; + ensure!( + column_schema.data_type.is_timestamp_compatible(), + error::InvalidTimestampIndexSnafu { + index: timestamp_index, + } + ); + ensure!( + column_schema.is_time_index(), + error::InvalidTimestampIndexSnafu { + index: timestamp_index, + } + ); + + Ok(()) +} + +pub type SchemaRef = Arc; + +impl TryFrom> for Schema { + type Error = Error; + + fn try_from(arrow_schema: Arc) -> Result { + let mut column_schemas = Vec::with_capacity(arrow_schema.fields.len()); + let mut name_to_index = HashMap::with_capacity(arrow_schema.fields.len()); + for field in &arrow_schema.fields { + let column_schema = ColumnSchema::try_from(field)?; + name_to_index.insert(field.name().to_string(), column_schemas.len()); + column_schemas.push(column_schema); + } + + let mut timestamp_index = None; + for (index, column_schema) in column_schemas.iter().enumerate() { + if column_schema.is_time_index() { + validate_timestamp_index(&column_schemas, index)?; + ensure!( + timestamp_index.is_none(), + error::DuplicateTimestampIndexSnafu { + exists: timestamp_index.unwrap(), + new: index, + } + ); + timestamp_index = Some(index); + } + } + + let version = try_parse_version(&arrow_schema.metadata, VERSION_KEY)?; + + Ok(Self { + column_schemas, + name_to_index, + arrow_schema, + timestamp_index, + version, + }) + } +} + +impl TryFrom for Schema { + type Error = Error; + + fn try_from(arrow_schema: ArrowSchema) -> Result { + let arrow_schema = Arc::new(arrow_schema); + + Schema::try_from(arrow_schema) + } +} + +fn try_parse_version(metadata: &HashMap, key: &str) -> Result { + if let Some(value) = metadata.get(key) { + let version = value + .parse() + .context(error::ParseSchemaVersionSnafu { value })?; + + Ok(version) + } else { + Ok(Schema::INITIAL_VERSION) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::data_type::ConcreteDataType; + + #[test] + fn test_build_empty_schema() { + let schema = SchemaBuilder::default().build().unwrap(); + assert_eq!(0, schema.num_columns()); + assert!(schema.is_empty()); + } + + #[test] + fn test_schema_no_timestamp() { + let column_schemas = vec![ + ColumnSchema::new("col1", ConcreteDataType::int32_datatype(), false), + ColumnSchema::new("col2", ConcreteDataType::float64_datatype(), true), + ]; + let schema = Schema::new(column_schemas.clone()); + + assert_eq!(2, schema.num_columns()); + assert!(!schema.is_empty()); + assert!(schema.timestamp_index().is_none()); + assert!(schema.timestamp_column().is_none()); + assert_eq!(Schema::INITIAL_VERSION, schema.version()); + + for column_schema in &column_schemas { + let found = schema.column_schema_by_name(&column_schema.name).unwrap(); + assert_eq!(column_schema, found); + } + assert!(schema.column_schema_by_name("col3").is_none()); + + let new_schema = Schema::try_from(schema.arrow_schema().clone()).unwrap(); + + assert_eq!(schema, new_schema); + assert_eq!(column_schemas, schema.column_schemas()); + } + + #[test] + fn test_metadata() { + let column_schemas = vec![ColumnSchema::new( + "col1", + ConcreteDataType::int32_datatype(), + false, + )]; + let schema = SchemaBuilder::try_from(column_schemas) + .unwrap() + .add_metadata("k1", "v1") + .build() + .unwrap(); + + assert_eq!("v1", schema.metadata().get("k1").unwrap()); + } + + #[test] + fn test_schema_with_timestamp() { + let column_schemas = vec![ + ColumnSchema::new("col1", ConcreteDataType::int32_datatype(), true), + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ]; + let schema = SchemaBuilder::try_from(column_schemas.clone()) + .unwrap() + .version(123) + .build() + .unwrap(); + + assert_eq!(1, schema.timestamp_index().unwrap()); + assert_eq!(&column_schemas[1], schema.timestamp_column().unwrap()); + assert_eq!(123, schema.version()); + + let new_schema = Schema::try_from(schema.arrow_schema().clone()).unwrap(); + assert_eq!(1, schema.timestamp_index().unwrap()); + assert_eq!(schema, new_schema); + } + + #[test] + fn test_schema_wrong_timestamp() { + let column_schemas = vec![ + ColumnSchema::new("col1", ConcreteDataType::int32_datatype(), true) + .with_time_index(true), + ColumnSchema::new("col2", ConcreteDataType::float64_datatype(), false), + ]; + assert!(SchemaBuilder::try_from(column_schemas) + .unwrap() + .build() + .is_err()); + + let column_schemas = vec![ + ColumnSchema::new("col1", ConcreteDataType::int32_datatype(), true), + ColumnSchema::new("col2", ConcreteDataType::float64_datatype(), false) + .with_time_index(true), + ]; + + assert!(SchemaBuilder::try_from(column_schemas) + .unwrap() + .build() + .is_err()); + } +} diff --git a/src/datatypes2/src/schema/column_schema.rs b/src/datatypes2/src/schema/column_schema.rs new file mode 100644 index 0000000000..0577ca6aff --- /dev/null +++ b/src/datatypes2/src/schema/column_schema.rs @@ -0,0 +1,305 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeMap; + +use arrow::datatypes::Field; +use serde::{Deserialize, Serialize}; +use snafu::{ensure, ResultExt}; + +use crate::data_type::{ConcreteDataType, DataType}; +use crate::error::{self, Error, Result}; +use crate::schema::constraint::ColumnDefaultConstraint; +use crate::vectors::VectorRef; + +pub type Metadata = BTreeMap; + +/// Key used to store whether the column is time index in arrow field's metadata. +const TIME_INDEX_KEY: &str = "greptime:time_index"; +/// Key used to store default constraint in arrow field's metadata. +const DEFAULT_CONSTRAINT_KEY: &str = "greptime:default_constraint"; + +/// Schema of a column, used as an immutable struct. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ColumnSchema { + pub name: String, + pub data_type: ConcreteDataType, + is_nullable: bool, + is_time_index: bool, + default_constraint: Option, + metadata: Metadata, +} + +impl ColumnSchema { + pub fn new>( + name: T, + data_type: ConcreteDataType, + is_nullable: bool, + ) -> ColumnSchema { + ColumnSchema { + name: name.into(), + data_type, + is_nullable, + is_time_index: false, + default_constraint: None, + metadata: Metadata::new(), + } + } + + #[inline] + pub fn is_time_index(&self) -> bool { + self.is_time_index + } + + #[inline] + pub fn is_nullable(&self) -> bool { + self.is_nullable + } + + #[inline] + pub fn default_constraint(&self) -> Option<&ColumnDefaultConstraint> { + self.default_constraint.as_ref() + } + + #[inline] + pub fn metadata(&self) -> &Metadata { + &self.metadata + } + + pub fn with_time_index(mut self, is_time_index: bool) -> Self { + self.is_time_index = is_time_index; + if is_time_index { + self.metadata + .insert(TIME_INDEX_KEY.to_string(), "true".to_string()); + } else { + self.metadata.remove(TIME_INDEX_KEY); + } + self + } + + pub fn with_default_constraint( + mut self, + default_constraint: Option, + ) -> Result { + if let Some(constraint) = &default_constraint { + constraint.validate(&self.data_type, self.is_nullable)?; + } + + self.default_constraint = default_constraint; + Ok(self) + } + + /// Creates a new [`ColumnSchema`] with given metadata. + pub fn with_metadata(mut self, metadata: Metadata) -> Self { + self.metadata = metadata; + self + } + + pub fn create_default_vector(&self, num_rows: usize) -> Result> { + match &self.default_constraint { + Some(c) => c + .create_default_vector(&self.data_type, self.is_nullable, num_rows) + .map(Some), + None => { + if self.is_nullable { + // No default constraint, use null as default value. + // TODO(yingwen): Use NullVector once it supports setting logical type. + ColumnDefaultConstraint::null_value() + .create_default_vector(&self.data_type, self.is_nullable, num_rows) + .map(Some) + } else { + Ok(None) + } + } + } + } +} + +impl TryFrom<&Field> for ColumnSchema { + type Error = Error; + + fn try_from(field: &Field) -> Result { + let data_type = ConcreteDataType::try_from(field.data_type())?; + let mut metadata = field.metadata().cloned().unwrap_or_default(); + let default_constraint = match metadata.remove(DEFAULT_CONSTRAINT_KEY) { + Some(json) => { + Some(serde_json::from_str(&json).context(error::DeserializeSnafu { json })?) + } + None => None, + }; + let is_time_index = metadata.contains_key(TIME_INDEX_KEY); + + Ok(ColumnSchema { + name: field.name().clone(), + data_type, + is_nullable: field.is_nullable(), + is_time_index, + default_constraint, + metadata, + }) + } +} + +impl TryFrom<&ColumnSchema> for Field { + type Error = Error; + + fn try_from(column_schema: &ColumnSchema) -> Result { + let mut metadata = column_schema.metadata.clone(); + if let Some(value) = &column_schema.default_constraint { + // Adds an additional metadata to store the default constraint. + let old = metadata.insert( + DEFAULT_CONSTRAINT_KEY.to_string(), + serde_json::to_string(&value).context(error::SerializeSnafu)?, + ); + + ensure!( + old.is_none(), + error::DuplicateMetaSnafu { + key: DEFAULT_CONSTRAINT_KEY, + } + ); + } + + Ok(Field::new( + &column_schema.name, + column_schema.data_type.as_arrow_type(), + column_schema.is_nullable(), + ) + .with_metadata(Some(metadata))) + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType as ArrowDataType; + + use super::*; + use crate::value::Value; + + #[test] + fn test_column_schema() { + let column_schema = ColumnSchema::new("test", ConcreteDataType::int32_datatype(), true); + let field = Field::try_from(&column_schema).unwrap(); + assert_eq!("test", field.name()); + assert_eq!(ArrowDataType::Int32, *field.data_type()); + assert!(field.is_nullable()); + + let new_column_schema = ColumnSchema::try_from(&field).unwrap(); + assert_eq!(column_schema, new_column_schema); + } + + #[test] + fn test_column_schema_with_default_constraint() { + let column_schema = ColumnSchema::new("test", ConcreteDataType::int32_datatype(), true) + .with_default_constraint(Some(ColumnDefaultConstraint::Value(Value::from(99)))) + .unwrap(); + assert!(column_schema + .metadata() + .get(DEFAULT_CONSTRAINT_KEY) + .is_none()); + + let field = Field::try_from(&column_schema).unwrap(); + assert_eq!("test", field.name()); + assert_eq!(ArrowDataType::Int32, *field.data_type()); + assert!(field.is_nullable()); + assert_eq!( + "{\"Value\":{\"Int32\":99}}", + field + .metadata() + .unwrap() + .get(DEFAULT_CONSTRAINT_KEY) + .unwrap() + ); + + let new_column_schema = ColumnSchema::try_from(&field).unwrap(); + assert_eq!(column_schema, new_column_schema); + } + + #[test] + fn test_column_schema_with_metadata() { + let mut metadata = Metadata::new(); + metadata.insert("k1".to_string(), "v1".to_string()); + let column_schema = ColumnSchema::new("test", ConcreteDataType::int32_datatype(), true) + .with_metadata(metadata) + .with_default_constraint(Some(ColumnDefaultConstraint::null_value())) + .unwrap(); + assert_eq!("v1", column_schema.metadata().get("k1").unwrap()); + assert!(column_schema + .metadata() + .get(DEFAULT_CONSTRAINT_KEY) + .is_none()); + + let field = Field::try_from(&column_schema).unwrap(); + assert_eq!("v1", field.metadata().unwrap().get("k1").unwrap()); + assert!(field + .metadata() + .unwrap() + .get(DEFAULT_CONSTRAINT_KEY) + .is_some()); + + let new_column_schema = ColumnSchema::try_from(&field).unwrap(); + assert_eq!(column_schema, new_column_schema); + } + + #[test] + fn test_column_schema_with_duplicate_metadata() { + let mut metadata = Metadata::new(); + metadata.insert(DEFAULT_CONSTRAINT_KEY.to_string(), "v1".to_string()); + let column_schema = ColumnSchema::new("test", ConcreteDataType::int32_datatype(), true) + .with_metadata(metadata) + .with_default_constraint(Some(ColumnDefaultConstraint::null_value())) + .unwrap(); + Field::try_from(&column_schema).unwrap_err(); + } + + #[test] + fn test_column_schema_invalid_default_constraint() { + ColumnSchema::new("test", ConcreteDataType::int32_datatype(), false) + .with_default_constraint(Some(ColumnDefaultConstraint::null_value())) + .unwrap_err(); + } + + #[test] + fn test_column_default_constraint_try_into_from() { + let default_constraint = ColumnDefaultConstraint::Value(Value::from(42i64)); + + let bytes: Vec = default_constraint.clone().try_into().unwrap(); + let from_value = ColumnDefaultConstraint::try_from(&bytes[..]).unwrap(); + + assert_eq!(default_constraint, from_value); + } + + #[test] + fn test_column_schema_create_default_null() { + // Implicit default null. + let column_schema = ColumnSchema::new("test", ConcreteDataType::int32_datatype(), true); + let v = column_schema.create_default_vector(5).unwrap().unwrap(); + assert_eq!(5, v.len()); + assert!(v.only_null()); + + // Explicit default null. + let column_schema = ColumnSchema::new("test", ConcreteDataType::int32_datatype(), true) + .with_default_constraint(Some(ColumnDefaultConstraint::null_value())) + .unwrap(); + let v = column_schema.create_default_vector(5).unwrap().unwrap(); + assert_eq!(5, v.len()); + assert!(v.only_null()); + } + + #[test] + fn test_column_schema_no_default() { + let column_schema = ColumnSchema::new("test", ConcreteDataType::int32_datatype(), false); + assert!(column_schema.create_default_vector(5).unwrap().is_none()); + } +} diff --git a/src/datatypes2/src/schema/constraint.rs b/src/datatypes2/src/schema/constraint.rs new file mode 100644 index 0000000000..4dd3ecc14b --- /dev/null +++ b/src/datatypes2/src/schema/constraint.rs @@ -0,0 +1,306 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use common_time::util; +use serde::{Deserialize, Serialize}; +use snafu::{ensure, ResultExt}; + +use crate::data_type::{ConcreteDataType, DataType}; +use crate::error::{self, Result}; +use crate::value::Value; +use crate::vectors::{Int64Vector, TimestampMillisecondVector, VectorRef}; + +const CURRENT_TIMESTAMP: &str = "current_timestamp()"; + +/// Column's default constraint. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ColumnDefaultConstraint { + // A function invocation + // TODO(dennis): we save the function expression here, maybe use a struct in future. + Function(String), + // A value + Value(Value), +} + +impl TryFrom<&[u8]> for ColumnDefaultConstraint { + type Error = error::Error; + + fn try_from(bytes: &[u8]) -> Result { + let json = String::from_utf8_lossy(bytes); + serde_json::from_str(&json).context(error::DeserializeSnafu { json }) + } +} + +impl TryFrom for Vec { + type Error = error::Error; + + fn try_from(value: ColumnDefaultConstraint) -> std::result::Result { + let s = serde_json::to_string(&value).context(error::SerializeSnafu)?; + Ok(s.into_bytes()) + } +} + +impl Display for ColumnDefaultConstraint { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ColumnDefaultConstraint::Function(expr) => write!(f, "{}", expr), + ColumnDefaultConstraint::Value(v) => write!(f, "{}", v), + } + } +} + +impl ColumnDefaultConstraint { + /// Returns a default null constraint. + pub fn null_value() -> ColumnDefaultConstraint { + ColumnDefaultConstraint::Value(Value::Null) + } + + /// Check whether the constraint is valid for columns with given `data_type` + /// and `is_nullable` attributes. + pub fn validate(&self, data_type: &ConcreteDataType, is_nullable: bool) -> Result<()> { + ensure!(is_nullable || !self.maybe_null(), error::NullDefaultSnafu); + + match self { + ColumnDefaultConstraint::Function(expr) => { + ensure!( + expr == CURRENT_TIMESTAMP, + error::UnsupportedDefaultExprSnafu { expr } + ); + ensure!( + data_type.is_timestamp_compatible(), + error::DefaultValueTypeSnafu { + reason: "return value of the function must has timestamp type", + } + ); + } + ColumnDefaultConstraint::Value(v) => { + if !v.is_null() { + // Whether the value could be nullable has been checked before, only need + // to check the type compatibility here. + ensure!( + data_type.logical_type_id() == v.logical_type_id(), + error::DefaultValueTypeSnafu { + reason: format!( + "column has type {:?} but default value has type {:?}", + data_type.logical_type_id(), + v.logical_type_id() + ), + } + ); + } + } + } + + Ok(()) + } + + /// Create a vector that contains `num_rows` default values for given `data_type`. + /// + /// If `is_nullable` is `true`, then this method would returns error if the created + /// default value is null. + /// + /// # Panics + /// Panics if `num_rows == 0`. + pub fn create_default_vector( + &self, + data_type: &ConcreteDataType, + is_nullable: bool, + num_rows: usize, + ) -> Result { + assert!(num_rows > 0); + + match self { + ColumnDefaultConstraint::Function(expr) => { + // Functions should also ensure its return value is not null when + // is_nullable is true. + match &expr[..] { + // TODO(dennis): we only supports current_timestamp right now, + // it's better to use a expression framework in future. + CURRENT_TIMESTAMP => create_current_timestamp_vector(data_type, num_rows), + _ => error::UnsupportedDefaultExprSnafu { expr }.fail(), + } + } + ColumnDefaultConstraint::Value(v) => { + ensure!(is_nullable || !v.is_null(), error::NullDefaultSnafu); + + // TODO(yingwen): + // 1. For null value, we could use NullVector once it supports custom logical type. + // 2. For non null value, we could use ConstantVector, but it would cause all codes + // attempt to downcast the vector fail if they don't check whether the vector is const + // first. + let mut mutable_vector = data_type.create_mutable_vector(1); + mutable_vector.push_value_ref(v.as_value_ref())?; + let base_vector = mutable_vector.to_vector(); + Ok(base_vector.replicate(&[num_rows])) + } + } + } + + /// Returns true if this constraint might creates NULL. + fn maybe_null(&self) -> bool { + // Once we support more functions, we may return true if given function + // could return null. + matches!(self, ColumnDefaultConstraint::Value(Value::Null)) + } +} + +fn create_current_timestamp_vector( + data_type: &ConcreteDataType, + num_rows: usize, +) -> Result { + // FIXME(yingwen): We should implements cast in VectorOp so we could cast the millisecond vector + // to other data type and avoid this match. + match data_type { + ConcreteDataType::Timestamp(_) => Ok(Arc::new(TimestampMillisecondVector::from_values( + std::iter::repeat(util::current_time_millis()).take(num_rows), + ))), + ConcreteDataType::Int64(_) => Ok(Arc::new(Int64Vector::from_values( + std::iter::repeat(util::current_time_millis()).take(num_rows), + ))), + _ => error::DefaultValueTypeSnafu { + reason: format!( + "Not support to assign current timestamp to {:?} type", + data_type + ), + } + .fail(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Error; + use crate::vectors::Int32Vector; + + #[test] + fn test_null_default_constraint() { + let constraint = ColumnDefaultConstraint::null_value(); + assert!(constraint.maybe_null()); + let constraint = ColumnDefaultConstraint::Value(Value::Int32(10)); + assert!(!constraint.maybe_null()); + } + + #[test] + fn test_validate_null_constraint() { + let constraint = ColumnDefaultConstraint::null_value(); + let data_type = ConcreteDataType::int32_datatype(); + constraint.validate(&data_type, false).unwrap_err(); + constraint.validate(&data_type, true).unwrap(); + } + + #[test] + fn test_validate_value_constraint() { + let constraint = ColumnDefaultConstraint::Value(Value::Int32(10)); + let data_type = ConcreteDataType::int32_datatype(); + constraint.validate(&data_type, false).unwrap(); + constraint.validate(&data_type, true).unwrap(); + + constraint + .validate(&ConcreteDataType::uint32_datatype(), true) + .unwrap_err(); + } + + #[test] + fn test_validate_function_constraint() { + let constraint = ColumnDefaultConstraint::Function(CURRENT_TIMESTAMP.to_string()); + constraint + .validate(&ConcreteDataType::timestamp_millisecond_datatype(), false) + .unwrap(); + constraint + .validate(&ConcreteDataType::boolean_datatype(), false) + .unwrap_err(); + + let constraint = ColumnDefaultConstraint::Function("hello()".to_string()); + constraint + .validate(&ConcreteDataType::timestamp_millisecond_datatype(), false) + .unwrap_err(); + } + + #[test] + fn test_create_default_vector_by_null() { + let constraint = ColumnDefaultConstraint::null_value(); + let data_type = ConcreteDataType::int32_datatype(); + constraint + .create_default_vector(&data_type, false, 10) + .unwrap_err(); + + let constraint = ColumnDefaultConstraint::null_value(); + let v = constraint + .create_default_vector(&data_type, true, 3) + .unwrap(); + assert_eq!(3, v.len()); + for i in 0..v.len() { + assert_eq!(Value::Null, v.get(i)); + } + } + + #[test] + fn test_create_default_vector_by_value() { + let constraint = ColumnDefaultConstraint::Value(Value::Int32(10)); + let data_type = ConcreteDataType::int32_datatype(); + let v = constraint + .create_default_vector(&data_type, false, 4) + .unwrap(); + let expect: VectorRef = Arc::new(Int32Vector::from_values(vec![10; 4])); + assert_eq!(expect, v); + } + + #[test] + fn test_create_default_vector_by_func() { + let constraint = ColumnDefaultConstraint::Function(CURRENT_TIMESTAMP.to_string()); + // Timestamp type. + let data_type = ConcreteDataType::timestamp_millisecond_datatype(); + let v = constraint + .create_default_vector(&data_type, false, 4) + .unwrap(); + assert_eq!(4, v.len()); + assert!( + matches!(v.get(0), Value::Timestamp(_)), + "v {:?} is not timestamp", + v.get(0) + ); + + // Int64 type. + let data_type = ConcreteDataType::int64_datatype(); + let v = constraint + .create_default_vector(&data_type, false, 4) + .unwrap(); + assert_eq!(4, v.len()); + assert!( + matches!(v.get(0), Value::Int64(_)), + "v {:?} is not timestamp", + v.get(0) + ); + + let constraint = ColumnDefaultConstraint::Function("no".to_string()); + let data_type = ConcreteDataType::timestamp_millisecond_datatype(); + constraint + .create_default_vector(&data_type, false, 4) + .unwrap_err(); + } + + #[test] + fn test_create_by_func_and_invalid_type() { + let constraint = ColumnDefaultConstraint::Function(CURRENT_TIMESTAMP.to_string()); + let data_type = ConcreteDataType::boolean_datatype(); + let err = constraint + .create_default_vector(&data_type, false, 4) + .unwrap_err(); + assert!(matches!(err, Error::DefaultValueType { .. }), "{:?}", err); + } +} diff --git a/src/datatypes2/src/schema/raw.rs b/src/datatypes2/src/schema/raw.rs new file mode 100644 index 0000000000..75f0853b4b --- /dev/null +++ b/src/datatypes2/src/schema/raw.rs @@ -0,0 +1,77 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use serde::{Deserialize, Serialize}; + +use crate::error::{Error, Result}; +use crate::schema::{ColumnSchema, Schema, SchemaBuilder}; + +/// Struct used to serialize and deserialize [`Schema`](crate::schema::Schema). +/// +/// This struct only contains necessary data to recover the Schema. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct RawSchema { + pub column_schemas: Vec, + pub timestamp_index: Option, + pub version: u32, +} + +impl TryFrom for Schema { + type Error = Error; + + fn try_from(raw: RawSchema) -> Result { + SchemaBuilder::try_from(raw.column_schemas)? + .version(raw.version) + .build() + } +} + +impl From<&Schema> for RawSchema { + fn from(schema: &Schema) -> RawSchema { + RawSchema { + column_schemas: schema.column_schemas.clone(), + timestamp_index: schema.timestamp_index, + version: schema.version, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::data_type::ConcreteDataType; + + #[test] + fn test_raw_convert() { + let column_schemas = vec![ + ColumnSchema::new("col1", ConcreteDataType::int32_datatype(), true), + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_millisecond_datatype(), + false, + ) + .with_time_index(true), + ]; + let schema = SchemaBuilder::try_from(column_schemas) + .unwrap() + .version(123) + .build() + .unwrap(); + + let raw = RawSchema::from(&schema); + let schema_new = Schema::try_from(raw).unwrap(); + + assert_eq!(schema, schema_new); + } +} diff --git a/src/datatypes2/src/serialize.rs b/src/datatypes2/src/serialize.rs new file mode 100644 index 0000000000..1cbf04cedd --- /dev/null +++ b/src/datatypes2/src/serialize.rs @@ -0,0 +1,20 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::error::Result; + +pub trait Serializable: Send + Sync { + /// Serialize a column of value with given type to JSON value + fn serialize_to_json(&self) -> Result>; +} diff --git a/src/datatypes2/src/timestamp.rs b/src/datatypes2/src/timestamp.rs new file mode 100644 index 0000000000..f14e91a6c6 --- /dev/null +++ b/src/datatypes2/src/timestamp.rs @@ -0,0 +1,135 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use common_time::timestamp::TimeUnit; +use common_time::Timestamp; +use paste::paste; +use serde::{Deserialize, Serialize}; + +use crate::prelude::{Scalar, Value, ValueRef}; +use crate::scalars::ScalarRef; +use crate::types::{ + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, WrapperType, +}; +use crate::vectors::{ + TimestampMicrosecondVector, TimestampMillisecondVector, TimestampNanosecondVector, + TimestampSecondVector, +}; + +macro_rules! define_timestamp_with_unit { + ($unit: ident) => { + paste! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] + pub struct [](pub Timestamp); + + impl [] { + pub fn new(val: i64) -> Self { + Self(Timestamp::new(val, TimeUnit::$unit)) + } + } + + impl Default for [] { + fn default() -> Self { + Self::new(0) + } + } + + impl From<[]> for Value { + fn from(t: []) -> Value { + Value::Timestamp(t.0) + } + } + + impl From<[]> for serde_json::Value { + fn from(t: []) -> Self { + t.0.into() + } + } + + impl From<[]> for ValueRef<'static> { + fn from(t: []) -> Self { + ValueRef::Timestamp(t.0) + } + } + + impl Scalar for [] { + type VectorType = []; + type RefType<'a> = []; + + fn as_scalar_ref(&self) -> Self::RefType<'_> { + *self + } + + fn upcast_gat<'short, 'long: 'short>( + long: Self::RefType<'long>, + ) -> Self::RefType<'short> { + long + } + } + + impl<'a> ScalarRef<'a> for [] { + type ScalarType = []; + + fn to_owned_scalar(&self) -> Self::ScalarType { + *self + } + } + + impl WrapperType for [] { + type LogicalType = []; + type Native = i64; + + fn from_native(value: Self::Native) -> Self { + Self::new(value) + } + + fn into_native(self) -> Self::Native { + self.0.into() + } + } + + impl From for [] { + fn from(val: i64) -> Self { + []::from_native(val) + } + } + } + }; +} + +define_timestamp_with_unit!(Second); +define_timestamp_with_unit!(Millisecond); +define_timestamp_with_unit!(Microsecond); +define_timestamp_with_unit!(Nanosecond); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_timestamp_scalar() { + let ts = TimestampSecond::new(123); + assert_eq!(ts, ts.as_scalar_ref()); + assert_eq!(ts, ts.to_owned_scalar()); + let ts = TimestampMillisecond::new(123); + assert_eq!(ts, ts.as_scalar_ref()); + assert_eq!(ts, ts.to_owned_scalar()); + let ts = TimestampMicrosecond::new(123); + assert_eq!(ts, ts.as_scalar_ref()); + assert_eq!(ts, ts.to_owned_scalar()); + let ts = TimestampNanosecond::new(123); + assert_eq!(ts, ts.as_scalar_ref()); + assert_eq!(ts, ts.to_owned_scalar()); + } +} diff --git a/src/datatypes2/src/type_id.rs b/src/datatypes2/src/type_id.rs new file mode 100644 index 0000000000..bcb7ea52b1 --- /dev/null +++ b/src/datatypes2/src/type_id.rs @@ -0,0 +1,93 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// Unique identifier for logical data type. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum LogicalTypeId { + Null, + + // Numeric types: + Boolean, + Int8, + Int16, + Int32, + Int64, + UInt8, + UInt16, + UInt32, + UInt64, + Float32, + Float64, + + // String types: + String, + Binary, + + // Date & Time types: + /// Date representing the elapsed time since UNIX epoch (1970-01-01) + /// in days (32 bits). + Date, + /// Datetime representing the elapsed time since UNIX epoch (1970-01-01) in + /// seconds/milliseconds/microseconds/nanoseconds, determined by precision. + DateTime, + + TimestampSecond, + TimestampMillisecond, + TimestampMicrosecond, + TimestampNanosecond, + + List, +} + +impl LogicalTypeId { + /// Create ConcreteDataType based on this id. This method is for test only as it + /// would lost some info. + /// + /// # Panics + /// Panics if data type is not supported. + #[cfg(any(test, feature = "test"))] + pub fn data_type(&self) -> crate::data_type::ConcreteDataType { + use crate::data_type::ConcreteDataType; + + match self { + LogicalTypeId::Null => ConcreteDataType::null_datatype(), + LogicalTypeId::Boolean => ConcreteDataType::boolean_datatype(), + LogicalTypeId::Int8 => ConcreteDataType::int8_datatype(), + LogicalTypeId::Int16 => ConcreteDataType::int16_datatype(), + LogicalTypeId::Int32 => ConcreteDataType::int32_datatype(), + LogicalTypeId::Int64 => ConcreteDataType::int64_datatype(), + LogicalTypeId::UInt8 => ConcreteDataType::uint8_datatype(), + LogicalTypeId::UInt16 => ConcreteDataType::uint16_datatype(), + LogicalTypeId::UInt32 => ConcreteDataType::uint32_datatype(), + LogicalTypeId::UInt64 => ConcreteDataType::uint64_datatype(), + LogicalTypeId::Float32 => ConcreteDataType::float32_datatype(), + LogicalTypeId::Float64 => ConcreteDataType::float64_datatype(), + LogicalTypeId::String => ConcreteDataType::string_datatype(), + LogicalTypeId::Binary => ConcreteDataType::binary_datatype(), + LogicalTypeId::Date => ConcreteDataType::date_datatype(), + LogicalTypeId::DateTime => ConcreteDataType::datetime_datatype(), + LogicalTypeId::TimestampSecond => ConcreteDataType::timestamp_second_datatype(), + LogicalTypeId::TimestampMillisecond => { + ConcreteDataType::timestamp_millisecond_datatype() + } + LogicalTypeId::TimestampMicrosecond => { + ConcreteDataType::timestamp_microsecond_datatype() + } + LogicalTypeId::TimestampNanosecond => ConcreteDataType::timestamp_nanosecond_datatype(), + LogicalTypeId::List => { + ConcreteDataType::list_datatype(ConcreteDataType::null_datatype()) + } + } + } +} diff --git a/src/datatypes2/src/types.rs b/src/datatypes2/src/types.rs new file mode 100644 index 0000000000..186704fdfd --- /dev/null +++ b/src/datatypes2/src/types.rs @@ -0,0 +1,37 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod binary_type; +mod boolean_type; +mod date_type; +mod datetime_type; +mod list_type; +mod null_type; +mod primitive_type; +mod string_type; + +mod timestamp_type; + +pub use binary_type::BinaryType; +pub use boolean_type::BooleanType; +pub use date_type::DateType; +pub use datetime_type::DateTimeType; +pub use list_type::ListType; +pub use null_type::NullType; +pub use primitive_type::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, LogicalPrimitiveType, + NativeType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, WrapperType, +}; +pub use string_type::StringType; +pub use timestamp_type::*; diff --git a/src/datatypes2/src/types/binary_type.rs b/src/datatypes2/src/types/binary_type.rs new file mode 100644 index 0000000000..0d06724fff --- /dev/null +++ b/src/datatypes2/src/types/binary_type.rs @@ -0,0 +1,60 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::datatypes::DataType as ArrowDataType; +use common_base::bytes::StringBytes; +use serde::{Deserialize, Serialize}; + +use crate::data_type::{DataType, DataTypeRef}; +use crate::scalars::ScalarVectorBuilder; +use crate::type_id::LogicalTypeId; +use crate::value::Value; +use crate::vectors::{BinaryVectorBuilder, MutableVector}; + +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct BinaryType; + +impl BinaryType { + pub fn arc() -> DataTypeRef { + Arc::new(Self) + } +} + +impl DataType for BinaryType { + fn name(&self) -> &str { + "Binary" + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::Binary + } + + fn default_value(&self) -> Value { + StringBytes::default().into() + } + + fn as_arrow_type(&self) -> ArrowDataType { + ArrowDataType::LargeBinary + } + + fn create_mutable_vector(&self, capacity: usize) -> Box { + Box::new(BinaryVectorBuilder::with_capacity(capacity)) + } + + fn is_timestamp_compatible(&self) -> bool { + false + } +} diff --git a/src/datatypes2/src/types/boolean_type.rs b/src/datatypes2/src/types/boolean_type.rs new file mode 100644 index 0000000000..36d92169eb --- /dev/null +++ b/src/datatypes2/src/types/boolean_type.rs @@ -0,0 +1,59 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::datatypes::DataType as ArrowDataType; +use serde::{Deserialize, Serialize}; + +use crate::data_type::{DataType, DataTypeRef}; +use crate::scalars::ScalarVectorBuilder; +use crate::type_id::LogicalTypeId; +use crate::value::Value; +use crate::vectors::{BooleanVectorBuilder, MutableVector}; + +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct BooleanType; + +impl BooleanType { + pub fn arc() -> DataTypeRef { + Arc::new(Self) + } +} + +impl DataType for BooleanType { + fn name(&self) -> &str { + "Boolean" + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::Boolean + } + + fn default_value(&self) -> Value { + bool::default().into() + } + + fn as_arrow_type(&self) -> ArrowDataType { + ArrowDataType::Boolean + } + + fn create_mutable_vector(&self, capacity: usize) -> Box { + Box::new(BooleanVectorBuilder::with_capacity(capacity)) + } + + fn is_timestamp_compatible(&self) -> bool { + false + } +} diff --git a/src/datatypes2/src/types/date_type.rs b/src/datatypes2/src/types/date_type.rs new file mode 100644 index 0000000000..052b837a3d --- /dev/null +++ b/src/datatypes2/src/types/date_type.rs @@ -0,0 +1,90 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::datatypes::{DataType as ArrowDataType, Date32Type}; +use common_time::Date; +use serde::{Deserialize, Serialize}; +use snafu::OptionExt; + +use crate::data_type::{ConcreteDataType, DataType}; +use crate::error::{self, Result}; +use crate::scalars::ScalarVectorBuilder; +use crate::type_id::LogicalTypeId; +use crate::types::LogicalPrimitiveType; +use crate::value::{Value, ValueRef}; +use crate::vectors::{DateVector, DateVectorBuilder, MutableVector, Vector}; + +/// Data type for Date (YYYY-MM-DD). +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DateType; + +impl DataType for DateType { + fn name(&self) -> &str { + "Date" + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::Date + } + + fn default_value(&self) -> Value { + Value::Date(Default::default()) + } + + fn as_arrow_type(&self) -> ArrowDataType { + ArrowDataType::Date32 + } + + fn create_mutable_vector(&self, capacity: usize) -> Box { + Box::new(DateVectorBuilder::with_capacity(capacity)) + } + + fn is_timestamp_compatible(&self) -> bool { + false + } +} + +impl LogicalPrimitiveType for DateType { + type ArrowPrimitive = Date32Type; + type Native = i32; + type Wrapper = Date; + + fn build_data_type() -> ConcreteDataType { + ConcreteDataType::date_datatype() + } + + fn type_name() -> &'static str { + "Date" + } + + fn cast_vector(vector: &dyn Vector) -> Result<&DateVector> { + vector + .as_any() + .downcast_ref::() + .with_context(|| error::CastTypeSnafu { + msg: format!("Failed to cast {} to DateVector", vector.vector_type_name(),), + }) + } + + fn cast_value_ref(value: ValueRef) -> Result> { + match value { + ValueRef::Null => Ok(None), + ValueRef::Date(v) => Ok(Some(v)), + other => error::CastTypeSnafu { + msg: format!("Failed to cast value {:?} to Date", other,), + } + .fail(), + } + } +} diff --git a/src/datatypes2/src/types/datetime_type.rs b/src/datatypes2/src/types/datetime_type.rs new file mode 100644 index 0000000000..d74a02effe --- /dev/null +++ b/src/datatypes2/src/types/datetime_type.rs @@ -0,0 +1,91 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::datatypes::{DataType as ArrowDataType, Date64Type}; +use common_time::DateTime; +use serde::{Deserialize, Serialize}; +use snafu::OptionExt; + +use crate::data_type::{ConcreteDataType, DataType}; +use crate::error::{self, Result}; +use crate::prelude::{LogicalTypeId, MutableVector, ScalarVectorBuilder, Value, ValueRef, Vector}; +use crate::types::LogicalPrimitiveType; +use crate::vectors::{DateTimeVector, DateTimeVectorBuilder, PrimitiveVector}; + +/// Data type for [`DateTime`]. +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct DateTimeType; + +impl DataType for DateTimeType { + fn name(&self) -> &str { + "DateTime" + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::DateTime + } + + fn default_value(&self) -> Value { + Value::DateTime(DateTime::default()) + } + + fn as_arrow_type(&self) -> ArrowDataType { + ArrowDataType::Date64 + } + + fn create_mutable_vector(&self, capacity: usize) -> Box { + Box::new(DateTimeVectorBuilder::with_capacity(capacity)) + } + + fn is_timestamp_compatible(&self) -> bool { + false + } +} + +impl LogicalPrimitiveType for DateTimeType { + type ArrowPrimitive = Date64Type; + type Native = i64; + type Wrapper = DateTime; + + fn build_data_type() -> ConcreteDataType { + ConcreteDataType::datetime_datatype() + } + + fn type_name() -> &'static str { + "DateTime" + } + + fn cast_vector(vector: &dyn Vector) -> Result<&PrimitiveVector> { + vector + .as_any() + .downcast_ref::() + .with_context(|| error::CastTypeSnafu { + msg: format!( + "Failed to cast {} to DateTimeVector", + vector.vector_type_name() + ), + }) + } + + fn cast_value_ref(value: ValueRef) -> Result> { + match value { + ValueRef::Null => Ok(None), + ValueRef::DateTime(v) => Ok(Some(v)), + other => error::CastTypeSnafu { + msg: format!("Failed to cast value {:?} to DateTime", other,), + } + .fail(), + } + } +} diff --git a/src/datatypes2/src/types/list_type.rs b/src/datatypes2/src/types/list_type.rs new file mode 100644 index 0000000000..b9875ca362 --- /dev/null +++ b/src/datatypes2/src/types/list_type.rs @@ -0,0 +1,95 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::datatypes::{DataType as ArrowDataType, Field}; +use serde::{Deserialize, Serialize}; + +use crate::data_type::{ConcreteDataType, DataType}; +use crate::type_id::LogicalTypeId; +use crate::value::{ListValue, Value}; +use crate::vectors::{ListVectorBuilder, MutableVector}; + +/// Used to represent the List datatype. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ListType { + /// The type of List's item. + // Use Box to avoid recursive dependency, as enum ConcreteDataType depends on ListType. + item_type: Box, +} + +impl Default for ListType { + fn default() -> Self { + ListType::new(ConcreteDataType::null_datatype()) + } +} + +impl ListType { + /// Create a new `ListType` whose item's data type is `item_type`. + pub fn new(item_type: ConcreteDataType) -> Self { + ListType { + item_type: Box::new(item_type), + } + } +} + +impl DataType for ListType { + fn name(&self) -> &str { + "List" + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::List + } + + fn default_value(&self) -> Value { + Value::List(ListValue::new(None, *self.item_type.clone())) + } + + fn as_arrow_type(&self) -> ArrowDataType { + let field = Box::new(Field::new("item", self.item_type.as_arrow_type(), true)); + ArrowDataType::List(field) + } + + fn create_mutable_vector(&self, capacity: usize) -> Box { + Box::new(ListVectorBuilder::with_type_capacity( + *self.item_type.clone(), + capacity, + )) + } + + fn is_timestamp_compatible(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::value::ListValue; + + #[test] + fn test_list_type() { + let t = ListType::new(ConcreteDataType::boolean_datatype()); + assert_eq!("List", t.name()); + assert_eq!(LogicalTypeId::List, t.logical_type_id()); + assert_eq!( + Value::List(ListValue::new(None, ConcreteDataType::boolean_datatype())), + t.default_value() + ); + assert_eq!( + ArrowDataType::List(Box::new(Field::new("item", ArrowDataType::Boolean, true))), + t.as_arrow_type() + ); + } +} diff --git a/src/datatypes2/src/types/null_type.rs b/src/datatypes2/src/types/null_type.rs new file mode 100644 index 0000000000..b9bb2dc752 --- /dev/null +++ b/src/datatypes2/src/types/null_type.rs @@ -0,0 +1,58 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::datatypes::DataType as ArrowDataType; +use serde::{Deserialize, Serialize}; + +use crate::data_type::{DataType, DataTypeRef}; +use crate::type_id::LogicalTypeId; +use crate::value::Value; +use crate::vectors::{MutableVector, NullVectorBuilder}; + +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct NullType; + +impl NullType { + pub fn arc() -> DataTypeRef { + Arc::new(NullType) + } +} + +impl DataType for NullType { + fn name(&self) -> &str { + "Null" + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::Null + } + + fn default_value(&self) -> Value { + Value::Null + } + + fn as_arrow_type(&self) -> ArrowDataType { + ArrowDataType::Null + } + + fn create_mutable_vector(&self, _capacity: usize) -> Box { + Box::new(NullVectorBuilder::default()) + } + + fn is_timestamp_compatible(&self) -> bool { + false + } +} diff --git a/src/datatypes2/src/types/primitive_type.rs b/src/datatypes2/src/types/primitive_type.rs new file mode 100644 index 0000000000..e389ca13bf --- /dev/null +++ b/src/datatypes2/src/types/primitive_type.rs @@ -0,0 +1,358 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Ordering; + +use arrow::datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType as ArrowDataType}; +use common_time::{Date, DateTime}; +use num::NumCast; +use serde::{Deserialize, Serialize}; +use snafu::OptionExt; + +use crate::data_type::{ConcreteDataType, DataType}; +use crate::error::{self, Result}; +use crate::scalars::{Scalar, ScalarRef, ScalarVectorBuilder}; +use crate::type_id::LogicalTypeId; +use crate::types::{DateTimeType, DateType}; +use crate::value::{Value, ValueRef}; +use crate::vectors::{MutableVector, PrimitiveVector, PrimitiveVectorBuilder, Vector}; + +/// Data types that can be used as arrow's native type. +pub trait NativeType: ArrowNativeType + NumCast { + /// Largest numeric type this primitive type can be cast to. + type LargestType: NativeType; +} + +macro_rules! impl_native_type { + ($Type: ident, $LargestType: ident) => { + impl NativeType for $Type { + type LargestType = $LargestType; + } + }; +} + +impl_native_type!(u8, u64); +impl_native_type!(u16, u64); +impl_native_type!(u32, u64); +impl_native_type!(u64, u64); +impl_native_type!(i8, i64); +impl_native_type!(i16, i64); +impl_native_type!(i32, i64); +impl_native_type!(i64, i64); +impl_native_type!(f32, f64); +impl_native_type!(f64, f64); + +/// Represents the wrapper type that wraps a native type using the `newtype pattern`, +/// such as [Date](`common_time::Date`) is a wrapper type for the underlying native +/// type `i32`. +pub trait WrapperType: + Copy + + Scalar + + PartialEq + + Into + + Into> + + Serialize + + Into +{ + /// Logical primitive type that this wrapper type belongs to. + type LogicalType: LogicalPrimitiveType; + /// The underlying native type. + type Native: NativeType; + + /// Convert native type into this wrapper type. + fn from_native(value: Self::Native) -> Self; + + /// Convert this wrapper type into native type. + fn into_native(self) -> Self::Native; +} + +/// Trait bridging the logical primitive type with [ArrowPrimitiveType]. +pub trait LogicalPrimitiveType: 'static + Sized { + /// Arrow primitive type of this logical type. + type ArrowPrimitive: ArrowPrimitiveType; + /// Native (physical) type of this logical type. + type Native: NativeType; + /// Wrapper type that the vector returns. + type Wrapper: WrapperType + + for<'a> Scalar, RefType<'a> = Self::Wrapper> + + for<'a> ScalarRef<'a, ScalarType = Self::Wrapper>; + + /// Construct the data type struct. + fn build_data_type() -> ConcreteDataType; + + /// Return the name of the type. + fn type_name() -> &'static str; + + /// Dynamic cast the vector to the concrete vector type. + fn cast_vector(vector: &dyn Vector) -> Result<&PrimitiveVector>; + + /// Cast value ref to the primitive type. + fn cast_value_ref(value: ValueRef) -> Result>; +} + +/// A new type for [WrapperType], complement the `Ord` feature for it. Wrapping non ordered +/// primitive types like `f32` and `f64` in `OrdPrimitive` can make them be used in places that +/// require `Ord`. For example, in `Median` or `Percentile` UDAFs. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct OrdPrimitive(pub T); + +impl OrdPrimitive { + pub fn as_primitive(&self) -> T { + self.0 + } +} + +impl Eq for OrdPrimitive {} + +impl PartialOrd for OrdPrimitive { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for OrdPrimitive { + fn cmp(&self, other: &Self) -> Ordering { + Into::::into(self.0).cmp(&Into::::into(other.0)) + } +} + +impl From> for Value { + fn from(p: OrdPrimitive) -> Self { + p.0.into() + } +} + +macro_rules! impl_wrapper { + ($Type: ident, $LogicalType: ident) => { + impl WrapperType for $Type { + type LogicalType = $LogicalType; + type Native = $Type; + + fn from_native(value: Self::Native) -> Self { + value + } + + fn into_native(self) -> Self::Native { + self + } + } + }; +} + +impl_wrapper!(u8, UInt8Type); +impl_wrapper!(u16, UInt16Type); +impl_wrapper!(u32, UInt32Type); +impl_wrapper!(u64, UInt64Type); +impl_wrapper!(i8, Int8Type); +impl_wrapper!(i16, Int16Type); +impl_wrapper!(i32, Int32Type); +impl_wrapper!(i64, Int64Type); +impl_wrapper!(f32, Float32Type); +impl_wrapper!(f64, Float64Type); + +impl WrapperType for Date { + type LogicalType = DateType; + type Native = i32; + + fn from_native(value: i32) -> Self { + Date::new(value) + } + + fn into_native(self) -> i32 { + self.val() + } +} + +impl WrapperType for DateTime { + type LogicalType = DateTimeType; + type Native = i64; + + fn from_native(value: Self::Native) -> Self { + DateTime::new(value) + } + + fn into_native(self) -> Self::Native { + self.val() + } +} + +macro_rules! define_logical_primitive_type { + ($Native: ident, $TypeId: ident, $DataType: ident) => { + // We need to define it as an empty struct `struct DataType {}` instead of a struct-unit + // `struct DataType;` to ensure the serialized JSON string is compatible with previous + // implementation. + #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] + pub struct $DataType {} + + impl LogicalPrimitiveType for $DataType { + type ArrowPrimitive = arrow::datatypes::$DataType; + type Native = $Native; + type Wrapper = $Native; + + fn build_data_type() -> ConcreteDataType { + ConcreteDataType::$TypeId($DataType::default()) + } + + fn type_name() -> &'static str { + stringify!($TypeId) + } + + fn cast_vector(vector: &dyn Vector) -> Result<&PrimitiveVector<$DataType>> { + vector + .as_any() + .downcast_ref::>() + .with_context(|| error::CastTypeSnafu { + msg: format!( + "Failed to cast {} to vector of primitive type {}", + vector.vector_type_name(), + stringify!($TypeId) + ), + }) + } + + fn cast_value_ref(value: ValueRef) -> Result> { + match value { + ValueRef::Null => Ok(None), + ValueRef::$TypeId(v) => Ok(Some(v.into())), + other => error::CastTypeSnafu { + msg: format!( + "Failed to cast value {:?} to primitive type {}", + other, + stringify!($TypeId), + ), + } + .fail(), + } + } + } + }; +} + +macro_rules! define_non_timestamp_primitive { + ($Native: ident, $TypeId: ident, $DataType: ident) => { + define_logical_primitive_type!($Native, $TypeId, $DataType); + + impl DataType for $DataType { + fn name(&self) -> &str { + stringify!($TypeId) + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::$TypeId + } + + fn default_value(&self) -> Value { + $Native::default().into() + } + + fn as_arrow_type(&self) -> ArrowDataType { + ArrowDataType::$TypeId + } + + fn create_mutable_vector(&self, capacity: usize) -> Box { + Box::new(PrimitiveVectorBuilder::<$DataType>::with_capacity(capacity)) + } + + fn is_timestamp_compatible(&self) -> bool { + false + } + } + }; +} + +define_non_timestamp_primitive!(u8, UInt8, UInt8Type); +define_non_timestamp_primitive!(u16, UInt16, UInt16Type); +define_non_timestamp_primitive!(u32, UInt32, UInt32Type); +define_non_timestamp_primitive!(u64, UInt64, UInt64Type); +define_non_timestamp_primitive!(i8, Int8, Int8Type); +define_non_timestamp_primitive!(i16, Int16, Int16Type); +define_non_timestamp_primitive!(i32, Int32, Int32Type); +define_non_timestamp_primitive!(f32, Float32, Float32Type); +define_non_timestamp_primitive!(f64, Float64, Float64Type); + +// Timestamp primitive: +define_logical_primitive_type!(i64, Int64, Int64Type); + +impl DataType for Int64Type { + fn name(&self) -> &str { + "Int64" + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::Int64 + } + + fn default_value(&self) -> Value { + Value::Int64(0) + } + + fn as_arrow_type(&self) -> ArrowDataType { + ArrowDataType::Int64 + } + + fn create_mutable_vector(&self, capacity: usize) -> Box { + Box::new(PrimitiveVectorBuilder::::with_capacity(capacity)) + } + + fn is_timestamp_compatible(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use std::collections::BinaryHeap; + + use super::*; + + #[test] + fn test_ord_primitive() { + struct Foo + where + T: WrapperType, + { + heap: BinaryHeap>, + } + + impl Foo + where + T: WrapperType, + { + fn push(&mut self, value: T) { + let value = OrdPrimitive::(value); + self.heap.push(value); + } + } + + macro_rules! test { + ($Type:ident) => { + let mut foo = Foo::<$Type> { + heap: BinaryHeap::new(), + }; + foo.push($Type::default()); + }; + } + + test!(u8); + test!(u16); + test!(u32); + test!(u64); + test!(i8); + test!(i16); + test!(i32); + test!(i64); + test!(f32); + test!(f64); + } +} diff --git a/src/datatypes2/src/types/string_type.rs b/src/datatypes2/src/types/string_type.rs new file mode 100644 index 0000000000..799cbbbdd3 --- /dev/null +++ b/src/datatypes2/src/types/string_type.rs @@ -0,0 +1,60 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::datatypes::DataType as ArrowDataType; +use common_base::bytes::StringBytes; +use serde::{Deserialize, Serialize}; + +use crate::data_type::{DataType, DataTypeRef}; +use crate::prelude::ScalarVectorBuilder; +use crate::type_id::LogicalTypeId; +use crate::value::Value; +use crate::vectors::{MutableVector, StringVectorBuilder}; + +#[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct StringType; + +impl StringType { + pub fn arc() -> DataTypeRef { + Arc::new(Self) + } +} + +impl DataType for StringType { + fn name(&self) -> &str { + "String" + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::String + } + + fn default_value(&self) -> Value { + StringBytes::default().into() + } + + fn as_arrow_type(&self) -> ArrowDataType { + ArrowDataType::Utf8 + } + + fn create_mutable_vector(&self, capacity: usize) -> Box { + Box::new(StringVectorBuilder::with_capacity(capacity)) + } + + fn is_timestamp_compatible(&self) -> bool { + false + } +} diff --git a/src/datatypes2/src/types/timestamp_type.rs b/src/datatypes2/src/types/timestamp_type.rs new file mode 100644 index 0000000000..fe86eeb8fd --- /dev/null +++ b/src/datatypes2/src/types/timestamp_type.rs @@ -0,0 +1,140 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::datatypes::{ + DataType as ArrowDataType, TimeUnit as ArrowTimeUnit, + TimestampMicrosecondType as ArrowTimestampMicrosecondType, + TimestampMillisecondType as ArrowTimestampMillisecondType, + TimestampNanosecondType as ArrowTimestampNanosecondType, + TimestampSecondType as ArrowTimestampSecondType, +}; +use common_time::timestamp::TimeUnit; +use common_time::Timestamp; +use enum_dispatch::enum_dispatch; +use paste::paste; +use serde::{Deserialize, Serialize}; +use snafu::OptionExt; + +use crate::data_type::ConcreteDataType; +use crate::error; +use crate::prelude::{ + DataType, LogicalTypeId, MutableVector, ScalarVectorBuilder, Value, ValueRef, Vector, +}; +use crate::timestamp::{ + TimestampMicrosecond, TimestampMillisecond, TimestampNanosecond, TimestampSecond, +}; +use crate::types::LogicalPrimitiveType; +use crate::vectors::{ + PrimitiveVector, TimestampMicrosecondVector, TimestampMicrosecondVectorBuilder, + TimestampMillisecondVector, TimestampMillisecondVectorBuilder, TimestampNanosecondVector, + TimestampNanosecondVectorBuilder, TimestampSecondVector, TimestampSecondVectorBuilder, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[enum_dispatch(DataType)] +pub enum TimestampType { + Second(TimestampSecondType), + Millisecond(TimestampMillisecondType), + Microsecond(TimestampMicrosecondType), + Nanosecond(TimestampNanosecondType), +} + +macro_rules! impl_data_type_for_timestamp { + ($unit: ident) => { + paste! { + #[derive(Debug, Default, Clone, PartialEq, Eq, Serialize, Deserialize)] + pub struct []; + + impl DataType for [] { + fn name(&self) -> &str { + stringify!([]) + } + + fn logical_type_id(&self) -> LogicalTypeId { + LogicalTypeId::[] + } + + fn default_value(&self) -> Value { + Value::Timestamp(Timestamp::new(0, TimeUnit::$unit)) + } + + fn as_arrow_type(&self) -> ArrowDataType { + ArrowDataType::Timestamp(ArrowTimeUnit::$unit, None) + } + + fn create_mutable_vector(&self, capacity: usize) -> Box { + Box::new([]::with_capacity(capacity)) + } + + fn is_timestamp_compatible(&self) -> bool { + true + } + } + + + impl LogicalPrimitiveType for [] { + type ArrowPrimitive = []; + type Native = i64; + type Wrapper = []; + + fn build_data_type() -> ConcreteDataType { + ConcreteDataType::Timestamp(TimestampType::$unit( + []::default(), + )) + } + + fn type_name() -> &'static str { + stringify!([]) + } + + fn cast_vector(vector: &dyn Vector) -> crate::Result<&PrimitiveVector> { + vector + .as_any() + .downcast_ref::<[]>() + .with_context(|| error::CastTypeSnafu { + msg: format!( + "Failed to cast {} to {}", + vector.vector_type_name(), stringify!([]) + ), + }) + } + + fn cast_value_ref(value: ValueRef) -> crate::Result> { + match value { + ValueRef::Null => Ok(None), + ValueRef::Timestamp(t) => match t.unit() { + TimeUnit::$unit => Ok(Some([](t))), + other => error::CastTypeSnafu { + msg: format!( + "Failed to cast Timestamp value with different unit {:?} to {}", + other, stringify!([]) + ), + } + .fail(), + }, + other => error::CastTypeSnafu { + msg: format!("Failed to cast value {:?} to {}", other, stringify!([])), + } + .fail(), + } + } + } + } + } +} + +impl_data_type_for_timestamp!(Nanosecond); +impl_data_type_for_timestamp!(Second); +impl_data_type_for_timestamp!(Millisecond); +impl_data_type_for_timestamp!(Microsecond); diff --git a/src/datatypes2/src/value.rs b/src/datatypes2/src/value.rs new file mode 100644 index 0000000000..bade88d419 --- /dev/null +++ b/src/datatypes2/src/value.rs @@ -0,0 +1,1275 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Ordering; +use std::fmt::{Display, Formatter}; + +use common_base::bytes::{Bytes, StringBytes}; +use common_time::date::Date; +use common_time::datetime::DateTime; +use common_time::timestamp::{TimeUnit, Timestamp}; +use datafusion_common::ScalarValue; +pub use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; + +use crate::error::{self, Result}; +use crate::prelude::*; +use crate::type_id::LogicalTypeId; +use crate::vectors::ListVector; + +pub type OrderedF32 = OrderedFloat; +pub type OrderedF64 = OrderedFloat; + +/// Value holds a single arbitrary value of any [DataType](crate::data_type::DataType). +/// +/// Comparison between values with different types (expect Null) is not allowed. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum Value { + Null, + + // Numeric types: + Boolean(bool), + UInt8(u8), + UInt16(u16), + UInt32(u32), + UInt64(u64), + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + Float32(OrderedF32), + Float64(OrderedF64), + + // String types: + String(StringBytes), + Binary(Bytes), + + // Date & Time types: + Date(Date), + DateTime(DateTime), + Timestamp(Timestamp), + + List(ListValue), +} + +impl Display for Value { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Value::Null => write!(f, "{}", self.data_type().name()), + Value::Boolean(v) => write!(f, "{}", v), + Value::UInt8(v) => write!(f, "{}", v), + Value::UInt16(v) => write!(f, "{}", v), + Value::UInt32(v) => write!(f, "{}", v), + Value::UInt64(v) => write!(f, "{}", v), + Value::Int8(v) => write!(f, "{}", v), + Value::Int16(v) => write!(f, "{}", v), + Value::Int32(v) => write!(f, "{}", v), + Value::Int64(v) => write!(f, "{}", v), + Value::Float32(v) => write!(f, "{}", v), + Value::Float64(v) => write!(f, "{}", v), + Value::String(v) => write!(f, "{}", v.as_utf8()), + Value::Binary(v) => { + let hex = v + .iter() + .map(|b| format!("{:02x}", b)) + .collect::>() + .join(""); + write!(f, "{}", hex) + } + Value::Date(v) => write!(f, "{}", v), + Value::DateTime(v) => write!(f, "{}", v), + Value::Timestamp(v) => write!(f, "{}", v.to_iso8601_string()), + Value::List(v) => { + let default = Box::new(vec![]); + let items = v.items().as_ref().unwrap_or(&default); + let items = items + .iter() + .map(|i| i.to_string()) + .collect::>() + .join(", "); + write!(f, "{}[{}]", v.datatype.name(), items) + } + } + } +} + +impl Value { + /// Returns data type of the value. + /// + /// # Panics + /// Panics if the data type is not supported. + pub fn data_type(&self) -> ConcreteDataType { + // TODO(yingwen): Implement this once all data types are implemented. + match self { + Value::Null => ConcreteDataType::null_datatype(), + Value::Boolean(_) => ConcreteDataType::boolean_datatype(), + Value::UInt8(_) => ConcreteDataType::uint8_datatype(), + Value::UInt16(_) => ConcreteDataType::uint16_datatype(), + Value::UInt32(_) => ConcreteDataType::uint32_datatype(), + Value::UInt64(_) => ConcreteDataType::uint64_datatype(), + Value::Int8(_) => ConcreteDataType::int8_datatype(), + Value::Int16(_) => ConcreteDataType::int16_datatype(), + Value::Int32(_) => ConcreteDataType::int32_datatype(), + Value::Int64(_) => ConcreteDataType::int64_datatype(), + Value::Float32(_) => ConcreteDataType::float32_datatype(), + Value::Float64(_) => ConcreteDataType::float64_datatype(), + Value::String(_) => ConcreteDataType::string_datatype(), + Value::Binary(_) => ConcreteDataType::binary_datatype(), + Value::Date(_) => ConcreteDataType::date_datatype(), + Value::DateTime(_) => ConcreteDataType::datetime_datatype(), + Value::Timestamp(v) => ConcreteDataType::timestamp_datatype(v.unit()), + Value::List(list) => ConcreteDataType::list_datatype(list.datatype().clone()), + } + } + + /// Returns true if this is a null value. + pub fn is_null(&self) -> bool { + matches!(self, Value::Null) + } + + /// Cast itself to [ListValue]. + pub fn as_list(&self) -> Result> { + match self { + Value::Null => Ok(None), + Value::List(v) => Ok(Some(v)), + other => error::CastTypeSnafu { + msg: format!("Failed to cast {:?} to list value", other), + } + .fail(), + } + } + + /// Cast itself to [ValueRef]. + pub fn as_value_ref(&self) -> ValueRef { + match self { + Value::Null => ValueRef::Null, + Value::Boolean(v) => ValueRef::Boolean(*v), + Value::UInt8(v) => ValueRef::UInt8(*v), + Value::UInt16(v) => ValueRef::UInt16(*v), + Value::UInt32(v) => ValueRef::UInt32(*v), + Value::UInt64(v) => ValueRef::UInt64(*v), + Value::Int8(v) => ValueRef::Int8(*v), + Value::Int16(v) => ValueRef::Int16(*v), + Value::Int32(v) => ValueRef::Int32(*v), + Value::Int64(v) => ValueRef::Int64(*v), + Value::Float32(v) => ValueRef::Float32(*v), + Value::Float64(v) => ValueRef::Float64(*v), + Value::String(v) => ValueRef::String(v.as_utf8()), + Value::Binary(v) => ValueRef::Binary(v), + Value::Date(v) => ValueRef::Date(*v), + Value::DateTime(v) => ValueRef::DateTime(*v), + Value::List(v) => ValueRef::List(ListValueRef::Ref { val: v }), + Value::Timestamp(v) => ValueRef::Timestamp(*v), + } + } + + /// Returns the logical type of the value. + pub fn logical_type_id(&self) -> LogicalTypeId { + match self { + Value::Null => LogicalTypeId::Null, + Value::Boolean(_) => LogicalTypeId::Boolean, + Value::UInt8(_) => LogicalTypeId::UInt8, + Value::UInt16(_) => LogicalTypeId::UInt16, + Value::UInt32(_) => LogicalTypeId::UInt32, + Value::UInt64(_) => LogicalTypeId::UInt64, + Value::Int8(_) => LogicalTypeId::Int8, + Value::Int16(_) => LogicalTypeId::Int16, + Value::Int32(_) => LogicalTypeId::Int32, + Value::Int64(_) => LogicalTypeId::Int64, + Value::Float32(_) => LogicalTypeId::Float32, + Value::Float64(_) => LogicalTypeId::Float64, + Value::String(_) => LogicalTypeId::String, + Value::Binary(_) => LogicalTypeId::Binary, + Value::List(_) => LogicalTypeId::List, + Value::Date(_) => LogicalTypeId::Date, + Value::DateTime(_) => LogicalTypeId::DateTime, + Value::Timestamp(t) => match t.unit() { + TimeUnit::Second => LogicalTypeId::TimestampSecond, + TimeUnit::Millisecond => LogicalTypeId::TimestampMillisecond, + TimeUnit::Microsecond => LogicalTypeId::TimestampMicrosecond, + TimeUnit::Nanosecond => LogicalTypeId::TimestampNanosecond, + }, + } + } +} + +macro_rules! impl_ord_for_value_like { + ($Type: ident, $left: ident, $right: ident) => { + if $left.is_null() && !$right.is_null() { + return Ordering::Less; + } else if !$left.is_null() && $right.is_null() { + return Ordering::Greater; + } else { + match ($left, $right) { + ($Type::Null, $Type::Null) => Ordering::Equal, + ($Type::Boolean(v1), $Type::Boolean(v2)) => v1.cmp(v2), + ($Type::UInt8(v1), $Type::UInt8(v2)) => v1.cmp(v2), + ($Type::UInt16(v1), $Type::UInt16(v2)) => v1.cmp(v2), + ($Type::UInt32(v1), $Type::UInt32(v2)) => v1.cmp(v2), + ($Type::UInt64(v1), $Type::UInt64(v2)) => v1.cmp(v2), + ($Type::Int8(v1), $Type::Int8(v2)) => v1.cmp(v2), + ($Type::Int16(v1), $Type::Int16(v2)) => v1.cmp(v2), + ($Type::Int32(v1), $Type::Int32(v2)) => v1.cmp(v2), + ($Type::Int64(v1), $Type::Int64(v2)) => v1.cmp(v2), + ($Type::Float32(v1), $Type::Float32(v2)) => v1.cmp(v2), + ($Type::Float64(v1), $Type::Float64(v2)) => v1.cmp(v2), + ($Type::String(v1), $Type::String(v2)) => v1.cmp(v2), + ($Type::Binary(v1), $Type::Binary(v2)) => v1.cmp(v2), + ($Type::Date(v1), $Type::Date(v2)) => v1.cmp(v2), + ($Type::DateTime(v1), $Type::DateTime(v2)) => v1.cmp(v2), + ($Type::Timestamp(v1), $Type::Timestamp(v2)) => v1.cmp(v2), + ($Type::List(v1), $Type::List(v2)) => v1.cmp(v2), + _ => panic!( + "Cannot compare different values {:?} and {:?}", + $left, $right + ), + } + } + }; +} + +impl PartialOrd for Value { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Value { + fn cmp(&self, other: &Self) -> Ordering { + impl_ord_for_value_like!(Value, self, other) + } +} + +macro_rules! impl_value_from { + ($Variant: ident, $Type: ident) => { + impl From<$Type> for Value { + fn from(value: $Type) -> Self { + Value::$Variant(value.into()) + } + } + + impl From> for Value { + fn from(value: Option<$Type>) -> Self { + match value { + Some(v) => Value::$Variant(v.into()), + None => Value::Null, + } + } + } + }; +} + +impl_value_from!(Boolean, bool); +impl_value_from!(UInt8, u8); +impl_value_from!(UInt16, u16); +impl_value_from!(UInt32, u32); +impl_value_from!(UInt64, u64); +impl_value_from!(Int8, i8); +impl_value_from!(Int16, i16); +impl_value_from!(Int32, i32); +impl_value_from!(Int64, i64); +impl_value_from!(Float32, f32); +impl_value_from!(Float64, f64); +impl_value_from!(String, StringBytes); +impl_value_from!(Binary, Bytes); +impl_value_from!(Date, Date); +impl_value_from!(DateTime, DateTime); +impl_value_from!(Timestamp, Timestamp); + +impl From for Value { + fn from(string: String) -> Value { + Value::String(string.into()) + } +} + +impl From<&str> for Value { + fn from(string: &str) -> Value { + Value::String(string.into()) + } +} + +impl From> for Value { + fn from(bytes: Vec) -> Value { + Value::Binary(bytes.into()) + } +} + +impl From<&[u8]> for Value { + fn from(bytes: &[u8]) -> Value { + Value::Binary(bytes.into()) + } +} + +impl TryFrom for serde_json::Value { + type Error = serde_json::Error; + + fn try_from(value: Value) -> serde_json::Result { + let json_value = match value { + Value::Null => serde_json::Value::Null, + Value::Boolean(v) => serde_json::Value::Bool(v), + Value::UInt8(v) => serde_json::Value::from(v), + Value::UInt16(v) => serde_json::Value::from(v), + Value::UInt32(v) => serde_json::Value::from(v), + Value::UInt64(v) => serde_json::Value::from(v), + Value::Int8(v) => serde_json::Value::from(v), + Value::Int16(v) => serde_json::Value::from(v), + Value::Int32(v) => serde_json::Value::from(v), + Value::Int64(v) => serde_json::Value::from(v), + Value::Float32(v) => serde_json::Value::from(v.0), + Value::Float64(v) => serde_json::Value::from(v.0), + Value::String(bytes) => serde_json::Value::String(bytes.as_utf8().to_string()), + Value::Binary(bytes) => serde_json::to_value(bytes)?, + Value::Date(v) => serde_json::Value::Number(v.val().into()), + Value::DateTime(v) => serde_json::Value::Number(v.val().into()), + Value::List(v) => serde_json::to_value(v)?, + Value::Timestamp(v) => serde_json::to_value(v.value())?, + }; + + Ok(json_value) + } +} + +// TODO(yingwen): Consider removing the `datatype` field from `ListValue`. +/// List value. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ListValue { + /// List of nested Values (boxed to reduce size_of(Value)) + #[allow(clippy::box_collection)] + items: Option>>, + /// Inner values datatype, to distinguish empty lists of different datatypes. + /// Restricted by DataFusion, cannot use null datatype for empty list. + datatype: ConcreteDataType, +} + +impl Eq for ListValue {} + +impl ListValue { + pub fn new(items: Option>>, datatype: ConcreteDataType) -> Self { + Self { items, datatype } + } + + pub fn items(&self) -> &Option>> { + &self.items + } + + pub fn datatype(&self) -> &ConcreteDataType { + &self.datatype + } +} + +impl Default for ListValue { + fn default() -> ListValue { + ListValue::new(None, ConcreteDataType::null_datatype()) + } +} + +impl PartialOrd for ListValue { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ListValue { + fn cmp(&self, other: &Self) -> Ordering { + assert_eq!( + self.datatype, other.datatype, + "Cannot compare different datatypes!" + ); + self.items.cmp(&other.items) + } +} + +impl TryFrom for Value { + type Error = error::Error; + + fn try_from(v: ScalarValue) -> Result { + let v = match v { + ScalarValue::Null => Value::Null, + ScalarValue::Boolean(b) => Value::from(b), + ScalarValue::Float32(f) => Value::from(f), + ScalarValue::Float64(f) => Value::from(f), + ScalarValue::Int8(i) => Value::from(i), + ScalarValue::Int16(i) => Value::from(i), + ScalarValue::Int32(i) => Value::from(i), + ScalarValue::Int64(i) => Value::from(i), + ScalarValue::UInt8(u) => Value::from(u), + ScalarValue::UInt16(u) => Value::from(u), + ScalarValue::UInt32(u) => Value::from(u), + ScalarValue::UInt64(u) => Value::from(u), + ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => { + Value::from(s.map(StringBytes::from)) + } + ScalarValue::Binary(b) + | ScalarValue::LargeBinary(b) + | ScalarValue::FixedSizeBinary(_, b) => Value::from(b.map(Bytes::from)), + ScalarValue::List(vs, field) => { + let items = if let Some(vs) = vs { + let vs = vs + .into_iter() + .map(ScalarValue::try_into) + .collect::>()?; + Some(Box::new(vs)) + } else { + None + }; + let datatype = ConcreteDataType::try_from(field.data_type())?; + Value::List(ListValue::new(items, datatype)) + } + ScalarValue::Date32(d) => d.map(|x| Value::Date(Date::new(x))).unwrap_or(Value::Null), + ScalarValue::Date64(d) => d + .map(|x| Value::DateTime(DateTime::new(x))) + .unwrap_or(Value::Null), + ScalarValue::TimestampSecond(t, _) => t + .map(|x| Value::Timestamp(Timestamp::new(x, TimeUnit::Second))) + .unwrap_or(Value::Null), + ScalarValue::TimestampMillisecond(t, _) => t + .map(|x| Value::Timestamp(Timestamp::new(x, TimeUnit::Millisecond))) + .unwrap_or(Value::Null), + ScalarValue::TimestampMicrosecond(t, _) => t + .map(|x| Value::Timestamp(Timestamp::new(x, TimeUnit::Microsecond))) + .unwrap_or(Value::Null), + ScalarValue::TimestampNanosecond(t, _) => t + .map(|x| Value::Timestamp(Timestamp::new(x, TimeUnit::Nanosecond))) + .unwrap_or(Value::Null), + ScalarValue::Decimal128(_, _, _) + | ScalarValue::Time64(_) + | ScalarValue::IntervalYearMonth(_) + | ScalarValue::IntervalDayTime(_) + | ScalarValue::IntervalMonthDayNano(_) + | ScalarValue::Struct(_, _) + | ScalarValue::Dictionary(_, _) => { + return error::UnsupportedArrowTypeSnafu { + arrow_type: v.get_datatype(), + } + .fail() + } + }; + Ok(v) + } +} + +/// Reference to [Value]. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ValueRef<'a> { + Null, + + // Numeric types: + Boolean(bool), + UInt8(u8), + UInt16(u16), + UInt32(u32), + UInt64(u64), + Int8(i8), + Int16(i16), + Int32(i32), + Int64(i64), + Float32(OrderedF32), + Float64(OrderedF64), + + // String types: + String(&'a str), + Binary(&'a [u8]), + + // Date & Time types: + Date(Date), + DateTime(DateTime), + Timestamp(Timestamp), + List(ListValueRef<'a>), +} + +macro_rules! impl_as_for_value_ref { + ($value: ident, $Variant: ident) => { + match $value { + ValueRef::Null => Ok(None), + ValueRef::$Variant(v) => Ok(Some(*v)), + other => error::CastTypeSnafu { + msg: format!( + "Failed to cast value ref {:?} to {}", + other, + stringify!($Variant) + ), + } + .fail(), + } + }; +} + +impl<'a> ValueRef<'a> { + /// Returns true if this is null. + pub fn is_null(&self) -> bool { + matches!(self, ValueRef::Null) + } + + /// Cast itself to binary slice. + pub fn as_binary(&self) -> Result> { + impl_as_for_value_ref!(self, Binary) + } + + /// Cast itself to string slice. + pub fn as_string(&self) -> Result> { + impl_as_for_value_ref!(self, String) + } + + /// Cast itself to boolean. + pub fn as_boolean(&self) -> Result> { + impl_as_for_value_ref!(self, Boolean) + } + + /// Cast itself to [Date]. + pub fn as_date(&self) -> Result> { + impl_as_for_value_ref!(self, Date) + } + + /// Cast itself to [DateTime]. + pub fn as_datetime(&self) -> Result> { + impl_as_for_value_ref!(self, DateTime) + } + + pub fn as_timestamp(&self) -> Result> { + impl_as_for_value_ref!(self, Timestamp) + } + + /// Cast itself to [ListValueRef]. + pub fn as_list(&self) -> Result> { + impl_as_for_value_ref!(self, List) + } +} + +impl<'a> PartialOrd for ValueRef<'a> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl<'a> Ord for ValueRef<'a> { + fn cmp(&self, other: &Self) -> Ordering { + impl_ord_for_value_like!(ValueRef, self, other) + } +} + +macro_rules! impl_value_ref_from { + ($Variant:ident, $Type:ident) => { + impl From<$Type> for ValueRef<'_> { + fn from(value: $Type) -> Self { + ValueRef::$Variant(value.into()) + } + } + + impl From> for ValueRef<'_> { + fn from(value: Option<$Type>) -> Self { + match value { + Some(v) => ValueRef::$Variant(v.into()), + None => ValueRef::Null, + } + } + } + }; +} + +impl_value_ref_from!(Boolean, bool); +impl_value_ref_from!(UInt8, u8); +impl_value_ref_from!(UInt16, u16); +impl_value_ref_from!(UInt32, u32); +impl_value_ref_from!(UInt64, u64); +impl_value_ref_from!(Int8, i8); +impl_value_ref_from!(Int16, i16); +impl_value_ref_from!(Int32, i32); +impl_value_ref_from!(Int64, i64); +impl_value_ref_from!(Float32, f32); +impl_value_ref_from!(Float64, f64); +impl_value_ref_from!(Date, Date); +impl_value_ref_from!(DateTime, DateTime); +impl_value_ref_from!(Timestamp, Timestamp); + +impl<'a> From<&'a str> for ValueRef<'a> { + fn from(string: &'a str) -> ValueRef<'a> { + ValueRef::String(string) + } +} + +impl<'a> From<&'a [u8]> for ValueRef<'a> { + fn from(bytes: &'a [u8]) -> ValueRef<'a> { + ValueRef::Binary(bytes) + } +} + +impl<'a> From>> for ValueRef<'a> { + fn from(list: Option) -> ValueRef { + match list { + Some(v) => ValueRef::List(v), + None => ValueRef::Null, + } + } +} + +/// Reference to a [ListValue]. +/// +/// Now comparison still requires some allocation (call of `to_value()`) and +/// might be avoidable by downcasting and comparing the underlying array slice +/// if it becomes bottleneck. +#[derive(Debug, Clone, Copy)] +pub enum ListValueRef<'a> { + // TODO(yingwen): Consider replace this by VectorRef. + Indexed { vector: &'a ListVector, idx: usize }, + Ref { val: &'a ListValue }, +} + +impl<'a> ListValueRef<'a> { + /// Convert self to [Value]. This method would clone the underlying data. + fn to_value(self) -> Value { + match self { + ListValueRef::Indexed { vector, idx } => vector.get(idx), + ListValueRef::Ref { val } => Value::List(val.clone()), + } + } +} + +impl<'a> PartialEq for ListValueRef<'a> { + fn eq(&self, other: &Self) -> bool { + self.to_value().eq(&other.to_value()) + } +} + +impl<'a> Eq for ListValueRef<'a> {} + +impl<'a> Ord for ListValueRef<'a> { + fn cmp(&self, other: &Self) -> Ordering { + // Respect the order of `Value` by converting into value before comparison. + self.to_value().cmp(&other.to_value()) + } +} + +impl<'a> PartialOrd for ListValueRef<'a> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType as ArrowDataType; + use num_traits::Float; + + use super::*; + + #[test] + fn test_try_from_scalar_value() { + assert_eq!( + Value::Boolean(true), + ScalarValue::Boolean(Some(true)).try_into().unwrap() + ); + assert_eq!( + Value::Boolean(false), + ScalarValue::Boolean(Some(false)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Boolean(None).try_into().unwrap()); + + assert_eq!( + Value::Float32(1.0f32.into()), + ScalarValue::Float32(Some(1.0f32)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Float32(None).try_into().unwrap()); + + assert_eq!( + Value::Float64(2.0f64.into()), + ScalarValue::Float64(Some(2.0f64)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Float64(None).try_into().unwrap()); + + assert_eq!( + Value::Int8(i8::MAX), + ScalarValue::Int8(Some(i8::MAX)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Int8(None).try_into().unwrap()); + + assert_eq!( + Value::Int16(i16::MAX), + ScalarValue::Int16(Some(i16::MAX)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Int16(None).try_into().unwrap()); + + assert_eq!( + Value::Int32(i32::MAX), + ScalarValue::Int32(Some(i32::MAX)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Int32(None).try_into().unwrap()); + + assert_eq!( + Value::Int64(i64::MAX), + ScalarValue::Int64(Some(i64::MAX)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Int64(None).try_into().unwrap()); + + assert_eq!( + Value::UInt8(u8::MAX), + ScalarValue::UInt8(Some(u8::MAX)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::UInt8(None).try_into().unwrap()); + + assert_eq!( + Value::UInt16(u16::MAX), + ScalarValue::UInt16(Some(u16::MAX)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::UInt16(None).try_into().unwrap()); + + assert_eq!( + Value::UInt32(u32::MAX), + ScalarValue::UInt32(Some(u32::MAX)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::UInt32(None).try_into().unwrap()); + + assert_eq!( + Value::UInt64(u64::MAX), + ScalarValue::UInt64(Some(u64::MAX)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::UInt64(None).try_into().unwrap()); + + assert_eq!( + Value::from("hello"), + ScalarValue::Utf8(Some("hello".to_string())) + .try_into() + .unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Utf8(None).try_into().unwrap()); + + assert_eq!( + Value::from("large_hello"), + ScalarValue::LargeUtf8(Some("large_hello".to_string())) + .try_into() + .unwrap() + ); + assert_eq!( + Value::Null, + ScalarValue::LargeUtf8(None).try_into().unwrap() + ); + + assert_eq!( + Value::from("world".as_bytes()), + ScalarValue::Binary(Some("world".as_bytes().to_vec())) + .try_into() + .unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Binary(None).try_into().unwrap()); + + assert_eq!( + Value::from("large_world".as_bytes()), + ScalarValue::LargeBinary(Some("large_world".as_bytes().to_vec())) + .try_into() + .unwrap() + ); + assert_eq!( + Value::Null, + ScalarValue::LargeBinary(None).try_into().unwrap() + ); + + assert_eq!( + Value::List(ListValue::new( + Some(Box::new(vec![Value::Int32(1), Value::Null])), + ConcreteDataType::int32_datatype() + )), + ScalarValue::new_list( + Some(vec![ScalarValue::Int32(Some(1)), ScalarValue::Int32(None)]), + ArrowDataType::Int32, + ) + .try_into() + .unwrap() + ); + assert_eq!( + Value::List(ListValue::new(None, ConcreteDataType::uint32_datatype())), + ScalarValue::new_list(None, ArrowDataType::UInt32) + .try_into() + .unwrap() + ); + + assert_eq!( + Value::Date(Date::new(123)), + ScalarValue::Date32(Some(123)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Date32(None).try_into().unwrap()); + + assert_eq!( + Value::DateTime(DateTime::new(456)), + ScalarValue::Date64(Some(456)).try_into().unwrap() + ); + assert_eq!(Value::Null, ScalarValue::Date64(None).try_into().unwrap()); + + assert_eq!( + Value::Timestamp(Timestamp::new(1, TimeUnit::Second)), + ScalarValue::TimestampSecond(Some(1), None) + .try_into() + .unwrap() + ); + assert_eq!( + Value::Null, + ScalarValue::TimestampSecond(None, None).try_into().unwrap() + ); + + assert_eq!( + Value::Timestamp(Timestamp::new(1, TimeUnit::Millisecond)), + ScalarValue::TimestampMillisecond(Some(1), None) + .try_into() + .unwrap() + ); + assert_eq!( + Value::Null, + ScalarValue::TimestampMillisecond(None, None) + .try_into() + .unwrap() + ); + + assert_eq!( + Value::Timestamp(Timestamp::new(1, TimeUnit::Microsecond)), + ScalarValue::TimestampMicrosecond(Some(1), None) + .try_into() + .unwrap() + ); + assert_eq!( + Value::Null, + ScalarValue::TimestampMicrosecond(None, None) + .try_into() + .unwrap() + ); + + assert_eq!( + Value::Timestamp(Timestamp::new(1, TimeUnit::Nanosecond)), + ScalarValue::TimestampNanosecond(Some(1), None) + .try_into() + .unwrap() + ); + assert_eq!( + Value::Null, + ScalarValue::TimestampNanosecond(None, None) + .try_into() + .unwrap() + ); + + let result: Result = ScalarValue::Decimal128(Some(1), 0, 0).try_into(); + result + .unwrap_err() + .to_string() + .contains("Unsupported arrow data type, type: Decimal(0, 0)"); + } + + #[test] + fn test_value_from_inner() { + assert_eq!(Value::Boolean(true), Value::from(true)); + assert_eq!(Value::Boolean(false), Value::from(false)); + + assert_eq!(Value::UInt8(u8::MIN), Value::from(u8::MIN)); + assert_eq!(Value::UInt8(u8::MAX), Value::from(u8::MAX)); + + assert_eq!(Value::UInt16(u16::MIN), Value::from(u16::MIN)); + assert_eq!(Value::UInt16(u16::MAX), Value::from(u16::MAX)); + + assert_eq!(Value::UInt32(u32::MIN), Value::from(u32::MIN)); + assert_eq!(Value::UInt32(u32::MAX), Value::from(u32::MAX)); + + assert_eq!(Value::UInt64(u64::MIN), Value::from(u64::MIN)); + assert_eq!(Value::UInt64(u64::MAX), Value::from(u64::MAX)); + + assert_eq!(Value::Int8(i8::MIN), Value::from(i8::MIN)); + assert_eq!(Value::Int8(i8::MAX), Value::from(i8::MAX)); + + assert_eq!(Value::Int16(i16::MIN), Value::from(i16::MIN)); + assert_eq!(Value::Int16(i16::MAX), Value::from(i16::MAX)); + + assert_eq!(Value::Int32(i32::MIN), Value::from(i32::MIN)); + assert_eq!(Value::Int32(i32::MAX), Value::from(i32::MAX)); + + assert_eq!(Value::Int64(i64::MIN), Value::from(i64::MIN)); + assert_eq!(Value::Int64(i64::MAX), Value::from(i64::MAX)); + + assert_eq!( + Value::Float32(OrderedFloat(f32::MIN)), + Value::from(f32::MIN) + ); + assert_eq!( + Value::Float32(OrderedFloat(f32::MAX)), + Value::from(f32::MAX) + ); + + assert_eq!( + Value::Float64(OrderedFloat(f64::MIN)), + Value::from(f64::MIN) + ); + assert_eq!( + Value::Float64(OrderedFloat(f64::MAX)), + Value::from(f64::MAX) + ); + + let string_bytes = StringBytes::from("hello"); + assert_eq!( + Value::String(string_bytes.clone()), + Value::from(string_bytes) + ); + + let bytes = Bytes::from(b"world".as_slice()); + assert_eq!(Value::Binary(bytes.clone()), Value::from(bytes)); + } + + fn check_type_and_value(data_type: &ConcreteDataType, value: &Value) { + assert_eq!(*data_type, value.data_type()); + assert_eq!(data_type.logical_type_id(), value.logical_type_id()); + } + + #[test] + fn test_value_datatype() { + check_type_and_value(&ConcreteDataType::boolean_datatype(), &Value::Boolean(true)); + check_type_and_value(&ConcreteDataType::uint8_datatype(), &Value::UInt8(u8::MIN)); + check_type_and_value( + &ConcreteDataType::uint16_datatype(), + &Value::UInt16(u16::MIN), + ); + check_type_and_value( + &ConcreteDataType::uint16_datatype(), + &Value::UInt16(u16::MAX), + ); + check_type_and_value( + &ConcreteDataType::uint32_datatype(), + &Value::UInt32(u32::MIN), + ); + check_type_and_value( + &ConcreteDataType::uint64_datatype(), + &Value::UInt64(u64::MIN), + ); + check_type_and_value(&ConcreteDataType::int8_datatype(), &Value::Int8(i8::MIN)); + check_type_and_value(&ConcreteDataType::int16_datatype(), &Value::Int16(i16::MIN)); + check_type_and_value(&ConcreteDataType::int32_datatype(), &Value::Int32(i32::MIN)); + check_type_and_value(&ConcreteDataType::int64_datatype(), &Value::Int64(i64::MIN)); + check_type_and_value( + &ConcreteDataType::float32_datatype(), + &Value::Float32(OrderedFloat(f32::MIN)), + ); + check_type_and_value( + &ConcreteDataType::float64_datatype(), + &Value::Float64(OrderedFloat(f64::MIN)), + ); + check_type_and_value( + &ConcreteDataType::string_datatype(), + &Value::String(StringBytes::from("hello")), + ); + check_type_and_value( + &ConcreteDataType::binary_datatype(), + &Value::Binary(Bytes::from(b"world".as_slice())), + ); + check_type_and_value( + &ConcreteDataType::list_datatype(ConcreteDataType::int32_datatype()), + &Value::List(ListValue::new( + Some(Box::new(vec![Value::Int32(10)])), + ConcreteDataType::int32_datatype(), + )), + ); + check_type_and_value( + &ConcreteDataType::list_datatype(ConcreteDataType::null_datatype()), + &Value::List(ListValue::default()), + ); + check_type_and_value( + &ConcreteDataType::date_datatype(), + &Value::Date(Date::new(1)), + ); + check_type_and_value( + &ConcreteDataType::datetime_datatype(), + &Value::DateTime(DateTime::new(1)), + ); + check_type_and_value( + &ConcreteDataType::timestamp_millisecond_datatype(), + &Value::Timestamp(Timestamp::from_millis(1)), + ); + } + + #[test] + fn test_value_from_string() { + let hello = "hello".to_string(); + assert_eq!( + Value::String(StringBytes::from(hello.clone())), + Value::from(hello) + ); + + let world = "world"; + assert_eq!(Value::String(StringBytes::from(world)), Value::from(world)); + } + + #[test] + fn test_value_from_bytes() { + let hello = b"hello".to_vec(); + assert_eq!( + Value::Binary(Bytes::from(hello.clone())), + Value::from(hello) + ); + + let world: &[u8] = b"world"; + assert_eq!(Value::Binary(Bytes::from(world)), Value::from(world)); + } + + fn to_json(value: Value) -> serde_json::Value { + value.try_into().unwrap() + } + + #[test] + fn test_to_json_value() { + assert_eq!(serde_json::Value::Null, to_json(Value::Null)); + assert_eq!(serde_json::Value::Bool(true), to_json(Value::Boolean(true))); + assert_eq!( + serde_json::Value::Number(20u8.into()), + to_json(Value::UInt8(20)) + ); + assert_eq!( + serde_json::Value::Number(20i8.into()), + to_json(Value::Int8(20)) + ); + assert_eq!( + serde_json::Value::Number(2000u16.into()), + to_json(Value::UInt16(2000)) + ); + assert_eq!( + serde_json::Value::Number(2000i16.into()), + to_json(Value::Int16(2000)) + ); + assert_eq!( + serde_json::Value::Number(3000u32.into()), + to_json(Value::UInt32(3000)) + ); + assert_eq!( + serde_json::Value::Number(3000i32.into()), + to_json(Value::Int32(3000)) + ); + assert_eq!( + serde_json::Value::Number(4000u64.into()), + to_json(Value::UInt64(4000)) + ); + assert_eq!( + serde_json::Value::Number(4000i64.into()), + to_json(Value::Int64(4000)) + ); + assert_eq!( + serde_json::Value::from(125.0f32), + to_json(Value::Float32(125.0.into())) + ); + assert_eq!( + serde_json::Value::from(125.0f64), + to_json(Value::Float64(125.0.into())) + ); + assert_eq!( + serde_json::Value::String(String::from("hello")), + to_json(Value::String(StringBytes::from("hello"))) + ); + assert_eq!( + serde_json::Value::from(b"world".as_slice()), + to_json(Value::Binary(Bytes::from(b"world".as_slice()))) + ); + assert_eq!( + serde_json::Value::Number(5000i32.into()), + to_json(Value::Date(Date::new(5000))) + ); + assert_eq!( + serde_json::Value::Number(5000i64.into()), + to_json(Value::DateTime(DateTime::new(5000))) + ); + + assert_eq!( + serde_json::Value::Number(1.into()), + to_json(Value::Timestamp(Timestamp::from_millis(1))) + ); + + let json_value: serde_json::Value = + serde_json::from_str(r#"{"items":[{"Int32":123}],"datatype":{"Int32":{}}}"#).unwrap(); + assert_eq!( + json_value, + to_json(Value::List(ListValue { + items: Some(Box::new(vec![Value::Int32(123)])), + datatype: ConcreteDataType::int32_datatype(), + })) + ); + } + + #[test] + fn test_null_value() { + assert!(Value::Null.is_null()); + assert!(!Value::Boolean(true).is_null()); + assert!(Value::Null < Value::Boolean(false)); + assert!(Value::Boolean(true) > Value::Null); + assert!(Value::Null < Value::Int32(10)); + assert!(Value::Int32(10) > Value::Null); + } + + #[test] + fn test_null_value_ref() { + assert!(ValueRef::Null.is_null()); + assert!(!ValueRef::Boolean(true).is_null()); + assert!(ValueRef::Null < ValueRef::Boolean(false)); + assert!(ValueRef::Boolean(true) > ValueRef::Null); + assert!(ValueRef::Null < ValueRef::Int32(10)); + assert!(ValueRef::Int32(10) > ValueRef::Null); + } + + #[test] + fn test_as_value_ref() { + macro_rules! check_as_value_ref { + ($Variant: ident, $data: expr) => { + let value = Value::$Variant($data); + let value_ref = value.as_value_ref(); + let expect_ref = ValueRef::$Variant($data); + + assert_eq!(expect_ref, value_ref); + }; + } + + assert_eq!(ValueRef::Null, Value::Null.as_value_ref()); + check_as_value_ref!(Boolean, true); + check_as_value_ref!(UInt8, 123); + check_as_value_ref!(UInt16, 123); + check_as_value_ref!(UInt32, 123); + check_as_value_ref!(UInt64, 123); + check_as_value_ref!(Int8, -12); + check_as_value_ref!(Int16, -12); + check_as_value_ref!(Int32, -12); + check_as_value_ref!(Int64, -12); + check_as_value_ref!(Float32, OrderedF32::from(16.0)); + check_as_value_ref!(Float64, OrderedF64::from(16.0)); + check_as_value_ref!(Timestamp, Timestamp::from_millis(1)); + + assert_eq!( + ValueRef::String("hello"), + Value::String("hello".into()).as_value_ref() + ); + assert_eq!( + ValueRef::Binary(b"hello"), + Value::Binary("hello".as_bytes().into()).as_value_ref() + ); + + check_as_value_ref!(Date, Date::new(103)); + check_as_value_ref!(DateTime, DateTime::new(1034)); + + let list = ListValue { + items: None, + datatype: ConcreteDataType::int32_datatype(), + }; + assert_eq!( + ValueRef::List(ListValueRef::Ref { val: &list }), + Value::List(list.clone()).as_value_ref() + ); + } + + #[test] + fn test_value_ref_as() { + macro_rules! check_as_null { + ($method: ident) => { + assert_eq!(None, ValueRef::Null.$method().unwrap()); + }; + } + + check_as_null!(as_binary); + check_as_null!(as_string); + check_as_null!(as_boolean); + check_as_null!(as_date); + check_as_null!(as_datetime); + check_as_null!(as_list); + + macro_rules! check_as_correct { + ($data: expr, $Variant: ident, $method: ident) => { + assert_eq!(Some($data), ValueRef::$Variant($data).$method().unwrap()); + }; + } + + check_as_correct!("hello", String, as_string); + check_as_correct!("hello".as_bytes(), Binary, as_binary); + check_as_correct!(true, Boolean, as_boolean); + check_as_correct!(Date::new(123), Date, as_date); + check_as_correct!(DateTime::new(12), DateTime, as_datetime); + let list = ListValue { + items: None, + datatype: ConcreteDataType::int32_datatype(), + }; + check_as_correct!(ListValueRef::Ref { val: &list }, List, as_list); + + let wrong_value = ValueRef::Int32(12345); + assert!(wrong_value.as_binary().is_err()); + assert!(wrong_value.as_string().is_err()); + assert!(wrong_value.as_boolean().is_err()); + assert!(wrong_value.as_date().is_err()); + assert!(wrong_value.as_datetime().is_err()); + assert!(wrong_value.as_list().is_err()); + } + + #[test] + fn test_display() { + assert_eq!(Value::Null.to_string(), "Null"); + assert_eq!(Value::UInt8(8).to_string(), "8"); + assert_eq!(Value::UInt16(16).to_string(), "16"); + assert_eq!(Value::UInt32(32).to_string(), "32"); + assert_eq!(Value::UInt64(64).to_string(), "64"); + assert_eq!(Value::Int8(-8).to_string(), "-8"); + assert_eq!(Value::Int16(-16).to_string(), "-16"); + assert_eq!(Value::Int32(-32).to_string(), "-32"); + assert_eq!(Value::Int64(-64).to_string(), "-64"); + assert_eq!(Value::Float32((-32.123).into()).to_string(), "-32.123"); + assert_eq!(Value::Float64((-64.123).into()).to_string(), "-64.123"); + assert_eq!(Value::Float64(OrderedF64::infinity()).to_string(), "inf"); + assert_eq!(Value::Float64(OrderedF64::nan()).to_string(), "NaN"); + assert_eq!(Value::String(StringBytes::from("123")).to_string(), "123"); + assert_eq!( + Value::Binary(Bytes::from(vec![1, 2, 3])).to_string(), + "010203" + ); + assert_eq!(Value::Date(Date::new(0)).to_string(), "1970-01-01"); + assert_eq!( + Value::DateTime(DateTime::new(0)).to_string(), + "1970-01-01 00:00:00" + ); + assert_eq!( + Value::Timestamp(Timestamp::new(1000, TimeUnit::Millisecond)).to_string(), + "1970-01-01 00:00:01+0000" + ); + assert_eq!( + Value::List(ListValue::new( + Some(Box::new(vec![Value::Int8(1), Value::Int8(2)])), + ConcreteDataType::int8_datatype(), + )) + .to_string(), + "Int8[1, 2]" + ); + assert_eq!( + Value::List(ListValue::new( + Some(Box::new(vec![])), + ConcreteDataType::timestamp_second_datatype(), + )) + .to_string(), + "TimestampSecondType[]" + ); + assert_eq!( + Value::List(ListValue::new( + Some(Box::new(vec![])), + ConcreteDataType::timestamp_millisecond_datatype(), + )) + .to_string(), + "TimestampMillisecondType[]" + ); + assert_eq!( + Value::List(ListValue::new( + Some(Box::new(vec![])), + ConcreteDataType::timestamp_microsecond_datatype(), + )) + .to_string(), + "TimestampMicrosecondType[]" + ); + assert_eq!( + Value::List(ListValue::new( + Some(Box::new(vec![])), + ConcreteDataType::timestamp_nanosecond_datatype(), + )) + .to_string(), + "TimestampNanosecondType[]" + ); + } +} diff --git a/src/datatypes2/src/vectors.rs b/src/datatypes2/src/vectors.rs new file mode 100644 index 0000000000..38fa762d4b --- /dev/null +++ b/src/datatypes2/src/vectors.rs @@ -0,0 +1,309 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef}; +use snafu::ensure; + +use crate::data_type::ConcreteDataType; +use crate::error::{self, Result}; +use crate::serialize::Serializable; +use crate::value::{Value, ValueRef}; +use crate::vectors::operations::VectorOp; + +mod binary; +mod boolean; +mod constant; +mod date; +mod datetime; +mod eq; +mod helper; +mod list; +mod null; +mod operations; +mod primitive; +mod string; +mod timestamp; +mod validity; + +pub use binary::{BinaryVector, BinaryVectorBuilder}; +pub use boolean::{BooleanVector, BooleanVectorBuilder}; +pub use constant::ConstantVector; +pub use date::{DateVector, DateVectorBuilder}; +pub use datetime::{DateTimeVector, DateTimeVectorBuilder}; +pub use helper::Helper; +pub use list::{ListIter, ListVector, ListVectorBuilder}; +pub use null::{NullVector, NullVectorBuilder}; +pub use primitive::{ + Float32Vector, Float32VectorBuilder, Float64Vector, Float64VectorBuilder, Int16Vector, + Int16VectorBuilder, Int32Vector, Int32VectorBuilder, Int64Vector, Int64VectorBuilder, + Int8Vector, Int8VectorBuilder, PrimitiveIter, PrimitiveVector, PrimitiveVectorBuilder, + UInt16Vector, UInt16VectorBuilder, UInt32Vector, UInt32VectorBuilder, UInt64Vector, + UInt64VectorBuilder, UInt8Vector, UInt8VectorBuilder, +}; +pub use string::{StringVector, StringVectorBuilder}; +pub use timestamp::{ + TimestampMicrosecondVector, TimestampMicrosecondVectorBuilder, TimestampMillisecondVector, + TimestampMillisecondVectorBuilder, TimestampNanosecondVector, TimestampNanosecondVectorBuilder, + TimestampSecondVector, TimestampSecondVectorBuilder, +}; +pub use validity::Validity; + +// TODO(yingwen): arrow 28.0 implements Clone for all arrays, we could upgrade to it and simplify +// some codes in methods such as `to_arrow_array()` and `to_boxed_arrow_array()`. +/// Vector of data values. +pub trait Vector: Send + Sync + Serializable + Debug + VectorOp { + /// Returns the data type of the vector. + /// + /// This may require heap allocation. + fn data_type(&self) -> ConcreteDataType; + + fn vector_type_name(&self) -> String; + + /// Returns the vector as [Any](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Returns number of elements in the vector. + fn len(&self) -> usize; + + /// Returns whether the vector is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Convert this vector to a new arrow [ArrayRef]. + fn to_arrow_array(&self) -> ArrayRef; + + /// Convert this vector to a new boxed arrow [Array]. + fn to_boxed_arrow_array(&self) -> Box; + + /// Returns the validity of the Array. + fn validity(&self) -> Validity; + + /// Returns the memory size of vector. + fn memory_size(&self) -> usize; + + /// The number of null slots on this [`Vector`]. + /// # Implementation + /// This is `O(1)`. + fn null_count(&self) -> usize; + + /// Returns true when it's a ConstantColumn + fn is_const(&self) -> bool { + false + } + + /// Returns whether row is null. + fn is_null(&self, row: usize) -> bool; + + /// If the only value vector can contain is NULL. + fn only_null(&self) -> bool { + self.null_count() == self.len() + } + + /// Slices the `Vector`, returning a new `VectorRef`. + /// + /// # Panics + /// This function panics if `offset + length > self.len()`. + fn slice(&self, offset: usize, length: usize) -> VectorRef; + + /// Returns the clone of value at `index`. + /// + /// # Panics + /// Panic if `index` is out of bound. + fn get(&self, index: usize) -> Value; + + /// Returns the clone of value at `index` or error if `index` + /// is out of bound. + fn try_get(&self, index: usize) -> Result { + ensure!( + index < self.len(), + error::BadArrayAccessSnafu { + index, + size: self.len() + } + ); + Ok(self.get(index)) + } + + /// Returns the reference of value at `index`. + /// + /// # Panics + /// Panic if `index` is out of bound. + fn get_ref(&self, index: usize) -> ValueRef; +} + +pub type VectorRef = Arc; + +/// Mutable vector that could be used to build an immutable vector. +pub trait MutableVector: Send + Sync { + /// Returns the data type of the vector. + fn data_type(&self) -> ConcreteDataType; + + /// Returns the length of the vector. + fn len(&self) -> usize; + + /// Returns whether the vector is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Convert to Any, to enable dynamic casting. + fn as_any(&self) -> &dyn Any; + + /// Convert to mutable Any, to enable dynamic casting. + fn as_mut_any(&mut self) -> &mut dyn Any; + + /// Convert `self` to an (immutable) [VectorRef] and reset `self`. + fn to_vector(&mut self) -> VectorRef; + + /// Push value ref to this mutable vector. + /// + /// Returns error if data type unmatch. + fn push_value_ref(&mut self, value: ValueRef) -> Result<()>; + + /// Extend this mutable vector by slice of `vector`. + /// + /// Returns error if data type unmatch. + /// + /// # Panics + /// Panics if `offset + length > vector.len()`. + fn extend_slice_of(&mut self, vector: &dyn Vector, offset: usize, length: usize) -> Result<()>; +} + +/// Helper to define `try_from_arrow_array(array: arrow::array::ArrayRef)` function. +macro_rules! impl_try_from_arrow_array_for_vector { + ($Array: ident, $Vector: ident) => { + impl $Vector { + pub fn try_from_arrow_array( + array: impl AsRef, + ) -> crate::error::Result<$Vector> { + use snafu::OptionExt; + + let data = array + .as_ref() + .as_any() + .downcast_ref::<$Array>() + .with_context(|| crate::error::ConversionSnafu { + from: std::format!("{:?}", array.as_ref().data_type()), + })? + .data() + .clone(); + + let concrete_array = $Array::from(data); + Ok($Vector::from(concrete_array)) + } + } + }; +} + +macro_rules! impl_validity_for_vector { + ($array: expr) => { + Validity::from_array_data($array.data()) + }; +} + +macro_rules! impl_get_for_vector { + ($array: expr, $index: ident) => { + if $array.is_valid($index) { + // Safety: The index have been checked by `is_valid()`. + unsafe { $array.value_unchecked($index).into() } + } else { + Value::Null + } + }; +} + +macro_rules! impl_get_ref_for_vector { + ($array: expr, $index: ident) => { + if $array.is_valid($index) { + // Safety: The index have been checked by `is_valid()`. + unsafe { $array.value_unchecked($index).into() } + } else { + ValueRef::Null + } + }; +} + +macro_rules! impl_extend_for_builder { + ($mutable_vector: expr, $vector: ident, $VectorType: ident, $offset: ident, $length: ident) => {{ + use snafu::OptionExt; + + let sliced_vector = $vector.slice($offset, $length); + let concrete_vector = sliced_vector + .as_any() + .downcast_ref::<$VectorType>() + .with_context(|| crate::error::CastTypeSnafu { + msg: format!( + "Failed to cast vector from {} to {}", + $vector.vector_type_name(), + stringify!($VectorType) + ), + })?; + for value in concrete_vector.iter_data() { + $mutable_vector.push(value); + } + Ok(()) + }}; +} + +pub(crate) use { + impl_extend_for_builder, impl_get_for_vector, impl_get_ref_for_vector, + impl_try_from_arrow_array_for_vector, impl_validity_for_vector, +}; + +#[cfg(test)] +pub mod tests { + use arrow::array::{Array, Int32Array, UInt8Array}; + use serde_json; + + use super::*; + use crate::data_type::DataType; + use crate::types::{Int32Type, LogicalPrimitiveType}; + use crate::vectors::helper::Helper; + + #[test] + fn test_df_columns_to_vector() { + let df_column: Arc = Arc::new(Int32Array::from(vec![1, 2, 3])); + let vector = Helper::try_into_vector(df_column).unwrap(); + assert_eq!( + Int32Type::build_data_type().as_arrow_type(), + vector.data_type().as_arrow_type() + ); + } + + #[test] + fn test_serialize_i32_vector() { + let df_column: Arc = Arc::new(Int32Array::from(vec![1, 2, 3])); + let json_value = Helper::try_into_vector(df_column) + .unwrap() + .serialize_to_json() + .unwrap(); + assert_eq!("[1,2,3]", serde_json::to_string(&json_value).unwrap()); + } + + #[test] + fn test_serialize_i8_vector() { + let df_column: Arc = Arc::new(UInt8Array::from(vec![1, 2, 3])); + let json_value = Helper::try_into_vector(df_column) + .unwrap() + .serialize_to_json() + .unwrap(); + assert_eq!("[1,2,3]", serde_json::to_string(&json_value).unwrap()); + } +} diff --git a/src/datatypes2/src/vectors/binary.rs b/src/datatypes2/src/vectors/binary.rs new file mode 100644 index 0000000000..3b5defc8ec --- /dev/null +++ b/src/datatypes2/src/vectors/binary.rs @@ -0,0 +1,353 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayBuilder, ArrayData, ArrayIter, ArrayRef}; +use snafu::ResultExt; + +use crate::arrow_array::{BinaryArray, MutableBinaryArray}; +use crate::data_type::ConcreteDataType; +use crate::error::{self, Result}; +use crate::scalars::{ScalarVector, ScalarVectorBuilder}; +use crate::serialize::Serializable; +use crate::value::{Value, ValueRef}; +use crate::vectors::{self, MutableVector, Validity, Vector, VectorRef}; + +/// Vector of binary strings. +#[derive(Debug, PartialEq)] +pub struct BinaryVector { + array: BinaryArray, +} + +impl BinaryVector { + pub(crate) fn as_arrow(&self) -> &dyn Array { + &self.array + } + + fn to_array_data(&self) -> ArrayData { + self.array.data().clone() + } + + fn from_array_data(data: ArrayData) -> BinaryVector { + BinaryVector { + array: BinaryArray::from(data), + } + } +} + +impl From for BinaryVector { + fn from(array: BinaryArray) -> Self { + Self { array } + } +} + +impl From>>> for BinaryVector { + fn from(data: Vec>>) -> Self { + Self { + array: BinaryArray::from_iter(data), + } + } +} + +impl Vector for BinaryVector { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::binary_datatype() + } + + fn vector_type_name(&self) -> String { + "BinaryVector".to_string() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.array.len() + } + + fn to_arrow_array(&self) -> ArrayRef { + let data = self.to_array_data(); + Arc::new(BinaryArray::from(data)) + } + + fn to_boxed_arrow_array(&self) -> Box { + let data = self.to_array_data(); + Box::new(BinaryArray::from(data)) + } + + fn validity(&self) -> Validity { + vectors::impl_validity_for_vector!(self.array) + } + + fn memory_size(&self) -> usize { + self.array.get_buffer_memory_size() + } + + fn null_count(&self) -> usize { + self.array.null_count() + } + + fn is_null(&self, row: usize) -> bool { + self.array.is_null(row) + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef { + let data = self.array.data().slice(offset, length); + Arc::new(Self::from_array_data(data)) + } + + fn get(&self, index: usize) -> Value { + vectors::impl_get_for_vector!(self.array, index) + } + + fn get_ref(&self, index: usize) -> ValueRef { + vectors::impl_get_ref_for_vector!(self.array, index) + } +} + +impl ScalarVector for BinaryVector { + type OwnedItem = Vec; + type RefItem<'a> = &'a [u8]; + type Iter<'a> = ArrayIter<&'a BinaryArray>; + type Builder = BinaryVectorBuilder; + + fn get_data(&self, idx: usize) -> Option> { + if self.array.is_valid(idx) { + Some(self.array.value(idx)) + } else { + None + } + } + + fn iter_data(&self) -> Self::Iter<'_> { + self.array.iter() + } +} + +pub struct BinaryVectorBuilder { + mutable_array: MutableBinaryArray, +} + +impl MutableVector for BinaryVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::binary_datatype() + } + + fn len(&self) -> usize { + self.mutable_array.len() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + Arc::new(self.finish()) + } + + fn push_value_ref(&mut self, value: ValueRef) -> Result<()> { + match value.as_binary()? { + Some(v) => self.mutable_array.append_value(v), + None => self.mutable_array.append_null(), + } + Ok(()) + } + + fn extend_slice_of(&mut self, vector: &dyn Vector, offset: usize, length: usize) -> Result<()> { + vectors::impl_extend_for_builder!(self, vector, BinaryVector, offset, length) + } +} + +impl ScalarVectorBuilder for BinaryVectorBuilder { + type VectorType = BinaryVector; + + fn with_capacity(capacity: usize) -> Self { + Self { + mutable_array: MutableBinaryArray::with_capacity(capacity, 0), + } + } + + fn push(&mut self, value: Option<::RefItem<'_>>) { + match value { + Some(v) => self.mutable_array.append_value(v), + None => self.mutable_array.append_null(), + } + } + + fn finish(&mut self) -> Self::VectorType { + BinaryVector { + array: self.mutable_array.finish(), + } + } +} + +impl Serializable for BinaryVector { + fn serialize_to_json(&self) -> Result> { + self.iter_data() + .map(|v| match v { + None => Ok(serde_json::Value::Null), // if binary vector not present, map to NULL + Some(vec) => serde_json::to_value(vec), + }) + .collect::>() + .context(error::SerializeSnafu) + } +} + +vectors::impl_try_from_arrow_array_for_vector!(BinaryArray, BinaryVector); + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType as ArrowDataType; + use common_base::bytes::Bytes; + use serde_json; + + use super::*; + use crate::arrow_array::BinaryArray; + use crate::data_type::DataType; + use crate::serialize::Serializable; + use crate::types::BinaryType; + + #[test] + fn test_binary_vector_misc() { + let v = BinaryVector::from(BinaryArray::from_iter_values(&[ + vec![1, 2, 3], + vec![1, 2, 3], + ])); + + assert_eq!(2, v.len()); + assert_eq!("BinaryVector", v.vector_type_name()); + assert!(!v.is_const()); + assert!(v.validity().is_all_valid()); + assert!(!v.only_null()); + assert_eq!(128, v.memory_size()); + + for i in 0..2 { + assert!(!v.is_null(i)); + assert_eq!(Value::Binary(Bytes::from(vec![1, 2, 3])), v.get(i)); + assert_eq!(ValueRef::Binary(&[1, 2, 3]), v.get_ref(i)); + } + + let arrow_arr = v.to_arrow_array(); + assert_eq!(2, arrow_arr.len()); + assert_eq!(&ArrowDataType::LargeBinary, arrow_arr.data_type()); + } + + #[test] + fn test_serialize_binary_vector_to_json() { + let vector = BinaryVector::from(BinaryArray::from_iter_values(&[ + vec![1, 2, 3], + vec![1, 2, 3], + ])); + + let json_value = vector.serialize_to_json().unwrap(); + assert_eq!( + "[[1,2,3],[1,2,3]]", + serde_json::to_string(&json_value).unwrap() + ); + } + + #[test] + fn test_serialize_binary_vector_with_null_to_json() { + let mut builder = BinaryVectorBuilder::with_capacity(4); + builder.push(Some(&[1, 2, 3])); + builder.push(None); + builder.push(Some(&[4, 5, 6])); + let vector = builder.finish(); + + let json_value = vector.serialize_to_json().unwrap(); + assert_eq!( + "[[1,2,3],null,[4,5,6]]", + serde_json::to_string(&json_value).unwrap() + ); + } + + #[test] + fn test_from_arrow_array() { + let arrow_array = BinaryArray::from_iter_values(&[vec![1, 2, 3], vec![1, 2, 3]]); + let original = BinaryArray::from(arrow_array.data().clone()); + let vector = BinaryVector::from(arrow_array); + assert_eq!(original, vector.array); + } + + #[test] + fn test_binary_vector_build_get() { + let mut builder = BinaryVectorBuilder::with_capacity(4); + builder.push(Some(b"hello")); + builder.push(Some(b"happy")); + builder.push(Some(b"world")); + builder.push(None); + + let vector = builder.finish(); + assert_eq!(b"hello", vector.get_data(0).unwrap()); + assert_eq!(None, vector.get_data(3)); + + assert_eq!(Value::Binary(b"hello".as_slice().into()), vector.get(0)); + assert_eq!(Value::Null, vector.get(3)); + + let mut iter = vector.iter_data(); + assert_eq!(b"hello", iter.next().unwrap().unwrap()); + assert_eq!(b"happy", iter.next().unwrap().unwrap()); + assert_eq!(b"world", iter.next().unwrap().unwrap()); + assert_eq!(None, iter.next().unwrap()); + assert_eq!(None, iter.next()); + } + + #[test] + fn test_binary_vector_validity() { + let mut builder = BinaryVectorBuilder::with_capacity(4); + builder.push(Some(b"hello")); + builder.push(Some(b"world")); + let vector = builder.finish(); + assert_eq!(0, vector.null_count()); + assert!(vector.validity().is_all_valid()); + + let mut builder = BinaryVectorBuilder::with_capacity(3); + builder.push(Some(b"hello")); + builder.push(None); + builder.push(Some(b"world")); + let vector = builder.finish(); + assert_eq!(1, vector.null_count()); + let validity = vector.validity(); + assert!(!validity.is_set(1)); + + assert_eq!(1, validity.null_count()); + assert!(!validity.is_set(1)); + } + + #[test] + fn test_binary_vector_builder() { + let input = BinaryVector::from_slice(&[b"world", b"one", b"two"]); + + let mut builder = BinaryType::default().create_mutable_vector(3); + builder + .push_value_ref(ValueRef::Binary("hello".as_bytes())) + .unwrap(); + assert!(builder.push_value_ref(ValueRef::Int32(123)).is_err()); + builder.extend_slice_of(&input, 1, 2).unwrap(); + assert!(builder + .extend_slice_of(&crate::vectors::Int32Vector::from_slice(&[13]), 0, 1) + .is_err()); + let vector = builder.to_vector(); + + let expect: VectorRef = Arc::new(BinaryVector::from_slice(&[b"hello", b"one", b"two"])); + assert_eq!(expect, vector); + } +} diff --git a/src/datatypes2/src/vectors/boolean.rs b/src/datatypes2/src/vectors/boolean.rs new file mode 100644 index 0000000000..2b4e5b8e10 --- /dev/null +++ b/src/datatypes2/src/vectors/boolean.rs @@ -0,0 +1,371 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::borrow::Borrow; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayBuilder, ArrayData, ArrayIter, ArrayRef, BooleanArray, BooleanBuilder, +}; +use snafu::ResultExt; + +use crate::data_type::ConcreteDataType; +use crate::error::Result; +use crate::scalars::{ScalarVector, ScalarVectorBuilder}; +use crate::serialize::Serializable; +use crate::value::{Value, ValueRef}; +use crate::vectors::{self, MutableVector, Validity, Vector, VectorRef}; + +/// Vector of boolean. +#[derive(Debug, PartialEq)] +pub struct BooleanVector { + array: BooleanArray, +} + +impl BooleanVector { + pub(crate) fn as_arrow(&self) -> &dyn Array { + &self.array + } + + pub(crate) fn as_boolean_array(&self) -> &BooleanArray { + &self.array + } + + fn to_array_data(&self) -> ArrayData { + self.array.data().clone() + } + + fn from_array_data(data: ArrayData) -> BooleanVector { + BooleanVector { + array: BooleanArray::from(data), + } + } + + pub(crate) fn false_count(&self) -> usize { + self.array.false_count() + } +} + +impl From> for BooleanVector { + fn from(data: Vec) -> Self { + BooleanVector { + array: BooleanArray::from(data), + } + } +} + +impl From for BooleanVector { + fn from(array: BooleanArray) -> Self { + Self { array } + } +} + +impl From>> for BooleanVector { + fn from(data: Vec>) -> Self { + BooleanVector { + array: BooleanArray::from(data), + } + } +} + +impl>> FromIterator for BooleanVector { + fn from_iter>(iter: I) -> Self { + BooleanVector { + array: BooleanArray::from_iter(iter), + } + } +} + +impl Vector for BooleanVector { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::boolean_datatype() + } + + fn vector_type_name(&self) -> String { + "BooleanVector".to_string() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.array.len() + } + + fn to_arrow_array(&self) -> ArrayRef { + let data = self.to_array_data(); + Arc::new(BooleanArray::from(data)) + } + + fn to_boxed_arrow_array(&self) -> Box { + let data = self.to_array_data(); + Box::new(BooleanArray::from(data)) + } + + fn validity(&self) -> Validity { + vectors::impl_validity_for_vector!(self.array) + } + + fn memory_size(&self) -> usize { + self.array.get_buffer_memory_size() + } + + fn null_count(&self) -> usize { + self.array.null_count() + } + + fn is_null(&self, row: usize) -> bool { + self.array.is_null(row) + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef { + let data = self.array.data().slice(offset, length); + Arc::new(Self::from_array_data(data)) + } + + fn get(&self, index: usize) -> Value { + vectors::impl_get_for_vector!(self.array, index) + } + + fn get_ref(&self, index: usize) -> ValueRef { + vectors::impl_get_ref_for_vector!(self.array, index) + } +} + +impl ScalarVector for BooleanVector { + type OwnedItem = bool; + type RefItem<'a> = bool; + type Iter<'a> = ArrayIter<&'a BooleanArray>; + type Builder = BooleanVectorBuilder; + + fn get_data(&self, idx: usize) -> Option> { + if self.array.is_valid(idx) { + Some(self.array.value(idx)) + } else { + None + } + } + + fn iter_data(&self) -> Self::Iter<'_> { + self.array.iter() + } +} + +pub struct BooleanVectorBuilder { + mutable_array: BooleanBuilder, +} + +impl MutableVector for BooleanVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::boolean_datatype() + } + + fn len(&self) -> usize { + self.mutable_array.len() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + Arc::new(self.finish()) + } + + fn push_value_ref(&mut self, value: ValueRef) -> Result<()> { + match value.as_boolean()? { + Some(v) => self.mutable_array.append_value(v), + None => self.mutable_array.append_null(), + } + Ok(()) + } + + fn extend_slice_of(&mut self, vector: &dyn Vector, offset: usize, length: usize) -> Result<()> { + vectors::impl_extend_for_builder!(self, vector, BooleanVector, offset, length) + } +} + +impl ScalarVectorBuilder for BooleanVectorBuilder { + type VectorType = BooleanVector; + + fn with_capacity(capacity: usize) -> Self { + Self { + mutable_array: BooleanBuilder::with_capacity(capacity), + } + } + + fn push(&mut self, value: Option<::RefItem<'_>>) { + match value { + Some(v) => self.mutable_array.append_value(v), + None => self.mutable_array.append_null(), + } + } + + fn finish(&mut self) -> Self::VectorType { + BooleanVector { + array: self.mutable_array.finish(), + } + } +} + +impl Serializable for BooleanVector { + fn serialize_to_json(&self) -> Result> { + self.iter_data() + .map(serde_json::to_value) + .collect::>() + .context(crate::error::SerializeSnafu) + } +} + +vectors::impl_try_from_arrow_array_for_vector!(BooleanArray, BooleanVector); + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType as ArrowDataType; + use serde_json; + + use super::*; + use crate::data_type::DataType; + use crate::serialize::Serializable; + use crate::types::BooleanType; + + #[test] + fn test_boolean_vector_misc() { + let bools = vec![true, false, true, true, false, false, true, true, false]; + let v = BooleanVector::from(bools.clone()); + assert_eq!(9, v.len()); + assert_eq!("BooleanVector", v.vector_type_name()); + assert!(!v.is_const()); + assert!(v.validity().is_all_valid()); + assert!(!v.only_null()); + assert_eq!(64, v.memory_size()); + + for (i, b) in bools.iter().enumerate() { + assert!(!v.is_null(i)); + assert_eq!(Value::Boolean(*b), v.get(i)); + assert_eq!(ValueRef::Boolean(*b), v.get_ref(i)); + } + + let arrow_arr = v.to_arrow_array(); + assert_eq!(9, arrow_arr.len()); + assert_eq!(&ArrowDataType::Boolean, arrow_arr.data_type()); + } + + #[test] + fn test_serialize_boolean_vector_to_json() { + let vector = BooleanVector::from(vec![true, false, true, true, false, false]); + + let json_value = vector.serialize_to_json().unwrap(); + assert_eq!( + "[true,false,true,true,false,false]", + serde_json::to_string(&json_value).unwrap(), + ); + } + + #[test] + fn test_serialize_boolean_vector_with_null_to_json() { + let vector = BooleanVector::from(vec![Some(true), None, Some(false)]); + + let json_value = vector.serialize_to_json().unwrap(); + assert_eq!( + "[true,null,false]", + serde_json::to_string(&json_value).unwrap(), + ); + } + + #[test] + fn test_boolean_vector_from_vec() { + let input = vec![false, true, false, true]; + let vec = BooleanVector::from(input.clone()); + assert_eq!(4, vec.len()); + for (i, v) in input.into_iter().enumerate() { + assert_eq!(Some(v), vec.get_data(i), "failed at {}", i) + } + } + + #[test] + fn test_boolean_vector_from_iter() { + let input = vec![Some(false), Some(true), Some(false), Some(true)]; + let vec = input.iter().collect::(); + assert_eq!(4, vec.len()); + for (i, v) in input.into_iter().enumerate() { + assert_eq!(v, vec.get_data(i), "failed at {}", i) + } + } + + #[test] + fn test_boolean_vector_from_vec_option() { + let input = vec![Some(false), Some(true), None, Some(true)]; + let vec = BooleanVector::from(input.clone()); + assert_eq!(4, vec.len()); + for (i, v) in input.into_iter().enumerate() { + assert_eq!(v, vec.get_data(i), "failed at {}", i) + } + } + + #[test] + fn test_boolean_vector_build_get() { + let input = [Some(true), None, Some(false)]; + let mut builder = BooleanVectorBuilder::with_capacity(3); + for v in input { + builder.push(v); + } + let vector = builder.finish(); + assert_eq!(input.len(), vector.len()); + + let res: Vec<_> = vector.iter_data().collect(); + assert_eq!(input, &res[..]); + + for (i, v) in input.into_iter().enumerate() { + assert_eq!(v, vector.get_data(i)); + assert_eq!(Value::from(v), vector.get(i)); + } + } + + #[test] + fn test_boolean_vector_validity() { + let vector = BooleanVector::from(vec![Some(true), None, Some(false)]); + assert_eq!(1, vector.null_count()); + let validity = vector.validity(); + assert_eq!(1, validity.null_count()); + assert!(!validity.is_set(1)); + + let vector = BooleanVector::from(vec![true, false, false]); + assert_eq!(0, vector.null_count()); + assert!(vector.validity().is_all_valid()); + } + + #[test] + fn test_boolean_vector_builder() { + let input = BooleanVector::from_slice(&[true, false, true]); + + let mut builder = BooleanType::default().create_mutable_vector(3); + builder.push_value_ref(ValueRef::Boolean(true)).unwrap(); + assert!(builder.push_value_ref(ValueRef::Int32(123)).is_err()); + builder.extend_slice_of(&input, 1, 2).unwrap(); + assert!(builder + .extend_slice_of(&crate::vectors::Int32Vector::from_slice(&[13]), 0, 1) + .is_err()); + let vector = builder.to_vector(); + + let expect: VectorRef = Arc::new(BooleanVector::from_slice(&[true, false, true])); + assert_eq!(expect, vector); + } +} diff --git a/src/datatypes2/src/vectors/constant.rs b/src/datatypes2/src/vectors/constant.rs new file mode 100644 index 0000000000..87739e9131 --- /dev/null +++ b/src/datatypes2/src/vectors/constant.rs @@ -0,0 +1,218 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef}; +use snafu::ResultExt; + +use crate::data_type::ConcreteDataType; +use crate::error::{Result, SerializeSnafu}; +use crate::serialize::Serializable; +use crate::value::{Value, ValueRef}; +use crate::vectors::{BooleanVector, Helper, Validity, Vector, VectorRef}; + +#[derive(Clone)] +pub struct ConstantVector { + length: usize, + vector: VectorRef, +} + +impl ConstantVector { + /// Create a new [ConstantVector]. + /// + /// # Panics + /// Panics if `vector.len() != 1`. + pub fn new(vector: VectorRef, length: usize) -> Self { + assert_eq!(1, vector.len()); + + // Avoid const recursion. + if vector.is_const() { + let vec: &ConstantVector = unsafe { Helper::static_cast(&vector) }; + return Self::new(vec.inner().clone(), length); + } + Self { vector, length } + } + + pub fn inner(&self) -> &VectorRef { + &self.vector + } + + /// Returns the constant value. + pub fn get_constant_ref(&self) -> ValueRef { + self.vector.get_ref(0) + } + + pub(crate) fn replicate_vector(&self, offsets: &[usize]) -> VectorRef { + assert_eq!(offsets.len(), self.len()); + + if offsets.is_empty() { + return self.slice(0, 0); + } + + Arc::new(ConstantVector::new( + self.vector.clone(), + *offsets.last().unwrap(), + )) + } + + pub(crate) fn filter_vector(&self, filter: &BooleanVector) -> Result { + let length = self.len() - filter.false_count(); + if length == self.len() { + return Ok(Arc::new(self.clone())); + } + Ok(Arc::new(ConstantVector::new(self.inner().clone(), length))) + } +} + +impl Vector for ConstantVector { + fn data_type(&self) -> ConcreteDataType { + self.vector.data_type() + } + + fn vector_type_name(&self) -> String { + "ConstantVector".to_string() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.length + } + + fn to_arrow_array(&self) -> ArrayRef { + let v = self.vector.replicate(&[self.length]); + v.to_arrow_array() + } + + fn to_boxed_arrow_array(&self) -> Box { + let v = self.vector.replicate(&[self.length]); + v.to_boxed_arrow_array() + } + + fn is_const(&self) -> bool { + true + } + + fn validity(&self) -> Validity { + if self.vector.is_null(0) { + Validity::all_null(self.length) + } else { + Validity::all_valid(self.length) + } + } + + fn memory_size(&self) -> usize { + self.vector.memory_size() + } + + fn is_null(&self, _row: usize) -> bool { + self.vector.is_null(0) + } + + fn only_null(&self) -> bool { + self.vector.is_null(0) + } + + fn slice(&self, _offset: usize, length: usize) -> VectorRef { + Arc::new(Self { + vector: self.vector.clone(), + length, + }) + } + + fn get(&self, _index: usize) -> Value { + self.vector.get(0) + } + + fn get_ref(&self, _index: usize) -> ValueRef { + self.vector.get_ref(0) + } + + fn null_count(&self) -> usize { + if self.only_null() { + self.len() + } else { + 0 + } + } +} + +impl fmt::Debug for ConstantVector { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "ConstantVector([{:?}; {}])", self.get(0), self.len()) + } +} + +impl Serializable for ConstantVector { + fn serialize_to_json(&self) -> Result> { + std::iter::repeat(self.get(0)) + .take(self.len()) + .map(serde_json::Value::try_from) + .collect::>() + .context(SerializeSnafu) + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType as ArrowDataType; + + use super::*; + use crate::vectors::Int32Vector; + + #[test] + fn test_constant_vector_misc() { + let a = Int32Vector::from_slice(vec![1]); + let c = ConstantVector::new(Arc::new(a), 10); + + assert_eq!("ConstantVector", c.vector_type_name()); + assert!(c.is_const()); + assert_eq!(10, c.len()); + assert!(c.validity().is_all_valid()); + assert!(!c.only_null()); + assert_eq!(64, c.memory_size()); + + for i in 0..10 { + assert!(!c.is_null(i)); + assert_eq!(Value::Int32(1), c.get(i)); + } + + let arrow_arr = c.to_arrow_array(); + assert_eq!(10, arrow_arr.len()); + assert_eq!(&ArrowDataType::Int32, arrow_arr.data_type()); + } + + #[test] + fn test_debug_null_array() { + let a = Int32Vector::from_slice(vec![1]); + let c = ConstantVector::new(Arc::new(a), 10); + + let s = format!("{:?}", c); + assert_eq!(s, "ConstantVector([Int32(1); 10])"); + } + + #[test] + fn test_serialize_json() { + let a = Int32Vector::from_slice(vec![1]); + let c = ConstantVector::new(Arc::new(a), 10); + + let s = serde_json::to_string(&c.serialize_to_json().unwrap()).unwrap(); + assert_eq!(s, "[1,1,1,1,1,1,1,1,1,1]"); + } +} diff --git a/src/datatypes2/src/vectors/date.rs b/src/datatypes2/src/vectors/date.rs new file mode 100644 index 0000000000..d0a66b80fb --- /dev/null +++ b/src/datatypes2/src/vectors/date.rs @@ -0,0 +1,103 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::types::DateType; +use crate::vectors::{PrimitiveVector, PrimitiveVectorBuilder}; + +// Vector for [`Date`](common_time::Date). +pub type DateVector = PrimitiveVector; +// Builder to build DateVector. +pub type DateVectorBuilder = PrimitiveVectorBuilder; + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::Array; + use common_time::date::Date; + + use super::*; + use crate::data_type::DataType; + use crate::scalars::{ScalarVector, ScalarVectorBuilder}; + use crate::serialize::Serializable; + use crate::types::DateType; + use crate::value::{Value, ValueRef}; + use crate::vectors::{Vector, VectorRef}; + + #[test] + fn test_build_date_vector() { + let mut builder = DateVectorBuilder::with_capacity(4); + builder.push(Some(Date::new(1))); + builder.push(None); + builder.push(Some(Date::new(-1))); + let vector = builder.finish(); + assert_eq!(3, vector.len()); + assert_eq!(Value::Date(Date::new(1)), vector.get(0)); + assert_eq!(ValueRef::Date(Date::new(1)), vector.get_ref(0)); + assert_eq!(Some(Date::new(1)), vector.get_data(0)); + assert_eq!(None, vector.get_data(1)); + assert_eq!(Value::Null, vector.get(1)); + assert_eq!(ValueRef::Null, vector.get_ref(1)); + assert_eq!(Some(Date::new(-1)), vector.get_data(2)); + let mut iter = vector.iter_data(); + assert_eq!(Some(Date::new(1)), iter.next().unwrap()); + assert_eq!(None, iter.next().unwrap()); + assert_eq!(Some(Date::new(-1)), iter.next().unwrap()); + } + + #[test] + fn test_date_scalar() { + let vector = DateVector::from_slice(&[1, 2]); + assert_eq!(2, vector.len()); + assert_eq!(Some(Date::new(1)), vector.get_data(0)); + assert_eq!(Some(Date::new(2)), vector.get_data(1)); + } + + #[test] + fn test_date_vector_builder() { + let input = DateVector::from_slice(&[1, 2, 3]); + + let mut builder = DateType::default().create_mutable_vector(3); + builder + .push_value_ref(ValueRef::Date(Date::new(5))) + .unwrap(); + assert!(builder.push_value_ref(ValueRef::Int32(123)).is_err()); + builder.extend_slice_of(&input, 1, 2).unwrap(); + assert!(builder + .extend_slice_of(&crate::vectors::Int32Vector::from_slice(&[13]), 0, 1) + .is_err()); + let vector = builder.to_vector(); + + let expect: VectorRef = Arc::new(DateVector::from_slice(&[5, 2, 3])); + assert_eq!(expect, vector); + } + + #[test] + fn test_date_from_arrow() { + let vector = DateVector::from_slice(&[1, 2]); + let arrow = vector.as_arrow().slice(0, vector.len()); + let vector2 = DateVector::try_from_arrow_array(&arrow).unwrap(); + assert_eq!(vector, vector2); + } + + #[test] + fn test_serialize_date_vector() { + let vector = DateVector::from_slice(&[-1, 0, 1]); + let serialized_json = serde_json::to_string(&vector.serialize_to_json().unwrap()).unwrap(); + assert_eq!( + r#"["1969-12-31","1970-01-01","1970-01-02"]"#, + serialized_json + ); + } +} diff --git a/src/datatypes2/src/vectors/datetime.rs b/src/datatypes2/src/vectors/datetime.rs new file mode 100644 index 0000000000..a40a3e54d3 --- /dev/null +++ b/src/datatypes2/src/vectors/datetime.rs @@ -0,0 +1,116 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::types::DateTimeType; +use crate::vectors::{PrimitiveVector, PrimitiveVectorBuilder}; + +/// Vector of [`DateTime`](common_time::Date) +pub type DateTimeVector = PrimitiveVector; +/// Builder for [`DateTimeVector`]. +pub type DateTimeVectorBuilder = PrimitiveVectorBuilder; + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{Array, PrimitiveArray}; + use common_time::DateTime; + use datafusion_common::from_slice::FromSlice; + + use super::*; + use crate::data_type::DataType; + use crate::prelude::{ + ConcreteDataType, ScalarVector, ScalarVectorBuilder, Value, ValueRef, Vector, VectorRef, + }; + use crate::serialize::Serializable; + + #[test] + fn test_datetime_vector() { + let v = DateTimeVector::new(PrimitiveArray::from_slice(&[1, 2, 3])); + assert_eq!(ConcreteDataType::datetime_datatype(), v.data_type()); + assert_eq!(3, v.len()); + assert_eq!("DateTimeVector", v.vector_type_name()); + assert_eq!( + &arrow::datatypes::DataType::Date64, + v.to_arrow_array().data_type() + ); + + assert_eq!(Some(DateTime::new(1)), v.get_data(0)); + assert_eq!(Value::DateTime(DateTime::new(1)), v.get(0)); + assert_eq!(ValueRef::DateTime(DateTime::new(1)), v.get_ref(0)); + + let mut iter = v.iter_data(); + assert_eq!(Some(DateTime::new(1)), iter.next().unwrap()); + assert_eq!(Some(DateTime::new(2)), iter.next().unwrap()); + assert_eq!(Some(DateTime::new(3)), iter.next().unwrap()); + assert!(!v.is_null(0)); + assert_eq!(64, v.memory_size()); + + if let Value::DateTime(d) = v.get(0) { + assert_eq!(1, d.val()); + } else { + unreachable!() + } + assert_eq!( + "[\"1970-01-01 00:00:01\",\"1970-01-01 00:00:02\",\"1970-01-01 00:00:03\"]", + serde_json::to_string(&v.serialize_to_json().unwrap()).unwrap() + ); + } + + #[test] + fn test_datetime_vector_builder() { + let mut builder = DateTimeVectorBuilder::with_capacity(3); + builder.push(Some(DateTime::new(1))); + builder.push(None); + builder.push(Some(DateTime::new(-1))); + + let v = builder.finish(); + assert_eq!(ConcreteDataType::datetime_datatype(), v.data_type()); + assert_eq!(Value::DateTime(DateTime::new(1)), v.get(0)); + assert_eq!(Value::Null, v.get(1)); + assert_eq!(Value::DateTime(DateTime::new(-1)), v.get(2)); + + let input = DateTimeVector::from_wrapper_slice(&[ + DateTime::new(1), + DateTime::new(2), + DateTime::new(3), + ]); + + let mut builder = DateTimeType::default().create_mutable_vector(3); + builder + .push_value_ref(ValueRef::DateTime(DateTime::new(5))) + .unwrap(); + assert!(builder.push_value_ref(ValueRef::Int32(123)).is_err()); + builder.extend_slice_of(&input, 1, 2).unwrap(); + assert!(builder + .extend_slice_of(&crate::vectors::Int32Vector::from_slice(&[13]), 0, 1) + .is_err()); + let vector = builder.to_vector(); + + let expect: VectorRef = Arc::new(DateTimeVector::from_wrapper_slice(&[ + DateTime::new(5), + DateTime::new(2), + DateTime::new(3), + ])); + assert_eq!(expect, vector); + } + + #[test] + fn test_datetime_from_arrow() { + let vector = DateTimeVector::from_wrapper_slice(&[DateTime::new(1), DateTime::new(2)]); + let arrow = vector.as_arrow().slice(0, vector.len()); + let vector2 = DateTimeVector::try_from_arrow_array(&arrow).unwrap(); + assert_eq!(vector, vector2); + } +} diff --git a/src/datatypes2/src/vectors/eq.rs b/src/datatypes2/src/vectors/eq.rs new file mode 100644 index 0000000000..55359026d4 --- /dev/null +++ b/src/datatypes2/src/vectors/eq.rs @@ -0,0 +1,228 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use crate::data_type::DataType; +use crate::types::TimestampType; +use crate::vectors::constant::ConstantVector; +use crate::vectors::{ + BinaryVector, BooleanVector, DateTimeVector, DateVector, ListVector, PrimitiveVector, + StringVector, TimestampMicrosecondVector, TimestampMillisecondVector, + TimestampNanosecondVector, TimestampSecondVector, Vector, +}; +use crate::with_match_primitive_type_id; + +impl Eq for dyn Vector + '_ {} + +impl PartialEq for dyn Vector + '_ { + fn eq(&self, other: &dyn Vector) -> bool { + equal(self, other) + } +} + +impl PartialEq for Arc { + fn eq(&self, other: &dyn Vector) -> bool { + equal(&**self, other) + } +} + +macro_rules! is_vector_eq { + ($VectorType: ident, $lhs: ident, $rhs: ident) => {{ + let lhs = $lhs.as_any().downcast_ref::<$VectorType>().unwrap(); + let rhs = $rhs.as_any().downcast_ref::<$VectorType>().unwrap(); + + lhs == rhs + }}; +} + +fn equal(lhs: &dyn Vector, rhs: &dyn Vector) -> bool { + if lhs.data_type() != rhs.data_type() || lhs.len() != rhs.len() { + return false; + } + + if lhs.is_const() || rhs.is_const() { + // Length has been checked before, so we only need to compare inner + // vector here. + return equal( + &**lhs + .as_any() + .downcast_ref::() + .unwrap() + .inner(), + &**lhs + .as_any() + .downcast_ref::() + .unwrap() + .inner(), + ); + } + + use crate::data_type::ConcreteDataType::*; + + let lhs_type = lhs.data_type(); + match lhs.data_type() { + Null(_) => true, + Boolean(_) => is_vector_eq!(BooleanVector, lhs, rhs), + Binary(_) => is_vector_eq!(BinaryVector, lhs, rhs), + String(_) => is_vector_eq!(StringVector, lhs, rhs), + Date(_) => is_vector_eq!(DateVector, lhs, rhs), + DateTime(_) => is_vector_eq!(DateTimeVector, lhs, rhs), + Timestamp(t) => match t { + TimestampType::Second(_) => { + is_vector_eq!(TimestampSecondVector, lhs, rhs) + } + TimestampType::Millisecond(_) => { + is_vector_eq!(TimestampMillisecondVector, lhs, rhs) + } + TimestampType::Microsecond(_) => { + is_vector_eq!(TimestampMicrosecondVector, lhs, rhs) + } + TimestampType::Nanosecond(_) => { + is_vector_eq!(TimestampNanosecondVector, lhs, rhs) + } + }, + List(_) => is_vector_eq!(ListVector, lhs, rhs), + UInt8(_) | UInt16(_) | UInt32(_) | UInt64(_) | Int8(_) | Int16(_) | Int32(_) | Int64(_) + | Float32(_) | Float64(_) => { + with_match_primitive_type_id!(lhs_type.logical_type_id(), |$T| { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + + lhs == rhs + }, + { + unreachable!("should not compare {} with {}", lhs.vector_type_name(), rhs.vector_type_name()) + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::vectors::{ + list, Float32Vector, Float64Vector, Int16Vector, Int32Vector, Int64Vector, Int8Vector, + NullVector, UInt16Vector, UInt32Vector, UInt64Vector, UInt8Vector, VectorRef, + }; + + fn assert_vector_ref_eq(vector: VectorRef) { + let rhs = vector.clone(); + assert_eq!(vector, rhs); + assert_dyn_vector_eq(&*vector, &*rhs); + } + + fn assert_dyn_vector_eq(lhs: &dyn Vector, rhs: &dyn Vector) { + assert_eq!(lhs, rhs); + } + + fn assert_vector_ref_ne(lhs: VectorRef, rhs: VectorRef) { + assert_ne!(lhs, rhs); + } + + #[test] + fn test_vector_eq() { + assert_vector_ref_eq(Arc::new(BinaryVector::from(vec![ + Some(b"hello".to_vec()), + Some(b"world".to_vec()), + ]))); + assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false]))); + assert_vector_ref_eq(Arc::new(ConstantVector::new( + Arc::new(BooleanVector::from(vec![true])), + 5, + ))); + assert_vector_ref_eq(Arc::new(BooleanVector::from(vec![true, false]))); + assert_vector_ref_eq(Arc::new(DateVector::from(vec![Some(100), Some(120)]))); + assert_vector_ref_eq(Arc::new(DateTimeVector::from(vec![Some(100), Some(120)]))); + assert_vector_ref_eq(Arc::new(TimestampSecondVector::from_values([100, 120]))); + assert_vector_ref_eq(Arc::new(TimestampMillisecondVector::from_values([ + 100, 120, + ]))); + assert_vector_ref_eq(Arc::new(TimestampMicrosecondVector::from_values([ + 100, 120, + ]))); + assert_vector_ref_eq(Arc::new(TimestampNanosecondVector::from_values([100, 120]))); + + let list_vector = list::tests::new_list_vector(&[ + Some(vec![Some(1), Some(2)]), + None, + Some(vec![Some(3), Some(4)]), + ]); + assert_vector_ref_eq(Arc::new(list_vector)); + + assert_vector_ref_eq(Arc::new(NullVector::new(4))); + assert_vector_ref_eq(Arc::new(StringVector::from(vec![ + Some("hello"), + Some("world"), + ]))); + + assert_vector_ref_eq(Arc::new(Int8Vector::from_slice(&[1, 2, 3, 4]))); + assert_vector_ref_eq(Arc::new(UInt8Vector::from_slice(&[1, 2, 3, 4]))); + assert_vector_ref_eq(Arc::new(Int16Vector::from_slice(&[1, 2, 3, 4]))); + assert_vector_ref_eq(Arc::new(UInt16Vector::from_slice(&[1, 2, 3, 4]))); + assert_vector_ref_eq(Arc::new(Int32Vector::from_slice(&[1, 2, 3, 4]))); + assert_vector_ref_eq(Arc::new(UInt32Vector::from_slice(&[1, 2, 3, 4]))); + assert_vector_ref_eq(Arc::new(Int64Vector::from_slice(&[1, 2, 3, 4]))); + assert_vector_ref_eq(Arc::new(UInt64Vector::from_slice(&[1, 2, 3, 4]))); + assert_vector_ref_eq(Arc::new(Float32Vector::from_slice(&[1.0, 2.0, 3.0, 4.0]))); + assert_vector_ref_eq(Arc::new(Float64Vector::from_slice(&[1.0, 2.0, 3.0, 4.0]))); + } + + #[test] + fn test_vector_ne() { + assert_vector_ref_ne( + Arc::new(Int32Vector::from_slice(&[1, 2, 3, 4])), + Arc::new(Int32Vector::from_slice(&[1, 2])), + ); + assert_vector_ref_ne( + Arc::new(Int32Vector::from_slice(&[1, 2, 3, 4])), + Arc::new(Int8Vector::from_slice(&[1, 2, 3, 4])), + ); + assert_vector_ref_ne( + Arc::new(Int32Vector::from_slice(&[1, 2, 3, 4])), + Arc::new(BooleanVector::from(vec![true, true])), + ); + assert_vector_ref_ne( + Arc::new(ConstantVector::new( + Arc::new(BooleanVector::from(vec![true])), + 5, + )), + Arc::new(ConstantVector::new( + Arc::new(BooleanVector::from(vec![true])), + 4, + )), + ); + assert_vector_ref_ne( + Arc::new(ConstantVector::new( + Arc::new(BooleanVector::from(vec![true])), + 5, + )), + Arc::new(ConstantVector::new( + Arc::new(BooleanVector::from(vec![false])), + 4, + )), + ); + assert_vector_ref_ne( + Arc::new(ConstantVector::new( + Arc::new(BooleanVector::from(vec![true])), + 5, + )), + Arc::new(ConstantVector::new( + Arc::new(Int32Vector::from_slice(vec![1])), + 4, + )), + ); + assert_vector_ref_ne(Arc::new(NullVector::new(5)), Arc::new(NullVector::new(8))); + } +} diff --git a/src/datatypes2/src/vectors/helper.rs b/src/datatypes2/src/vectors/helper.rs new file mode 100644 index 0000000000..f3236ca0ec --- /dev/null +++ b/src/datatypes2/src/vectors/helper.rs @@ -0,0 +1,431 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Vector helper functions, inspired by databend Series mod + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, StringArray}; +use arrow::compute; +use arrow::compute::kernels::comparison; +use arrow::datatypes::{DataType as ArrowDataType, TimeUnit}; +use datafusion_common::ScalarValue; +use snafu::{OptionExt, ResultExt}; + +use crate::data_type::ConcreteDataType; +use crate::error::{self, Result}; +use crate::scalars::{Scalar, ScalarVectorBuilder}; +use crate::value::{ListValue, ListValueRef}; +use crate::vectors::{ + BinaryVector, BooleanVector, ConstantVector, DateTimeVector, DateVector, Float32Vector, + Float64Vector, Int16Vector, Int32Vector, Int64Vector, Int8Vector, ListVector, + ListVectorBuilder, MutableVector, NullVector, StringVector, TimestampMicrosecondVector, + TimestampMillisecondVector, TimestampNanosecondVector, TimestampSecondVector, UInt16Vector, + UInt32Vector, UInt64Vector, UInt8Vector, Vector, VectorRef, +}; + +/// Helper functions for `Vector`. +pub struct Helper; + +impl Helper { + /// Get a pointer to the underlying data of this vectors. + /// Can be useful for fast comparisons. + /// # Safety + /// Assumes that the `vector` is T. + pub unsafe fn static_cast(vector: &VectorRef) -> &T { + let object = vector.as_ref(); + debug_assert!(object.as_any().is::()); + &*(object as *const dyn Vector as *const T) + } + + pub fn check_get_scalar(vector: &VectorRef) -> Result<&::VectorType> { + let arr = vector + .as_any() + .downcast_ref::<::VectorType>() + .with_context(|| error::UnknownVectorSnafu { + msg: format!( + "downcast vector error, vector type: {:?}, expected vector: {:?}", + vector.vector_type_name(), + std::any::type_name::(), + ), + }); + arr + } + + pub fn check_get(vector: &VectorRef) -> Result<&T> { + let arr = vector + .as_any() + .downcast_ref::() + .with_context(|| error::UnknownVectorSnafu { + msg: format!( + "downcast vector error, vector type: {:?}, expected vector: {:?}", + vector.vector_type_name(), + std::any::type_name::(), + ), + }); + arr + } + + pub fn check_get_mutable_vector( + vector: &mut dyn MutableVector, + ) -> Result<&mut T> { + let ty = vector.data_type(); + let arr = vector + .as_mut_any() + .downcast_mut() + .with_context(|| error::UnknownVectorSnafu { + msg: format!( + "downcast vector error, vector type: {:?}, expected vector: {:?}", + ty, + std::any::type_name::(), + ), + }); + arr + } + + pub fn check_get_scalar_vector( + vector: &VectorRef, + ) -> Result<&::VectorType> { + let arr = vector + .as_any() + .downcast_ref::<::VectorType>() + .with_context(|| error::UnknownVectorSnafu { + msg: format!( + "downcast vector error, vector type: {:?}, expected vector: {:?}", + vector.vector_type_name(), + std::any::type_name::(), + ), + }); + arr + } + + /// Try to cast an arrow scalar value into vector + pub fn try_from_scalar_value(value: ScalarValue, length: usize) -> Result { + let vector = match value { + ScalarValue::Null => ConstantVector::new(Arc::new(NullVector::new(1)), length), + ScalarValue::Boolean(v) => { + ConstantVector::new(Arc::new(BooleanVector::from(vec![v])), length) + } + ScalarValue::Float32(v) => { + ConstantVector::new(Arc::new(Float32Vector::from(vec![v])), length) + } + ScalarValue::Float64(v) => { + ConstantVector::new(Arc::new(Float64Vector::from(vec![v])), length) + } + ScalarValue::Int8(v) => { + ConstantVector::new(Arc::new(Int8Vector::from(vec![v])), length) + } + ScalarValue::Int16(v) => { + ConstantVector::new(Arc::new(Int16Vector::from(vec![v])), length) + } + ScalarValue::Int32(v) => { + ConstantVector::new(Arc::new(Int32Vector::from(vec![v])), length) + } + ScalarValue::Int64(v) => { + ConstantVector::new(Arc::new(Int64Vector::from(vec![v])), length) + } + ScalarValue::UInt8(v) => { + ConstantVector::new(Arc::new(UInt8Vector::from(vec![v])), length) + } + ScalarValue::UInt16(v) => { + ConstantVector::new(Arc::new(UInt16Vector::from(vec![v])), length) + } + ScalarValue::UInt32(v) => { + ConstantVector::new(Arc::new(UInt32Vector::from(vec![v])), length) + } + ScalarValue::UInt64(v) => { + ConstantVector::new(Arc::new(UInt64Vector::from(vec![v])), length) + } + ScalarValue::Utf8(v) | ScalarValue::LargeUtf8(v) => { + ConstantVector::new(Arc::new(StringVector::from(vec![v])), length) + } + ScalarValue::Binary(v) + | ScalarValue::LargeBinary(v) + | ScalarValue::FixedSizeBinary(_, v) => { + ConstantVector::new(Arc::new(BinaryVector::from(vec![v])), length) + } + ScalarValue::List(v, field) => { + let item_type = ConcreteDataType::try_from(field.data_type())?; + let mut builder = ListVectorBuilder::with_type_capacity(item_type.clone(), 1); + if let Some(values) = v { + let values = values + .into_iter() + .map(ScalarValue::try_into) + .collect::>()?; + let list_value = ListValue::new(Some(Box::new(values)), item_type); + builder.push(Some(ListValueRef::Ref { val: &list_value })); + } else { + builder.push(None); + } + let list_vector = builder.to_vector(); + ConstantVector::new(list_vector, length) + } + ScalarValue::Date32(v) => { + ConstantVector::new(Arc::new(DateVector::from(vec![v])), length) + } + ScalarValue::Date64(v) => { + ConstantVector::new(Arc::new(DateTimeVector::from(vec![v])), length) + } + ScalarValue::TimestampSecond(v, _) => { + // Timezone is unimplemented now. + ConstantVector::new(Arc::new(TimestampSecondVector::from(vec![v])), length) + } + ScalarValue::TimestampMillisecond(v, _) => { + // Timezone is unimplemented now. + ConstantVector::new(Arc::new(TimestampMillisecondVector::from(vec![v])), length) + } + ScalarValue::TimestampMicrosecond(v, _) => { + // Timezone is unimplemented now. + ConstantVector::new(Arc::new(TimestampMicrosecondVector::from(vec![v])), length) + } + ScalarValue::TimestampNanosecond(v, _) => { + // Timezone is unimplemented now. + ConstantVector::new(Arc::new(TimestampNanosecondVector::from(vec![v])), length) + } + ScalarValue::Decimal128(_, _, _) + | ScalarValue::Time64(_) + | ScalarValue::IntervalYearMonth(_) + | ScalarValue::IntervalDayTime(_) + | ScalarValue::IntervalMonthDayNano(_) + | ScalarValue::Struct(_, _) + | ScalarValue::Dictionary(_, _) => { + return error::ConversionSnafu { + from: format!("Unsupported scalar value: {}", value), + } + .fail() + } + }; + + Ok(Arc::new(vector)) + } + + /// Try to cast an arrow array into vector + /// + /// # Panics + /// Panic if given arrow data type is not supported. + pub fn try_into_vector(array: impl AsRef) -> Result { + Ok(match array.as_ref().data_type() { + ArrowDataType::Null => Arc::new(NullVector::try_from_arrow_array(array)?), + ArrowDataType::Boolean => Arc::new(BooleanVector::try_from_arrow_array(array)?), + ArrowDataType::LargeBinary => Arc::new(BinaryVector::try_from_arrow_array(array)?), + ArrowDataType::Int8 => Arc::new(Int8Vector::try_from_arrow_array(array)?), + ArrowDataType::Int16 => Arc::new(Int16Vector::try_from_arrow_array(array)?), + ArrowDataType::Int32 => Arc::new(Int32Vector::try_from_arrow_array(array)?), + ArrowDataType::Int64 => Arc::new(Int64Vector::try_from_arrow_array(array)?), + ArrowDataType::UInt8 => Arc::new(UInt8Vector::try_from_arrow_array(array)?), + ArrowDataType::UInt16 => Arc::new(UInt16Vector::try_from_arrow_array(array)?), + ArrowDataType::UInt32 => Arc::new(UInt32Vector::try_from_arrow_array(array)?), + ArrowDataType::UInt64 => Arc::new(UInt64Vector::try_from_arrow_array(array)?), + ArrowDataType::Float32 => Arc::new(Float32Vector::try_from_arrow_array(array)?), + ArrowDataType::Float64 => Arc::new(Float64Vector::try_from_arrow_array(array)?), + ArrowDataType::Utf8 => Arc::new(StringVector::try_from_arrow_array(array)?), + ArrowDataType::Date32 => Arc::new(DateVector::try_from_arrow_array(array)?), + ArrowDataType::Date64 => Arc::new(DateTimeVector::try_from_arrow_array(array)?), + ArrowDataType::List(_) => Arc::new(ListVector::try_from_arrow_array(array)?), + ArrowDataType::Timestamp(unit, _) => match unit { + TimeUnit::Second => Arc::new(TimestampSecondVector::try_from_arrow_array(array)?), + TimeUnit::Millisecond => { + Arc::new(TimestampMillisecondVector::try_from_arrow_array(array)?) + } + TimeUnit::Microsecond => { + Arc::new(TimestampMicrosecondVector::try_from_arrow_array(array)?) + } + TimeUnit::Nanosecond => { + Arc::new(TimestampNanosecondVector::try_from_arrow_array(array)?) + } + }, + ArrowDataType::Float16 + | ArrowDataType::Time32(_) + | ArrowDataType::Time64(_) + | ArrowDataType::Duration(_) + | ArrowDataType::Interval(_) + | ArrowDataType::Binary + | ArrowDataType::FixedSizeBinary(_) + | ArrowDataType::LargeUtf8 + | ArrowDataType::LargeList(_) + | ArrowDataType::FixedSizeList(_, _) + | ArrowDataType::Struct(_) + | ArrowDataType::Union(_, _, _) + | ArrowDataType::Dictionary(_, _) + | ArrowDataType::Decimal128(_, _) + | ArrowDataType::Decimal256(_, _) + | ArrowDataType::Map(_, _) => { + unimplemented!("Arrow array datatype: {:?}", array.as_ref().data_type()) + } + }) + } + + /// Try to cast slice of `arrays` to vectors. + pub fn try_into_vectors(arrays: &[ArrayRef]) -> Result> { + arrays.iter().map(Self::try_into_vector).collect() + } + + /// Perform SQL like operation on `names` and a scalar `s`. + pub fn like_utf8(names: Vec, s: &str) -> Result { + let array = StringArray::from(names); + + let filter = comparison::like_utf8_scalar(&array, s).context(error::ArrowComputeSnafu)?; + + let result = compute::filter(&array, &filter).context(error::ArrowComputeSnafu)?; + Helper::try_into_vector(result) + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{ + ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeBinaryArray, ListArray, NullArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }; + use arrow::datatypes::{Field, Int32Type}; + use common_time::{Date, DateTime}; + + use super::*; + use crate::value::Value; + use crate::vectors::ConcreteDataType; + + #[test] + fn test_try_into_vectors() { + let arrays: Vec = vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![2])), + Arc::new(Int32Array::from(vec![3])), + ]; + let vectors = Helper::try_into_vectors(&arrays); + assert!(vectors.is_ok()); + let vectors = vectors.unwrap(); + vectors.iter().for_each(|v| assert_eq!(1, v.len())); + assert_eq!(Value::Int32(1), vectors[0].get(0)); + assert_eq!(Value::Int32(2), vectors[1].get(0)); + assert_eq!(Value::Int32(3), vectors[2].get(0)); + } + + #[test] + fn test_try_into_date_vector() { + let vector = DateVector::from(vec![Some(1), Some(2), None]); + let arrow_array = vector.to_arrow_array(); + assert_eq!(&ArrowDataType::Date32, arrow_array.data_type()); + let vector_converted = Helper::try_into_vector(arrow_array).unwrap(); + assert_eq!(vector.len(), vector_converted.len()); + for i in 0..vector_converted.len() { + assert_eq!(vector.get(i), vector_converted.get(i)); + } + } + + #[test] + fn test_try_from_scalar_date_value() { + let vector = Helper::try_from_scalar_value(ScalarValue::Date32(Some(42)), 3).unwrap(); + assert_eq!(ConcreteDataType::date_datatype(), vector.data_type()); + assert_eq!(3, vector.len()); + for i in 0..vector.len() { + assert_eq!(Value::Date(Date::new(42)), vector.get(i)); + } + } + + #[test] + fn test_try_from_scalar_datetime_value() { + let vector = Helper::try_from_scalar_value(ScalarValue::Date64(Some(42)), 3).unwrap(); + assert_eq!(ConcreteDataType::datetime_datatype(), vector.data_type()); + assert_eq!(3, vector.len()); + for i in 0..vector.len() { + assert_eq!(Value::DateTime(DateTime::new(42)), vector.get(i)); + } + } + + #[test] + fn test_try_from_list_value() { + let value = ScalarValue::List( + Some(vec![ + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(2)), + ]), + Box::new(Field::new("item", ArrowDataType::Int32, true)), + ); + let vector = Helper::try_from_scalar_value(value, 3).unwrap(); + assert_eq!( + ConcreteDataType::list_datatype(ConcreteDataType::int32_datatype()), + vector.data_type() + ); + assert_eq!(3, vector.len()); + for i in 0..vector.len() { + let v = vector.get(i); + let items = v.as_list().unwrap().unwrap().items().as_ref().unwrap(); + assert_eq!(vec![Value::Int32(1), Value::Int32(2)], **items); + } + } + + #[test] + fn test_like_utf8() { + fn assert_vector(expected: Vec<&str>, actual: &VectorRef) { + let actual = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(*actual, StringVector::from(expected)); + } + + let names: Vec = vec!["greptime", "hello", "public", "world"] + .into_iter() + .map(|x| x.to_string()) + .collect(); + + let ret = Helper::like_utf8(names.clone(), "%ll%").unwrap(); + assert_vector(vec!["hello"], &ret); + + let ret = Helper::like_utf8(names.clone(), "%time").unwrap(); + assert_vector(vec!["greptime"], &ret); + + let ret = Helper::like_utf8(names.clone(), "%ld").unwrap(); + assert_vector(vec!["world"], &ret); + + let ret = Helper::like_utf8(names, "%").unwrap(); + assert_vector(vec!["greptime", "hello", "public", "world"], &ret); + } + + fn check_try_into_vector(array: impl Array + 'static) { + let array: ArrayRef = Arc::new(array); + let vector = Helper::try_into_vector(array.clone()).unwrap(); + assert_eq!(&array, &vector.to_arrow_array()); + } + + #[test] + fn test_try_into_vector() { + check_try_into_vector(NullArray::new(2)); + check_try_into_vector(BooleanArray::from(vec![true, false])); + check_try_into_vector(LargeBinaryArray::from(vec![ + "hello".as_bytes(), + "world".as_bytes(), + ])); + check_try_into_vector(Int8Array::from(vec![1, 2, 3])); + check_try_into_vector(Int16Array::from(vec![1, 2, 3])); + check_try_into_vector(Int32Array::from(vec![1, 2, 3])); + check_try_into_vector(Int64Array::from(vec![1, 2, 3])); + check_try_into_vector(UInt8Array::from(vec![1, 2, 3])); + check_try_into_vector(UInt16Array::from(vec![1, 2, 3])); + check_try_into_vector(UInt32Array::from(vec![1, 2, 3])); + check_try_into_vector(UInt64Array::from(vec![1, 2, 3])); + check_try_into_vector(Float32Array::from(vec![1.0, 2.0, 3.0])); + check_try_into_vector(Float64Array::from(vec![1.0, 2.0, 3.0])); + check_try_into_vector(StringArray::from(vec!["hello", "world"])); + check_try_into_vector(Date32Array::from(vec![1, 2, 3])); + check_try_into_vector(Date64Array::from(vec![1, 2, 3])); + let data = vec![None, Some(vec![Some(6), Some(7)])]; + let list_array = ListArray::from_iter_primitive::(data); + check_try_into_vector(list_array); + check_try_into_vector(TimestampSecondArray::from(vec![1, 2, 3])); + check_try_into_vector(TimestampMillisecondArray::from(vec![1, 2, 3])); + check_try_into_vector(TimestampMicrosecondArray::from(vec![1, 2, 3])); + check_try_into_vector(TimestampNanosecondArray::from(vec![1, 2, 3])); + } +} diff --git a/src/datatypes2/src/vectors/list.rs b/src/datatypes2/src/vectors/list.rs new file mode 100644 index 0000000000..747e03557b --- /dev/null +++ b/src/datatypes2/src/vectors/list.rs @@ -0,0 +1,747 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayData, ArrayRef, BooleanBufferBuilder, Int32BufferBuilder, ListArray, +}; +use arrow::buffer::Buffer; +use arrow::datatypes::DataType as ArrowDataType; +use serde_json::Value as JsonValue; + +use crate::data_type::{ConcreteDataType, DataType}; +use crate::error::Result; +use crate::scalars::{ScalarVector, ScalarVectorBuilder}; +use crate::serialize::Serializable; +use crate::types::ListType; +use crate::value::{ListValue, ListValueRef, Value, ValueRef}; +use crate::vectors::{self, Helper, MutableVector, Validity, Vector, VectorRef}; + +/// Vector of Lists, basically backed by Arrow's `ListArray`. +#[derive(Debug, PartialEq)] +pub struct ListVector { + array: ListArray, + /// The datatype of the items in the list. + item_type: ConcreteDataType, +} + +impl ListVector { + /// Iterate elements as [VectorRef]. + pub fn values_iter(&self) -> impl Iterator>> + '_ { + self.array + .iter() + .map(|value_opt| value_opt.map(Helper::try_into_vector).transpose()) + } + + fn to_array_data(&self) -> ArrayData { + self.array.data().clone() + } + + fn from_array_data_and_type(data: ArrayData, item_type: ConcreteDataType) -> Self { + Self { + array: ListArray::from(data), + item_type, + } + } + + pub(crate) fn as_arrow(&self) -> &dyn Array { + &self.array + } +} + +impl Vector for ListVector { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::List(ListType::new(self.item_type.clone())) + } + + fn vector_type_name(&self) -> String { + "ListVector".to_string() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.array.len() + } + + fn to_arrow_array(&self) -> ArrayRef { + let data = self.to_array_data(); + Arc::new(ListArray::from(data)) + } + + fn to_boxed_arrow_array(&self) -> Box { + let data = self.to_array_data(); + Box::new(ListArray::from(data)) + } + + fn validity(&self) -> Validity { + vectors::impl_validity_for_vector!(self.array) + } + + fn memory_size(&self) -> usize { + self.array.get_buffer_memory_size() + } + + fn null_count(&self) -> usize { + self.array.null_count() + } + + fn is_null(&self, row: usize) -> bool { + self.array.is_null(row) + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef { + let data = self.array.data().slice(offset, length); + Arc::new(Self::from_array_data_and_type(data, self.item_type.clone())) + } + + fn get(&self, index: usize) -> Value { + if !self.array.is_valid(index) { + return Value::Null; + } + + let array = &self.array.value(index); + let vector = Helper::try_into_vector(array).unwrap_or_else(|_| { + panic!( + "arrow array with datatype {:?} cannot converted to our vector", + array.data_type() + ) + }); + let values = (0..vector.len()) + .map(|i| vector.get(i)) + .collect::>(); + Value::List(ListValue::new( + Some(Box::new(values)), + self.item_type.clone(), + )) + } + + fn get_ref(&self, index: usize) -> ValueRef { + ValueRef::List(ListValueRef::Indexed { + vector: self, + idx: index, + }) + } +} + +impl Serializable for ListVector { + fn serialize_to_json(&self) -> Result> { + self.array + .iter() + .map(|v| match v { + None => Ok(JsonValue::Null), + Some(v) => Helper::try_into_vector(v) + .and_then(|v| v.serialize_to_json()) + .map(JsonValue::Array), + }) + .collect() + } +} + +impl From for ListVector { + fn from(array: ListArray) -> Self { + let item_type = ConcreteDataType::from_arrow_type(match array.data_type() { + ArrowDataType::List(field) => field.data_type(), + other => panic!( + "Try to create ListVector from an arrow array with type {:?}", + other + ), + }); + Self { array, item_type } + } +} + +vectors::impl_try_from_arrow_array_for_vector!(ListArray, ListVector); + +pub struct ListIter<'a> { + vector: &'a ListVector, + idx: usize, +} + +impl<'a> ListIter<'a> { + fn new(vector: &'a ListVector) -> ListIter { + ListIter { vector, idx: 0 } + } +} + +impl<'a> Iterator for ListIter<'a> { + type Item = Option>; + + #[inline] + fn next(&mut self) -> Option { + if self.idx >= self.vector.len() { + return None; + } + + let idx = self.idx; + self.idx += 1; + + if self.vector.is_null(idx) { + return Some(None); + } + + Some(Some(ListValueRef::Indexed { + vector: self.vector, + idx, + })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.vector.len(), Some(self.vector.len())) + } +} + +impl ScalarVector for ListVector { + type OwnedItem = ListValue; + type RefItem<'a> = ListValueRef<'a>; + type Iter<'a> = ListIter<'a>; + type Builder = ListVectorBuilder; + + fn get_data(&self, idx: usize) -> Option> { + if self.array.is_valid(idx) { + Some(ListValueRef::Indexed { vector: self, idx }) + } else { + None + } + } + + fn iter_data(&self) -> Self::Iter<'_> { + ListIter::new(self) + } +} + +// Ports from arrow's GenericListBuilder. +// See https://github.com/apache/arrow-rs/blob/94565bca99b5d9932a3e9a8e094aaf4e4384b1e5/arrow-array/src/builder/generic_list_builder.rs +/// [ListVector] builder. +pub struct ListVectorBuilder { + item_type: ConcreteDataType, + offsets_builder: Int32BufferBuilder, + null_buffer_builder: NullBufferBuilder, + values_builder: Box, +} + +impl ListVectorBuilder { + /// Creates a new [`ListVectorBuilder`]. `item_type` is the data type of the list item, `capacity` + /// is the number of items to pre-allocate space for in this builder. + pub fn with_type_capacity(item_type: ConcreteDataType, capacity: usize) -> ListVectorBuilder { + let mut offsets_builder = Int32BufferBuilder::new(capacity + 1); + offsets_builder.append(0); + // The actual required capacity might be greater than the capacity of the `ListVector` + // if the child vector has more than one element. + let values_builder = item_type.create_mutable_vector(capacity); + + ListVectorBuilder { + item_type, + offsets_builder, + null_buffer_builder: NullBufferBuilder::new(capacity), + values_builder, + } + } + + /// Finish the current variable-length list vector slot. + fn finish_list(&mut self, is_valid: bool) { + self.offsets_builder + .append(i32::try_from(self.values_builder.len()).unwrap()); + self.null_buffer_builder.append(is_valid); + } + + fn push_null(&mut self) { + self.finish_list(false); + } + + fn push_list_value(&mut self, list_value: &ListValue) -> Result<()> { + if let Some(items) = list_value.items() { + for item in &**items { + self.values_builder.push_value_ref(item.as_value_ref())?; + } + } + + self.finish_list(true); + Ok(()) + } +} + +impl MutableVector for ListVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::list_datatype(self.item_type.clone()) + } + + fn len(&self) -> usize { + self.null_buffer_builder.len() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + Arc::new(self.finish()) + } + + fn push_value_ref(&mut self, value: ValueRef) -> Result<()> { + if let Some(list_ref) = value.as_list()? { + match list_ref { + ListValueRef::Indexed { vector, idx } => match vector.get(idx).as_list()? { + Some(list_value) => self.push_list_value(list_value)?, + None => self.push_null(), + }, + ListValueRef::Ref { val } => self.push_list_value(val)?, + } + } else { + self.push_null(); + } + + Ok(()) + } + + fn extend_slice_of(&mut self, vector: &dyn Vector, offset: usize, length: usize) -> Result<()> { + for idx in offset..offset + length { + let value = vector.get_ref(idx); + self.push_value_ref(value)?; + } + + Ok(()) + } +} + +impl ScalarVectorBuilder for ListVectorBuilder { + type VectorType = ListVector; + + fn with_capacity(_capacity: usize) -> Self { + panic!("Must use ListVectorBuilder::with_type_capacity()"); + } + + fn push(&mut self, value: Option<::RefItem<'_>>) { + // We expect the input ListValue has the same inner type as the builder when using + // push(), so just panic if `push_value_ref()` returns error, which indicate an + // invalid input value type. + self.push_value_ref(value.into()).unwrap_or_else(|e| { + panic!( + "Failed to push value, expect value type {:?}, err:{}", + self.item_type, e + ); + }); + } + + fn finish(&mut self) -> Self::VectorType { + let len = self.len(); + let values_vector = self.values_builder.to_vector(); + let values_arr = values_vector.to_arrow_array(); + let values_data = values_arr.data(); + + let offset_buffer = self.offsets_builder.finish(); + let null_bit_buffer = self.null_buffer_builder.finish(); + // Re-initialize the offsets_builder. + self.offsets_builder.append(0); + let data_type = ConcreteDataType::list_datatype(self.item_type.clone()).as_arrow_type(); + let array_data_builder = ArrayData::builder(data_type) + .len(len) + .add_buffer(offset_buffer) + .add_child_data(values_data.clone()) + .null_bit_buffer(null_bit_buffer); + + let array_data = unsafe { array_data_builder.build_unchecked() }; + let array = ListArray::from(array_data); + + ListVector { + array, + item_type: self.item_type.clone(), + } + } +} + +// Ports from https://github.com/apache/arrow-rs/blob/94565bca99b5d9932a3e9a8e094aaf4e4384b1e5/arrow-array/src/builder/null_buffer_builder.rs +/// Builder for creating the null bit buffer. +/// This builder only materializes the buffer when we append `false`. +/// If you only append `true`s to the builder, what you get will be +/// `None` when calling [`finish`](#method.finish). +/// This optimization is **very** important for the performance. +#[derive(Debug)] +struct NullBufferBuilder { + bitmap_builder: Option, + /// Store the length of the buffer before materializing. + len: usize, + capacity: usize, +} + +impl NullBufferBuilder { + /// Creates a new empty builder. + /// `capacity` is the number of bits in the null buffer. + fn new(capacity: usize) -> Self { + Self { + bitmap_builder: None, + len: 0, + capacity, + } + } + + fn len(&self) -> usize { + if let Some(b) = &self.bitmap_builder { + b.len() + } else { + self.len + } + } + + /// Appends a `true` into the builder + /// to indicate that this item is not null. + #[inline] + fn append_non_null(&mut self) { + if let Some(buf) = self.bitmap_builder.as_mut() { + buf.append(true) + } else { + self.len += 1; + } + } + + /// Appends a `false` into the builder + /// to indicate that this item is null. + #[inline] + fn append_null(&mut self) { + self.materialize_if_needed(); + self.bitmap_builder.as_mut().unwrap().append(false); + } + + /// Appends a boolean value into the builder. + #[inline] + fn append(&mut self, not_null: bool) { + if not_null { + self.append_non_null() + } else { + self.append_null() + } + } + + /// Builds the null buffer and resets the builder. + /// Returns `None` if the builder only contains `true`s. + fn finish(&mut self) -> Option { + let buf = self.bitmap_builder.as_mut().map(|b| b.finish()); + self.bitmap_builder = None; + self.len = 0; + buf + } + + #[inline] + fn materialize_if_needed(&mut self) { + if self.bitmap_builder.is_none() { + self.materialize() + } + } + + #[cold] + fn materialize(&mut self) { + if self.bitmap_builder.is_none() { + let mut b = BooleanBufferBuilder::new(self.len.max(self.capacity)); + b.append_n(self.len, true); + self.bitmap_builder = Some(b); + } + } +} + +#[cfg(test)] +pub mod tests { + use arrow::array::{Int32Array, Int32Builder, ListBuilder}; + use serde_json::json; + + use super::*; + use crate::scalars::ScalarRef; + use crate::types::ListType; + use crate::vectors::Int32Vector; + + pub fn new_list_vector(data: &[Option>>]) -> ListVector { + let mut builder = + ListVectorBuilder::with_type_capacity(ConcreteDataType::int32_datatype(), 8); + for vec_opt in data { + if let Some(vec) = vec_opt { + let values = vec.iter().map(|v| Value::from(*v)).collect(); + let values = Some(Box::new(values)); + let list_value = ListValue::new(values, ConcreteDataType::int32_datatype()); + + builder.push(Some(ListValueRef::Ref { val: &list_value })); + } else { + builder.push(None); + } + } + + builder.finish() + } + + fn new_list_array(data: &[Option>>]) -> ListArray { + let mut builder = ListBuilder::new(Int32Builder::new()); + for vec_opt in data { + if let Some(vec) = vec_opt { + for value_opt in vec { + builder.values().append_option(*value_opt); + } + + builder.append(true); + } else { + builder.append(false); + } + } + + builder.finish() + } + + #[test] + fn test_list_vector() { + let data = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + + let list_vector = new_list_vector(&data); + + assert_eq!( + ConcreteDataType::List(ListType::new(ConcreteDataType::int32_datatype())), + list_vector.data_type() + ); + assert_eq!("ListVector", list_vector.vector_type_name()); + assert_eq!(3, list_vector.len()); + assert!(!list_vector.is_null(0)); + assert!(list_vector.is_null(1)); + assert!(!list_vector.is_null(2)); + + let arrow_array = new_list_array(&data); + assert_eq!( + arrow_array, + *list_vector + .to_arrow_array() + .as_any() + .downcast_ref::() + .unwrap() + ); + let validity = list_vector.validity(); + assert!(!validity.is_all_null()); + assert!(!validity.is_all_valid()); + assert!(validity.is_set(0)); + assert!(!validity.is_set(1)); + assert!(validity.is_set(2)); + assert_eq!(256, list_vector.memory_size()); + + let slice = list_vector.slice(0, 2).to_arrow_array(); + let sliced_array = slice.as_any().downcast_ref::().unwrap(); + assert_eq!( + Int32Array::from_iter_values([1, 2, 3]), + *sliced_array + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + ); + assert!(sliced_array.is_null(1)); + + assert_eq!( + Value::List(ListValue::new( + Some(Box::new(vec![ + Value::Int32(1), + Value::Int32(2), + Value::Int32(3) + ])), + ConcreteDataType::int32_datatype() + )), + list_vector.get(0) + ); + let value_ref = list_vector.get_ref(0); + assert!(matches!( + value_ref, + ValueRef::List(ListValueRef::Indexed { .. }) + )); + let value_ref = list_vector.get_ref(1); + if let ValueRef::List(ListValueRef::Indexed { idx, .. }) = value_ref { + assert_eq!(1, idx); + } else { + unreachable!() + } + assert_eq!(Value::Null, list_vector.get(1)); + assert_eq!( + Value::List(ListValue::new( + Some(Box::new(vec![ + Value::Int32(4), + Value::Null, + Value::Int32(6) + ])), + ConcreteDataType::int32_datatype() + )), + list_vector.get(2) + ); + } + + #[test] + fn test_from_arrow_array() { + let data = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + + let arrow_array = new_list_array(&data); + let array_ref: ArrayRef = Arc::new(arrow_array); + let expect = new_list_vector(&data); + + // Test try from ArrayRef + let list_vector = ListVector::try_from_arrow_array(array_ref).unwrap(); + assert_eq!(expect, list_vector); + + // Test from + let arrow_array = new_list_array(&data); + let list_vector = ListVector::from(arrow_array); + assert_eq!(expect, list_vector); + } + + #[test] + fn test_iter_list_vector_values() { + let data = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + + let list_vector = new_list_vector(&data); + + assert_eq!( + ConcreteDataType::List(ListType::new(ConcreteDataType::int32_datatype())), + list_vector.data_type() + ); + let mut iter = list_vector.values_iter(); + assert_eq!( + Arc::new(Int32Vector::from_slice(&[1, 2, 3])) as VectorRef, + *iter.next().unwrap().unwrap().unwrap() + ); + assert!(iter.next().unwrap().unwrap().is_none()); + assert_eq!( + Arc::new(Int32Vector::from(vec![Some(4), None, Some(6)])) as VectorRef, + *iter.next().unwrap().unwrap().unwrap(), + ); + assert!(iter.next().is_none()) + } + + #[test] + fn test_serialize_to_json() { + let data = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), None, Some(6)]), + ]; + + let list_vector = new_list_vector(&data); + assert_eq!( + vec![json!([1, 2, 3]), json!(null), json!([4, null, 6]),], + list_vector.serialize_to_json().unwrap() + ); + } + + #[test] + fn test_list_vector_builder() { + let mut builder = + ListType::new(ConcreteDataType::int32_datatype()).create_mutable_vector(3); + builder + .push_value_ref(ValueRef::List(ListValueRef::Ref { + val: &ListValue::new( + Some(Box::new(vec![ + Value::Int32(4), + Value::Null, + Value::Int32(6), + ])), + ConcreteDataType::int32_datatype(), + ), + })) + .unwrap(); + assert!(builder.push_value_ref(ValueRef::Int32(123)).is_err()); + + let data = vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(7), Some(8), None]), + ]; + let input = new_list_vector(&data); + builder.extend_slice_of(&input, 1, 2).unwrap(); + assert!(builder + .extend_slice_of(&crate::vectors::Int32Vector::from_slice(&[13]), 0, 1) + .is_err()); + let vector = builder.to_vector(); + + let expect: VectorRef = Arc::new(new_list_vector(&[ + Some(vec![Some(4), None, Some(6)]), + None, + Some(vec![Some(7), Some(8), None]), + ])); + assert_eq!(expect, vector); + } + + #[test] + fn test_list_vector_for_scalar() { + let mut builder = + ListVectorBuilder::with_type_capacity(ConcreteDataType::int32_datatype(), 2); + builder.push(None); + builder.push(Some(ListValueRef::Ref { + val: &ListValue::new( + Some(Box::new(vec![ + Value::Int32(4), + Value::Null, + Value::Int32(6), + ])), + ConcreteDataType::int32_datatype(), + ), + })); + let vector = builder.finish(); + + let expect = new_list_vector(&[None, Some(vec![Some(4), None, Some(6)])]); + assert_eq!(expect, vector); + + assert!(vector.get_data(0).is_none()); + assert_eq!( + ListValueRef::Indexed { + vector: &vector, + idx: 1 + }, + vector.get_data(1).unwrap() + ); + assert_eq!( + *vector.get(1).as_list().unwrap().unwrap(), + vector.get_data(1).unwrap().to_owned_scalar() + ); + + let mut iter = vector.iter_data(); + assert!(iter.next().unwrap().is_none()); + assert_eq!( + ListValueRef::Indexed { + vector: &vector, + idx: 1 + }, + iter.next().unwrap().unwrap() + ); + assert!(iter.next().is_none()); + + let mut iter = vector.iter_data(); + assert_eq!(2, iter.size_hint().0); + assert_eq!( + ListValueRef::Indexed { + vector: &vector, + idx: 1 + }, + iter.nth(1).unwrap().unwrap() + ); + } +} diff --git a/src/datatypes2/src/vectors/null.rs b/src/datatypes2/src/vectors/null.rs new file mode 100644 index 0000000000..bb66e09b39 --- /dev/null +++ b/src/datatypes2/src/vectors/null.rs @@ -0,0 +1,282 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use arrow::array::{Array, ArrayData, ArrayRef, NullArray}; +use snafu::{ensure, OptionExt}; + +use crate::data_type::ConcreteDataType; +use crate::error::{self, Result}; +use crate::serialize::Serializable; +use crate::types::NullType; +use crate::value::{Value, ValueRef}; +use crate::vectors::{self, MutableVector, Validity, Vector, VectorRef}; + +/// A vector where all elements are nulls. +#[derive(PartialEq)] +pub struct NullVector { + array: NullArray, +} + +// TODO(yingwen): Support null vector with other logical types. +impl NullVector { + /// Create a new `NullVector` with `n` elements. + pub fn new(n: usize) -> Self { + Self { + array: NullArray::new(n), + } + } + + pub(crate) fn as_arrow(&self) -> &dyn Array { + &self.array + } + + fn to_array_data(&self) -> ArrayData { + self.array.data().clone() + } +} + +impl From for NullVector { + fn from(array: NullArray) -> Self { + Self { array } + } +} + +impl Vector for NullVector { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::Null(NullType::default()) + } + + fn vector_type_name(&self) -> String { + "NullVector".to_string() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.array.len() + } + + fn to_arrow_array(&self) -> ArrayRef { + // TODO(yingwen): Replaced by clone after upgrading to arrow 28.0. + let data = self.to_array_data(); + Arc::new(NullArray::from(data)) + } + + fn to_boxed_arrow_array(&self) -> Box { + let data = self.to_array_data(); + Box::new(NullArray::from(data)) + } + + fn validity(&self) -> Validity { + Validity::all_null(self.array.len()) + } + + fn memory_size(&self) -> usize { + 0 + } + + fn null_count(&self) -> usize { + self.array.null_count() + } + + fn is_null(&self, _row: usize) -> bool { + true + } + + fn only_null(&self) -> bool { + true + } + + fn slice(&self, _offset: usize, length: usize) -> VectorRef { + Arc::new(Self::new(length)) + } + + fn get(&self, _index: usize) -> Value { + // Skips bound check for null array. + Value::Null + } + + fn get_ref(&self, _index: usize) -> ValueRef { + // Skips bound check for null array. + ValueRef::Null + } +} + +impl fmt::Debug for NullVector { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "NullVector({})", self.len()) + } +} + +impl Serializable for NullVector { + fn serialize_to_json(&self) -> Result> { + Ok(std::iter::repeat(serde_json::Value::Null) + .take(self.len()) + .collect()) + } +} + +vectors::impl_try_from_arrow_array_for_vector!(NullArray, NullVector); + +#[derive(Default)] +pub struct NullVectorBuilder { + length: usize, +} + +impl MutableVector for NullVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::null_datatype() + } + + fn len(&self) -> usize { + self.length + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + let vector = Arc::new(NullVector::new(self.length)); + self.length = 0; + vector + } + + fn push_value_ref(&mut self, value: ValueRef) -> Result<()> { + ensure!( + value.is_null(), + error::CastTypeSnafu { + msg: format!("Failed to cast value ref {:?} to null", value), + } + ); + + self.length += 1; + Ok(()) + } + + fn extend_slice_of(&mut self, vector: &dyn Vector, offset: usize, length: usize) -> Result<()> { + vector + .as_any() + .downcast_ref::() + .with_context(|| error::CastTypeSnafu { + msg: format!( + "Failed to convert vector from {} to NullVector", + vector.vector_type_name() + ), + })?; + assert!( + offset + length <= vector.len(), + "offset {} + length {} must less than {}", + offset, + length, + vector.len() + ); + + self.length += length; + Ok(()) + } +} + +pub(crate) fn replicate_null(vector: &NullVector, offsets: &[usize]) -> VectorRef { + assert_eq!(offsets.len(), vector.len()); + + if offsets.is_empty() { + return vector.slice(0, 0); + } + + Arc::new(NullVector::new(*offsets.last().unwrap())) +} + +#[cfg(test)] +mod tests { + use serde_json; + + use super::*; + use crate::data_type::DataType; + + #[test] + fn test_null_vector_misc() { + let v = NullVector::new(32); + + assert_eq!(v.len(), 32); + assert_eq!(0, v.memory_size()); + let arrow_arr = v.to_arrow_array(); + assert_eq!(arrow_arr.null_count(), 32); + + let array2 = arrow_arr.slice(8, 16); + assert_eq!(array2.len(), 16); + assert_eq!(array2.null_count(), 16); + + assert_eq!("NullVector", v.vector_type_name()); + assert!(!v.is_const()); + assert!(v.validity().is_all_null()); + assert!(v.only_null()); + + for i in 0..32 { + assert!(v.is_null(i)); + assert_eq!(Value::Null, v.get(i)); + assert_eq!(ValueRef::Null, v.get_ref(i)); + } + } + + #[test] + fn test_debug_null_vector() { + let array = NullVector::new(1024 * 1024); + assert_eq!(format!("{:?}", array), "NullVector(1048576)"); + } + + #[test] + fn test_serialize_json() { + let vector = NullVector::new(3); + let json_value = vector.serialize_to_json().unwrap(); + assert_eq!( + "[null,null,null]", + serde_json::to_string(&json_value).unwrap() + ); + } + + #[test] + fn test_null_vector_validity() { + let vector = NullVector::new(5); + assert!(vector.validity().is_all_null()); + assert_eq!(5, vector.null_count()); + } + + #[test] + fn test_null_vector_builder() { + let mut builder = NullType::default().create_mutable_vector(3); + builder.push_value_ref(ValueRef::Null).unwrap(); + assert!(builder.push_value_ref(ValueRef::Int32(123)).is_err()); + + let input = NullVector::new(3); + builder.extend_slice_of(&input, 1, 2).unwrap(); + assert!(builder + .extend_slice_of(&crate::vectors::Int32Vector::from_slice(&[13]), 0, 1) + .is_err()); + let vector = builder.to_vector(); + + let expect: VectorRef = Arc::new(input); + assert_eq!(expect, vector); + } +} diff --git a/src/datatypes2/src/vectors/operations.rs b/src/datatypes2/src/vectors/operations.rs new file mode 100644 index 0000000000..70ddb4a031 --- /dev/null +++ b/src/datatypes2/src/vectors/operations.rs @@ -0,0 +1,127 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod filter; +mod find_unique; +mod replicate; + +use common_base::BitVec; + +use crate::error::Result; +use crate::types::LogicalPrimitiveType; +use crate::vectors::constant::ConstantVector; +use crate::vectors::{ + BinaryVector, BooleanVector, ListVector, NullVector, PrimitiveVector, StringVector, Vector, + VectorRef, +}; + +/// Vector compute operations. +pub trait VectorOp { + /// Copies each element according `offsets` parameter. + /// - `i-th` element should be copied `offsets[i] - offsets[i - 1]` times + /// - `0-th` element would be copied `offsets[0]` times + /// + /// # Panics + /// Panics if `offsets.len() != self.len()`. + fn replicate(&self, offsets: &[usize]) -> VectorRef; + + /// Mark `i-th` bit of `selected` to `true` if the `i-th` element of `self` is unique, which + /// means there is no elements behind it have same value as it. + /// + /// The caller should ensure + /// 1. the length of `selected` bitmap is equal to `vector.len()`. + /// 2. `vector` and `prev_vector` are sorted. + /// + /// If there are multiple duplicate elements, this function retains the **first** element. + /// The first element is considered as unique if the first element of `self` is different + /// from its previous element, that is the last element of `prev_vector`. + /// + /// # Panics + /// Panics if + /// - `selected.len() < self.len()`. + /// - `prev_vector` and `self` have different data types. + fn find_unique(&self, selected: &mut BitVec, prev_vector: Option<&dyn Vector>); + + /// Filters the vector, returns elements matching the `filter` (i.e. where the values are true). + /// + /// Note that the nulls of `filter` are interpreted as `false` will lead to these elements being masked out. + fn filter(&self, filter: &BooleanVector) -> Result; +} + +macro_rules! impl_scalar_vector_op { + ($($VectorType: ident),+) => {$( + impl VectorOp for $VectorType { + fn replicate(&self, offsets: &[usize]) -> VectorRef { + replicate::replicate_scalar(self, offsets) + } + + fn find_unique(&self, selected: &mut BitVec, prev_vector: Option<&dyn Vector>) { + let prev_vector = prev_vector.map(|pv| pv.as_any().downcast_ref::<$VectorType>().unwrap()); + find_unique::find_unique_scalar(self, selected, prev_vector); + } + + fn filter(&self, filter: &BooleanVector) -> Result { + filter::filter_non_constant!(self, $VectorType, filter) + } + } + )+}; +} + +impl_scalar_vector_op!(BinaryVector, BooleanVector, ListVector, StringVector); + +impl VectorOp for PrimitiveVector { + fn replicate(&self, offsets: &[usize]) -> VectorRef { + std::sync::Arc::new(replicate::replicate_primitive(self, offsets)) + } + + fn find_unique(&self, selected: &mut BitVec, prev_vector: Option<&dyn Vector>) { + let prev_vector = + prev_vector.and_then(|pv| pv.as_any().downcast_ref::>()); + find_unique::find_unique_scalar(self, selected, prev_vector); + } + + fn filter(&self, filter: &BooleanVector) -> Result { + filter::filter_non_constant!(self, PrimitiveVector, filter) + } +} + +impl VectorOp for NullVector { + fn replicate(&self, offsets: &[usize]) -> VectorRef { + replicate::replicate_null(self, offsets) + } + + fn find_unique(&self, selected: &mut BitVec, prev_vector: Option<&dyn Vector>) { + let prev_vector = prev_vector.and_then(|pv| pv.as_any().downcast_ref::()); + find_unique::find_unique_null(self, selected, prev_vector); + } + + fn filter(&self, filter: &BooleanVector) -> Result { + filter::filter_non_constant!(self, NullVector, filter) + } +} + +impl VectorOp for ConstantVector { + fn replicate(&self, offsets: &[usize]) -> VectorRef { + self.replicate_vector(offsets) + } + + fn find_unique(&self, selected: &mut BitVec, prev_vector: Option<&dyn Vector>) { + let prev_vector = prev_vector.and_then(|pv| pv.as_any().downcast_ref::()); + find_unique::find_unique_constant(self, selected, prev_vector); + } + + fn filter(&self, filter: &BooleanVector) -> Result { + self.filter_vector(filter) + } +} diff --git a/src/datatypes2/src/vectors/operations/filter.rs b/src/datatypes2/src/vectors/operations/filter.rs new file mode 100644 index 0000000000..8368a6afb4 --- /dev/null +++ b/src/datatypes2/src/vectors/operations/filter.rs @@ -0,0 +1,145 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +macro_rules! filter_non_constant { + ($vector: expr, $VectorType: ty, $filter: ident) => {{ + use std::sync::Arc; + + use arrow::compute; + use snafu::ResultExt; + + let arrow_array = $vector.as_arrow(); + let filtered = compute::filter(arrow_array, $filter.as_boolean_array()) + .context(crate::error::ArrowComputeSnafu)?; + Ok(Arc::new(<$VectorType>::try_from_arrow_array(filtered)?)) + }}; +} + +pub(crate) use filter_non_constant; + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_time::{Date, DateTime}; + + use crate::scalars::ScalarVector; + use crate::timestamp::{ + TimestampMicrosecond, TimestampMillisecond, TimestampNanosecond, TimestampSecond, + }; + use crate::types::WrapperType; + use crate::vectors::constant::ConstantVector; + use crate::vectors::{ + BooleanVector, Int32Vector, NullVector, StringVector, VectorOp, VectorRef, + }; + + fn check_filter_primitive(expect: &[i32], input: &[i32], filter: &[bool]) { + let v = Int32Vector::from_slice(&input); + let filter = BooleanVector::from_slice(filter); + let out = v.filter(&filter).unwrap(); + + let expect: VectorRef = Arc::new(Int32Vector::from_slice(&expect)); + assert_eq!(expect, out); + } + + #[test] + fn test_filter_primitive() { + check_filter_primitive(&[], &[], &[]); + check_filter_primitive(&[5], &[5], &[true]); + check_filter_primitive(&[], &[5], &[false]); + check_filter_primitive(&[], &[5, 6], &[false, false]); + check_filter_primitive(&[5, 6], &[5, 6], &[true, true]); + check_filter_primitive(&[], &[5, 6, 7], &[false, false, false]); + check_filter_primitive(&[5], &[5, 6, 7], &[true, false, false]); + check_filter_primitive(&[6], &[5, 6, 7], &[false, true, false]); + check_filter_primitive(&[7], &[5, 6, 7], &[false, false, true]); + check_filter_primitive(&[5, 7], &[5, 6, 7], &[true, false, true]); + } + + fn check_filter_constant(expect_length: usize, input_length: usize, filter: &[bool]) { + let v = ConstantVector::new(Arc::new(Int32Vector::from_slice(&[123])), input_length); + let filter = BooleanVector::from_slice(filter); + let out = v.filter(&filter).unwrap(); + + assert!(out.is_const()); + assert_eq!(expect_length, out.len()); + } + + #[test] + fn test_filter_constant() { + check_filter_constant(0, 0, &[]); + check_filter_constant(1, 1, &[true]); + check_filter_constant(0, 1, &[false]); + check_filter_constant(1, 2, &[false, true]); + check_filter_constant(2, 2, &[true, true]); + check_filter_constant(1, 4, &[false, false, false, true]); + check_filter_constant(2, 4, &[false, true, false, true]); + } + + #[test] + fn test_filter_scalar() { + let v = StringVector::from_slice(&["0", "1", "2", "3"]); + let filter = BooleanVector::from_slice(&[false, true, false, true]); + let out = v.filter(&filter).unwrap(); + + let expect: VectorRef = Arc::new(StringVector::from_slice(&["1", "3"])); + assert_eq!(expect, out); + } + + #[test] + fn test_filter_null() { + let v = NullVector::new(5); + let filter = BooleanVector::from_slice(&[false, true, false, true, true]); + let out = v.filter(&filter).unwrap(); + + let expect: VectorRef = Arc::new(NullVector::new(3)); + assert_eq!(expect, out); + } + + macro_rules! impl_filter_date_like_test { + ($VectorType: ident, $ValueType: ident, $method: ident) => {{ + use std::sync::Arc; + + use $crate::vectors::{$VectorType, VectorRef}; + + let v = $VectorType::from_iterator((0..5).map($ValueType::$method)); + let filter = BooleanVector::from_slice(&[false, true, false, true, true]); + let out = v.filter(&filter).unwrap(); + + let expect: VectorRef = Arc::new($VectorType::from_iterator( + [1, 3, 4].into_iter().map($ValueType::$method), + )); + assert_eq!(expect, out); + }}; + } + + #[test] + fn test_filter_date_like() { + impl_filter_date_like_test!(DateVector, Date, new); + impl_filter_date_like_test!(DateTimeVector, DateTime, new); + + impl_filter_date_like_test!(TimestampSecondVector, TimestampSecond, from_native); + impl_filter_date_like_test!( + TimestampMillisecondVector, + TimestampMillisecond, + from_native + ); + impl_filter_date_like_test!( + TimestampMicrosecondVector, + TimestampMicrosecond, + from_native + ); + impl_filter_date_like_test!(TimestampNanosecondVector, TimestampNanosecond, from_native); + } +} diff --git a/src/datatypes2/src/vectors/operations/find_unique.rs b/src/datatypes2/src/vectors/operations/find_unique.rs new file mode 100644 index 0000000000..7116a9e90d --- /dev/null +++ b/src/datatypes2/src/vectors/operations/find_unique.rs @@ -0,0 +1,367 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_base::BitVec; + +use crate::scalars::ScalarVector; +use crate::vectors::constant::ConstantVector; +use crate::vectors::{NullVector, Vector}; + +// To implement `find_unique()` correctly, we need to keep in mind that always marks an element as +// selected when it is different from the previous one, and leaves the `selected` unchanged +// in any other case. +pub(crate) fn find_unique_scalar<'a, T: ScalarVector>( + vector: &'a T, + selected: &'a mut BitVec, + prev_vector: Option<&'a T>, +) where + T::RefItem<'a>: PartialEq, +{ + assert!(selected.len() >= vector.len()); + + if vector.is_empty() { + return; + } + + for ((i, current), next) in vector + .iter_data() + .enumerate() + .zip(vector.iter_data().skip(1)) + { + if current != next { + // If next element is a different element, we mark it as selected. + selected.set(i + 1, true); + } + } + + // Marks first element as selected if it is different from previous element, otherwise + // keep selected bitmap unchanged. + let is_first_not_duplicate = prev_vector + .map(|pv| { + if pv.is_empty() { + true + } else { + let last = pv.get_data(pv.len() - 1); + last != vector.get_data(0) + } + }) + .unwrap_or(true); + if is_first_not_duplicate { + selected.set(0, true); + } +} + +pub(crate) fn find_unique_null( + vector: &NullVector, + selected: &mut BitVec, + prev_vector: Option<&NullVector>, +) { + if vector.is_empty() { + return; + } + + let is_first_not_duplicate = prev_vector.map(NullVector::is_empty).unwrap_or(true); + if is_first_not_duplicate { + selected.set(0, true); + } +} + +pub(crate) fn find_unique_constant( + vector: &ConstantVector, + selected: &mut BitVec, + prev_vector: Option<&ConstantVector>, +) { + if vector.is_empty() { + return; + } + + let is_first_not_duplicate = prev_vector + .map(|pv| { + if pv.is_empty() { + true + } else { + vector.get_constant_ref() != pv.get_constant_ref() + } + }) + .unwrap_or(true); + + if is_first_not_duplicate { + selected.set(0, true); + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_time::{Date, DateTime}; + + use super::*; + use crate::timestamp::*; + use crate::vectors::{Int32Vector, StringVector, Vector, VectorOp}; + + fn check_bitmap(expect: &[bool], selected: &BitVec) { + let actual = selected.iter().collect::>(); + assert_eq!(expect, actual); + } + + fn check_find_unique_scalar(expect: &[bool], input: &[i32], prev: Option<&[i32]>) { + check_find_unique_scalar_opt(expect, input.iter().map(|v| Some(*v)), prev); + } + + fn check_find_unique_scalar_opt( + expect: &[bool], + input: impl Iterator>, + prev: Option<&[i32]>, + ) { + let input = Int32Vector::from(input.collect::>()); + let prev = prev.map(Int32Vector::from_slice); + + let mut selected = BitVec::repeat(false, input.len()); + input.find_unique(&mut selected, prev.as_ref().map(|v| v as _)); + + check_bitmap(expect, &selected); + } + + #[test] + fn test_find_unique_scalar() { + check_find_unique_scalar(&[], &[], None); + check_find_unique_scalar(&[true], &[1], None); + check_find_unique_scalar(&[true, false], &[1, 1], None); + check_find_unique_scalar(&[true, true], &[1, 2], None); + check_find_unique_scalar(&[true, true, true, true], &[1, 2, 3, 4], None); + check_find_unique_scalar(&[true, false, true, false], &[1, 1, 3, 3], None); + check_find_unique_scalar(&[true, false, false, false, true], &[2, 2, 2, 2, 3], None); + + check_find_unique_scalar(&[true], &[5], Some(&[])); + check_find_unique_scalar(&[true], &[5], Some(&[3])); + check_find_unique_scalar(&[false], &[5], Some(&[5])); + check_find_unique_scalar(&[false], &[5], Some(&[4, 5])); + check_find_unique_scalar(&[false, true], &[5, 6], Some(&[4, 5])); + check_find_unique_scalar(&[false, true, false], &[5, 6, 6], Some(&[4, 5])); + check_find_unique_scalar( + &[false, true, false, true, true], + &[5, 6, 6, 7, 8], + Some(&[4, 5]), + ); + + check_find_unique_scalar_opt( + &[true, true, false, true, false], + [Some(1), Some(2), Some(2), None, None].into_iter(), + None, + ); + } + + #[test] + fn test_find_unique_scalar_multi_times_with_prev() { + let prev = Int32Vector::from_slice(&[1]); + + let v1 = Int32Vector::from_slice(&[2, 3, 4]); + let mut selected = BitVec::repeat(false, v1.len()); + v1.find_unique(&mut selected, Some(&prev)); + + // Though element in v2 are the same as prev, but we should still keep them. + let v2 = Int32Vector::from_slice(&[1, 1, 1]); + v2.find_unique(&mut selected, Some(&prev)); + + check_bitmap(&[true, true, true], &selected); + } + + fn new_bitmap(bits: &[bool]) -> BitVec { + BitVec::from_iter(bits) + } + + #[test] + fn test_find_unique_scalar_with_prev() { + let prev = Int32Vector::from_slice(&[1]); + + let mut selected = new_bitmap(&[true, false, true, false]); + let v = Int32Vector::from_slice(&[2, 3, 4, 5]); + v.find_unique(&mut selected, Some(&prev)); + // All elements are different. + check_bitmap(&[true, true, true, true], &selected); + + let mut selected = new_bitmap(&[true, false, true, false]); + let v = Int32Vector::from_slice(&[1, 2, 3, 4]); + v.find_unique(&mut selected, Some(&prev)); + // Though first element is duplicate, but we keep the flag unchanged. + check_bitmap(&[true, true, true, true], &selected); + + // Same case as above, but now `prev` is None. + let mut selected = new_bitmap(&[true, false, true, false]); + let v = Int32Vector::from_slice(&[1, 2, 3, 4]); + v.find_unique(&mut selected, None); + check_bitmap(&[true, true, true, true], &selected); + + // Same case as above, but now `prev` is empty. + let mut selected = new_bitmap(&[true, false, true, false]); + let v = Int32Vector::from_slice(&[1, 2, 3, 4]); + v.find_unique(&mut selected, Some(&Int32Vector::from_slice(&[]))); + check_bitmap(&[true, true, true, true], &selected); + + let mut selected = new_bitmap(&[false, false, false, false]); + let v = Int32Vector::from_slice(&[2, 2, 4, 5]); + v.find_unique(&mut selected, Some(&prev)); + // only v[1] is duplicate. + check_bitmap(&[true, false, true, true], &selected); + } + + fn check_find_unique_null(len: usize) { + let input = NullVector::new(len); + let mut selected = BitVec::repeat(false, input.len()); + input.find_unique(&mut selected, None); + + let mut expect = vec![false; len]; + if !expect.is_empty() { + expect[0] = true; + } + check_bitmap(&expect, &selected); + + let mut selected = BitVec::repeat(false, input.len()); + let prev = Some(NullVector::new(1)); + input.find_unique(&mut selected, prev.as_ref().map(|v| v as _)); + let expect = vec![false; len]; + check_bitmap(&expect, &selected); + } + + #[test] + fn test_find_unique_null() { + for len in 0..5 { + check_find_unique_null(len); + } + } + + #[test] + fn test_find_unique_null_with_prev() { + let prev = NullVector::new(1); + + // Keep flags unchanged. + let mut selected = new_bitmap(&[true, false, true, false]); + let v = NullVector::new(4); + v.find_unique(&mut selected, Some(&prev)); + check_bitmap(&[true, false, true, false], &selected); + + // Keep flags unchanged. + let mut selected = new_bitmap(&[false, false, true, false]); + v.find_unique(&mut selected, Some(&prev)); + check_bitmap(&[false, false, true, false], &selected); + + // Prev is None, select first element. + let mut selected = new_bitmap(&[false, false, true, false]); + v.find_unique(&mut selected, None); + check_bitmap(&[true, false, true, false], &selected); + + // Prev is empty, select first element. + let mut selected = new_bitmap(&[false, false, true, false]); + v.find_unique(&mut selected, Some(&NullVector::new(0))); + check_bitmap(&[true, false, true, false], &selected); + } + + fn check_find_unique_constant(len: usize) { + let input = ConstantVector::new(Arc::new(Int32Vector::from_slice(&[8])), len); + let mut selected = BitVec::repeat(false, len); + input.find_unique(&mut selected, None); + + let mut expect = vec![false; len]; + if !expect.is_empty() { + expect[0] = true; + } + check_bitmap(&expect, &selected); + + let mut selected = BitVec::repeat(false, len); + let prev = Some(ConstantVector::new( + Arc::new(Int32Vector::from_slice(&[8])), + 1, + )); + input.find_unique(&mut selected, prev.as_ref().map(|v| v as _)); + let expect = vec![false; len]; + check_bitmap(&expect, &selected); + } + + #[test] + fn test_find_unique_constant() { + for len in 0..5 { + check_find_unique_constant(len); + } + } + + #[test] + fn test_find_unique_constant_with_prev() { + let prev = ConstantVector::new(Arc::new(Int32Vector::from_slice(&[1])), 1); + + // Keep flags unchanged. + let mut selected = new_bitmap(&[true, false, true, false]); + let v = ConstantVector::new(Arc::new(Int32Vector::from_slice(&[1])), 4); + v.find_unique(&mut selected, Some(&prev)); + check_bitmap(&[true, false, true, false], &selected); + + // Keep flags unchanged. + let mut selected = new_bitmap(&[false, false, true, false]); + v.find_unique(&mut selected, Some(&prev)); + check_bitmap(&[false, false, true, false], &selected); + + // Prev is None, select first element. + let mut selected = new_bitmap(&[false, false, true, false]); + v.find_unique(&mut selected, None); + check_bitmap(&[true, false, true, false], &selected); + + // Prev is empty, select first element. + let mut selected = new_bitmap(&[false, false, true, false]); + v.find_unique( + &mut selected, + Some(&ConstantVector::new( + Arc::new(Int32Vector::from_slice(&[1])), + 0, + )), + ); + check_bitmap(&[true, false, true, false], &selected); + + // Different constant vector. + let mut selected = new_bitmap(&[false, false, true, false]); + let v = ConstantVector::new(Arc::new(Int32Vector::from_slice(&[2])), 4); + v.find_unique(&mut selected, Some(&prev)); + check_bitmap(&[true, false, true, false], &selected); + } + + #[test] + fn test_find_unique_string() { + let input = StringVector::from_slice(&["a", "a", "b", "c"]); + let mut selected = BitVec::repeat(false, 4); + input.find_unique(&mut selected, None); + let expect = vec![true, false, true, true]; + check_bitmap(&expect, &selected); + } + + macro_rules! impl_find_unique_date_like_test { + ($VectorType: ident, $ValueType: ident, $method: ident) => {{ + use $crate::vectors::$VectorType; + + let v = $VectorType::from_iterator([8, 8, 9, 10].into_iter().map($ValueType::$method)); + let mut selected = BitVec::repeat(false, 4); + v.find_unique(&mut selected, None); + let expect = vec![true, false, true, true]; + check_bitmap(&expect, &selected); + }}; + } + + #[test] + fn test_find_unique_date_like() { + impl_find_unique_date_like_test!(DateVector, Date, new); + impl_find_unique_date_like_test!(DateTimeVector, DateTime, new); + impl_find_unique_date_like_test!(TimestampSecondVector, TimestampSecond, from); + impl_find_unique_date_like_test!(TimestampMillisecondVector, TimestampMillisecond, from); + impl_find_unique_date_like_test!(TimestampMicrosecondVector, TimestampMicrosecond, from); + impl_find_unique_date_like_test!(TimestampNanosecondVector, TimestampNanosecond, from); + } +} diff --git a/src/datatypes2/src/vectors/operations/replicate.rs b/src/datatypes2/src/vectors/operations/replicate.rs new file mode 100644 index 0000000000..8216517fc6 --- /dev/null +++ b/src/datatypes2/src/vectors/operations/replicate.rs @@ -0,0 +1,170 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::prelude::*; +pub(crate) use crate::vectors::null::replicate_null; +pub(crate) use crate::vectors::primitive::replicate_primitive; + +pub(crate) fn replicate_scalar(c: &C, offsets: &[usize]) -> VectorRef { + assert_eq!(offsets.len(), c.len()); + + if offsets.is_empty() { + return c.slice(0, 0); + } + let mut builder = <::Builder>::with_capacity(c.len()); + + let mut previous_offset = 0; + for (i, offset) in offsets.iter().enumerate() { + let data = c.get_data(i); + for _ in previous_offset..*offset { + builder.push(data); + } + previous_offset = *offset; + } + builder.to_vector() +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_time::timestamp::TimeUnit; + use common_time::{Date, DateTime, Timestamp}; + use paste::paste; + + use super::*; + use crate::vectors::constant::ConstantVector; + use crate::vectors::{Int32Vector, NullVector, StringVector, VectorOp}; + + #[test] + fn test_replicate_primitive() { + let v = Int32Vector::from_iterator(0..5); + let offsets = [0, 1, 2, 3, 4]; + + let v = v.replicate(&offsets); + assert_eq!(4, v.len()); + + for i in 0..4 { + assert_eq!(Value::Int32(i as i32 + 1), v.get(i)); + } + } + + #[test] + fn test_replicate_nullable_primitive() { + let v = Int32Vector::from(vec![None, Some(1), None, Some(2)]); + let offsets = [2, 4, 6, 8]; + let v = v.replicate(&offsets); + assert_eq!(8, v.len()); + + let expect: VectorRef = Arc::new(Int32Vector::from(vec![ + None, + None, + Some(1), + Some(1), + None, + None, + Some(2), + Some(2), + ])); + assert_eq!(expect, v); + } + + #[test] + fn test_replicate_scalar() { + let v = StringVector::from_slice(&["0", "1", "2", "3"]); + let offsets = [1, 3, 5, 6]; + + let v = v.replicate(&offsets); + assert_eq!(6, v.len()); + + let expect: VectorRef = Arc::new(StringVector::from_slice(&["0", "1", "1", "2", "2", "3"])); + assert_eq!(expect, v); + } + + #[test] + fn test_replicate_constant() { + let v = Arc::new(StringVector::from_slice(&["hello"])); + let cv = ConstantVector::new(v.clone(), 2); + let offsets = [1, 4]; + + let cv = cv.replicate(&offsets); + assert_eq!(4, cv.len()); + + let expect: VectorRef = Arc::new(ConstantVector::new(v, 4)); + assert_eq!(expect, cv); + } + + #[test] + fn test_replicate_null() { + let v = NullVector::new(0); + let offsets = []; + let v = v.replicate(&offsets); + assert!(v.is_empty()); + + let v = NullVector::new(3); + let offsets = [1, 3, 5]; + + let v = v.replicate(&offsets); + assert_eq!(5, v.len()); + } + + macro_rules! impl_replicate_date_like_test { + ($VectorType: ident, $ValueType: ident, $method: ident) => {{ + use $crate::vectors::$VectorType; + + let v = $VectorType::from_iterator((0..5).map($ValueType::$method)); + let offsets = [0, 1, 2, 3, 4]; + + let v = v.replicate(&offsets); + assert_eq!(4, v.len()); + + for i in 0..4 { + assert_eq!( + Value::$ValueType($ValueType::$method((i as i32 + 1).into())), + v.get(i) + ); + } + }}; + } + + macro_rules! impl_replicate_timestamp_test { + ($unit: ident) => {{ + paste!{ + use $crate::vectors::[]; + use $crate::timestamp::[]; + let v = []::from_iterator((0..5).map([]::from)); + let offsets = [0, 1, 2, 3, 4]; + let v = v.replicate(&offsets); + assert_eq!(4, v.len()); + for i in 0..4 { + assert_eq!( + Value::Timestamp(Timestamp::new(i as i64 + 1, TimeUnit::$unit)), + v.get(i) + ); + } + } + }}; + } + + #[test] + fn test_replicate_date_like() { + impl_replicate_date_like_test!(DateVector, Date, new); + impl_replicate_date_like_test!(DateTimeVector, DateTime, new); + + impl_replicate_timestamp_test!(Second); + impl_replicate_timestamp_test!(Millisecond); + impl_replicate_timestamp_test!(Microsecond); + impl_replicate_timestamp_test!(Nanosecond); + } +} diff --git a/src/datatypes2/src/vectors/primitive.rs b/src/datatypes2/src/vectors/primitive.rs new file mode 100644 index 0000000000..7829c31731 --- /dev/null +++ b/src/datatypes2/src/vectors/primitive.rs @@ -0,0 +1,552 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use arrow::array::{ + Array, ArrayBuilder, ArrayData, ArrayIter, ArrayRef, PrimitiveArray, PrimitiveBuilder, +}; +use serde_json::Value as JsonValue; +use snafu::OptionExt; + +use crate::data_type::ConcreteDataType; +use crate::error::{self, Result}; +use crate::scalars::{Scalar, ScalarRef, ScalarVector, ScalarVectorBuilder}; +use crate::serialize::Serializable; +use crate::types::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, LogicalPrimitiveType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, WrapperType, +}; +use crate::value::{Value, ValueRef}; +use crate::vectors::{self, MutableVector, Validity, Vector, VectorRef}; + +pub type UInt8Vector = PrimitiveVector; +pub type UInt16Vector = PrimitiveVector; +pub type UInt32Vector = PrimitiveVector; +pub type UInt64Vector = PrimitiveVector; + +pub type Int8Vector = PrimitiveVector; +pub type Int16Vector = PrimitiveVector; +pub type Int32Vector = PrimitiveVector; +pub type Int64Vector = PrimitiveVector; + +pub type Float32Vector = PrimitiveVector; +pub type Float64Vector = PrimitiveVector; + +/// Vector for primitive data types. +pub struct PrimitiveVector { + array: PrimitiveArray, +} + +impl PrimitiveVector { + pub fn new(array: PrimitiveArray) -> Self { + Self { array } + } + + pub fn try_from_arrow_array(array: impl AsRef) -> Result { + let data = array + .as_ref() + .as_any() + .downcast_ref::>() + .with_context(|| error::ConversionSnafu { + from: format!("{:?}", array.as_ref().data_type()), + })? + .data() + .clone(); + let concrete_array = PrimitiveArray::::from(data); + Ok(Self::new(concrete_array)) + } + + pub fn from_slice>(slice: P) -> Self { + let iter = slice.as_ref().iter().copied(); + Self { + array: PrimitiveArray::from_iter_values(iter), + } + } + + pub fn from_wrapper_slice>(slice: P) -> Self { + let iter = slice.as_ref().iter().copied().map(WrapperType::into_native); + Self { + array: PrimitiveArray::from_iter_values(iter), + } + } + + pub fn from_vec(array: Vec) -> Self { + Self { + array: PrimitiveArray::from_iter_values(array), + } + } + + pub fn from_values>(iter: I) -> Self { + Self { + array: PrimitiveArray::from_iter_values(iter), + } + } + + pub(crate) fn as_arrow(&self) -> &PrimitiveArray { + &self.array + } + + fn to_array_data(&self) -> ArrayData { + self.array.data().clone() + } + + fn from_array_data(data: ArrayData) -> Self { + Self { + array: PrimitiveArray::from(data), + } + } + + // To distinguish with `Vector::slice()`. + fn get_slice(&self, offset: usize, length: usize) -> Self { + let data = self.array.data().slice(offset, length); + Self::from_array_data(data) + } +} + +impl Vector for PrimitiveVector { + fn data_type(&self) -> ConcreteDataType { + T::build_data_type() + } + + fn vector_type_name(&self) -> String { + format!("{}Vector", T::type_name()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.array.len() + } + + fn to_arrow_array(&self) -> ArrayRef { + let data = self.to_array_data(); + Arc::new(PrimitiveArray::::from(data)) + } + + fn to_boxed_arrow_array(&self) -> Box { + let data = self.to_array_data(); + Box::new(PrimitiveArray::::from(data)) + } + + fn validity(&self) -> Validity { + vectors::impl_validity_for_vector!(self.array) + } + + fn memory_size(&self) -> usize { + self.array.get_buffer_memory_size() + } + + fn null_count(&self) -> usize { + self.array.null_count() + } + + fn is_null(&self, row: usize) -> bool { + self.array.is_null(row) + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef { + let data = self.array.data().slice(offset, length); + Arc::new(Self::from_array_data(data)) + } + + fn get(&self, index: usize) -> Value { + if self.array.is_valid(index) { + // Safety: The index have been checked by `is_valid()`. + let wrapper = unsafe { T::Wrapper::from_native(self.array.value_unchecked(index)) }; + wrapper.into() + } else { + Value::Null + } + } + + fn get_ref(&self, index: usize) -> ValueRef { + if self.array.is_valid(index) { + // Safety: The index have been checked by `is_valid()`. + let wrapper = unsafe { T::Wrapper::from_native(self.array.value_unchecked(index)) }; + wrapper.into() + } else { + ValueRef::Null + } + } +} + +impl fmt::Debug for PrimitiveVector { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("PrimitiveVector") + .field("array", &self.array) + .finish() + } +} + +impl From> for PrimitiveVector { + fn from(array: PrimitiveArray) -> Self { + Self { array } + } +} + +impl From>> for PrimitiveVector { + fn from(v: Vec>) -> Self { + Self { + array: PrimitiveArray::from_iter(v), + } + } +} + +pub struct PrimitiveIter<'a, T: LogicalPrimitiveType> { + iter: ArrayIter<&'a PrimitiveArray>, +} + +impl<'a, T: LogicalPrimitiveType> Iterator for PrimitiveIter<'a, T> { + type Item = Option; + + fn next(&mut self) -> Option> { + self.iter + .next() + .map(|item| item.map(T::Wrapper::from_native)) + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +impl ScalarVector for PrimitiveVector { + type OwnedItem = T::Wrapper; + type RefItem<'a> = T::Wrapper; + type Iter<'a> = PrimitiveIter<'a, T>; + type Builder = PrimitiveVectorBuilder; + + fn get_data(&self, idx: usize) -> Option> { + if self.array.is_valid(idx) { + Some(T::Wrapper::from_native(self.array.value(idx))) + } else { + None + } + } + + fn iter_data(&self) -> Self::Iter<'_> { + PrimitiveIter { + iter: self.array.iter(), + } + } +} + +impl Serializable for PrimitiveVector { + fn serialize_to_json(&self) -> Result> { + let res = self + .iter_data() + .map(|v| match v { + None => serde_json::Value::Null, + // use WrapperType's Into bound instead of + // serde_json::to_value to facilitate customized serialization + // for WrapperType + Some(v) => v.into(), + }) + .collect::>(); + Ok(res) + } +} + +impl PartialEq for PrimitiveVector { + fn eq(&self, other: &PrimitiveVector) -> bool { + self.array == other.array + } +} + +pub type UInt8VectorBuilder = PrimitiveVectorBuilder; +pub type UInt16VectorBuilder = PrimitiveVectorBuilder; +pub type UInt32VectorBuilder = PrimitiveVectorBuilder; +pub type UInt64VectorBuilder = PrimitiveVectorBuilder; + +pub type Int8VectorBuilder = PrimitiveVectorBuilder; +pub type Int16VectorBuilder = PrimitiveVectorBuilder; +pub type Int32VectorBuilder = PrimitiveVectorBuilder; +pub type Int64VectorBuilder = PrimitiveVectorBuilder; + +pub type Float32VectorBuilder = PrimitiveVectorBuilder; +pub type Float64VectorBuilder = PrimitiveVectorBuilder; + +/// Builder to build a primitive vector. +pub struct PrimitiveVectorBuilder { + mutable_array: PrimitiveBuilder, +} + +impl MutableVector for PrimitiveVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + T::build_data_type() + } + + fn len(&self) -> usize { + self.mutable_array.len() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + Arc::new(self.finish()) + } + + fn push_value_ref(&mut self, value: ValueRef) -> Result<()> { + let primitive = T::cast_value_ref(value)?; + match primitive { + Some(v) => self.mutable_array.append_value(v.into_native()), + None => self.mutable_array.append_null(), + } + Ok(()) + } + + fn extend_slice_of(&mut self, vector: &dyn Vector, offset: usize, length: usize) -> Result<()> { + let primitive = T::cast_vector(vector)?; + // Slice the underlying array to avoid creating a new Arc. + let slice = primitive.get_slice(offset, length); + for v in slice.iter_data() { + self.push(v); + } + Ok(()) + } +} + +impl ScalarVectorBuilder for PrimitiveVectorBuilder +where + T: LogicalPrimitiveType, + T::Wrapper: Scalar>, + for<'a> T::Wrapper: ScalarRef<'a, ScalarType = T::Wrapper>, + for<'a> T::Wrapper: Scalar = T::Wrapper>, +{ + type VectorType = PrimitiveVector; + + fn with_capacity(capacity: usize) -> Self { + Self { + mutable_array: PrimitiveBuilder::with_capacity(capacity), + } + } + + fn push(&mut self, value: Option<::RefItem<'_>>) { + self.mutable_array + .append_option(value.map(|v| v.into_native())); + } + + fn finish(&mut self) -> Self::VectorType { + PrimitiveVector { + array: self.mutable_array.finish(), + } + } +} + +pub(crate) fn replicate_primitive( + vector: &PrimitiveVector, + offsets: &[usize], +) -> PrimitiveVector { + assert_eq!(offsets.len(), vector.len()); + + if offsets.is_empty() { + return vector.get_slice(0, 0); + } + + let mut builder = PrimitiveVectorBuilder::::with_capacity(*offsets.last().unwrap() as usize); + + let mut previous_offset = 0; + + for (offset, value) in offsets.iter().zip(vector.array.iter()) { + let repeat_times = *offset - previous_offset; + match value { + Some(data) => { + unsafe { + // Safety: std::iter::Repeat and std::iter::Take implement TrustedLen. + builder + .mutable_array + .append_trusted_len_iter(std::iter::repeat(data).take(repeat_times)); + } + } + None => { + builder.mutable_array.append_nulls(repeat_times); + } + } + previous_offset = *offset; + } + builder.finish() +} + +#[cfg(test)] +mod tests { + use arrow::array::Int32Array; + use arrow::datatypes::DataType as ArrowDataType; + use serde_json; + + use super::*; + use crate::data_type::DataType; + use crate::serialize::Serializable; + use crate::types::Int64Type; + + fn check_vec(v: Int32Vector) { + assert_eq!(4, v.len()); + assert_eq!("Int32Vector", v.vector_type_name()); + assert!(!v.is_const()); + assert!(v.validity().is_all_valid()); + assert!(!v.only_null()); + + for i in 0..4 { + assert!(!v.is_null(i)); + assert_eq!(Value::Int32(i as i32 + 1), v.get(i)); + assert_eq!(ValueRef::Int32(i as i32 + 1), v.get_ref(i)); + } + + let json_value = v.serialize_to_json().unwrap(); + assert_eq!("[1,2,3,4]", serde_json::to_string(&json_value).unwrap(),); + + let arrow_arr = v.to_arrow_array(); + assert_eq!(4, arrow_arr.len()); + assert_eq!(&ArrowDataType::Int32, arrow_arr.data_type()); + } + + #[test] + fn test_from_values() { + let v = Int32Vector::from_values(vec![1, 2, 3, 4]); + check_vec(v); + } + + #[test] + fn test_from_vec() { + let v = Int32Vector::from_vec(vec![1, 2, 3, 4]); + check_vec(v); + } + + #[test] + fn test_from_slice() { + let v = Int32Vector::from_slice(vec![1, 2, 3, 4]); + check_vec(v); + } + + #[test] + fn test_serialize_primitive_vector_with_null_to_json() { + let input = [Some(1i32), Some(2i32), None, Some(4i32), None]; + let mut builder = Int32VectorBuilder::with_capacity(input.len()); + for v in input { + builder.push(v); + } + let vector = builder.finish(); + + let json_value = vector.serialize_to_json().unwrap(); + assert_eq!( + "[1,2,null,4,null]", + serde_json::to_string(&json_value).unwrap(), + ); + } + + #[test] + fn test_from_arrow_array() { + let arrow_array = Int32Array::from(vec![1, 2, 3, 4]); + let v = Int32Vector::from(arrow_array); + check_vec(v); + } + + #[test] + fn test_primitive_vector_build_get() { + let input = [Some(1i32), Some(2i32), None, Some(4i32), None]; + let mut builder = Int32VectorBuilder::with_capacity(input.len()); + for v in input { + builder.push(v); + } + let vector = builder.finish(); + assert_eq!(input.len(), vector.len()); + + for (i, v) in input.into_iter().enumerate() { + assert_eq!(v, vector.get_data(i)); + assert_eq!(Value::from(v), vector.get(i)); + } + + let res: Vec<_> = vector.iter_data().collect(); + assert_eq!(input, &res[..]); + } + + #[test] + fn test_primitive_vector_validity() { + let input = [Some(1i32), Some(2i32), None, None]; + let mut builder = Int32VectorBuilder::with_capacity(input.len()); + for v in input { + builder.push(v); + } + let vector = builder.finish(); + assert_eq!(2, vector.null_count()); + let validity = vector.validity(); + assert_eq!(2, validity.null_count()); + assert!(!validity.is_set(2)); + assert!(!validity.is_set(3)); + + let vector = Int32Vector::from_slice(vec![1, 2, 3, 4]); + assert_eq!(0, vector.null_count()); + assert!(vector.validity().is_all_valid()); + } + + #[test] + fn test_memory_size() { + let v = Int32Vector::from_slice((0..5).collect::>()); + assert_eq!(64, v.memory_size()); + let v = Int64Vector::from(vec![Some(0i64), Some(1i64), Some(2i64), None, None]); + assert_eq!(128, v.memory_size()); + } + + #[test] + fn test_primitive_vector_builder() { + let mut builder = Int64Type::default().create_mutable_vector(3); + builder.push_value_ref(ValueRef::Int64(123)).unwrap(); + assert!(builder.push_value_ref(ValueRef::Int32(123)).is_err()); + + let input = Int64Vector::from_slice(&[7, 8, 9]); + builder.extend_slice_of(&input, 1, 2).unwrap(); + assert!(builder + .extend_slice_of(&Int32Vector::from_slice(&[13]), 0, 1) + .is_err()); + let vector = builder.to_vector(); + + let expect: VectorRef = Arc::new(Int64Vector::from_slice(&[123, 8, 9])); + assert_eq!(expect, vector); + } + + #[test] + fn test_from_wrapper_slice() { + macro_rules! test_from_wrapper_slice { + ($vec: ident, $ty: ident) => { + let from_wrapper_slice = $vec::from_wrapper_slice(&[ + $ty::from_native($ty::MAX), + $ty::from_native($ty::MIN), + ]); + let from_slice = $vec::from_slice(&[$ty::MAX, $ty::MIN]); + assert_eq!(from_wrapper_slice, from_slice); + }; + } + + test_from_wrapper_slice!(UInt8Vector, u8); + test_from_wrapper_slice!(Int8Vector, i8); + test_from_wrapper_slice!(UInt16Vector, u16); + test_from_wrapper_slice!(Int16Vector, i16); + test_from_wrapper_slice!(UInt32Vector, u32); + test_from_wrapper_slice!(Int32Vector, i32); + test_from_wrapper_slice!(UInt64Vector, u64); + test_from_wrapper_slice!(Int64Vector, i64); + test_from_wrapper_slice!(Float32Vector, f32); + test_from_wrapper_slice!(Float64Vector, f64); + } +} diff --git a/src/datatypes2/src/vectors/string.rs b/src/datatypes2/src/vectors/string.rs new file mode 100644 index 0000000000..252116b3b2 --- /dev/null +++ b/src/datatypes2/src/vectors/string.rs @@ -0,0 +1,370 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{Array, ArrayBuilder, ArrayData, ArrayIter, ArrayRef}; +use snafu::ResultExt; + +use crate::arrow_array::{MutableStringArray, StringArray}; +use crate::data_type::ConcreteDataType; +use crate::error::{self, Result}; +use crate::scalars::{ScalarVector, ScalarVectorBuilder}; +use crate::serialize::Serializable; +use crate::value::{Value, ValueRef}; +use crate::vectors::{self, MutableVector, Validity, Vector, VectorRef}; + +/// Vector of strings. +#[derive(Debug, PartialEq)] +pub struct StringVector { + array: StringArray, +} + +impl StringVector { + pub(crate) fn as_arrow(&self) -> &dyn Array { + &self.array + } + + fn to_array_data(&self) -> ArrayData { + self.array.data().clone() + } + + fn from_array_data(data: ArrayData) -> Self { + Self { + array: StringArray::from(data), + } + } +} + +impl From for StringVector { + fn from(array: StringArray) -> Self { + Self { array } + } +} + +impl From>> for StringVector { + fn from(data: Vec>) -> Self { + Self { + array: StringArray::from_iter(data), + } + } +} + +impl From>> for StringVector { + fn from(data: Vec>) -> Self { + Self { + array: StringArray::from_iter(data), + } + } +} + +impl From<&[Option]> for StringVector { + fn from(data: &[Option]) -> Self { + Self { + array: StringArray::from_iter(data), + } + } +} + +impl From<&[Option<&str>]> for StringVector { + fn from(data: &[Option<&str>]) -> Self { + Self { + array: StringArray::from_iter(data), + } + } +} + +impl From> for StringVector { + fn from(data: Vec) -> Self { + Self { + array: StringArray::from_iter(data.into_iter().map(Some)), + } + } +} + +impl From> for StringVector { + fn from(data: Vec<&str>) -> Self { + Self { + array: StringArray::from_iter(data.into_iter().map(Some)), + } + } +} + +impl Vector for StringVector { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::string_datatype() + } + + fn vector_type_name(&self) -> String { + "StringVector".to_string() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn len(&self) -> usize { + self.array.len() + } + + fn to_arrow_array(&self) -> ArrayRef { + let data = self.to_array_data(); + Arc::new(StringArray::from(data)) + } + + fn to_boxed_arrow_array(&self) -> Box { + let data = self.to_array_data(); + Box::new(StringArray::from(data)) + } + + fn validity(&self) -> Validity { + vectors::impl_validity_for_vector!(self.array) + } + + fn memory_size(&self) -> usize { + self.array.get_buffer_memory_size() + } + + fn null_count(&self) -> usize { + self.array.null_count() + } + + fn is_null(&self, row: usize) -> bool { + self.array.is_null(row) + } + + fn slice(&self, offset: usize, length: usize) -> VectorRef { + let data = self.array.data().slice(offset, length); + Arc::new(Self::from_array_data(data)) + } + + fn get(&self, index: usize) -> Value { + vectors::impl_get_for_vector!(self.array, index) + } + + fn get_ref(&self, index: usize) -> ValueRef { + vectors::impl_get_ref_for_vector!(self.array, index) + } +} + +impl ScalarVector for StringVector { + type OwnedItem = String; + type RefItem<'a> = &'a str; + type Iter<'a> = ArrayIter<&'a StringArray>; + type Builder = StringVectorBuilder; + + fn get_data(&self, idx: usize) -> Option> { + if self.array.is_valid(idx) { + Some(self.array.value(idx)) + } else { + None + } + } + + fn iter_data(&self) -> Self::Iter<'_> { + self.array.iter() + } +} + +pub struct StringVectorBuilder { + mutable_array: MutableStringArray, +} + +impl MutableVector for StringVectorBuilder { + fn data_type(&self) -> ConcreteDataType { + ConcreteDataType::string_datatype() + } + + fn len(&self) -> usize { + self.mutable_array.len() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn to_vector(&mut self) -> VectorRef { + Arc::new(self.finish()) + } + + fn push_value_ref(&mut self, value: ValueRef) -> Result<()> { + match value.as_string()? { + Some(v) => self.mutable_array.append_value(v), + None => self.mutable_array.append_null(), + } + Ok(()) + } + + fn extend_slice_of(&mut self, vector: &dyn Vector, offset: usize, length: usize) -> Result<()> { + vectors::impl_extend_for_builder!(self, vector, StringVector, offset, length) + } +} + +impl ScalarVectorBuilder for StringVectorBuilder { + type VectorType = StringVector; + + fn with_capacity(capacity: usize) -> Self { + Self { + mutable_array: MutableStringArray::with_capacity(capacity, 0), + } + } + + fn push(&mut self, value: Option<::RefItem<'_>>) { + match value { + Some(v) => self.mutable_array.append_value(v), + None => self.mutable_array.append_null(), + } + } + + fn finish(&mut self) -> Self::VectorType { + StringVector { + array: self.mutable_array.finish(), + } + } +} + +impl Serializable for StringVector { + fn serialize_to_json(&self) -> Result> { + self.iter_data() + .map(serde_json::to_value) + .collect::>() + .context(error::SerializeSnafu) + } +} + +vectors::impl_try_from_arrow_array_for_vector!(StringArray, StringVector); + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType; + + use super::*; + + #[test] + fn test_string_vector_build_get() { + let mut builder = StringVectorBuilder::with_capacity(4); + builder.push(Some("hello")); + builder.push(None); + builder.push(Some("world")); + let vector = builder.finish(); + + assert_eq!(Some("hello"), vector.get_data(0)); + assert_eq!(None, vector.get_data(1)); + assert_eq!(Some("world"), vector.get_data(2)); + + // Get out of bound + assert!(vector.try_get(3).is_err()); + + assert_eq!(Value::String("hello".into()), vector.get(0)); + assert_eq!(Value::Null, vector.get(1)); + assert_eq!(Value::String("world".into()), vector.get(2)); + + let mut iter = vector.iter_data(); + assert_eq!("hello", iter.next().unwrap().unwrap()); + assert_eq!(None, iter.next().unwrap()); + assert_eq!("world", iter.next().unwrap().unwrap()); + assert_eq!(None, iter.next()); + } + + #[test] + fn test_string_vector_builder() { + let mut builder = StringVectorBuilder::with_capacity(3); + builder.push_value_ref(ValueRef::String("hello")).unwrap(); + assert!(builder.push_value_ref(ValueRef::Int32(123)).is_err()); + + let input = StringVector::from_slice(&["world", "one", "two"]); + builder.extend_slice_of(&input, 1, 2).unwrap(); + assert!(builder + .extend_slice_of(&crate::vectors::Int32Vector::from_slice(&[13]), 0, 1) + .is_err()); + let vector = builder.to_vector(); + + let expect: VectorRef = Arc::new(StringVector::from_slice(&["hello", "one", "two"])); + assert_eq!(expect, vector); + } + + #[test] + fn test_string_vector_misc() { + let strs = vec!["hello", "greptime", "rust"]; + let v = StringVector::from(strs.clone()); + assert_eq!(3, v.len()); + assert_eq!("StringVector", v.vector_type_name()); + assert!(!v.is_const()); + assert!(v.validity().is_all_valid()); + assert!(!v.only_null()); + assert_eq!(128, v.memory_size()); + + for (i, s) in strs.iter().enumerate() { + assert_eq!(Value::from(*s), v.get(i)); + assert_eq!(ValueRef::from(*s), v.get_ref(i)); + assert_eq!(Value::from(*s), v.try_get(i).unwrap()); + } + + let arrow_arr = v.to_arrow_array(); + assert_eq!(3, arrow_arr.len()); + assert_eq!(&DataType::Utf8, arrow_arr.data_type()); + } + + #[test] + fn test_serialize_string_vector() { + let mut builder = StringVectorBuilder::with_capacity(3); + builder.push(Some("hello")); + builder.push(None); + builder.push(Some("world")); + let string_vector = builder.finish(); + let serialized = + serde_json::to_string(&string_vector.serialize_to_json().unwrap()).unwrap(); + assert_eq!(r#"["hello",null,"world"]"#, serialized); + } + + #[test] + fn test_from_arrow_array() { + let mut builder = MutableStringArray::new(); + builder.append_option(Some("A")); + builder.append_option(Some("B")); + builder.append_null(); + builder.append_option(Some("D")); + let string_array: StringArray = builder.finish(); + let vector = StringVector::from(string_array); + assert_eq!( + r#"["A","B",null,"D"]"#, + serde_json::to_string(&vector.serialize_to_json().unwrap()).unwrap(), + ); + } + + #[test] + fn test_from_non_option_string() { + let nul = String::from_utf8(vec![0]).unwrap(); + let corpus = vec!["😅😅😅", "😍😍😍😍", "🥵🥵", nul.as_str()]; + let vector = StringVector::from(corpus); + let serialized = serde_json::to_string(&vector.serialize_to_json().unwrap()).unwrap(); + assert_eq!(r#"["😅😅😅","😍😍😍😍","🥵🥵","\u0000"]"#, serialized); + + let corpus = vec![ + "🀀🀀🀀".to_string(), + "🀁🀁🀁".to_string(), + "🀂🀂🀂".to_string(), + "🀃🀃🀃".to_string(), + "🀆🀆".to_string(), + ]; + let vector = StringVector::from(corpus); + let serialized = serde_json::to_string(&vector.serialize_to_json().unwrap()).unwrap(); + assert_eq!(r#"["🀀🀀🀀","🀁🀁🀁","🀂🀂🀂","🀃🀃🀃","🀆🀆"]"#, serialized); + } +} diff --git a/src/datatypes2/src/vectors/timestamp.rs b/src/datatypes2/src/vectors/timestamp.rs new file mode 100644 index 0000000000..5d9f7f2ed1 --- /dev/null +++ b/src/datatypes2/src/vectors/timestamp.rs @@ -0,0 +1,31 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::types::{ + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, +}; +use crate::vectors::{PrimitiveVector, PrimitiveVectorBuilder}; + +pub type TimestampSecondVector = PrimitiveVector; +pub type TimestampSecondVectorBuilder = PrimitiveVectorBuilder; + +pub type TimestampMillisecondVector = PrimitiveVector; +pub type TimestampMillisecondVectorBuilder = PrimitiveVectorBuilder; + +pub type TimestampMicrosecondVector = PrimitiveVector; +pub type TimestampMicrosecondVectorBuilder = PrimitiveVectorBuilder; + +pub type TimestampNanosecondVector = PrimitiveVector; +pub type TimestampNanosecondVectorBuilder = PrimitiveVectorBuilder; diff --git a/src/datatypes2/src/vectors/validity.rs b/src/datatypes2/src/vectors/validity.rs new file mode 100644 index 0000000000..01c7faa789 --- /dev/null +++ b/src/datatypes2/src/vectors/validity.rs @@ -0,0 +1,159 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::array::ArrayData; +use arrow::bitmap::Bitmap; + +#[derive(Debug, PartialEq)] +enum ValidityKind<'a> { + /// Whether the array slot is valid or not (null). + Slots { + bitmap: &'a Bitmap, + len: usize, + null_count: usize, + }, + /// All slots are valid. + AllValid { len: usize }, + /// All slots are null. + AllNull { len: usize }, +} + +/// Validity of a vector. +#[derive(Debug, PartialEq)] +pub struct Validity<'a> { + kind: ValidityKind<'a>, +} + +impl<'a> Validity<'a> { + /// Creates a `Validity` from [`ArrayData`]. + pub fn from_array_data(data: &'a ArrayData) -> Validity<'a> { + match data.null_bitmap() { + Some(bitmap) => Validity { + kind: ValidityKind::Slots { + bitmap, + len: data.len(), + null_count: data.null_count(), + }, + }, + None => Validity::all_valid(data.len()), + } + } + + /// Returns `Validity` that all elements are valid. + pub fn all_valid(len: usize) -> Validity<'a> { + Validity { + kind: ValidityKind::AllValid { len }, + } + } + + /// Returns `Validity` that all elements are null. + pub fn all_null(len: usize) -> Validity<'a> { + Validity { + kind: ValidityKind::AllNull { len }, + } + } + + /// Returns whether `i-th` bit is set. + pub fn is_set(&self, i: usize) -> bool { + match self.kind { + ValidityKind::Slots { bitmap, .. } => bitmap.is_set(i), + ValidityKind::AllValid { len } => i < len, + ValidityKind::AllNull { .. } => false, + } + } + + /// Returns true if all bits are null. + pub fn is_all_null(&self) -> bool { + match self.kind { + ValidityKind::Slots { + len, null_count, .. + } => len == null_count, + ValidityKind::AllValid { .. } => false, + ValidityKind::AllNull { .. } => true, + } + } + + /// Returns true if all bits are valid. + pub fn is_all_valid(&self) -> bool { + match self.kind { + ValidityKind::Slots { null_count, .. } => null_count == 0, + ValidityKind::AllValid { .. } => true, + ValidityKind::AllNull { .. } => false, + } + } + + /// The number of null slots on this [`Vector`]. + pub fn null_count(&self) -> usize { + match self.kind { + ValidityKind::Slots { null_count, .. } => null_count, + ValidityKind::AllValid { .. } => 0, + ValidityKind::AllNull { len } => len, + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, Int32Array}; + + use super::*; + + #[test] + fn test_all_valid() { + let validity = Validity::all_valid(5); + assert!(validity.is_all_valid()); + assert!(!validity.is_all_null()); + assert_eq!(0, validity.null_count()); + for i in 0..5 { + assert!(validity.is_set(i)); + } + assert!(!validity.is_set(5)); + } + + #[test] + fn test_all_null() { + let validity = Validity::all_null(5); + assert!(validity.is_all_null()); + assert!(!validity.is_all_valid()); + assert_eq!(5, validity.null_count()); + for i in 0..5 { + assert!(!validity.is_set(i)); + } + assert!(!validity.is_set(5)); + } + + #[test] + fn test_from_array_data() { + let array = Int32Array::from_iter([None, Some(1), None]); + let validity = Validity::from_array_data(array.data()); + assert_eq!(2, validity.null_count()); + assert!(!validity.is_set(0)); + assert!(validity.is_set(1)); + assert!(!validity.is_set(2)); + assert!(!validity.is_all_null()); + assert!(!validity.is_all_valid()); + + let array = Int32Array::from_iter([None, None]); + let validity = Validity::from_array_data(array.data()); + assert!(validity.is_all_null()); + assert!(!validity.is_all_valid()); + assert_eq!(2, validity.null_count()); + + let array = Int32Array::from_iter_values([1, 2]); + let validity = Validity::from_array_data(array.data()); + assert!(!validity.is_all_null()); + assert!(validity.is_all_valid()); + assert_eq!(0, validity.null_count()); + } +} diff --git a/src/frontend/Cargo.toml b/src/frontend/Cargo.toml index 69cd4d2861..6561d62460 100644 --- a/src/frontend/Cargo.toml +++ b/src/frontend/Cargo.toml @@ -12,11 +12,12 @@ catalog = { path = "../catalog" } chrono = "0.4" client = { path = "../client" } common-base = { path = "../common/base" } +common-catalog = { path = "../common/catalog" } common-error = { path = "../common/error" } common-grpc = { path = "../common/grpc" } +common-grpc-expr = { path = "../common/grpc-expr" } common-query = { path = "../common/query" } common-recordbatch = { path = "../common/recordbatch" } -common-catalog = { path = "../common/catalog" } common-runtime = { path = "../common/runtime" } common-telemetry = { path = "../common/telemetry" } common-time = { path = "../common/time" } @@ -33,12 +34,14 @@ moka = { version = "0.9", features = ["future"] } openmetrics-parser = "0.4" prost = "0.11" query = { path = "../query" } +rustls = "0.20" serde = "1.0" serde_json = "1.0" -sqlparser = "0.15" servers = { path = "../servers" } +session = { path = "../session" } snafu = { version = "0.7", features = ["backtraces"] } sql = { path = "../sql" } +sqlparser = "0.15" store-api = { path = "../store-api" } table = { path = "../table" } tokio = { version = "1.18", features = ["full"] } diff --git a/src/frontend/src/catalog.rs b/src/frontend/src/catalog.rs index 2e5d7b64d4..86356db08c 100644 --- a/src/frontend/src/catalog.rs +++ b/src/frontend/src/catalog.rs @@ -19,8 +19,9 @@ use std::sync::Arc; use catalog::error::{self as catalog_err, InvalidCatalogValueSnafu}; use catalog::remote::{Kv, KvBackendRef}; use catalog::{ - CatalogList, CatalogManager, CatalogProvider, CatalogProviderRef, RegisterSchemaRequest, - RegisterSystemTableRequest, RegisterTableRequest, SchemaProvider, SchemaProviderRef, + CatalogList, CatalogManager, CatalogProvider, CatalogProviderRef, DeregisterTableRequest, + RegisterSchemaRequest, RegisterSystemTableRequest, RegisterTableRequest, SchemaProvider, + SchemaProviderRef, }; use common_catalog::{CatalogKey, SchemaKey, TableGlobalKey, TableGlobalValue}; use futures::StreamExt; @@ -65,17 +66,21 @@ impl CatalogManager for FrontendCatalogManager { Ok(()) } - async fn register_table( + async fn register_table(&self, _request: RegisterTableRequest) -> catalog::error::Result { + unimplemented!() + } + + async fn deregister_table( &self, - _request: RegisterTableRequest, - ) -> catalog::error::Result { + _request: DeregisterTableRequest, + ) -> catalog::error::Result { unimplemented!() } async fn register_schema( &self, _request: RegisterSchemaRequest, - ) -> catalog::error::Result { + ) -> catalog::error::Result { unimplemented!() } @@ -273,8 +278,7 @@ impl SchemaProvider for FrontendSchemaProvider { } Some(r) => r, }; - let val = TableGlobalValue::parse(String::from_utf8_lossy(&res.1)) - .context(InvalidCatalogValueSnafu)?; + let val = TableGlobalValue::from_bytes(&res.1).context(InvalidCatalogValueSnafu)?; let table = Arc::new(DistTable::new( table_name, diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 9a23a2320a..9b6275c7bf 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -250,6 +250,12 @@ pub enum Error { source: client::Error, }, + #[snafu(display("Failed to drop table, source: {}", source))] + DropTable { + #[snafu(backtrace)] + source: client::Error, + }, + #[snafu(display("Failed to insert values to table, source: {}", source))] Insert { #[snafu(backtrace)] @@ -277,25 +283,25 @@ pub enum Error { #[snafu(display("Failed to build CreateExpr on insertion: {}", source))] BuildCreateExprOnInsertion { #[snafu(backtrace)] - source: common_insert::error::Error, + source: common_grpc_expr::error::Error, }, #[snafu(display("Failed to find new columns on insertion: {}", source))] FindNewColumnsOnInsertion { #[snafu(backtrace)] - source: common_insert::error::Error, + source: common_grpc_expr::error::Error, }, #[snafu(display("Failed to deserialize insert batching: {}", source))] DeserializeInsertBatch { #[snafu(backtrace)] - source: common_insert::error::Error, + source: common_grpc_expr::error::Error, }, #[snafu(display("Failed to deserialize insert batching: {}", source))] InsertBatchToRequest { #[snafu(backtrace)] - source: common_insert::error::Error, + source: common_grpc_expr::error::Error, }, #[snafu(display("Failed to find catalog by name: {}", catalog_name))] @@ -427,6 +433,18 @@ pub enum Error { #[snafu(display("Missing meta_client_opts section in config"))] MissingMetasrvOpts { backtrace: Backtrace }, + + #[snafu(display("Failed to convert AlterExpr to AlterRequest, source: {}", source))] + AlterExprToRequest { + #[snafu(backtrace)] + source: common_grpc_expr::error::Error, + }, + + #[snafu(display("Failed to find leaders when altering table, table: {}", table))] + LeaderNotFound { table: String, backtrace: Backtrace }, + + #[snafu(display("Table already exists: `{}`", table))] + TableAlreadyExist { table: String, backtrace: Backtrace }, } pub type Result = std::result::Result; @@ -497,23 +515,27 @@ impl ErrorExt for Error { Error::BumpTableId { source, .. } => source.status_code(), Error::SchemaNotFound { .. } => StatusCode::InvalidArguments, Error::CatalogNotFound { .. } => StatusCode::InvalidArguments, - Error::CreateTable { source, .. } => source.status_code(), - Error::AlterTable { source, .. } => source.status_code(), - Error::Insert { source, .. } => source.status_code(), + Error::CreateTable { source, .. } + | Error::AlterTable { source, .. } + | Error::DropTable { source } + | Error::Select { source, .. } + | Error::CreateDatabase { source, .. } + | Error::CreateTableOnInsertion { source, .. } + | Error::AlterTableOnInsertion { source, .. } + | Error::Insert { source, .. } => source.status_code(), Error::BuildCreateExprOnInsertion { source, .. } => source.status_code(), - Error::CreateTableOnInsertion { source, .. } => source.status_code(), - Error::AlterTableOnInsertion { source, .. } => source.status_code(), - Error::Select { source, .. } => source.status_code(), Error::FindNewColumnsOnInsertion { source, .. } => source.status_code(), Error::DeserializeInsertBatch { source, .. } => source.status_code(), Error::PrimaryKeyNotFound { .. } => StatusCode::InvalidArguments, Error::ExecuteSql { source, .. } => source.status_code(), Error::InsertBatchToRequest { source, .. } => source.status_code(), - Error::CreateDatabase { source, .. } => source.status_code(), Error::CollectRecordbatchStream { source } | Error::CreateRecordbatches { source } => { source.status_code() } Error::MissingMetasrvOpts { .. } => StatusCode::InvalidArguments, + Error::AlterExprToRequest { source, .. } => source.status_code(), + Error::LeaderNotFound { .. } => StatusCode::StorageUnavailable, + Error::TableAlreadyExist { .. } => StatusCode::TableAlreadyExists, } } diff --git a/src/frontend/src/expr_factory.rs b/src/frontend/src/expr_factory.rs index c8c7646c2c..9f406ace0b 100644 --- a/src/frontend/src/expr_factory.rs +++ b/src/frontend/src/expr_factory.rs @@ -16,8 +16,7 @@ use std::collections::HashMap; use std::sync::Arc; use api::helper::ColumnDataTypeWrapper; -use api::v1::codec::InsertBatch; -use api::v1::{ColumnDataType, CreateExpr}; +use api::v1::{Column, ColumnDataType, CreateExpr}; use datatypes::schema::ColumnSchema; use snafu::{ensure, ResultExt}; use sql::statements::create::{CreateTable, TIME_INDEX}; @@ -35,12 +34,12 @@ pub type CreateExprFactoryRef = Arc; pub trait CreateExprFactory { async fn create_expr_by_stmt(&self, stmt: &CreateTable) -> Result; - async fn create_expr_by_insert_batch( + async fn create_expr_by_columns( &self, catalog_name: &str, schema_name: &str, table_name: &str, - batch: &[InsertBatch], + columns: &[Column], ) -> crate::error::Result; } @@ -53,20 +52,20 @@ impl CreateExprFactory for DefaultCreateExprFactory { create_to_expr(None, vec![0], stmt) } - async fn create_expr_by_insert_batch( + async fn create_expr_by_columns( &self, catalog_name: &str, schema_name: &str, table_name: &str, - batch: &[InsertBatch], + columns: &[Column], ) -> Result { let table_id = None; - let create_expr = common_insert::build_create_expr_from_insertion( + let create_expr = common_grpc_expr::build_create_expr_from_insertion( catalog_name, schema_name, table_id, table_name, - batch, + columns, ) .context(BuildCreateExprOnInsertionSnafu)?; diff --git a/src/frontend/src/frontend.rs b/src/frontend/src/frontend.rs index 8a5a538449..521ed6c834 100644 --- a/src/frontend/src/frontend.rs +++ b/src/frontend/src/frontend.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use meta_client::MetaClientOpts; use serde::{Deserialize, Serialize}; +use servers::http::HttpOptions; use servers::Mode; use snafu::prelude::*; @@ -31,7 +32,7 @@ use crate::server::Services; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct FrontendOptions { - pub http_addr: Option, + pub http_options: Option, pub grpc_options: Option, pub mysql_options: Option, pub postgres_options: Option, @@ -46,7 +47,7 @@ pub struct FrontendOptions { impl Default for FrontendOptions { fn default() -> Self { Self { - http_addr: Some("127.0.0.1:4000".to_string()), + http_options: Some(HttpOptions::default()), grpc_options: Some(GrpcOptions::default()), mysql_options: Some(MysqlOptions::default()), postgres_options: Some(PostgresOptions::default()), diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 57dd9a634f..b1c04389a7 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -17,17 +17,16 @@ mod influxdb; mod opentsdb; mod prometheus; -use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use api::result::ObjectResultBuilder; use api::v1::alter_expr::Kind; -use api::v1::codec::InsertBatch; use api::v1::object_expr::Expr; use api::v1::{ - admin_expr, insert_expr, select_expr, AddColumns, AdminExpr, AdminResult, AlterExpr, - CreateDatabaseExpr, CreateExpr, InsertExpr, ObjectExpr, ObjectResult as GrpcObjectResult, + admin_expr, select_expr, AddColumns, AdminExpr, AdminResult, AlterExpr, Column, + CreateDatabaseExpr, CreateExpr, DropTableExpr, InsertExpr, ObjectExpr, + ObjectResult as GrpcObjectResult, }; use async_trait::async_trait; use catalog::remote::MetaKvBackend; @@ -39,6 +38,7 @@ use common_error::prelude::{BoxedError, StatusCode}; use common_grpc::channel_manager::{ChannelConfig, ChannelManager}; use common_grpc::select::to_object_result; use common_query::Output; +use common_recordbatch::RecordBatches; use common_telemetry::{debug, error, info}; use distributed::DistInstance; use meta_client::client::MetaClientBuilder; @@ -48,10 +48,12 @@ use servers::query_handler::{ PrometheusProtocolHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler, }; use servers::{error as server_error, Mode}; +use session::context::{QueryContext, QueryContextRef}; use snafu::prelude::*; use sql::dialect::GenericDialect; use sql::parser::ParserContext; use sql::statements::create::Partitions; +use sql::statements::explain::Explain; use sql::statements::insert::Insert; use sql::statements::statement::Statement; @@ -59,13 +61,14 @@ use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ self, AlterTableOnInsertionSnafu, AlterTableSnafu, CatalogNotFoundSnafu, CatalogSnafu, - CreateDatabaseSnafu, CreateTableSnafu, DeserializeInsertBatchSnafu, - FindNewColumnsOnInsertionSnafu, InsertSnafu, MissingMetasrvOptsSnafu, Result, - SchemaNotFoundSnafu, SelectSnafu, + CreateDatabaseSnafu, CreateTableSnafu, DropTableSnafu, FindNewColumnsOnInsertionSnafu, + InsertSnafu, MissingMetasrvOptsSnafu, Result, SchemaNotFoundSnafu, SelectSnafu, + UnsupportedExprSnafu, }; use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory}; use crate::frontend::FrontendOptions; use crate::sql::insert_to_request; +use crate::table::insert::insert_request_to_insert_batch; use crate::table::route::TableRoutes; #[async_trait] @@ -210,10 +213,15 @@ impl Instance { self.script_handler = Some(handler); } - pub async fn handle_select(&self, expr: Select, stmt: Statement) -> Result { + async fn handle_select( + &self, + expr: Select, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result { if let Some(dist_instance) = &self.dist_instance { let Select::Sql(sql) = expr; - dist_instance.handle_sql(&sql, stmt).await + dist_instance.handle_sql(&sql, stmt, query_ctx).await } else { // TODO(LFC): Refactor consideration: Datanode should directly execute statement in standalone mode to avoid parse SQL again. // Find a better way to execute query between Frontend and Datanode in standalone mode. @@ -268,15 +276,52 @@ impl Instance { /// Handle alter expr pub async fn handle_alter(&self, expr: AlterExpr) -> Result { - self.admin(expr.schema_name.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME)) - .alter(expr) - .await - .and_then(admin_result_to_output) - .context(AlterTableSnafu) + match &self.dist_instance { + Some(dist_instance) => dist_instance.handle_alter_table(expr).await, + None => self + .admin(expr.schema_name.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME)) + .alter(expr) + .await + .and_then(admin_result_to_output) + .context(AlterTableSnafu), + } + } + + /// Handle drop table expr + pub async fn handle_drop_table(&self, expr: DropTableExpr) -> Result { + match self.mode { + Mode::Standalone => self + .admin(&expr.schema_name) + .drop_table(expr) + .await + .and_then(admin_result_to_output) + .context(DropTableSnafu), + // TODO(ruihang): support drop table in distributed mode + Mode::Distributed => UnsupportedExprSnafu { + name: "Distributed DROP TABLE", + } + .fail(), + } + } + + /// Handle explain expr + pub async fn handle_explain( + &self, + sql: &str, + explain_stmt: Explain, + query_ctx: QueryContextRef, + ) -> Result { + if let Some(dist_instance) = &self.dist_instance { + dist_instance + .handle_sql(sql, Statement::Explain(explain_stmt), query_ctx) + .await + } else { + Ok(Output::AffectedRows(0)) + } } /// Handle batch inserts - pub async fn handle_inserts(&self, insert_expr: &[InsertExpr]) -> Result { + pub async fn handle_inserts(&self, insert_expr: Vec) -> Result { let mut success = 0; for expr in insert_expr { match self.handle_insert(expr).await? { @@ -288,68 +333,20 @@ impl Instance { } /// Handle insert. for 'values' insertion, create/alter the destination table on demand. - pub async fn handle_insert(&self, insert_expr: &InsertExpr) -> Result { + pub async fn handle_insert(&self, mut insert_expr: InsertExpr) -> Result { let table_name = &insert_expr.table_name; let catalog_name = DEFAULT_CATALOG_NAME; let schema_name = &insert_expr.schema_name; - if let Some(expr) = &insert_expr.expr { - match expr { - api::v1::insert_expr::Expr::Values(values) => { - // TODO(hl): gRPC should also support partitioning. - let region_number = 0; - self.handle_insert_values( - catalog_name, - schema_name, - table_name, - region_number, - values, - ) - .await - } - api::v1::insert_expr::Expr::Sql(_) => { - // Frontend does not comprehend insert request that is raw SQL string - self.database(schema_name) - .insert(insert_expr.clone()) - .await - .and_then(Output::try_from) - .context(InsertSnafu) - } - } - } else { - // expr is empty - Ok(Output::AffectedRows(0)) - } - } + let columns = &insert_expr.columns; + + self.create_or_alter_table_on_demand(catalog_name, schema_name, table_name, columns) + .await?; + + insert_expr.region_number = 0; - /// Handle insert requests in frontend - /// If insert is SQL string flavor, just forward to datanode - /// If insert is parsed InsertExpr, frontend should comprehend the schema and create/alter table on demand. - pub async fn handle_insert_values( - &self, - catalog_name: &str, - schema_name: &str, - table_name: &str, - region_number: u32, - values: &insert_expr::Values, - ) -> Result { - let insert_batches = - common_insert::insert_batches(&values.values).context(DeserializeInsertBatchSnafu)?; - self.create_or_alter_table_on_demand( - catalog_name, - schema_name, - table_name, - &insert_batches, - ) - .await?; self.database(schema_name) - .insert(InsertExpr { - schema_name: schema_name.to_string(), - table_name: table_name.to_string(), - region_number, - options: Default::default(), - expr: Some(insert_expr::Expr::Values(values.clone())), - }) + .insert(insert_expr) .await .and_then(Output::try_from) .context(InsertSnafu) @@ -363,7 +360,7 @@ impl Instance { catalog_name: &str, schema_name: &str, table_name: &str, - insert_batches: &[InsertBatch], + columns: &[Column], ) -> Result<()> { match self .catalog_manager @@ -385,13 +382,8 @@ impl Instance { "Table {}.{}.{} does not exist, try create table", catalog_name, schema_name, table_name, ); - self.create_table_by_insert_batches( - catalog_name, - schema_name, - table_name, - insert_batches, - ) - .await?; + self.create_table_by_columns(catalog_name, schema_name, table_name, columns) + .await?; info!( "Successfully created table on insertion: {}.{}.{}", catalog_name, schema_name, table_name @@ -399,7 +391,8 @@ impl Instance { } Some(table) => { let schema = table.schema(); - if let Some(add_columns) = common_insert::find_new_columns(&schema, insert_batches) + + if let Some(add_columns) = common_grpc_expr::find_new_columns(&schema, columns) .context(FindNewColumnsOnInsertionSnafu)? { info!( @@ -424,17 +417,17 @@ impl Instance { } /// Infer create table expr from inserting data - async fn create_table_by_insert_batches( + async fn create_table_by_columns( &self, catalog_name: &str, schema_name: &str, table_name: &str, - insert_batches: &[InsertBatch], + columns: &[Column], ) -> Result { // Create table automatically, build schema from data. let create_expr = self .create_expr_factory - .create_expr_by_insert_batch(catalog_name, schema_name, table_name, insert_batches) + .create_expr_by_columns(catalog_name, schema_name, table_name, columns) .await?; info!( @@ -495,9 +488,10 @@ impl Instance { let insert_request = insert_to_request(&schema_provider, *insert)?; - let batch = crate::table::insert::insert_request_to_insert_batch(&insert_request)?; + let (columns, _row_count) = + crate::table::insert::insert_request_to_insert_batch(&insert_request)?; - self.create_or_alter_table_on_demand(&catalog, &schema, &table, &[batch]) + self.create_or_alter_table_on_demand(&catalog, &schema, &table, &columns) .await?; let table = schema_provider @@ -510,6 +504,39 @@ impl Instance { .await .context(error::TableSnafu) } + + fn stmt_to_insert_batch( + &self, + catalog: &str, + schema: &str, + insert: Box, + ) -> Result<(Vec, u32)> { + let catalog_provider = self.get_catalog(catalog)?; + let schema_provider = Self::get_schema(catalog_provider, schema)?; + + let insert_request = insert_to_request(&schema_provider, *insert)?; + insert_request_to_insert_batch(&insert_request) + } + + fn handle_use(&self, db: String, query_ctx: QueryContextRef) -> Result { + let catalog_manager = &self.catalog_manager; + if let Some(catalog_manager) = catalog_manager { + ensure!( + catalog_manager + .schema(DEFAULT_CATALOG_NAME, &db) + .context(error::CatalogSnafu)? + .is_some(), + error::SchemaNotFoundSnafu { schema_info: &db } + ); + + query_ctx.set_current_schema(&db); + + Ok(Output::RecordBatches(RecordBatches::empty())) + } else { + // TODO(LFC): Handle "use" stmt here. + unimplemented!() + } + } } #[async_trait] @@ -550,20 +577,26 @@ fn parse_stmt(sql: &str) -> Result { #[async_trait] impl SqlQueryHandler for Instance { - async fn do_query(&self, query: &str) -> server_error::Result { + async fn do_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> server_error::Result { let stmt = parse_stmt(query) .map_err(BoxedError::new) .context(server_error::ExecuteQuerySnafu { query })?; match stmt { - Statement::Query(_) => self - .handle_select(Select::Sql(query.to_string()), stmt) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }), + Statement::ShowDatabases(_) + | Statement::ShowTables(_) + | Statement::DescribeTable(_) + | Statement::Query(_) => { + self.handle_select(Select::Sql(query.to_string()), stmt, query_ctx) + .await + } Statement::Insert(insert) => match self.mode { Mode::Standalone => { - let (_, schema_name, table_name) = insert + let (catalog_name, schema_name, table_name) = insert .full_table_name() .context(error::ParseSqlSnafu) .map_err(BoxedError::new) @@ -571,17 +604,19 @@ impl SqlQueryHandler for Instance { msg: "Failed to get table name", })?; + let (columns, row_count) = self + .stmt_to_insert_batch(&catalog_name, &schema_name, insert) + .map_err(BoxedError::new) + .context(server_error::ExecuteQuerySnafu { query })?; + let expr = InsertExpr { schema_name, table_name, - expr: Some(insert_expr::Expr::Sql(query.to_string())), region_number: 0, - options: HashMap::default(), + columns, + row_count, }; - self.handle_insert(&expr) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }) + self.handle_insert(expr).await } Mode::Distributed => { let affected = self @@ -604,39 +639,36 @@ impl SqlQueryHandler for Instance { self.handle_create_table(create_expr, create.partitions) .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }) } - - Statement::ShowDatabases(_) - | Statement::ShowTables(_) - | Statement::DescribeTable(_) => self - .handle_select(Select::Sql(query.to_string()), stmt) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }), - Statement::CreateDatabase(c) => { let expr = CreateDatabaseExpr { database_name: c.name.to_string(), }; - self.handle_create_database(expr) - .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }) + self.handle_create_database(expr).await } - Statement::Alter(alter_stmt) => self - .handle_alter( + Statement::Alter(alter_stmt) => { + self.handle_alter( AlterExpr::try_from(alter_stmt) .map_err(BoxedError::new) .context(server_error::ExecuteAlterSnafu { query })?, ) .await - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }), - Statement::ShowCreateTable(_) => { - return server_error::NotSupportedSnafu { feat: query }.fail() } + Statement::DropTable(drop_stmt) => { + let expr = DropTableExpr { + catalog_name: drop_stmt.catalog_name, + schema_name: drop_stmt.schema_name, + table_name: drop_stmt.table_name, + }; + self.handle_drop_table(expr).await + } + Statement::Explain(explain_stmt) => { + self.handle_explain(query, explain_stmt, query_ctx).await + } + Statement::ShowCreateTable(_) => { + return server_error::NotSupportedSnafu { feat: query }.fail(); + } + Statement::Use(db) => self.handle_use(db, query_ctx), } .map_err(BoxedError::new) .context(server_error::ExecuteQuerySnafu { query }) @@ -674,7 +706,8 @@ impl GrpcQueryHandler for Instance { if let Some(expr) = &query.expr { match expr { Expr::Insert(insert) => { - let result = self.handle_insert(insert).await; + // TODO(fys): refactor, avoid clone + let result = self.handle_insert(insert.clone()).await; result .map(|o| match o { Output::AffectedRows(rows) => ObjectResultBuilder::new() @@ -699,7 +732,8 @@ impl GrpcQueryHandler for Instance { })?; match select { select_expr::Expr::Sql(sql) => { - let output = SqlQueryHandler::do_query(self, sql).await; + let query_ctx = Arc::new(QueryContext::new()); + let output = SqlQueryHandler::do_query(self, sql, query_ctx).await; Ok(to_object_result(output).await) } _ => { @@ -739,6 +773,7 @@ fn get_schema_name(expr: &AdminExpr) -> &str { Some(admin_expr::Expr::Create(expr)) => expr.schema_name.as_deref(), Some(admin_expr::Expr::Alter(expr)) => expr.schema_name.as_deref(), Some(admin_expr::Expr::CreateDatabase(_)) | None => Some(DEFAULT_SCHEMA_NAME), + Some(admin_expr::Expr::DropTable(expr)) => Some(expr.schema_name.as_ref()), }; schema_name.unwrap_or(DEFAULT_SCHEMA_NAME) } @@ -765,7 +800,7 @@ impl GrpcAdminHandler for Instance { mod tests { use std::assert_matches::assert_matches; - use api::v1::codec::{InsertBatch, SelectResult}; + use api::v1::codec::SelectResult; use api::v1::column::SemanticType; use api::v1::{ admin_expr, admin_result, column, object_expr, object_result, select_expr, Column, @@ -779,6 +814,8 @@ mod tests { #[tokio::test] async fn test_execute_sql() { + let query_ctx = Arc::new(QueryContext::new()); + let instance = tests::create_frontend_instance().await; let sql = r#"CREATE TABLE demo( @@ -788,9 +825,11 @@ mod tests { memory DOUBLE NULL, disk_util DOUBLE DEFAULT 9.9, TIME INDEX (ts), - PRIMARY KEY(ts, host) + PRIMARY KEY(host) ) engine=mito with(regions=1);"#; - let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap(); + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 1), _ => unreachable!(), @@ -801,14 +840,18 @@ mod tests { ('frontend.host2', null, null, 2000), ('frontend.host3', 3.3, 300, 3000) "#; - let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap(); + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 3), _ => unreachable!(), } let sql = "select * from demo"; - let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap(); + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .unwrap(); match output { Output::RecordBatches(recordbatches) => { let pretty_print = recordbatches.pretty_print(); @@ -828,7 +871,9 @@ mod tests { }; let sql = "select * from demo where ts>cast(1000000000 as timestamp)"; // use nanoseconds as where condition - let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap(); + let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) + .await + .unwrap(); match output { Output::RecordBatches(recordbatches) => { let pretty_print = recordbatches.pretty_print(); @@ -924,22 +969,19 @@ mod tests { ); // insert - let values = vec![InsertBatch { - columns: vec![ - expected_host_col.clone(), - expected_cpu_col.clone(), - expected_mem_col.clone(), - expected_ts_col.clone(), - ], - row_count: 4, - } - .into()]; + let columns = vec![ + expected_host_col.clone(), + expected_cpu_col.clone(), + expected_mem_col.clone(), + expected_ts_col.clone(), + ]; + let row_count = 4; let insert_expr = InsertExpr { schema_name: "public".to_string(), table_name: "demo".to_string(), - expr: Some(insert_expr::Expr::Values(insert_expr::Values { values })), - options: HashMap::default(), region_number: 0, + columns, + row_count, }; let object_expr = ObjectExpr { header: Some(ExprHeader::default()), @@ -1035,7 +1077,7 @@ mod tests { desc: None, column_defs, time_index: "ts".to_string(), - primary_keys: vec!["ts".to_string(), "host".to_string()], + primary_keys: vec!["host".to_string()], create_if_not_exists: true, table_options: Default::default(), table_id: None, diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index f3b133096f..a96f817035 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -16,13 +16,14 @@ use std::collections::HashMap; use std::sync::Arc; use api::helper::ColumnDataTypeWrapper; -use api::v1::{CreateDatabaseExpr, CreateExpr}; +use api::v1::{AlterExpr, CreateDatabaseExpr, CreateExpr}; +use catalog::CatalogList; use chrono::DateTime; use client::admin::{admin_result_to_output, Admin}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::{SchemaKey, SchemaValue, TableGlobalKey, TableGlobalValue}; use common_query::Output; -use common_telemetry::{debug, info}; +use common_telemetry::{debug, error, info}; use datatypes::prelude::ConcreteDataType; use datatypes::schema::RawSchema; use meta_client::client::MetaClient; @@ -30,8 +31,9 @@ use meta_client::rpc::{ CreateRequest as MetaCreateRequest, Partition as MetaPartition, PutRequest, RouteResponse, TableName, TableRoute, }; -use query::sql::{describe_table, show_databases, show_tables}; +use query::sql::{describe_table, explain, show_databases, show_tables}; use query::{QueryEngineFactory, QueryEngineRef}; +use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::create::Partitions; use sql::statements::sql_value_to_value; @@ -42,10 +44,12 @@ use table::metadata::{RawTableInfo, RawTableMeta, TableIdent, TableType}; use crate::catalog::FrontendCatalogManager; use crate::datanode::DatanodeClients; use crate::error::{ - self, CatalogEntrySerdeSnafu, ColumnDataTypeSnafu, PrimaryKeyNotFoundSnafu, RequestMetaSnafu, - Result, StartMetaClientSnafu, + self, CatalogEntrySerdeSnafu, CatalogNotFoundSnafu, CatalogSnafu, ColumnDataTypeSnafu, + PrimaryKeyNotFoundSnafu, RequestMetaSnafu, Result, SchemaNotFoundSnafu, StartMetaClientSnafu, + TableNotFoundSnafu, }; use crate::partitioning::{PartitionBound, PartitionDef}; +use crate::table::DistTable; #[derive(Clone)] pub(crate) struct DistInstance { @@ -125,26 +129,31 @@ impl DistInstance { Ok(Output::AffectedRows(region_routes.len())) } - pub(crate) async fn handle_sql(&self, sql: &str, stmt: Statement) -> Result { + pub(crate) async fn handle_sql( + &self, + sql: &str, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result { match stmt { Statement::Query(_) => { let plan = self .query_engine - .statement_to_plan(stmt) + .statement_to_plan(stmt, query_ctx) .context(error::ExecuteSqlSnafu { sql })?; - self.query_engine - .execute(&plan) - .await - .context(error::ExecuteSqlSnafu { sql }) + self.query_engine.execute(&plan).await + } + Statement::ShowDatabases(stmt) => show_databases(stmt, self.catalog_manager.clone()), + Statement::ShowTables(stmt) => { + show_tables(stmt, self.catalog_manager.clone(), query_ctx) + } + Statement::DescribeTable(stmt) => describe_table(stmt, self.catalog_manager.clone()), + Statement::Explain(stmt) => { + explain(Box::new(stmt), self.query_engine.clone(), query_ctx).await } - Statement::ShowDatabases(stmt) => show_databases(stmt, self.catalog_manager.clone()) - .context(error::ExecuteSqlSnafu { sql }), - Statement::ShowTables(stmt) => show_tables(stmt, self.catalog_manager.clone()) - .context(error::ExecuteSqlSnafu { sql }), - Statement::DescribeTable(stmt) => describe_table(stmt, self.catalog_manager.clone()) - .context(error::ExecuteSqlSnafu { sql }), _ => unreachable!(), } + .context(error::ExecuteSqlSnafu { sql }) } /// Handles distributed database creation @@ -166,6 +175,34 @@ impl DistInstance { Ok(Output::AffectedRows(1)) } + pub async fn handle_alter_table(&self, expr: AlterExpr) -> Result { + let catalog_name = expr.catalog_name.as_deref().unwrap_or(DEFAULT_CATALOG_NAME); + let schema_name = expr.schema_name.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME); + let table_name = expr.table_name.as_str(); + let table = self + .catalog_manager + .catalog(catalog_name) + .context(CatalogSnafu)? + .context(CatalogNotFoundSnafu { catalog_name })? + .schema(schema_name) + .context(CatalogSnafu)? + .context(SchemaNotFoundSnafu { + schema_info: format!("{}.{}", catalog_name, schema_name), + })? + .table(table_name) + .context(CatalogSnafu)? + .context(TableNotFoundSnafu { + table_name: format!("{}.{}.{}", catalog_name, schema_name, table_name), + })?; + + let dist_table = table + .as_any() + .downcast_ref::() + .expect("Table impl must be DistTable in distributed mode"); + dist_table.alter_by_expr(expr).await?; + Ok(Output::AffectedRows(0)) + } + async fn create_table_in_meta( &self, create_table: &CreateExpr, @@ -205,17 +242,32 @@ impl DistInstance { catalog_name: table_name.catalog_name.clone(), schema_name: table_name.schema_name.clone(), table_name: table_name.table_name.clone(), - }; + } + .to_string(); let value = create_table_global_value(create_table, table_route)? .as_bytes() .context(error::CatalogEntrySerdeSnafu)?; - self.catalog_manager + if let Err(existing) = self + .catalog_manager .backend() - .set(key.to_string().as_bytes(), &value) + .compare_and_set(key.as_bytes(), &[], &value) .await - .context(error::CatalogSnafu) + .context(CatalogSnafu)? + { + let existing_bytes = existing.unwrap(); //this unwrap is safe since we compare with empty bytes and failed + let existing_value = + TableGlobalValue::from_bytes(&existing_bytes).context(CatalogEntrySerdeSnafu)?; + if existing_value.table_info.ident.table_id != create_table.table_id.unwrap() { + error!( + "Table with name {} already exists, value in catalog: {:?}", + key, existing_bytes + ); + return error::TableAlreadyExistSnafu { table: key }.fail(); + } + } + Ok(()) } } diff --git a/src/frontend/src/instance/influxdb.rs b/src/frontend/src/instance/influxdb.rs index 846d4f4527..aa25cffb8f 100644 --- a/src/frontend/src/instance/influxdb.rs +++ b/src/frontend/src/instance/influxdb.rs @@ -14,13 +14,11 @@ use std::collections::HashMap; -use api::v1::codec::InsertBatch; -use api::v1::insert_expr::Expr; -use api::v1::InsertExpr; +use api::v1::{Column, InsertExpr}; use async_trait::async_trait; use common_catalog::consts::DEFAULT_CATALOG_NAME; use common_error::prelude::BoxedError; -use common_insert::column_to_vector; +use common_grpc_expr::column_to_vector; use servers::influxdb::InfluxdbRequest; use servers::query_handler::InfluxdbLineProtocolHandler; use servers::{error as server_error, Mode}; @@ -28,7 +26,7 @@ use snafu::{OptionExt, ResultExt}; use table::requests::InsertRequest; use crate::error; -use crate::error::{DeserializeInsertBatchSnafu, InsertBatchToRequestSnafu, Result}; +use crate::error::{InsertBatchToRequestSnafu, Result}; use crate::instance::Instance; #[async_trait] @@ -37,7 +35,7 @@ impl InfluxdbLineProtocolHandler for Instance { match self.mode { Mode::Standalone => { let exprs: Vec = request.try_into()?; - self.handle_inserts(&exprs) + self.handle_inserts(exprs) .await .map_err(BoxedError::new) .context(server_error::ExecuteQuerySnafu { @@ -61,53 +59,41 @@ impl InfluxdbLineProtocolHandler for Instance { impl Instance { pub(crate) async fn dist_insert(&self, inserts: Vec) -> Result { let mut joins = Vec::with_capacity(inserts.len()); - let catalog_name = DEFAULT_CATALOG_NAME.to_string(); + let catalog_name = DEFAULT_CATALOG_NAME; for insert in inserts { let self_clone = self.clone(); - let insert_batches = match &insert.expr.unwrap() { - Expr::Values(values) => common_insert::insert_batches(&values.values) - .context(DeserializeInsertBatchSnafu)?, - Expr::Sql(_) => unreachable!(), - }; - self.create_or_alter_table_on_demand( - DEFAULT_CATALOG_NAME, - &insert.schema_name, - &insert.table_name, - &insert_batches, - ) - .await?; + let schema_name = insert.schema_name.to_string(); + let table_name = insert.table_name.to_string(); - let schema_name = insert.schema_name.clone(); - let table_name = insert.table_name.clone(); + let columns = &insert.columns; + let row_count = insert.row_count; - for insert_batch in &insert_batches { - let catalog_name = catalog_name.clone(); - let schema_name = schema_name.clone(); - let table_name = table_name.clone(); - let request = Self::insert_batch_to_request( - DEFAULT_CATALOG_NAME, - &schema_name, - &table_name, - insert_batch, - )?; - // TODO(fys): need a separate runtime here - let self_clone = self_clone.clone(); - let join = tokio::spawn(async move { - let catalog = self_clone.get_catalog(&catalog_name)?; - let schema = Self::get_schema(catalog, &schema_name)?; - let table = schema - .table(&table_name) - .context(error::CatalogSnafu)? - .context(error::TableNotFoundSnafu { - table_name: &table_name, - })?; + self.create_or_alter_table_on_demand(catalog_name, &schema_name, &table_name, columns) + .await?; - table.insert(request).await.context(error::TableSnafu) - }); - joins.push(join); - } + let request = Self::columns_to_request( + catalog_name, + &schema_name, + &table_name, + columns, + row_count, + )?; + + // TODO(fys): need a separate runtime here + let self_clone = self_clone.clone(); + let join = tokio::spawn(async move { + let catalog = self_clone.get_catalog(catalog_name)?; + let schema = Self::get_schema(catalog, &schema_name)?; + let table = schema + .table(&table_name) + .context(error::CatalogSnafu)? + .context(error::TableNotFoundSnafu { table_name })?; + + table.insert(request).await.context(error::TableSnafu) + }); + joins.push(join); } let mut affected = 0; @@ -119,16 +105,16 @@ impl Instance { Ok(affected) } - fn insert_batch_to_request( + fn columns_to_request( catalog_name: &str, schema_name: &str, table_name: &str, - batches: &InsertBatch, + columns: &[Column], + row_count: u32, ) -> Result { - let mut vectors = HashMap::with_capacity(batches.columns.len()); - for col in &batches.columns { - let vector = - column_to_vector(col, batches.row_count).context(InsertBatchToRequestSnafu)?; + let mut vectors = HashMap::with_capacity(columns.len()); + for col in columns { + let vector = column_to_vector(col, row_count).context(InsertBatchToRequestSnafu)?; vectors.insert(col.column_name.clone(), vector); } Ok(InsertRequest { diff --git a/src/frontend/src/instance/opentsdb.rs b/src/frontend/src/instance/opentsdb.rs index 1a7db09014..66b04b1317 100644 --- a/src/frontend/src/instance/opentsdb.rs +++ b/src/frontend/src/instance/opentsdb.rs @@ -53,16 +53,19 @@ impl OpentsdbProtocolHandler for Instance { impl Instance { async fn insert_opentsdb_metric(&self, data_point: &DataPoint) -> Result<()> { let expr = data_point.as_grpc_insert(); - self.handle_insert(&expr).await?; + self.handle_insert(expr).await?; Ok(()) } } #[cfg(test)] mod tests { + use std::sync::Arc; + use common_query::Output; use datafusion::arrow_print; use servers::query_handler::SqlQueryHandler; + use session::context::QueryContext; use super::*; use crate::tests; @@ -121,7 +124,7 @@ mod tests { assert!(result.is_ok()); let output = instance - .do_query("select * from my_metric_1") + .do_query("select * from my_metric_1", Arc::new(QueryContext::new())) .await .unwrap(); match output { diff --git a/src/frontend/src/instance/prometheus.rs b/src/frontend/src/instance/prometheus.rs index 246542893d..b6f322beb2 100644 --- a/src/frontend/src/instance/prometheus.rs +++ b/src/frontend/src/instance/prometheus.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use api::prometheus::remote::read_request::ResponseType; use api::prometheus::remote::{Query, QueryResult, ReadRequest, ReadResponse, WriteRequest}; use async_trait::async_trait; @@ -25,6 +27,7 @@ use servers::error::{self, Result as ServerResult}; use servers::prometheus::{self, Metrics}; use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse}; use servers::Mode; +use session::context::QueryContext; use snafu::{OptionExt, ResultExt}; use crate::instance::{parse_stmt, Instance}; @@ -93,7 +96,10 @@ impl Instance { let object_result = if let Some(dist_instance) = &self.dist_instance { let output = futures::future::ready(parse_stmt(&sql)) - .and_then(|stmt| dist_instance.handle_sql(&sql, stmt)) + .and_then(|stmt| { + let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string())); + dist_instance.handle_sql(&sql, stmt, query_ctx) + }) .await; to_object_result(output).await.try_into() } else { @@ -115,7 +121,7 @@ impl PrometheusProtocolHandler for Instance { Mode::Standalone => { let exprs = prometheus::write_request_to_insert_exprs(database, request)?; let futures = exprs - .iter() + .into_iter() .map(|e| self.handle_insert(e)) .collect::>(); let res = futures_util::future::join_all(futures) diff --git a/src/frontend/src/mysql.rs b/src/frontend/src/mysql.rs index 71bb600753..a0f8ef7961 100644 --- a/src/frontend/src/mysql.rs +++ b/src/frontend/src/mysql.rs @@ -12,12 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use serde::{Deserialize, Serialize}; +use servers::tls::TlsOption; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct MysqlOptions { pub addr: String, pub runtime_size: usize, + #[serde(default = "Default::default")] + pub tls: Arc, } impl Default for MysqlOptions { @@ -25,6 +30,7 @@ impl Default for MysqlOptions { Self { addr: "127.0.0.1:4002".to_string(), runtime_size: 2, + tls: Arc::new(TlsOption::default()), } } } diff --git a/src/frontend/src/postgres.rs b/src/frontend/src/postgres.rs index 41a11233bc..0b8c7d44e2 100644 --- a/src/frontend/src/postgres.rs +++ b/src/frontend/src/postgres.rs @@ -12,13 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use serde::{Deserialize, Serialize}; +use servers::tls::TlsOption; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct PostgresOptions { pub addr: String, pub runtime_size: usize, pub check_pwd: bool, + #[serde(default = "Default::default")] + pub tls: Arc, } impl Default for PostgresOptions { @@ -27,6 +32,7 @@ impl Default for PostgresOptions { addr: "127.0.0.1:4003".to_string(), runtime_size: 2, check_pwd: false, + tls: Default::default(), } } } diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 9ba9587600..8eee23da0c 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -69,7 +69,8 @@ impl Services { .context(error::RuntimeResourceSnafu)?, ); - let mysql_server = MysqlServer::create_server(instance.clone(), mysql_io_runtime); + let mysql_server = + MysqlServer::create_server(instance.clone(), mysql_io_runtime, opts.tls.clone()); Some((mysql_server, mysql_addr)) } else { @@ -90,6 +91,7 @@ impl Services { let pg_server = Box::new(PostgresServer::new( instance.clone(), opts.check_pwd, + opts.tls.clone(), pg_io_runtime, )) as Box; @@ -116,10 +118,10 @@ impl Services { None }; - let http_server_and_addr = if let Some(http_addr) = &opts.http_addr { - let http_addr = parse_addr(http_addr)?; + let http_server_and_addr = if let Some(http_options) = &opts.http_options { + let http_addr = parse_addr(&http_options.addr)?; - let mut http_server = HttpServer::new(instance.clone()); + let mut http_server = HttpServer::new(instance.clone(), http_options.clone()); if opentsdb_server_and_addr.is_some() { http_server.set_opentsdb_handler(instance.clone()); } diff --git a/src/frontend/src/table.rs b/src/frontend/src/table.rs index 0c07fd14f7..8f97ba12f7 100644 --- a/src/frontend/src/table.rs +++ b/src/frontend/src/table.rs @@ -18,13 +18,17 @@ use std::any::Any; use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use api::v1::AlterExpr; use async_trait::async_trait; +use client::admin::Admin; use client::Database; +use common_catalog::consts::DEFAULT_CATALOG_NAME; use common_query::error::Result as QueryResult; use common_query::logical_plan::Expr; use common_query::physical_plan::{PhysicalPlan, PhysicalPlanRef}; use common_recordbatch::adapter::AsyncRecordBatchStreamAdapter; use common_recordbatch::{RecordBatches, SendableRecordBatchStream}; +use common_telemetry::debug; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::Expr as DfExpr; use datafusion::physical_plan::{ @@ -43,7 +47,7 @@ use table::Table; use tokio::sync::RwLock; use crate::datanode::DatanodeClients; -use crate::error::{self, Error, Result}; +use crate::error::{self, Error, LeaderNotFoundSnafu, RequestDatanodeSnafu, Result}; use crate::partitioning::columns::RangeColumnsPartitionRule; use crate::partitioning::range::RangePartitionRule; use crate::partitioning::{ @@ -348,6 +352,36 @@ impl DistTable { }; Ok(partition_rule) } + + /// Define a `alter_by_expr` instead of impl [`Table::alter`] to avoid redundant conversion between + /// [`table::requests::AlterTableRequest`] and [`AlterExpr`]. + pub(crate) async fn alter_by_expr(&self, expr: AlterExpr) -> Result<()> { + let table_routes = self.table_routes.get_route(&self.table_name).await?; + let leaders = table_routes.find_leaders(); + ensure!( + !leaders.is_empty(), + LeaderNotFoundSnafu { + table: format!( + "{:?}.{:?}.{}", + expr.catalog_name, expr.schema_name, expr.table_name + ) + } + ); + for datanode in leaders { + let admin = Admin::new( + DEFAULT_CATALOG_NAME, + self.datanode_clients.get_client(&datanode).await, + ); + debug!("Sent alter table {:?} to {:?}", expr, admin); + let result = admin + .alter(expr.clone()) + .await + .context(RequestDatanodeSnafu)?; + debug!("Alter table result: {:?}", result); + // TODO(hl): We should further check and track alter result in some global DDL task tracker + } + Ok(()) + } } fn project_schema(table_schema: SchemaRef, projection: &Option>) -> SchemaRef { @@ -477,9 +511,8 @@ impl PartitionExec { mod test { use std::time::Duration; - use api::v1::codec::InsertBatch; use api::v1::column::SemanticType; - use api::v1::{column, insert_expr, Column, ColumnDataType}; + use api::v1::{column, Column, ColumnDataType}; use catalog::remote::MetaKvBackend; use common_recordbatch::util; use datafusion::arrow_print; @@ -936,8 +969,8 @@ mod test { start_ts: i64, ) { let rows = data.len() as u32; - let values = vec![InsertBatch { - columns: vec![ + let values = vec![( + vec![ Column { column_name: "ts".to_string(), values: Some(column::Values { @@ -967,10 +1000,8 @@ mod test { ..Default::default() }, ], - row_count: rows, - } - .into()]; - let values = insert_expr::Values { values }; + rows, + )]; dn_instance .execute_grpc_insert( &table_name.catalog_name, diff --git a/src/frontend/src/table/insert.rs b/src/frontend/src/table/insert.rs index ceb6780e13..409632474f 100644 --- a/src/frontend/src/table/insert.rs +++ b/src/frontend/src/table/insert.rs @@ -16,10 +16,8 @@ use std::collections::HashMap; use std::sync::Arc; use api::helper::ColumnDataTypeWrapper; -use api::v1::codec::InsertBatch; use api::v1::column::SemanticType; -use api::v1::insert_expr::Expr; -use api::v1::{codec, insert_expr, Column, InsertExpr, MutateResult}; +use api::v1::{Column, InsertExpr, MutateResult}; use client::{Database, ObjectResult}; use datatypes::prelude::ConcreteDataType; use snafu::{ensure, OptionExt, ResultExt}; @@ -84,7 +82,7 @@ impl DistTable { } } -pub fn insert_request_to_insert_batch(insert: &InsertRequest) -> Result { +pub fn insert_request_to_insert_batch(insert: &InsertRequest) -> Result<(Vec, u32)> { let mut row_count = None; let columns = insert @@ -127,24 +125,20 @@ pub fn insert_request_to_insert_batch(insert: &InsertRequest) -> Result>>()?; - let insert_batch = codec::InsertBatch { - columns, - row_count: row_count.map(|rows| rows as u32).unwrap_or(0), - }; - Ok(insert_batch) + let row_count = row_count.unwrap_or(0) as u32; + + Ok((columns, row_count)) } fn to_insert_expr(region_number: RegionNumber, insert: InsertRequest) -> Result { let table_name = insert.table_name.clone(); - let insert_batch = insert_request_to_insert_batch(&insert)?; + let (columns, row_count) = insert_request_to_insert_batch(&insert)?; Ok(InsertExpr { schema_name: insert.schema_name, table_name, - expr: Some(Expr::Values(insert_expr::Values { - values: vec![insert_batch.into()], - })), region_number, - options: Default::default(), + columns, + row_count, }) } @@ -152,8 +146,6 @@ fn to_insert_expr(region_number: RegionNumber, insert: InsertRequest) -> Result< mod tests { use std::collections::HashMap; - use api::v1::codec::InsertBatch; - use api::v1::insert_expr::Expr; use api::v1::{ColumnDataType, InsertExpr}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use datatypes::prelude::ConcreteDataType; @@ -199,16 +191,7 @@ mod tests { let table_name = insert_expr.table_name; assert_eq!("demo", table_name); - let expr = insert_expr.expr.as_ref().unwrap(); - let vals = match expr { - Expr::Values(vals) => vals, - Expr::Sql(_) => unreachable!(), - }; - - let batch: &[u8] = vals.values[0].as_ref(); - let vals: InsertBatch = batch.try_into().unwrap(); - - for column in vals.columns { + for column in insert_expr.columns { let name = column.column_name; if name == "id" { assert_eq!(0, column.null_mask[0]); diff --git a/src/meta-srv/src/service/router.rs b/src/meta-srv/src/service/router.rs index 11226fca1a..ba924e61d2 100644 --- a/src/meta-srv/src/service/router.rs +++ b/src/meta-srv/src/service/router.rs @@ -216,8 +216,7 @@ async fn get_table_global_value( let tv = get_from_store(kv_store, tg_key).await?; match tv { Some(tv) => { - let tv = TableGlobalValue::parse(&String::from_utf8_lossy(&tv)) - .context(error::InvalidCatalogValueSnafu)?; + let tv = TableGlobalValue::from_bytes(&tv).context(error::InvalidCatalogValueSnafu)?; Ok(Some(tv)) } None => Ok(None), diff --git a/src/mito/src/engine.rs b/src/mito/src/engine.rs index 627ef5bdf4..845493d745 100644 --- a/src/mito/src/engine.rs +++ b/src/mito/src/engine.rs @@ -21,7 +21,7 @@ use common_error::ext::BoxedError; use common_telemetry::logging; use datatypes::schema::SchemaRef; use object_store::ObjectStore; -use snafu::{OptionExt, ResultExt}; +use snafu::{ensure, OptionExt, ResultExt}; use store_api::storage::{ ColumnDescriptorBuilder, ColumnFamilyDescriptor, ColumnFamilyDescriptorBuilder, ColumnId, CreateOptions, EngineContext as StorageEngineContext, OpenOptions, RegionDescriptorBuilder, @@ -37,7 +37,8 @@ use tokio::sync::Mutex; use crate::config::EngineConfig; use crate::error::{ self, BuildColumnDescriptorSnafu, BuildColumnFamilyDescriptorSnafu, BuildRegionDescriptorSnafu, - BuildRowKeyDescriptorSnafu, MissingTimestampIndexSnafu, Result, TableExistsSnafu, + BuildRowKeyDescriptorSnafu, InvalidPrimaryKeySnafu, MissingTimestampIndexSnafu, Result, + TableExistsSnafu, }; use crate::table::MitoTable; @@ -57,8 +58,8 @@ fn region_id(table_id: TableId, n: u32) -> RegionId { } #[inline] -fn table_dir(schema_name: &str, table_name: &str) -> String { - format!("{}/{}/", schema_name, table_name) +fn table_dir(schema_name: &str, table_name: &str, table_id: TableId) -> String { + format!("{}/{}_{}/", schema_name, table_name, table_id) } /// [TableEngine] implementation. @@ -123,14 +124,14 @@ impl TableEngine for MitoEngine { async fn drop_table( &self, _ctx: &EngineContext, - _request: DropTableRequest, - ) -> TableResult<()> { - unimplemented!(); + request: DropTableRequest, + ) -> TableResult { + Ok(self.inner.drop_table(request).await?) } } struct MitoEngineInner { - /// All tables opened by the engine. + /// All tables opened by the engine. Map key is formatted [TableReference]. /// /// Writing to `tables` should also hold the `table_mutex`. tables: RwLock>, @@ -248,6 +249,27 @@ fn build_column_family( )) } +fn validate_create_table_request(request: &CreateTableRequest) -> Result<()> { + let ts_index = request + .schema + .timestamp_index() + .context(MissingTimestampIndexSnafu { + table_name: &request.table_name, + })?; + + ensure!( + !request + .primary_key_indices + .iter() + .any(|index| *index == ts_index), + InvalidPrimaryKeySnafu { + msg: "time index column can't be included in primary key" + } + ); + + Ok(()) +} + impl MitoEngineInner { async fn create_table( &self, @@ -263,6 +285,8 @@ impl MitoEngineInner { table: table_name, }; + validate_create_table_request(&request)?; + if let Some(table) = self.get_table(&table_ref) { if request.create_if_not_exists { return Ok(table); @@ -317,7 +341,7 @@ impl MitoEngineInner { } } - let table_dir = table_dir(schema_name, table_name); + let table_dir = table_dir(schema_name, table_name, table_id); let opts = CreateOptions { parent_dir: table_dir.clone(), }; @@ -396,13 +420,13 @@ impl MitoEngineInner { return Ok(Some(table)); } + let table_id = request.table_id; let engine_ctx = StorageEngineContext::default(); - let table_dir = table_dir(schema_name, table_name); + let table_dir = table_dir(schema_name, table_name, table_id); let opts = OpenOptions { parent_dir: table_dir.to_string(), }; - let table_id = request.table_id; // TODO(dennis): supports multi regions; assert_eq!(request.region_numbers.len(), 1); let region_number = request.region_numbers[0]; @@ -464,6 +488,22 @@ impl MitoEngineInner { .context(error::AlterTableSnafu { table_name })?; Ok(table) } + + /// Drop table. Returns whether a table is dropped (true) or not exist (false). + async fn drop_table(&self, req: DropTableRequest) -> Result { + let table_reference = TableReference { + catalog: &req.catalog_name, + schema: &req.schema_name, + table: &req.table_name, + }; + // todo(ruihang): reclaim persisted data + Ok(self + .tables + .write() + .unwrap() + .remove(&table_reference.to_string()) + .is_some()) + } } impl MitoEngineInner { @@ -626,8 +666,57 @@ mod tests { #[test] fn test_table_dir() { - assert_eq!("public/test_table/", table_dir("public", "test_table")); - assert_eq!("prometheus/demo/", table_dir("prometheus", "demo")); + assert_eq!( + "public/test_table_1024/", + table_dir("public", "test_table", 1024) + ); + assert_eq!( + "prometheus/demo_1024/", + table_dir("prometheus", "demo", 1024) + ); + } + + #[test] + fn test_validate_create_table_request() { + let table_name = "test_validate_create_table_request"; + let column_schemas = vec![ + ColumnSchema::new("name", ConcreteDataType::string_datatype(), false), + ColumnSchema::new( + "ts", + ConcreteDataType::timestamp_datatype(common_time::timestamp::TimeUnit::Millisecond), + true, + ) + .with_time_index(true), + ]; + + let schema = Arc::new( + SchemaBuilder::try_from(column_schemas) + .unwrap() + .build() + .expect("ts must be timestamp column"), + ); + + let mut request = CreateTableRequest { + id: 1, + catalog_name: "greptime".to_string(), + schema_name: "public".to_string(), + table_name: table_name.to_string(), + desc: Some("a test table".to_string()), + schema, + create_if_not_exists: true, + // put ts into primary keys + primary_key_indices: vec![0, 1], + table_options: HashMap::new(), + region_numbers: vec![0], + }; + + let err = validate_create_table_request(&request).unwrap_err(); + assert!(err + .to_string() + .contains("Invalid primary key: time index column can't be included in primary key")); + + request.primary_key_indices = vec![0]; + assert!(validate_create_table_request(&request).is_ok()); } #[tokio::test] @@ -961,4 +1050,69 @@ mod tests { assert_eq!(new_schema.timestamp_column(), old_schema.timestamp_column()); assert_eq!(new_schema.version(), old_schema.version() + 1); } + + #[tokio::test] + async fn test_drop_table() { + common_telemetry::init_default_ut_logging(); + let ctx = EngineContext::default(); + + let (_engine, table_engine, table, _object_store, _dir) = + test_util::setup_mock_engine_and_table().await; + let engine_ctx = EngineContext {}; + + let table_info = table.table_info(); + let table_reference = TableReference { + catalog: DEFAULT_CATALOG_NAME, + schema: DEFAULT_SCHEMA_NAME, + table: &table_info.name, + }; + + let create_table_request = CreateTableRequest { + id: 1, + catalog_name: DEFAULT_CATALOG_NAME.to_string(), + schema_name: DEFAULT_SCHEMA_NAME.to_string(), + table_name: table_info.name.to_string(), + schema: table_info.meta.schema.clone(), + create_if_not_exists: true, + desc: None, + primary_key_indices: Vec::default(), + table_options: HashMap::new(), + region_numbers: vec![0], + }; + + let created_table = table_engine + .create_table(&ctx, create_table_request) + .await + .unwrap(); + assert_eq!(table_info, created_table.table_info()); + assert!(table_engine.table_exists(&engine_ctx, &table_reference)); + + let drop_table_request = DropTableRequest { + catalog_name: table_reference.catalog.to_string(), + schema_name: table_reference.schema.to_string(), + table_name: table_reference.table.to_string(), + }; + let table_dropped = table_engine + .drop_table(&engine_ctx, drop_table_request) + .await + .unwrap(); + assert!(table_dropped); + assert!(!table_engine.table_exists(&engine_ctx, &table_reference)); + + // should be able to re-create + let request = CreateTableRequest { + id: 2, + catalog_name: DEFAULT_CATALOG_NAME.to_string(), + schema_name: DEFAULT_SCHEMA_NAME.to_string(), + table_name: table_info.name.to_string(), + schema: table_info.meta.schema.clone(), + create_if_not_exists: false, + desc: None, + primary_key_indices: Vec::default(), + table_options: HashMap::new(), + region_numbers: vec![0], + }; + table_engine.create_table(&ctx, request).await.unwrap(); + assert!(table_engine.table_exists(&engine_ctx, &table_reference)); + } } diff --git a/src/mito/src/error.rs b/src/mito/src/error.rs index ff3321ef7a..ff29e72a81 100644 --- a/src/mito/src/error.rs +++ b/src/mito/src/error.rs @@ -56,6 +56,9 @@ pub enum Error { backtrace: Backtrace, }, + #[snafu(display("Invalid primary key: {}", msg))] + InvalidPrimaryKey { msg: String, backtrace: Backtrace }, + #[snafu(display("Missing timestamp index for table: {}", table_name))] MissingTimestampIndex { table_name: String, @@ -214,6 +217,7 @@ impl ErrorExt for Error { | BuildRegionDescriptor { .. } | TableExists { .. } | ProjectedColumnNotFound { .. } + | InvalidPrimaryKey { .. } | MissingTimestampIndex { .. } | UnsupportedDefaultConstraint { .. } | TableNotFound { .. } => StatusCode::InvalidArguments, diff --git a/src/object-store/Cargo.toml b/src/object-store/Cargo.toml index e7e63109e1..a6bef20256 100644 --- a/src/object-store/Cargo.toml +++ b/src/object-store/Cargo.toml @@ -6,10 +6,11 @@ license = "Apache-2.0" [dependencies] futures = { version = "0.3" } -opendal = "0.20" +opendal = { version = "0.21", features = ["layers-tracing", "layers-metrics"] } tokio = { version = "1.0", features = ["full"] } [dev-dependencies] anyhow = "1.0" common-telemetry = { path = "../common/telemetry" } tempdir = "0.3" +uuid = { version = "1", features = ["serde", "v4"] } diff --git a/src/object-store/src/lib.rs b/src/object-store/src/lib.rs index 2be43fa5c7..7d6673d647 100644 --- a/src/object-store/src/lib.rs +++ b/src/object-store/src/lib.rs @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub use opendal::io_util::SeekableReader; +pub use opendal::raw::SeekableReader; pub use opendal::{ - layers, services, Accessor, Layer, Object, ObjectEntry, ObjectMetadata, ObjectMode, - ObjectStreamer, Operator as ObjectStore, + layers, services, Error, ErrorKind, Layer, Object, ObjectLister, ObjectMetadata, ObjectMode, + Operator as ObjectStore, Result, }; pub mod backend; +pub mod test_util; pub mod util; diff --git a/src/object-store/src/test_util.rs b/src/object-store/src/test_util.rs new file mode 100644 index 0000000000..d443aaf005 --- /dev/null +++ b/src/object-store/src/test_util.rs @@ -0,0 +1,35 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::{ObjectStore, Result}; + +pub struct TempFolder { + store: ObjectStore, + // The path under root. + path: String, +} + +impl TempFolder { + pub fn new(store: &ObjectStore, path: &str) -> Self { + Self { + store: store.clone(), + path: path.to_string(), + } + } + + pub async fn remove_all(&mut self) -> Result<()> { + let batch = self.store.batch(); + batch.remove_all(&self.path).await + } +} diff --git a/src/object-store/src/util.rs b/src/object-store/src/util.rs index 01bb9e5360..298069ab3b 100644 --- a/src/object-store/src/util.rs +++ b/src/object-store/src/util.rs @@ -14,9 +14,9 @@ use futures::TryStreamExt; -use crate::{ObjectEntry, ObjectStreamer}; +use crate::{Object, ObjectLister}; -pub async fn collect(stream: ObjectStreamer) -> Result, std::io::Error> { +pub async fn collect(stream: ObjectLister) -> Result, opendal::Error> { stream.try_collect::>().await } diff --git a/src/object-store/tests/object_store_test.rs b/src/object-store/tests/object_store_test.rs index 27fa76262b..33cba429fd 100644 --- a/src/object-store/tests/object_store_test.rs +++ b/src/object-store/tests/object_store_test.rs @@ -17,7 +17,8 @@ use std::env; use anyhow::Result; use common_telemetry::logging; use object_store::backend::{fs, s3}; -use object_store::{util, Object, ObjectMode, ObjectStore, ObjectStreamer}; +use object_store::test_util::TempFolder; +use object_store::{util, Object, ObjectLister, ObjectMode, ObjectStore}; use tempdir::TempDir; async fn test_object_crud(store: &ObjectStore) -> Result<()> { @@ -61,7 +62,7 @@ async fn test_object_list(store: &ObjectStore) -> Result<()> { // List objects let o: Object = store.object("/"); - let obs: ObjectStreamer = o.list().await?; + let obs: ObjectLister = o.list().await?; let objects = util::collect(obs).await?; assert_eq!(3, objects.len()); @@ -74,7 +75,7 @@ async fn test_object_list(store: &ObjectStore) -> Result<()> { assert_eq!(1, objects.len()); // Only o2 is exists - let o2 = &objects[0].clone().into_object(); + let o2 = &objects[0].clone(); let bs = o2.read().await?; assert_eq!("Hello, object2!", String::from_utf8(bs)?); // Delete o2 @@ -88,10 +89,12 @@ async fn test_object_list(store: &ObjectStore) -> Result<()> { #[tokio::test] async fn test_fs_backend() -> Result<()> { + let data_dir = TempDir::new("test_fs_backend")?; let tmp_dir = TempDir::new("test_fs_backend")?; let store = ObjectStore::new( fs::Builder::default() - .root(&tmp_dir.path().to_string_lossy()) + .root(&data_dir.path().to_string_lossy()) + .atomic_write_dir(&tmp_dir.path().to_string_lossy()) .build()?, ); @@ -108,15 +111,21 @@ async fn test_s3_backend() -> Result<()> { if !bucket.is_empty() { logging::info!("Running s3 test."); + let root = uuid::Uuid::new_v4().to_string(); + let accessor = s3::Builder::default() + .root(&root) .access_key_id(&env::var("GT_S3_ACCESS_KEY_ID")?) .secret_access_key(&env::var("GT_S3_ACCESS_KEY")?) .bucket(&bucket) .build()?; let store = ObjectStore::new(accessor); + + let mut guard = TempFolder::new(&store, "/"); test_object_crud(&store).await?; test_object_list(&store).await?; + guard.remove_all().await?; } } diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index 604da77377..34c3977d90 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -25,6 +25,7 @@ metrics = "0.20" once_cell = "1.10" serde = "1.0" serde_json = "1.0" +session = { path = "../session" } snafu = { version = "0.7", features = ["backtraces"] } sql = { path = "../sql" } table = { path = "../table" } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 7c65c758b3..8dda26a5db 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -32,6 +32,7 @@ use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream}; use common_telemetry::timer; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::ExecutionPlan; +use session::context::QueryContextRef; use snafu::{OptionExt, ResultExt}; use sql::dialect::GenericDialect; use sql::parser::ParserContext; @@ -46,7 +47,7 @@ use crate::physical_optimizer::PhysicalOptimizer; use crate::physical_planner::PhysicalPlanner; use crate::plan::LogicalPlan; use crate::planner::Planner; -use crate::query_engine::{QueryContext, QueryEngineState}; +use crate::query_engine::{QueryEngineContext, QueryEngineState}; use crate::{metric, QueryEngine}; pub(crate) struct DatafusionQueryEngine { @@ -61,6 +62,7 @@ impl DatafusionQueryEngine { } } +// TODO(LFC): Refactor consideration: extract a "Planner" that stores query context and execute queries inside. #[async_trait::async_trait] impl QueryEngine for DatafusionQueryEngine { fn name(&self) -> &str { @@ -75,21 +77,25 @@ impl QueryEngine for DatafusionQueryEngine { Ok(statement.remove(0)) } - fn statement_to_plan(&self, stmt: Statement) -> Result { - let context_provider = DfContextProviderAdapter::new(self.state.clone()); + fn statement_to_plan( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result { + let context_provider = DfContextProviderAdapter::new(self.state.clone(), query_ctx); let planner = DfPlanner::new(&context_provider); planner.statement_to_plan(stmt) } - fn sql_to_plan(&self, sql: &str) -> Result { + fn sql_to_plan(&self, sql: &str, query_ctx: QueryContextRef) -> Result { let _timer = timer!(metric::METRIC_PARSE_SQL_ELAPSED); let stmt = self.sql_to_statement(sql)?; - self.statement_to_plan(stmt) + self.statement_to_plan(stmt, query_ctx) } async fn execute(&self, plan: &LogicalPlan) -> Result { - let mut ctx = QueryContext::new(self.state.clone()); + let mut ctx = QueryEngineContext::new(self.state.clone()); let logical_plan = self.optimize_logical_plan(&mut ctx, plan)?; let physical_plan = self.create_physical_plan(&mut ctx, &logical_plan).await?; let physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?; @@ -100,7 +106,7 @@ impl QueryEngine for DatafusionQueryEngine { } async fn execute_physical(&self, plan: &Arc) -> Result { - let ctx = QueryContext::new(self.state.clone()); + let ctx = QueryEngineContext::new(self.state.clone()); Ok(Output::Stream(self.execute_stream(&ctx, plan).await?)) } @@ -127,7 +133,7 @@ impl QueryEngine for DatafusionQueryEngine { impl LogicalOptimizer for DatafusionQueryEngine { fn optimize_logical_plan( &self, - _ctx: &mut QueryContext, + _: &mut QueryEngineContext, plan: &LogicalPlan, ) -> Result { let _timer = timer!(metric::METRIC_OPTIMIZE_LOGICAL_ELAPSED); @@ -151,7 +157,7 @@ impl LogicalOptimizer for DatafusionQueryEngine { impl PhysicalPlanner for DatafusionQueryEngine { async fn create_physical_plan( &self, - _ctx: &mut QueryContext, + _: &mut QueryEngineContext, logical_plan: &LogicalPlan, ) -> Result> { let _timer = timer!(metric::METRIC_CREATE_PHYSICAL_ELAPSED); @@ -183,7 +189,7 @@ impl PhysicalPlanner for DatafusionQueryEngine { impl PhysicalOptimizer for DatafusionQueryEngine { fn optimize_physical_plan( &self, - _ctx: &mut QueryContext, + _: &mut QueryEngineContext, plan: Arc, ) -> Result> { let _timer = timer!(metric::METRIC_OPTIMIZE_PHYSICAL_ELAPSED); @@ -211,7 +217,7 @@ impl PhysicalOptimizer for DatafusionQueryEngine { impl QueryExecutor for DatafusionQueryEngine { async fn execute_stream( &self, - ctx: &QueryContext, + ctx: &QueryEngineContext, plan: &Arc, ) -> Result { let _timer = timer!(metric::METRIC_EXEC_PLAN_ELAPSED); @@ -250,6 +256,7 @@ mod tests { use common_recordbatch::util; use datafusion::field_util::{FieldExt, SchemaExt}; use datatypes::arrow::array::UInt64Array; + use session::context::QueryContext; use table::table::numbers::NumbersTable; use crate::query_engine::{QueryEngineFactory, QueryEngineRef}; @@ -277,7 +284,9 @@ mod tests { let engine = create_test_engine(); let sql = "select sum(number) from numbers limit 20"; - let plan = engine.sql_to_plan(sql).unwrap(); + let plan = engine + .sql_to_plan(sql, Arc::new(QueryContext::new())) + .unwrap(); assert_eq!( format!("{:?}", plan), @@ -293,7 +302,9 @@ mod tests { let engine = create_test_engine(); let sql = "select sum(number) from numbers limit 20"; - let plan = engine.sql_to_plan(sql).unwrap(); + let plan = engine + .sql_to_plan(sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); match output { diff --git a/src/query/src/datafusion/planner.rs b/src/query/src/datafusion/planner.rs index 6d5dcae527..6d70109e74 100644 --- a/src/query/src/datafusion/planner.rs +++ b/src/query/src/datafusion/planner.rs @@ -21,7 +21,9 @@ use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::udf::ScalarUDF; use datafusion::sql::planner::{ContextProvider, SqlToRel}; use datatypes::arrow::datatypes::DataType; +use session::context::QueryContextRef; use snafu::ResultExt; +use sql::statements::explain::Explain; use sql::statements::query::Query; use sql::statements::statement::Statement; @@ -53,6 +55,18 @@ impl<'a, S: ContextProvider + Send + Sync> DfPlanner<'a, S> { Ok(LogicalPlan::DfPlan(result)) } + + /// Converts EXPLAIN statement to logical plan. + pub fn explain_to_plan(&self, explain: Explain) -> Result { + let result = self + .sql_to_rel + .sql_statement_to_plan(explain.inner.clone()) + .context(error::PlanSqlSnafu { + sql: explain.to_string(), + })?; + + Ok(LogicalPlan::DfPlan(result)) + } } impl<'a, S> Planner for DfPlanner<'a, S> @@ -63,6 +77,7 @@ where fn statement_to_plan(&self, statement: Statement) -> Result { match statement { Statement::Query(qb) => self.query_to_plan(qb), + Statement::Explain(explain) => self.explain_to_plan(explain), Statement::ShowTables(_) | Statement::ShowDatabases(_) | Statement::ShowCreateTable(_) @@ -70,18 +85,21 @@ where | Statement::CreateTable(_) | Statement::CreateDatabase(_) | Statement::Alter(_) - | Statement::Insert(_) => unreachable!(), + | Statement::Insert(_) + | Statement::DropTable(_) + | Statement::Use(_) => unreachable!(), } } } pub(crate) struct DfContextProviderAdapter { state: QueryEngineState, + query_ctx: QueryContextRef, } impl DfContextProviderAdapter { - pub(crate) fn new(state: QueryEngineState) -> Self { - Self { state } + pub(crate) fn new(state: QueryEngineState, query_ctx: QueryContextRef) -> Self { + Self { state, query_ctx } } } @@ -89,11 +107,18 @@ impl DfContextProviderAdapter { /// manage UDFs, UDAFs, variables by ourself in future. impl ContextProvider for DfContextProviderAdapter { fn get_table_provider(&self, name: TableReference) -> Option> { - self.state - .df_context() - .state - .lock() - .get_table_provider(name) + let schema = self.query_ctx.current_schema(); + let execution_ctx = self.state.df_context().state.lock(); + match name { + TableReference::Bare { table } if schema.is_some() => { + execution_ctx.get_table_provider(TableReference::Partial { + // unwrap safety: checked in this match's arm + schema: &schema.unwrap(), + table, + }) + } + _ => execution_ctx.get_table_provider(name), + } } fn get_function_meta(&self, name: &str) -> Option> { diff --git a/src/query/src/executor.rs b/src/query/src/executor.rs index 46eeaf97fa..52664940fb 100644 --- a/src/query/src/executor.rs +++ b/src/query/src/executor.rs @@ -18,14 +18,14 @@ use common_query::physical_plan::PhysicalPlan; use common_recordbatch::SendableRecordBatchStream; use crate::error::Result; -use crate::query_engine::QueryContext; +use crate::query_engine::QueryEngineContext; /// Executor to run [ExecutionPlan]. #[async_trait::async_trait] pub trait QueryExecutor { async fn execute_stream( &self, - ctx: &QueryContext, + ctx: &QueryEngineContext, plan: &Arc, ) -> Result; } diff --git a/src/query/src/lib.rs b/src/query/src/lib.rs index 14aa4f773c..5b25707dc7 100644 --- a/src/query/src/lib.rs +++ b/src/query/src/lib.rs @@ -26,4 +26,6 @@ pub mod planner; pub mod query_engine; pub mod sql; -pub use crate::query_engine::{QueryContext, QueryEngine, QueryEngineFactory, QueryEngineRef}; +pub use crate::query_engine::{ + QueryEngine, QueryEngineContext, QueryEngineFactory, QueryEngineRef, +}; diff --git a/src/query/src/logical_optimizer.rs b/src/query/src/logical_optimizer.rs index 8c35f856a5..266a1a4233 100644 --- a/src/query/src/logical_optimizer.rs +++ b/src/query/src/logical_optimizer.rs @@ -14,12 +14,12 @@ use crate::error::Result; use crate::plan::LogicalPlan; -use crate::query_engine::QueryContext; +use crate::query_engine::QueryEngineContext; pub trait LogicalOptimizer { fn optimize_logical_plan( &self, - ctx: &mut QueryContext, + ctx: &mut QueryEngineContext, plan: &LogicalPlan, ) -> Result; } diff --git a/src/query/src/physical_optimizer.rs b/src/query/src/physical_optimizer.rs index c96d27a7f7..a75c629057 100644 --- a/src/query/src/physical_optimizer.rs +++ b/src/query/src/physical_optimizer.rs @@ -17,12 +17,12 @@ use std::sync::Arc; use common_query::physical_plan::PhysicalPlan; use crate::error::Result; -use crate::query_engine::QueryContext; +use crate::query_engine::QueryEngineContext; pub trait PhysicalOptimizer { fn optimize_physical_plan( &self, - ctx: &mut QueryContext, + ctx: &mut QueryEngineContext, plan: Arc, ) -> Result>; } diff --git a/src/query/src/physical_planner.rs b/src/query/src/physical_planner.rs index 2118f1cc82..40213a1346 100644 --- a/src/query/src/physical_planner.rs +++ b/src/query/src/physical_planner.rs @@ -18,7 +18,7 @@ use common_query::physical_plan::PhysicalPlan; use crate::error::Result; use crate::plan::LogicalPlan; -use crate::query_engine::QueryContext; +use crate::query_engine::QueryEngineContext; /// Physical query planner that converts a `LogicalPlan` to an /// `ExecutionPlan` suitable for execution. @@ -27,7 +27,7 @@ pub trait PhysicalPlanner { /// Create a physical plan from a logical plan async fn create_physical_plan( &self, - ctx: &mut QueryContext, + ctx: &mut QueryEngineContext, logical_plan: &LogicalPlan, ) -> Result>; } diff --git a/src/query/src/query_engine.rs b/src/query/src/query_engine.rs index 36d2c54d1a..110f78e6f3 100644 --- a/src/query/src/query_engine.rs +++ b/src/query/src/query_engine.rs @@ -23,12 +23,13 @@ use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY}; use common_query::physical_plan::PhysicalPlan; use common_query::prelude::ScalarUdf; use common_query::Output; +use session::context::QueryContextRef; use sql::statements::statement::Statement; use crate::datafusion::DatafusionQueryEngine; use crate::error::Result; use crate::plan::LogicalPlan; -pub use crate::query_engine::context::QueryContext; +pub use crate::query_engine::context::QueryEngineContext; pub use crate::query_engine::state::QueryEngineState; #[async_trait::async_trait] @@ -37,9 +38,10 @@ pub trait QueryEngine: Send + Sync { fn sql_to_statement(&self, sql: &str) -> Result; - fn statement_to_plan(&self, stmt: Statement) -> Result; + fn statement_to_plan(&self, stmt: Statement, query_ctx: QueryContextRef) + -> Result; - fn sql_to_plan(&self, sql: &str) -> Result; + fn sql_to_plan(&self, sql: &str, query_ctx: QueryContextRef) -> Result; async fn execute(&self, plan: &LogicalPlan) -> Result; diff --git a/src/query/src/query_engine/context.rs b/src/query/src/query_engine/context.rs index c5b5d20c2d..c54cb8b595 100644 --- a/src/query/src/query_engine/context.rs +++ b/src/query/src/query_engine/context.rs @@ -16,11 +16,11 @@ use crate::query_engine::state::QueryEngineState; #[derive(Debug)] -pub struct QueryContext { +pub struct QueryEngineContext { state: QueryEngineState, } -impl QueryContext { +impl QueryEngineContext { pub fn new(state: QueryEngineState) -> Self { Self { state } } diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index 51b81d01af..2854fed7fc 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -22,11 +22,15 @@ use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{Helper, StringVector}; use once_cell::sync::Lazy; +use session::context::QueryContextRef; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::describe::DescribeTable; +use sql::statements::explain::Explain; use sql::statements::show::{ShowDatabases, ShowKind, ShowTables}; +use sql::statements::statement::Statement; use crate::error::{self, Result}; +use crate::QueryEngineRef; const SCHEMAS_COLUMN: &str = "Schemas"; const TABLES_COLUMN: &str = "Tables"; @@ -106,7 +110,11 @@ pub fn show_databases(stmt: ShowDatabases, catalog_manager: CatalogManagerRef) - Ok(Output::RecordBatches(records)) } -pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Result { +pub fn show_tables( + stmt: ShowTables, + catalog_manager: CatalogManagerRef, + query_ctx: QueryContextRef, +) -> Result { // TODO(LFC): supports WHERE ensure!( matches!(stmt.kind, ShowKind::All | ShowKind::Like(_)), @@ -115,9 +123,15 @@ pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Resu } ); - let schema = stmt.database.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME); + let schema = if let Some(database) = stmt.database { + database + } else { + query_ctx + .current_schema() + .unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()) + }; let schema = catalog_manager - .schema(DEFAULT_CATALOG_NAME, schema) + .schema(DEFAULT_CATALOG_NAME, &schema) .context(error::CatalogSnafu)? .context(error::SchemaNotFoundSnafu { schema })?; let tables = schema.table_names().context(error::CatalogSnafu)?; @@ -138,6 +152,15 @@ pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Resu Ok(Output::RecordBatches(records)) } +pub async fn explain( + stmt: Box, + query_engine: QueryEngineRef, + query_ctx: QueryContextRef, +) -> Result { + let plan = query_engine.statement_to_plan(Statement::Explain(*stmt), query_ctx)?; + query_engine.execute(&plan).await +} + pub fn describe_table(stmt: DescribeTable, catalog_manager: CatalogManagerRef) -> Result { let catalog = stmt.catalog_name.as_str(); let schema = stmt.schema_name.as_str(); diff --git a/src/query/tests/argmax_test.rs b/src/query/tests/argmax_test.rs index 23ff4785ac..11f0167a09 100644 --- a/src/query/tests/argmax_test.rs +++ b/src/query/tests/argmax_test.rs @@ -24,6 +24,7 @@ use datatypes::types::PrimitiveElement; use function::{create_query_engine, get_numbers_from_table}; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; #[tokio::test] async fn test_argmax_aggregator() -> Result<()> { @@ -95,7 +96,9 @@ async fn execute_argmax<'a>( "select ARGMAX({}) as argmax from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/argmin_test.rs b/src/query/tests/argmin_test.rs index 0e02f9e4a2..2a509f05fd 100644 --- a/src/query/tests/argmin_test.rs +++ b/src/query/tests/argmin_test.rs @@ -25,6 +25,7 @@ use datatypes::types::PrimitiveElement; use function::{create_query_engine, get_numbers_from_table}; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; #[tokio::test] async fn test_argmin_aggregator() -> Result<()> { @@ -96,7 +97,9 @@ async fn execute_argmin<'a>( "select argmin({}) as argmin from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/function.rs b/src/query/tests/function.rs index f5ecba91ee..040dfa7a6b 100644 --- a/src/query/tests/function.rs +++ b/src/query/tests/function.rs @@ -27,6 +27,7 @@ use datatypes::vectors::PrimitiveVector; use query::query_engine::QueryEngineFactory; use query::QueryEngine; use rand::Rng; +use session::context::QueryContext; use table::test_util::MemTable; pub fn create_query_engine() -> Arc { @@ -80,7 +81,9 @@ where for<'a> T: Scalar = T>, { let sql = format!("SELECT {} FROM {}", column_name, table_name); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/mean_test.rs b/src/query/tests/mean_test.rs index 1b068f2456..705dea797d 100644 --- a/src/query/tests/mean_test.rs +++ b/src/query/tests/mean_test.rs @@ -28,6 +28,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; #[tokio::test] async fn test_mean_aggregator() -> Result<()> { @@ -89,7 +90,9 @@ async fn execute_mean<'a>( engine: Arc, ) -> RecordResult> { let sql = format!("select MEAN({}) as mean from {}", column_name, table_name); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/my_sum_udaf_example.rs b/src/query/tests/my_sum_udaf_example.rs index dbd0427752..4e05183861 100644 --- a/src/query/tests/my_sum_udaf_example.rs +++ b/src/query/tests/my_sum_udaf_example.rs @@ -36,6 +36,7 @@ use datatypes::with_match_primitive_type_id; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngineFactory; +use session::context::QueryContext; use table::test_util::MemTable; #[derive(Debug, Default)] @@ -228,7 +229,7 @@ where "select MY_SUM({}) as my_sum from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql)?; + let plan = engine.sql_to_plan(&sql, Arc::new(QueryContext::new()))?; let output = engine.execute(&plan).await?; let recordbatch_stream = match output { diff --git a/src/query/tests/percentile_test.rs b/src/query/tests/percentile_test.rs index c504da231a..6e210a0494 100644 --- a/src/query/tests/percentile_test.rs +++ b/src/query/tests/percentile_test.rs @@ -30,6 +30,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::{QueryEngine, QueryEngineFactory}; +use session::context::QueryContext; use table::test_util::MemTable; #[tokio::test] @@ -53,7 +54,9 @@ async fn test_percentile_aggregator() -> Result<()> { async fn test_percentile_correctness() -> Result<()> { let engine = create_correctness_engine(); let sql = String::from("select PERCENTILE(corr_number,88.0) as percentile from corr_numbers"); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { @@ -113,7 +116,9 @@ async fn execute_percentile<'a>( "select PERCENTILE({},50.0) as percentile from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/polyval_test.rs b/src/query/tests/polyval_test.rs index 55285e20e1..f2e60c0217 100644 --- a/src/query/tests/polyval_test.rs +++ b/src/query/tests/polyval_test.rs @@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; #[tokio::test] async fn test_polyval_aggregator() -> Result<()> { @@ -92,7 +93,9 @@ async fn execute_polyval<'a>( "select POLYVAL({}, 0) as polyval from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/query_engine_test.rs b/src/query/tests/query_engine_test.rs index 26afd8c9cc..cf640afba4 100644 --- a/src/query/tests/query_engine_test.rs +++ b/src/query/tests/query_engine_test.rs @@ -37,6 +37,7 @@ use query::plan::LogicalPlan; use query::query_engine::QueryEngineFactory; use query::QueryEngine; use rand::Rng; +use session::context::QueryContext; use table::table::adapter::DfTableProviderAdapter; use table::table::numbers::NumbersTable; use table::test_util::MemTable; @@ -134,7 +135,10 @@ async fn test_udf() -> Result<()> { engine.register_udf(udf); - let plan = engine.sql_to_plan("select pow(number, number) as p from numbers limit 10")?; + let plan = engine.sql_to_plan( + "select pow(number, number) as p from numbers limit 10", + Arc::new(QueryContext::new()), + )?; let output = engine.execute(&plan).await?; let recordbatch = match output { @@ -242,7 +246,9 @@ where for<'a> T: Scalar = T>, { let sql = format!("SELECT {} FROM {}", column_name, table_name); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { @@ -330,7 +336,9 @@ async fn execute_median<'a>( "select MEDIAN({}) as median from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/scipy_stats_norm_cdf_test.rs b/src/query/tests/scipy_stats_norm_cdf_test.rs index 572a433683..815501a314 100644 --- a/src/query/tests/scipy_stats_norm_cdf_test.rs +++ b/src/query/tests/scipy_stats_norm_cdf_test.rs @@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; use statrs::distribution::{ContinuousCDF, Normal}; use statrs::statistics::Statistics; @@ -94,7 +95,9 @@ async fn execute_scipy_stats_norm_cdf<'a>( "select SCIPYSTATSNORMCDF({},2.0) as scipy_stats_norm_cdf from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/query/tests/scipy_stats_norm_pdf.rs b/src/query/tests/scipy_stats_norm_pdf.rs index efbf0ddec3..dd5e0fc7fc 100644 --- a/src/query/tests/scipy_stats_norm_pdf.rs +++ b/src/query/tests/scipy_stats_norm_pdf.rs @@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table}; use num_traits::AsPrimitive; use query::error::Result; use query::QueryEngine; +use session::context::QueryContext; use statrs::distribution::{Continuous, Normal}; use statrs::statistics::Statistics; @@ -94,7 +95,9 @@ async fn execute_scipy_stats_norm_pdf<'a>( "select SCIPYSTATSNORMPDF({},2.0) as scipy_stats_norm_pdf from {}", column_name, table_name ); - let plan = engine.sql_to_plan(&sql).unwrap(); + let plan = engine + .sql_to_plan(&sql, Arc::new(QueryContext::new())) + .unwrap(); let output = engine.execute(&plan).await.unwrap(); let recordbatch_stream = match output { diff --git a/src/script/Cargo.toml b/src/script/Cargo.toml index 2fa0b15286..ff69720053 100644 --- a/src/script/Cargo.toml +++ b/src/script/Cargo.toml @@ -7,16 +7,16 @@ license = "Apache-2.0" [features] default = ["python"] python = [ - "dep:datafusion", - "dep:datafusion-expr", - "dep:datafusion-physical-expr", - "dep:rustpython-vm", - "dep:rustpython-parser", - "dep:rustpython-compiler", - "dep:rustpython-compiler-core", - "dep:rustpython-bytecode", - "dep:rustpython-ast", - "dep:paste", + "dep:datafusion", + "dep:datafusion-expr", + "dep:datafusion-physical-expr", + "dep:rustpython-vm", + "dep:rustpython-parser", + "dep:rustpython-compiler", + "dep:rustpython-compiler-core", + "dep:rustpython-bytecode", + "dep:rustpython-ast", + "dep:paste", ] [dependencies] @@ -45,9 +45,10 @@ rustpython-compiler = { git = "https://github.com/RustPython/RustPython", option rustpython-compiler-core = { git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d" } rustpython-parser = { git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d" } rustpython-vm = { git = "https://github.com/RustPython/RustPython", optional = true, rev = "02a1d1d", features = [ - "default", - "freeze-stdlib", + "default", + "freeze-stdlib", ] } +session = { path = "../session" } snafu = { version = "0.7", features = ["backtraces"] } sql = { path = "../sql" } table = { path = "../table" } @@ -55,10 +56,10 @@ tokio = { version = "1.0", features = ["full"] } [dev-dependencies] log-store = { path = "../log-store" } +mito = { path = "../mito", features = ["test"] } ron = "0.7" serde = { version = "1.0", features = ["derive"] } storage = { path = "../storage" } -mito = { path = "../mito", features = ["test"] } tempdir = "0.3" tokio = { version = "1.18", features = ["full"] } tokio-test = "0.4" diff --git a/src/script/src/python/coprocessor.rs b/src/script/src/python/coprocessor.rs index 214deaa24b..bb32494dbf 100644 --- a/src/script/src/python/coprocessor.rs +++ b/src/script/src/python/coprocessor.rs @@ -15,11 +15,13 @@ pub mod compile; pub mod parse; +use std::cell::RefCell; use std::collections::HashMap; use std::result::Result as StdResult; use std::sync::Arc; use common_recordbatch::RecordBatch; +use common_telemetry::info; use datafusion_common::record_batch::RecordBatch as DfRecordBatch; use datatypes::arrow; use datatypes::arrow::array::{Array, ArrayRef}; @@ -46,6 +48,8 @@ use crate::python::error::{ use crate::python::utils::{format_py_error, is_instance, py_vec_obj_to_array}; use crate::python::PyVector; +thread_local!(static INTERPRETER: RefCell>> = RefCell::new(None)); + #[cfg_attr(test, derive(Deserialize))] #[derive(Debug, Clone, PartialEq, Eq)] pub struct AnnotationInfo { @@ -114,11 +118,12 @@ impl Coprocessor { let AnnotationInfo { datatype: ty, is_nullable, - } = anno[idx].to_owned().unwrap_or_else(|| - // default to be not nullable and use DataType inferred by PyVector itself - AnnotationInfo{ - datatype: Some(real_ty.to_owned()), - is_nullable: false + } = anno[idx].to_owned().unwrap_or_else(|| { + // default to be not nullable and use DataType inferred by PyVector itself + AnnotationInfo { + datatype: Some(real_ty.to_owned()), + is_nullable: false, + } }); Field::new( name, @@ -282,7 +287,7 @@ fn check_args_anno_real_type( anno_ty .to_owned() .map(|v| v.datatype == None // like a vector[_] - || v.datatype == Some(real_ty.to_owned()) && v.is_nullable == is_nullable) + || v.datatype == Some(real_ty.to_owned()) && v.is_nullable == is_nullable) .unwrap_or(true), OtherSnafu { reason: format!( @@ -380,7 +385,7 @@ pub(crate) fn exec_with_cached_vm( copr: &Coprocessor, rb: &DfRecordBatch, args: Vec, - vm: &Interpreter, + vm: &Arc, ) -> Result { vm.enter(|vm| -> Result { PyVector::make_class(&vm.ctx); @@ -421,10 +426,18 @@ pub(crate) fn exec_with_cached_vm( } /// init interpreter with type PyVector and Module: greptime -pub(crate) fn init_interpreter() -> Interpreter { - vm::Interpreter::with_init(Default::default(), |vm| { - PyVector::make_class(&vm.ctx); - vm.add_native_module("greptime", Box::new(greptime_builtin::make_module)); +pub(crate) fn init_interpreter() -> Arc { + INTERPRETER.with(|i| { + i.borrow_mut() + .get_or_insert_with(|| { + let interpreter = Arc::new(vm::Interpreter::with_init(Default::default(), |vm| { + PyVector::make_class(&vm.ctx); + vm.add_native_module("greptime", Box::new(greptime_builtin::make_module)); + })); + info!("Initialized Python interpreter."); + interpreter + }) + .clone() }) } diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index 7a6c02c858..7ad5390f7b 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -26,6 +26,7 @@ use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStre use datatypes::schema::SchemaRef; use futures::Stream; use query::QueryEngineRef; +use session::context::QueryContext; use snafu::{ensure, ResultExt}; use sql::statements::statement::Statement; @@ -93,7 +94,9 @@ impl Script for PyScript { matches!(stmt, Statement::Query { .. }), error::UnsupportedSqlSnafu { sql } ); - let plan = self.query_engine.statement_to_plan(stmt)?; + let plan = self + .query_engine + .statement_to_plan(stmt, Arc::new(QueryContext::new()))?; let res = self.query_engine.execute(&plan).await?; let copr = self.copr.clone(); match res { diff --git a/src/script/src/python/vector.rs b/src/script/src/python/vector.rs index 2ebc3e0ad4..4a432df602 100644 --- a/src/script/src/python/vector.rs +++ b/src/script/src/python/vector.rs @@ -1115,12 +1115,13 @@ pub mod tests { } pub fn execute_script( + interpreter: &rustpython_vm::Interpreter, script: &str, test_vec: Option, predicate: PredicateFn, ) -> Result<(PyObjectRef, Option), PyRef> { let mut pred_res = None; - rustpython_vm::Interpreter::without_stdlib(Default::default()) + interpreter .enter(|vm| { PyVector::make_class(&vm.ctx); let scope = vm.new_scope_with_builtins(); @@ -1208,8 +1209,10 @@ pub mod tests { Some(|v, vm| is_eq(v, 2.0, vm)), ), ]; + + let interpreter = rustpython_vm::Interpreter::without_stdlib(Default::default()); for (code, pred) in snippet { - let result = execute_script(code, None, pred); + let result = execute_script(&interpreter, code, None, pred); println!( "\u{001B}[35m{code}\u{001B}[0m: {:?}{}", diff --git a/src/script/src/table.rs b/src/script/src/table.rs index 6eac358244..abc0279a3f 100644 --- a/src/script/src/table.rs +++ b/src/script/src/table.rs @@ -28,6 +28,7 @@ use datatypes::prelude::{ConcreteDataType, ScalarVector}; use datatypes::schema::{ColumnSchema, Schema, SchemaBuilder}; use datatypes::vectors::{StringVector, TimestampVector, VectorRef}; use query::QueryEngineRef; +use session::context::QueryContext; use snafu::{ensure, OptionExt, ResultExt}; use table::requests::{CreateTableRequest, InsertRequest}; @@ -59,9 +60,9 @@ impl ScriptsTable { table_name: SCRIPTS_TABLE_NAME.to_string(), desc: Some("Scripts table".to_string()), schema, - // name and timestamp as primary key region_numbers: vec![0], - primary_key_indices: vec![0, 3], + // name as primary key + primary_key_indices: vec![0], create_if_not_exists: true, table_options: HashMap::default(), }; @@ -151,7 +152,7 @@ impl ScriptsTable { let plan = self .query_engine - .sql_to_plan(&sql) + .sql_to_plan(&sql, Arc::new(QueryContext::new())) .context(FindScriptSnafu { name })?; let stream = match self diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 5a74c223fa..2e2c133416 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -5,11 +5,11 @@ edition = "2021" license = "Apache-2.0" [dependencies] -aide = { version = "0.6", features = ["axum"] } +aide = { version = "0.9", features = ["axum"] } api = { path = "../api" } async-trait = "0.1" -axum = "0.6.0-rc.2" -axum-macros = "0.3.0-rc.1" +axum = "0.6" +axum-macros = "0.3" bytes = "1.2" common-base = { path = "../common/base" } common-catalog = { path = "../common/catalog" } @@ -23,24 +23,29 @@ common-time = { path = "../common/time" } datatypes = { path = "../datatypes" } futures = "0.3" hex = { version = "0.4" } +humantime-serde = "1.1" hyper = { version = "0.14", features = ["full"] } influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", branch = "feat/line-protocol" } metrics = "0.20" num_cpus = "1.13" once_cell = "1.16" openmetrics-parser = "0.4" -opensrv-mysql = "0.2" +opensrv-mysql = "0.3" pgwire = "0.5" prost = "0.11" -regex = "1.6" rand = "0.8" +regex = "1.6" +rustls = "0.20" +rustls-pemfile = "1.0" schemars = "0.8" serde = "1.0" serde_json = "1.0" +session = { path = "../session" } snafu = { version = "0.7", features = ["backtraces"] } snap = "1" table = { path = "../table" } tokio = { version = "1.20", features = ["full"] } +tokio-rustls = "0.23" tokio-stream = { version = "0.1", features = ["net"] } tonic = "0.8" tonic-reflection = "0.5" @@ -51,10 +56,14 @@ tower-http = { version = "0.3", features = ["full"] } axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } catalog = { path = "../catalog" } common-base = { path = "../common/base" } -mysql_async = { git = "https://github.com/Morranto/mysql_async.git", rev = "127b538" } +mysql_async = { version = "0.31", default-features = false, features = [ + "default-rustls", +] } query = { path = "../query" } rand = "0.8" script = { path = "../script", features = ["python"] } +serde_json = "1.0" table = { path = "../table" } tokio-postgres = "0.7" +tokio-postgres-rustls = "0.9" tokio-test = "0.4" diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 62b3d7a814..7f6d46d1e5 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -192,6 +192,9 @@ pub enum Error { err_msg: String, backtrace: Backtrace, }, + + #[snafu(display("Tls is required for {}, plain connection is rejected", server))] + TlsRequired { server: String }, } pub type Result = std::result::Result; @@ -234,6 +237,7 @@ impl ErrorExt for Error { InfluxdbLinesWrite { source, .. } => source.status_code(), Hyper { .. } => StatusCode::Unknown, + TlsRequired { .. } => StatusCode::Unknown, StartFrontend { source, .. } => source.status_code(), } } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 1b90286737..c12403bff2 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -59,6 +59,7 @@ const HTTP_API_VERSION: &str = "v1"; pub struct HttpServer { sql_handler: SqlQueryHandlerRef, + options: HttpOptions, influxdb_handler: Option, opentsdb_handler: Option, prom_handler: Option, @@ -66,6 +67,22 @@ pub struct HttpServer { shutdown_tx: Mutex>>, } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct HttpOptions { + pub addr: String, + #[serde(with = "humantime_serde")] + pub timeout: Duration, +} + +impl Default for HttpOptions { + fn default() -> Self { + Self { + addr: "127.0.0.1:4000".to_string(), + timeout: Duration::from_secs(30), + } + } +} + #[derive(Debug, Serialize, Deserialize, JsonSchema, Eq, PartialEq)] pub struct ColumnSchema { name: String, @@ -168,7 +185,7 @@ impl TryFrom> for HttpRecordsOutput { } } -#[derive(Serialize, Deserialize, Debug, JsonSchema)] +#[derive(Serialize, Deserialize, Debug, JsonSchema, Eq, PartialEq)] #[serde(rename_all = "lowercase")] pub enum JsonOutput { AffectedRows(usize), @@ -271,9 +288,10 @@ pub struct ApiState { } impl HttpServer { - pub fn new(sql_handler: SqlQueryHandlerRef) -> Self { + pub fn new(sql_handler: SqlQueryHandlerRef, options: HttpOptions) -> Self { Self { sql_handler, + options, opentsdb_handler: None, influxdb_handler: None, prom_handler: None, @@ -326,71 +344,91 @@ impl HttpServer { url: format!("/{}", HTTP_API_VERSION), ..OpenAPIServer::default() }], - ..OpenApi::default() }; - // TODO(LFC): Use released Axum. - // Axum version 0.6 introduces state within router, making router methods far more elegant - // to write. Though version 0.6 is rc, I think it's worth to upgrade. - // Prior to version 0.6, we only have a single "Extension" to share all query - // handlers amongst router methods. That requires us to pack all query handlers in a shared - // state, and check-then-get the desired query handler in different router methods, which - // is a lot of tedious work. - let sql_router = ApiRouter::with_state(ApiState { - sql_handler: self.sql_handler.clone(), - script_handler: self.script_handler.clone(), - }) - .api_route( - "/sql", - apirouting::get_with(handler::sql, handler::sql_docs) - .post_with(handler::sql, handler::sql_docs), - ) - .api_route("/scripts", apirouting::post(script::scripts)) - .api_route("/run-script", apirouting::post(script::run_script)) - .route("/private/api.json", apirouting::get(serve_api)) - .route("/private/docs", apirouting::get(serve_docs)) - .finish_api(&mut api) - .layer(Extension(Arc::new(api))); + let sql_router = self + .route_sql(ApiState { + sql_handler: self.sql_handler.clone(), + script_handler: self.script_handler.clone(), + }) + .finish_api(&mut api) + .layer(Extension(Arc::new(api))); let mut router = Router::new().nest(&format!("/{}", HTTP_API_VERSION), sql_router); if let Some(opentsdb_handler) = self.opentsdb_handler.clone() { - let opentsdb_router = Router::with_state(opentsdb_handler) - .route("/api/put", routing::post(opentsdb::put)); - - router = router.nest(&format!("/{}/opentsdb", HTTP_API_VERSION), opentsdb_router); + router = router.nest( + &format!("/{}/opentsdb", HTTP_API_VERSION), + self.route_opentsdb(opentsdb_handler), + ); } if let Some(influxdb_handler) = self.influxdb_handler.clone() { - let influxdb_router = - Router::with_state(influxdb_handler).route("/write", routing::post(influxdb_write)); - - router = router.nest(&format!("/{}/influxdb", HTTP_API_VERSION), influxdb_router); + router = router.nest( + &format!("/{}/influxdb", HTTP_API_VERSION), + self.route_influxdb(influxdb_handler), + ); } if let Some(prom_handler) = self.prom_handler.clone() { - let prom_router = Router::with_state(prom_handler) - .route("/write", routing::post(prometheus::remote_write)) - .route("/read", routing::post(prometheus::remote_read)); - - router = router.nest(&format!("/{}/prometheus", HTTP_API_VERSION), prom_router); + router = router.nest( + &format!("/{}/prometheus", HTTP_API_VERSION), + self.route_prom(prom_handler), + ); } router = router.route("/metrics", routing::get(handler::metrics)); + router = router.route( + "/health", + routing::get(handler::health).post(handler::health), + ); + router // middlewares .layer( ServiceBuilder::new() .layer(HandleErrorLayer::new(handle_error)) .layer(TraceLayer::new_for_http()) - // TODO(LFC): make timeout configurable - .layer(TimeoutLayer::new(Duration::from_secs(30))) + .layer(TimeoutLayer::new(self.options.timeout)) // custom layer .layer(middleware::from_fn(context::build_ctx)), ) } + + fn route_sql(&self, api_state: ApiState) -> ApiRouter { + ApiRouter::new() + .api_route( + "/sql", + apirouting::get_with(handler::sql, handler::sql_docs) + .post_with(handler::sql, handler::sql_docs), + ) + .api_route("/scripts", apirouting::post(script::scripts)) + .api_route("/run-script", apirouting::post(script::run_script)) + .route("/private/api.json", apirouting::get(serve_api)) + .route("/private/docs", apirouting::get(serve_docs)) + .with_state(api_state) + } + + fn route_prom(&self, prom_handler: PrometheusProtocolHandlerRef) -> Router { + Router::new() + .route("/write", routing::post(prometheus::remote_write)) + .route("/read", routing::post(prometheus::remote_read)) + .with_state(prom_handler) + } + + fn route_influxdb(&self, influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router { + Router::new() + .route("/write", routing::post(influxdb_write)) + .with_state(influxdb_handler) + } + + fn route_opentsdb(&self, opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router { + Router::new() + .route("/api/put", routing::post(opentsdb::put)) + .with_state(opentsdb_handler) + } } #[async_trait] @@ -443,14 +481,72 @@ async fn handle_error(err: BoxError) -> Json { #[cfg(test)] mod test { + use std::future::pending; use std::sync::Arc; + use axum::handler::Handler; + use axum::http::StatusCode; + use axum::routing::get; + use axum_test_helper::TestClient; use common_recordbatch::RecordBatches; use datatypes::prelude::*; use datatypes::schema::{ColumnSchema, Schema}; use datatypes::vectors::{StringVector, UInt32Vector}; + use session::context::QueryContextRef; + use tokio::sync::mpsc; use super::*; + use crate::query_handler::SqlQueryHandler; + + struct DummyInstance { + _tx: mpsc::Sender<(String, Vec)>, + } + + #[async_trait] + impl SqlQueryHandler for DummyInstance { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { + unimplemented!() + } + } + + fn timeout() -> TimeoutLayer { + TimeoutLayer::new(Duration::from_millis(10)) + } + + async fn forever() { + pending().await + } + + fn make_test_app(tx: mpsc::Sender<(String, Vec)>) -> Router { + let instance = Arc::new(DummyInstance { _tx: tx }); + let server = HttpServer::new(instance, HttpOptions::default()); + server.make_app().route( + "/test/timeout", + get(forever.layer( + ServiceBuilder::new() + .layer(HandleErrorLayer::new(|_: BoxError| async { + StatusCode::REQUEST_TIMEOUT + })) + .layer(timeout()), + )), + ) + } + + #[test] + fn test_http_options_default() { + let default = HttpOptions::default(); + assert_eq!("127.0.0.1:4000".to_string(), default.addr); + assert_eq!(Duration::from_secs(30), default.timeout) + } + + #[tokio::test] + async fn test_http_server_request_timeout() { + let (tx, _rx) = mpsc::channel(100); + let app = make_test_app(tx); + let client = TestClient::new(app); + let res = client.get("/test/timeout").send().await; + assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); + } #[tokio::test] async fn test_recordbatches_conversion() { diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 69b41edd1b..a730d59ff4 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::sync::Arc; use std::time::Instant; use aide::transform::TransformOperation; @@ -21,6 +22,7 @@ use common_error::status_code::StatusCode; use common_telemetry::metric; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use session::context::QueryContext; use crate::http::{ApiState, JsonResponse}; @@ -39,7 +41,9 @@ pub async fn sql( let sql_handler = &state.sql_handler; let start = Instant::now(); let resp = if let Some(sql) = ¶ms.sql { - JsonResponse::from_output(sql_handler.do_query(sql).await).await + // TODO(LFC): Sessions in http server. + let query_ctx = Arc::new(QueryContext::new()); + JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await } else { JsonResponse::with_error( "sql parameter is required.".to_string(), @@ -63,3 +67,17 @@ pub async fn metrics(Query(_params): Query>) -> String { "Prometheus handle not initialized.".to_owned() } } + +#[derive(Debug, Serialize, Deserialize, JsonSchema)] +pub struct HealthQuery {} + +#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] +pub struct HealthResponse {} + +/// Handler to export healthy check +/// +/// Currently simply return status "200 OK" (default) with an empty json payload "{}" +#[axum_macros::debug_handler] +pub async fn health(Query(_params): Query) -> Json { + Json(HealthResponse {}) +} diff --git a/src/servers/src/influxdb.rs b/src/servers/src/influxdb.rs index df1835efcd..0766d65843 100644 --- a/src/servers/src/influxdb.rs +++ b/src/servers/src/influxdb.rs @@ -14,7 +14,6 @@ use std::collections::HashMap; -use api::v1::insert_expr::{self, Expr}; use api::v1::InsertExpr; use common_grpc::writer::{LinesWriter, Precision}; use influxdb_line_protocol::{parse_lines, FieldValue}; @@ -165,14 +164,15 @@ impl TryFrom<&InfluxdbRequest> for Vec { Ok(writers .into_iter() - .map(|(table_name, writer)| InsertExpr { - schema_name: schema_name.clone(), - table_name, - expr: Some(Expr::Values(insert_expr::Values { - values: vec![writer.finish().into()], - })), - options: HashMap::default(), - region_number: 0, + .map(|(table_name, writer)| { + let (columns, row_count) = writer.finish(); + InsertExpr { + schema_name: schema_name.clone(), + table_name, + region_number: 0, + columns, + row_count, + } }) .collect()) } @@ -180,12 +180,9 @@ impl TryFrom<&InfluxdbRequest> for Vec { #[cfg(test)] mod tests { - use std::ops::Deref; use std::sync::Arc; - use api::v1::codec::InsertBatch; use api::v1::column::{SemanticType, Values}; - use api::v1::insert_expr::Expr; use api::v1::{Column, ColumnDataType, InsertExpr}; use common_base::BitVec; use common_time::timestamp::TimeUnit; @@ -242,15 +239,9 @@ monitor2,host=host4 cpu=66.3,memory=1029 1663840496400340003"; for expr in insert_exprs { assert_eq!("public", expr.schema_name); - let values = match expr.expr.unwrap() { - Expr::Values(vals) => vals, - Expr::Sql(_) => panic!(), - }; - let raw_batch = values.values.get(0).unwrap(); - let batch: InsertBatch = raw_batch.deref().try_into().unwrap(); match &expr.table_name[..] { - "monitor1" => assert_monitor_1(&batch), - "monitor2" => assert_monitor_2(&batch), + "monitor1" => assert_monitor_1(&expr.columns), + "monitor2" => assert_monitor_2(&expr.columns), _ => panic!(), } } @@ -327,8 +318,7 @@ monitor2,host=host4 cpu=66.3,memory=1029 1663840496400340003"; } } - fn assert_monitor_1(insert_batch: &InsertBatch) { - let columns = &insert_batch.columns; + fn assert_monitor_1(columns: &[Column]) { assert_eq!(4, columns.len()); verify_column( &columns[0], @@ -379,8 +369,7 @@ monitor2,host=host4 cpu=66.3,memory=1029 1663840496400340003"; ); } - fn assert_monitor_2(insert_batch: &InsertBatch) { - let columns = &insert_batch.columns; + fn assert_monitor_2(columns: &[Column]) { assert_eq!(4, columns.len()); verify_column( &columns[0], diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 34b4f367a8..da6e8306eb 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -28,6 +28,8 @@ pub mod postgres; pub mod prometheus; pub mod query_handler; pub mod server; +pub mod tls; + mod shutdown; #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] diff --git a/src/servers/src/mysql/federated.rs b/src/servers/src/mysql/federated.rs index 8aa3f369fe..f2f1a8caed 100644 --- a/src/servers/src/mysql/federated.rs +++ b/src/servers/src/mysql/federated.rs @@ -26,21 +26,25 @@ use datatypes::vectors::StringVector; use once_cell::sync::Lazy; use regex::bytes::RegexSet; use regex::Regex; +use session::context::QueryContextRef; // TODO(LFC): Include GreptimeDB's version and git commit tag etc. const MYSQL_VERSION: &str = "8.0.26"; static SELECT_VAR_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap()); static MYSQL_CONN_JAVA_PATTERN: Lazy = - Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-java(.*))").unwrap()); + Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-j(.*))").unwrap()); static SHOW_LOWER_CASE_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'lower_case_table_names'(.*))").unwrap()); static SHOW_COLLATION_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(show collation where(.*))").unwrap()); static SHOW_VARIABLES_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES(.*))").unwrap()); + static SELECT_VERSION_PATTERN: Lazy = Lazy::new(|| Regex::new(r"(?i)^(SELECT VERSION\(\s*\))").unwrap()); +static SELECT_DATABASE_PATTERN: Lazy = + Lazy::new(|| Regex::new(r"(?i)^(SELECT DATABASE\(\s*\))").unwrap()); // SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP()); static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy = @@ -248,13 +252,18 @@ fn check_show_variables(query: &str) -> Option { } // Check for SET or others query, this is the final check of the federated query. -fn check_others(query: &str) -> Option { +fn check_others(query: &str, query_ctx: QueryContextRef) -> Option { if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) { return Some(Output::RecordBatches(RecordBatches::empty())); } let recordbatches = if SELECT_VERSION_PATTERN.is_match(query) { Some(select_function("version()", MYSQL_VERSION)) + } else if SELECT_DATABASE_PATTERN.is_match(query) { + let schema = query_ctx + .current_schema() + .unwrap_or_else(|| "NULL".to_string()); + Some(select_function("database()", &schema)) } else if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) { Some(select_function( "TIMEDIFF(NOW(), UTC_TIMESTAMP())", @@ -268,7 +277,7 @@ fn check_others(query: &str) -> Option { // Check whether the query is a federated or driver setup command, // and return some faked results if there are any. -pub fn check(query: &str) -> Option { +pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option { // First to check the query is like "select @@variables". let output = check_select_variable(query); if output.is_some() { @@ -282,25 +291,27 @@ pub fn check(query: &str) -> Option { } // Last check. - check_others(query) + check_others(query, query_ctx) } #[cfg(test)] mod test { + use session::context::QueryContext; + use super::*; #[test] fn test_check() { let query = "select 1"; - let result = check(query); + let result = check(query, Arc::new(QueryContext::new())); assert!(result.is_none()); let query = "select versiona"; - let output = check(query); + let output = check(query, Arc::new(QueryContext::new())); assert!(output.is_none()); fn test(query: &str, expected: Vec<&str>) { - let output = check(query); + let output = check(query, Arc::new(QueryContext::new())); match output.unwrap() { Output::RecordBatches(r) => { assert_eq!(r.pretty_print().lines().collect::>(), expected) diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index c1614377a7..2884b3e4bf 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -16,11 +16,13 @@ use std::sync::Arc; use std::time::Instant; use async_trait::async_trait; +use common_query::Output; use common_telemetry::{debug, error}; use opensrv_mysql::{ - AsyncMysqlShim, ErrorKind, ParamParser, QueryResultWriter, StatementMetaWriter, + AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter, }; use rand::RngCore; +use session::Session; use tokio::io::AsyncWrite; use tokio::sync::RwLock; @@ -36,7 +38,9 @@ pub struct MysqlInstanceShim { query_handler: SqlQueryHandlerRef, salt: [u8; 20], client_addr: String, + // TODO(LFC): Break `Context` struct into different fields in `Session`, each with its own purpose. ctx: Arc>>, + session: Arc, } impl MysqlInstanceShim { @@ -59,8 +63,33 @@ impl MysqlInstanceShim { salt: scramble, client_addr, ctx: Arc::new(RwLock::new(None)), + session: Arc::new(Session::new()), } } + + async fn do_query(&self, query: &str) -> Result { + debug!("Start executing query: '{}'", query); + let start = Instant::now(); + + // TODO(LFC): Find a better way to deal with these special federated queries: + // `check` uses regex to filter out unsupported statements emitted by MySQL's federated + // components, this is quick and dirty, there must be a better way to do it. + let output = + if let Some(output) = crate::mysql::federated::check(query, self.session.context()) { + Ok(output) + } else { + self.query_handler + .do_query(query, self.session.context()) + .await + }; + + debug!( + "Finished executing query: '{}', total time costs in microseconds: {}", + query, + start.elapsed().as_micros() + ); + output + } } #[async_trait] @@ -144,25 +173,20 @@ impl AsyncMysqlShim for MysqlInstanceShi query: &'a str, writer: QueryResultWriter<'a, W>, ) -> Result<()> { - debug!("Start executing query: '{}'", query); - let start = Instant::now(); - - // TODO(LFC): Find a better way: - // `check` uses regex to filter out unsupported statements emitted by MySQL's federated - // components, this is quick and dirty, there must be a better way to do it. - let output = if let Some(output) = crate::mysql::federated::check(query) { - Ok(output) - } else { - self.query_handler.do_query(query).await - }; - - debug!( - "Finished executing query: '{}', total time costs in microseconds: {}", - query, - start.elapsed().as_micros() - ); - + let output = self.do_query(query).await; let mut writer = MysqlResultWriter::new(writer); writer.write(query, output).await } + + async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> { + let query = format!("USE {}", database.trim()); + let output = self.do_query(&query).await; + if let Err(e) = output { + w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes()) + .await + } else { + w.ok().await + } + .map_err(|e| e.into()) + } } diff --git a/src/servers/src/mysql/server.rs b/src/servers/src/mysql/server.rs index f66669303c..0e3104a633 100644 --- a/src/servers/src/mysql/server.rs +++ b/src/servers/src/mysql/server.rs @@ -20,15 +20,19 @@ use async_trait::async_trait; use common_runtime::Runtime; use common_telemetry::logging::{error, info}; use futures::StreamExt; -use opensrv_mysql::AsyncMysqlIntermediary; +use opensrv_mysql::{ + plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions, +}; use tokio; use tokio::io::BufWriter; use tokio::net::TcpStream; +use tokio_rustls::rustls::ServerConfig; -use crate::error::Result; +use crate::error::{Error, Result}; use crate::mysql::handler::MysqlInstanceShim; use crate::query_handler::SqlQueryHandlerRef; use crate::server::{AbortableStream, BaseTcpServer, Server}; +use crate::tls::TlsOption; // Default size of ResultSet write buffer: 100KB const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024; @@ -36,16 +40,19 @@ const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024; pub struct MysqlServer { base_server: BaseTcpServer, query_handler: SqlQueryHandlerRef, + tls: Arc, } impl MysqlServer { pub fn create_server( query_handler: SqlQueryHandlerRef, io_runtime: Arc, + tls: Arc, ) -> Box { Box::new(MysqlServer { base_server: BaseTcpServer::create_server("MySQL", io_runtime), query_handler, + tls, }) } @@ -53,16 +60,22 @@ impl MysqlServer { &self, io_runtime: Arc, stream: AbortableStream, + tls_conf: Option>, ) -> impl Future { let query_handler = self.query_handler.clone(); + let force_tls = self.tls.should_force_tls(); + stream.for_each(move |tcp_stream| { let io_runtime = io_runtime.clone(); let query_handler = query_handler.clone(); + let tls_conf = tls_conf.clone(); async move { match tcp_stream { Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt. Ok(io_stream) => { - if let Err(error) = Self::handle(io_stream, io_runtime, query_handler).await + if let Err(error) = + Self::handle(io_stream, io_runtime, query_handler, tls_conf, force_tls) + .await { error!(error; "Unexpected error when handling TcpStream"); }; @@ -76,28 +89,49 @@ impl MysqlServer { stream: TcpStream, io_runtime: Arc, query_handler: SqlQueryHandlerRef, + tls_conf: Option>, + force_tls: bool, ) -> Result<()> { info!("MySQL connection coming from: {}", stream.peer_addr()?); - let shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string()); - - let (r, w) = stream.into_split(); - let w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w); - // TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client. - let spawn_result = io_runtime - .spawn(AsyncMysqlIntermediary::run_on(shim, r, w)) - .await; - match spawn_result { - Ok(run_result) => { - if let Err(e) = run_result { - // TODO(LFC): Write this error and the below one to client as well, in MySQL text protocol. - // Looks like we have to expose opensrv-mysql's `PacketWriter`? - error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.") - } + io_runtime .spawn(async move { + // TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client. + if let Err(e) = Self::do_handle(stream, query_handler, tls_conf, force_tls).await { + // TODO(LFC): Write this error to client as well, in MySQL text protocol. + // Looks like we have to expose opensrv-mysql's `PacketWriter`? + error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.") } - Err(e) => error!("IO runtime cannot execute task, error: {}", e), - } + }); + Ok(()) } + + async fn do_handle( + stream: TcpStream, + query_handler: SqlQueryHandlerRef, + tls_conf: Option>, + force_tls: bool, + ) -> Result<()> { + let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string()); + let (mut r, w) = stream.into_split(); + let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w); + let ops = IntermediaryOptions::default(); + + let (client_tls, init_params) = + AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &tls_conf).await?; + + if force_tls && !client_tls { + return Err(Error::TlsRequired { + server: "mysql".to_owned(), + }); + } + + match tls_conf { + Some(tls_conf) if client_tls => { + secure_run_with_options(shim, w, ops, tls_conf, init_params).await + } + _ => plain_run_with_options(shim, w, ops, init_params).await, + } + } } #[async_trait] @@ -110,7 +144,10 @@ impl Server for MysqlServer { let (stream, addr) = self.base_server.bind(listening).await?; let io_runtime = self.base_server.io_runtime(); - let join_handle = tokio::spawn(self.accept(io_runtime, stream)); + + let tls_conf = self.tls.setup()?.map(Arc::new); + + let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_conf)); self.base_server.start_with(join_handle).await?; Ok(addr) } diff --git a/src/servers/src/opentsdb/codec.rs b/src/servers/src/opentsdb/codec.rs index 2253fd7e63..260a206fe5 100644 --- a/src/servers/src/opentsdb/codec.rs +++ b/src/servers/src/opentsdb/codec.rs @@ -12,11 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; - -use api::v1::codec::InsertBatch; use api::v1::column::SemanticType; -use api::v1::{column, insert_expr, Column, ColumnDataType, InsertExpr}; +use api::v1::{column, Column, ColumnDataType, InsertExpr}; use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_grpc::writer::Precision; use table::requests::InsertRequest; @@ -189,18 +186,12 @@ impl DataPoint { }); } - let batch = InsertBatch { - columns, - row_count: 1, - }; InsertExpr { schema_name, table_name: self.metric.clone(), - expr: Some(insert_expr::Expr::Values(insert_expr::Values { - values: vec![batch.into()], - })), - options: HashMap::default(), region_number: 0, + columns, + row_count: 1, } } @@ -337,36 +328,31 @@ mod test { let grpc_insert = data_point.as_grpc_insert(); assert_eq!(grpc_insert.table_name, "my_metric_1"); - match grpc_insert.expr { - Some(insert_expr::Expr::Values(insert_expr::Values { values })) => { - assert_eq!(values.len(), 1); - let insert_batch = InsertBatch::try_from(values[0].as_slice()).unwrap(); - assert_eq!(insert_batch.row_count, 1); - let columns = insert_batch.columns; - assert_eq!(columns.len(), 4); + let columns = &grpc_insert.columns; + let row_count = grpc_insert.row_count; - assert_eq!(columns[0].column_name, OPENTSDB_TIMESTAMP_COLUMN_NAME); - assert_eq!( - columns[0].values.as_ref().unwrap().ts_millis_values, - vec![1000] - ); + assert_eq!(row_count, 1); + assert_eq!(columns.len(), 4); - assert_eq!(columns[1].column_name, OPENTSDB_VALUE_COLUMN_NAME); - assert_eq!(columns[1].values.as_ref().unwrap().f64_values, vec![1.0]); + assert_eq!(columns[0].column_name, OPENTSDB_TIMESTAMP_COLUMN_NAME); + assert_eq!( + columns[0].values.as_ref().unwrap().ts_millis_values, + vec![1000] + ); - assert_eq!(columns[2].column_name, "tagk1"); - assert_eq!( - columns[2].values.as_ref().unwrap().string_values, - vec!["tagv1"] - ); + assert_eq!(columns[1].column_name, OPENTSDB_VALUE_COLUMN_NAME); + assert_eq!(columns[1].values.as_ref().unwrap().f64_values, vec![1.0]); - assert_eq!(columns[3].column_name, "tagk2"); - assert_eq!( - columns[3].values.as_ref().unwrap().string_values, - vec!["tagv2"] - ); - } - _ => unreachable!(), - } + assert_eq!(columns[2].column_name, "tagk1"); + assert_eq!( + columns[2].values.as_ref().unwrap().string_values, + vec!["tagv1"] + ); + + assert_eq!(columns[3].column_name, "tagk2"); + assert_eq!( + columns[3].values.as_ref().unwrap().string_values, + vec!["tagv2"] + ); } } diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index c17053bfa1..2fe02b45a1 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -63,14 +63,16 @@ pub struct PgAuthStartupHandler { verifier: PgPwdVerifier, param_provider: GreptimeDBStartupParameters, with_pwd: bool, + force_tls: bool, } impl PgAuthStartupHandler { - pub fn new(with_pwd: bool) -> Self { + pub fn new(with_pwd: bool, force_tls: bool) -> Self { PgAuthStartupHandler { verifier: PgPwdVerifier, param_provider: GreptimeDBStartupParameters::new(), with_pwd, + force_tls, } } } @@ -89,6 +91,20 @@ impl StartupHandler for PgAuthStartupHandler { { match message { PgWireFrontendMessage::Startup(ref startup) => { + if !client.is_secure() && self.force_tls { + let error_info = ErrorInfo::new( + "FATAL".to_owned(), + "28000".to_owned(), + "No encryption".to_owned(), + ); + let error = ErrorResponse::from(error_info); + + client + .feed(PgWireBackendMessage::ErrorResponse(error)) + .await?; + client.close().await?; + return Ok(()); + } auth::save_startup_parameters_to_metadata(client, startup); if self.with_pwd { client.set_state(PgWireConnectionState::AuthenticationInProgress); diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 59c6cc2ea8..6cf82465a0 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::ops::Deref; +use std::sync::Arc; use async_trait::async_trait; use common_query::Output; @@ -26,6 +27,7 @@ use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{text_query_response, FieldInfo, Response, Tag, TextDataRowEncoder}; use pgwire::api::{ClientInfo, Type}; use pgwire::error::{PgWireError, PgWireResult}; +use session::context::QueryContext; use crate::error::{self, Error, Result}; use crate::query_handler::SqlQueryHandlerRef; @@ -40,15 +42,30 @@ impl PostgresServerHandler { } } +const CLIENT_METADATA_DATABASE: &str = "database"; + +fn query_context_from_client_info(client: &C) -> Arc +where + C: ClientInfo, +{ + let query_context = QueryContext::new(); + if let Some(current_schema) = client.metadata().get(CLIENT_METADATA_DATABASE) { + query_context.set_current_schema(current_schema); + } + + Arc::new(query_context) +} + #[async_trait] impl SimpleQueryHandler for PostgresServerHandler { - async fn do_query(&self, _client: &C, query: &str) -> PgWireResult> + async fn do_query(&self, client: &C, query: &str) -> PgWireResult> where C: ClientInfo + Unpin + Send + Sync, { + let query_ctx = query_context_from_client_info(client); let output = self .query_handler - .do_query(query) + .do_query(query, query_ctx) .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 61adc65d63..0845c20904 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -22,17 +22,20 @@ use common_telemetry::logging::error; use futures::StreamExt; use pgwire::tokio::process_socket; use tokio; +use tokio_rustls::TlsAcceptor; use crate::error::Result; use crate::postgres::auth_handler::PgAuthStartupHandler; use crate::postgres::handler::PostgresServerHandler; use crate::query_handler::SqlQueryHandlerRef; use crate::server::{AbortableStream, BaseTcpServer, Server}; +use crate::tls::TlsOption; pub struct PostgresServer { base_server: BaseTcpServer, auth_handler: Arc, query_handler: Arc, + tls: Arc, } impl PostgresServer { @@ -40,14 +43,17 @@ impl PostgresServer { pub fn new( query_handler: SqlQueryHandlerRef, check_pwd: bool, + tls: Arc, io_runtime: Arc, ) -> PostgresServer { let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler)); - let startup_handler = Arc::new(PgAuthStartupHandler::new(check_pwd)); + let startup_handler = + Arc::new(PgAuthStartupHandler::new(check_pwd, tls.should_force_tls())); PostgresServer { base_server: BaseTcpServer::create_server("Postgres", io_runtime), auth_handler: startup_handler, query_handler: postgres_handler, + tls, } } @@ -55,6 +61,7 @@ impl PostgresServer { &self, io_runtime: Arc, accepting_stream: AbortableStream, + tls_acceptor: Option>, ) -> impl Future { let auth_handler = self.auth_handler.clone(); let query_handler = self.query_handler.clone(); @@ -63,6 +70,7 @@ impl PostgresServer { let io_runtime = io_runtime.clone(); let auth_handler = auth_handler.clone(); let query_handler = query_handler.clone(); + let tls_acceptor = tls_acceptor.clone(); async move { match tcp_stream { @@ -70,7 +78,7 @@ impl PostgresServer { Ok(io_stream) => { io_runtime.spawn(process_socket( io_stream, - None, + tls_acceptor.clone(), auth_handler.clone(), query_handler.clone(), query_handler.clone(), @@ -91,8 +99,14 @@ impl Server for PostgresServer { async fn start(&self, listening: SocketAddr) -> Result { let (stream, addr) = self.base_server.bind(listening).await?; + let tls_acceptor = self + .tls + .setup()? + .map(|server_conf| Arc::new(TlsAcceptor::from(Arc::new(server_conf)))); + let io_runtime = self.base_server.io_runtime(); - let join_handle = tokio::spawn(self.accept(io_runtime, stream)); + let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_acceptor)); + self.base_server.start_with(join_handle).await?; Ok(addr) } diff --git a/src/servers/src/prometheus.rs b/src/servers/src/prometheus.rs index 48045b066e..1c2b035ec0 100644 --- a/src/servers/src/prometheus.rs +++ b/src/servers/src/prometheus.rs @@ -14,14 +14,14 @@ //! prometheus protocol supportings use std::cmp::Ordering; -use std::collections::{BTreeMap, HashMap}; +use std::collections::BTreeMap; use std::hash::{Hash, Hasher}; use api::prometheus::remote::label_matcher::Type as MatcherType; use api::prometheus::remote::{Label, Query, Sample, TimeSeries, WriteRequest}; -use api::v1::codec::{InsertBatch, SelectResult}; +use api::v1::codec::SelectResult; use api::v1::column::SemanticType; -use api::v1::{column, insert_expr, Column, ColumnDataType, InsertExpr}; +use api::v1::{column, Column, ColumnDataType, InsertExpr}; use common_grpc::writer::Precision::MILLISECOND; use openmetrics_parser::{MetricsExposition, PrometheusType, PrometheusValue}; use snafu::{OptionExt, ResultExt}; @@ -413,21 +413,14 @@ fn timeseries_to_insert_expr(database: &str, mut timeseries: TimeSeries) -> Resu }); } - let batch = InsertBatch { - columns, - row_count: row_count as u32, - }; Ok(InsertExpr { schema_name, table_name: table_name.context(error::InvalidPromRemoteRequestSnafu { msg: "missing '__name__' label in timeseries", })?, - - expr: Some(insert_expr::Expr::Values(insert_expr::Values { - values: vec![batch.into()], - })), - options: HashMap::default(), region_number: 0, + columns, + row_count: row_count as u32, }) } @@ -683,105 +676,93 @@ mod tests { assert_eq!("metric2", exprs[1].table_name); assert_eq!("metric3", exprs[2].table_name); - let values = exprs[0].clone().expr.unwrap(); - match values { - insert_expr::Expr::Values(insert_expr::Values { values }) => { - assert_eq!(1, values.len()); - let batch = InsertBatch::try_from(values[0].as_slice()).unwrap(); - assert_eq!(2, batch.row_count); - let columns = batch.columns; - assert_eq!(columns.len(), 3); + let expr = exprs.get(0).unwrap(); - assert_eq!(columns[0].column_name, TIMESTAMP_COLUMN_NAME); - assert_eq!( - columns[0].values.as_ref().unwrap().ts_millis_values, - vec![1000, 2000] - ); + let columns = &expr.columns; + let row_count = expr.row_count; - assert_eq!(columns[1].column_name, VALUE_COLUMN_NAME); - assert_eq!( - columns[1].values.as_ref().unwrap().f64_values, - vec![1.0, 2.0] - ); + assert_eq!(2, row_count); + assert_eq!(columns.len(), 3); - assert_eq!(columns[2].column_name, "job"); - assert_eq!( - columns[2].values.as_ref().unwrap().string_values, - vec!["spark", "spark"] - ); - } - _ => unreachable!(), - } + assert_eq!(columns[0].column_name, TIMESTAMP_COLUMN_NAME); + assert_eq!( + columns[0].values.as_ref().unwrap().ts_millis_values, + vec![1000, 2000] + ); - let values = exprs[1].clone().expr.unwrap(); - match values { - insert_expr::Expr::Values(insert_expr::Values { values }) => { - assert_eq!(1, values.len()); - let batch = InsertBatch::try_from(values[0].as_slice()).unwrap(); - assert_eq!(2, batch.row_count); - let columns = batch.columns; - assert_eq!(columns.len(), 4); + assert_eq!(columns[1].column_name, VALUE_COLUMN_NAME); + assert_eq!( + columns[1].values.as_ref().unwrap().f64_values, + vec![1.0, 2.0] + ); - assert_eq!(columns[0].column_name, TIMESTAMP_COLUMN_NAME); - assert_eq!( - columns[0].values.as_ref().unwrap().ts_millis_values, - vec![1000, 2000] - ); + assert_eq!(columns[2].column_name, "job"); + assert_eq!( + columns[2].values.as_ref().unwrap().string_values, + vec!["spark", "spark"] + ); - assert_eq!(columns[1].column_name, VALUE_COLUMN_NAME); - assert_eq!( - columns[1].values.as_ref().unwrap().f64_values, - vec![3.0, 4.0] - ); + let expr = exprs.get(1).unwrap(); - assert_eq!(columns[2].column_name, "instance"); - assert_eq!( - columns[2].values.as_ref().unwrap().string_values, - vec!["test_host1", "test_host1"] - ); - assert_eq!(columns[3].column_name, "idc"); - assert_eq!( - columns[3].values.as_ref().unwrap().string_values, - vec!["z001", "z001"] - ); - } - _ => unreachable!(), - } + let columns = &expr.columns; + let row_count = expr.row_count; - let values = exprs[2].clone().expr.unwrap(); - match values { - insert_expr::Expr::Values(insert_expr::Values { values }) => { - assert_eq!(1, values.len()); - let batch = InsertBatch::try_from(values[0].as_slice()).unwrap(); - assert_eq!(3, batch.row_count); - let columns = batch.columns; - assert_eq!(columns.len(), 4); + assert_eq!(2, row_count); + assert_eq!(columns.len(), 4); - assert_eq!(columns[0].column_name, TIMESTAMP_COLUMN_NAME); - assert_eq!( - columns[0].values.as_ref().unwrap().ts_millis_values, - vec![1000, 2000, 3000] - ); + assert_eq!(columns[0].column_name, TIMESTAMP_COLUMN_NAME); + assert_eq!( + columns[0].values.as_ref().unwrap().ts_millis_values, + vec![1000, 2000] + ); - assert_eq!(columns[1].column_name, VALUE_COLUMN_NAME); - assert_eq!( - columns[1].values.as_ref().unwrap().f64_values, - vec![5.0, 6.0, 7.0] - ); + assert_eq!(columns[1].column_name, VALUE_COLUMN_NAME); + assert_eq!( + columns[1].values.as_ref().unwrap().f64_values, + vec![3.0, 4.0] + ); - assert_eq!(columns[2].column_name, "idc"); - assert_eq!( - columns[2].values.as_ref().unwrap().string_values, - vec!["z002", "z002", "z002"] - ); - assert_eq!(columns[3].column_name, "app"); - assert_eq!( - columns[3].values.as_ref().unwrap().string_values, - vec!["biz", "biz", "biz"] - ); - } - _ => unreachable!(), - } + assert_eq!(columns[2].column_name, "instance"); + assert_eq!( + columns[2].values.as_ref().unwrap().string_values, + vec!["test_host1", "test_host1"] + ); + assert_eq!(columns[3].column_name, "idc"); + assert_eq!( + columns[3].values.as_ref().unwrap().string_values, + vec!["z001", "z001"] + ); + + let expr = exprs.get(2).unwrap(); + + let columns = &expr.columns; + let row_count = expr.row_count; + + assert_eq!(3, row_count); + assert_eq!(columns.len(), 4); + + assert_eq!(columns[0].column_name, TIMESTAMP_COLUMN_NAME); + assert_eq!( + columns[0].values.as_ref().unwrap().ts_millis_values, + vec![1000, 2000, 3000] + ); + + assert_eq!(columns[1].column_name, VALUE_COLUMN_NAME); + assert_eq!( + columns[1].values.as_ref().unwrap().f64_values, + vec![5.0, 6.0, 7.0] + ); + + assert_eq!(columns[2].column_name, "idc"); + assert_eq!( + columns[2].values.as_ref().unwrap().string_values, + vec!["z002", "z002", "z002"] + ); + assert_eq!(columns[3].column_name, "app"); + assert_eq!( + columns[3].values.as_ref().unwrap().string_values, + vec!["biz", "biz", "biz"] + ); } #[test] diff --git a/src/servers/src/query_handler.rs b/src/servers/src/query_handler.rs index ff76bebdc5..d9a48ba30f 100644 --- a/src/servers/src/query_handler.rs +++ b/src/servers/src/query_handler.rs @@ -18,6 +18,7 @@ use api::prometheus::remote::{ReadRequest, WriteRequest}; use api::v1::{AdminExpr, AdminResult, ObjectExpr, ObjectResult}; use async_trait::async_trait; use common_query::Output; +use session::context::QueryContextRef; use crate::error::Result; use crate::influxdb::InfluxdbRequest; @@ -44,7 +45,7 @@ pub type ScriptHandlerRef = Arc; #[async_trait] pub trait SqlQueryHandler { - async fn do_query(&self, query: &str) -> Result; + async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result; } #[async_trait] diff --git a/src/servers/src/tls.rs b/src/servers/src/tls.rs new file mode 100644 index 0000000000..57a3c78621 --- /dev/null +++ b/src/servers/src/tls.rs @@ -0,0 +1,177 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fs::File; +use std::io::{BufReader, Error, ErrorKind}; + +use rustls::{Certificate, PrivateKey, ServerConfig}; +use rustls_pemfile::{certs, pkcs8_private_keys}; +use serde::{Deserialize, Serialize}; + +/// TlsMode is used for Mysql and Postgres server start up. +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub enum TlsMode { + #[default] + Disable, + Prefer, + Require, + // TODO(SSebo): Implement the following 2 TSL mode described in + // ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html) + VerifyCa, + VerifyFull, +} + +#[derive(Debug, Default, Serialize, Deserialize, Clone)] +#[serde(rename_all = "snake_case")] +pub struct TlsOption { + pub mode: TlsMode, + #[serde(default)] + pub cert_path: String, + #[serde(default)] + pub key_path: String, +} + +impl TlsOption { + pub fn setup(&self) -> Result, Error> { + if let TlsMode::Disable = self.mode { + return Ok(None); + } + let cert = certs(&mut BufReader::new(File::open(&self.cert_path)?)) + .map_err(|_| Error::new(ErrorKind::InvalidInput, "invalid cert")) + .map(|mut certs| certs.drain(..).map(Certificate).collect())?; + + // TODO(SSebo): support more private key types + let key = pkcs8_private_keys(&mut BufReader::new(File::open(&self.key_path)?)) + .map_err(|_| Error::new(ErrorKind::InvalidInput, "invalid key")) + .map(|mut keys| keys.drain(..).map(PrivateKey).next())? + .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "invalid key"))?; + + // TODO(SSebo): with_client_cert_verifier if TlsMode is Required. + let config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert, key) + .map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?; + + Ok(Some(config)) + } + + pub fn should_force_tls(&self) -> bool { + !matches!(self.mode, TlsMode::Disable | TlsMode::Prefer) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tls_option_disable() { + let s = r#" + { + "mode": "disable" + } + "#; + + let t: TlsOption = serde_json::from_str(s).unwrap(); + + assert!(!t.should_force_tls()); + + assert!(matches!(t.mode, TlsMode::Disable)); + assert!(t.key_path.is_empty()); + assert!(t.cert_path.is_empty()); + + let setup = t.setup(); + assert!(setup.is_ok()); + let setup = setup.unwrap(); + assert!(setup.is_none()); + } + + #[test] + fn test_tls_option_prefer() { + let s = r#" + { + "mode": "prefer", + "cert_path": "/some_dir/some.crt", + "key_path": "/some_dir/some.key" + } + "#; + + let t: TlsOption = serde_json::from_str(s).unwrap(); + + assert!(!t.should_force_tls()); + + assert!(matches!(t.mode, TlsMode::Prefer)); + assert!(!t.key_path.is_empty()); + assert!(!t.cert_path.is_empty()); + } + + #[test] + fn test_tls_option_require() { + let s = r#" + { + "mode": "require", + "cert_path": "/some_dir/some.crt", + "key_path": "/some_dir/some.key" + } + "#; + + let t: TlsOption = serde_json::from_str(s).unwrap(); + + assert!(t.should_force_tls()); + + assert!(matches!(t.mode, TlsMode::Require)); + assert!(!t.key_path.is_empty()); + assert!(!t.cert_path.is_empty()); + } + + #[test] + fn test_tls_option_verifiy_ca() { + let s = r#" + { + "mode": "verify_ca", + "cert_path": "/some_dir/some.crt", + "key_path": "/some_dir/some.key" + } + "#; + + let t: TlsOption = serde_json::from_str(s).unwrap(); + + assert!(t.should_force_tls()); + + assert!(matches!(t.mode, TlsMode::VerifyCa)); + assert!(!t.key_path.is_empty()); + assert!(!t.cert_path.is_empty()); + } + + #[test] + fn test_tls_option_verifiy_full() { + let s = r#" + { + "mode": "verify_full", + "cert_path": "/some_dir/some.crt", + "key_path": "/some_dir/some.key" + } + "#; + + let t: TlsOption = serde_json::from_str(s).unwrap(); + + assert!(t.should_force_tls()); + + assert!(matches!(t.mode, TlsMode::VerifyFull)); + assert!(!t.key_path.is_empty()); + assert!(!t.cert_path.is_empty()); + } +} diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index ddec8eca54..f15a96dac0 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -135,3 +135,18 @@ fn create_query() -> Query { database: None, }) } + +/// Currently the payload of response should be simply an empty json "{}"; +#[tokio::test] +async fn test_health() { + let expected_json = http_handler::HealthResponse {}; + let expected_json_str = "{}".to_string(); + + let query = http_handler::HealthQuery {}; + let Json(json) = http_handler::health(Query(query)).await; + assert_eq!(json, expected_json); + assert_eq!( + serde_json::ser::to_string(&json).unwrap(), + expected_json_str + ); +} diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 24a21716cf..e81df37e66 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -20,9 +20,10 @@ use axum::Router; use axum_test_helper::TestClient; use common_query::Output; use servers::error::Result; -use servers::http::HttpServer; +use servers::http::{HttpOptions, HttpServer}; use servers::influxdb::InfluxdbRequest; use servers::query_handler::{InfluxdbLineProtocolHandler, SqlQueryHandler}; +use session::context::QueryContextRef; use tokio::sync::mpsc; struct DummyInstance { @@ -44,14 +45,14 @@ impl InfluxdbLineProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _query: &str) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { unimplemented!() } } fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router { let instance = Arc::new(DummyInstance { tx }); - let mut server = HttpServer::new(instance.clone()); + let mut server = HttpServer::new(instance.clone(), HttpOptions::default()); server.set_influxdb_handler(instance); server.make_app() } diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index f66281302e..3b51f66965 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -19,9 +19,10 @@ use axum::Router; use axum_test_helper::TestClient; use common_query::Output; use servers::error::{self, Result}; -use servers::http::HttpServer; +use servers::http::{HttpOptions, HttpServer}; use servers::opentsdb::codec::DataPoint; use servers::query_handler::{OpentsdbProtocolHandler, SqlQueryHandler}; +use session::context::QueryContextRef; use tokio::sync::mpsc; struct DummyInstance { @@ -44,14 +45,14 @@ impl OpentsdbProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _query: &str) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { unimplemented!() } } fn make_test_app(tx: mpsc::Sender) -> Router { let instance = Arc::new(DummyInstance { tx }); - let mut server = HttpServer::new(instance.clone()); + let mut server = HttpServer::new(instance.clone(), HttpOptions::default()); server.set_opentsdb_handler(instance); server.make_app() } diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index a5a3274dc5..b7df350505 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -23,10 +23,11 @@ use axum_test_helper::TestClient; use common_query::Output; use prost::Message; use servers::error::Result; -use servers::http::HttpServer; +use servers::http::{HttpOptions, HttpServer}; use servers::prometheus; use servers::prometheus::{snappy_compress, Metrics}; use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse, SqlQueryHandler}; +use session::context::QueryContextRef; use tokio::sync::mpsc; struct DummyInstance { @@ -69,14 +70,14 @@ impl PrometheusProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _query: &str) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { unimplemented!() } } fn make_test_app(tx: mpsc::Sender<(String, Vec)>) -> Router { let instance = Arc::new(DummyInstance { tx }); - let mut server = HttpServer::new(instance.clone()); + let mut server = HttpServer::new(instance.clone(), HttpOptions::default()); server.set_prom_handler(instance); server.make_app() } diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index a5663dddd1..63c8e2ebe2 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -31,6 +31,8 @@ mod http; mod mysql; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; +use session::context::QueryContextRef; + mod opentsdb; mod postgres; @@ -52,8 +54,8 @@ impl DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, query: &str) -> Result { - let plan = self.query_engine.sql_to_plan(query).unwrap(); + async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result { + let plan = self.query_engine.sql_to_plan(query, query_ctx).unwrap(); Ok(self.query_engine.execute(&plan).await.unwrap()) } } diff --git a/src/servers/tests/mysql/mysql_server_test.rs b/src/servers/tests/mysql/mysql_server_test.rs index ba82e8d68f..56f7c8a886 100644 --- a/src/servers/tests/mysql/mysql_server_test.rs +++ b/src/servers/tests/mysql/mysql_server_test.rs @@ -20,17 +20,19 @@ use common_recordbatch::RecordBatch; use common_runtime::Builder as RuntimeBuilder; use datatypes::schema::Schema; use mysql_async::prelude::*; +use mysql_async::SslOpts; use rand::rngs::StdRng; use rand::Rng; use servers::error::Result; use servers::mysql::server::MysqlServer; use servers::server::Server; +use servers::tls::TlsOption; use table::test_util::MemTable; use crate::create_testing_sql_query_handler; use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData}; -fn create_mysql_server(table: MemTable) -> Result> { +fn create_mysql_server(table: MemTable, tls: Arc) -> Result> { let query_handler = create_testing_sql_query_handler(table); let io_runtime = Arc::new( RuntimeBuilder::default() @@ -39,14 +41,14 @@ fn create_mysql_server(table: MemTable) -> Result> { .build() .unwrap(), ); - Ok(MysqlServer::create_server(query_handler, io_runtime)) + Ok(MysqlServer::create_server(query_handler, io_runtime, tls)) } #[tokio::test] async fn test_start_mysql_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table)?; + let mysql_server = create_mysql_server(table, Default::default())?; let listening = "127.0.0.1:0".parse::().unwrap(); let result = mysql_server.start(listening).await; assert!(result.is_ok()); @@ -65,7 +67,7 @@ async fn test_shutdown_mysql_server() -> Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table)?; + let mysql_server = create_mysql_server(table, Default::default())?; let result = mysql_server.shutdown().await; assert!(result .unwrap_err() @@ -80,7 +82,7 @@ async fn test_shutdown_mysql_server() -> Result<()> { for index in 0..2 { join_handles.push(tokio::spawn(async move { for _ in 0..1000 { - match create_connection(server_port, index == 1).await { + match create_connection(server_port, index == 1, false).await { Ok(mut connection) => { let result: u32 = connection .query_first("SELECT uint32s FROM numbers LIMIT 1") @@ -114,6 +116,63 @@ async fn test_shutdown_mysql_server() -> Result<()> { async fn test_query_all_datatypes() -> Result<()> { common_telemetry::init_default_ut_logging(); + let server_tls = Arc::new(TlsOption::default()); + let client_tls = false; + + do_test_query_all_datatypes(server_tls, client_tls, false).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_server_prefer_secure_client_plain() -> Result<()> { + let server_tls = Arc::new(TlsOption { + mode: servers::tls::TlsMode::Prefer, + cert_path: "tests/ssl/server.crt".to_owned(), + key_path: "tests/ssl/server.key".to_owned(), + }); + + let client_tls = false; + do_test_query_all_datatypes(server_tls, client_tls, false).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_server_prefer_secure_client_secure() -> Result<()> { + let server_tls = Arc::new(TlsOption { + mode: servers::tls::TlsMode::Prefer, + cert_path: "tests/ssl/server.crt".to_owned(), + key_path: "tests/ssl/server.key".to_owned(), + }); + + let client_tls = true; + do_test_query_all_datatypes(server_tls, client_tls, false).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_server_require_secure_client_secure() -> Result<()> { + let server_tls = Arc::new(TlsOption { + mode: servers::tls::TlsMode::Require, + cert_path: "tests/ssl/server.crt".to_owned(), + key_path: "tests/ssl/server.key".to_owned(), + }); + + let client_tls = true; + do_test_query_all_datatypes(server_tls, client_tls, false).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_server_required_secure_client_plain() -> Result<()> { + let server_tls = Arc::new(TlsOption { + mode: servers::tls::TlsMode::Require, + cert_path: "tests/ssl/server.crt".to_owned(), + key_path: "tests/ssl/server.key".to_owned(), + }); + + let client_tls = false; + + #[allow(unused)] let TestingData { column_schemas, mysql_columns_def, @@ -124,11 +183,41 @@ async fn test_query_all_datatypes() -> Result<()> { let recordbatch = RecordBatch::new(schema, columns).unwrap(); let table = MemTable::new("all_datatypes", recordbatch); - let mysql_server = create_mysql_server(table)?; + let mysql_server = create_mysql_server(table, server_tls)?; + let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); - let mut connection = create_connection(server_addr.port(), false).await.unwrap(); + let r = create_connection(server_addr.port(), client_tls, false).await; + assert!(r.is_err()); + Ok(()) +} + +async fn do_test_query_all_datatypes( + server_tls: Arc, + with_pwd: bool, + client_tls: bool, +) -> Result<()> { + common_telemetry::init_default_ut_logging(); + let TestingData { + column_schemas, + mysql_columns_def, + columns, + mysql_text_output_rows, + } = all_datatype_testing_data(); + let schema = Arc::new(Schema::new(column_schemas.clone())); + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let table = MemTable::new("all_datatypes", recordbatch); + + let mysql_server = create_mysql_server(table, server_tls)?; + + let listening = "127.0.0.1:0".parse::().unwrap(); + let server_addr = mysql_server.start(listening).await.unwrap(); + + let mut connection = create_connection(server_addr.port(), client_tls, with_pwd) + .await + .unwrap(); + let mut result = connection .query_iter("SELECT * FROM all_datatypes LIMIT 3") .await @@ -155,7 +244,7 @@ async fn test_query_concurrently() -> Result<()> { let table = MemTable::default_numbers_table(); - let mysql_server = create_mysql_server(table)?; + let mysql_server = create_mysql_server(table, Default::default())?; let listening = "127.0.0.1:0".parse::().unwrap(); let server_addr = mysql_server.start(listening).await.unwrap(); let server_port = server_addr.port(); @@ -167,7 +256,7 @@ async fn test_query_concurrently() -> Result<()> { join_handles.push(tokio::spawn(async move { let mut rand: StdRng = rand::SeedableRng::from_entropy(); - let mut connection = create_connection(server_port, index % 2 == 0) + let mut connection = create_connection(server_port, index % 2 == 0, false) .await .unwrap(); for _ in 0..expect_executed_queries_per_worker { @@ -184,8 +273,7 @@ async fn test_query_concurrently() -> Result<()> { let should_recreate_conn = expected == 1; if should_recreate_conn { - connection.disconnect().await.unwrap(); - connection = create_connection(server_port, index % 2 == 0) + connection = create_connection(server_port, index % 2 == 0, false) .await .unwrap(); } @@ -201,13 +289,24 @@ async fn test_query_concurrently() -> Result<()> { Ok(()) } -async fn create_connection(port: u16, with_pwd: bool) -> mysql_async::Result { +async fn create_connection( + port: u16, + with_pwd: bool, + ssl: bool, +) -> mysql_async::Result { let mut opts = mysql_async::OptsBuilder::default() .ip_or_hostname("127.0.0.1") .tcp_port(port) .prefer_socket(false) .wait_timeout(Some(1000)); + if ssl { + let ssl_opts = SslOpts::default() + .with_danger_skip_domain_validation(true) + .with_danger_accept_invalid_certs(true); + opts = opts.ssl_opts(ssl_opts) + } + if with_pwd { opts = opts.pass(Some("default_pwd".to_string())); } diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index bbbe88ee78..8abc5ff760 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -14,20 +14,28 @@ use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, SystemTime}; +use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_runtime::Builder as RuntimeBuilder; use rand::rngs::StdRng; use rand::Rng; +use rustls::client::{ServerCertVerified, ServerCertVerifier}; +use rustls::{Certificate, Error, ServerName}; use servers::error::Result; use servers::postgres::PostgresServer; use servers::server::Server; +use servers::tls::TlsOption; use table::test_util::MemTable; use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage}; use crate::create_testing_sql_query_handler; -fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result> { +fn create_postgres_server( + table: MemTable, + check_pwd: bool, + tls: Arc, +) -> Result> { let query_handler = create_testing_sql_query_handler(table); let io_runtime = Arc::new( RuntimeBuilder::default() @@ -39,6 +47,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result Result Result<()> { let table = MemTable::default_numbers_table(); - let pg_server = create_postgres_server(table, false)?; + let pg_server = create_postgres_server(table, false, Default::default())?; let listening = "127.0.0.1:0".parse::().unwrap(); let result = pg_server.start(listening).await; assert!(result.is_ok()); @@ -72,8 +81,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> { common_telemetry::init_default_ut_logging(); let table = MemTable::default_numbers_table(); - - let postgres_server = create_postgres_server(table, with_pwd)?; + let postgres_server = create_postgres_server(table, with_pwd, Default::default())?; let result = postgres_server.shutdown().await; assert!(result .unwrap_err() @@ -88,7 +96,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> { for _ in 0..2 { join_handles.push(tokio::spawn(async move { for _ in 0..1000 { - match create_connection(server_port, with_pwd).await { + match create_plain_connection(server_port, with_pwd).await { Ok(connection) => { match connection .simple_query("SELECT uint32s FROM numbers LIMIT 1") @@ -128,14 +136,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn test_query_pg_concurrently() -> Result<()> { - common_telemetry::init_default_ut_logging(); - - let table = MemTable::default_numbers_table(); - - let pg_server = create_postgres_server(table, false)?; - let listening = "127.0.0.1:0".parse::().unwrap(); - let server_addr = pg_server.start(listening).await.unwrap(); - let server_port = server_addr.port(); + let server_port = start_test_server(Default::default()).await?; let threads = 4; let expect_executed_queries_per_worker = 300; @@ -144,7 +145,7 @@ async fn test_query_pg_concurrently() -> Result<()> { join_handles.push(tokio::spawn(async move { let mut rand: StdRng = rand::SeedableRng::from_entropy(); - let mut client = create_connection(server_port, false).await.unwrap(); + let mut client = create_plain_connection(server_port, false).await.unwrap(); for _k in 0..expect_executed_queries_per_worker { let expected: u32 = rand.gen_range(0..100); @@ -165,7 +166,7 @@ async fn test_query_pg_concurrently() -> Result<()> { // 1/100 chance to reconnect let should_recreate_conn = expected == 1; if should_recreate_conn { - client = create_connection(server_port, false).await.unwrap(); + client = create_plain_connection(server_port, false).await.unwrap(); } } expect_executed_queries_per_worker @@ -179,7 +180,126 @@ async fn test_query_pg_concurrently() -> Result<()> { Ok(()) } -async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result { +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_server_secure_prefer_client_plain() -> Result<()> { + common_telemetry::init_default_ut_logging(); + + let server_tls = Arc::new(TlsOption { + mode: servers::tls::TlsMode::Prefer, + cert_path: "tests/ssl/server.crt".to_owned(), + key_path: "tests/ssl/server.key".to_owned(), + }); + + let client_tls = false; + do_simple_query(server_tls, client_tls).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_server_secure_require_client_plain() -> Result<()> { + common_telemetry::init_default_ut_logging(); + + let server_tls = Arc::new(TlsOption { + mode: servers::tls::TlsMode::Require, + cert_path: "tests/ssl/server.crt".to_owned(), + key_path: "tests/ssl/server.key".to_owned(), + }); + let server_port = start_test_server(server_tls).await?; + let r = create_plain_connection(server_port, false).await; + assert!(r.is_err()); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_server_secure_require_client_secure() -> Result<()> { + common_telemetry::init_default_ut_logging(); + + let server_tls = Arc::new(TlsOption { + mode: servers::tls::TlsMode::Require, + cert_path: "tests/ssl/server.crt".to_owned(), + key_path: "tests/ssl/server.key".to_owned(), + }); + + let client_tls = true; + do_simple_query(server_tls, client_tls).await?; + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_using_db() -> Result<()> { + let server_port = start_test_server(Arc::new(TlsOption::default())).await?; + + let client = create_connection_with_given_db(server_port, "testdb") + .await + .unwrap(); + let result = client.simple_query("SELECT uint32s FROM numbers").await; + assert!(result.is_err()); + + let client = create_connection_with_given_db(server_port, DEFAULT_SCHEMA_NAME) + .await + .unwrap(); + let result = client.simple_query("SELECT uint32s FROM numbers").await; + assert!(result.is_ok()); + Ok(()) +} + +async fn start_test_server(server_tls: Arc) -> Result { + common_telemetry::init_default_ut_logging(); + let table = MemTable::default_numbers_table(); + let pg_server = create_postgres_server(table, false, server_tls)?; + let listening = "127.0.0.1:0".parse::().unwrap(); + let server_addr = pg_server.start(listening).await.unwrap(); + Ok(server_addr.port()) +} + +async fn do_simple_query(server_tls: Arc, client_tls: bool) -> Result<()> { + let server_port = start_test_server(server_tls).await?; + + if !client_tls { + let client = create_plain_connection(server_port, false).await.unwrap(); + let result = client.simple_query("SELECT uint32s FROM numbers").await; + assert!(result.is_ok()); + } else { + let client = create_secure_connection(server_port, false).await.unwrap(); + let result = client.simple_query("SELECT uint32s FROM numbers").await; + assert!(result.is_ok()); + } + + Ok(()) +} + +async fn create_secure_connection( + port: u16, + with_pwd: bool, +) -> std::result::Result { + let url = if with_pwd { + format!( + "sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2", + port + ) + } else { + format!("host=127.0.0.1 port={} connect_timeout=2", port) + }; + + let mut config = rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + config + .dangerous() + .set_certificate_verifier(Arc::new(AcceptAllVerifier {})); + + let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config); + let (client, conn) = tokio_postgres::connect(&url, tls).await.expect("connect"); + + tokio::spawn(conn); + Ok(client) +} + +async fn create_plain_connection( + port: u16, + with_pwd: bool, +) -> std::result::Result { let url = if with_pwd { format!( "host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2", @@ -193,6 +313,19 @@ async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result std::result::Result { + let url = format!( + "host=127.0.0.1 port={} connect_timeout=2 dbname={}", + port, db + ); + let (client, conn) = tokio_postgres::connect(&url, NoTls).await?; + tokio::spawn(conn); + Ok(client) +} + fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> { match resp { &SimpleQueryMessage::Row(ref r) => r.get(col_index), @@ -203,3 +336,18 @@ fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> { fn unwrap_results(resp: &[SimpleQueryMessage]) -> Vec<&str> { resp.iter().filter_map(|m| resolve_result(m, 0)).collect() } + +struct AcceptAllVerifier {} +impl ServerCertVerifier for AcceptAllVerifier { + fn verify_server_cert( + &self, + _end_entity: &Certificate, + _intermediates: &[Certificate], + _server_name: &ServerName, + _scts: &mut dyn Iterator, + _ocsp_response: &[u8], + _now: SystemTime, + ) -> std::result::Result { + Ok(ServerCertVerified::assertion()) + } +} diff --git a/src/servers/tests/ssl/server.crt b/src/servers/tests/ssl/server.crt new file mode 100644 index 0000000000..308430c8bc --- /dev/null +++ b/src/servers/tests/ssl/server.crt @@ -0,0 +1,77 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: + 1e:a1:44:88:27:3d:5c:c8:ff:ef:06:2e:da:21:05:29:30:a5:ce:2c + Signature Algorithm: sha256WithRSAEncryption + Issuer: CN = localhost + Validity + Not Before: Oct 11 07:36:01 2022 GMT + Not After : Oct 8 07:36:01 2032 GMT + Subject: CN = localhost + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (2048 bit) + Modulus: + 00:d5:b0:29:38:63:13:5e:1e:1d:ae:1f:47:88:b4: + 44:96:21:d8:d7:03:a3:d8:f9:03:2f:4e:79:66:e6: + db:19:55:1d:85:9b:f1:78:2d:87:f3:72:91:13:dc: + ff:00:cb:ab:fd:a1:c8:3a:56:26:e3:88:1d:ec:98: + 4a:af:eb:f9:60:80:27:e1:06:ba:c0:0d:c3:09:0e: + fe:d8:86:1e:25:b4:04:62:a5:75:46:8e:11:e8:61: + 59:aa:97:17:ea:c7:4c:c6:13:8c:6d:54:2a:b9:78: + 86:54:a9:6f:d6:31:96:c6:41:76:a3:c7:67:40:6f: + f2:1a:4c:0d:77:05:bb:3d:0b:16:f8:c7:de:6c:de: + 7b:2e:b6:29:85:4b:a8:36:d3:f2:84:75:e0:85:17: + ce:22:84:4b:94:02:17:8a:36:2b:13:ee:2f:aa:55: + 6b:ff:8b:df:d3:e0:23:8d:fd:c3:f8:e2:c8:a7:d5: + 76:a6:73:7d:a8:5f:6a:49:02:78:a2:c5:66:14:ee: + 86:50:3b:d1:67:7f:1b:0c:27:0d:84:ec:44:0d:39: + 08:ba:69:65:e0:35:a4:67:aa:19:e7:fe:0e:4b:9f: + 23:1e:4e:38:ed:d7:93:57:6e:94:31:05:d3:ae:f7: + 6c:01:3c:30:69:19:f4:7b:b5:48:95:71:c9:9c:30: + 43:9d + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + 8E:81:0B:60:B1:F9:7D:D8:64:91:BB:30:86:E5:3D:CD:B7:82:D8:31 + X509v3 Authority Key Identifier: + keyid:8E:81:0B:60:B1:F9:7D:D8:64:91:BB:30:86:E5:3D:CD:B7:82:D8:31 + + X509v3 Basic Constraints: critical + CA:TRUE + Signature Algorithm: sha256WithRSAEncryption + 6c:ae:ee:3e:e3:d4:5d:29:37:62:b0:32:ce:a4:36:c7:25:b4: + 6a:9f:ba:b4:f0:2f:0a:96:2f:dc:6d:df:7d:92:e7:f0:ee:f7: + de:44:9d:52:36:ff:0c:98:ef:8b:7f:27:df:6e:fe:64:11:7c: + 01:5d:7f:c8:73:a3:24:24:ba:81:fd:a8:ae:28:4f:93:bb:92: + ff:86:d6:48:a2:ca:a5:1f:ea:1c:0d:02:22:e8:71:23:27:22: + 4f:0f:37:58:9a:d9:fd:70:c5:4c:93:7d:47:1c:b6:ea:1b:4f: + 4e:7c:eb:9d:9a:d3:28:78:67:27:e9:b1:ea:f6:93:68:76:e5: + 2e:52:c6:29:91:ba:0a:96:2e:14:33:69:35:d7:b5:e0:c0:ef: + 05:77:09:9b:a1:cc:7b:b2:f0:6a:cb:5c:5f:a1:27:69:b0:2c: + 6e:93:eb:37:98:cd:97:8d:9e:78:a8:f5:99:12:66:86:48:cf: + b2:e0:68:6f:77:98:06:13:24:55:d1:c3:80:1d:59:53:1f:44: + 85:bc:5d:29:aa:2a:a1:06:17:6b:e7:2b:11:0b:fd:e3:f8:88: + 89:32:57:a3:70:f7:1b:6c:c1:66:c7:3c:a4:2d:e8:5f:00:1c: + 55:2f:72:ed:d4:3a:3f:d0:95:de:6c:a4:96:6e:b4:63:0e:80: + 08:b2:25:d5 +-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIUHqFEiCc9XMj/7wYu2iEFKTClziwwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMTAxMTA3MzYwMVoXDTMyMTAw +ODA3MzYwMVowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEA1bApOGMTXh4drh9HiLREliHY1wOj2PkDL055ZubbGVUd +hZvxeC2H83KRE9z/AMur/aHIOlYm44gd7JhKr+v5YIAn4Qa6wA3DCQ7+2IYeJbQE +YqV1Ro4R6GFZqpcX6sdMxhOMbVQquXiGVKlv1jGWxkF2o8dnQG/yGkwNdwW7PQsW ++MfebN57LrYphUuoNtPyhHXghRfOIoRLlAIXijYrE+4vqlVr/4vf0+Ajjf3D+OLI +p9V2pnN9qF9qSQJ4osVmFO6GUDvRZ38bDCcNhOxEDTkIumll4DWkZ6oZ5/4OS58j +Hk447deTV26UMQXTrvdsATwwaRn0e7VIlXHJnDBDnQIDAQABo1MwUTAdBgNVHQ4E +FgQUjoELYLH5fdhkkbswhuU9zbeC2DEwHwYDVR0jBBgwFoAUjoELYLH5fdhkkbsw +huU9zbeC2DEwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbK7u +PuPUXSk3YrAyzqQ2xyW0ap+6tPAvCpYv3G3ffZLn8O733kSdUjb/DJjvi38n327+ +ZBF8AV1/yHOjJCS6gf2orihPk7uS/4bWSKLKpR/qHA0CIuhxIyciTw83WJrZ/XDF +TJN9Rxy26htPTnzrnZrTKHhnJ+mx6vaTaHblLlLGKZG6CpYuFDNpNde14MDvBXcJ +m6HMe7LwastcX6EnabAsbpPrN5jNl42eeKj1mRJmhkjPsuBob3eYBhMkVdHDgB1Z +Ux9EhbxdKaoqoQYXa+crEQv94/iIiTJXo3D3G2zBZsc8pC3oXwAcVS9y7dQ6P9CV +3myklm60Yw6ACLIl1Q== +-----END CERTIFICATE----- diff --git a/src/servers/tests/ssl/server.key b/src/servers/tests/ssl/server.key new file mode 100644 index 0000000000..61b3c4eb90 --- /dev/null +++ b/src/servers/tests/ssl/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEwAIBADANBgkqhkiG9w0BAQEFAASCBKowggSmAgEAAoIBAQDVsCk4YxNeHh2u +H0eItESWIdjXA6PY+QMvTnlm5tsZVR2Fm/F4LYfzcpET3P8Ay6v9ocg6VibjiB3s +mEqv6/lggCfhBrrADcMJDv7Yhh4ltARipXVGjhHoYVmqlxfqx0zGE4xtVCq5eIZU +qW/WMZbGQXajx2dAb/IaTA13Bbs9Cxb4x95s3nsutimFS6g20/KEdeCFF84ihEuU +AheKNisT7i+qVWv/i9/T4CON/cP44sin1Xamc32oX2pJAniixWYU7oZQO9FnfxsM +Jw2E7EQNOQi6aWXgNaRnqhnn/g5LnyMeTjjt15NXbpQxBdOu92wBPDBpGfR7tUiV +ccmcMEOdAgMBAAECggEBAMMCIJv0zpf1o+Bja0S2PmFEQj72c3Buzxk85E2kIA7e +PjLQPW0PICJrSzp1U8HGHQ85tSCHvrWmYqin0oD5OHt4eOxC1+qspHB/3tJ6ksiV +n+rmVEAvJuiK7ulfOdRoTQf2jxC23saj1vMsLYOrfY0v8LVGJFQJ1UdqYF9eO6FX +8i6eQekV0n8u+DMUysYXfePDXEwpunKrlZwZtThgBY31gAIOdNo/FOAFe1yBJdPl +rUFZes1IrE0c4CNxodajuRNCjtNWoX8TK1cXQVUpPprdFLBcYG2P9mPZ7SkZWJc7 +rkyPX6Wkb7q3laUCBxuKL1iOJIwaVBYaKfv4HS7VuYECgYEA9H7VB8+whWx2cTFb +9oYbcaU3HtbKRh6KQP8eB4IWeKV/c/ceWVAxtU9Hx2QU1zZ2fLl+KkaOGeECNNqD +BP1O5qk2qmkjJcP4kzh1K+p7zkqAkrhHqB36y/gwptB8v7JbCchQq9cnBeYsXNIa +j13KvteprRSnanKu18d2aC43cNMCgYEA3746ITtqy1g6AQ0Q/MXN/axsXixKfVjf +kgN/lpjy6oeoEIWKqiNrOQpwy4NeBo6ZN+cwjUUr9SY/BKsZqMGErO8Xuu+QtJYD +ioW/My9rTrTElbpsLpSvZDLc9IRepV4k+5PpXTIRBqp7Q3BZnTjbRMc8x/owG23G +eXnfVKlWM88CgYEA5HBQuMCrzK3/qFkW9Kpun+tfKfhD++nzATGcrCU2u7jd8cr1 +1zsfhqkxhrIS6tYfNP/XSsarZLCgcCOuAQ5wFwIJaoVbaqDE80Dv8X1f+eoQYYW+ +peyE9OjLBEGOHUoW13gLL9ORyWg7EOraGBPpKBC2n1nJ5qKKjF/4WPS9pjMCgYEA +3UuUyxGtivn0RN3bk2dBWkmT1YERG/EvD4gORbF5caZDADRU9fqaLoy5C1EfSnT3 +7mbnipKD67CsW72vX04oH7NLUUVpZnOJhRTMC6A3Dl2UolMEdP3yi7QS/nV99ymq +gnnFMrw2QtWTnRweRnbZyKkW4OP/eOGWkMeNsHrcG9kCgYEAz/09cKumk349AIXV +g6Jw64gCTjWh157wnD3ZSPPEcr/09/fZwf1W0gkY/tbCVrVPJHWb3K5t2nRXjLlz +HMnQXmcMxMlY3Ufvm2H3ov1ODPKwpcBWUZqnpFTZX7rC58lO/wvgiKpgtHA3pDdw +oYDaaozVP4EnnByxhmHaM7ce07U= +-----END PRIVATE KEY----- diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml new file mode 100644 index 0000000000..0e4e0b1591 --- /dev/null +++ b/src/session/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "session" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" + +[dependencies] +arc-swap = "1.5" +common-telemetry = { path = "../common/telemetry" } diff --git a/src/session/src/context.rs b/src/session/src/context.rs new file mode 100644 index 0000000000..aec55ac941 --- /dev/null +++ b/src/session/src/context.rs @@ -0,0 +1,56 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arc_swap::ArcSwapOption; +use common_telemetry::info; + +pub type QueryContextRef = Arc; + +pub struct QueryContext { + current_schema: ArcSwapOption, +} + +impl Default for QueryContext { + fn default() -> Self { + Self::new() + } +} + +impl QueryContext { + pub fn new() -> Self { + Self { + current_schema: ArcSwapOption::new(None), + } + } + + pub fn with_current_schema(schema: String) -> Self { + Self { + current_schema: ArcSwapOption::new(Some(Arc::new(schema))), + } + } + + pub fn current_schema(&self) -> Option { + self.current_schema.load().as_deref().cloned() + } + + pub fn set_current_schema(&self, schema: &str) { + let last = self.current_schema.swap(Some(Arc::new(schema.to_string()))); + info!( + "set new session default schema: {:?}, swap old: {:?}", + schema, last + ) + } +} diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs new file mode 100644 index 0000000000..57437c3057 --- /dev/null +++ b/src/session/src/lib.rs @@ -0,0 +1,36 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod context; + +use std::sync::Arc; + +use crate::context::{QueryContext, QueryContextRef}; + +#[derive(Default)] +pub struct Session { + query_ctx: QueryContextRef, +} + +impl Session { + pub fn new() -> Self { + Session { + query_ctx: Arc::new(QueryContext::new()), + } + } + + pub fn context(&self) -> QueryContextRef { + self.query_ctx.clone() + } +} diff --git a/src/sql/Cargo.toml b/src/sql/Cargo.toml index 7b3949e043..6f7f40b017 100644 --- a/src/sql/Cargo.toml +++ b/src/sql/Cargo.toml @@ -12,7 +12,7 @@ common-error = { path = "../common/error" } common-time = { path = "../common/time" } datatypes = { path = "../datatypes" } itertools = "0.10" +mito = { path = "../mito" } once_cell = "1.10" snafu = { version = "0.7", features = ["backtraces"] } sqlparser = "0.15.0" -mito = { path = "../mito" } diff --git a/src/sql/src/parser.rs b/src/sql/src/parser.rs index 6744cb824b..254982e88e 100644 --- a/src/sql/src/parser.rs +++ b/src/sql/src/parser.rs @@ -22,6 +22,8 @@ use crate::error::{ self, InvalidDatabaseNameSnafu, InvalidTableNameSnafu, Result, SyntaxSnafu, TokenizerSnafu, }; use crate::statements::describe::DescribeTable; +use crate::statements::drop::DropTable; +use crate::statements::explain::Explain; use crate::statements::show::{ShowCreateTable, ShowDatabases, ShowKind, ShowTables}; use crate::statements::statement::Statement; use crate::statements::table_idents_to_full_name; @@ -98,6 +100,23 @@ impl<'a> ParserContext<'a> { Keyword::ALTER => self.parse_alter(), + Keyword::DROP => self.parse_drop(), + + // TODO(LFC): Use "Keyword::USE" when we can upgrade to newer version of crate sqlparser. + Keyword::NoKeyword if w.value.to_lowercase() == "use" => { + self.parser.next_token(); + + let database_name = + self.parser + .parse_identifier() + .context(error::UnexpectedSnafu { + sql: self.sql, + expected: "a database name", + actual: self.peek_token_as_string(), + })?; + Ok(Statement::Use(database_name.value)) + } + // todo(hl) support more statements. _ => self.unsupported(self.peek_token_as_string()), } @@ -258,7 +277,46 @@ impl<'a> ParserContext<'a> { } fn parse_explain(&mut self) -> Result { - todo!() + let explain_statement = + self.parser + .parse_explain(false) + .with_context(|_| error::UnexpectedSnafu { + sql: self.sql, + expected: "a query statement", + actual: self.peek_token_as_string(), + })?; + + Ok(Statement::Explain(Explain::try_from(explain_statement)?)) + } + + fn parse_drop(&mut self) -> Result { + self.parser.next_token(); + if !self.matches_keyword(Keyword::TABLE) { + return self.unsupported(self.peek_token_as_string()); + } + self.parser.next_token(); + + let table_ident = + self.parser + .parse_object_name() + .with_context(|_| error::UnexpectedSnafu { + sql: self.sql, + expected: "a table name", + actual: self.peek_token_as_string(), + })?; + ensure!( + !table_ident.0.is_empty(), + InvalidTableNameSnafu { + name: table_ident.to_string() + } + ); + + let (catalog_name, schema_name, table_name) = table_idents_to_full_name(&table_ident)?; + Ok(Statement::DropTable(DropTable { + catalog_name, + schema_name, + table_name, + })) } // Report unexpected token @@ -328,6 +386,8 @@ impl<'a> ParserContext<'a> { mod tests { use std::assert_matches::assert_matches; + use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; + use sqlparser::ast::{Query as SpQuery, Statement as SpStatement}; use sqlparser::dialect::GenericDialect; use super::*; @@ -471,4 +531,93 @@ mod tests { }) ); } + + #[test] + pub fn test_explain() { + let sql = "EXPLAIN select * from foo"; + let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let stmts = result.unwrap(); + assert_eq!(1, stmts.len()); + + let select = sqlparser::ast::Select { + distinct: false, + top: None, + projection: vec![sqlparser::ast::SelectItem::Wildcard], + from: vec![sqlparser::ast::TableWithJoins { + relation: sqlparser::ast::TableFactor::Table { + name: sqlparser::ast::ObjectName(vec![sqlparser::ast::Ident::new("foo")]), + alias: None, + args: vec![], + with_hints: vec![], + }, + joins: vec![], + }], + lateral_views: vec![], + selection: None, + group_by: vec![], + cluster_by: vec![], + distribute_by: vec![], + sort_by: vec![], + having: None, + }; + + let sp_statement = SpStatement::Query(Box::new(SpQuery { + with: None, + body: sqlparser::ast::SetExpr::Select(Box::new(select)), + order_by: vec![], + limit: None, + offset: None, + fetch: None, + lock: None, + })); + + let explain = Explain::try_from(SpStatement::Explain { + describe_alias: false, + analyze: false, + verbose: false, + statement: Box::new(sp_statement), + }) + .unwrap(); + + assert_eq!(stmts[0], Statement::Explain(explain)) + } + + #[test] + pub fn test_drop_table() { + let sql = "DROP TABLE foo"; + let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let mut stmts = result.unwrap(); + assert_eq!( + stmts.pop().unwrap(), + Statement::DropTable(DropTable { + catalog_name: DEFAULT_CATALOG_NAME.to_string(), + schema_name: DEFAULT_SCHEMA_NAME.to_string(), + table_name: "foo".to_string() + }) + ); + + let sql = "DROP TABLE my_schema.foo"; + let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let mut stmts = result.unwrap(); + assert_eq!( + stmts.pop().unwrap(), + Statement::DropTable(DropTable { + catalog_name: DEFAULT_CATALOG_NAME.to_string(), + schema_name: "my_schema".to_string(), + table_name: "foo".to_string() + }) + ); + + let sql = "DROP TABLE my_catalog.my_schema.foo"; + let result = ParserContext::create_with_dialect(sql, &GenericDialect {}); + let mut stmts = result.unwrap(); + assert_eq!( + stmts.pop().unwrap(), + Statement::DropTable(DropTable { + catalog_name: "my_catalog".to_string(), + schema_name: "my_schema".to_string(), + table_name: "foo".to_string() + }) + ) + } } diff --git a/src/sql/src/statements.rs b/src/sql/src/statements.rs index 4e15d15ea4..bcdc099265 100644 --- a/src/sql/src/statements.rs +++ b/src/sql/src/statements.rs @@ -15,6 +15,8 @@ pub mod alter; pub mod create; pub mod describe; +pub mod drop; +pub mod explain; pub mod insert; pub mod query; pub mod show; @@ -40,6 +42,8 @@ use crate::error::{ SerializeColumnDefaultConstraintSnafu, UnsupportedDefaultValueSnafu, }; +// TODO(LFC): Get rid of this function, use session context aware version of "table_idents_to_full_name" instead. +// Current obstacles remain in some usage in Frontend, and other SQLs like "describe", "drop" etc. /// Converts maybe fully-qualified table name (`..
` or `
` when /// catalog and schema are default) to tuple. pub fn table_idents_to_full_name(obj_name: &ObjectName) -> Result<(String, String, String)> { @@ -321,11 +325,16 @@ pub fn sql_data_type_to_concrete_data_type(data_type: &SqlDataType) -> Result Self { + DropTable { + catalog_name, + schema_name, + table_name, + } + } +} diff --git a/src/sql/src/statements/explain.rs b/src/sql/src/statements/explain.rs new file mode 100644 index 0000000000..01f9330ef3 --- /dev/null +++ b/src/sql/src/statements/explain.rs @@ -0,0 +1,37 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use sqlparser::ast::Statement as SpStatement; + +use crate::error::Error; + +/// Explain statement. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Explain { + pub inner: SpStatement, +} + +impl TryFrom for Explain { + type Error = Error; + + fn try_from(value: SpStatement) -> Result { + Ok(Explain { inner: value }) + } +} + +impl ToString for Explain { + fn to_string(&self) -> String { + self.inner.to_string() + } +} diff --git a/src/sql/src/statements/insert.rs b/src/sql/src/statements/insert.rs index e94e512d15..410c0d09cb 100644 --- a/src/sql/src/statements/insert.rs +++ b/src/sql/src/statements/insert.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlparser::ast::{SetExpr, Statement, UnaryOperator, Values}; +use sqlparser::ast::{ObjectName, SetExpr, Statement, UnaryOperator, Values}; use sqlparser::parser::ParserError; use crate::ast::{Expr, Value}; @@ -33,6 +33,13 @@ impl Insert { } } + pub fn table_name(&self) -> &ObjectName { + match &self.inner { + Statement::Insert { table_name, .. } => table_name, + _ => unreachable!(), + } + } + pub fn columns(&self) -> Vec<&String> { match &self.inner { Statement::Insert { columns, .. } => columns.iter().map(|ident| &ident.value).collect(), @@ -110,15 +117,6 @@ mod tests { use super::*; use crate::parser::ParserContext; - #[test] - pub fn test_insert_convert() { - let sql = r"INSERT INTO tables_0 VALUES ( 'field_0', 0) "; - let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); - assert_eq!(1, stmts.len()); - let insert = stmts.pop().unwrap(); - let _stmt: Statement = insert.try_into().unwrap(); - } - #[test] fn test_insert_value_with_unary_op() { use crate::statements::statement::Statement; diff --git a/src/sql/src/statements/statement.rs b/src/sql/src/statements/statement.rs index 6e91424cc8..e1c8d731bb 100644 --- a/src/sql/src/statements/statement.rs +++ b/src/sql/src/statements/statement.rs @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use sqlparser::ast::Statement as SpStatement; -use sqlparser::parser::ParserError; - use crate::statements::alter::AlterTable; use crate::statements::create::{CreateDatabase, CreateTable}; use crate::statements::describe::DescribeTable; +use crate::statements::drop::DropTable; +use crate::statements::explain::Explain; use crate::statements::insert::Insert; use crate::statements::query::Query; use crate::statements::show::{ShowCreateTable, ShowDatabases, ShowTables}; /// Tokens parsed by `DFParser` are converted into these values. +#[allow(clippy::large_enum_variant)] #[derive(Debug, Clone, PartialEq, Eq)] pub enum Statement { // Query @@ -31,6 +31,8 @@ pub enum Statement { Insert(Box), /// CREATE TABLE CreateTable(CreateTable), + // DROP TABLE + DropTable(DropTable), // CREATE DATABASE CreateDatabase(CreateDatabase), /// ALTER TABLE @@ -43,33 +45,9 @@ pub enum Statement { ShowCreateTable(ShowCreateTable), // DESCRIBE TABLE DescribeTable(DescribeTable), -} - -/// Converts Statement to sqlparser statement -impl TryFrom for SpStatement { - type Error = sqlparser::parser::ParserError; - - fn try_from(value: Statement) -> Result { - match value { - Statement::ShowDatabases(_) => Err(ParserError::ParserError( - "sqlparser does not support SHOW DATABASE query.".to_string(), - )), - Statement::ShowTables(_) => Err(ParserError::ParserError( - "sqlparser does not support SHOW TABLES query.".to_string(), - )), - Statement::ShowCreateTable(_) => Err(ParserError::ParserError( - "sqlparser does not support SHOW CREATE TABLE query.".to_string(), - )), - Statement::DescribeTable(_) => Err(ParserError::ParserError( - "sqlparser does not support DESCRIBE TABLE query.".to_string(), - )), - Statement::Query(s) => Ok(SpStatement::Query(Box::new(s.inner))), - Statement::Insert(i) => Ok(i.inner), - Statement::CreateDatabase(_) | Statement::CreateTable(_) | Statement::Alter(_) => { - unimplemented!() - } - } - } + // EXPLAIN QUERY + Explain(Explain), + Use(String), } /// Comment hints from SQL. @@ -81,24 +59,3 @@ pub struct Hint { pub comment: String, pub prefix: String, } - -#[cfg(test)] -mod tests { - use std::assert_matches::assert_matches; - - use sqlparser::dialect::GenericDialect; - - use super::*; - use crate::parser::ParserContext; - - #[test] - pub fn test_statement_convert() { - let sql = "SELECT * FROM table_0"; - let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap(); - assert_eq!(1, stmts.len()); - let x = stmts.remove(0); - let statement = SpStatement::try_from(x).unwrap(); - - assert_matches!(statement, SpStatement::Query { .. }); - } -} diff --git a/src/storage/src/error.rs b/src/storage/src/error.rs index d18054f6fb..de6d8e002b 100644 --- a/src/storage/src/error.rs +++ b/src/storage/src/error.rs @@ -48,7 +48,7 @@ pub enum Error { #[snafu(display("Failed to write columns, source: {}", source))] FlushIo { - source: std::io::Error, + source: object_store::Error, backtrace: Backtrace, }, @@ -62,28 +62,28 @@ pub enum Error { ReadObject { path: String, backtrace: Backtrace, - source: IoError, + source: object_store::Error, }, #[snafu(display("Fail to write object into path: {}, source: {}", path, source))] WriteObject { path: String, backtrace: Backtrace, - source: IoError, + source: object_store::Error, }, #[snafu(display("Fail to delete object from path: {}, source: {}", path, source))] DeleteObject { path: String, backtrace: Backtrace, - source: IoError, + source: object_store::Error, }, #[snafu(display("Fail to list objects in path: {}, source: {}", path, source))] ListObjects { path: String, backtrace: Backtrace, - source: IoError, + source: object_store::Error, }, #[snafu(display("Fail to create str from bytes, source: {}", source))] @@ -218,7 +218,7 @@ pub enum Error { }, #[snafu(display( - "Sequence of region should increase monotonically ({} > {})", + "Sequence of region should increase monotonically (should be {} < {})", prev, given ))] @@ -457,7 +457,14 @@ mod tests { )) } - let error = throw_io_error().context(FlushIoSnafu).err().unwrap(); + let error = throw_io_error() + .map_err(|err| { + object_store::Error::new(object_store::ErrorKind::Unexpected, "writer close failed") + .set_source(err) + }) + .context(FlushIoSnafu) + .err() + .unwrap(); assert_eq!(StatusCode::StorageUnavailable, error.status_code()); assert!(error.backtrace_opt().is_some()); } diff --git a/src/storage/src/manifest/storage.rs b/src/storage/src/manifest/storage.rs index 27c924e56e..744f97a6eb 100644 --- a/src/storage/src/manifest/storage.rs +++ b/src/storage/src/manifest/storage.rs @@ -19,7 +19,7 @@ use async_trait::async_trait; use common_telemetry::logging; use futures::TryStreamExt; use lazy_static::lazy_static; -use object_store::{util, ObjectEntry, ObjectStore}; +use object_store::{util, Object, ObjectStore}; use regex::Regex; use serde::{Deserialize, Serialize}; use snafu::{ensure, ResultExt}; @@ -63,7 +63,7 @@ pub fn is_delta_file(file_name: &str) -> bool { } pub struct ObjectStoreLogIterator { - iter: Box + Send + Sync>, + iter: Box + Send + Sync>, } #[async_trait] @@ -72,8 +72,7 @@ impl LogIterator for ObjectStoreLogIterator { async fn next_log(&mut self) -> Result)>> { match self.iter.next() { - Some((v, e)) => { - let object = e.into_object(); + Some((v, object)) => { let bytes = object.read().await.context(ReadObjectSnafu { path: object.path(), })?; @@ -156,7 +155,7 @@ impl ManifestLogStorage for ManifestObjectStore { .await .context(ListObjectsSnafu { path: &self.path })?; - let mut entries: Vec<(ManifestVersion, ObjectEntry)> = streamer + let mut entries: Vec<(ManifestVersion, Object)> = streamer .try_filter_map(|e| async move { let file_name = e.name(); if is_delta_file(file_name) { diff --git a/src/storage/src/sst/parquet.rs b/src/storage/src/sst/parquet.rs index 1864cd6bcb..1244582b69 100644 --- a/src/storage/src/sst/parquet.rs +++ b/src/storage/src/sst/parquet.rs @@ -122,9 +122,19 @@ impl<'a> ParquetWriter<'a> { sink.close().await.context(error::WriteParquetSnafu)?; drop(sink); - writer.close().await.context(error::WriteObjectSnafu { - path: self.file_path, - }) + writer + .close() + .await + .map_err(|err| { + object_store::Error::new( + object_store::ErrorKind::Unexpected, + "writer close failed", + ) + .set_source(err) + }) + .context(error::WriteObjectSnafu { + path: self.file_path, + }) } ) .map(|_| ()) diff --git a/src/table/src/engine.rs b/src/table/src/engine.rs index d2983e547b..55f68c31cf 100644 --- a/src/table/src/engine.rs +++ b/src/table/src/engine.rs @@ -26,6 +26,27 @@ pub struct TableReference<'a> { pub table: &'a str, } +// TODO(LFC): Find a better place for `TableReference`, +// so that we can reuse the default catalog and schema consts. +// Could be done together with issue #559. +impl<'a> TableReference<'a> { + pub fn bare(table: &'a str) -> Self { + TableReference { + catalog: "greptime", + schema: "public", + table, + } + } + + pub fn full(catalog: &'a str, schema: &'a str, table: &'a str) -> Self { + TableReference { + catalog, + schema, + table, + } + } +} + impl<'a> Display for TableReference<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{}.{}.{}", self.catalog, self.schema, self.table) @@ -74,8 +95,8 @@ pub trait TableEngine: Send + Sync { /// Returns true when the given table is exists. fn table_exists<'a>(&self, ctx: &EngineContext, table_ref: &'a TableReference) -> bool; - /// Drops the given table. - async fn drop_table(&self, ctx: &EngineContext, request: DropTableRequest) -> Result<()>; + /// Drops the given table. Return true if the table is dropped, or false if the table doesn't exist. + async fn drop_table(&self, ctx: &EngineContext, request: DropTableRequest) -> Result; } pub type TableEngineRef = Arc; diff --git a/src/table/src/requests.rs b/src/table/src/requests.rs index bc0b1a8e34..9d5e877aad 100644 --- a/src/table/src/requests.rs +++ b/src/table/src/requests.rs @@ -84,4 +84,8 @@ pub enum AlterKind { /// Drop table request #[derive(Debug)] -pub struct DropTableRequest {} +pub struct DropTableRequest { + pub catalog_name: String, + pub schema_name: String, + pub table_name: String, +} diff --git a/src/table/src/table.rs b/src/table/src/table.rs index f3ba11245b..9aff8a061f 100644 --- a/src/table/src/table.rs +++ b/src/table/src/table.rs @@ -69,7 +69,8 @@ pub trait Table: Send + Sync { Ok(FilterPushDownType::Unsupported) } - async fn alter(&self, _request: AlterTableRequest) -> Result<()> { + async fn alter(&self, request: AlterTableRequest) -> Result<()> { + let _ = request; unimplemented!() } } diff --git a/src/table/src/table/numbers.rs b/src/table/src/table/numbers.rs index 26eab9a8bf..db33769c31 100644 --- a/src/table/src/table/numbers.rs +++ b/src/table/src/table/numbers.rs @@ -27,24 +27,26 @@ use futures::task::{Context, Poll}; use futures::Stream; use crate::error::Result; -use crate::metadata::{TableInfoBuilder, TableInfoRef, TableMetaBuilder, TableType}; +use crate::metadata::{TableId, TableInfoBuilder, TableInfoRef, TableMetaBuilder, TableType}; use crate::table::scan::SimpleTableScan; use crate::table::{Expr, Table}; /// numbers table for test #[derive(Debug, Clone)] pub struct NumbersTable { + table_id: TableId, schema: SchemaRef, } -impl Default for NumbersTable { - fn default() -> Self { +impl NumbersTable { + pub fn new(table_id: TableId) -> Self { let column_schemas = vec![ColumnSchema::new( "number", ConcreteDataType::uint32_datatype(), false, )]; Self { + table_id, schema: Arc::new( SchemaBuilder::try_from_columns(column_schemas) .unwrap() @@ -55,6 +57,12 @@ impl Default for NumbersTable { } } +impl Default for NumbersTable { + fn default() -> Self { + NumbersTable::new(1) + } +} + #[async_trait::async_trait] impl Table for NumbersTable { fn as_any(&self) -> &dyn Any { @@ -68,7 +76,7 @@ impl Table for NumbersTable { fn table_info(&self) -> TableInfoRef { Arc::new( TableInfoBuilder::default() - .table_id(1) + .table_id(self.table_id) .name("numbers") .catalog_name("greptime") .schema_name("public") diff --git a/src/table/src/test_util/mock_engine.rs b/src/table/src/test_util/mock_engine.rs index af14a6112b..2b19b1889a 100644 --- a/src/table/src/test_util/mock_engine.rs +++ b/src/table/src/test_util/mock_engine.rs @@ -97,7 +97,7 @@ impl TableEngine for MockTableEngine { unimplemented!() } - async fn drop_table(&self, _ctx: &EngineContext, _request: DropTableRequest) -> Result<()> { + async fn drop_table(&self, _ctx: &EngineContext, _request: DropTableRequest) -> Result { unimplemented!() } } diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml new file mode 100644 index 0000000000..1a7107fc8f --- /dev/null +++ b/tests-integration/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "tests-integration" +version = "0.1.0" +edition = "2021" +license = "Apache-2.0" + +[dependencies] +api = { path = "../src/api" } +axum = "0.6" +axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } +catalog = { path = "../src/catalog" } +client = { path = "../src/client" } +common-catalog = { path = "../src/common/catalog" } +common-runtime = { path = "../src/common/runtime" } +common-telemetry = { path = "../src/common/telemetry" } +datanode = { path = "../src/datanode" } +datatypes = { path = "../src/datatypes" } +dotenv = "0.15" +frontend = { path = "../src/frontend" } +mito = { path = "../src/mito", features = ["test"] } +object-store = { path = "../src/object-store" } +once_cell = "1.16" +rand = "0.8" +serde = "1.0" +serde_json = "1.0" +servers = { path = "../src/servers" } +snafu = { version = "0.7", features = ["backtraces"] } +sql = { path = "../src/sql" } +table = { path = "../src/table" } +tempdir = "0.3" +tokio = { version = "1.20", features = ["full"] } +uuid = { version = "1", features = ["serde", "v4"] } + +[dev-dependencies] +paste = "1.0" diff --git a/tests-integration/README.md b/tests-integration/README.md new file mode 100644 index 0000000000..ec0905504a --- /dev/null +++ b/tests-integration/README.md @@ -0,0 +1,26 @@ +## Setup + +To run the integration test, please copy `.env.example` to `.env` in the project root folder and change the values on need. + +Take `s3` for example. You need to set your S3 bucket, access key id and secret key: + +```sh +# Settings for s3 test +GT_S3_BUCKET=S3 bucket +GT_S3_ACCESS_KEY_ID=S3 access key id +GT_S3_ACCESS_KEY=S3 secret access key +``` + +## Run + +Execute the following command in the project root folder: + +``` +cargo test integration +``` + +Test s3 storage: + +``` +cargo test s3 +``` diff --git a/tests-integration/src/lib.rs b/tests-integration/src/lib.rs new file mode 100644 index 0000000000..1bfde512a8 --- /dev/null +++ b/tests-integration/src/lib.rs @@ -0,0 +1,15 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod test_util; diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs new file mode 100644 index 0000000000..70a3355f3d --- /dev/null +++ b/tests-integration/src/test_util.rs @@ -0,0 +1,333 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::env; +use std::net::SocketAddr; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +use axum::Router; +use catalog::CatalogManagerRef; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, MIN_USER_TABLE_ID}; +use common_runtime::Builder as RuntimeBuilder; +use datanode::datanode::{DatanodeOptions, ObjectStoreConfig}; +use datanode::error::{CreateTableSnafu, Result}; +use datanode::instance::{Instance, InstanceRef}; +use datanode::sql::SqlHandler; +use datatypes::data_type::ConcreteDataType; +use datatypes::schema::{ColumnSchema, SchemaBuilder}; +use frontend::frontend::FrontendOptions; +use frontend::grpc::GrpcOptions; +use frontend::instance::{FrontendInstance, Instance as FeInstance}; +use object_store::backend::s3; +use object_store::test_util::TempFolder; +use object_store::ObjectStore; +use once_cell::sync::OnceCell; +use rand::Rng; +use servers::grpc::GrpcServer; +use servers::http::{HttpOptions, HttpServer}; +use servers::server::Server; +use servers::Mode; +use snafu::ResultExt; +use table::engine::{EngineContext, TableEngineRef}; +use table::requests::CreateTableRequest; +use tempdir::TempDir; +static PORTS: OnceCell = OnceCell::new(); + +fn get_port() -> usize { + PORTS + .get_or_init(|| AtomicUsize::new(rand::thread_rng().gen_range(3500..3900))) + .fetch_add(1, Ordering::Relaxed) +} + +pub enum StorageType { + S3, + File, +} + +impl StorageType { + pub fn test_on(&self) -> bool { + let _ = dotenv::dotenv(); + + match self { + StorageType::File => true, // always test file + StorageType::S3 => { + if let Ok(b) = env::var("GT_S3_BUCKET") { + !b.is_empty() + } else { + false + } + } + } + } +} + +fn get_test_store_config( + store_type: &StorageType, + name: &str, +) -> (ObjectStoreConfig, Option) { + let _ = dotenv::dotenv(); + + match store_type { + StorageType::S3 => { + let root = uuid::Uuid::new_v4().to_string(); + let key_id = env::var("GT_S3_ACCESS_KEY_ID").unwrap(); + let secret_key = env::var("GT_S3_ACCESS_KEY").unwrap(); + let bucket = env::var("GT_S3_BUCKET").unwrap(); + + let accessor = s3::Builder::default() + .root(&root) + .access_key_id(&key_id) + .secret_access_key(&secret_key) + .bucket(&bucket) + .build() + .unwrap(); + + let config = ObjectStoreConfig::S3 { + root, + bucket, + access_key_id: key_id, + secret_access_key: secret_key, + }; + + let store = ObjectStore::new(accessor); + + (config, Some(TempDirGuard::S3(TempFolder::new(&store, "/")))) + } + StorageType::File => { + let data_tmp_dir = TempDir::new(&format!("gt_data_{}", name)).unwrap(); + + ( + ObjectStoreConfig::File { + data_dir: data_tmp_dir.path().to_str().unwrap().to_string(), + }, + Some(TempDirGuard::File(data_tmp_dir)), + ) + } + } +} + +enum TempDirGuard { + File(TempDir), + S3(TempFolder), +} + +/// Create a tmp dir(will be deleted once it goes out of scope.) and a default `DatanodeOptions`, +/// Only for test. +pub struct TestGuard { + _wal_tmp_dir: TempDir, + data_tmp_dir: Option, +} + +impl TestGuard { + pub async fn remove_all(&mut self) { + if let Some(TempDirGuard::S3(mut guard)) = self.data_tmp_dir.take() { + guard.remove_all().await.unwrap(); + } + } +} + +pub fn create_tmp_dir_and_datanode_opts( + store_type: StorageType, + name: &str, +) -> (DatanodeOptions, TestGuard) { + let wal_tmp_dir = TempDir::new(&format!("gt_wal_{}", name)).unwrap(); + + let (storage, data_tmp_dir) = get_test_store_config(&store_type, name); + + let opts = DatanodeOptions { + wal_dir: wal_tmp_dir.path().to_str().unwrap().to_string(), + storage, + mode: Mode::Standalone, + ..Default::default() + }; + ( + opts, + TestGuard { + _wal_tmp_dir: wal_tmp_dir, + data_tmp_dir, + }, + ) +} + +pub async fn create_test_table( + catalog_manager: &CatalogManagerRef, + sql_handler: &SqlHandler, + ts_type: ConcreteDataType, +) -> Result<()> { + let column_schemas = vec![ + ColumnSchema::new("host", ConcreteDataType::string_datatype(), false), + ColumnSchema::new("cpu", ConcreteDataType::float64_datatype(), true), + ColumnSchema::new("memory", ConcreteDataType::float64_datatype(), true), + ColumnSchema::new("ts", ts_type, true).with_time_index(true), + ]; + + let table_name = "demo"; + let table_engine: TableEngineRef = sql_handler.table_engine(); + let table = table_engine + .create_table( + &EngineContext::default(), + CreateTableRequest { + id: MIN_USER_TABLE_ID, + catalog_name: "greptime".to_string(), + schema_name: "public".to_string(), + table_name: table_name.to_string(), + desc: Some(" a test table".to_string()), + schema: Arc::new( + SchemaBuilder::try_from(column_schemas) + .unwrap() + .build() + .expect("ts is expected to be timestamp column"), + ), + create_if_not_exists: true, + primary_key_indices: vec![0], // "host" is in primary keys + table_options: HashMap::new(), + region_numbers: vec![0], + }, + ) + .await + .context(CreateTableSnafu { table_name })?; + + let schema_provider = catalog_manager + .catalog(DEFAULT_CATALOG_NAME) + .unwrap() + .unwrap() + .schema(DEFAULT_SCHEMA_NAME) + .unwrap() + .unwrap(); + schema_provider + .register_table(table_name.to_string(), table) + .unwrap(); + Ok(()) +} + +async fn build_frontend_instance(datanode_instance: InstanceRef) -> FeInstance { + let fe_opts = FrontendOptions::default(); + let mut frontend_instance = FeInstance::try_new(&fe_opts).await.unwrap(); + frontend_instance.set_catalog_manager(datanode_instance.catalog_manager().clone()); + frontend_instance.set_script_handler(datanode_instance); + frontend_instance +} + +pub async fn setup_test_app(store_type: StorageType, name: &str) -> (Router, TestGuard) { + let (opts, guard) = create_tmp_dir_and_datanode_opts(store_type, name); + let instance = Arc::new(Instance::with_mock_meta_client(&opts).await.unwrap()); + instance.start().await.unwrap(); + create_test_table( + instance.catalog_manager(), + instance.sql_handler(), + ConcreteDataType::timestamp_millis_datatype(), + ) + .await + .unwrap(); + let http_server = HttpServer::new(instance, HttpOptions::default()); + (http_server.make_app(), guard) +} + +pub async fn setup_test_app_with_frontend( + store_type: StorageType, + name: &str, +) -> (Router, TestGuard) { + let (opts, guard) = create_tmp_dir_and_datanode_opts(store_type, name); + let instance = Arc::new(Instance::with_mock_meta_client(&opts).await.unwrap()); + let mut frontend = build_frontend_instance(instance.clone()).await; + instance.start().await.unwrap(); + create_test_table( + frontend.catalog_manager().as_ref().unwrap(), + instance.sql_handler(), + ConcreteDataType::timestamp_millis_datatype(), + ) + .await + .unwrap(); + frontend.start().await.unwrap(); + let mut http_server = HttpServer::new(Arc::new(frontend), HttpOptions::default()); + http_server.set_script_handler(instance.clone()); + let app = http_server.make_app(); + (app, guard) +} + +pub async fn setup_grpc_server( + store_type: StorageType, + name: &str, +) -> (String, TestGuard, Arc, Arc) { + common_telemetry::init_default_ut_logging(); + + let datanode_port = get_port(); + let frontend_port = get_port(); + + let (mut opts, guard) = create_tmp_dir_and_datanode_opts(store_type, name); + let datanode_grpc_addr = format!("127.0.0.1:{}", datanode_port); + opts.rpc_addr = datanode_grpc_addr.clone(); + let instance = Arc::new(Instance::with_mock_meta_client(&opts).await.unwrap()); + instance.start().await.unwrap(); + + let datanode_grpc_addr = datanode_grpc_addr.clone(); + let runtime = Arc::new( + RuntimeBuilder::default() + .worker_threads(2) + .thread_name("grpc-handlers") + .build() + .unwrap(), + ); + + let fe_grpc_addr = format!("127.0.0.1:{}", frontend_port); + let fe_opts = FrontendOptions { + mode: Mode::Standalone, + datanode_rpc_addr: datanode_grpc_addr.clone(), + grpc_options: Some(GrpcOptions { + addr: fe_grpc_addr.clone(), + runtime_size: 8, + }), + ..Default::default() + }; + + let datanode_grpc_server = Arc::new(GrpcServer::new( + instance.clone(), + instance.clone(), + runtime.clone(), + )); + + let mut fe_instance = frontend::instance::Instance::try_new(&fe_opts) + .await + .unwrap(); + fe_instance.set_catalog_manager(instance.catalog_manager().clone()); + + let fe_instance_ref = Arc::new(fe_instance); + let fe_grpc_server = Arc::new(GrpcServer::new( + fe_instance_ref.clone(), + fe_instance_ref, + runtime, + )); + let grpc_server_clone = fe_grpc_server.clone(); + + let fe_grpc_addr_clone = fe_grpc_addr.clone(); + tokio::spawn(async move { + let addr = fe_grpc_addr_clone.parse::().unwrap(); + grpc_server_clone.start(addr).await.unwrap() + }); + + let dn_grpc_addr_clone = datanode_grpc_addr.clone(); + let dn_grpc_server_clone = datanode_grpc_server.clone(); + tokio::spawn(async move { + let addr = dn_grpc_addr_clone.parse::().unwrap(); + dn_grpc_server_clone.start(addr).await.unwrap() + }); + + // wait for GRPC server to start + tokio::time::sleep(Duration::from_secs(1)).await; + + (fe_grpc_addr, guard, fe_grpc_server, datanode_grpc_server) +} diff --git a/src/datanode/src/tests/grpc_test.rs b/tests-integration/tests/grpc.rs similarity index 65% rename from src/datanode/src/tests/grpc_test.rs rename to tests-integration/tests/grpc.rs index 3116543c1b..cf6ba4b922 100644 --- a/src/datanode/src/tests/grpc_test.rs +++ b/tests-integration/tests/grpc.rs @@ -11,114 +11,65 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -use std::assert_matches::assert_matches; -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; - use api::v1::alter_expr::Kind; -use api::v1::codec::InsertBatch; use api::v1::column::SemanticType; use api::v1::{ - admin_result, column, insert_expr, AddColumn, AddColumns, AlterExpr, Column, ColumnDataType, - ColumnDef, CreateExpr, InsertExpr, MutateResult, + admin_result, column, AddColumn, AddColumns, AlterExpr, Column, ColumnDataType, ColumnDef, + CreateExpr, InsertExpr, MutateResult, }; use client::admin::Admin; use client::{Client, Database, ObjectResult}; use common_catalog::consts::MIN_USER_TABLE_ID; -use common_runtime::Builder as RuntimeBuilder; -use frontend::frontend::FrontendOptions; -use frontend::grpc::GrpcOptions; -use servers::grpc::GrpcServer; use servers::server::Server; -use servers::Mode; +use tests_integration::test_util::{setup_grpc_server, StorageType}; -use crate::instance::Instance; -use crate::tests::test_util::{self, TestGuard}; +#[macro_export] +macro_rules! grpc_test { + ($service:ident, $($(#[$meta:meta])* $test:ident),*,) => { + paste::item! { + mod [] { + $( + #[tokio::test(flavor = "multi_thread")] + $( + #[$meta] + )* + async fn [< $test >]() { + let store_type = tests_integration::test_util::StorageType::$service; + if store_type.test_on() { + let _ = $crate::grpc::$test(store_type).await; + } -async fn setup_grpc_server( - name: &str, - datanode_port: usize, - frontend_port: usize, -) -> (String, TestGuard, Arc, Arc) { - common_telemetry::init_default_ut_logging(); - - let (mut opts, guard) = test_util::create_tmp_dir_and_datanode_opts(name); - let datanode_grpc_addr = format!("127.0.0.1:{}", datanode_port); - opts.rpc_addr = datanode_grpc_addr.clone(); - let instance = Arc::new(Instance::with_mock_meta_client(&opts).await.unwrap()); - instance.start().await.unwrap(); - - let datanode_grpc_addr = datanode_grpc_addr.clone(); - let runtime = Arc::new( - RuntimeBuilder::default() - .worker_threads(2) - .thread_name("grpc-handlers") - .build() - .unwrap(), - ); - - let fe_grpc_addr = format!("127.0.0.1:{}", frontend_port); - let fe_opts = FrontendOptions { - mode: Mode::Standalone, - datanode_rpc_addr: datanode_grpc_addr.clone(), - grpc_options: Some(GrpcOptions { - addr: fe_grpc_addr.clone(), - runtime_size: 8, - }), - ..Default::default() + } + )* + } + } }; - - let datanode_grpc_server = Arc::new(GrpcServer::new( - instance.clone(), - instance.clone(), - runtime.clone(), - )); - - let mut fe_instance = frontend::instance::Instance::try_new(&fe_opts) - .await - .unwrap(); - fe_instance.set_catalog_manager(instance.catalog_manager.clone()); - - let fe_instance_ref = Arc::new(fe_instance); - let fe_grpc_server = Arc::new(GrpcServer::new( - fe_instance_ref.clone(), - fe_instance_ref, - runtime, - )); - let grpc_server_clone = fe_grpc_server.clone(); - - let fe_grpc_addr_clone = fe_grpc_addr.clone(); - tokio::spawn(async move { - let addr = fe_grpc_addr_clone.parse::().unwrap(); - grpc_server_clone.start(addr).await.unwrap() - }); - - let dn_grpc_addr_clone = datanode_grpc_addr.clone(); - let dn_grpc_server_clone = datanode_grpc_server.clone(); - tokio::spawn(async move { - let addr = dn_grpc_addr_clone.parse::().unwrap(); - dn_grpc_server_clone.start(addr).await.unwrap() - }); - - // wait for GRPC server to start - tokio::time::sleep(Duration::from_secs(1)).await; - - (fe_grpc_addr, guard, fe_grpc_server, datanode_grpc_server) } -#[tokio::test(flavor = "multi_thread")] -async fn test_auto_create_table() { - let (addr, _guard, fe_grpc_server, dn_grpc_server) = - setup_grpc_server("auto_create_table", 3992, 3993).await; +#[macro_export] +macro_rules! grpc_tests { + ($($service:ident),*) => { + $( + grpc_test!( + $service, + + test_auto_create_table, + test_insert_and_select, + ); + )* + }; +} + +pub async fn test_auto_create_table(store_type: StorageType) { + let (addr, mut guard, fe_grpc_server, dn_grpc_server) = + setup_grpc_server(store_type, "auto_create_table").await; let grpc_client = Client::with_urls(vec![addr]); let db = Database::new("greptime", grpc_client); insert_and_assert(&db).await; let _ = fe_grpc_server.shutdown().await; let _ = dn_grpc_server.shutdown().await; + guard.remove_all().await; } fn expect_data() -> (Column, Column, Column, Column) { @@ -175,11 +126,10 @@ fn expect_data() -> (Column, Column, Column, Column) { ) } -#[tokio::test(flavor = "multi_thread")] -async fn test_insert_and_select() { +pub async fn test_insert_and_select(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (addr, _guard, fe_grpc_server, dn_grpc_server) = - setup_grpc_server("insert_and_select", 3990, 3991).await; + let (addr, mut guard, fe_grpc_server, dn_grpc_server) = + setup_grpc_server(store_type, "insert_and_select").await; let grpc_client = Client::with_urls(vec![addr]); @@ -189,13 +139,13 @@ async fn test_insert_and_select() { // create let expr = testing_create_expr(); let result = admin.create(expr).await.unwrap(); - assert_matches!( + assert!(matches!( result.result, Some(admin_result::Result::Mutate(MutateResult { success: 1, failure: 0 })) - ); + )); //alter let add_column = ColumnDef { @@ -224,13 +174,17 @@ async fn test_insert_and_select() { let _ = fe_grpc_server.shutdown().await; let _ = dn_grpc_server.shutdown().await; + guard.remove_all().await; } async fn insert_and_assert(db: &Database) { // testing data: let (expected_host_col, expected_cpu_col, expected_mem_col, expected_ts_col) = expect_data(); - let values = vec![InsertBatch { + let expr = InsertExpr { + schema_name: "public".to_string(), + table_name: "demo".to_string(), + region_number: 0, columns: vec![ expected_host_col.clone(), expected_cpu_col.clone(), @@ -238,14 +192,6 @@ async fn insert_and_assert(db: &Database) { expected_ts_col.clone(), ], row_count: 4, - } - .into()]; - let expr = InsertExpr { - schema_name: "public".to_string(), - table_name: "demo".to_string(), - expr: Some(insert_expr::Expr::Values(insert_expr::Values { values })), - options: HashMap::default(), - region_number: 0, }; let result = db.insert(expr).await; result.unwrap(); @@ -312,7 +258,7 @@ fn testing_create_expr() -> CreateExpr { desc: Some("blabla".to_string()), column_defs, time_index: "ts".to_string(), - primary_keys: vec!["ts".to_string(), "host".to_string()], + primary_keys: vec!["host".to_string()], create_if_not_exists: true, table_options: Default::default(), table_id: Some(MIN_USER_TABLE_ID), diff --git a/src/datanode/src/tests/http_test.rs b/tests-integration/tests/http.rs similarity index 52% rename from src/datanode/src/tests/http_test.rs rename to tests-integration/tests/http.rs index 7348fb6430..8d074bba67 100644 --- a/src/datanode/src/tests/http_test.rs +++ b/tests-integration/tests/http.rs @@ -12,67 +12,54 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use axum::http::StatusCode; -use axum::Router; use axum_test_helper::TestClient; -use datatypes::prelude::ConcreteDataType; -use frontend::frontend::FrontendOptions; -use frontend::instance::{FrontendInstance, Instance as FeInstance}; use serde_json::json; -use servers::http::{ColumnSchema, HttpServer, JsonOutput, JsonResponse, Schema}; -use test_util::TestGuard; +use servers::http::handler::HealthResponse; +use servers::http::{JsonOutput, JsonResponse}; +use tests_integration::test_util::{setup_test_app, setup_test_app_with_frontend, StorageType}; -use crate::instance::{Instance, InstanceRef}; -use crate::tests::test_util; - -async fn build_frontend_instance(datanode_instance: InstanceRef) -> FeInstance { - let fe_opts = FrontendOptions::default(); - let mut frontend_instance = FeInstance::try_new(&fe_opts).await.unwrap(); - frontend_instance.set_catalog_manager(datanode_instance.catalog_manager().clone()); - frontend_instance.set_script_handler(datanode_instance); - frontend_instance +#[macro_export] +macro_rules! http_test { + ($service:ident, $($(#[$meta:meta])* $test:ident),*,) => { + paste::item! { + mod [] { + $( + #[tokio::test(flavor = "multi_thread")] + $( + #[$meta] + )* + async fn [< $test >]() { + let store_type = tests_integration::test_util::StorageType::$service; + if store_type.test_on() { + let _ = $crate::http::$test(store_type).await; + } + } + )* + } + } + }; } -async fn make_test_app(name: &str) -> (Router, TestGuard) { - let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts(name); - let instance = Arc::new(Instance::with_mock_meta_client(&opts).await.unwrap()); - instance.start().await.unwrap(); - test_util::create_test_table( - instance.catalog_manager(), - instance.sql_handler(), - ConcreteDataType::timestamp_millis_datatype(), - ) - .await - .unwrap(); - let http_server = HttpServer::new(instance); - (http_server.make_app(), guard) +#[macro_export] +macro_rules! http_tests { + ($($service:ident),*) => { + $( + http_test!( + $service, + + test_sql_api, + test_metrics_api, + test_scripts_api, + test_health_api, + ); + )* + }; } -async fn make_test_app_with_frontend(name: &str) -> (Router, TestGuard) { - let (opts, guard) = test_util::create_tmp_dir_and_datanode_opts(name); - let instance = Arc::new(Instance::with_mock_meta_client(&opts).await.unwrap()); - let mut frontend = build_frontend_instance(instance.clone()).await; - instance.start().await.unwrap(); - test_util::create_test_table( - frontend.catalog_manager().as_ref().unwrap(), - instance.sql_handler(), - ConcreteDataType::timestamp_millis_datatype(), - ) - .await - .unwrap(); - frontend.start().await.unwrap(); - let mut http_server = HttpServer::new(Arc::new(frontend)); - http_server.set_script_handler(instance.clone()); - let app = http_server.make_app(); - (app, guard) -} - -#[tokio::test(flavor = "multi_thread")] -async fn test_sql_api() { +pub async fn test_sql_api(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (app, _guard) = make_test_app("sql_api").await; + let (app, mut guard) = setup_test_app(store_type, "sql_api").await; let client = TestClient::new(app); let res = client.get("/v1/sql").send().await; assert_eq!(res.status(), StatusCode::OK); @@ -98,21 +85,12 @@ async fn test_sql_api() { let output = body.output().unwrap(); assert_eq!(output.len(), 1); - if let JsonOutput::Records(records) = &output[0] { - assert_eq!(records.num_cols(), 1); - assert_eq!(records.num_rows(), 10); - assert_eq!( - records.schema().unwrap(), - &Schema::new(vec![ColumnSchema::new( - "number".to_owned(), - "UInt32".to_owned() - )]) - ); - assert_eq!(records.rows()[0][0], json!(0)); - assert_eq!(records.rows()[9][0], json!(9)); - } else { - unreachable!() - } + assert_eq!( + output[0], + serde_json::from_value::(json!({ + "records" :{"schema":{"column_schemas":[{"name":"number","data_type":"UInt32"}]},"rows":[[0],[1],[2],[3],[4],[5],[6],[7],[8],[9]]} + })).unwrap() + ); // test insert and select let res = client @@ -134,25 +112,13 @@ async fn test_sql_api() { assert!(body.execution_time_ms().is_some()); let output = body.output().unwrap(); assert_eq!(output.len(), 1); - if let JsonOutput::Records(records) = &output[0] { - assert_eq!(records.num_cols(), 4); - assert_eq!(records.num_rows(), 1); - assert_eq!( - records.schema().unwrap(), - &Schema::new(vec![ - ColumnSchema::new("host".to_owned(), "String".to_owned()), - ColumnSchema::new("cpu".to_owned(), "Float64".to_owned()), - ColumnSchema::new("memory".to_owned(), "Float64".to_owned()), - ColumnSchema::new("ts".to_owned(), "Timestamp".to_owned()) - ]) - ); - assert_eq!( - records.rows()[0], - vec![json!("host"), json!(66.6), json!(1024.0), json!(0)] - ); - } else { - unreachable!(); - } + + assert_eq!( + output[0], + serde_json::from_value::(json!({ + "records":{"schema":{"column_schemas":[{"name":"host","data_type":"String"},{"name":"cpu","data_type":"Float64"},{"name":"memory","data_type":"Float64"},{"name":"ts","data_type":"Timestamp"}]},"rows":[["host",66.6,1024.0,0]]} + })).unwrap() + ); // select with projections let res = client @@ -168,20 +134,13 @@ async fn test_sql_api() { assert!(body.execution_time_ms().is_some()); let output = body.output().unwrap(); assert_eq!(output.len(), 1); - if let JsonOutput::Records(records) = &output[0] { - assert_eq!(records.num_cols(), 2); - assert_eq!(records.num_rows(), 1); - assert_eq!( - records.schema().unwrap(), - &Schema::new(vec![ - ColumnSchema::new("cpu".to_owned(), "Float64".to_owned()), - ColumnSchema::new("ts".to_owned(), "Timestamp".to_owned()) - ]) - ); - assert_eq!(records.rows()[0], vec![json!(66.6), json!(0)]); - } else { - unreachable!() - } + + assert_eq!( + output[0], + serde_json::from_value::(json!({ + "records":{"schema":{"column_schemas":[{"name":"cpu","data_type":"Float64"},{"name":"ts","data_type":"Timestamp"}]},"rows":[[66.6,0]]} + })).unwrap() + ); // select with column alias let res = client @@ -197,27 +156,20 @@ async fn test_sql_api() { assert!(body.execution_time_ms().is_some()); let output = body.output().unwrap(); assert_eq!(output.len(), 1); - if let JsonOutput::Records(records) = &output[0] { - assert_eq!(records.num_cols(), 2); - assert_eq!(records.num_rows(), 1); - assert_eq!( - records.schema().unwrap(), - &Schema::new(vec![ - ColumnSchema::new("c".to_owned(), "Float64".to_owned()), - ColumnSchema::new("time".to_owned(), "Timestamp".to_owned()) - ]) - ); - assert_eq!(records.rows()[0], vec![json!(66.6), json!(0)]); - } else { - unreachable!() - } + assert_eq!( + output[0], + serde_json::from_value::(json!({ + "records":{"schema":{"column_schemas":[{"name":"c","data_type":"Float64"},{"name":"time","data_type":"Timestamp"}]},"rows":[[66.6,0]]} + })).unwrap() + ); + + guard.remove_all().await; } -#[tokio::test(flavor = "multi_thread")] -async fn test_metrics_api() { +pub async fn test_metrics_api(store_type: StorageType) { common_telemetry::init_default_ut_logging(); common_telemetry::init_default_metrics_recorder(); - let (app, _guard) = make_test_app("metrics_api").await; + let (app, mut guard) = setup_test_app(store_type, "metrics_api").await; let client = TestClient::new(app); // Send a sql @@ -232,12 +184,12 @@ async fn test_metrics_api() { assert_eq!(res.status(), StatusCode::OK); let body = res.text().await; assert!(body.contains("datanode_handle_sql_elapsed")); + guard.remove_all().await; } -#[tokio::test(flavor = "multi_thread")] -async fn test_scripts_api() { +pub async fn test_scripts_api(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (app, _guard) = make_test_app_with_frontend("script_api").await; + let (app, mut guard) = setup_test_app_with_frontend(store_type, "script_api").await; let client = TestClient::new(app); let res = client @@ -269,18 +221,34 @@ def test(n): assert!(body.execution_time_ms().is_some()); let output = body.output().unwrap(); assert_eq!(output.len(), 1); - if let JsonOutput::Records(ref records) = output[0] { - assert_eq!(records.num_cols(), 1); - assert_eq!(records.num_rows(), 10); - assert_eq!( - records.schema().unwrap(), - &Schema::new(vec![ColumnSchema::new( - "n".to_owned(), - "Float64".to_owned() - )]) - ); - assert_eq!(records.rows()[0][0], json!(1.0)); - } else { - unreachable!() - } + assert_eq!( + output[0], + serde_json::from_value::(json!({ + "records":{"schema":{"column_schemas":[{"name":"n","data_type":"Float64"}]},"rows":[[1.0],[2.0],[3.0],[4.0],[5.0],[6.0],[7.0],[8.0],[9.0],[10.0]]} + })).unwrap() + ); + + guard.remove_all().await; +} + +pub async fn test_health_api(store_type: StorageType) { + common_telemetry::init_default_ut_logging(); + let (app, _guard) = setup_test_app_with_frontend(store_type, "health_api").await; + let client = TestClient::new(app); + + // we can call health api with both `GET` and `POST` method. + let res_post = client.post("/health").send().await; + assert_eq!(res_post.status(), StatusCode::OK); + let res_get = client.get("/health").send().await; + assert_eq!(res_get.status(), StatusCode::OK); + + // both `GET` and `POST` method return same result + let body_text = res_post.text().await; + assert_eq!(body_text, res_get.text().await); + + // currently health api simply returns an empty json `{}`, which can be deserialized to an empty `HealthResponse` + assert_eq!(body_text, "{}"); + + let body = serde_json::from_str::(&body_text).unwrap(); + assert_eq!(body, HealthResponse {}); } diff --git a/tests-integration/tests/main.rs b/tests-integration/tests/main.rs new file mode 100644 index 0000000000..b664e03ad9 --- /dev/null +++ b/tests-integration/tests/main.rs @@ -0,0 +1,21 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[macro_use] +mod grpc; +#[macro_use] +mod http; + +grpc_tests!(File, S3); +http_tests!(File, S3); diff --git a/tests/cases/standalone/basic.result b/tests/cases/standalone/basic.result new file mode 100644 index 0000000000..229da9c61f --- /dev/null +++ b/tests/cases/standalone/basic.result @@ -0,0 +1,60 @@ +CREATE TABLE system_metrics ( + host STRING, + idc STRING, + cpu_util DOUBLE, + memory_util DOUBLE, + disk_util DOUBLE, + ts TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY(host, idc), + TIME INDEX(ts) +); + +MutateResult { success: 1, failure: 0 } + +INSERT INTO system_metrics +VALUES + ("host1", "idc_a", 11.8, 10.3, 10.3, 1667446797450), + ("host2", "idc_a", 80.1, 70.3, 90.0, 1667446797450), + ("host1", "idc_b", 50.0, 66.7, 40.6, 1667446797450); + +MutateResult { success: 3, failure: 0 } + +SELECT * FROM system_metrics; + ++-----------------------+----------------------+----------------------------+-------------------------------+-----------------------------+----------------------------+ +| host, #Field, #String | idc, #Field, #String | cpu_util, #Field, #Float64 | memory_util, #Field, #Float64 | disk_util, #Field, #Float64 | ts, #Timestamp, #Timestamp | ++-----------------------+----------------------+----------------------------+-------------------------------+-----------------------------+----------------------------+ +| host1 | idc_a | 11.8 | 10.3 | 10.3 | 1667446797450 | +| host1 | idc_b | 50 | 66.7 | 40.6 | 1667446797450 | +| host2 | idc_a | 80.1 | 70.3 | 90 | 1667446797450 | ++-----------------------+----------------------+----------------------------+-------------------------------+-----------------------------+----------------------------+ + +SELECT count(*) FROM system_metrics; + ++----------------------------------+ +| COUNT(UInt8(1)), #Field, #Uint64 | ++----------------------------------+ +| 3 | ++----------------------------------+ + +SELECT avg(cpu_util) FROM system_metrics; + ++------------------------------------------------+ +| AVG(system_metrics.cpu_util), #Field, #Float64 | ++------------------------------------------------+ +| 47.29999999999999 | ++------------------------------------------------+ + +SELECT idc, avg(memory_util) FROM system_metrics GROUP BY idc ORDER BY idc; + ++----------------------+---------------------------------------------------+ +| idc, #Field, #String | AVG(system_metrics.memory_util), #Field, #Float64 | ++----------------------+---------------------------------------------------+ +| idc_a | 40.3 | +| idc_b | 66.7 | ++----------------------+---------------------------------------------------+ + +DROP TABLE system_metrics; + +MutateResult { success: 1, failure: 0 } + diff --git a/tests/cases/standalone/basic.sql b/tests/cases/standalone/basic.sql new file mode 100644 index 0000000000..d08a78a6a6 --- /dev/null +++ b/tests/cases/standalone/basic.sql @@ -0,0 +1,26 @@ +CREATE TABLE system_metrics ( + host STRING, + idc STRING, + cpu_util DOUBLE, + memory_util DOUBLE, + disk_util DOUBLE, + ts TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY(host, idc), + TIME INDEX(ts) +); + +INSERT INTO system_metrics +VALUES + ("host1", "idc_a", 11.8, 10.3, 10.3, 1667446797450), + ("host2", "idc_a", 80.1, 70.3, 90.0, 1667446797450), + ("host1", "idc_b", 50.0, 66.7, 40.6, 1667446797450); + +SELECT * FROM system_metrics; + +SELECT count(*) FROM system_metrics; + +SELECT avg(cpu_util) FROM system_metrics; + +SELECT idc, avg(memory_util) FROM system_metrics GROUP BY idc ORDER BY idc; + +DROP TABLE system_metrics; diff --git a/tests/cases/standalone/select/dummy.result b/tests/cases/standalone/select/dummy.result new file mode 100644 index 0000000000..24132d6b4f --- /dev/null +++ b/tests/cases/standalone/select/dummy.result @@ -0,0 +1,36 @@ +select 1; + ++--------------------------+ +| Int64(1), #Field, #Int64 | ++--------------------------+ +| 1 | ++--------------------------+ + +select 2 + 3; + ++----------------------------------------+ +| Int64(2) Plus Int64(3), #Field, #Int64 | ++----------------------------------------+ +| 5 | ++----------------------------------------+ + +select 4 + 0.5; + ++----------------------------------------------+ +| Int64(4) Plus Float64(0.5), #Field, #Float64 | ++----------------------------------------------+ +| 4.5 | ++----------------------------------------------+ + +select "a"; + +Failed to execute, error: Datanode { code: 1003, msg: "Failed to execute query: select \"a\";, source: Failed to select from table, source: Error occurred on the data node, code: 3000, msg: Failed to execute sql, source: Cannot plan SQL: SELECT \"a\", source: Error during planning: Invalid identifier '#a' for schema fields:[], metadata:{}" } + +select "A"; + +Failed to execute, error: Datanode { code: 1003, msg: "Failed to execute query: select \"A\";, source: Failed to select from table, source: Error occurred on the data node, code: 3000, msg: Failed to execute sql, source: Cannot plan SQL: SELECT \"A\", source: Error during planning: Invalid identifier '#A' for schema fields:[], metadata:{}" } + +select * where "a" = "A"; + +Failed to execute, error: Datanode { code: 1003, msg: "Failed to execute query: select * where \"a\" = \"A\";, source: Failed to select from table, source: Error occurred on the data node, code: 3000, msg: Failed to execute sql, source: Cannot plan SQL: SELECT * WHERE \"a\" = \"A\", source: Error during planning: Invalid identifier '#a' for schema fields:[], metadata:{}" } + diff --git a/tests/cases/standalone/select/dummy.sql b/tests/cases/standalone/select/dummy.sql new file mode 100644 index 0000000000..97d975b2e2 --- /dev/null +++ b/tests/cases/standalone/select/dummy.sql @@ -0,0 +1,11 @@ +select 1; + +select 2 + 3; + +select 4 + 0.5; + +select "a"; + +select "A"; + +select * where "a" = "A"; diff --git a/tests/runner/Cargo.toml b/tests/runner/Cargo.toml new file mode 100644 index 0000000000..9728e65cad --- /dev/null +++ b/tests/runner/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "sqlness-runner" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-trait = "0.1" +client = { path = "../../src/client" } +comfy-table = "6.1" +sqlness = { git = "https://github.com/ceresdb/sqlness.git" } +tokio = { version = "1.21", features = ["full"] } diff --git a/tests/runner/src/env.rs b/tests/runner/src/env.rs new file mode 100644 index 0000000000..5a92aa8374 --- /dev/null +++ b/tests/runner/src/env.rs @@ -0,0 +1,202 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; +use std::fs::OpenOptions; +use std::process::Stdio; +use std::time::Duration; + +use async_trait::async_trait; +use client::api::v1::codec::SelectResult; +use client::api::v1::column::SemanticType; +use client::api::v1::ColumnDataType; +use client::{Client, Database as DB, Error as ClientError, ObjectResult, Select}; +use comfy_table::{Cell, Table}; +use sqlness::{Database, Environment}; +use tokio::process::{Child, Command}; + +use crate::util; + +const SERVER_ADDR: &str = "127.0.0.1:4001"; +const SERVER_LOG_FILE: &str = "/tmp/greptime-sqlness.log"; + +pub struct Env {} + +#[async_trait] +impl Environment for Env { + type DB = GreptimeDB; + + async fn start(&self, mode: &str, _config: Option) -> Self::DB { + match mode { + "standalone" => Self::start_standalone().await, + "distributed" => Self::start_distributed().await, + _ => panic!("Unexpected mode: {}", mode), + } + } + + /// Stop one [`Database`]. + async fn stop(&self, _mode: &str, mut database: Self::DB) { + database.server_process.kill().await.unwrap() + } +} + +impl Env { + #[allow(clippy::print_stdout)] + pub async fn start_standalone() -> GreptimeDB { + // Build the DB with `cargo build --bin greptime` + println!("Going to build the DB..."); + let cargo_build_result = Command::new("cargo") + .current_dir(util::get_workspace_root()) + .args(["build", "--bin", "greptime"]) + .stdout(Stdio::null()) + .output() + .await + .expect("Failed to start GreptimeDB") + .status; + if !cargo_build_result.success() { + panic!("Failed to build GreptimeDB (`cargo build` fails)"); + } + println!("Build finished, starting..."); + + // Open log file (build logs will be truncated). + let log_file = OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(SERVER_LOG_FILE) + .unwrap_or_else(|_| panic!("Cannot open log file at {}", SERVER_LOG_FILE)); + // Start the DB + let server_process = Command::new("./greptime") + .current_dir(util::get_binary_dir("debug")) + .args(["standalone", "start", "-m"]) + .stdout(log_file) + .spawn() + .expect("Failed to start the DB"); + + let is_up = util::check_port(SERVER_ADDR.parse().unwrap(), Duration::from_secs(10)).await; + if !is_up { + panic!("Server doesn't up in 10 seconds, quit.") + } + println!( + "Started, going to test. Log will be write to {}", + SERVER_LOG_FILE + ); + + let client = Client::with_urls(vec![SERVER_ADDR]); + let db = DB::new("greptime", client.clone()); + + GreptimeDB { + server_process, + client, + db, + } + } + + pub async fn start_distributed() -> GreptimeDB { + todo!() + } +} + +pub struct GreptimeDB { + server_process: Child, + #[allow(dead_code)] + client: Client, + db: DB, +} + +#[async_trait] +impl Database for GreptimeDB { + async fn query(&self, query: String) -> Box { + let sql = Select::Sql(query); + let result = self.db.select(sql).await; + Box::new(ResultDisplayer { result }) as _ + } +} + +struct ResultDisplayer { + result: Result, +} + +impl Display for ResultDisplayer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.result { + Ok(result) => match result { + ObjectResult::Select(select_result) => { + write!( + f, + "{}", + SelectResultDisplayer { + result: select_result + } + .display() + ) + } + ObjectResult::Mutate(mutate_result) => { + write!(f, "{:?}", mutate_result) + } + }, + Err(e) => write!(f, "Failed to execute, error: {:?}", e), + } + } +} + +struct SelectResultDisplayer<'a> { + result: &'a SelectResult, +} + +impl SelectResultDisplayer<'_> { + fn display(&self) -> impl Display { + let mut table = Table::new(); + table.load_preset("||--+-++| ++++++"); + + if self.result.row_count == 0 { + return table; + } + + let mut headers = vec![]; + for column in &self.result.columns { + headers.push(Cell::new(format!( + "{}, #{:?}, #{:?}", + column.column_name, + SemanticType::from_i32(column.semantic_type).unwrap(), + ColumnDataType::from_i32(column.datatype).unwrap() + ))); + } + table.set_header(headers); + + let col_count = self.result.columns.len(); + let row_count = self.result.row_count as usize; + let columns = self + .result + .columns + .iter() + .map(|col| { + util::values_to_string( + ColumnDataType::from_i32(col.datatype).unwrap(), + col.values.clone().unwrap(), + ) + }) + .collect::>(); + + for row_index in 0..row_count { + let mut row = Vec::with_capacity(col_count); + for col in columns.iter() { + row.push(col[row_index].clone()); + } + table.add_row(row); + } + + table + } +} diff --git a/tests/runner/src/main.rs b/tests/runner/src/main.rs new file mode 100644 index 0000000000..696cbf1d7c --- /dev/null +++ b/tests/runner/src/main.rs @@ -0,0 +1,29 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use env::Env; +use sqlness::{ConfigBuilder, Runner}; + +mod env; +mod util; + +#[tokio::main] +async fn main() { + let config = ConfigBuilder::default() + .case_dir(util::get_case_dir()) + .build() + .unwrap(); + let runner = Runner::new_with_config(config, Env {}).await.unwrap(); + runner.run().await.unwrap(); +} diff --git a/tests/runner/src/util.rs b/tests/runner/src/util.rs new file mode 100644 index 0000000000..a6accc9ed7 --- /dev/null +++ b/tests/runner/src/util.rs @@ -0,0 +1,162 @@ +// Copyright 2022 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::SocketAddr; +use std::path::PathBuf; +use std::time::Duration; + +use client::api::v1::column::Values; +use client::api::v1::ColumnDataType; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpSocket; +use tokio::time; + +/// Check port every 0.1 second. +const PORT_CHECK_INTERVAL: Duration = Duration::from_millis(100); + +pub fn values_to_string(data_type: ColumnDataType, values: Values) -> Vec { + match data_type { + ColumnDataType::Int64 => values + .i64_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Float64 => values + .f64_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::String => values.string_values, + ColumnDataType::Boolean => values + .bool_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Int8 => values + .i8_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Int16 => values + .i16_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Int32 => values + .i32_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Uint8 => values + .u8_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Uint16 => values + .u16_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Uint32 => values + .u32_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Uint64 => values + .u64_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Float32 => values + .f32_values + .into_iter() + .map(|val| val.to_string()) + .collect(), + ColumnDataType::Binary => values + .binary_values + .into_iter() + .map(|val| format!("{:?}", val)) + .collect(), + ColumnDataType::Datetime => values + .i64_values + .into_iter() + .map(|v| v.to_string()) + .collect(), + ColumnDataType::Date => values + .i32_values + .into_iter() + .map(|v| v.to_string()) + .collect(), + ColumnDataType::Timestamp => values + .ts_millis_values + .into_iter() + .map(|v| v.to_string()) + .collect(), + } +} + +/// Get the dir of test cases. This function only works when the runner is run +/// under the project's dir because it depends on some envs set by cargo. +pub fn get_case_dir() -> String { + // retrieve the manifest runner (./tests/runner) + let mut runner_crate_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + + // change directory to cases' dir from runner's (should be runner/../cases) + runner_crate_path.pop(); + runner_crate_path.push("cases"); + + runner_crate_path.into_os_string().into_string().unwrap() +} + +/// Get the dir that contains workspace manifest (the top-level Cargo.toml). +pub fn get_workspace_root() -> String { + // retrieve the manifest runner (./tests/runner) + let mut runner_crate_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + + // change directory to workspace's root (runner/../..) + runner_crate_path.pop(); + runner_crate_path.pop(); + + runner_crate_path.into_os_string().into_string().unwrap() +} + +pub fn get_binary_dir(mode: &str) -> String { + // first go to the workspace root. + let mut workspace_root = PathBuf::from(get_workspace_root()); + + // change directory to target dir (workspace/target//) + workspace_root.push("target"); + workspace_root.push(mode); + + workspace_root.into_os_string().into_string().unwrap() +} + +/// Spin-waiting a socket address is available, or timeout. +/// Returns whether the addr is up. +pub async fn check_port(ip_addr: SocketAddr, timeout: Duration) -> bool { + let check_task = async { + loop { + let socket = TcpSocket::new_v4().expect("Cannot create v4 socket"); + match socket.connect(ip_addr).await { + Ok(mut stream) => { + let _ = stream.shutdown().await; + break; + } + Err(_) => time::sleep(PORT_CHECK_INTERVAL).await, + } + } + }; + + tokio::time::timeout(timeout, check_task).await.is_ok() +} diff --git a/typos.toml b/typos.toml new file mode 100644 index 0000000000..b8bff9fe48 --- /dev/null +++ b/typos.toml @@ -0,0 +1,2 @@ +[default.extend-words] +ue = "ue"