Merge branch 'problame/2024-02-walredo-work/prespawn/split-code' into problame/2024-02-walredo-work/prespawn/broken-tenants-no-walredo

This commit is contained in:
Christian Schwarz
2024-02-02 17:16:25 +00:00
19 changed files with 898 additions and 766 deletions

242
Cargo.lock generated
View File

@@ -275,6 +275,8 @@ name = "attachment_service"
version = "0.1.0"
dependencies = [
"anyhow",
"aws-config",
"aws-sdk-secretsmanager",
"camino",
"clap",
"control_plane",
@@ -304,12 +306,11 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "aws-config"
version = "1.0.1"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80c950a809d39bc9480207cb1cfc879ace88ea7e3a4392a8e9999e45d6e5692e"
checksum = "8b30c39ebe61f75d1b3785362b1586b41991873c9ab3e317a9181c246fb71d82"
dependencies = [
"aws-credential-types",
"aws-http",
"aws-runtime",
"aws-sdk-sso",
"aws-sdk-ssooidc",
@@ -324,7 +325,7 @@ dependencies = [
"bytes",
"fastrand 2.0.0",
"hex",
"http",
"http 0.2.9",
"hyper",
"ring 0.17.6",
"time",
@@ -335,9 +336,9 @@ dependencies = [
[[package]]
name = "aws-credential-types"
version = "1.0.1"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c1317e1a3514b103cf7d5828bbab3b4d30f56bd22d684f8568bc51b6cfbbb1c"
checksum = "33cc49dcdd31c8b6e79850a179af4c367669150c7ac0135f176c61bec81a70f7"
dependencies = [
"aws-smithy-async",
"aws-smithy-runtime-api",
@@ -345,30 +346,13 @@ dependencies = [
"zeroize",
]
[[package]]
name = "aws-http"
version = "0.60.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "361c4310fdce94328cc2d1ca0c8a48c13f43009c61d3367585685a50ca8c66b6"
dependencies = [
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"http",
"http-body",
"pin-project-lite",
"tracing",
]
[[package]]
name = "aws-runtime"
version = "1.0.1"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ed7ef604a15fd0d4d9e43701295161ea6b504b63c44990ead352afea2bc15e9"
checksum = "eb031bff99877c26c28895766f7bb8484a05e24547e370768d6cc9db514662aa"
dependencies = [
"aws-credential-types",
"aws-http",
"aws-sigv4",
"aws-smithy-async",
"aws-smithy-eventstream",
@@ -376,21 +360,23 @@ dependencies = [
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand 2.0.0",
"http",
"http 0.2.9",
"http-body",
"percent-encoding",
"pin-project-lite",
"tracing",
"uuid",
]
[[package]]
name = "aws-sdk-s3"
version = "1.4.0"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dcafc2fe52cc30b2d56685e2fa6a879ba50d79704594852112337a472ddbd24"
checksum = "951f7730f51a2155c711c85c79f337fbc02a577fa99d2a0a8059acfce5392113"
dependencies = [
"aws-credential-types",
"aws-http",
"aws-runtime",
"aws-sigv4",
"aws-smithy-async",
@@ -404,23 +390,22 @@ dependencies = [
"aws-smithy-xml",
"aws-types",
"bytes",
"http",
"http 0.2.9",
"http-body",
"once_cell",
"percent-encoding",
"regex",
"regex-lite",
"tracing",
"url",
]
[[package]]
name = "aws-sdk-sso"
version = "1.3.0"
name = "aws-sdk-secretsmanager"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0619ab97a5ca8982e7de073cdc66f93e5f6a1b05afc09e696bec1cb3607cd4df"
checksum = "0a0b64e61e7d632d9df90a2e0f32630c68c24960cab1d27d848718180af883d3"
dependencies = [
"aws-credential-types",
"aws-http",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
@@ -430,19 +415,42 @@ dependencies = [
"aws-smithy-types",
"aws-types",
"bytes",
"http",
"regex",
"fastrand 2.0.0",
"http 0.2.9",
"once_cell",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-sso"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f486420a66caad72635bc2ce0ff6581646e0d32df02aa39dc983bfe794955a5b"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"http 0.2.9",
"once_cell",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-ssooidc"
version = "1.3.0"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f04b9f5474cc0f35d829510b2ec8c21e352309b46bf9633c5a81fb9321e9b1c7"
checksum = "39ddccf01d82fce9b4a15c8ae8608211ee7db8ed13a70b514bbfe41df3d24841"
dependencies = [
"aws-credential-types",
"aws-http",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
@@ -452,19 +460,19 @@ dependencies = [
"aws-smithy-types",
"aws-types",
"bytes",
"http",
"regex",
"http 0.2.9",
"once_cell",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-sts"
version = "1.3.0"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5700da387716ccfc30b27f44b008f457e1baca5b0f05b6b95455778005e3432a"
checksum = "1a591f8c7e6a621a501b2b5d2e88e1697fcb6274264523a6ad4d5959889a41ce"
dependencies = [
"aws-credential-types",
"aws-http",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
@@ -475,16 +483,17 @@ dependencies = [
"aws-smithy-types",
"aws-smithy-xml",
"aws-types",
"http",
"regex",
"http 0.2.9",
"once_cell",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sigv4"
version = "1.0.1"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "380adcc8134ad8bbdfeb2ace7626a869914ee266322965276cbc54066186d236"
checksum = "c371c6b0ac54d4605eb6f016624fb5c7c2925d315fdf600ac1bf21b19d5f1742"
dependencies = [
"aws-credential-types",
"aws-smithy-eventstream",
@@ -496,11 +505,11 @@ dependencies = [
"form_urlencoded",
"hex",
"hmac",
"http",
"http 0.2.9",
"http 1.0.0",
"once_cell",
"p256",
"percent-encoding",
"regex",
"ring 0.17.6",
"sha2",
"subtle",
@@ -511,9 +520,9 @@ dependencies = [
[[package]]
name = "aws-smithy-async"
version = "1.0.2"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e37ca17d25fe1e210b6d4bdf59b81caebfe99f986201a1228cb5061233b4b13"
checksum = "72ee2d09cce0ef3ae526679b522835d63e75fb427aca5413cd371e490d52dcc6"
dependencies = [
"futures-util",
"pin-project-lite",
@@ -522,9 +531,9 @@ dependencies = [
[[package]]
name = "aws-smithy-checksums"
version = "0.60.0"
version = "0.60.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5a373ec01aede3dd066ec018c1bc4e8f5dd11b2c11c59c8eef1a5c68101f397"
checksum = "be2acd1b9c6ae5859999250ed5a62423aedc5cf69045b844432de15fa2f31f2b"
dependencies = [
"aws-smithy-http",
"aws-smithy-types",
@@ -532,7 +541,7 @@ dependencies = [
"crc32c",
"crc32fast",
"hex",
"http",
"http 0.2.9",
"http-body",
"md-5",
"pin-project-lite",
@@ -543,9 +552,9 @@ dependencies = [
[[package]]
name = "aws-smithy-eventstream"
version = "0.60.0"
version = "0.60.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c669e1e5fc0d79561bf7a122b118bd50c898758354fe2c53eb8f2d31507cbc3"
checksum = "e6363078f927f612b970edf9d1903ef5cef9a64d1e8423525ebb1f0a1633c858"
dependencies = [
"aws-smithy-types",
"bytes",
@@ -554,9 +563,9 @@ dependencies = [
[[package]]
name = "aws-smithy-http"
version = "0.60.0"
version = "0.60.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b1de8aee22f67de467b2e3d0dd0fb30859dc53f579a63bd5381766b987db644"
checksum = "dab56aea3cd9e1101a0a999447fb346afb680ab1406cebc44b32346e25b4117d"
dependencies = [
"aws-smithy-eventstream",
"aws-smithy-runtime-api",
@@ -564,7 +573,7 @@ dependencies = [
"bytes",
"bytes-utils",
"futures-core",
"http",
"http 0.2.9",
"http-body",
"once_cell",
"percent-encoding",
@@ -575,18 +584,18 @@ dependencies = [
[[package]]
name = "aws-smithy-json"
version = "0.60.0"
version = "0.60.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a46dd338dc9576d6a6a5b5a19bd678dcad018ececee11cf28ecd7588bd1a55c"
checksum = "fd3898ca6518f9215f62678870064398f00031912390efd03f1f6ef56d83aa8e"
dependencies = [
"aws-smithy-types",
]
[[package]]
name = "aws-smithy-query"
version = "0.60.0"
version = "0.60.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "feb5b8c7a86d4b6399169670723b7e6f21a39fc833a30f5c5a2f997608178129"
checksum = "bda4b1dfc9810e35fba8a620e900522cd1bd4f9578c446e82f49d1ce41d2e9f9"
dependencies = [
"aws-smithy-types",
"urlencoding",
@@ -594,9 +603,9 @@ dependencies = [
[[package]]
name = "aws-smithy-runtime"
version = "1.0.2"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "273479291efc55e7b0bce985b139d86b6031adb8e50f65c1f712f20ba38f6388"
checksum = "fafdab38f40ad7816e7da5dec279400dd505160780083759f01441af1bbb10ea"
dependencies = [
"aws-smithy-async",
"aws-smithy-http",
@@ -605,7 +614,7 @@ dependencies = [
"bytes",
"fastrand 2.0.0",
"h2",
"http",
"http 0.2.9",
"http-body",
"hyper",
"hyper-rustls",
@@ -619,14 +628,14 @@ dependencies = [
[[package]]
name = "aws-smithy-runtime-api"
version = "1.0.2"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6cebff0d977b6b6feed2fd07db52aac58ba3ccaf26cdd49f1af4add5061bef9"
checksum = "c18276dd28852f34b3bf501f4f3719781f4999a51c7bff1a5c6dc8c4529adc29"
dependencies = [
"aws-smithy-async",
"aws-smithy-types",
"bytes",
"http",
"http 0.2.9",
"pin-project-lite",
"tokio",
"tracing",
@@ -635,15 +644,15 @@ dependencies = [
[[package]]
name = "aws-smithy-types"
version = "1.0.2"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7f48b3f27ddb40ab19892a5abda331f403e3cb877965e4e51171447807104af"
checksum = "bb3e134004170d3303718baa2a4eb4ca64ee0a1c0a7041dca31b38be0fb414f3"
dependencies = [
"base64-simd",
"bytes",
"bytes-utils",
"futures-core",
"http",
"http 0.2.9",
"http-body",
"itoa",
"num-integer",
@@ -658,24 +667,24 @@ dependencies = [
[[package]]
name = "aws-smithy-xml"
version = "0.60.0"
version = "0.60.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ec40d74a67fd395bc3f6b4ccbdf1543672622d905ef3f979689aea5b730cb95"
checksum = "8604a11b25e9ecaf32f9aa56b9fe253c5e2f606a3477f0071e96d3155a5ed218"
dependencies = [
"xmlparser",
]
[[package]]
name = "aws-types"
version = "1.0.1"
version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8403fc56b1f3761e8efe45771ddc1165e47ec3417c68e68a4519b5cb030159ca"
checksum = "789bbe008e65636fe1b6dbbb374c40c8960d1232b96af5ff4aec349f9c4accf4"
dependencies = [
"aws-credential-types",
"aws-smithy-async",
"aws-smithy-runtime-api",
"aws-smithy-types",
"http",
"http 0.2.9",
"rustc_version",
"tracing",
]
@@ -692,7 +701,7 @@ dependencies = [
"bitflags 1.3.2",
"bytes",
"futures-util",
"http",
"http 0.2.9",
"http-body",
"hyper",
"itoa",
@@ -724,7 +733,7 @@ dependencies = [
"async-trait",
"bytes",
"futures-util",
"http",
"http 0.2.9",
"http-body",
"mime",
"rustversion",
@@ -2003,9 +2012,9 @@ dependencies = [
[[package]]
name = "futures-channel"
version = "0.3.28"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2"
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
dependencies = [
"futures-core",
"futures-sink",
@@ -2013,9 +2022,9 @@ dependencies = [
[[package]]
name = "futures-core"
version = "0.3.28"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
[[package]]
name = "futures-executor"
@@ -2030,9 +2039,9 @@ dependencies = [
[[package]]
name = "futures-io"
version = "0.3.28"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
[[package]]
name = "futures-lite"
@@ -2051,9 +2060,9 @@ dependencies = [
[[package]]
name = "futures-macro"
version = "0.3.28"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
@@ -2062,15 +2071,15 @@ dependencies = [
[[package]]
name = "futures-sink"
version = "0.3.28"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e"
checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
[[package]]
name = "futures-task"
version = "0.3.28"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]]
name = "futures-timer"
@@ -2080,9 +2089,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
[[package]]
name = "futures-util"
version = "0.3.28"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
dependencies = [
"futures-channel",
"futures-core",
@@ -2186,7 +2195,7 @@ dependencies = [
"futures-core",
"futures-sink",
"futures-util",
"http",
"http 0.2.9",
"indexmap 2.0.1",
"slab",
"tokio",
@@ -2337,6 +2346,17 @@ dependencies = [
"itoa",
]
[[package]]
name = "http"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http-body"
version = "0.4.5"
@@ -2344,7 +2364,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1"
dependencies = [
"bytes",
"http",
"http 0.2.9",
"pin-project-lite",
]
@@ -2407,7 +2427,7 @@ dependencies = [
"futures-core",
"futures-util",
"h2",
"http",
"http 0.2.9",
"http-body",
"httparse",
"httpdate",
@@ -2426,7 +2446,7 @@ version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7"
dependencies = [
"http",
"http 0.2.9",
"hyper",
"log",
"rustls",
@@ -3108,7 +3128,7 @@ dependencies = [
"base64 0.13.1",
"chrono",
"getrandom 0.2.11",
"http",
"http 0.2.9",
"rand 0.8.5",
"serde",
"serde_json",
@@ -3210,7 +3230,7 @@ checksum = "c7594ec0e11d8e33faf03530a4c49af7064ebba81c1480e01be67d90b356508b"
dependencies = [
"async-trait",
"bytes",
"http",
"http 0.2.9",
"opentelemetry_api",
"reqwest",
]
@@ -3223,7 +3243,7 @@ checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275"
dependencies = [
"async-trait",
"futures-core",
"http",
"http 0.2.9",
"opentelemetry-http",
"opentelemetry-proto",
"opentelemetry-semantic-conventions",
@@ -4323,6 +4343,12 @@ dependencies = [
"regex-syntax 0.8.2",
]
[[package]]
name = "regex-lite"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e"
[[package]]
name = "regex-syntax"
version = "0.6.29"
@@ -4392,7 +4418,7 @@ dependencies = [
"futures-core",
"futures-util",
"h2",
"http",
"http 0.2.9",
"http-body",
"hyper",
"hyper-rustls",
@@ -4433,7 +4459,7 @@ checksum = "4531c89d50effe1fac90d095c8b133c20c5c714204feee0bfc3fd158e784209d"
dependencies = [
"anyhow",
"async-trait",
"http",
"http 0.2.9",
"reqwest",
"serde",
"task-local-extensions",
@@ -4451,7 +4477,7 @@ dependencies = [
"chrono",
"futures",
"getrandom 0.2.11",
"http",
"http 0.2.9",
"hyper",
"parking_lot 0.11.2",
"reqwest",
@@ -4538,7 +4564,7 @@ version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945"
dependencies = [
"http",
"http 0.2.9",
"hyper",
"lazy_static",
"percent-encoding",
@@ -5868,7 +5894,7 @@ dependencies = [
"futures-core",
"futures-util",
"h2",
"http",
"http 0.2.9",
"http-body",
"hyper",
"hyper-timeout",
@@ -6083,7 +6109,7 @@ dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http",
"http 0.2.9",
"httparse",
"log",
"rand 0.8.5",

View File

@@ -48,11 +48,12 @@ azure_storage_blobs = "0.18"
flate2 = "1.0.26"
async-stream = "0.3"
async-trait = "0.1"
aws-config = { version = "1.0", default-features = false, features=["rustls"] }
aws-sdk-s3 = "1.0"
aws-smithy-async = { version = "1.0", default-features = false, features=["rt-tokio"] }
aws-smithy-types = "1.0"
aws-credential-types = "1.0"
aws-config = { version = "1.1.4", default-features = false, features=["rustls"] }
aws-sdk-s3 = "1.14"
aws-sdk-secretsmanager = { version = "1.14.0" }
aws-smithy-async = { version = "1.1.4", default-features = false, features=["rt-tokio"] }
aws-smithy-types = "1.1.4"
aws-credential-types = "1.1.4"
axum = { version = "0.6.20", features = ["ws"] }
base64 = "0.13.0"
bincode = "1.3"

View File

@@ -6,6 +6,8 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
aws-config.workspace = true
aws-sdk-secretsmanager.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true

View File

@@ -8,6 +8,7 @@ use anyhow::anyhow;
use attachment_service::http::make_router;
use attachment_service::persistence::Persistence;
use attachment_service::service::{Config, Service};
use aws_config::{self, BehaviorVersion, Region};
use camino::Utf8PathBuf;
use clap::Parser;
use metrics::launch_timestamp::LaunchTimestamp;
@@ -46,6 +47,100 @@ struct Cli {
database_url: String,
}
/// Secrets may either be provided on the command line (for testing), or loaded from AWS SecretManager: this
/// type encapsulates the logic to decide which and do the loading.
struct Secrets {
database_url: String,
public_key: Option<JwtAuth>,
jwt_token: Option<String>,
}
impl Secrets {
const DATABASE_URL_SECRET: &'static str = "rds-neon-storage-controller-url";
const JWT_TOKEN_SECRET: &'static str = "neon-storage-controller-pageserver-jwt-token";
const PUBLIC_KEY_SECRET: &'static str = "neon-storage-controller-public-key";
async fn load(args: &Cli) -> anyhow::Result<Self> {
if args.database_url.is_empty() {
Self::load_aws_sm().await
} else {
Self::load_cli(args)
}
}
async fn load_aws_sm() -> anyhow::Result<Self> {
let Ok(region) = std::env::var("AWS_REGION") else {
anyhow::bail!("AWS_REGION is not set, cannot load secrets automatically: either set this, or use CLI args to supply secrets");
};
let config = aws_config::defaults(BehaviorVersion::v2023_11_09())
.region(Region::new(region.clone()))
.load()
.await;
let asm = aws_sdk_secretsmanager::Client::new(&config);
let Some(database_url) = asm
.get_secret_value()
.secret_id(Self::DATABASE_URL_SECRET)
.send()
.await?
.secret_string()
.map(str::to_string)
else {
anyhow::bail!(
"Database URL secret not found at {region}/{}",
Self::DATABASE_URL_SECRET
)
};
let jwt_token = asm
.get_secret_value()
.secret_id(Self::JWT_TOKEN_SECRET)
.send()
.await?
.secret_string()
.map(str::to_string);
if jwt_token.is_none() {
tracing::warn!("No pageserver JWT token set: this will only work if authentication is disabled on the pageserver");
}
let public_key = asm
.get_secret_value()
.secret_id(Self::PUBLIC_KEY_SECRET)
.send()
.await?
.secret_string()
.map(str::to_string);
let public_key = match public_key {
Some(key) => Some(JwtAuth::from_key(key)?),
None => {
tracing::warn!(
"No public key set: inccoming HTTP requests will not be authenticated"
);
None
}
};
Ok(Self {
database_url,
public_key,
jwt_token,
})
}
fn load_cli(args: &Cli) -> anyhow::Result<Self> {
let public_key = match &args.public_key {
None => None,
Some(key_path) => Some(JwtAuth::from_key_path(key_path)?),
};
Ok(Self {
database_url: args.database_url.clone(),
public_key,
jwt_token: args.jwt_token.clone(),
})
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let launch_ts = Box::leak(Box::new(LaunchTimestamp::generate()));
@@ -66,23 +161,22 @@ async fn main() -> anyhow::Result<()> {
args.listen
);
let secrets = Secrets::load(&args).await?;
let config = Config {
jwt_token: args.jwt_token,
jwt_token: secrets.jwt_token,
};
let json_path = args.path;
let persistence = Arc::new(Persistence::new(args.database_url, json_path.clone()));
let persistence = Arc::new(Persistence::new(secrets.database_url, json_path.clone()));
let service = Service::spawn(config, persistence.clone()).await?;
let http_listener = tcp_listener::bind(args.listen)?;
let auth = if let Some(public_key_path) = &args.public_key {
let jwt_auth = JwtAuth::from_key_path(public_key_path)?;
Some(Arc::new(SwappableJwtAuth::new(jwt_auth)))
} else {
None
};
let auth = secrets
.public_key
.map(|jwt_auth| Arc::new(SwappableJwtAuth::new(jwt_auth)));
let router = make_router(service, auth)
.build()
.map_err(|err| anyhow!(err))?;

View File

@@ -127,6 +127,10 @@ impl JwtAuth {
Ok(Self::new(decoding_keys))
}
pub fn from_key(key: String) -> Result<Self> {
Ok(Self::new(vec![DecodingKey::from_ed_pem(key.as_bytes())?]))
}
/// Attempt to decode the token with the internal decoding keys.
///
/// The function tries the stored decoding keys in succession,

View File

@@ -12,8 +12,7 @@ use crate::console::errors::GetAuthInfoError;
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
use crate::console::AuthSecret;
use crate::context::RequestMonitoring;
use crate::proxy::connect_compute::handle_try_wake;
use crate::proxy::retry::retry_after;
use crate::proxy::wake_compute::wake_compute;
use crate::proxy::NeonOptions;
use crate::stream::Stream;
use crate::{
@@ -28,11 +27,26 @@ use crate::{
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use futures::TryFutureExt;
use std::borrow::Cow;
use std::ops::ControlFlow;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use tracing::info;
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
Owned(T),
Borrowed(&'a T),
}
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
MaybeOwned::Owned(t) => t,
MaybeOwned::Borrowed(t) => t,
}
}
}
/// This type serves two purposes:
///
@@ -44,12 +58,9 @@ use tracing::{error, info, warn};
/// backends which require them for the authentication process.
pub enum BackendType<'a, T> {
/// Cloud API (V2).
Console(Cow<'a, ConsoleBackend>, T),
Console(MaybeOwned<'a, ConsoleBackend>, T),
/// Authentication via a web browser.
Link(Cow<'a, url::ApiUrl>),
#[cfg(test)]
/// Test backend.
Test(&'a dyn TestBackend),
Link(MaybeOwned<'a, url::ApiUrl>),
}
pub trait TestBackend: Send + Sync + 'static {
@@ -67,14 +78,14 @@ impl std::fmt::Display for BackendType<'_, ()> {
ConsoleBackend::Console(endpoint) => {
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
}
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
ConsoleBackend::Postgres(endpoint) => {
fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
}
#[cfg(test)]
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
},
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
#[cfg(test)]
Test(_) => fmt.debug_tuple("Test").finish(),
}
}
}
@@ -85,10 +96,8 @@ impl<T> BackendType<'_, T> {
pub fn as_ref(&self) -> BackendType<'_, &T> {
use BackendType::*;
match self {
Console(c, x) => Console(Cow::Borrowed(c), x),
Link(c) => Link(Cow::Borrowed(c)),
#[cfg(test)]
Test(x) => Test(*x),
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
Link(c) => Link(MaybeOwned::Borrowed(c)),
}
}
}
@@ -102,8 +111,6 @@ impl<'a, T> BackendType<'a, T> {
match self {
Console(c, x) => Console(c, f(x)),
Link(c) => Link(c),
#[cfg(test)]
Test(x) => Test(x),
}
}
}
@@ -116,8 +123,6 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
match self {
Console(c, x) => x.map(|x| Console(c, x)),
Link(c) => Ok(Link(c)),
#[cfg(test)]
Test(x) => Ok(Test(x)),
}
}
}
@@ -147,7 +152,7 @@ impl ComputeUserInfo {
}
pub enum ComputeCredentialKeys {
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
Password(Vec<u8>),
AuthKeys(AuthKeys),
}
@@ -277,42 +282,6 @@ async fn authenticate_with_secret(
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
}
/// wake a compute (or retrieve an existing compute session from cache)
async fn wake_compute(
ctx: &mut RequestMonitoring,
api: &impl console::Api,
compute_credentials: ComputeCredentials<ComputeCredentialKeys>,
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
let mut num_retries = 0;
let mut node = loop {
let wake_res = api.wake_compute(ctx, &compute_credentials.info).await;
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
return Err(e.into());
}
Ok(ControlFlow::Continue(e)) => {
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
}
Ok(ControlFlow::Break(n)) => break n,
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
tokio::time::sleep(wait_duration).await;
};
ctx.set_project(node.aux.clone());
match compute_credentials.keys {
#[cfg(feature = "testing")]
ComputeCredentialKeys::Password(password) => node.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
};
Ok((node, compute_credentials.info))
}
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<EndpointId> {
@@ -321,8 +290,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
match self {
Console(_, user_info) => user_info.endpoint_id.clone(),
Link(_) => Some("link".into()),
#[cfg(test)]
Test(_) => Some("test".into()),
}
}
@@ -333,8 +300,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
match self {
Console(_, user_info) => &user_info.user,
Link(_) => "link",
#[cfg(test)]
Test(_) => "test",
}
}
@@ -359,8 +324,20 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
let compute_credentials =
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
let (cache_info, user_info) = wake_compute(ctx, &*api, compute_credentials).await?;
(cache_info, BackendType::Console(api, user_info))
let mut num_retries = 0;
let mut node =
wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?;
ctx.set_project(node.aux.clone());
match compute_credentials.keys {
#[cfg(any(test, feature = "testing"))]
ComputeCredentialKeys::Password(password) => node.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
};
(node, BackendType::Console(api, compute_credentials.info))
}
// NOTE: this auth backend doesn't use client credentials.
Link(url) => {
@@ -373,10 +350,6 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
BackendType::Link(url),
)
}
#[cfg(test)]
Test(_) => {
unreachable!("this function should never be called in the test backend")
}
};
info!("user successfully authenticated");
@@ -393,8 +366,6 @@ impl BackendType<'_, ComputeUserInfo> {
match self {
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
#[cfg(test)]
Test(x) => x.get_allowed_ips_and_secret(),
}
}
@@ -409,8 +380,6 @@ impl BackendType<'_, ComputeUserInfo> {
match self {
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
Link(_) => Ok(None),
#[cfg(test)]
Test(x) => x.wake_compute().map(Some),
}
}
}

View File

@@ -20,7 +20,7 @@ pub(super) async fn authenticate(
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
let flow = AuthFlow::new(client);
let scram_keys = match secret {
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));

View File

@@ -172,7 +172,7 @@ pub(super) fn validate_password_and_exchange(
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => {
// test only
Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password(

View File

@@ -1,5 +1,6 @@
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::MaybeOwned;
use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
@@ -17,9 +18,9 @@ use proxy::usage_metrics;
use anyhow::bail;
use proxy::config::{self, ProxyConfig};
use proxy::serverless;
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use std::{borrow::Cow, net::SocketAddr};
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -259,18 +260,13 @@ async fn main() -> anyhow::Result<()> {
}
if let auth::BackendType::Console(api, _) = &config.auth_backend {
match &**api {
proxy::console::provider::ConsoleBackend::Console(api) => {
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
maintenance_tasks
.spawn(notifications::task_main(url.to_owned(), cache.clone()));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
}
#[cfg(feature = "testing")]
proxy::console::provider::ConsoleBackend::Postgres(_) => {}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
@@ -369,18 +365,18 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let api = console::provider::neon::Api::new(endpoint, caches, locks);
let api = console::provider::ConsoleBackend::Console(api);
auth::BackendType::Console(Cow::Owned(api), ())
auth::BackendType::Console(MaybeOwned::Owned(api), ())
}
#[cfg(feature = "testing")]
AuthBackend::Postgres => {
let url = args.auth_endpoint.parse()?;
let api = console::provider::mock::Api::new(url);
let api = console::provider::ConsoleBackend::Postgres(api);
auth::BackendType::Console(Cow::Owned(api), ())
auth::BackendType::Console(MaybeOwned::Owned(api), ())
}
AuthBackend::Link => {
let url = args.uri.parse()?;
auth::BackendType::Link(Cow::Owned(url))
auth::BackendType::Link(MaybeOwned::Owned(url))
}
};
let http_config = HttpConfig {

View File

@@ -1,4 +1,4 @@
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
pub mod mock;
pub mod neon;
@@ -199,7 +199,7 @@ pub mod errors {
/// Auth secret which is managed by the cloud.
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum AuthSecret {
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
/// Md5 hash of user's password.
Md5([u8; 16]),
@@ -264,13 +264,16 @@ pub trait Api {
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
}
#[derive(Clone)]
#[non_exhaustive]
pub enum ConsoleBackend {
/// Current Cloud API (V2).
Console(neon::Api),
/// Local mock of Cloud API (V2).
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
Postgres(mock::Api),
/// Internal testing
#[cfg(test)]
Test(Box<dyn crate::auth::backend::TestBackend>),
}
#[async_trait]
@@ -283,8 +286,10 @@ impl Api for ConsoleBackend {
use ConsoleBackend::*;
match self {
Console(api) => api.get_role_secret(ctx, user_info).await,
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
Postgres(api) => api.get_role_secret(ctx, user_info).await,
#[cfg(test)]
Test(_) => unreachable!("this function should never be called in the test backend"),
}
}
@@ -296,8 +301,10 @@ impl Api for ConsoleBackend {
use ConsoleBackend::*;
match self {
Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
#[cfg(test)]
Test(api) => api.get_allowed_ips_and_secret(),
}
}
@@ -310,8 +317,10 @@ impl Api for ConsoleBackend {
match self {
Console(api) => api.wake_compute(ctx, user_info).await,
#[cfg(feature = "testing")]
#[cfg(any(test, feature = "testing"))]
Postgres(api) => api.wake_compute(ctx, user_info).await,
#[cfg(test)]
Test(api) => api.wake_compute(),
}
}
}

View File

@@ -19,7 +19,6 @@ use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{error, info, info_span, warn, Instrument};
#[derive(Clone)]
pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,

View File

@@ -5,6 +5,7 @@ pub mod connect_compute;
pub mod handshake;
pub mod passthrough;
pub mod retry;
pub mod wake_compute;
use crate::{
auth,

View File

@@ -1,15 +1,16 @@
use crate::{
auth,
compute::{self, PostgresConnection},
console::{self, errors::WakeComputeError, Api},
console::{self, errors::WakeComputeError},
context::RequestMonitoring,
metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES},
proxy::retry::{retry_after, ShouldRetry},
metrics::NUM_CONNECTION_FAILURES,
proxy::{
retry::{retry_after, ShouldRetry},
wake_compute::wake_compute,
},
};
use async_trait::async_trait;
use hyper::StatusCode;
use pq_proto::StartupMessageParams;
use std::ops::ControlFlow;
use tokio::time;
use tracing::{error, info, warn};
@@ -88,39 +89,6 @@ impl ConnectMechanism for TcpMechanism<'_> {
}
}
fn report_error(e: &WakeComputeError, retry: bool) {
use crate::console::errors::ApiError;
let retry = bool_to_str(retry);
let kind = match e {
WakeComputeError::BadComputeAddress(_) => "bad_compute_address",
WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error",
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
ref text,
}) if text.contains("written data quota exceeded")
|| text.contains("the limit for current plan reached") =>
{
"quota_exceeded"
}
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
..
}) => "api_console_locked",
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::BAD_REQUEST,
..
}) => "api_console_bad_request",
WakeComputeError::ApiError(ApiError::Console { status, .. })
if status.is_server_error() =>
{
"api_console_other_server_error"
}
WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error",
WakeComputeError::TimeoutError => "timeout_error",
};
NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc();
}
/// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)]
@@ -137,7 +105,7 @@ where
mechanism.update_connect_config(&mut node_info.config);
// try once
let (config, err) = match mechanism
let err = match mechanism
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
.await
{
@@ -145,51 +113,27 @@ where
ctx.latency_timer.success();
return Ok(res);
}
Err(e) => {
error!(error = ?e, "could not connect to compute node");
(invalidate_cache(node_info), e)
}
Err(e) => e,
};
ctx.latency_timer.cache_miss();
error!(error = ?err, "could not connect to compute node");
let mut num_retries = 1;
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up");
let node_info = loop {
let wake_res = match user_info {
auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await,
// nothing to do?
auth::BackendType::Link(_) => return Err(err.into()),
// test backend
#[cfg(test)]
auth::BackendType::Test(x) => x.wake_compute(),
};
match user_info {
auth::BackendType::Console(api, info) => {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up");
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
report_error(&e, false);
return Err(e.into());
}
// failed to wake up but we can continue to retry
Ok(ControlFlow::Continue(e)) => {
report_error(&e, true);
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
}
// successfully woke up a compute node and can break the wakeup loop
Ok(ControlFlow::Break(mut node_info)) => {
node_info.config.reuse_password(&config);
mechanism.update_connect_config(&mut node_info.config);
break node_info;
}
ctx.latency_timer.cache_miss();
let config = invalidate_cache(node_info);
node_info = wake_compute(&mut num_retries, ctx, api, info).await?;
node_info.config.reuse_password(&config);
mechanism.update_connect_config(&mut node_info.config);
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
time::sleep(wait_duration).await;
// nothing to do?
auth::BackendType::Link(_) => {}
};
// now that we have a new node, try connect to it repeatedly.
@@ -221,23 +165,3 @@ where
time::sleep(wait_duration).await;
}
}
/// Attempts to wake up the compute node.
/// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable
/// * Returns Ok(Break(node)) if the wakeup succeeded
/// * Returns Err(e) if there was an error
pub fn handle_try_wake(
result: Result<console::CachedNodeInfo, WakeComputeError>,
num_retries: u32,
) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
match result {
Err(err) => match &err {
WakeComputeError::ApiError(api) if api.should_retry(num_retries) => {
Ok(ControlFlow::Continue(err))
}
_ => Err(err),
},
// Ready to try again.
Ok(new) => Ok(ControlFlow::Break(new)),
}
}

View File

@@ -5,9 +5,9 @@ mod mitm;
use super::connect_compute::ConnectMechanism;
use super::retry::ShouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, TestBackend};
use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend};
use crate::config::CertResolver;
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
use crate::{auth, http, sasl, scram};
@@ -371,6 +371,7 @@ enum ConnectAction {
Fail,
}
#[derive(Clone)]
struct TestConnectMechanism {
counter: Arc<std::sync::Mutex<usize>>,
sequence: Vec<ConnectAction>,
@@ -490,9 +491,16 @@ fn helper_create_cached_node_info() -> CachedNodeInfo {
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> (CachedNodeInfo, auth::BackendType<'_, ComputeUserInfo>) {
) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) {
let cache = helper_create_cached_node_info();
let user_info = auth::BackendType::Test(mechanism);
let user_info = auth::BackendType::Console(
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
);
(cache, user_info)
}

View File

@@ -0,0 +1,95 @@
use crate::auth::backend::ComputeUserInfo;
use crate::console::{
errors::WakeComputeError,
provider::{CachedNodeInfo, ConsoleBackend},
Api,
};
use crate::context::RequestMonitoring;
use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES};
use crate::proxy::retry::retry_after;
use hyper::StatusCode;
use std::ops::ControlFlow;
use tracing::{error, warn};
use super::retry::ShouldRetry;
/// wake a compute (or retrieve an existing compute session from cache)
pub async fn wake_compute(
num_retries: &mut u32,
ctx: &mut RequestMonitoring,
api: &ConsoleBackend,
info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
loop {
let wake_res = api.wake_compute(ctx, info).await;
match handle_try_wake(wake_res, *num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
report_error(&e, false);
return Err(e);
}
Ok(ControlFlow::Continue(e)) => {
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
report_error(&e, true);
}
Ok(ControlFlow::Break(n)) => return Ok(n),
}
let wait_duration = retry_after(*num_retries);
*num_retries += 1;
tokio::time::sleep(wait_duration).await;
}
}
/// Attempts to wake up the compute node.
/// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable
/// * Returns Ok(Break(node)) if the wakeup succeeded
/// * Returns Err(e) if there was an error
pub fn handle_try_wake(
result: Result<CachedNodeInfo, WakeComputeError>,
num_retries: u32,
) -> Result<ControlFlow<CachedNodeInfo, WakeComputeError>, WakeComputeError> {
match result {
Err(err) => match &err {
WakeComputeError::ApiError(api) if api.should_retry(num_retries) => {
Ok(ControlFlow::Continue(err))
}
_ => Err(err),
},
// Ready to try again.
Ok(new) => Ok(ControlFlow::Break(new)),
}
}
fn report_error(e: &WakeComputeError, retry: bool) {
use crate::console::errors::ApiError;
let retry = bool_to_str(retry);
let kind = match e {
WakeComputeError::BadComputeAddress(_) => "bad_compute_address",
WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error",
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
ref text,
}) if text.contains("written data quota exceeded")
|| text.contains("the limit for current plan reached") =>
{
"quota_exceeded"
}
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
..
}) => "api_console_locked",
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::BAD_REQUEST,
..
}) => "api_console_bad_request",
WakeComputeError::ApiError(ApiError::Console { status, .. })
if status.is_server_error() =>
{
"api_console_other_server_error"
}
WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error",
WakeComputeError::TimeoutError => "timeout_error",
};
NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc();
}

View File

@@ -3,6 +3,7 @@
//! Handles both SQL over HTTP and SQL over Websockets.
mod conn_pool;
mod json;
mod sql_over_http;
mod websocket;

View File

@@ -0,0 +1,448 @@
use serde_json::Map;
use serde_json::Value;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::Row;
//
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
pub fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
json.iter()
.map(|value| {
match value {
// special care for nulls
Value::Null => None,
// convert to text with escaping
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.to_string()),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
}
})
.collect()
}
//
// Serialize a JSON array to a Postgres array. Contrary to the strings in the params
// in the array we need to escape the strings. Postgres is okay with arrays of form
// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
// it for Postgres to check.
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
// recurse into array
Value::Array(arr) => {
let vals = arr
.iter()
.map(json_array_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.collect::<Vec<_>>()
.join(",");
Some(format!("{{{}}}", vals))
}
}
}
//
// Convert postgres row with text-encoded values to JSON object
//
pub fn pg_text_row_to_json(
row: &Row,
columns: &[Type],
raw_output: bool,
array_mode: bool,
) -> Result<Value, anyhow::Error> {
let iter = row
.columns()
.iter()
.zip(columns)
.enumerate()
.map(|(i, (column, typ))| {
let name = column.name();
let pg_value = row.as_text(i)?;
let json_value = if raw_output {
match pg_value {
Some(v) => Value::String(v.to_string()),
None => Value::Null,
}
} else {
pg_text_to_json(pg_value, typ)?
};
Ok((name.to_string(), json_value))
});
if array_mode {
// drop keys and aggregate into array
let arr = iter
.map(|r| r.map(|(_key, val)| val))
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
Ok(Value::Array(arr))
} else {
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
Ok(Value::Object(obj))
}
}
//
// Convert postgres text-encoded value to JSON value
//
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
if let Some(val) = pg_value {
if let Kind::Array(elem_type) = pg_type.kind() {
return pg_array_parse(val, elem_type);
}
match *pg_type {
Type::BOOL => Ok(Value::Bool(val == "t")),
Type::INT2 | Type::INT4 => {
let val = val.parse::<i32>()?;
Ok(Value::Number(serde_json::Number::from(val)))
}
Type::FLOAT4 | Type::FLOAT8 => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
Ok(Value::Number(num))
} else {
// Pass Nan, Inf, -Inf as strings
// JS JSON.stringify() does converts them to null, but we
// want to preserve them, so we pass them as strings
Ok(Value::String(val.to_string()))
}
}
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
_ => Ok(Value::String(val.to_string())),
}
} else {
Ok(Value::Null)
}
}
//
// Parse postgres array into JSON array.
//
// This is a bit involved because we need to handle nested arrays and quoted
// values. Unlike postgres we don't check that all nested arrays have the same
// dimensions, we just return them as is.
//
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
_pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
}
fn _pg_array_parse(
pg_array: &str,
elem_type: &Type,
nested: bool,
) -> Result<(Value, usize), anyhow::Error> {
let mut pg_array_chr = pg_array.char_indices();
let mut level = 0;
let mut quote = false;
let mut entries: Vec<Value> = Vec::new();
let mut entry = String::new();
// skip bounds decoration
if let Some('[') = pg_array.chars().next() {
for (_, c) in pg_array_chr.by_ref() {
if c == '=' {
break;
}
}
}
fn push_checked(
entry: &mut String,
entries: &mut Vec<Value>,
elem_type: &Type,
) -> Result<(), anyhow::Error> {
if !entry.is_empty() {
// While in usual postgres response we get nulls as None and everything else
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
// string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
// here while we have quotation info and convert them to None.
if entry == "NULL" {
entries.push(pg_text_to_json(None, elem_type)?);
} else {
entries.push(pg_text_to_json(Some(entry), elem_type)?);
}
entry.clear();
}
Ok(())
}
while let Some((mut i, mut c)) = pg_array_chr.next() {
let mut escaped = false;
if c == '\\' {
escaped = true;
(i, c) = pg_array_chr.next().unwrap();
}
match c {
'{' if !quote => {
level += 1;
if level > 1 {
let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?;
entries.push(res);
for _ in 0..off - 1 {
pg_array_chr.next();
}
}
}
'}' if !quote => {
level -= 1;
if level == 0 {
push_checked(&mut entry, &mut entries, elem_type)?;
if nested {
return Ok((Value::Array(entries), i));
}
}
}
'"' if !escaped => {
if quote {
// end of quoted string, so push it manually without any checks
// for emptiness or nulls
entries.push(pg_text_to_json(Some(&entry), elem_type)?);
entry.clear();
}
quote = !quote;
}
',' if !quote => {
push_checked(&mut entry, &mut entries, elem_type)?;
}
_ => {
entry.push(c);
}
}
}
if level != 0 {
return Err(anyhow::anyhow!("unbalanced array"));
}
Ok((Value::Array(entries), 0))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_atomic_types_to_pg_params() {
let json = vec![Value::Bool(true), Value::Bool(false)];
let pg_params = json_to_pg_text(json);
assert_eq!(
pg_params,
vec![Some("true".to_owned()), Some("false".to_owned())]
);
let json = vec![Value::Number(serde_json::Number::from(42))];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![Some("42".to_owned())]);
let json = vec![Value::String("foo\"".to_string())];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
let json = vec![Value::Null];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![None]);
}
#[test]
fn test_json_array_to_pg_array() {
// atoms and escaping
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
vec![Some(
"{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
)]
);
// nested arrays
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
vec![Some(
"{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
)]
);
// array of objects
let json = r#"[{"foo": 1},{"bar": 2}]"#;
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
);
}
#[test]
fn test_atomic_types_parse() {
assert_eq!(
pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
json!("foo")
);
assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
assert_eq!(
pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
json!("42")
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
json!("NaN")
);
assert_eq!(
pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
json!("Infinity")
);
assert_eq!(
pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
json!("-Infinity")
);
let json: Value =
serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}")
.unwrap();
assert_eq!(
pg_text_to_json(
Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
&Type::JSONB
)
.unwrap(),
json
);
}
#[test]
fn test_pg_array_parse_text() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::TEXT).unwrap()
}
assert_eq!(
pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
json!(["aa\"\\,a", "cha", "bbbb"])
);
assert_eq!(
pt(r#"{{"foo","bar"},{"bee","bop"}}"#),
json!([["foo", "bar"], ["bee", "bop"]])
);
assert_eq!(
pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#),
json!([[[["foo", null, "bop", "bup"]]]])
);
assert_eq!(
pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#),
json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]])
);
}
#[test]
fn test_pg_array_parse_bool() {
fn pb(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::BOOL).unwrap()
}
assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
assert_eq!(
pb(r#"{{t,f},{f,t}}"#),
json!([[true, false], [false, true]])
);
assert_eq!(
pb(r#"{{t,NULL},{NULL,f}}"#),
json!([[true, null], [null, false]])
);
}
#[test]
fn test_pg_array_parse_numbers() {
fn pn(pg_arr: &str, ty: &Type) -> Value {
pg_array_parse(pg_arr, ty).unwrap()
}
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0]));
assert_eq!(
pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4),
json!([1.1, 2.2, 3.3])
);
assert_eq!(
pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8),
json!([1.1, 2.2, 3.3])
);
assert_eq!(
pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4),
json!(["NaN", "Infinity", "-Infinity"])
);
assert_eq!(
pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8),
json!(["NaN", "Infinity", "-Infinity"])
);
}
#[test]
fn test_pg_array_with_decoration() {
fn p(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::INT2).unwrap()
}
assert_eq!(
p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
json!([[[1, 2, 3], [4, 5, 6]]])
);
}
#[test]
fn test_pg_array_parse_json() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::JSONB).unwrap()
}
assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
assert_eq!(
pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
json!([{"foo": 1, "bar": 2}])
);
assert_eq!(
pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
json!([{"foo": 1}, {"bar": 2}])
);
assert_eq!(
pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
json!([[{"foo": 1}, {"bar": 2}]])
);
}
}

View File

@@ -12,16 +12,12 @@ use hyper::Response;
use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use serde_json::json;
use serde_json::Map;
use serde_json::Value;
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::GenericClient;
use tokio_postgres::IsolationLevel;
use tokio_postgres::ReadyForQueryStatus;
use tokio_postgres::Row;
use tokio_postgres::Transaction;
use tracing::error;
use tracing::instrument;
@@ -40,6 +36,7 @@ use crate::RoleName;
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
use super::json::{json_to_pg_text, pg_text_row_to_json};
use super::SERVERLESS_DRIVER_SNI;
#[derive(serde::Deserialize)]
@@ -72,62 +69,6 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
//
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
json.iter()
.map(|value| {
match value {
// special care for nulls
Value::Null => None,
// convert to text with escaping
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.to_string()),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
}
})
.collect()
}
//
// Serialize a JSON array to a Postgres array. Contrary to the strings in the params
// in the array we need to escape the strings. Postgres is okay with arrays of form
// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
// it for Postgres to check.
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
// recurse into array
Value::Array(arr) => {
let vals = arr
.iter()
.map(json_array_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.collect::<Vec<_>>()
.join(",");
Some(format!("{{{}}}", vals))
}
}
}
fn get_conn_info(
ctx: &mut RequestMonitoring,
headers: &HeaderMap,
@@ -611,389 +552,3 @@ async fn query_to_json<T: GenericClient>(
}),
))
}
//
// Convert postgres row with text-encoded values to JSON object
//
pub fn pg_text_row_to_json(
row: &Row,
columns: &[Type],
raw_output: bool,
array_mode: bool,
) -> Result<Value, anyhow::Error> {
let iter = row
.columns()
.iter()
.zip(columns)
.enumerate()
.map(|(i, (column, typ))| {
let name = column.name();
let pg_value = row.as_text(i)?;
let json_value = if raw_output {
match pg_value {
Some(v) => Value::String(v.to_string()),
None => Value::Null,
}
} else {
pg_text_to_json(pg_value, typ)?
};
Ok((name.to_string(), json_value))
});
if array_mode {
// drop keys and aggregate into array
let arr = iter
.map(|r| r.map(|(_key, val)| val))
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
Ok(Value::Array(arr))
} else {
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
Ok(Value::Object(obj))
}
}
//
// Convert postgres text-encoded value to JSON value
//
pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
if let Some(val) = pg_value {
if let Kind::Array(elem_type) = pg_type.kind() {
return pg_array_parse(val, elem_type);
}
match *pg_type {
Type::BOOL => Ok(Value::Bool(val == "t")),
Type::INT2 | Type::INT4 => {
let val = val.parse::<i32>()?;
Ok(Value::Number(serde_json::Number::from(val)))
}
Type::FLOAT4 | Type::FLOAT8 => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
Ok(Value::Number(num))
} else {
// Pass Nan, Inf, -Inf as strings
// JS JSON.stringify() does converts them to null, but we
// want to preserve them, so we pass them as strings
Ok(Value::String(val.to_string()))
}
}
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
_ => Ok(Value::String(val.to_string())),
}
} else {
Ok(Value::Null)
}
}
//
// Parse postgres array into JSON array.
//
// This is a bit involved because we need to handle nested arrays and quoted
// values. Unlike postgres we don't check that all nested arrays have the same
// dimensions, we just return them as is.
//
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, anyhow::Error> {
_pg_array_parse(pg_array, elem_type, false).map(|(v, _)| v)
}
fn _pg_array_parse(
pg_array: &str,
elem_type: &Type,
nested: bool,
) -> Result<(Value, usize), anyhow::Error> {
let mut pg_array_chr = pg_array.char_indices();
let mut level = 0;
let mut quote = false;
let mut entries: Vec<Value> = Vec::new();
let mut entry = String::new();
// skip bounds decoration
if let Some('[') = pg_array.chars().next() {
for (_, c) in pg_array_chr.by_ref() {
if c == '=' {
break;
}
}
}
fn push_checked(
entry: &mut String,
entries: &mut Vec<Value>,
elem_type: &Type,
) -> Result<(), anyhow::Error> {
if !entry.is_empty() {
// While in usual postgres response we get nulls as None and everything else
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
// string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
// here while we have quotation info and convert them to None.
if entry == "NULL" {
entries.push(pg_text_to_json(None, elem_type)?);
} else {
entries.push(pg_text_to_json(Some(entry), elem_type)?);
}
entry.clear();
}
Ok(())
}
while let Some((mut i, mut c)) = pg_array_chr.next() {
let mut escaped = false;
if c == '\\' {
escaped = true;
(i, c) = pg_array_chr.next().unwrap();
}
match c {
'{' if !quote => {
level += 1;
if level > 1 {
let (res, off) = _pg_array_parse(&pg_array[i..], elem_type, true)?;
entries.push(res);
for _ in 0..off - 1 {
pg_array_chr.next();
}
}
}
'}' if !quote => {
level -= 1;
if level == 0 {
push_checked(&mut entry, &mut entries, elem_type)?;
if nested {
return Ok((Value::Array(entries), i));
}
}
}
'"' if !escaped => {
if quote {
// end of quoted string, so push it manually without any checks
// for emptiness or nulls
entries.push(pg_text_to_json(Some(&entry), elem_type)?);
entry.clear();
}
quote = !quote;
}
',' if !quote => {
push_checked(&mut entry, &mut entries, elem_type)?;
}
_ => {
entry.push(c);
}
}
}
if level != 0 {
return Err(anyhow::anyhow!("unbalanced array"));
}
Ok((Value::Array(entries), 0))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_atomic_types_to_pg_params() {
let json = vec![Value::Bool(true), Value::Bool(false)];
let pg_params = json_to_pg_text(json);
assert_eq!(
pg_params,
vec![Some("true".to_owned()), Some("false".to_owned())]
);
let json = vec![Value::Number(serde_json::Number::from(42))];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![Some("42".to_owned())]);
let json = vec![Value::String("foo\"".to_string())];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
let json = vec![Value::Null];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![None]);
}
#[test]
fn test_json_array_to_pg_array() {
// atoms and escaping
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
vec![Some(
"{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
)]
);
// nested arrays
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
vec![Some(
"{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
)]
);
// array of objects
let json = r#"[{"foo": 1},{"bar": 2}]"#;
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
);
}
#[test]
fn test_atomic_types_parse() {
assert_eq!(
pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
json!("foo")
);
assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
assert_eq!(
pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
json!("42")
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
json!("NaN")
);
assert_eq!(
pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
json!("Infinity")
);
assert_eq!(
pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
json!("-Infinity")
);
let json: Value =
serde_json::from_str("{\"s\":\"str\",\"n\":42,\"f\":4.2,\"a\":[null,3,\"a\"]}")
.unwrap();
assert_eq!(
pg_text_to_json(
Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
&Type::JSONB
)
.unwrap(),
json
);
}
#[test]
fn test_pg_array_parse_text() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::TEXT).unwrap()
}
assert_eq!(
pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
json!(["aa\"\\,a", "cha", "bbbb"])
);
assert_eq!(
pt(r#"{{"foo","bar"},{"bee","bop"}}"#),
json!([["foo", "bar"], ["bee", "bop"]])
);
assert_eq!(
pt(r#"{{{{"foo",NULL,"bop",bup}}}}"#),
json!([[[["foo", null, "bop", "bup"]]]])
);
assert_eq!(
pt(r#"{{"1",2,3},{4,NULL,6},{NULL,NULL,NULL}}"#),
json!([["1", "2", "3"], ["4", null, "6"], [null, null, null]])
);
}
#[test]
fn test_pg_array_parse_bool() {
fn pb(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::BOOL).unwrap()
}
assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
assert_eq!(
pb(r#"{{t,f},{f,t}}"#),
json!([[true, false], [false, true]])
);
assert_eq!(
pb(r#"{{t,NULL},{NULL,f}}"#),
json!([[true, null], [null, false]])
);
}
#[test]
fn test_pg_array_parse_numbers() {
fn pn(pg_arr: &str, ty: &Type) -> Value {
pg_array_parse(pg_arr, ty).unwrap()
}
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT8), json!(["1", "2", "3"]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT4), json!([1.0, 2.0, 3.0]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::FLOAT8), json!([1.0, 2.0, 3.0]));
assert_eq!(
pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT4),
json!([1.1, 2.2, 3.3])
);
assert_eq!(
pn(r#"{1.1,2.2,3.3}"#, &Type::FLOAT8),
json!([1.1, 2.2, 3.3])
);
assert_eq!(
pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT4),
json!(["NaN", "Infinity", "-Infinity"])
);
assert_eq!(
pn(r#"{NaN,Infinity,-Infinity}"#, &Type::FLOAT8),
json!(["NaN", "Infinity", "-Infinity"])
);
}
#[test]
fn test_pg_array_with_decoration() {
fn p(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::INT2).unwrap()
}
assert_eq!(
p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
json!([[[1, 2, 3], [4, 5, 6]]])
);
}
#[test]
fn test_pg_array_parse_json() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::JSONB).unwrap()
}
assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
assert_eq!(
pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
json!([{"foo": 1, "bar": 2}])
);
assert_eq!(
pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
json!([{"foo": 1}, {"bar": 2}])
);
assert_eq!(
pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
json!([[{"foo": 1}, {"bar": 2}]])
);
}
}

View File

@@ -15,7 +15,7 @@ publish = false
[dependencies]
anyhow = { version = "1", features = ["backtrace"] }
aws-config = { version = "1", default-features = false, features = ["rustls", "sso"] }
aws-runtime = { version = "1", default-features = false, features = ["event-stream", "sigv4a"] }
aws-runtime = { version = "1", default-features = false, features = ["event-stream", "http-02x", "sigv4a"] }
aws-sigv4 = { version = "1", features = ["http0-compat", "sign-eventstream", "sigv4a"] }
aws-smithy-async = { version = "1", default-features = false, features = ["rt-tokio"] }
aws-smithy-http = { version = "0.60", default-features = false, features = ["event-stream"] }