mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-03 20:02:54 +00:00
Compare commits
25 Commits
v0.1.0-alp
...
v0.1.0-alp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4052563248 | ||
|
|
952e1bd626 | ||
|
|
8232015998 | ||
|
|
d82a3a7d58 | ||
|
|
0599465685 | ||
|
|
13d51250ba | ||
|
|
6127706b5b | ||
|
|
2e17e9c4b5 | ||
|
|
b0cbfa7ffb | ||
|
|
20172338e8 | ||
|
|
9c53f9b24c | ||
|
|
6d24f7ebb6 | ||
|
|
68c2de8e45 | ||
|
|
a17dcbc511 | ||
|
|
53ab19ea5a | ||
|
|
84c44cf540 | ||
|
|
020b9936cd | ||
|
|
75dcf2467b | ||
|
|
eea5393f96 | ||
|
|
3d312d389d | ||
|
|
fdc73fb52f | ||
|
|
2a36e26d19 | ||
|
|
baef640fe3 | ||
|
|
5fddb799f7 | ||
|
|
f372229b18 |
4
.env.example
Normal file
4
.env.example
Normal file
@@ -0,0 +1,4 @@
|
||||
# Settings for s3 test
|
||||
GT_S3_BUCKET=S3 bucket
|
||||
GT_S3_ACCESS_KEY_ID=S3 access key id
|
||||
GT_S3_ACCESS_KEY=S3 secret access key
|
||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
@@ -13,7 +13,7 @@ Please explain IN DETAIL what the changes are in this PR and why they are needed
|
||||
|
||||
## Checklist
|
||||
|
||||
- [] I have written the necessary rustdoc comments.
|
||||
- [] I have added the necessary unit tests and integration tests.
|
||||
- [ ] I have written the necessary rustdoc comments.
|
||||
- [ ] I have added the necessary unit tests and integration tests.
|
||||
|
||||
## Refer to a related PR or issue link (optional)
|
||||
|
||||
25
.github/workflows/doc-issue.yml
vendored
Normal file
25
.github/workflows/doc-issue.yml
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
name: Create Issue in docs repo on doc related changes
|
||||
|
||||
on:
|
||||
issues:
|
||||
types:
|
||||
- labeled
|
||||
pull_request_target:
|
||||
types:
|
||||
- labeled
|
||||
|
||||
jobs:
|
||||
doc_issue:
|
||||
if: github.event.label.name == 'doc update required'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: create an issue in doc repo
|
||||
uses: dacbd/create-issue-action@main
|
||||
with:
|
||||
owner: GreptimeTeam
|
||||
repo: docs
|
||||
token: ${{ secrets.DOCS_REPO_TOKEN }}
|
||||
title: Update docs for ${{ github.event.issue.title || github.event.pull_request.title }}
|
||||
body: |
|
||||
A document change request is generated from
|
||||
${{ github.event.issue.html_url || github.event.pull_request.html_url }}
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -32,3 +32,6 @@ logs/
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/data
|
||||
|
||||
# dotenv
|
||||
.env
|
||||
|
||||
@@ -9,7 +9,7 @@ repos:
|
||||
rev: e6a795bc6b2c0958f9ef52af4863bbd7cc17238f
|
||||
hooks:
|
||||
- id: cargo-sort
|
||||
args: ["--workspace", "--print"]
|
||||
args: ["--workspace"]
|
||||
|
||||
- repo: https://github.com/doublify/pre-commit-rust
|
||||
rev: v1.0
|
||||
|
||||
202
Cargo.lock
generated
202
Cargo.lock
generated
@@ -51,11 +51,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aide"
|
||||
version = "0.6.2"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a47c350c121222a7d8cc7d2efad0856ddd3a903eb62f0f5e982efb27d811c94c"
|
||||
checksum = "befdff0b4683a0824fc8719ce639a252d9d62cd89c8d0004c39e2417128c1eb8"
|
||||
dependencies = [
|
||||
"axum 0.6.0-rc.2",
|
||||
"axum 0.6.1",
|
||||
"bytes",
|
||||
"cfg-if",
|
||||
"http",
|
||||
@@ -407,12 +407,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.6.0-rc.2"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2628a243073c55aef15a1c1fe45c87f21b84f9e89ca9e7b262a180d3d03543d"
|
||||
checksum = "08b108ad2665fa3f6e6a517c3d80ec3e77d224c47d605167aefaa5d7ef97fa48"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core 0.3.0-rc.2",
|
||||
"axum-core 0.3.0",
|
||||
"bitflags",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
@@ -420,13 +420,15 @@ dependencies = [
|
||||
"http-body",
|
||||
"hyper",
|
||||
"itoa 1.0.3",
|
||||
"matchit 0.6.0",
|
||||
"matchit 0.7.0",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
@@ -454,9 +456,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.3.0-rc.2"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "473bd0762170028bb6b5068be9e97de2a9f0af3bf2084498d840498f47194d3d"
|
||||
checksum = "79b8558f5a0581152dc94dcd289132a1d377494bdeafcd41869b3258e3e2ad92"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
@@ -464,15 +466,16 @@ dependencies = [
|
||||
"http",
|
||||
"http-body",
|
||||
"mime",
|
||||
"rustversion",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-macros"
|
||||
version = "0.3.0-rc.1"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "247a599903eb2e02abbaf2facc6396140df7af6dcc84e64ce3b71d117702fa22"
|
||||
checksum = "e4df0fc33ada14a338b799002f7e8657711422b25d4e16afb032708d6b185621"
|
||||
dependencies = [
|
||||
"heck 0.4.0",
|
||||
"proc-macro2",
|
||||
@@ -718,6 +721,17 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "build-data"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a94f9f7aab679acac7ce29ba5581c00d3971a861c3b501c5bb74c3ba0026d90"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"safe-lock",
|
||||
"safe-regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.11.0"
|
||||
@@ -845,7 +859,6 @@ dependencies = [
|
||||
"meta-client",
|
||||
"mito",
|
||||
"object-store",
|
||||
"opendal",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -1106,6 +1119,7 @@ dependencies = [
|
||||
name = "cmd"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"build-data",
|
||||
"clap 3.2.22",
|
||||
"common-error",
|
||||
"common-telemetry",
|
||||
@@ -1787,7 +1801,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"api",
|
||||
"async-trait",
|
||||
"axum 0.6.0-rc.2",
|
||||
"axum 0.6.1",
|
||||
"axum-macros",
|
||||
"axum-test-helper",
|
||||
"backon",
|
||||
@@ -1820,6 +1834,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"servers",
|
||||
"session",
|
||||
"snafu",
|
||||
"sql",
|
||||
"storage",
|
||||
@@ -1970,6 +1985,12 @@ version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
|
||||
|
||||
[[package]]
|
||||
name = "dotenv"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
|
||||
|
||||
[[package]]
|
||||
name = "dyn-clone"
|
||||
version = "1.0.9"
|
||||
@@ -2207,9 +2228,11 @@ dependencies = [
|
||||
"openmetrics-parser",
|
||||
"prost 0.11.0",
|
||||
"query",
|
||||
"rustls",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"servers",
|
||||
"session",
|
||||
"snafu",
|
||||
"sql",
|
||||
"sqlparser",
|
||||
@@ -3096,9 +3119,9 @@ checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.6.0"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3dfc802da7b1cf80aefffa0c7b2f77247c8b32206cc83c270b61264f5b360a80"
|
||||
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
@@ -3645,6 +3668,7 @@ dependencies = [
|
||||
"opendal",
|
||||
"tempdir",
|
||||
"tokio",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3706,16 +3730,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "opensrv-mysql"
|
||||
version = "0.2.0"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4c24c12fd688cb5aa5b1a54c6ccb2e30fb9b5132debb0e89fcb432b3f73db8f"
|
||||
checksum = "ac5d68ae914b1317d874ce049e52d386b1209d8835d4e6e094f2e90bfb49eccc"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
"chrono",
|
||||
"mysql_common",
|
||||
"nom",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4491,6 +4517,7 @@ dependencies = [
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"session",
|
||||
"snafu",
|
||||
"sql",
|
||||
"statrs",
|
||||
@@ -4912,9 +4939,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.20.6"
|
||||
version = "0.20.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033"
|
||||
checksum = "539a2bfe908f471bfa933876bd1eb6a19cf2176d375f82ef7f99530a40e48c2c"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
@@ -5191,6 +5218,59 @@ version = "1.0.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09"
|
||||
|
||||
[[package]]
|
||||
name = "safe-lock"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "077d73db7973cccf63eb4aff1e5a34dc2459baa867512088269ea5f2f4253c90"
|
||||
|
||||
[[package]]
|
||||
name = "safe-proc-macro2"
|
||||
version = "1.0.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "814c536dcd27acf03296c618dab7ad62d28e70abd7ba41d3f34a2ce707a2c666"
|
||||
dependencies = [
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "safe-quote"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "77e530f7831f3feafcd5f1aae406ac205dd998436b4007c8e80f03eca78a88f7"
|
||||
dependencies = [
|
||||
"safe-proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "safe-regex"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a15289bf322e0673d52756a18194167f2378ec1a15fe884af6e2d2cb934822b0"
|
||||
dependencies = [
|
||||
"safe-regex-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "safe-regex-compiler"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fba76fae590a2aa665279deb1f57b5098cbace01a0c5e60e262fcf55f7c51542"
|
||||
dependencies = [
|
||||
"safe-proc-macro2",
|
||||
"safe-quote",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "safe-regex-macro"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96c2e96b5c03f158d1b16ba79af515137795f4ad4e8de3f790518aae91f1d127"
|
||||
dependencies = [
|
||||
"safe-proc-macro2",
|
||||
"safe-regex-compiler",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
@@ -5289,6 +5369,7 @@ dependencies = [
|
||||
"rustpython-parser",
|
||||
"rustpython-vm",
|
||||
"serde",
|
||||
"session",
|
||||
"snafu",
|
||||
"sql",
|
||||
"storage",
|
||||
@@ -5393,6 +5474,15 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "184c643044780f7ceb59104cef98a5a6f12cb2288a7bc701ab93a362b49fd47d"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.7.1"
|
||||
@@ -5412,7 +5502,7 @@ dependencies = [
|
||||
"aide",
|
||||
"api",
|
||||
"async-trait",
|
||||
"axum 0.6.0-rc.2",
|
||||
"axum 0.6.1",
|
||||
"axum-macros",
|
||||
"axum-test-helper",
|
||||
"bytes",
|
||||
@@ -5443,15 +5533,20 @@ dependencies = [
|
||||
"query",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"rustls",
|
||||
"rustls-pemfile 1.0.1",
|
||||
"schemars",
|
||||
"script",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"session",
|
||||
"snafu",
|
||||
"snap",
|
||||
"table",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls",
|
||||
"tokio-stream",
|
||||
"tokio-test",
|
||||
"tonic",
|
||||
@@ -5460,6 +5555,14 @@ dependencies = [
|
||||
"tower-http",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "session"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"common-telemetry",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha-1"
|
||||
version = "0.10.0"
|
||||
@@ -5956,7 +6059,9 @@ dependencies = [
|
||||
"catalog",
|
||||
"common-catalog",
|
||||
"common-error",
|
||||
"common-telemetry",
|
||||
"datafusion",
|
||||
"datafusion-expr",
|
||||
"datatypes",
|
||||
"futures",
|
||||
"prost 0.9.0",
|
||||
@@ -6102,6 +6207,38 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tests-integration"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"api",
|
||||
"axum 0.6.1",
|
||||
"axum-test-helper",
|
||||
"catalog",
|
||||
"client",
|
||||
"common-catalog",
|
||||
"common-runtime",
|
||||
"common-telemetry",
|
||||
"datanode",
|
||||
"datatypes",
|
||||
"dotenv",
|
||||
"frontend",
|
||||
"mito",
|
||||
"object-store",
|
||||
"once_cell",
|
||||
"paste",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"servers",
|
||||
"snafu",
|
||||
"sql",
|
||||
"table",
|
||||
"tempdir",
|
||||
"tokio",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.11.0"
|
||||
@@ -6328,6 +6465,20 @@ dependencies = [
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-postgres-rustls"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "606f2b73660439474394432239c82249c0d45eb5f23d91f401be1e33590444a7"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"ring",
|
||||
"rustls",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-rustls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.23.4"
|
||||
@@ -6500,9 +6651,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.1"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "343bc9466d3fe6b0f960ef45960509f84480bf4fd96f92901afe7ff3df9d3a62"
|
||||
checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0"
|
||||
|
||||
[[package]]
|
||||
name = "tower-service"
|
||||
@@ -6803,6 +6954,12 @@ version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_names2"
|
||||
version = "0.5.1"
|
||||
@@ -6856,6 +7013,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd6469f4314d5f1ffec476e05f17cc9a78bc7a27a6a857842170bdf8d6f98d2f"
|
||||
dependencies = [
|
||||
"getrandom 0.2.7",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
10
Cargo.toml
10
Cargo.toml
@@ -11,11 +11,11 @@ members = [
|
||||
"src/common/function",
|
||||
"src/common/function-macro",
|
||||
"src/common/grpc",
|
||||
"src/common/grpc-expr",
|
||||
"src/common/query",
|
||||
"src/common/recordbatch",
|
||||
"src/common/runtime",
|
||||
"src/common/substrait",
|
||||
"src/common/grpc-expr",
|
||||
"src/common/telemetry",
|
||||
"src/common/time",
|
||||
"src/datanode",
|
||||
@@ -24,17 +24,19 @@ members = [
|
||||
"src/log-store",
|
||||
"src/meta-client",
|
||||
"src/meta-srv",
|
||||
"src/mito",
|
||||
"src/object-store",
|
||||
"src/query",
|
||||
"src/script",
|
||||
"src/servers",
|
||||
"src/session",
|
||||
"src/sql",
|
||||
"src/storage",
|
||||
"src/store-api",
|
||||
"src/table",
|
||||
"src/mito",
|
||||
"tests/runner",
|
||||
]
|
||||
"tests-integration"
|
||||
,
|
||||
"tests/runner"]
|
||||
|
||||
[profile.release]
|
||||
debug = true
|
||||
|
||||
@@ -56,7 +56,6 @@ To compile GreptimeDB from source, you'll need:
|
||||
find an installation instructions [here](https://grpc.io/docs/protoc-installation/).
|
||||
**Note that `protoc` version needs to be >= 3.15** because we have used the `optional`
|
||||
keyword. You can check it with `protoc --version`.
|
||||
|
||||
|
||||
#### Build with Docker
|
||||
|
||||
@@ -161,6 +160,8 @@ break things. Benchmark on development branch may not represent its potential
|
||||
performance. We release pre-built binaries constantly for functional
|
||||
evaluation. Do not use it in production at the moment.
|
||||
|
||||
For future plans, check out [GreptimeDB roadmap](https://github.com/GreptimeTeam/greptimedb/issues/669).
|
||||
|
||||
## Community
|
||||
|
||||
Our core team is thrilled too see you participate in any ways you like. When you are stuck, try to
|
||||
|
||||
@@ -7,8 +7,8 @@ license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
common-base = { path = "../common/base" }
|
||||
common-time = { path = "../common/time" }
|
||||
common-error = { path = "../common/error" }
|
||||
common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
prost = "0.11"
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
|
||||
@@ -27,7 +27,6 @@ futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
lazy_static = "1.4"
|
||||
meta-client = { path = "../meta-client" }
|
||||
opendal = "0.21"
|
||||
regex = "1.6"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
@@ -39,9 +38,8 @@ tokio = { version = "1.18", features = ["full"] }
|
||||
[dev-dependencies]
|
||||
chrono = "0.4"
|
||||
log-store = { path = "../log-store" }
|
||||
object-store = { path = "../object-store" }
|
||||
opendal = "0.21"
|
||||
storage = { path = "../storage" }
|
||||
mito = { path = "../mito", features = ["test"] }
|
||||
object-store = { path = "../object-store" }
|
||||
storage = { path = "../storage" }
|
||||
tempdir = "0.3"
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
|
||||
@@ -11,9 +11,9 @@ async-stream = "0.3"
|
||||
common-base = { path = "../common/base" }
|
||||
common-error = { path = "../common/error" }
|
||||
common-grpc = { path = "../common/grpc" }
|
||||
common-grpc-expr = { path = "../common/grpc-expr" }
|
||||
common-query = { path = "../common/query" }
|
||||
common-recordbatch = { path = "../common/recordbatch" }
|
||||
common-grpc-expr = { path = "../common/grpc-expr" }
|
||||
common-time = { path = "../common/time" }
|
||||
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [
|
||||
"simd",
|
||||
|
||||
@@ -15,13 +15,13 @@ common-error = { path = "../common/error" }
|
||||
common-telemetry = { path = "../common/telemetry", features = [
|
||||
"deadlock_detection",
|
||||
] }
|
||||
meta-client = { path = "../meta-client" }
|
||||
datanode = { path = "../datanode" }
|
||||
frontend = { path = "../frontend" }
|
||||
futures = "0.3"
|
||||
meta-client = { path = "../meta-client" }
|
||||
meta-srv = { path = "../meta-srv" }
|
||||
serde = "1.0"
|
||||
servers = {path = "../servers"}
|
||||
servers = { path = "../servers" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
tokio = { version = "1.18", features = ["full"] }
|
||||
toml = "0.5"
|
||||
@@ -29,3 +29,6 @@ toml = "0.5"
|
||||
[dev-dependencies]
|
||||
serde = "1.0"
|
||||
tempdir = "0.3"
|
||||
|
||||
[build-dependencies]
|
||||
build-data = "0.1.3"
|
||||
|
||||
19
src/cmd/build.rs
Normal file
19
src/cmd/build.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
fn main() {
|
||||
build_data::set_GIT_BRANCH();
|
||||
build_data::set_GIT_COMMIT();
|
||||
build_data::set_GIT_DIRTY();
|
||||
}
|
||||
@@ -20,7 +20,7 @@ use cmd::{datanode, frontend, metasrv, standalone};
|
||||
use common_telemetry::logging::{error, info};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[clap(name = "greptimedb")]
|
||||
#[clap(name = "greptimedb", version = print_version())]
|
||||
struct Command {
|
||||
#[clap(long, default_value = "/tmp/greptimedb/logs")]
|
||||
log_dir: String,
|
||||
@@ -70,6 +70,17 @@ impl fmt::Display for SubCommand {
|
||||
}
|
||||
}
|
||||
|
||||
fn print_version() -> &'static str {
|
||||
concat!(
|
||||
"\nbranch: ",
|
||||
env!("GIT_BRANCH"),
|
||||
"\ncommit: ",
|
||||
env!("GIT_COMMIT"),
|
||||
"\ndirty: ",
|
||||
env!("GIT_DIRTY")
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cmd = Command::parse();
|
||||
|
||||
@@ -170,6 +170,7 @@ mod tests {
|
||||
ObjectStoreConfig::File { data_dir } => {
|
||||
assert_eq!("/tmp/greptimedb/data/".to_string(), data_dir)
|
||||
}
|
||||
ObjectStoreConfig::S3 { .. } => unreachable!(),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -11,3 +11,4 @@ common-error = { path = "../error" }
|
||||
paste = "1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
|
||||
|
||||
@@ -28,31 +28,42 @@ use crate::error::{
|
||||
DeserializeCatalogEntryValueSnafu, Error, InvalidCatalogSnafu, SerializeCatalogEntryValueSnafu,
|
||||
};
|
||||
|
||||
const ALPHANUMERICS_NAME_PATTERN: &str = "[a-zA-Z_][a-zA-Z0-9_]*";
|
||||
|
||||
lazy_static! {
|
||||
static ref CATALOG_KEY_PATTERN: Regex =
|
||||
Regex::new(&format!("^{}-([a-zA-Z_]+)$", CATALOG_KEY_PREFIX)).unwrap();
|
||||
static ref CATALOG_KEY_PATTERN: Regex = Regex::new(&format!(
|
||||
"^{}-({})$",
|
||||
CATALOG_KEY_PREFIX, ALPHANUMERICS_NAME_PATTERN
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref SCHEMA_KEY_PATTERN: Regex = Regex::new(&format!(
|
||||
"^{}-([a-zA-Z_]+)-([a-zA-Z_]+)$",
|
||||
SCHEMA_KEY_PREFIX
|
||||
"^{}-({})-({})$",
|
||||
SCHEMA_KEY_PREFIX, ALPHANUMERICS_NAME_PATTERN, ALPHANUMERICS_NAME_PATTERN
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref TABLE_GLOBAL_KEY_PATTERN: Regex = Regex::new(&format!(
|
||||
"^{}-([a-zA-Z_]+)-([a-zA-Z_]+)-([a-zA-Z0-9_]+)$",
|
||||
TABLE_GLOBAL_KEY_PREFIX
|
||||
"^{}-({})-({})-({})$",
|
||||
TABLE_GLOBAL_KEY_PREFIX,
|
||||
ALPHANUMERICS_NAME_PATTERN,
|
||||
ALPHANUMERICS_NAME_PATTERN,
|
||||
ALPHANUMERICS_NAME_PATTERN
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref TABLE_REGIONAL_KEY_PATTERN: Regex = Regex::new(&format!(
|
||||
"^{}-([a-zA-Z_]+)-([a-zA-Z_]+)-([a-zA-Z0-9_]+)-([0-9]+)$",
|
||||
TABLE_REGIONAL_KEY_PREFIX
|
||||
"^{}-({})-({})-({})-([0-9]+)$",
|
||||
TABLE_REGIONAL_KEY_PREFIX,
|
||||
ALPHANUMERICS_NAME_PATTERN,
|
||||
ALPHANUMERICS_NAME_PATTERN,
|
||||
ALPHANUMERICS_NAME_PATTERN
|
||||
))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ arc-swap = "1.0"
|
||||
chrono-tz = "0.6"
|
||||
common-error = { path = "../error" }
|
||||
common-function-macro = { path = "../function-macro" }
|
||||
common-time = { path = "../time" }
|
||||
common-query = { path = "../query" }
|
||||
common-time = { path = "../time" }
|
||||
datafusion-common = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" }
|
||||
datatypes = { path = "../../datatypes" }
|
||||
libc = "0.2"
|
||||
|
||||
@@ -8,12 +8,12 @@ license = "Apache-2.0"
|
||||
api = { path = "../../api" }
|
||||
async-trait = "0.1"
|
||||
common-base = { path = "../base" }
|
||||
common-catalog = { path = "../catalog" }
|
||||
common-error = { path = "../error" }
|
||||
common-grpc = { path = "../grpc" }
|
||||
common-query = { path = "../query" }
|
||||
common-telemetry = { path = "../telemetry" }
|
||||
common-time = { path = "../time" }
|
||||
common-catalog = { path = "../catalog" }
|
||||
common-query = { path = "../query" }
|
||||
datatypes = { path = "../../datatypes" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
table = { path = "../../table" }
|
||||
|
||||
@@ -12,11 +12,11 @@ common-error = { path = "../error" }
|
||||
common-query = { path = "../query" }
|
||||
common-recordbatch = { path = "../recordbatch" }
|
||||
common-runtime = { path = "../runtime" }
|
||||
datatypes = { path = "../../datatypes" }
|
||||
dashmap = "5.4"
|
||||
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [
|
||||
"simd",
|
||||
] }
|
||||
datatypes = { path = "../../datatypes" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tonic = "0.8"
|
||||
|
||||
@@ -9,9 +9,11 @@ bytes = "1.1"
|
||||
catalog = { path = "../../catalog" }
|
||||
common-catalog = { path = "../catalog" }
|
||||
common-error = { path = "../error" }
|
||||
common-telemetry = { path = "../telemetry" }
|
||||
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [
|
||||
"simd",
|
||||
] }
|
||||
datafusion-expr = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2" }
|
||||
datatypes = { path = "../../datatypes" }
|
||||
futures = "0.3"
|
||||
prost = "0.9"
|
||||
|
||||
66
src/common/substrait/src/context.rs
Normal file
66
src/common/substrait/src/context.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use substrait_proto::protobuf::extensions::simple_extension_declaration::{
|
||||
ExtensionFunction, MappingType,
|
||||
};
|
||||
use substrait_proto::protobuf::extensions::SimpleExtensionDeclaration;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ConvertorContext {
|
||||
scalar_fn_names: HashMap<String, u32>,
|
||||
scalar_fn_map: HashMap<u32, String>,
|
||||
}
|
||||
|
||||
impl ConvertorContext {
|
||||
pub fn register_scalar_fn<S: AsRef<str>>(&mut self, name: S) -> u32 {
|
||||
if let Some(anchor) = self.scalar_fn_names.get(name.as_ref()) {
|
||||
return *anchor;
|
||||
}
|
||||
|
||||
let next_anchor = self.scalar_fn_map.len() as _;
|
||||
self.scalar_fn_map
|
||||
.insert(next_anchor, name.as_ref().to_string());
|
||||
self.scalar_fn_names
|
||||
.insert(name.as_ref().to_string(), next_anchor);
|
||||
next_anchor
|
||||
}
|
||||
|
||||
pub fn register_scalar_with_anchor<S: AsRef<str>>(&mut self, name: S, anchor: u32) {
|
||||
self.scalar_fn_map.insert(anchor, name.as_ref().to_string());
|
||||
self.scalar_fn_names
|
||||
.insert(name.as_ref().to_string(), anchor);
|
||||
}
|
||||
|
||||
pub fn find_scalar_fn(&self, anchor: u32) -> Option<&str> {
|
||||
self.scalar_fn_map.get(&anchor).map(|s| s.as_str())
|
||||
}
|
||||
|
||||
pub fn generate_function_extension(&self) -> Vec<SimpleExtensionDeclaration> {
|
||||
let mut result = Vec::with_capacity(self.scalar_fn_map.len());
|
||||
for (anchor, name) in &self.scalar_fn_map {
|
||||
let declaration = SimpleExtensionDeclaration {
|
||||
mapping_type: Some(MappingType::ExtensionFunction(ExtensionFunction {
|
||||
extension_uri_reference: 0,
|
||||
function_anchor: *anchor,
|
||||
name: name.clone(),
|
||||
})),
|
||||
};
|
||||
result.push(declaration);
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
742
src/common/substrait/src/df_expr.rs
Normal file
742
src/common/substrait/src/df_expr.rs
Normal file
@@ -0,0 +1,742 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::str::FromStr;
|
||||
|
||||
use datafusion::logical_plan::{Column, Expr};
|
||||
use datafusion_expr::{expr_fn, BuiltinScalarFunction, Operator};
|
||||
use datatypes::schema::Schema;
|
||||
use snafu::{ensure, OptionExt};
|
||||
use substrait_proto::protobuf::expression::field_reference::ReferenceType as FieldReferenceType;
|
||||
use substrait_proto::protobuf::expression::reference_segment::{
|
||||
ReferenceType as SegReferenceType, StructField,
|
||||
};
|
||||
use substrait_proto::protobuf::expression::{
|
||||
FieldReference, ReferenceSegment, RexType, ScalarFunction,
|
||||
};
|
||||
use substrait_proto::protobuf::function_argument::ArgType;
|
||||
use substrait_proto::protobuf::Expression;
|
||||
|
||||
use crate::context::ConvertorContext;
|
||||
use crate::error::{
|
||||
EmptyExprSnafu, InvalidParametersSnafu, MissingFieldSnafu, Result, UnsupportedExprSnafu,
|
||||
};
|
||||
|
||||
/// Convert substrait's `Expression` to DataFusion's `Expr`.
|
||||
pub fn to_df_expr(ctx: &ConvertorContext, expression: Expression, schema: &Schema) -> Result<Expr> {
|
||||
let expr_rex_type = expression.rex_type.context(EmptyExprSnafu)?;
|
||||
match expr_rex_type {
|
||||
RexType::Literal(_) => UnsupportedExprSnafu {
|
||||
name: "substrait Literal expression",
|
||||
}
|
||||
.fail()?,
|
||||
RexType::Selection(selection) => convert_selection_rex(*selection, schema),
|
||||
RexType::ScalarFunction(scalar_fn) => convert_scalar_function(ctx, scalar_fn, schema),
|
||||
RexType::WindowFunction(_)
|
||||
| RexType::IfThen(_)
|
||||
| RexType::SwitchExpression(_)
|
||||
| RexType::SingularOrList(_)
|
||||
| RexType::MultiOrList(_)
|
||||
| RexType::Cast(_)
|
||||
| RexType::Subquery(_)
|
||||
| RexType::Enum(_) => UnsupportedExprSnafu {
|
||||
name: format!("substrait expression {:?}", expr_rex_type),
|
||||
}
|
||||
.fail()?,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Substrait's `FieldReference` - `DirectReference` - `StructField` to Datafusion's
|
||||
/// `Column` expr.
|
||||
pub fn convert_selection_rex(selection: FieldReference, schema: &Schema) -> Result<Expr> {
|
||||
if let Some(FieldReferenceType::DirectReference(direct_ref)) = selection.reference_type
|
||||
&& let Some(SegReferenceType::StructField(field)) = direct_ref.reference_type {
|
||||
let column_name = schema.column_name_by_index(field.field as _).to_string();
|
||||
Ok(Expr::Column(Column {
|
||||
relation: None,
|
||||
name: column_name,
|
||||
}))
|
||||
} else {
|
||||
InvalidParametersSnafu {
|
||||
reason: "Only support direct struct reference in Selection Rex",
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_scalar_function(
|
||||
ctx: &ConvertorContext,
|
||||
scalar_fn: ScalarFunction,
|
||||
schema: &Schema,
|
||||
) -> Result<Expr> {
|
||||
// convert argument
|
||||
let mut inputs = VecDeque::with_capacity(scalar_fn.arguments.len());
|
||||
for arg in scalar_fn.arguments {
|
||||
if let Some(ArgType::Value(sub_expr)) = arg.arg_type {
|
||||
inputs.push_back(to_df_expr(ctx, sub_expr, schema)?);
|
||||
} else {
|
||||
InvalidParametersSnafu {
|
||||
reason: "Only value expression arg is supported to be function argument",
|
||||
}
|
||||
.fail()?;
|
||||
}
|
||||
}
|
||||
|
||||
// convert this scalar function
|
||||
// map function name
|
||||
let anchor = scalar_fn.function_reference;
|
||||
let fn_name = ctx
|
||||
.find_scalar_fn(anchor)
|
||||
.with_context(|| InvalidParametersSnafu {
|
||||
reason: format!("Unregistered scalar function reference: {}", anchor),
|
||||
})?;
|
||||
|
||||
// convenient util
|
||||
let ensure_arg_len = |expected: usize| -> Result<()> {
|
||||
ensure!(
|
||||
inputs.len() == expected,
|
||||
InvalidParametersSnafu {
|
||||
reason: format!(
|
||||
"Invalid number of scalar function {}, expected {} but found {}",
|
||||
fn_name,
|
||||
expected,
|
||||
inputs.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
Ok(())
|
||||
};
|
||||
|
||||
// construct DataFusion expr
|
||||
let expr = match fn_name {
|
||||
// begin binary exprs, with the same order of DF `Operator`'s definition.
|
||||
"eq" | "equal" => {
|
||||
ensure_arg_len(2)?;
|
||||
inputs.pop_front().unwrap().eq(inputs.pop_front().unwrap())
|
||||
}
|
||||
"not_eq" | "not_equal" => {
|
||||
ensure_arg_len(2)?;
|
||||
inputs
|
||||
.pop_front()
|
||||
.unwrap()
|
||||
.not_eq(inputs.pop_front().unwrap())
|
||||
}
|
||||
"lt" => {
|
||||
ensure_arg_len(2)?;
|
||||
inputs.pop_front().unwrap().lt(inputs.pop_front().unwrap())
|
||||
}
|
||||
"lt_eq" | "lte" => {
|
||||
ensure_arg_len(2)?;
|
||||
inputs
|
||||
.pop_front()
|
||||
.unwrap()
|
||||
.lt_eq(inputs.pop_front().unwrap())
|
||||
}
|
||||
"gt" => {
|
||||
ensure_arg_len(2)?;
|
||||
inputs.pop_front().unwrap().gt(inputs.pop_front().unwrap())
|
||||
}
|
||||
"gt_eq" | "gte" => {
|
||||
ensure_arg_len(2)?;
|
||||
inputs
|
||||
.pop_front()
|
||||
.unwrap()
|
||||
.gt_eq(inputs.pop_front().unwrap())
|
||||
}
|
||||
"plus" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::Plus,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"minus" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::Minus,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"multiply" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::Multiply,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"divide" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::Divide,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"modulo" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::Modulo,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"and" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::and(inputs.pop_front().unwrap(), inputs.pop_front().unwrap())
|
||||
}
|
||||
"or" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::or(inputs.pop_front().unwrap(), inputs.pop_front().unwrap())
|
||||
}
|
||||
"like" => {
|
||||
ensure_arg_len(2)?;
|
||||
inputs
|
||||
.pop_front()
|
||||
.unwrap()
|
||||
.like(inputs.pop_front().unwrap())
|
||||
}
|
||||
"not_like" => {
|
||||
ensure_arg_len(2)?;
|
||||
inputs
|
||||
.pop_front()
|
||||
.unwrap()
|
||||
.not_like(inputs.pop_front().unwrap())
|
||||
}
|
||||
"is_distinct_from" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::IsDistinctFrom,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"is_not_distinct_from" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::IsNotDistinctFrom,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"regex_match" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::RegexMatch,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"regex_i_match" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::RegexIMatch,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"regex_not_match" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::RegexNotMatch,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"regex_not_i_match" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::RegexNotIMatch,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"bitwise_and" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::BitwiseAnd,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
"bitwise_or" => {
|
||||
ensure_arg_len(2)?;
|
||||
expr_fn::binary_expr(
|
||||
inputs.pop_front().unwrap(),
|
||||
Operator::BitwiseOr,
|
||||
inputs.pop_front().unwrap(),
|
||||
)
|
||||
}
|
||||
// end binary exprs
|
||||
// start other direct expr, with the same order of DF `Expr`'s definition.
|
||||
"not" => {
|
||||
ensure_arg_len(1)?;
|
||||
inputs.pop_front().unwrap().not()
|
||||
}
|
||||
"is_not_null" => {
|
||||
ensure_arg_len(1)?;
|
||||
inputs.pop_front().unwrap().is_not_null()
|
||||
}
|
||||
"is_null" => {
|
||||
ensure_arg_len(1)?;
|
||||
inputs.pop_front().unwrap().is_null()
|
||||
}
|
||||
"negative" => {
|
||||
ensure_arg_len(1)?;
|
||||
Expr::Negative(Box::new(inputs.pop_front().unwrap()))
|
||||
}
|
||||
// skip GetIndexedField, unimplemented.
|
||||
"between" => {
|
||||
ensure_arg_len(3)?;
|
||||
Expr::Between {
|
||||
expr: Box::new(inputs.pop_front().unwrap()),
|
||||
negated: false,
|
||||
low: Box::new(inputs.pop_front().unwrap()),
|
||||
high: Box::new(inputs.pop_front().unwrap()),
|
||||
}
|
||||
}
|
||||
"not_between" => {
|
||||
ensure_arg_len(3)?;
|
||||
Expr::Between {
|
||||
expr: Box::new(inputs.pop_front().unwrap()),
|
||||
negated: true,
|
||||
low: Box::new(inputs.pop_front().unwrap()),
|
||||
high: Box::new(inputs.pop_front().unwrap()),
|
||||
}
|
||||
}
|
||||
// skip Case, is covered in substrait::SwitchExpression.
|
||||
// skip Cast and TryCast, is covered in substrait::Cast.
|
||||
"sort" | "sort_des" => {
|
||||
ensure_arg_len(1)?;
|
||||
Expr::Sort {
|
||||
expr: Box::new(inputs.pop_front().unwrap()),
|
||||
asc: false,
|
||||
nulls_first: false,
|
||||
}
|
||||
}
|
||||
"sort_asc" => {
|
||||
ensure_arg_len(1)?;
|
||||
Expr::Sort {
|
||||
expr: Box::new(inputs.pop_front().unwrap()),
|
||||
asc: true,
|
||||
nulls_first: false,
|
||||
}
|
||||
}
|
||||
// those are datafusion built-in "scalar functions".
|
||||
"abs"
|
||||
| "acos"
|
||||
| "asin"
|
||||
| "atan"
|
||||
| "atan2"
|
||||
| "ceil"
|
||||
| "cos"
|
||||
| "exp"
|
||||
| "floor"
|
||||
| "ln"
|
||||
| "log"
|
||||
| "log10"
|
||||
| "log2"
|
||||
| "power"
|
||||
| "pow"
|
||||
| "round"
|
||||
| "signum"
|
||||
| "sin"
|
||||
| "sqrt"
|
||||
| "tan"
|
||||
| "trunc"
|
||||
| "coalesce"
|
||||
| "make_array"
|
||||
| "ascii"
|
||||
| "bit_length"
|
||||
| "btrim"
|
||||
| "char_length"
|
||||
| "character_length"
|
||||
| "concat"
|
||||
| "concat_ws"
|
||||
| "chr"
|
||||
| "current_date"
|
||||
| "current_time"
|
||||
| "date_part"
|
||||
| "datepart"
|
||||
| "date_trunc"
|
||||
| "datetrunc"
|
||||
| "date_bin"
|
||||
| "initcap"
|
||||
| "left"
|
||||
| "length"
|
||||
| "lower"
|
||||
| "lpad"
|
||||
| "ltrim"
|
||||
| "md5"
|
||||
| "nullif"
|
||||
| "octet_length"
|
||||
| "random"
|
||||
| "regexp_replace"
|
||||
| "repeat"
|
||||
| "replace"
|
||||
| "reverse"
|
||||
| "right"
|
||||
| "rpad"
|
||||
| "rtrim"
|
||||
| "sha224"
|
||||
| "sha256"
|
||||
| "sha384"
|
||||
| "sha512"
|
||||
| "digest"
|
||||
| "split_part"
|
||||
| "starts_with"
|
||||
| "strpos"
|
||||
| "substr"
|
||||
| "to_hex"
|
||||
| "to_timestamp"
|
||||
| "to_timestamp_millis"
|
||||
| "to_timestamp_micros"
|
||||
| "to_timestamp_seconds"
|
||||
| "now"
|
||||
| "translate"
|
||||
| "trim"
|
||||
| "upper"
|
||||
| "uuid"
|
||||
| "regexp_match"
|
||||
| "struct"
|
||||
| "from_unixtime"
|
||||
| "arrow_typeof" => Expr::ScalarFunction {
|
||||
fun: BuiltinScalarFunction::from_str(fn_name).unwrap(),
|
||||
args: inputs.into(),
|
||||
},
|
||||
// skip ScalarUDF, unimplemented.
|
||||
// skip AggregateFunction, is covered in substrait::AggregateRel
|
||||
// skip WindowFunction, is covered in substrait WindowFunction
|
||||
// skip AggregateUDF, unimplemented.
|
||||
// skip InList, unimplemented
|
||||
// skip Wildcard, unimplemented.
|
||||
// end other direct expr
|
||||
_ => UnsupportedExprSnafu {
|
||||
name: format!("scalar function {}", fn_name),
|
||||
}
|
||||
.fail()?,
|
||||
};
|
||||
|
||||
Ok(expr)
|
||||
}
|
||||
|
||||
/// Convert DataFusion's `Expr` to substrait's `Expression`
|
||||
pub fn expression_from_df_expr(
|
||||
ctx: &mut ConvertorContext,
|
||||
expr: &Expr,
|
||||
schema: &Schema,
|
||||
) -> Result<Expression> {
|
||||
let expression = match expr {
|
||||
// Don't merge them with other unsupported expr arms to preserve the ordering.
|
||||
Expr::Alias(..) => UnsupportedExprSnafu {
|
||||
name: expr.to_string(),
|
||||
}
|
||||
.fail()?,
|
||||
Expr::Column(column) => {
|
||||
let field_reference = convert_column(column, schema)?;
|
||||
Expression {
|
||||
rex_type: Some(RexType::Selection(Box::new(field_reference))),
|
||||
}
|
||||
}
|
||||
// Don't merge them with other unsupported expr arms to preserve the ordering.
|
||||
Expr::ScalarVariable(..) | Expr::Literal(..) => UnsupportedExprSnafu {
|
||||
name: expr.to_string(),
|
||||
}
|
||||
.fail()?,
|
||||
Expr::BinaryExpr { left, op, right } => {
|
||||
let left = expression_from_df_expr(ctx, left, schema)?;
|
||||
let right = expression_from_df_expr(ctx, right, schema)?;
|
||||
let arguments = utils::expression_to_argument(vec![left, right]);
|
||||
let op_name = utils::name_df_operator(op);
|
||||
let function_reference = ctx.register_scalar_fn(op_name);
|
||||
utils::build_scalar_function_expression(function_reference, arguments)
|
||||
}
|
||||
Expr::Not(e) => {
|
||||
let arg = expression_from_df_expr(ctx, e, schema)?;
|
||||
let arguments = utils::expression_to_argument(vec![arg]);
|
||||
let op_name = "not";
|
||||
let function_reference = ctx.register_scalar_fn(op_name);
|
||||
utils::build_scalar_function_expression(function_reference, arguments)
|
||||
}
|
||||
Expr::IsNotNull(e) => {
|
||||
let arg = expression_from_df_expr(ctx, e, schema)?;
|
||||
let arguments = utils::expression_to_argument(vec![arg]);
|
||||
let op_name = "is_not_null";
|
||||
let function_reference = ctx.register_scalar_fn(op_name);
|
||||
utils::build_scalar_function_expression(function_reference, arguments)
|
||||
}
|
||||
Expr::IsNull(e) => {
|
||||
let arg = expression_from_df_expr(ctx, e, schema)?;
|
||||
let arguments = utils::expression_to_argument(vec![arg]);
|
||||
let op_name = "is_null";
|
||||
let function_reference = ctx.register_scalar_fn(op_name);
|
||||
utils::build_scalar_function_expression(function_reference, arguments)
|
||||
}
|
||||
Expr::Negative(e) => {
|
||||
let arg = expression_from_df_expr(ctx, e, schema)?;
|
||||
let arguments = utils::expression_to_argument(vec![arg]);
|
||||
let op_name = "negative";
|
||||
let function_reference = ctx.register_scalar_fn(op_name);
|
||||
utils::build_scalar_function_expression(function_reference, arguments)
|
||||
}
|
||||
// Don't merge them with other unsupported expr arms to preserve the ordering.
|
||||
Expr::GetIndexedField { .. } => UnsupportedExprSnafu {
|
||||
name: expr.to_string(),
|
||||
}
|
||||
.fail()?,
|
||||
Expr::Between {
|
||||
expr,
|
||||
negated,
|
||||
low,
|
||||
high,
|
||||
} => {
|
||||
let expr = expression_from_df_expr(ctx, expr, schema)?;
|
||||
let low = expression_from_df_expr(ctx, low, schema)?;
|
||||
let high = expression_from_df_expr(ctx, high, schema)?;
|
||||
let arguments = utils::expression_to_argument(vec![expr, low, high]);
|
||||
let op_name = if *negated { "not_between" } else { "between" };
|
||||
let function_reference = ctx.register_scalar_fn(op_name);
|
||||
utils::build_scalar_function_expression(function_reference, arguments)
|
||||
}
|
||||
// Don't merge them with other unsupported expr arms to preserve the ordering.
|
||||
Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } => UnsupportedExprSnafu {
|
||||
name: expr.to_string(),
|
||||
}
|
||||
.fail()?,
|
||||
Expr::Sort {
|
||||
expr,
|
||||
asc,
|
||||
nulls_first: _,
|
||||
} => {
|
||||
let expr = expression_from_df_expr(ctx, expr, schema)?;
|
||||
let arguments = utils::expression_to_argument(vec![expr]);
|
||||
let op_name = if *asc { "sort_asc" } else { "sort_des" };
|
||||
let function_reference = ctx.register_scalar_fn(op_name);
|
||||
utils::build_scalar_function_expression(function_reference, arguments)
|
||||
}
|
||||
Expr::ScalarFunction { fun, args } => {
|
||||
let arguments = utils::expression_to_argument(
|
||||
args.iter()
|
||||
.map(|e| expression_from_df_expr(ctx, e, schema))
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
);
|
||||
let op_name = utils::name_builtin_scalar_function(fun);
|
||||
let function_reference = ctx.register_scalar_fn(op_name);
|
||||
utils::build_scalar_function_expression(function_reference, arguments)
|
||||
}
|
||||
// Don't merge them with other unsupported expr arms to preserve the ordering.
|
||||
Expr::ScalarUDF { .. }
|
||||
| Expr::AggregateFunction { .. }
|
||||
| Expr::WindowFunction { .. }
|
||||
| Expr::AggregateUDF { .. }
|
||||
| Expr::InList { .. }
|
||||
| Expr::Wildcard => UnsupportedExprSnafu {
|
||||
name: expr.to_string(),
|
||||
}
|
||||
.fail()?,
|
||||
};
|
||||
|
||||
Ok(expression)
|
||||
}
|
||||
|
||||
/// Convert DataFusion's `Column` expr into substrait's `FieldReference` -
|
||||
/// `DirectReference` - `StructField`.
|
||||
pub fn convert_column(column: &Column, schema: &Schema) -> Result<FieldReference> {
|
||||
let column_name = &column.name;
|
||||
let field_index =
|
||||
schema
|
||||
.column_index_by_name(column_name)
|
||||
.with_context(|| MissingFieldSnafu {
|
||||
field: format!("{:?}", column),
|
||||
plan: format!("schema: {:?}", schema),
|
||||
})?;
|
||||
|
||||
Ok(FieldReference {
|
||||
reference_type: Some(FieldReferenceType::DirectReference(ReferenceSegment {
|
||||
reference_type: Some(SegReferenceType::StructField(Box::new(StructField {
|
||||
field: field_index as _,
|
||||
child: None,
|
||||
}))),
|
||||
})),
|
||||
root_type: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Some utils special for this `DataFusion::Expr` and `Substrait::Expression` conversion.
|
||||
mod utils {
|
||||
use datafusion_expr::{BuiltinScalarFunction, Operator};
|
||||
use substrait_proto::protobuf::expression::{RexType, ScalarFunction};
|
||||
use substrait_proto::protobuf::function_argument::ArgType;
|
||||
use substrait_proto::protobuf::{Expression, FunctionArgument};
|
||||
|
||||
pub(crate) fn name_df_operator(op: &Operator) -> &str {
|
||||
match op {
|
||||
Operator::Eq => "equal",
|
||||
Operator::NotEq => "not_equal",
|
||||
Operator::Lt => "lt",
|
||||
Operator::LtEq => "lte",
|
||||
Operator::Gt => "gt",
|
||||
Operator::GtEq => "gte",
|
||||
Operator::Plus => "plus",
|
||||
Operator::Minus => "minus",
|
||||
Operator::Multiply => "multiply",
|
||||
Operator::Divide => "divide",
|
||||
Operator::Modulo => "modulo",
|
||||
Operator::And => "and",
|
||||
Operator::Or => "or",
|
||||
Operator::Like => "like",
|
||||
Operator::NotLike => "not_like",
|
||||
Operator::IsDistinctFrom => "is_distinct_from",
|
||||
Operator::IsNotDistinctFrom => "is_not_distinct_from",
|
||||
Operator::RegexMatch => "regex_match",
|
||||
Operator::RegexIMatch => "regex_i_match",
|
||||
Operator::RegexNotMatch => "regex_not_match",
|
||||
Operator::RegexNotIMatch => "regex_not_i_match",
|
||||
Operator::BitwiseAnd => "bitwise_and",
|
||||
Operator::BitwiseOr => "bitwise_or",
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert list of [Expression] to [FunctionArgument] vector.
|
||||
pub(crate) fn expression_to_argument<I: IntoIterator<Item = Expression>>(
|
||||
expressions: I,
|
||||
) -> Vec<FunctionArgument> {
|
||||
expressions
|
||||
.into_iter()
|
||||
.map(|expr| FunctionArgument {
|
||||
arg_type: Some(ArgType::Value(expr)),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convenient builder for [Expression]
|
||||
pub(crate) fn build_scalar_function_expression(
|
||||
function_reference: u32,
|
||||
arguments: Vec<FunctionArgument>,
|
||||
) -> Expression {
|
||||
Expression {
|
||||
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
|
||||
function_reference,
|
||||
arguments,
|
||||
output_type: None,
|
||||
..Default::default()
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn name_builtin_scalar_function(fun: &BuiltinScalarFunction) -> &str {
|
||||
match fun {
|
||||
BuiltinScalarFunction::Abs => "abs",
|
||||
BuiltinScalarFunction::Acos => "acos",
|
||||
BuiltinScalarFunction::Asin => "asin",
|
||||
BuiltinScalarFunction::Atan => "atan",
|
||||
BuiltinScalarFunction::Ceil => "ceil",
|
||||
BuiltinScalarFunction::Cos => "cos",
|
||||
BuiltinScalarFunction::Digest => "digest",
|
||||
BuiltinScalarFunction::Exp => "exp",
|
||||
BuiltinScalarFunction::Floor => "floor",
|
||||
BuiltinScalarFunction::Ln => "ln",
|
||||
BuiltinScalarFunction::Log => "log",
|
||||
BuiltinScalarFunction::Log10 => "log10",
|
||||
BuiltinScalarFunction::Log2 => "log2",
|
||||
BuiltinScalarFunction::Round => "round",
|
||||
BuiltinScalarFunction::Signum => "signum",
|
||||
BuiltinScalarFunction::Sin => "sin",
|
||||
BuiltinScalarFunction::Sqrt => "sqrt",
|
||||
BuiltinScalarFunction::Tan => "tan",
|
||||
BuiltinScalarFunction::Trunc => "trunc",
|
||||
BuiltinScalarFunction::Array => "make_array",
|
||||
BuiltinScalarFunction::Ascii => "ascii",
|
||||
BuiltinScalarFunction::BitLength => "bit_length",
|
||||
BuiltinScalarFunction::Btrim => "btrim",
|
||||
BuiltinScalarFunction::CharacterLength => "character_length",
|
||||
BuiltinScalarFunction::Chr => "chr",
|
||||
BuiltinScalarFunction::Concat => "concat",
|
||||
BuiltinScalarFunction::ConcatWithSeparator => "concat_ws",
|
||||
BuiltinScalarFunction::DatePart => "date_part",
|
||||
BuiltinScalarFunction::DateTrunc => "date_trunc",
|
||||
BuiltinScalarFunction::InitCap => "initcap",
|
||||
BuiltinScalarFunction::Left => "left",
|
||||
BuiltinScalarFunction::Lpad => "lpad",
|
||||
BuiltinScalarFunction::Lower => "lower",
|
||||
BuiltinScalarFunction::Ltrim => "ltrim",
|
||||
BuiltinScalarFunction::MD5 => "md5",
|
||||
BuiltinScalarFunction::NullIf => "nullif",
|
||||
BuiltinScalarFunction::OctetLength => "octet_length",
|
||||
BuiltinScalarFunction::Random => "random",
|
||||
BuiltinScalarFunction::RegexpReplace => "regexp_replace",
|
||||
BuiltinScalarFunction::Repeat => "repeat",
|
||||
BuiltinScalarFunction::Replace => "replace",
|
||||
BuiltinScalarFunction::Reverse => "reverse",
|
||||
BuiltinScalarFunction::Right => "right",
|
||||
BuiltinScalarFunction::Rpad => "rpad",
|
||||
BuiltinScalarFunction::Rtrim => "rtrim",
|
||||
BuiltinScalarFunction::SHA224 => "sha224",
|
||||
BuiltinScalarFunction::SHA256 => "sha256",
|
||||
BuiltinScalarFunction::SHA384 => "sha384",
|
||||
BuiltinScalarFunction::SHA512 => "sha512",
|
||||
BuiltinScalarFunction::SplitPart => "split_part",
|
||||
BuiltinScalarFunction::StartsWith => "starts_with",
|
||||
BuiltinScalarFunction::Strpos => "strpos",
|
||||
BuiltinScalarFunction::Substr => "substr",
|
||||
BuiltinScalarFunction::ToHex => "to_hex",
|
||||
BuiltinScalarFunction::ToTimestamp => "to_timestamp",
|
||||
BuiltinScalarFunction::ToTimestampMillis => "to_timestamp_millis",
|
||||
BuiltinScalarFunction::ToTimestampMicros => "to_timestamp_macros",
|
||||
BuiltinScalarFunction::ToTimestampSeconds => "to_timestamp_seconds",
|
||||
BuiltinScalarFunction::Now => "now",
|
||||
BuiltinScalarFunction::Translate => "translate",
|
||||
BuiltinScalarFunction::Trim => "trim",
|
||||
BuiltinScalarFunction::Upper => "upper",
|
||||
BuiltinScalarFunction::RegexpMatch => "regexp_match",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use datatypes::schema::ColumnSchema;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn expr_round_trip() {
|
||||
let expr = expr_fn::and(
|
||||
expr_fn::col("column_a").lt_eq(expr_fn::col("column_b")),
|
||||
expr_fn::col("column_a").gt(expr_fn::col("column_b")),
|
||||
);
|
||||
|
||||
let schema = Schema::new(vec![
|
||||
ColumnSchema::new(
|
||||
"column_a",
|
||||
datatypes::data_type::ConcreteDataType::int64_datatype(),
|
||||
true,
|
||||
),
|
||||
ColumnSchema::new(
|
||||
"column_b",
|
||||
datatypes::data_type::ConcreteDataType::float64_datatype(),
|
||||
true,
|
||||
),
|
||||
]);
|
||||
|
||||
let mut ctx = ConvertorContext::default();
|
||||
let substrait_expr = expression_from_df_expr(&mut ctx, &expr, &schema).unwrap();
|
||||
let converted_expr = to_df_expr(&ctx, substrait_expr, &schema).unwrap();
|
||||
|
||||
assert_eq!(expr, converted_expr);
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ use std::sync::Arc;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use catalog::CatalogManagerRef;
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_telemetry::debug;
|
||||
use datafusion::datasource::TableProvider;
|
||||
use datafusion::logical_plan::{LogicalPlan, TableScan, ToDFSchema};
|
||||
use datafusion::physical_plan::project_schema;
|
||||
@@ -24,12 +25,15 @@ use prost::Message;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use substrait_proto::protobuf::expression::mask_expression::{StructItem, StructSelect};
|
||||
use substrait_proto::protobuf::expression::MaskExpression;
|
||||
use substrait_proto::protobuf::extensions::simple_extension_declaration::MappingType;
|
||||
use substrait_proto::protobuf::plan_rel::RelType as PlanRelType;
|
||||
use substrait_proto::protobuf::read_rel::{NamedTable, ReadType};
|
||||
use substrait_proto::protobuf::rel::RelType;
|
||||
use substrait_proto::protobuf::{PlanRel, ReadRel, Rel};
|
||||
use substrait_proto::protobuf::{Plan, PlanRel, ReadRel, Rel};
|
||||
use table::table::adapter::DfTableProviderAdapter;
|
||||
|
||||
use crate::context::ConvertorContext;
|
||||
use crate::df_expr::{expression_from_df_expr, to_df_expr};
|
||||
use crate::error::{
|
||||
DFInternalSnafu, DecodeRelSnafu, EmptyPlanSnafu, EncodeRelSnafu, Error, InternalSnafu,
|
||||
InvalidParametersSnafu, MissingFieldSnafu, SchemaNotMatchSnafu, TableNotFoundSnafu,
|
||||
@@ -48,25 +52,15 @@ impl SubstraitPlan for DFLogicalSubstraitConvertor {
|
||||
type Plan = LogicalPlan;
|
||||
|
||||
fn decode<B: Buf + Send>(&self, message: B) -> Result<Self::Plan, Self::Error> {
|
||||
let plan_rel = PlanRel::decode(message).context(DecodeRelSnafu)?;
|
||||
let rel = match plan_rel.rel_type.context(EmptyPlanSnafu)? {
|
||||
PlanRelType::Rel(rel) => rel,
|
||||
PlanRelType::Root(_) => UnsupportedPlanSnafu {
|
||||
name: "Root Relation",
|
||||
}
|
||||
.fail()?,
|
||||
};
|
||||
self.convert_rel(rel)
|
||||
let plan = Plan::decode(message).context(DecodeRelSnafu)?;
|
||||
self.convert_plan(plan)
|
||||
}
|
||||
|
||||
fn encode(&self, plan: Self::Plan) -> Result<Bytes, Self::Error> {
|
||||
let rel = self.convert_plan(plan)?;
|
||||
let plan_rel = PlanRel {
|
||||
rel_type: Some(PlanRelType::Rel(rel)),
|
||||
};
|
||||
let plan = self.convert_df_plan(plan)?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
plan_rel.encode(&mut buf).context(EncodeRelSnafu)?;
|
||||
plan.encode(&mut buf).context(EncodeRelSnafu)?;
|
||||
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
@@ -79,10 +73,37 @@ impl DFLogicalSubstraitConvertor {
|
||||
}
|
||||
|
||||
impl DFLogicalSubstraitConvertor {
|
||||
pub fn convert_rel(&self, rel: Rel) -> Result<LogicalPlan, Error> {
|
||||
pub fn convert_plan(&self, mut plan: Plan) -> Result<LogicalPlan, Error> {
|
||||
// prepare convertor context
|
||||
let mut ctx = ConvertorContext::default();
|
||||
for simple_ext in plan.extensions {
|
||||
if let Some(MappingType::ExtensionFunction(function_extension)) =
|
||||
simple_ext.mapping_type
|
||||
{
|
||||
ctx.register_scalar_with_anchor(
|
||||
function_extension.name,
|
||||
function_extension.function_anchor,
|
||||
);
|
||||
} else {
|
||||
debug!("Encounter unsupported substrait extension {:?}", simple_ext);
|
||||
}
|
||||
}
|
||||
|
||||
// extract rel
|
||||
let rel = if let Some(PlanRel { rel_type }) = plan.relations.pop()
|
||||
&& let Some(PlanRelType::Rel(rel)) = rel_type {
|
||||
rel
|
||||
} else {
|
||||
UnsupportedPlanSnafu {
|
||||
name: "Emply or non-Rel relation",
|
||||
}
|
||||
.fail()?
|
||||
};
|
||||
let rel_type = rel.rel_type.context(EmptyPlanSnafu)?;
|
||||
|
||||
// build logical plan
|
||||
let logical_plan = match rel_type {
|
||||
RelType::Read(read_rel) => self.convert_read_rel(read_rel),
|
||||
RelType::Read(read_rel) => self.convert_read_rel(&mut ctx, read_rel),
|
||||
RelType::Filter(_filter_rel) => UnsupportedPlanSnafu {
|
||||
name: "Filter Relation",
|
||||
}
|
||||
@@ -132,9 +153,12 @@ impl DFLogicalSubstraitConvertor {
|
||||
Ok(logical_plan)
|
||||
}
|
||||
|
||||
fn convert_read_rel(&self, read_rel: Box<ReadRel>) -> Result<LogicalPlan, Error> {
|
||||
fn convert_read_rel(
|
||||
&self,
|
||||
ctx: &mut ConvertorContext,
|
||||
read_rel: Box<ReadRel>,
|
||||
) -> Result<LogicalPlan, Error> {
|
||||
// Extract the catalog, schema and table name from NamedTable. Assume the first three are those names.
|
||||
|
||||
let read_type = read_rel.read_type.context(MissingFieldSnafu {
|
||||
field: "read_type",
|
||||
plan: "Read",
|
||||
@@ -190,6 +214,13 @@ impl DFLogicalSubstraitConvertor {
|
||||
}
|
||||
);
|
||||
|
||||
// Convert filter
|
||||
let filters = if let Some(filter) = read_rel.filter {
|
||||
vec![to_df_expr(ctx, *filter, &retrieved_schema)?]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Calculate the projected schema
|
||||
let projected_schema = project_schema(&stored_schema, projection.as_ref())
|
||||
.context(DFInternalSnafu)?
|
||||
@@ -202,7 +233,7 @@ impl DFLogicalSubstraitConvertor {
|
||||
source: adapter,
|
||||
projection,
|
||||
projected_schema,
|
||||
filters: vec![],
|
||||
filters,
|
||||
limit: None,
|
||||
}))
|
||||
}
|
||||
@@ -219,8 +250,12 @@ impl DFLogicalSubstraitConvertor {
|
||||
}
|
||||
|
||||
impl DFLogicalSubstraitConvertor {
|
||||
pub fn convert_plan(&self, plan: LogicalPlan) -> Result<Rel, Error> {
|
||||
match plan {
|
||||
pub fn convert_df_plan(&self, plan: LogicalPlan) -> Result<Plan, Error> {
|
||||
let mut ctx = ConvertorContext::default();
|
||||
|
||||
// TODO(ruihang): extract this translation logic into a separated function
|
||||
// convert PlanRel
|
||||
let rel = match plan {
|
||||
LogicalPlan::Projection(_) => UnsupportedPlanSnafu {
|
||||
name: "DataFusion Logical Projection",
|
||||
}
|
||||
@@ -258,10 +293,10 @@ impl DFLogicalSubstraitConvertor {
|
||||
}
|
||||
.fail()?,
|
||||
LogicalPlan::TableScan(table_scan) => {
|
||||
let read_rel = self.convert_table_scan_plan(table_scan)?;
|
||||
Ok(Rel {
|
||||
let read_rel = self.convert_table_scan_plan(&mut ctx, table_scan)?;
|
||||
Rel {
|
||||
rel_type: Some(RelType::Read(Box::new(read_rel))),
|
||||
})
|
||||
}
|
||||
}
|
||||
LogicalPlan::EmptyRelation(_) => UnsupportedPlanSnafu {
|
||||
name: "DataFusion Logical EmptyRelation",
|
||||
@@ -284,10 +319,30 @@ impl DFLogicalSubstraitConvertor {
|
||||
),
|
||||
}
|
||||
.fail()?,
|
||||
}
|
||||
};
|
||||
|
||||
// convert extension
|
||||
let extensions = ctx.generate_function_extension();
|
||||
|
||||
// assemble PlanRel
|
||||
let plan_rel = PlanRel {
|
||||
rel_type: Some(PlanRelType::Rel(rel)),
|
||||
};
|
||||
|
||||
Ok(Plan {
|
||||
extension_uris: vec![],
|
||||
extensions,
|
||||
relations: vec![plan_rel],
|
||||
advanced_extensions: None,
|
||||
expected_type_urls: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
pub fn convert_table_scan_plan(&self, table_scan: TableScan) -> Result<ReadRel, Error> {
|
||||
pub fn convert_table_scan_plan(
|
||||
&self,
|
||||
ctx: &mut ConvertorContext,
|
||||
table_scan: TableScan,
|
||||
) -> Result<ReadRel, Error> {
|
||||
let provider = table_scan
|
||||
.source
|
||||
.as_any()
|
||||
@@ -313,10 +368,25 @@ impl DFLogicalSubstraitConvertor {
|
||||
// assemble base (unprojected) schema using Table's schema.
|
||||
let base_schema = from_schema(&provider.table().schema())?;
|
||||
|
||||
// make conjunction over a list of filters and convert the result to substrait
|
||||
let filter = if let Some(conjunction) = table_scan
|
||||
.filters
|
||||
.into_iter()
|
||||
.reduce(|accum, expr| accum.and(expr))
|
||||
{
|
||||
Some(Box::new(expression_from_df_expr(
|
||||
ctx,
|
||||
&conjunction,
|
||||
&provider.table().schema(),
|
||||
)?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let read_rel = ReadRel {
|
||||
common: None,
|
||||
base_schema: Some(base_schema),
|
||||
filter: None,
|
||||
filter,
|
||||
projection,
|
||||
advanced_extension: None,
|
||||
read_type: Some(read_type),
|
||||
|
||||
@@ -23,10 +23,10 @@ use snafu::{Backtrace, ErrorCompat, Snafu};
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
pub enum Error {
|
||||
#[snafu(display("Unsupported physical expr: {}", name))]
|
||||
#[snafu(display("Unsupported physical plan: {}", name))]
|
||||
UnsupportedPlan { name: String, backtrace: Backtrace },
|
||||
|
||||
#[snafu(display("Unsupported physical plan: {}", name))]
|
||||
#[snafu(display("Unsupported expr: {}", name))]
|
||||
UnsupportedExpr { name: String, backtrace: Backtrace },
|
||||
|
||||
#[snafu(display("Unsupported concrete type: {:?}", ty))]
|
||||
|
||||
@@ -12,6 +12,10 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#![feature(let_chains)]
|
||||
|
||||
mod context;
|
||||
mod df_expr;
|
||||
mod df_logical;
|
||||
pub mod error;
|
||||
mod schema;
|
||||
|
||||
@@ -11,49 +11,50 @@ python = ["dep:script"]
|
||||
[dependencies]
|
||||
api = { path = "../api" }
|
||||
async-trait = "0.1"
|
||||
axum = "0.6.0-rc.2"
|
||||
axum-macros = "0.3.0-rc.1"
|
||||
axum = "0.6"
|
||||
axum-macros = "0.3"
|
||||
backon = "0.2"
|
||||
catalog = { path = "../catalog" }
|
||||
common-base = { path = "../common/base" }
|
||||
common-catalog = { path = "../common/catalog" }
|
||||
common-error = { path = "../common/error" }
|
||||
common-grpc = { path = "../common/grpc" }
|
||||
common-grpc-expr = { path = "../common/grpc-expr" }
|
||||
common-query = { path = "../common/query" }
|
||||
common-recordbatch = { path = "../common/recordbatch" }
|
||||
common-runtime = { path = "../common/runtime" }
|
||||
common-telemetry = { path = "../common/telemetry" }
|
||||
common-time = { path = "../common/time" }
|
||||
common-grpc-expr = { path = "../common/grpc-expr" }
|
||||
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [
|
||||
"simd",
|
||||
] }
|
||||
datatypes = { path = "../datatypes" }
|
||||
frontend = { path = "../frontend" }
|
||||
futures = "0.3"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
log-store = { path = "../log-store" }
|
||||
meta-client = { path = "../meta-client" }
|
||||
meta-srv = { path = "../meta-srv", features = ["mock"] }
|
||||
metrics = "0.20"
|
||||
mito = { path = "../mito", features = ["test"] }
|
||||
object-store = { path = "../object-store" }
|
||||
query = { path = "../query" }
|
||||
script = { path = "../script", features = ["python"], optional = true }
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
servers = { path = "../servers" }
|
||||
session = { path = "../session" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
sql = { path = "../sql" }
|
||||
storage = { path = "../storage" }
|
||||
store-api = { path = "../store-api" }
|
||||
substrait = { path = "../common/substrait" }
|
||||
table = { path = "../table" }
|
||||
mito = { path = "../mito", features = ["test"] }
|
||||
tokio = { version = "1.18", features = ["full"] }
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
tonic = "0.8"
|
||||
tower = { version = "0.4", features = ["full"] }
|
||||
tower-http = { version = "0.3", features = ["full"] }
|
||||
frontend = { path = "../frontend" }
|
||||
|
||||
[dev-dependencies]
|
||||
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
|
||||
|
||||
@@ -26,7 +26,15 @@ use crate::server::Services;
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ObjectStoreConfig {
|
||||
File { data_dir: String },
|
||||
File {
|
||||
data_dir: String,
|
||||
},
|
||||
S3 {
|
||||
bucket: String,
|
||||
root: String,
|
||||
access_key_id: String,
|
||||
secret_access_key: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for ObjectStoreConfig {
|
||||
|
||||
@@ -18,6 +18,8 @@ use common_error::prelude::*;
|
||||
use storage::error::Error as StorageError;
|
||||
use table::error::Error as TableError;
|
||||
|
||||
use crate::datanode::ObjectStoreConfig;
|
||||
|
||||
/// Business error of datanode.
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
@@ -142,9 +144,9 @@ pub enum Error {
|
||||
#[snafu(display("Failed to storage engine, source: {}", source))]
|
||||
OpenStorageEngine { source: StorageError },
|
||||
|
||||
#[snafu(display("Failed to init backend, dir: {}, source: {}", dir, source))]
|
||||
#[snafu(display("Failed to init backend, config: {:#?}, source: {}", config, source))]
|
||||
InitBackend {
|
||||
dir: String,
|
||||
config: ObjectStoreConfig,
|
||||
source: object_store::Error,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
@@ -28,7 +28,8 @@ use meta_client::MetaClientOpts;
|
||||
use mito::config::EngineConfig as TableEngineConfig;
|
||||
use mito::engine::MitoEngine;
|
||||
use object_store::layers::{LoggingLayer, MetricsLayer, RetryLayer, TracingLayer};
|
||||
use object_store::services::fs::Builder;
|
||||
use object_store::services::fs::Builder as FsBuilder;
|
||||
use object_store::services::s3::Builder as S3Builder;
|
||||
use object_store::{util, ObjectStore};
|
||||
use query::query_engine::{QueryEngineFactory, QueryEngineRef};
|
||||
use servers::Mode;
|
||||
@@ -187,32 +188,64 @@ impl Instance {
|
||||
}
|
||||
|
||||
pub(crate) async fn new_object_store(store_config: &ObjectStoreConfig) -> Result<ObjectStore> {
|
||||
// TODO(dennis): supports other backend
|
||||
let data_dir = util::normalize_dir(match store_config {
|
||||
ObjectStoreConfig::File { data_dir } => data_dir,
|
||||
});
|
||||
let object_store = match store_config {
|
||||
ObjectStoreConfig::File { data_dir } => new_fs_object_store(data_dir).await,
|
||||
ObjectStoreConfig::S3 { .. } => new_s3_object_store(store_config).await,
|
||||
};
|
||||
|
||||
object_store.map(|object_store| {
|
||||
object_store
|
||||
.layer(RetryLayer::new(ExponentialBackoff::default().with_jitter()))
|
||||
.layer(MetricsLayer)
|
||||
.layer(LoggingLayer)
|
||||
.layer(TracingLayer)
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn new_s3_object_store(store_config: &ObjectStoreConfig) -> Result<ObjectStore> {
|
||||
let (root, secret_key, key_id, bucket) = match store_config {
|
||||
ObjectStoreConfig::S3 {
|
||||
bucket,
|
||||
root,
|
||||
access_key_id,
|
||||
secret_access_key,
|
||||
} => (root, secret_access_key, access_key_id, bucket),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let root = util::normalize_dir(root);
|
||||
info!("The s3 storage bucket is: {}, root is: {}", bucket, &root);
|
||||
|
||||
let accessor = S3Builder::default()
|
||||
.root(&root)
|
||||
.bucket(bucket)
|
||||
.access_key_id(key_id)
|
||||
.secret_access_key(secret_key)
|
||||
.build()
|
||||
.with_context(|_| error::InitBackendSnafu {
|
||||
config: store_config.clone(),
|
||||
})?;
|
||||
|
||||
Ok(ObjectStore::new(accessor))
|
||||
}
|
||||
|
||||
pub(crate) async fn new_fs_object_store(data_dir: &str) -> Result<ObjectStore> {
|
||||
let data_dir = util::normalize_dir(data_dir);
|
||||
fs::create_dir_all(path::Path::new(&data_dir))
|
||||
.context(error::CreateDirSnafu { dir: &data_dir })?;
|
||||
info!("The file storage directory is: {}", &data_dir);
|
||||
|
||||
info!("The storage directory is: {}", &data_dir);
|
||||
let atomic_write_dir = format!("{}/.tmp/", data_dir);
|
||||
|
||||
let accessor = Builder::default()
|
||||
let accessor = FsBuilder::default()
|
||||
.root(&data_dir)
|
||||
.atomic_write_dir(&atomic_write_dir)
|
||||
.build()
|
||||
.context(error::InitBackendSnafu { dir: &data_dir })?;
|
||||
.context(error::InitBackendSnafu {
|
||||
config: ObjectStoreConfig::File { data_dir },
|
||||
})?;
|
||||
|
||||
let object_store = ObjectStore::new(accessor)
|
||||
// Add retry
|
||||
.layer(RetryLayer::new(ExponentialBackoff::default().with_jitter()))
|
||||
// Add metrics
|
||||
.layer(MetricsLayer)
|
||||
// Add logging
|
||||
.layer(LoggingLayer)
|
||||
// Add tracing
|
||||
.layer(TracingLayer);
|
||||
|
||||
Ok(object_store)
|
||||
Ok(ObjectStore::new(accessor))
|
||||
}
|
||||
|
||||
/// Create metasrv client instance and spawn heartbeat loop.
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::result::{build_err_result, AdminResultBuilder, ObjectResultBuilder};
|
||||
use api::v1::{
|
||||
admin_expr, object_expr, select_expr, AdminExpr, AdminResult, Column, CreateDatabaseExpr,
|
||||
@@ -26,6 +28,7 @@ use common_grpc_expr::insertion_expr_to_request;
|
||||
use common_query::Output;
|
||||
use query::plan::LogicalPlan;
|
||||
use servers::query_handler::{GrpcAdminHandler, GrpcQueryHandler};
|
||||
use session::context::QueryContext;
|
||||
use snafu::prelude::*;
|
||||
use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
|
||||
use table::requests::CreateDatabaseRequest;
|
||||
@@ -110,7 +113,9 @@ impl Instance {
|
||||
async fn do_handle_select(&self, select_expr: SelectExpr) -> Result<Output> {
|
||||
let expr = select_expr.expr;
|
||||
match expr {
|
||||
Some(select_expr::Expr::Sql(sql)) => self.execute_sql(&sql).await,
|
||||
Some(select_expr::Expr::Sql(sql)) => {
|
||||
self.execute_sql(&sql, Arc::new(QueryContext::new())).await
|
||||
}
|
||||
Some(select_expr::Expr::LogicalPlan(plan)) => self.execute_logical(plan).await,
|
||||
Some(select_expr::Expr::PhysicalPlan(api::v1::PhysicalPlan { original_ql, plan })) => {
|
||||
self.physical_planner
|
||||
|
||||
@@ -13,25 +13,27 @@
|
||||
// limitations under the License.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_error::prelude::BoxedError;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::RecordBatches;
|
||||
use common_telemetry::logging::{error, info};
|
||||
use common_telemetry::timer;
|
||||
use servers::query_handler::SqlQueryHandler;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::prelude::*;
|
||||
use sql::ast::ObjectName;
|
||||
use sql::statements::statement::Statement;
|
||||
use table::engine::TableReference;
|
||||
use table::requests::CreateDatabaseRequest;
|
||||
|
||||
use crate::error::{
|
||||
BumpTableIdSnafu, CatalogNotFoundSnafu, CatalogSnafu, ExecuteSqlSnafu, ParseSqlSnafu, Result,
|
||||
SchemaNotFoundSnafu, TableIdProviderNotFoundSnafu,
|
||||
};
|
||||
use crate::error::{self, BumpTableIdSnafu, ExecuteSqlSnafu, Result, TableIdProviderNotFoundSnafu};
|
||||
use crate::instance::Instance;
|
||||
use crate::metric;
|
||||
use crate::sql::SqlRequest;
|
||||
|
||||
impl Instance {
|
||||
pub async fn execute_sql(&self, sql: &str) -> Result<Output> {
|
||||
pub async fn execute_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
let stmt = self
|
||||
.query_engine
|
||||
.sql_to_statement(sql)
|
||||
@@ -41,7 +43,7 @@ impl Instance {
|
||||
Statement::Query(_) => {
|
||||
let logical_plan = self
|
||||
.query_engine
|
||||
.statement_to_plan(stmt)
|
||||
.statement_to_plan(stmt, query_ctx)
|
||||
.context(ExecuteSqlSnafu)?;
|
||||
|
||||
self.query_engine
|
||||
@@ -50,20 +52,15 @@ impl Instance {
|
||||
.context(ExecuteSqlSnafu)
|
||||
}
|
||||
Statement::Insert(i) => {
|
||||
let (catalog_name, schema_name, _table_name) =
|
||||
i.full_table_name().context(ParseSqlSnafu)?;
|
||||
|
||||
let schema_provider = self
|
||||
.catalog_manager
|
||||
.catalog(&catalog_name)
|
||||
.context(CatalogSnafu)?
|
||||
.context(CatalogNotFoundSnafu { name: catalog_name })?
|
||||
.schema(&schema_name)
|
||||
.context(CatalogSnafu)?
|
||||
.context(SchemaNotFoundSnafu { name: schema_name })?;
|
||||
|
||||
let request = self.sql_handler.insert_to_request(schema_provider, *i)?;
|
||||
self.sql_handler.execute(request).await
|
||||
let (catalog, schema, table) =
|
||||
table_idents_to_full_name(i.table_name(), query_ctx.clone())?;
|
||||
let table_ref = TableReference::full(&catalog, &schema, &table);
|
||||
let request = self.sql_handler.insert_to_request(
|
||||
self.catalog_manager.clone(),
|
||||
*i,
|
||||
table_ref,
|
||||
)?;
|
||||
self.sql_handler.execute(request, query_ctx).await
|
||||
}
|
||||
|
||||
Statement::CreateDatabase(c) => {
|
||||
@@ -74,7 +71,7 @@ impl Instance {
|
||||
info!("Creating a new database: {}", request.db_name);
|
||||
|
||||
self.sql_handler
|
||||
.execute(SqlRequest::CreateDatabase(request))
|
||||
.execute(SqlRequest::CreateDatabase(request), query_ctx)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -89,58 +86,116 @@ impl Instance {
|
||||
let _engine_name = c.engine.clone();
|
||||
// TODO(hl): Select table engine by engine_name
|
||||
|
||||
let request = self.sql_handler.create_to_request(table_id, c)?;
|
||||
let catalog_name = &request.catalog_name;
|
||||
let schema_name = &request.schema_name;
|
||||
let table_name = &request.table_name;
|
||||
let name = c.name.clone();
|
||||
let (catalog, schema, table) = table_idents_to_full_name(&name, query_ctx.clone())?;
|
||||
let table_ref = TableReference::full(&catalog, &schema, &table);
|
||||
let request = self.sql_handler.create_to_request(table_id, c, table_ref)?;
|
||||
let table_id = request.id;
|
||||
info!(
|
||||
"Creating table, catalog: {:?}, schema: {:?}, table name: {:?}, table id: {}",
|
||||
catalog_name, schema_name, table_name, table_id
|
||||
catalog, schema, table, table_id
|
||||
);
|
||||
|
||||
self.sql_handler
|
||||
.execute(SqlRequest::CreateTable(request))
|
||||
.execute(SqlRequest::CreateTable(request), query_ctx)
|
||||
.await
|
||||
}
|
||||
Statement::Alter(alter_table) => {
|
||||
let req = self.sql_handler.alter_to_request(alter_table)?;
|
||||
self.sql_handler.execute(SqlRequest::Alter(req)).await
|
||||
let name = alter_table.table_name().clone();
|
||||
let (catalog, schema, table) = table_idents_to_full_name(&name, query_ctx.clone())?;
|
||||
let table_ref = TableReference::full(&catalog, &schema, &table);
|
||||
let req = self.sql_handler.alter_to_request(alter_table, table_ref)?;
|
||||
self.sql_handler
|
||||
.execute(SqlRequest::Alter(req), query_ctx)
|
||||
.await
|
||||
}
|
||||
Statement::DropTable(drop_table) => {
|
||||
let req = self.sql_handler.drop_table_to_request(drop_table);
|
||||
self.sql_handler.execute(SqlRequest::DropTable(req)).await
|
||||
self.sql_handler
|
||||
.execute(SqlRequest::DropTable(req), query_ctx)
|
||||
.await
|
||||
}
|
||||
Statement::ShowDatabases(stmt) => {
|
||||
self.sql_handler
|
||||
.execute(SqlRequest::ShowDatabases(stmt))
|
||||
.execute(SqlRequest::ShowDatabases(stmt), query_ctx)
|
||||
.await
|
||||
}
|
||||
Statement::ShowTables(stmt) => {
|
||||
self.sql_handler.execute(SqlRequest::ShowTables(stmt)).await
|
||||
self.sql_handler
|
||||
.execute(SqlRequest::ShowTables(stmt), query_ctx)
|
||||
.await
|
||||
}
|
||||
Statement::Explain(stmt) => {
|
||||
self.sql_handler
|
||||
.execute(SqlRequest::Explain(Box::new(stmt)))
|
||||
.execute(SqlRequest::Explain(Box::new(stmt)), query_ctx)
|
||||
.await
|
||||
}
|
||||
Statement::DescribeTable(stmt) => {
|
||||
self.sql_handler
|
||||
.execute(SqlRequest::DescribeTable(stmt))
|
||||
.execute(SqlRequest::DescribeTable(stmt), query_ctx)
|
||||
.await
|
||||
}
|
||||
Statement::ShowCreateTable(_stmt) => {
|
||||
unimplemented!("SHOW CREATE TABLE is unimplemented yet");
|
||||
}
|
||||
Statement::Use(db) => {
|
||||
ensure!(
|
||||
self.catalog_manager
|
||||
.schema(DEFAULT_CATALOG_NAME, &db)
|
||||
.context(error::CatalogSnafu)?
|
||||
.is_some(),
|
||||
error::SchemaNotFoundSnafu { name: &db }
|
||||
);
|
||||
|
||||
query_ctx.set_current_schema(&db);
|
||||
|
||||
Ok(Output::RecordBatches(RecordBatches::empty()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(LFC): Refactor consideration: move this function to some helper mod,
|
||||
// could be done together or after `TableReference`'s refactoring, when issue #559 is resolved.
|
||||
/// Converts maybe fully-qualified table name (`<catalog>.<schema>.<table>`) to tuple.
|
||||
fn table_idents_to_full_name(
|
||||
obj_name: &ObjectName,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<(String, String, String)> {
|
||||
match &obj_name.0[..] {
|
||||
[table] => Ok((
|
||||
DEFAULT_CATALOG_NAME.to_string(),
|
||||
query_ctx.current_schema().unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()),
|
||||
table.value.clone(),
|
||||
)),
|
||||
[schema, table] => Ok((
|
||||
DEFAULT_CATALOG_NAME.to_string(),
|
||||
schema.value.clone(),
|
||||
table.value.clone(),
|
||||
)),
|
||||
[catalog, schema, table] => Ok((
|
||||
catalog.value.clone(),
|
||||
schema.value.clone(),
|
||||
table.value.clone(),
|
||||
)),
|
||||
_ => error::InvalidSqlSnafu {
|
||||
msg: format!(
|
||||
"expect table name to be <catalog>.<schema>.<table>, <schema>.<table> or <table>, actual: {}",
|
||||
obj_name
|
||||
),
|
||||
}.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SqlQueryHandler for Instance {
|
||||
async fn do_query(&self, query: &str) -> servers::error::Result<Output> {
|
||||
async fn do_query(
|
||||
&self,
|
||||
query: &str,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> servers::error::Result<Output> {
|
||||
let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED);
|
||||
self.execute_sql(query)
|
||||
self.execute_sql(query, query_ctx)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(e; "Instance failed to execute sql");
|
||||
@@ -149,3 +204,78 @@ impl SqlQueryHandler for Instance {
|
||||
.context(servers::error::ExecuteQuerySnafu { query })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::sync::Arc;
|
||||
|
||||
use session::context::QueryContext;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_table_idents_to_full_name() {
|
||||
let my_catalog = "my_catalog";
|
||||
let my_schema = "my_schema";
|
||||
let my_table = "my_table";
|
||||
|
||||
let full = ObjectName(vec![my_catalog.into(), my_schema.into(), my_table.into()]);
|
||||
let partial = ObjectName(vec![my_schema.into(), my_table.into()]);
|
||||
let bare = ObjectName(vec![my_table.into()]);
|
||||
|
||||
let using_schema = "foo";
|
||||
let query_ctx = Arc::new(QueryContext::with_current_schema(using_schema.to_string()));
|
||||
let empty_ctx = Arc::new(QueryContext::new());
|
||||
|
||||
assert_eq!(
|
||||
table_idents_to_full_name(&full, query_ctx.clone()).unwrap(),
|
||||
(
|
||||
my_catalog.to_string(),
|
||||
my_schema.to_string(),
|
||||
my_table.to_string()
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
table_idents_to_full_name(&full, empty_ctx.clone()).unwrap(),
|
||||
(
|
||||
my_catalog.to_string(),
|
||||
my_schema.to_string(),
|
||||
my_table.to_string()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
table_idents_to_full_name(&partial, query_ctx.clone()).unwrap(),
|
||||
(
|
||||
DEFAULT_CATALOG_NAME.to_string(),
|
||||
my_schema.to_string(),
|
||||
my_table.to_string()
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
table_idents_to_full_name(&partial, empty_ctx.clone()).unwrap(),
|
||||
(
|
||||
DEFAULT_CATALOG_NAME.to_string(),
|
||||
my_schema.to_string(),
|
||||
my_table.to_string()
|
||||
)
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
table_idents_to_full_name(&bare, query_ctx).unwrap(),
|
||||
(
|
||||
DEFAULT_CATALOG_NAME.to_string(),
|
||||
using_schema.to_string(),
|
||||
my_table.to_string()
|
||||
)
|
||||
);
|
||||
assert_eq!(
|
||||
table_idents_to_full_name(&bare, empty_ctx).unwrap(),
|
||||
(
|
||||
DEFAULT_CATALOG_NAME.to_string(),
|
||||
DEFAULT_SCHEMA_NAME.to_string(),
|
||||
my_table.to_string()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,6 @@ mod metric;
|
||||
mod mock;
|
||||
mod script;
|
||||
pub mod server;
|
||||
mod sql;
|
||||
pub mod sql;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
@@ -62,6 +62,7 @@ impl Services {
|
||||
Some(MysqlServer::create_server(
|
||||
instance.clone(),
|
||||
mysql_io_runtime,
|
||||
Default::default(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::result::AdminResultBuilder;
|
||||
use api::v1::{AdminResult, AlterExpr, CreateExpr, DropTableExpr};
|
||||
use common_error::prelude::{ErrorExt, StatusCode};
|
||||
@@ -19,6 +21,7 @@ use common_grpc_expr::{alter_expr_to_request, create_expr_to_request};
|
||||
use common_query::Output;
|
||||
use common_telemetry::{error, info};
|
||||
use futures::TryFutureExt;
|
||||
use session::context::QueryContext;
|
||||
use snafu::prelude::*;
|
||||
use table::requests::DropTableRequest;
|
||||
|
||||
@@ -72,7 +75,12 @@ impl Instance {
|
||||
|
||||
let request = create_expr_to_request(table_id, expr).context(CreateExprToRequestSnafu);
|
||||
let result = futures::future::ready(request)
|
||||
.and_then(|request| self.sql_handler().execute(SqlRequest::CreateTable(request)))
|
||||
.and_then(|request| {
|
||||
self.sql_handler().execute(
|
||||
SqlRequest::CreateTable(request),
|
||||
Arc::new(QueryContext::new()),
|
||||
)
|
||||
})
|
||||
.await;
|
||||
match result {
|
||||
Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default()
|
||||
@@ -103,7 +111,10 @@ impl Instance {
|
||||
};
|
||||
|
||||
let result = futures::future::ready(request)
|
||||
.and_then(|request| self.sql_handler().execute(SqlRequest::Alter(request)))
|
||||
.and_then(|request| {
|
||||
self.sql_handler()
|
||||
.execute(SqlRequest::Alter(request), Arc::new(QueryContext::new()))
|
||||
})
|
||||
.await;
|
||||
match result {
|
||||
Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default()
|
||||
@@ -124,7 +135,10 @@ impl Instance {
|
||||
schema_name: expr.schema_name,
|
||||
table_name: expr.table_name,
|
||||
};
|
||||
let result = self.sql_handler().execute(SqlRequest::DropTable(req)).await;
|
||||
let result = self
|
||||
.sql_handler()
|
||||
.execute(SqlRequest::DropTable(req), Arc::new(QueryContext::new()))
|
||||
.await;
|
||||
match result {
|
||||
Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default()
|
||||
.status_code(StatusCode::Success as u32)
|
||||
|
||||
@@ -12,13 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! sql handler
|
||||
|
||||
use catalog::CatalogManagerRef;
|
||||
use common_query::Output;
|
||||
use common_telemetry::error;
|
||||
use query::query_engine::QueryEngineRef;
|
||||
use query::sql::{describe_table, explain, show_databases, show_tables};
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use sql::statements::describe::DescribeTable;
|
||||
use sql::statements::explain::Explain;
|
||||
@@ -67,7 +66,11 @@ impl SqlHandler {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn execute(&self, request: SqlRequest) -> Result<Output> {
|
||||
// TODO(LFC): Refactor consideration: a context awareness "Planner".
|
||||
// Now we have some query related state (like current using database in session context), maybe
|
||||
// we could create a new struct called `Planner` that stores context and handle these queries
|
||||
// there, instead of executing here in a "static" fashion.
|
||||
pub async fn execute(&self, request: SqlRequest, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
let result = match request {
|
||||
SqlRequest::Insert(req) => self.insert(req).await,
|
||||
SqlRequest::CreateTable(req) => self.create_table(req).await,
|
||||
@@ -78,12 +81,12 @@ impl SqlHandler {
|
||||
show_databases(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu)
|
||||
}
|
||||
SqlRequest::ShowTables(stmt) => {
|
||||
show_tables(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu)
|
||||
show_tables(stmt, self.catalog_manager.clone(), query_ctx).context(ExecuteSqlSnafu)
|
||||
}
|
||||
SqlRequest::DescribeTable(stmt) => {
|
||||
describe_table(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu)
|
||||
}
|
||||
SqlRequest::Explain(stmt) => explain(stmt, self.query_engine.clone())
|
||||
SqlRequest::Explain(stmt) => explain(stmt, self.query_engine.clone(), query_ctx)
|
||||
.await
|
||||
.context(ExecuteSqlSnafu),
|
||||
};
|
||||
@@ -114,7 +117,8 @@ mod tests {
|
||||
use std::any::Any;
|
||||
use std::sync::Arc;
|
||||
|
||||
use catalog::SchemaProvider;
|
||||
use catalog::{CatalogList, SchemaProvider};
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_query::logical_plan::Expr;
|
||||
use common_query::physical_plan::PhysicalPlanRef;
|
||||
use common_time::timestamp::Timestamp;
|
||||
@@ -234,9 +238,17 @@ mod tests {
|
||||
.await
|
||||
.unwrap(),
|
||||
);
|
||||
let catalog_provider = catalog_list.catalog(DEFAULT_CATALOG_NAME).unwrap().unwrap();
|
||||
catalog_provider
|
||||
.register_schema(
|
||||
DEFAULT_SCHEMA_NAME.to_string(),
|
||||
Arc::new(MockSchemaProvider {}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let factory = QueryEngineFactory::new(catalog_list.clone());
|
||||
let query_engine = factory.query_engine();
|
||||
let sql_handler = SqlHandler::new(table_engine, catalog_list, query_engine.clone());
|
||||
let sql_handler = SqlHandler::new(table_engine, catalog_list.clone(), query_engine.clone());
|
||||
|
||||
let stmt = match query_engine.sql_to_statement(sql).unwrap() {
|
||||
Statement::Insert(i) => i,
|
||||
@@ -244,9 +256,8 @@ mod tests {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
let schema_provider = Arc::new(MockSchemaProvider {});
|
||||
let request = sql_handler
|
||||
.insert_to_request(schema_provider, *stmt)
|
||||
.insert_to_request(catalog_list.clone(), *stmt, TableReference::bare("demo"))
|
||||
.unwrap();
|
||||
|
||||
match request {
|
||||
|
||||
@@ -16,7 +16,7 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_query::Output;
|
||||
use snafu::prelude::*;
|
||||
use sql::statements::alter::{AlterTable, AlterTableOperation};
|
||||
use sql::statements::{column_def_to_schema, table_idents_to_full_name};
|
||||
use sql::statements::column_def_to_schema;
|
||||
use table::engine::{EngineContext, TableReference};
|
||||
use table::requests::{AddColumnRequest, AlterKind, AlterTableRequest};
|
||||
|
||||
@@ -53,10 +53,11 @@ impl SqlHandler {
|
||||
Ok(Output::AffectedRows(0))
|
||||
}
|
||||
|
||||
pub(crate) fn alter_to_request(&self, alter_table: AlterTable) -> Result<AlterTableRequest> {
|
||||
let (catalog_name, schema_name, table_name) =
|
||||
table_idents_to_full_name(alter_table.table_name()).context(error::ParseSqlSnafu)?;
|
||||
|
||||
pub(crate) fn alter_to_request(
|
||||
&self,
|
||||
alter_table: AlterTable,
|
||||
table_ref: TableReference,
|
||||
) -> Result<AlterTableRequest> {
|
||||
let alter_kind = match alter_table.alter_operation() {
|
||||
AlterTableOperation::AddConstraint(table_constraint) => {
|
||||
return error::InvalidSqlSnafu {
|
||||
@@ -77,9 +78,9 @@ impl SqlHandler {
|
||||
},
|
||||
};
|
||||
Ok(AlterTableRequest {
|
||||
catalog_name: Some(catalog_name),
|
||||
schema_name: Some(schema_name),
|
||||
table_name,
|
||||
catalog_name: Some(table_ref.catalog.to_string()),
|
||||
schema_name: Some(table_ref.schema.to_string()),
|
||||
table_name: table_ref.table.to_string(),
|
||||
alter_kind,
|
||||
})
|
||||
}
|
||||
@@ -112,7 +113,9 @@ mod tests {
|
||||
async fn test_alter_to_request_with_adding_column() {
|
||||
let handler = create_mock_sql_handler().await;
|
||||
let alter_table = parse_sql("ALTER TABLE my_metric_1 ADD tagk_i STRING Null;");
|
||||
let req = handler.alter_to_request(alter_table).unwrap();
|
||||
let req = handler
|
||||
.alter_to_request(alter_table, TableReference::bare("my_metric_1"))
|
||||
.unwrap();
|
||||
assert_eq!(req.catalog_name, Some("greptime".to_string()));
|
||||
assert_eq!(req.schema_name, Some("public".to_string()));
|
||||
assert_eq!(req.table_name, "my_metric_1");
|
||||
|
||||
@@ -23,10 +23,10 @@ use common_telemetry::tracing::log::error;
|
||||
use datatypes::schema::SchemaBuilder;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use sql::ast::TableConstraint;
|
||||
use sql::statements::column_def_to_schema;
|
||||
use sql::statements::create::CreateTable;
|
||||
use sql::statements::{column_def_to_schema, table_idents_to_full_name};
|
||||
use store_api::storage::consts::TIME_INDEX_NAME;
|
||||
use table::engine::EngineContext;
|
||||
use table::engine::{EngineContext, TableReference};
|
||||
use table::metadata::TableId;
|
||||
use table::requests::*;
|
||||
|
||||
@@ -114,13 +114,11 @@ impl SqlHandler {
|
||||
&self,
|
||||
table_id: TableId,
|
||||
stmt: CreateTable,
|
||||
table_ref: TableReference,
|
||||
) -> Result<CreateTableRequest> {
|
||||
let mut ts_index = usize::MAX;
|
||||
let mut primary_keys = vec![];
|
||||
|
||||
let (catalog_name, schema_name, table_name) =
|
||||
table_idents_to_full_name(&stmt.name).context(error::ParseSqlSnafu)?;
|
||||
|
||||
let col_map = stmt
|
||||
.columns
|
||||
.iter()
|
||||
@@ -187,8 +185,8 @@ impl SqlHandler {
|
||||
|
||||
if primary_keys.is_empty() {
|
||||
info!(
|
||||
"Creating table: {:?}.{:?}.{} but primary key not set, use time index column: {}",
|
||||
catalog_name, schema_name, table_name, ts_index
|
||||
"Creating table: {} with time index column: {} upon primary keys absent",
|
||||
table_ref, ts_index
|
||||
);
|
||||
primary_keys.push(ts_index);
|
||||
}
|
||||
@@ -211,9 +209,9 @@ impl SqlHandler {
|
||||
|
||||
let request = CreateTableRequest {
|
||||
id: table_id,
|
||||
catalog_name,
|
||||
schema_name,
|
||||
table_name,
|
||||
catalog_name: table_ref.catalog.to_string(),
|
||||
schema_name: table_ref.schema.to_string(),
|
||||
table_name: table_ref.table.to_string(),
|
||||
desc: None,
|
||||
schema,
|
||||
region_numbers: vec![0],
|
||||
@@ -261,7 +259,9 @@ mod tests {
|
||||
TIME INDEX (ts),
|
||||
PRIMARY KEY(host)) engine=mito with(regions=1);"#,
|
||||
);
|
||||
let c = handler.create_to_request(42, parsed_stmt).unwrap();
|
||||
let c = handler
|
||||
.create_to_request(42, parsed_stmt, TableReference::bare("demo_table"))
|
||||
.unwrap();
|
||||
assert_eq!("demo_table", c.table_name);
|
||||
assert_eq!(42, c.id);
|
||||
assert!(!c.create_if_not_exists);
|
||||
@@ -282,7 +282,9 @@ mod tests {
|
||||
memory double,
|
||||
PRIMARY KEY(host)) engine=mito with(regions=1);"#,
|
||||
);
|
||||
let error = handler.create_to_request(42, parsed_stmt).unwrap_err();
|
||||
let error = handler
|
||||
.create_to_request(42, parsed_stmt, TableReference::bare("demo_table"))
|
||||
.unwrap_err();
|
||||
assert_matches!(error, Error::MissingTimestampColumn { .. });
|
||||
}
|
||||
|
||||
@@ -299,7 +301,9 @@ mod tests {
|
||||
memory double,
|
||||
TIME INDEX (ts)) engine=mito with(regions=1);"#,
|
||||
);
|
||||
let c = handler.create_to_request(42, parsed_stmt).unwrap();
|
||||
let c = handler
|
||||
.create_to_request(42, parsed_stmt, TableReference::bare("demo_table"))
|
||||
.unwrap();
|
||||
assert_eq!(1, c.primary_key_indices.len());
|
||||
assert_eq!(
|
||||
c.schema.timestamp_index().unwrap(),
|
||||
@@ -318,7 +322,9 @@ mod tests {
|
||||
TIME INDEX (ts)) engine=mito with(regions=1);"#,
|
||||
);
|
||||
|
||||
let error = handler.create_to_request(42, parsed_stmt).unwrap_err();
|
||||
let error = handler
|
||||
.create_to_request(42, parsed_stmt, TableReference::bare("demo_table"))
|
||||
.unwrap_err();
|
||||
assert_matches!(error, Error::KeyColumnNotFound { .. });
|
||||
}
|
||||
|
||||
@@ -338,7 +344,9 @@ mod tests {
|
||||
|
||||
let handler = create_mock_sql_handler().await;
|
||||
|
||||
let error = handler.create_to_request(42, create_table).unwrap_err();
|
||||
let error = handler
|
||||
.create_to_request(42, create_table, TableReference::full("c", "s", "demo"))
|
||||
.unwrap_err();
|
||||
assert_matches!(error, Error::InvalidPrimaryKey { .. });
|
||||
}
|
||||
|
||||
@@ -358,7 +366,9 @@ mod tests {
|
||||
|
||||
let handler = create_mock_sql_handler().await;
|
||||
|
||||
let request = handler.create_to_request(42, create_table).unwrap();
|
||||
let request = handler
|
||||
.create_to_request(42, create_table, TableReference::full("c", "s", "demo"))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(42, request.id);
|
||||
assert_eq!("c".to_string(), request.catalog_name);
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use catalog::SchemaProviderRef;
|
||||
use catalog::CatalogManagerRef;
|
||||
use common_query::Output;
|
||||
use datatypes::prelude::{ConcreteDataType, VectorBuilder};
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
@@ -23,7 +23,7 @@ use table::engine::TableReference;
|
||||
use table::requests::*;
|
||||
|
||||
use crate::error::{
|
||||
CatalogSnafu, ColumnNotFoundSnafu, ColumnValuesNumberMismatchSnafu, InsertSnafu, ParseSqlSnafu,
|
||||
CatalogSnafu, ColumnNotFoundSnafu, ColumnValuesNumberMismatchSnafu, InsertSnafu,
|
||||
ParseSqlValueSnafu, Result, TableNotFoundSnafu,
|
||||
};
|
||||
use crate::sql::{SqlHandler, SqlRequest};
|
||||
@@ -49,19 +49,18 @@ impl SqlHandler {
|
||||
|
||||
pub(crate) fn insert_to_request(
|
||||
&self,
|
||||
schema_provider: SchemaProviderRef,
|
||||
catalog_manager: CatalogManagerRef,
|
||||
stmt: Insert,
|
||||
table_ref: TableReference,
|
||||
) -> Result<SqlRequest> {
|
||||
let columns = stmt.columns();
|
||||
let values = stmt.values().context(ParseSqlValueSnafu)?;
|
||||
let (catalog_name, schema_name, table_name) =
|
||||
stmt.full_table_name().context(ParseSqlSnafu)?;
|
||||
|
||||
let table = schema_provider
|
||||
.table(&table_name)
|
||||
let table = catalog_manager
|
||||
.table(table_ref.catalog, table_ref.schema, table_ref.table)
|
||||
.context(CatalogSnafu)?
|
||||
.context(TableNotFoundSnafu {
|
||||
table_name: &table_name,
|
||||
table_name: table_ref.table,
|
||||
})?;
|
||||
let schema = table.schema();
|
||||
let columns_num = if columns.is_empty() {
|
||||
@@ -88,7 +87,7 @@ impl SqlHandler {
|
||||
let column_schema =
|
||||
schema.column_schema_by_name(column_name).with_context(|| {
|
||||
ColumnNotFoundSnafu {
|
||||
table_name: &table_name,
|
||||
table_name: table_ref.table,
|
||||
column_name: column_name.to_string(),
|
||||
}
|
||||
})?;
|
||||
@@ -119,9 +118,9 @@ impl SqlHandler {
|
||||
}
|
||||
|
||||
Ok(SqlRequest::Insert(InsertRequest {
|
||||
catalog_name,
|
||||
schema_name,
|
||||
table_name,
|
||||
catalog_name: table_ref.catalog.to_string(),
|
||||
schema_name: table_ref.schema.to_string(),
|
||||
table_name: table_ref.table.to_string(),
|
||||
columns_values: columns_builders
|
||||
.into_iter()
|
||||
.map(|(c, _, mut b)| (c.to_owned(), b.finish()))
|
||||
|
||||
@@ -12,7 +12,5 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod grpc_test;
|
||||
mod http_test;
|
||||
mod instance_test;
|
||||
pub(crate) mod test_util;
|
||||
|
||||
@@ -12,6 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::util;
|
||||
use datafusion::arrow_print;
|
||||
@@ -19,6 +22,7 @@ use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
|
||||
use datatypes::arrow::array::{Int64Array, UInt64Array, Utf8Array};
|
||||
use datatypes::arrow_array::StringArray;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use session::context::QueryContext;
|
||||
|
||||
use crate::instance::Instance;
|
||||
use crate::tests::test_util;
|
||||
@@ -32,39 +36,33 @@ async fn test_create_database_and_insert_query() {
|
||||
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
let output = instance.execute_sql("create database test").await.unwrap();
|
||||
let output = execute_sql(&instance, "create database test").await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"create table greptime.test.demo(
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
r#"create table greptime.test.demo(
|
||||
host STRING,
|
||||
cpu DOUBLE,
|
||||
memory DOUBLE,
|
||||
ts bigint,
|
||||
TIME INDEX(ts)
|
||||
)"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"insert into test.demo(host, cpu, memory, ts) values
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
r#"insert into test.demo(host, cpu, memory, ts) values
|
||||
('host1', 66.6, 1024, 1655276557000),
|
||||
('host2', 88.8, 333.3, 1655276558000)
|
||||
"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(2)));
|
||||
|
||||
let query_output = instance
|
||||
.execute_sql("select ts from test.demo order by ts")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let query_output = execute_sql(&instance, "select ts from test.demo order by ts").await;
|
||||
match query_output {
|
||||
Output::Stream(s) => {
|
||||
let batches = util::collect(s).await.unwrap();
|
||||
@@ -88,54 +86,50 @@ async fn test_issue477_same_table_name_in_different_databases() {
|
||||
instance.start().await.unwrap();
|
||||
|
||||
// Create database a and b
|
||||
let output = instance.execute_sql("create database a").await.unwrap();
|
||||
let output = execute_sql(&instance, "create database a").await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
let output = instance.execute_sql("create database b").await.unwrap();
|
||||
let output = execute_sql(&instance, "create database b").await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
// Create table a.demo and b.demo
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"create table a.demo(
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
r#"create table a.demo(
|
||||
host STRING,
|
||||
ts bigint,
|
||||
TIME INDEX(ts)
|
||||
)"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"create table b.demo(
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
r#"create table b.demo(
|
||||
host STRING,
|
||||
ts bigint,
|
||||
TIME INDEX(ts)
|
||||
)"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
// Insert different data into a.demo and b.demo
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"insert into a.demo(host, ts) values
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
r#"insert into a.demo(host, ts) values
|
||||
('host1', 1655276557000)
|
||||
"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"insert into b.demo(host, ts) values
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
r#"insert into b.demo(host, ts) values
|
||||
('host2',1655276558000)
|
||||
"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
// Query data and assert
|
||||
@@ -157,7 +151,7 @@ async fn test_issue477_same_table_name_in_different_databases() {
|
||||
}
|
||||
|
||||
async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str) {
|
||||
let query_output = instance.execute_sql(sql).await.unwrap();
|
||||
let query_output = execute_sql(instance, sql).await;
|
||||
match query_output {
|
||||
Output::Stream(s) => {
|
||||
let batches = util::collect(s).await.unwrap();
|
||||
@@ -200,15 +194,14 @@ async fn setup_test_instance() -> Instance {
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_execute_insert() {
|
||||
let instance = setup_test_instance().await;
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"insert into demo(host, cpu, memory, ts) values
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
r#"insert into demo(host, cpu, memory, ts) values
|
||||
('host1', 66.6, 1024, 1655276557000),
|
||||
('host2', 88.8, 333.3, 1655276558000)
|
||||
"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(2)));
|
||||
}
|
||||
|
||||
@@ -228,22 +221,17 @@ async fn test_execute_insert_query_with_i64_timestamp() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"insert into demo(host, cpu, memory, ts) values
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
r#"insert into demo(host, cpu, memory, ts) values
|
||||
('host1', 66.6, 1024, 1655276557000),
|
||||
('host2', 88.8, 333.3, 1655276558000)
|
||||
"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(2)));
|
||||
|
||||
let query_output = instance
|
||||
.execute_sql("select ts from demo order by ts")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let query_output = execute_sql(&instance, "select ts from demo order by ts").await;
|
||||
match query_output {
|
||||
Output::Stream(s) => {
|
||||
let batches = util::collect(s).await.unwrap();
|
||||
@@ -257,11 +245,7 @@ async fn test_execute_insert_query_with_i64_timestamp() {
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let query_output = instance
|
||||
.execute_sql("select ts as time from demo order by ts")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let query_output = execute_sql(&instance, "select ts as time from demo order by ts").await;
|
||||
match query_output {
|
||||
Output::Stream(s) => {
|
||||
let batches = util::collect(s).await.unwrap();
|
||||
@@ -282,10 +266,7 @@ async fn test_execute_query() {
|
||||
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
let output = instance
|
||||
.execute_sql("select sum(number) from numbers limit 20")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(&instance, "select sum(number) from numbers limit 20").await;
|
||||
match output {
|
||||
Output::Stream(recordbatch) => {
|
||||
let numbers = util::collect(recordbatch).await.unwrap();
|
||||
@@ -309,7 +290,7 @@ async fn test_execute_show_databases_tables() {
|
||||
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
let output = instance.execute_sql("show databases").await.unwrap();
|
||||
let output = execute_sql(&instance, "show databases").await;
|
||||
match output {
|
||||
Output::RecordBatches(databases) => {
|
||||
let databases = databases.take();
|
||||
@@ -325,10 +306,7 @@ async fn test_execute_show_databases_tables() {
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let output = instance
|
||||
.execute_sql("show databases like '%bl%'")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(&instance, "show databases like '%bl%'").await;
|
||||
match output {
|
||||
Output::RecordBatches(databases) => {
|
||||
let databases = databases.take();
|
||||
@@ -344,7 +322,7 @@ async fn test_execute_show_databases_tables() {
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let output = instance.execute_sql("show tables").await.unwrap();
|
||||
let output = execute_sql(&instance, "show tables").await;
|
||||
match output {
|
||||
Output::RecordBatches(databases) => {
|
||||
let databases = databases.take();
|
||||
@@ -364,7 +342,7 @@ async fn test_execute_show_databases_tables() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let output = instance.execute_sql("show tables").await.unwrap();
|
||||
let output = execute_sql(&instance, "show tables").await;
|
||||
match output {
|
||||
Output::RecordBatches(databases) => {
|
||||
let databases = databases.take();
|
||||
@@ -376,10 +354,7 @@ async fn test_execute_show_databases_tables() {
|
||||
}
|
||||
|
||||
// show tables like [string]
|
||||
let output = instance
|
||||
.execute_sql("show tables like 'de%'")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(&instance, "show tables like 'de%'").await;
|
||||
match output {
|
||||
Output::RecordBatches(databases) => {
|
||||
let databases = databases.take();
|
||||
@@ -404,9 +379,9 @@ pub async fn test_execute_create() {
|
||||
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"create table test_table(
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
r#"create table test_table(
|
||||
host string,
|
||||
ts timestamp,
|
||||
cpu double default 0,
|
||||
@@ -414,56 +389,24 @@ pub async fn test_execute_create() {
|
||||
TIME INDEX (ts),
|
||||
PRIMARY KEY(host)
|
||||
) engine=mito with(regions=1);"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
pub async fn test_create_table_illegal_timestamp_type() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let (opts, _guard) =
|
||||
test_util::create_tmp_dir_and_datanode_opts("create_table_illegal_timestamp_type");
|
||||
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
r#"create table test_table(
|
||||
host string,
|
||||
ts bigint,
|
||||
cpu double default 0,
|
||||
memory double,
|
||||
TIME INDEX (ts),
|
||||
PRIMARY KEY(host)
|
||||
) engine=mito with(regions=1);"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
match output {
|
||||
Output::AffectedRows(rows) => {
|
||||
assert_eq!(1, rows);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_output_stream(output: Output, expected: Vec<&str>) {
|
||||
match output {
|
||||
Output::Stream(stream) => {
|
||||
let recordbatches = util::collect(stream).await.unwrap();
|
||||
let recordbatch = recordbatches
|
||||
.into_iter()
|
||||
.map(|r| r.df_recordbatch)
|
||||
.collect::<Vec<DfRecordBatch>>();
|
||||
let pretty_print = arrow_print::write(&recordbatch);
|
||||
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
|
||||
assert_eq!(pretty_print, expected);
|
||||
}
|
||||
let recordbatches = match output {
|
||||
Output::Stream(stream) => util::collect(stream).await.unwrap(),
|
||||
Output::RecordBatches(recordbatches) => recordbatches.take(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
let recordbatches = recordbatches
|
||||
.into_iter()
|
||||
.map(|r| r.df_recordbatch)
|
||||
.collect::<Vec<DfRecordBatch>>();
|
||||
let pretty_print = arrow_print::write(&recordbatches);
|
||||
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
|
||||
assert_eq!(pretty_print, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -479,35 +422,30 @@ async fn test_alter_table() {
|
||||
.await
|
||||
.unwrap();
|
||||
// make sure table insertion is ok before altering table
|
||||
instance
|
||||
.execute_sql("insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000)")
|
||||
.await
|
||||
.unwrap();
|
||||
execute_sql(
|
||||
&instance,
|
||||
"insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000)",
|
||||
)
|
||||
.await;
|
||||
|
||||
// Add column
|
||||
let output = instance
|
||||
.execute_sql("alter table demo add my_tag string null")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(&instance, "alter table demo add my_tag string null").await;
|
||||
assert!(matches!(output, Output::AffectedRows(0)));
|
||||
|
||||
let output = instance
|
||||
.execute_sql(
|
||||
"insert into demo(host, cpu, memory, ts, my_tag) values ('host2', 2.2, 200, 2000, 'hello')",
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
"insert into demo(host, cpu, memory, ts, my_tag) values ('host2', 2.2, 200, 2000, 'hello')",
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
let output = instance
|
||||
.execute_sql("insert into demo(host, cpu, memory, ts) values ('host3', 3.3, 300, 3000)")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
"insert into demo(host, cpu, memory, ts) values ('host3', 3.3, 300, 3000)",
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
let output = instance
|
||||
.execute_sql("select * from demo order by ts")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(&instance, "select * from demo order by ts").await;
|
||||
let expected = vec![
|
||||
"+-------+-----+--------+---------------------+--------+",
|
||||
"| host | cpu | memory | ts | my_tag |",
|
||||
@@ -520,16 +458,10 @@ async fn test_alter_table() {
|
||||
check_output_stream(output, expected).await;
|
||||
|
||||
// Drop a column
|
||||
let output = instance
|
||||
.execute_sql("alter table demo drop column memory")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(&instance, "alter table demo drop column memory").await;
|
||||
assert!(matches!(output, Output::AffectedRows(0)));
|
||||
|
||||
let output = instance
|
||||
.execute_sql("select * from demo order by ts")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(&instance, "select * from demo order by ts").await;
|
||||
let expected = vec![
|
||||
"+-------+-----+---------------------+--------+",
|
||||
"| host | cpu | ts | my_tag |",
|
||||
@@ -542,16 +474,14 @@ async fn test_alter_table() {
|
||||
check_output_stream(output, expected).await;
|
||||
|
||||
// insert a new row
|
||||
let output = instance
|
||||
.execute_sql("insert into demo(host, cpu, ts, my_tag) values ('host4', 400, 4000, 'world')")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
"insert into demo(host, cpu, ts, my_tag) values ('host4', 400, 4000, 'world')",
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
let output = instance
|
||||
.execute_sql("select * from demo order by ts")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(&instance, "select * from demo order by ts").await;
|
||||
let expected = vec![
|
||||
"+-------+-----+---------------------+--------+",
|
||||
"| host | cpu | ts | my_tag |",
|
||||
@@ -580,27 +510,26 @@ async fn test_insert_with_default_value_for_type(type_name: &str) {
|
||||
) engine=mito with(regions=1);"#,
|
||||
type_name
|
||||
);
|
||||
let output = instance.execute_sql(&create_sql).await.unwrap();
|
||||
let output = execute_sql(&instance, &create_sql).await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
// Insert with ts.
|
||||
instance
|
||||
.execute_sql("insert into test_table(host, cpu, ts) values ('host1', 1.1, 1000)")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
"insert into test_table(host, cpu, ts) values ('host1', 1.1, 1000)",
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
// Insert without ts, so it should be filled by default value.
|
||||
let output = instance
|
||||
.execute_sql("insert into test_table(host, cpu) values ('host2', 2.2)")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(
|
||||
&instance,
|
||||
"insert into test_table(host, cpu) values ('host2', 2.2)",
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
let output = instance
|
||||
.execute_sql("select host, cpu from test_table")
|
||||
.await
|
||||
.unwrap();
|
||||
let output = execute_sql(&instance, "select host, cpu from test_table").await;
|
||||
let expected = vec![
|
||||
"+-------+-----+",
|
||||
"| host | cpu |",
|
||||
@@ -619,3 +548,70 @@ async fn test_insert_with_default_value() {
|
||||
test_insert_with_default_value_for_type("timestamp").await;
|
||||
test_insert_with_default_value_for_type("bigint").await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_use_database() {
|
||||
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("use_database");
|
||||
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
|
||||
instance.start().await.unwrap();
|
||||
|
||||
let output = execute_sql(&instance, "create database db1").await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
let output = execute_sql_in_db(
|
||||
&instance,
|
||||
"create table tb1(col_i32 int, ts bigint, TIME INDEX(ts))",
|
||||
"db1",
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
let output = execute_sql_in_db(&instance, "show tables", "db1").await;
|
||||
let expected = vec![
|
||||
"+--------+",
|
||||
"| Tables |",
|
||||
"+--------+",
|
||||
"| tb1 |",
|
||||
"+--------+",
|
||||
];
|
||||
check_output_stream(output, expected).await;
|
||||
|
||||
let output = execute_sql_in_db(
|
||||
&instance,
|
||||
r#"insert into tb1(col_i32, ts) values (1, 1655276557000)"#,
|
||||
"db1",
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(output, Output::AffectedRows(1)));
|
||||
|
||||
let output = execute_sql_in_db(&instance, "select col_i32 from tb1", "db1").await;
|
||||
let expected = vec![
|
||||
"+---------+",
|
||||
"| col_i32 |",
|
||||
"+---------+",
|
||||
"| 1 |",
|
||||
"+---------+",
|
||||
];
|
||||
check_output_stream(output, expected).await;
|
||||
|
||||
// Making a particular database the default by means of the USE statement does not preclude
|
||||
// accessing tables in other databases.
|
||||
let output = execute_sql(&instance, "select number from public.numbers limit 1").await;
|
||||
let expected = vec![
|
||||
"+--------+",
|
||||
"| number |",
|
||||
"+--------+",
|
||||
"| 0 |",
|
||||
"+--------+",
|
||||
];
|
||||
check_output_stream(output, expected).await;
|
||||
}
|
||||
|
||||
async fn execute_sql(instance: &Instance, sql: &str) -> Output {
|
||||
execute_sql_in_db(instance, sql, DEFAULT_SCHEMA_NAME).await
|
||||
}
|
||||
|
||||
async fn execute_sql_in_db(instance: &Instance, sql: &str, db: &str) -> Output {
|
||||
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
|
||||
instance.execute_sql(sql, query_ctx).await.unwrap()
|
||||
}
|
||||
|
||||
@@ -12,15 +12,15 @@ catalog = { path = "../catalog" }
|
||||
chrono = "0.4"
|
||||
client = { path = "../client" }
|
||||
common-base = { path = "../common/base" }
|
||||
common-catalog = { path = "../common/catalog" }
|
||||
common-error = { path = "../common/error" }
|
||||
common-grpc = { path = "../common/grpc" }
|
||||
common-grpc-expr = { path = "../common/grpc-expr" }
|
||||
common-query = { path = "../common/query" }
|
||||
common-recordbatch = { path = "../common/recordbatch" }
|
||||
common-catalog = { path = "../common/catalog" }
|
||||
common-runtime = { path = "../common/runtime" }
|
||||
common-telemetry = { path = "../common/telemetry" }
|
||||
common-time = { path = "../common/time" }
|
||||
common-grpc-expr = { path = "../common/grpc-expr" }
|
||||
datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch = "arrow2", features = [
|
||||
"simd",
|
||||
] }
|
||||
@@ -36,12 +36,14 @@ moka = { version = "0.9", features = ["future"] }
|
||||
openmetrics-parser = "0.4"
|
||||
prost = "0.11"
|
||||
query = { path = "../query" }
|
||||
rustls = "0.20"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
sqlparser = "0.15"
|
||||
servers = { path = "../servers" }
|
||||
session = { path = "../session" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
sql = { path = "../sql" }
|
||||
sqlparser = "0.15"
|
||||
store-api = { path = "../store-api" }
|
||||
table = { path = "../table" }
|
||||
tokio = { version = "1.18", features = ["full"] }
|
||||
|
||||
@@ -38,6 +38,7 @@ use common_error::prelude::{BoxedError, StatusCode};
|
||||
use common_grpc::channel_manager::{ChannelConfig, ChannelManager};
|
||||
use common_grpc::select::to_object_result;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::RecordBatches;
|
||||
use common_telemetry::{debug, error, info};
|
||||
use distributed::DistInstance;
|
||||
use meta_client::client::MetaClientBuilder;
|
||||
@@ -47,6 +48,7 @@ use servers::query_handler::{
|
||||
PrometheusProtocolHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler,
|
||||
};
|
||||
use servers::{error as server_error, Mode};
|
||||
use session::context::{QueryContext, QueryContextRef};
|
||||
use snafu::prelude::*;
|
||||
use sql::dialect::GenericDialect;
|
||||
use sql::parser::ParserContext;
|
||||
@@ -211,10 +213,15 @@ impl Instance {
|
||||
self.script_handler = Some(handler);
|
||||
}
|
||||
|
||||
pub async fn handle_select(&self, expr: Select, stmt: Statement) -> Result<Output> {
|
||||
async fn handle_select(
|
||||
&self,
|
||||
expr: Select,
|
||||
stmt: Statement,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Output> {
|
||||
if let Some(dist_instance) = &self.dist_instance {
|
||||
let Select::Sql(sql) = expr;
|
||||
dist_instance.handle_sql(&sql, stmt).await
|
||||
dist_instance.handle_sql(&sql, stmt, query_ctx).await
|
||||
} else {
|
||||
// TODO(LFC): Refactor consideration: Datanode should directly execute statement in standalone mode to avoid parse SQL again.
|
||||
// Find a better way to execute query between Frontend and Datanode in standalone mode.
|
||||
@@ -298,10 +305,15 @@ impl Instance {
|
||||
}
|
||||
|
||||
/// Handle explain expr
|
||||
pub async fn handle_explain(&self, sql: &str, explain_stmt: Explain) -> Result<Output> {
|
||||
pub async fn handle_explain(
|
||||
&self,
|
||||
sql: &str,
|
||||
explain_stmt: Explain,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Output> {
|
||||
if let Some(dist_instance) = &self.dist_instance {
|
||||
dist_instance
|
||||
.handle_sql(sql, Statement::Explain(explain_stmt))
|
||||
.handle_sql(sql, Statement::Explain(explain_stmt), query_ctx)
|
||||
.await
|
||||
} else {
|
||||
Ok(Output::AffectedRows(0))
|
||||
@@ -505,6 +517,26 @@ impl Instance {
|
||||
let insert_request = insert_to_request(&schema_provider, *insert)?;
|
||||
insert_request_to_insert_batch(&insert_request)
|
||||
}
|
||||
|
||||
fn handle_use(&self, db: String, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
let catalog_manager = &self.catalog_manager;
|
||||
if let Some(catalog_manager) = catalog_manager {
|
||||
ensure!(
|
||||
catalog_manager
|
||||
.schema(DEFAULT_CATALOG_NAME, &db)
|
||||
.context(error::CatalogSnafu)?
|
||||
.is_some(),
|
||||
error::SchemaNotFoundSnafu { schema_info: &db }
|
||||
);
|
||||
|
||||
query_ctx.set_current_schema(&db);
|
||||
|
||||
Ok(Output::RecordBatches(RecordBatches::empty()))
|
||||
} else {
|
||||
// TODO(LFC): Handle "use" stmt here.
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -545,17 +577,23 @@ fn parse_stmt(sql: &str) -> Result<Statement> {
|
||||
|
||||
#[async_trait]
|
||||
impl SqlQueryHandler for Instance {
|
||||
async fn do_query(&self, query: &str) -> server_error::Result<Output> {
|
||||
async fn do_query(
|
||||
&self,
|
||||
query: &str,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> server_error::Result<Output> {
|
||||
let stmt = parse_stmt(query)
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query })?;
|
||||
|
||||
match stmt {
|
||||
Statement::Query(_) => self
|
||||
.handle_select(Select::Sql(query.to_string()), stmt)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query }),
|
||||
Statement::ShowDatabases(_)
|
||||
| Statement::ShowTables(_)
|
||||
| Statement::DescribeTable(_)
|
||||
| Statement::Query(_) => {
|
||||
self.handle_select(Select::Sql(query.to_string()), stmt, query_ctx)
|
||||
.await
|
||||
}
|
||||
Statement::Insert(insert) => match self.mode {
|
||||
Mode::Standalone => {
|
||||
let (catalog_name, schema_name, table_name) = insert
|
||||
@@ -578,10 +616,7 @@ impl SqlQueryHandler for Instance {
|
||||
columns,
|
||||
row_count,
|
||||
};
|
||||
self.handle_insert(expr)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query })
|
||||
self.handle_insert(expr).await
|
||||
}
|
||||
Mode::Distributed => {
|
||||
let affected = self
|
||||
@@ -604,55 +639,36 @@ impl SqlQueryHandler for Instance {
|
||||
|
||||
self.handle_create_table(create_expr, create.partitions)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query })
|
||||
}
|
||||
|
||||
Statement::ShowDatabases(_)
|
||||
| Statement::ShowTables(_)
|
||||
| Statement::DescribeTable(_) => self
|
||||
.handle_select(Select::Sql(query.to_string()), stmt)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query }),
|
||||
|
||||
Statement::CreateDatabase(c) => {
|
||||
let expr = CreateDatabaseExpr {
|
||||
database_name: c.name.to_string(),
|
||||
};
|
||||
self.handle_create_database(expr)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query })
|
||||
self.handle_create_database(expr).await
|
||||
}
|
||||
Statement::Alter(alter_stmt) => self
|
||||
.handle_alter(
|
||||
Statement::Alter(alter_stmt) => {
|
||||
self.handle_alter(
|
||||
AlterExpr::try_from(alter_stmt)
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteAlterSnafu { query })?,
|
||||
)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query }),
|
||||
}
|
||||
Statement::DropTable(drop_stmt) => {
|
||||
let expr = DropTableExpr {
|
||||
catalog_name: drop_stmt.catalog_name,
|
||||
schema_name: drop_stmt.schema_name,
|
||||
table_name: drop_stmt.table_name,
|
||||
};
|
||||
self.handle_drop_table(expr)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query })
|
||||
self.handle_drop_table(expr).await
|
||||
}
|
||||
Statement::Explain(explain_stmt) => {
|
||||
self.handle_explain(query, explain_stmt, query_ctx).await
|
||||
}
|
||||
Statement::Explain(explain_stmt) => self
|
||||
.handle_explain(query, explain_stmt)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query }),
|
||||
Statement::ShowCreateTable(_) => {
|
||||
return server_error::NotSupportedSnafu { feat: query }.fail();
|
||||
}
|
||||
Statement::Use(db) => self.handle_use(db, query_ctx),
|
||||
}
|
||||
.map_err(BoxedError::new)
|
||||
.context(server_error::ExecuteQuerySnafu { query })
|
||||
@@ -716,7 +732,8 @@ impl GrpcQueryHandler for Instance {
|
||||
})?;
|
||||
match select {
|
||||
select_expr::Expr::Sql(sql) => {
|
||||
let output = SqlQueryHandler::do_query(self, sql).await;
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
let output = SqlQueryHandler::do_query(self, sql, query_ctx).await;
|
||||
Ok(to_object_result(output).await)
|
||||
}
|
||||
_ => {
|
||||
@@ -797,6 +814,8 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_sql() {
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
|
||||
let instance = tests::create_frontend_instance().await;
|
||||
|
||||
let sql = r#"CREATE TABLE demo(
|
||||
@@ -808,7 +827,9 @@ mod tests {
|
||||
TIME INDEX (ts),
|
||||
PRIMARY KEY(ts, host)
|
||||
) engine=mito with(regions=1);"#;
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
match output {
|
||||
Output::AffectedRows(rows) => assert_eq!(rows, 1),
|
||||
_ => unreachable!(),
|
||||
@@ -819,14 +840,18 @@ mod tests {
|
||||
('frontend.host2', null, null, 2000),
|
||||
('frontend.host3', 3.3, 300, 3000)
|
||||
"#;
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
match output {
|
||||
Output::AffectedRows(rows) => assert_eq!(rows, 3),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let sql = "select * from demo";
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
match output {
|
||||
Output::RecordBatches(recordbatches) => {
|
||||
let pretty_print = recordbatches.pretty_print();
|
||||
@@ -846,7 +871,9 @@ mod tests {
|
||||
};
|
||||
|
||||
let sql = "select * from demo where ts>cast(1000000000 as timestamp)"; // use nanoseconds as where condition
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
|
||||
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
match output {
|
||||
Output::RecordBatches(recordbatches) => {
|
||||
let pretty_print = recordbatches.pretty_print();
|
||||
|
||||
@@ -33,6 +33,7 @@ use meta_client::rpc::{
|
||||
};
|
||||
use query::sql::{describe_table, explain, show_databases, show_tables};
|
||||
use query::{QueryEngineFactory, QueryEngineRef};
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use sql::statements::create::Partitions;
|
||||
use sql::statements::sql_value_to_value;
|
||||
@@ -128,29 +129,31 @@ impl DistInstance {
|
||||
Ok(Output::AffectedRows(region_routes.len()))
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_sql(&self, sql: &str, stmt: Statement) -> Result<Output> {
|
||||
pub(crate) async fn handle_sql(
|
||||
&self,
|
||||
sql: &str,
|
||||
stmt: Statement,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Output> {
|
||||
match stmt {
|
||||
Statement::Query(_) => {
|
||||
let plan = self
|
||||
.query_engine
|
||||
.statement_to_plan(stmt)
|
||||
.statement_to_plan(stmt, query_ctx)
|
||||
.context(error::ExecuteSqlSnafu { sql })?;
|
||||
self.query_engine
|
||||
.execute(&plan)
|
||||
.await
|
||||
.context(error::ExecuteSqlSnafu { sql })
|
||||
self.query_engine.execute(&plan).await
|
||||
}
|
||||
Statement::ShowDatabases(stmt) => show_databases(stmt, self.catalog_manager.clone()),
|
||||
Statement::ShowTables(stmt) => {
|
||||
show_tables(stmt, self.catalog_manager.clone(), query_ctx)
|
||||
}
|
||||
Statement::DescribeTable(stmt) => describe_table(stmt, self.catalog_manager.clone()),
|
||||
Statement::Explain(stmt) => {
|
||||
explain(Box::new(stmt), self.query_engine.clone(), query_ctx).await
|
||||
}
|
||||
Statement::ShowDatabases(stmt) => show_databases(stmt, self.catalog_manager.clone())
|
||||
.context(error::ExecuteSqlSnafu { sql }),
|
||||
Statement::ShowTables(stmt) => show_tables(stmt, self.catalog_manager.clone())
|
||||
.context(error::ExecuteSqlSnafu { sql }),
|
||||
Statement::DescribeTable(stmt) => describe_table(stmt, self.catalog_manager.clone())
|
||||
.context(error::ExecuteSqlSnafu { sql }),
|
||||
Statement::Explain(stmt) => explain(Box::new(stmt), self.query_engine.clone())
|
||||
.await
|
||||
.context(error::ExecuteSqlSnafu { sql }),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
.context(error::ExecuteSqlSnafu { sql })
|
||||
}
|
||||
|
||||
/// Handles distributed database creation
|
||||
|
||||
@@ -60,9 +60,12 @@ impl Instance {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::Output;
|
||||
use datafusion::arrow_print;
|
||||
use servers::query_handler::SqlQueryHandler;
|
||||
use session::context::QueryContext;
|
||||
|
||||
use super::*;
|
||||
use crate::tests;
|
||||
@@ -121,7 +124,7 @@ mod tests {
|
||||
assert!(result.is_ok());
|
||||
|
||||
let output = instance
|
||||
.do_query("select * from my_metric_1")
|
||||
.do_query("select * from my_metric_1", Arc::new(QueryContext::new()))
|
||||
.await
|
||||
.unwrap();
|
||||
match output {
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::prometheus::remote::read_request::ResponseType;
|
||||
use api::prometheus::remote::{Query, QueryResult, ReadRequest, ReadResponse, WriteRequest};
|
||||
use async_trait::async_trait;
|
||||
@@ -25,6 +27,7 @@ use servers::error::{self, Result as ServerResult};
|
||||
use servers::prometheus::{self, Metrics};
|
||||
use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse};
|
||||
use servers::Mode;
|
||||
use session::context::QueryContext;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::instance::{parse_stmt, Instance};
|
||||
@@ -93,7 +96,10 @@ impl Instance {
|
||||
|
||||
let object_result = if let Some(dist_instance) = &self.dist_instance {
|
||||
let output = futures::future::ready(parse_stmt(&sql))
|
||||
.and_then(|stmt| dist_instance.handle_sql(&sql, stmt))
|
||||
.and_then(|stmt| {
|
||||
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
|
||||
dist_instance.handle_sql(&sql, stmt, query_ctx)
|
||||
})
|
||||
.await;
|
||||
to_object_result(output).await.try_into()
|
||||
} else {
|
||||
|
||||
@@ -12,12 +12,17 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use servers::tls::TlsOption;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct MysqlOptions {
|
||||
pub addr: String,
|
||||
pub runtime_size: usize,
|
||||
#[serde(default = "Default::default")]
|
||||
pub tls: Arc<TlsOption>,
|
||||
}
|
||||
|
||||
impl Default for MysqlOptions {
|
||||
@@ -25,6 +30,7 @@ impl Default for MysqlOptions {
|
||||
Self {
|
||||
addr: "127.0.0.1:4002".to_string(),
|
||||
runtime_size: 2,
|
||||
tls: Arc::new(TlsOption::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,13 +12,18 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use servers::tls::TlsOption;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PostgresOptions {
|
||||
pub addr: String,
|
||||
pub runtime_size: usize,
|
||||
pub check_pwd: bool,
|
||||
#[serde(default = "Default::default")]
|
||||
pub tls: Arc<TlsOption>,
|
||||
}
|
||||
|
||||
impl Default for PostgresOptions {
|
||||
@@ -27,6 +32,7 @@ impl Default for PostgresOptions {
|
||||
addr: "127.0.0.1:4003".to_string(),
|
||||
runtime_size: 2,
|
||||
check_pwd: false,
|
||||
tls: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +69,8 @@ impl Services {
|
||||
.context(error::RuntimeResourceSnafu)?,
|
||||
);
|
||||
|
||||
let mysql_server = MysqlServer::create_server(instance.clone(), mysql_io_runtime);
|
||||
let mysql_server =
|
||||
MysqlServer::create_server(instance.clone(), mysql_io_runtime, opts.tls.clone());
|
||||
|
||||
Some((mysql_server, mysql_addr))
|
||||
} else {
|
||||
@@ -90,6 +91,7 @@ impl Services {
|
||||
let pg_server = Box::new(PostgresServer::new(
|
||||
instance.clone(),
|
||||
opts.check_pwd,
|
||||
opts.tls.clone(),
|
||||
pg_io_runtime,
|
||||
)) as Box<dyn Server>;
|
||||
|
||||
|
||||
@@ -57,8 +57,8 @@ fn region_id(table_id: TableId, n: u32) -> RegionId {
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn table_dir(schema_name: &str, table_name: &str) -> String {
|
||||
format!("{}/{}/", schema_name, table_name)
|
||||
fn table_dir(schema_name: &str, table_name: &str, table_id: TableId) -> String {
|
||||
format!("{}/{}_{}/", schema_name, table_name, table_id)
|
||||
}
|
||||
|
||||
/// [TableEngine] implementation.
|
||||
@@ -317,7 +317,7 @@ impl<S: StorageEngine> MitoEngineInner<S> {
|
||||
}
|
||||
}
|
||||
|
||||
let table_dir = table_dir(schema_name, table_name);
|
||||
let table_dir = table_dir(schema_name, table_name, table_id);
|
||||
let opts = CreateOptions {
|
||||
parent_dir: table_dir.clone(),
|
||||
};
|
||||
@@ -396,13 +396,13 @@ impl<S: StorageEngine> MitoEngineInner<S> {
|
||||
return Ok(Some(table));
|
||||
}
|
||||
|
||||
let table_id = request.table_id;
|
||||
let engine_ctx = StorageEngineContext::default();
|
||||
let table_dir = table_dir(schema_name, table_name);
|
||||
let table_dir = table_dir(schema_name, table_name, table_id);
|
||||
let opts = OpenOptions {
|
||||
parent_dir: table_dir.to_string(),
|
||||
};
|
||||
|
||||
let table_id = request.table_id;
|
||||
// TODO(dennis): supports multi regions;
|
||||
assert_eq!(request.region_numbers.len(), 1);
|
||||
let region_number = request.region_numbers[0];
|
||||
@@ -642,8 +642,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_table_dir() {
|
||||
assert_eq!("public/test_table/", table_dir("public", "test_table"));
|
||||
assert_eq!("prometheus/demo/", table_dir("prometheus", "demo"));
|
||||
assert_eq!(
|
||||
"public/test_table_1024/",
|
||||
table_dir("public", "test_table", 1024)
|
||||
);
|
||||
assert_eq!(
|
||||
"prometheus/demo_1024/",
|
||||
table_dir("prometheus", "demo", 1024)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -6,10 +6,11 @@ license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
futures = { version = "0.3" }
|
||||
opendal = { version = "0.21", features = ["layers-tracing", "layers-metrics"]}
|
||||
opendal = { version = "0.21", features = ["layers-tracing", "layers-metrics"] }
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1.0"
|
||||
common-telemetry = { path = "../common/telemetry" }
|
||||
tempdir = "0.3"
|
||||
uuid = { version = "1", features = ["serde", "v4"] }
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
pub use opendal::raw::SeekableReader;
|
||||
pub use opendal::{
|
||||
layers, services, Error, ErrorKind, Layer, Object, ObjectLister, ObjectMetadata, ObjectMode,
|
||||
Operator as ObjectStore,
|
||||
Operator as ObjectStore, Result,
|
||||
};
|
||||
pub mod backend;
|
||||
pub mod test_util;
|
||||
pub mod util;
|
||||
|
||||
35
src/object-store/src/test_util.rs
Normal file
35
src/object-store/src/test_util.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use crate::{ObjectStore, Result};
|
||||
|
||||
pub struct TempFolder {
|
||||
store: ObjectStore,
|
||||
// The path under root.
|
||||
path: String,
|
||||
}
|
||||
|
||||
impl TempFolder {
|
||||
pub fn new(store: &ObjectStore, path: &str) -> Self {
|
||||
Self {
|
||||
store: store.clone(),
|
||||
path: path.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn remove_all(&mut self) -> Result<()> {
|
||||
let batch = self.store.batch();
|
||||
batch.remove_all(&self.path).await
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ use std::env;
|
||||
use anyhow::Result;
|
||||
use common_telemetry::logging;
|
||||
use object_store::backend::{fs, s3};
|
||||
use object_store::test_util::TempFolder;
|
||||
use object_store::{util, Object, ObjectLister, ObjectMode, ObjectStore};
|
||||
use tempdir::TempDir;
|
||||
|
||||
@@ -88,10 +89,12 @@ async fn test_object_list(store: &ObjectStore) -> Result<()> {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_fs_backend() -> Result<()> {
|
||||
let data_dir = TempDir::new("test_fs_backend")?;
|
||||
let tmp_dir = TempDir::new("test_fs_backend")?;
|
||||
let store = ObjectStore::new(
|
||||
fs::Builder::default()
|
||||
.root(&tmp_dir.path().to_string_lossy())
|
||||
.root(&data_dir.path().to_string_lossy())
|
||||
.atomic_write_dir(&tmp_dir.path().to_string_lossy())
|
||||
.build()?,
|
||||
);
|
||||
|
||||
@@ -108,15 +111,21 @@ async fn test_s3_backend() -> Result<()> {
|
||||
if !bucket.is_empty() {
|
||||
logging::info!("Running s3 test.");
|
||||
|
||||
let root = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let accessor = s3::Builder::default()
|
||||
.root(&root)
|
||||
.access_key_id(&env::var("GT_S3_ACCESS_KEY_ID")?)
|
||||
.secret_access_key(&env::var("GT_S3_ACCESS_KEY")?)
|
||||
.bucket(&bucket)
|
||||
.build()?;
|
||||
|
||||
let store = ObjectStore::new(accessor);
|
||||
|
||||
let mut guard = TempFolder::new(&store, "/");
|
||||
test_object_crud(&store).await?;
|
||||
test_object_list(&store).await?;
|
||||
guard.remove_all().await?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ metrics = "0.20"
|
||||
once_cell = "1.10"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
session = { path = "../session" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
sql = { path = "../sql" }
|
||||
table = { path = "../table" }
|
||||
|
||||
@@ -32,6 +32,7 @@ use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
|
||||
use common_telemetry::timer;
|
||||
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
|
||||
use datafusion::physical_plan::ExecutionPlan;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use sql::dialect::GenericDialect;
|
||||
use sql::parser::ParserContext;
|
||||
@@ -46,7 +47,7 @@ use crate::physical_optimizer::PhysicalOptimizer;
|
||||
use crate::physical_planner::PhysicalPlanner;
|
||||
use crate::plan::LogicalPlan;
|
||||
use crate::planner::Planner;
|
||||
use crate::query_engine::{QueryContext, QueryEngineState};
|
||||
use crate::query_engine::{QueryEngineContext, QueryEngineState};
|
||||
use crate::{metric, QueryEngine};
|
||||
|
||||
pub(crate) struct DatafusionQueryEngine {
|
||||
@@ -61,6 +62,7 @@ impl DatafusionQueryEngine {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(LFC): Refactor consideration: extract a "Planner" that stores query context and execute queries inside.
|
||||
#[async_trait::async_trait]
|
||||
impl QueryEngine for DatafusionQueryEngine {
|
||||
fn name(&self) -> &str {
|
||||
@@ -75,21 +77,25 @@ impl QueryEngine for DatafusionQueryEngine {
|
||||
Ok(statement.remove(0))
|
||||
}
|
||||
|
||||
fn statement_to_plan(&self, stmt: Statement) -> Result<LogicalPlan> {
|
||||
let context_provider = DfContextProviderAdapter::new(self.state.clone());
|
||||
fn statement_to_plan(
|
||||
&self,
|
||||
stmt: Statement,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<LogicalPlan> {
|
||||
let context_provider = DfContextProviderAdapter::new(self.state.clone(), query_ctx);
|
||||
let planner = DfPlanner::new(&context_provider);
|
||||
|
||||
planner.statement_to_plan(stmt)
|
||||
}
|
||||
|
||||
fn sql_to_plan(&self, sql: &str) -> Result<LogicalPlan> {
|
||||
fn sql_to_plan(&self, sql: &str, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
|
||||
let _timer = timer!(metric::METRIC_PARSE_SQL_ELAPSED);
|
||||
let stmt = self.sql_to_statement(sql)?;
|
||||
self.statement_to_plan(stmt)
|
||||
self.statement_to_plan(stmt, query_ctx)
|
||||
}
|
||||
|
||||
async fn execute(&self, plan: &LogicalPlan) -> Result<Output> {
|
||||
let mut ctx = QueryContext::new(self.state.clone());
|
||||
let mut ctx = QueryEngineContext::new(self.state.clone());
|
||||
let logical_plan = self.optimize_logical_plan(&mut ctx, plan)?;
|
||||
let physical_plan = self.create_physical_plan(&mut ctx, &logical_plan).await?;
|
||||
let physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
|
||||
@@ -100,7 +106,7 @@ impl QueryEngine for DatafusionQueryEngine {
|
||||
}
|
||||
|
||||
async fn execute_physical(&self, plan: &Arc<dyn PhysicalPlan>) -> Result<Output> {
|
||||
let ctx = QueryContext::new(self.state.clone());
|
||||
let ctx = QueryEngineContext::new(self.state.clone());
|
||||
Ok(Output::Stream(self.execute_stream(&ctx, plan).await?))
|
||||
}
|
||||
|
||||
@@ -127,7 +133,7 @@ impl QueryEngine for DatafusionQueryEngine {
|
||||
impl LogicalOptimizer for DatafusionQueryEngine {
|
||||
fn optimize_logical_plan(
|
||||
&self,
|
||||
_ctx: &mut QueryContext,
|
||||
_: &mut QueryEngineContext,
|
||||
plan: &LogicalPlan,
|
||||
) -> Result<LogicalPlan> {
|
||||
let _timer = timer!(metric::METRIC_OPTIMIZE_LOGICAL_ELAPSED);
|
||||
@@ -151,7 +157,7 @@ impl LogicalOptimizer for DatafusionQueryEngine {
|
||||
impl PhysicalPlanner for DatafusionQueryEngine {
|
||||
async fn create_physical_plan(
|
||||
&self,
|
||||
_ctx: &mut QueryContext,
|
||||
_: &mut QueryEngineContext,
|
||||
logical_plan: &LogicalPlan,
|
||||
) -> Result<Arc<dyn PhysicalPlan>> {
|
||||
let _timer = timer!(metric::METRIC_CREATE_PHYSICAL_ELAPSED);
|
||||
@@ -183,7 +189,7 @@ impl PhysicalPlanner for DatafusionQueryEngine {
|
||||
impl PhysicalOptimizer for DatafusionQueryEngine {
|
||||
fn optimize_physical_plan(
|
||||
&self,
|
||||
_ctx: &mut QueryContext,
|
||||
_: &mut QueryEngineContext,
|
||||
plan: Arc<dyn PhysicalPlan>,
|
||||
) -> Result<Arc<dyn PhysicalPlan>> {
|
||||
let _timer = timer!(metric::METRIC_OPTIMIZE_PHYSICAL_ELAPSED);
|
||||
@@ -211,7 +217,7 @@ impl PhysicalOptimizer for DatafusionQueryEngine {
|
||||
impl QueryExecutor for DatafusionQueryEngine {
|
||||
async fn execute_stream(
|
||||
&self,
|
||||
ctx: &QueryContext,
|
||||
ctx: &QueryEngineContext,
|
||||
plan: &Arc<dyn PhysicalPlan>,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let _timer = timer!(metric::METRIC_EXEC_PLAN_ELAPSED);
|
||||
@@ -250,6 +256,7 @@ mod tests {
|
||||
use common_recordbatch::util;
|
||||
use datafusion::field_util::{FieldExt, SchemaExt};
|
||||
use datatypes::arrow::array::UInt64Array;
|
||||
use session::context::QueryContext;
|
||||
use table::table::numbers::NumbersTable;
|
||||
|
||||
use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
|
||||
@@ -277,7 +284,9 @@ mod tests {
|
||||
let engine = create_test_engine();
|
||||
let sql = "select sum(number) from numbers limit 20";
|
||||
|
||||
let plan = engine.sql_to_plan(sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", plan),
|
||||
@@ -293,7 +302,9 @@ mod tests {
|
||||
let engine = create_test_engine();
|
||||
let sql = "select sum(number) from numbers limit 20";
|
||||
|
||||
let plan = engine.sql_to_plan(sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
|
||||
match output {
|
||||
|
||||
@@ -21,6 +21,7 @@ use datafusion::physical_plan::udaf::AggregateUDF;
|
||||
use datafusion::physical_plan::udf::ScalarUDF;
|
||||
use datafusion::sql::planner::{ContextProvider, SqlToRel};
|
||||
use datatypes::arrow::datatypes::DataType;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ResultExt;
|
||||
use sql::statements::explain::Explain;
|
||||
use sql::statements::query::Query;
|
||||
@@ -85,18 +86,20 @@ where
|
||||
| Statement::CreateDatabase(_)
|
||||
| Statement::Alter(_)
|
||||
| Statement::Insert(_)
|
||||
| Statement::DropTable(_) => unreachable!(),
|
||||
| Statement::DropTable(_)
|
||||
| Statement::Use(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct DfContextProviderAdapter {
|
||||
state: QueryEngineState,
|
||||
query_ctx: QueryContextRef,
|
||||
}
|
||||
|
||||
impl DfContextProviderAdapter {
|
||||
pub(crate) fn new(state: QueryEngineState) -> Self {
|
||||
Self { state }
|
||||
pub(crate) fn new(state: QueryEngineState, query_ctx: QueryContextRef) -> Self {
|
||||
Self { state, query_ctx }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,11 +107,18 @@ impl DfContextProviderAdapter {
|
||||
/// manage UDFs, UDAFs, variables by ourself in future.
|
||||
impl ContextProvider for DfContextProviderAdapter {
|
||||
fn get_table_provider(&self, name: TableReference) -> Option<Arc<dyn TableProvider>> {
|
||||
self.state
|
||||
.df_context()
|
||||
.state
|
||||
.lock()
|
||||
.get_table_provider(name)
|
||||
let schema = self.query_ctx.current_schema();
|
||||
let execution_ctx = self.state.df_context().state.lock();
|
||||
match name {
|
||||
TableReference::Bare { table } if schema.is_some() => {
|
||||
execution_ctx.get_table_provider(TableReference::Partial {
|
||||
// unwrap safety: checked in this match's arm
|
||||
schema: &schema.unwrap(),
|
||||
table,
|
||||
})
|
||||
}
|
||||
_ => execution_ctx.get_table_provider(name),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
|
||||
|
||||
@@ -18,14 +18,14 @@ use common_query::physical_plan::PhysicalPlan;
|
||||
use common_recordbatch::SendableRecordBatchStream;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::query_engine::QueryContext;
|
||||
use crate::query_engine::QueryEngineContext;
|
||||
|
||||
/// Executor to run [ExecutionPlan].
|
||||
#[async_trait::async_trait]
|
||||
pub trait QueryExecutor {
|
||||
async fn execute_stream(
|
||||
&self,
|
||||
ctx: &QueryContext,
|
||||
ctx: &QueryEngineContext,
|
||||
plan: &Arc<dyn PhysicalPlan>,
|
||||
) -> Result<SendableRecordBatchStream>;
|
||||
}
|
||||
|
||||
@@ -26,4 +26,6 @@ pub mod planner;
|
||||
pub mod query_engine;
|
||||
pub mod sql;
|
||||
|
||||
pub use crate::query_engine::{QueryContext, QueryEngine, QueryEngineFactory, QueryEngineRef};
|
||||
pub use crate::query_engine::{
|
||||
QueryEngine, QueryEngineContext, QueryEngineFactory, QueryEngineRef,
|
||||
};
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::plan::LogicalPlan;
|
||||
use crate::query_engine::QueryContext;
|
||||
use crate::query_engine::QueryEngineContext;
|
||||
|
||||
pub trait LogicalOptimizer {
|
||||
fn optimize_logical_plan(
|
||||
&self,
|
||||
ctx: &mut QueryContext,
|
||||
ctx: &mut QueryEngineContext,
|
||||
plan: &LogicalPlan,
|
||||
) -> Result<LogicalPlan>;
|
||||
}
|
||||
|
||||
@@ -17,12 +17,12 @@ use std::sync::Arc;
|
||||
use common_query::physical_plan::PhysicalPlan;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::query_engine::QueryContext;
|
||||
use crate::query_engine::QueryEngineContext;
|
||||
|
||||
pub trait PhysicalOptimizer {
|
||||
fn optimize_physical_plan(
|
||||
&self,
|
||||
ctx: &mut QueryContext,
|
||||
ctx: &mut QueryEngineContext,
|
||||
plan: Arc<dyn PhysicalPlan>,
|
||||
) -> Result<Arc<dyn PhysicalPlan>>;
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use common_query::physical_plan::PhysicalPlan;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::plan::LogicalPlan;
|
||||
use crate::query_engine::QueryContext;
|
||||
use crate::query_engine::QueryEngineContext;
|
||||
|
||||
/// Physical query planner that converts a `LogicalPlan` to an
|
||||
/// `ExecutionPlan` suitable for execution.
|
||||
@@ -27,7 +27,7 @@ pub trait PhysicalPlanner {
|
||||
/// Create a physical plan from a logical plan
|
||||
async fn create_physical_plan(
|
||||
&self,
|
||||
ctx: &mut QueryContext,
|
||||
ctx: &mut QueryEngineContext,
|
||||
logical_plan: &LogicalPlan,
|
||||
) -> Result<Arc<dyn PhysicalPlan>>;
|
||||
}
|
||||
|
||||
@@ -23,12 +23,13 @@ use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY};
|
||||
use common_query::physical_plan::PhysicalPlan;
|
||||
use common_query::prelude::ScalarUdf;
|
||||
use common_query::Output;
|
||||
use session::context::QueryContextRef;
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::datafusion::DatafusionQueryEngine;
|
||||
use crate::error::Result;
|
||||
use crate::plan::LogicalPlan;
|
||||
pub use crate::query_engine::context::QueryContext;
|
||||
pub use crate::query_engine::context::QueryEngineContext;
|
||||
pub use crate::query_engine::state::QueryEngineState;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -37,9 +38,10 @@ pub trait QueryEngine: Send + Sync {
|
||||
|
||||
fn sql_to_statement(&self, sql: &str) -> Result<Statement>;
|
||||
|
||||
fn statement_to_plan(&self, stmt: Statement) -> Result<LogicalPlan>;
|
||||
fn statement_to_plan(&self, stmt: Statement, query_ctx: QueryContextRef)
|
||||
-> Result<LogicalPlan>;
|
||||
|
||||
fn sql_to_plan(&self, sql: &str) -> Result<LogicalPlan>;
|
||||
fn sql_to_plan(&self, sql: &str, query_ctx: QueryContextRef) -> Result<LogicalPlan>;
|
||||
|
||||
async fn execute(&self, plan: &LogicalPlan) -> Result<Output>;
|
||||
|
||||
|
||||
@@ -16,11 +16,11 @@
|
||||
use crate::query_engine::state::QueryEngineState;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct QueryContext {
|
||||
pub struct QueryEngineContext {
|
||||
state: QueryEngineState,
|
||||
}
|
||||
|
||||
impl QueryContext {
|
||||
impl QueryEngineContext {
|
||||
pub fn new(state: QueryEngineState) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ use datatypes::prelude::*;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::{Helper, StringVector};
|
||||
use once_cell::sync::Lazy;
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use sql::statements::describe::DescribeTable;
|
||||
use sql::statements::explain::Explain;
|
||||
@@ -109,7 +110,11 @@ pub fn show_databases(stmt: ShowDatabases, catalog_manager: CatalogManagerRef) -
|
||||
Ok(Output::RecordBatches(records))
|
||||
}
|
||||
|
||||
pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Result<Output> {
|
||||
pub fn show_tables(
|
||||
stmt: ShowTables,
|
||||
catalog_manager: CatalogManagerRef,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Output> {
|
||||
// TODO(LFC): supports WHERE
|
||||
ensure!(
|
||||
matches!(stmt.kind, ShowKind::All | ShowKind::Like(_)),
|
||||
@@ -118,9 +123,15 @@ pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Resu
|
||||
}
|
||||
);
|
||||
|
||||
let schema = stmt.database.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME);
|
||||
let schema = if let Some(database) = stmt.database {
|
||||
database
|
||||
} else {
|
||||
query_ctx
|
||||
.current_schema()
|
||||
.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string())
|
||||
};
|
||||
let schema = catalog_manager
|
||||
.schema(DEFAULT_CATALOG_NAME, schema)
|
||||
.schema(DEFAULT_CATALOG_NAME, &schema)
|
||||
.context(error::CatalogSnafu)?
|
||||
.context(error::SchemaNotFoundSnafu { schema })?;
|
||||
let tables = schema.table_names().context(error::CatalogSnafu)?;
|
||||
@@ -141,8 +152,12 @@ pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Resu
|
||||
Ok(Output::RecordBatches(records))
|
||||
}
|
||||
|
||||
pub async fn explain(stmt: Box<Explain>, query_engine: QueryEngineRef) -> Result<Output> {
|
||||
let plan = query_engine.statement_to_plan(Statement::Explain(*stmt))?;
|
||||
pub async fn explain(
|
||||
stmt: Box<Explain>,
|
||||
query_engine: QueryEngineRef,
|
||||
query_ctx: QueryContextRef,
|
||||
) -> Result<Output> {
|
||||
let plan = query_engine.statement_to_plan(Statement::Explain(*stmt), query_ctx)?;
|
||||
query_engine.execute(&plan).await
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ use datatypes::types::PrimitiveElement;
|
||||
use function::{create_query_engine, get_numbers_from_table};
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_argmax_aggregator() -> Result<()> {
|
||||
@@ -95,7 +96,9 @@ async fn execute_argmax<'a>(
|
||||
"select ARGMAX({}) as argmax from {}",
|
||||
column_name, table_name
|
||||
);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -25,6 +25,7 @@ use datatypes::types::PrimitiveElement;
|
||||
use function::{create_query_engine, get_numbers_from_table};
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_argmin_aggregator() -> Result<()> {
|
||||
@@ -96,7 +97,9 @@ async fn execute_argmin<'a>(
|
||||
"select argmin({}) as argmin from {}",
|
||||
column_name, table_name
|
||||
);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -27,6 +27,7 @@ use datatypes::vectors::PrimitiveVector;
|
||||
use query::query_engine::QueryEngineFactory;
|
||||
use query::QueryEngine;
|
||||
use rand::Rng;
|
||||
use session::context::QueryContext;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
pub fn create_query_engine() -> Arc<dyn QueryEngine> {
|
||||
@@ -80,7 +81,9 @@ where
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
{
|
||||
let sql = format!("SELECT {} FROM {}", column_name, table_name);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -28,6 +28,7 @@ use function::{create_query_engine, get_numbers_from_table};
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mean_aggregator() -> Result<()> {
|
||||
@@ -89,7 +90,9 @@ async fn execute_mean<'a>(
|
||||
engine: Arc<dyn QueryEngine>,
|
||||
) -> RecordResult<Vec<RecordBatch>> {
|
||||
let sql = format!("select MEAN({}) as mean from {}", column_name, table_name);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -36,6 +36,7 @@ use datatypes::with_match_primitive_type_id;
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::QueryEngineFactory;
|
||||
use session::context::QueryContext;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
@@ -228,7 +229,7 @@ where
|
||||
"select MY_SUM({}) as my_sum from {}",
|
||||
column_name, table_name
|
||||
);
|
||||
let plan = engine.sql_to_plan(&sql)?;
|
||||
let plan = engine.sql_to_plan(&sql, Arc::new(QueryContext::new()))?;
|
||||
|
||||
let output = engine.execute(&plan).await?;
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -30,6 +30,7 @@ use function::{create_query_engine, get_numbers_from_table};
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::{QueryEngine, QueryEngineFactory};
|
||||
use session::context::QueryContext;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -53,7 +54,9 @@ async fn test_percentile_aggregator() -> Result<()> {
|
||||
async fn test_percentile_correctness() -> Result<()> {
|
||||
let engine = create_correctness_engine();
|
||||
let sql = String::from("select PERCENTILE(corr_number,88.0) as percentile from corr_numbers");
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
@@ -113,7 +116,9 @@ async fn execute_percentile<'a>(
|
||||
"select PERCENTILE({},50.0) as percentile from {}",
|
||||
column_name, table_name
|
||||
);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table};
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_polyval_aggregator() -> Result<()> {
|
||||
@@ -92,7 +93,9 @@ async fn execute_polyval<'a>(
|
||||
"select POLYVAL({}, 0) as polyval from {}",
|
||||
column_name, table_name
|
||||
);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -37,6 +37,7 @@ use query::plan::LogicalPlan;
|
||||
use query::query_engine::QueryEngineFactory;
|
||||
use query::QueryEngine;
|
||||
use rand::Rng;
|
||||
use session::context::QueryContext;
|
||||
use table::table::adapter::DfTableProviderAdapter;
|
||||
use table::table::numbers::NumbersTable;
|
||||
use table::test_util::MemTable;
|
||||
@@ -134,7 +135,10 @@ async fn test_udf() -> Result<()> {
|
||||
|
||||
engine.register_udf(udf);
|
||||
|
||||
let plan = engine.sql_to_plan("select pow(number, number) as p from numbers limit 10")?;
|
||||
let plan = engine.sql_to_plan(
|
||||
"select pow(number, number) as p from numbers limit 10",
|
||||
Arc::new(QueryContext::new()),
|
||||
)?;
|
||||
|
||||
let output = engine.execute(&plan).await?;
|
||||
let recordbatch = match output {
|
||||
@@ -242,7 +246,9 @@ where
|
||||
for<'a> T: Scalar<RefType<'a> = T>,
|
||||
{
|
||||
let sql = format!("SELECT {} FROM {}", column_name, table_name);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
@@ -330,7 +336,9 @@ async fn execute_median<'a>(
|
||||
"select MEDIAN({}) as median from {}",
|
||||
column_name, table_name
|
||||
);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table};
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
use statrs::distribution::{ContinuousCDF, Normal};
|
||||
use statrs::statistics::Statistics;
|
||||
|
||||
@@ -94,7 +95,9 @@ async fn execute_scipy_stats_norm_cdf<'a>(
|
||||
"select SCIPYSTATSNORMCDF({},2.0) as scipy_stats_norm_cdf from {}",
|
||||
column_name, table_name
|
||||
);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table};
|
||||
use num_traits::AsPrimitive;
|
||||
use query::error::Result;
|
||||
use query::QueryEngine;
|
||||
use session::context::QueryContext;
|
||||
use statrs::distribution::{Continuous, Normal};
|
||||
use statrs::statistics::Statistics;
|
||||
|
||||
@@ -94,7 +95,9 @@ async fn execute_scipy_stats_norm_pdf<'a>(
|
||||
"select SCIPYSTATSNORMPDF({},2.0) as scipy_stats_norm_pdf from {}",
|
||||
column_name, table_name
|
||||
);
|
||||
let plan = engine.sql_to_plan(&sql).unwrap();
|
||||
let plan = engine
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.unwrap();
|
||||
|
||||
let output = engine.execute(&plan).await.unwrap();
|
||||
let recordbatch_stream = match output {
|
||||
|
||||
@@ -48,6 +48,7 @@ rustpython-vm = { git = "https://github.com/RustPython/RustPython", optional = t
|
||||
"default",
|
||||
"freeze-stdlib",
|
||||
] }
|
||||
session = { path = "../session" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
sql = { path = "../sql" }
|
||||
table = { path = "../table" }
|
||||
@@ -55,10 +56,10 @@ tokio = { version = "1.0", features = ["full"] }
|
||||
|
||||
[dev-dependencies]
|
||||
log-store = { path = "../log-store" }
|
||||
mito = { path = "../mito", features = ["test"] }
|
||||
ron = "0.7"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
storage = { path = "../storage" }
|
||||
mito = { path = "../mito", features = ["test"] }
|
||||
tempdir = "0.3"
|
||||
tokio = { version = "1.18", features = ["full"] }
|
||||
tokio-test = "0.4"
|
||||
|
||||
@@ -15,11 +15,13 @@
|
||||
pub mod compile;
|
||||
pub mod parse;
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_telemetry::info;
|
||||
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
|
||||
use datatypes::arrow;
|
||||
use datatypes::arrow::array::{Array, ArrayRef};
|
||||
@@ -46,6 +48,8 @@ use crate::python::error::{
|
||||
use crate::python::utils::{format_py_error, is_instance, py_vec_obj_to_array};
|
||||
use crate::python::PyVector;
|
||||
|
||||
thread_local!(static INTERPRETER: RefCell<Option<Arc<Interpreter>>> = RefCell::new(None));
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AnnotationInfo {
|
||||
@@ -114,11 +118,12 @@ impl Coprocessor {
|
||||
let AnnotationInfo {
|
||||
datatype: ty,
|
||||
is_nullable,
|
||||
} = anno[idx].to_owned().unwrap_or_else(||
|
||||
// default to be not nullable and use DataType inferred by PyVector itself
|
||||
AnnotationInfo{
|
||||
datatype: Some(real_ty.to_owned()),
|
||||
is_nullable: false
|
||||
} = anno[idx].to_owned().unwrap_or_else(|| {
|
||||
// default to be not nullable and use DataType inferred by PyVector itself
|
||||
AnnotationInfo {
|
||||
datatype: Some(real_ty.to_owned()),
|
||||
is_nullable: false,
|
||||
}
|
||||
});
|
||||
Field::new(
|
||||
name,
|
||||
@@ -282,7 +287,7 @@ fn check_args_anno_real_type(
|
||||
anno_ty
|
||||
.to_owned()
|
||||
.map(|v| v.datatype == None // like a vector[_]
|
||||
|| v.datatype == Some(real_ty.to_owned()) && v.is_nullable == is_nullable)
|
||||
|| v.datatype == Some(real_ty.to_owned()) && v.is_nullable == is_nullable)
|
||||
.unwrap_or(true),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
@@ -380,7 +385,7 @@ pub(crate) fn exec_with_cached_vm(
|
||||
copr: &Coprocessor,
|
||||
rb: &DfRecordBatch,
|
||||
args: Vec<PyVector>,
|
||||
vm: &Interpreter,
|
||||
vm: &Arc<Interpreter>,
|
||||
) -> Result<RecordBatch> {
|
||||
vm.enter(|vm| -> Result<RecordBatch> {
|
||||
PyVector::make_class(&vm.ctx);
|
||||
@@ -421,10 +426,18 @@ pub(crate) fn exec_with_cached_vm(
|
||||
}
|
||||
|
||||
/// init interpreter with type PyVector and Module: greptime
|
||||
pub(crate) fn init_interpreter() -> Interpreter {
|
||||
vm::Interpreter::with_init(Default::default(), |vm| {
|
||||
PyVector::make_class(&vm.ctx);
|
||||
vm.add_native_module("greptime", Box::new(greptime_builtin::make_module));
|
||||
pub(crate) fn init_interpreter() -> Arc<Interpreter> {
|
||||
INTERPRETER.with(|i| {
|
||||
i.borrow_mut()
|
||||
.get_or_insert_with(|| {
|
||||
let interpreter = Arc::new(vm::Interpreter::with_init(Default::default(), |vm| {
|
||||
PyVector::make_class(&vm.ctx);
|
||||
vm.add_native_module("greptime", Box::new(greptime_builtin::make_module));
|
||||
}));
|
||||
info!("Initialized Python interpreter.");
|
||||
interpreter
|
||||
})
|
||||
.clone()
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStre
|
||||
use datatypes::schema::SchemaRef;
|
||||
use futures::Stream;
|
||||
use query::QueryEngineRef;
|
||||
use session::context::QueryContext;
|
||||
use snafu::{ensure, ResultExt};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
@@ -93,7 +94,9 @@ impl Script for PyScript {
|
||||
matches!(stmt, Statement::Query { .. }),
|
||||
error::UnsupportedSqlSnafu { sql }
|
||||
);
|
||||
let plan = self.query_engine.statement_to_plan(stmt)?;
|
||||
let plan = self
|
||||
.query_engine
|
||||
.statement_to_plan(stmt, Arc::new(QueryContext::new()))?;
|
||||
let res = self.query_engine.execute(&plan).await?;
|
||||
let copr = self.copr.clone();
|
||||
match res {
|
||||
|
||||
@@ -1115,12 +1115,13 @@ pub mod tests {
|
||||
}
|
||||
|
||||
pub fn execute_script(
|
||||
interpreter: &rustpython_vm::Interpreter,
|
||||
script: &str,
|
||||
test_vec: Option<PyVector>,
|
||||
predicate: PredicateFn,
|
||||
) -> Result<(PyObjectRef, Option<bool>), PyRef<rustpython_vm::builtins::PyBaseException>> {
|
||||
let mut pred_res = None;
|
||||
rustpython_vm::Interpreter::without_stdlib(Default::default())
|
||||
interpreter
|
||||
.enter(|vm| {
|
||||
PyVector::make_class(&vm.ctx);
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
@@ -1208,8 +1209,10 @@ pub mod tests {
|
||||
Some(|v, vm| is_eq(v, 2.0, vm)),
|
||||
),
|
||||
];
|
||||
|
||||
let interpreter = rustpython_vm::Interpreter::without_stdlib(Default::default());
|
||||
for (code, pred) in snippet {
|
||||
let result = execute_script(code, None, pred);
|
||||
let result = execute_script(&interpreter, code, None, pred);
|
||||
|
||||
println!(
|
||||
"\u{001B}[35m{code}\u{001B}[0m: {:?}{}",
|
||||
|
||||
@@ -28,6 +28,7 @@ use datatypes::prelude::{ConcreteDataType, ScalarVector};
|
||||
use datatypes::schema::{ColumnSchema, Schema, SchemaBuilder};
|
||||
use datatypes::vectors::{StringVector, TimestampVector, VectorRef};
|
||||
use query::QueryEngineRef;
|
||||
use session::context::QueryContext;
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use table::requests::{CreateTableRequest, InsertRequest};
|
||||
|
||||
@@ -151,7 +152,7 @@ impl ScriptsTable {
|
||||
|
||||
let plan = self
|
||||
.query_engine
|
||||
.sql_to_plan(&sql)
|
||||
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
|
||||
.context(FindScriptSnafu { name })?;
|
||||
|
||||
let stream = match self
|
||||
|
||||
@@ -5,11 +5,11 @@ edition = "2021"
|
||||
license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
aide = { version = "0.6", features = ["axum"] }
|
||||
aide = { version = "0.9", features = ["axum"] }
|
||||
api = { path = "../api" }
|
||||
async-trait = "0.1"
|
||||
axum = "0.6.0-rc.2"
|
||||
axum-macros = "0.3.0-rc.1"
|
||||
axum = "0.6"
|
||||
axum-macros = "0.3"
|
||||
bytes = "1.2"
|
||||
common-base = { path = "../common/base" }
|
||||
common-catalog = { path = "../common/catalog" }
|
||||
@@ -23,25 +23,29 @@ common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
futures = "0.3"
|
||||
hex = { version = "0.4" }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
humantime-serde = "1.1"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", branch = "feat/line-protocol" }
|
||||
metrics = "0.20"
|
||||
num_cpus = "1.13"
|
||||
once_cell = "1.16"
|
||||
openmetrics-parser = "0.4"
|
||||
opensrv-mysql = "0.2"
|
||||
opensrv-mysql = "0.3"
|
||||
pgwire = "0.5"
|
||||
prost = "0.11"
|
||||
regex = "1.6"
|
||||
rand = "0.8"
|
||||
regex = "1.6"
|
||||
rustls = "0.20"
|
||||
rustls-pemfile = "1.0"
|
||||
schemars = "0.8"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
session = { path = "../session" }
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
snap = "1"
|
||||
table = { path = "../table" }
|
||||
tokio = { version = "1.20", features = ["full"] }
|
||||
tokio-rustls = "0.23"
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
tonic = "0.8"
|
||||
tonic-reflection = "0.5"
|
||||
@@ -56,6 +60,8 @@ mysql_async = { git = "https://github.com/Morranto/mysql_async.git", rev = "127b
|
||||
query = { path = "../query" }
|
||||
rand = "0.8"
|
||||
script = { path = "../script", features = ["python"] }
|
||||
serde_json = "1.0"
|
||||
table = { path = "../table" }
|
||||
tokio-postgres = "0.7"
|
||||
tokio-postgres-rustls = "0.9"
|
||||
tokio-test = "0.4"
|
||||
|
||||
@@ -192,6 +192,9 @@ pub enum Error {
|
||||
err_msg: String,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Tls is required for {}, plain connection is rejected", server))]
|
||||
TlsRequired { server: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -234,6 +237,7 @@ impl ErrorExt for Error {
|
||||
|
||||
InfluxdbLinesWrite { source, .. } => source.status_code(),
|
||||
Hyper { .. } => StatusCode::Unknown,
|
||||
TlsRequired { .. } => StatusCode::Unknown,
|
||||
StartFrontend { source, .. } => source.status_code(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ impl TryFrom<Vec<RecordBatch>> for HttpRecordsOutput {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, JsonSchema)]
|
||||
#[derive(Serialize, Deserialize, Debug, JsonSchema, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum JsonOutput {
|
||||
AffectedRows(usize),
|
||||
@@ -344,59 +344,47 @@ impl HttpServer {
|
||||
url: format!("/{}", HTTP_API_VERSION),
|
||||
..OpenAPIServer::default()
|
||||
}],
|
||||
|
||||
..OpenApi::default()
|
||||
};
|
||||
|
||||
// TODO(LFC): Use released Axum.
|
||||
// Axum version 0.6 introduces state within router, making router methods far more elegant
|
||||
// to write. Though version 0.6 is rc, I think it's worth to upgrade.
|
||||
// Prior to version 0.6, we only have a single "Extension" to share all query
|
||||
// handlers amongst router methods. That requires us to pack all query handlers in a shared
|
||||
// state, and check-then-get the desired query handler in different router methods, which
|
||||
// is a lot of tedious work.
|
||||
let sql_router = ApiRouter::with_state(ApiState {
|
||||
sql_handler: self.sql_handler.clone(),
|
||||
script_handler: self.script_handler.clone(),
|
||||
})
|
||||
.api_route(
|
||||
"/sql",
|
||||
apirouting::get_with(handler::sql, handler::sql_docs)
|
||||
.post_with(handler::sql, handler::sql_docs),
|
||||
)
|
||||
.api_route("/scripts", apirouting::post(script::scripts))
|
||||
.api_route("/run-script", apirouting::post(script::run_script))
|
||||
.route("/private/api.json", apirouting::get(serve_api))
|
||||
.route("/private/docs", apirouting::get(serve_docs))
|
||||
.finish_api(&mut api)
|
||||
.layer(Extension(Arc::new(api)));
|
||||
let sql_router = self
|
||||
.route_sql(ApiState {
|
||||
sql_handler: self.sql_handler.clone(),
|
||||
script_handler: self.script_handler.clone(),
|
||||
})
|
||||
.finish_api(&mut api)
|
||||
.layer(Extension(Arc::new(api)));
|
||||
|
||||
let mut router = Router::new().nest(&format!("/{}", HTTP_API_VERSION), sql_router);
|
||||
|
||||
if let Some(opentsdb_handler) = self.opentsdb_handler.clone() {
|
||||
let opentsdb_router = Router::with_state(opentsdb_handler)
|
||||
.route("/api/put", routing::post(opentsdb::put));
|
||||
|
||||
router = router.nest(&format!("/{}/opentsdb", HTTP_API_VERSION), opentsdb_router);
|
||||
router = router.nest(
|
||||
&format!("/{}/opentsdb", HTTP_API_VERSION),
|
||||
self.route_opentsdb(opentsdb_handler),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(influxdb_handler) = self.influxdb_handler.clone() {
|
||||
let influxdb_router =
|
||||
Router::with_state(influxdb_handler).route("/write", routing::post(influxdb_write));
|
||||
|
||||
router = router.nest(&format!("/{}/influxdb", HTTP_API_VERSION), influxdb_router);
|
||||
router = router.nest(
|
||||
&format!("/{}/influxdb", HTTP_API_VERSION),
|
||||
self.route_influxdb(influxdb_handler),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(prom_handler) = self.prom_handler.clone() {
|
||||
let prom_router = Router::with_state(prom_handler)
|
||||
.route("/write", routing::post(prometheus::remote_write))
|
||||
.route("/read", routing::post(prometheus::remote_read));
|
||||
|
||||
router = router.nest(&format!("/{}/prometheus", HTTP_API_VERSION), prom_router);
|
||||
router = router.nest(
|
||||
&format!("/{}/prometheus", HTTP_API_VERSION),
|
||||
self.route_prom(prom_handler),
|
||||
);
|
||||
}
|
||||
|
||||
router = router.route("/metrics", routing::get(handler::metrics));
|
||||
|
||||
router = router.route(
|
||||
"/health",
|
||||
routing::get(handler::health).post(handler::health),
|
||||
);
|
||||
|
||||
router
|
||||
// middlewares
|
||||
.layer(
|
||||
@@ -408,6 +396,39 @@ impl HttpServer {
|
||||
.layer(middleware::from_fn(context::build_ctx)),
|
||||
)
|
||||
}
|
||||
|
||||
fn route_sql<S>(&self, api_state: ApiState) -> ApiRouter<S> {
|
||||
ApiRouter::new()
|
||||
.api_route(
|
||||
"/sql",
|
||||
apirouting::get_with(handler::sql, handler::sql_docs)
|
||||
.post_with(handler::sql, handler::sql_docs),
|
||||
)
|
||||
.api_route("/scripts", apirouting::post(script::scripts))
|
||||
.api_route("/run-script", apirouting::post(script::run_script))
|
||||
.route("/private/api.json", apirouting::get(serve_api))
|
||||
.route("/private/docs", apirouting::get(serve_docs))
|
||||
.with_state(api_state)
|
||||
}
|
||||
|
||||
fn route_prom<S>(&self, prom_handler: PrometheusProtocolHandlerRef) -> Router<S> {
|
||||
Router::new()
|
||||
.route("/write", routing::post(prometheus::remote_write))
|
||||
.route("/read", routing::post(prometheus::remote_read))
|
||||
.with_state(prom_handler)
|
||||
}
|
||||
|
||||
fn route_influxdb<S>(&self, influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
|
||||
Router::new()
|
||||
.route("/write", routing::post(influxdb_write))
|
||||
.with_state(influxdb_handler)
|
||||
}
|
||||
|
||||
fn route_opentsdb<S>(&self, opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
|
||||
Router::new()
|
||||
.route("/api/put", routing::post(opentsdb::put))
|
||||
.with_state(opentsdb_handler)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -471,6 +492,7 @@ mod test {
|
||||
use datatypes::prelude::*;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::{StringVector, UInt32Vector};
|
||||
use session::context::QueryContextRef;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::*;
|
||||
@@ -482,7 +504,7 @@ mod test {
|
||||
|
||||
#[async_trait]
|
||||
impl SqlQueryHandler for DummyInstance {
|
||||
async fn do_query(&self, _query: &str) -> Result<Output> {
|
||||
async fn do_query(&self, _: &str, _: QueryContextRef) -> Result<Output> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use aide::transform::TransformOperation;
|
||||
@@ -21,6 +22,7 @@ use common_error::status_code::StatusCode;
|
||||
use common_telemetry::metric;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use session::context::QueryContext;
|
||||
|
||||
use crate::http::{ApiState, JsonResponse};
|
||||
|
||||
@@ -39,7 +41,9 @@ pub async fn sql(
|
||||
let sql_handler = &state.sql_handler;
|
||||
let start = Instant::now();
|
||||
let resp = if let Some(sql) = ¶ms.sql {
|
||||
JsonResponse::from_output(sql_handler.do_query(sql).await).await
|
||||
// TODO(LFC): Sessions in http server.
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
|
||||
} else {
|
||||
JsonResponse::with_error(
|
||||
"sql parameter is required.".to_string(),
|
||||
@@ -63,3 +67,17 @@ pub async fn metrics(Query(_params): Query<HashMap<String, String>>) -> String {
|
||||
"Prometheus handle not initialized.".to_owned()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct HealthQuery {}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
|
||||
pub struct HealthResponse {}
|
||||
|
||||
/// Handler to export healthy check
|
||||
///
|
||||
/// Currently simply return status "200 OK" (default) with an empty json payload "{}"
|
||||
#[axum_macros::debug_handler]
|
||||
pub async fn health(Query(_params): Query<HealthQuery>) -> Json<HealthResponse> {
|
||||
Json(HealthResponse {})
|
||||
}
|
||||
|
||||
@@ -28,6 +28,8 @@ pub mod postgres;
|
||||
pub mod prometheus;
|
||||
pub mod query_handler;
|
||||
pub mod server;
|
||||
pub mod tls;
|
||||
|
||||
mod shutdown;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
|
||||
@@ -26,21 +26,25 @@ use datatypes::vectors::StringVector;
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::bytes::RegexSet;
|
||||
use regex::Regex;
|
||||
use session::context::QueryContextRef;
|
||||
|
||||
// TODO(LFC): Include GreptimeDB's version and git commit tag etc.
|
||||
const MYSQL_VERSION: &str = "8.0.26";
|
||||
|
||||
static SELECT_VAR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap());
|
||||
static MYSQL_CONN_JAVA_PATTERN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-java(.*))").unwrap());
|
||||
Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-j(.*))").unwrap());
|
||||
static SHOW_LOWER_CASE_PATTERN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'lower_case_table_names'(.*))").unwrap());
|
||||
static SHOW_COLLATION_PATTERN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new("(?i)^(show collation where(.*))").unwrap());
|
||||
static SHOW_VARIABLES_PATTERN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES(.*))").unwrap());
|
||||
|
||||
static SELECT_VERSION_PATTERN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"(?i)^(SELECT VERSION\(\s*\))").unwrap());
|
||||
static SELECT_DATABASE_PATTERN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"(?i)^(SELECT DATABASE\(\s*\))").unwrap());
|
||||
|
||||
// SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP());
|
||||
static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
|
||||
@@ -248,13 +252,18 @@ fn check_show_variables(query: &str) -> Option<Output> {
|
||||
}
|
||||
|
||||
// Check for SET or others query, this is the final check of the federated query.
|
||||
fn check_others(query: &str) -> Option<Output> {
|
||||
fn check_others(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
|
||||
if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
|
||||
return Some(Output::RecordBatches(RecordBatches::empty()));
|
||||
}
|
||||
|
||||
let recordbatches = if SELECT_VERSION_PATTERN.is_match(query) {
|
||||
Some(select_function("version()", MYSQL_VERSION))
|
||||
} else if SELECT_DATABASE_PATTERN.is_match(query) {
|
||||
let schema = query_ctx
|
||||
.current_schema()
|
||||
.unwrap_or_else(|| "NULL".to_string());
|
||||
Some(select_function("database()", &schema))
|
||||
} else if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) {
|
||||
Some(select_function(
|
||||
"TIMEDIFF(NOW(), UTC_TIMESTAMP())",
|
||||
@@ -268,7 +277,7 @@ fn check_others(query: &str) -> Option<Output> {
|
||||
|
||||
// Check whether the query is a federated or driver setup command,
|
||||
// and return some faked results if there are any.
|
||||
pub fn check(query: &str) -> Option<Output> {
|
||||
pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
|
||||
// First to check the query is like "select @@variables".
|
||||
let output = check_select_variable(query);
|
||||
if output.is_some() {
|
||||
@@ -282,25 +291,27 @@ pub fn check(query: &str) -> Option<Output> {
|
||||
}
|
||||
|
||||
// Last check.
|
||||
check_others(query)
|
||||
check_others(query, query_ctx)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use session::context::QueryContext;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_check() {
|
||||
let query = "select 1";
|
||||
let result = check(query);
|
||||
let result = check(query, Arc::new(QueryContext::new()));
|
||||
assert!(result.is_none());
|
||||
|
||||
let query = "select versiona";
|
||||
let output = check(query);
|
||||
let output = check(query, Arc::new(QueryContext::new()));
|
||||
assert!(output.is_none());
|
||||
|
||||
fn test(query: &str, expected: Vec<&str>) {
|
||||
let output = check(query);
|
||||
let output = check(query, Arc::new(QueryContext::new()));
|
||||
match output.unwrap() {
|
||||
Output::RecordBatches(r) => {
|
||||
assert_eq!(r.pretty_print().lines().collect::<Vec<_>>(), expected)
|
||||
|
||||
@@ -16,11 +16,13 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_query::Output;
|
||||
use common_telemetry::{debug, error};
|
||||
use opensrv_mysql::{
|
||||
AsyncMysqlShim, ErrorKind, ParamParser, QueryResultWriter, StatementMetaWriter,
|
||||
AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter,
|
||||
};
|
||||
use rand::RngCore;
|
||||
use session::Session;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
@@ -36,7 +38,9 @@ pub struct MysqlInstanceShim {
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
salt: [u8; 20],
|
||||
client_addr: String,
|
||||
// TODO(LFC): Break `Context` struct into different fields in `Session`, each with its own purpose.
|
||||
ctx: Arc<RwLock<Option<Context>>>,
|
||||
session: Arc<Session>,
|
||||
}
|
||||
|
||||
impl MysqlInstanceShim {
|
||||
@@ -59,8 +63,33 @@ impl MysqlInstanceShim {
|
||||
salt: scramble,
|
||||
client_addr,
|
||||
ctx: Arc::new(RwLock::new(None)),
|
||||
session: Arc::new(Session::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_query(&self, query: &str) -> Result<Output> {
|
||||
debug!("Start executing query: '{}'", query);
|
||||
let start = Instant::now();
|
||||
|
||||
// TODO(LFC): Find a better way to deal with these special federated queries:
|
||||
// `check` uses regex to filter out unsupported statements emitted by MySQL's federated
|
||||
// components, this is quick and dirty, there must be a better way to do it.
|
||||
let output =
|
||||
if let Some(output) = crate::mysql::federated::check(query, self.session.context()) {
|
||||
Ok(output)
|
||||
} else {
|
||||
self.query_handler
|
||||
.do_query(query, self.session.context())
|
||||
.await
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Finished executing query: '{}', total time costs in microseconds: {}",
|
||||
query,
|
||||
start.elapsed().as_micros()
|
||||
);
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -144,25 +173,20 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
|
||||
query: &'a str,
|
||||
writer: QueryResultWriter<'a, W>,
|
||||
) -> Result<()> {
|
||||
debug!("Start executing query: '{}'", query);
|
||||
let start = Instant::now();
|
||||
|
||||
// TODO(LFC): Find a better way:
|
||||
// `check` uses regex to filter out unsupported statements emitted by MySQL's federated
|
||||
// components, this is quick and dirty, there must be a better way to do it.
|
||||
let output = if let Some(output) = crate::mysql::federated::check(query) {
|
||||
Ok(output)
|
||||
} else {
|
||||
self.query_handler.do_query(query).await
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Finished executing query: '{}', total time costs in microseconds: {}",
|
||||
query,
|
||||
start.elapsed().as_micros()
|
||||
);
|
||||
|
||||
let output = self.do_query(query).await;
|
||||
let mut writer = MysqlResultWriter::new(writer);
|
||||
writer.write(query, output).await
|
||||
}
|
||||
|
||||
async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
|
||||
let query = format!("USE {}", database.trim());
|
||||
let output = self.do_query(&query).await;
|
||||
if let Err(e) = output {
|
||||
w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes())
|
||||
.await
|
||||
} else {
|
||||
w.ok().await
|
||||
}
|
||||
.map_err(|e| e.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,15 +20,19 @@ use async_trait::async_trait;
|
||||
use common_runtime::Runtime;
|
||||
use common_telemetry::logging::{error, info};
|
||||
use futures::StreamExt;
|
||||
use opensrv_mysql::AsyncMysqlIntermediary;
|
||||
use opensrv_mysql::{
|
||||
plain_run_with_options, secure_run_with_options, AsyncMysqlIntermediary, IntermediaryOptions,
|
||||
};
|
||||
use tokio;
|
||||
use tokio::io::BufWriter;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::ServerConfig;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::mysql::handler::MysqlInstanceShim;
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
use crate::server::{AbortableStream, BaseTcpServer, Server};
|
||||
use crate::tls::TlsOption;
|
||||
|
||||
// Default size of ResultSet write buffer: 100KB
|
||||
const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
|
||||
@@ -36,16 +40,19 @@ const DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE: usize = 100 * 1024;
|
||||
pub struct MysqlServer {
|
||||
base_server: BaseTcpServer,
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
tls: Arc<TlsOption>,
|
||||
}
|
||||
|
||||
impl MysqlServer {
|
||||
pub fn create_server(
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
io_runtime: Arc<Runtime>,
|
||||
tls: Arc<TlsOption>,
|
||||
) -> Box<dyn Server> {
|
||||
Box::new(MysqlServer {
|
||||
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
|
||||
query_handler,
|
||||
tls,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -53,16 +60,22 @@ impl MysqlServer {
|
||||
&self,
|
||||
io_runtime: Arc<Runtime>,
|
||||
stream: AbortableStream,
|
||||
tls_conf: Option<Arc<ServerConfig>>,
|
||||
) -> impl Future<Output = ()> {
|
||||
let query_handler = self.query_handler.clone();
|
||||
let force_tls = self.tls.should_force_tls();
|
||||
|
||||
stream.for_each(move |tcp_stream| {
|
||||
let io_runtime = io_runtime.clone();
|
||||
let query_handler = query_handler.clone();
|
||||
let tls_conf = tls_conf.clone();
|
||||
async move {
|
||||
match tcp_stream {
|
||||
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
|
||||
Ok(io_stream) => {
|
||||
if let Err(error) = Self::handle(io_stream, io_runtime, query_handler).await
|
||||
if let Err(error) =
|
||||
Self::handle(io_stream, io_runtime, query_handler, tls_conf, force_tls)
|
||||
.await
|
||||
{
|
||||
error!(error; "Unexpected error when handling TcpStream");
|
||||
};
|
||||
@@ -76,28 +89,49 @@ impl MysqlServer {
|
||||
stream: TcpStream,
|
||||
io_runtime: Arc<Runtime>,
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
tls_conf: Option<Arc<ServerConfig>>,
|
||||
force_tls: bool,
|
||||
) -> Result<()> {
|
||||
info!("MySQL connection coming from: {}", stream.peer_addr()?);
|
||||
let shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string());
|
||||
|
||||
let (r, w) = stream.into_split();
|
||||
let w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
|
||||
// TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client.
|
||||
let spawn_result = io_runtime
|
||||
.spawn(AsyncMysqlIntermediary::run_on(shim, r, w))
|
||||
.await;
|
||||
match spawn_result {
|
||||
Ok(run_result) => {
|
||||
if let Err(e) = run_result {
|
||||
// TODO(LFC): Write this error and the below one to client as well, in MySQL text protocol.
|
||||
// Looks like we have to expose opensrv-mysql's `PacketWriter`?
|
||||
error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.")
|
||||
}
|
||||
io_runtime .spawn(async move {
|
||||
// TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client.
|
||||
if let Err(e) = Self::do_handle(stream, query_handler, tls_conf, force_tls).await {
|
||||
// TODO(LFC): Write this error to client as well, in MySQL text protocol.
|
||||
// Looks like we have to expose opensrv-mysql's `PacketWriter`?
|
||||
error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.")
|
||||
}
|
||||
Err(e) => error!("IO runtime cannot execute task, error: {}", e),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_handle(
|
||||
stream: TcpStream,
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
tls_conf: Option<Arc<ServerConfig>>,
|
||||
force_tls: bool,
|
||||
) -> Result<()> {
|
||||
let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string());
|
||||
let (mut r, w) = stream.into_split();
|
||||
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
|
||||
let ops = IntermediaryOptions::default();
|
||||
|
||||
let (client_tls, init_params) =
|
||||
AsyncMysqlIntermediary::init_before_ssl(&mut shim, &mut r, &mut w, &tls_conf).await?;
|
||||
|
||||
if force_tls && !client_tls {
|
||||
return Err(Error::TlsRequired {
|
||||
server: "mysql".to_owned(),
|
||||
});
|
||||
}
|
||||
|
||||
match tls_conf {
|
||||
Some(tls_conf) if client_tls => {
|
||||
secure_run_with_options(shim, w, ops, tls_conf, init_params).await
|
||||
}
|
||||
_ => plain_run_with_options(shim, w, ops, init_params).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -110,7 +144,10 @@ impl Server for MysqlServer {
|
||||
let (stream, addr) = self.base_server.bind(listening).await?;
|
||||
|
||||
let io_runtime = self.base_server.io_runtime();
|
||||
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
|
||||
|
||||
let tls_conf = self.tls.setup()?.map(Arc::new);
|
||||
|
||||
let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_conf));
|
||||
self.base_server.start_with(join_handle).await?;
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
@@ -63,14 +63,16 @@ pub struct PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier,
|
||||
param_provider: GreptimeDBStartupParameters,
|
||||
with_pwd: bool,
|
||||
force_tls: bool,
|
||||
}
|
||||
|
||||
impl PgAuthStartupHandler {
|
||||
pub fn new(with_pwd: bool) -> Self {
|
||||
pub fn new(with_pwd: bool, force_tls: bool) -> Self {
|
||||
PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier,
|
||||
param_provider: GreptimeDBStartupParameters::new(),
|
||||
with_pwd,
|
||||
force_tls,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -89,6 +91,20 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
{
|
||||
match message {
|
||||
PgWireFrontendMessage::Startup(ref startup) => {
|
||||
if !client.is_secure() && self.force_tls {
|
||||
let error_info = ErrorInfo::new(
|
||||
"FATAL".to_owned(),
|
||||
"28000".to_owned(),
|
||||
"No encryption".to_owned(),
|
||||
);
|
||||
let error = ErrorResponse::from(error_info);
|
||||
|
||||
client
|
||||
.feed(PgWireBackendMessage::ErrorResponse(error))
|
||||
.await?;
|
||||
client.close().await?;
|
||||
return Ok(());
|
||||
}
|
||||
auth::save_startup_parameters_to_metadata(client, startup);
|
||||
if self.with_pwd {
|
||||
client.set_state(PgWireConnectionState::AuthenticationInProgress);
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_query::Output;
|
||||
@@ -26,6 +27,7 @@ use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
|
||||
use pgwire::api::results::{text_query_response, FieldInfo, Response, Tag, TextDataRowEncoder};
|
||||
use pgwire::api::{ClientInfo, Type};
|
||||
use pgwire::error::{PgWireError, PgWireResult};
|
||||
use session::context::QueryContext;
|
||||
|
||||
use crate::error::{self, Error, Result};
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
@@ -40,15 +42,30 @@ impl PostgresServerHandler {
|
||||
}
|
||||
}
|
||||
|
||||
const CLIENT_METADATA_DATABASE: &str = "database";
|
||||
|
||||
fn query_context_from_client_info<C>(client: &C) -> Arc<QueryContext>
|
||||
where
|
||||
C: ClientInfo,
|
||||
{
|
||||
let query_context = QueryContext::new();
|
||||
if let Some(current_schema) = client.metadata().get(CLIENT_METADATA_DATABASE) {
|
||||
query_context.set_current_schema(current_schema);
|
||||
}
|
||||
|
||||
Arc::new(query_context)
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SimpleQueryHandler for PostgresServerHandler {
|
||||
async fn do_query<C>(&self, _client: &C, query: &str) -> PgWireResult<Vec<Response>>
|
||||
async fn do_query<C>(&self, client: &C, query: &str) -> PgWireResult<Vec<Response>>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
let query_ctx = query_context_from_client_info(client);
|
||||
let output = self
|
||||
.query_handler
|
||||
.do_query(query)
|
||||
.do_query(query, query_ctx)
|
||||
.await
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
|
||||
|
||||
|
||||
@@ -22,17 +22,20 @@ use common_telemetry::logging::error;
|
||||
use futures::StreamExt;
|
||||
use pgwire::tokio::process_socket;
|
||||
use tokio;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::postgres::auth_handler::PgAuthStartupHandler;
|
||||
use crate::postgres::handler::PostgresServerHandler;
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
use crate::server::{AbortableStream, BaseTcpServer, Server};
|
||||
use crate::tls::TlsOption;
|
||||
|
||||
pub struct PostgresServer {
|
||||
base_server: BaseTcpServer,
|
||||
auth_handler: Arc<PgAuthStartupHandler>,
|
||||
query_handler: Arc<PostgresServerHandler>,
|
||||
tls: Arc<TlsOption>,
|
||||
}
|
||||
|
||||
impl PostgresServer {
|
||||
@@ -40,14 +43,17 @@ impl PostgresServer {
|
||||
pub fn new(
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
check_pwd: bool,
|
||||
tls: Arc<TlsOption>,
|
||||
io_runtime: Arc<Runtime>,
|
||||
) -> PostgresServer {
|
||||
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
|
||||
let startup_handler = Arc::new(PgAuthStartupHandler::new(check_pwd));
|
||||
let startup_handler =
|
||||
Arc::new(PgAuthStartupHandler::new(check_pwd, tls.should_force_tls()));
|
||||
PostgresServer {
|
||||
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
|
||||
auth_handler: startup_handler,
|
||||
query_handler: postgres_handler,
|
||||
tls,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +61,7 @@ impl PostgresServer {
|
||||
&self,
|
||||
io_runtime: Arc<Runtime>,
|
||||
accepting_stream: AbortableStream,
|
||||
tls_acceptor: Option<Arc<TlsAcceptor>>,
|
||||
) -> impl Future<Output = ()> {
|
||||
let auth_handler = self.auth_handler.clone();
|
||||
let query_handler = self.query_handler.clone();
|
||||
@@ -63,6 +70,7 @@ impl PostgresServer {
|
||||
let io_runtime = io_runtime.clone();
|
||||
let auth_handler = auth_handler.clone();
|
||||
let query_handler = query_handler.clone();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
|
||||
async move {
|
||||
match tcp_stream {
|
||||
@@ -70,7 +78,7 @@ impl PostgresServer {
|
||||
Ok(io_stream) => {
|
||||
io_runtime.spawn(process_socket(
|
||||
io_stream,
|
||||
None,
|
||||
tls_acceptor.clone(),
|
||||
auth_handler.clone(),
|
||||
query_handler.clone(),
|
||||
query_handler.clone(),
|
||||
@@ -91,8 +99,14 @@ impl Server for PostgresServer {
|
||||
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
|
||||
let (stream, addr) = self.base_server.bind(listening).await?;
|
||||
|
||||
let tls_acceptor = self
|
||||
.tls
|
||||
.setup()?
|
||||
.map(|server_conf| Arc::new(TlsAcceptor::from(Arc::new(server_conf))));
|
||||
|
||||
let io_runtime = self.base_server.io_runtime();
|
||||
let join_handle = tokio::spawn(self.accept(io_runtime, stream));
|
||||
let join_handle = tokio::spawn(self.accept(io_runtime, stream, tls_acceptor));
|
||||
|
||||
self.base_server.start_with(join_handle).await?;
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ use api::prometheus::remote::{ReadRequest, WriteRequest};
|
||||
use api::v1::{AdminExpr, AdminResult, ObjectExpr, ObjectResult};
|
||||
use async_trait::async_trait;
|
||||
use common_query::Output;
|
||||
use session::context::QueryContextRef;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::influxdb::InfluxdbRequest;
|
||||
@@ -44,7 +45,7 @@ pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait SqlQueryHandler {
|
||||
async fn do_query(&self, query: &str) -> Result<Output>;
|
||||
async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result<Output>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
||||
177
src/servers/src/tls.rs
Normal file
177
src/servers/src/tls.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
// Copyright 2022 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Error, ErrorKind};
|
||||
|
||||
use rustls::{Certificate, PrivateKey, ServerConfig};
|
||||
use rustls_pemfile::{certs, pkcs8_private_keys};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// TlsMode is used for Mysql and Postgres server start up.
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TlsMode {
|
||||
#[default]
|
||||
Disable,
|
||||
Prefer,
|
||||
Require,
|
||||
// TODO(SSebo): Implement the following 2 TSL mode described in
|
||||
// ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html)
|
||||
VerifyCa,
|
||||
VerifyFull,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct TlsOption {
|
||||
pub mode: TlsMode,
|
||||
#[serde(default)]
|
||||
pub cert_path: String,
|
||||
#[serde(default)]
|
||||
pub key_path: String,
|
||||
}
|
||||
|
||||
impl TlsOption {
|
||||
pub fn setup(&self) -> Result<Option<ServerConfig>, Error> {
|
||||
if let TlsMode::Disable = self.mode {
|
||||
return Ok(None);
|
||||
}
|
||||
let cert = certs(&mut BufReader::new(File::open(&self.cert_path)?))
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidInput, "invalid cert"))
|
||||
.map(|mut certs| certs.drain(..).map(Certificate).collect())?;
|
||||
|
||||
// TODO(SSebo): support more private key types
|
||||
let key = pkcs8_private_keys(&mut BufReader::new(File::open(&self.key_path)?))
|
||||
.map_err(|_| Error::new(ErrorKind::InvalidInput, "invalid key"))
|
||||
.map(|mut keys| keys.drain(..).map(PrivateKey).next())?
|
||||
.ok_or_else(|| Error::new(ErrorKind::InvalidInput, "invalid key"))?;
|
||||
|
||||
// TODO(SSebo): with_client_cert_verifier if TlsMode is Required.
|
||||
let config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert, key)
|
||||
.map_err(|err| std::io::Error::new(ErrorKind::InvalidInput, err))?;
|
||||
|
||||
Ok(Some(config))
|
||||
}
|
||||
|
||||
pub fn should_force_tls(&self) -> bool {
|
||||
!matches!(self.mode, TlsMode::Disable | TlsMode::Prefer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tls_option_disable() {
|
||||
let s = r#"
|
||||
{
|
||||
"mode": "disable"
|
||||
}
|
||||
"#;
|
||||
|
||||
let t: TlsOption = serde_json::from_str(s).unwrap();
|
||||
|
||||
assert!(!t.should_force_tls());
|
||||
|
||||
assert!(matches!(t.mode, TlsMode::Disable));
|
||||
assert!(t.key_path.is_empty());
|
||||
assert!(t.cert_path.is_empty());
|
||||
|
||||
let setup = t.setup();
|
||||
assert!(setup.is_ok());
|
||||
let setup = setup.unwrap();
|
||||
assert!(setup.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tls_option_prefer() {
|
||||
let s = r#"
|
||||
{
|
||||
"mode": "prefer",
|
||||
"cert_path": "/some_dir/some.crt",
|
||||
"key_path": "/some_dir/some.key"
|
||||
}
|
||||
"#;
|
||||
|
||||
let t: TlsOption = serde_json::from_str(s).unwrap();
|
||||
|
||||
assert!(!t.should_force_tls());
|
||||
|
||||
assert!(matches!(t.mode, TlsMode::Prefer));
|
||||
assert!(!t.key_path.is_empty());
|
||||
assert!(!t.cert_path.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tls_option_require() {
|
||||
let s = r#"
|
||||
{
|
||||
"mode": "require",
|
||||
"cert_path": "/some_dir/some.crt",
|
||||
"key_path": "/some_dir/some.key"
|
||||
}
|
||||
"#;
|
||||
|
||||
let t: TlsOption = serde_json::from_str(s).unwrap();
|
||||
|
||||
assert!(t.should_force_tls());
|
||||
|
||||
assert!(matches!(t.mode, TlsMode::Require));
|
||||
assert!(!t.key_path.is_empty());
|
||||
assert!(!t.cert_path.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tls_option_verifiy_ca() {
|
||||
let s = r#"
|
||||
{
|
||||
"mode": "verify_ca",
|
||||
"cert_path": "/some_dir/some.crt",
|
||||
"key_path": "/some_dir/some.key"
|
||||
}
|
||||
"#;
|
||||
|
||||
let t: TlsOption = serde_json::from_str(s).unwrap();
|
||||
|
||||
assert!(t.should_force_tls());
|
||||
|
||||
assert!(matches!(t.mode, TlsMode::VerifyCa));
|
||||
assert!(!t.key_path.is_empty());
|
||||
assert!(!t.cert_path.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tls_option_verifiy_full() {
|
||||
let s = r#"
|
||||
{
|
||||
"mode": "verify_full",
|
||||
"cert_path": "/some_dir/some.crt",
|
||||
"key_path": "/some_dir/some.key"
|
||||
}
|
||||
"#;
|
||||
|
||||
let t: TlsOption = serde_json::from_str(s).unwrap();
|
||||
|
||||
assert!(t.should_force_tls());
|
||||
|
||||
assert!(matches!(t.mode, TlsMode::VerifyFull));
|
||||
assert!(!t.key_path.is_empty());
|
||||
assert!(!t.cert_path.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -135,3 +135,18 @@ fn create_query() -> Query<http_handler::SqlQuery> {
|
||||
database: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Currently the payload of response should be simply an empty json "{}";
|
||||
#[tokio::test]
|
||||
async fn test_health() {
|
||||
let expected_json = http_handler::HealthResponse {};
|
||||
let expected_json_str = "{}".to_string();
|
||||
|
||||
let query = http_handler::HealthQuery {};
|
||||
let Json(json) = http_handler::health(Query(query)).await;
|
||||
assert_eq!(json, expected_json);
|
||||
assert_eq!(
|
||||
serde_json::ser::to_string(&json).unwrap(),
|
||||
expected_json_str
|
||||
);
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ use servers::error::Result;
|
||||
use servers::http::{HttpOptions, HttpServer};
|
||||
use servers::influxdb::InfluxdbRequest;
|
||||
use servers::query_handler::{InfluxdbLineProtocolHandler, SqlQueryHandler};
|
||||
use session::context::QueryContextRef;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
struct DummyInstance {
|
||||
@@ -44,7 +45,7 @@ impl InfluxdbLineProtocolHandler for DummyInstance {
|
||||
|
||||
#[async_trait]
|
||||
impl SqlQueryHandler for DummyInstance {
|
||||
async fn do_query(&self, _query: &str) -> Result<Output> {
|
||||
async fn do_query(&self, _: &str, _: QueryContextRef) -> Result<Output> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ use servers::error::{self, Result};
|
||||
use servers::http::{HttpOptions, HttpServer};
|
||||
use servers::opentsdb::codec::DataPoint;
|
||||
use servers::query_handler::{OpentsdbProtocolHandler, SqlQueryHandler};
|
||||
use session::context::QueryContextRef;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
struct DummyInstance {
|
||||
@@ -44,7 +45,7 @@ impl OpentsdbProtocolHandler for DummyInstance {
|
||||
|
||||
#[async_trait]
|
||||
impl SqlQueryHandler for DummyInstance {
|
||||
async fn do_query(&self, _query: &str) -> Result<Output> {
|
||||
async fn do_query(&self, _: &str, _: QueryContextRef) -> Result<Output> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ use servers::http::{HttpOptions, HttpServer};
|
||||
use servers::prometheus;
|
||||
use servers::prometheus::{snappy_compress, Metrics};
|
||||
use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse, SqlQueryHandler};
|
||||
use session::context::QueryContextRef;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
struct DummyInstance {
|
||||
@@ -69,7 +70,7 @@ impl PrometheusProtocolHandler for DummyInstance {
|
||||
|
||||
#[async_trait]
|
||||
impl SqlQueryHandler for DummyInstance {
|
||||
async fn do_query(&self, _query: &str) -> Result<Output> {
|
||||
async fn do_query(&self, _: &str, _: QueryContextRef) -> Result<Output> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@ mod http;
|
||||
mod mysql;
|
||||
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use script::python::{PyEngine, PyScript};
|
||||
use session::context::QueryContextRef;
|
||||
|
||||
mod opentsdb;
|
||||
mod postgres;
|
||||
|
||||
@@ -52,8 +54,8 @@ impl DummyInstance {
|
||||
|
||||
#[async_trait]
|
||||
impl SqlQueryHandler for DummyInstance {
|
||||
async fn do_query(&self, query: &str) -> Result<Output> {
|
||||
let plan = self.query_engine.sql_to_plan(query).unwrap();
|
||||
async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
let plan = self.query_engine.sql_to_plan(query, query_ctx).unwrap();
|
||||
Ok(self.query_engine.execute(&plan).await.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,17 +20,19 @@ use common_recordbatch::RecordBatch;
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use datatypes::schema::Schema;
|
||||
use mysql_async::prelude::*;
|
||||
use mysql_async::SslOpts;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use servers::error::Result;
|
||||
use servers::mysql::server::MysqlServer;
|
||||
use servers::server::Server;
|
||||
use servers::tls::TlsOption;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::create_testing_sql_query_handler;
|
||||
use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData};
|
||||
|
||||
fn create_mysql_server(table: MemTable) -> Result<Box<dyn Server>> {
|
||||
fn create_mysql_server(table: MemTable, tls: Arc<TlsOption>) -> Result<Box<dyn Server>> {
|
||||
let query_handler = create_testing_sql_query_handler(table);
|
||||
let io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
@@ -39,14 +41,14 @@ fn create_mysql_server(table: MemTable) -> Result<Box<dyn Server>> {
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
Ok(MysqlServer::create_server(query_handler, io_runtime))
|
||||
Ok(MysqlServer::create_server(query_handler, io_runtime, tls))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_start_mysql_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mysql_server = create_mysql_server(table)?;
|
||||
let mysql_server = create_mysql_server(table, Default::default())?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = mysql_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -65,7 +67,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mysql_server = create_mysql_server(table)?;
|
||||
let mysql_server = create_mysql_server(table, Default::default())?;
|
||||
let result = mysql_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -80,7 +82,7 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
for index in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port, index == 1).await {
|
||||
match create_connection(server_port, index == 1, false).await {
|
||||
Ok(mut connection) => {
|
||||
let result: u32 = connection
|
||||
.query_first("SELECT uint32s FROM numbers LIMIT 1")
|
||||
@@ -114,6 +116,63 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
async fn test_query_all_datatypes() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let server_tls = Arc::new(TlsOption::default());
|
||||
let client_tls = false;
|
||||
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_prefer_secure_client_plain() -> Result<()> {
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Prefer,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let client_tls = false;
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_prefer_secure_client_secure() -> Result<()> {
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Prefer,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let client_tls = true;
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_require_secure_client_secure() -> Result<()> {
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Require,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let client_tls = true;
|
||||
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_required_secure_client_plain() -> Result<()> {
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Require,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let client_tls = false;
|
||||
|
||||
#[allow(unused)]
|
||||
let TestingData {
|
||||
column_schemas,
|
||||
mysql_columns_def,
|
||||
@@ -124,11 +183,41 @@ async fn test_query_all_datatypes() -> Result<()> {
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table)?;
|
||||
let mysql_server = create_mysql_server(table, server_tls)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let mut connection = create_connection(server_addr.port(), false).await.unwrap();
|
||||
let r = create_connection(server_addr.port(), client_tls, false).await;
|
||||
assert!(r.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_test_query_all_datatypes(
|
||||
server_tls: Arc<TlsOption>,
|
||||
with_pwd: bool,
|
||||
client_tls: bool,
|
||||
) -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let TestingData {
|
||||
column_schemas,
|
||||
mysql_columns_def,
|
||||
columns,
|
||||
mysql_text_output_rows,
|
||||
} = all_datatype_testing_data();
|
||||
let schema = Arc::new(Schema::new(column_schemas.clone()));
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
let table = MemTable::new("all_datatypes", recordbatch);
|
||||
|
||||
let mysql_server = create_mysql_server(table, server_tls)?;
|
||||
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let mut connection = create_connection(server_addr.port(), client_tls, with_pwd)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut result = connection
|
||||
.query_iter("SELECT * FROM all_datatypes LIMIT 3")
|
||||
.await
|
||||
@@ -155,7 +244,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let mysql_server = create_mysql_server(table)?;
|
||||
let mysql_server = create_mysql_server(table, Default::default())?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
@@ -167,7 +256,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut connection = create_connection(server_port, index % 2 == 0)
|
||||
let mut connection = create_connection(server_port, index % 2 == 0, false)
|
||||
.await
|
||||
.unwrap();
|
||||
for _ in 0..expect_executed_queries_per_worker {
|
||||
@@ -184,8 +273,7 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
connection.disconnect().await.unwrap();
|
||||
connection = create_connection(server_port, index % 2 == 0)
|
||||
connection = create_connection(server_port, index % 2 == 0, false)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
@@ -201,13 +289,24 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection(port: u16, with_pwd: bool) -> mysql_async::Result<mysql_async::Conn> {
|
||||
async fn create_connection(
|
||||
port: u16,
|
||||
with_pwd: bool,
|
||||
ssl: bool,
|
||||
) -> mysql_async::Result<mysql_async::Conn> {
|
||||
let mut opts = mysql_async::OptsBuilder::default()
|
||||
.ip_or_hostname("127.0.0.1")
|
||||
.tcp_port(port)
|
||||
.prefer_socket(false)
|
||||
.wait_timeout(Some(1000));
|
||||
|
||||
if ssl {
|
||||
let ssl_opts = SslOpts::default()
|
||||
.with_danger_skip_domain_validation(true)
|
||||
.with_danger_accept_invalid_certs(true);
|
||||
opts = opts.ssl_opts(ssl_opts)
|
||||
}
|
||||
|
||||
if with_pwd {
|
||||
opts = opts.pass(Some("default_pwd".to_string()));
|
||||
}
|
||||
|
||||
@@ -14,20 +14,28 @@
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::{Certificate, Error, ServerName};
|
||||
use servers::error::Result;
|
||||
use servers::postgres::PostgresServer;
|
||||
use servers::server::Server;
|
||||
use servers::tls::TlsOption;
|
||||
use table::test_util::MemTable;
|
||||
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
|
||||
|
||||
use crate::create_testing_sql_query_handler;
|
||||
|
||||
fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Server>> {
|
||||
fn create_postgres_server(
|
||||
table: MemTable,
|
||||
check_pwd: bool,
|
||||
tls: Arc<TlsOption>,
|
||||
) -> Result<Box<dyn Server>> {
|
||||
let query_handler = create_testing_sql_query_handler(table);
|
||||
let io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
@@ -39,6 +47,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Se
|
||||
Ok(Box::new(PostgresServer::new(
|
||||
query_handler,
|
||||
check_pwd,
|
||||
tls,
|
||||
io_runtime,
|
||||
)))
|
||||
}
|
||||
@@ -47,7 +56,7 @@ fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Se
|
||||
pub async fn test_start_postgres_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let pg_server = create_postgres_server(table, false)?;
|
||||
let pg_server = create_postgres_server(table, false, Default::default())?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = pg_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -72,8 +81,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let postgres_server = create_postgres_server(table, with_pwd)?;
|
||||
let postgres_server = create_postgres_server(table, with_pwd, Default::default())?;
|
||||
let result = postgres_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -88,7 +96,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
for _ in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port, with_pwd).await {
|
||||
match create_plain_connection(server_port, with_pwd).await {
|
||||
Ok(connection) => {
|
||||
match connection
|
||||
.simple_query("SELECT uint32s FROM numbers LIMIT 1")
|
||||
@@ -128,14 +136,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_query_pg_concurrently() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let pg_server = create_postgres_server(table, false)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
let server_port = start_test_server(Default::default()).await?;
|
||||
|
||||
let threads = 4;
|
||||
let expect_executed_queries_per_worker = 300;
|
||||
@@ -144,7 +145,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut client = create_connection(server_port, false).await.unwrap();
|
||||
let mut client = create_plain_connection(server_port, false).await.unwrap();
|
||||
|
||||
for _k in 0..expect_executed_queries_per_worker {
|
||||
let expected: u32 = rand.gen_range(0..100);
|
||||
@@ -165,7 +166,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
// 1/100 chance to reconnect
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
client = create_connection(server_port, false).await.unwrap();
|
||||
client = create_plain_connection(server_port, false).await.unwrap();
|
||||
}
|
||||
}
|
||||
expect_executed_queries_per_worker
|
||||
@@ -179,7 +180,126 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result<Client, PgError> {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_secure_prefer_client_plain() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Prefer,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let client_tls = false;
|
||||
do_simple_query(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_secure_require_client_plain() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Require,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
let server_port = start_test_server(server_tls).await?;
|
||||
let r = create_plain_connection(server_port, false).await;
|
||||
assert!(r.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_server_secure_require_client_secure() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let server_tls = Arc::new(TlsOption {
|
||||
mode: servers::tls::TlsMode::Require,
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let client_tls = true;
|
||||
do_simple_query(server_tls, client_tls).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_using_db() -> Result<()> {
|
||||
let server_port = start_test_server(Arc::new(TlsOption::default())).await?;
|
||||
|
||||
let client = create_connection_with_given_db(server_port, "testdb")
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
let client = create_connection_with_given_db(server_port, DEFAULT_SCHEMA_NAME)
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_ok());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_test_server(server_tls: Arc<TlsOption>) -> Result<u16> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let table = MemTable::default_numbers_table();
|
||||
let pg_server = create_postgres_server(table, false, server_tls)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
Ok(server_addr.port())
|
||||
}
|
||||
|
||||
async fn do_simple_query(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
|
||||
let server_port = start_test_server(server_tls).await?;
|
||||
|
||||
if !client_tls {
|
||||
let client = create_plain_connection(server_port, false).await.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_ok());
|
||||
} else {
|
||||
let client = create_secure_connection(server_port, false).await.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_secure_connection(
|
||||
port: u16,
|
||||
with_pwd: bool,
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = if with_pwd {
|
||||
format!(
|
||||
"sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
|
||||
port
|
||||
)
|
||||
} else {
|
||||
format!("host=127.0.0.1 port={} connect_timeout=2", port)
|
||||
};
|
||||
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(rustls::RootCertStore::empty())
|
||||
.with_no_client_auth();
|
||||
config
|
||||
.dangerous()
|
||||
.set_certificate_verifier(Arc::new(AcceptAllVerifier {}));
|
||||
|
||||
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
|
||||
let (client, conn) = tokio_postgres::connect(&url, tls).await.expect("connect");
|
||||
|
||||
tokio::spawn(conn);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn create_plain_connection(
|
||||
port: u16,
|
||||
with_pwd: bool,
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = if with_pwd {
|
||||
format!(
|
||||
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
|
||||
@@ -193,6 +313,19 @@ async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result<Cli
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn create_connection_with_given_db(
|
||||
port: u16,
|
||||
db: &str,
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = format!(
|
||||
"host=127.0.0.1 port={} connect_timeout=2 dbname={}",
|
||||
port, db
|
||||
);
|
||||
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
|
||||
tokio::spawn(conn);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
|
||||
match resp {
|
||||
&SimpleQueryMessage::Row(ref r) => r.get(col_index),
|
||||
@@ -203,3 +336,18 @@ fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
|
||||
fn unwrap_results(resp: &[SimpleQueryMessage]) -> Vec<&str> {
|
||||
resp.iter().filter_map(|m| resolve_result(m, 0)).collect()
|
||||
}
|
||||
|
||||
struct AcceptAllVerifier {}
|
||||
impl ServerCertVerifier for AcceptAllVerifier {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &Certificate,
|
||||
_intermediates: &[Certificate],
|
||||
_server_name: &ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: SystemTime,
|
||||
) -> std::result::Result<ServerCertVerified, Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user