Merge branch 'develop' into replace-arrow2

This commit is contained in:
evenyag
2022-12-15 15:29:35 +08:00
102 changed files with 4182 additions and 1448 deletions

345
Cargo.lock generated
View File

@@ -46,7 +46,7 @@ version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf6ccdb167abbf410dcb915cabd428929d7f6a04980b54a11f26a39f1c7f7107"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"const-random",
"getrandom 0.2.8",
"once_cell",
@@ -70,7 +70,7 @@ checksum = "befdff0b4683a0824fc8719ce639a252d9d62cd89c8d0004c39e2417128c1eb8"
dependencies = [
"axum",
"bytes",
"cfg-if",
"cfg-if 1.0.0",
"http",
"indexmap",
"schemars",
@@ -127,6 +127,12 @@ version = "1.0.66"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6"
[[package]]
name = "anymap"
version = "1.0.0-beta.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f1f8f5a6f3d50d89e3797d7593a50f96bb2aaa20ca0cc7be1fb673232c91d72"
[[package]]
name = "api"
version = "0.1.0"
@@ -543,7 +549,7 @@ checksum = "cab84319d616cfb654d03394f38ab7e6f0919e181b1b57e1fd15e7fb4077d9a7"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"cfg-if 1.0.0",
"libc",
"miniz_oxide 0.5.4",
"object",
@@ -689,7 +695,7 @@ dependencies = [
"arrayref",
"arrayvec 0.7.2",
"cc",
"cfg-if",
"cfg-if 1.0.0",
"constant_time_eq 0.2.4",
"digest",
]
@@ -861,6 +867,12 @@ dependencies = [
"pkg-config",
]
[[package]]
name = "cactus"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf034765b7d19a011c6d619e880582bf95e8186b580e6fab56589872dd87dcf5"
[[package]]
name = "camino"
version = "1.1.1"
@@ -887,7 +899,7 @@ checksum = "4acbb09d9ee8e23699b9634375c72795d095bf268439da88562cf9b501f181fa"
dependencies = [
"camino",
"cargo-platform",
"semver",
"semver 1.0.14",
"serde",
"serde_json",
]
@@ -963,12 +975,32 @@ dependencies = [
"nom",
]
[[package]]
name = "cfg-if"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cfgrammar"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf74ea341ae8905eac9a234b6a5a845e118c25bbbdecf85ec77431a8b3bfa0be"
dependencies = [
"indexmap",
"lazy_static",
"num-traits",
"regex",
"serde",
"vob",
]
[[package]]
name = "chrono"
version = "0.4.23"
@@ -1195,6 +1227,7 @@ dependencies = [
name = "cmd"
version = "0.1.0"
dependencies = [
"anymap",
"build-data",
"clap 3.2.23",
"common-error",
@@ -1566,7 +1599,7 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@@ -1647,7 +1680,7 @@ version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
@@ -1661,7 +1694,7 @@ version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2dd04ddaf88237dc3b8d8f9a3c1004b506b54b3313403944054d23c0870c521"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-utils",
]
@@ -1671,7 +1704,7 @@ version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "715e8152b692bba2d374b53d4875445368fdf21a94751410af607a5ac677d1fc"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-epoch",
"crossbeam-utils",
]
@@ -1683,7 +1716,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01a9af1f4c2ef74bb8aa1f7e19706bc72d03598c8a570bb5de72243c7a9d9d5a"
dependencies = [
"autocfg",
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-utils",
"memoffset 0.7.1",
"scopeguard",
@@ -1695,7 +1728,7 @@ version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-utils",
]
@@ -1705,7 +1738,7 @@ version = "0.8.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fb766fa798726286dbbb842f174001dab8abc7b627a1dd86e0b7222a95d929f"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@@ -1856,7 +1889,7 @@ version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"hashbrown 0.12.3",
"lock_api",
"once_cell",
@@ -2031,7 +2064,6 @@ dependencies = [
"datafusion",
"datafusion-common",
"datatypes",
"frontend",
"futures",
"hyper",
"log-store",
@@ -2165,7 +2197,7 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"dirs-sys-next",
]
@@ -2203,7 +2235,7 @@ version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53ecafc952c4528d9b51a458d1a8904b81783feff9fde08ab6ed2545ff396872"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"libc",
"socket2",
"winapi",
@@ -2254,7 +2286,7 @@ version = "0.8.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9852635589dc9f9ea1b6fe9f05b50ef208c85c834a562f0c6abb1c475736ec2b"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@@ -2263,6 +2295,26 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d"
[[package]]
name = "enum-iterator"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45a0ac4aeb3a18f92eaf09c6bb9b3ac30ff61ca95514fc58cbead1c9a6bf5401"
dependencies = [
"enum-iterator-derive",
]
[[package]]
name = "enum-iterator-derive"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "828de45d0ca18782232dfb8f3ea9cc428e8ced380eb26a520baaacfc70de39ce"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "enum_dispatch"
version = "0.3.8"
@@ -2364,11 +2416,23 @@ version = "3.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb21c69b9fea5e15dbc1049e4b77145dd0ba1c84019c488102de0dc4ea4b0a27"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"rustix",
"windows-sys 0.42.0",
]
[[package]]
name = "filetime"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e884668cd0c7480504233e951174ddc3b382f7c2666e3b7310b5c4e7b0c37f9"
dependencies = [
"cfg-if 1.0.0",
"libc",
"redox_syscall 0.2.16",
"windows-sys 0.42.0",
]
[[package]]
name = "fixedbitset"
version = "0.4.2"
@@ -2430,6 +2494,7 @@ dependencies = [
name = "frontend"
version = "0.1.0"
dependencies = [
"anymap",
"api",
"async-stream",
"async-trait",
@@ -2676,13 +2741,22 @@ dependencies = [
"winapi",
]
[[package]]
name = "getopts"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5"
dependencies = [
"unicode-width",
]
[[package]]
name = "getrandom"
version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"libc",
"wasi 0.9.0+wasi-snapshot-preview1",
]
@@ -2693,7 +2767,7 @@ version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
@@ -3029,7 +3103,7 @@ version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@@ -3283,7 +3357,7 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"winapi",
]
@@ -3335,7 +3409,7 @@ version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@@ -3361,6 +3435,60 @@ dependencies = [
"store-api",
"tempdir",
"tokio",
"tokio-util",
]
[[package]]
name = "lrlex"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22b832738fbfa58ad036580929e973b3b6bd31c6d6c7f18f6b5ea7b626675c85"
dependencies = [
"getopts",
"lazy_static",
"lrpar",
"num-traits",
"regex",
"serde",
"try_from",
"vergen",
]
[[package]]
name = "lrpar"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f270b952b07995fe874b10a5ed7dd28c80aa2130e37a7de7ed667d034e0a521"
dependencies = [
"bincode 1.3.3",
"cactus",
"cfgrammar",
"filetime",
"indexmap",
"lazy_static",
"lrtable",
"num-traits",
"packedvec",
"regex",
"serde",
"static_assertions",
"vergen",
"vob",
]
[[package]]
name = "lrtable"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a854115c6a10772ac154261592b082436abc869c812575cadcf9d7ceda8eff0b"
dependencies = [
"cfgrammar",
"fnv",
"num-traits",
"serde",
"sparsevec",
"static_assertions",
"vob",
]
[[package]]
@@ -3465,6 +3593,12 @@ dependencies = [
"digest",
]
[[package]]
name = "md5"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "memchr"
version = "2.5.0"
@@ -3707,7 +3841,7 @@ dependencies = [
"once_cell",
"parking_lot",
"quanta",
"rustc_version",
"rustc_version 0.4.0",
"scheduled-thread-pool",
"skeptic",
"smallvec",
@@ -3878,7 +4012,7 @@ checksum = "8f3790c00a0150112de0f4cd161e3d7fc4b2d8a5542ffc35f099a2562aecb35c"
dependencies = [
"bitflags",
"cc",
"cfg-if",
"cfg-if 1.0.0",
"libc",
"memoffset 0.6.5",
]
@@ -3890,7 +4024,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa52e972a9a719cecb6864fb88568781eb706bac2cd1d4f04a648542dbf78069"
dependencies = [
"bitflags",
"cfg-if",
"cfg-if 1.0.0",
"libc",
"memoffset 0.6.5",
]
@@ -4087,9 +4221,9 @@ checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
[[package]]
name = "opendal"
version = "0.21.2"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "27b897fbc34f29b7975b7856aaa6d2199c9df6469245146e20d19c031d4db9d2"
checksum = "e2ce68ece2dc033c0faf446fe654b0182de8e8b876aef36d733cca7c95e2601a"
dependencies = [
"anyhow",
"async-compat",
@@ -4117,6 +4251,7 @@ dependencies = [
"tokio",
"tracing",
"ureq",
"uuid",
]
[[package]]
@@ -4244,6 +4379,16 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "packedvec"
version = "1.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bde3c690ec20e4a2b4fb46f0289a451181eb50011a1e2acc8d85e2fde9062a45"
dependencies = [
"num-traits",
"serde",
]
[[package]]
name = "page_size"
version = "0.4.2"
@@ -4277,7 +4422,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ff9f3fef3968a3ec5945535ed654cb38ff72d7495a25619e2247fb15a2ed9ba"
dependencies = [
"backtrace",
"cfg-if",
"cfg-if 1.0.0",
"libc",
"petgraph",
"redox_syscall 0.2.16",
@@ -4416,16 +4561,18 @@ dependencies = [
[[package]]
name = "pgwire"
version = "0.5.0"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5dacbf864d6cb6a0e676c9a1162ab7b315b5c8e6c87fa9b6e0ba9ba0a569adb1"
checksum = "d90fd7db2eab0a1b9cdde0ef2393f99b83c6198b1c2e62595e8d269d59b8ffca"
dependencies = [
"async-trait",
"bytes",
"derive-new",
"futures",
"getset",
"hex",
"log",
"md5",
"postgres-types",
"rand 0.8.5",
"thiserror",
@@ -4611,7 +4758,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "166ca89eb77fd403230b9c156612965a81e094ec6ec3aa13663d4c8b113fa748"
dependencies = [
"autocfg",
"cfg-if",
"cfg-if 1.0.0",
"libc",
"log",
"wepoll-ffi",
@@ -4769,6 +4916,27 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "promql"
version = "0.1.0"
dependencies = [
"common-error",
"promql-parser",
"snafu",
]
[[package]]
name = "promql-parser"
version = "0.0.1"
source = "git+https://github.com/GreptimeTeam/promql-parser.git?rev=71d8a90#71d8a90979304a7f128b3125f37a209384a81051"
dependencies = [
"cfgrammar",
"lazy_static",
"lrlex",
"lrpar",
"regex",
]
[[package]]
name = "prost"
version = "0.9.0"
@@ -5367,7 +5535,7 @@ version = "0.18.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6d5f2436026b4f6e79dc829837d467cc7e9a55ee40e750d716713540715a2df"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"ordered-multimap",
]
@@ -5401,13 +5569,22 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc_version"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0dfe2087c51c460008730de8b57e6a320782fbfb312e1f4d520e6c6fae155ee"
dependencies = [
"semver 0.11.0",
]
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
"semver 1.0.14",
]
[[package]]
@@ -5489,7 +5666,7 @@ version = "0.0.0"
source = "git+https://github.com/discord9/RustPython?rev=183e8dab#183e8dabe0027e31630368e36c6be83b5f9cb3f8"
dependencies = [
"ascii",
"cfg-if",
"cfg-if 1.0.0",
"hexf-parse",
"lexical-parse-float",
"libc",
@@ -5606,7 +5783,7 @@ dependencies = [
"ascii",
"base64",
"blake2",
"cfg-if",
"cfg-if 1.0.0",
"crc32fast",
"crossbeam-utils",
"csv-core",
@@ -5672,7 +5849,7 @@ dependencies = [
"bitflags",
"bstr",
"caseless",
"cfg-if",
"cfg-if 1.0.0",
"chrono",
"crossbeam-utils",
"exitcode",
@@ -5703,7 +5880,7 @@ dependencies = [
"paste",
"rand 0.8.5",
"result-like",
"rustc_version",
"rustc_version 0.4.0",
"rustpython-ast",
"rustpython-codegen",
"rustpython-common",
@@ -5748,7 +5925,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1cd5ae51d3f7bf65d7969d579d502168ef578f289452bd8ccc91de28fda20e"
dependencies = [
"bitflags",
"cfg-if",
"cfg-if 1.0.0",
"clipboard-win",
"dirs-next",
"fd-lock",
@@ -5978,6 +6155,15 @@ dependencies = [
"libc",
]
[[package]]
name = "semver"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6"
dependencies = [
"semver-parser",
]
[[package]]
name = "semver"
version = "1.0.14"
@@ -5987,6 +6173,15 @@ dependencies = [
"serde",
]
[[package]]
name = "semver-parser"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0bef5b7f9e0df16536d3961cfb6e84331c065b4066afb39768d0e319411f7"
dependencies = [
"pest",
]
[[package]]
name = "seq-macro"
version = "0.3.1"
@@ -6076,6 +6271,7 @@ dependencies = [
"axum",
"axum-macros",
"axum-test-helper",
"base64",
"bytes",
"catalog",
"common-base",
@@ -6088,8 +6284,10 @@ dependencies = [
"common-telemetry",
"common-time",
"datatypes",
"digest",
"futures",
"hex",
"http-body",
"humantime-serde",
"hyper",
"influxdb_line_protocol",
@@ -6111,9 +6309,12 @@ dependencies = [
"serde",
"serde_json",
"session",
"sha1",
"snafu",
"snap",
"strum",
"table",
"tempdir",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
@@ -6140,7 +6341,7 @@ version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"cpufeatures",
"digest",
]
@@ -6151,7 +6352,7 @@ version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"cpufeatures",
"digest",
]
@@ -6162,7 +6363,7 @@ version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"cpufeatures",
"digest",
]
@@ -6338,6 +6539,18 @@ dependencies = [
"winapi",
]
[[package]]
name = "sparsevec"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "928d1ef5df00aec8c5643c2ac37db4dd282763013c0fcc81efbb8e13db8dd8ec"
dependencies = [
"num-traits",
"packedvec",
"serde",
"vob",
]
[[package]]
name = "spin"
version = "0.5.2"
@@ -6577,6 +6790,9 @@ name = "strum"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
@@ -6742,7 +6958,7 @@ version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"fastrand",
"libc",
"redox_syscall 0.2.16",
@@ -7271,7 +7487,7 @@ version = "0.1.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"log",
"pin-project-lite",
"tracing-attributes",
@@ -7394,13 +7610,22 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "try_from"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "283d3b89e1368717881a9d51dad843cc435380d8109c9e47d38780a324698d8b"
dependencies = [
"cfg-if 0.1.10",
]
[[package]]
name = "twox-hash"
version = "1.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"rand 0.8.5",
"static_assertions",
]
@@ -7693,6 +7918,21 @@ version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191"
[[package]]
name = "vergen"
version = "7.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "447f9238a4553957277b3ee09d80babeae0811f1b3baefb093de1c0448437a37"
dependencies = [
"anyhow",
"cfg-if 1.0.0",
"enum-iterator",
"getset",
"rustversion",
"thiserror",
"time 0.3.17",
]
[[package]]
name = "version_check"
version = "0.9.4"
@@ -7705,6 +7945,17 @@ version = "0.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b60dcd6a64dd45abf9bd426970c9843726da7fc08f44cd6fcebf68c21220a63"
[[package]]
name = "vob"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cbdb3eee5dd38a27129832bca4a3171888e699a6ac36de86547975466997986f"
dependencies = [
"num-traits",
"rustc_version 0.3.3",
"serde",
]
[[package]]
name = "volatile"
version = "0.3.0"
@@ -7762,7 +8013,7 @@ version = "0.2.83"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"wasm-bindgen-macro",
]
@@ -7787,7 +8038,7 @@ version = "0.4.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23639446165ca5a5de86ae1d8896b737ae80319560fbaa4c2887b7da6e7ebd7d"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"js-sys",
"wasm-bindgen",
"web-sys",

View File

@@ -26,6 +26,7 @@ members = [
"src/meta-srv",
"src/mito",
"src/object-store",
"src/promql",
"src/query",
"src/script",
"src/servers",

67
Makefile Normal file
View File

@@ -0,0 +1,67 @@
IMAGE_REGISTRY ?= greptimedb
IMAGE_TAG ?= latest
##@ Build
.PHONY: build
build: ## Build debug version greptime.
cargo build
.PHONY: release
release: ## Build release version greptime.
cargo build --release
.PHONY: clean
clean: ## Clean the project.
cargo clean
.PHONY: fmt
fmt: ## Format all the Rust code.
cargo fmt --all
.PHONY: docker-image
docker-image: ## Build docker image.
docker build --network host -f docker/Dockerfile -t ${IMAGE_REGISTRY}:${IMAGE_TAG} .
##@ Test
.PHONY: unit-test
unit-test: ## Run unit test.
cargo test --workspace
.PHONY: integration-test
integration-test: ## Run integation test.
cargo test integration
.PHONY: sqlness-test
sqlness-test: ## Run sqlness test.
cargo run --bin sqlness-runner
.PHONY: check
check: ## Cargo check all the targets.
cargo check --workspace --all-targets
.PHONY: clippy
clippy: ## Check clippy rules.
cargo clippy --workspace --all-targets -- -D warnings -D clippy::print_stdout -D clippy::print_stderr
.PHONY: fmt-check
fmt-check: ## Check code format.
cargo fmt --all -- --check
##@ General
# The help target prints out all targets with their descriptions organized
# beneath their categories. The categories are represented by '##@' and the
# target descriptions by '##'. The awk commands is responsible for reading the
# entire set of makefiles included in this invocation, looking for lines of the
# file as xyz: ## something, and then pretty-format the target and help. Then,
# if there's a line with ##@ something, that gets pretty-printed as a category.
# More info on the usage of ANSI control characters for terminal formatting:
# https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters
# More info on the awk command:
# https://linuxcommand.org/lc3_adv_awk.php
.PHONY: help
help: ## Display help messages.
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m<target>\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-20s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)

View File

@@ -1,7 +1,12 @@
<p align="center">
<img src="/docs/logo-text-padding.png" alt="GreptimeDB Logo" width="400px"></img>
<picture>
<source media="(prefers-color-scheme: light)" srcset="/docs/logo-text-padding.png">
<source media="(prefers-color-scheme: dark)" srcset="/docs/logo-text-padding-dark.png">
<img alt="GreptimeDB Logo" src="/docs/logo-text-padding.png" width="400px">
</picture>
</p>
<h3 align="center">
The next-generation hybrid timeseries/analytics processing database in the cloud
</h3>

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

View File

@@ -5,6 +5,8 @@ package greptime.v1.meta;
import "greptime/v1/meta/common.proto";
service Router {
rpc Create(CreateRequest) returns (RouteResponse) {}
// Fetch routing information for tables. The smallest unit is the complete
// routing information(all regions) of a table.
//
@@ -26,7 +28,14 @@ service Router {
//
rpc Route(RouteRequest) returns (RouteResponse) {}
rpc Create(CreateRequest) returns (RouteResponse) {}
rpc Delete(DeleteRequest) returns (RouteResponse) {}
}
message CreateRequest {
RequestHeader header = 1;
TableName table_name = 2;
repeated Partition partitions = 3;
}
message RouteRequest {
@@ -35,6 +44,12 @@ message RouteRequest {
repeated TableName table_names = 2;
}
message DeleteRequest {
RequestHeader header = 1;
TableName table_name = 2;
}
message RouteResponse {
ResponseHeader header = 1;
@@ -42,13 +57,6 @@ message RouteResponse {
repeated TableRoute table_routes = 3;
}
message CreateRequest {
RequestHeader header = 1;
TableName table_name = 2;
repeated Partition partitions = 3;
}
message TableRoute {
Table table = 1;
repeated RegionRoute region_routes = 2;

View File

@@ -20,6 +20,9 @@ service Store {
// DeleteRange deletes the given range from the key-value store.
rpc DeleteRange(DeleteRangeRequest) returns (DeleteRangeResponse);
// MoveValue atomically renames the key to the given updated key.
rpc MoveValue(MoveValueRequest) returns (MoveValueResponse);
}
message RangeRequest {
@@ -136,3 +139,21 @@ message DeleteRangeResponse {
// returned.
repeated KeyValue prev_kvs = 3;
}
message MoveValueRequest {
RequestHeader header = 1;
// If from_key dose not exist, return the value of to_key (if it exists).
// If from_key exists, move the value of from_key to to_key (i.e. rename),
// and return the value.
bytes from_key = 2;
bytes to_key = 3;
}
message MoveValueResponse {
ResponseHeader header = 1;
// If from_key dose not exist, return the value of to_key (if it exists).
// If from_key exists, return the value of from_key.
KeyValue kv = 2;
}

View File

@@ -145,10 +145,12 @@ gen_set_header!(HeartbeatRequest);
gen_set_header!(RouteRequest);
gen_set_header!(CreateRequest);
gen_set_header!(RangeRequest);
gen_set_header!(DeleteRequest);
gen_set_header!(PutRequest);
gen_set_header!(BatchPutRequest);
gen_set_header!(CompareAndPutRequest);
gen_set_header!(DeleteRangeRequest);
gen_set_header!(MoveValueRequest);
#[cfg(test)]
mod tests {

View File

@@ -10,6 +10,7 @@ name = "greptime"
path = "src/bin/greptime.rs"
[dependencies]
anymap = "1.0.0-beta.2"
clap = { version = "3.1", features = ["derive"] }
common-error = { path = "../common/error" }
common-telemetry = { path = "../common/telemetry", features = [

View File

@@ -12,8 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
const DEFAULT_VALUE: &str = "unknown";
fn main() {
build_data::set_GIT_BRANCH();
build_data::set_GIT_COMMIT();
build_data::set_GIT_DIRTY();
println!(
"cargo:rustc-env=GIT_COMMIT={}",
build_data::get_git_commit().unwrap_or_else(|_| DEFAULT_VALUE.to_string())
);
println!(
"cargo:rustc-env=GIT_BRANCH={}",
build_data::get_git_branch().unwrap_or_else(|_| DEFAULT_VALUE.to_string())
);
println!(
"cargo:rustc-env=GIT_DIRTY={}",
build_data::get_git_dirty().map_or(DEFAULT_VALUE.to_string(), |v| v.to_string())
);
}

View File

@@ -77,7 +77,9 @@ fn print_version() -> &'static str {
"\ncommit: ",
env!("GIT_COMMIT"),
"\ndirty: ",
env!("GIT_DIRTY")
env!("GIT_DIRTY"),
"\nversion: ",
env!("CARGO_PKG_VERSION")
)
}

View File

@@ -25,12 +25,6 @@ pub enum Error {
source: datanode::error::Error,
},
#[snafu(display("Failed to build frontend, source: {}", source))]
BuildFrontend {
#[snafu(backtrace)]
source: frontend::error::Error,
},
#[snafu(display("Failed to start frontend, source: {}", source))]
StartFrontend {
#[snafu(backtrace)]
@@ -61,6 +55,12 @@ pub enum Error {
#[snafu(display("Illegal config: {}", msg))]
IllegalConfig { msg: String, backtrace: Backtrace },
#[snafu(display("Illegal auth config: {}", source))]
IllegalAuthConfig {
#[snafu(backtrace)]
source: servers::auth::Error,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -75,7 +75,7 @@ impl ErrorExt for Error {
StatusCode::InvalidArguments
}
Error::IllegalConfig { .. } => StatusCode::InvalidArguments,
Error::BuildFrontend { source, .. } => source.status_code(),
Error::IllegalAuthConfig { .. } => StatusCode::InvalidArguments,
}
}

View File

@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use anymap::AnyMap;
use clap::Parser;
use frontend::frontend::{Frontend, FrontendOptions};
use frontend::grpc::GrpcOptions;
@@ -21,11 +24,13 @@ use frontend::mysql::MysqlOptions;
use frontend::opentsdb::OpentsdbOptions;
use frontend::postgres::PostgresOptions;
use meta_client::MetaClientOpts;
use servers::auth::UserProviderRef;
use servers::http::HttpOptions;
use servers::Mode;
use servers::tls::{TlsMode, TlsOption};
use servers::{auth, Mode};
use snafu::ResultExt;
use crate::error::{self, Result};
use crate::error::{self, IllegalAuthConfigSnafu, Result};
use crate::toml_loader;
#[derive(Parser)]
@@ -71,21 +76,41 @@ pub struct StartCommand {
influxdb_enable: Option<bool>,
#[clap(long)]
metasrv_addr: Option<String>,
#[clap(long)]
tls_mode: Option<TlsMode>,
#[clap(long)]
tls_cert_path: Option<String>,
#[clap(long)]
tls_key_path: Option<String>,
#[clap(long)]
user_provider: Option<String>,
}
impl StartCommand {
async fn run(self) -> Result<()> {
let plugins = load_frontend_plugins(&self.user_provider)?;
let opts: FrontendOptions = self.try_into()?;
let mut frontend = Frontend::new(
opts.clone(),
Instance::try_new(&opts)
Instance::try_new_distributed(&opts)
.await
.context(error::StartFrontendSnafu)?,
plugins,
);
frontend.start().await.context(error::StartFrontendSnafu)
}
}
pub fn load_frontend_plugins(user_provider: &Option<String>) -> Result<AnyMap> {
let mut plugins = AnyMap::new();
if let Some(provider) = user_provider {
let provider = auth::user_provider_from_option(provider).context(IllegalAuthConfigSnafu)?;
plugins.insert::<UserProviderRef>(provider);
}
Ok(plugins)
}
impl TryFrom<StartCommand> for FrontendOptions {
type Error = error::Error;
@@ -96,6 +121,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
FrontendOptions::default()
};
let tls_option = TlsOption::new(cmd.tls_mode, cmd.tls_cert_path, cmd.tls_key_path);
if let Some(addr) = cmd.http_addr {
opts.http_options = Some(HttpOptions {
addr,
@@ -111,12 +138,14 @@ impl TryFrom<StartCommand> for FrontendOptions {
if let Some(addr) = cmd.mysql_addr {
opts.mysql_options = Some(MysqlOptions {
addr,
tls: Arc::new(tls_option.clone()),
..Default::default()
});
}
if let Some(addr) = cmd.postgres_addr {
opts.postgres_options = Some(PostgresOptions {
addr,
tls: Arc::new(tls_option),
..Default::default()
});
}
@@ -147,6 +176,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
mod tests {
use std::time::Duration;
use servers::auth::{Identity, Password, UserProviderRef};
use super::*;
#[test]
@@ -160,6 +191,10 @@ mod tests {
influxdb_enable: Some(false),
config_file: None,
metasrv_addr: None,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: None,
};
let opts: FrontendOptions = command.try_into().unwrap();
@@ -209,11 +244,14 @@ mod tests {
std::env::current_dir().unwrap().as_path().to_str().unwrap()
)),
metasrv_addr: None,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: None,
};
let fe_opts = FrontendOptions::try_from(command).unwrap();
assert_eq!(Mode::Distributed, fe_opts.mode);
assert_eq!("127.0.0.1:3001".to_string(), fe_opts.datanode_rpc_addr);
assert_eq!(
"127.0.0.1:4000".to_string(),
fe_opts.http_options.as_ref().unwrap().addr
@@ -223,4 +261,34 @@ mod tests {
fe_opts.http_options.as_ref().unwrap().timeout
);
}
#[tokio::test]
async fn test_try_from_start_command_to_anymap() {
let command = StartCommand {
http_addr: None,
grpc_addr: None,
mysql_addr: None,
postgres_addr: None,
opentsdb_addr: None,
influxdb_enable: None,
config_file: None,
metasrv_addr: None,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: Some("static_user_provider:cmd:test=test".to_string()),
};
let plugins = load_frontend_plugins(&command.user_provider);
assert!(plugins.is_ok());
let plugins = plugins.unwrap();
let provider = plugins.get::<UserProviderRef>();
assert!(provider.is_some());
let provider = provider.unwrap();
let result = provider
.auth(Identity::UserId("test", None), Password::PlainText("test"))
.await;
assert!(result.is_ok());
}
}

View File

@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use anymap::AnyMap;
use clap::Parser;
use common_telemetry::info;
use datanode::datanode::{Datanode, DatanodeOptions, ObjectStoreConfig};
@@ -26,13 +29,12 @@ use frontend::postgres::PostgresOptions;
use frontend::prometheus::PrometheusOptions;
use serde::{Deserialize, Serialize};
use servers::http::HttpOptions;
use servers::tls::{TlsMode, TlsOption};
use servers::Mode;
use snafu::ResultExt;
use tokio::try_join;
use crate::error::{
BuildFrontendSnafu, Error, IllegalConfigSnafu, Result, StartDatanodeSnafu, StartFrontendSnafu,
};
use crate::error::{Error, IllegalConfigSnafu, Result, StartDatanodeSnafu, StartFrontendSnafu};
use crate::frontend::load_frontend_plugins;
use crate::toml_loader;
#[derive(Parser)]
@@ -104,7 +106,6 @@ impl StandaloneOptions {
influxdb_options: self.influxdb_options,
prometheus_options: self.prometheus_options,
mode: self.mode,
datanode_rpc_addr: "127.0.0.1:3001".to_string(),
meta_client_opts: None,
}
}
@@ -137,12 +138,21 @@ struct StartCommand {
config_file: Option<String>,
#[clap(short = 'm', long = "memory-catalog")]
enable_memory_catalog: bool,
#[clap(long)]
tls_mode: Option<TlsMode>,
#[clap(long)]
tls_cert_path: Option<String>,
#[clap(long)]
tls_key_path: Option<String>,
#[clap(long)]
user_provider: Option<String>,
}
impl StartCommand {
async fn run(self) -> Result<()> {
let enable_memory_catalog = self.enable_memory_catalog;
let config_file = self.config_file.clone();
let plugins = load_frontend_plugins(&self.user_provider)?;
let fe_opts = FrontendOptions::try_from(self)?;
let dn_opts: DatanodeOptions = {
let mut opts: StandaloneOptions = if let Some(path) = config_file {
@@ -162,7 +172,7 @@ impl StartCommand {
let mut datanode = Datanode::new(dn_opts.clone())
.await
.context(StartDatanodeSnafu)?;
let mut frontend = build_frontend(fe_opts, &dn_opts, datanode.get_instance()).await?;
let mut frontend = build_frontend(fe_opts, plugins, datanode.get_instance()).await?;
// Start datanode instance before starting services, to avoid requests come in before internal components are started.
datanode
@@ -171,11 +181,7 @@ impl StartCommand {
.context(StartDatanodeSnafu)?;
info!("Datanode instance started");
try_join!(
async { datanode.start_services().await.context(StartDatanodeSnafu) },
async { frontend.start().await.context(StartFrontendSnafu) }
)?;
frontend.start().await.context(StartFrontendSnafu)?;
Ok(())
}
}
@@ -183,20 +189,12 @@ impl StartCommand {
/// Build frontend instance in standalone mode
async fn build_frontend(
fe_opts: FrontendOptions,
dn_opts: &DatanodeOptions,
plugins: AnyMap,
datanode_instance: InstanceRef,
) -> Result<Frontend<FeInstance>> {
let grpc_server_addr = &dn_opts.rpc_addr;
info!(
"Build frontend with datanode gRPC addr: {}",
grpc_server_addr
);
let mut frontend_instance = FeInstance::try_new(&fe_opts)
.await
.context(BuildFrontendSnafu)?;
frontend_instance.set_catalog_manager(datanode_instance.catalog_manager().clone());
let mut frontend_instance = FeInstance::new_standalone(datanode_instance.clone());
frontend_instance.set_script_handler(datanode_instance);
Ok(Frontend::new(fe_opts, frontend_instance))
Ok(Frontend::new(fe_opts, frontend_instance, plugins))
}
impl TryFrom<StartCommand> for FrontendOptions {
@@ -261,6 +259,18 @@ impl TryFrom<StartCommand> for FrontendOptions {
opts.influxdb_options = Some(InfluxdbOptions { enable: true });
}
let tls_option = TlsOption::new(cmd.tls_mode, cmd.tls_cert_path, cmd.tls_key_path);
if let Some(mut mysql_options) = opts.mysql_options {
mysql_options.tls = Arc::new(tls_option.clone());
opts.mysql_options = Some(mysql_options);
}
if let Some(mut postgres_options) = opts.postgres_options {
postgres_options.tls = Arc::new(tls_option);
opts.postgres_options = Some(postgres_options);
}
Ok(opts)
}
}
@@ -269,6 +279,8 @@ impl TryFrom<StartCommand> for FrontendOptions {
mod tests {
use std::time::Duration;
use servers::auth::{Identity, Password, UserProviderRef};
use super::*;
#[test]
@@ -285,11 +297,14 @@ mod tests {
)),
influxdb_enable: false,
enable_memory_catalog: false,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: None,
};
let fe_opts = FrontendOptions::try_from(cmd).unwrap();
assert_eq!(Mode::Standalone, fe_opts.mode);
assert_eq!("127.0.0.1:3001".to_string(), fe_opts.datanode_rpc_addr);
assert_eq!(
"127.0.0.1:4000".to_string(),
fe_opts.http_options.as_ref().unwrap().addr
@@ -309,4 +324,33 @@ mod tests {
assert_eq!(2, fe_opts.mysql_options.as_ref().unwrap().runtime_size);
assert!(fe_opts.influxdb_options.as_ref().unwrap().enable);
}
#[tokio::test]
async fn test_try_from_start_command_to_anymap() {
let command = StartCommand {
http_addr: None,
rpc_addr: None,
mysql_addr: None,
postgres_addr: None,
opentsdb_addr: None,
config_file: None,
influxdb_enable: false,
enable_memory_catalog: false,
tls_mode: None,
tls_cert_path: None,
tls_key_path: None,
user_provider: Some("static_user_provider:cmd:test=test".to_string()),
};
let plugins = load_frontend_plugins(&command.user_provider);
assert!(plugins.is_ok());
let plugins = plugins.unwrap();
let provider = plugins.get::<UserProviderRef>();
assert!(provider.is_some());
let provider = provider.unwrap();
let result = provider
.auth(Identity::UserId("test", None), Password::PlainText("test"))
.await;
assert!(result.is_ok());
}
}

View File

@@ -62,6 +62,19 @@ pub enum StatusCode {
/// Runtime resources exhausted, like creating threads failed.
RuntimeResourcesExhausted = 6000,
// ====== End of server related status code =======
// ====== Begin of auth related status code =====
/// User not exist
UserNotFound = 7000,
/// Unsupported password type
UnsupportedPasswordType = 7001,
/// Username and password does not match
UserPasswordMismatch = 7002,
/// Not found http authorization header
AuthHeaderNotFound = 7003,
/// Invalid http authorization header
InvalidAuthHeader = 7004,
// ====== End of auth related status code =====
}
impl StatusCode {

View File

@@ -21,7 +21,7 @@ use common_base::BitVec;
use common_error::prelude::ErrorExt;
use common_error::status_code::StatusCode;
use common_query::Output;
use common_recordbatch::{util, RecordBatches, SendableRecordBatchStream};
use common_recordbatch::{RecordBatches, SendableRecordBatchStream};
use datatypes::schema::SchemaRef;
use datatypes::types::{TimestampType, WrapperType};
use datatypes::vectors::{
@@ -51,13 +51,9 @@ pub async fn to_object_result(output: std::result::Result<Output, impl ErrorExt>
}
async fn collect(stream: SendableRecordBatchStream) -> Result<ObjectResult> {
let schema = stream.schema();
let recordbatches = util::collect(stream)
let recordbatches = RecordBatches::try_collect(stream)
.await
.and_then(|batches| RecordBatches::try_new(schema, batches))
.context(error::CollectRecordBatchesSnafu)?;
let object_result = build_result(recordbatches)?;
Ok(object_result)
}

View File

@@ -28,7 +28,7 @@ use datatypes::prelude::VectorRef;
use datatypes::schema::{Schema, SchemaRef};
use error::Result;
use futures::task::{Context, Poll};
use futures::Stream;
use futures::{Stream, TryStreamExt};
pub use recordbatch::RecordBatch;
use snafu::{ensure, ResultExt};
@@ -81,6 +81,12 @@ impl RecordBatches {
Ok(Self { schema, batches })
}
pub async fn try_collect(stream: SendableRecordBatchStream) -> Result<Self> {
let schema = stream.schema();
let batches = stream.try_collect::<Vec<_>>().await?;
Ok(Self { schema, batches })
}
#[inline]
pub fn empty() -> Self {
Self {

View File

@@ -27,7 +27,6 @@ common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }
datafusion = "14.0.0"
datatypes = { path = "../datatypes" }
frontend = { path = "../frontend" }
futures = "0.3"
hyper = { version = "0.14", features = ["full"] }
log-store = { path = "../log-store" }

View File

@@ -139,7 +139,16 @@ pub enum Error {
CreateDir { dir: String, source: std::io::Error },
#[snafu(display("Failed to open log store, source: {}", source))]
OpenLogStore { source: log_store::error::Error },
OpenLogStore {
#[snafu(backtrace)]
source: log_store::error::Error,
},
#[snafu(display("Failed to star log store gc task, source: {}", source))]
StartLogStore {
#[snafu(backtrace)]
source: log_store::error::Error,
},
#[snafu(display("Failed to storage engine, source: {}", source))]
OpenStorageEngine { source: StorageError },
@@ -358,6 +367,7 @@ impl ErrorExt for Error {
Error::BumpTableId { source, .. } => source.status_code(),
Error::MissingNodeId { .. } => StatusCode::InvalidArguments,
Error::MissingMetasrvOpts { .. } => StatusCode::InvalidArguments,
Error::StartLogStore { source, .. } => source.status_code(),
}
}

View File

@@ -36,12 +36,13 @@ use servers::Mode;
use snafu::prelude::*;
use storage::config::EngineConfig as StorageEngineConfig;
use storage::EngineImpl;
use store_api::logstore::LogStore;
use table::table::TableIdProviderRef;
use crate::datanode::{DatanodeOptions, ObjectStoreConfig};
use crate::error::{
self, CatalogSnafu, MetaClientInitSnafu, MissingMetasrvOptsSnafu, MissingNodeIdSnafu,
NewCatalogSnafu, Result,
NewCatalogSnafu, Result, StartLogStoreSnafu,
};
use crate::heartbeat::HeartbeatTask;
use crate::script::ScriptExecutor;
@@ -60,9 +61,8 @@ pub struct Instance {
pub(crate) catalog_manager: CatalogManagerRef,
pub(crate) script_executor: ScriptExecutor,
pub(crate) table_id_provider: Option<TableIdProviderRef>,
#[allow(unused)]
pub(crate) meta_client: Option<Arc<MetaClient>>,
pub(crate) heartbeat_task: Option<HeartbeatTask>,
pub(crate) logstore: Arc<LocalFileLogStore>,
}
pub type InstanceRef = Arc<Instance>;
@@ -70,7 +70,7 @@ pub type InstanceRef = Arc<Instance>;
impl Instance {
pub async fn new(opts: &DatanodeOptions) -> Result<Self> {
let object_store = new_object_store(&opts.storage).await?;
let log_store = create_local_file_log_store(opts).await?;
let logstore = Arc::new(create_local_file_log_store(&opts.wal_dir).await?);
let meta_client = match opts.mode {
Mode::Standalone => None,
@@ -90,7 +90,7 @@ impl Instance {
TableEngineConfig::default(),
EngineImpl::new(
StorageEngineConfig::default(),
Arc::new(log_store),
logstore.clone(),
object_store.clone(),
),
object_store,
@@ -158,9 +158,9 @@ impl Instance {
),
catalog_manager,
script_executor,
meta_client,
heartbeat_task,
table_id_provider,
logstore,
})
}
@@ -169,6 +169,7 @@ impl Instance {
.start()
.await
.context(NewCatalogSnafu)?;
self.logstore.start().await.context(StartLogStoreSnafu)?;
if let Some(task) = &self.heartbeat_task {
task.start().await?;
}
@@ -194,7 +195,7 @@ pub(crate) async fn new_object_store(store_config: &ObjectStoreConfig) -> Result
object_store
.layer(RetryLayer::new(ExponentialBackoff::default().with_jitter()))
.layer(MetricsLayer)
.layer(LoggingLayer)
.layer(LoggingLayer::default())
.layer(TracingLayer)
})
}
@@ -275,16 +276,16 @@ async fn new_metasrv_client(node_id: u64, meta_config: &MetaClientOpts) -> Resul
}
pub(crate) async fn create_local_file_log_store(
opts: &DatanodeOptions,
path: impl AsRef<str>,
) -> Result<LocalFileLogStore> {
let path = path.as_ref();
// create WAL directory
fs::create_dir_all(path::Path::new(&opts.wal_dir))
.context(error::CreateDirSnafu { dir: &opts.wal_dir })?;
fs::create_dir_all(path::Path::new(path)).context(error::CreateDirSnafu { dir: path })?;
info!("The WAL directory is: {}", &opts.wal_dir);
info!("The WAL directory is: {}", path);
let log_config = LogConfig {
log_file_dir: opts.wal_dir.clone(),
log_file_dir: path.to_string(),
..Default::default()
};

View File

@@ -24,7 +24,7 @@ use query::QueryEngineFactory;
use storage::config::EngineConfig as StorageEngineConfig;
use storage::EngineImpl;
use table::metadata::TableId;
use table::table::{TableIdProvider, TableIdProviderRef};
use table::table::TableIdProvider;
use crate::datanode::DatanodeOptions;
use crate::error::Result;
@@ -34,56 +34,6 @@ use crate::script::ScriptExecutor;
use crate::sql::SqlHandler;
impl Instance {
// This method is used in other crate's testing codes, so move it out of "cfg(test)".
// TODO(LFC): Delete it when callers no longer need it.
pub async fn new_mock() -> Result<Self> {
use mito::table::test_util::{new_test_object_store, MockEngine, MockMitoEngine};
let mock_info = meta_srv::mocks::mock_with_memstore().await;
let meta_client = Some(Arc::new(mock_meta_client(mock_info, 0).await));
let (_dir, object_store) = new_test_object_store("setup_mock_engine_and_table").await;
let mock_engine = Arc::new(MockMitoEngine::new(
TableEngineConfig::default(),
MockEngine::default(),
object_store,
));
let catalog_manager = Arc::new(
catalog::local::manager::LocalCatalogManager::try_new(mock_engine.clone())
.await
.unwrap(),
);
let factory = QueryEngineFactory::new(catalog_manager.clone());
let query_engine = factory.query_engine();
let sql_handler = SqlHandler::new(
mock_engine.clone(),
catalog_manager.clone(),
query_engine.clone(),
);
let script_executor = ScriptExecutor::new(catalog_manager.clone(), query_engine.clone())
.await
.unwrap();
let heartbeat_task = Some(HeartbeatTask::new(
0,
"127.0.0.1:3302".to_string(),
meta_client.as_ref().unwrap().clone(),
));
let table_id_provider = Some(catalog_manager.clone() as TableIdProviderRef);
Ok(Self {
query_engine,
sql_handler,
catalog_manager,
script_executor,
meta_client,
heartbeat_task,
table_id_provider,
})
}
pub async fn with_mock_meta_client(opts: &DatanodeOptions) -> Result<Self> {
let mock_info = meta_srv::mocks::mock_with_memstore().await;
Self::with_mock_meta_server(opts, mock_info).await
@@ -91,13 +41,13 @@ impl Instance {
pub async fn with_mock_meta_server(opts: &DatanodeOptions, meta_srv: MockInfo) -> Result<Self> {
let object_store = new_object_store(&opts.storage).await?;
let log_store = create_local_file_log_store(opts).await?;
let logstore = Arc::new(create_local_file_log_store(&opts.wal_dir).await?);
let meta_client = Arc::new(mock_meta_client(meta_srv, opts.node_id.unwrap_or(42)).await);
let table_engine = Arc::new(DefaultEngine::new(
TableEngineConfig::default(),
EngineImpl::new(
StorageEngineConfig::default(),
Arc::new(log_store),
logstore.clone(),
object_store.clone(),
),
object_store,
@@ -132,8 +82,8 @@ impl Instance {
catalog_manager,
script_executor,
table_id_provider: Some(Arc::new(LocalTableIdProvider::default())),
meta_client: Some(meta_client),
heartbeat_task: Some(heartbeat_task),
logstore,
})
}
}

View File

@@ -63,6 +63,7 @@ impl Services {
instance.clone(),
mysql_io_runtime,
Default::default(),
None,
))
}
};

View File

@@ -173,10 +173,10 @@ async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str
}
}
async fn setup_test_instance() -> Instance {
async fn setup_test_instance(test_name: &str) -> Instance {
common_telemetry::init_default_ut_logging();
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("execute_insert");
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts(test_name);
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
@@ -193,7 +193,7 @@ async fn setup_test_instance() -> Instance {
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_insert() {
let instance = setup_test_instance().await;
let instance = setup_test_instance("test_execute_insert").await;
let output = execute_sql(
&instance,
r#"insert into demo(host, cpu, memory, ts) values
@@ -409,18 +409,10 @@ async fn check_output_stream(output: Output, expected: Vec<&str>) {
assert_eq!(pretty_print, expected);
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn test_alter_table() {
let instance = Instance::new_mock().await.unwrap();
instance.start().await.unwrap();
let instance = setup_test_instance("test_alter_table").await;
test_util::create_test_table(
instance.catalog_manager(),
instance.sql_handler(),
ConcreteDataType::timestamp_millisecond_datatype(),
)
.await
.unwrap();
// make sure table insertion is ok before altering table
execute_sql(
&instance,

View File

@@ -5,6 +5,7 @@ edition = "2021"
license = "Apache-2.0"
[dependencies]
anymap = "1.0.0-beta.2"
api = { path = "../api" }
async-stream = "0.3"
async-trait = "0.1"
@@ -24,6 +25,7 @@ common-time = { path = "../common/time" }
datafusion = "14.0.0"
datafusion-common = "14.0.0"
datafusion-expr = "14.0.0"
datanode = { path = "../datanode" }
datatypes = { path = "../datatypes" }
futures = "0.3"
futures-util = "0.3"

View File

@@ -59,6 +59,16 @@ impl FrontendCatalogManager {
pub(crate) fn backend(&self) -> KvBackendRef {
self.backend.clone()
}
#[cfg(test)]
pub(crate) fn table_routes(&self) -> Arc<TableRoutes> {
self.table_routes.clone()
}
#[cfg(test)]
pub(crate) fn datanode_clients(&self) -> Arc<DatanodeClients> {
self.datanode_clients.clone()
}
}
// FIXME(hl): Frontend only needs a CatalogList, should replace with trait upcasting

View File

@@ -245,18 +245,6 @@ pub enum Error {
source: client::Error,
},
#[snafu(display("Failed to alter table, source: {}", source))]
AlterTable {
#[snafu(backtrace)]
source: client::Error,
},
#[snafu(display("Failed to drop table, source: {}", source))]
DropTable {
#[snafu(backtrace)]
source: client::Error,
},
#[snafu(display("Failed to insert values to table, source: {}", source))]
Insert {
#[snafu(backtrace)]
@@ -399,9 +387,6 @@ pub enum Error {
source: query::error::Error,
},
#[snafu(display("Unsupported expr type: {}", name))]
UnsupportedExpr { name: String, backtrace: Backtrace },
#[snafu(display("Failed to do vector computation, source: {}", source))]
VectorComputation {
#[snafu(backtrace)]
@@ -463,6 +448,12 @@ pub enum Error {
#[snafu(backtrace)]
source: datatypes::error::Error,
},
#[snafu(display("Failed to invoke GRPC server, source: {}", source))]
InvokeGrpcServer {
#[snafu(backtrace)]
source: servers::error::Error,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -482,7 +473,9 @@ impl ErrorExt for Error {
Error::RuntimeResource { source, .. } => source.status_code(),
Error::StartServer { source, .. } => source.status_code(),
Error::StartServer { source, .. } | Error::InvokeGrpcServer { source } => {
source.status_code()
}
Error::ParseSql { source } => source.status_code(),
@@ -512,7 +505,6 @@ impl ErrorExt for Error {
| Error::FindLeaderPeer { .. }
| Error::FindRegionPartition { .. }
| Error::IllegalTableRoutesData { .. }
| Error::UnsupportedExpr { .. }
| Error::BuildDfLogicalPlan { .. } => StatusCode::Internal,
Error::IllegalFrontendState { .. } | Error::IncompleteGrpcResult { .. } => {
@@ -534,8 +526,6 @@ impl ErrorExt for Error {
Error::SchemaNotFound { .. } => StatusCode::InvalidArguments,
Error::CatalogNotFound { .. } => StatusCode::InvalidArguments,
Error::CreateTable { source, .. }
| Error::AlterTable { source, .. }
| Error::DropTable { source }
| Error::Select { source, .. }
| Error::CreateDatabase { source, .. }
| Error::CreateTableOnInsertion { source, .. }

View File

@@ -14,8 +14,10 @@
use std::sync::Arc;
use anymap::AnyMap;
use meta_client::MetaClientOpts;
use serde::{Deserialize, Serialize};
use servers::auth::UserProviderRef;
use servers::http::HttpOptions;
use servers::Mode;
use snafu::prelude::*;
@@ -40,7 +42,6 @@ pub struct FrontendOptions {
pub influxdb_options: Option<InfluxdbOptions>,
pub prometheus_options: Option<PrometheusOptions>,
pub mode: Mode,
pub datanode_rpc_addr: String,
pub meta_client_opts: Option<MetaClientOpts>,
}
@@ -55,34 +56,26 @@ impl Default for FrontendOptions {
influxdb_options: Some(InfluxdbOptions::default()),
prometheus_options: Some(PrometheusOptions::default()),
mode: Mode::Standalone,
datanode_rpc_addr: "127.0.0.1:3001".to_string(),
meta_client_opts: None,
}
}
}
impl FrontendOptions {
pub(crate) fn datanode_grpc_addr(&self) -> String {
self.datanode_rpc_addr.clone()
}
}
pub struct Frontend<T>
where
T: FrontendInstance,
{
opts: FrontendOptions,
instance: Option<T>,
plugins: AnyMap,
}
impl<T> Frontend<T>
where
T: FrontendInstance,
{
pub fn new(opts: FrontendOptions, instance: T) -> Self {
impl<T: FrontendInstance> Frontend<T> {
pub fn new(opts: FrontendOptions, instance: T, plugins: AnyMap) -> Self {
Self {
opts,
instance: Some(instance),
plugins,
}
}
@@ -96,6 +89,9 @@ where
instance.start().await?;
let instance = Arc::new(instance);
Services::start(&self.opts, instance).await
let provider = self.plugins.get::<UserProviderRef>().cloned();
Services::start(&self.opts, instance, provider).await
}
}

View File

@@ -20,50 +20,49 @@ mod prometheus;
use std::sync::Arc;
use std::time::Duration;
use api::result::ObjectResultBuilder;
use api::result::{ObjectResultBuilder, PROTOCOL_VERSION};
use api::v1::alter_expr::Kind;
use api::v1::object_expr::Expr;
use api::v1::{
admin_expr, select_expr, AddColumns, AdminExpr, AdminResult, AlterExpr, Column,
CreateDatabaseExpr, CreateExpr, DropTableExpr, InsertExpr, ObjectExpr,
admin_expr, AddColumns, AdminExpr, AdminResult, AlterExpr, Column, CreateDatabaseExpr,
CreateExpr, DropTableExpr, ExprHeader, InsertExpr, ObjectExpr,
ObjectResult as GrpcObjectResult,
};
use async_trait::async_trait;
use catalog::remote::MetaKvBackend;
use catalog::{CatalogManagerRef, CatalogProviderRef, SchemaProviderRef};
use client::admin::{admin_result_to_output, Admin};
use client::{Client, Database, Select};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_error::prelude::{BoxedError, StatusCode};
use client::admin::admin_result_to_output;
use client::ObjectResult;
use common_catalog::consts::DEFAULT_CATALOG_NAME;
use common_error::prelude::BoxedError;
use common_grpc::channel_manager::{ChannelConfig, ChannelManager};
use common_grpc::select::to_object_result;
use common_query::Output;
use common_recordbatch::RecordBatches;
use common_telemetry::{debug, error, info};
use common_telemetry::{debug, info};
use datanode::instance::InstanceRef as DnInstanceRef;
use distributed::DistInstance;
use meta_client::client::MetaClientBuilder;
use meta_client::client::{MetaClient, MetaClientBuilder};
use meta_client::MetaClientOpts;
use servers::query_handler::{
GrpcAdminHandler, GrpcQueryHandler, InfluxdbLineProtocolHandler, OpentsdbProtocolHandler,
PrometheusProtocolHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler,
GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef,
InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler,
ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef,
};
use servers::{error as server_error, Mode};
use session::context::{QueryContext, QueryContextRef};
use session::context::QueryContextRef;
use snafu::prelude::*;
use sql::dialect::GenericDialect;
use sql::parser::ParserContext;
use sql::statements::create::Partitions;
use sql::statements::explain::Explain;
use sql::statements::insert::Insert;
use sql::statements::statement::Statement;
use table::TableRef;
use crate::catalog::FrontendCatalogManager;
use crate::datanode::DatanodeClients;
use crate::error::{
self, AlterTableOnInsertionSnafu, AlterTableSnafu, CatalogNotFoundSnafu, CatalogSnafu,
CreateDatabaseSnafu, CreateTableSnafu, DropTableSnafu, FindNewColumnsOnInsertionSnafu,
InsertSnafu, MissingMetasrvOptsSnafu, Result, SchemaNotFoundSnafu, SelectSnafu,
UnsupportedExprSnafu,
self, AlterTableOnInsertionSnafu, CatalogSnafu, CreateDatabaseSnafu, CreateTableSnafu,
FindNewColumnsOnInsertionSnafu, InsertSnafu, MissingMetasrvOptsSnafu, Result,
};
use crate::expr_factory::{CreateExprFactoryRef, DefaultCreateExprFactory};
use crate::frontend::FrontendOptions;
@@ -91,11 +90,7 @@ pub type FrontendInstanceRef = Arc<dyn FrontendInstance>;
#[derive(Clone)]
pub struct Instance {
// TODO(hl): In standalone mode, there is only one client.
// But in distribute mode, frontend should fetch datanodes' addresses from metasrv.
client: Client,
/// catalog manager is None in standalone mode, datanode will keep their own
catalog_manager: Option<CatalogManagerRef>,
catalog_manager: CatalogManagerRef,
/// Script handler is None in distributed mode, only works on standalone mode.
script_handler: Option<ScriptHandlerRef>,
create_expr_factory: CreateExprFactoryRef,
@@ -103,106 +98,91 @@ pub struct Instance {
// Standalone and Distributed, then the code behind it doesn't need to use so
// many match statements.
mode: Mode,
// TODO(LFC): Refactor consideration: Can we split Frontend to DistInstance and EmbedInstance?
dist_instance: Option<DistInstance>,
}
impl Default for Instance {
fn default() -> Self {
Self {
client: Client::default(),
catalog_manager: None,
script_handler: None,
create_expr_factory: Arc::new(DefaultCreateExprFactory {}),
mode: Mode::Standalone,
dist_instance: None,
}
}
// TODO(LFC): Remove `dist_instance` together with Arrow Flight adoption refactor.
dist_instance: Option<DistInstance>,
sql_handler: SqlQueryHandlerRef,
grpc_query_handler: GrpcQueryHandlerRef,
grpc_admin_handler: GrpcAdminHandlerRef,
}
impl Instance {
pub async fn try_new(opts: &FrontendOptions) -> Result<Self> {
let mut instance = Instance {
mode: opts.mode.clone(),
..Default::default()
};
pub async fn try_new_distributed(opts: &FrontendOptions) -> Result<Self> {
let meta_client = Self::create_meta_client(opts).await?;
let addr = opts.datanode_grpc_addr();
instance.client.start(vec![addr]);
let meta_backend = Arc::new(MetaKvBackend {
client: meta_client.clone(),
});
let table_routes = Arc::new(TableRoutes::new(meta_client.clone()));
let datanode_clients = Arc::new(DatanodeClients::new());
let catalog_manager = Arc::new(FrontendCatalogManager::new(
meta_backend,
table_routes,
datanode_clients.clone(),
));
instance.dist_instance = match &opts.mode {
Mode::Standalone => None,
Mode::Distributed => {
let metasrv_addr = &opts
.meta_client_opts
.as_ref()
.context(MissingMetasrvOptsSnafu)?
.metasrv_addrs;
info!(
"Creating Frontend instance in distributed mode with Meta server addr {:?}",
metasrv_addr
);
let dist_instance =
DistInstance::new(meta_client, catalog_manager.clone(), datanode_clients);
let dist_instance_ref = Arc::new(dist_instance.clone());
let meta_config = MetaClientOpts::default();
let channel_config = ChannelConfig::new()
.timeout(Duration::from_millis(meta_config.timeout_millis))
.connect_timeout(Duration::from_millis(meta_config.connect_timeout_millis))
.tcp_nodelay(meta_config.tcp_nodelay);
let channel_manager = ChannelManager::with_config(channel_config);
let mut meta_client = MetaClientBuilder::new(0, 0)
.enable_router()
.enable_store()
.channel_manager(channel_manager)
.build();
meta_client
.start(metasrv_addr)
.await
.context(error::StartMetaClientSnafu)?;
let meta_client = Arc::new(meta_client);
let meta_backend = Arc::new(MetaKvBackend {
client: meta_client.clone(),
});
let table_routes = Arc::new(TableRoutes::new(meta_client.clone()));
let datanode_clients = Arc::new(DatanodeClients::new());
let catalog_manager = Arc::new(FrontendCatalogManager::new(
meta_backend,
table_routes,
datanode_clients.clone(),
));
instance.catalog_manager = Some(catalog_manager.clone());
Some(DistInstance::new(
meta_client,
catalog_manager,
datanode_clients,
))
}
};
Ok(instance)
Ok(Instance {
catalog_manager,
script_handler: None,
create_expr_factory: Arc::new(DefaultCreateExprFactory),
mode: Mode::Distributed,
dist_instance: Some(dist_instance),
sql_handler: dist_instance_ref.clone(),
grpc_query_handler: dist_instance_ref.clone(),
grpc_admin_handler: dist_instance_ref,
})
}
pub fn database(&self, database: &str) -> Database {
Database::new(database, self.client.clone())
}
pub fn admin(&self, database: &str) -> Admin {
Admin::new(database, self.client.clone())
}
pub fn catalog_manager(&self) -> &Option<CatalogManagerRef> {
&self.catalog_manager
}
pub fn set_catalog_manager(&mut self, catalog_manager: CatalogManagerRef) {
debug_assert!(
self.catalog_manager.is_none(),
"Catalog manager can be set only once!"
async fn create_meta_client(opts: &FrontendOptions) -> Result<Arc<MetaClient>> {
let metasrv_addr = &opts
.meta_client_opts
.as_ref()
.context(MissingMetasrvOptsSnafu)?
.metasrv_addrs;
info!(
"Creating Frontend instance in distributed mode with Meta server addr {:?}",
metasrv_addr
);
self.catalog_manager = Some(catalog_manager);
let meta_config = MetaClientOpts::default();
let channel_config = ChannelConfig::new()
.timeout(Duration::from_millis(meta_config.timeout_millis))
.connect_timeout(Duration::from_millis(meta_config.connect_timeout_millis))
.tcp_nodelay(meta_config.tcp_nodelay);
let channel_manager = ChannelManager::with_config(channel_config);
let mut meta_client = MetaClientBuilder::new(0, 0)
.enable_router()
.enable_store()
.channel_manager(channel_manager)
.build();
meta_client
.start(metasrv_addr)
.await
.context(error::StartMetaClientSnafu)?;
Ok(Arc::new(meta_client))
}
pub fn new_standalone(dn_instance: DnInstanceRef) -> Self {
Instance {
catalog_manager: dn_instance.catalog_manager().clone(),
script_handler: None,
create_expr_factory: Arc::new(DefaultCreateExprFactory),
mode: Mode::Standalone,
dist_instance: None,
sql_handler: dn_instance.clone(),
grpc_query_handler: dn_instance.clone(),
grpc_admin_handler: dn_instance,
}
}
pub fn catalog_manager(&self) -> &CatalogManagerRef {
&self.catalog_manager
}
pub fn set_script_handler(&mut self, handler: ScriptHandlerRef) {
@@ -213,27 +193,6 @@ impl Instance {
self.script_handler = Some(handler);
}
async fn handle_select(
&self,
expr: Select,
stmt: Statement,
query_ctx: QueryContextRef,
) -> Result<Output> {
if let Some(dist_instance) = &self.dist_instance {
let Select::Sql(sql) = expr;
dist_instance.handle_sql(&sql, stmt, query_ctx).await
} else {
// TODO(LFC): Refactor consideration: Datanode should directly execute statement in standalone mode to avoid parse SQL again.
// Find a better way to execute query between Frontend and Datanode in standalone mode.
// Otherwise we have to parse SQL first to get schema name. Maybe not GRPC.
self.database(DEFAULT_SCHEMA_NAME)
.select(expr)
.await
.and_then(Output::try_from)
.context(SelectSnafu)
}
}
/// Handle create expr.
pub async fn handle_create_table(
&self,
@@ -243,81 +202,38 @@ impl Instance {
if let Some(v) = &self.dist_instance {
v.create_table(&mut expr, partitions).await
} else {
// Currently standalone mode does not support multi partitions/regions.
let expr = AdminExpr {
header: Some(ExprHeader {
version: PROTOCOL_VERSION,
}),
expr: Some(admin_expr::Expr::Create(expr)),
};
let result = self
.admin(expr.schema_name.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME))
.create(expr.clone())
.await;
if let Err(e) = &result {
error!(e; "Failed to create table by expr: {:?}", expr);
}
result
.and_then(admin_result_to_output)
.context(CreateTableSnafu)
.grpc_admin_handler
.exec_admin_request(expr)
.await
.context(error::InvokeGrpcServerSnafu)?;
admin_result_to_output(result).context(CreateTableSnafu)
}
}
/// Handle create database expr.
pub async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result<Output> {
let database_name = expr.database_name.clone();
if let Some(dist_instance) = &self.dist_instance {
dist_instance.handle_create_database(expr).await
} else {
// FIXME(hl): In order to get admin client to create schema, we need to use the default schema admin
self.admin(DEFAULT_SCHEMA_NAME)
.create_database(expr)
.await
.and_then(admin_result_to_output)
.context(CreateDatabaseSnafu {
name: database_name,
})
}
}
/// Handle alter expr
pub async fn handle_alter(&self, expr: AlterExpr) -> Result<Output> {
match &self.dist_instance {
Some(dist_instance) => dist_instance.handle_alter_table(expr).await,
None => self
.admin(expr.schema_name.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME))
.alter(expr)
.await
.and_then(admin_result_to_output)
.context(AlterTableSnafu),
}
}
/// Handle drop table expr
pub async fn handle_drop_table(&self, expr: DropTableExpr) -> Result<Output> {
match self.mode {
Mode::Standalone => self
.admin(&expr.schema_name)
.drop_table(expr)
.await
.and_then(admin_result_to_output)
.context(DropTableSnafu),
// TODO(ruihang): support drop table in distributed mode
Mode::Distributed => UnsupportedExprSnafu {
name: "Distributed DROP TABLE",
}
.fail(),
}
}
/// Handle explain expr
pub async fn handle_explain(
&self,
sql: &str,
explain_stmt: Explain,
query_ctx: QueryContextRef,
) -> Result<Output> {
if let Some(dist_instance) = &self.dist_instance {
dist_instance
.handle_sql(sql, Statement::Explain(explain_stmt), query_ctx)
.await
} else {
Ok(Output::AffectedRows(0))
}
let expr = AdminExpr {
header: Some(ExprHeader {
version: PROTOCOL_VERSION,
}),
expr: Some(admin_expr::Expr::CreateDatabase(expr)),
};
let result = self
.grpc_admin_handler
.exec_admin_request(expr)
.await
.context(error::InvokeGrpcServerSnafu)?;
admin_result_to_output(result).context(CreateDatabaseSnafu {
name: database_name,
})
}
/// Handle batch inserts
@@ -333,7 +249,7 @@ impl Instance {
}
/// Handle insert. for 'values' insertion, create/alter the destination table on demand.
pub async fn handle_insert(&self, mut insert_expr: InsertExpr) -> Result<Output> {
async fn handle_insert(&self, mut insert_expr: InsertExpr) -> Result<Output> {
let table_name = &insert_expr.table_name;
let catalog_name = DEFAULT_CATALOG_NAME;
let schema_name = &insert_expr.schema_name;
@@ -345,11 +261,17 @@ impl Instance {
insert_expr.region_number = 0;
self.database(schema_name)
.insert(insert_expr)
let query = ObjectExpr {
header: Some(ExprHeader {
version: PROTOCOL_VERSION,
}),
expr: Some(Expr::Insert(insert_expr)),
};
let result = GrpcQueryHandler::do_query(&*self.grpc_query_handler, query)
.await
.and_then(Output::try_from)
.context(InsertSnafu)
.context(error::InvokeGrpcServerSnafu)?;
let result: ObjectResult = result.try_into().context(InsertSnafu)?;
result.try_into().context(InsertSnafu)
}
// check if table already exist:
@@ -362,21 +284,7 @@ impl Instance {
table_name: &str,
columns: &[Column],
) -> Result<()> {
match self
.catalog_manager
.as_ref()
.expect("catalog manager cannot be None")
.catalog(catalog_name)
.context(CatalogSnafu)?
.context(CatalogNotFoundSnafu { catalog_name })?
.schema(schema_name)
.context(CatalogSnafu)?
.context(SchemaNotFoundSnafu {
schema_info: schema_name,
})?
.table(table_name)
.context(CatalogSnafu)?
{
match self.find_table(catalog_name, schema_name, table_name)? {
None => {
info!(
"Table {}.{}.{} does not exist, try create table",
@@ -455,17 +363,23 @@ impl Instance {
catalog_name: Some(catalog_name.to_string()),
kind: Some(Kind::AddColumns(add_columns)),
};
self.admin(schema_name)
.alter(expr)
let expr = AdminExpr {
header: Some(ExprHeader {
version: PROTOCOL_VERSION,
}),
expr: Some(admin_expr::Expr::Alter(expr)),
};
let result = self
.grpc_admin_handler
.exec_admin_request(expr)
.await
.and_then(admin_result_to_output)
.context(AlterTableOnInsertionSnafu)
.context(error::InvokeGrpcServerSnafu)?;
admin_result_to_output(result).context(AlterTableOnInsertionSnafu)
}
fn get_catalog(&self, catalog_name: &str) -> Result<CatalogProviderRef> {
self.catalog_manager
.as_ref()
.context(error::CatalogManagerSnafu)?
.catalog(catalog_name)
.context(error::CatalogSnafu)?
.context(error::CatalogNotFoundSnafu { catalog_name })
@@ -480,6 +394,12 @@ impl Instance {
})
}
fn find_table(&self, catalog: &str, schema: &str, table: &str) -> Result<Option<TableRef>> {
self.catalog_manager
.table(catalog, schema, table)
.context(CatalogSnafu)
}
async fn sql_dist_insert(&self, insert: Box<Insert>) -> Result<usize> {
let (catalog, schema, table) = insert.full_table_name().context(error::ParseSqlSnafu)?;
@@ -519,23 +439,17 @@ impl Instance {
}
fn handle_use(&self, db: String, query_ctx: QueryContextRef) -> Result<Output> {
let catalog_manager = &self.catalog_manager;
if let Some(catalog_manager) = catalog_manager {
ensure!(
catalog_manager
.schema(DEFAULT_CATALOG_NAME, &db)
.context(error::CatalogSnafu)?
.is_some(),
error::SchemaNotFoundSnafu { schema_info: &db }
);
ensure!(
self.catalog_manager
.schema(DEFAULT_CATALOG_NAME, &db)
.context(error::CatalogSnafu)?
.is_some(),
error::SchemaNotFoundSnafu { schema_info: &db }
);
query_ctx.set_current_schema(&db);
query_ctx.set_current_schema(&db);
Ok(Output::RecordBatches(RecordBatches::empty()))
} else {
// TODO(LFC): Handle "use" stmt here.
unimplemented!()
}
Ok(Output::RecordBatches(RecordBatches::empty()))
}
}
@@ -547,20 +461,6 @@ impl FrontendInstance for Instance {
}
}
#[cfg(test)]
impl Instance {
pub fn with_client_and_catalog_manager(client: Client, catalog: CatalogManagerRef) -> Self {
Self {
client,
catalog_manager: Some(catalog),
script_handler: None,
create_expr_factory: Arc::new(DefaultCreateExprFactory),
mode: Mode::Standalone,
dist_instance: None,
}
}
}
fn parse_stmt(sql: &str) -> Result<Statement> {
let mut stmt = ParserContext::create_with_dialect(sql, &GenericDialect {})
.context(error::ParseSqlSnafu)?;
@@ -587,12 +487,14 @@ impl SqlQueryHandler for Instance {
.context(server_error::ExecuteQuerySnafu { query })?;
match stmt {
Statement::ShowDatabases(_)
Statement::CreateDatabase(_)
| Statement::ShowDatabases(_)
| Statement::CreateTable(_)
| Statement::ShowTables(_)
| Statement::DescribeTable(_)
| Statement::Explain(_)
| Statement::Query(_) => {
self.handle_select(Select::Sql(query.to_string()), stmt, query_ctx)
.await
return self.sql_handler.do_query(query, query_ctx).await;
}
Statement::Insert(insert) => match self.mode {
Mode::Standalone => {
@@ -629,30 +531,18 @@ impl SqlQueryHandler for Instance {
Ok(Output::AffectedRows(affected))
}
},
Statement::CreateTable(create) => {
let create_expr = self
.create_expr_factory
.create_expr_by_stmt(&create)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query })?;
self.handle_create_table(create_expr, create.partitions)
.await
}
Statement::CreateDatabase(c) => {
let expr = CreateDatabaseExpr {
database_name: c.name.to_string(),
};
self.handle_create_database(expr).await
}
Statement::Alter(alter_stmt) => {
self.handle_alter(
AlterExpr::try_from(alter_stmt)
.map_err(BoxedError::new)
.context(server_error::ExecuteAlterSnafu { query })?,
)
.await
let expr = AlterExpr::try_from(alter_stmt)
.map_err(BoxedError::new)
.context(server_error::ExecuteAlterSnafu { query })?;
let expr = AdminExpr {
header: Some(ExprHeader {
version: PROTOCOL_VERSION,
}),
expr: Some(admin_expr::Expr::Alter(expr)),
};
let result = self.grpc_admin_handler.exec_admin_request(expr).await?;
admin_result_to_output(result).context(error::InvalidAdminResultSnafu)
}
Statement::DropTable(drop_stmt) => {
let expr = DropTableExpr {
@@ -660,10 +550,14 @@ impl SqlQueryHandler for Instance {
schema_name: drop_stmt.schema_name,
table_name: drop_stmt.table_name,
};
self.handle_drop_table(expr).await
}
Statement::Explain(explain_stmt) => {
self.handle_explain(query, explain_stmt, query_ctx).await
let expr = AdminExpr {
header: Some(ExprHeader {
version: PROTOCOL_VERSION,
}),
expr: Some(admin_expr::Expr::DropTable(expr)),
};
let result = self.grpc_admin_handler.exec_admin_request(expr).await?;
admin_result_to_output(result).context(error::InvalidAdminResultSnafu)
}
Statement::ShowCreateTable(_) => {
return server_error::NotSupportedSnafu { feat: query }.fail();
@@ -703,81 +597,34 @@ impl ScriptHandler for Instance {
#[async_trait]
impl GrpcQueryHandler for Instance {
async fn do_query(&self, query: ObjectExpr) -> server_error::Result<GrpcObjectResult> {
if let Some(expr) = &query.expr {
match expr {
Expr::Insert(insert) => {
// TODO(fys): refactor, avoid clone
let result = self.handle_insert(insert.clone()).await;
result
.map(|o| match o {
Output::AffectedRows(rows) => ObjectResultBuilder::new()
.status_code(StatusCode::Success as u32)
.mutate_result(rows as u32, 0u32)
.build(),
_ => {
unreachable!()
}
})
.map_err(BoxedError::new)
.with_context(|_| server_error::ExecuteQuerySnafu {
query: format!("{:?}", query),
})
}
Expr::Select(select) => {
let select = select
.expr
.as_ref()
.context(server_error::InvalidQuerySnafu {
reason: "empty query",
})?;
match select {
select_expr::Expr::Sql(sql) => {
let query_ctx = Arc::new(QueryContext::new());
let output = SqlQueryHandler::do_query(self, sql, query_ctx).await;
Ok(to_object_result(output).await)
}
_ => {
if self.dist_instance.is_some() {
return server_error::NotSupportedSnafu {
feat: "Executing plan directly in Frontend.",
}
.fail();
}
// FIXME(hl): refactor
self.database(DEFAULT_SCHEMA_NAME)
.object(query.clone())
.await
.map_err(BoxedError::new)
.with_context(|_| server_error::ExecuteQuerySnafu {
query: format!("{:?}", query),
})
}
}
}
_ => server_error::NotSupportedSnafu {
feat: "Currently only insert and select is supported in GRPC service.",
}
.fail(),
let expr = query
.clone()
.expr
.context(server_error::InvalidQuerySnafu {
reason: "empty expr",
})?;
match expr {
Expr::Insert(insert_expr) => {
let output = self
.handle_insert(insert_expr.clone())
.await
.map_err(BoxedError::new)
.with_context(|_| server_error::ExecuteQuerySnafu {
query: format!("{:?}", insert_expr),
})?;
let object_result = match output {
Output::AffectedRows(rows) => ObjectResultBuilder::default()
.mutate_result(rows as _, 0)
.build(),
_ => unreachable!(),
};
Ok(object_result)
}
} else {
server_error::InvalidQuerySnafu {
reason: "empty query",
}
.fail()
_ => GrpcQueryHandler::do_query(&*self.grpc_query_handler, query).await,
}
}
}
fn get_schema_name(expr: &AdminExpr) -> &str {
let schema_name = match &expr.expr {
Some(admin_expr::Expr::Create(expr)) => expr.schema_name.as_deref(),
Some(admin_expr::Expr::Alter(expr)) => expr.schema_name.as_deref(),
Some(admin_expr::Expr::CreateDatabase(_)) | None => Some(DEFAULT_SCHEMA_NAME),
Some(admin_expr::Expr::DropTable(expr)) => Some(expr.schema_name.as_ref()),
};
schema_name.unwrap_or(DEFAULT_SCHEMA_NAME)
}
#[async_trait]
impl GrpcAdminHandler for Instance {
async fn exec_admin_request(&self, mut expr: AdminExpr) -> server_error::Result<AdminResult> {
@@ -786,13 +633,7 @@ impl GrpcAdminHandler for Instance {
if let Some(api::v1::admin_expr::Expr::Create(create)) = &mut expr.expr {
create.table_id = None;
}
self.admin(get_schema_name(&expr))
.do_request(expr.clone())
.await
.map_err(BoxedError::new)
.with_context(|_| server_error::ExecuteQuerySnafu {
query: format!("{:?}", expr),
})
self.grpc_admin_handler.exec_admin_request(expr).await
}
}
@@ -808,15 +649,16 @@ mod tests {
};
use datatypes::schema::ColumnDefaultConstraint;
use datatypes::value::Value;
use session::context::QueryContext;
use super::*;
use crate::tests;
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_sql() {
let query_ctx = Arc::new(QueryContext::new());
let instance = tests::create_frontend_instance().await;
let (instance, _guard) = tests::create_frontend_instance("test_execute_sql").await;
let sql = r#"CREATE TABLE demo(
host STRING,
@@ -853,7 +695,8 @@ mod tests {
.await
.unwrap();
match output {
Output::RecordBatches(recordbatches) => {
Output::Stream(stream) => {
let recordbatches = RecordBatches::try_collect(stream).await.unwrap();
let pretty_print = recordbatches.pretty_print();
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
let expected = vec![
@@ -875,7 +718,8 @@ mod tests {
.await
.unwrap();
match output {
Output::RecordBatches(recordbatches) => {
Output::Stream(stream) => {
let recordbatches = RecordBatches::try_collect(stream).await.unwrap();
let pretty_print = recordbatches.pretty_print();
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
let expected = vec![
@@ -892,9 +736,9 @@ mod tests {
};
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_grpc() {
let instance = tests::create_frontend_instance().await;
let (instance, _guard) = tests::create_frontend_instance("test_execute_grpc").await;
// testing data:
let expected_host_col = Column {

View File

@@ -16,12 +16,18 @@ use std::collections::HashMap;
use std::sync::Arc;
use api::helper::ColumnDataTypeWrapper;
use api::v1::{AlterExpr, CreateDatabaseExpr, CreateExpr};
use api::result::AdminResultBuilder;
use api::v1::{
admin_expr, AdminExpr, AdminResult, AlterExpr, CreateDatabaseExpr, CreateExpr, ObjectExpr,
ObjectResult,
};
use async_trait::async_trait;
use catalog::helper::{SchemaKey, SchemaValue, TableGlobalKey, TableGlobalValue};
use catalog::CatalogList;
use chrono::DateTime;
use client::admin::{admin_result_to_output, Admin};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_error::prelude::BoxedError;
use common_query::Output;
use common_telemetry::{debug, error, info};
use datatypes::prelude::ConcreteDataType;
@@ -33,6 +39,8 @@ use meta_client::rpc::{
};
use query::sql::{describe_table, explain, show_databases, show_tables};
use query::{QueryEngineFactory, QueryEngineRef};
use servers::error as server_error;
use servers::query_handler::{GrpcAdminHandler, GrpcQueryHandler, SqlQueryHandler};
use session::context::QueryContextRef;
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::Value as SqlValue;
@@ -48,6 +56,8 @@ use crate::error::{
PrimaryKeyNotFoundSnafu, RequestMetaSnafu, Result, SchemaNotFoundSnafu, StartMetaClientSnafu,
TableNotFoundSnafu,
};
use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory};
use crate::instance::parse_stmt;
use crate::partitioning::{PartitionBound, PartitionDef};
use crate::table::DistTable;
@@ -126,15 +136,12 @@ impl DistInstance {
.context(error::InvalidAdminResultSnafu)?;
}
Ok(Output::AffectedRows(region_routes.len()))
// Checked in real MySQL, it truly returns "0 rows affected".
Ok(Output::AffectedRows(0))
}
pub(crate) async fn handle_sql(
&self,
sql: &str,
stmt: Statement,
query_ctx: QueryContextRef,
) -> Result<Output> {
async fn handle_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result<Output> {
let stmt = parse_stmt(sql)?;
match stmt {
Statement::Query(_) => {
let plan = self
@@ -143,6 +150,17 @@ impl DistInstance {
.context(error::ExecuteSqlSnafu { sql })?;
self.query_engine.execute(&plan).await
}
Statement::CreateDatabase(stmt) => {
let expr = CreateDatabaseExpr {
database_name: stmt.name.to_string(),
};
self.handle_create_database(expr).await?;
Ok(Output::AffectedRows(1))
}
Statement::CreateTable(stmt) => {
let create_expr = &mut DefaultCreateExprFactory.create_expr_by_stmt(&stmt).await?;
Ok(self.create_table(create_expr, stmt.partitions).await?)
}
Statement::ShowDatabases(stmt) => show_databases(stmt, self.catalog_manager.clone()),
Statement::ShowTables(stmt) => {
show_tables(stmt, self.catalog_manager.clone(), query_ctx)
@@ -157,7 +175,7 @@ impl DistInstance {
}
/// Handles distributed database creation
pub(crate) async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result<Output> {
async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result<()> {
let key = SchemaKey {
catalog_name: DEFAULT_CATALOG_NAME.to_string(),
schema_name: expr.database_name,
@@ -172,10 +190,10 @@ impl DistInstance {
.with_key(key.to_string())
.with_value(value.as_bytes().context(CatalogEntrySerdeSnafu)?);
client.put(request.into()).await.context(RequestMetaSnafu)?;
Ok(Output::AffectedRows(1))
Ok(())
}
pub async fn handle_alter_table(&self, expr: AlterExpr) -> Result<Output> {
async fn handle_alter_table(&self, expr: AlterExpr) -> Result<AdminResult> {
let catalog_name = expr.catalog_name.as_deref().unwrap_or(DEFAULT_CATALOG_NAME);
let schema_name = expr.schema_name.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME);
let table_name = expr.table_name.as_str();
@@ -200,7 +218,7 @@ impl DistInstance {
.downcast_ref::<DistTable>()
.expect("Table impl must be DistTable in distributed mode");
dist_table.alter_by_expr(expr).await?;
Ok(Output::AffectedRows(0))
Ok(AdminResultBuilder::default().mutate_result(0, 0).build())
}
async fn create_table_in_meta(
@@ -269,6 +287,56 @@ impl DistInstance {
}
Ok(())
}
#[cfg(test)]
pub(crate) fn catalog_manager(&self) -> Arc<FrontendCatalogManager> {
self.catalog_manager.clone()
}
}
#[async_trait]
impl SqlQueryHandler for DistInstance {
async fn do_query(
&self,
query: &str,
query_ctx: QueryContextRef,
) -> server_error::Result<Output> {
self.handle_sql(query, query_ctx)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query })
}
}
#[async_trait]
impl GrpcQueryHandler for DistInstance {
async fn do_query(&self, _: ObjectExpr) -> server_error::Result<ObjectResult> {
unimplemented!()
}
}
#[async_trait]
impl GrpcAdminHandler for DistInstance {
async fn exec_admin_request(&self, query: AdminExpr) -> server_error::Result<AdminResult> {
let expr = query
.clone()
.expr
.context(server_error::InvalidQuerySnafu {
reason: "empty expr",
})?;
match expr {
admin_expr::Expr::CreateDatabase(create_database) => self
.handle_create_database(create_database)
.await
.map(|_| AdminResultBuilder::default().mutate_result(1, 0).build()),
admin_expr::Expr::Alter(alter) => self.handle_alter_table(alter).await,
_ => unimplemented!(),
}
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu {
query: format!("{:?}", query),
})
}
}
fn create_table_global_value(
@@ -454,12 +522,15 @@ fn find_partition_columns(
#[cfg(test)]
mod test {
use servers::query_handler::SqlQueryHandlerRef;
use session::context::QueryContext;
use sql::dialect::GenericDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use super::*;
use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory};
use crate::tests::create_dist_instance;
#[tokio::test]
async fn test_parse_partitions() {
@@ -492,9 +563,10 @@ ENGINE=mito",
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
match &result[0] {
Statement::CreateTable(c) => {
common_telemetry::info!("{}", sql);
let factory = DefaultCreateExprFactory {};
let expr = factory.create_expr_by_stmt(c).await.unwrap();
let expr = DefaultCreateExprFactory
.create_expr_by_stmt(c)
.await
.unwrap();
let partitions = parse_partitions(&expr, c.partitions.clone()).unwrap();
let json = serde_json::to_string(&partitions).unwrap();
assert_eq!(json, expected);
@@ -503,4 +575,103 @@ ENGINE=mito",
}
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_show_databases() {
let (dist_instance, _) = create_dist_instance().await;
let sql = "create database test_show_databases";
let output = dist_instance
.handle_sql(sql, QueryContext::arc())
.await
.unwrap();
match output {
Output::AffectedRows(rows) => assert_eq!(rows, 1),
_ => unreachable!(),
}
let sql = "show databases";
let output = dist_instance
.handle_sql(sql, QueryContext::arc())
.await
.unwrap();
match output {
Output::RecordBatches(r) => {
let expected1 = vec![
"+---------------------+",
"| Schemas |",
"+---------------------+",
"| public |",
"| test_show_databases |",
"+---------------------+",
];
let expected2 = vec![
"+---------------------+",
"| Schemas |",
"+---------------------+",
"| test_show_databases |",
"| public |",
"+---------------------+",
];
let pretty = r.pretty_print();
let lines = pretty.lines().collect::<Vec<_>>();
assert!(lines == expected1 || lines == expected2)
}
_ => unreachable!(),
}
}
#[tokio::test(flavor = "multi_thread")]
async fn test_show_tables() {
let (dist_instance, datanode_instances) = create_dist_instance().await;
let sql = "create database test_show_tables";
dist_instance
.handle_sql(sql, QueryContext::arc())
.await
.unwrap();
let sql = "
CREATE TABLE greptime.test_show_tables.dist_numbers (
ts BIGINT,
n INT,
TIME INDEX (ts),
)
PARTITION BY RANGE COLUMNS (n) (
PARTITION r0 VALUES LESS THAN (10),
PARTITION r1 VALUES LESS THAN (20),
PARTITION r2 VALUES LESS THAN (50),
PARTITION r3 VALUES LESS THAN (MAXVALUE),
)
ENGINE=mito";
dist_instance
.handle_sql(sql, QueryContext::arc())
.await
.unwrap();
async fn assert_show_tables(instance: SqlQueryHandlerRef) {
let sql = "show tables in test_show_tables";
let output = instance.do_query(sql, QueryContext::arc()).await.unwrap();
match output {
Output::RecordBatches(r) => {
let expected = vec![
"+--------------+",
"| Tables |",
"+--------------+",
"| dist_numbers |",
"+--------------+",
];
assert_eq!(r.pretty_print().lines().collect::<Vec<_>>(), expected);
}
_ => unreachable!(),
}
}
assert_show_tables(Arc::new(dist_instance)).await;
// Asserts that new table is created in Datanode as well.
for x in datanode_instances.values() {
assert_show_tables(x.clone()).await
}
}
}

View File

@@ -19,7 +19,6 @@ use servers::query_handler::OpentsdbProtocolHandler;
use servers::{error as server_error, Mode};
use snafu::prelude::*;
use crate::error::Result;
use crate::instance::Instance;
#[async_trait]
@@ -29,12 +28,7 @@ impl OpentsdbProtocolHandler for Instance {
// metric table and tags can be created upon insertion.
match self.mode {
Mode::Standalone => {
self.insert_opentsdb_metric(data_point)
.await
.map_err(BoxedError::new)
.with_context(|_| server_error::PutOpentsdbDataPointSnafu {
data_point: format!("{:?}", data_point),
})?;
self.insert_opentsdb_metric(data_point).await?;
}
Mode::Distributed => {
self.dist_insert(vec![data_point.as_grpc_insert()])
@@ -51,9 +45,14 @@ impl OpentsdbProtocolHandler for Instance {
}
impl Instance {
async fn insert_opentsdb_metric(&self, data_point: &DataPoint) -> Result<()> {
let expr = data_point.as_grpc_insert();
self.handle_insert(expr).await?;
async fn insert_opentsdb_metric(&self, data_point: &DataPoint) -> server_error::Result<()> {
let insert_expr = data_point.as_grpc_insert();
self.handle_insert(insert_expr)
.await
.map_err(BoxedError::new)
.with_context(|_| server_error::ExecuteQuerySnafu {
query: format!("{:?}", data_point),
})?;
Ok(())
}
}
@@ -63,6 +62,7 @@ mod tests {
use std::sync::Arc;
use common_query::Output;
use common_recordbatch::RecordBatches;
use datafusion::arrow_print;
use servers::query_handler::SqlQueryHandler;
use session::context::QueryContext;
@@ -70,9 +70,9 @@ mod tests {
use super::*;
use crate::tests;
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn test_exec() {
let instance = tests::create_frontend_instance().await;
let (instance, _guard) = tests::create_frontend_instance("test_exec").await;
instance
.exec(
&DataPoint::try_create(
@@ -88,9 +88,10 @@ mod tests {
.unwrap();
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn test_insert_opentsdb_metric() {
let instance = tests::create_frontend_instance().await;
let (instance, _guard) =
tests::create_frontend_instance("test_insert_opentsdb_metric").await;
let data_point1 = DataPoint::new(
"my_metric_1".to_string(),
@@ -124,11 +125,15 @@ mod tests {
assert!(result.is_ok());
let output = instance
.do_query("select * from my_metric_1", Arc::new(QueryContext::new()))
.do_query(
"select * from my_metric_1 order by greptime_timestamp",
Arc::new(QueryContext::new()),
)
.await
.unwrap();
match output {
Output::RecordBatches(recordbatches) => {
Output::Stream(stream) => {
let recordbatches = RecordBatches::try_collect(stream).await.unwrap();
let recordbatches = recordbatches
.take()
.into_iter()

View File

@@ -17,11 +17,10 @@ use std::sync::Arc;
use api::prometheus::remote::read_request::ResponseType;
use api::prometheus::remote::{Query, QueryResult, ReadRequest, ReadResponse, WriteRequest};
use async_trait::async_trait;
use client::{ObjectResult, Select};
use client::ObjectResult;
use common_error::prelude::BoxedError;
use common_grpc::select::to_object_result;
use common_telemetry::logging;
use futures_util::TryFutureExt;
use prost::Message;
use servers::error::{self, Result as ServerResult};
use servers::prometheus::{self, Metrics};
@@ -30,7 +29,7 @@ use servers::Mode;
use session::context::QueryContext;
use snafu::{OptionExt, ResultExt};
use crate::instance::{parse_stmt, Instance};
use crate::instance::Instance;
const SAMPLES_RESPONSE_TYPE: i32 = ResponseType::Samples as i32;
@@ -94,19 +93,14 @@ impl Instance {
sql
);
let object_result = if let Some(dist_instance) = &self.dist_instance {
let output = futures::future::ready(parse_stmt(&sql))
.and_then(|stmt| {
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
dist_instance.handle_sql(&sql, stmt, query_ctx)
})
.await;
to_object_result(output).await.try_into()
} else {
self.database(db).select(Select::Sql(sql.clone())).await
}
.map_err(BoxedError::new)
.context(error::ExecuteQuerySnafu { query: sql })?;
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
let output = self.sql_handler.do_query(&sql, query_ctx).await;
let object_result = to_object_result(output)
.await
.try_into()
.map_err(BoxedError::new)
.context(error::ExecuteQuerySnafu { query: sql })?;
results.push((table_name, object_result));
}
@@ -117,34 +111,25 @@ impl Instance {
#[async_trait]
impl PrometheusProtocolHandler for Instance {
async fn write(&self, database: &str, request: WriteRequest) -> ServerResult<()> {
let exprs = prometheus::write_request_to_insert_exprs(database, request.clone())?;
match self.mode {
Mode::Standalone => {
let exprs = prometheus::write_request_to_insert_exprs(database, request)?;
let futures = exprs
.into_iter()
.map(|e| self.handle_insert(e))
.collect::<Vec<_>>();
let res = futures_util::future::join_all(futures)
self.handle_inserts(exprs)
.await
.into_iter()
.collect::<Result<Vec<_>, crate::error::Error>>();
res.map_err(BoxedError::new)
.context(error::ExecuteInsertSnafu {
msg: "failed to write prometheus remote request",
.map_err(BoxedError::new)
.with_context(|_| error::ExecuteInsertSnafu {
msg: format!("{:?}", request),
})?;
}
Mode::Distributed => {
let inserts = prometheus::write_request_to_insert_exprs(database, request)?;
self.dist_insert(inserts)
self.dist_insert(exprs)
.await
.map_err(BoxedError::new)
.context(error::ExecuteInsertSnafu {
msg: "execute insert failed",
.with_context(|_| error::ExecuteInsertSnafu {
msg: format!("{:?}", request),
})?;
}
}
Ok(())
}
@@ -197,10 +182,11 @@ mod tests {
use super::*;
use crate::tests;
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn test_prometheus_remote_write_and_read() {
common_telemetry::init_default_ut_logging();
let instance = tests::create_frontend_instance().await;
let (instance, _guard) =
tests::create_frontend_instance("test_prometheus_remote_write_and_read").await;
let write_request = WriteRequest {
timeseries: prometheus::mock_timeseries(),

View File

@@ -19,7 +19,6 @@ use servers::tls::TlsOption;
pub struct PostgresOptions {
pub addr: String,
pub runtime_size: usize,
pub check_pwd: bool,
#[serde(default = "Default::default")]
pub tls: TlsOption,
}
@@ -29,7 +28,6 @@ impl Default for PostgresOptions {
Self {
addr: "127.0.0.1:4003".to_string(),
runtime_size: 2,
check_pwd: false,
tls: Default::default(),
}
}

View File

@@ -17,6 +17,7 @@ use std::sync::Arc;
use common_runtime::Builder as RuntimeBuilder;
use common_telemetry::info;
use servers::auth::UserProviderRef;
use servers::grpc::GrpcServer;
use servers::http::HttpServer;
use servers::mysql::server::MysqlServer;
@@ -35,7 +36,11 @@ use crate::prometheus::PrometheusOptions;
pub(crate) struct Services;
impl Services {
pub(crate) async fn start<T>(opts: &FrontendOptions, instance: Arc<T>) -> Result<()>
pub(crate) async fn start<T>(
opts: &FrontendOptions,
instance: Arc<T>,
user_provider: Option<UserProviderRef>,
) -> Result<()>
where
T: FrontendInstance,
{
@@ -69,8 +74,12 @@ impl Services {
.context(error::RuntimeResourceSnafu)?,
);
let mysql_server =
MysqlServer::create_server(instance.clone(), mysql_io_runtime, opts.tls.clone());
let mysql_server = MysqlServer::create_server(
instance.clone(),
mysql_io_runtime,
opts.tls.clone(),
user_provider.clone(),
);
Some((mysql_server, mysql_addr))
} else {
@@ -90,9 +99,9 @@ impl Services {
let pg_server = Box::new(PostgresServer::new(
instance.clone(),
opts.check_pwd,
opts.tls.clone(),
pg_io_runtime,
user_provider.clone(),
)) as Box<dyn Server>;
Some((pg_server, pg_addr))
@@ -122,6 +131,10 @@ impl Services {
let http_addr = parse_addr(&http_options.addr)?;
let mut http_server = HttpServer::new(instance.clone(), http_options.clone());
if let Some(user_provider) = user_provider {
http_server.set_user_provider(user_provider);
}
if opentsdb_server_and_addr.is_some() {
http_server.set_opentsdb_handler(instance.clone());
}

View File

@@ -506,43 +506,34 @@ impl PartitionExec {
}
}
// FIXME(LFC): no allow, for clippy temporarily
#[allow(clippy::print_stdout)]
#[cfg(test)]
mod test {
use std::time::Duration;
use api::v1::column::SemanticType;
use api::v1::{column, Column, ColumnDataType};
use catalog::remote::MetaKvBackend;
use common_query::physical_plan::DfPhysicalPlanAdapter;
use common_recordbatch::adapter::RecordBatchStreamAdapter;
use common_recordbatch::util;
use datafusion::arrow_print;
use datafusion::execution::context::TaskContext;
use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::expressions::{col as physical_col, PhysicalSortExpr};
use datafusion::physical_plan::sorts::sort::SortExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_expr::expr_fn::{and, binary_expr, col, or};
use datafusion_expr::lit;
use datanode::datanode::{DatanodeOptions, ObjectStoreConfig};
use datanode::instance::Instance;
use datatypes::arrow::compute::sort::SortOptions;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use meta_client::client::{MetaClient, MetaClientBuilder};
use meta_client::client::MetaClient;
use meta_client::rpc::router::RegionRoute;
use meta_client::rpc::{Region, Table, TableRoute};
use meta_srv::metasrv::MetaSrvOptions;
use meta_srv::mocks::MockInfo;
use meta_srv::service::store::kv::KvStoreRef;
use meta_srv::service::store::memory::MemStore;
use sql::dialect::GenericDialect;
use sql::parser::ParserContext;
use sql::statements::statement::Statement;
use table::metadata::{TableInfoBuilder, TableMetaBuilder};
use table::TableRef;
use tempdir::TempDir;
use super::*;
use crate::catalog::FrontendCatalogManager;
use crate::expr_factory::{CreateExprFactory, DefaultCreateExprFactory};
use crate::instance::distributed::DistInstance;
use crate::partitioning::range::RangePartitionRule;
#[tokio::test(flavor = "multi_thread")]
@@ -743,28 +734,77 @@ mod test {
#[tokio::test(flavor = "multi_thread")]
async fn test_dist_table_scan() {
common_telemetry::init_default_ut_logging();
let table = Arc::new(new_dist_table().await);
// should scan all regions
// select * from numbers
let projection = None;
// select a, row_id from numbers
let projection = Some(vec![1, 2]);
let filters = vec![];
exec_table_scan(table.clone(), projection, filters, None).await;
println!();
let expected_output = vec![
"+-----+--------+",
"| a | row_id |",
"+-----+--------+",
"| 0 | 1 |",
"| 1 | 2 |",
"| 2 | 3 |",
"| 3 | 4 |",
"| 4 | 5 |",
"| 10 | 1 |",
"| 11 | 2 |",
"| 12 | 3 |",
"| 13 | 4 |",
"| 14 | 5 |",
"| 30 | 1 |",
"| 31 | 2 |",
"| 32 | 3 |",
"| 33 | 4 |",
"| 34 | 5 |",
"| 100 | 1 |",
"| 101 | 2 |",
"| 102 | 3 |",
"| 103 | 4 |",
"| 104 | 5 |",
"+-----+--------+",
];
exec_table_scan(table.clone(), projection, filters, 4, expected_output).await;
// should scan only region 1
// select a, row_id from numbers where a < 10
let projection = Some(vec![1, 2]);
let filters = vec![binary_expr(col("a"), Operator::Lt, lit(10)).into()];
exec_table_scan(table.clone(), projection, filters, None).await;
println!();
let expected_output = vec![
"+---+--------+",
"| a | row_id |",
"+---+--------+",
"| 0 | 1 |",
"| 1 | 2 |",
"| 2 | 3 |",
"| 3 | 4 |",
"| 4 | 5 |",
"+---+--------+",
];
exec_table_scan(table.clone(), projection, filters, 1, expected_output).await;
// should scan region 1 and 2
// select a, row_id from numbers where a < 15
let projection = Some(vec![1, 2]);
let filters = vec![binary_expr(col("a"), Operator::Lt, lit(15)).into()];
exec_table_scan(table.clone(), projection, filters, None).await;
println!();
let expected_output = vec![
"+----+--------+",
"| a | row_id |",
"+----+--------+",
"| 0 | 1 |",
"| 1 | 2 |",
"| 2 | 3 |",
"| 3 | 4 |",
"| 4 | 5 |",
"| 10 | 1 |",
"| 11 | 2 |",
"| 12 | 3 |",
"| 13 | 4 |",
"| 14 | 5 |",
"+----+--------+",
];
exec_table_scan(table.clone(), projection, filters, 2, expected_output).await;
// should scan region 2 and 3
// select a, row_id from numbers where a < 40 and a >= 10
@@ -774,8 +814,23 @@ mod test {
binary_expr(col("a"), Operator::GtEq, lit(10)),
)
.into()];
exec_table_scan(table.clone(), projection, filters, None).await;
println!();
let expected_output = vec![
"+----+--------+",
"| a | row_id |",
"+----+--------+",
"| 10 | 1 |",
"| 11 | 2 |",
"| 12 | 3 |",
"| 13 | 4 |",
"| 14 | 5 |",
"| 30 | 1 |",
"| 31 | 2 |",
"| 32 | 3 |",
"| 33 | 4 |",
"| 34 | 5 |",
"+----+--------+",
];
exec_table_scan(table.clone(), projection, filters, 2, expected_output).await;
// should scan all regions
// select a, row_id from numbers where a < 1000 and row_id == 1
@@ -785,42 +840,59 @@ mod test {
binary_expr(col("row_id"), Operator::Eq, lit(1)),
)
.into()];
exec_table_scan(table.clone(), projection, filters, None).await;
let expected_output = vec![
"+-----+--------+",
"| a | row_id |",
"+-----+--------+",
"| 0 | 1 |",
"| 10 | 1 |",
"| 30 | 1 |",
"| 100 | 1 |",
"+-----+--------+",
];
exec_table_scan(table.clone(), projection, filters, 4, expected_output).await;
}
async fn exec_table_scan(
table: TableRef,
projection: Option<Vec<usize>>,
filters: Vec<Expr>,
limit: Option<usize>,
expected_partitions: usize,
expected_output: Vec<&str>,
) {
let table_scan = table
.scan(&projection, filters.as_slice(), limit)
.scan(&projection, filters.as_slice(), None)
.await
.unwrap();
assert_eq!(
table_scan.output_partitioning().partition_count(),
expected_partitions
);
let task_ctx = Arc::new(TaskContext::new(
"0".to_string(),
"0".to_string(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
Arc::new(RuntimeEnv::default()),
));
for partition in 0..table_scan.output_partitioning().partition_count() {
let result = table_scan.execute(partition, task_ctx.clone()).unwrap();
let recordbatches = util::collect(result).await.unwrap();
let merge =
CoalescePartitionsExec::new(Arc::new(DfPhysicalPlanAdapter(table_scan.clone())));
let df_recordbatch = recordbatches
.into_iter()
.map(|r| r.df_recordbatch)
.collect::<Vec<DfRecordBatch>>();
let sort = SortExec::try_new(
vec![PhysicalSortExpr {
expr: physical_col("a", table_scan.schema().arrow_schema()).unwrap(),
options: SortOptions::default(),
}],
Arc::new(merge),
)
.unwrap();
assert_eq!(sort.output_partitioning().partition_count(), 1);
println!("DataFusion partition {}:", partition);
let pretty_print = arrow_print::write(&df_recordbatch);
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
pretty_print.iter().for_each(|x| println!("{}", x));
}
let stream = sort
.execute(0, Arc::new(RuntimeEnv::default()))
.await
.unwrap();
let stream = Box::pin(RecordBatchStreamAdapter::try_new(stream).unwrap());
let recordbatches = RecordBatches::try_collect(stream).await.unwrap();
assert_eq!(
recordbatches.pretty_print().lines().collect::<Vec<_>>(),
expected_output
);
}
async fn new_dist_table() -> DistTable {
@@ -831,52 +903,13 @@ mod test {
];
let schema = Arc::new(Schema::new(column_schemas.clone()));
let kv_store: KvStoreRef = Arc::new(MemStore::default()) as _;
let meta_srv =
meta_srv::mocks::mock(MetaSrvOptions::default(), kv_store.clone(), None).await;
let datanode_clients = Arc::new(DatanodeClients::new());
let mut datanode_instances = HashMap::new();
for datanode_id in 1..=4 {
let dn_instance = create_datanode_instance(datanode_id, meta_srv.clone()).await;
datanode_instances.insert(datanode_id, dn_instance.clone());
let (addr, client) = crate::tests::create_datanode_client(dn_instance).await;
datanode_clients
.insert_client(Peer::new(datanode_id, addr), client)
.await;
}
let MockInfo {
server_addr,
channel_manager,
} = meta_srv.clone();
let mut meta_client = MetaClientBuilder::new(1000, 0)
.enable_router()
.enable_store()
.channel_manager(channel_manager)
.build();
meta_client.start(&[&server_addr]).await.unwrap();
let meta_client = Arc::new(meta_client);
let (dist_instance, datanode_instances) = crate::tests::create_dist_instance().await;
let catalog_manager = dist_instance.catalog_manager();
let table_routes = catalog_manager.table_routes();
let datanode_clients = catalog_manager.datanode_clients();
let table_name = TableName::new("greptime", "public", "dist_numbers");
let meta_backend = Arc::new(MetaKvBackend {
client: meta_client.clone(),
});
let table_routes = Arc::new(TableRoutes::new(meta_client.clone()));
let catalog_manager = Arc::new(FrontendCatalogManager::new(
meta_backend,
table_routes.clone(),
datanode_clients.clone(),
));
let dist_instance = DistInstance::new(
meta_client.clone(),
catalog_manager,
datanode_clients.clone(),
);
let sql = "
CREATE TABLE greptime.public.dist_numbers (
ts BIGINT,
@@ -900,17 +933,16 @@ mod test {
_ => unreachable!(),
};
wait_datanodes_alive(kv_store).await;
let factory = DefaultCreateExprFactory {};
let mut expr = factory.create_expr_by_stmt(&create_table).await.unwrap();
let mut expr = DefaultCreateExprFactory
.create_expr_by_stmt(&create_table)
.await
.unwrap();
let _result = dist_instance
.create_table(&mut expr, create_table.partitions)
.await
.unwrap();
let table_route = table_routes.get_route(&table_name).await.unwrap();
println!("{}", serde_json::to_string_pretty(&table_route).unwrap());
let mut region_to_datanode_mapping = HashMap::new();
for region_route in table_route.region_routes.iter() {
@@ -955,20 +987,6 @@ mod test {
}
}
async fn wait_datanodes_alive(kv_store: KvStoreRef) {
let wait = 10;
for _ in 0..wait {
let datanodes = meta_srv::lease::alive_datanodes(1000, &kv_store, |_, _| true)
.await
.unwrap();
if datanodes.len() >= 4 {
return;
}
tokio::time::sleep(Duration::from_secs(1)).await
}
panic!()
}
async fn insert_testing_data(
table_name: &TableName,
dn_instance: Arc<Instance>,
@@ -1020,30 +1038,6 @@ mod test {
.unwrap();
}
async fn create_datanode_instance(datanode_id: u64, meta_srv: MockInfo) -> Arc<Instance> {
let current = common_time::util::current_time_millis();
let wal_tmp_dir =
TempDir::new_in("/tmp", &format!("dist_table_test-wal-{}", current)).unwrap();
let data_tmp_dir =
TempDir::new_in("/tmp", &format!("dist_table_test-data-{}", current)).unwrap();
let opts = DatanodeOptions {
node_id: Some(datanode_id),
wal_dir: wal_tmp_dir.path().to_str().unwrap().to_string(),
storage: ObjectStoreConfig::File {
data_dir: data_tmp_dir.path().to_str().unwrap().to_string(),
},
..Default::default()
};
let instance = Arc::new(
Instance::with_mock_meta_server(&opts, meta_srv)
.await
.unwrap(),
);
instance.start().await.unwrap();
instance
}
#[tokio::test(flavor = "multi_thread")]
async fn test_find_regions() {
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(

View File

@@ -12,34 +12,70 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use catalog::remote::MetaKvBackend;
use client::Client;
use common_grpc::channel_manager::ChannelManager;
use common_runtime::Builder as RuntimeBuilder;
use datanode::datanode::{DatanodeOptions, ObjectStoreConfig};
use datanode::instance::Instance as DatanodeInstance;
use meta_client::client::MetaClientBuilder;
use meta_client::rpc::Peer;
use meta_srv::metasrv::MetaSrvOptions;
use meta_srv::mocks::MockInfo;
use meta_srv::service::store::kv::KvStoreRef;
use meta_srv::service::store::memory::MemStore;
use servers::grpc::GrpcServer;
use servers::Mode;
use tempdir::TempDir;
use tonic::transport::Server;
use tower::service_fn;
use crate::catalog::FrontendCatalogManager;
use crate::datanode::DatanodeClients;
use crate::instance::distributed::DistInstance;
use crate::instance::Instance;
use crate::table::route::TableRoutes;
async fn create_datanode_instance() -> Arc<DatanodeInstance> {
// TODO(LFC) Use real Mito engine when we can alter its region schema,
// and delete the `new_mock` method.
let instance = Arc::new(DatanodeInstance::new_mock().await.unwrap());
instance.start().await.unwrap();
instance
/// Guard against the `TempDir`s that used in unit tests.
/// (The `TempDir` will be deleted once it goes out of scope.)
pub struct TestGuard {
_wal_tmp_dir: TempDir,
_data_tmp_dir: TempDir,
}
pub(crate) async fn create_frontend_instance() -> Arc<Instance> {
let datanode_instance: Arc<DatanodeInstance> = create_datanode_instance().await;
let dn_catalog_manager = datanode_instance.catalog_manager().clone();
let (_, client) = create_datanode_client(datanode_instance).await;
Arc::new(Instance::with_client_and_catalog_manager(
client,
dn_catalog_manager,
))
pub(crate) async fn create_frontend_instance(test_name: &str) -> (Arc<Instance>, TestGuard) {
let (opts, guard) = create_tmp_dir_and_datanode_opts(test_name);
let datanode_instance = DatanodeInstance::with_mock_meta_client(&opts)
.await
.unwrap();
datanode_instance.start().await.unwrap();
let frontend_instance = Instance::new_standalone(Arc::new(datanode_instance));
(Arc::new(frontend_instance), guard)
}
fn create_tmp_dir_and_datanode_opts(name: &str) -> (DatanodeOptions, TestGuard) {
let wal_tmp_dir = TempDir::new(&format!("gt_wal_{}", name)).unwrap();
let data_tmp_dir = TempDir::new(&format!("gt_data_{}", name)).unwrap();
let opts = DatanodeOptions {
wal_dir: wal_tmp_dir.path().to_str().unwrap().to_string(),
storage: ObjectStoreConfig::File {
data_dir: data_tmp_dir.path().to_str().unwrap().to_string(),
},
mode: Mode::Standalone,
..Default::default()
};
(
opts,
TestGuard {
_wal_tmp_dir: wal_tmp_dir,
_data_tmp_dir: data_tmp_dir,
},
)
}
pub(crate) async fn create_datanode_client(
@@ -96,3 +132,91 @@ pub(crate) async fn create_datanode_client(
Client::with_manager_and_urls(channel_manager, vec![addr]),
)
}
async fn create_dist_datanode_instance(
datanode_id: u64,
meta_srv: MockInfo,
) -> Arc<DatanodeInstance> {
let current = common_time::util::current_time_millis();
let wal_tmp_dir = TempDir::new_in("/tmp", &format!("dist_datanode-wal-{}", current)).unwrap();
let data_tmp_dir = TempDir::new_in("/tmp", &format!("dist_datanode-data-{}", current)).unwrap();
let opts = DatanodeOptions {
node_id: Some(datanode_id),
wal_dir: wal_tmp_dir.path().to_str().unwrap().to_string(),
storage: ObjectStoreConfig::File {
data_dir: data_tmp_dir.path().to_str().unwrap().to_string(),
},
..Default::default()
};
let instance = Arc::new(
DatanodeInstance::with_mock_meta_server(&opts, meta_srv)
.await
.unwrap(),
);
instance.start().await.unwrap();
instance
}
async fn wait_datanodes_alive(kv_store: KvStoreRef) {
let wait = 10;
for _ in 0..wait {
let datanodes = meta_srv::lease::alive_datanodes(1000, &kv_store, |_, _| true)
.await
.unwrap();
if datanodes.len() >= 4 {
return;
}
tokio::time::sleep(Duration::from_secs(1)).await
}
panic!()
}
pub(crate) async fn create_dist_instance() -> (DistInstance, HashMap<u64, Arc<DatanodeInstance>>) {
let kv_store: KvStoreRef = Arc::new(MemStore::default()) as _;
let meta_srv = meta_srv::mocks::mock(MetaSrvOptions::default(), kv_store.clone(), None).await;
let datanode_clients = Arc::new(DatanodeClients::new());
let mut datanode_instances = HashMap::new();
for datanode_id in 1..=4 {
let dn_instance = create_dist_datanode_instance(datanode_id, meta_srv.clone()).await;
datanode_instances.insert(datanode_id, dn_instance.clone());
let (addr, client) = create_datanode_client(dn_instance).await;
datanode_clients
.insert_client(Peer::new(datanode_id, addr), client)
.await;
}
let MockInfo {
server_addr,
channel_manager,
} = meta_srv.clone();
let mut meta_client = MetaClientBuilder::new(1000, 0)
.enable_router()
.enable_store()
.channel_manager(channel_manager)
.build();
meta_client.start(&[&server_addr]).await.unwrap();
let meta_client = Arc::new(meta_client);
let meta_backend = Arc::new(MetaKvBackend {
client: meta_client.clone(),
});
let table_routes = Arc::new(TableRoutes::new(meta_client.clone()));
let catalog_manager = Arc::new(FrontendCatalogManager::new(
meta_backend,
table_routes.clone(),
datanode_clients.clone(),
));
wait_datanodes_alive(kv_store).await;
let dist_instance = DistInstance::new(
meta_client.clone(),
catalog_manager,
datanode_clients.clone(),
);
(dist_instance, datanode_instances)
}

View File

@@ -23,6 +23,7 @@ snafu = { version = "0.7", features = ["backtraces"] }
store-api = { path = "../store-api" }
tempdir = "0.3"
tokio = { version = "1.18", features = ["full"] }
tokio-util = "0.7"
[dev-dependencies]
rand = "0.8"

View File

@@ -17,6 +17,7 @@ use std::any::Any;
use common_error::ext::BoxedError;
use common_error::prelude::{ErrorExt, Snafu};
use snafu::{Backtrace, ErrorCompat};
use tokio::task::JoinError;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
@@ -89,6 +90,15 @@ pub enum Error {
#[snafu(display("Failed while waiting for write to finish, source: {}", source))]
WaitWrite { source: tokio::task::JoinError },
#[snafu(display("Invalid logstore status, msg: {}", msg))]
InvalidState { msg: String, backtrace: Backtrace },
#[snafu(display("Failed to wait for gc task to stop, source: {}", source))]
WaitGcTaskStop {
source: JoinError,
backtrace: Backtrace,
},
}
impl ErrorExt for Error {

View File

@@ -12,11 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct LogConfig {
pub append_buffer_size: usize,
pub max_log_file_size: usize,
pub log_file_dir: String,
pub gc_interval: Duration,
}
impl Default for LogConfig {
@@ -27,6 +30,7 @@ impl Default for LogConfig {
append_buffer_size: 128,
max_log_file_size: 1024 * 1024 * 1024,
log_file_dir: "/tmp/greptimedb".to_string(),
gc_interval: Duration::from_secs(10 * 60),
}
}
}
@@ -44,5 +48,6 @@ mod tests {
info!("LogConfig::default(): {:?}", default);
assert_eq!(1024 * 1024 * 1024, default.max_log_file_size);
assert_eq!(128, default.append_buffer_size);
assert_eq!(Duration::from_secs(600), default.gc_interval);
}
}

View File

@@ -55,11 +55,12 @@ const LOG_WRITER_BATCH_SIZE: usize = 16;
/// Wraps File operation to get rid of `&mut self` requirements
struct FileWriter {
inner: Arc<File>,
path: String,
}
impl FileWriter {
pub fn new(file: Arc<File>) -> Self {
Self { inner: file }
pub fn new(file: Arc<File>, path: String) -> Self {
Self { inner: file, path }
}
pub async fn write(&self, data: Bytes, offset: u64) -> Result<()> {
@@ -100,6 +101,11 @@ impl FileWriter {
.await
.context(WaitWriteSnafu)?
}
pub async fn destroy(&self) -> Result<()> {
tokio::fs::remove_file(&self.path).await.context(IoSnafu)?;
Ok(())
}
}
pub type LogFileRef = Arc<LogFile>;
@@ -128,7 +134,7 @@ pub struct LogFile {
impl Drop for LogFile {
fn drop(&mut self) {
self.state.stopped.store(true, Ordering::Relaxed);
info!("Stopping log file {}", self.name);
info!("Dropping log file {}", self.name);
}
}
@@ -143,12 +149,12 @@ impl LogFile {
.open(path.clone())
.context(OpenLogSnafu { file_name: &path })?;
let file_name: FileName = path.as_str().try_into()?;
let file_name = FileName::try_from(path.as_str())?;
let start_entry_id = file_name.entry_id();
let mut log = Self {
name: file_name,
writer: Arc::new(FileWriter::new(Arc::new(file))),
writer: Arc::new(FileWriter::new(Arc::new(file), path.clone())),
start_entry_id,
pending_request_tx: None,
notify: Arc::new(Notify::new()),
@@ -243,6 +249,11 @@ impl LogFile {
res
}
pub async fn destroy(&self) -> Result<()> {
self.writer.destroy().await?;
Ok(())
}
async fn handle_batch(
mut batch: Vec<AppendRequest>,
state: &Arc<State>,
@@ -477,6 +488,11 @@ impl LogFile {
self.state.sealed.load(Ordering::Acquire)
}
#[inline]
pub fn is_stopped(&self) -> bool {
self.state.stopped.load(Ordering::Acquire)
}
#[inline]
pub fn unseal(&self) {
self.state.sealed.store(false, Ordering::Release);

View File

@@ -12,24 +12,26 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::BTreeMap;
use std::collections::{BTreeMap, HashMap};
use std::path::Path;
use std::sync::Arc;
use arc_swap::ArcSwap;
use async_stream::stream;
use common_telemetry::{error, info, warn};
use common_telemetry::{debug, error, info, warn};
use futures::{pin_mut, StreamExt};
use snafu::{OptionExt, ResultExt};
use store_api::logstore::entry::{Encode, Entry, Id};
use store_api::logstore::entry_stream::SendableEntryStream;
use store_api::logstore::namespace::{Id as NamespaceId, Namespace};
use store_api::logstore::LogStore;
use tokio::sync::RwLock;
use tokio::sync::{Mutex, RwLock};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::error::{
CreateDirSnafu, DuplicateFileSnafu, Error, FileNameIllegalSnafu, InternalSnafu, IoSnafu,
ReadPathSnafu, Result,
CreateDirSnafu, DuplicateFileSnafu, Error, FileNameIllegalSnafu, InternalSnafu,
InvalidStateSnafu, IoSnafu, ReadPathSnafu, Result, WaitGcTaskStopSnafu,
};
use crate::fs::config::LogConfig;
use crate::fs::entry::EntryImpl;
@@ -42,9 +44,12 @@ type FileMap = BTreeMap<u64, LogFileRef>;
#[derive(Debug)]
pub struct LocalFileLogStore {
files: RwLock<FileMap>,
files: Arc<RwLock<FileMap>>,
active: ArcSwap<LogFile>,
config: LogConfig,
obsolete_ids: Arc<RwLock<HashMap<LocalNamespace, u64>>>,
cancel_token: Mutex<Option<CancellationToken>>,
gc_task_handle: Mutex<Option<JoinHandle<()>>>,
}
impl LocalFileLogStore {
@@ -101,9 +106,12 @@ impl LocalFileLogStore {
let active_file_cloned = active_file.clone();
Ok(Self {
files: RwLock::new(files),
files: Arc::new(RwLock::new(files)),
active: ArcSwap::new(active_file_cloned),
config: config.clone(),
obsolete_ids: Arc::new(Default::default()),
cancel_token: Mutex::new(None),
gc_task_handle: Mutex::new(None),
})
}
@@ -185,6 +193,60 @@ impl LocalFileLogStore {
}
}
async fn gc(
files: Arc<RwLock<FileMap>>,
obsolete_ids: Arc<RwLock<HashMap<LocalNamespace, u64>>>,
) -> Result<()> {
if let Some(lowest) = find_lowest_id(obsolete_ids).await {
gc_inner(files, lowest).await
} else {
Ok(())
}
}
async fn find_lowest_id(obsolete_ids: Arc<RwLock<HashMap<LocalNamespace, u64>>>) -> Option<u64> {
let mut lowest_obsolete = None;
{
let obsolete_ids = obsolete_ids.read().await;
for (ns, id) in obsolete_ids.iter() {
if *id <= *lowest_obsolete.get_or_insert(*id) {
lowest_obsolete = Some(*id);
debug!("Current lowest obsolete id: {}, namespace: {:?}", *id, ns);
}
}
}
lowest_obsolete
}
async fn gc_inner(files: Arc<RwLock<FileMap>>, obsolete_id: u64) -> Result<()> {
let mut files = files.write().await;
let files_to_delete = find_files_to_delete(&files, obsolete_id);
info!(
"Compacting log file up to entry id: {}, files to delete: {:?}",
obsolete_id, files_to_delete
);
for entry_id in files_to_delete {
if let Some(f) = files.remove(&entry_id) {
if !f.is_stopped() {
f.stop().await?;
}
f.destroy().await?;
info!("Destroyed log file: {}", f.file_name());
}
}
Ok(())
}
fn find_files_to_delete<T>(offset_map: &BTreeMap<u64, T>, entry_id: u64) -> Vec<u64> {
let mut res = vec![];
for (cur, next) in offset_map.keys().zip(offset_map.keys().skip(1)) {
if *cur < entry_id && *next <= entry_id {
res.push(*cur);
}
}
res
}
#[async_trait::async_trait]
impl LogStore for LocalFileLogStore {
type Error = Error;
@@ -192,6 +254,55 @@ impl LogStore for LocalFileLogStore {
type Entry = EntryImpl;
type AppendResponse = AppendResponseImpl;
async fn start(&self) -> Result<()> {
let files = self.files.clone();
let obsolete_ids = self.obsolete_ids.clone();
let interval = self.config.gc_interval;
let token = tokio_util::sync::CancellationToken::new();
let child = token.child_token();
let handle = common_runtime::spawn_bg(async move {
loop {
if let Err(e) = gc(files.clone(), obsolete_ids.clone()).await {
error!(e; "Failed to gc log store");
}
tokio::select! {
_ = tokio::time::sleep(interval) => {}
_ = child.cancelled() => {
info!("LogStore gc task has been cancelled");
return;
}
}
}
});
*self.gc_task_handle.lock().await = Some(handle);
*self.cancel_token.lock().await = Some(token);
Ok(())
}
async fn stop(&self) -> Result<()> {
let handle = self
.gc_task_handle
.lock()
.await
.take()
.context(InvalidStateSnafu {
msg: "Logstore gc task not spawned",
})?;
let token = self
.cancel_token
.lock()
.await
.take()
.context(InvalidStateSnafu {
msg: "Logstore gc task not spawned",
})?;
token.cancel();
Ok(handle.await.context(WaitGcTaskStopSnafu)?)
}
async fn append(&self, mut entry: Self::Entry) -> Result<Self::AppendResponse> {
// TODO(hl): configurable retry times
for _ in 0..3 {
@@ -280,10 +391,25 @@ impl LogStore for LocalFileLogStore {
fn namespace(&self, id: NamespaceId) -> Self::Namespace {
LocalNamespace::new(id)
}
async fn obsolete(
&self,
namespace: Self::Namespace,
id: Id,
) -> std::result::Result<(), Self::Error> {
info!("Mark namespace obsolete entry id, {:?}:{}", namespace, id);
let mut map = self.obsolete_ids.write().await;
let prev = map.insert(namespace, id);
info!("Prev: {:?}", prev);
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use std::time::Duration;
use futures_util::StreamExt;
use rand::distributions::Alphanumeric;
use rand::Rng;
@@ -300,6 +426,7 @@ mod tests {
append_buffer_size: 128,
max_log_file_size: 128,
log_file_dir: dir.path().to_str().unwrap().to_string(),
..Default::default()
};
let logstore = LocalFileLogStore::open(&config).await.unwrap();
@@ -351,6 +478,7 @@ mod tests {
append_buffer_size: 128,
max_log_file_size: 128,
log_file_dir: dir.path().to_str().unwrap().to_string(),
..Default::default()
};
let logstore = LocalFileLogStore::open(&config).await.unwrap();
let ns = LocalNamespace::new(42);
@@ -382,6 +510,7 @@ mod tests {
append_buffer_size: 128,
max_log_file_size: 1024 * 1024,
log_file_dir: dir.path().to_str().unwrap().to_string(),
..Default::default()
};
let logstore = LocalFileLogStore::open(&config).await.unwrap();
assert_eq!(
@@ -426,4 +555,217 @@ mod tests {
assert_eq!(entries[0].id(), 1);
assert_eq!(43, entries[0].namespace_id);
}
#[test]
fn test_find_files_to_delete() {
let file_map = vec![(1u64, ()), (11u64, ()), (21u64, ()), (31u64, ())]
.into_iter()
.collect::<BTreeMap<u64, ()>>();
assert!(find_files_to_delete(&file_map, 0).is_empty());
assert!(find_files_to_delete(&file_map, 1).is_empty());
assert!(find_files_to_delete(&file_map, 2).is_empty());
assert!(find_files_to_delete(&file_map, 10).is_empty());
assert_eq!(vec![1], find_files_to_delete(&file_map, 11));
assert_eq!(vec![1], find_files_to_delete(&file_map, 20));
assert_eq!(vec![1, 11], find_files_to_delete(&file_map, 21));
assert_eq!(vec![1, 11, 21], find_files_to_delete(&file_map, 31));
assert_eq!(vec![1, 11, 21], find_files_to_delete(&file_map, 100));
}
#[tokio::test]
async fn test_find_lowest_id() {
common_telemetry::logging::init_default_ut_logging();
let dir = TempDir::new("greptimedb-log-compact").unwrap();
let config = LogConfig {
append_buffer_size: 128,
max_log_file_size: 4096,
log_file_dir: dir.path().to_str().unwrap().to_string(),
..Default::default()
};
let logstore = LocalFileLogStore::open(&config).await.unwrap();
assert!(find_lowest_id(logstore.obsolete_ids.clone())
.await
.is_none());
logstore
.obsolete(LocalNamespace::new(1), 100)
.await
.unwrap();
assert_eq!(
Some(100),
find_lowest_id(logstore.obsolete_ids.clone()).await
);
logstore
.obsolete(LocalNamespace::new(2), 200)
.await
.unwrap();
assert_eq!(
Some(100),
find_lowest_id(logstore.obsolete_ids.clone()).await
);
logstore
.obsolete(LocalNamespace::new(1), 101)
.await
.unwrap();
assert_eq!(
Some(101),
find_lowest_id(logstore.obsolete_ids.clone()).await
);
logstore
.obsolete(LocalNamespace::new(2), 202)
.await
.unwrap();
assert_eq!(
Some(101),
find_lowest_id(logstore.obsolete_ids.clone()).await
);
logstore
.obsolete(LocalNamespace::new(1), 300)
.await
.unwrap();
assert_eq!(
Some(202),
find_lowest_id(logstore.obsolete_ids.clone()).await
);
}
#[tokio::test]
async fn test_compact_log_file() {
common_telemetry::logging::init_default_ut_logging();
let dir = TempDir::new("greptimedb-log-compact").unwrap();
let config = LogConfig {
append_buffer_size: 128,
max_log_file_size: 4096,
log_file_dir: dir.path().to_str().unwrap().to_string(),
..Default::default()
};
let logstore = LocalFileLogStore::open(&config).await.unwrap();
for id in 0..50 {
logstore
.append(EntryImpl::new(
generate_data(990),
id,
LocalNamespace::new(42),
))
.await
.unwrap();
}
assert_eq!(
vec![0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48],
logstore
.files
.read()
.await
.keys()
.copied()
.collect::<Vec<_>>()
);
gc_inner(logstore.files.clone(), 10).await.unwrap();
assert_eq!(
vec![8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48],
logstore
.files
.read()
.await
.keys()
.copied()
.collect::<Vec<_>>()
);
gc_inner(logstore.files.clone(), 28).await.unwrap();
assert_eq!(
vec![28, 32, 36, 40, 44, 48],
logstore
.files
.read()
.await
.keys()
.copied()
.collect::<Vec<_>>()
);
gc_inner(logstore.files.clone(), 50).await.unwrap();
assert_eq!(
vec![48],
logstore
.files
.read()
.await
.keys()
.copied()
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn test_gc_task() {
common_telemetry::logging::init_default_ut_logging();
let dir = TempDir::new("greptimedb-log-compact").unwrap();
let config = LogConfig {
append_buffer_size: 128,
max_log_file_size: 4096,
log_file_dir: dir.path().to_str().unwrap().to_string(),
gc_interval: Duration::from_millis(100),
};
let logstore = LocalFileLogStore::open(&config).await.unwrap();
logstore.start().await.unwrap();
for id in 0..50 {
logstore
.append(EntryImpl::new(
generate_data(990),
id,
LocalNamespace::new(42),
))
.await
.unwrap();
}
logstore
.obsolete(LocalNamespace::new(42), 30)
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(150)).await;
let file_ids = logstore
.files
.read()
.await
.keys()
.cloned()
.collect::<Vec<_>>();
assert_eq!(vec![28, 32, 36, 40, 44, 48], file_ids);
let mut files = vec![];
let mut readir = tokio::fs::read_dir(dir.path()).await.unwrap();
while let Some(r) = readir.next_entry().await.transpose() {
let entry = r.unwrap();
files.push(entry.file_name().to_str().unwrap().to_string());
}
assert_eq!(
vec![
"00000000000000000028.log".to_string(),
"00000000000000000048.log".to_string(),
"00000000000000000040.log".to_string(),
"00000000000000000044.log".to_string(),
"00000000000000000036.log".to_string(),
"00000000000000000032.log".to_string()
]
.into_iter()
.collect::<HashSet<String>>(),
files.into_iter().collect::<HashSet<String>>()
);
}
}

View File

@@ -14,7 +14,7 @@
use store_api::logstore::namespace::{Id, Namespace};
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct LocalNamespace {
pub(crate) id: Id,
}

View File

@@ -33,6 +33,14 @@ impl LogStore for NoopLogStore {
type Entry = EntryImpl;
type AppendResponse = AppendResponseImpl;
async fn start(&self) -> Result<()> {
Ok(())
}
async fn stop(&self) -> Result<()> {
Ok(())
}
async fn append(&self, mut _e: Self::Entry) -> Result<Self::AppendResponse> {
Ok(AppendResponseImpl {
entry_id: 0,
@@ -72,4 +80,14 @@ impl LogStore for NoopLogStore {
fn namespace(&self, id: NamespaceId) -> Self::Namespace {
LocalNamespace::new(id)
}
async fn obsolete(
&self,
namespace: Self::Namespace,
id: Id,
) -> std::result::Result<(), Self::Error> {
let _ = namespace;
let _ = id;
Ok(())
}
}

View File

@@ -25,6 +25,7 @@ pub async fn create_tmp_local_file_log_store(dir: &str) -> (LocalFileLogStore, T
append_buffer_size: 128,
max_log_file_size: 128,
log_file_dir: dir.path().to_str().unwrap().to_string(),
..Default::default()
};
(LocalFileLogStore::open(&cfg).await.unwrap(), dir)

View File

@@ -27,10 +27,11 @@ use store::Client as StoreClient;
pub use self::heartbeat::{HeartbeatSender, HeartbeatStream};
use crate::error;
use crate::error::Result;
use crate::rpc::router::DeleteRequest;
use crate::rpc::{
BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse, CreateRequest,
DeleteRangeRequest, DeleteRangeResponse, PutRequest, PutResponse, RangeRequest, RangeResponse,
RouteRequest, RouteResponse,
DeleteRangeRequest, DeleteRangeResponse, MoveValueRequest, MoveValueResponse, PutRequest,
PutResponse, RangeRequest, RangeResponse, RouteRequest, RouteResponse,
};
pub type Id = (u64, u64);
@@ -206,6 +207,13 @@ impl MetaClient {
self.router_client()?.route(req.into()).await?.try_into()
}
/// Can be called repeatedly, the first call will delete and return the
/// table of routing information, the nth call can still return the
/// deleted route information.
pub async fn delete_route(&self, req: DeleteRequest) -> Result<RouteResponse> {
self.router_client()?.delete(req.into()).await?.try_into()
}
/// Range gets the keys in the range from the key-value store.
pub async fn range(&self, req: RangeRequest) -> Result<RangeResponse> {
self.store_client()?.range(req.into()).await?.try_into()
@@ -241,6 +249,14 @@ impl MetaClient {
.try_into()
}
/// MoveValue atomically renames the key to the given updated key.
pub async fn move_value(&self, req: MoveValueRequest) -> Result<MoveValueResponse> {
self.store_client()?
.move_value(req.into())
.await?
.try_into()
}
#[inline]
pub fn heartbeat_client(&self) -> Result<HeartbeatClient> {
self.heartbeat.clone().context(error::NotStartedSnafu {
@@ -286,6 +302,52 @@ mod tests {
use crate::mocks;
use crate::rpc::{Partition, TableName};
const TEST_KEY_PREFIX: &str = "__unit_test__meta__";
struct TestClient {
ns: String,
client: MetaClient,
}
impl TestClient {
async fn new(ns: impl Into<String>) -> Self {
// can also test with etcd: mocks::mock_client_with_etcdstore("127.0.0.1:2379").await;
let client = mocks::mock_client_with_memstore().await;
Self {
ns: ns.into(),
client,
}
}
fn key(&self, name: &str) -> Vec<u8> {
format!("{}-{}-{}", TEST_KEY_PREFIX, self.ns, name).into_bytes()
}
async fn gen_data(&self) {
for i in 0..10 {
let req = PutRequest::new()
.with_key(self.key(&format!("key-{}", i)))
.with_value(format!("{}-{}", "value", i).into_bytes())
.with_prev_kv();
let res = self.client.put(req).await;
assert!(res.is_ok());
}
}
async fn clear_data(&self) {
let req =
DeleteRangeRequest::new().with_prefix(format!("{}-{}", TEST_KEY_PREFIX, self.ns));
let res = self.client.delete_range(req).await;
assert!(res.is_ok());
}
}
async fn new_client(ns: impl Into<String>) -> TestClient {
let client = TestClient::new(ns).await;
client.clear_data().await;
client
}
#[tokio::test]
async fn test_meta_client_builder() {
let urls = &["127.0.0.1:3001", "127.0.0.1:3002"];
@@ -373,15 +435,15 @@ mod tests {
#[tokio::test]
async fn test_ask_leader() {
let client = mocks::mock_client_with_memstore().await;
let res = client.ask_leader().await;
let tc = new_client("test_ask_leader").await;
let res = tc.client.ask_leader().await;
assert!(res.is_ok());
}
#[tokio::test]
async fn test_heartbeat() {
let client = mocks::mock_client_with_memstore().await;
let (sender, mut receiver) = client.heartbeat().await.unwrap();
let tc = new_client("test_heartbeat").await;
let (sender, mut receiver) = tc.client.heartbeat().await.unwrap();
// send heartbeats
tokio::spawn(async move {
for _ in 0..5 {
@@ -449,66 +511,58 @@ mod tests {
let res = client.create_route(req).await.unwrap();
assert_eq!(1, res.table_routes.len());
let req = RouteRequest::new().add_table_name(table_name);
let req = RouteRequest::new().add_table_name(table_name.clone());
let res = client.route(req).await.unwrap();
// empty table_routes since no TableGlobalValue is stored by datanode
assert!(res.table_routes.is_empty());
}
async fn gen_data(client: &MetaClient) {
for i in 0..10 {
let req = PutRequest::new()
.with_key(format!("{}-{}", "key", i).into_bytes())
.with_value(format!("{}-{}", "value", i).into_bytes())
.with_prev_kv();
let res = client.put(req).await;
assert!(res.is_ok());
}
let req = DeleteRequest::new(table_name.clone());
let res = client.delete_route(req).await;
// empty table_routes since no TableGlobalValue is stored by datanode
assert!(res.is_err());
}
#[tokio::test]
async fn test_range_get() {
let client = mocks::mock_client_with_memstore().await;
let tc = new_client("test_range_get").await;
tc.gen_data().await;
gen_data(&client).await;
let req = RangeRequest::new().with_key(b"key-0".to_vec());
let res = client.range(req).await;
let key = tc.key("key-0");
let req = RangeRequest::new().with_key(key.as_slice());
let res = tc.client.range(req).await;
let mut kvs = res.unwrap().take_kvs();
assert_eq!(1, kvs.len());
let mut kv = kvs.pop().unwrap();
assert_eq!(b"key-0".to_vec(), kv.take_key());
assert_eq!(key, kv.take_key());
assert_eq!(b"value-0".to_vec(), kv.take_value());
}
#[tokio::test]
async fn test_range_get_prefix() {
let client = mocks::mock_client_with_memstore().await;
let tc = new_client("test_range_get_prefix").await;
tc.gen_data().await;
gen_data(&client).await;
let req = RangeRequest::new().with_prefix(b"key-".to_vec());
let res = client.range(req).await;
let req = RangeRequest::new().with_prefix(tc.key("key-"));
let res = tc.client.range(req).await;
let kvs = res.unwrap().take_kvs();
assert_eq!(10, kvs.len());
for (i, mut kv) in kvs.into_iter().enumerate() {
assert_eq!(format!("{}-{}", "key", i).into_bytes(), kv.take_key());
assert_eq!(tc.key(&format!("key-{}", i)), kv.take_key());
assert_eq!(format!("{}-{}", "value", i).into_bytes(), kv.take_value());
}
}
#[tokio::test]
async fn test_range() {
let client = mocks::mock_client_with_memstore().await;
let tc = new_client("test_range").await;
tc.gen_data().await;
gen_data(&client).await;
let req = RangeRequest::new().with_range(b"key-5".to_vec(), b"key-8".to_vec());
let res = client.range(req).await;
let req = RangeRequest::new().with_range(tc.key("key-5"), tc.key("key-8"));
let res = tc.client.range(req).await;
let kvs = res.unwrap().take_kvs();
assert_eq!(3, kvs.len());
for (i, mut kv) in kvs.into_iter().enumerate() {
assert_eq!(format!("{}-{}", "key", i + 5).into_bytes(), kv.take_key());
assert_eq!(tc.key(&format!("key-{}", i + 5)), kv.take_key());
assert_eq!(
format!("{}-{}", "value", i + 5).into_bytes(),
kv.take_value()
@@ -518,121 +572,129 @@ mod tests {
#[tokio::test]
async fn test_range_keys_only() {
let client = mocks::mock_client_with_memstore().await;
gen_data(&client).await;
let tc = new_client("test_range_keys_only").await;
tc.gen_data().await;
let req = RangeRequest::new()
.with_range(b"key-5".to_vec(), b"key-8".to_vec())
.with_range(tc.key("key-5"), tc.key("key-8"))
.with_keys_only();
let res = client.range(req).await;
let res = tc.client.range(req).await;
let kvs = res.unwrap().take_kvs();
assert_eq!(3, kvs.len());
for (i, mut kv) in kvs.into_iter().enumerate() {
assert_eq!(format!("{}-{}", "key", i + 5).into_bytes(), kv.take_key());
assert_eq!(tc.key(&format!("key-{}", i + 5)), kv.take_key());
assert!(kv.take_value().is_empty());
}
}
#[tokio::test]
async fn test_put() {
let client = mocks::mock_client_with_memstore().await;
let tc = new_client("test_put").await;
let req = PutRequest::new()
.with_key(b"key".to_vec())
.with_key(tc.key("key"))
.with_value(b"value".to_vec());
let res = client.put(req).await;
let res = tc.client.put(req).await;
assert!(res.unwrap().take_prev_kv().is_none());
}
#[tokio::test]
async fn test_put_with_prev_kv() {
let client = mocks::mock_client_with_memstore().await;
let tc = new_client("test_put_with_prev_kv").await;
let key = tc.key("key");
let req = PutRequest::new()
.with_key(b"key".to_vec())
.with_key(key.as_slice())
.with_value(b"value".to_vec())
.with_prev_kv();
let res = client.put(req).await;
let res = tc.client.put(req).await;
assert!(res.unwrap().take_prev_kv().is_none());
let req = PutRequest::new()
.with_key(b"key".to_vec())
.with_key(key.as_slice())
.with_value(b"value1".to_vec())
.with_prev_kv();
let res = client.put(req).await;
let res = tc.client.put(req).await;
let mut kv = res.unwrap().take_prev_kv().unwrap();
assert_eq!(b"key".to_vec(), kv.take_key());
assert_eq!(key, kv.take_key());
assert_eq!(b"value".to_vec(), kv.take_value());
}
#[tokio::test]
async fn test_batch_put() {
let client = mocks::mock_client_with_memstore().await;
let tc = new_client("test_batch_put").await;
let req = BatchPutRequest::new()
.add_kv(b"key".to_vec(), b"value".to_vec())
.add_kv(b"key2".to_vec(), b"value2".to_vec());
let res = client.batch_put(req).await;
.add_kv(tc.key("key"), b"value".to_vec())
.add_kv(tc.key("key2"), b"value2".to_vec());
let res = tc.client.batch_put(req).await;
assert_eq!(0, res.unwrap().take_prev_kvs().len());
let req = RangeRequest::new().with_range(b"key".to_vec(), b"key3".to_vec());
let res = client.range(req).await;
let req = RangeRequest::new().with_range(tc.key("key"), tc.key("key3"));
let res = tc.client.range(req).await;
let kvs = res.unwrap().take_kvs();
assert_eq!(2, kvs.len());
}
#[tokio::test]
async fn test_batch_put_with_prev_kv() {
let client = mocks::mock_client_with_memstore().await;
let req = BatchPutRequest::new().add_kv(b"key".to_vec(), b"value".to_vec());
let res = client.batch_put(req).await;
let tc = new_client("test_batch_put_with_prev_kv").await;
let key = tc.key("key");
let key2 = tc.key("key2");
let req = BatchPutRequest::new().add_kv(key.as_slice(), b"value".to_vec());
let res = tc.client.batch_put(req).await;
assert_eq!(0, res.unwrap().take_prev_kvs().len());
let req = BatchPutRequest::new()
.add_kv(b"key".to_vec(), b"value-".to_vec())
.add_kv(b"key2".to_vec(), b"value2-".to_vec())
.add_kv(key.as_slice(), b"value-".to_vec())
.add_kv(key2.as_slice(), b"value2-".to_vec())
.with_prev_kv();
let res = client.batch_put(req).await;
let res = tc.client.batch_put(req).await;
let mut kvs = res.unwrap().take_prev_kvs();
assert_eq!(1, kvs.len());
let mut kv = kvs.pop().unwrap();
assert_eq!(b"key".to_vec(), kv.take_key());
assert_eq!(key, kv.take_key());
assert_eq!(b"value".to_vec(), kv.take_value());
}
#[tokio::test]
async fn test_compare_and_put() {
let client = mocks::mock_client_with_memstore().await;
let tc = new_client("test_compare_and_put").await;
let key = tc.key("key");
let req = CompareAndPutRequest::new()
.with_key(b"key".to_vec())
.with_key(key.as_slice())
.with_expect(b"expect".to_vec())
.with_value(b"value".to_vec());
let res = client.compare_and_put(req).await;
let res = tc.client.compare_and_put(req).await;
assert!(!res.unwrap().is_success());
// create if absent
let req = CompareAndPutRequest::new()
.with_key(b"key".to_vec())
.with_key(key.as_slice())
.with_value(b"value".to_vec());
let res = client.compare_and_put(req).await;
let res = tc.client.compare_and_put(req).await;
let mut res = res.unwrap();
assert!(res.is_success());
assert!(res.take_prev_kv().is_none());
// compare and put fail
let req = CompareAndPutRequest::new()
.with_key(b"key".to_vec())
.with_key(key.as_slice())
.with_expect(b"not_eq".to_vec())
.with_value(b"value2".to_vec());
let res = client.compare_and_put(req).await;
let res = tc.client.compare_and_put(req).await;
let mut res = res.unwrap();
assert!(!res.is_success());
assert_eq!(b"value".to_vec(), res.take_prev_kv().unwrap().take_value());
// compare and put success
let req = CompareAndPutRequest::new()
.with_key(b"key".to_vec())
.with_key(key.as_slice())
.with_expect(b"value".to_vec())
.with_value(b"value2".to_vec());
let res = client.compare_and_put(req).await;
let res = tc.client.compare_and_put(req).await;
let mut res = res.unwrap();
assert!(res.is_success());
assert_eq!(b"value".to_vec(), res.take_prev_kv().unwrap().take_value());
@@ -640,14 +702,13 @@ mod tests {
#[tokio::test]
async fn test_delete_with_key() {
let client = mocks::mock_client_with_memstore().await;
gen_data(&client).await;
let tc = new_client("test_delete_with_key").await;
tc.gen_data().await;
let req = DeleteRangeRequest::new()
.with_key(b"key-0".to_vec())
.with_key(tc.key("key-0"))
.with_prev_kv();
let res = client.delete_range(req).await;
let res = tc.client.delete_range(req).await;
let mut res = res.unwrap();
assert_eq!(1, res.deleted());
let mut kvs = res.take_prev_kvs();
@@ -658,14 +719,13 @@ mod tests {
#[tokio::test]
async fn test_delete_with_prefix() {
let client = mocks::mock_client_with_memstore().await;
gen_data(&client).await;
let tc = new_client("test_delete_with_prefix").await;
tc.gen_data().await;
let req = DeleteRangeRequest::new()
.with_prefix(b"key-".to_vec())
.with_prefix(tc.key("key-"))
.with_prev_kv();
let res = client.delete_range(req).await;
let res = tc.client.delete_range(req).await;
let mut res = res.unwrap();
assert_eq!(10, res.deleted());
let kvs = res.take_prev_kvs();
@@ -677,14 +737,13 @@ mod tests {
#[tokio::test]
async fn test_delete_with_range() {
let client = mocks::mock_client_with_memstore().await;
gen_data(&client).await;
let tc = new_client("test_delete_with_range").await;
tc.gen_data().await;
let req = DeleteRangeRequest::new()
.with_range(b"key-2".to_vec(), b"key-7".to_vec())
.with_range(tc.key("key-2"), tc.key("key-7"))
.with_prev_kv();
let res = client.delete_range(req).await;
let res = tc.client.delete_range(req).await;
let mut res = res.unwrap();
assert_eq!(5, res.deleted());
let kvs = res.take_prev_kvs();
@@ -696,4 +755,38 @@ mod tests {
);
}
}
#[tokio::test]
async fn test_move_value() {
let tc = new_client("test_move_value").await;
let from_key = tc.key("from_key");
let to_key = tc.key("to_key");
let req = MoveValueRequest::new(from_key.as_slice(), to_key.as_slice());
let res = tc.client.move_value(req).await;
assert!(res.unwrap().take_kv().is_none());
let req = PutRequest::new()
.with_key(to_key.as_slice())
.with_value(b"value".to_vec());
let _ = tc.client.put(req).await;
let req = MoveValueRequest::new(from_key.as_slice(), to_key.as_slice());
let res = tc.client.move_value(req).await;
let mut kv = res.unwrap().take_kv().unwrap();
assert_eq!(to_key.clone(), kv.take_key());
assert_eq!(b"value".to_vec(), kv.take_value());
let req = PutRequest::new()
.with_key(from_key.as_slice())
.with_value(b"value2".to_vec());
let _ = tc.client.put(req).await;
let req = MoveValueRequest::new(from_key.as_slice(), to_key.as_slice());
let res = tc.client.move_value(req).await;
let mut kv = res.unwrap().take_kv().unwrap();
assert_eq!(from_key, kv.take_key());
assert_eq!(b"value2".to_vec(), kv.take_value());
}
}

View File

@@ -16,7 +16,7 @@ use std::collections::HashSet;
use std::sync::Arc;
use api::v1::meta::router_client::RouterClient;
use api::v1::meta::{CreateRequest, RouteRequest, RouteResponse};
use api::v1::meta::{CreateRequest, DeleteRequest, RouteRequest, RouteResponse};
use common_grpc::channel_manager::ChannelManager;
use snafu::{ensure, OptionExt, ResultExt};
use tokio::sync::RwLock;
@@ -65,6 +65,11 @@ impl Client {
let inner = self.inner.read().await;
inner.route(req).await
}
pub async fn delete(&self, req: DeleteRequest) -> Result<RouteResponse> {
let inner = self.inner.read().await;
inner.delete(req).await
}
}
#[derive(Debug)]
@@ -98,6 +103,14 @@ impl Inner {
Ok(())
}
async fn create(&self, mut req: CreateRequest) -> Result<RouteResponse> {
let mut client = self.random_client()?;
req.set_header(self.id);
let res = client.create(req).await.context(error::TonicStatusSnafu)?;
Ok(res.into_inner())
}
async fn route(&self, mut req: RouteRequest) -> Result<RouteResponse> {
let mut client = self.random_client()?;
req.set_header(self.id);
@@ -106,10 +119,10 @@ impl Inner {
Ok(res.into_inner())
}
async fn create(&self, mut req: CreateRequest) -> Result<RouteResponse> {
async fn delete(&self, mut req: DeleteRequest) -> Result<RouteResponse> {
let mut client = self.random_client()?;
req.set_header(self.id);
let res = client.create(req).await.context(error::TonicStatusSnafu)?;
let res = client.delete(req).await.context(error::TonicStatusSnafu)?;
Ok(res.into_inner())
}

View File

@@ -18,7 +18,8 @@ use std::sync::Arc;
use api::v1::meta::store_client::StoreClient;
use api::v1::meta::{
BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse,
DeleteRangeRequest, DeleteRangeResponse, PutRequest, PutResponse, RangeRequest, RangeResponse,
DeleteRangeRequest, DeleteRangeResponse, MoveValueRequest, MoveValueResponse, PutRequest,
PutResponse, RangeRequest, RangeResponse,
};
use common_grpc::channel_manager::ChannelManager;
use snafu::{ensure, OptionExt, ResultExt};
@@ -86,6 +87,11 @@ impl Client {
let inner = self.inner.read().await;
inner.delete_range(req).await
}
pub async fn move_value(&self, req: MoveValueRequest) -> Result<MoveValueResponse> {
let inner = self.inner.read().await;
inner.move_value(req).await
}
}
#[derive(Debug)]
@@ -171,6 +177,17 @@ impl Inner {
Ok(res.into_inner())
}
async fn move_value(&self, mut req: MoveValueRequest) -> Result<MoveValueResponse> {
let mut client = self.random_client()?;
req.set_header(self.id);
let res = client
.move_value(req)
.await
.context(error::TonicStatusSnafu)?;
Ok(res.into_inner())
}
fn random_client(&self) -> Result<StoreClient<Channel>> {
let len = self.peers.len();
let peer = lb::random_get(len, |i| Some(&self.peers[i])).context(

View File

@@ -28,7 +28,8 @@ pub use router::{
use serde::{Deserialize, Serialize};
pub use store::{
BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse,
DeleteRangeRequest, DeleteRangeResponse, PutRequest, PutResponse, RangeRequest, RangeResponse,
DeleteRangeRequest, DeleteRangeResponse, MoveValueRequest, MoveValueResponse, PutRequest,
PutResponse, RangeRequest, RangeResponse,
};
#[derive(Debug, Clone)]

View File

@@ -15,8 +15,9 @@
use std::collections::HashMap;
use api::v1::meta::{
CreateRequest as PbCreateRequest, Partition as PbPartition, Region as PbRegion,
RouteRequest as PbRouteRequest, RouteResponse as PbRouteResponse, Table as PbTable,
CreateRequest as PbCreateRequest, DeleteRequest as PbDeleteRequest, Partition as PbPartition,
Region as PbRegion, RouteRequest as PbRouteRequest, RouteResponse as PbRouteResponse,
Table as PbTable,
};
use serde::{Deserialize, Serialize, Serializer};
use snafu::OptionExt;
@@ -25,6 +26,38 @@ use crate::error;
use crate::error::Result;
use crate::rpc::{util, Peer, TableName};
#[derive(Debug, Clone)]
pub struct CreateRequest {
pub table_name: TableName,
pub partitions: Vec<Partition>,
}
impl From<CreateRequest> for PbCreateRequest {
fn from(mut req: CreateRequest) -> Self {
Self {
header: None,
table_name: Some(req.table_name.into()),
partitions: req.partitions.drain(..).map(Into::into).collect(),
}
}
}
impl CreateRequest {
#[inline]
pub fn new(table_name: TableName) -> Self {
Self {
table_name,
partitions: vec![],
}
}
#[inline]
pub fn add_partition(mut self, partition: Partition) -> Self {
self.partitions.push(partition);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct RouteRequest {
pub table_names: Vec<TableName>,
@@ -55,34 +88,23 @@ impl RouteRequest {
}
#[derive(Debug, Clone)]
pub struct CreateRequest {
pub struct DeleteRequest {
pub table_name: TableName,
pub partitions: Vec<Partition>,
}
impl From<CreateRequest> for PbCreateRequest {
fn from(mut req: CreateRequest) -> Self {
impl From<DeleteRequest> for PbDeleteRequest {
fn from(req: DeleteRequest) -> Self {
Self {
header: None,
table_name: Some(req.table_name.into()),
partitions: req.partitions.drain(..).map(Into::into).collect(),
}
}
}
impl CreateRequest {
impl DeleteRequest {
#[inline]
pub fn new(table_name: TableName) -> Self {
Self {
table_name,
partitions: vec![],
}
}
#[inline]
pub fn add_partition(mut self, partition: Partition) -> Self {
self.partitions.push(partition);
self
Self { table_name }
}
}
@@ -275,33 +297,14 @@ impl From<PbPartition> for Partition {
#[cfg(test)]
mod tests {
use api::v1::meta::{
Partition as PbPartition, Peer as PbPeer, Region as PbRegion, RegionRoute as PbRegionRoute,
RouteRequest as PbRouteRequest, RouteResponse as PbRouteResponse, Table as PbTable,
TableName as PbTableName, TableRoute as PbTableRoute,
DeleteRequest as PbDeleteRequest, Partition as PbPartition, Peer as PbPeer,
Region as PbRegion, RegionRoute as PbRegionRoute, RouteRequest as PbRouteRequest,
RouteResponse as PbRouteResponse, Table as PbTable, TableName as PbTableName,
TableRoute as PbTableRoute,
};
use super::*;
#[test]
fn test_route_request_trans() {
let req = RouteRequest {
table_names: vec![
TableName::new("c1", "s1", "t1"),
TableName::new("c2", "s2", "t2"),
],
};
let into_req: PbRouteRequest = req.into();
assert!(into_req.header.is_none());
assert_eq!("c1", into_req.table_names.get(0).unwrap().catalog_name);
assert_eq!("s1", into_req.table_names.get(0).unwrap().schema_name);
assert_eq!("t1", into_req.table_names.get(0).unwrap().table_name);
assert_eq!("c2", into_req.table_names.get(1).unwrap().catalog_name);
assert_eq!("s2", into_req.table_names.get(1).unwrap().schema_name);
assert_eq!("t2", into_req.table_names.get(1).unwrap().table_name);
}
#[test]
fn test_create_request_trans() {
let req = CreateRequest {
@@ -343,6 +346,40 @@ mod tests {
);
}
#[test]
fn test_route_request_trans() {
let req = RouteRequest {
table_names: vec![
TableName::new("c1", "s1", "t1"),
TableName::new("c2", "s2", "t2"),
],
};
let into_req: PbRouteRequest = req.into();
assert!(into_req.header.is_none());
assert_eq!("c1", into_req.table_names.get(0).unwrap().catalog_name);
assert_eq!("s1", into_req.table_names.get(0).unwrap().schema_name);
assert_eq!("t1", into_req.table_names.get(0).unwrap().table_name);
assert_eq!("c2", into_req.table_names.get(1).unwrap().catalog_name);
assert_eq!("s2", into_req.table_names.get(1).unwrap().schema_name);
assert_eq!("t2", into_req.table_names.get(1).unwrap().table_name);
}
#[test]
fn test_delete_request_trans() {
let req = DeleteRequest {
table_name: TableName::new("c1", "s1", "t1"),
};
let into_req: PbDeleteRequest = req.into();
assert!(into_req.header.is_none());
assert_eq!("c1", into_req.table_name.as_ref().unwrap().catalog_name);
assert_eq!("s1", into_req.table_name.as_ref().unwrap().schema_name);
assert_eq!("t1", into_req.table_name.as_ref().unwrap().table_name);
}
#[test]
fn test_route_response_trans() {
let res = PbRouteResponse {

View File

@@ -17,6 +17,7 @@ use api::v1::meta::{
CompareAndPutRequest as PbCompareAndPutRequest,
CompareAndPutResponse as PbCompareAndPutResponse, DeleteRangeRequest as PbDeleteRangeRequest,
DeleteRangeResponse as PbDeleteRangeResponse, KeyValue as PbKeyValue,
MoveValueRequest as PbMoveValueRequest, MoveValueResponse as PbMoveValueResponse,
PutRequest as PbPutRequest, PutResponse as PbPutResponse, RangeRequest as PbRangeRequest,
RangeResponse as PbRangeResponse,
};
@@ -511,6 +512,7 @@ impl DeleteRangeResponse {
self.0.header.take().map(ResponseHeader::new)
}
#[inline]
pub fn deleted(&self) -> i64 {
self.0.deleted
}
@@ -521,6 +523,65 @@ impl DeleteRangeResponse {
}
}
#[derive(Debug, Clone, Default)]
pub struct MoveValueRequest {
/// If from_key dose not exist, return the value of to_key (if it exists).
/// If from_key exists, move the value of from_key to to_key (i.e. rename),
/// and return the value.
pub from_key: Vec<u8>,
pub to_key: Vec<u8>,
}
impl From<MoveValueRequest> for PbMoveValueRequest {
fn from(req: MoveValueRequest) -> Self {
Self {
header: None,
from_key: req.from_key,
to_key: req.to_key,
}
}
}
impl MoveValueRequest {
#[inline]
pub fn new(from_key: impl Into<Vec<u8>>, to_key: impl Into<Vec<u8>>) -> Self {
Self {
from_key: from_key.into(),
to_key: to_key.into(),
}
}
}
#[derive(Debug, Clone)]
pub struct MoveValueResponse(PbMoveValueResponse);
impl TryFrom<PbMoveValueResponse> for MoveValueResponse {
type Error = error::Error;
fn try_from(pb: PbMoveValueResponse) -> Result<Self> {
util::check_response_header(pb.header.as_ref())?;
Ok(Self::new(pb))
}
}
impl MoveValueResponse {
#[inline]
pub fn new(res: PbMoveValueResponse) -> Self {
Self(res)
}
#[inline]
pub fn take_header(&mut self) -> Option<ResponseHeader> {
self.0.header.take().map(ResponseHeader::new)
}
#[inline]
pub fn take_kv(&mut self) -> Option<KeyValue> {
self.0.kv.take().map(KeyValue::new)
}
}
#[cfg(test)]
mod tests {
use api::v1::meta::{
@@ -528,8 +589,10 @@ mod tests {
CompareAndPutRequest as PbCompareAndPutRequest,
CompareAndPutResponse as PbCompareAndPutResponse,
DeleteRangeRequest as PbDeleteRangeRequest, DeleteRangeResponse as PbDeleteRangeResponse,
KeyValue as PbKeyValue, PutRequest as PbPutRequest, PutResponse as PbPutResponse,
RangeRequest as PbRangeRequest, RangeResponse as PbRangeResponse,
KeyValue as PbKeyValue, MoveValueRequest as PbMoveValueRequest,
MoveValueResponse as PbMoveValueResponse, PutRequest as PbPutRequest,
PutResponse as PbPutResponse, RangeRequest as PbRangeRequest,
RangeResponse as PbRangeResponse,
};
use super::*;
@@ -775,4 +838,35 @@ mod tests {
assert_eq!(b"v2".to_vec(), kv1.value().to_vec());
assert_eq!(b"v2".to_vec(), kv1.take_value());
}
#[test]
fn test_move_value_request_trans() {
let (from_key, to_key) = (b"test_key1".to_vec(), b"test_key2".to_vec());
let req = MoveValueRequest::new(from_key.clone(), to_key.clone());
let into_req: PbMoveValueRequest = req.into();
assert!(into_req.header.is_none());
assert_eq!(from_key, into_req.from_key);
assert_eq!(to_key, into_req.to_key);
}
#[test]
fn test_move_value_response_trans() {
let pb_res = PbMoveValueResponse {
header: None,
kv: Some(PbKeyValue {
key: b"k1".to_vec(),
value: b"v1".to_vec(),
}),
};
let mut res = MoveValueResponse::new(pb_res);
assert!(res.take_header().is_none());
let mut kv = res.take_kv().unwrap();
assert_eq!(b"k1".to_vec(), kv.key().to_vec());
assert_eq!(b"k1".to_vec(), kv.take_key());
assert_eq!(b"v1".to_vec(), kv.value().to_vec());
assert_eq!(b"v1".to_vec(), kv.take_value());
}
}

View File

@@ -123,6 +123,15 @@ pub enum Error {
#[snafu(display("MetaSrv has no leader at this moment"))]
NoLeader { backtrace: Backtrace },
#[snafu(display("Table {} not found", name))]
TableNotFound { name: String, backtrace: Backtrace },
#[snafu(display(
"Failed to move the value of {} because other clients caused a race condition",
key
))]
MoveValue { key: String, backtrace: Backtrace },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -162,7 +171,9 @@ impl ErrorExt for Error {
| Error::UnexceptedSequenceValue { .. }
| Error::TableRouteNotFound { .. }
| Error::NextSequence { .. }
| Error::MoveValue { .. }
| Error::InvalidTxnResult { .. } => StatusCode::Unexpected,
Error::TableNotFound { .. } => StatusCode::TableNotFound,
Error::InvalidCatalogValue { source, .. } => source.status_code(),
}
}

View File

@@ -24,6 +24,7 @@ use snafu::{ensure, OptionExt, ResultExt};
use crate::error;
use crate::error::Result;
pub(crate) const REMOVED_PREFIX: &str = "__removed";
pub(crate) const DN_LEASE_PREFIX: &str = "__meta_dnlease";
pub(crate) const SEQ_PREFIX: &str = "__meta_seq";
pub(crate) const TABLE_ROUTE_PREFIX: &str = "__meta_table_route";
@@ -149,6 +150,7 @@ impl<'a> TableRouteKey<'a> {
}
}
#[inline]
pub fn prefix(&self) -> String {
format!(
"{}-{}-{}-{}",
@@ -156,9 +158,15 @@ impl<'a> TableRouteKey<'a> {
)
}
#[inline]
pub fn key(&self) -> String {
format!("{}-{}", self.prefix(), self.table_id)
}
#[inline]
pub fn removed_key(&self) -> String {
format!("{}-{}", REMOVED_PREFIX, self.key())
}
}
#[cfg(test)]

View File

@@ -205,6 +205,13 @@ mod tests {
) -> Result<api::v1::meta::DeleteRangeResponse> {
unreachable!()
}
async fn move_value(
&self,
_: api::v1::meta::MoveValueRequest,
) -> Result<api::v1::meta::MoveValueResponse> {
unreachable!()
}
}
let kv_store = Arc::new(Noop {});

View File

@@ -13,8 +13,9 @@
// limitations under the License.
use api::v1::meta::{
router_server, CreateRequest, Error, PeerDict, PutRequest, RangeRequest, Region, RegionRoute,
ResponseHeader, RouteRequest, RouteResponse, Table, TableRoute, TableRouteValue,
router_server, CreateRequest, DeleteRequest, Error, MoveValueRequest, Peer, PeerDict,
PutRequest, RangeRequest, Region, RegionRoute, ResponseHeader, RouteRequest, RouteResponse,
Table, TableRoute, TableRouteValue,
};
use catalog::helper::{TableGlobalKey, TableGlobalValue};
use common_telemetry::warn;
@@ -31,14 +32,6 @@ use crate::service::GrpcResult;
#[async_trait::async_trait]
impl router_server::Router for MetaSrv {
async fn route(&self, req: Request<RouteRequest>) -> GrpcResult<RouteResponse> {
let req = req.into_inner();
let ctx = self.new_ctx();
let res = handle_route(req, ctx).await?;
Ok(Response::new(res))
}
async fn create(&self, req: Request<CreateRequest>) -> GrpcResult<RouteResponse> {
let req = req.into_inner();
let ctx = self.new_ctx();
@@ -48,56 +41,22 @@ impl router_server::Router for MetaSrv {
Ok(Response::new(res))
}
}
async fn handle_route(req: RouteRequest, ctx: Context) -> Result<RouteResponse> {
let RouteRequest {
header,
table_names,
} = req;
let cluster_id = header.as_ref().map_or(0, |h| h.cluster_id);
let table_global_keys = table_names.into_iter().map(|t| TableGlobalKey {
catalog_name: t.catalog_name,
schema_name: t.schema_name,
table_name: t.table_name,
});
let tables = fetch_tables(&ctx.kv_store, table_global_keys).await?;
async fn route(&self, req: Request<RouteRequest>) -> GrpcResult<RouteResponse> {
let req = req.into_inner();
let ctx = self.new_ctx();
let res = handle_route(req, ctx).await?;
let mut peer_dict = PeerDict::default();
let mut table_routes = vec![];
for (tg, tr) in tables {
let TableRouteValue {
peers,
mut table_route,
} = tr;
if let Some(table_route) = &mut table_route {
for rr in &mut table_route.region_routes {
if let Some(peer) = peers.get(rr.leader_peer_index as usize) {
rr.leader_peer_index = peer_dict.get_or_insert(peer.clone()) as u64;
}
for index in &mut rr.follower_peer_indexes {
if let Some(peer) = peers.get(*index as usize) {
*index = peer_dict.get_or_insert(peer.clone()) as u64;
}
}
}
if let Some(table) = &mut table_route.table {
table.table_schema = tg.as_bytes().context(error::InvalidCatalogValueSnafu)?;
}
}
if let Some(table_route) = table_route {
table_routes.push(table_route)
}
Ok(Response::new(res))
}
let peers = peer_dict.into_peers();
let header = Some(ResponseHeader::success(cluster_id));
Ok(RouteResponse {
header,
peers,
table_routes,
})
async fn delete(&self, req: Request<DeleteRequest>) -> GrpcResult<RouteResponse> {
let req = req.into_inner();
let ctx = self.new_ctx();
let res = handle_delete(req, ctx).await?;
Ok(Response::new(res))
}
}
async fn handle_create(
@@ -169,6 +128,90 @@ async fn handle_create(
})
}
async fn handle_route(req: RouteRequest, ctx: Context) -> Result<RouteResponse> {
let RouteRequest {
header,
table_names,
} = req;
let cluster_id = header.as_ref().map_or(0, |h| h.cluster_id);
let table_global_keys = table_names.into_iter().map(|t| TableGlobalKey {
catalog_name: t.catalog_name,
schema_name: t.schema_name,
table_name: t.table_name,
});
let tables = fetch_tables(&ctx.kv_store, table_global_keys).await?;
let (peers, table_routes) = fill_table_routes(tables)?;
let header = Some(ResponseHeader::success(cluster_id));
Ok(RouteResponse {
header,
peers,
table_routes,
})
}
async fn handle_delete(req: DeleteRequest, ctx: Context) -> Result<RouteResponse> {
let DeleteRequest { header, table_name } = req;
let cluster_id = header.as_ref().map_or(0, |h| h.cluster_id);
let tgk = table_name
.map(|t| TableGlobalKey {
catalog_name: t.catalog_name,
schema_name: t.schema_name,
table_name: t.table_name,
})
.context(error::EmptyTableNameSnafu)?;
let tgv = get_table_global_value(&ctx.kv_store, &tgk)
.await?
.with_context(|| error::TableNotFoundSnafu {
name: format!("{}", tgk),
})?;
let trk = TableRouteKey::with_table_global_key(tgv.table_id() as u64, &tgk);
let (_, trv) = remove_table_route_value(&ctx.kv_store, &trk).await?;
let (peers, table_routes) = fill_table_routes(vec![(tgv, trv)])?;
let header = Some(ResponseHeader::success(cluster_id));
Ok(RouteResponse {
header,
peers,
table_routes,
})
}
fn fill_table_routes(
tables: Vec<(TableGlobalValue, TableRouteValue)>,
) -> Result<(Vec<Peer>, Vec<TableRoute>)> {
let mut peer_dict = PeerDict::default();
let mut table_routes = vec![];
for (tgv, trv) in tables {
let TableRouteValue {
peers,
mut table_route,
} = trv;
if let Some(table_route) = &mut table_route {
for rr in &mut table_route.region_routes {
if let Some(peer) = peers.get(rr.leader_peer_index as usize) {
rr.leader_peer_index = peer_dict.get_or_insert(peer.clone()) as u64;
}
for index in &mut rr.follower_peer_indexes {
if let Some(peer) = peers.get(*index as usize) {
*index = peer_dict.get_or_insert(peer.clone()) as u64;
}
}
}
if let Some(table) = &mut table_route.table {
table.table_schema = tgv.as_bytes().context(error::InvalidCatalogValueSnafu)?;
}
}
if let Some(table_route) = table_route {
table_routes.push(table_route)
}
}
Ok((peer_dict.into_peers(), table_routes))
}
async fn fetch_tables(
kv_store: &KvStoreRef,
keys: impl Iterator<Item = TableGlobalKey>,
@@ -176,18 +219,18 @@ async fn fetch_tables(
let mut tables = vec![];
// Maybe we can optimize the for loop in the future, but in general,
// there won't be many keys, in fact, there is usually just one.
for tk in keys {
let tv = get_table_global_value(kv_store, &tk).await?;
if tv.is_none() {
warn!("Table global value is absent: {}", tk);
for tgk in keys {
let tgv = get_table_global_value(kv_store, &tgk).await?;
if tgv.is_none() {
warn!("Table global value is absent: {}", tgk);
continue;
}
let tv = tv.unwrap();
let tgv = tgv.unwrap();
let tr_key = TableRouteKey::with_table_global_key(tv.table_id() as u64, &tk);
let tr = get_table_route_value(kv_store, &tr_key).await?;
let trk = TableRouteKey::with_table_global_key(tgv.table_id() as u64, &tgk);
let trv = get_table_route_value(kv_store, &trk).await?;
tables.push((tv, tr));
tables.push((tgv, trv));
}
Ok(tables)
@@ -197,15 +240,32 @@ async fn get_table_route_value(
kv_store: &KvStoreRef,
key: &TableRouteKey<'_>,
) -> Result<TableRouteValue> {
let tr = get_from_store(kv_store, key.key().into_bytes())
let trv = get_from_store(kv_store, key.key().into_bytes())
.await?
.context(error::TableRouteNotFoundSnafu { key: key.key() })?;
let tr: TableRouteValue = tr
let trv: TableRouteValue = trv
.as_slice()
.try_into()
.context(error::DecodeTableRouteSnafu)?;
Ok(tr)
Ok(trv)
}
async fn remove_table_route_value(
kv_store: &KvStoreRef,
key: &TableRouteKey<'_>,
) -> Result<(Vec<u8>, TableRouteValue)> {
let from_key = key.key().into_bytes();
let to_key = key.removed_key().into_bytes();
let v = move_value(kv_store, from_key, to_key)
.await?
.context(error::TableRouteNotFoundSnafu { key: key.key() })?;
let trv: TableRouteValue =
v.1.as_slice()
.try_into()
.context(error::DecodeTableRouteSnafu)?;
Ok((v.0, trv))
}
async fn get_table_global_value(
@@ -223,6 +283,23 @@ async fn get_table_global_value(
}
}
async fn move_value(
kv_store: &KvStoreRef,
from_key: impl Into<Vec<u8>>,
to_key: impl Into<Vec<u8>>,
) -> Result<Option<(Vec<u8>, Vec<u8>)>> {
let from_key = from_key.into();
let to_key = to_key.into();
let move_req = MoveValueRequest {
from_key,
to_key,
..Default::default()
};
let res = kv_store.move_value(move_req).await?;
Ok(res.kv.map(|kv| (kv.key, kv.value)))
}
async fn put_into_store(
kv_store: &KvStoreRef,
key: impl Into<Vec<u8>>,

View File

@@ -18,7 +18,8 @@ pub mod memory;
use api::v1::meta::{
store_server, BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse,
DeleteRangeRequest, DeleteRangeResponse, PutRequest, PutResponse, RangeRequest, RangeResponse,
DeleteRangeRequest, DeleteRangeResponse, MoveValueRequest, MoveValueResponse, PutRequest,
PutResponse, RangeRequest, RangeResponse,
};
use tonic::{Request, Response};
@@ -67,6 +68,13 @@ impl store_server::Store for MetaSrv {
Ok(Response::new(res))
}
async fn move_value(&self, req: Request<MoveValueRequest>) -> GrpcResult<MoveValueResponse> {
let req = req.into_inner();
let res = self.kv_store().move_value(req).await?;
Ok(Response::new(res))
}
}
#[cfg(test)]
@@ -130,4 +138,14 @@ mod tests {
assert!(res.is_ok());
}
#[tokio::test]
async fn test_move_value() {
let kv_store = Arc::new(MemStore::new());
let meta_srv = MetaSrv::new(MetaSrvOptions::default(), kv_store, None, None).await;
let req = MoveValueRequest::default();
let res = meta_srv.move_value(req.into_request()).await;
assert!(res.is_ok());
}
}

View File

@@ -16,10 +16,11 @@ use std::sync::Arc;
use api::v1::meta::{
BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse,
DeleteRangeRequest, DeleteRangeResponse, KeyValue, PutRequest, PutResponse, RangeRequest,
RangeResponse, ResponseHeader,
DeleteRangeRequest, DeleteRangeResponse, KeyValue, MoveValueRequest, MoveValueResponse,
PutRequest, PutResponse, RangeRequest, RangeResponse, ResponseHeader,
};
use common_error::prelude::*;
use common_telemetry::warn;
use etcd_client::{
Client, Compare, CompareOp, DeleteOptions, GetOptions, PutOptions, Txn, TxnOp, TxnOpResponse,
};
@@ -63,11 +64,7 @@ impl KvStore for EtcdStore {
.await
.context(error::EtcdFailedSnafu)?;
let kvs = res
.kvs()
.iter()
.map(|kv| KvPair::new(kv).into())
.collect::<Vec<_>>();
let kvs = res.kvs().iter().map(KvPair::to_kv).collect::<Vec<_>>();
let header = Some(ResponseHeader::success(cluster_id));
Ok(RangeResponse {
@@ -92,7 +89,7 @@ impl KvStore for EtcdStore {
.await
.context(error::EtcdFailedSnafu)?;
let prev_kv = res.prev_key().map(|kv| KvPair::new(kv).into());
let prev_kv = res.prev_key().map(KvPair::to_kv);
let header = Some(ResponseHeader::success(cluster_id));
Ok(PutResponse { header, prev_kv })
@@ -123,7 +120,7 @@ impl KvStore for EtcdStore {
match op_res {
TxnOpResponse::Put(put_res) => {
if let Some(prev_kv) = put_res.prev_key() {
prev_kvs.push(KvPair::new(prev_kv).into());
prev_kvs.push(KvPair::to_kv(prev_kv));
}
}
_ => unreachable!(), // never get here
@@ -140,20 +137,23 @@ impl KvStore for EtcdStore {
key,
expect,
value,
options,
put_options,
} = req.try_into()?;
let put_op = vec![TxnOp::put(key.clone(), value, options)];
let get_op = vec![TxnOp::get(key.clone(), None)];
let mut txn = if expect.is_empty() {
let compare = if expect.is_empty() {
// create if absent
// revision 0 means key was not exist
Txn::new().when(vec![Compare::create_revision(key, CompareOp::Equal, 0)])
Compare::create_revision(key.clone(), CompareOp::Equal, 0)
} else {
// compare and put
Txn::new().when(vec![Compare::value(key, CompareOp::Equal, expect)])
Compare::value(key.clone(), CompareOp::Equal, expect)
};
txn = txn.and_then(put_op).or_else(get_op);
let put = TxnOp::put(key.clone(), value, put_options);
let get = TxnOp::get(key, None);
let txn = Txn::new()
.when(vec![compare])
.and_then(vec![put])
.or_else(vec![get]);
let txn_res = self
.client
@@ -171,23 +171,8 @@ impl KvStore for EtcdStore {
})?;
let prev_kv = match op_res {
TxnOpResponse::Put(put_res) => {
put_res.prev_key().map(|kv| KeyValue::from(KvPair::new(kv)))
}
TxnOpResponse::Get(get_res) => {
if get_res.count() == 0 {
// do not exists
None
} else {
ensure!(
get_res.count() == 1,
error::InvalidTxnResultSnafu {
err_msg: format!("expect 1 response, actual {}", get_res.count())
}
);
Some(KeyValue::from(KvPair::new(&get_res.kvs()[0])))
}
}
TxnOpResponse::Put(res) => res.prev_key().map(KvPair::to_kv),
TxnOpResponse::Get(res) => res.kvs().first().map(KvPair::to_kv),
_ => unreachable!(), // never get here
};
@@ -213,11 +198,7 @@ impl KvStore for EtcdStore {
.await
.context(error::EtcdFailedSnafu)?;
let prev_kvs = res
.prev_kvs()
.iter()
.map(|kv| KvPair::new(kv).into())
.collect::<Vec<_>>();
let prev_kvs = res.prev_kvs().iter().map(KvPair::to_kv).collect::<Vec<_>>();
let header = Some(ResponseHeader::success(cluster_id));
Ok(DeleteRangeResponse {
@@ -226,6 +207,83 @@ impl KvStore for EtcdStore {
prev_kvs,
})
}
async fn move_value(&self, req: MoveValueRequest) -> Result<MoveValueResponse> {
let MoveValue {
cluster_id,
from_key,
to_key,
delete_options,
} = req.try_into()?;
let mut client = self.client.kv_client();
let header = Some(ResponseHeader::success(cluster_id));
// TODO(jiachun): Maybe it's better to let the users control it in the request
const MAX_RETRIES: usize = 8;
for _ in 0..MAX_RETRIES {
let from_key = from_key.as_slice();
let to_key = to_key.as_slice();
let res = client
.get(from_key, None)
.await
.context(error::EtcdFailedSnafu)?;
let txn = match res.kvs().first() {
None => {
// get `to_key` if `from_key` absent
// revision 0 means key was not exist
let compare = Compare::create_revision(from_key, CompareOp::Equal, 0);
let get = TxnOp::get(to_key, None);
Txn::new().when(vec![compare]).and_then(vec![get])
}
Some(kv) => {
// compare `from_key` and move to `to_key`
let value = kv.value();
let compare = Compare::value(from_key, CompareOp::Equal, value);
let delete = TxnOp::delete(from_key, delete_options.clone());
let put = TxnOp::put(to_key, value, None);
Txn::new().when(vec![compare]).and_then(vec![delete, put])
}
};
let txn_res = client.txn(txn).await.context(error::EtcdFailedSnafu)?;
if !txn_res.succeeded() {
warn!(
"Failed to atomically move {:?} to {:?}, try again...",
String::from_utf8_lossy(from_key),
String::from_utf8_lossy(to_key)
);
continue;
}
// [`get_res'] or [`delete_res`, `put_res`], `put_res` will be ignored.
for op_res in txn_res.op_responses() {
match op_res {
TxnOpResponse::Get(res) => {
return Ok(MoveValueResponse {
header,
kv: res.kvs().first().map(KvPair::to_kv),
});
}
TxnOpResponse::Delete(res) => {
return Ok(MoveValueResponse {
header,
kv: res.prev_kvs().first().map(KvPair::to_kv),
});
}
_ => {}
}
}
}
error::MoveValueSnafu {
key: String::from_utf8_lossy(&from_key),
}
.fail()
}
}
struct Get {
@@ -333,7 +391,7 @@ struct CompareAndPut {
key: Vec<u8>,
expect: Vec<u8>,
value: Vec<u8>,
options: Option<PutOptions>,
put_options: Option<PutOptions>,
}
impl TryFrom<CompareAndPutRequest> for CompareAndPut {
@@ -352,7 +410,7 @@ impl TryFrom<CompareAndPutRequest> for CompareAndPut {
key,
expect,
value,
options: Some(PutOptions::default().with_prev_key()),
put_options: Some(PutOptions::default().with_prev_key()),
})
}
}
@@ -392,6 +450,32 @@ impl TryFrom<DeleteRangeRequest> for Delete {
}
}
struct MoveValue {
cluster_id: u64,
from_key: Vec<u8>,
to_key: Vec<u8>,
delete_options: Option<DeleteOptions>,
}
impl TryFrom<MoveValueRequest> for MoveValue {
type Error = error::Error;
fn try_from(req: MoveValueRequest) -> Result<Self> {
let MoveValueRequest {
header,
from_key,
to_key,
} = req;
Ok(MoveValue {
cluster_id: header.map_or(0, |h| h.cluster_id),
from_key,
to_key,
delete_options: Some(DeleteOptions::default().with_prev_key()),
})
}
}
struct KvPair<'a>(&'a etcd_client::KeyValue);
impl<'a> KvPair<'a> {
@@ -400,6 +484,11 @@ impl<'a> KvPair<'a> {
fn new(kv: &'a etcd_client::KeyValue) -> Self {
Self(kv)
}
#[inline]
fn to_kv(kv: &etcd_client::KeyValue) -> KeyValue {
KeyValue::from(KvPair::new(kv))
}
}
impl<'a> From<KvPair<'a>> for KeyValue {
@@ -479,7 +568,7 @@ mod tests {
assert_eq!(b"test_key".to_vec(), compare_and_put.key);
assert_eq!(b"test_expect".to_vec(), compare_and_put.expect);
assert_eq!(b"test_value".to_vec(), compare_and_put.value);
assert!(compare_and_put.options.is_some());
assert!(compare_and_put.put_options.is_some());
}
#[test]
@@ -496,4 +585,19 @@ mod tests {
assert_eq!(b"test_key".to_vec(), delete.key);
assert!(delete.options.is_some());
}
#[test]
fn test_parse_move_value() {
let req = MoveValueRequest {
from_key: b"test_from_key".to_vec(),
to_key: b"test_to_key".to_vec(),
..Default::default()
};
let move_value: MoveValue = req.try_into().unwrap();
assert_eq!(b"test_from_key".to_vec(), move_value.from_key);
assert_eq!(b"test_to_key".to_vec(), move_value.to_key);
assert!(move_value.delete_options.is_some());
}
}

View File

@@ -16,7 +16,8 @@ use std::sync::Arc;
use api::v1::meta::{
BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse,
DeleteRangeRequest, DeleteRangeResponse, PutRequest, PutResponse, RangeRequest, RangeResponse,
DeleteRangeRequest, DeleteRangeResponse, MoveValueRequest, MoveValueResponse, PutRequest,
PutResponse, RangeRequest, RangeResponse,
};
use crate::error::Result;
@@ -34,4 +35,6 @@ pub trait KvStore: Send + Sync {
async fn compare_and_put(&self, req: CompareAndPutRequest) -> Result<CompareAndPutResponse>;
async fn delete_range(&self, req: DeleteRangeRequest) -> Result<DeleteRangeResponse>;
async fn move_value(&self, req: MoveValueRequest) -> Result<MoveValueResponse>;
}

View File

@@ -19,8 +19,8 @@ use std::sync::Arc;
use api::v1::meta::{
BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse,
DeleteRangeRequest, DeleteRangeResponse, KeyValue, PutRequest, PutResponse, RangeRequest,
RangeResponse, ResponseHeader,
DeleteRangeRequest, DeleteRangeResponse, KeyValue, MoveValueRequest, MoveValueResponse,
PutRequest, PutResponse, RangeRequest, RangeResponse, ResponseHeader,
};
use parking_lot::RwLock;
@@ -219,4 +219,28 @@ impl KvStore for MemStore {
},
})
}
async fn move_value(&self, req: MoveValueRequest) -> Result<MoveValueResponse> {
let MoveValueRequest {
header,
from_key,
to_key,
} = req;
let mut memory = self.inner.write();
let kv = match memory.remove(&from_key) {
Some(v) => {
memory.insert(to_key, v.clone());
Some((from_key, v))
}
None => memory.get(&to_key).map(|v| (to_key, v.clone())),
};
let kv = kv.map(|(key, value)| KeyValue { key, value });
let cluster_id = header.map_or(0, |h| h.cluster_id);
let header = Some(ResponseHeader::success(cluster_id));
Ok(MoveValueResponse { header, kv })
}
}

View File

@@ -58,8 +58,8 @@ fn region_id(table_id: TableId, n: u32) -> RegionId {
}
#[inline]
fn table_dir(schema_name: &str, table_name: &str, table_id: TableId) -> String {
format!("{}/{}_{}/", schema_name, table_name, table_id)
fn table_dir(schema_name: &str, table_id: TableId) -> String {
format!("{}/{}/", schema_name, table_id)
}
/// [TableEngine] implementation.
@@ -341,7 +341,7 @@ impl<S: StorageEngine> MitoEngineInner<S> {
}
}
let table_dir = table_dir(schema_name, table_name, table_id);
let table_dir = table_dir(schema_name, table_id);
let opts = CreateOptions {
parent_dir: table_dir.clone(),
};
@@ -422,7 +422,7 @@ impl<S: StorageEngine> MitoEngineInner<S> {
let table_id = request.table_id;
let engine_ctx = StorageEngineContext::default();
let table_dir = table_dir(schema_name, table_name, table_id);
let table_dir = table_dir(schema_name, table_id);
let opts = OpenOptions {
parent_dir: table_dir.to_string(),
};
@@ -665,14 +665,8 @@ mod tests {
#[test]
fn test_table_dir() {
assert_eq!(
"public/test_table_1024/",
table_dir("public", "test_table", 1024)
);
assert_eq!(
"prometheus/demo_1024/",
table_dir("prometheus", "demo", 1024)
);
assert_eq!("public/1024/", table_dir("public", 1024));
assert_eq!("prometheus/1024/", table_dir("prometheus", 1024));
}
#[test]

View File

@@ -6,7 +6,7 @@ license = "Apache-2.0"
[dependencies]
futures = { version = "0.3" }
opendal = { version = "0.21", features = ["layers-tracing", "layers-metrics"] }
opendal = { version = "0.22", features = ["layers-tracing", "layers-metrics"] }
tokio = { version = "1.0", features = ["full"] }
[dev-dependencies]

9
src/promql/Cargo.toml Normal file
View File

@@ -0,0 +1,9 @@
[package]
name = "promql"
version = "0.1.0"
edition = "2021"
[dependencies]
common-error = { path = "../common/error" }
promql-parser = { git = "https://github.com/GreptimeTeam/promql-parser.git", rev = "71d8a90" }
snafu = { version = "0.7", features = ["backtraces"] }

36
src/promql/src/engine.rs Normal file
View File

@@ -0,0 +1,36 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use promql_parser::parser::Value;
use crate::error::Result;
mod evaluator;
mod functions;
pub use evaluator::*;
pub struct Context {}
pub struct Query {}
pub struct Engine {}
impl Engine {
pub fn exec(_ctx: &Context, _q: Query) -> Result<Arc<dyn Value>> {
unimplemented!();
}
}

View File

@@ -0,0 +1,29 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use promql_parser::parser::{Expr, Value};
use crate::engine::Context;
use crate::error::Result;
/// An evaluator evaluates given expressions over given fixed timestamps.
pub struct Evaluator {}
impl Evaluator {
pub fn eval(_ctx: &Context, _expr: &Expr) -> Result<Arc<dyn Value>> {
unimplemented!();
}
}

View File

@@ -0,0 +1,15 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! PromQL functions

50
src/promql/src/error.rs Normal file
View File

@@ -0,0 +1,50 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::any::Any;
use common_error::prelude::*;
common_error::define_opaque_error!(Error);
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum InnerError {
#[snafu(display("Unsupported expr type: {}", name))]
UnsupportedExpr { name: String, backtrace: Backtrace },
}
impl ErrorExt for InnerError {
fn status_code(&self) -> StatusCode {
use InnerError::*;
match self {
UnsupportedExpr { .. } => StatusCode::InvalidArguments,
}
}
fn backtrace_opt(&self) -> Option<&Backtrace> {
ErrorCompat::backtrace(self)
}
fn as_any(&self) -> &dyn Any {
self
}
}
impl From<InnerError> for Error {
fn from(e: InnerError) -> Error {
Error::new(e)
}
}
pub type Result<T> = std::result::Result<T, Error>;

16
src/promql/src/lib.rs Normal file
View File

@@ -0,0 +1,16 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod engine;
pub mod error;

View File

@@ -44,18 +44,18 @@ paste = { version = "1.0", optional = true }
query = { path = "../query" }
# TODO(discord9): This is a forked and tweaked version of RustPython, please update it to newest original RustPython After Update toolchain to 1.65
rustpython-ast = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab" }
rustpython-codegen = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab" }
rustpython-compiler = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab" }
rustpython-compiler-core = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab" }
rustpython-codegen = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab" }
rustpython-parser = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab" }
rustpython-pylib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab", features = [
"freeze-stdlib",
] }
rustpython-stdlib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab" }
rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab", features = [
"default",
"codegen",
] }
rustpython-stdlib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab" }
rustpython-pylib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "183e8dab", features = [
"freeze-stdlib",
] }
session = { path = "../session" }
snafu = { version = "0.7", features = ["backtraces"] }
sql = { path = "../sql" }

View File

@@ -15,7 +15,6 @@
//! Builtin module contains GreptimeDB builtin udf/udaf
#[cfg(test)]
#[allow(clippy::print_stdout)]
mod test;
use datafusion_common::{DataFusionError, ScalarValue};

View File

@@ -18,6 +18,7 @@ use std::io::Read;
use std::path::Path;
use std::sync::Arc;
use common_telemetry::{error, info};
use datatypes::arrow::array::{Float64Array, Int64Array};
use datatypes::arrow::compute;
use datatypes::arrow::datatypes::{DataType as ArrowDataType, Field};
@@ -308,6 +309,8 @@ impl PyValue {
#[test]
fn run_builtin_fn_testcases() {
common_telemetry::init_default_ut_logging();
let loc = Path::new("src/python/builtins/testcases.ron");
let loc = loc.to_str().expect("Fail to parse path");
let mut file = File::open(loc).expect("Fail to open file");
@@ -320,7 +323,7 @@ fn run_builtin_fn_testcases() {
PyVector::make_class(&vm.ctx);
});
for (idx, case) in testcases.into_iter().enumerate() {
print!("Testcase {idx} ...");
info!("Testcase {idx} ...");
cached_vm
.enter(|vm| {
let scope = vm.new_scope_with_builtins();
@@ -345,7 +348,7 @@ fn run_builtin_fn_testcases() {
let err_res = format_py_error(e, vm).to_string();
match case.expect{
Ok(v) => {
println!("\nError:\n{err_res}");
error!("\nError:\n{err_res}");
panic!("Expect Ok: {v:?}, found Error");
},
Err(err) => {
@@ -374,7 +377,6 @@ fn run_builtin_fn_testcases() {
}
};
});
println!(" passed!");
}
}
@@ -420,6 +422,8 @@ fn set_lst_of_vecs_in_scope(
#[allow(unused_must_use)]
#[test]
fn test_vm() {
common_telemetry::init_default_ut_logging();
rustpython_vm::Interpreter::with_init(Default::default(), |vm| {
vm.add_native_module("udf_builtins", Box::new(greptime_builtin::make_module));
// this can be in `.enter()` closure, but for clearity, put it in the `with_init()`
@@ -448,11 +452,10 @@ sin(values)"#,
.map_err(|err| vm.new_syntax_error(&err))
.unwrap();
let res = vm.run_code_obj(code_obj, scope);
println!("{:#?}", res);
match res {
Err(e) => {
let err_res = format_py_error(e, vm).to_string();
println!("Error:\n{err_res}");
error!("Error:\n{err_res}");
}
Ok(obj) => {
let _ser = PyValue::from_py_obj(&obj, vm);

View File

@@ -12,15 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#![allow(clippy::print_stdout, clippy::print_stderr)]
// for debug purpose, also this is already a
// test module so allow print_stdout shouldn't be a problem?
use std::fs::File;
use std::io::prelude::*;
use std::path::Path;
use std::sync::Arc;
use common_recordbatch::RecordBatch;
use common_telemetry::{error, info};
use console::style;
use datatypes::arrow::datatypes::DataType as ArrowDataType;
use datatypes::data_type::{ConcreteDataType, DataType};
@@ -91,6 +89,8 @@ fn create_sample_recordbatch() -> RecordBatch {
/// and exec/parse (depending on the type of predicate) then decide if result is as expected
#[test]
fn run_ron_testcases() {
common_telemetry::init_default_ut_logging();
let loc = Path::new("src/python/testcases.ron");
let loc = loc.to_str().expect("Fail to parse path");
let mut file = File::open(loc).expect("Fail to open file");
@@ -98,9 +98,9 @@ fn run_ron_testcases() {
file.read_to_string(&mut buf)
.expect("Fail to read to string");
let testcases: Vec<TestCase> = from_ron_string(&buf).expect("Fail to convert to testcases");
println!("Read {} testcases from {}", testcases.len(), loc);
info!("Read {} testcases from {}", testcases.len(), loc);
for testcase in testcases {
print!(".ron test {}", testcase.name);
info!(".ron test {}", testcase.name);
match testcase.predicate {
Predicate::ParseIsOk { result } => {
let copr = parse_and_compile_copr(&testcase.code);
@@ -110,19 +110,19 @@ fn run_ron_testcases() {
}
Predicate::ParseIsErr { reason } => {
let copr = parse_and_compile_copr(&testcase.code);
if copr.is_ok() {
panic!("Expect to be err, found{copr:#?}");
}
assert!(copr.is_err(), "Expect to be err, actual {copr:#?}");
let res = &copr.unwrap_err();
println!(
error!(
"{}",
pretty_print_error_in_src(&testcase.code, res, 0, "<embedded>")
);
let (res, _) = get_error_reason_loc(res);
if !res.contains(&reason) {
eprintln!("{}", testcase.code);
panic!("Parse Error, expect \"{reason}\" in \"{res}\", but not found.");
}
assert!(
res.contains(&reason),
"{} Parse Error, expect \"{reason}\" in \"{res}\", actual not found.",
testcase.code,
);
}
Predicate::ExecIsOk { fields, columns } => {
let rb = create_sample_recordbatch();
@@ -130,58 +130,47 @@ fn run_ron_testcases() {
fields
.iter()
.zip(res.schema.column_schemas())
.map(|(anno, real)| {
if !(anno.datatype.as_ref().unwrap() == &real.data_type.as_arrow_type()
&& anno.is_nullable == real.is_nullable())
{
eprintln!("fields expect to be {anno:#?}, found to be {real:#?}.");
panic!()
}
})
.count();
columns
.iter()
.enumerate()
.map(|(i, anno)| {
let real = res.column(i);
if !(anno.ty == real.data_type().as_arrow_type() && anno.len == real.len())
{
panic!(
"Unmatch type or length!Expect [{:#?}; {}], found [{:#?}; {}]",
anno.ty,
anno.len,
real.data_type(),
real.len()
);
}
})
.count();
.for_each(|(anno, real)| {
assert!(
anno.datatype.as_ref().unwrap() == &real.data_type.as_arrow_type()
&& anno.is_nullable == real.is_nullable(),
"Fields expected to be {anno:#?}, actual {real:#?}"
);
});
columns.iter().zip(res.columns()).for_each(|(anno, real)| {
assert!(
anno.ty == real.data_type().as_arrow_type() && anno.len == real.len(),
"Type or length not match! Expect [{:#?}; {}], actual [{:#?}; {}]",
anno.ty,
anno.len,
real.data_type(),
real.len()
);
});
}
Predicate::ExecIsErr {
reason: part_reason,
} => {
let rb = create_sample_recordbatch();
let res = coprocessor::exec_coprocessor(&testcase.code, &rb);
assert!(res.is_err(), "{:#?}\nExpect Err(...), actual Ok(...)", res);
if let Err(res) = res {
println!(
error!(
"{}",
pretty_print_error_in_src(&testcase.code, &res, 1120, "<embedded>")
);
let (reason, _) = get_error_reason_loc(&res);
if !reason.contains(&part_reason) {
panic!(
"{}\nExecute error, expect \"{reason}\" in \"{res}\", but not found.",
testcase.code,
reason = style(reason).green(),
res = style(res).red()
);
}
} else {
panic!("{:#?}\nExpect Err(...), found Ok(...)", res);
assert!(
reason.contains(&part_reason),
"{}\nExecute error, expect \"{reason}\" in \"{res}\", actual not found.",
testcase.code,
reason = style(reason).green(),
res = style(res).red()
)
}
}
}
println!(" ... {}", style("ok✅").green());
info!(" ... {}", style("ok✅").green());
}
}
@@ -279,7 +268,7 @@ def calc_rvs(open_time, close):
0,
"copr.py",
);
println!("{res}");
info!("{res}");
} else if let Ok(res) = ret {
dbg!(&res);
} else {
@@ -329,7 +318,7 @@ def a(cpu, mem):
0,
"copr.py",
);
println!("{res}");
info!("{res}");
} else if let Ok(res) = ret {
dbg!(&res);
} else {

View File

@@ -997,6 +997,7 @@ pub mod tests {
use std::sync::Arc;
use common_telemetry::info;
use datatypes::vectors::{Float32Vector, Int32Vector, NullVector};
use rustpython_vm::builtins::PyList;
use rustpython_vm::class::PyClassImpl;
@@ -1128,9 +1129,10 @@ pub mod tests {
}
#[test]
#[allow(clippy::print_stdout)]
// for debug purpose, also this is already a test function so allow print_stdout shouldn't be a problem?
fn test_execute_script() {
common_telemetry::init_default_ut_logging();
fn is_eq<T: std::cmp::PartialEq + rustpython_vm::TryFromObject>(
v: PyResult,
i: T,
@@ -1179,7 +1181,7 @@ pub mod tests {
for (code, pred) in snippet {
let result = execute_script(&interpreter, code, None, pred);
println!(
info!(
"\u{001B}[35m{code}\u{001B}[0m: {:?}{}",
result.clone().map(|v| v.0),
result

View File

@@ -10,6 +10,7 @@ api = { path = "../api" }
async-trait = "0.1"
axum = "0.6"
axum-macros = "0.3"
base64 = "0.13"
bytes = "1.2"
common-base = { path = "../common/base" }
common-catalog = { path = "../common/catalog" }
@@ -21,8 +22,10 @@ common-runtime = { path = "../common/runtime" }
common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }
datatypes = { path = "../datatypes" }
digest = "0.10"
futures = "0.3"
hex = { version = "0.4" }
http-body = "0.4"
humantime-serde = "1.1"
hyper = { version = "0.14", features = ["full"] }
influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", branch = "feat/line-protocol" }
@@ -31,7 +34,7 @@ num_cpus = "1.13"
once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = "0.3"
pgwire = "0.5"
pgwire = "0.6.1"
prost = "0.11"
rand = "0.8"
regex = "1.6"
@@ -41,8 +44,10 @@ schemars = "0.8"
serde = "1.0"
serde_json = "1.0"
session = { path = "../session" }
sha1 = "0.10"
snafu = { version = "0.7", features = ["backtraces"] }
snap = "1"
strum = { version = "0.24", features = ["derive"] }
table = { path = "../table" }
tokio = { version = "1.20", features = ["full"] }
tokio-rustls = "0.23"
@@ -64,6 +69,7 @@ rand = "0.8"
script = { path = "../script", features = ["python"] }
serde_json = "1.0"
table = { path = "../table" }
tempdir = "0.3"
tokio-postgres = "0.7"
tokio-postgres-rustls = "0.9"
tokio-test = "0.4"

251
src/servers/src/auth.rs Normal file
View File

@@ -0,0 +1,251 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod user_provider;
pub const DEFAULT_USERNAME: &str = "greptime";
use std::sync::Arc;
use common_error::prelude::ErrorExt;
use common_error::status_code::StatusCode;
use snafu::{Backtrace, ErrorCompat, OptionExt, Snafu};
use crate::auth::user_provider::StaticUserProvider;
#[async_trait::async_trait]
pub trait UserProvider: Send + Sync {
fn name(&self) -> &str;
async fn auth(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfo, Error>;
}
pub type UserProviderRef = Arc<dyn UserProvider>;
type Username<'a> = &'a str;
type HostOrIp<'a> = &'a str;
#[derive(Debug, Clone)]
pub enum Identity<'a> {
UserId(Username<'a>, Option<HostOrIp<'a>>),
}
pub type HashedPassword<'a> = &'a [u8];
pub type Salt<'a> = &'a [u8];
/// Authentication information sent by the client.
pub enum Password<'a> {
PlainText(&'a str),
MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
PgMD5(HashedPassword<'a>, Salt<'a>),
}
#[derive(Clone, Debug)]
pub struct UserInfo {
username: String,
}
impl Default for UserInfo {
fn default() -> Self {
Self {
username: DEFAULT_USERNAME.to_string(),
}
}
}
impl UserInfo {
pub fn user_name(&self) -> &str {
&self.username
}
#[cfg(test)]
pub fn new(username: impl Into<String>) -> Self {
Self {
username: username.into(),
}
}
}
pub fn user_provider_from_option(opt: &String) -> Result<UserProviderRef, Error> {
let (name, content) = opt.split_once(':').context(InvalidConfigSnafu {
value: opt.to_string(),
msg: "UserProviderOption must be in format `<option>:<value>`",
})?;
match name {
user_provider::STATIC_USER_PROVIDER => {
let provider =
StaticUserProvider::try_from(content).map(|p| Arc::new(p) as UserProviderRef)?;
Ok(provider)
}
_ => InvalidConfigSnafu {
value: name.to_string(),
msg: "Invalid UserProviderOption",
}
.fail(),
}
}
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("Invalid config value: {}, {}", value, msg))]
InvalidConfig {
value: String,
msg: String,
backtrace: Backtrace,
},
#[snafu(display("Encounter IO error, source: {}", source))]
IOErr { source: std::io::Error },
#[snafu(display("User not found, username: {}", username))]
UserNotFound { username: String },
#[snafu(display("Unsupported password type: {}", password_type))]
UnsupportedPasswordType {
password_type: String,
backtrace: Backtrace,
},
#[snafu(display("Username and password does not match, username: {}", username))]
UserPasswordMismatch { username: String },
}
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::InvalidConfig { .. } => StatusCode::InvalidArguments,
Error::IOErr { .. } => StatusCode::Internal,
Error::UserNotFound { .. } => StatusCode::UserNotFound,
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
}
}
fn backtrace_opt(&self) -> Option<&Backtrace> {
ErrorCompat::backtrace(self)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
pub mod test {
use super::{Identity, Password, UserInfo, UserProvider};
pub struct MockUserProvider {}
#[async_trait::async_trait]
impl UserProvider for MockUserProvider {
fn name(&self) -> &str {
"mock_user_provider"
}
async fn auth(
&self,
id: Identity<'_>,
password: Password<'_>,
) -> Result<UserInfo, super::Error> {
match id {
Identity::UserId(username, _host) => match password {
Password::PlainText(password) => {
if username == "greptime" {
if password == "greptime" {
return Ok(UserInfo {
username: "greptime".to_string(),
});
} else {
return super::UserPasswordMismatchSnafu {
username: username.to_string(),
}
.fail();
}
} else {
return super::UserNotFoundSnafu {
username: username.to_string(),
}
.fail();
}
}
_ => super::UnsupportedPasswordTypeSnafu {
password_type: "mysql_native_password",
}
.fail(),
},
}
}
}
}
#[cfg(test)]
mod tests {
use super::test::MockUserProvider;
use super::{Identity, Password, UserProvider};
use crate::auth;
#[tokio::test]
async fn test_auth_by_plain_text() {
let user_provider = MockUserProvider {};
assert_eq!("mock_user_provider", user_provider.name());
// auth success
let auth_result = user_provider
.auth(
Identity::UserId("greptime", None),
Password::PlainText("greptime"),
)
.await;
assert!(auth_result.is_ok());
assert_eq!("greptime", auth_result.unwrap().user_name());
// auth failed, unsupported password type
let auth_result = user_provider
.auth(
Identity::UserId("greptime", None),
Password::MysqlNativePassword(b"hashed_value", b"salt"),
)
.await;
assert!(auth_result.is_err());
matches!(
auth_result.err().unwrap(),
auth::Error::UnsupportedPasswordType { .. }
);
// auth failed, err: user not exist.
let auth_result = user_provider
.auth(
Identity::UserId("not_exist_username", None),
Password::PlainText("greptime"),
)
.await;
assert!(auth_result.is_err());
matches!(auth_result.err().unwrap(), auth::Error::UserNotFound { .. });
// auth failed, err: wrong password
let auth_result = user_provider
.auth(
Identity::UserId("greptime", None),
Password::PlainText("wrong_password"),
)
.await;
assert!(auth_result.is_err());
matches!(
auth_result.err().unwrap(),
auth::Error::UserPasswordMismatch { .. }
);
}
}

View File

@@ -0,0 +1,253 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::fs::File;
use std::io;
use std::io::BufRead;
use std::path::Path;
use async_trait::async_trait;
use digest;
use digest::Digest;
use sha1::Sha1;
use snafu::{ensure, OptionExt, ResultExt};
use crate::auth::{
Error, HashedPassword, IOErrSnafu, Identity, InvalidConfigSnafu, Password, Salt,
UnsupportedPasswordTypeSnafu, UserInfo, UserNotFoundSnafu, UserPasswordMismatchSnafu,
UserProvider,
};
pub const STATIC_USER_PROVIDER: &str = "static_user_provider";
impl TryFrom<&str> for StaticUserProvider {
type Error = Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let (mode, content) = value.split_once(':').context(InvalidConfigSnafu {
value: value.to_string(),
msg: "StaticUserProviderOption must be in format `<option>:<value>`",
})?;
return match mode {
"file" => {
// check valid path
let path = Path::new(content);
ensure!(path.exists() && path.is_file(), InvalidConfigSnafu {
value: content.to_string(),
msg: "StaticUserProviderOption file must be a valid file path",
});
let file = File::open(path).context(IOErrSnafu)?;
let credential = io::BufReader::new(file)
.lines()
.filter_map(|line| line.ok())
.filter_map(|line| {
if let Some((k, v)) = line.split_once('=') {
Some((k.to_string(), v.as_bytes().to_vec()))
} else {
None
}
})
.collect::<HashMap<String, Vec<u8>>>();
ensure!(!credential.is_empty(), InvalidConfigSnafu {
value: content.to_string(),
msg: "StaticUserProviderOption file must contains at least one valid credential",
});
Ok(StaticUserProvider { users: credential, })
}
"cmd" => content
.split(',')
.map(|kv| {
let (k, v) = kv.split_once('=').context(InvalidConfigSnafu {
value: kv.to_string(),
msg: "StaticUserProviderOption cmd values must be in format `user=pwd[,user=pwd]`",
})?;
Ok((k.to_string(), v.as_bytes().to_vec()))
})
.collect::<Result<HashMap<String, Vec<u8>>, Error>>()
.map(|users| StaticUserProvider { users }),
_ => InvalidConfigSnafu {
value: mode.to_string(),
msg: "StaticUserProviderOption must be in format `file:<path>` or `cmd:<values>`",
}
.fail(),
};
}
}
pub struct StaticUserProvider {
users: HashMap<String, Vec<u8>>,
}
#[async_trait]
impl UserProvider for StaticUserProvider {
fn name(&self) -> &str {
STATIC_USER_PROVIDER
}
async fn auth(
&self,
input_id: Identity<'_>,
input_pwd: Password<'_>,
) -> Result<UserInfo, Error> {
match input_id {
Identity::UserId(username, _) => {
let save_pwd = self.users.get(username).context(UserNotFoundSnafu {
username: username.to_string(),
})?;
match input_pwd {
Password::PlainText(pwd) => {
return if save_pwd == pwd.as_bytes() {
Ok(UserInfo {
username: username.to_string(),
})
} else {
UserPasswordMismatchSnafu {
username: username.to_string(),
}
.fail()
}
}
Password::MysqlNativePassword(auth_data, salt) => {
auth_mysql(auth_data, salt, username.to_string(), save_pwd)
}
Password::PgMD5(_, _) => UnsupportedPasswordTypeSnafu {
password_type: "pg_md5",
}
.fail(),
}
}
}
}
}
fn auth_mysql(
auth_data: HashedPassword,
salt: Salt,
username: String,
save_pwd: &[u8],
) -> Result<UserInfo, Error> {
// ref: https://github.com/mysql/mysql-server/blob/a246bad76b9271cb4333634e954040a970222e0a/sql/auth/password.cc#L62
let hash_stage_2 = double_sha1(save_pwd);
let tmp = sha1_two(salt, &hash_stage_2);
// xor auth_data and tmp
let mut xor_result = [0u8; 20];
for i in 0..20 {
xor_result[i] = auth_data[i] ^ tmp[i];
}
let candidate_stage_2 = sha1_one(&xor_result);
if candidate_stage_2 == hash_stage_2 {
Ok(UserInfo { username })
} else {
UserPasswordMismatchSnafu { username }.fail()
}
}
fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
let mut hasher = Sha1::new();
hasher.update(input_1);
hasher.update(input_2);
hasher.finalize().to_vec()
}
fn sha1_one(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha1::new();
hasher.update(data);
hasher.finalize().to_vec()
}
fn double_sha1(data: &[u8]) -> Vec<u8> {
sha1_one(&sha1_one(data))
}
#[cfg(test)]
pub mod test {
use std::fs::File;
use std::io::{LineWriter, Write};
use tempdir::TempDir;
use crate::auth::user_provider::{double_sha1, sha1_one, sha1_two, StaticUserProvider};
use crate::auth::{Identity, Password, UserProvider};
#[test]
fn test_sha() {
let sha_1_answer: Vec<u8> = vec![
124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
27,
];
let sha_1 = sha1_one("123456".as_bytes());
assert_eq!(sha_1, sha_1_answer);
let double_sha1_answer: Vec<u8> = vec![
107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
42, 217,
];
let double_sha1 = double_sha1("123456".as_bytes());
assert_eq!(double_sha1, double_sha1_answer);
let sha1_2_answer: Vec<u8> = vec![
132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
37, 204,
];
let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
assert_eq!(sha1_2, sha1_2_answer);
}
async fn test_auth(provider: &dyn UserProvider, username: &str, password: &str) {
let re = provider
.auth(
Identity::UserId(username, None),
Password::PlainText(password),
)
.await;
assert!(re.is_ok());
}
#[tokio::test]
async fn test_inline_provider() {
let provider = StaticUserProvider::try_from("cmd:root=123456,admin=654321").unwrap();
test_auth(&provider, "root", "123456").await;
test_auth(&provider, "admin", "654321").await;
}
#[tokio::test]
async fn test_file_provider() {
let dir = TempDir::new("test_file_provider").unwrap();
let file_path = format!("{}/test_file_provider", dir.path().to_str().unwrap());
{
// write a tmp file
let file = File::create(&file_path);
assert!(file.is_ok());
let file = file.unwrap();
let mut lw = LineWriter::new(file);
assert!(lw
.write_all(
b"root=123456
admin=654321",
)
.is_ok());
assert!(lw.flush().is_ok());
}
let param = format!("file:{}", file_path);
let provider = StaticUserProvider::try_from(param.as_str()).unwrap();
test_auth(&provider, "root", "123456").await;
test_auth(&provider, "admin", "654321").await;
}
}

View File

@@ -12,23 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use snafu::OptionExt;
use crate::auth::UserInfo;
use crate::error::{BuildingContextSnafu, Result};
type CtxFnRef = Arc<dyn Fn(&Context) -> bool + Send + Sync>;
#[derive(Serialize, Deserialize)]
pub struct Context {
pub exec_info: ExecInfo,
pub client_info: ClientInfo,
pub user_info: UserInfo,
pub quota: Quota,
#[serde(skip)]
pub predicates: Vec<CtxFnRef>,
}
@@ -41,10 +37,8 @@ impl Context {
#[derive(Default)]
pub struct CtxBuilder {
client_addr: Option<String>,
username: Option<String>,
from_channel: Option<Channel>,
auth_method: Option<AuthMethod>,
user_info: Option<UserInfo>,
}
impl CtxBuilder {
@@ -52,23 +46,18 @@ impl CtxBuilder {
CtxBuilder::default()
}
pub fn client_addr(mut self, addr: Option<String>) -> CtxBuilder {
self.client_addr = addr;
pub fn client_addr(mut self, addr: String) -> CtxBuilder {
self.client_addr = Some(addr);
self
}
pub fn set_channel(mut self, channel: Option<Channel>) -> CtxBuilder {
self.from_channel = channel;
pub fn set_channel(mut self, channel: Channel) -> CtxBuilder {
self.from_channel = Some(channel);
self
}
pub fn set_auth_method(mut self, auth_method: Option<AuthMethod>) -> CtxBuilder {
self.auth_method = auth_method;
self
}
pub fn set_username(mut self, username: Option<String>) -> CtxBuilder {
self.username = username;
pub fn set_user_info(mut self, user_info: UserInfo) -> CtxBuilder {
self.user_info = Some(user_info);
self
}
@@ -78,87 +67,32 @@ impl CtxBuilder {
client_host: self.client_addr.context(BuildingContextSnafu {
err_msg: "unknown client addr while building ctx",
})?,
},
user_info: UserInfo {
username: self.username,
from_channel: self.from_channel.context(BuildingContextSnafu {
channel: self.from_channel.context(BuildingContextSnafu {
err_msg: "unknown channel while building ctx",
})?,
auth_method: self.auth_method.context(BuildingContextSnafu {
err_msg: "unknown auth method while building ctx",
})?,
},
exec_info: ExecInfo::default(),
user_info: self.user_info.context(BuildingContextSnafu {
err_msg: "missing user info while building ctx",
})?,
quota: Quota::default(),
predicates: vec![],
})
}
}
#[derive(Serialize, Deserialize)]
pub struct ExecInfo {
pub catalog: Option<String>,
pub schema: Option<String>,
// should opts to be thread safe?
pub extra_opts: HashMap<String, String>,
pub trace_id: Option<String>,
}
impl Default for ExecInfo {
fn default() -> Self {
ExecInfo {
catalog: Some("greptime".to_string()),
schema: Some("public".to_string()),
extra_opts: HashMap::new(),
trace_id: None,
}
}
}
#[derive(Default, Serialize, Deserialize)]
pub struct ClientInfo {
pub client_host: String,
pub channel: Channel,
}
#[derive(Serialize, Deserialize)]
pub struct UserInfo {
pub username: Option<String>,
pub from_channel: Channel,
pub auth_method: AuthMethod,
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Channel {
GRPC,
HTTP,
MYSQL,
Grpc,
Http,
Mysql,
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum AuthMethod {
None,
Password {
hash_method: AuthHashMethod,
hashed_value: Vec<u8>,
salt: Vec<u8>,
},
Token(String),
}
impl Default for AuthMethod {
fn default() -> Self {
AuthMethod::None
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum AuthHashMethod {
DoubleSha1,
Sha256,
}
#[derive(Default, Serialize, Deserialize)]
#[derive(Default)]
pub struct Quota {
pub total: u64,
pub consumed: u64,
@@ -170,20 +104,18 @@ mod test {
use std::sync::Arc;
use crate::context::AuthMethod::Token;
use crate::context::Channel::HTTP;
use crate::context::{Channel, Context, CtxBuilder, UserInfo};
use crate::auth::UserInfo;
use crate::context::Channel::{self, Http};
use crate::context::{ClientInfo, Context, CtxBuilder};
#[test]
fn test_predicate() {
let mut ctx = Context {
exec_info: Default::default(),
client_info: Default::default(),
user_info: UserInfo {
username: None,
from_channel: Channel::GRPC,
auth_method: Default::default(),
client_info: ClientInfo {
client_host: Default::default(),
channel: Channel::Grpc,
},
user_info: UserInfo::new("greptime"),
quota: Default::default(),
predicates: vec![],
};
@@ -204,23 +136,14 @@ mod test {
#[test]
fn test_build() {
let ctx = CtxBuilder::new()
.client_addr(Some("127.0.0.1:4001".to_string()))
.set_channel(Some(HTTP))
.set_auth_method(Some(Token("HELLO".to_string())))
.client_addr("127.0.0.1:4001".to_string())
.set_channel(Http)
.set_user_info(UserInfo::new("greptime"))
.build()
.unwrap();
assert_eq!(ctx.exec_info.catalog.unwrap(), String::from("greptime"));
assert_eq!(ctx.exec_info.schema.unwrap(), String::from("public"));
assert_eq!(ctx.exec_info.extra_opts.len(), 0);
assert_eq!(ctx.exec_info.trace_id, None);
assert_eq!(ctx.client_info.client_host, String::from("127.0.0.1:4001"));
assert_eq!(ctx.user_info.username, None);
assert_eq!(ctx.user_info.from_channel, HTTP);
assert_eq!(ctx.user_info.auth_method, Token(String::from("HELLO")));
assert_eq!(ctx.quota.total, 0);
assert_eq!(ctx.quota.consumed, 0);
assert_eq!(ctx.quota.estimated, 0);

View File

@@ -14,13 +14,18 @@
use std::any::Any;
use std::net::SocketAddr;
use std::string::FromUtf8Error;
use axum::http::StatusCode as HttpStatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use base64::DecodeError;
use common_error::prelude::*;
use hyper::header::ToStrError;
use serde_json::json;
use crate::auth;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
@@ -195,6 +200,39 @@ pub enum Error {
#[snafu(display("Tls is required for {}, plain connection is rejected", server))]
TlsRequired { server: String },
#[snafu(display("Failed to get user info, source: {}", source))]
Auth {
#[snafu(backtrace)]
source: auth::Error,
},
#[snafu(display("Not found http authorization header"))]
NotFoundAuthHeader {},
#[snafu(display("Invalid visibility ASCII chars, source: {}", source))]
InvisibleASCII {
source: ToStrError,
backtrace: Backtrace,
},
#[snafu(display("Unsupported http auth scheme, name: {}", name))]
UnsupportedAuthScheme { name: String },
#[snafu(display("Invalid http authorization header"))]
InvalidAuthorizationHeader { backtrace: Backtrace },
#[snafu(display("Invalid base64 value, source: {:?}", source))]
InvalidBase64Value {
source: DecodeError,
backtrace: Backtrace,
},
#[snafu(display("Invalid utf-8 value, source: {:?}", source))]
InvalidUtf8Value {
source: FromUtf8Error,
backtrace: Backtrace,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -239,6 +277,14 @@ impl ErrorExt for Error {
Hyper { .. } => StatusCode::Unknown,
TlsRequired { .. } => StatusCode::Unknown,
StartFrontend { source, .. } => source.status_code(),
Auth { source, .. } => source.status_code(),
NotFoundAuthHeader { .. } => StatusCode::AuthHeaderNotFound,
InvisibleASCII { .. }
| UnsupportedAuthScheme { .. }
| InvalidAuthorizationHeader { .. }
| InvalidBase64Value { .. }
| InvalidUtf8Value { .. } => StatusCode::InvalidAuthHeader,
}
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod context;
mod authorize;
pub mod handler;
pub mod influxdb;
pub mod opentsdb;
@@ -26,8 +26,8 @@ use std::time::Duration;
use aide::axum::{routing as apirouting, ApiRouter, IntoApiResponse};
use aide::openapi::{Info, OpenApi, Server as OpenAPIServer};
use async_trait::async_trait;
use axum::body::BoxBody;
use axum::error_handling::HandleErrorLayer;
use axum::middleware::{self};
use axum::response::{Html, Json};
use axum::{routing, BoxError, Extension, Router};
use common_error::prelude::ErrorExt;
@@ -45,9 +45,12 @@ use tokio::sync::oneshot::{self, Sender};
use tokio::sync::Mutex;
use tower::timeout::TimeoutLayer;
use tower::ServiceBuilder;
use tower_http::auth::AsyncRequireAuthorizationLayer;
use tower_http::trace::TraceLayer;
use self::authorize::HttpAuth;
use self::influxdb::influxdb_write;
use crate::auth::UserProviderRef;
use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu};
use crate::query_handler::{
InfluxdbLineProtocolHandlerRef, OpentsdbProtocolHandlerRef, PrometheusProtocolHandlerRef,
@@ -65,6 +68,7 @@ pub struct HttpServer {
prom_handler: Option<PrometheusProtocolHandlerRef>,
script_handler: Option<ScriptHandlerRef>,
shutdown_tx: Mutex<Option<Sender<()>>>,
user_provider: Option<UserProviderRef>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -295,6 +299,7 @@ impl HttpServer {
opentsdb_handler: None,
influxdb_handler: None,
prom_handler: None,
user_provider: None,
script_handler: None,
shutdown_tx: Mutex::new(None),
}
@@ -332,6 +337,14 @@ impl HttpServer {
self.prom_handler.get_or_insert(handler);
}
pub fn set_user_provider(&mut self, user_provider: UserProviderRef) {
debug_assert!(
self.user_provider.is_none(),
"User provider can be set only once!"
);
self.user_provider.get_or_insert(user_provider);
}
pub fn make_app(&self) -> Router {
let mut api = OpenApi {
info: Info {
@@ -393,7 +406,9 @@ impl HttpServer {
.layer(TraceLayer::new_for_http())
.layer(TimeoutLayer::new(self.options.timeout))
// custom layer
.layer(middleware::from_fn(context::build_ctx)),
.layer(AsyncRequireAuthorizationLayer::new(
HttpAuth::<BoxBody>::new(self.user_provider.clone()),
)),
)
}

View File

@@ -0,0 +1,282 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::marker::PhantomData;
use axum::http::{self, Request, StatusCode};
use axum::response::Response;
use common_telemetry::error;
use futures::future::BoxFuture;
use http_body::Body;
use snafu::{OptionExt, ResultExt};
use tower_http::auth::AsyncAuthorizeRequest;
use crate::auth::{Identity, UserInfo, UserProviderRef};
use crate::error::{self, Result};
pub struct HttpAuth<RespBody> {
user_provider: Option<UserProviderRef>,
_ty: PhantomData<RespBody>,
}
impl<RespBody> HttpAuth<RespBody> {
pub fn new(user_provider: Option<UserProviderRef>) -> Self {
Self {
user_provider,
_ty: PhantomData,
}
}
}
impl<RespBody> Clone for HttpAuth<RespBody> {
fn clone(&self) -> Self {
Self {
user_provider: self.user_provider.clone(),
_ty: PhantomData,
}
}
}
impl<B, RespBody> AsyncAuthorizeRequest<B> for HttpAuth<RespBody>
where
B: Send + Sync + 'static,
RespBody: Body + Default,
{
type RequestBody = B;
type ResponseBody = RespBody;
type Future = BoxFuture<'static, std::result::Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
let user_provider = self.user_provider.clone();
Box::pin(async move {
let user_provider = if let Some(user_provider) = &user_provider {
user_provider
} else {
request.extensions_mut().insert(UserInfo::default());
return Ok(request);
};
let (scheme, credential) = match auth_header(&request) {
Ok(auth_header) => auth_header,
Err(e) => {
error!("failed to get http authorize header, err: {:?}", e);
return Err(unauthorized_resp());
}
};
match scheme {
AuthScheme::Basic => {
let (username, password) = match decode_basic(credential) {
Ok(basic_auth) => basic_auth,
Err(e) => {
error!("failed to decode basic authorize, err: {:?}", e);
return Err(unauthorized_resp());
}
};
match user_provider
.auth(
Identity::UserId(&username, None),
crate::auth::Password::PlainText(&password),
)
.await
{
Ok(user_info) => {
request.extensions_mut().insert(user_info);
Ok(request)
}
Err(e) => {
error!("failed to auth, err: {:?}", e);
Err(unauthorized_resp())
}
}
}
}
})
}
}
fn unauthorized_resp<RespBody>() -> Response<RespBody>
where
RespBody: Body + Default,
{
let mut res = Response::new(RespBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res
}
#[derive(Debug)]
pub enum AuthScheme {
Basic,
}
impl TryFrom<&str> for AuthScheme {
type Error = error::Error;
fn try_from(value: &str) -> Result<Self> {
match value.to_lowercase().as_str() {
"basic" => Ok(AuthScheme::Basic),
other => error::UnsupportedAuthSchemeSnafu { name: other }.fail(),
}
}
}
type Credential<'a> = &'a str;
fn auth_header<B>(req: &Request<B>) -> Result<(AuthScheme, Credential)> {
let auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
.context(error::NotFoundAuthHeaderSnafu)?
.to_str()
.context(error::InvisibleASCIISnafu)?;
let (auth_scheme, encoded_credentials) = auth_header
.split_once(' ')
.context(error::InvalidAuthorizationHeaderSnafu)?;
if encoded_credentials.contains(' ') {
return error::InvalidAuthorizationHeaderSnafu {}.fail();
}
Ok((auth_scheme.try_into()?, encoded_credentials))
}
type Username = String;
type Password = String;
fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
let decoded = base64::decode(credential).context(error::InvalidBase64ValueSnafu)?;
let as_utf8 = String::from_utf8(decoded).context(error::InvalidUtf8ValueSnafu)?;
if let Some((user_id, password)) = as_utf8.split_once(':') {
return Ok((user_id.to_string(), password.to_string()));
}
error::InvalidAuthorizationHeaderSnafu {}.fail()
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use std::sync::Arc;
use axum::body::BoxBody;
use axum::http;
use hyper::Request;
use tower_http::auth::AsyncAuthorizeRequest;
use super::{auth_header, decode_basic, AuthScheme, HttpAuth};
use crate::auth::test::MockUserProvider;
use crate::auth::{UserInfo, UserProvider};
use crate::error;
use crate::error::Result;
#[tokio::test]
async fn test_http_auth() {
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
user_provider: None,
_ty: PhantomData,
};
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let auth_res = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.user_name(), user_info.user_name());
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
user_provider: mock_user_provider,
_ty: PhantomData,
};
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
let req = mock_http_request("Basic Z3JlcHRpbWU6Z3JlcHRpbWU=").unwrap();
let req = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = req.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.user_name(), user_info.user_name());
let req = mock_http_request_no_auth().unwrap();
let auth_res = http_auth.authorize(req).await;
assert!(auth_res.is_err());
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let auth_res = http_auth.authorize(wrong_req).await;
assert!(auth_res.is_err());
}
#[test]
fn test_decode_basic() {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let credential = "dXNlcm5hbWU6cGFzc3dvcmQ=";
let (username, pwd) = decode_basic(credential).unwrap();
assert_eq!("username", username);
assert_eq!("password", pwd);
let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ=";
let result = decode_basic(wrong_credential);
matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
}
#[test]
fn test_try_into_auth_scheme() {
let auth_scheme_str = "basic";
let auth_scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
matches!(auth_scheme, AuthScheme::Basic);
let unsupported = "digest";
let auth_scheme: Result<AuthScheme> = unsupported.try_into();
assert!(auth_scheme.is_err());
}
#[test]
fn test_auth_header() {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let (auth_scheme, credential) = auth_header(&req).unwrap();
matches!(auth_scheme, AuthScheme::Basic);
assert_eq!("dXNlcm5hbWU6cGFzc3dvcmQ=", credential);
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6 cGFzc3dvcmQ=").unwrap();
let res = auth_header(&wrong_req);
matches!(
res.err(),
Some(error::Error::InvalidAuthorizationHeader { .. })
);
let wrong_req = mock_http_request("Digest dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let res = auth_header(&wrong_req);
matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
}
fn mock_http_request(auth_header: &str) -> Result<Request<()>> {
Ok(Request::builder()
.uri("https://www.rust-lang.org/")
.header(http::header::AUTHORIZATION, auth_header)
.body(())
.unwrap())
}
fn mock_http_request_no_auth() -> Result<Request<()>> {
Ok(Request::builder()
.uri("https://www.rust-lang.org/")
.body(())
.unwrap())
}
}

View File

@@ -1,60 +0,0 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use axum::http;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::Response;
use common_telemetry::error;
use crate::context::{AuthMethod, Channel, CtxBuilder};
pub async fn build_ctx<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
let auth_option = req
.headers()
.get(http::header::AUTHORIZATION)
.map(|header| {
header
.to_str()
.map(|header_str| match header_str.split_once(' ') {
Some((name, content)) if name == "Bearer" || name == "TOKEN" => {
AuthMethod::Token(String::from(content))
}
_ => AuthMethod::None,
})
.unwrap_or(AuthMethod::None)
})
.or(Some(AuthMethod::None));
match CtxBuilder::new()
.client_addr(
req.headers()
.get(http::header::HOST)
.and_then(|h| h.to_str().ok())
.map(|h| h.to_string()),
)
.set_channel(Some(Channel::HTTP))
.set_auth_method(auth_option)
.build()
{
Ok(ctx) => {
req.extensions_mut().insert(ctx);
Ok(next.run(req).await)
}
Err(e) => {
error!(e; "fail to create context");
Err(StatusCode::INTERNAL_SERVER_ERROR)
}
}
}

View File

@@ -18,12 +18,14 @@ use std::time::Instant;
use aide::transform::TransformOperation;
use axum::extract::{Json, Query, State};
use axum::Extension;
use common_error::status_code::StatusCode;
use common_telemetry::metric;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use session::context::QueryContext;
use crate::auth::UserInfo;
use crate::http::{ApiState, JsonResponse};
#[derive(Debug, Default, Serialize, Deserialize, JsonSchema)]
@@ -37,6 +39,8 @@ pub struct SqlQuery {
pub async fn sql(
State(state): State<ApiState>,
Query(params): Query<SqlQuery>,
// TODO(fys): pass _user_info into query context
_user_info: Extension<UserInfo>,
) -> Json<JsonResponse> {
let sql_handler = &state.sql_handler;
let start = Instant::now();

View File

@@ -16,6 +16,7 @@
use serde::{Deserialize, Serialize};
pub mod auth;
pub mod context;
pub mod error;
pub mod grpc;

View File

@@ -26,9 +26,9 @@ use session::Session;
use tokio::io::AsyncWrite;
use tokio::sync::RwLock;
use crate::context::AuthHashMethod::DoubleSha1;
use crate::context::Channel::MYSQL;
use crate::context::{AuthMethod, Context, CtxBuilder};
use crate::auth::{Identity, Password, UserProviderRef};
use crate::context::Channel::Mysql;
use crate::context::{Context, CtxBuilder};
use crate::error::{self, Result};
use crate::mysql::writer::MysqlResultWriter;
use crate::query_handler::SqlQueryHandlerRef;
@@ -41,10 +41,15 @@ pub struct MysqlInstanceShim {
// TODO(LFC): Break `Context` struct into different fields in `Session`, each with its own purpose.
ctx: Arc<RwLock<Option<Context>>>,
session: Arc<Session>,
user_provider: Option<UserProviderRef>,
}
impl MysqlInstanceShim {
pub fn create(query_handler: SqlQueryHandlerRef, client_addr: String) -> MysqlInstanceShim {
pub fn create(
query_handler: SqlQueryHandlerRef,
client_addr: String,
user_provider: Option<UserProviderRef>,
) -> MysqlInstanceShim {
// init a random salt
let mut bs = vec![0u8; 20];
let mut rng = rand::thread_rng();
@@ -64,6 +69,7 @@ impl MysqlInstanceShim {
client_addr,
ctx: Arc::new(RwLock::new(None)),
session: Arc::new(Session::new()),
user_provider,
}
}
@@ -102,28 +108,42 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
async fn authenticate(
&self,
_auth_plugin: &str,
auth_plugin: &str,
username: &[u8],
salt: &[u8],
auth_data: &[u8],
) -> bool {
// if not specified then **root** will be used
// if not specified then **greptime** will be used
let username = String::from_utf8_lossy(username);
let client_addr = self.client_addr.clone();
let auth_method = match auth_data.len() {
0 => AuthMethod::None,
_ => AuthMethod::Password {
hash_method: DoubleSha1,
hashed_value: auth_data.to_vec(),
salt: salt.to_vec(),
},
};
let mut user_info = None;
if let Some(user_provider) = &self.user_provider {
let user_id = Identity::UserId(&username, Some(&client_addr));
let password = match auth_plugin {
"mysql_native_password" => Password::MysqlNativePassword(auth_data, salt),
other => {
error!("Unsupported mysql auth plugin: {}", other);
return false;
}
};
match user_provider.auth(user_id, password).await {
Ok(userinfo) => {
user_info = Some(userinfo);
}
Err(e) => {
error!("Failed to auth, err: {:?}", e);
return false;
}
};
}
let user_info = user_info.unwrap_or_default();
return match CtxBuilder::new()
.client_addr(Some(client_addr))
.set_channel(Some(MYSQL))
.set_username(Some(username.to_string()))
.set_auth_method(Some(auth_method))
.client_addr(client_addr)
.set_channel(Mysql)
.set_user_info(user_info)
.build()
{
Ok(ctx) => {

View File

@@ -28,6 +28,7 @@ use tokio::io::BufWriter;
use tokio::net::TcpStream;
use tokio_rustls::rustls::ServerConfig;
use crate::auth::UserProviderRef;
use crate::error::{Error, Result};
use crate::mysql::handler::MysqlInstanceShim;
use crate::query_handler::SqlQueryHandlerRef;
@@ -41,6 +42,7 @@ pub struct MysqlServer {
base_server: BaseTcpServer,
query_handler: SqlQueryHandlerRef,
tls: TlsOption,
user_provider: Option<UserProviderRef>,
}
impl MysqlServer {
@@ -48,11 +50,13 @@ impl MysqlServer {
query_handler: SqlQueryHandlerRef,
io_runtime: Arc<Runtime>,
tls: TlsOption,
user_provider: Option<UserProviderRef>,
) -> Box<dyn Server> {
Box::new(MysqlServer {
base_server: BaseTcpServer::create_server("MySQL", io_runtime),
query_handler,
tls,
user_provider,
})
}
@@ -63,19 +67,29 @@ impl MysqlServer {
tls_conf: Option<Arc<ServerConfig>>,
) -> impl Future<Output = ()> {
let query_handler = self.query_handler.clone();
let user_provider = self.user_provider.clone();
let force_tls = self.tls.should_force_tls();
stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let query_handler = query_handler.clone();
let user_provider = user_provider.clone();
let tls_conf = tls_conf.clone();
async move {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
if let Err(error) =
Self::handle(io_stream, io_runtime, query_handler, tls_conf, force_tls)
.await
if let Err(error) = Self::handle(
io_stream,
io_runtime,
query_handler,
tls_conf,
force_tls,
user_provider,
)
.await
{
error!(error; "Unexpected error when handling TcpStream");
};
@@ -91,11 +105,12 @@ impl MysqlServer {
query_handler: SqlQueryHandlerRef,
tls_conf: Option<Arc<ServerConfig>>,
force_tls: bool,
user_provider: Option<UserProviderRef>,
) -> Result<()> {
info!("MySQL connection coming from: {}", stream.peer_addr()?);
io_runtime .spawn(async move {
// TODO(LFC): Use `output_stream` to write large MySQL ResultSet to client.
if let Err(e) = Self::do_handle(stream, query_handler, tls_conf, force_tls).await {
if let Err(e) = Self::do_handle(stream, query_handler, tls_conf, force_tls, user_provider).await {
// TODO(LFC): Write this error to client as well, in MySQL text protocol.
// Looks like we have to expose opensrv-mysql's `PacketWriter`?
error!(e; "Internal error occurred during query exec, server actively close the channel to let client try next time.")
@@ -110,8 +125,13 @@ impl MysqlServer {
query_handler: SqlQueryHandlerRef,
tls_conf: Option<Arc<ServerConfig>>,
force_tls: bool,
user_provider: Option<UserProviderRef>,
) -> Result<()> {
let mut shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string());
let mut shim = MysqlInstanceShim::create(
query_handler,
stream.peer_addr()?.to_string(),
user_provider,
);
let (mut r, w) = stream.into_split();
let mut w = BufWriter::with_capacity(DEFAULT_RESULT_SET_WRITE_BUFFER_SIZE, w);
let ops = IntermediaryOptions::default();

View File

@@ -23,11 +23,56 @@ use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use snafu::ResultExt;
struct PgPwdVerifier;
use crate::auth::{Identity, Password, UserProviderRef};
use crate::error;
use crate::error::Result;
struct PgPwdVerifier {
user_provider: Option<UserProviderRef>,
}
#[allow(dead_code)]
struct LoginInfo {
user: Option<String>,
database: Option<String>,
host: String,
}
impl LoginInfo {
pub fn from_client_info<C>(client: &C) -> LoginInfo
where
C: ClientInfo,
{
LoginInfo {
user: client.metadata().get(super::METADATA_USER).map(Into::into),
database: client
.metadata()
.get(super::METADATA_DATABASE)
.map(Into::into),
host: client.socket_addr().ip().to_string(),
}
}
}
impl PgPwdVerifier {
async fn verify_pwd(&self, _pwd: &str, _meta: HashMap<String, String>) -> PgWireResult<bool> {
async fn verify_pwd(&self, password: &str, login: LoginInfo) -> Result<bool> {
if let Some(user_provider) = &self.user_provider {
let user_name = match login.user {
Some(name) => name,
None => return Ok(false),
};
// TODO(fys): pass user_info to context
let _user_info = user_provider
.auth(
Identity::UserId(&user_name, None),
Password::PlainText(password),
)
.await
.context(error::AuthSnafu)?;
}
Ok(true)
}
}
@@ -62,16 +107,14 @@ impl ServerParameterProvider for GreptimeDBStartupParameters {
pub struct PgAuthStartupHandler {
verifier: PgPwdVerifier,
param_provider: GreptimeDBStartupParameters,
with_pwd: bool,
force_tls: bool,
}
impl PgAuthStartupHandler {
pub fn new(with_pwd: bool, force_tls: bool) -> Self {
pub fn new(user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
PgAuthStartupHandler {
verifier: PgPwdVerifier,
verifier: PgPwdVerifier { user_provider },
param_provider: GreptimeDBStartupParameters::new(),
with_pwd,
force_tls,
}
}
@@ -106,7 +149,7 @@ impl StartupHandler for PgAuthStartupHandler {
return Ok(());
}
auth::save_startup_parameters_to_metadata(client, startup);
if self.with_pwd {
if self.verifier.user_provider.is_some() {
client.set_state(PgWireConnectionState::AuthenticationInProgress);
client
.send(PgWireBackendMessage::Authentication(
@@ -118,8 +161,8 @@ impl StartupHandler for PgAuthStartupHandler {
}
}
PgWireFrontendMessage::Password(ref pwd) => {
let meta = client.metadata().clone();
if let Ok(true) = self.verifier.verify_pwd(pwd.password(), meta).await {
let login_info = LoginInfo::from_client_info(client);
if let Ok(true) = self.verifier.verify_pwd(pwd.password(), login_info).await {
auth::finish_authentication(client, &self.param_provider).await
} else {
let error_info = ErrorInfo::new(

View File

@@ -42,14 +42,12 @@ impl PostgresServerHandler {
}
}
const CLIENT_METADATA_DATABASE: &str = "database";
fn query_context_from_client_info<C>(client: &C) -> Arc<QueryContext>
where
C: ClientInfo,
{
let query_context = QueryContext::new();
if let Some(current_schema) = client.metadata().get(CLIENT_METADATA_DATABASE) {
if let Some(current_schema) = client.metadata().get(super::METADATA_DATABASE) {
query_context.set_current_schema(current_schema);
}

View File

@@ -16,4 +16,7 @@ mod auth_handler;
mod handler;
mod server;
pub(crate) const METADATA_USER: &str = "user";
pub(crate) const METADATA_DATABASE: &str = "database";
pub use server::PostgresServer;

View File

@@ -19,11 +19,13 @@ use std::sync::Arc;
use async_trait::async_trait;
use common_runtime::Runtime;
use common_telemetry::logging::error;
use common_telemetry::{debug, warn};
use futures::StreamExt;
use pgwire::tokio::process_socket;
use tokio;
use tokio_rustls::TlsAcceptor;
use crate::auth::UserProviderRef;
use crate::error::Result;
use crate::postgres::auth_handler::PgAuthStartupHandler;
use crate::postgres::handler::PostgresServerHandler;
@@ -42,13 +44,15 @@ impl PostgresServer {
/// Creates a new Postgres server with provided query_handler and async runtime
pub fn new(
query_handler: SqlQueryHandlerRef,
check_pwd: bool,
tls: TlsOption,
io_runtime: Arc<Runtime>,
user_provider: Option<UserProviderRef>,
) -> PostgresServer {
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
let startup_handler =
Arc::new(PgAuthStartupHandler::new(check_pwd, tls.should_force_tls()));
let startup_handler = Arc::new(PgAuthStartupHandler::new(
user_provider,
tls.should_force_tls(),
));
PostgresServer {
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
auth_handler: startup_handler,
@@ -76,6 +80,11 @@ impl PostgresServer {
match tcp_stream {
Err(error) => error!("Broken pipe: {}", error), // IoError doesn't impl ErrorExt.
Ok(io_stream) => {
match io_stream.peer_addr() {
Ok(addr) => debug!("PostgreSQL client coming from {}", addr),
Err(e) => warn!("Failed to get PostgreSQL client addr, err: {}", e),
}
io_runtime.spawn(process_socket(
io_stream,
tls_acceptor.clone(),
@@ -99,6 +108,7 @@ impl Server for PostgresServer {
async fn start(&self, listening: SocketAddr) -> Result<SocketAddr> {
let (stream, addr) = self.base_server.bind(listening).await?;
debug!("Starting PostgreSQL with TLS option: {:?}", self.tls);
let tls_acceptor = self
.tls
.setup()?

View File

@@ -18,22 +18,32 @@ use std::io::{BufReader, Error, ErrorKind};
use rustls::{Certificate, PrivateKey, ServerConfig};
use rustls_pemfile::{certs, pkcs8_private_keys};
use serde::{Deserialize, Serialize};
use strum::EnumString;
/// TlsMode is used for Mysql and Postgres server start up.
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq, EnumString)]
#[serde(rename_all = "snake_case")]
pub enum TlsMode {
#[default]
#[strum(to_string = "disable")]
Disable,
#[strum(to_string = "prefer")]
Prefer,
#[strum(to_string = "require")]
Require,
// TODO(SSebo): Implement the following 2 TSL mode described in
// ["34.19.3. Protection Provided in Different Modes"](https://www.postgresql.org/docs/current/libpq-ssl.html)
#[strum(to_string = "verify-ca")]
VerifyCa,
#[strum(to_string = "verify-full")]
VerifyFull,
}
#[derive(Debug, Default, Serialize, Deserialize, Clone)]
#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct TlsOption {
pub mode: TlsMode,
@@ -44,6 +54,24 @@ pub struct TlsOption {
}
impl TlsOption {
pub fn new(mode: Option<TlsMode>, cert_path: Option<String>, key_path: Option<String>) -> Self {
let mut tls_option = TlsOption::default();
if let Some(mode) = mode {
tls_option.mode = mode
};
if let Some(cert_path) = cert_path {
tls_option.cert_path = cert_path
};
if let Some(key_path) = key_path {
tls_option.key_path = key_path
};
tls_option
}
pub fn setup(&self) -> Result<Option<ServerConfig>, Error> {
if let TlsMode::Disable = self.mode {
return Ok(None);
@@ -76,6 +104,31 @@ impl TlsOption {
#[cfg(test)]
mod tests {
use super::*;
use crate::tls::TlsMode::Disable;
#[test]
fn test_new_tls_option() {
assert_eq!(TlsOption::default(), TlsOption::new(None, None, None));
assert_eq!(
TlsOption {
mode: Disable,
..Default::default()
},
TlsOption::new(Some(Disable), None, None)
);
assert_eq!(
TlsOption {
mode: Disable,
cert_path: "/path/to/cert_path".to_string(),
key_path: "/path/to/key_path".to_string(),
},
TlsOption::new(
Some(Disable),
Some("/path/to/cert_path".to_string()),
Some("/path/to/key_path".to_string())
)
);
}
#[test]
fn test_tls_option_disable() {

View File

@@ -18,6 +18,7 @@ use axum::body::Body;
use axum::extract::{Json, Query, RawBody, State};
use common_telemetry::metric;
use metrics::counter;
use servers::auth::UserInfo;
use servers::http::{handler as http_handler, script as script_handler, ApiState, JsonOutput};
use table::test_util::MemTable;
@@ -32,6 +33,7 @@ async fn test_sql_not_provided() {
script_handler: None,
}),
Query(http_handler::SqlQuery::default()),
axum::Extension(UserInfo::default()),
)
.await;
assert!(!json.success());
@@ -55,6 +57,7 @@ async fn test_sql_output_rows() {
script_handler: None,
}),
query,
axum::Extension(UserInfo::default()),
)
.await;
assert!(json.success(), "{:?}", json);

View File

@@ -16,9 +16,10 @@ use std::sync::Arc;
use api::v1::InsertExpr;
use async_trait::async_trait;
use axum::Router;
use axum::{http, Router};
use axum_test_helper::TestClient;
use common_query::Output;
use servers::auth::user_provider::StaticUserProvider;
use servers::error::Result;
use servers::http::{HttpOptions, HttpServer};
use servers::influxdb::InfluxdbRequest;
@@ -53,6 +54,9 @@ impl SqlQueryHandler for DummyInstance {
fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router {
let instance = Arc::new(DummyInstance { tx });
let mut server = HttpServer::new(instance.clone(), HttpOptions::default());
let up = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
server.set_user_provider(Arc::new(up));
server.set_influxdb_handler(instance);
server.make_app()
}
@@ -68,6 +72,10 @@ async fn test_influxdb_write() {
let result = client
.post("/v1/influxdb/write")
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
)
.send()
.await;
assert_eq!(result.status(), 204);
@@ -76,6 +84,10 @@ async fn test_influxdb_write() {
let result = client
.post("/v1/influxdb/write?db=influxdb")
.body("monitor,host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
)
.send()
.await;
assert_eq!(result.status(), 204);
@@ -85,6 +97,10 @@ async fn test_influxdb_write() {
let result = client
.post("/v1/influxdb/write")
.body("monitor, host=host1 cpu=1.2 1664370459457010101")
.header(
http::header::AUTHORIZATION,
"basic Z3JlcHRpbWU6Z3JlcHRpbWU=",
)
.send()
.await;
assert_eq!(result.status(), 400);

View File

@@ -23,6 +23,7 @@ use mysql_async::prelude::*;
use mysql_async::SslOpts;
use rand::rngs::StdRng;
use rand::Rng;
use servers::auth::user_provider::StaticUserProvider;
use servers::error::Result;
use servers::mysql::server::MysqlServer;
use servers::server::Server;
@@ -41,7 +42,15 @@ fn create_mysql_server(table: MemTable, tls: Arc<TlsOption>) -> Result<Box<dyn S
.build()
.unwrap(),
);
Ok(MysqlServer::create_server(query_handler, io_runtime, tls))
let provider = StaticUserProvider::try_from("cmd:greptime=greptime").unwrap();
Ok(MysqlServer::create_server(
query_handler,
io_runtime,
tls,
Some(Arc::new(provider)),
))
}
#[tokio::test]
@@ -79,10 +88,10 @@ async fn test_shutdown_mysql_server() -> Result<()> {
let server_port = server_addr.port();
let mut join_handles = vec![];
for index in 0..2 {
for _ in 0..2 {
join_handles.push(tokio::spawn(async move {
for _ in 0..1000 {
match create_connection(server_port, index == 1, false).await {
match create_connection(server_port, false).await {
Ok(mut connection) => {
let result: u32 = connection
.query_first("SELECT uint32s FROM numbers LIMIT 1")
@@ -119,7 +128,7 @@ async fn test_query_all_datatypes() -> Result<()> {
let server_tls = Arc::new(TlsOption::default());
let client_tls = false;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -132,7 +141,7 @@ async fn test_server_prefer_secure_client_plain() -> Result<()> {
});
let client_tls = false;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -145,7 +154,7 @@ async fn test_server_prefer_secure_client_secure() -> Result<()> {
});
let client_tls = true;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -158,7 +167,7 @@ async fn test_server_require_secure_client_secure() -> Result<()> {
});
let client_tls = true;
do_test_query_all_datatypes(server_tls, client_tls, false).await?;
do_test_query_all_datatypes(server_tls, client_tls).await?;
Ok(())
}
@@ -188,16 +197,12 @@ async fn test_server_required_secure_client_plain() -> Result<()> {
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let r = create_connection(server_addr.port(), client_tls, false).await;
let r = create_connection(server_addr.port(), client_tls).await;
assert!(r.is_err());
Ok(())
}
async fn do_test_query_all_datatypes(
server_tls: Arc<TlsOption>,
with_pwd: bool,
client_tls: bool,
) -> Result<()> {
async fn do_test_query_all_datatypes(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
common_telemetry::init_default_ut_logging();
let TestingData {
column_schemas,
@@ -214,7 +219,7 @@ async fn do_test_query_all_datatypes(
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = mysql_server.start(listening).await.unwrap();
let mut connection = create_connection(server_addr.port(), client_tls, with_pwd)
let mut connection = create_connection(server_addr.port(), client_tls)
.await
.unwrap();
@@ -252,13 +257,11 @@ async fn test_query_concurrently() -> Result<()> {
let threads = 4;
let expect_executed_queries_per_worker = 1000;
let mut join_handles = vec![];
for index in 0..threads {
for _ in 0..threads {
join_handles.push(tokio::spawn(async move {
let mut rand: StdRng = rand::SeedableRng::from_entropy();
let mut connection = create_connection(server_port, index % 2 == 0, false)
.await
.unwrap();
let mut connection = create_connection(server_port, false).await.unwrap();
for _ in 0..expect_executed_queries_per_worker {
let expected: u32 = rand.gen_range(0..100);
let result: u32 = connection
@@ -273,9 +276,7 @@ async fn test_query_concurrently() -> Result<()> {
let should_recreate_conn = expected == 1;
if should_recreate_conn {
connection = create_connection(server_port, index % 2 == 0, false)
.await
.unwrap();
connection = create_connection(server_port, false).await.unwrap();
}
}
expect_executed_queries_per_worker
@@ -289,16 +290,14 @@ async fn test_query_concurrently() -> Result<()> {
Ok(())
}
async fn create_connection(
port: u16,
with_pwd: bool,
ssl: bool,
) -> mysql_async::Result<mysql_async::Conn> {
async fn create_connection(port: u16, ssl: bool) -> mysql_async::Result<mysql_async::Conn> {
let mut opts = mysql_async::OptsBuilder::default()
.ip_or_hostname("127.0.0.1")
.tcp_port(port)
.prefer_socket(false)
.wait_timeout(Some(1000));
.wait_timeout(Some(1000))
.user(Some("greptime".to_string()))
.pass(Some("greptime".to_string()));
if ssl {
let ssl_opts = SslOpts::default()
@@ -307,9 +306,5 @@ async fn create_connection(
opts = opts.ssl_opts(ssl_opts)
}
if with_pwd {
opts = opts.pass(Some("default_pwd".to_string()));
}
mysql_async::Conn::new(opts).await
}

View File

@@ -22,6 +22,8 @@ use rand::rngs::StdRng;
use rand::Rng;
use rustls::client::{ServerCertVerified, ServerCertVerifier};
use rustls::{Certificate, Error, ServerName};
use servers::auth::user_provider::StaticUserProvider;
use servers::auth::UserProviderRef;
use servers::error::Result;
use servers::postgres::PostgresServer;
use servers::server::Server;
@@ -44,11 +46,19 @@ fn create_postgres_server(
.build()
.unwrap(),
);
let user_provider: Option<UserProviderRef> = if check_pwd {
Some(Arc::new(
StaticUserProvider::try_from("cmd:test_user=test_pwd").unwrap(),
))
} else {
None
};
Ok(Box::new(PostgresServer::new(
query_handler,
check_pwd,
tls,
io_runtime,
user_provider,
)))
}

View File

@@ -30,6 +30,10 @@ impl Default for QueryContext {
}
impl QueryContext {
pub fn arc() -> QueryContextRef {
Arc::new(QueryContext::new())
}
pub fn new() -> Self {
Self {
current_schema: ArcSwapOption::new(None),

View File

@@ -18,7 +18,8 @@ use itertools::Itertools;
use mito::engine;
use once_cell::sync::Lazy;
use snafu::{ensure, OptionExt, ResultExt};
use sqlparser::ast::Value;
use sqlparser::ast::ColumnOption::NotNull;
use sqlparser::ast::{ColumnOptionDef, DataType, Value};
use sqlparser::dialect::keywords::Keyword;
use sqlparser::parser::IsOptional::Mandatory;
use sqlparser::tokenizer::{Token, Word};
@@ -220,11 +221,7 @@ impl<'a> ParserContext<'a> {
if let Some(constraint) = self.parse_optional_table_constraint()? {
constraints.push(constraint);
} else if let Token::Word(_) = self.parser.peek_token() {
columns.push(
self.parser
.parse_column_def()
.context(SyntaxSnafu { sql: self.sql })?,
);
self.parse_column(&mut columns, &mut constraints)?;
} else {
return self.expected(
"column name or constraint definition",
@@ -246,6 +243,75 @@ impl<'a> ParserContext<'a> {
Ok((columns, constraints))
}
fn parse_column(
&mut self,
columns: &mut Vec<ColumnDef>,
constraints: &mut Vec<TableConstraint>,
) -> Result<()> {
let column = self
.parser
.parse_column_def()
.context(SyntaxSnafu { sql: self.sql })?;
if !matches!(column.data_type, DataType::Timestamp(_))
|| matches!(self.parser.peek_token(), Token::Comma)
{
columns.push(column);
return Ok(());
}
// for supporting `ts TIMESTAMP TIME INDEX,` syntax.
self.parse_time_index(column, columns, constraints)
}
fn parse_time_index(
&mut self,
mut column: ColumnDef,
columns: &mut Vec<ColumnDef>,
constraints: &mut Vec<TableConstraint>,
) -> Result<()> {
self.parser
.expect_keywords(&[Keyword::TIME, Keyword::INDEX])
.context(error::UnexpectedSnafu {
sql: self.sql,
expected: "TIME INDEX",
actual: self.peek_token_as_string(),
})?;
let constraint = TableConstraint::Unique {
name: Some(Ident {
value: TIME_INDEX.to_owned(),
quote_style: None,
}),
columns: vec![Ident {
value: column.name.value.clone(),
quote_style: None,
}],
is_primary: false,
};
column.options = vec![ColumnOptionDef {
name: None,
option: NotNull,
}];
columns.push(column);
constraints.push(constraint);
if let Token::Comma = self.parser.peek_token() {
return Ok(());
}
self.parser
.expect_keywords(&[Keyword::NOT, Keyword::NULL])
.context(error::UnexpectedSnafu {
sql: self.sql,
expected: "NOT NULL",
actual: self.peek_token_as_string(),
})?;
Ok(())
}
// Copy from sqlparser by boyan
fn parse_optional_table_constraint(&mut self) -> Result<Option<TableConstraint>> {
let name = if self.parser.parse_keyword(Keyword::CONSTRAINT) {
@@ -705,6 +771,160 @@ ENGINE=mito";
}
}
#[test]
fn test_parse_create_table_with_timestamp_index() {
let sql1 = r"
CREATE TABLE monitor (
host_id INT,
idc STRING,
ts TIMESTAMP TIME INDEX,
cpu DOUBLE DEFAULT 0,
memory DOUBLE,
PRIMARY KEY (host),
)
ENGINE=mito";
let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap();
if let Statement::CreateTable(c) = &result1[0] {
assert_eq!(c.constraints.len(), 2);
let tc = c.constraints[0].clone();
match tc {
TableConstraint::Unique {
name,
columns,
is_primary,
} => {
assert_eq!(name.unwrap().to_string(), "__time_index");
assert_eq!(columns.len(), 1);
assert_eq!(&columns[0].value, "ts");
assert!(!is_primary);
}
_ => panic!("should be time index constraint"),
};
} else {
panic!("should be create_table statement");
}
// `TIME INDEX` should be in front of `PRIMARY KEY`
// in order to equal the `TIMESTAMP TIME INDEX` constraint options vector
let sql2 = r"
CREATE TABLE monitor (
host_id INT,
idc STRING,
ts TIMESTAMP NOT NULL,
cpu DOUBLE DEFAULT 0,
memory DOUBLE,
TIME INDEX (ts),
PRIMARY KEY (host),
)
ENGINE=mito";
let result2 = ParserContext::create_with_dialect(sql2, &GenericDialect {}).unwrap();
assert_eq!(result1, result2);
// TIMESTAMP can be NULL which is not equal to above
let sql3 = r"
CREATE TABLE monitor (
host_id INT,
idc STRING,
ts TIMESTAMP,
cpu DOUBLE DEFAULT 0,
memory DOUBLE,
TIME INDEX (ts),
PRIMARY KEY (host),
)
ENGINE=mito";
let result3 = ParserContext::create_with_dialect(sql3, &GenericDialect {}).unwrap();
assert_ne!(result1, result3);
}
#[test]
fn test_parse_create_table_with_timestamp_index_not_null() {
let sql = r"
CREATE TABLE monitor (
host_id INT,
idc STRING,
ts TIMESTAMP TIME INDEX,
cpu DOUBLE DEFAULT 0,
memory DOUBLE,
TIME INDEX (ts),
PRIMARY KEY (host),
)
ENGINE=mito";
let result = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
assert_eq!(result.len(), 1);
if let Statement::CreateTable(c) = &result[0] {
let ts = c.columns[2].clone();
assert_eq!(ts.name.to_string(), "ts");
assert_eq!(ts.options[0].option, NotNull);
} else {
panic!("should be create table statement");
}
let sql1 = r"
CREATE TABLE monitor (
host_id INT,
idc STRING,
ts TIMESTAMP NOT NULL TIME INDEX,
cpu DOUBLE DEFAULT 0,
memory DOUBLE,
TIME INDEX (ts),
PRIMARY KEY (host),
)
ENGINE=mito";
let result1 = ParserContext::create_with_dialect(sql1, &GenericDialect {}).unwrap();
assert_eq!(result, result1);
let sql2 = r"
CREATE TABLE monitor (
host_id INT,
idc STRING,
ts TIMESTAMP TIME INDEX NOT NULL,
cpu DOUBLE DEFAULT 0,
memory DOUBLE,
TIME INDEX (ts),
PRIMARY KEY (host),
)
ENGINE=mito";
let result2 = ParserContext::create_with_dialect(sql2, &GenericDialect {}).unwrap();
assert_eq!(result, result2);
let sql3 = r"
CREATE TABLE monitor (
host_id INT,
idc STRING,
ts TIMESTAMP TIME INDEX NULL NOT,
cpu DOUBLE DEFAULT 0,
memory DOUBLE,
TIME INDEX (ts),
PRIMARY KEY (host),
)
ENGINE=mito";
let result3 = ParserContext::create_with_dialect(sql3, &GenericDialect {});
assert!(result3.is_err());
let sql4 = r"
CREATE TABLE monitor (
host_id INT,
idc STRING,
ts TIMESTAMP TIME INDEX NOT NULL NULL,
cpu DOUBLE DEFAULT 0,
memory DOUBLE,
TIME INDEX (ts),
PRIMARY KEY (host),
)
ENGINE=mito";
let result4 = ParserContext::create_with_dialect(sql4, &GenericDialect {});
assert!(result4.is_err());
}
#[test]
fn test_parse_partitions_with_error_syntax() {
let sql = r"

View File

@@ -209,6 +209,17 @@ pub enum Error {
source: BoxedError,
},
#[snafu(display(
"Failed to mark WAL as stable, region id: {}, source: {}",
region_id,
source
))]
MarkWalStable {
region_id: u64,
#[snafu(backtrace)]
source: BoxedError,
},
#[snafu(display("WAL data corrupted, region_id: {}, message: {}", region_id, message))]
WalDataCorrupted {
region_id: RegionId,
@@ -415,6 +426,7 @@ impl ErrorExt for Error {
PushBatch { source, .. } => source.status_code(),
AddDefault { source, .. } => source.status_code(),
ConvertChunk { source, .. } => source.status_code(),
MarkWalStable { source, .. } => source.status_code(),
}
}

View File

@@ -223,7 +223,8 @@ impl<S: LogStore> FlushJob<S> {
edit,
self.max_memtable_id,
)
.await
.await?;
self.wal.obsolete(self.flush_sequence).await
}
/// Generates random SST file name in format: `^[a-f\d]{8}(-[a-f\d]{4}){3}-[a-f\d]{12}.parquet$`
@@ -237,9 +238,7 @@ impl<S: LogStore> Job for FlushJob<S> {
// TODO(yingwen): [flush] Support in-job parallelism (Flush memtables concurrently)
async fn run(&mut self, ctx: &Context) -> Result<()> {
let file_metas = self.write_memtables_to_layer(ctx).await?;
self.write_manifest_and_apply(&file_metas).await?;
Ok(())
}
}

View File

@@ -225,6 +225,7 @@ impl<S: LogStore> RegionImpl<S> {
}
let wal = Wal::new(metadata.id(), store_config.log_store);
wal.obsolete(flushed_sequence).await?;
let shared = Arc::new(SharedData {
id: metadata.id(),
name,

View File

@@ -24,7 +24,7 @@ use store_api::logstore::{AppendResponse, LogStore};
use store_api::storage::{RegionId, SequenceNumber};
use crate::codec::{Decoder, Encoder};
use crate::error::{self, Error, Result};
use crate::error::{self, Error, MarkWalStableSnafu, Result};
use crate::proto::wal::{self, PayloadType, WalHeader};
use crate::write_batch::codec::{
WriteBatchArrowDecoder, WriteBatchArrowEncoder, WriteBatchProtobufDecoder,
@@ -64,6 +64,16 @@ impl<S: LogStore> Wal<S> {
}
}
pub async fn obsolete(&self, seq: SequenceNumber) -> Result<()> {
self.store
.obsolete(self.namespace.clone(), seq)
.await
.map_err(BoxedError::new)
.context(MarkWalStableSnafu {
region_id: self.region_id,
})
}
#[inline]
pub fn region_id(&self) -> RegionId {
self.region_id

View File

@@ -32,6 +32,10 @@ pub trait LogStore: Send + Sync + 'static + std::fmt::Debug {
type Entry: Entry;
type AppendResponse: AppendResponse;
async fn start(&self) -> Result<(), Self::Error>;
async fn stop(&self) -> Result<(), Self::Error>;
/// Append an `Entry` to WAL with given namespace
async fn append(&self, mut e: Self::Entry) -> Result<Self::AppendResponse, Self::Error>;
@@ -65,6 +69,11 @@ pub trait LogStore: Send + Sync + 'static + std::fmt::Debug {
/// Create a namespace of the associate Namespace type
// TODO(sunng87): confusion with `create_namespace`
fn namespace(&self, id: namespace::Id) -> Self::Namespace;
/// Mark all entry ids `<=id` of given `namespace` as obsolete so that logstore can safely delete
/// the log files if all entries inside are obsolete. This method may not delete log
/// files immediately.
async fn obsolete(&self, namespace: Self::Namespace, id: Id) -> Result<(), Self::Error>;
}
pub trait AppendResponse: Send + Sync {

View File

@@ -55,7 +55,7 @@ mod tests {
#[snafu(visibility(pub))]
pub struct Error {}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct Namespace {}
impl crate::logstore::Namespace for Namespace {

View File

@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::hash::Hash;
pub type Id = u64;
pub trait Namespace: Send + Sync + Clone + std::fmt::Debug {
pub trait Namespace: Send + Sync + Clone + std::fmt::Debug + Hash + Eq {
fn id(&self) -> Id;
}

Some files were not shown because too many files have changed in this diff Show More