mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-03 02:30:37 +00:00
Compare commits
35 Commits
enable_v17
...
proxy-abst
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77e431d80a | ||
|
|
784571eac7 | ||
|
|
ffd0d875cf | ||
|
|
dbf34a17ce | ||
|
|
4001e24745 | ||
|
|
d9a59c5a3f | ||
|
|
c7afbe55c9 | ||
|
|
d0930c9d1d | ||
|
|
addfff61b5 | ||
|
|
cb28721eee | ||
|
|
07076e88a9 | ||
|
|
2feba8a3da | ||
|
|
323bd018cd | ||
|
|
ad267d849f | ||
|
|
8cd7b5bf54 | ||
|
|
47c3c9a413 | ||
|
|
eae4470bb6 | ||
|
|
2d248aea6f | ||
|
|
6c05f89f7d | ||
|
|
db53f98725 | ||
|
|
04a6222418 | ||
|
|
dcf7af5a16 | ||
|
|
37158d0424 | ||
|
|
60fb840e1f | ||
|
|
52232dd85c | ||
|
|
8ef0c38b23 | ||
|
|
56bb1ac458 | ||
|
|
19db9e9aad | ||
|
|
4e9b32c442 | ||
|
|
2fac0b7fac | ||
|
|
e3d6ecaeee | ||
|
|
d785fcb5ff | ||
|
|
552fa2b972 | ||
|
|
9d93dd4807 | ||
|
|
53b6e1a01c |
250
Cargo.lock
generated
250
Cargo.lock
generated
@@ -674,7 +674,6 @@ checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core 0.3.4",
|
||||
"base64 0.21.1",
|
||||
"bitflags 1.3.2",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
@@ -689,13 +688,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper 0.1.2",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -709,11 +702,14 @@ checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core 0.4.5",
|
||||
"base64 0.21.1",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.1.0",
|
||||
"http-body 1.0.0",
|
||||
"http-body-util",
|
||||
"hyper 1.4.1",
|
||||
"hyper-util",
|
||||
"itoa",
|
||||
"matchit 0.7.0",
|
||||
"memchr",
|
||||
@@ -722,10 +718,17 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper 1.0.1",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -763,6 +766,7 @@ dependencies = [
|
||||
"sync_wrapper 1.0.1",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -967,7 +971,7 @@ dependencies = [
|
||||
"clang-sys",
|
||||
"itertools 0.12.1",
|
||||
"log",
|
||||
"prettyplease",
|
||||
"prettyplease 0.2.17",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
@@ -1261,6 +1265,7 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"camino",
|
||||
"cfg-if",
|
||||
"chrono",
|
||||
"clap",
|
||||
@@ -2449,6 +2454,15 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "home"
|
||||
version = "0.5.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5"
|
||||
dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hostname"
|
||||
version = "0.4.0"
|
||||
@@ -2643,15 +2657,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hyper-timeout"
|
||||
version = "0.5.1"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793"
|
||||
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
|
||||
dependencies = [
|
||||
"hyper 1.4.1",
|
||||
"hyper-util",
|
||||
"hyper 0.14.30",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tokio-io-timeout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3457,7 +3470,7 @@ dependencies = [
|
||||
"opentelemetry-http",
|
||||
"opentelemetry-proto",
|
||||
"opentelemetry_sdk",
|
||||
"prost",
|
||||
"prost 0.13.3",
|
||||
"reqwest 0.12.4",
|
||||
"thiserror",
|
||||
]
|
||||
@@ -3470,8 +3483,8 @@ checksum = "30ee9f20bff9c984511a02f082dc8ede839e4a9bf15cc2487c8d6fea5ad850d9"
|
||||
dependencies = [
|
||||
"opentelemetry",
|
||||
"opentelemetry_sdk",
|
||||
"prost",
|
||||
"tonic",
|
||||
"prost 0.13.3",
|
||||
"tonic 0.12.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4165,6 +4178,16 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.1.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.17"
|
||||
@@ -4235,6 +4258,16 @@ dependencies = [
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost-derive 0.11.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost"
|
||||
version = "0.13.3"
|
||||
@@ -4242,28 +4275,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b0487d90e047de87f984913713b85c601c05609aad5b0df4b4573fbf69aa13f"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost-derive",
|
||||
"prost-derive 0.13.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-build"
|
||||
version = "0.13.3"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
|
||||
checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"heck 0.5.0",
|
||||
"itertools 0.12.1",
|
||||
"heck 0.4.1",
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
"log",
|
||||
"multimap",
|
||||
"once_cell",
|
||||
"petgraph",
|
||||
"prettyplease",
|
||||
"prost",
|
||||
"prettyplease 0.1.25",
|
||||
"prost 0.11.9",
|
||||
"prost-types",
|
||||
"regex",
|
||||
"syn 2.0.52",
|
||||
"syn 1.0.109",
|
||||
"tempfile",
|
||||
"which",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-derive"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools 0.10.5",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4281,11 +4328,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "prost-types"
|
||||
version = "0.13.3"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4759aa0d3a6232fb8dbdb97b61de2c20047c68aca932c7ed76da9d788508d670"
|
||||
checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
|
||||
dependencies = [
|
||||
"prost",
|
||||
"prost 0.11.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5047,21 +5094,6 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebbbdb961df0ad3f2652da8f3fdc4b36122f568f968f45ad3316f26c025c677b"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.102.2",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.6.2"
|
||||
@@ -5087,19 +5119,6 @@ dependencies = [
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pemfile 2.1.1",
|
||||
"rustls-pki-types",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "1.0.2"
|
||||
@@ -5175,7 +5194,6 @@ dependencies = [
|
||||
"fail",
|
||||
"futures",
|
||||
"hex",
|
||||
"http 1.1.0",
|
||||
"humantime",
|
||||
"hyper 0.14.30",
|
||||
"metrics",
|
||||
@@ -5732,22 +5750,19 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-stream",
|
||||
"bytes",
|
||||
"clap",
|
||||
"const_format",
|
||||
"futures",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"http-body-util",
|
||||
"humantime",
|
||||
"hyper 1.4.1",
|
||||
"hyper-util",
|
||||
"hyper 0.14.30",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
"prost",
|
||||
"prost 0.11.9",
|
||||
"tokio",
|
||||
"tonic",
|
||||
"tonic 0.9.2",
|
||||
"tonic-build",
|
||||
"tracing",
|
||||
"utils",
|
||||
@@ -6291,17 +6306,6 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
|
||||
dependencies = [
|
||||
"rustls 0.23.7",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.16"
|
||||
@@ -6330,9 +6334,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.20.0"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b2dbec703c26b00d74844519606ef15d09a7d6857860f84ad223dec002ddea2"
|
||||
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
@@ -6393,30 +6397,29 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
version = "0.12.3"
|
||||
version = "0.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52"
|
||||
checksum = "3082666a3a6433f7f511c7192923fa1fe07c69332d3c6a2e6bb040b569199d5a"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum 0.7.5",
|
||||
"base64 0.22.1",
|
||||
"axum 0.6.20",
|
||||
"base64 0.21.1",
|
||||
"bytes",
|
||||
"h2 0.4.4",
|
||||
"http 1.1.0",
|
||||
"http-body 1.0.0",
|
||||
"http-body-util",
|
||||
"hyper 1.4.1",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2 0.3.26",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.30",
|
||||
"hyper-timeout",
|
||||
"hyper-util",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"prost",
|
||||
"rustls-native-certs 0.8.0",
|
||||
"rustls-pemfile 2.1.1",
|
||||
"socket2",
|
||||
"prost 0.11.9",
|
||||
"rustls-native-certs 0.6.2",
|
||||
"rustls-pemfile 1.0.2",
|
||||
"tokio",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-stream",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
@@ -6425,17 +6428,37 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic-build"
|
||||
name = "tonic"
|
||||
version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9557ce109ea773b399c9b9e5dca39294110b74f1f342cb347a80d1fce8c26a11"
|
||||
checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52"
|
||||
dependencies = [
|
||||
"prettyplease",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"http 1.1.0",
|
||||
"http-body 1.0.0",
|
||||
"http-body-util",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"prost 0.13.3",
|
||||
"tokio-stream",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic-build"
|
||||
version = "0.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a6fdaae4c2c638bb70fe42803a26fbd6fc6ac8c72f5c59f67ecc2a2dcabf4b07"
|
||||
dependencies = [
|
||||
"prettyplease 0.1.25",
|
||||
"proc-macro2",
|
||||
"prost-build",
|
||||
"prost-types",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6606,14 +6629,14 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.20.1"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e3dac10fd62eaf6617d3a904ae222845979aec67c615d1c842b4002c7666fb9"
|
||||
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 0.2.9",
|
||||
"http 1.1.0",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.8.5",
|
||||
@@ -6841,7 +6864,7 @@ name = "vm_monitor"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum 0.6.20",
|
||||
"axum 0.7.5",
|
||||
"cgroups-rs",
|
||||
"clap",
|
||||
"futures",
|
||||
@@ -7072,6 +7095,18 @@ dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "which"
|
||||
version = "4.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
|
||||
dependencies = [
|
||||
"either",
|
||||
"home",
|
||||
"once_cell",
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whoami"
|
||||
version = "1.5.1"
|
||||
@@ -7300,14 +7335,9 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
"aws-config",
|
||||
"aws-runtime",
|
||||
"aws-sigv4",
|
||||
"aws-smithy-async",
|
||||
"aws-smithy-http",
|
||||
"aws-smithy-types",
|
||||
"base64 0.21.1",
|
||||
"base64ct",
|
||||
"bitflags 2.4.1",
|
||||
"bytes",
|
||||
"camino",
|
||||
"cc",
|
||||
@@ -7335,6 +7365,7 @@ dependencies = [
|
||||
"hyper 1.4.1",
|
||||
"hyper-util",
|
||||
"indexmap 1.9.3",
|
||||
"itertools 0.10.5",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
@@ -7346,9 +7377,8 @@ dependencies = [
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"parquet",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"prost",
|
||||
"prost 0.11.9",
|
||||
"quote",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
@@ -7371,15 +7401,13 @@ dependencies = [
|
||||
"time",
|
||||
"time-macros",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-util",
|
||||
"toml_edit",
|
||||
"tonic",
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"url",
|
||||
"uuid",
|
||||
"zeroize",
|
||||
"zstd",
|
||||
"zstd-safe",
|
||||
|
||||
16
Cargo.toml
16
Cargo.toml
@@ -53,7 +53,7 @@ azure_storage_blobs = { version = "0.19", default-features = false, features = [
|
||||
flate2 = "1.0.26"
|
||||
async-stream = "0.3"
|
||||
async-trait = "0.1"
|
||||
aws-config = { version = "1.5", default-features = false, features=["rustls"] }
|
||||
aws-config = { version = "1.5", default-features = false, features=["rustls", "sso"] }
|
||||
aws-sdk-s3 = "1.52"
|
||||
aws-sdk-iam = "1.46.0"
|
||||
aws-smithy-async = { version = "1.2.1", default-features = false, features=["rt-tokio"] }
|
||||
@@ -61,7 +61,7 @@ aws-smithy-types = "1.2"
|
||||
aws-credential-types = "1.2.0"
|
||||
aws-sigv4 = { version = "1.2", features = ["sign-http"] }
|
||||
aws-types = "1.3"
|
||||
axum = { version = "0.6.20", features = ["ws"] }
|
||||
axum = { version = "0.7.5", features = ["ws"] }
|
||||
base64 = "0.13.0"
|
||||
bincode = "1.3"
|
||||
bindgen = "0.70"
|
||||
@@ -99,10 +99,10 @@ http-types = { version = "2", default-features = false }
|
||||
http-body-util = "0.1.2"
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
hyper_1 = { package = "hyper", version = "1.4" }
|
||||
hyper0 = { package = "hyper", version = "0.14" }
|
||||
hyper = "1.4"
|
||||
hyper-util = "0.1"
|
||||
tokio-tungstenite = "0.20.0"
|
||||
tokio-tungstenite = "0.21.0"
|
||||
indexmap = "2"
|
||||
indoc = "2"
|
||||
ipnet = "2.9.0"
|
||||
@@ -130,7 +130,7 @@ pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
|
||||
pin-project-lite = "0.2"
|
||||
procfs = "0.16"
|
||||
prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency
|
||||
prost = "0.13"
|
||||
prost = "0.11"
|
||||
rand = "0.8"
|
||||
redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] }
|
||||
regex = "1.10.2"
|
||||
@@ -178,7 +178,7 @@ tokio-tar = "0.3"
|
||||
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
|
||||
toml = "0.8"
|
||||
toml_edit = "0.22"
|
||||
tonic = {version = "0.12.3", features = ["tls", "tls-roots"]}
|
||||
tonic = {version = "0.9", features = ["tls", "tls-roots"]}
|
||||
tower-service = "0.3.2"
|
||||
tracing = "0.1"
|
||||
tracing-error = "0.2"
|
||||
@@ -246,7 +246,7 @@ criterion = "0.5.1"
|
||||
rcgen = "0.12"
|
||||
rstest = "0.18"
|
||||
camino-tempfile = "1.0.2"
|
||||
tonic-build = "0.12"
|
||||
tonic-build = "0.9"
|
||||
|
||||
[patch.crates-io]
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
1. Install XCode and dependencies
|
||||
```
|
||||
xcode-select --install
|
||||
brew install protobuf openssl flex bison icu4c pkg-config
|
||||
brew install protobuf openssl flex bison icu4c pkg-config m4
|
||||
|
||||
# add openssl to PATH, required for ed25519 keys generation in neon_local
|
||||
echo 'export PATH="$(brew --prefix openssl)/bin:$PATH"' >> ~/.zshrc
|
||||
|
||||
@@ -880,9 +880,6 @@ RUN case "${PG_VERSION}" in "v17") \
|
||||
mkdir pg_session_jwt-src && cd pg_session_jwt-src && tar xzf ../pg_session_jwt.tar.gz --strip-components=1 -C . && \
|
||||
sed -i 's/pgrx = "=0.11.3"/pgrx = { version = "=0.11.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
cargo pgrx install --release
|
||||
# it's needed to enable extension because it uses untrusted C language
|
||||
# sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_session_jwt.control && \
|
||||
# echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_session_jwt.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
@@ -1078,6 +1075,20 @@ RUN set -e \
|
||||
&& make -j $(nproc) dist_man_MANS= \
|
||||
&& make install dist_man_MANS=
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Compile the Neon-specific `local_proxy` binary
|
||||
#
|
||||
#########################################################################################
|
||||
FROM $REPOSITORY/$IMAGE:$TAG AS local_proxy
|
||||
ARG BUILD_TAG
|
||||
ENV BUILD_TAG=$BUILD_TAG
|
||||
|
||||
USER nonroot
|
||||
# Copy entire project to get Cargo.* files with proper dependencies for the whole project
|
||||
COPY --chown=nonroot . .
|
||||
RUN mold -run cargo build --locked --profile release-line-debug-size-lto --bin local_proxy
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layers "postgres-exporter" and "sql-exporter"
|
||||
@@ -1216,6 +1227,10 @@ COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-deb
|
||||
COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/pgbouncer
|
||||
COPY --chmod=0666 --chown=postgres compute/etc/pgbouncer.ini /etc/pgbouncer.ini
|
||||
|
||||
# local_proxy and its config
|
||||
COPY --from=local_proxy --chown=postgres /home/nonroot/target/release-line-debug-size-lto/local_proxy /usr/local/bin/local_proxy
|
||||
RUN mkdir -p /etc/local_proxy && chown postgres:postgres /etc/local_proxy
|
||||
|
||||
# Metrics exporter binaries and configuration files
|
||||
COPY --from=postgres-exporter /bin/postgres_exporter /bin/postgres_exporter
|
||||
COPY --from=sql-exporter /bin/sql_exporter /bin/sql_exporter
|
||||
|
||||
@@ -19,6 +19,10 @@ commands:
|
||||
user: postgres
|
||||
sysvInitAction: respawn
|
||||
shell: '/usr/local/bin/pgbouncer /etc/pgbouncer.ini'
|
||||
- name: local_proxy
|
||||
user: postgres
|
||||
sysvInitAction: respawn
|
||||
shell: '/usr/local/bin/local_proxy --config-path /etc/local_proxy/config.json --pid-path /etc/local_proxy/pid --http 0.0.0.0:10432'
|
||||
- name: postgres-exporter
|
||||
user: nobody
|
||||
sysvInitAction: respawn
|
||||
|
||||
@@ -11,12 +11,13 @@ testing = []
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
camino.workspace = true
|
||||
chrono.workspace = true
|
||||
cfg-if.workspace = true
|
||||
clap.workspace = true
|
||||
flate2.workspace = true
|
||||
futures.workspace = true
|
||||
hyper = { workspace = true, features = ["full"] }
|
||||
hyper0 = { workspace = true, features = ["full"] }
|
||||
nix.workspace = true
|
||||
notify.workspace = true
|
||||
num_cpus.workspace = true
|
||||
|
||||
@@ -34,6 +34,7 @@ use nix::sys::signal::{kill, Signal};
|
||||
use remote_storage::{DownloadError, RemotePath};
|
||||
|
||||
use crate::checker::create_availability_check_data;
|
||||
use crate::local_proxy;
|
||||
use crate::logger::inlinify;
|
||||
use crate::pg_helpers::*;
|
||||
use crate::spec::*;
|
||||
@@ -886,6 +887,11 @@ impl ComputeNode {
|
||||
// 'Close' connection
|
||||
drop(client);
|
||||
|
||||
if let Some(ref local_proxy) = spec.local_proxy_config {
|
||||
info!("configuring local_proxy");
|
||||
local_proxy::configure(local_proxy).context("apply_config local_proxy")?;
|
||||
}
|
||||
|
||||
// Run migrations separately to not hold up cold starts
|
||||
thread::spawn(move || {
|
||||
let mut connstr = connstr.clone();
|
||||
@@ -936,6 +942,19 @@ impl ComputeNode {
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(ref local_proxy) = spec.local_proxy_config {
|
||||
info!("configuring local_proxy");
|
||||
|
||||
// Spawn a thread to do the configuration,
|
||||
// so that we don't block the main thread that starts Postgres.
|
||||
let local_proxy = local_proxy.clone();
|
||||
let _handle = Some(thread::spawn(move || {
|
||||
if let Err(err) = local_proxy::configure(&local_proxy) {
|
||||
error!("error while configuring local_proxy: {err:?}");
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Write new config
|
||||
let pgdata_path = Path::new(&self.pgdata);
|
||||
let postgresql_conf_path = pgdata_path.join("postgresql.conf");
|
||||
@@ -1023,6 +1042,19 @@ impl ComputeNode {
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(local_proxy) = &pspec.spec.local_proxy_config {
|
||||
info!("configuring local_proxy");
|
||||
|
||||
// Spawn a thread to do the configuration,
|
||||
// so that we don't block the main thread that starts Postgres.
|
||||
let local_proxy = local_proxy.clone();
|
||||
let _handle = thread::spawn(move || {
|
||||
if let Err(err) = local_proxy::configure(&local_proxy) {
|
||||
error!("error while configuring local_proxy: {err:?}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
info!(
|
||||
"start_compute spec.remote_extensions {:?}",
|
||||
pspec.spec.remote_extensions
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
//! configuration.
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
extern crate hyper0 as hyper;
|
||||
|
||||
pub mod checker;
|
||||
pub mod config;
|
||||
pub mod configurator;
|
||||
@@ -12,6 +15,7 @@ pub mod catalog;
|
||||
pub mod compute;
|
||||
pub mod disk_quota;
|
||||
pub mod extension_server;
|
||||
pub mod local_proxy;
|
||||
pub mod lsn_lease;
|
||||
mod migration;
|
||||
pub mod monitor;
|
||||
|
||||
56
compute_tools/src/local_proxy.rs
Normal file
56
compute_tools/src/local_proxy.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
//! Local Proxy is a feature of our BaaS Neon Authorize project.
|
||||
//!
|
||||
//! Local Proxy validates JWTs and manages the pg_session_jwt extension.
|
||||
//! It also maintains a connection pool to postgres.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use camino::Utf8Path;
|
||||
use compute_api::spec::LocalProxySpec;
|
||||
use nix::sys::signal::Signal;
|
||||
use utils::pid_file::{self, PidFileRead};
|
||||
|
||||
pub fn configure(local_proxy: &LocalProxySpec) -> Result<()> {
|
||||
write_local_proxy_conf("/etc/local_proxy/config.json".as_ref(), local_proxy)?;
|
||||
notify_local_proxy("/etc/local_proxy/pid".as_ref())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create or completely rewrite configuration file specified by `path`
|
||||
fn write_local_proxy_conf(path: &Utf8Path, local_proxy: &LocalProxySpec) -> Result<()> {
|
||||
let config =
|
||||
serde_json::to_string_pretty(local_proxy).context("serializing LocalProxySpec to json")?;
|
||||
std::fs::write(path, config).with_context(|| format!("writing {path}"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Notify local proxy about a new config file.
|
||||
fn notify_local_proxy(path: &Utf8Path) -> Result<()> {
|
||||
match pid_file::read(path)? {
|
||||
// if the file doesn't exist, or isn't locked, local_proxy isn't running
|
||||
// and will naturally pick up our config later
|
||||
PidFileRead::NotExist | PidFileRead::NotHeldByAnyProcess(_) => {}
|
||||
PidFileRead::LockedByOtherProcess(pid) => {
|
||||
// From the pid_file docs:
|
||||
//
|
||||
// > 1. The other process might exit at any time, turning the given PID stale.
|
||||
// > 2. There is a small window in which `claim_for_current_process` has already
|
||||
// > locked the file but not yet updates its contents. [`read`] will return
|
||||
// > this variant here, but with the old file contents, i.e., a stale PID.
|
||||
// >
|
||||
// > The kernel is free to recycle PID once it has been `wait(2)`ed upon by
|
||||
// > its creator. Thus, acting upon a stale PID, e.g., by issuing a `kill`
|
||||
// > system call on it, bears the risk of killing an unrelated process.
|
||||
// > This is an inherent limitation of using pidfiles.
|
||||
// > The only race-free solution is to have a supervisor-process with a lifetime
|
||||
// > that exceeds that of all of its child-processes (e.g., `runit`, `supervisord`).
|
||||
//
|
||||
// This is an ok risk as we only send a SIGHUP which likely won't actually
|
||||
// kill the process, only reload config.
|
||||
nix::sys::signal::kill(pid, Signal::SIGHUP).context("sending signal to local_proxy")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -14,7 +14,7 @@ humantime.workspace = true
|
||||
nix.workspace = true
|
||||
once_cell.workspace = true
|
||||
humantime-serde.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper0.workspace = true
|
||||
regex.workspace = true
|
||||
reqwest = { workspace = true, features = ["blocking", "json"] }
|
||||
scopeguard.workspace = true
|
||||
|
||||
@@ -599,6 +599,7 @@ impl Endpoint {
|
||||
remote_extensions,
|
||||
pgbouncer_settings: None,
|
||||
shard_stripe_size: Some(shard_stripe_size),
|
||||
local_proxy_config: None,
|
||||
};
|
||||
let spec_path = self.endpoint_path().join("spec.json");
|
||||
std::fs::write(spec_path, serde_json::to_string_pretty(&spec)?)?;
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::{
|
||||
local_env::{LocalEnv, NeonStorageControllerConf},
|
||||
};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use hyper::Uri;
|
||||
use hyper0::Uri;
|
||||
use nix::unistd::Pid;
|
||||
use pageserver_api::{
|
||||
controller_api::{
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
Currently we build two main images:
|
||||
|
||||
- [neondatabase/neon](https://hub.docker.com/repository/docker/neondatabase/neon) — image with pre-built `pageserver`, `safekeeper` and `proxy` binaries and all the required runtime dependencies. Built from [/Dockerfile](/Dockerfile).
|
||||
- [neondatabase/compute-node-v16](https://hub.docker.com/repository/docker/neondatabase/compute-node-v16) — compute node image with pre-built Postgres binaries from [neondatabase/postgres](https://github.com/neondatabase/postgres). Similar images exist for v15 and v14.
|
||||
- [neondatabase/compute-node-v16](https://hub.docker.com/repository/docker/neondatabase/compute-node-v16) — compute node image with pre-built Postgres binaries from [neondatabase/postgres](https://github.com/neondatabase/postgres). Similar images exist for v15 and v14. Built from [/compute-node/Dockerfile](/compute/Dockerfile.compute-node).
|
||||
|
||||
And additional intermediate image:
|
||||
|
||||
|
||||
@@ -106,6 +106,10 @@ pub struct ComputeSpec {
|
||||
// Stripe size for pageserver sharding, in pages
|
||||
#[serde(default)]
|
||||
pub shard_stripe_size: Option<usize>,
|
||||
|
||||
/// Local Proxy configuration used for JWT authentication
|
||||
#[serde(default)]
|
||||
pub local_proxy_config: Option<LocalProxySpec>,
|
||||
}
|
||||
|
||||
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.
|
||||
@@ -278,11 +282,13 @@ pub struct GenericOption {
|
||||
/// declare a `trait` on it.
|
||||
pub type GenericOptions = Option<Vec<GenericOption>>;
|
||||
|
||||
/// Configured the local-proxy application with the relevant JWKS and roles it should
|
||||
/// Configured the local_proxy application with the relevant JWKS and roles it should
|
||||
/// use for authorizing connect requests using JWT.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct LocalProxySpec {
|
||||
pub jwks: Vec<JwksSettings>,
|
||||
#[serde(default)]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwks: Option<Vec<JwksSettings>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
|
||||
@@ -296,7 +296,14 @@ pub mod defaults {
|
||||
|
||||
pub const DEFAULT_INGEST_BATCH_SIZE: u64 = 100;
|
||||
|
||||
pub const DEFAULT_MAX_VECTORED_READ_BYTES: usize = 128 * 1024; // 128 KiB
|
||||
/// Soft limit for the maximum size of a vectored read.
|
||||
///
|
||||
/// This is determined by the largest NeonWalRecord that can exist (minus dbdir and reldir keys
|
||||
/// which are bounded by the blob io limits only). As of this writing, that is a `NeonWalRecord::ClogSetCommitted` record,
|
||||
/// with 32k xids. That's the max number of XIDS on a single CLOG page. The size of such a record
|
||||
/// is `sizeof(Transactionid) * 32768 + (some fixed overhead from 'timestamp`, the Vec length and whatever extra serde serialization adds)`.
|
||||
/// That is, slightly above 128 kB.
|
||||
pub const DEFAULT_MAX_VECTORED_READ_BYTES: usize = 130 * 1024; // 130 KiB
|
||||
|
||||
pub const DEFAULT_IMAGE_COMPRESSION: ImageCompressionAlgorithm =
|
||||
ImageCompressionAlgorithm::Zstd { level: Some(1) };
|
||||
|
||||
@@ -748,6 +748,16 @@ impl Key {
|
||||
self.field1 == 0x00 && self.field4 != 0 && self.field6 != 0xffffffff
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn is_rel_dir_key(&self) -> bool {
|
||||
self.field1 == 0x00
|
||||
&& self.field2 != 0
|
||||
&& self.field3 != 0
|
||||
&& self.field4 == 0
|
||||
&& self.field5 == 0
|
||||
&& self.field6 == 1
|
||||
}
|
||||
|
||||
/// Guaranteed to return `Ok()` if [`Self::is_rel_block_key`] returns `true` for `key`.
|
||||
#[inline(always)]
|
||||
pub fn to_rel_block(self) -> anyhow::Result<(RelTag, BlockNumber)> {
|
||||
|
||||
@@ -16,7 +16,7 @@ aws-sdk-s3.workspace = true
|
||||
bytes.workspace = true
|
||||
camino = { workspace = true, features = ["serde1"] }
|
||||
humantime-serde.workspace = true
|
||||
hyper = { workspace = true, features = ["stream"] }
|
||||
hyper0 = { workspace = true, features = ["stream"] }
|
||||
futures.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -14,7 +14,7 @@ use std::time::SystemTime;
|
||||
|
||||
use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
|
||||
use anyhow::Result;
|
||||
use azure_core::request_options::{MaxResults, Metadata, Range};
|
||||
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
|
||||
use azure_core::{Continuable, RetryOptions};
|
||||
use azure_identity::DefaultAzureCredential;
|
||||
use azure_storage::StorageCredentials;
|
||||
@@ -33,10 +33,10 @@ use tracing::debug;
|
||||
use utils::backoff;
|
||||
|
||||
use crate::metrics::{start_measuring_requests, AttemptOutcome, RequestKind};
|
||||
use crate::ListingObject;
|
||||
use crate::{
|
||||
config::AzureConfig, error::Cancelled, ConcurrencyLimiter, Download, DownloadError, Listing,
|
||||
ListingMode, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
|
||||
config::AzureConfig, error::Cancelled, ConcurrencyLimiter, Download, DownloadError,
|
||||
DownloadOpts, Listing, ListingMode, ListingObject, RemotePath, RemoteStorage, StorageMetadata,
|
||||
TimeTravelError, TimeoutOrCancel,
|
||||
};
|
||||
|
||||
pub struct AzureBlobStorage {
|
||||
@@ -259,6 +259,7 @@ fn to_download_error(error: azure_core::Error) -> DownloadError {
|
||||
if let Some(http_err) = error.as_http_error() {
|
||||
match http_err.status() {
|
||||
StatusCode::NotFound => DownloadError::NotFound,
|
||||
StatusCode::NotModified => DownloadError::Unmodified,
|
||||
StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
|
||||
_ => DownloadError::Other(anyhow::Error::new(error)),
|
||||
}
|
||||
@@ -484,11 +485,16 @@ impl RemoteStorage for AzureBlobStorage {
|
||||
async fn download(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
opts: &DownloadOpts,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<Download, DownloadError> {
|
||||
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
|
||||
|
||||
let builder = blob_client.get();
|
||||
let mut builder = blob_client.get();
|
||||
|
||||
if let Some(ref etag) = opts.etag {
|
||||
builder = builder.if_match(IfMatchCondition::NotMatch(etag.to_string()))
|
||||
}
|
||||
|
||||
self.download_for_builder(builder, cancel).await
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ pub enum DownloadError {
|
||||
BadInput(anyhow::Error),
|
||||
/// The file was not found in the remote storage.
|
||||
NotFound,
|
||||
/// The caller provided an ETag, and the file was not modified.
|
||||
Unmodified,
|
||||
/// A cancellation token aborted the download, typically during
|
||||
/// tenant detach or process shutdown.
|
||||
Cancelled,
|
||||
@@ -24,6 +26,7 @@ impl std::fmt::Display for DownloadError {
|
||||
write!(f, "Failed to download a remote file due to user input: {e}")
|
||||
}
|
||||
DownloadError::NotFound => write!(f, "No file found for the remote object id given"),
|
||||
DownloadError::Unmodified => write!(f, "File was not modified"),
|
||||
DownloadError::Cancelled => write!(f, "Cancelled, shutting down"),
|
||||
DownloadError::Timeout => write!(f, "timeout"),
|
||||
DownloadError::Other(e) => write!(f, "Failed to download a remote file: {e:?}"),
|
||||
@@ -38,7 +41,7 @@ impl DownloadError {
|
||||
pub fn is_permanent(&self) -> bool {
|
||||
use DownloadError::*;
|
||||
match self {
|
||||
BadInput(_) | NotFound | Cancelled => true,
|
||||
BadInput(_) | NotFound | Unmodified | Cancelled => true,
|
||||
Timeout | Other(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,6 +161,14 @@ pub struct Listing {
|
||||
pub keys: Vec<ListingObject>,
|
||||
}
|
||||
|
||||
/// Options for downloads. The default value is a plain GET.
|
||||
#[derive(Default)]
|
||||
pub struct DownloadOpts {
|
||||
/// If given, returns [`DownloadError::Unmodified`] if the object still has
|
||||
/// the same ETag (using If-None-Match).
|
||||
pub etag: Option<Etag>,
|
||||
}
|
||||
|
||||
/// Storage (potentially remote) API to manage its state.
|
||||
/// This storage tries to be unaware of any layered repository context,
|
||||
/// providing basic CRUD operations for storage files.
|
||||
@@ -245,6 +253,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
async fn download(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
opts: &DownloadOpts,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<Download, DownloadError>;
|
||||
|
||||
@@ -401,16 +410,18 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
|
||||
}
|
||||
}
|
||||
|
||||
/// See [`RemoteStorage::download`]
|
||||
pub async fn download(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
opts: &DownloadOpts,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<Download, DownloadError> {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.download(from, cancel).await,
|
||||
Self::AwsS3(s) => s.download(from, cancel).await,
|
||||
Self::AzureBlob(s) => s.download(from, cancel).await,
|
||||
Self::Unreliable(s) => s.download(from, cancel).await,
|
||||
Self::LocalFs(s) => s.download(from, opts, cancel).await,
|
||||
Self::AwsS3(s) => s.download(from, opts, cancel).await,
|
||||
Self::AzureBlob(s) => s.download(from, opts, cancel).await,
|
||||
Self::Unreliable(s) => s.download(from, opts, cancel).await,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -572,7 +583,7 @@ impl GenericRemoteStorage {
|
||||
) -> Result<Download, DownloadError> {
|
||||
match byte_range {
|
||||
Some((start, end)) => self.download_byte_range(from, start, end, cancel).await,
|
||||
None => self.download(from, cancel).await,
|
||||
None => self.download(from, &DownloadOpts::default(), cancel).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,8 +23,8 @@ use tokio_util::{io::ReaderStream, sync::CancellationToken};
|
||||
use utils::crashsafe::path_with_suffix_extension;
|
||||
|
||||
use crate::{
|
||||
Download, DownloadError, Listing, ListingMode, ListingObject, RemotePath, TimeTravelError,
|
||||
TimeoutOrCancel, REMOTE_STORAGE_PREFIX_SEPARATOR,
|
||||
Download, DownloadError, DownloadOpts, Listing, ListingMode, ListingObject, RemotePath,
|
||||
TimeTravelError, TimeoutOrCancel, REMOTE_STORAGE_PREFIX_SEPARATOR,
|
||||
};
|
||||
|
||||
use super::{RemoteStorage, StorageMetadata};
|
||||
@@ -494,11 +494,17 @@ impl RemoteStorage for LocalFs {
|
||||
async fn download(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
opts: &DownloadOpts,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<Download, DownloadError> {
|
||||
let target_path = from.with_base(&self.storage_root);
|
||||
|
||||
let file_metadata = file_metadata(&target_path).await?;
|
||||
let etag = mock_etag(&file_metadata);
|
||||
|
||||
if opts.etag.as_ref() == Some(&etag) {
|
||||
return Err(DownloadError::Unmodified);
|
||||
}
|
||||
|
||||
let source = ReaderStream::new(
|
||||
fs::OpenOptions::new()
|
||||
@@ -519,7 +525,6 @@ impl RemoteStorage for LocalFs {
|
||||
let cancel_or_timeout = crate::support::cancel_or_timeout(self.timeout, cancel.clone());
|
||||
let source = crate::support::DownloadStream::new(cancel_or_timeout, source);
|
||||
|
||||
let etag = mock_etag(&file_metadata);
|
||||
Ok(Download {
|
||||
metadata,
|
||||
last_modified: file_metadata
|
||||
@@ -692,7 +697,7 @@ mod fs_tests {
|
||||
) -> anyhow::Result<String> {
|
||||
let cancel = CancellationToken::new();
|
||||
let download = storage
|
||||
.download(remote_storage_path, &cancel)
|
||||
.download(remote_storage_path, &DownloadOpts::default(), &cancel)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Download failed: {e}"))?;
|
||||
ensure!(
|
||||
@@ -773,8 +778,8 @@ mod fs_tests {
|
||||
"We should upload and download the same contents"
|
||||
);
|
||||
|
||||
let non_existing_path = "somewhere/else";
|
||||
match storage.download(&RemotePath::new(Utf8Path::new(non_existing_path))?, &cancel).await {
|
||||
let non_existing_path = RemotePath::new(Utf8Path::new("somewhere/else"))?;
|
||||
match storage.download(&non_existing_path, &DownloadOpts::default(), &cancel).await {
|
||||
Err(DownloadError::NotFound) => {} // Should get NotFound for non existing keys
|
||||
other => panic!("Should get a NotFound error when downloading non-existing storage files, but got: {other:?}"),
|
||||
}
|
||||
@@ -1101,7 +1106,13 @@ mod fs_tests {
|
||||
storage.upload(body, len, &path, None, &cancel).await?;
|
||||
}
|
||||
|
||||
let read = aggregate(storage.download(&path, &cancel).await?.download_stream).await?;
|
||||
let read = aggregate(
|
||||
storage
|
||||
.download(&path, &DownloadOpts::default(), &cancel)
|
||||
.await?
|
||||
.download_stream,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(body, read);
|
||||
|
||||
let shorter = Bytes::from_static(b"shorter body");
|
||||
@@ -1112,7 +1123,13 @@ mod fs_tests {
|
||||
storage.upload(body, len, &path, None, &cancel).await?;
|
||||
}
|
||||
|
||||
let read = aggregate(storage.download(&path, &cancel).await?.download_stream).await?;
|
||||
let read = aggregate(
|
||||
storage
|
||||
.download(&path, &DownloadOpts::default(), &cancel)
|
||||
.await?
|
||||
.download_stream,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(shorter, read);
|
||||
Ok(())
|
||||
}
|
||||
@@ -1145,7 +1162,13 @@ mod fs_tests {
|
||||
storage.upload(body, len, &path, None, &cancel).await?;
|
||||
}
|
||||
|
||||
let read = aggregate(storage.download(&path, &cancel).await?.download_stream).await?;
|
||||
let read = aggregate(
|
||||
storage
|
||||
.download(&path, &DownloadOpts::default(), &cancel)
|
||||
.await?
|
||||
.download_stream,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(body, read);
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -28,12 +28,13 @@ use aws_sdk_s3::{
|
||||
Client,
|
||||
};
|
||||
use aws_smithy_async::rt::sleep::TokioSleep;
|
||||
use http_types::StatusCode;
|
||||
|
||||
use aws_smithy_types::{body::SdkBody, DateTime};
|
||||
use aws_smithy_types::{byte_stream::ByteStream, date_time::ConversionError};
|
||||
use bytes::Bytes;
|
||||
use futures::stream::Stream;
|
||||
use hyper::Body;
|
||||
use hyper0::Body;
|
||||
use scopeguard::ScopeGuard;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::backoff;
|
||||
@@ -44,8 +45,8 @@ use crate::{
|
||||
error::Cancelled,
|
||||
metrics::{start_counting_cancelled_wait, start_measuring_requests},
|
||||
support::PermitCarrying,
|
||||
ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, ListingObject, RemotePath,
|
||||
RemoteStorage, TimeTravelError, TimeoutOrCancel, MAX_KEYS_PER_DELETE,
|
||||
ConcurrencyLimiter, Download, DownloadError, DownloadOpts, Listing, ListingMode, ListingObject,
|
||||
RemotePath, RemoteStorage, TimeTravelError, TimeoutOrCancel, MAX_KEYS_PER_DELETE,
|
||||
REMOTE_STORAGE_PREFIX_SEPARATOR,
|
||||
};
|
||||
|
||||
@@ -67,6 +68,7 @@ pub struct S3Bucket {
|
||||
struct GetObjectRequest {
|
||||
bucket: String,
|
||||
key: String,
|
||||
etag: Option<String>,
|
||||
range: Option<String>,
|
||||
}
|
||||
impl S3Bucket {
|
||||
@@ -248,13 +250,18 @@ impl S3Bucket {
|
||||
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let get_object = self
|
||||
let mut builder = self
|
||||
.client
|
||||
.get_object()
|
||||
.bucket(request.bucket)
|
||||
.key(request.key)
|
||||
.set_range(request.range)
|
||||
.send();
|
||||
.set_range(request.range);
|
||||
|
||||
if let Some(etag) = request.etag {
|
||||
builder = builder.if_none_match(etag);
|
||||
}
|
||||
|
||||
let get_object = builder.send();
|
||||
|
||||
let get_object = tokio::select! {
|
||||
res = get_object => res,
|
||||
@@ -277,6 +284,20 @@ impl S3Bucket {
|
||||
);
|
||||
return Err(DownloadError::NotFound);
|
||||
}
|
||||
Err(SdkError::ServiceError(e))
|
||||
// aws_smithy_runtime_api::http::response::StatusCode isn't
|
||||
// re-exported by any aws crates, so just check the numeric
|
||||
// status against http_types::StatusCode instead of pulling it.
|
||||
if e.raw().status().as_u16() == StatusCode::NotModified =>
|
||||
{
|
||||
// Count an unmodified file as a success.
|
||||
crate::metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
|
||||
kind,
|
||||
AttemptOutcome::Ok,
|
||||
started_at,
|
||||
);
|
||||
return Err(DownloadError::Unmodified);
|
||||
}
|
||||
Err(e) => {
|
||||
crate::metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
|
||||
kind,
|
||||
@@ -773,6 +794,7 @@ impl RemoteStorage for S3Bucket {
|
||||
async fn download(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
opts: &DownloadOpts,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<Download, DownloadError> {
|
||||
// if prefix is not none then download file `prefix/from`
|
||||
@@ -781,6 +803,7 @@ impl RemoteStorage for S3Bucket {
|
||||
GetObjectRequest {
|
||||
bucket: self.bucket_name.clone(),
|
||||
key: self.relative_path_to_s3_object(from),
|
||||
etag: opts.etag.as_ref().map(|e| e.to_string()),
|
||||
range: None,
|
||||
},
|
||||
cancel,
|
||||
@@ -807,6 +830,7 @@ impl RemoteStorage for S3Bucket {
|
||||
GetObjectRequest {
|
||||
bucket: self.bucket_name.clone(),
|
||||
key: self.relative_path_to_s3_object(from),
|
||||
etag: None,
|
||||
range,
|
||||
},
|
||||
cancel,
|
||||
|
||||
@@ -12,8 +12,8 @@ use std::{collections::hash_map::Entry, sync::Arc};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::{
|
||||
Download, DownloadError, GenericRemoteStorage, Listing, ListingMode, RemotePath, RemoteStorage,
|
||||
StorageMetadata, TimeTravelError,
|
||||
Download, DownloadError, DownloadOpts, GenericRemoteStorage, Listing, ListingMode, RemotePath,
|
||||
RemoteStorage, StorageMetadata, TimeTravelError,
|
||||
};
|
||||
|
||||
pub struct UnreliableWrapper {
|
||||
@@ -167,11 +167,12 @@ impl RemoteStorage for UnreliableWrapper {
|
||||
async fn download(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
opts: &DownloadOpts,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<Download, DownloadError> {
|
||||
self.attempt(RemoteOp::Download(from.clone()))
|
||||
.map_err(DownloadError::Other)?;
|
||||
self.inner.download(from, cancel).await
|
||||
self.inner.download(from, opts, cancel).await
|
||||
}
|
||||
|
||||
async fn download_byte_range(
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use anyhow::Context;
|
||||
use camino::Utf8Path;
|
||||
use futures::StreamExt;
|
||||
use remote_storage::ListingMode;
|
||||
use remote_storage::RemotePath;
|
||||
use remote_storage::{DownloadError, DownloadOpts, ListingMode, ListingObject, RemotePath};
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashSet, num::NonZeroU32};
|
||||
use test_context::test_context;
|
||||
@@ -284,7 +283,10 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
|
||||
ctx.client.upload(data, len, &path, None, &cancel).await?;
|
||||
|
||||
// Normal download request
|
||||
let dl = ctx.client.download(&path, &cancel).await?;
|
||||
let dl = ctx
|
||||
.client
|
||||
.download(&path, &DownloadOpts::default(), &cancel)
|
||||
.await?;
|
||||
let buf = download_to_vec(dl).await?;
|
||||
assert_eq!(&buf, &orig);
|
||||
|
||||
@@ -337,6 +339,54 @@ async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Tests that conditional downloads work properly, by returning
|
||||
/// DownloadError::Unmodified when the object ETag matches the given ETag.
|
||||
#[test_context(MaybeEnabledStorage)]
|
||||
#[tokio::test]
|
||||
async fn download_conditional(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> {
|
||||
let MaybeEnabledStorage::Enabled(ctx) = ctx else {
|
||||
return Ok(());
|
||||
};
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
// Create a file.
|
||||
let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str()))?;
|
||||
let data = bytes::Bytes::from_static("foo".as_bytes());
|
||||
let (stream, len) = wrap_stream(data);
|
||||
ctx.client.upload(stream, len, &path, None, &cancel).await?;
|
||||
|
||||
// Download it to obtain its etag.
|
||||
let mut opts = DownloadOpts::default();
|
||||
let download = ctx.client.download(&path, &opts, &cancel).await?;
|
||||
|
||||
// Download with the etag yields DownloadError::Unmodified.
|
||||
opts.etag = Some(download.etag);
|
||||
let result = ctx.client.download(&path, &opts, &cancel).await;
|
||||
assert!(
|
||||
matches!(result, Err(DownloadError::Unmodified)),
|
||||
"expected DownloadError::Unmodified, got {result:?}"
|
||||
);
|
||||
|
||||
// Replace the file contents.
|
||||
let data = bytes::Bytes::from_static("bar".as_bytes());
|
||||
let (stream, len) = wrap_stream(data);
|
||||
ctx.client.upload(stream, len, &path, None, &cancel).await?;
|
||||
|
||||
// A download with the old etag should yield the new file.
|
||||
let download = ctx.client.download(&path, &opts, &cancel).await?;
|
||||
assert_ne!(download.etag, opts.etag.unwrap(), "ETag did not change");
|
||||
|
||||
// A download with the new etag should yield Unmodified again.
|
||||
opts.etag = Some(download.etag);
|
||||
let result = ctx.client.download(&path, &opts, &cancel).await;
|
||||
assert!(
|
||||
matches!(result, Err(DownloadError::Unmodified)),
|
||||
"expected DownloadError::Unmodified, got {result:?}"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test_context(MaybeEnabledStorage)]
|
||||
#[tokio::test]
|
||||
async fn copy_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> {
|
||||
@@ -364,7 +414,10 @@ async fn copy_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> {
|
||||
// Normal download request
|
||||
ctx.client.copy_object(&path, &path_dest, &cancel).await?;
|
||||
|
||||
let dl = ctx.client.download(&path_dest, &cancel).await?;
|
||||
let dl = ctx
|
||||
.client
|
||||
.download(&path_dest, &DownloadOpts::default(), &cancel)
|
||||
.await?;
|
||||
let buf = download_to_vec(dl).await?;
|
||||
assert_eq!(&buf, &orig);
|
||||
|
||||
@@ -376,3 +429,56 @@ async fn copy_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Tests that head_object works properly.
|
||||
#[test_context(MaybeEnabledStorage)]
|
||||
#[tokio::test]
|
||||
async fn head_object(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> {
|
||||
let MaybeEnabledStorage::Enabled(ctx) = ctx else {
|
||||
return Ok(());
|
||||
};
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str()))?;
|
||||
|
||||
// Errors on missing file.
|
||||
let result = ctx.client.head_object(&path, &cancel).await;
|
||||
assert!(
|
||||
matches!(result, Err(DownloadError::NotFound)),
|
||||
"expected NotFound, got {result:?}"
|
||||
);
|
||||
|
||||
// Create the file.
|
||||
let data = bytes::Bytes::from_static("foo".as_bytes());
|
||||
let (stream, len) = wrap_stream(data);
|
||||
ctx.client.upload(stream, len, &path, None, &cancel).await?;
|
||||
|
||||
// Fetch the head metadata.
|
||||
let object = ctx.client.head_object(&path, &cancel).await?;
|
||||
assert_eq!(
|
||||
object,
|
||||
ListingObject {
|
||||
key: path.clone(),
|
||||
last_modified: object.last_modified, // ignore
|
||||
size: 3
|
||||
}
|
||||
);
|
||||
|
||||
// Wait for a couple of seconds, and then update the file to check the last
|
||||
// modified timestamp.
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
|
||||
let data = bytes::Bytes::from_static("bar".as_bytes());
|
||||
let (stream, len) = wrap_stream(data);
|
||||
ctx.client.upload(stream, len, &path, None, &cancel).await?;
|
||||
let new = ctx.client.head_object(&path, &cancel).await?;
|
||||
|
||||
assert!(
|
||||
!new.last_modified
|
||||
.duration_since(object.last_modified)?
|
||||
.is_zero(),
|
||||
"last_modified did not advance"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -12,8 +12,8 @@ use anyhow::Context;
|
||||
use camino::Utf8Path;
|
||||
use futures_util::StreamExt;
|
||||
use remote_storage::{
|
||||
DownloadError, GenericRemoteStorage, ListingMode, RemotePath, RemoteStorageConfig,
|
||||
RemoteStorageKind, S3Config,
|
||||
DownloadError, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath,
|
||||
RemoteStorageConfig, RemoteStorageKind, S3Config,
|
||||
};
|
||||
use test_context::test_context;
|
||||
use test_context::AsyncTestContext;
|
||||
@@ -121,7 +121,8 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
|
||||
|
||||
// A little check to ensure that our clock is not too far off from the S3 clock
|
||||
{
|
||||
let dl = retry(|| ctx.client.download(&path2, &cancel)).await?;
|
||||
let opts = DownloadOpts::default();
|
||||
let dl = retry(|| ctx.client.download(&path2, &opts, &cancel)).await?;
|
||||
let last_modified = dl.last_modified;
|
||||
let half_wt = WAIT_TIME.mul_f32(0.5);
|
||||
let t0_hwt = t0 + half_wt;
|
||||
@@ -159,7 +160,12 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
|
||||
let t2_files_recovered = list_files(&ctx.client, &cancel).await?;
|
||||
println!("after recovery to t2: {t2_files_recovered:?}");
|
||||
assert_eq!(t2_files, t2_files_recovered);
|
||||
let path2_recovered_t2 = download_to_vec(ctx.client.download(&path2, &cancel).await?).await?;
|
||||
let path2_recovered_t2 = download_to_vec(
|
||||
ctx.client
|
||||
.download(&path2, &DownloadOpts::default(), &cancel)
|
||||
.await?,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(path2_recovered_t2, new_data.as_bytes());
|
||||
|
||||
// after recovery to t1: path1 is back, path2 has the old content
|
||||
@@ -170,7 +176,12 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow:
|
||||
let t1_files_recovered = list_files(&ctx.client, &cancel).await?;
|
||||
println!("after recovery to t1: {t1_files_recovered:?}");
|
||||
assert_eq!(t1_files, t1_files_recovered);
|
||||
let path2_recovered_t1 = download_to_vec(ctx.client.download(&path2, &cancel).await?).await?;
|
||||
let path2_recovered_t1 = download_to_vec(
|
||||
ctx.client
|
||||
.download(&path2, &DownloadOpts::default(), &cancel)
|
||||
.await?,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(path2_recovered_t1, old_data.as_bytes());
|
||||
|
||||
// after recovery to t0: everything is gone except for path1
|
||||
@@ -416,7 +427,7 @@ async fn download_is_timeouted(ctx: &mut MaybeEnabledStorage) {
|
||||
let started_at = std::time::Instant::now();
|
||||
let mut stream = ctx
|
||||
.client
|
||||
.download(&path, &cancel)
|
||||
.download(&path, &DownloadOpts::default(), &cancel)
|
||||
.await
|
||||
.expect("download succeeds")
|
||||
.download_stream;
|
||||
@@ -491,7 +502,7 @@ async fn download_is_cancelled(ctx: &mut MaybeEnabledStorage) {
|
||||
{
|
||||
let stream = ctx
|
||||
.client
|
||||
.download(&path, &cancel)
|
||||
.download(&path, &DownloadOpts::default(), &cancel)
|
||||
.await
|
||||
.expect("download succeeds")
|
||||
.download_stream;
|
||||
|
||||
@@ -5,7 +5,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
hyper.workspace = true
|
||||
hyper0.workspace = true
|
||||
opentelemetry = { workspace = true, features = ["trace"] }
|
||||
opentelemetry_sdk = { workspace = true, features = ["rt-tokio"] }
|
||||
opentelemetry-otlp = { workspace = true, default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Tracing wrapper for Hyper HTTP server
|
||||
|
||||
use hyper::HeaderMap;
|
||||
use hyper::{Body, Request, Response};
|
||||
use hyper0::HeaderMap;
|
||||
use hyper0::{Body, Request, Response};
|
||||
use std::future::Future;
|
||||
use tracing::Instrument;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
@@ -22,7 +22,7 @@ chrono.workspace = true
|
||||
git-version.workspace = true
|
||||
hex = { workspace = true, features = ["serde"] }
|
||||
humantime.workspace = true
|
||||
hyper = { workspace = true, features = ["full"] }
|
||||
hyper0 = { workspace = true, features = ["full"] }
|
||||
fail.workspace = true
|
||||
futures = { workspace = true}
|
||||
jsonwebtoken.workspace = true
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
//! between other crates in this repository.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
extern crate hyper0 as hyper;
|
||||
|
||||
pub mod backoff;
|
||||
|
||||
/// `Lsn` type implements common tasks on Log Sequence Numbers
|
||||
|
||||
@@ -7,11 +7,13 @@ use axum::{
|
||||
extract::{ws::WebSocket, State, WebSocketUpgrade},
|
||||
response::Response,
|
||||
};
|
||||
use axum::{routing::get, Router, Server};
|
||||
use axum::{routing::get, Router};
|
||||
use clap::Parser;
|
||||
use futures::Future;
|
||||
use std::net::SocketAddr;
|
||||
use std::{fmt::Debug, time::Duration};
|
||||
use sysinfo::{RefreshKind, System, SystemExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::{sync::broadcast, task::JoinHandle};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info};
|
||||
@@ -132,14 +134,14 @@ pub async fn start(args: &'static Args, token: CancellationToken) -> anyhow::Res
|
||||
args,
|
||||
});
|
||||
|
||||
let addr = args.addr();
|
||||
let bound = Server::try_bind(&addr.parse().expect("parsing address should not fail"))
|
||||
let addr_str = args.addr();
|
||||
let addr: SocketAddr = addr_str.parse().expect("parsing address should not fail");
|
||||
|
||||
let listener = TcpListener::bind(&addr)
|
||||
.await
|
||||
.with_context(|| format!("failed to bind to {addr}"))?;
|
||||
|
||||
info!(addr, "server bound");
|
||||
|
||||
bound
|
||||
.serve(app.into_make_service())
|
||||
info!(addr_str, "server bound");
|
||||
axum::serve(listener, app.into_make_service())
|
||||
.await
|
||||
.context("server exited")?;
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ futures.workspace = true
|
||||
hex.workspace = true
|
||||
humantime.workspace = true
|
||||
humantime-serde.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper0.workspace = true
|
||||
itertools.workspace = true
|
||||
md5.workspace = true
|
||||
nix.workspace = true
|
||||
|
||||
@@ -575,7 +575,7 @@ fn start_pageserver(
|
||||
.build()
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
let service = utils::http::RouterService::new(router).unwrap();
|
||||
let server = hyper::Server::from_tcp(http_listener)?
|
||||
let server = hyper0::Server::from_tcp(http_listener)?
|
||||
.serve(service)
|
||||
.with_graceful_shutdown({
|
||||
let cancel = cancel.clone();
|
||||
|
||||
@@ -13,6 +13,8 @@ pub mod http;
|
||||
pub mod import_datadir;
|
||||
pub mod l0_flush;
|
||||
|
||||
extern crate hyper0 as hyper;
|
||||
|
||||
use futures::{stream::FuturesUnordered, StreamExt};
|
||||
pub use pageserver_api::keyspace;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
@@ -27,7 +27,7 @@ use crate::tenant::Generation;
|
||||
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
|
||||
use crate::virtual_file::{on_fatal_io_error, MaybeFatalIo, VirtualFile};
|
||||
use crate::TEMP_FILE_SUFFIX;
|
||||
use remote_storage::{DownloadError, GenericRemoteStorage, ListingMode, RemotePath};
|
||||
use remote_storage::{DownloadError, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath};
|
||||
use utils::crashsafe::path_with_suffix_extension;
|
||||
use utils::id::{TenantId, TimelineId};
|
||||
use utils::pausable_failpoint;
|
||||
@@ -153,7 +153,9 @@ async fn download_object<'a>(
|
||||
.with_context(|| format!("create a destination file for layer '{dst_path}'"))
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let download = storage.download(src_path, cancel).await?;
|
||||
let download = storage
|
||||
.download(src_path, &DownloadOpts::default(), cancel)
|
||||
.await?;
|
||||
|
||||
pausable_failpoint!("before-downloading-layer-stream-pausable");
|
||||
|
||||
@@ -204,7 +206,9 @@ async fn download_object<'a>(
|
||||
.with_context(|| format!("create a destination file for layer '{dst_path}'"))
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let mut download = storage.download(src_path, cancel).await?;
|
||||
let mut download = storage
|
||||
.download(src_path, &DownloadOpts::default(), cancel)
|
||||
.await?;
|
||||
|
||||
pausable_failpoint!("before-downloading-layer-stream-pausable");
|
||||
|
||||
@@ -344,7 +348,9 @@ async fn do_download_index_part(
|
||||
|
||||
let index_part_bytes = download_retry_forever(
|
||||
|| async {
|
||||
let download = storage.download(&remote_path, cancel).await?;
|
||||
let download = storage
|
||||
.download(&remote_path, &DownloadOpts::default(), cancel)
|
||||
.await?;
|
||||
|
||||
let mut bytes = Vec::new();
|
||||
|
||||
@@ -526,10 +532,15 @@ pub(crate) async fn download_initdb_tar_zst(
|
||||
.with_context(|| format!("tempfile creation {temp_path}"))
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let download = match storage.download(&remote_path, cancel).await {
|
||||
let download = match storage
|
||||
.download(&remote_path, &DownloadOpts::default(), cancel)
|
||||
.await
|
||||
{
|
||||
Ok(dl) => dl,
|
||||
Err(DownloadError::NotFound) => {
|
||||
storage.download(&remote_preserved_path, cancel).await?
|
||||
storage
|
||||
.download(&remote_preserved_path, &DownloadOpts::default(), cancel)
|
||||
.await?
|
||||
}
|
||||
Err(other) => Err(other)?,
|
||||
};
|
||||
|
||||
@@ -49,7 +49,7 @@ use futures::Future;
|
||||
use metrics::UIntGauge;
|
||||
use pageserver_api::models::SecondaryProgress;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use remote_storage::{DownloadError, Etag, GenericRemoteStorage};
|
||||
use remote_storage::{DownloadError, DownloadOpts, Etag, GenericRemoteStorage};
|
||||
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info_span, instrument, warn, Instrument};
|
||||
@@ -944,36 +944,34 @@ impl<'a> TenantDownloader<'a> {
|
||||
) -> Result<HeatMapDownload, UpdateError> {
|
||||
debug_assert_current_span_has_tenant_id();
|
||||
let tenant_shard_id = self.secondary_state.get_tenant_shard_id();
|
||||
// TODO: pull up etag check into the request, to do a conditional GET rather than
|
||||
// issuing a GET and then maybe ignoring the response body
|
||||
// (https://github.com/neondatabase/neon/issues/6199)
|
||||
tracing::debug!("Downloading heatmap for secondary tenant",);
|
||||
|
||||
let heatmap_path = remote_heatmap_path(tenant_shard_id);
|
||||
let cancel = &self.secondary_state.cancel;
|
||||
let opts = DownloadOpts {
|
||||
etag: prev_etag.cloned(),
|
||||
};
|
||||
|
||||
backoff::retry(
|
||||
|| async {
|
||||
let download = self
|
||||
let download = match self
|
||||
.remote_storage
|
||||
.download(&heatmap_path, cancel)
|
||||
.download(&heatmap_path, &opts, cancel)
|
||||
.await
|
||||
.map_err(UpdateError::from)?;
|
||||
{
|
||||
Ok(download) => download,
|
||||
Err(DownloadError::Unmodified) => return Ok(HeatMapDownload::Unmodified),
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
|
||||
SECONDARY_MODE.download_heatmap.inc();
|
||||
|
||||
if Some(&download.etag) == prev_etag {
|
||||
Ok(HeatMapDownload::Unmodified)
|
||||
} else {
|
||||
let mut heatmap_bytes = Vec::new();
|
||||
let mut body = tokio_util::io::StreamReader::new(download.download_stream);
|
||||
let _size = tokio::io::copy_buf(&mut body, &mut heatmap_bytes).await?;
|
||||
Ok(HeatMapDownload::Modified(HeatMapModified {
|
||||
etag: download.etag,
|
||||
last_modified: download.last_modified,
|
||||
bytes: heatmap_bytes,
|
||||
}))
|
||||
}
|
||||
let mut heatmap_bytes = Vec::new();
|
||||
let mut body = tokio_util::io::StreamReader::new(download.download_stream);
|
||||
let _size = tokio::io::copy_buf(&mut body, &mut heatmap_bytes).await?;
|
||||
Ok(HeatMapDownload::Modified(HeatMapModified {
|
||||
etag: download.etag,
|
||||
last_modified: download.last_modified,
|
||||
bytes: heatmap_bytes,
|
||||
}))
|
||||
},
|
||||
|e| matches!(e, UpdateError::NoData | UpdateError::Cancelled),
|
||||
FAILED_DOWNLOAD_WARN_THRESHOLD,
|
||||
@@ -984,6 +982,7 @@ impl<'a> TenantDownloader<'a> {
|
||||
.await
|
||||
.ok_or_else(|| UpdateError::Cancelled)
|
||||
.and_then(|x| x)
|
||||
.inspect(|_| SECONDARY_MODE.download_heatmap.inc())
|
||||
}
|
||||
|
||||
/// Download heatmap layers that are not present on local disk, or update their
|
||||
|
||||
@@ -53,6 +53,7 @@ use camino::{Utf8Path, Utf8PathBuf};
|
||||
use futures::StreamExt;
|
||||
use itertools::Itertools;
|
||||
use pageserver_api::config::MaxVectoredReadBytes;
|
||||
use pageserver_api::key::DBDIR_KEY;
|
||||
use pageserver_api::keyspace::KeySpace;
|
||||
use pageserver_api::models::ImageCompressionAlgorithm;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
@@ -963,14 +964,25 @@ impl DeltaLayerInner {
|
||||
.blobs_at
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|(_, blob_meta)| format!("{}@{}", blob_meta.key, blob_meta.lsn))
|
||||
.filter_map(|(_, blob_meta)| {
|
||||
if blob_meta.key.is_rel_dir_key() || blob_meta.key == DBDIR_KEY {
|
||||
// The size of values for these keys is unbounded and can
|
||||
// grow very large in pathological cases.
|
||||
None
|
||||
} else {
|
||||
Some(format!("{}@{}", blob_meta.key, blob_meta.lsn))
|
||||
}
|
||||
})
|
||||
.join(", ");
|
||||
tracing::warn!(
|
||||
"Oversized vectored read ({} > {}) for keys {}",
|
||||
largest_read_size,
|
||||
read_size_soft_max,
|
||||
offenders
|
||||
);
|
||||
|
||||
if !offenders.is_empty() {
|
||||
tracing::warn!(
|
||||
"Oversized vectored read ({} > {}) for keys {}",
|
||||
largest_read_size,
|
||||
read_size_soft_max,
|
||||
offenders
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
largest_read_size
|
||||
|
||||
@@ -49,6 +49,7 @@ use camino::{Utf8Path, Utf8PathBuf};
|
||||
use hex;
|
||||
use itertools::Itertools;
|
||||
use pageserver_api::config::MaxVectoredReadBytes;
|
||||
use pageserver_api::key::DBDIR_KEY;
|
||||
use pageserver_api::keyspace::KeySpace;
|
||||
use pageserver_api::shard::{ShardIdentity, TenantShardId};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
@@ -587,14 +588,25 @@ impl ImageLayerInner {
|
||||
.blobs_at
|
||||
.as_slice()
|
||||
.iter()
|
||||
.map(|(_, blob_meta)| format!("{}@{}", blob_meta.key, blob_meta.lsn))
|
||||
.filter_map(|(_, blob_meta)| {
|
||||
if blob_meta.key.is_rel_dir_key() || blob_meta.key == DBDIR_KEY {
|
||||
// The size of values for these keys is unbounded and can
|
||||
// grow very large in pathological cases.
|
||||
None
|
||||
} else {
|
||||
Some(format!("{}@{}", blob_meta.key, blob_meta.lsn))
|
||||
}
|
||||
})
|
||||
.join(", ");
|
||||
tracing::warn!(
|
||||
"Oversized vectored read ({} > {}) for keys {}",
|
||||
buf_size,
|
||||
max_vectored_read_bytes,
|
||||
offenders
|
||||
);
|
||||
|
||||
if !offenders.is_empty() {
|
||||
tracing::warn!(
|
||||
"Oversized vectored read ({} > {}) for keys {}",
|
||||
buf_size,
|
||||
max_vectored_read_bytes,
|
||||
offenders
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let buf = BytesMut::with_capacity(buf_size);
|
||||
|
||||
@@ -803,15 +803,19 @@ prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
|
||||
bool is_prefetch)
|
||||
{
|
||||
uint64 min_ring_index;
|
||||
PrefetchRequest req;
|
||||
PrefetchRequest hashkey;
|
||||
#if USE_ASSERT_CHECKING
|
||||
bool any_hits = false;
|
||||
#endif
|
||||
/* We will never read further ahead than our buffer can store. */
|
||||
nblocks = Max(1, Min(nblocks, readahead_buffer_size));
|
||||
|
||||
/* use an intermediate PrefetchRequest struct to ensure correct alignment */
|
||||
req.buftag = tag;
|
||||
/*
|
||||
* Use an intermediate PrefetchRequest struct as the hash key to ensure
|
||||
* correct alignment and that the padding bytes are cleared.
|
||||
*/
|
||||
memset(&hashkey.buftag, 0, sizeof(BufferTag));
|
||||
hashkey.buftag = tag;
|
||||
|
||||
Retry:
|
||||
min_ring_index = UINT64_MAX;
|
||||
@@ -837,8 +841,8 @@ Retry:
|
||||
slot = NULL;
|
||||
entry = NULL;
|
||||
|
||||
req.buftag.blockNum = tag.blockNum + i;
|
||||
entry = prfh_lookup(MyPState->prf_hash, (PrefetchRequest *) &req);
|
||||
hashkey.buftag.blockNum = tag.blockNum + i;
|
||||
entry = prfh_lookup(MyPState->prf_hash, &hashkey);
|
||||
|
||||
if (entry != NULL)
|
||||
{
|
||||
@@ -849,7 +853,7 @@ Retry:
|
||||
Assert(slot->status != PRFS_UNUSED);
|
||||
Assert(MyPState->ring_last <= ring_index &&
|
||||
ring_index < MyPState->ring_unused);
|
||||
Assert(BUFFERTAGS_EQUAL(slot->buftag, req.buftag));
|
||||
Assert(BUFFERTAGS_EQUAL(slot->buftag, hashkey.buftag));
|
||||
|
||||
/*
|
||||
* If the caller specified a request LSN to use, only accept
|
||||
@@ -886,12 +890,19 @@ Retry:
|
||||
{
|
||||
min_ring_index = Min(min_ring_index, ring_index);
|
||||
/* The buffered request is good enough, return that index */
|
||||
pgBufferUsage.prefetch.duplicates++;
|
||||
if (is_prefetch)
|
||||
pgBufferUsage.prefetch.duplicates++;
|
||||
else
|
||||
pgBufferUsage.prefetch.hits++;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else if (!is_prefetch)
|
||||
{
|
||||
pgBufferUsage.prefetch.misses += 1;
|
||||
MyNeonCounters->getpage_prefetch_misses_total++;
|
||||
}
|
||||
/*
|
||||
* We can only leave the block above by finding that there's
|
||||
* no entry that can satisfy this request, either because there
|
||||
@@ -974,7 +985,7 @@ Retry:
|
||||
* We must update the slot data before insertion, because the hash
|
||||
* function reads the buffer tag from the slot.
|
||||
*/
|
||||
slot->buftag = req.buftag;
|
||||
slot->buftag = hashkey.buftag;
|
||||
slot->shard_no = get_shard_number(&tag);
|
||||
slot->my_ring_index = ring_index;
|
||||
|
||||
@@ -2742,14 +2753,19 @@ neon_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber base_block
|
||||
uint64 ring_index;
|
||||
PrfHashEntry *entry;
|
||||
PrefetchRequest *slot;
|
||||
BufferTag buftag = {0};
|
||||
PrefetchRequest hashkey;
|
||||
|
||||
Assert(PointerIsValid(request_lsns));
|
||||
Assert(nblocks >= 1);
|
||||
|
||||
CopyNRelFileInfoToBufTag(buftag, rinfo);
|
||||
buftag.forkNum = forkNum;
|
||||
buftag.blockNum = base_blockno;
|
||||
/*
|
||||
* Use an intermediate PrefetchRequest struct as the hash key to ensure
|
||||
* correct alignment and that the padding bytes are cleared.
|
||||
*/
|
||||
memset(&hashkey.buftag, 0, sizeof(BufferTag));
|
||||
CopyNRelFileInfoToBufTag(hashkey.buftag, rinfo);
|
||||
hashkey.buftag.forkNum = forkNum;
|
||||
hashkey.buftag.blockNum = base_blockno;
|
||||
|
||||
/*
|
||||
* The redo process does not lock pages that it needs to replay but are
|
||||
@@ -2767,7 +2783,7 @@ neon_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber base_block
|
||||
* weren't for the behaviour of the LwLsn cache that uses the highest
|
||||
* value of the LwLsn cache when the entry is not found.
|
||||
*/
|
||||
prefetch_register_bufferv(buftag, request_lsns, nblocks, mask, false);
|
||||
prefetch_register_bufferv(hashkey.buftag, request_lsns, nblocks, mask, false);
|
||||
|
||||
for (int i = 0; i < nblocks; i++)
|
||||
{
|
||||
@@ -2788,8 +2804,8 @@ neon_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber base_block
|
||||
* Try to find prefetched page in the list of received pages.
|
||||
*/
|
||||
Retry:
|
||||
buftag.blockNum = blockno;
|
||||
entry = prfh_lookup(MyPState->prf_hash, (PrefetchRequest *) &buftag);
|
||||
hashkey.buftag.blockNum = blockno;
|
||||
entry = prfh_lookup(MyPState->prf_hash, &hashkey);
|
||||
|
||||
if (entry != NULL)
|
||||
{
|
||||
@@ -2797,7 +2813,6 @@ Retry:
|
||||
if (neon_prefetch_response_usable(reqlsns, slot))
|
||||
{
|
||||
ring_index = slot->my_ring_index;
|
||||
pgBufferUsage.prefetch.hits += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -2827,10 +2842,7 @@ Retry:
|
||||
{
|
||||
if (entry == NULL)
|
||||
{
|
||||
pgBufferUsage.prefetch.misses += 1;
|
||||
MyNeonCounters->getpage_prefetch_misses_total++;
|
||||
|
||||
ring_index = prefetch_register_bufferv(buftag, reqlsns, 1, NULL, false);
|
||||
ring_index = prefetch_register_bufferv(hashkey.buftag, reqlsns, 1, NULL, false);
|
||||
Assert(ring_index != UINT64_MAX);
|
||||
slot = GetPrfSlot(ring_index);
|
||||
}
|
||||
@@ -2855,8 +2867,8 @@ Retry:
|
||||
} while (!prefetch_wait_for(ring_index));
|
||||
|
||||
Assert(slot->status == PRFS_RECEIVED);
|
||||
Assert(memcmp(&buftag, &slot->buftag, sizeof(BufferTag)) == 0);
|
||||
Assert(buftag.blockNum == base_blockno + i);
|
||||
Assert(memcmp(&hashkey.buftag, &slot->buftag, sizeof(BufferTag)) == 0);
|
||||
Assert(hashkey.buftag.blockNum == base_blockno + i);
|
||||
|
||||
resp = slot->response;
|
||||
|
||||
@@ -3059,6 +3071,9 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
|
||||
lfc_result = lfc_readv_select(InfoFromSMgrRel(reln), forknum, blocknum, buffers,
|
||||
nblocks, read);
|
||||
|
||||
if (lfc_result > 0)
|
||||
MyNeonCounters->file_cache_hits_total += lfc_result;
|
||||
|
||||
/* Read all blocks from LFC, so we're done */
|
||||
if (lfc_result == nblocks)
|
||||
return;
|
||||
|
||||
@@ -191,13 +191,14 @@ NeonOnDemandXLogReaderRoutines(XLogReaderRoutine *xlr)
|
||||
|
||||
if (!wal_reader)
|
||||
{
|
||||
XLogRecPtr epochStartLsn = pg_atomic_read_u64(&GetWalpropShmemState()->propEpochStartLsn);
|
||||
XLogRecPtr basebackupLsn = GetRedoStartLsn();
|
||||
|
||||
if (epochStartLsn == 0)
|
||||
/* should never happen */
|
||||
if (basebackupLsn == 0)
|
||||
{
|
||||
elog(ERROR, "Unable to start walsender when propEpochStartLsn is 0!");
|
||||
elog(ERROR, "unable to start walsender when basebackupLsn is 0");
|
||||
}
|
||||
wal_reader = NeonWALReaderAllocate(wal_segment_size, epochStartLsn, "[walsender] ");
|
||||
wal_reader = NeonWALReaderAllocate(wal_segment_size, basebackupLsn, "[walsender] ");
|
||||
}
|
||||
xlr->page_read = NeonWALPageRead;
|
||||
xlr->segment_open = NeonWALReadSegmentOpen;
|
||||
|
||||
@@ -38,7 +38,7 @@ hostname.workspace = true
|
||||
http.workspace = true
|
||||
humantime.workspace = true
|
||||
humantime-serde.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper0.workspace = true
|
||||
hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
|
||||
hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] }
|
||||
http-body-util = { version = "0.1" }
|
||||
|
||||
@@ -3,8 +3,8 @@ use crate::{
|
||||
auth::{self, backend::ComputeCredentialKeys, AuthFlow},
|
||||
compute,
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
control_plane::AuthSecret,
|
||||
sasl,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
use crate::{
|
||||
auth, compute,
|
||||
auth,
|
||||
cache::Cached,
|
||||
compute,
|
||||
config::AuthenticationConfig,
|
||||
console::{self, provider::NodeInfo},
|
||||
context::RequestMonitoring,
|
||||
control_plane::{self, provider::NodeInfo, CachedNodeInfo},
|
||||
error::{ReportableError, UserFacingError},
|
||||
proxy::connect_compute::ComputeConnectBackend,
|
||||
stream::PqStream,
|
||||
waiters,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{info, info_span};
|
||||
|
||||
use super::ComputeCredentialKeys;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum WebAuthError {
|
||||
#[error(transparent)]
|
||||
@@ -25,6 +31,11 @@ pub(crate) enum WebAuthError {
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConsoleRedirectBackend {
|
||||
console_uri: reqwest::Url,
|
||||
}
|
||||
|
||||
impl UserFacingError for WebAuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
"Internal error".to_string()
|
||||
@@ -57,7 +68,40 @@ pub(crate) fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
impl ConsoleRedirectBackend {
|
||||
pub fn new(console_uri: reqwest::Url) -> Self {
|
||||
Self { console_uri }
|
||||
}
|
||||
|
||||
pub(crate) async fn authenticate(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<ConsoleRedirectNodeInfo> {
|
||||
authenticate(ctx, auth_config, &self.console_uri, client)
|
||||
.await
|
||||
.map(ConsoleRedirectNodeInfo)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
|
||||
|
||||
#[async_trait]
|
||||
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
Ok(Cached::new_uncached(self.0.clone()))
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
&ComputeCredentialKeys::None
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
ctx: &RequestMonitoring,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
link_uri: &reqwest::Url,
|
||||
@@ -70,7 +114,7 @@ pub(super) async fn authenticate(
|
||||
let (psql_session_id, waiter) = loop {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
|
||||
match console::mgmt::get_waiter(&psql_session_id) {
|
||||
match control_plane::mgmt::get_waiter(&psql_session_id) {
|
||||
Ok(waiter) => break (psql_session_id, waiter),
|
||||
Err(_e) => continue,
|
||||
}
|
||||
@@ -2,8 +2,8 @@ use super::{ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint};
|
||||
use crate::{
|
||||
auth::{self, AuthFlow},
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
control_plane::AuthSecret,
|
||||
intern::EndpointIdInt,
|
||||
sasl,
|
||||
stream::{self, Stream},
|
||||
|
||||
@@ -5,11 +5,11 @@ use arc_swap::ArcSwapOption;
|
||||
|
||||
use crate::{
|
||||
compute::ConnCfg,
|
||||
console::{
|
||||
context::RequestMonitoring,
|
||||
control_plane::{
|
||||
messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo},
|
||||
NodeInfo,
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag},
|
||||
EndpointId,
|
||||
};
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
mod classic;
|
||||
mod console_redirect;
|
||||
mod hacks;
|
||||
pub mod jwt;
|
||||
pub mod local;
|
||||
mod web;
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub use console_redirect::ConsoleRedirectBackend;
|
||||
pub(crate) use console_redirect::WebAuthError;
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
use local::LocalBackend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
use tracing::{info, warn};
|
||||
pub(crate) use web::WebAuthError;
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::{validate_password_and_exchange, AuthError};
|
||||
use crate::cache::Cached;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::{AuthSecret, NodeInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::provider::ControlPlaneBackend;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
@@ -31,55 +31,29 @@ use crate::stream::Stream;
|
||||
use crate::{
|
||||
auth::{self, ComputeUserInfoMaybeEndpoint},
|
||||
config::AuthenticationConfig,
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedAllowedIps, CachedNodeInfo},
|
||||
Api,
|
||||
},
|
||||
stream, url,
|
||||
control_plane::{self, provider::CachedNodeInfo, Api},
|
||||
stream,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
|
||||
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
|
||||
pub enum MaybeOwned<'a, T> {
|
||||
Owned(T),
|
||||
Borrowed(&'a T),
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for MaybeOwned<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
match self {
|
||||
MaybeOwned::Owned(t) => t,
|
||||
MaybeOwned::Borrowed(t) => t,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This type serves two purposes:
|
||||
///
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
/// which we use in [`crate::config::ProxyConfig`].
|
||||
///
|
||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum Backend<'a, T, D> {
|
||||
/// The [crate::serverless] module can authenticate either using control-plane
|
||||
/// to get authentication state, or by using JWKs stored in the filesystem.
|
||||
pub enum ServerlessBackend<'a> {
|
||||
/// Cloud API (V2).
|
||||
Console(MaybeOwned<'a, ConsoleBackend>, T),
|
||||
/// Authentication via a web browser.
|
||||
Web(MaybeOwned<'a, url::ApiUrl>, D),
|
||||
ControlPlane(&'a ControlPlaneBackend),
|
||||
/// Local proxy uses configured auth credentials and does not wake compute
|
||||
Local(MaybeOwned<'a, LocalBackend>),
|
||||
Local(&'a LocalBackend),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::control_plane::provider::{CachedAllowedIps, CachedRoleSecret};
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), control_plane::errors::GetAuthInfoError>;
|
||||
fn dyn_clone(&self) -> Box<dyn TestBackend>;
|
||||
}
|
||||
|
||||
@@ -90,58 +64,20 @@ impl Clone for Box<dyn TestBackend> {
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Backend<'_, (), ()> {
|
||||
impl std::fmt::Display for ControlPlaneBackend {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Console(api, ()) => match &**api {
|
||||
ConsoleBackend::Console(endpoint) => {
|
||||
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
|
||||
}
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ConsoleBackend::Postgres(endpoint) => {
|
||||
fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
|
||||
}
|
||||
#[cfg(test)]
|
||||
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
},
|
||||
Self::Web(url, ()) => fmt.debug_tuple("Web").field(&url.as_str()).finish(),
|
||||
Self::Local(_) => fmt.debug_tuple("Local").finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, D> Backend<'_, T, D> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub(crate) fn as_ref(&self) -> Backend<'_, &T, &D> {
|
||||
match self {
|
||||
Self::Console(c, x) => Backend::Console(MaybeOwned::Borrowed(c), x),
|
||||
Self::Web(c, x) => Backend::Web(MaybeOwned::Borrowed(c), x),
|
||||
Self::Local(l) => Backend::Local(MaybeOwned::Borrowed(l)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, D> Backend<'a, T, D> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`Backend<T>`] to [`Backend<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> Backend<'a, R, D> {
|
||||
match self {
|
||||
Self::Console(c, x) => Backend::Console(c, f(x)),
|
||||
Self::Web(c, x) => Backend::Web(c, x),
|
||||
Self::Local(l) => Backend::Local(l),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a, T, D, E> Backend<'a, Result<T, E>, D> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub(crate) fn transpose(self) -> Result<Backend<'a, T, D>, E> {
|
||||
match self {
|
||||
Self::Console(c, x) => x.map(|x| Backend::Console(c, x)),
|
||||
Self::Web(c, x) => Ok(Backend::Web(c, x)),
|
||||
Self::Local(l) => Ok(Backend::Local(l)),
|
||||
ControlPlaneBackend::Management(endpoint) => fmt
|
||||
.debug_tuple("ControlPlane::Management")
|
||||
.field(&endpoint.url())
|
||||
.finish(),
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ControlPlaneBackend::PostgresMock(endpoint) => fmt
|
||||
.debug_tuple("ControlPlane::PostgresMock")
|
||||
.field(&endpoint.url())
|
||||
.finish(),
|
||||
#[cfg(test)]
|
||||
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -234,7 +170,6 @@ impl AuthenticationConfig {
|
||||
pub(crate) fn check_rate_limit(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
endpoint: &EndpointId,
|
||||
is_cleartext: bool,
|
||||
@@ -258,7 +193,7 @@ impl AuthenticationConfig {
|
||||
let limit_not_exceeded = self.rate_limiter.check(
|
||||
(
|
||||
endpoint_int,
|
||||
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
|
||||
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
|
||||
),
|
||||
password_weight,
|
||||
);
|
||||
@@ -290,7 +225,7 @@ impl AuthenticationConfig {
|
||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
||||
async fn auth_quirks(
|
||||
ctx: &RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
api: &impl control_plane::Api,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
@@ -332,7 +267,6 @@ async fn auth_quirks(
|
||||
let secret = if let Some(secret) = secret {
|
||||
config.check_rate_limit(
|
||||
ctx,
|
||||
config,
|
||||
secret,
|
||||
&info.endpoint,
|
||||
unauthenticated_password.is_some() || allow_cleartext,
|
||||
@@ -408,131 +342,79 @@ async fn authenticate_with_secret(
|
||||
classic::authenticate(ctx, info, client, config, secret).await
|
||||
}
|
||||
|
||||
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
/// Get username from the credentials.
|
||||
pub(crate) fn get_user(&self) -> &str {
|
||||
match self {
|
||||
Self::Console(_, user_info) => &user_info.user,
|
||||
Self::Web(_, ()) => "web",
|
||||
Self::Local(_) => "local",
|
||||
}
|
||||
}
|
||||
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
impl ControlPlaneBackend {
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub(crate) async fn authenticate(
|
||||
self,
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<Backend<'a, ComputeCredentials, NodeInfo>> {
|
||||
let res = match self {
|
||||
Self::Console(api, user_info) => {
|
||||
info!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
) -> auth::Result<ControlPlaneComputeBackend> {
|
||||
info!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let credentials = auth_quirks(
|
||||
ctx,
|
||||
&*api,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
Backend::Console(api, credentials)
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Self::Web(url, ()) => {
|
||||
info!("performing web authentication");
|
||||
|
||||
let info = web::authenticate(ctx, config, &url, client).await?;
|
||||
|
||||
Backend::Web(url, info)
|
||||
}
|
||||
Self::Local(_) => {
|
||||
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
|
||||
}
|
||||
};
|
||||
let credentials = auth_quirks(
|
||||
ctx,
|
||||
self,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!("user successfully authenticated");
|
||||
Ok(res)
|
||||
Ok(ControlPlaneComputeBackend {
|
||||
api: self,
|
||||
creds: credentials,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn attach_to_credentials(
|
||||
&self,
|
||||
creds: ComputeCredentials,
|
||||
) -> ControlPlaneComputeBackend {
|
||||
ControlPlaneComputeBackend { api: self, creds }
|
||||
}
|
||||
}
|
||||
|
||||
impl Backend<'_, ComputeUserInfo, &()> {
|
||||
pub(crate) async fn get_role_secret(
|
||||
pub struct ControlPlaneComputeBackend<'a> {
|
||||
api: &'a ControlPlaneBackend,
|
||||
creds: ComputeCredentials,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for ControlPlaneComputeBackend<'_> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::Web(_, ()) => Ok(Cached::new_uncached(None)),
|
||||
Self::Local(_) => Ok(Cached::new_uncached(None)),
|
||||
}
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
self.api.wake_compute(ctx, &self.creds.info).await
|
||||
}
|
||||
|
||||
pub(crate) async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::Web(_, ()) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
Self::Local(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
&self.creds.keys
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for Backend<'_, ComputeCredentials, NodeInfo> {
|
||||
impl ComputeConnectBackend for LocalBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Self::Web(_, info) => Ok(Cached::new_uncached(info.clone())),
|
||||
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
|
||||
}
|
||||
_ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
Ok(Cached::new_uncached(self.node_info.clone()))
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
match self {
|
||||
Self::Console(_, creds) => &creds.keys,
|
||||
Self::Web(_, _) => &ComputeCredentialKeys::None,
|
||||
Self::Local(_) => &ComputeCredentialKeys::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for Backend<'_, ComputeCredentials, &()> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Self::Web(_, ()) => {
|
||||
unreachable!("web auth flow doesn't support waking the compute")
|
||||
}
|
||||
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
match self {
|
||||
Self::Console(_, creds) => &creds.keys,
|
||||
Self::Web(_, ()) => &ComputeCredentialKeys::None,
|
||||
Self::Local(_) => &ComputeCredentialKeys::None,
|
||||
}
|
||||
&ComputeCredentialKeys::None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -553,12 +435,12 @@ mod tests {
|
||||
use crate::{
|
||||
auth::{backend::MaskedIp, ComputeUserInfoMaybeEndpoint, IpPattern},
|
||||
config::AuthenticationConfig,
|
||||
console::{
|
||||
context::RequestMonitoring,
|
||||
control_plane::{
|
||||
self,
|
||||
provider::{self, CachedAllowedIps, CachedRoleSecret},
|
||||
CachedNodeInfo,
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
proxy::NeonOptions,
|
||||
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
|
||||
scram::{threadpool::ThreadPool, ServerSecret},
|
||||
@@ -572,12 +454,12 @@ mod tests {
|
||||
secret: AuthSecret,
|
||||
}
|
||||
|
||||
impl console::Api for Auth {
|
||||
impl control_plane::Api for Auth {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
|
||||
) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
|
||||
}
|
||||
|
||||
@@ -585,8 +467,10 @@ mod tests {
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
|
||||
{
|
||||
) -> Result<
|
||||
(CachedAllowedIps, Option<CachedRoleSecret>),
|
||||
control_plane::errors::GetAuthInfoError,
|
||||
> {
|
||||
Ok((
|
||||
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
|
||||
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
|
||||
@@ -605,7 +489,7 @@ mod tests {
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &super::ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
@@ -3,8 +3,8 @@
|
||||
use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
control_plane::AuthSecret,
|
||||
intern::EndpointIdInt,
|
||||
sasl,
|
||||
scram::{self, threadpool::ThreadPool},
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::Backend;
|
||||
pub use backend::ServerlessBackend;
|
||||
|
||||
mod credentials;
|
||||
pub(crate) use credentials::{
|
||||
@@ -18,7 +18,7 @@ pub(crate) use flow::*;
|
||||
use tokio::time::error::Elapsed;
|
||||
|
||||
use crate::{
|
||||
console,
|
||||
control_plane,
|
||||
error::{ReportableError, UserFacingError},
|
||||
};
|
||||
use std::{io, net::IpAddr};
|
||||
@@ -34,7 +34,7 @@ pub(crate) enum AuthErrorImpl {
|
||||
Web(#[from] backend::WebAuthError),
|
||||
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
|
||||
GetAuthInfo(#[from] control_plane::errors::GetAuthInfoError),
|
||||
|
||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||
#[error(transparent)]
|
||||
@@ -6,13 +6,16 @@ use compute_api::spec::LocalProxySpec;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::Either;
|
||||
use proxy::{
|
||||
auth::backend::{
|
||||
jwt::JwkCache,
|
||||
local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
auth::{
|
||||
self,
|
||||
backend::{
|
||||
jwt::JwkCache,
|
||||
local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
},
|
||||
},
|
||||
cancellation::CancellationHandlerMain,
|
||||
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
|
||||
console::{
|
||||
control_plane::{
|
||||
locks::ApiLocks,
|
||||
messages::{EndpointJwksResponse, JwksSettings},
|
||||
},
|
||||
@@ -77,10 +80,10 @@ struct LocalProxyCliArgs {
|
||||
#[clap(long, default_value = "127.0.0.1:5432")]
|
||||
compute: SocketAddr,
|
||||
/// Path of the local proxy config file
|
||||
#[clap(long, default_value = "./localproxy.json")]
|
||||
#[clap(long, default_value = "./local_proxy.json")]
|
||||
config_path: Utf8PathBuf,
|
||||
/// Path of the local proxy PID file
|
||||
#[clap(long, default_value = "./localproxy.pid")]
|
||||
#[clap(long, default_value = "./local_proxy.pid")]
|
||||
pid_path: Utf8PathBuf,
|
||||
}
|
||||
|
||||
@@ -109,7 +112,7 @@ struct SqlOverHttpArgs {
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let _logging_guard = proxy::logging::init().await?;
|
||||
let _logging_guard = proxy::logging::init_local_proxy()?;
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
|
||||
@@ -132,13 +135,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let args = LocalProxyCliArgs::parse();
|
||||
let config = build_config(&args)?;
|
||||
let auth_backend = build_auth_backend(&args)?;
|
||||
|
||||
// before we bind to any ports, write the process ID to a file
|
||||
// so that compute-ctl can find our process later
|
||||
// in order to trigger the appropriate SIGHUP on config change.
|
||||
//
|
||||
// This also claims a "lock" that makes sure only one instance
|
||||
// of local-proxy runs at a time.
|
||||
// of local_proxy runs at a time.
|
||||
let _process_guard = loop {
|
||||
match pid_file::claim_for_current_process(&args.pid_path) {
|
||||
Ok(guard) => break guard,
|
||||
@@ -164,12 +168,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
16,
|
||||
));
|
||||
|
||||
// write the process ID to a file so that compute-ctl can find our process later
|
||||
// in order to trigger the appropriate SIGHUP on config change.
|
||||
let pid = std::process::id();
|
||||
info!("process running in PID {pid}");
|
||||
std::fs::write(args.pid_path, format!("{pid}\n")).context("writing PID to file")?;
|
||||
|
||||
let mut maintenance_tasks = JoinSet::new();
|
||||
|
||||
let refresh_config_notify = Arc::new(Notify::new());
|
||||
@@ -182,9 +180,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
// trigger the first config load **after** setting up the signal hook
|
||||
// to avoid the race condition where:
|
||||
// 1. No config file registered when local-proxy starts up
|
||||
// 1. No config file registered when local_proxy starts up
|
||||
// 2. The config file is written but the signal hook is not yet received
|
||||
// 3. local-proxy completes startup but has no config loaded, despite there being a registerd config.
|
||||
// 3. local_proxy completes startup but has no config loaded, despite there being a registerd config.
|
||||
refresh_config_notify.notify_one();
|
||||
tokio::spawn(refresh_config_loop(args.config_path, refresh_config_notify));
|
||||
|
||||
@@ -199,6 +197,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let task = serverless::task_main(
|
||||
config,
|
||||
auth::ServerlessBackend::Local(auth_backend),
|
||||
http_listener,
|
||||
shutdown.clone(),
|
||||
Arc::new(CancellationHandlerMain::new(
|
||||
@@ -263,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
|
||||
Ok(Box::leak(Box::new(ProxyConfig {
|
||||
tls_config: None,
|
||||
auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
|
||||
LocalBackend::new(args.compute),
|
||||
)),
|
||||
metric_collection: None,
|
||||
allow_self_signed_compute: false,
|
||||
http_config,
|
||||
@@ -292,6 +288,13 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
})))
|
||||
}
|
||||
|
||||
/// auth::Backend is created at proxy startup, and lives forever.
|
||||
fn build_auth_backend(args: &LocalProxyCliArgs) -> anyhow::Result<&'static LocalBackend> {
|
||||
let auth_backend = LocalBackend::new(args.compute);
|
||||
|
||||
Ok(Box::leak(Box::new(auth_backend)))
|
||||
}
|
||||
|
||||
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
|
||||
loop {
|
||||
rx.notified().await;
|
||||
@@ -311,7 +314,7 @@ async fn refresh_config_inner(path: &Utf8Path) -> anyhow::Result<()> {
|
||||
|
||||
let mut jwks_set = vec![];
|
||||
|
||||
for jwks in data.jwks {
|
||||
for jwks in data.jwks.into_iter().flatten() {
|
||||
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
|
||||
|
||||
ensure!(
|
||||
|
||||
@@ -10,7 +10,7 @@ use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::AuthRateLimiter;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::auth::backend::ConsoleRedirectBackend;
|
||||
use proxy::cancellation::CancelMap;
|
||||
use proxy::cancellation::CancellationHandler;
|
||||
use proxy::config::remote_storage_from_toml;
|
||||
@@ -19,8 +19,9 @@ use proxy::config::CacheOptions;
|
||||
use proxy::config::HttpConfig;
|
||||
use proxy::config::ProjectInfoCacheOptions;
|
||||
use proxy::config::ProxyProtocolV2;
|
||||
use proxy::console;
|
||||
use proxy::context::parquet::ParquetUploadArgs;
|
||||
use proxy::control_plane;
|
||||
use proxy::control_plane::provider::ControlPlaneBackend;
|
||||
use proxy::http;
|
||||
use proxy::http::health_server::AppMetrics;
|
||||
use proxy::metrics::Metrics;
|
||||
@@ -311,8 +312,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let args = ProxyCliArgs::parse();
|
||||
let config = build_config(&args)?;
|
||||
let auth_backend = build_auth_backend(&args)?;
|
||||
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
match auth_backend {
|
||||
Either::Left(auth_backend) => info!("Authentication backend: {auth_backend}"),
|
||||
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
|
||||
};
|
||||
info!("Using region: {}", args.aws_region);
|
||||
|
||||
let region_provider =
|
||||
@@ -459,24 +464,41 @@ async fn main() -> anyhow::Result<()> {
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
match auth_backend {
|
||||
Either::Left(auth_backend) => {
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(serverless_listener) = serverless_listener {
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
if let Some(serverless_listener) = serverless_listener {
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
auth::ServerlessBackend::ControlPlane(auth_backend),
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Either::Right(auth_backend) => {
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(proxy::console_redirect_proxy::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client_tasks.spawn(proxy::context::parquet::worker(
|
||||
@@ -495,7 +517,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
proxy: proxy::metrics::Metrics::get(),
|
||||
},
|
||||
));
|
||||
maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
|
||||
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
|
||||
|
||||
if let Some(metrics_config) = &config.metric_collection {
|
||||
// TODO: Add gc regardles of the metric collection being enabled.
|
||||
@@ -506,40 +528,38 @@ async fn main() -> anyhow::Result<()> {
|
||||
));
|
||||
}
|
||||
|
||||
if let auth::Backend::Console(api, _) = &config.auth_backend {
|
||||
if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
|
||||
match (redis_notifications_client, regional_redis_client.clone()) {
|
||||
(None, None) => {}
|
||||
(client1, client2) => {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
if let Either::Left(ControlPlaneBackend::Management(api)) = &auth_backend {
|
||||
match (redis_notifications_client, regional_redis_client.clone()) {
|
||||
(None, None) => {}
|
||||
(client1, client2) => {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = regional_redis_client;
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks.spawn(
|
||||
async move { cache.do_read(con, cancellation_token.clone()).await }
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = regional_redis_client;
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks.spawn(
|
||||
async move { cache.do_read(con, cancellation_token.clone()).await }
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -610,73 +630,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
bail!("dynamic rate limiter should be disabled");
|
||||
}
|
||||
|
||||
let auth_backend = match &args.auth_backend {
|
||||
AuthBackendType::Console => {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse()?;
|
||||
let endpoint_cache_config: config::EndpointCacheConfig =
|
||||
args.endpoint_cache_config.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
|
||||
info!(
|
||||
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
|
||||
);
|
||||
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
|
||||
let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
endpoint_cache_config,
|
||||
)));
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.wake_compute_lock.parse()?;
|
||||
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
|
||||
let locks = Box::leak(Box::new(console::locks::ApiLocks::new(
|
||||
"wake_compute_lock",
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().wake_compute_lock,
|
||||
)?));
|
||||
tokio::spawn(locks.garbage_collect_worker());
|
||||
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let endpoint = http::Endpoint::new(url, http::new_client());
|
||||
|
||||
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
|
||||
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
|
||||
let wake_compute_endpoint_rate_limiter =
|
||||
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
|
||||
let api = console::provider::neon::Api::new(
|
||||
endpoint,
|
||||
caches,
|
||||
locks,
|
||||
wake_compute_endpoint_rate_limiter,
|
||||
);
|
||||
let api = console::provider::ConsoleBackend::Console(api);
|
||||
auth::Backend::Console(MaybeOwned::Owned(api), ())
|
||||
}
|
||||
|
||||
AuthBackendType::Web => {
|
||||
let url = args.uri.parse()?;
|
||||
auth::Backend::Web(MaybeOwned::Owned(url), ())
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
AuthBackendType::Postgres => {
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let api = console::provider::mock::Api::new(url, !args.is_private_access_proxy);
|
||||
let api = console::provider::ConsoleBackend::Postgres(api);
|
||||
auth::Backend::Console(MaybeOwned::Owned(api), ())
|
||||
}
|
||||
};
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
limiter,
|
||||
@@ -689,7 +642,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
?epoch,
|
||||
"Using NodeLocks (connect_compute)"
|
||||
);
|
||||
let connect_compute_locks = console::locks::ApiLocks::new(
|
||||
let connect_compute_locks = control_plane::locks::ApiLocks::new(
|
||||
"connect_compute_lock",
|
||||
limiter,
|
||||
shards,
|
||||
@@ -726,9 +679,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
let config = ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
metric_collection,
|
||||
allow_self_signed_compute: args.allow_self_signed_compute,
|
||||
http_config,
|
||||
@@ -741,13 +693,97 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
connect_to_compute_retry_config: config::RetryConfig::parse(
|
||||
&args.connect_to_compute_retry,
|
||||
)?,
|
||||
}));
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(config));
|
||||
|
||||
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// auth::Backend is created at proxy startup, and lives forever.
|
||||
fn build_auth_backend(
|
||||
args: &ProxyCliArgs,
|
||||
) -> anyhow::Result<Either<&'static ControlPlaneBackend, &'static ConsoleRedirectBackend>> {
|
||||
match &args.auth_backend {
|
||||
AuthBackendType::Console => {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse()?;
|
||||
let endpoint_cache_config: config::EndpointCacheConfig =
|
||||
args.endpoint_cache_config.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
|
||||
info!(
|
||||
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
|
||||
);
|
||||
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
|
||||
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
endpoint_cache_config,
|
||||
)));
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.wake_compute_lock.parse()?;
|
||||
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
|
||||
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
|
||||
"wake_compute_lock",
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().wake_compute_lock,
|
||||
)?));
|
||||
tokio::spawn(locks.garbage_collect_worker());
|
||||
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let endpoint = http::Endpoint::new(url, http::new_client());
|
||||
|
||||
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
|
||||
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
|
||||
let wake_compute_endpoint_rate_limiter =
|
||||
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
|
||||
let api = control_plane::provider::neon::Api::new(
|
||||
endpoint,
|
||||
caches,
|
||||
locks,
|
||||
wake_compute_endpoint_rate_limiter,
|
||||
);
|
||||
let auth_backend = control_plane::provider::ControlPlaneBackend::Management(api);
|
||||
|
||||
let config = Box::leak(Box::new(auth_backend));
|
||||
|
||||
Ok(Either::Left(config))
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
AuthBackendType::Postgres => {
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let api = control_plane::provider::mock::Api::new(url, !args.is_private_access_proxy);
|
||||
let auth_backend = control_plane::provider::ControlPlaneBackend::PostgresMock(api);
|
||||
|
||||
let config = Box::leak(Box::new(auth_backend));
|
||||
|
||||
Ok(Either::Left(config))
|
||||
}
|
||||
|
||||
AuthBackendType::Web => {
|
||||
let url = args.uri.parse()?;
|
||||
let backend = ConsoleRedirectBackend::new(url);
|
||||
|
||||
let config = Box::leak(Box::new(backend));
|
||||
|
||||
Ok(Either::Right(config))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
2
proxy/src/cache/project_info.rs
vendored
2
proxy/src/cache/project_info.rs
vendored
@@ -16,7 +16,7 @@ use tracing::{debug, info};
|
||||
use crate::{
|
||||
auth::IpPattern,
|
||||
config::ProjectInfoCacheOptions,
|
||||
console::AuthSecret,
|
||||
control_plane::AuthSecret,
|
||||
intern::{EndpointIdInt, ProjectIdInt, RoleNameInt},
|
||||
EndpointId, RoleName,
|
||||
};
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use crate::{
|
||||
auth::parse_endpoint_param,
|
||||
cancellation::CancelClosure,
|
||||
console::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
|
||||
context::RequestMonitoring,
|
||||
control_plane::{errors::WakeComputeError, messages::MetricsAuxInfo, provider::ApiLockError},
|
||||
error::{ReportableError, UserFacingError},
|
||||
metrics::{Metrics, NumDbConnectionsGuard},
|
||||
proxy::neon_option,
|
||||
@@ -20,7 +20,7 @@ use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum ConnectionError {
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
use crate::{
|
||||
auth::{
|
||||
self,
|
||||
backend::{jwt::JwkCache, AuthRateLimiter},
|
||||
},
|
||||
console::locks::ApiLocks,
|
||||
auth::backend::{jwt::JwkCache, AuthRateLimiter},
|
||||
control_plane::locks::ApiLocks,
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
||||
scram::threadpool::ThreadPool,
|
||||
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
|
||||
@@ -29,7 +26,6 @@ use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::Backend<'static, (), ()>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub http_config: HttpConfig,
|
||||
@@ -372,7 +368,7 @@ pub struct EndpointCacheConfig {
|
||||
}
|
||||
|
||||
impl EndpointCacheConfig {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
/// Default options for [`crate::control_plane::provider::NodeInfoCache`].
|
||||
/// Notice that by default the limiter is empty, which means that cache is disabled.
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
|
||||
@@ -447,7 +443,7 @@ pub struct CacheOptions {
|
||||
}
|
||||
|
||||
impl CacheOptions {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
/// Default options for [`crate::control_plane::provider::NodeInfoCache`].
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
@@ -503,7 +499,7 @@ pub struct ProjectInfoCacheOptions {
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheOptions {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
/// Default options for [`crate::control_plane::provider::NodeInfoCache`].
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"size=10000,ttl=4m,max_roles=10,gc_interval=60m";
|
||||
|
||||
@@ -622,9 +618,9 @@ pub struct ConcurrencyLockOptions {
|
||||
}
|
||||
|
||||
impl ConcurrencyLockOptions {
|
||||
/// Default options for [`crate::console::provider::ApiLocks`].
|
||||
/// Default options for [`crate::control_plane::provider::ApiLocks`].
|
||||
pub const DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK: &'static str = "permits=0";
|
||||
/// Default options for [`crate::console::provider::ApiLocks`].
|
||||
/// Default options for [`crate::control_plane::provider::ApiLocks`].
|
||||
pub const DEFAULT_OPTIONS_CONNECT_COMPUTE_LOCK: &'static str =
|
||||
"shards=64,permits=100,epoch=10m,timeout=10ms";
|
||||
|
||||
|
||||
161
proxy/src/console_redirect_proxy.rs
Normal file
161
proxy/src/console_redirect_proxy.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
use crate::auth::backend::ConsoleRedirectBackend;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::metrics::Protocol;
|
||||
use crate::proxy::{prepare_client_connection, transition_connection, ClientRequestError};
|
||||
use crate::{
|
||||
cancellation::CancellationHandlerMain,
|
||||
context::RequestMonitoring,
|
||||
metrics::{Metrics, NumClientConnectionsGuard},
|
||||
proxy::handshake::{handshake, HandshakeData},
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{info, Instrument};
|
||||
|
||||
use crate::proxy::{
|
||||
connect_compute::{connect_to_compute, TcpMechanism},
|
||||
passthrough::ProxyPassthrough,
|
||||
};
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
super::connection_loop(
|
||||
config,
|
||||
listener,
|
||||
cancellation_token,
|
||||
Protocol::Tcp,
|
||||
C {
|
||||
config,
|
||||
backend,
|
||||
cancellation_handler,
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct C {
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
}
|
||||
|
||||
impl super::ConnHandler for C {
|
||||
async fn handle(
|
||||
self,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
socket: crate::protocol2::ChainRW<tokio::net::TcpStream>,
|
||||
conn_gauge: crate::metrics::NumClientConnectionsGuard<'static>,
|
||||
) {
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, Protocol::Tcp, &self.config.region);
|
||||
let span = ctx.span();
|
||||
|
||||
let startup = Box::pin(
|
||||
handle_client(
|
||||
self.config,
|
||||
self.backend,
|
||||
&ctx,
|
||||
self.cancellation_handler,
|
||||
socket,
|
||||
conn_gauge,
|
||||
)
|
||||
.instrument(span.clone()),
|
||||
);
|
||||
|
||||
let res = startup.await;
|
||||
transition_connection(ctx, res).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
ctx: &RequestMonitoring,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
stream: S,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
info!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
|
||||
let (mut stream, params) =
|
||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
return Ok(cancellation_handler
|
||||
.cancel_session(cancel_key_data, ctx.session_id())
|
||||
.await
|
||||
.map(|()| None)?)
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let user_info = match backend
|
||||
.authenticate(ctx, &config.authentication_config, &mut stream)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
return stream.throw_error(e).await?;
|
||||
}
|
||||
};
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&user_info,
|
||||
config.allow_self_signed_compute,
|
||||
config.wake_compute_retry_config,
|
||||
config.connect_to_compute_retry_config,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
let session = cancellation_handler.get_session();
|
||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_cancel: session,
|
||||
}))
|
||||
}
|
||||
@@ -11,7 +11,7 @@ use try_lock::TryLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
console::messages::{ColdStartInfo, MetricsAuxInfo},
|
||||
control_plane::messages::{ColdStartInfo, MetricsAuxInfo},
|
||||
error::ErrorKind,
|
||||
intern::{BranchIdInt, ProjectIdInt},
|
||||
metrics::{ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting},
|
||||
@@ -10,14 +10,14 @@ use crate::proxy::retry::CouldRetry;
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub(crate) struct ConsoleError {
|
||||
pub(crate) struct ControlPlaneError {
|
||||
pub(crate) error: Box<str>,
|
||||
#[serde(skip)]
|
||||
pub(crate) http_status_code: http::StatusCode,
|
||||
pub(crate) status: Option<Status>,
|
||||
}
|
||||
|
||||
impl ConsoleError {
|
||||
impl ControlPlaneError {
|
||||
pub(crate) fn get_reason(&self) -> Reason {
|
||||
self.status
|
||||
.as_ref()
|
||||
@@ -51,7 +51,7 @@ impl ConsoleError {
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ConsoleError {
|
||||
impl Display for ControlPlaneError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let msg: &str = self
|
||||
.status
|
||||
@@ -62,7 +62,7 @@ impl Display for ConsoleError {
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for ConsoleError {
|
||||
impl CouldRetry for ControlPlaneError {
|
||||
fn could_retry(&self) -> bool {
|
||||
// If the error message does not have a status,
|
||||
// the error is unknown and probably should not retry automatically
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
console::messages::{DatabaseInfo, KickSession},
|
||||
control_plane::messages::{DatabaseInfo, KickSession},
|
||||
waiters::{self, Waiter, Waiters},
|
||||
};
|
||||
use anyhow::Context;
|
||||
@@ -10,7 +10,7 @@ use crate::{
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::IpPattern, cache::Cached};
|
||||
use crate::{
|
||||
console::{
|
||||
control_plane::{
|
||||
messages::MetricsAuxInfo,
|
||||
provider::{CachedAllowedIps, CachedRoleSecret},
|
||||
},
|
||||
@@ -166,7 +166,7 @@ impl Api {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
@@ -2,7 +2,7 @@
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use super::messages::{ConsoleError, MetricsAuxInfo};
|
||||
use super::messages::{ControlPlaneError, MetricsAuxInfo};
|
||||
use crate::{
|
||||
auth::{
|
||||
backend::{
|
||||
@@ -28,7 +28,7 @@ use tracing::info;
|
||||
|
||||
pub(crate) mod errors {
|
||||
use crate::{
|
||||
console::messages::{self, ConsoleError, Reason},
|
||||
control_plane::messages::{self, ControlPlaneError, Reason},
|
||||
error::{io_error, ErrorKind, ReportableError, UserFacingError},
|
||||
proxy::retry::CouldRetry,
|
||||
};
|
||||
@@ -44,7 +44,7 @@ pub(crate) mod errors {
|
||||
pub(crate) enum ApiError {
|
||||
/// Error returned by the console itself.
|
||||
#[error("{REQUEST_FAILED} with {0}")]
|
||||
Console(ConsoleError),
|
||||
ControlPlane(ControlPlaneError),
|
||||
|
||||
/// Various IO errors like broken pipe or malformed payload.
|
||||
#[error("{REQUEST_FAILED}: {0}")]
|
||||
@@ -55,7 +55,7 @@ pub(crate) mod errors {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
pub(crate) fn get_reason(&self) -> messages::Reason {
|
||||
match self {
|
||||
ApiError::Console(e) => e.get_reason(),
|
||||
ApiError::ControlPlane(e) => e.get_reason(),
|
||||
ApiError::Transport(_) => messages::Reason::Unknown,
|
||||
}
|
||||
}
|
||||
@@ -65,7 +65,7 @@ pub(crate) mod errors {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// To minimize risks, only select errors are forwarded to users.
|
||||
ApiError::Console(c) => c.get_user_facing_message(),
|
||||
ApiError::ControlPlane(c) => c.get_user_facing_message(),
|
||||
ApiError::Transport(_) => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
@@ -74,7 +74,7 @@ pub(crate) mod errors {
|
||||
impl ReportableError for ApiError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiError::Console(e) => match e.get_reason() {
|
||||
ApiError::ControlPlane(e) => match e.get_reason() {
|
||||
Reason::RoleProtected => ErrorKind::User,
|
||||
Reason::ResourceNotFound => ErrorKind::User,
|
||||
Reason::ProjectNotFound => ErrorKind::User,
|
||||
@@ -91,12 +91,12 @@ pub(crate) mod errors {
|
||||
Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
|
||||
Reason::RunningOperations => ErrorKind::ControlPlane,
|
||||
Reason::Unknown => match &e {
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
http_status_code:
|
||||
http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
|
||||
..
|
||||
} => crate::error::ErrorKind::User,
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
|
||||
error,
|
||||
..
|
||||
@@ -105,7 +105,7 @@ pub(crate) mod errors {
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
http_status_code: http::StatusCode::LOCKED,
|
||||
error,
|
||||
..
|
||||
@@ -114,11 +114,11 @@ pub(crate) mod errors {
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
|
||||
..
|
||||
} => crate::error::ErrorKind::ServiceRateLimit,
|
||||
ConsoleError { .. } => crate::error::ErrorKind::ControlPlane,
|
||||
ControlPlaneError { .. } => crate::error::ErrorKind::ControlPlane,
|
||||
},
|
||||
},
|
||||
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
|
||||
@@ -131,7 +131,7 @@ pub(crate) mod errors {
|
||||
match self {
|
||||
// retry some transport errors
|
||||
Self::Transport(io) => io.could_retry(),
|
||||
Self::Console(e) => e.could_retry(),
|
||||
Self::ControlPlane(e) => e.could_retry(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -314,7 +314,8 @@ impl NodeInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type NodeInfoCache = TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ConsoleError>>>;
|
||||
pub(crate) type NodeInfoCache =
|
||||
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneError>>>;
|
||||
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
@@ -353,28 +354,28 @@ pub(crate) trait Api {
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone)]
|
||||
pub enum ConsoleBackend {
|
||||
/// Current Cloud API (V2).
|
||||
Console(neon::Api),
|
||||
/// Local mock of Cloud API (V2).
|
||||
pub enum ControlPlaneBackend {
|
||||
/// Current Management API (V2).
|
||||
Management(neon::Api),
|
||||
/// Local mock control plane.
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Postgres(mock::Api),
|
||||
PostgresMock(mock::Api),
|
||||
/// Internal testing
|
||||
#[cfg(test)]
|
||||
#[allow(private_interfaces)]
|
||||
Test(Box<dyn crate::auth::backend::TestBackend>),
|
||||
}
|
||||
|
||||
impl Api for ConsoleBackend {
|
||||
impl Api for ControlPlaneBackend {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Console(api) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::Management(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::Postgres(api) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_) => {
|
||||
unreachable!("this function should never be called in the test backend")
|
||||
@@ -388,9 +389,9 @@ impl Api for ConsoleBackend {
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::Management(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_allowed_ips_and_secret(),
|
||||
}
|
||||
@@ -402,9 +403,9 @@ impl Api for ConsoleBackend {
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
match self {
|
||||
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
Self::Management(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_api) => Ok(vec![]),
|
||||
}
|
||||
@@ -416,16 +417,16 @@ impl Api for ConsoleBackend {
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::Console(api) => api.wake_compute(ctx, user_info).await,
|
||||
Self::Management(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::Postgres(api) => api.wake_compute(ctx, user_info).await,
|
||||
Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.wake_compute(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
/// Various caches for [`control_plane`](super).
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub(crate) node_info: NodeInfoCache,
|
||||
@@ -454,7 +455,7 @@ impl ApiCaches {
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
/// Various caches for [`control_plane`](super).
|
||||
pub struct ApiLocks<K> {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<K, Arc<DynamicLimiter>>,
|
||||
@@ -577,7 +578,7 @@ impl WakeComputePermit {
|
||||
}
|
||||
}
|
||||
|
||||
impl FetchAuthRules for ConsoleBackend {
|
||||
impl FetchAuthRules for ControlPlaneBackend {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Production console backend.
|
||||
|
||||
use super::{
|
||||
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
super::messages::{ControlPlaneError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
|
||||
NodeInfo,
|
||||
@@ -9,7 +9,7 @@ use super::{
|
||||
use crate::{
|
||||
auth::backend::{jwt::AuthRule, ComputeUserInfo},
|
||||
compute,
|
||||
console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
|
||||
control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
|
||||
http,
|
||||
metrics::{CacheOutcome, Metrics},
|
||||
rate_limiter::WakeComputeRateLimiter,
|
||||
@@ -348,7 +348,7 @@ impl super::Api for Api {
|
||||
let (cached, info) = cached.take_value();
|
||||
let info = info.map_err(|c| {
|
||||
info!(key = &*key, "found cached wake_compute error");
|
||||
WakeComputeError::ApiError(ApiError::Console(*c))
|
||||
WakeComputeError::ApiError(ApiError::ControlPlane(*c))
|
||||
})?;
|
||||
|
||||
debug!(key = &*key, "found cached compute node info");
|
||||
@@ -395,9 +395,9 @@ impl super::Api for Api {
|
||||
Ok(cached.map(|()| node))
|
||||
}
|
||||
Err(err) => match err {
|
||||
WakeComputeError::ApiError(ApiError::Console(err)) => {
|
||||
WakeComputeError::ApiError(ApiError::ControlPlane(err)) => {
|
||||
let Some(status) = &err.status else {
|
||||
return Err(WakeComputeError::ApiError(ApiError::Console(err)));
|
||||
return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)));
|
||||
};
|
||||
|
||||
let reason = status
|
||||
@@ -407,7 +407,7 @@ impl super::Api for Api {
|
||||
|
||||
// if we can retry this error, do not cache it.
|
||||
if reason.can_retry() {
|
||||
return Err(WakeComputeError::ApiError(ApiError::Console(err)));
|
||||
return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)));
|
||||
}
|
||||
|
||||
// at this point, we should only have quota errors.
|
||||
@@ -422,7 +422,7 @@ impl super::Api for Api {
|
||||
Duration::from_secs(30),
|
||||
);
|
||||
|
||||
Err(WakeComputeError::ApiError(ApiError::Console(err)))
|
||||
Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)))
|
||||
}
|
||||
err => return Err(err),
|
||||
},
|
||||
@@ -448,7 +448,7 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
// as the fact that the request itself has failed.
|
||||
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
|
||||
warn!("failed to parse error body: {e}");
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
error: "reason unclear (malformed error message)".into(),
|
||||
http_status_code: status,
|
||||
status: None,
|
||||
@@ -457,7 +457,7 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
body.http_status_code = status;
|
||||
|
||||
error!("console responded with an error ({status}): {body:?}");
|
||||
Err(ApiError::Console(body))
|
||||
Err(ApiError::ControlPlane(body))
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
@@ -82,21 +82,27 @@
|
||||
impl_trait_overcaptures,
|
||||
)]
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::{convert::Infallible, future::Future, net::IpAddr};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use intern::{EndpointIdInt, EndpointIdTag, InternId};
|
||||
use tokio::task::JoinError;
|
||||
use protocol2::{get_client_conn_info, ChainRW};
|
||||
use proxy::run_until_cancelled;
|
||||
use tokio::{net::TcpStream, task::JoinError};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
use tracing::{error, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
extern crate hyper0 as hyper;
|
||||
|
||||
pub mod auth;
|
||||
pub mod cache;
|
||||
pub mod cancellation;
|
||||
pub mod compute;
|
||||
pub mod config;
|
||||
pub mod console;
|
||||
pub mod console_redirect_proxy;
|
||||
pub mod context;
|
||||
pub mod control_plane;
|
||||
pub mod error;
|
||||
pub mod http;
|
||||
pub mod intern;
|
||||
@@ -273,3 +279,81 @@ impl EndpointId {
|
||||
ProjectId(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait ConnHandler: Clone + Send + 'static {
|
||||
fn handle(
|
||||
self,
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
stream: ChainRW<TcpStream>,
|
||||
conn_gauge: metrics::NumClientConnectionsGuard<'static>,
|
||||
) -> impl Future<Output = ()> + Send;
|
||||
}
|
||||
|
||||
/// Accept connections, parse the proxy-protocol v2 header and spawn a tracked connection task.
|
||||
pub(crate) async fn connection_loop<C>(
|
||||
config: &'static config::ProxyConfig,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
protocol: metrics::Protocol,
|
||||
conn_handler: C,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
C: ConnHandler,
|
||||
{
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = metrics::Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(protocol);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let conn_handler = conn_handler.clone();
|
||||
|
||||
tracing::info!(protocol = protocol.as_str(), %session_id, "accepted new TCP connection");
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match get_client_conn_info(socket, config.proxy_protocol_v2).await {
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
Ok((socket, Some(addr))) => (socket, addr),
|
||||
Ok((socket, None)) => (socket, peer_addr.ip()),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: failed to set socket option: {e:#}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
conn_handler.handle(
|
||||
session_id,
|
||||
peer_addr,
|
||||
socket,
|
||||
conn_gauge,
|
||||
).await;
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
use tracing::Subscriber;
|
||||
use tracing_subscriber::{
|
||||
filter::{EnvFilter, LevelFilter},
|
||||
fmt::{
|
||||
format::{Format, Full},
|
||||
time::SystemTime,
|
||||
FormatEvent, FormatFields,
|
||||
},
|
||||
prelude::*,
|
||||
registry::LookupSpan,
|
||||
};
|
||||
|
||||
/// Initialize logging and OpenTelemetry tracing and exporter.
|
||||
@@ -33,6 +40,45 @@ pub async fn init() -> anyhow::Result<LoggingGuard> {
|
||||
Ok(LoggingGuard)
|
||||
}
|
||||
|
||||
/// Initialize logging for local_proxy with log prefix and no opentelemetry.
|
||||
///
|
||||
/// Logging can be configured using `RUST_LOG` environment variable.
|
||||
pub fn init_local_proxy() -> anyhow::Result<LoggingGuard> {
|
||||
let env_filter = EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.from_env_lossy();
|
||||
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false)
|
||||
.with_writer(std::io::stderr)
|
||||
.event_format(LocalProxyFormatter(Format::default().with_target(false)));
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(fmt_layer)
|
||||
.try_init()?;
|
||||
|
||||
Ok(LoggingGuard)
|
||||
}
|
||||
|
||||
pub struct LocalProxyFormatter(Format<Full, SystemTime>);
|
||||
|
||||
impl<S, N> FormatEvent<S, N> for LocalProxyFormatter
|
||||
where
|
||||
S: Subscriber + for<'a> LookupSpan<'a>,
|
||||
N: for<'a> FormatFields<'a> + 'static,
|
||||
{
|
||||
fn format_event(
|
||||
&self,
|
||||
ctx: &tracing_subscriber::fmt::FmtContext<'_, S, N>,
|
||||
mut writer: tracing_subscriber::fmt::format::Writer<'_>,
|
||||
event: &tracing::Event<'_>,
|
||||
) -> std::fmt::Result {
|
||||
writer.write_str("[local_proxy] ")?;
|
||||
self.0.format_event(ctx, writer, event)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LoggingGuard;
|
||||
|
||||
impl Drop for LoggingGuard {
|
||||
|
||||
@@ -11,7 +11,7 @@ use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
|
||||
|
||||
use tokio::time::{self, Instant};
|
||||
|
||||
use crate::console::messages::ColdStartInfo;
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
|
||||
|
||||
@@ -2,15 +2,18 @@
|
||||
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
net::{IpAddr, SocketAddr},
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use anyhow::bail;
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
|
||||
use crate::config::ProxyProtocolV2;
|
||||
|
||||
pin_project! {
|
||||
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
|
||||
pub(crate) struct ChainRW<T> {
|
||||
@@ -60,7 +63,23 @@ const HEADER: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
pub(crate) async fn get_client_conn_info<T: AsyncRead + Unpin>(
|
||||
socket: T,
|
||||
proxy_protocol_v2: ProxyProtocolV2,
|
||||
) -> anyhow::Result<(ChainRW<T>, Option<IpAddr>)> {
|
||||
match read_proxy_protocol(socket).await? {
|
||||
(_socket, None) if proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
bail!("missing required proxy protocol header");
|
||||
}
|
||||
(_socket, Some(_)) if proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
bail!("proxy protocol header not supported");
|
||||
}
|
||||
(socket, Some(addr)) => Ok((socket, Some(addr.ip()))),
|
||||
(socket, None) => Ok((socket, None)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
mut read: T,
|
||||
) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::{
|
||||
auth::backend::ComputeCredentialKeys,
|
||||
compute::COULD_NOT_CONNECT,
|
||||
compute::{self, PostgresConnection},
|
||||
config::RetryConfig,
|
||||
console::{self, errors::WakeComputeError, locks::ApiLocks, CachedNodeInfo, NodeInfo},
|
||||
context::RequestMonitoring,
|
||||
control_plane::{self, errors::WakeComputeError, locks::ApiLocks, CachedNodeInfo, NodeInfo},
|
||||
error::ReportableError,
|
||||
metrics::{ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType},
|
||||
proxy::{
|
||||
@@ -15,7 +16,7 @@ use crate::{
|
||||
use async_trait::async_trait;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use tokio::time;
|
||||
use tracing::{error, info, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::retry::ShouldRetryWakeCompute;
|
||||
|
||||
@@ -25,7 +26,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
||||
pub(crate) fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
warn!("invalidating stalled compute node info cache entry");
|
||||
@@ -48,7 +49,7 @@ pub(crate) trait ConnectMechanism {
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
|
||||
@@ -60,7 +61,7 @@ pub(crate) trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys;
|
||||
}
|
||||
@@ -83,7 +84,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let host = node_info.config.get_host()?;
|
||||
@@ -116,7 +117,6 @@ where
|
||||
|
||||
node_info.set_keys(user_info.get_keys());
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
let retry_type = RetryType::ConnectToCompute;
|
||||
|
||||
@@ -139,10 +139,10 @@ where
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
error!(error = ?err, "could not connect to compute node");
|
||||
debug!(error = ?err, COULD_NOT_CONNECT);
|
||||
|
||||
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
||||
// If we just recieved this from cplane and dodn't get it from cache, we shouldn't retry.
|
||||
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if should_retry(&err, num_retries, connect_to_compute_retry_config) {
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
@@ -191,7 +191,7 @@ where
|
||||
}
|
||||
Err(e) => {
|
||||
if !should_retry(&e, num_retries, connect_to_compute_retry_config) {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't connect to compute node");
|
||||
// Don't log an error here, caller will print the error
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
@@ -202,7 +202,7 @@ where
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
warn!(error = ?e, num_retries, retriable = true, "couldn't connect to compute node");
|
||||
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -10,16 +10,16 @@ pub(crate) mod wake_compute;
|
||||
pub use copy_bidirectional::copy_bidirectional_client_compute;
|
||||
pub use copy_bidirectional::ErrorSource;
|
||||
|
||||
use crate::config::ProxyProtocolV2;
|
||||
use crate::control_plane::provider::ControlPlaneBackend;
|
||||
use crate::metrics::Protocol;
|
||||
use crate::{
|
||||
auth,
|
||||
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
|
||||
cancellation::{self, CancellationHandlerMain},
|
||||
compute,
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::{Metrics, NumClientConnectionsGuard},
|
||||
protocol2::read_proxy_protocol,
|
||||
proxy::handshake::{handshake, HandshakeData},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
stream::{PqStream, Stream},
|
||||
@@ -31,6 +31,7 @@ use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use smol_str::{format_smolstr, SmolStr};
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
@@ -61,6 +62,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -70,109 +72,91 @@ pub async fn task_main(
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
super::connection_loop(
|
||||
config,
|
||||
listener,
|
||||
cancellation_token,
|
||||
Protocol::Tcp,
|
||||
C {
|
||||
config,
|
||||
auth_backend,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
#[derive(Clone)]
|
||||
struct C {
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
impl super::ConnHandler for C {
|
||||
async fn handle(
|
||||
self,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
socket: crate::protocol2::ChainRW<tokio::net::TcpStream>,
|
||||
conn_gauge: crate::metrics::NumClientConnectionsGuard<'static>,
|
||||
) {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&self.config.region,
|
||||
);
|
||||
let span = ctx.span();
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
let startup = Box::pin(
|
||||
handle_client(
|
||||
self.config,
|
||||
self.auth_backend,
|
||||
&ctx,
|
||||
self.cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
self.endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
)
|
||||
.instrument(span.clone()),
|
||||
);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let res = startup.await;
|
||||
transition_connection(ctx, res).await;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
error!("missing required proxy protocol header");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
error!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, Some(addr))) => (socket, addr.ip()),
|
||||
Ok((socket, None)) => (socket, peer_addr.ip()),
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
pub(crate) async fn transition_connection<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: RequestMonitoring,
|
||||
res: Result<Option<ProxyPassthrough<S>>, ClientRequestError>,
|
||||
) {
|
||||
let span = ctx.span();
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
error!(parent: &span, "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass().instrument(span.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: failed to set socket option: {e:#}");
|
||||
return;
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
|
||||
}
|
||||
};
|
||||
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
let span = ctx.span();
|
||||
|
||||
let startup = Box::pin(
|
||||
handle_client(
|
||||
config,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
)
|
||||
.instrument(span.clone()),
|
||||
);
|
||||
let res = startup.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
// todo: log and push to ctx the error kind
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
error!(parent: &span, "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass().instrument(span.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
|
||||
}
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
connections.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) enum ClientMode {
|
||||
@@ -243,15 +227,17 @@ impl ReportableError for ClientRequestError {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
ctx: &RequestMonitoring,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
info!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
@@ -285,21 +271,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names);
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e).await?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let user_info = match user_info
|
||||
let user = user_info.user.clone();
|
||||
let user_info = match auth_backend
|
||||
.authenticate(
|
||||
ctx,
|
||||
user_info,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
@@ -353,7 +335,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn prepare_client_connection<P>(
|
||||
pub(crate) async fn prepare_client_connection<P>(
|
||||
node: &compute::PostgresConnection,
|
||||
session: &cancellation::Session<P>,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
cancellation,
|
||||
cancellation::{self, CancellationHandlerMainInternal},
|
||||
compute::PostgresConnection,
|
||||
console::messages::MetricsAuxInfo,
|
||||
control_plane::messages::MetricsAuxInfo,
|
||||
metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard},
|
||||
stream::Stream,
|
||||
usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS},
|
||||
@@ -57,17 +57,17 @@ pub(crate) async fn proxy_pass(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) struct ProxyPassthrough<P, S> {
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
pub(crate) _cancel: cancellation::Session<P>,
|
||||
pub(crate) _cancel: cancellation::Session<CancellationHandlerMainInternal>,
|
||||
}
|
||||
|
||||
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
|
||||
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
|
||||
|
||||
@@ -8,16 +8,20 @@ use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::CouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, TestBackend,
|
||||
};
|
||||
use crate::config::{CertResolver, RetryConfig};
|
||||
use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend, NodeInfoCache};
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::config::{CertResolver, ProxyProtocolV2, RetryConfig};
|
||||
use crate::control_plane::messages::{ControlPlaneError, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::provider::{
|
||||
CachedAllowedIps, CachedRoleSecret, ControlPlaneBackend, NodeInfoCache,
|
||||
};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::protocol2::get_client_conn_info;
|
||||
use crate::{sasl, scram, BranchId, EndpointId, ProjectId};
|
||||
use anyhow::{bail, Context};
|
||||
use async_trait::async_trait;
|
||||
use auth::backend::ControlPlaneComputeBackend;
|
||||
use http::StatusCode;
|
||||
use retry::{retry_after, ShouldRetryWakeCompute};
|
||||
use rstest::rstest;
|
||||
@@ -174,7 +178,7 @@ async fn dummy_proxy(
|
||||
tls: Option<TlsConfig>,
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let (client, _) = read_proxy_protocol(client).await?;
|
||||
let (client, _) = get_client_conn_info(client, ProxyProtocolV2::Supported).await?;
|
||||
let mut stream =
|
||||
match handshake(&RequestMonitoring::test(), client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
@@ -459,7 +463,7 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
async fn connect_once(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_node_info: &console::CachedNodeInfo,
|
||||
_node_info: &control_plane::CachedNodeInfo,
|
||||
_timeout: std::time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let mut counter = self.counter.lock().unwrap();
|
||||
@@ -483,23 +487,23 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
}
|
||||
|
||||
impl TestBackend for TestConnectMechanism {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
let mut counter = self.counter.lock().unwrap();
|
||||
let action = self.sequence[*counter];
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
|
||||
ConnectAction::WakeFail => {
|
||||
let err = console::errors::ApiError::Console(ConsoleError {
|
||||
let err = control_plane::errors::ApiError::ControlPlane(ControlPlaneError {
|
||||
http_status_code: StatusCode::BAD_REQUEST,
|
||||
error: "TEST".into(),
|
||||
status: None,
|
||||
});
|
||||
assert!(!err.could_retry());
|
||||
Err(console::errors::WakeComputeError::ApiError(err))
|
||||
Err(control_plane::errors::WakeComputeError::ApiError(err))
|
||||
}
|
||||
ConnectAction::WakeRetry => {
|
||||
let err = console::errors::ApiError::Console(ConsoleError {
|
||||
let err = control_plane::errors::ApiError::ControlPlane(ControlPlaneError {
|
||||
http_status_code: StatusCode::BAD_REQUEST,
|
||||
error: "TEST".into(),
|
||||
status: Some(Status {
|
||||
@@ -507,13 +511,15 @@ impl TestBackend for TestConnectMechanism {
|
||||
message: "error".into(),
|
||||
details: Details {
|
||||
error_info: None,
|
||||
retry_info: Some(console::messages::RetryInfo { retry_delay_ms: 1 }),
|
||||
retry_info: Some(control_plane::messages::RetryInfo {
|
||||
retry_delay_ms: 1,
|
||||
}),
|
||||
user_facing_message: None,
|
||||
},
|
||||
}),
|
||||
});
|
||||
assert!(err.could_retry());
|
||||
Err(console::errors::WakeComputeError::ApiError(err))
|
||||
Err(control_plane::errors::WakeComputeError::ApiError(err))
|
||||
}
|
||||
x => panic!("expecting action {x:?}, wake_compute is called instead"),
|
||||
}
|
||||
@@ -521,7 +527,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), control_plane::errors::GetAuthInfoError>
|
||||
{
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
@@ -538,7 +544,7 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
@@ -548,19 +554,19 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> auth::Backend<'static, ComputeCredentials, &()> {
|
||||
let user_info = auth::Backend::Console(
|
||||
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
|
||||
ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
) -> ControlPlaneComputeBackend<'static> {
|
||||
let api = Box::leak(Box::new(ControlPlaneBackend::Test(Box::new(
|
||||
mechanism.clone(),
|
||||
))));
|
||||
|
||||
api.attach_to_credentials(ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
);
|
||||
user_info
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::config::RetryConfig;
|
||||
use crate::console::messages::{ConsoleError, Reason};
|
||||
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::control_plane::messages::{ControlPlaneError, Reason};
|
||||
use crate::control_plane::{errors::WakeComputeError, provider::CachedNodeInfo};
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
|
||||
WakeupFailureKind,
|
||||
@@ -59,11 +59,11 @@ pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
|
||||
}
|
||||
|
||||
fn report_error(e: &WakeComputeError, retry: bool) {
|
||||
use crate::console::errors::ApiError;
|
||||
use crate::control_plane::errors::ApiError;
|
||||
let kind = match e {
|
||||
WakeComputeError::BadComputeAddress(_) => WakeupFailureKind::BadComputeAddress,
|
||||
WakeComputeError::ApiError(ApiError::Transport(_)) => WakeupFailureKind::ApiTransportError,
|
||||
WakeComputeError::ApiError(ApiError::Console(e)) => match e.get_reason() {
|
||||
WakeComputeError::ApiError(ApiError::ControlPlane(e)) => match e.get_reason() {
|
||||
Reason::RoleProtected => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
Reason::ResourceNotFound => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
Reason::ProjectNotFound => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
@@ -80,7 +80,7 @@ fn report_error(e: &WakeComputeError, retry: bool) {
|
||||
Reason::LockAlreadyTaken => WakeupFailureKind::ApiConsoleLocked,
|
||||
Reason::RunningOperations => WakeupFailureKind::ApiConsoleLocked,
|
||||
Reason::Unknown => match e {
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
http_status_code: StatusCode::LOCKED,
|
||||
ref error,
|
||||
..
|
||||
@@ -89,27 +89,27 @@ fn report_error(e: &WakeComputeError, retry: bool) {
|
||||
{
|
||||
WakeupFailureKind::QuotaExceeded
|
||||
}
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
http_status_code: StatusCode::UNPROCESSABLE_ENTITY,
|
||||
ref error,
|
||||
..
|
||||
} if error.contains("compute time quota of non-primary branches is exceeded") => {
|
||||
WakeupFailureKind::QuotaExceeded
|
||||
}
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
http_status_code: StatusCode::LOCKED,
|
||||
..
|
||||
} => WakeupFailureKind::ApiConsoleLocked,
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
http_status_code: StatusCode::BAD_REQUEST,
|
||||
..
|
||||
} => WakeupFailureKind::ApiConsoleBadRequest,
|
||||
ConsoleError {
|
||||
ControlPlaneError {
|
||||
http_status_code, ..
|
||||
} if http_status_code.is_server_error() => {
|
||||
WakeupFailureKind::ApiConsoleOtherServerError
|
||||
}
|
||||
ConsoleError { .. } => WakeupFailureKind::ApiConsoleOtherError,
|
||||
ControlPlaneError { .. } => WakeupFailureKind::ApiConsoleOtherError,
|
||||
},
|
||||
},
|
||||
WakeComputeError::TooManyConnections => WakeupFailureKind::ApiConsoleLocked,
|
||||
|
||||
@@ -8,17 +8,17 @@ use tracing::{field::display, info};
|
||||
use crate::{
|
||||
auth::{
|
||||
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
|
||||
check_peer_addr_is_in_list, AuthError,
|
||||
check_peer_addr_is_in_list, AuthError, ServerlessBackend,
|
||||
},
|
||||
compute,
|
||||
config::{AuthenticationConfig, ProxyConfig},
|
||||
console::{
|
||||
config::ProxyConfig,
|
||||
context::RequestMonitoring,
|
||||
control_plane::{
|
||||
errors::{GetAuthInfoError, WakeComputeError},
|
||||
locks::ApiLocks,
|
||||
provider::ApiLockError,
|
||||
CachedNodeInfo,
|
||||
Api, CachedNodeInfo,
|
||||
},
|
||||
context::RequestMonitoring,
|
||||
error::{ErrorKind, ReportableError, UserFacingError},
|
||||
intern::EndpointIdInt,
|
||||
proxy::{
|
||||
@@ -38,6 +38,7 @@ pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) auth_backend: ServerlessBackend<'static>,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
@@ -45,18 +46,20 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_password(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
user_info: &ComputeUserInfo,
|
||||
password: &[u8],
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = user_info.clone();
|
||||
let backend = self
|
||||
.config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|()| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
if config.ip_allowlist_check_enabled
|
||||
let cplane = match &self.auth_backend {
|
||||
ServerlessBackend::ControlPlane(cplane) => cplane,
|
||||
ServerlessBackend::Local(_local) => {
|
||||
return Err(AuthError::bad_auth_method(
|
||||
"password authentication not supported by local_proxy",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let (allowed_ips, maybe_secret) = cplane.get_allowed_ips_and_secret(ctx, user_info).await?;
|
||||
if self.config.authentication_config.ip_allowlist_check_enabled
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
|
||||
{
|
||||
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
@@ -69,13 +72,12 @@ impl PoolingBackend {
|
||||
}
|
||||
let cached_secret = match maybe_secret {
|
||||
Some(secret) => secret,
|
||||
None => backend.get_role_secret(ctx).await?,
|
||||
None => cplane.get_role_secret(ctx, user_info).await?,
|
||||
};
|
||||
|
||||
let secret = match cached_secret.value.clone() {
|
||||
Some(secret) => self.config.authentication_config.check_rate_limit(
|
||||
ctx,
|
||||
config,
|
||||
secret,
|
||||
&user_info.endpoint,
|
||||
true,
|
||||
@@ -87,9 +89,13 @@ impl PoolingBackend {
|
||||
}
|
||||
};
|
||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||
let auth_outcome =
|
||||
crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
|
||||
.await?;
|
||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||
&self.config.authentication_config.thread_pool,
|
||||
ep,
|
||||
password,
|
||||
secret,
|
||||
)
|
||||
.await?;
|
||||
let res = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => {
|
||||
info!("user successfully authenticated");
|
||||
@@ -101,7 +107,7 @@ impl PoolingBackend {
|
||||
}
|
||||
};
|
||||
res.map(|key| ComputeCredentials {
|
||||
info: user_info,
|
||||
info: user_info.clone(),
|
||||
keys: key,
|
||||
})
|
||||
}
|
||||
@@ -109,13 +115,13 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_jwt(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
user_info: &ComputeUserInfo,
|
||||
jwt: String,
|
||||
) -> Result<(), AuthError> {
|
||||
match &self.config.auth_backend {
|
||||
crate::auth::Backend::Console(console, ()) => {
|
||||
config
|
||||
match &self.auth_backend {
|
||||
ServerlessBackend::ControlPlane(console) => {
|
||||
self.config
|
||||
.authentication_config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
@@ -129,11 +135,9 @@ impl PoolingBackend {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
|
||||
"JWT login over web auth proxy is not supported",
|
||||
)),
|
||||
crate::auth::Backend::Local(_) => {
|
||||
config
|
||||
ServerlessBackend::Local(_) => {
|
||||
self.config
|
||||
.authentication_config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
@@ -176,21 +180,41 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.config.auth_backend.as_ref().map(|()| keys);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&backend,
|
||||
false, // do not allow self signed compute for http flow
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
|
||||
match &self.auth_backend {
|
||||
ServerlessBackend::ControlPlane(cplane) => {
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&cplane.attach_to_credentials(keys),
|
||||
false, // do not allow self signed compute for http flow
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
ServerlessBackend::Local(local_proxy) => {
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&**local_proxy,
|
||||
false, // do not allow self signed compute for http flow
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wake up the destination if needed
|
||||
@@ -200,6 +224,13 @@ impl PoolingBackend {
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client, HttpConnError> {
|
||||
let cplane = match &self.auth_backend {
|
||||
ServerlessBackend::Local(_) => {
|
||||
panic!("connect to local_proxy should not be called if we are already local_proxy")
|
||||
}
|
||||
ServerlessBackend::ControlPlane(cplane) => cplane,
|
||||
};
|
||||
|
||||
info!("pool: looking for an existing connection");
|
||||
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
return Ok(client);
|
||||
@@ -208,14 +239,11 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self
|
||||
.config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|()| ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
});
|
||||
|
||||
let backend = cplane.attach_to_credentials(ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
|
||||
@@ -17,7 +17,7 @@ use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{
|
||||
@@ -760,7 +760,7 @@ mod tests {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
conn_id: uuid::Uuid::new_v4(),
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::sync::atomic::{self, AtomicUsize};
|
||||
use std::{sync::Arc, sync::Weak};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{context::RequestMonitoring, EndpointCacheKey};
|
||||
|
||||
@@ -16,7 +16,6 @@ use atomic_take::AtomicTake;
|
||||
use bytes::Bytes;
|
||||
pub use conn_pool::GlobalConnPoolOptions;
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::future::{select, Either};
|
||||
use futures::TryFutureExt;
|
||||
use http::{Method, Response, StatusCode};
|
||||
@@ -32,28 +31,29 @@ use tokio::time::timeout;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::auth::ServerlessBackend;
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{read_proxy_protocol, ChainRW};
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::metrics::{Metrics, Protocol};
|
||||
use crate::protocol2::ChainRW;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::serverless::http_util::{api_error_into_response, json_response};
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::net::IpAddr;
|
||||
use std::pin::{pin, Pin};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
use tracing::{error, info, instrument, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: ServerlessBackend<'static>,
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -107,6 +107,7 @@ pub async fn task_main(
|
||||
http_conn_pool: Arc::clone(&http_conn_pool),
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
auth_backend,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
});
|
||||
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
|
||||
@@ -122,81 +123,100 @@ pub async fn task_main(
|
||||
}
|
||||
};
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
let requests = TaskTracker::new();
|
||||
requests.close(); // allows `requests.wait to complete`
|
||||
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
tracing::error!("could not set nodelay: {e}");
|
||||
continue;
|
||||
}
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
|
||||
crate::connection_loop(
|
||||
config,
|
||||
ws_listener,
|
||||
cancellation_token.clone(),
|
||||
Protocol::Http,
|
||||
C {
|
||||
config,
|
||||
backend,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
tls_acceptor,
|
||||
requests: requests.clone(),
|
||||
cancellation_token,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
requests.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct C {
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
|
||||
requests: TaskTracker,
|
||||
cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl super::ConnHandler for C {
|
||||
#[instrument(name = "http_conn", skip_all, fields(conn_id))]
|
||||
async fn handle(
|
||||
self,
|
||||
conn_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
stream: ChainRW<TcpStream>,
|
||||
conn_gauge: crate::metrics::NumClientConnectionsGuard<'static>,
|
||||
) {
|
||||
// try and close an old HTTP connection.
|
||||
// picked at random
|
||||
let n_connections = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.sample(crate::metrics::Protocol::Http);
|
||||
tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check");
|
||||
if n_connections > config.http_config.client_conn_threshold {
|
||||
tracing::trace!(?n_connections, threshold = ?self.config.http_config.client_conn_threshold, "check");
|
||||
if n_connections > self.config.http_config.client_conn_threshold {
|
||||
tracing::trace!("attempting to cancel a random connection");
|
||||
if let Some(token) = config.http_config.cancel_set.take() {
|
||||
if let Some(token) = self.config.http_config.cancel_set.take() {
|
||||
tracing::debug!("cancelling a random connection");
|
||||
token.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
let conn_token = cancellation_token.child_token();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let backend = backend.clone();
|
||||
let connections2 = connections.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
connections.spawn(
|
||||
async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
|
||||
let conn_token = self.cancellation_token.child_token();
|
||||
let _cancel_guard = self
|
||||
.config
|
||||
.http_config
|
||||
.cancel_set
|
||||
.insert(conn_id, conn_token.clone());
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let startup_result = Box::pin(connection_startup(
|
||||
self.config,
|
||||
self.tls_acceptor,
|
||||
conn_id,
|
||||
stream,
|
||||
peer_addr,
|
||||
))
|
||||
.await;
|
||||
let Some((conn, peer_addr)) = startup_result else {
|
||||
return;
|
||||
};
|
||||
|
||||
let _gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Http);
|
||||
Box::pin(connection_handler(
|
||||
self.config,
|
||||
self.backend,
|
||||
self.requests,
|
||||
self.cancellation_handler,
|
||||
self.endpoint_rate_limiter,
|
||||
conn_token,
|
||||
conn,
|
||||
peer_addr,
|
||||
conn_id,
|
||||
))
|
||||
.await;
|
||||
|
||||
let startup_result = Box::pin(connection_startup(
|
||||
config,
|
||||
tls_acceptor,
|
||||
session_id,
|
||||
conn,
|
||||
peer_addr,
|
||||
))
|
||||
.await;
|
||||
let Some((conn, peer_addr)) = startup_result else {
|
||||
return;
|
||||
};
|
||||
|
||||
Box::pin(connection_handler(
|
||||
config,
|
||||
backend,
|
||||
connections2,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
conn,
|
||||
peer_addr,
|
||||
session_id,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
.instrument(http_conn_span),
|
||||
);
|
||||
drop(conn_gauge);
|
||||
}
|
||||
|
||||
connections.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
|
||||
@@ -224,26 +244,14 @@ impl MaybeTlsAcceptor for NoTls {
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles the TCP startup lifecycle.
|
||||
/// 1. Parses PROXY protocol V2
|
||||
/// 2. Handles TLS handshake
|
||||
/// Handles the TLS startup handshake.
|
||||
async fn connection_startup(
|
||||
config: &ProxyConfig,
|
||||
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
|
||||
session_id: uuid::Uuid,
|
||||
conn: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
conn: ChainRW<TcpStream>,
|
||||
peer_addr: IpAddr,
|
||||
) -> Option<(AsyncRW, IpAddr)> {
|
||||
// handle PROXY protocol
|
||||
let (conn, peer) = match read_proxy_protocol(conn).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let peer_addr = peer.unwrap_or(peer_addr).ip();
|
||||
let has_private_peer_addr = match peer_addr {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
IpAddr::V6(_) => false,
|
||||
@@ -377,6 +385,10 @@ async fn request_handler(
|
||||
if config.http_config.accept_websockets
|
||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
||||
{
|
||||
let ServerlessBackend::ControlPlane(auth_backend) = backend.auth_backend else {
|
||||
return json_response(StatusCode::BAD_REQUEST, "query is not supported");
|
||||
};
|
||||
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
peer_addr,
|
||||
@@ -394,6 +406,7 @@ async fn request_handler(
|
||||
async move {
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
auth_backend,
|
||||
ctx,
|
||||
websocket,
|
||||
cancellation_handler,
|
||||
@@ -45,6 +45,7 @@ use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::config::HttpConfig;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
@@ -552,7 +553,7 @@ async fn handle_inner(
|
||||
|
||||
match conn_info.auth {
|
||||
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
}
|
||||
auth => {
|
||||
handle_db_inner(
|
||||
@@ -623,22 +624,12 @@ async fn handle_db_inner(
|
||||
let keys = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
.authenticate_with_password(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
&pw,
|
||||
)
|
||||
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
|
||||
.await?
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await?;
|
||||
|
||||
ComputeCredentials {
|
||||
@@ -680,7 +671,7 @@ async fn handle_db_inner(
|
||||
// Now execute the query and return the result.
|
||||
let json_output = match payload {
|
||||
Payload::Single(stmt) => {
|
||||
stmt.process(config, cancel, &mut client, parsed_headers)
|
||||
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await?
|
||||
}
|
||||
Payload::Batch(statements) => {
|
||||
@@ -698,7 +689,7 @@ async fn handle_db_inner(
|
||||
}
|
||||
|
||||
statements
|
||||
.process(config, cancel, &mut client, parsed_headers)
|
||||
.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
@@ -738,7 +729,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
];
|
||||
|
||||
async fn handle_auth_broker_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
@@ -746,12 +736,7 @@ async fn handle_auth_broker_inner(
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
@@ -789,7 +774,7 @@ async fn handle_auth_broker_inner(
|
||||
impl QueryData {
|
||||
async fn process(
|
||||
self,
|
||||
config: &'static ProxyConfig,
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
parsed_headers: HttpHeaders,
|
||||
@@ -863,7 +848,7 @@ impl QueryData {
|
||||
impl BatchQueryData {
|
||||
async fn process(
|
||||
self,
|
||||
config: &'static ProxyConfig,
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
parsed_headers: HttpHeaders,
|
||||
@@ -933,7 +918,7 @@ impl BatchQueryData {
|
||||
}
|
||||
|
||||
async fn query_batch(
|
||||
config: &'static ProxyConfig,
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
transaction: &Transaction<'_>,
|
||||
queries: BatchQueryData,
|
||||
@@ -972,7 +957,7 @@ async fn query_batch(
|
||||
}
|
||||
|
||||
async fn query_to_json<T: GenericClient>(
|
||||
config: &'static ProxyConfig,
|
||||
config: &'static HttpConfig,
|
||||
client: &T,
|
||||
data: QueryData,
|
||||
current_size: &mut usize,
|
||||
@@ -993,9 +978,9 @@ async fn query_to_json<T: GenericClient>(
|
||||
rows.push(row);
|
||||
// we don't have a streaming response support yet so this is to prevent OOM
|
||||
// from a malicious query (eg a cross join)
|
||||
if *current_size > config.http_config.max_response_size_bytes {
|
||||
if *current_size > config.max_response_size_bytes {
|
||||
return Err(SqlOverHttpError::ResponseTooLarge(
|
||||
config.http_config.max_response_size_bytes,
|
||||
config.max_response_size_bytes,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::control_plane::provider::ControlPlaneBackend;
|
||||
use crate::proxy::ErrorSource;
|
||||
use crate::{
|
||||
cancellation::CancellationHandlerMain,
|
||||
@@ -129,6 +130,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
|
||||
pub(crate) async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static ControlPlaneBackend,
|
||||
ctx: RequestMonitoring,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -145,6 +147,7 @@ pub(crate) async fn serve_websocket(
|
||||
|
||||
let res = Box::pin(handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
|
||||
@@ -23,8 +23,7 @@ crc32c.workspace = true
|
||||
fail.workspace = true
|
||||
hex.workspace = true
|
||||
humantime.workspace = true
|
||||
http.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper0.workspace = true
|
||||
futures.workspace = true
|
||||
once_cell.workspace = true
|
||||
parking_lot.workspace = true
|
||||
|
||||
@@ -253,6 +253,13 @@ pub async fn build(args: Args) -> Result<Response> {
|
||||
});
|
||||
}
|
||||
|
||||
// Tokio forbids to drop runtime in async context, so this is a stupid way
|
||||
// to drop it in non async context.
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let _r = runtime;
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(Response {
|
||||
start_time,
|
||||
finish_time: Utc::now(),
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
extern crate hyper0 as hyper;
|
||||
|
||||
use camino::Utf8PathBuf;
|
||||
use once_cell::sync::Lazy;
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
|
||||
@@ -2,21 +2,29 @@ use utils::lsn::Lsn;
|
||||
|
||||
use crate::timeline_manager::StateSnapshot;
|
||||
|
||||
/// Get oldest LSN we still need to keep. We hold WAL till it is consumed
|
||||
/// by all of 1) pageserver (remote_consistent_lsn) 2) peers 3) s3
|
||||
/// offloading.
|
||||
/// While it is safe to use inmem values for determining horizon,
|
||||
/// we use persistent to make possible normal states less surprising.
|
||||
/// All segments covering LSNs before horizon_lsn can be removed.
|
||||
/// Get oldest LSN we still need to keep.
|
||||
///
|
||||
/// We hold WAL till it is consumed by
|
||||
/// 1) pageserver (remote_consistent_lsn)
|
||||
/// 2) s3 offloading.
|
||||
/// 3) Additionally we must store WAL since last local commit_lsn because
|
||||
/// that's where we start looking for last WAL record on start.
|
||||
///
|
||||
/// If some peer safekeeper misses data it will fetch it from the remote
|
||||
/// storage. While it is safe to use inmem values for determining horizon, we
|
||||
/// use persistent to make possible normal states less surprising. All segments
|
||||
/// covering LSNs before horizon_lsn can be removed.
|
||||
pub(crate) fn calc_horizon_lsn(state: &StateSnapshot, extra_horizon_lsn: Option<Lsn>) -> Lsn {
|
||||
use std::cmp::min;
|
||||
|
||||
let mut horizon_lsn = min(
|
||||
state.cfile_remote_consistent_lsn,
|
||||
state.cfile_peer_horizon_lsn,
|
||||
);
|
||||
let mut horizon_lsn = state.cfile_remote_consistent_lsn;
|
||||
// we don't want to remove WAL that is not yet offloaded to s3
|
||||
horizon_lsn = min(horizon_lsn, state.cfile_backup_lsn);
|
||||
// Min by local commit_lsn to be able to begin reading WAL from somewhere on
|
||||
// sk start. Technically we don't allow local commit_lsn to be higher than
|
||||
// flush_lsn, but let's be double safe by including it as well.
|
||||
horizon_lsn = min(horizon_lsn, state.cfile_commit_lsn);
|
||||
horizon_lsn = min(horizon_lsn, state.flush_lsn);
|
||||
if let Some(extra_horizon_lsn) = extra_horizon_lsn {
|
||||
horizon_lsn = min(horizon_lsn, extra_horizon_lsn);
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ pub(crate) struct StateSnapshot {
|
||||
pub(crate) remote_consistent_lsn: Lsn,
|
||||
|
||||
// persistent control file values
|
||||
pub(crate) cfile_peer_horizon_lsn: Lsn,
|
||||
pub(crate) cfile_commit_lsn: Lsn,
|
||||
pub(crate) cfile_remote_consistent_lsn: Lsn,
|
||||
pub(crate) cfile_backup_lsn: Lsn,
|
||||
|
||||
@@ -70,7 +70,7 @@ impl StateSnapshot {
|
||||
commit_lsn: state.inmem.commit_lsn,
|
||||
backup_lsn: state.inmem.backup_lsn,
|
||||
remote_consistent_lsn: state.inmem.remote_consistent_lsn,
|
||||
cfile_peer_horizon_lsn: state.peer_horizon_lsn,
|
||||
cfile_commit_lsn: state.commit_lsn,
|
||||
cfile_remote_consistent_lsn: state.remote_consistent_lsn,
|
||||
cfile_backup_lsn: state.backup_lsn,
|
||||
flush_lsn: read_guard.sk.flush_lsn(),
|
||||
|
||||
@@ -13,7 +13,7 @@ use desim::{
|
||||
node_os::NodeOs,
|
||||
proto::{AnyMessage, NetEvent, NodeEvent},
|
||||
};
|
||||
use http::Uri;
|
||||
use hyper0::Uri;
|
||||
use safekeeper::{
|
||||
safekeeper::{ProposerAcceptorMessage, SafeKeeper, ServerInfo, UNKNOWN_SERVER_VERSION},
|
||||
state::{TimelinePersistentState, TimelineState},
|
||||
|
||||
@@ -10,16 +10,13 @@ bench = []
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-stream.workspace = true
|
||||
bytes.workspace = true
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
const_format.workspace = true
|
||||
futures.workspace = true
|
||||
futures-core.workspace = true
|
||||
futures-util.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper_1 = { workspace = true, features = ["full"] }
|
||||
http-body-util.workspace = true
|
||||
hyper-util = "0.1"
|
||||
hyper0 = { workspace = true, features = ["full"] }
|
||||
once_cell.workspace = true
|
||||
parking_lot.workspace = true
|
||||
prost.workspace = true
|
||||
|
||||
@@ -10,16 +10,16 @@
|
||||
//!
|
||||
//! Only safekeeper message is supported, but it is not hard to add something
|
||||
//! else with generics.
|
||||
|
||||
extern crate hyper0 as hyper;
|
||||
|
||||
use clap::{command, Parser};
|
||||
use futures_core::Stream;
|
||||
use futures_util::StreamExt;
|
||||
use http_body_util::Full;
|
||||
use hyper::header::CONTENT_TYPE;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Method, StatusCode};
|
||||
use hyper_1 as hyper;
|
||||
use hyper_1::body::Incoming;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use hyper::server::conn::AddrStream;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Method, StatusCode};
|
||||
use parking_lot::RwLock;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
@@ -27,11 +27,9 @@ use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::broadcast::error::RecvError;
|
||||
use tokio::time;
|
||||
use tonic::body::{self, empty_body, BoxBody};
|
||||
use tonic::codegen::Service;
|
||||
use tonic::transport::server::Connected;
|
||||
use tonic::Code;
|
||||
@@ -50,7 +48,9 @@ use storage_broker::proto::{
|
||||
FilterTenantTimelineId, MessageType, SafekeeperDiscoveryRequest, SafekeeperDiscoveryResponse,
|
||||
SafekeeperTimelineInfo, SubscribeByFilterRequest, SubscribeSafekeeperInfoRequest, TypedMessage,
|
||||
};
|
||||
use storage_broker::{parse_proto_ttid, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_LISTEN_ADDR};
|
||||
use storage_broker::{
|
||||
parse_proto_ttid, EitherBody, DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_LISTEN_ADDR,
|
||||
};
|
||||
use utils::id::TenantTimelineId;
|
||||
use utils::logging::{self, LogFormat};
|
||||
use utils::sentry_init::init_sentry;
|
||||
@@ -602,8 +602,8 @@ impl BrokerService for Broker {
|
||||
|
||||
// We serve only metrics and healthcheck through http1.
|
||||
async fn http1_handler(
|
||||
req: hyper::Request<Incoming>,
|
||||
) -> Result<hyper::Response<BoxBody>, Infallible> {
|
||||
req: hyper::Request<hyper::body::Body>,
|
||||
) -> Result<hyper::Response<Body>, Infallible> {
|
||||
let resp = match (req.method(), req.uri().path()) {
|
||||
(&Method::GET, "/metrics") => {
|
||||
let mut buffer = vec![];
|
||||
@@ -614,16 +614,16 @@ async fn http1_handler(
|
||||
hyper::Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, encoder.format_type())
|
||||
.body(body::boxed(Full::new(bytes::Bytes::from(buffer))))
|
||||
.body(Body::from(buffer))
|
||||
.unwrap()
|
||||
}
|
||||
(&Method::GET, "/status") => hyper::Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(empty_body())
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
_ => hyper::Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(empty_body())
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
};
|
||||
Ok(resp)
|
||||
@@ -665,74 +665,52 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
};
|
||||
let storage_broker_server = BrokerServiceServer::new(storage_broker_impl);
|
||||
|
||||
info!("listening on {}", &args.listen_addr);
|
||||
|
||||
// grpc is served along with http1 for metrics on a single port, hence we
|
||||
// don't use tonic's Server.
|
||||
let tcp_listener = TcpListener::bind(&args.listen_addr).await?;
|
||||
info!("listening on {}", &args.listen_addr);
|
||||
loop {
|
||||
let (stream, addr) = match tcp_listener.accept().await {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
info!("couldn't accept connection: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
hyper::Server::bind(&args.listen_addr)
|
||||
.http2_keep_alive_interval(Some(args.http2_keepalive_interval))
|
||||
.serve(make_service_fn(move |conn: &AddrStream| {
|
||||
let storage_broker_server_cloned = storage_broker_server.clone();
|
||||
let connect_info = conn.connect_info();
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |mut req| {
|
||||
// That's what tonic's MakeSvc.call does to pass conninfo to
|
||||
// the request handler (and where its request.remote_addr()
|
||||
// expects it to find).
|
||||
req.extensions_mut().insert(connect_info.clone());
|
||||
|
||||
let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
|
||||
builder.http1().timer(TokioTimer::new());
|
||||
builder
|
||||
.http2()
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive_interval(Some(args.http2_keepalive_interval));
|
||||
|
||||
let storage_broker_server_cloned = storage_broker_server.clone();
|
||||
let connect_info = stream.connect_info();
|
||||
let service_fn_ = async move {
|
||||
service_fn(move |mut req| {
|
||||
// That's what tonic's MakeSvc.call does to pass conninfo to
|
||||
// the request handler (and where its request.remote_addr()
|
||||
// expects it to find).
|
||||
req.extensions_mut().insert(connect_info.clone());
|
||||
|
||||
// Technically this second clone is not needed, but consume
|
||||
// by async block is apparently unavoidable. BTW, error
|
||||
// message is enigmatic, see
|
||||
// https://github.com/rust-lang/rust/issues/68119
|
||||
//
|
||||
// We could get away without async block at all, but then we
|
||||
// need to resort to futures::Either to merge the result,
|
||||
// which doesn't caress an eye as well.
|
||||
let mut storage_broker_server_svc = storage_broker_server_cloned.clone();
|
||||
async move {
|
||||
if req.headers().get("content-type").map(|x| x.as_bytes())
|
||||
== Some(b"application/grpc")
|
||||
{
|
||||
let res_resp = storage_broker_server_svc.call(req).await;
|
||||
// Grpc and http1 handlers have slightly different
|
||||
// Response types: it is UnsyncBoxBody for the
|
||||
// former one (not sure why) and plain hyper::Body
|
||||
// for the latter. Both implement HttpBody though,
|
||||
// and `Either` is used to merge them.
|
||||
res_resp.map(|resp| resp.map(http_body_util::Either::Left))
|
||||
} else {
|
||||
let res_resp = http1_handler(req).await;
|
||||
res_resp.map(|resp| resp.map(http_body_util::Either::Right))
|
||||
// Technically this second clone is not needed, but consume
|
||||
// by async block is apparently unavoidable. BTW, error
|
||||
// message is enigmatic, see
|
||||
// https://github.com/rust-lang/rust/issues/68119
|
||||
//
|
||||
// We could get away without async block at all, but then we
|
||||
// need to resort to futures::Either to merge the result,
|
||||
// which doesn't caress an eye as well.
|
||||
let mut storage_broker_server_svc = storage_broker_server_cloned.clone();
|
||||
async move {
|
||||
if req.headers().get("content-type").map(|x| x.as_bytes())
|
||||
== Some(b"application/grpc")
|
||||
{
|
||||
let res_resp = storage_broker_server_svc.call(req).await;
|
||||
// Grpc and http1 handlers have slightly different
|
||||
// Response types: it is UnsyncBoxBody for the
|
||||
// former one (not sure why) and plain hyper::Body
|
||||
// for the latter. Both implement HttpBody though,
|
||||
// and EitherBody is used to merge them.
|
||||
res_resp.map(|resp| resp.map(EitherBody::Left))
|
||||
} else {
|
||||
let res_resp = http1_handler(req).await;
|
||||
res_resp.map(|resp| resp.map(EitherBody::Right))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
.await;
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let res = builder
|
||||
.serve_connection(TokioIo::new(stream), service_fn_)
|
||||
.await;
|
||||
|
||||
if let Err(e) = res {
|
||||
info!("error serving connection from {addr}: {e}");
|
||||
}))
|
||||
}
|
||||
});
|
||||
}
|
||||
}))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
use hyper_1 as hyper;
|
||||
extern crate hyper0 as hyper;
|
||||
|
||||
use hyper::body::HttpBody;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tonic::codegen::StdError;
|
||||
use tonic::transport::{ClientTlsConfig, Endpoint};
|
||||
@@ -92,3 +96,56 @@ pub fn parse_proto_ttid(proto_ttid: &ProtoTenantTimelineId) -> Result<TenantTime
|
||||
timeline_id,
|
||||
})
|
||||
}
|
||||
|
||||
// These several usages don't justify anyhow dependency, though it would work as
|
||||
// well.
|
||||
type AnyError = Box<dyn std::error::Error + Send + Sync + 'static>;
|
||||
|
||||
// Provides impl HttpBody for two different types implementing it. Inspired by
|
||||
// https://github.com/hyperium/tonic/blob/master/examples/src/hyper_warp/server.rs
|
||||
pub enum EitherBody<A, B> {
|
||||
Left(A),
|
||||
Right(B),
|
||||
}
|
||||
|
||||
impl<A, B> HttpBody for EitherBody<A, B>
|
||||
where
|
||||
A: HttpBody + Send + Unpin,
|
||||
B: HttpBody<Data = A::Data> + Send + Unpin,
|
||||
A::Error: Into<AnyError>,
|
||||
B::Error: Into<AnyError>,
|
||||
{
|
||||
type Data = A::Data;
|
||||
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
|
||||
|
||||
fn is_end_stream(&self) -> bool {
|
||||
match self {
|
||||
EitherBody::Left(b) => b.is_end_stream(),
|
||||
EitherBody::Right(b) => b.is_end_stream(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_data(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
|
||||
match self.get_mut() {
|
||||
EitherBody::Left(b) => Pin::new(b).poll_data(cx).map(map_option_err),
|
||||
EitherBody::Right(b) => Pin::new(b).poll_data(cx).map(map_option_err),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_trailers(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Option<hyper::HeaderMap>, Self::Error>> {
|
||||
match self.get_mut() {
|
||||
EitherBody::Left(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
|
||||
EitherBody::Right(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn map_option_err<T, U: Into<AnyError>>(err: Option<Result<T, U>>) -> Option<Result<T, AnyError>> {
|
||||
err.map(|e| e.map_err(Into::into))
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ clap.workspace = true
|
||||
fail.workspace = true
|
||||
futures.workspace = true
|
||||
hex.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper0.workspace = true
|
||||
humantime.workspace = true
|
||||
itertools.workspace = true
|
||||
lasso.workspace = true
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use serde::Serialize;
|
||||
use utils::seqwait::MonotonicCounter;
|
||||
|
||||
extern crate hyper0 as hyper;
|
||||
|
||||
mod auth;
|
||||
mod background_node_operations;
|
||||
mod compute_hook;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use clap::Parser;
|
||||
use hyper::Uri;
|
||||
use hyper0::Uri;
|
||||
use metrics::launch_timestamp::LaunchTimestamp;
|
||||
use metrics::BuildInfo;
|
||||
use std::path::PathBuf;
|
||||
@@ -324,7 +324,7 @@ async fn async_main() -> anyhow::Result<()> {
|
||||
|
||||
// Start HTTP server
|
||||
let server_shutdown = CancellationToken::new();
|
||||
let server = hyper::Server::from_tcp(http_listener)?
|
||||
let server = hyper0::Server::from_tcp(http_listener)?
|
||||
.serve(router_service)
|
||||
.with_graceful_shutdown({
|
||||
let server_shutdown = server_shutdown.clone();
|
||||
|
||||
@@ -526,6 +526,21 @@ pub(crate) enum ReconcileResultRequest {
|
||||
Stop,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MutationLocation {
|
||||
node: Node,
|
||||
generation: Generation,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ShardMutationLocations {
|
||||
latest: MutationLocation,
|
||||
other: Vec<MutationLocation>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct TenantMutationLocations(BTreeMap<TenantShardId, ShardMutationLocations>);
|
||||
|
||||
impl Service {
|
||||
pub fn get_config(&self) -> &Config {
|
||||
&self.config
|
||||
@@ -2987,38 +3002,83 @@ impl Service {
|
||||
failpoint_support::sleep_millis_async!("tenant-create-timeline-shared-lock");
|
||||
|
||||
self.tenant_remote_mutation(tenant_id, move |mut targets| async move {
|
||||
if targets.is_empty() {
|
||||
if targets.0.is_empty() {
|
||||
return Err(ApiError::NotFound(
|
||||
anyhow::anyhow!("Tenant not found").into(),
|
||||
));
|
||||
};
|
||||
let shard_zero = targets.remove(0);
|
||||
|
||||
let (shard_zero_tid, shard_zero_locations) =
|
||||
targets.0.pop_first().expect("Must have at least one shard");
|
||||
assert!(shard_zero_tid.is_shard_zero());
|
||||
|
||||
async fn create_one(
|
||||
tenant_shard_id: TenantShardId,
|
||||
node: Node,
|
||||
locations: ShardMutationLocations,
|
||||
jwt: Option<String>,
|
||||
create_req: TimelineCreateRequest,
|
||||
) -> Result<TimelineInfo, ApiError> {
|
||||
let latest = locations.latest.node;
|
||||
|
||||
tracing::info!(
|
||||
"Creating timeline on shard {}/{}, attached to node {node}",
|
||||
"Creating timeline on shard {}/{}, attached to node {latest} in generation {:?}",
|
||||
tenant_shard_id,
|
||||
create_req.new_timeline_id,
|
||||
locations.latest.generation
|
||||
);
|
||||
let client = PageserverClient::new(node.get_id(), node.base_url(), jwt.as_deref());
|
||||
|
||||
client
|
||||
let client =
|
||||
PageserverClient::new(latest.get_id(), latest.base_url(), jwt.as_deref());
|
||||
|
||||
let timeline_info = client
|
||||
.timeline_create(tenant_shard_id, &create_req)
|
||||
.await
|
||||
.map_err(|e| passthrough_api_error(&node, e))
|
||||
.map_err(|e| passthrough_api_error(&latest, e))?;
|
||||
|
||||
// We propagate timeline creations to all attached locations such that a compute
|
||||
// for the new timeline is able to start regardless of the current state of the
|
||||
// tenant shard reconciliation.
|
||||
for location in locations.other {
|
||||
tracing::info!(
|
||||
"Creating timeline on shard {}/{}, stale attached to node {} in generation {:?}",
|
||||
tenant_shard_id,
|
||||
create_req.new_timeline_id,
|
||||
location.node,
|
||||
location.generation
|
||||
);
|
||||
|
||||
let client = PageserverClient::new(
|
||||
location.node.get_id(),
|
||||
location.node.base_url(),
|
||||
jwt.as_deref(),
|
||||
);
|
||||
|
||||
let res = client
|
||||
.timeline_create(tenant_shard_id, &create_req)
|
||||
.await;
|
||||
|
||||
if let Err(e) = res {
|
||||
match e {
|
||||
mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, _) => {
|
||||
// Tenant might have been detached from the stale location,
|
||||
// so ignore 404s.
|
||||
},
|
||||
_ => {
|
||||
return Err(passthrough_api_error(&location.node, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(timeline_info)
|
||||
}
|
||||
|
||||
// Because the caller might not provide an explicit LSN, we must do the creation first on a single shard, and then
|
||||
// use whatever LSN that shard picked when creating on subsequent shards. We arbitrarily use shard zero as the shard
|
||||
// that will get the first creation request, and propagate the LSN to all the >0 shards.
|
||||
let timeline_info = create_one(
|
||||
shard_zero.0,
|
||||
shard_zero.1,
|
||||
shard_zero_tid,
|
||||
shard_zero_locations,
|
||||
self.config.jwt_token.clone(),
|
||||
create_req.clone(),
|
||||
)
|
||||
@@ -3031,14 +3091,24 @@ impl Service {
|
||||
}
|
||||
|
||||
// Create timeline on remaining shards with number >0
|
||||
if !targets.is_empty() {
|
||||
if !targets.0.is_empty() {
|
||||
// If we had multiple shards, issue requests for the remainder now.
|
||||
let jwt = &self.config.jwt_token;
|
||||
self.tenant_for_shards(
|
||||
targets.iter().map(|t| (t.0, t.1.clone())).collect(),
|
||||
|tenant_shard_id: TenantShardId, node: Node| {
|
||||
targets
|
||||
.0
|
||||
.iter()
|
||||
.map(|t| (*t.0, t.1.latest.node.clone()))
|
||||
.collect(),
|
||||
|tenant_shard_id: TenantShardId, _node: Node| {
|
||||
let create_req = create_req.clone();
|
||||
Box::pin(create_one(tenant_shard_id, node, jwt.clone(), create_req))
|
||||
let mutation_locations = targets.0.remove(&tenant_shard_id).unwrap();
|
||||
Box::pin(create_one(
|
||||
tenant_shard_id,
|
||||
mutation_locations,
|
||||
jwt.clone(),
|
||||
create_req,
|
||||
))
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
@@ -3068,7 +3138,7 @@ impl Service {
|
||||
.await;
|
||||
|
||||
self.tenant_remote_mutation(tenant_id, move |targets| async move {
|
||||
if targets.is_empty() {
|
||||
if targets.0.is_empty() {
|
||||
return Err(ApiError::NotFound(
|
||||
anyhow::anyhow!("Tenant not found").into(),
|
||||
));
|
||||
@@ -3099,8 +3169,9 @@ impl Service {
|
||||
|
||||
// no shard needs to go first/last; the operation should be idempotent
|
||||
// TODO: it would be great to ensure that all shards return the same error
|
||||
let locations = targets.0.iter().map(|t| (*t.0, t.1.latest.node.clone())).collect();
|
||||
let results = self
|
||||
.tenant_for_shards(targets, |tenant_shard_id, node| {
|
||||
.tenant_for_shards(locations, |tenant_shard_id, node| {
|
||||
futures::FutureExt::boxed(config_one(
|
||||
tenant_shard_id,
|
||||
timeline_id,
|
||||
@@ -3131,7 +3202,7 @@ impl Service {
|
||||
.await;
|
||||
|
||||
self.tenant_remote_mutation(tenant_id, move |targets| async move {
|
||||
if targets.is_empty() {
|
||||
if targets.0.is_empty() {
|
||||
return Err(ApiError::NotFound(
|
||||
anyhow::anyhow!("Tenant not found").into(),
|
||||
));
|
||||
@@ -3179,8 +3250,9 @@ impl Service {
|
||||
}
|
||||
|
||||
// no shard needs to go first/last; the operation should be idempotent
|
||||
let locations = targets.0.iter().map(|t| (*t.0, t.1.latest.node.clone())).collect();
|
||||
let mut results = self
|
||||
.tenant_for_shards(targets, |tenant_shard_id, node| {
|
||||
.tenant_for_shards(locations, |tenant_shard_id, node| {
|
||||
futures::FutureExt::boxed(detach_one(
|
||||
tenant_shard_id,
|
||||
timeline_id,
|
||||
@@ -3227,7 +3299,7 @@ impl Service {
|
||||
.await;
|
||||
|
||||
self.tenant_remote_mutation(tenant_id, move |targets| async move {
|
||||
if targets.is_empty() {
|
||||
if targets.0.is_empty() {
|
||||
return Err(ApiError::NotFound(
|
||||
anyhow::anyhow!("Tenant not found").into(),
|
||||
));
|
||||
@@ -3249,7 +3321,12 @@ impl Service {
|
||||
}
|
||||
|
||||
// no shard needs to go first/last; the operation should be idempotent
|
||||
self.tenant_for_shards(targets, |tenant_shard_id, node| {
|
||||
let locations = targets
|
||||
.0
|
||||
.iter()
|
||||
.map(|t| (*t.0, t.1.latest.node.clone()))
|
||||
.collect();
|
||||
self.tenant_for_shards(locations, |tenant_shard_id, node| {
|
||||
futures::FutureExt::boxed(do_one(
|
||||
tenant_shard_id,
|
||||
timeline_id,
|
||||
@@ -3344,11 +3421,11 @@ impl Service {
|
||||
op: O,
|
||||
) -> Result<R, ApiError>
|
||||
where
|
||||
O: FnOnce(Vec<(TenantShardId, Node)>) -> F,
|
||||
O: FnOnce(TenantMutationLocations) -> F,
|
||||
F: std::future::Future<Output = R>,
|
||||
{
|
||||
let target_gens = {
|
||||
let mut targets = Vec::new();
|
||||
let mutation_locations = {
|
||||
let mut locations = TenantMutationLocations::default();
|
||||
|
||||
// Load the currently attached pageservers for the latest generation of each shard. This can
|
||||
// run concurrently with reconciliations, and it is not guaranteed that the node we find here
|
||||
@@ -3399,14 +3476,50 @@ impl Service {
|
||||
.ok_or(ApiError::Conflict(format!(
|
||||
"Raced with removal of node {node_id}"
|
||||
)))?;
|
||||
targets.push((tenant_shard_id, node.clone(), generation));
|
||||
let generation = generation.expect("Checked above");
|
||||
|
||||
let tenant = locked.tenants.get(&tenant_shard_id);
|
||||
|
||||
// TODO(vlad): Abstract the logic that finds stale attached locations
|
||||
// from observed state into a [`Service`] method.
|
||||
let other_locations = match tenant {
|
||||
Some(tenant) => {
|
||||
let mut other = tenant.attached_locations();
|
||||
let latest_location_index =
|
||||
other.iter().position(|&l| l == (node.get_id(), generation));
|
||||
if let Some(idx) = latest_location_index {
|
||||
other.remove(idx);
|
||||
}
|
||||
|
||||
other
|
||||
}
|
||||
None => Vec::default(),
|
||||
};
|
||||
|
||||
let location = ShardMutationLocations {
|
||||
latest: MutationLocation {
|
||||
node: node.clone(),
|
||||
generation,
|
||||
},
|
||||
other: other_locations
|
||||
.into_iter()
|
||||
.filter_map(|(node_id, generation)| {
|
||||
let node = locked.nodes.get(&node_id)?;
|
||||
|
||||
Some(MutationLocation {
|
||||
node: node.clone(),
|
||||
generation,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
locations.0.insert(tenant_shard_id, location);
|
||||
}
|
||||
|
||||
targets
|
||||
locations
|
||||
};
|
||||
|
||||
let targets = target_gens.iter().map(|t| (t.0, t.1.clone())).collect();
|
||||
let result = op(targets).await;
|
||||
let result = op(mutation_locations.clone()).await;
|
||||
|
||||
// Post-check: are all the generations of all the shards the same as they were initially? This proves that
|
||||
// our remote operation executed on the latest generation and is therefore persistent.
|
||||
@@ -3422,9 +3535,10 @@ impl Service {
|
||||
}| (tenant_shard_id, generation),
|
||||
)
|
||||
.collect::<Vec<_>>()
|
||||
!= target_gens
|
||||
!= mutation_locations
|
||||
.0
|
||||
.into_iter()
|
||||
.map(|i| (i.0, i.2))
|
||||
.map(|i| (i.0, Some(i.1.latest.generation)))
|
||||
.collect::<Vec<_>>()
|
||||
{
|
||||
// We raced with something that incremented the generation, and therefore cannot be
|
||||
@@ -3454,12 +3568,14 @@ impl Service {
|
||||
.await;
|
||||
|
||||
self.tenant_remote_mutation(tenant_id, move |mut targets| async move {
|
||||
if targets.is_empty() {
|
||||
if targets.0.is_empty() {
|
||||
return Err(ApiError::NotFound(
|
||||
anyhow::anyhow!("Tenant not found").into(),
|
||||
));
|
||||
}
|
||||
let shard_zero = targets.remove(0);
|
||||
|
||||
let (shard_zero_tid, shard_zero_locations) = targets.0.pop_first().expect("Must have at least one shard");
|
||||
assert!(shard_zero_tid.is_shard_zero());
|
||||
|
||||
async fn delete_one(
|
||||
tenant_shard_id: TenantShardId,
|
||||
@@ -3482,8 +3598,9 @@ impl Service {
|
||||
})
|
||||
}
|
||||
|
||||
let locations = targets.0.iter().map(|t| (*t.0, t.1.latest.node.clone())).collect();
|
||||
let statuses = self
|
||||
.tenant_for_shards(targets, |tenant_shard_id: TenantShardId, node: Node| {
|
||||
.tenant_for_shards(locations, |tenant_shard_id: TenantShardId, node: Node| {
|
||||
Box::pin(delete_one(
|
||||
tenant_shard_id,
|
||||
timeline_id,
|
||||
@@ -3501,9 +3618,9 @@ impl Service {
|
||||
// Delete shard zero last: this is not strictly necessary, but since a caller's GET on a timeline will be routed
|
||||
// to shard zero, it gives a more obvious behavior that a GET returns 404 once the deletion is done.
|
||||
let shard_zero_status = delete_one(
|
||||
shard_zero.0,
|
||||
shard_zero_tid,
|
||||
timeline_id,
|
||||
shard_zero.1,
|
||||
shard_zero_locations.latest.node,
|
||||
self.config.jwt_token.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -17,6 +17,7 @@ use crate::{
|
||||
service::ReconcileResultRequest,
|
||||
};
|
||||
use futures::future::{self, Either};
|
||||
use itertools::Itertools;
|
||||
use pageserver_api::controller_api::{
|
||||
AvailabilityZone, NodeSchedulingPolicy, PlacementPolicy, ShardSchedulingPolicy,
|
||||
};
|
||||
@@ -1410,6 +1411,32 @@ impl TenantShard {
|
||||
pub(crate) fn set_preferred_az(&mut self, preferred_az_id: AvailabilityZone) {
|
||||
self.preferred_az_id = Some(preferred_az_id);
|
||||
}
|
||||
|
||||
/// Returns all the nodes to which this tenant shard is attached according to the
|
||||
/// observed state and the generations. Return vector is sorted from latest generation
|
||||
/// to earliest.
|
||||
pub(crate) fn attached_locations(&self) -> Vec<(NodeId, Generation)> {
|
||||
self.observed
|
||||
.locations
|
||||
.iter()
|
||||
.filter_map(|(node_id, observed)| {
|
||||
use LocationConfigMode::{AttachedMulti, AttachedSingle, AttachedStale};
|
||||
|
||||
let conf = observed.conf.as_ref()?;
|
||||
|
||||
match (conf.generation, conf.mode) {
|
||||
(Some(gen), AttachedMulti | AttachedSingle | AttachedStale) => {
|
||||
Some((*node_id, gen))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
.sorted_by(|(_lhs_node_id, lhs_gen), (_rhs_node_id, rhs_gen)| {
|
||||
lhs_gen.cmp(rhs_gen).reverse()
|
||||
})
|
||||
.map(|(node_id, gen)| (node_id, Generation::new(gen)))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -5,6 +5,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
aws-config.workspace = true
|
||||
aws-sdk-s3.workspace = true
|
||||
either.workspace = true
|
||||
anyhow.workspace = true
|
||||
@@ -31,7 +32,6 @@ storage_controller_client.workspace = true
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
|
||||
chrono = { workspace = true, default-features = false, features = ["clock", "serde"] }
|
||||
reqwest = { workspace = true, default-features = false, features = ["rustls-tls", "json"] }
|
||||
aws-config = { workspace = true, default-features = false, features = ["rustls", "sso"] }
|
||||
|
||||
pageserver = { path = "../pageserver" }
|
||||
pageserver_api = { path = "../libs/pageserver_api" }
|
||||
|
||||
@@ -28,8 +28,9 @@ use pageserver::tenant::remote_timeline_client::{remote_tenant_path, remote_time
|
||||
use pageserver::tenant::TENANTS_SEGMENT_NAME;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use remote_storage::{
|
||||
GenericRemoteStorage, Listing, ListingMode, RemotePath, RemoteStorageConfig, RemoteStorageKind,
|
||||
S3Config, DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
|
||||
DownloadOpts, GenericRemoteStorage, Listing, ListingMode, RemotePath, RemoteStorageConfig,
|
||||
RemoteStorageKind, S3Config, DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
|
||||
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
|
||||
};
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -488,7 +489,10 @@ async fn download_object_with_retries(
|
||||
let cancel = CancellationToken::new();
|
||||
for trial in 0..MAX_RETRIES {
|
||||
let mut buf = Vec::new();
|
||||
let download = match remote_client.download(key, &cancel).await {
|
||||
let download = match remote_client
|
||||
.download(key, &DownloadOpts::default(), &cancel)
|
||||
.await
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
error!("Failed to download object for key {key}: {e}");
|
||||
|
||||
662
test_runner/fixtures/neon_cli.py
Normal file
662
test_runner/fixtures/neon_cli.py
Normal file
@@ -0,0 +1,662 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
from itertools import chain, product
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
import toml
|
||||
|
||||
from fixtures.common_types import Lsn, TenantId, TimelineId
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.pageserver.common_types import IndexPartDump
|
||||
from fixtures.pg_version import PgVersion
|
||||
from fixtures.utils import AuxFileStore
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AbstractNeonCli(abc.ABC):
|
||||
"""
|
||||
A typed wrapper around an arbitrary Neon CLI tool.
|
||||
Supports a way to run arbitrary command directly via CLI.
|
||||
Do not use directly, use specific subclasses instead.
|
||||
"""
|
||||
|
||||
def __init__(self, extra_env: Optional[Dict[str, str]], binpath: Path):
|
||||
self.extra_env = extra_env
|
||||
self.binpath = binpath
|
||||
|
||||
COMMAND: str = cast(str, None) # To be overwritten by the derived class.
|
||||
|
||||
def raw_cli(
|
||||
self,
|
||||
arguments: List[str],
|
||||
extra_env_vars: Optional[Dict[str, str]] = None,
|
||||
check_return_code=True,
|
||||
timeout=None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
"""
|
||||
Run the command with the specified arguments.
|
||||
|
||||
Arguments must be in list form, e.g. ['endpoint', 'create']
|
||||
|
||||
Return both stdout and stderr, which can be accessed as
|
||||
|
||||
>>> result = env.neon_cli.raw_cli(...)
|
||||
>>> assert result.stderr == ""
|
||||
>>> log.info(result.stdout)
|
||||
|
||||
If `check_return_code`, on non-zero exit code logs failure and raises.
|
||||
"""
|
||||
|
||||
assert isinstance(arguments, list)
|
||||
assert isinstance(self.COMMAND, str)
|
||||
|
||||
command_path = str(self.binpath / self.COMMAND)
|
||||
|
||||
args = [command_path] + arguments
|
||||
log.info('Running command "{}"'.format(" ".join(args)))
|
||||
|
||||
env_vars = os.environ.copy()
|
||||
|
||||
# extra env
|
||||
for extra_env_key, extra_env_value in (self.extra_env or {}).items():
|
||||
env_vars[extra_env_key] = extra_env_value
|
||||
for extra_env_key, extra_env_value in (extra_env_vars or {}).items():
|
||||
env_vars[extra_env_key] = extra_env_value
|
||||
|
||||
# Pass through coverage settings
|
||||
var = "LLVM_PROFILE_FILE"
|
||||
val = os.environ.get(var)
|
||||
if val:
|
||||
env_vars[var] = val
|
||||
|
||||
# Intercept CalledProcessError and print more info
|
||||
try:
|
||||
res = subprocess.run(
|
||||
args,
|
||||
env=env_vars,
|
||||
check=False,
|
||||
universal_newlines=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
timeout=timeout,
|
||||
)
|
||||
except subprocess.TimeoutExpired as e:
|
||||
if e.stderr:
|
||||
stderr = e.stderr.decode(errors="replace")
|
||||
else:
|
||||
stderr = ""
|
||||
|
||||
if e.stdout:
|
||||
stdout = e.stdout.decode(errors="replace")
|
||||
else:
|
||||
stdout = ""
|
||||
|
||||
log.warn(f"CLI timeout: stderr={stderr}, stdout={stdout}")
|
||||
raise
|
||||
|
||||
indent = " "
|
||||
if not res.returncode:
|
||||
stripped = res.stdout.strip()
|
||||
lines = stripped.splitlines()
|
||||
if len(lines) < 2:
|
||||
log.debug(f"Run {res.args} success: {stripped}")
|
||||
else:
|
||||
log.debug("Run %s success:\n%s" % (res.args, textwrap.indent(stripped, indent)))
|
||||
elif check_return_code:
|
||||
# this way command output will be in recorded and shown in CI in failure message
|
||||
indent = indent * 2
|
||||
msg = textwrap.dedent(
|
||||
"""\
|
||||
Run %s failed:
|
||||
stdout:
|
||||
%s
|
||||
stderr:
|
||||
%s
|
||||
"""
|
||||
)
|
||||
msg = msg % (
|
||||
res.args,
|
||||
textwrap.indent(res.stdout.strip(), indent),
|
||||
textwrap.indent(res.stderr.strip(), indent),
|
||||
)
|
||||
log.info(msg)
|
||||
raise RuntimeError(msg) from subprocess.CalledProcessError(
|
||||
res.returncode, res.args, res.stdout, res.stderr
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
class NeonLocalCli(AbstractNeonCli):
|
||||
"""A typed wrapper around the `neon_local` CLI tool.
|
||||
Supports main commands via typed methods and a way to run arbitrary command directly via CLI.
|
||||
|
||||
Note: The methods in this class are supposed to be faithful wrappers of the underlying
|
||||
'neon_local' commands. If you're tempted to add any logic here, please consider putting it
|
||||
in the caller instead!
|
||||
|
||||
There are a few exceptions where these wrapper methods intentionally differ from the
|
||||
underlying commands, however:
|
||||
- Many 'neon_local' commands take an optional 'tenant_id' argument and use the default from
|
||||
the config file if it's omitted. The corresponding wrappers require an explicit 'tenant_id'
|
||||
argument. The idea is that we don't want to rely on the config file's default in tests,
|
||||
because NeonEnv has its own 'initial_tenant'. They are currently always the same, but we
|
||||
want to rely on the Neonenv's default instead of the config file default in tests.
|
||||
|
||||
- Similarly, --pg_version argument is always required in the wrappers, even when it's
|
||||
optional in the 'neon_local' command. The default in 'neon_local' is a specific
|
||||
hardcoded version, but in tests, we never want to accidentally rely on that;, we
|
||||
always want to use the version from the test fixtures.
|
||||
|
||||
- Wrappers for commands that create a new tenant or timeline ID require the new tenant
|
||||
or timeline ID to be passed by the caller, while the 'neon_local' commands will
|
||||
generate a random ID if it's not specified. This is because we don't want to have to
|
||||
parse the ID from the 'neon_local' output. Making it required ensures that the
|
||||
caller has to generate it.
|
||||
"""
|
||||
|
||||
COMMAND = "neon_local"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
extra_env: Optional[Dict[str, str]],
|
||||
binpath: Path,
|
||||
repo_dir: Path,
|
||||
pg_distrib_dir: Path,
|
||||
):
|
||||
if extra_env is None:
|
||||
env_vars = {}
|
||||
else:
|
||||
env_vars = extra_env.copy()
|
||||
env_vars["NEON_REPO_DIR"] = str(repo_dir)
|
||||
env_vars["POSTGRES_DISTRIB_DIR"] = str(pg_distrib_dir)
|
||||
|
||||
super().__init__(env_vars, binpath)
|
||||
|
||||
def raw_cli(self, *args, **kwargs) -> subprocess.CompletedProcess[str]:
|
||||
return super().raw_cli(*args, **kwargs)
|
||||
|
||||
def tenant_create(
|
||||
self,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
pg_version: PgVersion,
|
||||
conf: Optional[Dict[str, Any]] = None,
|
||||
shard_count: Optional[int] = None,
|
||||
shard_stripe_size: Optional[int] = None,
|
||||
placement_policy: Optional[str] = None,
|
||||
set_default: bool = False,
|
||||
aux_file_policy: Optional[AuxFileStore] = None,
|
||||
):
|
||||
"""
|
||||
Creates a new tenant, returns its id and its initial timeline's id.
|
||||
"""
|
||||
args = [
|
||||
"tenant",
|
||||
"create",
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--pg-version",
|
||||
pg_version,
|
||||
]
|
||||
if conf is not None:
|
||||
args.extend(
|
||||
chain.from_iterable(
|
||||
product(["-c"], (f"{key}:{value}" for key, value in conf.items()))
|
||||
)
|
||||
)
|
||||
|
||||
if aux_file_policy is AuxFileStore.V2:
|
||||
args.extend(["-c", "switch_aux_file_policy:v2"])
|
||||
elif aux_file_policy is AuxFileStore.V1:
|
||||
args.extend(["-c", "switch_aux_file_policy:v1"])
|
||||
elif aux_file_policy is AuxFileStore.CrossValidation:
|
||||
args.extend(["-c", "switch_aux_file_policy:cross-validation"])
|
||||
|
||||
if set_default:
|
||||
args.append("--set-default")
|
||||
|
||||
if shard_count is not None:
|
||||
args.extend(["--shard-count", str(shard_count)])
|
||||
|
||||
if shard_stripe_size is not None:
|
||||
args.extend(["--shard-stripe-size", str(shard_stripe_size)])
|
||||
|
||||
if placement_policy is not None:
|
||||
args.extend(["--placement-policy", str(placement_policy)])
|
||||
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
|
||||
def tenant_import(self, tenant_id: TenantId):
|
||||
args = ["tenant", "import", "--tenant-id", str(tenant_id)]
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
|
||||
def tenant_set_default(self, tenant_id: TenantId):
|
||||
"""
|
||||
Update default tenant for future operations that require tenant_id.
|
||||
"""
|
||||
res = self.raw_cli(["tenant", "set-default", "--tenant-id", str(tenant_id)])
|
||||
res.check_returncode()
|
||||
|
||||
def tenant_config(self, tenant_id: TenantId, conf: Dict[str, str]):
|
||||
"""
|
||||
Update tenant config.
|
||||
"""
|
||||
|
||||
args = ["tenant", "config", "--tenant-id", str(tenant_id)]
|
||||
if conf is not None:
|
||||
args.extend(
|
||||
chain.from_iterable(
|
||||
product(["-c"], (f"{key}:{value}" for key, value in conf.items()))
|
||||
)
|
||||
)
|
||||
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
|
||||
def tenant_list(self) -> "subprocess.CompletedProcess[str]":
|
||||
res = self.raw_cli(["tenant", "list"])
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
def timeline_create(
|
||||
self,
|
||||
new_branch_name: str,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
pg_version: PgVersion,
|
||||
) -> TimelineId:
|
||||
if timeline_id is None:
|
||||
timeline_id = TimelineId.generate()
|
||||
|
||||
cmd = [
|
||||
"timeline",
|
||||
"create",
|
||||
"--branch-name",
|
||||
new_branch_name,
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--pg-version",
|
||||
pg_version,
|
||||
]
|
||||
|
||||
res = self.raw_cli(cmd)
|
||||
res.check_returncode()
|
||||
|
||||
return timeline_id
|
||||
|
||||
def timeline_branch(
|
||||
self,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
new_branch_name,
|
||||
ancestor_branch_name: Optional[str] = None,
|
||||
ancestor_start_lsn: Optional[Lsn] = None,
|
||||
):
|
||||
cmd = [
|
||||
"timeline",
|
||||
"branch",
|
||||
"--branch-name",
|
||||
new_branch_name,
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
]
|
||||
if ancestor_branch_name is not None:
|
||||
cmd.extend(["--ancestor-branch-name", ancestor_branch_name])
|
||||
if ancestor_start_lsn is not None:
|
||||
cmd.extend(["--ancestor-start-lsn", str(ancestor_start_lsn)])
|
||||
|
||||
res = self.raw_cli(cmd)
|
||||
res.check_returncode()
|
||||
|
||||
def timeline_import(
|
||||
self,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
new_branch_name: str,
|
||||
base_lsn: Lsn,
|
||||
base_tarfile: Path,
|
||||
pg_version: PgVersion,
|
||||
end_lsn: Optional[Lsn] = None,
|
||||
wal_tarfile: Optional[Path] = None,
|
||||
):
|
||||
cmd = [
|
||||
"timeline",
|
||||
"import",
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--pg-version",
|
||||
pg_version,
|
||||
"--branch-name",
|
||||
new_branch_name,
|
||||
"--base-lsn",
|
||||
str(base_lsn),
|
||||
"--base-tarfile",
|
||||
str(base_tarfile),
|
||||
]
|
||||
if end_lsn is not None:
|
||||
cmd.extend(["--end-lsn", str(end_lsn)])
|
||||
if wal_tarfile is not None:
|
||||
cmd.extend(["--wal-tarfile", str(wal_tarfile)])
|
||||
|
||||
res = self.raw_cli(cmd)
|
||||
res.check_returncode()
|
||||
|
||||
def timeline_list(self, tenant_id: TenantId) -> List[Tuple[str, TimelineId]]:
|
||||
"""
|
||||
Returns a list of (branch_name, timeline_id) tuples out of parsed `neon timeline list` CLI output.
|
||||
"""
|
||||
|
||||
# main [b49f7954224a0ad25cc0013ea107b54b]
|
||||
# ┣━ @0/16B5A50: test_cli_branch_list_main [20f98c79111b9015d84452258b7d5540]
|
||||
TIMELINE_DATA_EXTRACTOR: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
r"\s?(?P<branch_name>[^\s]+)\s\[(?P<timeline_id>[^\]]+)\]", re.MULTILINE
|
||||
)
|
||||
res = self.raw_cli(["timeline", "list", "--tenant-id", str(tenant_id)])
|
||||
timelines_cli = sorted(
|
||||
map(
|
||||
lambda branch_and_id: (branch_and_id[0], TimelineId(branch_and_id[1])),
|
||||
TIMELINE_DATA_EXTRACTOR.findall(res.stdout),
|
||||
)
|
||||
)
|
||||
return timelines_cli
|
||||
|
||||
def init(
|
||||
self,
|
||||
init_config: Dict[str, Any],
|
||||
force: Optional[str] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as init_config_tmpfile:
|
||||
init_config_tmpfile.write(toml.dumps(init_config))
|
||||
init_config_tmpfile.flush()
|
||||
|
||||
cmd = [
|
||||
"init",
|
||||
f"--config={init_config_tmpfile.name}",
|
||||
]
|
||||
|
||||
if force is not None:
|
||||
cmd.extend(["--force", force])
|
||||
|
||||
res = self.raw_cli(cmd)
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
def storage_controller_start(
|
||||
self,
|
||||
timeout_in_seconds: Optional[int] = None,
|
||||
instance_id: Optional[int] = None,
|
||||
base_port: Optional[int] = None,
|
||||
):
|
||||
cmd = ["storage_controller", "start"]
|
||||
if timeout_in_seconds is not None:
|
||||
cmd.append(f"--start-timeout={timeout_in_seconds}s")
|
||||
if instance_id is not None:
|
||||
cmd.append(f"--instance-id={instance_id}")
|
||||
if base_port is not None:
|
||||
cmd.append(f"--base-port={base_port}")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def storage_controller_stop(self, immediate: bool, instance_id: Optional[int] = None):
|
||||
cmd = ["storage_controller", "stop"]
|
||||
if immediate:
|
||||
cmd.extend(["-m", "immediate"])
|
||||
if instance_id is not None:
|
||||
cmd.append(f"--instance-id={instance_id}")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def pageserver_start(
|
||||
self,
|
||||
id: int,
|
||||
extra_env_vars: Optional[Dict[str, str]] = None,
|
||||
timeout_in_seconds: Optional[int] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
start_args = ["pageserver", "start", f"--id={id}"]
|
||||
if timeout_in_seconds is not None:
|
||||
start_args.append(f"--start-timeout={timeout_in_seconds}s")
|
||||
return self.raw_cli(start_args, extra_env_vars=extra_env_vars)
|
||||
|
||||
def pageserver_stop(self, id: int, immediate=False) -> "subprocess.CompletedProcess[str]":
|
||||
cmd = ["pageserver", "stop", f"--id={id}"]
|
||||
if immediate:
|
||||
cmd.extend(["-m", "immediate"])
|
||||
|
||||
log.info(f"Stopping pageserver with {cmd}")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def safekeeper_start(
|
||||
self,
|
||||
id: int,
|
||||
extra_opts: Optional[List[str]] = None,
|
||||
extra_env_vars: Optional[Dict[str, str]] = None,
|
||||
timeout_in_seconds: Optional[int] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
if extra_opts is not None:
|
||||
extra_opts = [f"-e={opt}" for opt in extra_opts]
|
||||
else:
|
||||
extra_opts = []
|
||||
if timeout_in_seconds is not None:
|
||||
extra_opts.append(f"--start-timeout={timeout_in_seconds}s")
|
||||
return self.raw_cli(
|
||||
["safekeeper", "start", str(id), *extra_opts], extra_env_vars=extra_env_vars
|
||||
)
|
||||
|
||||
def safekeeper_stop(
|
||||
self, id: Optional[int] = None, immediate=False
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = ["safekeeper", "stop"]
|
||||
if id is not None:
|
||||
args.append(str(id))
|
||||
if immediate:
|
||||
args.extend(["-m", "immediate"])
|
||||
return self.raw_cli(args)
|
||||
|
||||
def storage_broker_start(
|
||||
self, timeout_in_seconds: Optional[int] = None
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
cmd = ["storage_broker", "start"]
|
||||
if timeout_in_seconds is not None:
|
||||
cmd.append(f"--start-timeout={timeout_in_seconds}s")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def storage_broker_stop(self) -> "subprocess.CompletedProcess[str]":
|
||||
cmd = ["storage_broker", "stop"]
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def endpoint_create(
|
||||
self,
|
||||
branch_name: str,
|
||||
pg_port: int,
|
||||
http_port: int,
|
||||
tenant_id: TenantId,
|
||||
pg_version: PgVersion,
|
||||
endpoint_id: Optional[str] = None,
|
||||
hot_standby: bool = False,
|
||||
lsn: Optional[Lsn] = None,
|
||||
pageserver_id: Optional[int] = None,
|
||||
allow_multiple=False,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = [
|
||||
"endpoint",
|
||||
"create",
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--branch-name",
|
||||
branch_name,
|
||||
"--pg-version",
|
||||
pg_version,
|
||||
]
|
||||
if lsn is not None:
|
||||
args.extend(["--lsn", str(lsn)])
|
||||
if pg_port is not None:
|
||||
args.extend(["--pg-port", str(pg_port)])
|
||||
if http_port is not None:
|
||||
args.extend(["--http-port", str(http_port)])
|
||||
if endpoint_id is not None:
|
||||
args.append(endpoint_id)
|
||||
if hot_standby:
|
||||
args.extend(["--hot-standby", "true"])
|
||||
if pageserver_id is not None:
|
||||
args.extend(["--pageserver-id", str(pageserver_id)])
|
||||
if allow_multiple:
|
||||
args.extend(["--allow-multiple"])
|
||||
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
def endpoint_start(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
safekeepers: Optional[List[int]] = None,
|
||||
remote_ext_config: Optional[str] = None,
|
||||
pageserver_id: Optional[int] = None,
|
||||
allow_multiple=False,
|
||||
basebackup_request_tries: Optional[int] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = [
|
||||
"endpoint",
|
||||
"start",
|
||||
]
|
||||
extra_env_vars = {}
|
||||
if basebackup_request_tries is not None:
|
||||
extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_TRIES"] = str(basebackup_request_tries)
|
||||
if remote_ext_config is not None:
|
||||
args.extend(["--remote-ext-config", remote_ext_config])
|
||||
|
||||
if safekeepers is not None:
|
||||
args.extend(["--safekeepers", (",".join(map(str, safekeepers)))])
|
||||
if endpoint_id is not None:
|
||||
args.append(endpoint_id)
|
||||
if pageserver_id is not None:
|
||||
args.extend(["--pageserver-id", str(pageserver_id)])
|
||||
if allow_multiple:
|
||||
args.extend(["--allow-multiple"])
|
||||
|
||||
res = self.raw_cli(args, extra_env_vars)
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
def endpoint_reconfigure(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
pageserver_id: Optional[int] = None,
|
||||
safekeepers: Optional[List[int]] = None,
|
||||
check_return_code=True,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = ["endpoint", "reconfigure", endpoint_id]
|
||||
if tenant_id is not None:
|
||||
args.extend(["--tenant-id", str(tenant_id)])
|
||||
if pageserver_id is not None:
|
||||
args.extend(["--pageserver-id", str(pageserver_id)])
|
||||
if safekeepers is not None:
|
||||
args.extend(["--safekeepers", (",".join(map(str, safekeepers)))])
|
||||
return self.raw_cli(args, check_return_code=check_return_code)
|
||||
|
||||
def endpoint_stop(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
destroy=False,
|
||||
check_return_code=True,
|
||||
mode: Optional[str] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = [
|
||||
"endpoint",
|
||||
"stop",
|
||||
]
|
||||
if destroy:
|
||||
args.append("--destroy")
|
||||
if mode is not None:
|
||||
args.append(f"--mode={mode}")
|
||||
if endpoint_id is not None:
|
||||
args.append(endpoint_id)
|
||||
|
||||
return self.raw_cli(args, check_return_code=check_return_code)
|
||||
|
||||
def mappings_map_branch(
|
||||
self, name: str, tenant_id: TenantId, timeline_id: TimelineId
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
"""
|
||||
Map tenant id and timeline id to a neon_local branch name. They do not have to exist.
|
||||
Usually needed when creating branches via PageserverHttpClient and not neon_local.
|
||||
|
||||
After creating a name mapping, you can use EndpointFactory.create_start
|
||||
with this registered branch name.
|
||||
"""
|
||||
args = [
|
||||
"mappings",
|
||||
"map",
|
||||
"--branch-name",
|
||||
name,
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
]
|
||||
|
||||
return self.raw_cli(args, check_return_code=True)
|
||||
|
||||
def start(self, check_return_code=True) -> "subprocess.CompletedProcess[str]":
|
||||
return self.raw_cli(["start"], check_return_code=check_return_code)
|
||||
|
||||
def stop(self, check_return_code=True) -> "subprocess.CompletedProcess[str]":
|
||||
return self.raw_cli(["stop"], check_return_code=check_return_code)
|
||||
|
||||
|
||||
class WalCraft(AbstractNeonCli):
|
||||
"""
|
||||
A typed wrapper around the `wal_craft` CLI tool.
|
||||
Supports main commands via typed methods and a way to run arbitrary command directly via CLI.
|
||||
"""
|
||||
|
||||
COMMAND = "wal_craft"
|
||||
|
||||
def postgres_config(self) -> List[str]:
|
||||
res = self.raw_cli(["print-postgres-config"])
|
||||
res.check_returncode()
|
||||
return res.stdout.split("\n")
|
||||
|
||||
def in_existing(self, type: str, connection: str) -> None:
|
||||
res = self.raw_cli(["in-existing", type, connection])
|
||||
res.check_returncode()
|
||||
|
||||
|
||||
class Pagectl(AbstractNeonCli):
|
||||
"""
|
||||
A typed wrapper around the `pagectl` utility CLI tool.
|
||||
"""
|
||||
|
||||
COMMAND = "pagectl"
|
||||
|
||||
def dump_index_part(self, path: Path) -> IndexPartDump:
|
||||
res = self.raw_cli(["index-part", "dump", str(path)])
|
||||
res.check_returncode()
|
||||
parsed = json.loads(res.stdout)
|
||||
return IndexPartDump.from_json(parsed)
|
||||
@@ -9,8 +9,6 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import textwrap
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@@ -21,7 +19,6 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from fcntl import LOCK_EX, LOCK_UN, flock
|
||||
from functools import cached_property
|
||||
from itertools import chain, product
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
@@ -64,11 +61,12 @@ from fixtures.common_types import Lsn, NodeId, TenantId, TenantShardId, Timeline
|
||||
from fixtures.endpoint.http import EndpointHttpClient
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
|
||||
from fixtures.neon_cli import NeonLocalCli, Pagectl
|
||||
from fixtures.pageserver.allowed_errors import (
|
||||
DEFAULT_PAGESERVER_ALLOWED_ERRORS,
|
||||
DEFAULT_STORAGE_CONTROLLER_ALLOWED_ERRORS,
|
||||
)
|
||||
from fixtures.pageserver.common_types import IndexPartDump, LayerName, parse_layer_file_name
|
||||
from fixtures.pageserver.common_types import LayerName, parse_layer_file_name
|
||||
from fixtures.pageserver.http import PageserverHttpClient
|
||||
from fixtures.pageserver.utils import (
|
||||
wait_for_last_record_lsn,
|
||||
@@ -491,7 +489,7 @@ class NeonEnvBuilder:
|
||||
log.debug(
|
||||
f"Services started, creating initial tenant {env.initial_tenant} and its initial timeline"
|
||||
)
|
||||
initial_tenant, initial_timeline = env.neon_cli.create_tenant(
|
||||
initial_tenant, initial_timeline = env.create_tenant(
|
||||
tenant_id=env.initial_tenant,
|
||||
conf=initial_tenant_conf,
|
||||
timeline_id=env.initial_timeline,
|
||||
@@ -952,10 +950,16 @@ class NeonEnv:
|
||||
|
||||
initial_tenant - tenant ID of the initial tenant created in the repository
|
||||
|
||||
neon_cli - can be used to run the 'neon' CLI tool
|
||||
neon_cli - can be used to run the 'neon_local' CLI tool
|
||||
|
||||
create_tenant() - initializes a new tenant in the page server, returns
|
||||
the tenant id
|
||||
create_tenant() - initializes a new tenant and an initial empty timeline on it,
|
||||
returns the tenant and timeline id
|
||||
|
||||
create_branch() - branch a new timeline from an existing one, returns
|
||||
the new timeline id
|
||||
|
||||
create_timeline() - initializes a new timeline by running initdb, returns
|
||||
the new timeline id
|
||||
"""
|
||||
|
||||
BASE_PAGESERVER_ID = 1
|
||||
@@ -966,8 +970,6 @@ class NeonEnv:
|
||||
self.rust_log_override = config.rust_log_override
|
||||
self.port_distributor = config.port_distributor
|
||||
self.s3_mock_server = config.mock_s3_server
|
||||
self.neon_cli = NeonCli(env=self)
|
||||
self.pagectl = Pagectl(env=self)
|
||||
self.endpoints = EndpointFactory(self)
|
||||
self.safekeepers: List[Safekeeper] = []
|
||||
self.pageservers: List[NeonPageserver] = []
|
||||
@@ -987,6 +989,21 @@ class NeonEnv:
|
||||
self.initial_tenant = config.initial_tenant
|
||||
self.initial_timeline = config.initial_timeline
|
||||
|
||||
neon_local_env_vars = {}
|
||||
if self.rust_log_override is not None:
|
||||
neon_local_env_vars["RUST_LOG"] = self.rust_log_override
|
||||
self.neon_cli = NeonLocalCli(
|
||||
extra_env=neon_local_env_vars,
|
||||
binpath=self.neon_local_binpath,
|
||||
repo_dir=self.repo_dir,
|
||||
pg_distrib_dir=self.pg_distrib_dir,
|
||||
)
|
||||
|
||||
pagectl_env_vars = {}
|
||||
if self.rust_log_override is not None:
|
||||
pagectl_env_vars["RUST_LOG"] = self.rust_log_override
|
||||
self.pagectl = Pagectl(extra_env=pagectl_env_vars, binpath=self.neon_binpath)
|
||||
|
||||
# The URL for the pageserver to use as its control_plane_api config
|
||||
if config.storage_controller_port_override is not None:
|
||||
log.info(
|
||||
@@ -1310,6 +1327,74 @@ class NeonEnv:
|
||||
self.endpoint_counter += 1
|
||||
return "ep-" + str(self.endpoint_counter)
|
||||
|
||||
def create_tenant(
|
||||
self,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
timeline_id: Optional[TimelineId] = None,
|
||||
conf: Optional[Dict[str, Any]] = None,
|
||||
shard_count: Optional[int] = None,
|
||||
shard_stripe_size: Optional[int] = None,
|
||||
placement_policy: Optional[str] = None,
|
||||
set_default: bool = False,
|
||||
aux_file_policy: Optional[AuxFileStore] = None,
|
||||
) -> Tuple[TenantId, TimelineId]:
|
||||
"""
|
||||
Creates a new tenant, returns its id and its initial timeline's id.
|
||||
"""
|
||||
tenant_id = tenant_id or TenantId.generate()
|
||||
timeline_id = timeline_id or TimelineId.generate()
|
||||
|
||||
self.neon_cli.tenant_create(
|
||||
tenant_id=tenant_id,
|
||||
timeline_id=timeline_id,
|
||||
pg_version=self.pg_version,
|
||||
conf=conf,
|
||||
shard_count=shard_count,
|
||||
shard_stripe_size=shard_stripe_size,
|
||||
placement_policy=placement_policy,
|
||||
set_default=set_default,
|
||||
aux_file_policy=aux_file_policy,
|
||||
)
|
||||
|
||||
return tenant_id, timeline_id
|
||||
|
||||
def config_tenant(self, tenant_id: Optional[TenantId], conf: Dict[str, str]):
|
||||
"""
|
||||
Update tenant config.
|
||||
"""
|
||||
tenant_id = tenant_id or self.initial_tenant
|
||||
self.neon_cli.tenant_config(tenant_id, conf)
|
||||
|
||||
def create_branch(
|
||||
self,
|
||||
new_branch_name: str = DEFAULT_BRANCH_NAME,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
ancestor_branch_name: Optional[str] = None,
|
||||
ancestor_start_lsn: Optional[Lsn] = None,
|
||||
new_timeline_id: Optional[TimelineId] = None,
|
||||
) -> TimelineId:
|
||||
new_timeline_id = new_timeline_id or TimelineId.generate()
|
||||
tenant_id = tenant_id or self.initial_tenant
|
||||
|
||||
self.neon_cli.timeline_branch(
|
||||
tenant_id, new_timeline_id, new_branch_name, ancestor_branch_name, ancestor_start_lsn
|
||||
)
|
||||
|
||||
return new_timeline_id
|
||||
|
||||
def create_timeline(
|
||||
self,
|
||||
new_branch_name: str,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
timeline_id: Optional[TimelineId] = None,
|
||||
) -> TimelineId:
|
||||
timeline_id = timeline_id or TimelineId.generate()
|
||||
tenant_id = tenant_id or self.initial_tenant
|
||||
|
||||
self.neon_cli.timeline_create(new_branch_name, tenant_id, timeline_id, self.pg_version)
|
||||
|
||||
return timeline_id
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def neon_simple_env(
|
||||
@@ -1425,597 +1510,6 @@ class PageserverPort:
|
||||
http: int
|
||||
|
||||
|
||||
class AbstractNeonCli(abc.ABC):
|
||||
"""
|
||||
A typed wrapper around an arbitrary Neon CLI tool.
|
||||
Supports a way to run arbitrary command directly via CLI.
|
||||
Do not use directly, use specific subclasses instead.
|
||||
"""
|
||||
|
||||
def __init__(self, env: NeonEnv):
|
||||
self.env = env
|
||||
|
||||
COMMAND: str = cast(str, None) # To be overwritten by the derived class.
|
||||
|
||||
def raw_cli(
|
||||
self,
|
||||
arguments: List[str],
|
||||
extra_env_vars: Optional[Dict[str, str]] = None,
|
||||
check_return_code=True,
|
||||
timeout=None,
|
||||
local_binpath=False,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
"""
|
||||
Run the command with the specified arguments.
|
||||
|
||||
Arguments must be in list form, e.g. ['pg', 'create']
|
||||
|
||||
Return both stdout and stderr, which can be accessed as
|
||||
|
||||
>>> result = env.neon_cli.raw_cli(...)
|
||||
>>> assert result.stderr == ""
|
||||
>>> log.info(result.stdout)
|
||||
|
||||
If `check_return_code`, on non-zero exit code logs failure and raises.
|
||||
|
||||
If `local_binpath` is true, then we are invoking a test utility
|
||||
"""
|
||||
|
||||
assert isinstance(arguments, list)
|
||||
assert isinstance(self.COMMAND, str)
|
||||
|
||||
if local_binpath:
|
||||
# Test utility
|
||||
bin_neon = str(self.env.neon_local_binpath / self.COMMAND)
|
||||
else:
|
||||
# Normal binary
|
||||
bin_neon = str(self.env.neon_binpath / self.COMMAND)
|
||||
|
||||
args = [bin_neon] + arguments
|
||||
log.info('Running command "{}"'.format(" ".join(args)))
|
||||
|
||||
env_vars = os.environ.copy()
|
||||
env_vars["NEON_REPO_DIR"] = str(self.env.repo_dir)
|
||||
env_vars["POSTGRES_DISTRIB_DIR"] = str(self.env.pg_distrib_dir)
|
||||
if self.env.rust_log_override is not None:
|
||||
env_vars["RUST_LOG"] = self.env.rust_log_override
|
||||
for extra_env_key, extra_env_value in (extra_env_vars or {}).items():
|
||||
env_vars[extra_env_key] = extra_env_value
|
||||
|
||||
# Pass coverage settings
|
||||
var = "LLVM_PROFILE_FILE"
|
||||
val = os.environ.get(var)
|
||||
if val:
|
||||
env_vars[var] = val
|
||||
|
||||
# Intercept CalledProcessError and print more info
|
||||
try:
|
||||
res = subprocess.run(
|
||||
args,
|
||||
env=env_vars,
|
||||
check=False,
|
||||
universal_newlines=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
timeout=timeout,
|
||||
)
|
||||
except subprocess.TimeoutExpired as e:
|
||||
if e.stderr:
|
||||
stderr = e.stderr.decode(errors="replace")
|
||||
else:
|
||||
stderr = ""
|
||||
|
||||
if e.stdout:
|
||||
stdout = e.stdout.decode(errors="replace")
|
||||
else:
|
||||
stdout = ""
|
||||
|
||||
log.warn(f"CLI timeout: stderr={stderr}, stdout={stdout}")
|
||||
raise
|
||||
|
||||
indent = " "
|
||||
if not res.returncode:
|
||||
stripped = res.stdout.strip()
|
||||
lines = stripped.splitlines()
|
||||
if len(lines) < 2:
|
||||
log.debug(f"Run {res.args} success: {stripped}")
|
||||
else:
|
||||
log.debug("Run %s success:\n%s" % (res.args, textwrap.indent(stripped, indent)))
|
||||
elif check_return_code:
|
||||
# this way command output will be in recorded and shown in CI in failure message
|
||||
indent = indent * 2
|
||||
msg = textwrap.dedent(
|
||||
"""\
|
||||
Run %s failed:
|
||||
stdout:
|
||||
%s
|
||||
stderr:
|
||||
%s
|
||||
"""
|
||||
)
|
||||
msg = msg % (
|
||||
res.args,
|
||||
textwrap.indent(res.stdout.strip(), indent),
|
||||
textwrap.indent(res.stderr.strip(), indent),
|
||||
)
|
||||
log.info(msg)
|
||||
raise RuntimeError(msg) from subprocess.CalledProcessError(
|
||||
res.returncode, res.args, res.stdout, res.stderr
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
class NeonCli(AbstractNeonCli):
|
||||
"""
|
||||
A typed wrapper around the `neon` CLI tool.
|
||||
Supports main commands via typed methods and a way to run arbitrary command directly via CLI.
|
||||
"""
|
||||
|
||||
COMMAND = "neon_local"
|
||||
|
||||
def raw_cli(self, *args, **kwargs) -> subprocess.CompletedProcess[str]:
|
||||
kwargs["local_binpath"] = True
|
||||
return super().raw_cli(*args, **kwargs)
|
||||
|
||||
def create_tenant(
|
||||
self,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
timeline_id: Optional[TimelineId] = None,
|
||||
conf: Optional[Dict[str, Any]] = None,
|
||||
shard_count: Optional[int] = None,
|
||||
shard_stripe_size: Optional[int] = None,
|
||||
placement_policy: Optional[str] = None,
|
||||
set_default: bool = False,
|
||||
aux_file_policy: Optional[AuxFileStore] = None,
|
||||
) -> Tuple[TenantId, TimelineId]:
|
||||
"""
|
||||
Creates a new tenant, returns its id and its initial timeline's id.
|
||||
"""
|
||||
tenant_id = tenant_id or TenantId.generate()
|
||||
timeline_id = timeline_id or TimelineId.generate()
|
||||
|
||||
args = [
|
||||
"tenant",
|
||||
"create",
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--pg-version",
|
||||
self.env.pg_version,
|
||||
]
|
||||
if conf is not None:
|
||||
args.extend(
|
||||
chain.from_iterable(
|
||||
product(["-c"], (f"{key}:{value}" for key, value in conf.items()))
|
||||
)
|
||||
)
|
||||
|
||||
if aux_file_policy is AuxFileStore.V2:
|
||||
args.extend(["-c", "switch_aux_file_policy:v2"])
|
||||
elif aux_file_policy is AuxFileStore.V1:
|
||||
args.extend(["-c", "switch_aux_file_policy:v1"])
|
||||
elif aux_file_policy is AuxFileStore.CrossValidation:
|
||||
args.extend(["-c", "switch_aux_file_policy:cross-validation"])
|
||||
|
||||
if set_default:
|
||||
args.append("--set-default")
|
||||
|
||||
if shard_count is not None:
|
||||
args.extend(["--shard-count", str(shard_count)])
|
||||
|
||||
if shard_stripe_size is not None:
|
||||
args.extend(["--shard-stripe-size", str(shard_stripe_size)])
|
||||
|
||||
if placement_policy is not None:
|
||||
args.extend(["--placement-policy", str(placement_policy)])
|
||||
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
return tenant_id, timeline_id
|
||||
|
||||
def import_tenant(self, tenant_id: TenantId):
|
||||
args = ["tenant", "import", "--tenant-id", str(tenant_id)]
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
|
||||
def set_default(self, tenant_id: TenantId):
|
||||
"""
|
||||
Update default tenant for future operations that require tenant_id.
|
||||
"""
|
||||
res = self.raw_cli(["tenant", "set-default", "--tenant-id", str(tenant_id)])
|
||||
res.check_returncode()
|
||||
|
||||
def config_tenant(self, tenant_id: TenantId, conf: Dict[str, str]):
|
||||
"""
|
||||
Update tenant config.
|
||||
"""
|
||||
|
||||
args = ["tenant", "config", "--tenant-id", str(tenant_id)]
|
||||
if conf is not None:
|
||||
args.extend(
|
||||
chain.from_iterable(
|
||||
product(["-c"], (f"{key}:{value}" for key, value in conf.items()))
|
||||
)
|
||||
)
|
||||
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
|
||||
def list_tenants(self) -> "subprocess.CompletedProcess[str]":
|
||||
res = self.raw_cli(["tenant", "list"])
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
def create_timeline(
|
||||
self,
|
||||
new_branch_name: str,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
timeline_id: Optional[TimelineId] = None,
|
||||
) -> TimelineId:
|
||||
if timeline_id is None:
|
||||
timeline_id = TimelineId.generate()
|
||||
|
||||
cmd = [
|
||||
"timeline",
|
||||
"create",
|
||||
"--branch-name",
|
||||
new_branch_name,
|
||||
"--tenant-id",
|
||||
str(tenant_id or self.env.initial_tenant),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--pg-version",
|
||||
self.env.pg_version,
|
||||
]
|
||||
|
||||
res = self.raw_cli(cmd)
|
||||
res.check_returncode()
|
||||
|
||||
return timeline_id
|
||||
|
||||
def create_branch(
|
||||
self,
|
||||
new_branch_name: str = DEFAULT_BRANCH_NAME,
|
||||
ancestor_branch_name: Optional[str] = None,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
ancestor_start_lsn: Optional[Lsn] = None,
|
||||
new_timeline_id: Optional[TimelineId] = None,
|
||||
) -> TimelineId:
|
||||
if new_timeline_id is None:
|
||||
new_timeline_id = TimelineId.generate()
|
||||
cmd = [
|
||||
"timeline",
|
||||
"branch",
|
||||
"--branch-name",
|
||||
new_branch_name,
|
||||
"--timeline-id",
|
||||
str(new_timeline_id),
|
||||
"--tenant-id",
|
||||
str(tenant_id or self.env.initial_tenant),
|
||||
]
|
||||
if ancestor_branch_name is not None:
|
||||
cmd.extend(["--ancestor-branch-name", ancestor_branch_name])
|
||||
if ancestor_start_lsn is not None:
|
||||
cmd.extend(["--ancestor-start-lsn", str(ancestor_start_lsn)])
|
||||
|
||||
res = self.raw_cli(cmd)
|
||||
res.check_returncode()
|
||||
|
||||
return TimelineId(str(new_timeline_id))
|
||||
|
||||
def list_timelines(self, tenant_id: Optional[TenantId] = None) -> List[Tuple[str, TimelineId]]:
|
||||
"""
|
||||
Returns a list of (branch_name, timeline_id) tuples out of parsed `neon timeline list` CLI output.
|
||||
"""
|
||||
|
||||
# main [b49f7954224a0ad25cc0013ea107b54b]
|
||||
# ┣━ @0/16B5A50: test_cli_branch_list_main [20f98c79111b9015d84452258b7d5540]
|
||||
TIMELINE_DATA_EXTRACTOR: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
r"\s?(?P<branch_name>[^\s]+)\s\[(?P<timeline_id>[^\]]+)\]", re.MULTILINE
|
||||
)
|
||||
res = self.raw_cli(
|
||||
["timeline", "list", "--tenant-id", str(tenant_id or self.env.initial_tenant)]
|
||||
)
|
||||
timelines_cli = sorted(
|
||||
map(
|
||||
lambda branch_and_id: (branch_and_id[0], TimelineId(branch_and_id[1])),
|
||||
TIMELINE_DATA_EXTRACTOR.findall(res.stdout),
|
||||
)
|
||||
)
|
||||
return timelines_cli
|
||||
|
||||
def init(
|
||||
self,
|
||||
init_config: Dict[str, Any],
|
||||
force: Optional[str] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as init_config_tmpfile:
|
||||
init_config_tmpfile.write(toml.dumps(init_config))
|
||||
init_config_tmpfile.flush()
|
||||
|
||||
cmd = [
|
||||
"init",
|
||||
f"--config={init_config_tmpfile.name}",
|
||||
]
|
||||
|
||||
if force is not None:
|
||||
cmd.extend(["--force", force])
|
||||
|
||||
res = self.raw_cli(cmd)
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
def storage_controller_start(
|
||||
self,
|
||||
timeout_in_seconds: Optional[int] = None,
|
||||
instance_id: Optional[int] = None,
|
||||
base_port: Optional[int] = None,
|
||||
):
|
||||
cmd = ["storage_controller", "start"]
|
||||
if timeout_in_seconds is not None:
|
||||
cmd.append(f"--start-timeout={timeout_in_seconds}s")
|
||||
if instance_id is not None:
|
||||
cmd.append(f"--instance-id={instance_id}")
|
||||
if base_port is not None:
|
||||
cmd.append(f"--base-port={base_port}")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def storage_controller_stop(self, immediate: bool, instance_id: Optional[int] = None):
|
||||
cmd = ["storage_controller", "stop"]
|
||||
if immediate:
|
||||
cmd.extend(["-m", "immediate"])
|
||||
if instance_id is not None:
|
||||
cmd.append(f"--instance-id={instance_id}")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def pageserver_start(
|
||||
self,
|
||||
id: int,
|
||||
extra_env_vars: Optional[Dict[str, str]] = None,
|
||||
timeout_in_seconds: Optional[int] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
start_args = ["pageserver", "start", f"--id={id}"]
|
||||
if timeout_in_seconds is not None:
|
||||
start_args.append(f"--start-timeout={timeout_in_seconds}s")
|
||||
storage = self.env.pageserver_remote_storage
|
||||
|
||||
if isinstance(storage, S3Storage):
|
||||
s3_env_vars = storage.access_env_vars()
|
||||
extra_env_vars = (extra_env_vars or {}) | s3_env_vars
|
||||
|
||||
return self.raw_cli(start_args, extra_env_vars=extra_env_vars)
|
||||
|
||||
def pageserver_stop(self, id: int, immediate=False) -> "subprocess.CompletedProcess[str]":
|
||||
cmd = ["pageserver", "stop", f"--id={id}"]
|
||||
if immediate:
|
||||
cmd.extend(["-m", "immediate"])
|
||||
|
||||
log.info(f"Stopping pageserver with {cmd}")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def safekeeper_start(
|
||||
self,
|
||||
id: int,
|
||||
extra_opts: Optional[List[str]] = None,
|
||||
timeout_in_seconds: Optional[int] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
s3_env_vars = None
|
||||
if isinstance(self.env.safekeepers_remote_storage, S3Storage):
|
||||
s3_env_vars = self.env.safekeepers_remote_storage.access_env_vars()
|
||||
|
||||
if extra_opts is not None:
|
||||
extra_opts = [f"-e={opt}" for opt in extra_opts]
|
||||
else:
|
||||
extra_opts = []
|
||||
if timeout_in_seconds is not None:
|
||||
extra_opts.append(f"--start-timeout={timeout_in_seconds}s")
|
||||
return self.raw_cli(
|
||||
["safekeeper", "start", str(id), *extra_opts], extra_env_vars=s3_env_vars
|
||||
)
|
||||
|
||||
def safekeeper_stop(
|
||||
self, id: Optional[int] = None, immediate=False
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = ["safekeeper", "stop"]
|
||||
if id is not None:
|
||||
args.append(str(id))
|
||||
if immediate:
|
||||
args.extend(["-m", "immediate"])
|
||||
return self.raw_cli(args)
|
||||
|
||||
def broker_start(
|
||||
self, timeout_in_seconds: Optional[int] = None
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
cmd = ["storage_broker", "start"]
|
||||
if timeout_in_seconds is not None:
|
||||
cmd.append(f"--start-timeout={timeout_in_seconds}s")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def broker_stop(self) -> "subprocess.CompletedProcess[str]":
|
||||
cmd = ["storage_broker", "stop"]
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def endpoint_create(
|
||||
self,
|
||||
branch_name: str,
|
||||
pg_port: int,
|
||||
http_port: int,
|
||||
endpoint_id: Optional[str] = None,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
hot_standby: bool = False,
|
||||
lsn: Optional[Lsn] = None,
|
||||
pageserver_id: Optional[int] = None,
|
||||
allow_multiple=False,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = [
|
||||
"endpoint",
|
||||
"create",
|
||||
"--tenant-id",
|
||||
str(tenant_id or self.env.initial_tenant),
|
||||
"--branch-name",
|
||||
branch_name,
|
||||
"--pg-version",
|
||||
self.env.pg_version,
|
||||
]
|
||||
if lsn is not None:
|
||||
args.extend(["--lsn", str(lsn)])
|
||||
if pg_port is not None:
|
||||
args.extend(["--pg-port", str(pg_port)])
|
||||
if http_port is not None:
|
||||
args.extend(["--http-port", str(http_port)])
|
||||
if endpoint_id is not None:
|
||||
args.append(endpoint_id)
|
||||
if hot_standby:
|
||||
args.extend(["--hot-standby", "true"])
|
||||
if pageserver_id is not None:
|
||||
args.extend(["--pageserver-id", str(pageserver_id)])
|
||||
if allow_multiple:
|
||||
args.extend(["--allow-multiple"])
|
||||
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
def endpoint_start(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
safekeepers: Optional[List[int]] = None,
|
||||
remote_ext_config: Optional[str] = None,
|
||||
pageserver_id: Optional[int] = None,
|
||||
allow_multiple=False,
|
||||
basebackup_request_tries: Optional[int] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = [
|
||||
"endpoint",
|
||||
"start",
|
||||
]
|
||||
extra_env_vars = {}
|
||||
if basebackup_request_tries is not None:
|
||||
extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_TRIES"] = str(basebackup_request_tries)
|
||||
if remote_ext_config is not None:
|
||||
args.extend(["--remote-ext-config", remote_ext_config])
|
||||
|
||||
if safekeepers is not None:
|
||||
args.extend(["--safekeepers", (",".join(map(str, safekeepers)))])
|
||||
if endpoint_id is not None:
|
||||
args.append(endpoint_id)
|
||||
if pageserver_id is not None:
|
||||
args.extend(["--pageserver-id", str(pageserver_id)])
|
||||
if allow_multiple:
|
||||
args.extend(["--allow-multiple"])
|
||||
|
||||
res = self.raw_cli(args, extra_env_vars)
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
def endpoint_reconfigure(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
tenant_id: Optional[TenantId] = None,
|
||||
pageserver_id: Optional[int] = None,
|
||||
safekeepers: Optional[List[int]] = None,
|
||||
check_return_code=True,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = ["endpoint", "reconfigure", endpoint_id]
|
||||
if tenant_id is not None:
|
||||
args.extend(["--tenant-id", str(tenant_id)])
|
||||
if pageserver_id is not None:
|
||||
args.extend(["--pageserver-id", str(pageserver_id)])
|
||||
if safekeepers is not None:
|
||||
args.extend(["--safekeepers", (",".join(map(str, safekeepers)))])
|
||||
return self.raw_cli(args, check_return_code=check_return_code)
|
||||
|
||||
def endpoint_stop(
|
||||
self,
|
||||
endpoint_id: str,
|
||||
destroy=False,
|
||||
check_return_code=True,
|
||||
mode: Optional[str] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
args = [
|
||||
"endpoint",
|
||||
"stop",
|
||||
]
|
||||
if destroy:
|
||||
args.append("--destroy")
|
||||
if mode is not None:
|
||||
args.append(f"--mode={mode}")
|
||||
if endpoint_id is not None:
|
||||
args.append(endpoint_id)
|
||||
|
||||
return self.raw_cli(args, check_return_code=check_return_code)
|
||||
|
||||
def map_branch(
|
||||
self, name: str, tenant_id: TenantId, timeline_id: TimelineId
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
"""
|
||||
Map tenant id and timeline id to a neon_local branch name. They do not have to exist.
|
||||
Usually needed when creating branches via PageserverHttpClient and not neon_local.
|
||||
|
||||
After creating a name mapping, you can use EndpointFactory.create_start
|
||||
with this registered branch name.
|
||||
"""
|
||||
args = [
|
||||
"mappings",
|
||||
"map",
|
||||
"--branch-name",
|
||||
name,
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
]
|
||||
|
||||
return self.raw_cli(args, check_return_code=True)
|
||||
|
||||
def start(self, check_return_code=True) -> "subprocess.CompletedProcess[str]":
|
||||
return self.raw_cli(["start"], check_return_code=check_return_code)
|
||||
|
||||
def stop(self, check_return_code=True) -> "subprocess.CompletedProcess[str]":
|
||||
return self.raw_cli(["stop"], check_return_code=check_return_code)
|
||||
|
||||
|
||||
class WalCraft(AbstractNeonCli):
|
||||
"""
|
||||
A typed wrapper around the `wal_craft` CLI tool.
|
||||
Supports main commands via typed methods and a way to run arbitrary command directly via CLI.
|
||||
"""
|
||||
|
||||
COMMAND = "wal_craft"
|
||||
|
||||
def postgres_config(self) -> List[str]:
|
||||
res = self.raw_cli(["print-postgres-config"])
|
||||
res.check_returncode()
|
||||
return res.stdout.split("\n")
|
||||
|
||||
def in_existing(self, type: str, connection: str) -> None:
|
||||
res = self.raw_cli(["in-existing", type, connection])
|
||||
res.check_returncode()
|
||||
|
||||
|
||||
class ComputeCtl(AbstractNeonCli):
|
||||
"""
|
||||
A typed wrapper around the `compute_ctl` CLI tool.
|
||||
"""
|
||||
|
||||
COMMAND = "compute_ctl"
|
||||
|
||||
|
||||
class Pagectl(AbstractNeonCli):
|
||||
"""
|
||||
A typed wrapper around the `pagectl` utility CLI tool.
|
||||
"""
|
||||
|
||||
COMMAND = "pagectl"
|
||||
|
||||
def dump_index_part(self, path: Path) -> IndexPartDump:
|
||||
res = self.raw_cli(["index-part", "dump", str(path)])
|
||||
res.check_returncode()
|
||||
parsed = json.loads(res.stdout)
|
||||
return IndexPartDump.from_json(parsed)
|
||||
|
||||
|
||||
class LogUtils:
|
||||
"""
|
||||
A mixin class which provides utilities for inspecting the logs of a service.
|
||||
@@ -2933,6 +2427,10 @@ class NeonPageserver(PgProtocol, LogUtils):
|
||||
"""
|
||||
assert self.running is False
|
||||
|
||||
storage = self.env.pageserver_remote_storage
|
||||
if isinstance(storage, S3Storage):
|
||||
s3_env_vars = storage.access_env_vars()
|
||||
extra_env_vars = (extra_env_vars or {}) | s3_env_vars
|
||||
self.env.neon_cli.pageserver_start(
|
||||
self.id, extra_env_vars=extra_env_vars, timeout_in_seconds=timeout_in_seconds
|
||||
)
|
||||
@@ -3953,6 +3451,7 @@ class Endpoint(PgProtocol, LogUtils):
|
||||
hot_standby=hot_standby,
|
||||
pg_port=self.pg_port,
|
||||
http_port=self.http_port,
|
||||
pg_version=self.env.pg_version,
|
||||
pageserver_id=pageserver_id,
|
||||
allow_multiple=allow_multiple,
|
||||
)
|
||||
@@ -4395,8 +3894,16 @@ class Safekeeper(LogUtils):
|
||||
extra_opts = self.extra_opts
|
||||
|
||||
assert self.running is False
|
||||
|
||||
s3_env_vars = None
|
||||
if isinstance(self.env.safekeepers_remote_storage, S3Storage):
|
||||
s3_env_vars = self.env.safekeepers_remote_storage.access_env_vars()
|
||||
|
||||
self.env.neon_cli.safekeeper_start(
|
||||
self.id, extra_opts=extra_opts, timeout_in_seconds=timeout_in_seconds
|
||||
self.id,
|
||||
extra_opts=extra_opts,
|
||||
timeout_in_seconds=timeout_in_seconds,
|
||||
extra_env_vars=s3_env_vars,
|
||||
)
|
||||
self.running = True
|
||||
# wait for wal acceptor start by checking its status
|
||||
@@ -4542,7 +4049,7 @@ class Safekeeper(LogUtils):
|
||||
1) wait for remote_consistent_lsn and wal_backup_lsn on safekeeper to reach it.
|
||||
2) checkpoint timeline on safekeeper, which should remove WAL before this LSN; optionally wait for that.
|
||||
"""
|
||||
cli = self.http_client()
|
||||
client = self.http_client()
|
||||
|
||||
target_segment_file = lsn.segment_name()
|
||||
|
||||
@@ -4554,7 +4061,7 @@ class Safekeeper(LogUtils):
|
||||
assert all(target_segment_file <= s for s in segments)
|
||||
|
||||
def are_lsns_advanced():
|
||||
stat = cli.timeline_status(tenant_id, timeline_id)
|
||||
stat = client.timeline_status(tenant_id, timeline_id)
|
||||
log.info(
|
||||
f"waiting for remote_consistent_lsn and backup_lsn on sk {self.id} to reach {lsn}, currently remote_consistent_lsn={stat.remote_consistent_lsn}, backup_lsn={stat.backup_lsn}"
|
||||
)
|
||||
@@ -4563,7 +4070,7 @@ class Safekeeper(LogUtils):
|
||||
# xxx: max wait is long because we might be waiting for reconnection from
|
||||
# pageserver to this safekeeper
|
||||
wait_until(30, 1, are_lsns_advanced)
|
||||
cli.checkpoint(tenant_id, timeline_id)
|
||||
client.checkpoint(tenant_id, timeline_id)
|
||||
if wait_wal_removal:
|
||||
wait_until(30, 1, are_segments_removed)
|
||||
|
||||
@@ -4591,13 +4098,13 @@ class NeonBroker(LogUtils):
|
||||
timeout_in_seconds: Optional[int] = None,
|
||||
):
|
||||
assert not self.running
|
||||
self.env.neon_cli.broker_start(timeout_in_seconds)
|
||||
self.env.neon_cli.storage_broker_start(timeout_in_seconds)
|
||||
self.running = True
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
if self.running:
|
||||
self.env.neon_cli.broker_stop()
|
||||
self.env.neon_cli.storage_broker_stop()
|
||||
self.running = False
|
||||
return self
|
||||
|
||||
@@ -5226,10 +4733,10 @@ def flush_ep_to_pageserver(
|
||||
commit_lsn: Lsn = Lsn(0)
|
||||
# In principle in the absense of failures polling single sk would be enough.
|
||||
for sk in env.safekeepers:
|
||||
cli = sk.http_client()
|
||||
client = sk.http_client()
|
||||
# wait until compute connections are gone
|
||||
wait_walreceivers_absent(cli, tenant, timeline)
|
||||
commit_lsn = max(cli.get_commit_lsn(tenant, timeline), commit_lsn)
|
||||
wait_walreceivers_absent(client, tenant, timeline)
|
||||
commit_lsn = max(client.get_commit_lsn(tenant, timeline), commit_lsn)
|
||||
|
||||
# Note: depending on WAL filtering implementation, probably most shards
|
||||
# won't be able to reach commit_lsn (unless gaps are also ack'ed), so this
|
||||
@@ -5282,7 +4789,12 @@ def fork_at_current_lsn(
|
||||
the WAL up to that LSN to arrive in the pageserver before creating the branch.
|
||||
"""
|
||||
current_lsn = endpoint.safe_psql("SELECT pg_current_wal_lsn()")[0][0]
|
||||
return env.neon_cli.create_branch(new_branch_name, ancestor_branch_name, tenant_id, current_lsn)
|
||||
return env.create_branch(
|
||||
new_branch_name=new_branch_name,
|
||||
tenant_id=tenant_id,
|
||||
ancestor_branch_name=ancestor_branch_name,
|
||||
ancestor_start_lsn=current_lsn,
|
||||
)
|
||||
|
||||
|
||||
def import_timeline_from_vanilla_postgres(
|
||||
@@ -5301,9 +4813,9 @@ def import_timeline_from_vanilla_postgres(
|
||||
"""
|
||||
|
||||
# Take backup of the existing PostgreSQL server with pg_basebackup
|
||||
basebackup_dir = os.path.join(test_output_dir, "basebackup")
|
||||
base_tar = os.path.join(basebackup_dir, "base.tar")
|
||||
wal_tar = os.path.join(basebackup_dir, "pg_wal.tar")
|
||||
basebackup_dir = test_output_dir / "basebackup"
|
||||
base_tar = basebackup_dir / "base.tar"
|
||||
wal_tar = basebackup_dir / "pg_wal.tar"
|
||||
os.mkdir(basebackup_dir)
|
||||
pg_bin.run(
|
||||
[
|
||||
@@ -5313,40 +4825,28 @@ def import_timeline_from_vanilla_postgres(
|
||||
"-d",
|
||||
vanilla_pg_connstr,
|
||||
"-D",
|
||||
basebackup_dir,
|
||||
str(basebackup_dir),
|
||||
]
|
||||
)
|
||||
|
||||
# Extract start_lsn and end_lsn form the backup manifest file
|
||||
with open(os.path.join(basebackup_dir, "backup_manifest")) as f:
|
||||
manifest = json.load(f)
|
||||
start_lsn = manifest["WAL-Ranges"][0]["Start-LSN"]
|
||||
end_lsn = manifest["WAL-Ranges"][0]["End-LSN"]
|
||||
start_lsn = Lsn(manifest["WAL-Ranges"][0]["Start-LSN"])
|
||||
end_lsn = Lsn(manifest["WAL-Ranges"][0]["End-LSN"])
|
||||
|
||||
# Import the backup tarballs into the pageserver
|
||||
env.neon_cli.raw_cli(
|
||||
[
|
||||
"timeline",
|
||||
"import",
|
||||
"--tenant-id",
|
||||
str(tenant_id),
|
||||
"--timeline-id",
|
||||
str(timeline_id),
|
||||
"--branch-name",
|
||||
branch_name,
|
||||
"--base-lsn",
|
||||
start_lsn,
|
||||
"--base-tarfile",
|
||||
base_tar,
|
||||
"--end-lsn",
|
||||
end_lsn,
|
||||
"--wal-tarfile",
|
||||
wal_tar,
|
||||
"--pg-version",
|
||||
env.pg_version,
|
||||
]
|
||||
env.neon_cli.timeline_import(
|
||||
tenant_id=tenant_id,
|
||||
timeline_id=timeline_id,
|
||||
new_branch_name=branch_name,
|
||||
base_lsn=start_lsn,
|
||||
base_tarfile=base_tar,
|
||||
end_lsn=end_lsn,
|
||||
wal_tarfile=wal_tar,
|
||||
pg_version=env.pg_version,
|
||||
)
|
||||
wait_for_last_record_lsn(env.pageserver.http_client(), tenant_id, timeline_id, Lsn(end_lsn))
|
||||
wait_for_last_record_lsn(env.pageserver.http_client(), tenant_id, timeline_id, end_lsn)
|
||||
|
||||
|
||||
def last_flush_lsn_upload(
|
||||
|
||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from fixtures.common_types import TenantId, TimelineId
|
||||
from fixtures.neon_fixtures import NeonEnv, Pagectl
|
||||
from fixtures.neon_fixtures import NeonEnv
|
||||
from fixtures.pageserver.common_types import (
|
||||
InvalidFileName,
|
||||
parse_layer_file_name,
|
||||
@@ -35,7 +35,7 @@ def duplicate_one_tenant(env: NeonEnv, template_tenant: TenantId, new_tenant: Te
|
||||
for file in tl.iterdir():
|
||||
shutil.copy2(file, dst_tl_dir)
|
||||
if "__" in file.name:
|
||||
Pagectl(env).raw_cli(
|
||||
env.pagectl.raw_cli(
|
||||
[
|
||||
"layer",
|
||||
"rewrite-summary",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user