From 0f6b8ff81595d15801d31dbc794517b5ba7df60a Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Thu, 26 Feb 2026 11:41:33 +0800 Subject: [PATCH] feat: implement postgres copy to stdout (#7709) * feat: update pgwire * feat: add special parser for copy to stdout * feat: implement copy to stdout * fix: improve code * fix: expect optional with * fix: lint * feat: correct encoder using and refactor * chore: fmt * refactor: update api * chore: use released dependencies * fix: update datafusion-pg-catalog to support schema query * fix: support for double quoted identifier * feat: update datafusion-postgres to support schema.table * refactor: use pgsqlparser container * refactor: remove unquote which is no longer needed * fix: correctly handle invalid query * fix: correct handle null in nano timestamp * test: add a new test for additional close ) --- Cargo.lock | 281 +++++++++++++-- Cargo.toml | 3 +- src/servers/Cargo.toml | 4 +- src/servers/src/postgres/handler.rs | 274 ++++++++++++-- src/servers/src/postgres/types.rs | 208 +++++++---- src/sql/src/parsers/copy_parser.rs | 334 ++++++++++++++++-- .../common/system/pg_catalog.result | 29 ++ .../standalone/common/system/pg_catalog.sql | 17 + 8 files changed, 986 insertions(+), 164 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c52a4565a2..a8af990b36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,7 +41,7 @@ checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" dependencies = [ "cfg-if", "cipher", - "cpufeatures", + "cpufeatures 0.2.17", ] [[package]] @@ -642,9 +642,9 @@ dependencies = [ [[package]] name = "arrow-pg" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87bc2eb53228ffb0cffff4a8a99d5311641b6d8ce63ec48b860dab70ec01ae1f" +checksum = "978af69bfebf96147f743e5ad50f68c058ec80593c156723d94775a362d3dc42" dependencies = [ "arrow 57.0.0", "arrow-schema 57.0.0", @@ -1816,7 +1816,18 @@ checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" dependencies = [ "cfg-if", "cipher", - "cpufeatures", + "cpufeatures 0.2.17", +] + +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.0", ] [[package]] @@ -1826,7 +1837,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" dependencies = [ "aead", - "chacha20", + "chacha20 0.9.1", "cipher", "poly1305", "zeroize", @@ -2211,7 +2222,7 @@ checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81" dependencies = [ "serde", "termcolor", - "unicode-width 0.2.1", + "unicode-width 0.1.14", ] [[package]] @@ -2222,7 +2233,7 @@ checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681" dependencies = [ "serde", "termcolor", - "unicode-width 0.2.1", + "unicode-width 0.1.14", ] [[package]] @@ -2238,7 +2249,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" dependencies = [ "lazy_static", - "windows-sys 0.59.0", + "windows-sys 0.48.0", ] [[package]] @@ -3049,7 +3060,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3bb320cac8a0750d7f25280aa97b09c26edfe161164238ecbbb31092b079e735" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "proptest", "serde_core", ] @@ -3186,6 +3197,15 @@ dependencies = [ "libc", ] +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + [[package]] name = "crc" version = "3.3.0" @@ -4115,7 +4135,7 @@ dependencies = [ [[package]] name = "datafusion-pg-catalog" version = "0.13.1" -source = "git+https://github.com/GreptimeTeam/datafusion-postgres.git?rev=74ac8e2806be6de91ff192b97f64735392539d16#74ac8e2806be6de91ff192b97f64735392539d16" +source = "git+https://github.com/GreptimeTeam/datafusion-postgres.git?rev=f675927a79cd714a8eeb438b0d3015cd54d4e60a#f675927a79cd714a8eeb438b0d3015cd54d4e60a" dependencies = [ "async-trait", "datafusion", @@ -4909,7 +4929,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -5717,6 +5737,20 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "rand_core 0.10.0", + "wasip2", + "wasip3", +] + [[package]] name = "gimli" version = "0.31.1" @@ -6273,7 +6307,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.0", + "socket2 0.5.10", "tokio", "tower-service", "tracing", @@ -6344,7 +6378,7 @@ dependencies = [ "js-sys", "log", "wasm-bindgen", - "windows-core 0.61.2", + "windows-core 0.57.0", ] [[package]] @@ -6442,6 +6476,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "ident_case" version = "1.0.1" @@ -6582,6 +6622,8 @@ checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" dependencies = [ "equivalent", "hashbrown 0.16.0", + "serde", + "serde_core", ] [[package]] @@ -6758,7 +6800,7 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ "hermit-abi 0.5.2", "libc", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -6987,7 +7029,7 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653" dependencies = [ - "cpufeatures", + "cpufeatures 0.2.17", ] [[package]] @@ -7218,6 +7260,12 @@ dependencies = [ "spin", ] +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "levenshtein_automata" version = "0.2.1" @@ -7341,7 +7389,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -8731,7 +8779,7 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77e878c846a8abae00dd069496dbe8751b16ac1c3d6bd2a7283a938e8228f90d" dependencies = [ - "proc-macro-crate 3.3.0", + "proc-macro-crate 1.3.1", "proc-macro2", "quote", "syn 2.0.114", @@ -9653,9 +9701,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.37.3" +version = "0.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fcd410bc6990bd8d20b3fe3cd879a3c3ec250bdb1cb12537b528818823b02c9" +checksum = "89d5e5a60d3f6e40c91f6a2a7f8d09665e636272bd5611977253559b6651aabb" dependencies = [ "async-trait", "base64 0.22.1", @@ -9668,7 +9716,7 @@ dependencies = [ "md5", "pg_interval_2", "postgres-types", - "rand 0.9.1", + "rand 0.10.0", "ring", "rust_decimal", "rustls-pki-types", @@ -9973,7 +10021,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" dependencies = [ - "cpufeatures", + "cpufeatures 0.2.17", "opaque-debug", "universal-hash", ] @@ -10364,7 +10412,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck 0.5.0", + "heck 0.4.1", "itertools 0.12.1", "log", "multimap", @@ -10384,8 +10432,8 @@ version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ - "heck 0.5.0", - "itertools 0.14.0", + "heck 0.4.1", + "itertools 0.10.5", "log", "multimap", "once_cell", @@ -10402,8 +10450,8 @@ version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" dependencies = [ - "heck 0.5.0", - "itertools 0.14.0", + "heck 0.4.1", + "itertools 0.10.5", "log", "multimap", "once_cell", @@ -10451,7 +10499,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.14.0", + "itertools 0.10.5", "proc-macro2", "quote", "syn 2.0.114", @@ -10464,7 +10512,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" dependencies = [ "anyhow", - "itertools 0.14.0", + "itertools 0.10.5", "proc-macro2", "quote", "syn 2.0.114", @@ -10851,7 +10899,7 @@ dependencies = [ "once_cell", "socket2 0.5.10", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -10942,6 +10990,17 @@ dependencies = [ "rand_core 0.9.3", ] +[[package]] +name = "rand" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +dependencies = [ + "chacha20 0.10.0", + "getrandom 0.4.1", + "rand_core 0.10.0", +] + [[package]] name = "rand_chacha" version = "0.3.1" @@ -10981,6 +11040,12 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "rand_core" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" + [[package]] name = "rand_distr" version = "0.4.3" @@ -11629,7 +11694,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -11642,7 +11707,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.9.4", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -12305,7 +12370,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -12316,7 +12381,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -12327,7 +12392,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -12545,7 +12610,7 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1961e2ef424c1424204d3a5d6975f934f56b6d50ff5732382d84ebf460e147f7" dependencies = [ - "heck 0.5.0", + "heck 0.4.1", "proc-macro2", "quote", "syn 2.0.114", @@ -12958,7 +13023,7 @@ dependencies = [ "cfg-if", "libc", "psm", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -13582,7 +13647,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix 1.0.7", - "windows-sys 0.61.2", + "windows-sys 0.52.0", ] [[package]] @@ -15077,6 +15142,24 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasite" version = "0.1.0" @@ -15154,6 +15237,28 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap 2.12.0", + "wasm-encoder", + "wasmparser", +] + [[package]] name = "wasm-streams" version = "0.4.2" @@ -15167,6 +15272,18 @@ dependencies = [ "web-sys", ] +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags 2.9.1", + "hashbrown 0.15.4", + "indexmap 2.12.0", + "semver", +] + [[package]] name = "web-sys" version = "0.3.77" @@ -15258,7 +15375,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.48.0", ] [[package]] @@ -15621,6 +15738,26 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck 0.5.0", + "wit-parser", +] + [[package]] name = "wit-bindgen-rt" version = "0.39.0" @@ -15630,6 +15767,74 @@ dependencies = [ "bitflags 2.9.1", ] +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck 0.5.0", + "indexmap 2.12.0", + "prettyplease", + "syn 2.0.114", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.114", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags 2.9.1", + "indexmap 2.12.0", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap 2.12.0", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + [[package]] name = "wkt" version = "0.11.1" diff --git a/Cargo.toml b/Cargo.toml index ca1b30c862..2114e02863 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -340,12 +340,11 @@ datafusion-optimizer = { git = "https://github.com/GreptimeTeam/datafusion.git", datafusion-physical-expr = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" } datafusion-physical-expr-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" } datafusion-physical-plan = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" } -datafusion-pg-catalog = { git = "https://github.com/GreptimeTeam/datafusion-postgres.git", rev = "74ac8e2806be6de91ff192b97f64735392539d16" } +datafusion-pg-catalog = { git = "https://github.com/GreptimeTeam/datafusion-postgres.git", rev = "f675927a79cd714a8eeb438b0d3015cd54d4e60a" } datafusion-datasource = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" } datafusion-sql = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" } datafusion-substrait = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" } sqlparser = { git = "https://github.com/GreptimeTeam/sqlparser-rs.git", rev = "d7d95a44889e099e32d78e9bad9bc00598faef28" } # on branch v0.59.x - [profile.release] debug = 1 diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 428a96e15b..0bab854dc5 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -23,7 +23,7 @@ api.workspace = true arrow.workspace = true arrow-flight.workspace = true arrow-ipc.workspace = true -arrow-pg = "0.11" +arrow-pg = "0.12" arrow-schema.workspace = true async-trait.workspace = true auth.workspace = true @@ -89,7 +89,7 @@ operator.workspace = true otel-arrow-rust.workspace = true parking_lot.workspace = true pg_interval = { version = "0.5.2", package = "pg_interval_2" } -pgwire = { version = "0.37.3", default-features = false, features = [ +pgwire = { version = "0.38", default-features = false, features = [ "server-api-ring", "pg-ext-types", ] } diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 56ca14e85d..daf4bfc646 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::fmt::Debug; +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; @@ -20,6 +21,7 @@ use common_query::{Output, OutputData}; use common_recordbatch::RecordBatch; use common_recordbatch::error::Result as RecordBatchResult; use common_telemetry::{debug, tracing}; +use datafusion::sql::sqlparser::ast::{CopyOption, CopyTarget, Statement as SqlParserStatement}; use datafusion_common::ParamValues; use datafusion_pg_catalog::sql::PostgresCompatibilityParser; use datatypes::prelude::ConcreteDataType; @@ -28,12 +30,15 @@ use futures::{Sink, SinkExt, Stream, StreamExt, future, stream}; use pgwire::api::portal::{Format, Portal}; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{ + CopyCsvOptions, CopyEncoder, CopyResponse, CopyTextOptions, DataRowEncoder, DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag, }; use pgwire::api::stmt::{QueryParser, StoredStatement}; use pgwire::api::{ClientInfo, ErrorHandler, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::PgWireBackendMessage; +use pgwire::messages::copy::CopyData; +use pgwire::messages::data::DataRow; use query::planner::DfLogicalPlanner; use query::query_engine::DescribeResult; use session::Session; @@ -70,7 +75,9 @@ impl SimpleQueryHandler for PostgresServerHandlerInner { return Ok(vec![Response::EmptyQuery]); } - let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) { + let parsed_query = self.query_parser.compatibility_parser.parse(query); + + let query = if let Ok(statements) = &parsed_query { statements .iter() .map(|s| s.to_string()) @@ -88,9 +95,17 @@ impl SimpleQueryHandler for PostgresServerHandlerInner { let mut results = Vec::with_capacity(outputs.len()); - for output in outputs { - let resp = - output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?; + let statements = parsed_query.ok(); + for (idx, output) in outputs.into_iter().enumerate() { + let copy_format = statements + .as_ref() + .and_then(|stmts| stmts.get(idx)) + .and_then(check_copy_to_stdout); + let resp = if let Some(format) = ©_format { + output_to_copy_response(query_ctx.clone(), output, format)? + } else { + output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)? + }; results.push(resp); } @@ -150,6 +165,8 @@ pub(crate) fn output_to_query_response( } } +type RowStream = Pin> + Send + Unpin>>; + fn recordbatches_to_query_response( query_ctx: QueryContextRef, recordbatches_stream: S, @@ -163,18 +180,24 @@ where let pg_schema = Arc::new( schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?, ); - let pg_schema_ref = pg_schema.clone(); - let data_row_stream = recordbatches_stream - .map(move |result| match result { - Ok(record_batch) => stream::iter(RecordBatchRowIterator::new( - query_ctx.clone(), - pg_schema_ref.clone(), - record_batch, - )) - .boxed(), - Err(e) => stream::once(future::err(convert_err(e))).boxed(), - }) - .flatten(); + + let encoder = DataRowEncoder::new(pg_schema.clone()); + let row_stream = RecordBatchRowStream::new( + query_ctx.clone(), + pg_schema.clone(), + schema.clone(), + recordbatches_stream, + encoder, + ); + + let data_row_stream: RowStream = Box::pin( + row_stream + .map(move |result| match result { + Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream, + Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream, + }) + .flatten(), + ); Ok(Response::Query(QueryResponse::new( pg_schema, @@ -182,6 +205,82 @@ where ))) } +pub(crate) fn output_to_copy_response( + query_ctx: QueryContextRef, + output: Result, + format: &str, +) -> PgWireResult { + match output { + Ok(o) => match o.data { + OutputData::AffectedRows(_) => Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "ERROR".to_string(), + "42601".to_string(), + "COPY cannot be used with non-query statements".to_string(), + )))), + OutputData::Stream(record_stream) => { + let schema = record_stream.schema(); + recordbatches_to_copy_response(query_ctx, record_stream, schema, format) + } + OutputData::RecordBatches(recordbatches) => { + let schema = recordbatches.schema(); + recordbatches_to_copy_response(query_ctx, recordbatches.as_stream(), schema, format) + } + }, + Err(e) => Err(convert_err(e)), + } +} + +fn recordbatches_to_copy_response( + query_ctx: QueryContextRef, + recordbatches_stream: S, + schema: SchemaRef, + format: &str, +) -> PgWireResult +where + S: Stream> + Send + Unpin + 'static, +{ + let format_options = format_options_from_query_ctx(&query_ctx); + let pg_fields = schema_to_pg(schema.as_ref(), &Format::UnifiedText, Some(format_options)) + .map_err(convert_err)?; + + let copy_format = match format.to_lowercase().as_str() { + "binary" => 1, + _ => 0, + }; + + let pg_schema = Arc::new(pg_fields); + let num_columns = pg_schema.len(); + + let copy_encoder = match format.to_lowercase().as_str() { + "csv" => CopyEncoder::new_csv(pg_schema.clone(), CopyCsvOptions::default()), + "binary" => CopyEncoder::new_binary(pg_schema.clone()), + _ => CopyEncoder::new_text(pg_schema.clone(), CopyTextOptions::default()), + }; + + let row_stream = RecordBatchRowStream::new( + query_ctx.clone(), + pg_schema.clone(), + schema.clone(), + recordbatches_stream, + copy_encoder, + ); + + let copy_stream: RowStream = Box::pin( + row_stream + .map(move |result| match result { + Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream, + Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream, + }) + .flatten(), + ); + + Ok(Response::CopyOut(CopyResponse::new( + copy_format, + num_columns, + copy_stream, + ))) +} + pub struct DefaultQueryParser { query_handler: ServerSqlQueryHandlerRef, session: Arc, @@ -198,9 +297,16 @@ impl DefaultQueryParser { } } +/// A container type of parse result types +#[derive(Clone, Debug)] +pub struct PgSqlPlan { + plan: SqlPlan, + copy_to_stdout_format: Option, +} + #[async_trait] impl QueryParser for DefaultQueryParser { - type Statement = SqlPlan; + type Statement = PgSqlPlan; async fn parse_sql( &self, @@ -213,20 +319,26 @@ impl QueryParser for DefaultQueryParser { // do not parse if query is empty or matches rules if sql.is_empty() || fixtures::matches(sql) { - return Ok(SqlPlan { - query: sql.to_owned(), - statement: None, - plan: None, - schema: None, + return Ok(PgSqlPlan { + plan: SqlPlan { + query: sql.to_owned(), + statement: None, + plan: None, + schema: None, + }, + copy_to_stdout_format: None, }); } - let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) { - statements.remove(0).to_string() + let parsed_statements = self.compatibility_parser.parse(sql); + let (sql, copy_to_stdout_format) = if let Ok(mut statements) = parsed_statements { + let first_stmt = statements.remove(0); + let format = check_copy_to_stdout(&first_stmt); + (first_stmt.to_string(), format) } else { // bypass the error: it can run into error because of different // versions of sqlparser - sql.to_string() + (sql.to_string(), None) }; let mut stmts = ParserContext::create_with_dialect( @@ -258,11 +370,14 @@ impl QueryParser for DefaultQueryParser { (None, None) }; - Ok(SqlPlan { - query: sql.clone(), - statement: Some(stmt), - plan, - schema, + Ok(PgSqlPlan { + plan: SqlPlan { + query: sql.clone(), + statement: Some(stmt), + plan, + schema, + }, + copy_to_stdout_format, }) } } @@ -290,7 +405,7 @@ impl QueryParser for DefaultQueryParser { #[async_trait] impl ExtendedQueryHandler for PostgresServerHandlerInner { - type Statement = SqlPlan; + type Statement = PgSqlPlan; type QueryParser = DefaultQueryParser; fn query_parser(&self) -> Arc { @@ -314,7 +429,8 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { .with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()]) .start_timer(); - let sql_plan = &portal.statement.statement; + let pg_sql_plan = &portal.statement.statement; + let sql_plan = &pg_sql_plan.plan; if sql_plan.query.is_empty() { // early return if query is empty @@ -355,7 +471,12 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { }; send_warning_opt(client, query_ctx.clone()).await?; - output_to_query_response(query_ctx, output, &portal.result_column_format) + + if let Some(format) = &pg_sql_plan.copy_to_stdout_format { + output_to_copy_response(query_ctx, output, format) + } else { + output_to_query_response(query_ctx, output, &portal.result_column_format) + } } async fn do_describe_statement( @@ -366,7 +487,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { where C: ClientInfo + Unpin + Send + Sync, { - let sql_plan = &stmt.statement; + let sql_plan = &stmt.statement.plan; // client provided parameter types, can be empty if client doesn't try to parse statement let provided_param_types = &stmt.parameter_types; let server_inferenced_types = if let Some(plan) = &sql_plan.plan { @@ -434,7 +555,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner { where C: ClientInfo + Unpin + Send + Sync, { - let sql_plan = &portal.statement.statement; + let sql_plan = &portal.statement.statement.plan; let format = &portal.result_column_format; match sql_plan.statement.as_ref() { @@ -510,3 +631,86 @@ impl ErrorHandler for PostgresServerHandlerInner { debug!("Postgres interface error {}", error) } } + +fn check_copy_to_stdout(statement: &SqlParserStatement) -> Option { + if let SqlParserStatement::Copy { + target, options, .. + } = statement + && matches!(target, CopyTarget::Stdout) + { + for opt in options { + if let CopyOption::Format(format_ident) = opt { + return Some(format_ident.value.to_lowercase()); + } + } + return Some("txt".to_string()); + } + + None +} + +#[cfg(test)] +mod tests { + use datafusion_pg_catalog::sql::PostgresCompatibilityParser; + + use super::*; + + fn parse_copy_statement(sql: &str) -> SqlParserStatement { + let parser = PostgresCompatibilityParser::new(); + let statements = parser.parse(sql).unwrap(); + statements.into_iter().next().unwrap() + } + + #[test] + fn test_check_copy_out_with_csv_format() { + let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT CSV)"); + assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string())); + } + + #[test] + fn test_check_copy_out_with_txt_format() { + let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT TXT)"); + assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string())); + } + + #[test] + fn test_check_copy_out_with_binary_format() { + let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT BINARY)"); + assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string())); + } + + #[test] + fn test_check_copy_out_without_format() { + let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT"); + assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string())); + } + + #[test] + fn test_check_copy_out_to_file() { + let statement = + parse_copy_statement("COPY (SELECT 1) TO '/path/to/file.csv' WITH (FORMAT CSV)"); + assert_eq!(check_copy_to_stdout(&statement), None); + } + + #[test] + fn test_check_copy_out_case_insensitive() { + let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT csv)"); + assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string())); + + let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT binary)"); + assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string())); + } + + #[test] + fn test_check_copy_out_with_multiple_options() { + let statement = parse_copy_statement( + "COPY (SELECT 1) TO STDOUT WITH (FORMAT csv, DELIMITER ',', HEADER)", + ); + assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string())); + + let statement = parse_copy_statement( + "COPY (SELECT 1) TO STDOUT WITH (DELIMITER ',', HEADER, FORMAT binary)", + ); + assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string())); + } +} diff --git a/src/servers/src/postgres/types.rs b/src/servers/src/postgres/types.rs index 0b76819bc9..a95890e78c 100644 --- a/src/servers/src/postgres/types.rs +++ b/src/servers/src/postgres/types.rs @@ -15,14 +15,17 @@ mod error; use std::collections::HashMap; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use arrow::array::{Array, AsArray}; -use arrow_pg::encoder::encode_value; +use arrow_pg::encoder::{Encoder, encode_value}; use arrow_pg::list_encoder::encode_list; use arrow_schema::{DataType, TimeUnit}; use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime}; use common_recordbatch::RecordBatch; +use common_recordbatch::error::Result as RecordBatchResult; use common_time::{IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth}; use datafusion_common::ScalarValue; use datafusion_expr::LogicalPlan; @@ -32,20 +35,20 @@ use datatypes::prelude::{ConcreteDataType, Value}; use datatypes::schema::{Schema, SchemaRef}; use datatypes::types::{IntervalType, TimestampType, jsonb_to_string}; use datatypes::value::StructValue; +use futures::Stream; use pg_interval::Interval as PgInterval; use pgwire::api::Type; use pgwire::api::portal::{Format, Portal}; -use pgwire::api::results::{DataRowEncoder, FieldInfo}; +use pgwire::api::results::FieldInfo; use pgwire::error::{PgWireError, PgWireResult}; -use pgwire::messages::data::DataRow; use pgwire::types::format::FormatOptions as PgFormatOptions; use query::planner::DfLogicalPlanner; use session::context::QueryContextRef; use snafu::ResultExt; pub use self::error::{PgErrorCode, PgErrorSeverity}; -use crate::SqlPlan; use crate::error::{self as server_error, InferParameterTypesSnafu, Result}; +use crate::postgres::handler::PgSqlPlan; use crate::postgres::utils::convert_err; pub(super) fn schema_to_pg( @@ -82,78 +85,125 @@ pub(super) fn schema_to_pg( /// there are alternatives like records, arrays, etc. but there are also limitations: /// records: there is no support for include keys /// arrays: element in array must be the same type -fn encode_struct( +fn encode_struct( _query_ctx: &QueryContextRef, struct_value: StructValue, - builder: &mut DataRowEncoder, + builder: &mut S, + pg_field: &FieldInfo, ) -> PgWireResult<()> { let encoding_setting = JsonStructureSettings::Structured(None); let json_value = encoding_setting .decode(Value::Struct(struct_value)) .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - builder.encode_field(&json_value) + builder.encode_field(&json_value, pg_field) } -pub(crate) struct RecordBatchRowIterator { +pub(crate) struct RecordBatchRowStream +where + S: Encoder, + B: Stream>, +{ query_ctx: QueryContextRef, pg_schema: Arc>, schema: SchemaRef, - record_batch: arrow::record_batch::RecordBatch, - i: usize, + record_batches: Pin>, + encoder: S, } -impl Iterator for RecordBatchRowIterator { - type Item = PgWireResult; +impl Stream for RecordBatchRowStream +where + S: Encoder + Unpin, + B: Stream>, +{ + type Item = PgWireResult>; - fn next(&mut self) -> Option { - let mut encoder = DataRowEncoder::new(self.pg_schema.clone()); - if self.i < self.record_batch.num_rows() { - if let Err(e) = self.encode_row(self.i, &mut encoder) { - return Some(Err(e)); + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.record_batches.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(batch))) => { + let record_batch = batch.into_df_record_batch(); + let num_rows = record_batch.num_rows(); + + if num_rows == 0 { + return Poll::Ready(Some(Ok(vec![]))); + } + + let arrow_schema = record_batch.schema(); + let query_ctx = self.query_ctx.clone(); + let pg_schema = self.pg_schema.clone(); + let schema = self.schema.clone(); + let mut results = Vec::with_capacity(num_rows); + + for i in 0..num_rows { + if let Err(e) = Self::encode_row( + &query_ctx, + &pg_schema, + &schema, + arrow_schema.as_ref(), + &mut self.encoder, + &record_batch, + i, + ) { + return Poll::Ready(Some(Err(e))); + } + results.push(self.encoder.take_row()); + } + + Poll::Ready(Some(Ok(results))) } - self.i += 1; - Some(Ok(encoder.take_row())) - } else { - None + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(convert_err(e)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } } -impl RecordBatchRowIterator { +impl RecordBatchRowStream +where + S: Encoder, + B: Stream>, +{ pub(crate) fn new( query_ctx: QueryContextRef, pg_schema: Arc>, - record_batch: RecordBatch, + schema: SchemaRef, + record_batches: B, + encoder: S, ) -> Self { - let schema = record_batch.schema.clone(); - let record_batch = record_batch.into_df_record_batch(); Self { query_ctx, pg_schema, schema, - record_batch, - i: 0, + record_batches: Box::pin(record_batches), + encoder, } } - fn encode_row(&mut self, i: usize, encoder: &mut DataRowEncoder) -> PgWireResult<()> { - let arrow_schema = self.record_batch.schema(); - for (j, column) in self.record_batch.columns().iter().enumerate() { + fn encode_row( + query_ctx: &QueryContextRef, + pg_schema: &Arc>, + schema: &SchemaRef, + arrow_schema: &arrow::datatypes::Schema, + encoder: &mut S, + record_batch: &arrow::record_batch::RecordBatch, + i: usize, + ) -> PgWireResult<()> { + for (j, column) in record_batch.columns().iter().enumerate() { + let pg_field = &pg_schema[j]; + if column.is_null(i) { - encoder.encode_field(&None::<&i8>)?; + encoder.encode_field(&None::<&i8>, pg_field)?; continue; } - let pg_field = &self.pg_schema[j]; match column.data_type() { // these types are greptimedb specific or custom DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { // jsonb - if let ConcreteDataType::Json(_) = &self.schema.column_schemas()[j].data_type { + if let ConcreteDataType::Json(_) = &schema.column_schemas()[j].data_type { let v = datatypes::arrow_array::binary_array_value(column, i); let s = jsonb_to_string(v).map_err(convert_err)?; - encoder.encode_field(&s)?; + encoder.encode_field(&s, pg_field)?; } else { // bytea let arrow_field = arrow_schema.field(j); @@ -168,7 +218,7 @@ impl RecordBatchRowIterator { encode_list(encoder, items, pg_field)?; } DataType::Struct(_) => { - encode_struct(&self.query_ctx, Default::default(), encoder)?; + encode_struct(query_ctx, Default::default(), encoder, pg_field)?; } _ => { // Encode value using arrow-pg @@ -277,7 +327,7 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result { } } -pub(super) fn parameter_to_string(portal: &Portal, idx: usize) -> PgWireResult { +pub(super) fn parameter_to_string(portal: &Portal, idx: usize) -> PgWireResult { // the index is managed from portal's parameters count so it's safe to // unwrap here. let param_type = portal @@ -359,7 +409,7 @@ where pub(super) fn parameters_to_scalar_values( plan: &LogicalPlan, - portal: &Portal, + portal: &Portal, ) -> PgWireResult> { let param_count = portal.parameter_len(); let mut results = Vec::with_capacity(param_count); @@ -761,7 +811,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::INT2_ARRAY => { - let data = portal.parameter::>(idx, &client_type)?; + let data = portal.parameter::>>(idx, &client_type)?; if let Some(data) = data { let values = data.into_iter().map(|i| i.into()).collect::>(); ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int16, true)) @@ -770,7 +820,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::INT4_ARRAY => { - let data = portal.parameter::>(idx, &client_type)?; + let data = portal.parameter::>>(idx, &client_type)?; if let Some(data) = data { let values = data.into_iter().map(|i| i.into()).collect::>(); ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int32, true)) @@ -779,7 +829,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::INT8_ARRAY => { - let data = portal.parameter::>(idx, &client_type)?; + let data = portal.parameter::>>(idx, &client_type)?; if let Some(data) = data { let values = data.into_iter().map(|i| i.into()).collect::>(); ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int64, true)) @@ -788,7 +838,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::VARCHAR_ARRAY => { - let data = portal.parameter::>(idx, &client_type)?; + let data = portal.parameter::>>(idx, &client_type)?; if let Some(data) = data { let values = data.into_iter().map(|i| i.into()).collect::>(); ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Utf8, true)) @@ -797,7 +847,7 @@ pub(super) fn parameters_to_scalar_values( } } &Type::TIMESTAMP_ARRAY => { - let data = portal.parameter::>(idx, &client_type)?; + let data = portal.parameter::>>(idx, &client_type)?; if let Some(data) = data { if let Some(ConcreteDataType::List(list_type)) = &server_type { match list_type.item_type() { @@ -807,7 +857,7 @@ pub(super) fn parameters_to_scalar_values( .into_iter() .map(|ts| { ScalarValue::TimestampSecond( - Some(ts.and_utc().timestamp()), + ts.map(|ts| ts.and_utc().timestamp()), None, ) }) @@ -823,7 +873,7 @@ pub(super) fn parameters_to_scalar_values( .into_iter() .map(|ts| { ScalarValue::TimestampMillisecond( - Some(ts.and_utc().timestamp_millis()), + ts.map(|ts| ts.and_utc().timestamp_millis()), None, ) }) @@ -839,7 +889,7 @@ pub(super) fn parameters_to_scalar_values( .into_iter() .map(|ts| { ScalarValue::TimestampMicrosecond( - Some(ts.and_utc().timestamp_micros()), + ts.map(|ts| ts.and_utc().timestamp_micros()), None, ) }) @@ -854,8 +904,13 @@ pub(super) fn parameters_to_scalar_values( let values = data .into_iter() .filter_map(|ts| { - ts.and_utc().timestamp_nanos_opt().map(|nanos| { - ScalarValue::TimestampNanosecond(Some(nanos), None) + ts.and_then(|ts| { + ts.and_utc().timestamp_nanos_opt().map(|nanos| { + ScalarValue::TimestampNanosecond( + Some(nanos), + None, + ) + }) }) }) .collect::>(); @@ -878,12 +933,11 @@ pub(super) fn parameters_to_scalar_values( } } } else { - // Default to millisecond when no server type is specified let values = data .into_iter() .map(|ts| { ScalarValue::TimestampMillisecond( - Some(ts.and_utc().timestamp_millis()), + ts.map(|ts| ts.and_utc().timestamp_millis()), None, ) }) @@ -899,7 +953,8 @@ pub(super) fn parameters_to_scalar_values( } } &Type::TIMESTAMPTZ_ARRAY => { - let data = portal.parameter::>>(idx, &client_type)?; + let data = + portal.parameter::>>>(idx, &client_type)?; if let Some(data) = data { if let Some(ConcreteDataType::List(list_type)) = &server_type { match list_type.item_type() { @@ -908,7 +963,10 @@ pub(super) fn parameters_to_scalar_values( let values = data .into_iter() .map(|ts| { - ScalarValue::TimestampSecond(Some(ts.timestamp()), None) + ScalarValue::TimestampSecond( + ts.map(|ts| ts.timestamp()), + None, + ) }) .collect::>(); ScalarValue::List(ScalarValue::new_list( @@ -922,7 +980,7 @@ pub(super) fn parameters_to_scalar_values( .into_iter() .map(|ts| { ScalarValue::TimestampMillisecond( - Some(ts.timestamp_millis()), + ts.map(|ts| ts.timestamp_millis()), None, ) }) @@ -938,7 +996,7 @@ pub(super) fn parameters_to_scalar_values( .into_iter() .map(|ts| { ScalarValue::TimestampMicrosecond( - Some(ts.timestamp_micros()), + ts.map(|ts| ts.timestamp_micros()), None, ) }) @@ -952,10 +1010,11 @@ pub(super) fn parameters_to_scalar_values( TimestampType::Nanosecond(_) => { let values = data .into_iter() - .filter_map(|ts| { - ts.timestamp_nanos_opt().map(|nanos| { - ScalarValue::TimestampNanosecond(Some(nanos), None) - }) + .map(|ts| { + ScalarValue::TimestampNanosecond( + ts.and_then(|ts| ts.timestamp_nanos_opt()), + None, + ) }) .collect::>(); ScalarValue::List(ScalarValue::new_list( @@ -977,11 +1036,13 @@ pub(super) fn parameters_to_scalar_values( } } } else { - // Default to millisecond when no server type is specified let values = data .into_iter() .map(|ts| { - ScalarValue::TimestampMillisecond(Some(ts.timestamp_millis()), None) + ScalarValue::TimestampMillisecond( + ts.map(|ts| ts.timestamp_millis()), + None, + ) }) .collect::>(); ScalarValue::List(ScalarValue::new_list( @@ -1050,8 +1111,9 @@ mod test { IntervalYearMonthVector, ListVector, NullVector, StringVector, TimeSecondVector, TimestampSecondVector, UInt8Vector, UInt16Vector, UInt32Vector, UInt64Vector, VectorRef, }; + use futures::{StreamExt as FuturesStreamExt, stream}; use pgwire::api::Type; - use pgwire::api::results::{FieldFormat, FieldInfo}; + use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo}; use session::context::QueryContextBuilder; use super::*; @@ -1155,7 +1217,7 @@ mod test { #[test] fn test_encode_text_format_data() { - let schema = vec![ + let pg_schema = vec![ FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text), FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text), FieldInfo::new("uint8s".into(), None, None, Type::INT2, FieldFormat::Text), @@ -1416,14 +1478,28 @@ mod test { .configuration_parameter(Default::default()) .build() .into(); - let schema = Arc::new(schema); + let schema = record_batch.schema.clone(); + let pg_schema_ref = Arc::new(pg_schema); - let rows = RecordBatchRowIterator::new(query_context, schema.clone(), record_batch) - .filter_map(|x| x.ok()) - .collect::>(); + let encoder = DataRowEncoder::new(pg_schema_ref.clone()); + + let row_stream = RecordBatchRowStream::new( + query_context, + pg_schema_ref.clone(), + schema, + stream::once(async { Ok(record_batch) }), + encoder, + ); + + let rows: Vec<_> = futures::executor::block_on( + row_stream + .filter_map(|x: PgWireResult<_>| async move { x.ok() }) + .flat_map(stream::iter) + .collect::>(), + ); assert_eq!(rows.len(), 3); for row in rows { - assert_eq!(row.field_count, schema.len() as i16); + assert_eq!(row.field_count, pg_schema_ref.len() as i16); } } diff --git a/src/sql/src/parsers/copy_parser.rs b/src/sql/src/parsers/copy_parser.rs index 892992d310..d975d884f6 100644 --- a/src/sql/src/parsers/copy_parser.rs +++ b/src/sql/src/parsers/copy_parser.rs @@ -32,22 +32,34 @@ impl ParserContext<'_> { pub(crate) fn parse_copy(&mut self) -> Result { let _ = self.parser.next_token(); let next = self.parser.peek_token(); - let copy = if next.token == Token::LParen { + if next.token == Token::LParen { let copy_query = self.parse_copy_query_to()?; - - crate::statements::copy::Copy::CopyQueryTo(copy_query) + // the COPY ... TO STDOUT is a special case for postgres wire protocol + // the logic is completely identical to query, but with an alternative data encoding on transport + // + // so at the query engine level, we simple parse the command as it's inner query + // we will deal with the encoding and its format options in server/src/postgres/handler.rs + if copy_query.arg.location == "STDOUT" { + Ok(*copy_query.query) + } else { + Ok(Statement::Copy(crate::statements::copy::Copy::CopyQueryTo( + copy_query, + ))) + } } else if let Word(word) = next.token && word.keyword == Keyword::DATABASE { let _ = self.parser.next_token(); let copy_database = self.parser_copy_database()?; - crate::statements::copy::Copy::CopyDatabase(copy_database) + Ok(Statement::Copy( + crate::statements::copy::Copy::CopyDatabase(copy_database), + )) } else { let copy_table = self.parse_copy_table()?; - crate::statements::copy::Copy::CopyTable(copy_table) - }; - - Ok(Statement::Copy(copy)) + Ok(Statement::Copy(crate::statements::copy::Copy::CopyTable( + copy_table, + ))) + } } fn parser_copy_database(&mut self) -> Result { @@ -147,21 +159,62 @@ impl ParserContext<'_> { self.parser .expect_keyword(Keyword::TO) .context(error::SyntaxSnafu)?; - let (with, connection, location, limit) = self.parse_copy_parameters()?; - if limit.is_some() { - return error::InvalidSqlSnafu { - msg: "limit is not supported", + + if self.parser.parse_keyword(Keyword::STDOUT) { + // early return without parsing options + // we will deal with copy to stdout on postgres protocol layer + // consume [WITH] (...) options if present (they will be ignored) + // we support both "WITH (FORMAT binary)" and "(FORMAT binary)" + // for PostgreSQL compatibility + + // Check for optional WITH keyword or direct LParen (PostgreSQL syntax) + // Both "WITH (...)" and "(...)" are valid after STDOUT + let _ = self.parser.parse_keyword(Keyword::WITH); + if self.parser.peek_token().token == Token::LParen { + let _ = self.parser.next_token(); + // consume all tokens until we find matching RParen + let mut depth = 1; + while depth > 0 { + match self.parser.next_token().token { + Token::LParen => depth += 1, + Token::RParen => depth -= 1, + Token::EOF => { + return error::UnexpectedTokenSnafu { + expected: ")", + actual: "EOF", + } + .fail(); + } + _ => {} + } + } } - .fail(); + + Ok(CopyQueryTo { + query: Box::new(query), + arg: CopyQueryToArgument { + with: OptionMap::default(), + connection: OptionMap::default(), + location: "STDOUT".to_string(), + }, + }) + } else { + let (with, connection, location, limit) = self.parse_copy_parameters()?; + if limit.is_some() { + return error::InvalidSqlSnafu { + msg: "limit is not supported", + } + .fail(); + } + Ok(CopyQueryTo { + query: Box::new(query), + arg: CopyQueryToArgument { + with, + connection, + location, + }, + }) } - Ok(CopyQueryTo { - query: Box::new(query), - arg: CopyQueryToArgument { - with, - connection, - location, - }, - }) } fn parse_copy_parameters(&mut self) -> Result<(OptionMap, OptionMap, String, Option)> { @@ -540,4 +593,243 @@ mod tests { ) } } + + #[test] + fn test_copy_query_to_stdout() { + let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO STDOUT WITH (FORMAT = 'csv')"; + let stmt = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap() + .pop() + .unwrap(); + + let expected_query = ParserContext::create_with_dialect( + "SELECT * FROM tbl WHERE ts > 10", + &GreptimeDbDialect {}, + ParseOptions::default(), + ) + .unwrap() + .remove(0); + + assert_eq!(&expected_query, &stmt); + } + + #[test] + fn test_copy_query_to_stdout_without_format() { + let sql = "COPY (SELECT generate_series(1, 2), generate_series(2, 3)) TO STDOUT"; + let stmt = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()) + .unwrap() + .pop() + .unwrap(); + + let query_str = "SELECT generate_series(1, 2), generate_series(2, 3)"; + let expected_query = ParserContext::create_with_dialect( + query_str, + &GreptimeDbDialect {}, + ParseOptions::default(), + ) + .unwrap() + .remove(0); + + assert_eq!(&expected_query, &stmt); + } + + #[test] + fn test_copy_query_to_stdout_with_binary_format() { + let sql = "COPY (SELECT * FROM test_table) TO STDOUT WITH (FORMAT binary)"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + + if let Err(e) = &result { + panic!( + "COPY TO STDOUT WITH (FORMAT binary) should parse without error, got: {:?}", + e + ); + } + + let stmt = result.unwrap().pop().unwrap(); + + let expected_query = ParserContext::create_with_dialect( + "SELECT * FROM test_table", + &GreptimeDbDialect {}, + ParseOptions::default(), + ) + .unwrap() + .remove(0); + + assert_eq!(&expected_query, &stmt); + } + + #[test] + fn test_copy_query_to_stdout_with_csv_format() { + let sql = "COPY (SELECT * FROM test_table) TO STDOUT WITH (FORMAT csv)"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + + if let Err(e) = &result { + panic!( + "COPY TO STDOUT WITH (FORMAT csv) should parse without error, got: {:?}", + e + ); + } + + let stmt = result.unwrap().pop().unwrap(); + + let expected_query = ParserContext::create_with_dialect( + "SELECT * FROM test_table", + &GreptimeDbDialect {}, + ParseOptions::default(), + ) + .unwrap() + .remove(0); + + assert_eq!(&expected_query, &stmt); + } + + #[test] + fn test_copy_query_to_stdout_with_equals_format() { + let sql = "COPY (SELECT * FROM test_table) TO STDOUT WITH (FORMAT = 'binary')"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + + if let Err(e) = &result { + panic!( + "COPY TO STDOUT WITH (FORMAT = 'binary') should parse without error, got: {:?}", + e + ); + } + + let stmt = result.unwrap().pop().unwrap(); + + let expected_query = ParserContext::create_with_dialect( + "SELECT * FROM test_table", + &GreptimeDbDialect {}, + ParseOptions::default(), + ) + .unwrap() + .remove(0); + + assert_eq!(&expected_query, &stmt); + } + + #[test] + fn test_copy_query_to_stdout_with_multiple_options() { + let sql = + "COPY (SELECT * FROM test_table) TO STDOUT WITH (FORMAT csv, DELIMITER ',', HEADER)"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + + if let Err(e) = &result { + panic!( + "COPY TO STDOUT WITH multiple options should parse without error, got: {:?}", + e + ); + } + + let stmt = result.unwrap().pop().unwrap(); + + let expected_query = ParserContext::create_with_dialect( + "SELECT * FROM test_table", + &GreptimeDbDialect {}, + ParseOptions::default(), + ) + .unwrap() + .remove(0); + + assert_eq!(&expected_query, &stmt); + } + + #[test] + fn test_copy_query_to_stdout_without_with_keyword() { + let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT binary)"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + + if let Err(e) = &result { + panic!( + "COPY TO STDOUT (FORMAT binary) without WITH keyword should parse without error, got: {:?}", + e + ); + } + + let stmt = result.unwrap().pop().unwrap(); + + let expected_query = ParserContext::create_with_dialect( + "SELECT * FROM test_table", + &GreptimeDbDialect {}, + ParseOptions::default(), + ) + .unwrap() + .remove(0); + + assert_eq!(&expected_query, &stmt); + } + + #[test] + fn test_copy_query_to_stdout_without_with_csv_format() { + let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT csv)"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + + if let Err(e) = &result { + panic!( + "COPY TO STDOUT (FORMAT csv) without WITH keyword should parse without error, got: {:?}", + e + ); + } + + let stmt = result.unwrap().pop().unwrap(); + + let expected_query = ParserContext::create_with_dialect( + "SELECT * FROM test_table", + &GreptimeDbDialect {}, + ParseOptions::default(), + ) + .unwrap() + .remove(0); + + assert_eq!(&expected_query, &stmt); + } + + #[test] + fn test_copy_query_to_stdout_without_with_multiple_options() { + let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT csv, DELIMITER ',', HEADER)"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + + if let Err(e) = &result { + panic!( + "COPY TO STDOUT (FORMAT csv, ...) without WITH keyword should parse without error, got: {:?}", + e + ); + } + + let stmt = result.unwrap().pop().unwrap(); + + let expected_query = ParserContext::create_with_dialect( + "SELECT * FROM test_table", + &GreptimeDbDialect {}, + ParseOptions::default(), + ) + .unwrap() + .remove(0); + + assert_eq!(&expected_query, &stmt); + } + + #[test] + fn test_invalid_copy_query() { + let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT csv"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + + assert!(result.is_err()); + + let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT csv))"; + let result = + ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default()); + + assert!(result.is_err()); + } } diff --git a/tests/cases/standalone/common/system/pg_catalog.result b/tests/cases/standalone/common/system/pg_catalog.result index 6960de5d2f..ef0452e316 100644 --- a/tests/cases/standalone/common/system/pg_catalog.result +++ b/tests/cases/standalone/common/system/pg_catalog.result @@ -1099,3 +1099,32 @@ WHERE oid = pg_my_temp_schema(); | OID| public | t | | +-------+--------------------+-------------------+---------+ +-- SQLNESS PROTOCOL POSTGRES +CREATE table foo +( + ts TIMESTAMP TIME INDEX, + log_data TEXT, + count_num BIGINT, +); + +Affected Rows: 0 + +-- SQLNESS PROTOCOL POSTGRES +SELECT attname, atttypid FROM pg_catalog.pg_class AS cls INNER JOIN +pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid INNER JOIN +pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid WHERE attr.attnum >= 0 AND +cls.oid = 'foo'::regclass::oid ORDER BY attr.attnum; + ++-----------+----------+ +| attname | atttypid | ++-----------+----------+ +| ts | 1114 | +| log_data | 25 | +| count_num | 20 | ++-----------+----------+ + +-- SQLNESS PROTOCOL POSTGRES +DROP TABLE foo; + +Affected Rows: 0 + diff --git a/tests/cases/standalone/common/system/pg_catalog.sql b/tests/cases/standalone/common/system/pg_catalog.sql index 979d55b480..ad59da372c 100644 --- a/tests/cases/standalone/common/system/pg_catalog.sql +++ b/tests/cases/standalone/common/system/pg_catalog.sql @@ -258,3 +258,20 @@ oid ,nspname FROM pg_namespace WHERE oid = pg_my_temp_schema(); + +-- SQLNESS PROTOCOL POSTGRES +CREATE table foo +( + ts TIMESTAMP TIME INDEX, + log_data TEXT, + count_num BIGINT, +); + +-- SQLNESS PROTOCOL POSTGRES +SELECT attname, atttypid FROM pg_catalog.pg_class AS cls INNER JOIN +pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid INNER JOIN +pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid WHERE attr.attnum >= 0 AND +cls.oid = 'foo'::regclass::oid ORDER BY attr.attnum; + +-- SQLNESS PROTOCOL POSTGRES +DROP TABLE foo;