feat: implement postgres copy to stdout (#7709)

* feat: update pgwire

* feat: add special parser for copy to stdout

* feat: implement copy to stdout

* fix: improve code

* fix: expect optional with

* fix: lint

* feat: correct encoder using and refactor

* chore: fmt

* refactor: update api

* chore: use released dependencies

* fix: update datafusion-pg-catalog to support schema query

* fix: support for double quoted identifier

* feat: update datafusion-postgres to support schema.table

* refactor: use pgsqlparser container

* refactor: remove unquote which is no longer needed

* fix: correctly handle invalid query

* fix: correct handle null in nano timestamp

* test: add a new test for additional close )
This commit is contained in:
Ning Sun
2026-02-26 11:41:33 +08:00
committed by GitHub
parent b0fb4abbdf
commit 0f6b8ff815
8 changed files with 986 additions and 164 deletions

281
Cargo.lock generated
View File

@@ -41,7 +41,7 @@ checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
"cpufeatures 0.2.17",
]
[[package]]
@@ -642,9 +642,9 @@ dependencies = [
[[package]]
name = "arrow-pg"
version = "0.11.0"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87bc2eb53228ffb0cffff4a8a99d5311641b6d8ce63ec48b860dab70ec01ae1f"
checksum = "978af69bfebf96147f743e5ad50f68c058ec80593c156723d94775a362d3dc42"
dependencies = [
"arrow 57.0.0",
"arrow-schema 57.0.0",
@@ -1816,7 +1816,18 @@ checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
"cpufeatures 0.2.17",
]
[[package]]
name = "chacha20"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601"
dependencies = [
"cfg-if",
"cpufeatures 0.3.0",
"rand_core 0.10.0",
]
[[package]]
@@ -1826,7 +1837,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35"
dependencies = [
"aead",
"chacha20",
"chacha20 0.9.1",
"cipher",
"poly1305",
"zeroize",
@@ -2211,7 +2222,7 @@ checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81"
dependencies = [
"serde",
"termcolor",
"unicode-width 0.2.1",
"unicode-width 0.1.14",
]
[[package]]
@@ -2222,7 +2233,7 @@ checksum = "af491d569909a7e4dee0ad7db7f5341fef5c614d5b8ec8cf765732aba3cff681"
dependencies = [
"serde",
"termcolor",
"unicode-width 0.2.1",
"unicode-width 0.1.14",
]
[[package]]
@@ -2238,7 +2249,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c"
dependencies = [
"lazy_static",
"windows-sys 0.59.0",
"windows-sys 0.48.0",
]
[[package]]
@@ -3049,7 +3060,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3bb320cac8a0750d7f25280aa97b09c26edfe161164238ecbbb31092b079e735"
dependencies = [
"cfg-if",
"cpufeatures",
"cpufeatures 0.2.17",
"proptest",
"serde_core",
]
@@ -3186,6 +3197,15 @@ dependencies = [
"libc",
]
[[package]]
name = "cpufeatures"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201"
dependencies = [
"libc",
]
[[package]]
name = "crc"
version = "3.3.0"
@@ -4115,7 +4135,7 @@ dependencies = [
[[package]]
name = "datafusion-pg-catalog"
version = "0.13.1"
source = "git+https://github.com/GreptimeTeam/datafusion-postgres.git?rev=74ac8e2806be6de91ff192b97f64735392539d16#74ac8e2806be6de91ff192b97f64735392539d16"
source = "git+https://github.com/GreptimeTeam/datafusion-postgres.git?rev=f675927a79cd714a8eeb438b0d3015cd54d4e60a#f675927a79cd714a8eeb438b0d3015cd54d4e60a"
dependencies = [
"async-trait",
"datafusion",
@@ -4909,7 +4929,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
dependencies = [
"libc",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -5717,6 +5737,20 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"rand_core 0.10.0",
"wasip2",
"wasip3",
]
[[package]]
name = "gimli"
version = "0.31.1"
@@ -6273,7 +6307,7 @@ dependencies = [
"libc",
"percent-encoding",
"pin-project-lite",
"socket2 0.6.0",
"socket2 0.5.10",
"tokio",
"tower-service",
"tracing",
@@ -6344,7 +6378,7 @@ dependencies = [
"js-sys",
"log",
"wasm-bindgen",
"windows-core 0.61.2",
"windows-core 0.57.0",
]
[[package]]
@@ -6442,6 +6476,12 @@ dependencies = [
"zerovec",
]
[[package]]
name = "id-arena"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954"
[[package]]
name = "ident_case"
version = "1.0.1"
@@ -6582,6 +6622,8 @@ checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f"
dependencies = [
"equivalent",
"hashbrown 0.16.0",
"serde",
"serde_core",
]
[[package]]
@@ -6758,7 +6800,7 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9"
dependencies = [
"hermit-abi 0.5.2",
"libc",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -6987,7 +7029,7 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb26cec98cce3a3d96cbb7bced3c4b16e3d13f27ec56dbd62cbc8f39cfb9d653"
dependencies = [
"cpufeatures",
"cpufeatures 0.2.17",
]
[[package]]
@@ -7218,6 +7260,12 @@ dependencies = [
"spin",
]
[[package]]
name = "leb128fmt"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
[[package]]
name = "levenshtein_automata"
version = "0.2.1"
@@ -7341,7 +7389,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
dependencies = [
"cfg-if",
"windows-targets 0.52.6",
"windows-targets 0.48.5",
]
[[package]]
@@ -8731,7 +8779,7 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77e878c846a8abae00dd069496dbe8751b16ac1c3d6bd2a7283a938e8228f90d"
dependencies = [
"proc-macro-crate 3.3.0",
"proc-macro-crate 1.3.1",
"proc-macro2",
"quote",
"syn 2.0.114",
@@ -9653,9 +9701,9 @@ dependencies = [
[[package]]
name = "pgwire"
version = "0.37.3"
version = "0.38.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fcd410bc6990bd8d20b3fe3cd879a3c3ec250bdb1cb12537b528818823b02c9"
checksum = "89d5e5a60d3f6e40c91f6a2a7f8d09665e636272bd5611977253559b6651aabb"
dependencies = [
"async-trait",
"base64 0.22.1",
@@ -9668,7 +9716,7 @@ dependencies = [
"md5",
"pg_interval_2",
"postgres-types",
"rand 0.9.1",
"rand 0.10.0",
"ring",
"rust_decimal",
"rustls-pki-types",
@@ -9973,7 +10021,7 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf"
dependencies = [
"cpufeatures",
"cpufeatures 0.2.17",
"opaque-debug",
"universal-hash",
]
@@ -10364,7 +10412,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck 0.5.0",
"heck 0.4.1",
"itertools 0.12.1",
"log",
"multimap",
@@ -10384,8 +10432,8 @@ version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf"
dependencies = [
"heck 0.5.0",
"itertools 0.14.0",
"heck 0.4.1",
"itertools 0.10.5",
"log",
"multimap",
"once_cell",
@@ -10402,8 +10450,8 @@ version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1"
dependencies = [
"heck 0.5.0",
"itertools 0.14.0",
"heck 0.4.1",
"itertools 0.10.5",
"log",
"multimap",
"once_cell",
@@ -10451,7 +10499,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
dependencies = [
"anyhow",
"itertools 0.14.0",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn 2.0.114",
@@ -10464,7 +10512,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425"
dependencies = [
"anyhow",
"itertools 0.14.0",
"itertools 0.10.5",
"proc-macro2",
"quote",
"syn 2.0.114",
@@ -10851,7 +10899,7 @@ dependencies = [
"once_cell",
"socket2 0.5.10",
"tracing",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -10942,6 +10990,17 @@ dependencies = [
"rand_core 0.9.3",
]
[[package]]
name = "rand"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8"
dependencies = [
"chacha20 0.10.0",
"getrandom 0.4.1",
"rand_core 0.10.0",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
@@ -10981,6 +11040,12 @@ dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "rand_core"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba"
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -11629,7 +11694,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -11642,7 +11707,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.9.4",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -12305,7 +12370,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c"
dependencies = [
"cfg-if",
"cpufeatures",
"cpufeatures 0.2.17",
"digest",
]
@@ -12316,7 +12381,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures",
"cpufeatures 0.2.17",
"digest",
]
@@ -12327,7 +12392,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
dependencies = [
"cfg-if",
"cpufeatures",
"cpufeatures 0.2.17",
"digest",
]
@@ -12545,7 +12610,7 @@ version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1961e2ef424c1424204d3a5d6975f934f56b6d50ff5732382d84ebf460e147f7"
dependencies = [
"heck 0.5.0",
"heck 0.4.1",
"proc-macro2",
"quote",
"syn 2.0.114",
@@ -12958,7 +13023,7 @@ dependencies = [
"cfg-if",
"libc",
"psm",
"windows-sys 0.59.0",
"windows-sys 0.52.0",
]
[[package]]
@@ -13582,7 +13647,7 @@ dependencies = [
"getrandom 0.3.3",
"once_cell",
"rustix 1.0.7",
"windows-sys 0.61.2",
"windows-sys 0.52.0",
]
[[package]]
@@ -15077,6 +15142,24 @@ dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasip2"
version = "1.0.2+wasi-0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5"
dependencies = [
"wit-bindgen",
]
[[package]]
name = "wasip3"
version = "0.4.0+wasi-0.3.0-rc-2026-01-06"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5"
dependencies = [
"wit-bindgen",
]
[[package]]
name = "wasite"
version = "0.1.0"
@@ -15154,6 +15237,28 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "wasm-encoder"
version = "0.244.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319"
dependencies = [
"leb128fmt",
"wasmparser",
]
[[package]]
name = "wasm-metadata"
version = "0.244.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909"
dependencies = [
"anyhow",
"indexmap 2.12.0",
"wasm-encoder",
"wasmparser",
]
[[package]]
name = "wasm-streams"
version = "0.4.2"
@@ -15167,6 +15272,18 @@ dependencies = [
"web-sys",
]
[[package]]
name = "wasmparser"
version = "0.244.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe"
dependencies = [
"bitflags 2.9.1",
"hashbrown 0.15.4",
"indexmap 2.12.0",
"semver",
]
[[package]]
name = "web-sys"
version = "0.3.77"
@@ -15258,7 +15375,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.48.0",
]
[[package]]
@@ -15621,6 +15738,26 @@ dependencies = [
"memchr",
]
[[package]]
name = "wit-bindgen"
version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5"
dependencies = [
"wit-bindgen-rust-macro",
]
[[package]]
name = "wit-bindgen-core"
version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc"
dependencies = [
"anyhow",
"heck 0.5.0",
"wit-parser",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
@@ -15630,6 +15767,74 @@ dependencies = [
"bitflags 2.9.1",
]
[[package]]
name = "wit-bindgen-rust"
version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21"
dependencies = [
"anyhow",
"heck 0.5.0",
"indexmap 2.12.0",
"prettyplease",
"syn 2.0.114",
"wasm-metadata",
"wit-bindgen-core",
"wit-component",
]
[[package]]
name = "wit-bindgen-rust-macro"
version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a"
dependencies = [
"anyhow",
"prettyplease",
"proc-macro2",
"quote",
"syn 2.0.114",
"wit-bindgen-core",
"wit-bindgen-rust",
]
[[package]]
name = "wit-component"
version = "0.244.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2"
dependencies = [
"anyhow",
"bitflags 2.9.1",
"indexmap 2.12.0",
"log",
"serde",
"serde_derive",
"serde_json",
"wasm-encoder",
"wasm-metadata",
"wasmparser",
"wit-parser",
]
[[package]]
name = "wit-parser"
version = "0.244.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736"
dependencies = [
"anyhow",
"id-arena",
"indexmap 2.12.0",
"log",
"semver",
"serde",
"serde_derive",
"serde_json",
"unicode-xid",
"wasmparser",
]
[[package]]
name = "wkt"
version = "0.11.1"

View File

@@ -340,12 +340,11 @@ datafusion-optimizer = { git = "https://github.com/GreptimeTeam/datafusion.git",
datafusion-physical-expr = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" }
datafusion-physical-expr-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" }
datafusion-physical-plan = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" }
datafusion-pg-catalog = { git = "https://github.com/GreptimeTeam/datafusion-postgres.git", rev = "74ac8e2806be6de91ff192b97f64735392539d16" }
datafusion-pg-catalog = { git = "https://github.com/GreptimeTeam/datafusion-postgres.git", rev = "f675927a79cd714a8eeb438b0d3015cd54d4e60a" }
datafusion-datasource = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" }
datafusion-sql = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" }
datafusion-substrait = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "7143b2fc4492a7970774583ed0997a459f3e5c05" }
sqlparser = { git = "https://github.com/GreptimeTeam/sqlparser-rs.git", rev = "d7d95a44889e099e32d78e9bad9bc00598faef28" } # on branch v0.59.x
[profile.release]
debug = 1

View File

@@ -23,7 +23,7 @@ api.workspace = true
arrow.workspace = true
arrow-flight.workspace = true
arrow-ipc.workspace = true
arrow-pg = "0.11"
arrow-pg = "0.12"
arrow-schema.workspace = true
async-trait.workspace = true
auth.workspace = true
@@ -89,7 +89,7 @@ operator.workspace = true
otel-arrow-rust.workspace = true
parking_lot.workspace = true
pg_interval = { version = "0.5.2", package = "pg_interval_2" }
pgwire = { version = "0.37.3", default-features = false, features = [
pgwire = { version = "0.38", default-features = false, features = [
"server-api-ring",
"pg-ext-types",
] }

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
@@ -20,6 +21,7 @@ use common_query::{Output, OutputData};
use common_recordbatch::RecordBatch;
use common_recordbatch::error::Result as RecordBatchResult;
use common_telemetry::{debug, tracing};
use datafusion::sql::sqlparser::ast::{CopyOption, CopyTarget, Statement as SqlParserStatement};
use datafusion_common::ParamValues;
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
use datatypes::prelude::ConcreteDataType;
@@ -28,12 +30,15 @@ use futures::{Sink, SinkExt, Stream, StreamExt, future, stream};
use pgwire::api::portal::{Format, Portal};
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{
CopyCsvOptions, CopyEncoder, CopyResponse, CopyTextOptions, DataRowEncoder,
DescribePortalResponse, DescribeStatementResponse, FieldInfo, QueryResponse, Response, Tag,
};
use pgwire::api::stmt::{QueryParser, StoredStatement};
use pgwire::api::{ClientInfo, ErrorHandler, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::PgWireBackendMessage;
use pgwire::messages::copy::CopyData;
use pgwire::messages::data::DataRow;
use query::planner::DfLogicalPlanner;
use query::query_engine::DescribeResult;
use session::Session;
@@ -70,7 +75,9 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
return Ok(vec![Response::EmptyQuery]);
}
let query = if let Ok(statements) = self.query_parser.compatibility_parser.parse(query) {
let parsed_query = self.query_parser.compatibility_parser.parse(query);
let query = if let Ok(statements) = &parsed_query {
statements
.iter()
.map(|s| s.to_string())
@@ -88,9 +95,17 @@ impl SimpleQueryHandler for PostgresServerHandlerInner {
let mut results = Vec::with_capacity(outputs.len());
for output in outputs {
let resp =
output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?;
let statements = parsed_query.ok();
for (idx, output) in outputs.into_iter().enumerate() {
let copy_format = statements
.as_ref()
.and_then(|stmts| stmts.get(idx))
.and_then(check_copy_to_stdout);
let resp = if let Some(format) = &copy_format {
output_to_copy_response(query_ctx.clone(), output, format)?
} else {
output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?
};
results.push(resp);
}
@@ -150,6 +165,8 @@ pub(crate) fn output_to_query_response(
}
}
type RowStream<T> = Pin<Box<dyn Stream<Item = PgWireResult<T>> + Send + Unpin>>;
fn recordbatches_to_query_response<S>(
query_ctx: QueryContextRef,
recordbatches_stream: S,
@@ -163,18 +180,24 @@ where
let pg_schema = Arc::new(
schema_to_pg(schema.as_ref(), field_format, Some(format_options)).map_err(convert_err)?,
);
let pg_schema_ref = pg_schema.clone();
let data_row_stream = recordbatches_stream
.map(move |result| match result {
Ok(record_batch) => stream::iter(RecordBatchRowIterator::new(
query_ctx.clone(),
pg_schema_ref.clone(),
record_batch,
))
.boxed(),
Err(e) => stream::once(future::err(convert_err(e))).boxed(),
})
.flatten();
let encoder = DataRowEncoder::new(pg_schema.clone());
let row_stream = RecordBatchRowStream::new(
query_ctx.clone(),
pg_schema.clone(),
schema.clone(),
recordbatches_stream,
encoder,
);
let data_row_stream: RowStream<DataRow> = Box::pin(
row_stream
.map(move |result| match result {
Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<DataRow>,
Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<DataRow>,
})
.flatten(),
);
Ok(Response::Query(QueryResponse::new(
pg_schema,
@@ -182,6 +205,82 @@ where
)))
}
pub(crate) fn output_to_copy_response(
query_ctx: QueryContextRef,
output: Result<Output>,
format: &str,
) -> PgWireResult<Response> {
match output {
Ok(o) => match o.data {
OutputData::AffectedRows(_) => Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_string(),
"42601".to_string(),
"COPY cannot be used with non-query statements".to_string(),
)))),
OutputData::Stream(record_stream) => {
let schema = record_stream.schema();
recordbatches_to_copy_response(query_ctx, record_stream, schema, format)
}
OutputData::RecordBatches(recordbatches) => {
let schema = recordbatches.schema();
recordbatches_to_copy_response(query_ctx, recordbatches.as_stream(), schema, format)
}
},
Err(e) => Err(convert_err(e)),
}
}
fn recordbatches_to_copy_response<S>(
query_ctx: QueryContextRef,
recordbatches_stream: S,
schema: SchemaRef,
format: &str,
) -> PgWireResult<Response>
where
S: Stream<Item = RecordBatchResult<RecordBatch>> + Send + Unpin + 'static,
{
let format_options = format_options_from_query_ctx(&query_ctx);
let pg_fields = schema_to_pg(schema.as_ref(), &Format::UnifiedText, Some(format_options))
.map_err(convert_err)?;
let copy_format = match format.to_lowercase().as_str() {
"binary" => 1,
_ => 0,
};
let pg_schema = Arc::new(pg_fields);
let num_columns = pg_schema.len();
let copy_encoder = match format.to_lowercase().as_str() {
"csv" => CopyEncoder::new_csv(pg_schema.clone(), CopyCsvOptions::default()),
"binary" => CopyEncoder::new_binary(pg_schema.clone()),
_ => CopyEncoder::new_text(pg_schema.clone(), CopyTextOptions::default()),
};
let row_stream = RecordBatchRowStream::new(
query_ctx.clone(),
pg_schema.clone(),
schema.clone(),
recordbatches_stream,
copy_encoder,
);
let copy_stream: RowStream<CopyData> = Box::pin(
row_stream
.map(move |result| match result {
Ok(rows) => Box::pin(stream::iter(rows.into_iter().map(Ok))) as RowStream<CopyData>,
Err(e) => Box::pin(stream::once(future::ready(Err(e)))) as RowStream<CopyData>,
})
.flatten(),
);
Ok(Response::CopyOut(CopyResponse::new(
copy_format,
num_columns,
copy_stream,
)))
}
pub struct DefaultQueryParser {
query_handler: ServerSqlQueryHandlerRef,
session: Arc<Session>,
@@ -198,9 +297,16 @@ impl DefaultQueryParser {
}
}
/// A container type of parse result types
#[derive(Clone, Debug)]
pub struct PgSqlPlan {
plan: SqlPlan,
copy_to_stdout_format: Option<String>,
}
#[async_trait]
impl QueryParser for DefaultQueryParser {
type Statement = SqlPlan;
type Statement = PgSqlPlan;
async fn parse_sql<C>(
&self,
@@ -213,20 +319,26 @@ impl QueryParser for DefaultQueryParser {
// do not parse if query is empty or matches rules
if sql.is_empty() || fixtures::matches(sql) {
return Ok(SqlPlan {
query: sql.to_owned(),
statement: None,
plan: None,
schema: None,
return Ok(PgSqlPlan {
plan: SqlPlan {
query: sql.to_owned(),
statement: None,
plan: None,
schema: None,
},
copy_to_stdout_format: None,
});
}
let sql = if let Ok(mut statements) = self.compatibility_parser.parse(sql) {
statements.remove(0).to_string()
let parsed_statements = self.compatibility_parser.parse(sql);
let (sql, copy_to_stdout_format) = if let Ok(mut statements) = parsed_statements {
let first_stmt = statements.remove(0);
let format = check_copy_to_stdout(&first_stmt);
(first_stmt.to_string(), format)
} else {
// bypass the error: it can run into error because of different
// versions of sqlparser
sql.to_string()
(sql.to_string(), None)
};
let mut stmts = ParserContext::create_with_dialect(
@@ -258,11 +370,14 @@ impl QueryParser for DefaultQueryParser {
(None, None)
};
Ok(SqlPlan {
query: sql.clone(),
statement: Some(stmt),
plan,
schema,
Ok(PgSqlPlan {
plan: SqlPlan {
query: sql.clone(),
statement: Some(stmt),
plan,
schema,
},
copy_to_stdout_format,
})
}
}
@@ -290,7 +405,7 @@ impl QueryParser for DefaultQueryParser {
#[async_trait]
impl ExtendedQueryHandler for PostgresServerHandlerInner {
type Statement = SqlPlan;
type Statement = PgSqlPlan;
type QueryParser = DefaultQueryParser;
fn query_parser(&self) -> Arc<Self::QueryParser> {
@@ -314,7 +429,8 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
.with_label_values(&[crate::metrics::METRIC_POSTGRES_EXTENDED_QUERY, db.as_str()])
.start_timer();
let sql_plan = &portal.statement.statement;
let pg_sql_plan = &portal.statement.statement;
let sql_plan = &pg_sql_plan.plan;
if sql_plan.query.is_empty() {
// early return if query is empty
@@ -355,7 +471,12 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
};
send_warning_opt(client, query_ctx.clone()).await?;
output_to_query_response(query_ctx, output, &portal.result_column_format)
if let Some(format) = &pg_sql_plan.copy_to_stdout_format {
output_to_copy_response(query_ctx, output, format)
} else {
output_to_query_response(query_ctx, output, &portal.result_column_format)
}
}
async fn do_describe_statement<C>(
@@ -366,7 +487,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
where
C: ClientInfo + Unpin + Send + Sync,
{
let sql_plan = &stmt.statement;
let sql_plan = &stmt.statement.plan;
// client provided parameter types, can be empty if client doesn't try to parse statement
let provided_param_types = &stmt.parameter_types;
let server_inferenced_types = if let Some(plan) = &sql_plan.plan {
@@ -434,7 +555,7 @@ impl ExtendedQueryHandler for PostgresServerHandlerInner {
where
C: ClientInfo + Unpin + Send + Sync,
{
let sql_plan = &portal.statement.statement;
let sql_plan = &portal.statement.statement.plan;
let format = &portal.result_column_format;
match sql_plan.statement.as_ref() {
@@ -510,3 +631,86 @@ impl ErrorHandler for PostgresServerHandlerInner {
debug!("Postgres interface error {}", error)
}
}
fn check_copy_to_stdout(statement: &SqlParserStatement) -> Option<String> {
if let SqlParserStatement::Copy {
target, options, ..
} = statement
&& matches!(target, CopyTarget::Stdout)
{
for opt in options {
if let CopyOption::Format(format_ident) = opt {
return Some(format_ident.value.to_lowercase());
}
}
return Some("txt".to_string());
}
None
}
#[cfg(test)]
mod tests {
use datafusion_pg_catalog::sql::PostgresCompatibilityParser;
use super::*;
fn parse_copy_statement(sql: &str) -> SqlParserStatement {
let parser = PostgresCompatibilityParser::new();
let statements = parser.parse(sql).unwrap();
statements.into_iter().next().unwrap()
}
#[test]
fn test_check_copy_out_with_csv_format() {
let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT CSV)");
assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
}
#[test]
fn test_check_copy_out_with_txt_format() {
let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT TXT)");
assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
}
#[test]
fn test_check_copy_out_with_binary_format() {
let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT BINARY)");
assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
}
#[test]
fn test_check_copy_out_without_format() {
let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT");
assert_eq!(check_copy_to_stdout(&statement), Some("txt".to_string()));
}
#[test]
fn test_check_copy_out_to_file() {
let statement =
parse_copy_statement("COPY (SELECT 1) TO '/path/to/file.csv' WITH (FORMAT CSV)");
assert_eq!(check_copy_to_stdout(&statement), None);
}
#[test]
fn test_check_copy_out_case_insensitive() {
let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT csv)");
assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
let statement = parse_copy_statement("COPY (SELECT 1) TO STDOUT WITH (FORMAT binary)");
assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
}
#[test]
fn test_check_copy_out_with_multiple_options() {
let statement = parse_copy_statement(
"COPY (SELECT 1) TO STDOUT WITH (FORMAT csv, DELIMITER ',', HEADER)",
);
assert_eq!(check_copy_to_stdout(&statement), Some("csv".to_string()));
let statement = parse_copy_statement(
"COPY (SELECT 1) TO STDOUT WITH (DELIMITER ',', HEADER, FORMAT binary)",
);
assert_eq!(check_copy_to_stdout(&statement), Some("binary".to_string()));
}
}

View File

@@ -15,14 +15,17 @@
mod error;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use arrow::array::{Array, AsArray};
use arrow_pg::encoder::encode_value;
use arrow_pg::encoder::{Encoder, encode_value};
use arrow_pg::list_encoder::encode_list;
use arrow_schema::{DataType, TimeUnit};
use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime};
use common_recordbatch::RecordBatch;
use common_recordbatch::error::Result as RecordBatchResult;
use common_time::{IntervalDayTime, IntervalMonthDayNano, IntervalYearMonth};
use datafusion_common::ScalarValue;
use datafusion_expr::LogicalPlan;
@@ -32,20 +35,20 @@ use datatypes::prelude::{ConcreteDataType, Value};
use datatypes::schema::{Schema, SchemaRef};
use datatypes::types::{IntervalType, TimestampType, jsonb_to_string};
use datatypes::value::StructValue;
use futures::Stream;
use pg_interval::Interval as PgInterval;
use pgwire::api::Type;
use pgwire::api::portal::{Format, Portal};
use pgwire::api::results::{DataRowEncoder, FieldInfo};
use pgwire::api::results::FieldInfo;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::data::DataRow;
use pgwire::types::format::FormatOptions as PgFormatOptions;
use query::planner::DfLogicalPlanner;
use session::context::QueryContextRef;
use snafu::ResultExt;
pub use self::error::{PgErrorCode, PgErrorSeverity};
use crate::SqlPlan;
use crate::error::{self as server_error, InferParameterTypesSnafu, Result};
use crate::postgres::handler::PgSqlPlan;
use crate::postgres::utils::convert_err;
pub(super) fn schema_to_pg(
@@ -82,78 +85,125 @@ pub(super) fn schema_to_pg(
/// there are alternatives like records, arrays, etc. but there are also limitations:
/// records: there is no support for include keys
/// arrays: element in array must be the same type
fn encode_struct(
fn encode_struct<S: Encoder>(
_query_ctx: &QueryContextRef,
struct_value: StructValue,
builder: &mut DataRowEncoder,
builder: &mut S,
pg_field: &FieldInfo,
) -> PgWireResult<()> {
let encoding_setting = JsonStructureSettings::Structured(None);
let json_value = encoding_setting
.decode(Value::Struct(struct_value))
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
builder.encode_field(&json_value)
builder.encode_field(&json_value, pg_field)
}
pub(crate) struct RecordBatchRowIterator {
pub(crate) struct RecordBatchRowStream<S, B>
where
S: Encoder,
B: Stream<Item = RecordBatchResult<RecordBatch>>,
{
query_ctx: QueryContextRef,
pg_schema: Arc<Vec<FieldInfo>>,
schema: SchemaRef,
record_batch: arrow::record_batch::RecordBatch,
i: usize,
record_batches: Pin<Box<B>>,
encoder: S,
}
impl Iterator for RecordBatchRowIterator {
type Item = PgWireResult<DataRow>;
impl<S, B> Stream for RecordBatchRowStream<S, B>
where
S: Encoder + Unpin,
B: Stream<Item = RecordBatchResult<RecordBatch>>,
{
type Item = PgWireResult<Vec<S::Item>>;
fn next(&mut self) -> Option<Self::Item> {
let mut encoder = DataRowEncoder::new(self.pg_schema.clone());
if self.i < self.record_batch.num_rows() {
if let Err(e) = self.encode_row(self.i, &mut encoder) {
return Some(Err(e));
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.record_batches.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(batch))) => {
let record_batch = batch.into_df_record_batch();
let num_rows = record_batch.num_rows();
if num_rows == 0 {
return Poll::Ready(Some(Ok(vec![])));
}
let arrow_schema = record_batch.schema();
let query_ctx = self.query_ctx.clone();
let pg_schema = self.pg_schema.clone();
let schema = self.schema.clone();
let mut results = Vec::with_capacity(num_rows);
for i in 0..num_rows {
if let Err(e) = Self::encode_row(
&query_ctx,
&pg_schema,
&schema,
arrow_schema.as_ref(),
&mut self.encoder,
&record_batch,
i,
) {
return Poll::Ready(Some(Err(e)));
}
results.push(self.encoder.take_row());
}
Poll::Ready(Some(Ok(results)))
}
self.i += 1;
Some(Ok(encoder.take_row()))
} else {
None
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(convert_err(e)))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
impl RecordBatchRowIterator {
impl<S, B> RecordBatchRowStream<S, B>
where
S: Encoder,
B: Stream<Item = RecordBatchResult<RecordBatch>>,
{
pub(crate) fn new(
query_ctx: QueryContextRef,
pg_schema: Arc<Vec<FieldInfo>>,
record_batch: RecordBatch,
schema: SchemaRef,
record_batches: B,
encoder: S,
) -> Self {
let schema = record_batch.schema.clone();
let record_batch = record_batch.into_df_record_batch();
Self {
query_ctx,
pg_schema,
schema,
record_batch,
i: 0,
record_batches: Box::pin(record_batches),
encoder,
}
}
fn encode_row(&mut self, i: usize, encoder: &mut DataRowEncoder) -> PgWireResult<()> {
let arrow_schema = self.record_batch.schema();
for (j, column) in self.record_batch.columns().iter().enumerate() {
fn encode_row(
query_ctx: &QueryContextRef,
pg_schema: &Arc<Vec<FieldInfo>>,
schema: &SchemaRef,
arrow_schema: &arrow::datatypes::Schema,
encoder: &mut S,
record_batch: &arrow::record_batch::RecordBatch,
i: usize,
) -> PgWireResult<()> {
for (j, column) in record_batch.columns().iter().enumerate() {
let pg_field = &pg_schema[j];
if column.is_null(i) {
encoder.encode_field(&None::<&i8>)?;
encoder.encode_field(&None::<&i8>, pg_field)?;
continue;
}
let pg_field = &self.pg_schema[j];
match column.data_type() {
// these types are greptimedb specific or custom
DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
// jsonb
if let ConcreteDataType::Json(_) = &self.schema.column_schemas()[j].data_type {
if let ConcreteDataType::Json(_) = &schema.column_schemas()[j].data_type {
let v = datatypes::arrow_array::binary_array_value(column, i);
let s = jsonb_to_string(v).map_err(convert_err)?;
encoder.encode_field(&s)?;
encoder.encode_field(&s, pg_field)?;
} else {
// bytea
let arrow_field = arrow_schema.field(j);
@@ -168,7 +218,7 @@ impl RecordBatchRowIterator {
encode_list(encoder, items, pg_field)?;
}
DataType::Struct(_) => {
encode_struct(&self.query_ctx, Default::default(), encoder)?;
encode_struct(query_ctx, Default::default(), encoder, pg_field)?;
}
_ => {
// Encode value using arrow-pg
@@ -277,7 +327,7 @@ pub(super) fn type_pg_to_gt(origin: &Type) -> Result<ConcreteDataType> {
}
}
pub(super) fn parameter_to_string(portal: &Portal<SqlPlan>, idx: usize) -> PgWireResult<String> {
pub(super) fn parameter_to_string(portal: &Portal<PgSqlPlan>, idx: usize) -> PgWireResult<String> {
// the index is managed from portal's parameters count so it's safe to
// unwrap here.
let param_type = portal
@@ -359,7 +409,7 @@ where
pub(super) fn parameters_to_scalar_values(
plan: &LogicalPlan,
portal: &Portal<SqlPlan>,
portal: &Portal<PgSqlPlan>,
) -> PgWireResult<Vec<ScalarValue>> {
let param_count = portal.parameter_len();
let mut results = Vec::with_capacity(param_count);
@@ -761,7 +811,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::INT2_ARRAY => {
let data = portal.parameter::<Vec<i16>>(idx, &client_type)?;
let data = portal.parameter::<Vec<Option<i16>>>(idx, &client_type)?;
if let Some(data) = data {
let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int16, true))
@@ -770,7 +820,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::INT4_ARRAY => {
let data = portal.parameter::<Vec<i32>>(idx, &client_type)?;
let data = portal.parameter::<Vec<Option<i32>>>(idx, &client_type)?;
if let Some(data) = data {
let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int32, true))
@@ -779,7 +829,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::INT8_ARRAY => {
let data = portal.parameter::<Vec<i64>>(idx, &client_type)?;
let data = portal.parameter::<Vec<Option<i64>>>(idx, &client_type)?;
if let Some(data) = data {
let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Int64, true))
@@ -788,7 +838,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::VARCHAR_ARRAY => {
let data = portal.parameter::<Vec<String>>(idx, &client_type)?;
let data = portal.parameter::<Vec<Option<String>>>(idx, &client_type)?;
if let Some(data) = data {
let values = data.into_iter().map(|i| i.into()).collect::<Vec<_>>();
ScalarValue::List(ScalarValue::new_list(&values, &ArrowDataType::Utf8, true))
@@ -797,7 +847,7 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::TIMESTAMP_ARRAY => {
let data = portal.parameter::<Vec<NaiveDateTime>>(idx, &client_type)?;
let data = portal.parameter::<Vec<Option<NaiveDateTime>>>(idx, &client_type)?;
if let Some(data) = data {
if let Some(ConcreteDataType::List(list_type)) = &server_type {
match list_type.item_type() {
@@ -807,7 +857,7 @@ pub(super) fn parameters_to_scalar_values(
.into_iter()
.map(|ts| {
ScalarValue::TimestampSecond(
Some(ts.and_utc().timestamp()),
ts.map(|ts| ts.and_utc().timestamp()),
None,
)
})
@@ -823,7 +873,7 @@ pub(super) fn parameters_to_scalar_values(
.into_iter()
.map(|ts| {
ScalarValue::TimestampMillisecond(
Some(ts.and_utc().timestamp_millis()),
ts.map(|ts| ts.and_utc().timestamp_millis()),
None,
)
})
@@ -839,7 +889,7 @@ pub(super) fn parameters_to_scalar_values(
.into_iter()
.map(|ts| {
ScalarValue::TimestampMicrosecond(
Some(ts.and_utc().timestamp_micros()),
ts.map(|ts| ts.and_utc().timestamp_micros()),
None,
)
})
@@ -854,8 +904,13 @@ pub(super) fn parameters_to_scalar_values(
let values = data
.into_iter()
.filter_map(|ts| {
ts.and_utc().timestamp_nanos_opt().map(|nanos| {
ScalarValue::TimestampNanosecond(Some(nanos), None)
ts.and_then(|ts| {
ts.and_utc().timestamp_nanos_opt().map(|nanos| {
ScalarValue::TimestampNanosecond(
Some(nanos),
None,
)
})
})
})
.collect::<Vec<_>>();
@@ -878,12 +933,11 @@ pub(super) fn parameters_to_scalar_values(
}
}
} else {
// Default to millisecond when no server type is specified
let values = data
.into_iter()
.map(|ts| {
ScalarValue::TimestampMillisecond(
Some(ts.and_utc().timestamp_millis()),
ts.map(|ts| ts.and_utc().timestamp_millis()),
None,
)
})
@@ -899,7 +953,8 @@ pub(super) fn parameters_to_scalar_values(
}
}
&Type::TIMESTAMPTZ_ARRAY => {
let data = portal.parameter::<Vec<DateTime<FixedOffset>>>(idx, &client_type)?;
let data =
portal.parameter::<Vec<Option<DateTime<FixedOffset>>>>(idx, &client_type)?;
if let Some(data) = data {
if let Some(ConcreteDataType::List(list_type)) = &server_type {
match list_type.item_type() {
@@ -908,7 +963,10 @@ pub(super) fn parameters_to_scalar_values(
let values = data
.into_iter()
.map(|ts| {
ScalarValue::TimestampSecond(Some(ts.timestamp()), None)
ScalarValue::TimestampSecond(
ts.map(|ts| ts.timestamp()),
None,
)
})
.collect::<Vec<_>>();
ScalarValue::List(ScalarValue::new_list(
@@ -922,7 +980,7 @@ pub(super) fn parameters_to_scalar_values(
.into_iter()
.map(|ts| {
ScalarValue::TimestampMillisecond(
Some(ts.timestamp_millis()),
ts.map(|ts| ts.timestamp_millis()),
None,
)
})
@@ -938,7 +996,7 @@ pub(super) fn parameters_to_scalar_values(
.into_iter()
.map(|ts| {
ScalarValue::TimestampMicrosecond(
Some(ts.timestamp_micros()),
ts.map(|ts| ts.timestamp_micros()),
None,
)
})
@@ -952,10 +1010,11 @@ pub(super) fn parameters_to_scalar_values(
TimestampType::Nanosecond(_) => {
let values = data
.into_iter()
.filter_map(|ts| {
ts.timestamp_nanos_opt().map(|nanos| {
ScalarValue::TimestampNanosecond(Some(nanos), None)
})
.map(|ts| {
ScalarValue::TimestampNanosecond(
ts.and_then(|ts| ts.timestamp_nanos_opt()),
None,
)
})
.collect::<Vec<_>>();
ScalarValue::List(ScalarValue::new_list(
@@ -977,11 +1036,13 @@ pub(super) fn parameters_to_scalar_values(
}
}
} else {
// Default to millisecond when no server type is specified
let values = data
.into_iter()
.map(|ts| {
ScalarValue::TimestampMillisecond(Some(ts.timestamp_millis()), None)
ScalarValue::TimestampMillisecond(
ts.map(|ts| ts.timestamp_millis()),
None,
)
})
.collect::<Vec<_>>();
ScalarValue::List(ScalarValue::new_list(
@@ -1050,8 +1111,9 @@ mod test {
IntervalYearMonthVector, ListVector, NullVector, StringVector, TimeSecondVector,
TimestampSecondVector, UInt8Vector, UInt16Vector, UInt32Vector, UInt64Vector, VectorRef,
};
use futures::{StreamExt as FuturesStreamExt, stream};
use pgwire::api::Type;
use pgwire::api::results::{FieldFormat, FieldInfo};
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo};
use session::context::QueryContextBuilder;
use super::*;
@@ -1155,7 +1217,7 @@ mod test {
#[test]
fn test_encode_text_format_data() {
let schema = vec![
let pg_schema = vec![
FieldInfo::new("nulls".into(), None, None, Type::UNKNOWN, FieldFormat::Text),
FieldInfo::new("bools".into(), None, None, Type::BOOL, FieldFormat::Text),
FieldInfo::new("uint8s".into(), None, None, Type::INT2, FieldFormat::Text),
@@ -1416,14 +1478,28 @@ mod test {
.configuration_parameter(Default::default())
.build()
.into();
let schema = Arc::new(schema);
let schema = record_batch.schema.clone();
let pg_schema_ref = Arc::new(pg_schema);
let rows = RecordBatchRowIterator::new(query_context, schema.clone(), record_batch)
.filter_map(|x| x.ok())
.collect::<Vec<_>>();
let encoder = DataRowEncoder::new(pg_schema_ref.clone());
let row_stream = RecordBatchRowStream::new(
query_context,
pg_schema_ref.clone(),
schema,
stream::once(async { Ok(record_batch) }),
encoder,
);
let rows: Vec<_> = futures::executor::block_on(
row_stream
.filter_map(|x: PgWireResult<_>| async move { x.ok() })
.flat_map(stream::iter)
.collect::<Vec<_>>(),
);
assert_eq!(rows.len(), 3);
for row in rows {
assert_eq!(row.field_count, schema.len() as i16);
assert_eq!(row.field_count, pg_schema_ref.len() as i16);
}
}

View File

@@ -32,22 +32,34 @@ impl ParserContext<'_> {
pub(crate) fn parse_copy(&mut self) -> Result<Statement> {
let _ = self.parser.next_token();
let next = self.parser.peek_token();
let copy = if next.token == Token::LParen {
if next.token == Token::LParen {
let copy_query = self.parse_copy_query_to()?;
crate::statements::copy::Copy::CopyQueryTo(copy_query)
// the COPY ... TO STDOUT is a special case for postgres wire protocol
// the logic is completely identical to query, but with an alternative data encoding on transport
//
// so at the query engine level, we simple parse the command as it's inner query
// we will deal with the encoding and its format options in server/src/postgres/handler.rs
if copy_query.arg.location == "STDOUT" {
Ok(*copy_query.query)
} else {
Ok(Statement::Copy(crate::statements::copy::Copy::CopyQueryTo(
copy_query,
)))
}
} else if let Word(word) = next.token
&& word.keyword == Keyword::DATABASE
{
let _ = self.parser.next_token();
let copy_database = self.parser_copy_database()?;
crate::statements::copy::Copy::CopyDatabase(copy_database)
Ok(Statement::Copy(
crate::statements::copy::Copy::CopyDatabase(copy_database),
))
} else {
let copy_table = self.parse_copy_table()?;
crate::statements::copy::Copy::CopyTable(copy_table)
};
Ok(Statement::Copy(copy))
Ok(Statement::Copy(crate::statements::copy::Copy::CopyTable(
copy_table,
)))
}
}
fn parser_copy_database(&mut self) -> Result<CopyDatabase> {
@@ -147,21 +159,62 @@ impl ParserContext<'_> {
self.parser
.expect_keyword(Keyword::TO)
.context(error::SyntaxSnafu)?;
let (with, connection, location, limit) = self.parse_copy_parameters()?;
if limit.is_some() {
return error::InvalidSqlSnafu {
msg: "limit is not supported",
if self.parser.parse_keyword(Keyword::STDOUT) {
// early return without parsing options
// we will deal with copy to stdout on postgres protocol layer
// consume [WITH] (...) options if present (they will be ignored)
// we support both "WITH (FORMAT binary)" and "(FORMAT binary)"
// for PostgreSQL compatibility
// Check for optional WITH keyword or direct LParen (PostgreSQL syntax)
// Both "WITH (...)" and "(...)" are valid after STDOUT
let _ = self.parser.parse_keyword(Keyword::WITH);
if self.parser.peek_token().token == Token::LParen {
let _ = self.parser.next_token();
// consume all tokens until we find matching RParen
let mut depth = 1;
while depth > 0 {
match self.parser.next_token().token {
Token::LParen => depth += 1,
Token::RParen => depth -= 1,
Token::EOF => {
return error::UnexpectedTokenSnafu {
expected: ")",
actual: "EOF",
}
.fail();
}
_ => {}
}
}
}
.fail();
Ok(CopyQueryTo {
query: Box::new(query),
arg: CopyQueryToArgument {
with: OptionMap::default(),
connection: OptionMap::default(),
location: "STDOUT".to_string(),
},
})
} else {
let (with, connection, location, limit) = self.parse_copy_parameters()?;
if limit.is_some() {
return error::InvalidSqlSnafu {
msg: "limit is not supported",
}
.fail();
}
Ok(CopyQueryTo {
query: Box::new(query),
arg: CopyQueryToArgument {
with,
connection,
location,
},
})
}
Ok(CopyQueryTo {
query: Box::new(query),
arg: CopyQueryToArgument {
with,
connection,
location,
},
})
}
fn parse_copy_parameters(&mut self) -> Result<(OptionMap, OptionMap, String, Option<u64>)> {
@@ -540,4 +593,243 @@ mod tests {
)
}
}
#[test]
fn test_copy_query_to_stdout() {
let sql = "COPY (SELECT * FROM tbl WHERE ts > 10) TO STDOUT WITH (FORMAT = 'csv')";
let stmt =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap()
.pop()
.unwrap();
let expected_query = ParserContext::create_with_dialect(
"SELECT * FROM tbl WHERE ts > 10",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()
.remove(0);
assert_eq!(&expected_query, &stmt);
}
#[test]
fn test_copy_query_to_stdout_without_format() {
let sql = "COPY (SELECT generate_series(1, 2), generate_series(2, 3)) TO STDOUT";
let stmt =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default())
.unwrap()
.pop()
.unwrap();
let query_str = "SELECT generate_series(1, 2), generate_series(2, 3)";
let expected_query = ParserContext::create_with_dialect(
query_str,
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()
.remove(0);
assert_eq!(&expected_query, &stmt);
}
#[test]
fn test_copy_query_to_stdout_with_binary_format() {
let sql = "COPY (SELECT * FROM test_table) TO STDOUT WITH (FORMAT binary)";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
if let Err(e) = &result {
panic!(
"COPY TO STDOUT WITH (FORMAT binary) should parse without error, got: {:?}",
e
);
}
let stmt = result.unwrap().pop().unwrap();
let expected_query = ParserContext::create_with_dialect(
"SELECT * FROM test_table",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()
.remove(0);
assert_eq!(&expected_query, &stmt);
}
#[test]
fn test_copy_query_to_stdout_with_csv_format() {
let sql = "COPY (SELECT * FROM test_table) TO STDOUT WITH (FORMAT csv)";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
if let Err(e) = &result {
panic!(
"COPY TO STDOUT WITH (FORMAT csv) should parse without error, got: {:?}",
e
);
}
let stmt = result.unwrap().pop().unwrap();
let expected_query = ParserContext::create_with_dialect(
"SELECT * FROM test_table",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()
.remove(0);
assert_eq!(&expected_query, &stmt);
}
#[test]
fn test_copy_query_to_stdout_with_equals_format() {
let sql = "COPY (SELECT * FROM test_table) TO STDOUT WITH (FORMAT = 'binary')";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
if let Err(e) = &result {
panic!(
"COPY TO STDOUT WITH (FORMAT = 'binary') should parse without error, got: {:?}",
e
);
}
let stmt = result.unwrap().pop().unwrap();
let expected_query = ParserContext::create_with_dialect(
"SELECT * FROM test_table",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()
.remove(0);
assert_eq!(&expected_query, &stmt);
}
#[test]
fn test_copy_query_to_stdout_with_multiple_options() {
let sql =
"COPY (SELECT * FROM test_table) TO STDOUT WITH (FORMAT csv, DELIMITER ',', HEADER)";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
if let Err(e) = &result {
panic!(
"COPY TO STDOUT WITH multiple options should parse without error, got: {:?}",
e
);
}
let stmt = result.unwrap().pop().unwrap();
let expected_query = ParserContext::create_with_dialect(
"SELECT * FROM test_table",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()
.remove(0);
assert_eq!(&expected_query, &stmt);
}
#[test]
fn test_copy_query_to_stdout_without_with_keyword() {
let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT binary)";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
if let Err(e) = &result {
panic!(
"COPY TO STDOUT (FORMAT binary) without WITH keyword should parse without error, got: {:?}",
e
);
}
let stmt = result.unwrap().pop().unwrap();
let expected_query = ParserContext::create_with_dialect(
"SELECT * FROM test_table",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()
.remove(0);
assert_eq!(&expected_query, &stmt);
}
#[test]
fn test_copy_query_to_stdout_without_with_csv_format() {
let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT csv)";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
if let Err(e) = &result {
panic!(
"COPY TO STDOUT (FORMAT csv) without WITH keyword should parse without error, got: {:?}",
e
);
}
let stmt = result.unwrap().pop().unwrap();
let expected_query = ParserContext::create_with_dialect(
"SELECT * FROM test_table",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()
.remove(0);
assert_eq!(&expected_query, &stmt);
}
#[test]
fn test_copy_query_to_stdout_without_with_multiple_options() {
let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT csv, DELIMITER ',', HEADER)";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
if let Err(e) = &result {
panic!(
"COPY TO STDOUT (FORMAT csv, ...) without WITH keyword should parse without error, got: {:?}",
e
);
}
let stmt = result.unwrap().pop().unwrap();
let expected_query = ParserContext::create_with_dialect(
"SELECT * FROM test_table",
&GreptimeDbDialect {},
ParseOptions::default(),
)
.unwrap()
.remove(0);
assert_eq!(&expected_query, &stmt);
}
#[test]
fn test_invalid_copy_query() {
let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT csv";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
assert!(result.is_err());
let sql = "COPY (SELECT * FROM test_table) TO STDOUT (FORMAT csv))";
let result =
ParserContext::create_with_dialect(sql, &GreptimeDbDialect {}, ParseOptions::default());
assert!(result.is_err());
}
}

View File

@@ -1099,3 +1099,32 @@ WHERE oid = pg_my_temp_schema();
| OID| public | t | |
+-------+--------------------+-------------------+---------+
-- SQLNESS PROTOCOL POSTGRES
CREATE table foo
(
ts TIMESTAMP TIME INDEX,
log_data TEXT,
count_num BIGINT,
);
Affected Rows: 0
-- SQLNESS PROTOCOL POSTGRES
SELECT attname, atttypid FROM pg_catalog.pg_class AS cls INNER JOIN
pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid INNER JOIN
pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid WHERE attr.attnum >= 0 AND
cls.oid = 'foo'::regclass::oid ORDER BY attr.attnum;
+-----------+----------+
| attname | atttypid |
+-----------+----------+
| ts | 1114 |
| log_data | 25 |
| count_num | 20 |
+-----------+----------+
-- SQLNESS PROTOCOL POSTGRES
DROP TABLE foo;
Affected Rows: 0

View File

@@ -258,3 +258,20 @@ oid
,nspname
FROM pg_namespace
WHERE oid = pg_my_temp_schema();
-- SQLNESS PROTOCOL POSTGRES
CREATE table foo
(
ts TIMESTAMP TIME INDEX,
log_data TEXT,
count_num BIGINT,
);
-- SQLNESS PROTOCOL POSTGRES
SELECT attname, atttypid FROM pg_catalog.pg_class AS cls INNER JOIN
pg_catalog.pg_attribute AS attr ON cls.oid = attr.attrelid INNER JOIN
pg_catalog.pg_type AS typ ON attr.atttypid = typ.oid WHERE attr.attnum >= 0 AND
cls.oid = 'foo'::regclass::oid ORDER BY attr.attnum;
-- SQLNESS PROTOCOL POSTGRES
DROP TABLE foo;