From 4098c57446a96d4ecbf42959a1fbcab31272a9dc Mon Sep 17 00:00:00 2001 From: LFC Date: Fri, 12 Aug 2022 11:41:45 +0800 Subject: [PATCH] feat: MySQL protocol server (#158) * MySQL protocol server * fix: Rustfmt check * fix: resolve PR comments Co-authored-by: luofucong --- Cargo.lock | 422 +++++++++++++++++- Cargo.toml | 1 + src/common/error/src/status_code.rs | 5 + src/common/servers/Cargo.toml | 27 ++ src/common/servers/src/error.rs | 28 ++ src/common/servers/src/lib.rs | 3 + src/common/servers/src/mysql/error.rs | 60 +++ src/common/servers/src/mysql/mod.rs | 4 + .../servers/src/mysql/mysql_instance.rs | 80 ++++ src/common/servers/src/mysql/mysql_server.rs | 137 ++++++ src/common/servers/src/mysql/mysql_writer.rs | 237 ++++++++++ src/common/servers/src/server.rs | 11 + src/common/servers/tests/mod.rs | 1 + src/common/servers/tests/mysql/mod.rs | 269 +++++++++++ .../servers/tests/mysql/mysql_server_test.rs | 221 +++++++++ .../servers/tests/mysql/mysql_writer_test.rs | 34 ++ src/query/tests/my_sum_udaf_example.rs | 10 +- src/query/tests/query_engine_test.rs | 20 +- test-util/src/memtable.rs | 34 +- 19 files changed, 1565 insertions(+), 39 deletions(-) create mode 100644 src/common/servers/Cargo.toml create mode 100644 src/common/servers/src/error.rs create mode 100644 src/common/servers/src/lib.rs create mode 100644 src/common/servers/src/mysql/error.rs create mode 100644 src/common/servers/src/mysql/mod.rs create mode 100644 src/common/servers/src/mysql/mysql_instance.rs create mode 100644 src/common/servers/src/mysql/mysql_server.rs create mode 100644 src/common/servers/src/mysql/mysql_writer.rs create mode 100644 src/common/servers/src/server.rs create mode 100644 src/common/servers/tests/mod.rs create mode 100644 src/common/servers/tests/mysql/mod.rs create mode 100644 src/common/servers/tests/mysql/mysql_server_test.rs create mode 100644 src/common/servers/tests/mysql/mysql_writer_test.rs diff --git a/Cargo.lock b/Cargo.lock index f6f851ccce..4795d02bdb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -352,6 +352,36 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" +[[package]] +name = "bigdecimal" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6aaf33151a6429fe9211d1b276eafdf70cdff28b071e76c0b0e1503221ea3744" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "bindgen" +version = "0.59.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bd2a9a458e8f4304c52c43ebb0cfbd520289f8379a52e329a38afda99bf8eb8" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "lazy_static", + "lazycell", + "peeking_take_while", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -536,6 +566,15 @@ dependencies = [ "jobserver", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -578,6 +617,17 @@ dependencies = [ "phf_codegen", ] +[[package]] +name = "clang-sys" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a050e2153c5be08febd6734e29298e844fdb0fa21aeddd63b4eb7baa106c69b" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "2.34.0" @@ -639,6 +689,15 @@ dependencies = [ "tonic 0.8.0", ] +[[package]] +name = "cmake" +version = "0.1.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8ad8cef104ac57b68b89df3208164d228503abbdce70f6880ffa3d970e7443a" +dependencies = [ + "cc", +] + [[package]] name = "cmd" version = "0.1.0" @@ -742,6 +801,31 @@ dependencies = [ "tokio-test", ] +[[package]] +name = "common-servers" +version = "0.1.0" +dependencies = [ + "async-trait", + "catalog", + "common-base", + "common-error", + "common-recordbatch", + "common-runtime", + "common-telemetry", + "datatypes", + "futures", + "metrics 0.20.1", + "mysql_async", + "num_cpus", + "opensrv-mysql", + "query", + "rand 0.8.5", + "snafu", + "test-util", + "tokio", + "tokio-stream", +] + [[package]] name = "common-telemetry" version = "0.1.0" @@ -1355,6 +1439,7 @@ dependencies = [ "cfg-if", "crc32fast", "libc", + "libz-sys", "miniz_oxide", ] @@ -1389,6 +1474,70 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "frunk" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cd67cf7d54b7e72d0ea76f3985c3747d74aee43e0218ad993b7903ba7a5395e" +dependencies = [ + "frunk_core", + "frunk_derives", + "frunk_proc_macros", +] + +[[package]] +name = "frunk_core" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1246cf43ec80bf8b2505b5c360b8fb999c97dabd17dbb604d85558d5cbc25482" + +[[package]] +name = "frunk_derives" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3dbc4f084ec5a3f031d24ccedeb87ab2c3189a2f33b8d070889073837d5ea09e" +dependencies = [ + "frunk_proc_macro_helpers", + "quote", + "syn", +] + +[[package]] +name = "frunk_proc_macro_helpers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99f11257f106c6753f5ffcb8e601fb39c390a088017aaa55b70c526bff15f63e" +dependencies = [ + "frunk_core", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "frunk_proc_macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a078bd8459eccbb85e0b007b8f756585762a72a9efc53f359b371c3b6351dbcc" +dependencies = [ + "frunk_core", + "frunk_proc_macros_impl", + "proc-macro-hack", +] + +[[package]] +name = "frunk_proc_macros_impl" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ffba99f0fa4f57e42f57388fbb9a0ca863bc2b4261f3c5570fed579d5df6c32" +dependencies = [ + "frunk_core", + "frunk_proc_macro_helpers", + "proc-macro-hack", + "quote", + "syn", +] + [[package]] name = "fuchsia-cprng" version = "0.1.1" @@ -1542,6 +1691,12 @@ version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4" +[[package]] +name = "glob" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b919933a397b79c37e33b77bb2aa3dc8eb6e165ad809e58ff75bc7db2e34574" + [[package]] name = "h2" version = "0.3.13" @@ -1557,7 +1712,7 @@ dependencies = [ "indexmap", "slab", "tokio", - "tokio-util 0.7.1", + "tokio-util 0.7.3", "tracing", ] @@ -1887,10 +2042,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] -name = "lexical-core" -version = "0.8.3" +name = "lazycell" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92912c4af2e7d9075be3e5e3122c4d7263855fa6cce34fbece4dd08e5884624d" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + +[[package]] +name = "lexical" +version = "6.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7aefb36fd43fef7003334742cbf77b243fcd36418a1d1bdd480d613a67968f6" +dependencies = [ + "lexical-core", +] + +[[package]] +name = "lexical-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -1901,9 +2071,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f518eed87c3be6debe6d26b855c97358d8a11bf05acec137e5f53080f5ad2dd8" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -1912,9 +2082,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "0.8.3" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc852ec67c6538bbb2b9911116a385b24510e879a69ab516e6a151b15a79168" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" dependencies = [ "lexical-util", "static_assertions", @@ -1922,18 +2092,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c72a9d52c5c4e62fa2cdc2cb6c694a39ae1382d9c2a17a466f18e272a0930eb1" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "0.8.4" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a89ec1d062e481210c309b672f73a0567b7855f21e7d2fae636df44d12e97f9" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" dependencies = [ "lexical-util", "lexical-write-integer", @@ -1942,9 +2112,9 @@ dependencies = [ [[package]] name = "lexical-write-integer" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "094060bd2a7c2ff3a16d5304a6ae82727cb3cc9d1c70f813cc73f744c319337e" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" dependencies = [ "lexical-util", "static_assertions", @@ -1956,6 +2126,16 @@ version = "0.2.125" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5916d2ae698f6de9bfb891ad7a8d65c09d232dc58cc4ac433c7da3b2fd84bc2b" +[[package]] +name = "libloading" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efbc0f03f9a775e9f6aed295c6a1ba2253c5757a9e03d55c6caa46a681abcddd" +dependencies = [ + "cfg-if", + "winapi", +] + [[package]] name = "libnghttp2-sys" version = "0.1.7+1.45.0" @@ -2026,6 +2206,15 @@ dependencies = [ name = "logical-plans" version = "0.1.0" +[[package]] +name = "lru" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999beba7b6e8345721bd280141ed958096a2e4abdf74f67ff4ce49b4b54e47a" +dependencies = [ + "hashbrown 0.12.1", +] + [[package]] name = "lz4" version = "1.23.3" @@ -2113,7 +2302,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e52eb6380b6d2a10eb3434aec0885374490f5b82c8aaf5cd487a183c98be834" dependencies = [ "ahash", - "metrics-macros", + "metrics-macros 0.5.1", ] [[package]] @@ -2123,7 +2312,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "142c53885123b68d94108295a09d4afe1a1388ed95b54d5dacd9a454753030f2" dependencies = [ "ahash", - "metrics-macros", + "metrics-macros 0.5.1", +] + +[[package]] +name = "metrics" +version = "0.20.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b9b8653cec6897f73b519a43fba5ee3d50f62fe9af80b428accdcc093b4a849" +dependencies = [ + "ahash", + "metrics-macros 0.6.0", + "portable-atomic", ] [[package]] @@ -2151,6 +2351,17 @@ dependencies = [ "syn", ] +[[package]] +name = "metrics-macros" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "731f8ecebd9f3a4aa847dfe75455e4757a45da40a7793d2f0b1f9b6ed18b23f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "metrics-util" version = "0.12.1" @@ -2276,6 +2487,76 @@ dependencies = [ "syn", ] +[[package]] +name = "mysql_async" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "456207bb9636a0fdade67a64cea7bdebe6730c3c16ee5e34f2c481838ee5a39e" +dependencies = [ + "bytes", + "crossbeam", + "flate2", + "futures-core", + "futures-sink", + "futures-util", + "lazy_static", + "lru", + "mio", + "mysql_common", + "native-tls", + "once_cell", + "pem", + "percent-encoding", + "pin-project", + "serde", + "serde_json", + "socket2", + "thiserror", + "tokio", + "tokio-native-tls", + "tokio-util 0.7.3", + "twox-hash", + "url", +] + +[[package]] +name = "mysql_common" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "522f2f30f72de409fc04f88df25a031f98cfc5c398a94e0b892cabb33a1464cb" +dependencies = [ + "base64", + "bigdecimal", + "bindgen", + "bitflags", + "bitvec", + "byteorder", + "bytes", + "cc", + "chrono", + "cmake", + "crc32fast", + "flate2", + "frunk", + "lazy_static", + "lexical", + "num-bigint", + "num-traits", + "rand 0.8.5", + "regex", + "rust_decimal", + "saturating", + "serde", + "serde_json", + "sha-1", + "sha2", + "smallvec", + "subprocess", + "thiserror", + "time 0.3.9", + "uuid", +] + [[package]] name = "native-tls" version = "0.2.10" @@ -2463,6 +2744,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "opensrv-mysql" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bcb5fc2fda7e5e5f8478cd637285bbdd6196a9601e32293d0897e469a7dd020" +dependencies = [ + "async-trait", + "byteorder", + "chrono", + "mysql_common", + "nom", + "tokio", +] + [[package]] name = "openssl" version = "0.10.40" @@ -2700,6 +2995,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c520e05135d6e763148b6426a837e239041653ba7becd2e538c076c738025fc" +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" + [[package]] name = "pem" version = "1.0.2" @@ -2852,6 +3153,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "portable-atomic" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ef3e12daa83946e79a4e22dff6ff8154138bfb34bef1769ec80c92bc3aa88e3" + [[package]] name = "ppv-lite86" version = "0.2.16" @@ -2892,6 +3199,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "proc-macro-hack" +version = "0.5.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" + [[package]] name = "proc-macro2" version = "1.0.38" @@ -3321,12 +3634,29 @@ dependencies = [ "ordered-multimap", ] +[[package]] +name = "rust_decimal" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34a3bb58e85333f1ab191bf979104b586ebd77475bc6681882825f4532dfe87c" +dependencies = [ + "arrayvec", + "num-traits", + "serde", +] + [[package]] name = "rustc-demangle" version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.0" @@ -3357,6 +3687,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "saturating" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ece8e78b2f38ec51c51f5d475df0a7187ba5111b2a28bdc761ee05b075d40a71" + [[package]] name = "schannel" version = "0.1.19" @@ -3456,6 +3792,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha-1" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.2" @@ -3476,6 +3823,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -3712,6 +4065,16 @@ dependencies = [ "syn", ] +[[package]] +name = "subprocess" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c2e86926081dda636c546d8c5e641661049d7562a68f5488be4a1f7f66f6086" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "subtle" version = "2.4.1" @@ -4022,9 +4385,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.8" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50145484efff8818b5ccd256697f36863f587da82cf8b409c53adf1e840798e3" +checksum = "df54d54117d6fdc4e4fea40fe1e4e566b3505700e148a6827e59b34b0d2600d9" dependencies = [ "futures-core", "pin-project-lite", @@ -4060,9 +4423,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.1" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0edfdeb067411dba2044da6d1cb2df793dd35add7888d73c16e3381ded401764" +checksum = "cc463cd8deddc3770d20f9852143d50bf6094e640b485cb2e189a2099085ff45" dependencies = [ "bytes", "futures-core", @@ -4096,7 +4459,7 @@ dependencies = [ "prost-derive 0.10.1", "tokio", "tokio-stream", - "tokio-util 0.7.1", + "tokio-util 0.7.3", "tower", "tower-layer", "tower-service", @@ -4128,7 +4491,7 @@ dependencies = [ "prost-derive 0.11.0", "tokio", "tokio-stream", - "tokio-util 0.7.1", + "tokio-util 0.7.3", "tower", "tower-layer", "tower-service", @@ -4164,7 +4527,7 @@ dependencies = [ "rand 0.8.5", "slab", "tokio", - "tokio-util 0.7.1", + "tokio-util 0.7.3", "tower-layer", "tower-service", "tracing", @@ -4192,7 +4555,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "tokio", - "tokio-util 0.7.1", + "tokio-util 0.7.3", "tower", "tower-layer", "tower-service", @@ -4334,6 +4697,17 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "rand 0.8.5", + "static_assertions", +] + [[package]] name = "typenum" version = "1.15.0" diff --git a/Cargo.toml b/Cargo.toml index 132b677d7d..743bfde0f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "src/common/query", "src/common/recordbatch", "src/common/runtime", + "src/common/servers", "src/common/telemetry", "src/common/time", "src/cmd", diff --git a/src/common/error/src/status_code.rs b/src/common/error/src/status_code.rs index 6b38e62546..424dd93c4a 100644 --- a/src/common/error/src/status_code.rs +++ b/src/common/error/src/status_code.rs @@ -39,6 +39,11 @@ pub enum StatusCode { /// Storage is temporarily unable to handle the request StorageUnavailable, // ====== End of storage related status code ======= + + // ====== Begin of server related status code ===== + /// Runtime resources exhausted, like creating threads failed. + RuntimeResourcesExhausted, + // ====== End of server related status code ======= } impl fmt::Display for StatusCode { diff --git a/src/common/servers/Cargo.toml b/src/common/servers/Cargo.toml new file mode 100644 index 0000000000..c5b10fb2f5 --- /dev/null +++ b/src/common/servers/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "common-servers" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-trait = "0.1" +common-error = { path = "../error" } +common-recordbatch = { path = "../recordbatch" } +common-runtime = { path = "../runtime" } +common-telemetry = { path = "../telemetry" } +datatypes = { path = "../../datatypes"} +futures = "0.3" +metrics = "0.20" +num_cpus = "1.13" +opensrv-mysql = "0.1" +query = { path = "../../query" } +snafu = { version = "0.7", features = ["backtraces"] } +tokio = { version = "1.20", features = ["full"] } +tokio-stream = { version = "0.1", features = ["net"] } + +[dev-dependencies] +common-base = { path = "../base" } +catalog = { path = "../../catalog" } +mysql_async = "0.30" +rand = "0.8" +test-util = { path = "../../../test-util" } diff --git a/src/common/servers/src/error.rs b/src/common/servers/src/error.rs new file mode 100644 index 0000000000..a1ccdd1cdb --- /dev/null +++ b/src/common/servers/src/error.rs @@ -0,0 +1,28 @@ +use std::any::Any; + +use common_error::prelude::*; + +#[derive(Debug, Snafu)] +#[snafu(visibility(pub))] +pub enum Error { + #[snafu(display("MySQL server error, source: {}", source))] + MysqlServer { source: crate::mysql::error::Error }, +} + +pub type Result = std::result::Result; + +impl ErrorExt for Error { + fn status_code(&self) -> StatusCode { + match self { + Error::MysqlServer { .. } => StatusCode::Internal, + } + } + + fn backtrace_opt(&self) -> Option<&Backtrace> { + ErrorCompat::backtrace(self) + } + + fn as_any(&self) -> &dyn Any { + self + } +} diff --git a/src/common/servers/src/lib.rs b/src/common/servers/src/lib.rs new file mode 100644 index 0000000000..1274427dbd --- /dev/null +++ b/src/common/servers/src/lib.rs @@ -0,0 +1,3 @@ +mod error; +pub mod mysql; +pub mod server; diff --git a/src/common/servers/src/mysql/error.rs b/src/common/servers/src/mysql/error.rs new file mode 100644 index 0000000000..9f851be9f3 --- /dev/null +++ b/src/common/servers/src/mysql/error.rs @@ -0,0 +1,60 @@ +use std::any::Any; +use std::io; + +use common_error::prelude::*; + +#[derive(Debug, Snafu)] +#[snafu(visibility(pub))] +pub enum Error { + #[snafu(display("Internal error: {}", err_msg))] + Internal { err_msg: String }, + + #[snafu(display("Internal IO error, source: {}", source))] + InternalIo { source: io::Error }, + + #[snafu(display("Tokio IO error: {}, source: {}", err_msg, source))] + TokioIo { err_msg: String, source: io::Error }, + + #[snafu(display("Runtime resource error, source: {}", source))] + RuntimeResource { + source: common_runtime::error::Error, + }, + + #[snafu(display("Failed to convert vector, source: {}", source))] + VectorConversion { source: datatypes::error::Error }, + + #[snafu(display("Failed to collect recordbatch, source: {}", source))] + CollectRecordbatch { + source: common_recordbatch::error::Error, + }, +} + +pub type Result = std::result::Result; + +impl ErrorExt for Error { + fn status_code(&self) -> StatusCode { + match self { + Error::Internal { .. } | Error::InternalIo { .. } | Error::TokioIo { .. } => { + StatusCode::Unexpected + } + Error::VectorConversion { .. } | Error::CollectRecordbatch { .. } => { + StatusCode::Internal + } + Error::RuntimeResource { .. } => StatusCode::RuntimeResourcesExhausted, + } + } + + fn backtrace_opt(&self) -> Option<&Backtrace> { + ErrorCompat::backtrace(self) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl From for Error { + fn from(e: io::Error) -> Self { + Error::InternalIo { source: e } + } +} diff --git a/src/common/servers/src/mysql/mod.rs b/src/common/servers/src/mysql/mod.rs new file mode 100644 index 0000000000..47bc2a60a2 --- /dev/null +++ b/src/common/servers/src/mysql/mod.rs @@ -0,0 +1,4 @@ +pub mod error; +pub mod mysql_instance; +pub mod mysql_server; +pub mod mysql_writer; diff --git a/src/common/servers/src/mysql/mysql_instance.rs b/src/common/servers/src/mysql/mysql_instance.rs new file mode 100644 index 0000000000..bdf81fa248 --- /dev/null +++ b/src/common/servers/src/mysql/mysql_instance.rs @@ -0,0 +1,80 @@ +use std::io; +use std::sync::Arc; + +use async_trait::async_trait; +use opensrv_mysql::AsyncMysqlShim; +use opensrv_mysql::ErrorKind; +use opensrv_mysql::ParamParser; +use opensrv_mysql::QueryResultWriter; +use opensrv_mysql::StatementMetaWriter; +use query::query_engine::Output; + +use crate::mysql::error::{self, Result}; +use crate::mysql::mysql_writer::MysqlResultWriter; + +pub type MysqlInstanceRef = Arc; + +// TODO(LFC): Move to instance layer. +#[async_trait] +pub trait MysqlInstance { + async fn do_query(&self, query: &str) -> Result; +} + +// An intermediate shim for executing MySQL queries. +pub struct MysqlInstanceShim { + mysql_instance: MysqlInstanceRef, +} + +impl MysqlInstanceShim { + pub fn create(mysql_instance: MysqlInstanceRef) -> MysqlInstanceShim { + MysqlInstanceShim { mysql_instance } + } +} + +#[async_trait] +impl AsyncMysqlShim for MysqlInstanceShim { + type Error = error::Error; + + async fn on_prepare<'a>( + &'a mut self, + _: &'a str, + writer: StatementMetaWriter<'a, W>, + ) -> Result<()> { + writer.error( + ErrorKind::ER_UNKNOWN_ERROR, + "prepare statement is not supported yet".as_bytes(), + )?; + Ok(()) + } + + async fn on_execute<'a>( + &'a mut self, + _: u32, + _: ParamParser<'a>, + writer: QueryResultWriter<'a, W>, + ) -> Result<()> { + writer.error( + ErrorKind::ER_UNKNOWN_ERROR, + "prepare statement is not supported yet".as_bytes(), + )?; + Ok(()) + } + + async fn on_close<'a>(&'a mut self, _stmt_id: u32) + where + W: 'async_trait, + { + // do nothing because we haven't implemented prepare statement + } + + async fn on_query<'a>( + &'a mut self, + query: &'a str, + writer: QueryResultWriter<'a, W>, + ) -> Result<()> { + let output = self.mysql_instance.do_query(query).await; + + let mut writer = MysqlResultWriter::new(writer); + writer.write(output).await + } +} diff --git a/src/common/servers/src/mysql/mysql_server.rs b/src/common/servers/src/mysql/mysql_server.rs new file mode 100644 index 0000000000..b3d1976d95 --- /dev/null +++ b/src/common/servers/src/mysql/mysql_server.rs @@ -0,0 +1,137 @@ +use std::future::Future; +use std::net::SocketAddr; +use std::sync::Arc; + +use async_trait::async_trait; +use common_runtime::Runtime; +use common_telemetry::logging::{error, info}; +use futures::future::AbortHandle; +use futures::future::AbortRegistration; +use futures::future::Abortable; +use futures::StreamExt; +use opensrv_mysql::AsyncMysqlIntermediary; +use snafu::prelude::*; +use tokio; +use tokio::net::TcpStream; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::TcpListenerStream; + +use crate::error as server_error; +use crate::mysql::error::{self, Result}; +use crate::mysql::mysql_instance::{MysqlInstanceRef, MysqlInstanceShim}; +use crate::server::Server; + +pub struct MysqlServer { + // `abort_handle` and `abort_registration` are used in pairs in shutting down MySQL server. + // They work like sender and receiver for aborting stream. When the server is shutting down, + // calling `abort_handle.abort()` will "notify" `abort_registration` to stop emitting new + // elements in the stream. + abort_handle: AbortHandle, + abort_registration: Option, + + // A handle holding the TCP accepting task. + join_handle: Option>, + + mysql_handler: MysqlInstanceRef, + io_runtime: Arc, +} + +impl MysqlServer { + /// Creates a new MySQL server with provided [MysqlInstance] and [Runtime]. + pub fn create_server( + mysql_handler: MysqlInstanceRef, + io_runtime: Arc, + ) -> Box { + let (abort_handle, registration) = AbortHandle::new_pair(); + Box::new(MysqlServer { + abort_handle, + abort_registration: Some(registration), + join_handle: None, + mysql_handler, + io_runtime, + }) + } + + async fn bind(addr: SocketAddr) -> Result<(TcpListenerStream, SocketAddr)> { + let listener = tokio::net::TcpListener::bind(addr) + .await + .context(error::TokioIoSnafu { + err_msg: format!("Failed to bind addr {}", addr), + })?; + // get actually bond addr in case input addr use port 0 + let listener_addr = listener.local_addr()?; + Ok((TcpListenerStream::new(listener), listener_addr)) + } + + fn accept(&self, accepting_stream: Abortable) -> impl Future { + let io_runtime = self.io_runtime.clone(); + let mysql_handler = self.mysql_handler.clone(); + accepting_stream.for_each(move |tcp_stream| { + let io_runtime = io_runtime.clone(); + let mysql_handler = mysql_handler.clone(); + async move { + match tcp_stream { + Err(error) => error!("Broken pipe: {}", error), + Ok(io_stream) => { + if let Err(error) = Self::handle(io_stream, io_runtime, mysql_handler) { + error!("Unexpected error when handling TcpStream: {:?}", error); + }; + } + }; + } + }) + } + + pub fn handle( + stream: TcpStream, + io_runtime: Arc, + mysql_handler: MysqlInstanceRef, + ) -> Result<()> { + info!("MySQL connection coming from: {}", stream.peer_addr()?); + let shim = MysqlInstanceShim::create(mysql_handler); + // TODO(LFC): Relate "handler" with MySQL session; also deal with panics there. + let _handler = io_runtime.spawn(AsyncMysqlIntermediary::run_on(shim, stream)); + Ok(()) + } +} + +#[async_trait] +impl Server for MysqlServer { + async fn shutdown(&mut self) -> server_error::Result<()> { + match self.join_handle.take() { + Some(join_handle) => { + self.abort_handle.abort(); + + if let Err(error) = join_handle.await { + error!("Unexpected error during shutdown MySQL server: {}", error); + } else { + info!("MySQL server is shutdown.") + } + Ok(()) + } + None => error::InternalSnafu { + err_msg: "MySQL server is not started.", + } + .fail() + .context(server_error::MysqlServerSnafu), + } + } + + async fn start(&mut self, listening: SocketAddr) -> server_error::Result { + match self.abort_registration.take() { + Some(registration) => { + let (stream, listener) = Self::bind(listening) + .await + .context(server_error::MysqlServerSnafu)?; + let stream = Abortable::new(stream, registration); + self.join_handle = Some(tokio::spawn(self.accept(stream))); + Ok(listener) + } + None => error::InternalSnafu { + err_msg: "MySQL server has been started.", + } + .fail() + .context(server_error::MysqlServerSnafu), + } + } +} diff --git a/src/common/servers/src/mysql/mysql_writer.rs b/src/common/servers/src/mysql/mysql_writer.rs new file mode 100644 index 0000000000..305a7d0e04 --- /dev/null +++ b/src/common/servers/src/mysql/mysql_writer.rs @@ -0,0 +1,237 @@ +use std::io; + +use common_recordbatch::{util, RecordBatch}; +use datatypes::prelude::{ConcreteDataType, Value, VectorHelper}; +use datatypes::schema::{ColumnSchema, SchemaRef}; +use opensrv_mysql::{ + Column, ColumnFlags, ColumnType, ErrorKind, OkResponse, QueryResultWriter, RowWriter, +}; +use query::Output; +use snafu::prelude::*; + +use crate::mysql::error::{self, Error, Result}; + +struct QueryResult { + recordbatches: Vec, + schema: SchemaRef, +} + +pub struct MysqlResultWriter<'a, W: io::Write> { + // `QueryResultWriter` will be consumed when the write completed (see + // QueryResultWriter::completed), thus we use an option to wrap it. + inner: Option>, +} + +impl<'a, W: io::Write> MysqlResultWriter<'a, W> { + pub fn new(inner: QueryResultWriter<'a, W>) -> MysqlResultWriter<'a, W> { + MysqlResultWriter::<'a, W> { inner: Some(inner) } + } + + pub async fn write(&mut self, output: Result) -> Result<()> { + let writer = self.inner.take().context(error::InternalSnafu { + err_msg: "inner MySQL writer is consumed", + })?; + match output { + Ok(output) => match output { + Output::RecordBatch(stream) => { + let schema = stream.schema().clone(); + let recordbatches = util::collect(stream) + .await + .context(error::CollectRecordbatchSnafu)?; + let query_result = QueryResult { + recordbatches, + schema, + }; + Self::write_query_result(query_result, writer)? + } + Output::AffectedRows(rows) => Self::write_affected_rows(writer, rows)?, + }, + Err(error) => Self::write_query_error(error, writer)?, + } + Ok(()) + } + + fn write_affected_rows(writer: QueryResultWriter, rows: usize) -> Result<()> { + writer.completed(OkResponse { + affected_rows: rows as u64, + ..Default::default() + })?; + Ok(()) + } + + fn write_query_result( + query_result: QueryResult, + writer: QueryResultWriter<'a, W>, + ) -> Result<()> { + if query_result.recordbatches.is_empty() { + writer.completed(OkResponse::default())?; + return Ok(()); + } + + match create_mysql_column_def(&query_result.schema) { + Ok(column_def) => { + let mut row_writer = writer.start(&column_def)?; + for recordbatch in &query_result.recordbatches { + Self::write_recordbatch(&mut row_writer, recordbatch)?; + } + row_writer.finish()?; + Ok(()) + } + Err(error) => Self::write_query_error(error, writer), + } + } + + fn write_recordbatch(row_writer: &mut RowWriter, recordbatch: &RecordBatch) -> Result<()> { + let matrix = transpose(recordbatch)?; + for row in matrix.iter() { + for v in row.iter() { + match v { + Value::Null => row_writer.write_col(None::)?, + Value::Boolean(v) => row_writer.write_col(*v as i8)?, + Value::UInt8(v) => row_writer.write_col(v)?, + Value::UInt16(v) => row_writer.write_col(v)?, + Value::UInt32(v) => row_writer.write_col(v)?, + Value::UInt64(v) => row_writer.write_col(v)?, + Value::Int8(v) => row_writer.write_col(v)?, + Value::Int16(v) => row_writer.write_col(v)?, + Value::Int32(v) => row_writer.write_col(v)?, + Value::Int64(v) => row_writer.write_col(v)?, + Value::Float32(v) => row_writer.write_col(v.0)?, + Value::Float64(v) => row_writer.write_col(v.0)?, + Value::String(v) => row_writer.write_col(v.as_utf8())?, + Value::Binary(v) => row_writer.write_col(v.to_vec())?, + Value::Date(v) => row_writer.write_col(v)?, + Value::DateTime(v) => row_writer.write_col(v)?, + _ => { + return Err(Error::Internal { + err_msg: format!( + "cannot write value {:?} in mysql protocol: unimplemented", + v + ), + }) + } + } + } + row_writer.end_row()?; + } + Ok(()) + } + + fn write_query_error(error: Error, writer: QueryResultWriter<'a, W>) -> Result<()> { + writer.error(ErrorKind::ER_INTERNAL_ERROR, error.to_string().as_bytes())?; + Ok(()) + } +} + +fn create_mysql_column(column_schema: &ColumnSchema) -> Result { + let column_type = match column_schema.data_type { + ConcreteDataType::Null(_) => Ok(ColumnType::MYSQL_TYPE_NULL), + ConcreteDataType::Boolean(_) | ConcreteDataType::Int8(_) | ConcreteDataType::UInt8(_) => { + Ok(ColumnType::MYSQL_TYPE_TINY) + } + ConcreteDataType::Int16(_) | ConcreteDataType::UInt16(_) => { + Ok(ColumnType::MYSQL_TYPE_SHORT) + } + ConcreteDataType::Int32(_) | ConcreteDataType::UInt32(_) => Ok(ColumnType::MYSQL_TYPE_LONG), + ConcreteDataType::Int64(_) | ConcreteDataType::UInt64(_) => { + Ok(ColumnType::MYSQL_TYPE_LONGLONG) + } + ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => { + Ok(ColumnType::MYSQL_TYPE_FLOAT) + } + ConcreteDataType::Binary(_) | ConcreteDataType::String(_) => { + Ok(ColumnType::MYSQL_TYPE_VARCHAR) + } + _ => error::InternalSnafu { + err_msg: format!( + "not implemented for column datatype {:?}", + column_schema.data_type + ), + } + .fail(), + }; + column_type.map(|column_type| Column { + column: column_schema.name.clone(), + coltype: column_type, + + // TODO(LFC): Currently "table" and "colflags" are not relevant in MySQL server + // implementation, will revisit them again in the future. + table: "".to_string(), + colflags: ColumnFlags::empty(), + }) +} + +/// Creates MySQL columns definition from our column schema. +pub fn create_mysql_column_def(schema: &SchemaRef) -> Result> { + schema + .column_schemas() + .iter() + .map(create_mysql_column) + .collect() +} + +/// RecordBatch organizes its values in columns while MySQL needs to write row by row. +/// This function creates a view of [Value]s organized in rows from RecordBatch (just like matrix +/// transpose, hence the function name), helping us write RecordBatch to MySQL. +fn transpose(recordbatch: &RecordBatch) -> Result>> { + let recordbatch = &recordbatch.df_recordbatch; + let rows = recordbatch.num_rows(); + let columns = recordbatch.num_columns(); + let mut matrix = vec![vec![Value::Null; columns]; rows]; + for column in 0..columns { + let array = recordbatch.column(column); + let vector = VectorHelper::try_into_vector(array).context(error::VectorConversionSnafu)?; + // Clippy suggests us to use "matrix.iter_mut().enumerate().take(rows)", which is not wanted. + #[allow(clippy::needless_range_loop)] + for row in 0..rows { + matrix[row][column] = vector.get(row); + } + } + Ok(matrix) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_base::bytes::StringBytes; + use datatypes::prelude::*; + use datatypes::schema::Schema; + use datatypes::vectors::{StringVector, UInt32Vector}; + + use super::*; + + #[test] + fn test_transpose() { + let column_schemas = vec![ + ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false), + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), + ]; + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![ + Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])), + Arc::new(StringVector::from(vec![ + None, + Some("hello"), + Some("greptime"), + None, + ])), + ]; + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let matrix = transpose(&recordbatch).unwrap(); + assert_eq!(4, matrix.len()); + assert_eq!(vec![Value::UInt32(1), Value::Null], matrix[0]); + assert_eq!( + vec![Value::UInt32(2), Value::String(StringBytes::from("hello"))], + matrix[1] + ); + assert_eq!( + vec![ + Value::UInt32(3), + Value::String(StringBytes::from("greptime")) + ], + matrix[2] + ); + assert_eq!(vec![Value::UInt32(4), Value::Null], matrix[3]); + } +} diff --git a/src/common/servers/src/server.rs b/src/common/servers/src/server.rs new file mode 100644 index 0000000000..77ffd615d7 --- /dev/null +++ b/src/common/servers/src/server.rs @@ -0,0 +1,11 @@ +use std::net::SocketAddr; + +use async_trait::async_trait; + +use crate::error::Result; + +#[async_trait] +pub trait Server: Send { + async fn shutdown(&mut self) -> Result<()>; + async fn start(&mut self, listening: SocketAddr) -> Result; +} diff --git a/src/common/servers/tests/mod.rs b/src/common/servers/tests/mod.rs new file mode 100644 index 0000000000..02b5c273ef --- /dev/null +++ b/src/common/servers/tests/mod.rs @@ -0,0 +1 @@ +mod mysql; diff --git a/src/common/servers/tests/mysql/mod.rs b/src/common/servers/tests/mysql/mod.rs new file mode 100644 index 0000000000..ddaa407ded --- /dev/null +++ b/src/common/servers/tests/mysql/mod.rs @@ -0,0 +1,269 @@ +use std::sync::Arc; + +use datatypes::prelude::*; +use datatypes::schema::ColumnSchema; +use datatypes::vectors::{ + BinaryVector, BooleanVector, Float32Vector, Float64Vector, Int16Vector, Int32Vector, + Int64Vector, Int8Vector, NullVector, StringVector, UInt16Vector, UInt32Vector, UInt64Vector, + UInt8Vector, +}; +use mysql_async::prelude::FromRow; +use mysql_async::FromRowError; +use mysql_async::Value as MysqlValue; +use opensrv_mysql::ColumnType; + +mod mysql_server_test; +mod mysql_writer_test; + +pub struct TestingData { + column_schemas: Vec, + mysql_columns_def: Vec, + columns: Vec, + mysql_text_output_rows: Vec>, +} + +impl TestingData { + fn new( + column_schemas: Vec, + mysql_columns_def: Vec, + columns: Vec, + mysql_text_output_rows: Vec>, + ) -> Self { + // Check input columns have same size, + assert_eq!(column_schemas.len(), mysql_columns_def.len()); + assert_eq!(column_schemas.len(), columns.len()); + // and all columns length are equal + assert!(columns.windows(2).all(|x| x[0].len() == x[1].len())); + // and all output rows width are equal + assert!(mysql_text_output_rows + .windows(2) + .all(|x| x[0].len() == x[1].len())); + // and the rows' columns size equals to input columns size. + assert_eq!(columns.first().unwrap().len(), mysql_text_output_rows.len()); + + TestingData { + column_schemas, + mysql_columns_def, + columns, + mysql_text_output_rows, + } + } +} + +#[derive(Debug)] +struct MysqlTextRow { + values: Vec, +} + +impl FromRow for MysqlTextRow { + fn from_row_opt(row: mysql_async::Row) -> Result + where + Self: Sized, + { + let mut values = Vec::with_capacity(row.len()); + for i in 0..row.len() { + let value = if let Some(mysql_value) = row.as_ref(i) { + match mysql_value { + MysqlValue::NULL => Value::Null, + MysqlValue::Bytes(v) => Value::from(v.to_vec()), + _ => unreachable!(), + } + } else { + Value::Null + }; + values.push(value); + } + Ok(MysqlTextRow { values }) + } +} + +pub fn all_datatype_testing_data() -> TestingData { + let column_schemas = vec![ + ColumnSchema::new("nulls", ConcreteDataType::null_datatype(), true), + ColumnSchema::new("bools", ConcreteDataType::boolean_datatype(), true), + ColumnSchema::new("int8s", ConcreteDataType::int8_datatype(), true), + ColumnSchema::new("int16s", ConcreteDataType::int16_datatype(), true), + ColumnSchema::new("int32s", ConcreteDataType::int32_datatype(), true), + ColumnSchema::new("int64s", ConcreteDataType::int64_datatype(), true), + ColumnSchema::new("uint8s", ConcreteDataType::uint8_datatype(), true), + ColumnSchema::new("uint16s", ConcreteDataType::uint16_datatype(), true), + ColumnSchema::new("uint32s", ConcreteDataType::uint32_datatype(), true), + ColumnSchema::new("uint64s", ConcreteDataType::uint64_datatype(), true), + ColumnSchema::new("float32s", ConcreteDataType::float32_datatype(), true), + ColumnSchema::new("float64s", ConcreteDataType::float64_datatype(), true), + ColumnSchema::new("binaries", ConcreteDataType::binary_datatype(), true), + ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true), + ]; + let mysql_columns_def = vec![ + ColumnType::MYSQL_TYPE_NULL, + ColumnType::MYSQL_TYPE_TINY, + ColumnType::MYSQL_TYPE_TINY, + ColumnType::MYSQL_TYPE_SHORT, + ColumnType::MYSQL_TYPE_LONG, + ColumnType::MYSQL_TYPE_LONGLONG, + ColumnType::MYSQL_TYPE_TINY, + ColumnType::MYSQL_TYPE_SHORT, + ColumnType::MYSQL_TYPE_LONG, + ColumnType::MYSQL_TYPE_LONGLONG, + ColumnType::MYSQL_TYPE_FLOAT, + ColumnType::MYSQL_TYPE_FLOAT, + ColumnType::MYSQL_TYPE_VARCHAR, + ColumnType::MYSQL_TYPE_VARCHAR, + ]; + let columns: Vec = vec![ + Arc::new(NullVector::new(4)), + Arc::new(BooleanVector::from(vec![ + Some(true), + None, + Some(false), + None, + ])), + Arc::new(Int8Vector::from(vec![ + Some(i8::MIN), + None, + Some(i8::MAX), + None, + ])), + Arc::new(Int16Vector::from(vec![ + Some(i16::MIN), + None, + Some(i16::MAX), + None, + ])), + Arc::new(Int32Vector::from(vec![ + Some(i32::MIN), + None, + Some(i32::MAX), + None, + ])), + Arc::new(Int64Vector::from(vec![ + Some(i64::MIN), + None, + Some(i64::MAX), + None, + ])), + Arc::new(UInt8Vector::from(vec![ + Some(u8::MIN), + None, + Some(u8::MAX), + None, + ])), + Arc::new(UInt16Vector::from(vec![ + Some(u16::MIN), + None, + Some(u16::MAX), + None, + ])), + Arc::new(UInt32Vector::from(vec![ + Some(u32::MIN), + None, + Some(u32::MAX), + None, + ])), + Arc::new(UInt64Vector::from(vec![ + Some(u64::MIN), + None, + Some(u64::MAX), + None, + ])), + Arc::new(Float32Vector::from(vec![ + Some(-1.123456_f32), + None, + Some(1.654321), + None, + ])), + Arc::new(Float64Vector::from(vec![ + Some(-10.123456_f64), + None, + Some(10.654321), + None, + ])), + Arc::new(BinaryVector::from(vec![ + None, + Some("hello".as_bytes().to_vec()), + Some("greptime".as_bytes().to_vec()), + None, + ])), + Arc::new(StringVector::from(vec![ + Some("hola"), + None, + None, + Some("GT"), + ])), + ]; + + // Because we can only use MySQL text protocol (binary protocol requires prepared statement, + // which we are not implemented yet), every MysqlValue is of type "Bytes" + let mysql_text_output_rows = vec![ + vec![ + Value::Null, + Value::from("1".as_bytes()), + Value::from(i8::MIN.to_string().as_bytes()), + Value::from(i16::MIN.to_string().as_bytes()), + Value::from(i32::MIN.to_string().as_bytes()), + Value::from(i64::MIN.to_string().as_bytes()), + Value::from(u8::MIN.to_string().as_bytes()), + Value::from(u16::MIN.to_string().as_bytes()), + Value::from(u32::MIN.to_string().as_bytes()), + Value::from(u64::MIN.to_string().as_bytes()), + Value::from((-1.123456_f32).to_string().as_bytes()), + Value::from((-10.123456_f64).to_string().as_bytes()), + Value::Null, + Value::from("hola".as_bytes()), + ], + vec![ + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::from("hello".as_bytes()), + Value::Null, + ], + vec![ + Value::Null, + Value::from("0".as_bytes()), + Value::from(i8::MAX.to_string().as_bytes()), + Value::from(i16::MAX.to_string().as_bytes()), + Value::from(i32::MAX.to_string().as_bytes()), + Value::from(i64::MAX.to_string().as_bytes()), + Value::from(u8::MAX.to_string().as_bytes()), + Value::from(u16::MAX.to_string().as_bytes()), + Value::from(u32::MAX.to_string().as_bytes()), + Value::from(u64::MAX.to_string().as_bytes()), + Value::from(1.654321_f32.to_string().as_bytes()), + Value::from(10.654321_f64.to_string().as_bytes()), + Value::from("greptime".as_bytes()), + Value::Null, + ], + vec![ + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::Null, + Value::from("GT".as_bytes()), + ], + ]; + TestingData::new( + column_schemas, + mysql_columns_def, + columns, + mysql_text_output_rows, + ) +} diff --git a/src/common/servers/tests/mysql/mysql_server_test.rs b/src/common/servers/tests/mysql/mysql_server_test.rs new file mode 100644 index 0000000000..638e59e3b5 --- /dev/null +++ b/src/common/servers/tests/mysql/mysql_server_test.rs @@ -0,0 +1,221 @@ +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use catalog::memory::{MemoryCatalogList, MemoryCatalogProvider, MemorySchemaProvider}; +use catalog::{CatalogList, SchemaProvider, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_recordbatch::RecordBatch; +use common_runtime::Builder as RuntimeBuilder; +use common_servers::mysql::error::{Result, RuntimeResourceSnafu}; +use common_servers::mysql::mysql_instance::MysqlInstance; +use common_servers::mysql::mysql_server::MysqlServer; +use common_servers::server::Server; +use datatypes::schema::Schema; +use mysql_async::prelude::*; +use query::{Output, QueryEngineFactory, QueryEngineRef}; +use rand::rngs::StdRng; +use rand::Rng; +use snafu::prelude::*; +use test_util::MemTable; + +use crate::mysql::{all_datatype_testing_data, MysqlTextRow, TestingData}; + +fn create_mysql_server(table: MemTable) -> Result> { + let table_name = table.table_name().to_string(); + let table = Arc::new(table); + + let schema_provider = Arc::new(MemorySchemaProvider::new()); + schema_provider.register_table(table_name, table).unwrap(); + let catalog_provider = Arc::new(MemoryCatalogProvider::new()); + catalog_provider.register_schema(DEFAULT_SCHEMA_NAME, schema_provider); + let catalog_list = Arc::new(MemoryCatalogList::default()); + catalog_list.register_catalog(DEFAULT_CATALOG_NAME.to_string(), catalog_provider); + let factory = QueryEngineFactory::new(catalog_list); + let query_engine = factory.query_engine().clone(); + + let mysql_instance = Arc::new(DummyMysqlInstance { query_engine }); + let io_runtime = Arc::new( + RuntimeBuilder::default() + .worker_threads(4) + .thread_name("mysql-io-handlers") + .build() + .context(RuntimeResourceSnafu)?, + ); + Ok(MysqlServer::create_server(mysql_instance, io_runtime)) +} + +#[tokio::test] +async fn test_start_mysql_server() -> Result<()> { + let table = MemTable::default_numbers_table(); + + let mut mysql_server = create_mysql_server(table)?; + let listening = "127.0.0.1:0".parse::().unwrap(); + let result = mysql_server.start(listening).await; + assert!(result.is_ok()); + + let result = mysql_server.start(listening).await; + assert!(result + .unwrap_err() + .to_string() + .contains("MySQL server has been started.")); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_shutdown_mysql_server() -> Result<()> { + common_telemetry::init_default_ut_logging(); + + let table = MemTable::default_numbers_table(); + + let mut mysql_server = create_mysql_server(table)?; + let result = mysql_server.shutdown().await; + assert!(result + .unwrap_err() + .to_string() + .contains("MySQL server is not started.")); + + let listening = "127.0.0.1:0".parse::().unwrap(); + let server_addr = mysql_server.start(listening).await.unwrap(); + let server_port = server_addr.port(); + + let mut join_handles = vec![]; + for _ in 0..2 { + join_handles.push(tokio::spawn(async move { + for _ in 0..1000 { + match create_connection(server_port).await { + Ok(mut connection) => { + let result: u32 = connection + .query_first("SELECT uint32s FROM numbers LIMIT 1") + .await + .unwrap() + .unwrap(); + assert_eq!(result, 0); + tokio::time::sleep(Duration::from_millis(10)).await; + } + Err(e) => return Err(e), + } + } + Ok(()) + })) + } + + tokio::time::sleep(Duration::from_millis(100)).await; + let result = mysql_server.shutdown().await; + assert!(result.is_ok()); + + for handle in join_handles.iter_mut() { + let result = handle.await.unwrap(); + assert!(result.is_err()); + let error = result.unwrap_err().to_string(); + assert!(error.contains("Connection refused") || error.contains("Connection reset by peer")); + } + Ok(()) +} + +#[tokio::test] +async fn test_query_all_datatypes() -> Result<()> { + common_telemetry::init_default_ut_logging(); + + let TestingData { + column_schemas, + mysql_columns_def, + columns, + mysql_text_output_rows, + } = all_datatype_testing_data(); + let schema = Arc::new(Schema::new(column_schemas.clone())); + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + let table = MemTable::new("all_datatypes", recordbatch); + + let mut mysql_server = create_mysql_server(table)?; + let listening = "127.0.0.1:0".parse::().unwrap(); + let server_addr = mysql_server.start(listening).await.unwrap(); + + let mut connection = create_connection(server_addr.port()).await.unwrap(); + let mut result = connection + .query_iter("SELECT * FROM all_datatypes LIMIT 3") + .await + .unwrap(); + let columns = result.columns().unwrap(); + assert_eq!(column_schemas.len(), columns.len()); + + for (i, column) in columns.iter().enumerate() { + assert_eq!(mysql_columns_def[i], column.column_type()); + assert_eq!(column_schemas[i].name, column.name_str()); + } + + let rows = result.collect::().await.unwrap(); + assert_eq!(3, rows.len()); + for (expected, actual) in mysql_text_output_rows.iter().take(3).zip(rows.iter()) { + assert_eq!(expected, &actual.values); + } + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn test_query_concurrently() -> Result<()> { + common_telemetry::init_default_ut_logging(); + + let table = MemTable::default_numbers_table(); + + let mut mysql_server = create_mysql_server(table)?; + let listening = "127.0.0.1:0".parse::().unwrap(); + let server_addr = mysql_server.start(listening).await.unwrap(); + let server_port = server_addr.port(); + + let threads = 4; + let expect_executed_queries_per_worker = 1000; + let mut join_handles = vec![]; + for _ in 0..threads { + join_handles.push(tokio::spawn(async move { + let mut rand: StdRng = rand::SeedableRng::from_entropy(); + + let mut connection = create_connection(server_port).await.unwrap(); + for _ in 0..expect_executed_queries_per_worker { + let expected: u32 = rand.gen_range(0..100); + let result: u32 = connection + .query_first(format!( + "SELECT uint32s FROM numbers WHERE uint32s = {}", + expected + )) + .await + .unwrap() + .unwrap(); + assert_eq!(result, expected); + + let should_recreate_conn = expected == 1; + if should_recreate_conn { + connection = create_connection(server_port).await.unwrap(); + } + } + expect_executed_queries_per_worker + })) + } + let mut total_pending_queries = threads * expect_executed_queries_per_worker; + for handle in join_handles.iter_mut() { + total_pending_queries -= handle.await.unwrap(); + } + assert_eq!(0, total_pending_queries); + Ok(()) +} + +async fn create_connection(port: u16) -> mysql_async::Result { + let opts = mysql_async::OptsBuilder::default() + .ip_or_hostname("127.0.0.1") + .tcp_port(port) + .prefer_socket(false) + .wait_timeout(Some(1000)); + mysql_async::Conn::new(opts).await +} + +struct DummyMysqlInstance { + query_engine: QueryEngineRef, +} + +#[async_trait] +impl MysqlInstance for DummyMysqlInstance { + async fn do_query(&self, query: &str) -> Result { + let plan = self.query_engine.sql_to_plan(query).unwrap(); + Ok(self.query_engine.execute(&plan).await.unwrap()) + } +} diff --git a/src/common/servers/tests/mysql/mysql_writer_test.rs b/src/common/servers/tests/mysql/mysql_writer_test.rs new file mode 100644 index 0000000000..064c700751 --- /dev/null +++ b/src/common/servers/tests/mysql/mysql_writer_test.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use common_servers::mysql::mysql_writer::create_mysql_column_def; +use datatypes::prelude::*; +use datatypes::schema::{ColumnSchema, Schema}; + +use crate::mysql::{all_datatype_testing_data, TestingData}; + +#[test] +fn test_create_mysql_column_def() { + let TestingData { + column_schemas, + mysql_columns_def, + .. + } = all_datatype_testing_data(); + let schema = Arc::new(Schema::new(column_schemas.clone())); + let columns_def = create_mysql_column_def(&schema).unwrap(); + assert_eq!(column_schemas.len(), columns_def.len()); + + for (i, column_def) in columns_def.iter().enumerate() { + let column_schema = &column_schemas[i]; + assert_eq!(column_schema.name, column_def.column); + let expected_coltype = mysql_columns_def[i]; + assert_eq!(column_def.coltype, expected_coltype); + } + + let column_schemas = vec![ColumnSchema::new( + "lists", + ConcreteDataType::list_datatype(ConcreteDataType::string_datatype()), + true, + )]; + let schema = Arc::new(Schema::new(column_schemas)); + assert!(create_mysql_column_def(&schema).is_err()); +} diff --git a/src/query/tests/my_sum_udaf_example.rs b/src/query/tests/my_sum_udaf_example.rs index 3c530f96b9..8f1b65f4d5 100644 --- a/src/query/tests/my_sum_udaf_example.rs +++ b/src/query/tests/my_sum_udaf_example.rs @@ -26,7 +26,6 @@ use num_traits::AsPrimitive; use query::error::Result; use query::query_engine::Output; use query::QueryEngineFactory; -use table::TableRef; use test_util::MemTable; #[derive(Debug, Default)] @@ -225,9 +224,9 @@ where let schema = Arc::new(Schema::new(column_schemas.clone())); let column: VectorRef = Arc::new(PrimitiveVector::::from_vec(numbers)); let recordbatch = RecordBatch::new(schema, vec![column]).unwrap(); - let testing_table = Arc::new(MemTable::new(recordbatch)); + let testing_table = MemTable::new(&table_name, recordbatch); - let factory = new_query_engine_factory(table_name.clone(), testing_table); + let factory = new_query_engine_factory(testing_table); let engine = factory.query_engine(); engine.register_aggregate_function(Arc::new(AggregateFunctionMeta::new( @@ -258,7 +257,10 @@ where Ok(()) } -pub fn new_query_engine_factory(table_name: String, table: TableRef) -> QueryEngineFactory { +fn new_query_engine_factory(table: MemTable) -> QueryEngineFactory { + let table_name = table.table_name().to_string(); + let table = Arc::new(table); + let schema_provider = Arc::new(MemorySchemaProvider::new()); let catalog_provider = Arc::new(MemoryCatalogProvider::new()); let catalog_list = Arc::new(MemoryCatalogList::default()); diff --git a/src/query/tests/query_engine_test.rs b/src/query/tests/query_engine_test.rs index c12a60c650..0b0f1d7c3c 100644 --- a/src/query/tests/query_engine_test.rs +++ b/src/query/tests/query_engine_test.rs @@ -47,7 +47,7 @@ async fn test_datafusion_query_engine() -> Result<()> { (0..100).collect::>(), ))]; let recordbatch = RecordBatch::new(schema, columns).unwrap(); - let table = Arc::new(MemTable::new(recordbatch)); + let table = Arc::new(MemTable::new("numbers", recordbatch)); let limit = 10; let table_provider = Arc::new(DfTableProviderAdapter::new(table.clone())); @@ -170,9 +170,12 @@ fn create_query_engine() -> Arc { let schema = Arc::new(Schema::new(column_schemas.clone())); let recordbatch = RecordBatch::new(schema, columns).unwrap(); - let even_number_table = Arc::new(MemTable::new(recordbatch)); + let even_number_table = Arc::new(MemTable::new("even_numbers", recordbatch)); schema_provider - .register_table("even_numbers".to_string(), even_number_table) + .register_table( + even_number_table.table_name().to_string(), + even_number_table, + ) .unwrap(); // create table with ordered primitives, and all columns' length are odd @@ -197,9 +200,9 @@ fn create_query_engine() -> Arc { let schema = Arc::new(Schema::new(column_schemas.clone())); let recordbatch = RecordBatch::new(schema, columns).unwrap(); - let odd_number_table = Arc::new(MemTable::new(recordbatch)); + let odd_number_table = Arc::new(MemTable::new("odd_numbers", recordbatch)); schema_provider - .register_table("odd_numbers".to_string(), odd_number_table) + .register_table(odd_number_table.table_name().to_string(), odd_number_table) .unwrap(); // create table with floating numbers @@ -212,9 +215,12 @@ fn create_query_engine() -> Arc { let columns = vec![f32_numbers, f64_numbers]; let schema = Arc::new(Schema::new(column_schemas)); let recordbatch = RecordBatch::new(schema, columns).unwrap(); - let float_number_table = Arc::new(MemTable::new(recordbatch)); + let float_number_table = Arc::new(MemTable::new("float_numbers", recordbatch)); schema_provider - .register_table("float_numbers".to_string(), float_number_table) + .register_table( + float_number_table.table_name().to_string(), + float_number_table, + ) .unwrap(); catalog_provider.register_schema(DEFAULT_SCHEMA_NAME.to_string(), schema_provider); diff --git a/test-util/src/memtable.rs b/test-util/src/memtable.rs index 08066eb3ff..72a3d43340 100644 --- a/test-util/src/memtable.rs +++ b/test-util/src/memtable.rs @@ -6,7 +6,9 @@ use async_trait::async_trait; use common_query::prelude::Expr; use common_recordbatch::error::Result as RecordBatchResult; use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStream}; -use datatypes::schema::{Schema, SchemaRef}; +use datatypes::prelude::*; +use datatypes::schema::{ColumnSchema, Schema, SchemaRef}; +use datatypes::vectors::UInt32Vector; use futures::task::{Context, Poll}; use futures::Stream; use snafu::prelude::*; @@ -15,12 +17,36 @@ use table::Table; #[derive(Debug, Clone)] pub struct MemTable { + table_name: String, recordbatch: RecordBatch, } impl MemTable { - pub fn new(recordbatch: RecordBatch) -> Self { - Self { recordbatch } + pub fn new(table_name: impl Into, recordbatch: RecordBatch) -> Self { + Self { + table_name: table_name.into(), + recordbatch, + } + } + + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Creates a 1 column 100 rows table, with table name "numbers", column name "uint32s" and + /// column type "uint32". Column data increased from 0 to 100. + pub fn default_numbers_table() -> Self { + let column_schemas = vec![ColumnSchema::new( + "uint32s", + ConcreteDataType::uint32_datatype(), + true, + )]; + let schema = Arc::new(Schema::new(column_schemas)); + let columns: Vec = vec![Arc::new(UInt32Vector::from_slice( + (0..100).collect::>(), + ))]; + let recordbatch = RecordBatch::new(schema, columns).unwrap(); + MemTable::new("numbers", recordbatch) } } @@ -167,6 +193,6 @@ mod test { ])), ]; let recordbatch = RecordBatch::new(schema, columns).unwrap(); - MemTable::new(recordbatch) + MemTable::new("", recordbatch) } }