mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-05 03:30:36 +00:00
Compare commits
75 Commits
release-45
...
http2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24306129f7 | ||
|
|
3e4265d706 | ||
|
|
923017af8c | ||
|
|
80186412a9 | ||
|
|
9ab91b42eb | ||
|
|
7061c5dc76 | ||
|
|
b8312a1ec7 | ||
|
|
2e6ddc94a4 | ||
|
|
e8c787810a | ||
|
|
3b29bd3e4f | ||
|
|
aafe79873c | ||
|
|
eae74383c1 | ||
|
|
8b657a1481 | ||
|
|
42613d4c30 | ||
|
|
7f828890cf | ||
|
|
1eb30b40af | ||
|
|
8551a61014 | ||
|
|
087526b81b | ||
|
|
915fba146d | ||
|
|
da7a7c867e | ||
|
|
551f0cc097 | ||
|
|
a84935d266 | ||
|
|
3ee981889f | ||
|
|
fc66ba43c4 | ||
|
|
544284cce0 | ||
|
|
71beabf82d | ||
|
|
76372ce002 | ||
|
|
4e1b0b84eb | ||
|
|
f94abbab95 | ||
|
|
4b9b4c2c36 | ||
|
|
8186f6b6f9 | ||
|
|
90e0219b29 | ||
|
|
4b6004e8c9 | ||
|
|
9bf7664049 | ||
|
|
d5e3434371 | ||
|
|
66c52a629a | ||
|
|
8a646cb750 | ||
|
|
a4ac8e26e8 | ||
|
|
b3a681d121 | ||
|
|
b5ed6f22ae | ||
|
|
d1c0232e21 | ||
|
|
a41c4122e3 | ||
|
|
7de829e475 | ||
|
|
3c560d27a8 | ||
|
|
d260426a14 | ||
|
|
f3b5db1443 | ||
|
|
18e9208158 | ||
|
|
7662df6ca0 | ||
|
|
c119af8ddd | ||
|
|
a2e083ebe0 | ||
|
|
73a944205b | ||
|
|
34ebfbdd6f | ||
|
|
ef7c9c2ccc | ||
|
|
6c79e12630 | ||
|
|
753d97bd77 | ||
|
|
edc962f1d7 | ||
|
|
65b4e6e7d6 | ||
|
|
17b256679b | ||
|
|
673a865055 | ||
|
|
fb518aea0d | ||
|
|
42f41afcbd | ||
|
|
f71110383c | ||
|
|
ae3eaf9995 | ||
|
|
aa9f1d4b69 | ||
|
|
946c6a0006 | ||
|
|
ce13281d54 | ||
|
|
4e1d16f311 | ||
|
|
091a0cda9d | ||
|
|
ea9fad419e | ||
|
|
e92c9f42c0 | ||
|
|
aaaa39d9f5 | ||
|
|
e79a19339c | ||
|
|
dbd36e40dc | ||
|
|
90ef48aab8 | ||
|
|
9a43c04a19 |
213
Cargo.lock
generated
213
Cargo.lock
generated
@@ -30,6 +30,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cd7d5a2cecb58716e47d67d5703a249964b14c7be1ec3cad3affc295b2d1c35d"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"const-random",
|
||||
"getrandom 0.2.11",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
"zerocopy",
|
||||
@@ -50,6 +52,12 @@ version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
@@ -247,6 +255,12 @@ dependencies = [
|
||||
"syn 2.0.32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba"
|
||||
|
||||
[[package]]
|
||||
name = "atomic-polyfill"
|
||||
version = "1.0.2"
|
||||
@@ -1011,17 +1025,17 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.24"
|
||||
version = "0.4.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e3c5919066adf22df73762e50cffcde3a758f2a848b113b586d1f86728b673b"
|
||||
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"wasm-bindgen",
|
||||
"winapi",
|
||||
"windows-targets 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1120,6 +1134,20 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7"
|
||||
|
||||
[[package]]
|
||||
name = "combine"
|
||||
version = "4.6.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35ed6e9d84f0b51a7f52daf1c7d71dd136fd7a3f41a8462b8cdb8c78d920fad4"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "comfy-table"
|
||||
version = "6.1.4"
|
||||
@@ -2361,19 +2389,6 @@ dependencies = [
|
||||
"tokio-native-tls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tungstenite"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.56"
|
||||
@@ -2475,6 +2490,12 @@ dependencies = [
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "integer-encoding"
|
||||
version = "3.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02"
|
||||
|
||||
[[package]]
|
||||
name = "io-lifetimes"
|
||||
version = "1.0.11"
|
||||
@@ -2838,6 +2859,19 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
|
||||
dependencies = [
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.3"
|
||||
@@ -2849,6 +2883,15 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.45"
|
||||
@@ -2859,6 +2902,28 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-iter"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.15"
|
||||
@@ -3081,6 +3146,15 @@ dependencies = [
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ordered-float"
|
||||
version = "2.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ordered-multimap"
|
||||
version = "0.7.1"
|
||||
@@ -3124,6 +3198,7 @@ name = "pagebench"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"camino",
|
||||
"clap",
|
||||
"futures",
|
||||
"hdrhistogram",
|
||||
@@ -3136,6 +3211,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
@@ -3339,6 +3415,35 @@ dependencies = [
|
||||
"windows-targets 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parquet"
|
||||
version = "49.0.0"
|
||||
source = "git+https://github.com/neondatabase/arrow-rs?branch=neon-fix-bugs#8a0bc58aa67b98aabbd8eee7c6ca4281967ff9e9"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"hashbrown 0.14.0",
|
||||
"num",
|
||||
"num-bigint",
|
||||
"paste",
|
||||
"seq-macro",
|
||||
"thrift",
|
||||
"twox-hash",
|
||||
"zstd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parquet_derive"
|
||||
version = "49.0.0"
|
||||
source = "git+https://github.com/neondatabase/arrow-rs?branch=neon-fix-bugs#8a0bc58aa67b98aabbd8eee7c6ca4281967ff9e9"
|
||||
dependencies = [
|
||||
"parquet",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "password-hash"
|
||||
version = "0.5.0"
|
||||
@@ -3762,6 +3867,8 @@ dependencies = [
|
||||
"base64 0.13.1",
|
||||
"bstr",
|
||||
"bytes",
|
||||
"camino",
|
||||
"camino-tempfile",
|
||||
"chrono",
|
||||
"clap",
|
||||
"consumption_metrics",
|
||||
@@ -3775,7 +3882,6 @@ dependencies = [
|
||||
"hostname",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper-tungstenite",
|
||||
"ipnet",
|
||||
"itertools",
|
||||
"md5",
|
||||
@@ -3784,6 +3890,8 @@ dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"parking_lot 0.12.1",
|
||||
"parquet",
|
||||
"parquet_derive",
|
||||
"pbkdf2",
|
||||
"pin-project-lite",
|
||||
"postgres-native-tls",
|
||||
@@ -3793,7 +3901,9 @@ dependencies = [
|
||||
"prometheus",
|
||||
"rand 0.8.5",
|
||||
"rcgen",
|
||||
"redis",
|
||||
"regex",
|
||||
"remote_storage",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
"reqwest-retry",
|
||||
@@ -3817,11 +3927,13 @@ dependencies = [
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-opentelemetry",
|
||||
"tracing-subscriber",
|
||||
"tracing-utils",
|
||||
"tungstenite",
|
||||
"url",
|
||||
"utils",
|
||||
"uuid",
|
||||
@@ -3954,6 +4066,32 @@ dependencies = [
|
||||
"yasna",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "0.24.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"combine",
|
||||
"futures-util",
|
||||
"itoa",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
"rustls-pemfile",
|
||||
"rustls-webpki 0.101.7",
|
||||
"ryu",
|
||||
"sha1_smol",
|
||||
"socket2 0.4.9",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.16"
|
||||
@@ -4682,6 +4820,12 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "seq-macro"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.183"
|
||||
@@ -4804,6 +4948,12 @@ dependencies = [
|
||||
"digest",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1_smol"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.6"
|
||||
@@ -5202,6 +5352,17 @@ dependencies = [
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thrift"
|
||||
version = "0.17.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"integer-encoding",
|
||||
"ordered-float",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.21"
|
||||
@@ -5746,6 +5907,16 @@ dependencies = [
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "twox-hash"
|
||||
version = "1.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.16.0"
|
||||
@@ -5923,10 +6094,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "1.3.3"
|
||||
version = "1.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "345444e32442451b267fc254ae85a209c64be56d2890e601a0c37ff0c3c5ecd2"
|
||||
checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560"
|
||||
dependencies = [
|
||||
"atomic",
|
||||
"getrandom 0.2.11",
|
||||
"serde",
|
||||
]
|
||||
@@ -6422,6 +6594,7 @@ dependencies = [
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"parquet",
|
||||
"prost",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
|
||||
12
Cargo.toml
12
Cargo.toml
@@ -89,7 +89,6 @@ http-types = { version = "2", default-features = false }
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
hyper-tungstenite = "0.11"
|
||||
inotify = "0.10.2"
|
||||
ipnet = "2.9.0"
|
||||
itertools = "0.10"
|
||||
@@ -107,11 +106,14 @@ opentelemetry = "0.19.0"
|
||||
opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
|
||||
opentelemetry-semantic-conventions = "0.11.0"
|
||||
parking_lot = "0.12"
|
||||
parquet = { version = "49.0.0", default-features = false, features = ["zstd"] }
|
||||
parquet_derive = "49.0.0"
|
||||
pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
|
||||
pin-project-lite = "0.2"
|
||||
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
||||
prost = "0.11"
|
||||
rand = "0.8"
|
||||
redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] }
|
||||
regex = "1.10.2"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
|
||||
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_19"] }
|
||||
@@ -153,6 +155,7 @@ tokio-rustls = "0.24"
|
||||
tokio-stream = "0.1"
|
||||
tokio-tar = "0.3"
|
||||
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
|
||||
tokio-tungstenite = "0.20"
|
||||
toml = "0.7"
|
||||
toml_edit = "0.19"
|
||||
tonic = {version = "0.9", features = ["tls", "tls-roots"]}
|
||||
@@ -160,8 +163,9 @@ tracing = "0.1"
|
||||
tracing-error = "0.2.0"
|
||||
tracing-opentelemetry = "0.19.0"
|
||||
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
|
||||
tungstenite = "0.20"
|
||||
url = "2.2"
|
||||
uuid = { version = "1.2", features = ["v4", "serde"] }
|
||||
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
|
||||
walkdir = "2.3.2"
|
||||
webpki-roots = "0.25"
|
||||
x509-parser = "0.15"
|
||||
@@ -215,6 +219,10 @@ tonic-build = "0.9"
|
||||
# TODO: we should probably fork `tokio-postgres-rustls` instead.
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
|
||||
|
||||
# bug fixes for UUID
|
||||
parquet = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs" }
|
||||
parquet_derive = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs" }
|
||||
|
||||
################# Binary contents sections
|
||||
|
||||
[profile.release]
|
||||
|
||||
@@ -135,7 +135,7 @@ WORKDIR /home/nonroot
|
||||
|
||||
# Rust
|
||||
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
|
||||
ENV RUSTC_VERSION=1.74.0
|
||||
ENV RUSTC_VERSION=1.75.0
|
||||
ENV RUSTUP_HOME="/home/nonroot/.rustup"
|
||||
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
|
||||
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \
|
||||
|
||||
@@ -350,7 +350,7 @@ fn main() -> Result<()> {
|
||||
|
||||
// Wait for the child Postgres process forever. In this state Ctrl+C will
|
||||
// propagate to Postgres and it will be shut down as well.
|
||||
if let Some(mut pg) = pg {
|
||||
if let Some((mut pg, logs_handle)) = pg {
|
||||
// Startup is finished, exit the startup tracing span
|
||||
drop(startup_context_guard);
|
||||
|
||||
@@ -358,6 +358,12 @@ fn main() -> Result<()> {
|
||||
.wait()
|
||||
.expect("failed to start waiting on Postgres process");
|
||||
PG_PID.store(0, Ordering::SeqCst);
|
||||
|
||||
// Process has exited, so we can join the logs thread.
|
||||
let _ = logs_handle
|
||||
.join()
|
||||
.map_err(|e| tracing::error!("log thread panicked: {:?}", e));
|
||||
|
||||
info!("Postgres exited with code {}, shutting down", ecode);
|
||||
exit_code = ecode.code()
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ use utils::measured_stream::MeasuredReader;
|
||||
use remote_storage::{DownloadError, RemotePath};
|
||||
|
||||
use crate::checker::create_availability_check_data;
|
||||
use crate::logger::inlinify;
|
||||
use crate::pg_helpers::*;
|
||||
use crate::spec::*;
|
||||
use crate::sync_sk::{check_if_synced, ping_safekeeper};
|
||||
@@ -279,7 +280,7 @@ fn create_neon_superuser(spec: &ComputeSpec, client: &mut Client) -> Result<()>
|
||||
$$;"#,
|
||||
roles_decl, database_decl,
|
||||
);
|
||||
info!("Neon superuser created:\n{}", &query);
|
||||
info!("Neon superuser created:\n{}", inlinify(&query));
|
||||
client
|
||||
.simple_query(&query)
|
||||
.map_err(|e| anyhow::anyhow!(e).context(query))?;
|
||||
@@ -495,7 +496,7 @@ impl ComputeNode {
|
||||
pub fn sync_safekeepers(&self, storage_auth_token: Option<String>) -> Result<Lsn> {
|
||||
let start_time = Utc::now();
|
||||
|
||||
let sync_handle = maybe_cgexec(&self.pgbin)
|
||||
let mut sync_handle = maybe_cgexec(&self.pgbin)
|
||||
.args(["--sync-safekeepers"])
|
||||
.env("PGDATA", &self.pgdata) // we cannot use -D in this mode
|
||||
.envs(if let Some(storage_auth_token) = &storage_auth_token {
|
||||
@@ -504,18 +505,30 @@ impl ComputeNode {
|
||||
vec![]
|
||||
})
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.expect("postgres --sync-safekeepers failed to start");
|
||||
SYNC_SAFEKEEPERS_PID.store(sync_handle.id(), Ordering::SeqCst);
|
||||
|
||||
// `postgres --sync-safekeepers` will print all log output to stderr and
|
||||
// final LSN to stdout. So we pipe only stdout, while stderr will be automatically
|
||||
// redirected to the caller output.
|
||||
// final LSN to stdout. So we leave stdout to collect LSN, while stderr logs
|
||||
// will be collected in a child thread.
|
||||
let stderr = sync_handle
|
||||
.stderr
|
||||
.take()
|
||||
.expect("stderr should be captured");
|
||||
let logs_handle = handle_postgres_logs(stderr);
|
||||
|
||||
let sync_output = sync_handle
|
||||
.wait_with_output()
|
||||
.expect("postgres --sync-safekeepers failed");
|
||||
SYNC_SAFEKEEPERS_PID.store(0, Ordering::SeqCst);
|
||||
|
||||
// Process has exited, so we can join the logs thread.
|
||||
let _ = logs_handle
|
||||
.join()
|
||||
.map_err(|e| tracing::error!("log thread panicked: {:?}", e));
|
||||
|
||||
if !sync_output.status.success() {
|
||||
anyhow::bail!(
|
||||
"postgres --sync-safekeepers exited with non-zero status: {}. stdout: {}",
|
||||
@@ -652,11 +665,12 @@ impl ComputeNode {
|
||||
|
||||
/// Start Postgres as a child process and manage DBs/roles.
|
||||
/// After that this will hang waiting on the postmaster process to exit.
|
||||
/// Returns a handle to the child process and a handle to the logs thread.
|
||||
#[instrument(skip_all)]
|
||||
pub fn start_postgres(
|
||||
&self,
|
||||
storage_auth_token: Option<String>,
|
||||
) -> Result<std::process::Child> {
|
||||
) -> Result<(std::process::Child, std::thread::JoinHandle<()>)> {
|
||||
let pgdata_path = Path::new(&self.pgdata);
|
||||
|
||||
// Run postgres as a child process.
|
||||
@@ -667,13 +681,18 @@ impl ComputeNode {
|
||||
} else {
|
||||
vec![]
|
||||
})
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.expect("cannot start postgres process");
|
||||
PG_PID.store(pg.id(), Ordering::SeqCst);
|
||||
|
||||
// Start a thread to collect logs from stderr.
|
||||
let stderr = pg.stderr.take().expect("stderr should be captured");
|
||||
let logs_handle = handle_postgres_logs(stderr);
|
||||
|
||||
wait_for_postgres(&mut pg, pgdata_path)?;
|
||||
|
||||
Ok(pg)
|
||||
Ok((pg, logs_handle))
|
||||
}
|
||||
|
||||
/// Do initial configuration of the already started Postgres.
|
||||
@@ -818,7 +837,10 @@ impl ComputeNode {
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn start_compute(&self, extension_server_port: u16) -> Result<std::process::Child> {
|
||||
pub fn start_compute(
|
||||
&self,
|
||||
extension_server_port: u16,
|
||||
) -> Result<(std::process::Child, std::thread::JoinHandle<()>)> {
|
||||
let compute_state = self.state.lock().unwrap().clone();
|
||||
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
|
||||
info!(
|
||||
@@ -889,7 +911,7 @@ impl ComputeNode {
|
||||
self.prepare_pgdata(&compute_state, extension_server_port)?;
|
||||
|
||||
let start_time = Utc::now();
|
||||
let pg = self.start_postgres(pspec.storage_auth_token.clone())?;
|
||||
let pg_process = self.start_postgres(pspec.storage_auth_token.clone())?;
|
||||
|
||||
let config_time = Utc::now();
|
||||
if pspec.spec.mode == ComputeMode::Primary && !pspec.spec.skip_pg_catalog_updates {
|
||||
@@ -939,7 +961,7 @@ impl ComputeNode {
|
||||
};
|
||||
info!(?metrics, "compute start finished");
|
||||
|
||||
Ok(pg)
|
||||
Ok(pg_process)
|
||||
}
|
||||
|
||||
// Look for core dumps and collect backtraces.
|
||||
|
||||
@@ -38,3 +38,9 @@ pub fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Replace all newline characters with a special character to make it
|
||||
/// easier to grep for log messages.
|
||||
pub fn inlinify(s: &str) -> String {
|
||||
s.replace('\n', "\u{200B}")
|
||||
}
|
||||
|
||||
@@ -6,12 +6,15 @@ use std::io::{BufRead, BufReader};
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::Path;
|
||||
use std::process::Child;
|
||||
use std::thread::JoinHandle;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use ini::Ini;
|
||||
use notify::{RecursiveMode, Watcher};
|
||||
use postgres::{Client, Transaction};
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::time::timeout;
|
||||
use tokio_postgres::NoTls;
|
||||
use tracing::{debug, error, info, instrument};
|
||||
|
||||
@@ -426,3 +429,72 @@ pub async fn tune_pgbouncer(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Spawn a thread that will read Postgres logs from `stderr`, join multiline logs
|
||||
/// and send them to the logger. In the future we may also want to add context to
|
||||
/// these logs.
|
||||
pub fn handle_postgres_logs(stderr: std::process::ChildStderr) -> JoinHandle<()> {
|
||||
std::thread::spawn(move || {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("failed to build tokio runtime");
|
||||
|
||||
let res = runtime.block_on(async move {
|
||||
let stderr = tokio::process::ChildStderr::from_std(stderr)?;
|
||||
handle_postgres_logs_async(stderr).await
|
||||
});
|
||||
if let Err(e) = res {
|
||||
tracing::error!("error while processing postgres logs: {}", e);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Read Postgres logs from `stderr` until EOF. Buffer is flushed on one of the following conditions:
|
||||
/// - next line starts with timestamp
|
||||
/// - EOF
|
||||
/// - no new lines were written for the last second
|
||||
async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Result<()> {
|
||||
let mut lines = tokio::io::BufReader::new(stderr).lines();
|
||||
let timeout_duration = Duration::from_secs(1);
|
||||
let ts_regex =
|
||||
regex::Regex::new(r"^\d+-\d{2}-\d{2} \d{2}:\d{2}:\d{2}").expect("regex is valid");
|
||||
|
||||
let mut buf = vec![];
|
||||
loop {
|
||||
let next_line = timeout(timeout_duration, lines.next_line()).await;
|
||||
|
||||
// we should flush lines from the buffer if we cannot continue reading multiline message
|
||||
let should_flush_buf = match next_line {
|
||||
// Flushing if new line starts with timestamp
|
||||
Ok(Ok(Some(ref line))) => ts_regex.is_match(line),
|
||||
// Flushing on EOF, timeout or error
|
||||
_ => true,
|
||||
};
|
||||
|
||||
if !buf.is_empty() && should_flush_buf {
|
||||
// join multiline message into a single line, separated by unicode Zero Width Space.
|
||||
// "PG:" suffix is used to distinguish postgres logs from other logs.
|
||||
let combined = format!("PG:{}\n", buf.join("\u{200B}"));
|
||||
buf.clear();
|
||||
|
||||
// sync write to stderr to avoid interleaving with other logs
|
||||
use std::io::Write;
|
||||
let res = std::io::stderr().lock().write_all(combined.as_bytes());
|
||||
if let Err(e) = res {
|
||||
tracing::error!("error while writing to stderr: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// if not timeout, append line to the buffer
|
||||
if next_line.is_ok() {
|
||||
match next_line?? {
|
||||
Some(line) => buf.push(line),
|
||||
// EOF
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ use reqwest::StatusCode;
|
||||
use tracing::{error, info, info_span, instrument, span_enabled, warn, Level};
|
||||
|
||||
use crate::config;
|
||||
use crate::logger::inlinify;
|
||||
use crate::params::PG_HBA_ALL_MD5;
|
||||
use crate::pg_helpers::*;
|
||||
|
||||
@@ -662,7 +663,11 @@ pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) ->
|
||||
$$;"
|
||||
.to_string();
|
||||
|
||||
info!("grant query for db {} : {}", &db.name, &grant_query);
|
||||
info!(
|
||||
"grant query for db {} : {}",
|
||||
&db.name,
|
||||
inlinify(&grant_query)
|
||||
);
|
||||
db_client.simple_query(&grant_query)?;
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@
|
||||
//! rely on `neon_local` to set up the environment for each test.
|
||||
//!
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use clap::{value_parser, Arg, ArgAction, ArgMatches, Command};
|
||||
use clap::{value_parser, Arg, ArgAction, ArgMatches, Command, ValueEnum};
|
||||
use compute_api::spec::ComputeMode;
|
||||
use control_plane::attachment_service::AttachmentService;
|
||||
use control_plane::endpoint::ComputeControlPlane;
|
||||
use control_plane::local_env::LocalEnv;
|
||||
use control_plane::local_env::{InitForceMode, LocalEnv};
|
||||
use control_plane::pageserver::{PageServerNode, PAGESERVER_REMOTE_STORAGE_DIR};
|
||||
use control_plane::safekeeper::SafekeeperNode;
|
||||
use control_plane::tenant_migration::migrate_tenant;
|
||||
@@ -338,7 +338,7 @@ fn handle_init(init_match: &ArgMatches) -> anyhow::Result<LocalEnv> {
|
||||
|
||||
let mut env =
|
||||
LocalEnv::parse_config(&toml_file).context("Failed to create neon configuration")?;
|
||||
let force = init_match.get_flag("force");
|
||||
let force = init_match.get_one("force").expect("we set a default value");
|
||||
env.init(pg_version, force)
|
||||
.context("Failed to initialize neon repository")?;
|
||||
|
||||
@@ -1266,9 +1266,15 @@ fn cli() -> Command {
|
||||
.required(false);
|
||||
|
||||
let force_arg = Arg::new("force")
|
||||
.value_parser(value_parser!(bool))
|
||||
.value_parser(value_parser!(InitForceMode))
|
||||
.long("force")
|
||||
.action(ArgAction::SetTrue)
|
||||
.default_value(
|
||||
InitForceMode::MustNotExist
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
.to_owned(),
|
||||
)
|
||||
.help("Force initialization even if the repository is not empty")
|
||||
.required(false);
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
|
||||
use clap::ValueEnum;
|
||||
use postgres_backend::AuthType;
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -162,6 +163,31 @@ impl Default for SafekeeperConf {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum InitForceMode {
|
||||
MustNotExist,
|
||||
EmptyDirOk,
|
||||
RemoveAllContents,
|
||||
}
|
||||
|
||||
impl ValueEnum for InitForceMode {
|
||||
fn value_variants<'a>() -> &'a [Self] {
|
||||
&[
|
||||
Self::MustNotExist,
|
||||
Self::EmptyDirOk,
|
||||
Self::RemoveAllContents,
|
||||
]
|
||||
}
|
||||
|
||||
fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
|
||||
Some(clap::builder::PossibleValue::new(match self {
|
||||
InitForceMode::MustNotExist => "must-not-exist",
|
||||
InitForceMode::EmptyDirOk => "empty-dir-ok",
|
||||
InitForceMode::RemoveAllContents => "remove-all-contents",
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl SafekeeperConf {
|
||||
/// Compute is served by port on which only tenant scoped tokens allowed, if
|
||||
/// it is configured.
|
||||
@@ -384,7 +410,7 @@ impl LocalEnv {
|
||||
//
|
||||
// Initialize a new Neon repository
|
||||
//
|
||||
pub fn init(&mut self, pg_version: u32, force: bool) -> anyhow::Result<()> {
|
||||
pub fn init(&mut self, pg_version: u32, force: &InitForceMode) -> anyhow::Result<()> {
|
||||
// check if config already exists
|
||||
let base_path = &self.base_data_dir;
|
||||
ensure!(
|
||||
@@ -393,25 +419,34 @@ impl LocalEnv {
|
||||
);
|
||||
|
||||
if base_path.exists() {
|
||||
if force {
|
||||
println!("removing all contents of '{}'", base_path.display());
|
||||
// instead of directly calling `remove_dir_all`, we keep the original dir but removing
|
||||
// all contents inside. This helps if the developer symbol links another directory (i.e.,
|
||||
// S3 local SSD) to the `.neon` base directory.
|
||||
for entry in std::fs::read_dir(base_path)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
fs::remove_dir_all(&path)?;
|
||||
} else {
|
||||
fs::remove_file(&path)?;
|
||||
match force {
|
||||
InitForceMode::MustNotExist => {
|
||||
bail!(
|
||||
"directory '{}' already exists. Perhaps already initialized?",
|
||||
base_path.display()
|
||||
);
|
||||
}
|
||||
InitForceMode::EmptyDirOk => {
|
||||
if let Some(res) = std::fs::read_dir(base_path)?.next() {
|
||||
res.context("check if directory is empty")?;
|
||||
anyhow::bail!("directory not empty: {base_path:?}");
|
||||
}
|
||||
}
|
||||
InitForceMode::RemoveAllContents => {
|
||||
println!("removing all contents of '{}'", base_path.display());
|
||||
// instead of directly calling `remove_dir_all`, we keep the original dir but removing
|
||||
// all contents inside. This helps if the developer symbol links another directory (i.e.,
|
||||
// S3 local SSD) to the `.neon` base directory.
|
||||
for entry in std::fs::read_dir(base_path)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
fs::remove_dir_all(&path)?;
|
||||
} else {
|
||||
fs::remove_file(&path)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
bail!(
|
||||
"directory '{}' already exists. Perhaps already initialized? (Hint: use --force to remove all contents)",
|
||||
base_path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
142
docs/rfcs/030-vectored-timeline-get.md
Normal file
142
docs/rfcs/030-vectored-timeline-get.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# Vectored Timeline Get
|
||||
|
||||
Created on: 2024-01-02
|
||||
Author: Christian Schwarz
|
||||
|
||||
# Summary
|
||||
|
||||
A brief RFC / GitHub Epic describing a vectored version of the `Timeline::get` method that is at the heart of Pageserver.
|
||||
|
||||
# Motivation
|
||||
|
||||
During basebackup, we issue many `Timeline::get` calls for SLRU pages that are *adjacent* in key space.
|
||||
For an example, see
|
||||
https://github.com/neondatabase/neon/blob/5c88213eaf1b1e29c610a078d0b380f69ed49a7e/pageserver/src/basebackup.rs#L281-L302.
|
||||
|
||||
Each of these `Timeline::get` calls must traverse the layer map to gather reconstruct data (`Timeline::get_reconstruct_data`) for the requested page number (`blknum` in the example).
|
||||
For each layer visited by layer map traversal, we do a `DiskBtree` point lookup.
|
||||
If it's negative (no entry), we resume layer map traversal.
|
||||
If it's positive, we collect the result in our reconstruct data bag.
|
||||
If the reconstruct data bag contents suffice to reconstruct the page, we're done with `get_reconstruct_data` and move on to walredo.
|
||||
Otherwise, we resume layer map traversal.
|
||||
|
||||
Doing this many `Timeline::get` calls is quite inefficient because:
|
||||
|
||||
1. We do the layer map traversal repeatedly, even if, e.g., all the data sits in the same image layer at the bottom of the stack.
|
||||
2. We may visit many DiskBtree inner pages multiple times for point lookup of different keys.
|
||||
This is likely particularly bad for L0s which span the whole key space and hence must be visited by layer map traversal, but
|
||||
may not contain the data we're looking for.
|
||||
3. Anecdotally, keys adjacent in keyspace and written simultaneously also end up physically adjacent in the layer files [^1].
|
||||
So, to provide the reconstruct data for N adjacent keys, we would actually only _need_ to issue a single large read to the filesystem, instead of the N reads we currently do.
|
||||
The filesystem, in turn, ideally stores the layer file physically contiguously, so our large read will turn into one IOP toward the disk.
|
||||
|
||||
[^1]: https://www.notion.so/neondatabase/Christian-Investigation-Slow-Basebackups-Early-2023-12-34ea5c7dcdc1485d9ac3731da4d2a6fc?pvs=4#15ee4e143392461fa64590679c8f54c9
|
||||
|
||||
# Solution
|
||||
|
||||
We should have a vectored aka batched aka scatter-gather style alternative API for `Timeline::get`. Having such an API unlocks:
|
||||
|
||||
* more efficient basebackup
|
||||
* batched IO during compaction (useful for strides of unchanged pages)
|
||||
* page_service: expose vectored get_page_at_lsn for compute (=> good for seqscan / prefetch)
|
||||
* if [on-demand SLRU downloads](https://github.com/neondatabase/neon/pull/6151) land before vectored Timeline::get, on-demand SLRU downloads will still benefit from this API
|
||||
|
||||
# DoD
|
||||
|
||||
There is a new variant of `Timeline::get`, called `Timeline::get_vectored`.
|
||||
It takes as arguments an `lsn: Lsn` and a `src: &[KeyVec]` where `struct KeyVec { base: Key, count: usize }`.
|
||||
|
||||
It is up to the implementor to figure out a suitable and efficient way to return the reconstructed page images.
|
||||
It is sufficient to simply return a `Vec<Bytes>`, but, likely more efficient solutions can be found after studying all the callers of `Timeline::get`.
|
||||
|
||||
Functionally, the behavior of `Timeline::get_vectored` is equivalent to
|
||||
|
||||
```rust
|
||||
let mut keys_iter: impl Iterator<Item=Key>
|
||||
= src.map(|KeyVec{ base, count }| (base..base+count)).flatten();
|
||||
let mut out = Vec::new();
|
||||
for key in keys_iter {
|
||||
let data = Timeline::get(key, lsn)?;
|
||||
out.push(data);
|
||||
}
|
||||
return out;
|
||||
```
|
||||
|
||||
However, unlike above, an ideal solution will
|
||||
|
||||
* Visit each `struct Layer` at most once.
|
||||
* For each visited layer, call `Layer::get_value_reconstruct_data` at most once.
|
||||
* This means, read each `DiskBtree` page at most once.
|
||||
* Facilitate merging of the reads we issue to the OS and eventually NVMe.
|
||||
|
||||
Each of these items above represents a signficant amount of work.
|
||||
|
||||
## Performance
|
||||
|
||||
Ideally, the **base performance** of a vectored get of a single page should be identical to the current `Timeline::get`.
|
||||
A reasonable constant overhead over current `Timeline::get` is acceptable.
|
||||
|
||||
The performance improvement for the vectored use case is demonstrated in some way, e.g., using the `pagebench` basebackup benchmark against a tenant with a lot of SLRU segments.
|
||||
|
||||
# Implementation
|
||||
|
||||
High-level set of tasks / changes to be made:
|
||||
|
||||
- **Get clarity on API**:
|
||||
- Define naive `Timeline::get_vectored` implementation & adopt it across pageserver.
|
||||
- The tricky thing here will be the return type (e.g. `Vec<Bytes>` vs `impl Stream`).
|
||||
- Start with something simple to explore the different usages of the API.
|
||||
Then iterate with peers until we have something that is good enough.
|
||||
- **Vectored Layer Map traversal**
|
||||
- Vectored `LayerMap::search` (take 1 LSN and N `Key`s instead of just 1 LSN and 1 `Key`)
|
||||
- Refactor `Timeline::get_reconstruct_data` to hold & return state for N `Key`s instead of 1
|
||||
- The slightly tricky part here is what to do about `cont_lsn` [after we've found some reconstruct data for some keys](https://github.com/neondatabase/neon/blob/d066dad84b076daf3781cdf9a692098889d3974e/pageserver/src/tenant/timeline.rs#L2378-L2385)
|
||||
but need more.
|
||||
Likely we'll need to keep track of `cont_lsn` per key and continue next iteration at `max(cont_lsn)` of all keys that still need data.
|
||||
- **Vectored `Layer::get_value_reconstruct_data` / `DiskBtree`**
|
||||
- Current code calls it [here](https://github.com/neondatabase/neon/blob/d066dad84b076daf3781cdf9a692098889d3974e/pageserver/src/tenant/timeline.rs#L2378-L2384).
|
||||
- Delta layers use `DiskBtreeReader::visit()` to collect the `(offset,len)` pairs for delta record blobs to load.
|
||||
- Image layers use `DiskBtreeReader::get` to get the offset of the image blob to load. Underneath, that's just a `::visit()` call.
|
||||
- What needs to happen to `DiskBtree::visit()`?
|
||||
* Minimally
|
||||
* take a single `KeyVec` instead of a single `Key` as argument, i.e., take a single contiguous key range to visit.
|
||||
* Change the visit code to to invoke the callback for all values in the `KeyVec`'s key range
|
||||
* This should be good enough for what we've seen when investigating basebackup slowness, because there, the key ranges are contiguous.
|
||||
* Ideally:
|
||||
* Take a `&[KeyVec]`, sort it;
|
||||
* during Btree traversal, peek at the next `KeyVec` range to determine whether we need to descend or back out.
|
||||
* NB: this should be a straight-forward extension of the minimal solution above, as we'll already be checking for "is there more key range in the requested `KeyVec`".
|
||||
- **Facilitate merging of the reads we issue to the OS and eventually NVMe.**
|
||||
- The `DiskBtree::visit` produces a set of offsets which we then read from a `VirtualFile` [here](https://github.com/neondatabase/neon/blob/292281c9dfb24152b728b1a846cc45105dac7fe0/pageserver/src/tenant/storage_layer/delta_layer.rs#L772-L804)
|
||||
- [Delta layer reads](https://github.com/neondatabase/neon/blob/292281c9dfb24152b728b1a846cc45105dac7fe0/pageserver/src/tenant/storage_layer/delta_layer.rs#L772-L804)
|
||||
- We hit (and rely) on `PageCache` and `VirtualFile here (not great under pressure)
|
||||
- [Image layer reads](https://github.com/neondatabase/neon/blob/292281c9dfb24152b728b1a846cc45105dac7fe0/pageserver/src/tenant/storage_layer/image_layer.rs#L429-L435)
|
||||
- What needs to happen is the **vectorization of the `blob_io` interface and then the `VirtualFile` API**.
|
||||
- That is tricky because
|
||||
- the `VirtualFile` API, which sits underneath `blob_io`, is being touched by ongoing [io_uring work](https://github.com/neondatabase/neon/pull/5824)
|
||||
- there's the question how IO buffers will be managed; currently this area relies heavily on `PageCache`, but there's controversy around the future of `PageCache`.
|
||||
- The guiding principle here should be to avoid coupling this work to the `PageCache`.
|
||||
- I.e., treat `PageCache` as an extra hop in the I/O chain, rather than as an integral part of buffer management.
|
||||
|
||||
|
||||
Let's see how we can improve by doing the first three items in above list first, then revisit.
|
||||
|
||||
## Rollout / Feature Flags
|
||||
|
||||
No feature flags are required for this epic.
|
||||
|
||||
At the end of this epic, `Timeline::get` forwards to `Timeline::get_vectored`, i.e., it's an all-or-nothing type of change.
|
||||
|
||||
It is encouraged to deliver this feature incrementally, i.e., do many small PRs over multiple weeks.
|
||||
That will help isolate performance regressions across weekly releases.
|
||||
|
||||
# Interaction With Sharding
|
||||
|
||||
[Sharding](https://github.com/neondatabase/neon/pull/5432) splits up the key space, see functions `is_key_local` / `key_to_shard_number`.
|
||||
|
||||
Just as with `Timeline::get`, callers of `Timeline::get_vectored` are responsible for ensuring that they only ask for blocks of the given `struct Timeline`'s shard.
|
||||
|
||||
Given that this is already the case, there shouldn't be significant interaction/interference with sharding.
|
||||
|
||||
However, let's have a safety check for this constraint (error or assertion) because there are currently few affordances at the higher layers of Pageserver for sharding<=>keyspace interaction.
|
||||
For example, `KeySpace` is not broken up by shard stripe, so if someone naively converted the compaction code to issue a vectored get for a keyspace range it would violate this constraint.
|
||||
@@ -141,6 +141,7 @@ impl Key {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn is_rel_block_key(key: &Key) -> bool {
|
||||
key.field1 == 0x00 && key.field4 != 0 && key.field6 != 0xffffffff
|
||||
}
|
||||
|
||||
@@ -114,10 +114,12 @@ impl KeySpaceAccum {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_key(&mut self, key: Key) {
|
||||
self.add_range(singleton_range(key))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn add_range(&mut self, range: Range<Key>) {
|
||||
match self.accum.as_mut() {
|
||||
Some(accum) => {
|
||||
|
||||
@@ -2,7 +2,7 @@ pub mod partitioning;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io::Read,
|
||||
io::{BufRead, Read},
|
||||
num::{NonZeroU64, NonZeroUsize},
|
||||
time::SystemTime,
|
||||
};
|
||||
@@ -813,9 +813,10 @@ impl PagestreamBeMessage {
|
||||
PagestreamBeMessage::GetPage(PagestreamGetPageResponse { page: page.into() })
|
||||
}
|
||||
Tag::Error => {
|
||||
let buf = buf.get_ref();
|
||||
let cstr = std::ffi::CStr::from_bytes_until_nul(buf)?;
|
||||
let rust_str = cstr.to_str()?;
|
||||
let mut msg = Vec::new();
|
||||
buf.read_until(0, &mut msg)?;
|
||||
let cstring = std::ffi::CString::from_vec_with_nul(msg)?;
|
||||
let rust_str = cstring.to_str()?;
|
||||
PagestreamBeMessage::Error(PagestreamErrorResponse {
|
||||
message: rust_str.to_owned(),
|
||||
})
|
||||
|
||||
@@ -15,6 +15,10 @@ use tracing::*;
|
||||
/// specified time (in milliseconds). The main difference is that we use async
|
||||
/// tokio sleep function. Another difference is that we print lines to the log,
|
||||
/// which can be useful in tests to check that the failpoint was hit.
|
||||
///
|
||||
/// Optionally pass a cancellation token, and this failpoint will drop out of
|
||||
/// its sleep when the cancellation token fires. This is useful for testing
|
||||
/// cases where we would like to block something, but test its clean shutdown behavior.
|
||||
#[macro_export]
|
||||
macro_rules! __failpoint_sleep_millis_async {
|
||||
($name:literal) => {{
|
||||
@@ -30,6 +34,24 @@ macro_rules! __failpoint_sleep_millis_async {
|
||||
$crate::failpoint_support::failpoint_sleep_helper($name, duration_str).await
|
||||
}
|
||||
}};
|
||||
($name:literal, $cancel:expr) => {{
|
||||
// If the failpoint is used with a "return" action, set should_sleep to the
|
||||
// returned value (as string). Otherwise it's set to None.
|
||||
let should_sleep = (|| {
|
||||
::fail::fail_point!($name, |x| x);
|
||||
::std::option::Option::None
|
||||
})();
|
||||
|
||||
// Sleep if the action was a returned value
|
||||
if let ::std::option::Option::Some(duration_str) = should_sleep {
|
||||
$crate::failpoint_support::failpoint_sleep_cancellable_helper(
|
||||
$name,
|
||||
duration_str,
|
||||
$cancel,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}};
|
||||
}
|
||||
pub use __failpoint_sleep_millis_async as sleep_millis_async;
|
||||
|
||||
@@ -45,6 +67,22 @@ pub async fn failpoint_sleep_helper(name: &'static str, duration_str: String) {
|
||||
tracing::info!("failpoint {:?}: sleep done", name);
|
||||
}
|
||||
|
||||
// Helper function used by the macro. (A function has nicer scoping so we
|
||||
// don't need to decorate everything with "::")
|
||||
#[doc(hidden)]
|
||||
pub async fn failpoint_sleep_cancellable_helper(
|
||||
name: &'static str,
|
||||
duration_str: String,
|
||||
cancel: &CancellationToken,
|
||||
) {
|
||||
let millis = duration_str.parse::<u64>().unwrap();
|
||||
let d = std::time::Duration::from_millis(millis);
|
||||
|
||||
tracing::info!("failpoint {:?}: sleeping for {:?}", name, d);
|
||||
tokio::time::timeout(d, cancel.cancelled()).await.ok();
|
||||
tracing::info!("failpoint {:?}: sleep done", name);
|
||||
}
|
||||
|
||||
pub fn init() -> fail::FailScenario<'static> {
|
||||
// The failpoints lib provides support for parsing the `FAILPOINTS` env var.
|
||||
// We want non-default behavior for `exit`, though, so, we handle it separately.
|
||||
|
||||
@@ -446,12 +446,11 @@ impl Runner {
|
||||
if let Some(t) = self.last_upscale_request_at {
|
||||
let elapsed = t.elapsed();
|
||||
if elapsed < Duration::from_secs(1) {
|
||||
info!(
|
||||
elapsed_millis = elapsed.as_millis(),
|
||||
avg_non_reclaimable = bytes_to_mebibytes(cgroup_mem_stat.avg_non_reclaimable),
|
||||
threshold = bytes_to_mebibytes(cgroup.threshold),
|
||||
"cgroup memory stats are high enough to upscale but too soon to forward the request, ignoring",
|
||||
);
|
||||
// *Ideally* we'd like to log here that we're ignoring the fact the
|
||||
// memory stats are too high, but in practice this can result in
|
||||
// spamming the logs with repetitive messages about ignoring the signal
|
||||
//
|
||||
// See https://github.com/neondatabase/neon/issues/5865 for more.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,14 +28,12 @@ pub enum Error {
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait ResponseErrorMessageExt: Sized {
|
||||
pub(crate) trait ResponseErrorMessageExt: Sized {
|
||||
async fn error_from_body(self) -> Result<Self>;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ResponseErrorMessageExt for reqwest::Response {
|
||||
async fn error_from_body(mut self) -> Result<Self> {
|
||||
async fn error_from_body(self) -> Result<Self> {
|
||||
let status = self.status();
|
||||
if !(status.is_client_error() || status.is_server_error()) {
|
||||
return Ok(self);
|
||||
@@ -51,6 +49,11 @@ impl ResponseErrorMessageExt for reqwest::Response {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ForceAwaitLogicalSize {
|
||||
Yes,
|
||||
No,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub fn new(mgmt_api_endpoint: String, jwt: Option<&str>) -> Self {
|
||||
Self {
|
||||
@@ -94,11 +97,18 @@ impl Client {
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
force_await_logical_size: ForceAwaitLogicalSize,
|
||||
) -> Result<pageserver_api::models::TimelineInfo> {
|
||||
let uri = format!(
|
||||
"{}/v1/tenant/{tenant_id}/timeline/{timeline_id}",
|
||||
self.mgmt_api_endpoint
|
||||
);
|
||||
|
||||
let uri = match force_await_logical_size {
|
||||
ForceAwaitLogicalSize::Yes => format!("{}?force-await-logical-size={}", uri, true),
|
||||
ForceAwaitLogicalSize::No => uri,
|
||||
};
|
||||
|
||||
self.get(&uri)
|
||||
.await?
|
||||
.json()
|
||||
@@ -211,4 +221,16 @@ impl Client {
|
||||
.await
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
|
||||
pub async fn tenant_reset(&self, tenant_shard_id: TenantShardId) -> Result<()> {
|
||||
let uri = format!(
|
||||
"{}/v1/tenant/{}/reset",
|
||||
self.mgmt_api_endpoint, tenant_shard_id
|
||||
);
|
||||
self.request(Method::POST, &uri, ())
|
||||
.await?
|
||||
.json()
|
||||
.await
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
futures.workspace = true
|
||||
hdrhistogram.workspace = true
|
||||
@@ -18,6 +19,7 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tracing.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio-util.workspace = true
|
||||
|
||||
pageserver = { path = ".." }
|
||||
pageserver_client.workspace = true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use anyhow::Context;
|
||||
use pageserver_client::mgmt_api::ForceAwaitLogicalSize;
|
||||
use pageserver_client::page_service::BasebackupRequest;
|
||||
|
||||
use utils::id::TenantTimelineId;
|
||||
@@ -92,10 +93,12 @@ async fn main_impl(
|
||||
for timeline in &timelines {
|
||||
js.spawn({
|
||||
let timeline = *timeline;
|
||||
// FIXME: this triggers initial logical size calculation
|
||||
// https://github.com/neondatabase/neon/issues/6168
|
||||
let info = mgmt_api_client
|
||||
.timeline_info(timeline.tenant_id, timeline.timeline_id)
|
||||
.timeline_info(
|
||||
timeline.tenant_id,
|
||||
timeline.timeline_id,
|
||||
ForceAwaitLogicalSize::No,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
async move {
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use anyhow::Context;
|
||||
use camino::Utf8PathBuf;
|
||||
use futures::future::join_all;
|
||||
use pageserver::pgdatadir_mapping::key_to_rel_block;
|
||||
use pageserver::repository;
|
||||
use pageserver_api::key::is_rel_block_key;
|
||||
use pageserver_api::keyspace::KeySpaceAccum;
|
||||
use pageserver_api::models::PagestreamGetPageRequest;
|
||||
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::id::TenantTimelineId;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
@@ -13,7 +16,7 @@ use tokio::sync::Barrier;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{info, instrument};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::future::Future;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::pin::Pin;
|
||||
@@ -44,6 +47,12 @@ pub(crate) struct Args {
|
||||
req_latest_probability: f64,
|
||||
#[clap(long)]
|
||||
limit_to_first_n_targets: Option<usize>,
|
||||
/// For large pageserver installations, enumerating the keyspace takes a lot of time.
|
||||
/// If specified, the specified path is used to maintain a cache of the keyspace enumeration result.
|
||||
/// The cache is tagged and auto-invalided by the tenant/timeline ids only.
|
||||
/// It doesn't get invalidated if the keyspace changes under the hood, e.g., due to new ingested data or compaction.
|
||||
#[clap(long)]
|
||||
keyspace_cache: Option<Utf8PathBuf>,
|
||||
targets: Option<Vec<TenantTimelineId>>,
|
||||
}
|
||||
|
||||
@@ -58,7 +67,7 @@ impl LiveStats {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct KeyRange {
|
||||
timeline: TenantTimelineId,
|
||||
timeline_lsn: Lsn,
|
||||
@@ -106,59 +115,107 @@ async fn main_impl(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut js = JoinSet::new();
|
||||
for timeline in &timelines {
|
||||
js.spawn({
|
||||
let mgmt_api_client = Arc::clone(&mgmt_api_client);
|
||||
let timeline = *timeline;
|
||||
async move {
|
||||
let partitioning = mgmt_api_client
|
||||
.keyspace(timeline.tenant_id, timeline.timeline_id)
|
||||
.await?;
|
||||
let lsn = partitioning.at_lsn;
|
||||
|
||||
let ranges = partitioning
|
||||
.keys
|
||||
.ranges
|
||||
.iter()
|
||||
.filter_map(|r| {
|
||||
let start = r.start;
|
||||
let end = r.end;
|
||||
// filter out non-relblock keys
|
||||
match (is_rel_block_key(&start), is_rel_block_key(&end)) {
|
||||
(true, true) => Some(KeyRange {
|
||||
timeline,
|
||||
timeline_lsn: lsn,
|
||||
start: start.to_i128(),
|
||||
end: end.to_i128(),
|
||||
}),
|
||||
(true, false) | (false, true) => {
|
||||
unimplemented!("split up range")
|
||||
#[derive(serde::Deserialize)]
|
||||
struct KeyspaceCacheDe {
|
||||
tag: Vec<TenantTimelineId>,
|
||||
data: Vec<KeyRange>,
|
||||
}
|
||||
#[derive(serde::Serialize)]
|
||||
struct KeyspaceCacheSer<'a> {
|
||||
tag: &'a [TenantTimelineId],
|
||||
data: &'a [KeyRange],
|
||||
}
|
||||
let cache = args
|
||||
.keyspace_cache
|
||||
.as_ref()
|
||||
.map(|keyspace_cache_file| {
|
||||
let contents = match std::fs::read(keyspace_cache_file) {
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
return anyhow::Ok(None);
|
||||
}
|
||||
x => x.context("read keyspace cache file")?,
|
||||
};
|
||||
let cache: KeyspaceCacheDe =
|
||||
serde_json::from_slice(&contents).context("deserialize cache file")?;
|
||||
let tag_ok = HashSet::<TenantTimelineId>::from_iter(cache.tag.into_iter())
|
||||
== HashSet::from_iter(timelines.iter().cloned());
|
||||
info!("keyspace cache file matches tag: {tag_ok}");
|
||||
anyhow::Ok(if tag_ok { Some(cache.data) } else { None })
|
||||
})
|
||||
.transpose()?
|
||||
.flatten();
|
||||
let all_ranges: Vec<KeyRange> = if let Some(cached) = cache {
|
||||
info!("using keyspace cache file");
|
||||
cached
|
||||
} else {
|
||||
let mut js = JoinSet::new();
|
||||
for timeline in &timelines {
|
||||
js.spawn({
|
||||
let mgmt_api_client = Arc::clone(&mgmt_api_client);
|
||||
let timeline = *timeline;
|
||||
async move {
|
||||
let partitioning = mgmt_api_client
|
||||
.keyspace(timeline.tenant_id, timeline.timeline_id)
|
||||
.await?;
|
||||
let lsn = partitioning.at_lsn;
|
||||
let start = Instant::now();
|
||||
let mut filtered = KeySpaceAccum::new();
|
||||
// let's hope this is inlined and vectorized...
|
||||
// TODO: turn this loop into a is_rel_block_range() function.
|
||||
for r in partitioning.keys.ranges.iter() {
|
||||
let mut i = r.start;
|
||||
while i != r.end {
|
||||
if is_rel_block_key(&i) {
|
||||
filtered.add_key(i);
|
||||
}
|
||||
(false, false) => None,
|
||||
i = i.next();
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
}
|
||||
let filtered = filtered.to_keyspace();
|
||||
let filter_duration = start.elapsed();
|
||||
|
||||
anyhow::Ok(ranges)
|
||||
}
|
||||
});
|
||||
}
|
||||
let mut all_ranges: Vec<KeyRange> = Vec::new();
|
||||
while let Some(res) = js.join_next().await {
|
||||
all_ranges.extend(res.unwrap().unwrap());
|
||||
}
|
||||
anyhow::Ok((
|
||||
filter_duration,
|
||||
filtered.ranges.into_iter().map(move |r| KeyRange {
|
||||
timeline,
|
||||
timeline_lsn: lsn,
|
||||
start: r.start.to_i128(),
|
||||
end: r.end.to_i128(),
|
||||
}),
|
||||
))
|
||||
}
|
||||
});
|
||||
}
|
||||
let mut total_filter_duration = Duration::from_secs(0);
|
||||
let mut all_ranges: Vec<KeyRange> = Vec::new();
|
||||
while let Some(res) = js.join_next().await {
|
||||
let (filter_duration, range) = res.unwrap().unwrap();
|
||||
all_ranges.extend(range);
|
||||
total_filter_duration += filter_duration;
|
||||
}
|
||||
info!("filter duration: {}", total_filter_duration.as_secs_f64());
|
||||
if let Some(cachefile) = args.keyspace_cache.as_ref() {
|
||||
let cache = KeyspaceCacheSer {
|
||||
tag: &timelines,
|
||||
data: &all_ranges,
|
||||
};
|
||||
let bytes = serde_json::to_vec(&cache).context("serialize keyspace for cache file")?;
|
||||
std::fs::write(cachefile, bytes).context("write keyspace cache file to disk")?;
|
||||
info!("successfully wrote keyspace cache file");
|
||||
}
|
||||
all_ranges
|
||||
};
|
||||
|
||||
let live_stats = Arc::new(LiveStats::default());
|
||||
|
||||
let num_client_tasks = timelines.len();
|
||||
let num_live_stats_dump = 1;
|
||||
let num_work_sender_tasks = 1;
|
||||
let num_main_impl = 1;
|
||||
|
||||
let start_work_barrier = Arc::new(tokio::sync::Barrier::new(
|
||||
num_client_tasks + num_live_stats_dump + num_work_sender_tasks,
|
||||
num_client_tasks + num_live_stats_dump + num_work_sender_tasks + num_main_impl,
|
||||
));
|
||||
let all_work_done_barrier = Arc::new(tokio::sync::Barrier::new(num_client_tasks));
|
||||
|
||||
tokio::spawn({
|
||||
let stats = Arc::clone(&live_stats);
|
||||
@@ -178,125 +235,143 @@ async fn main_impl(
|
||||
}
|
||||
});
|
||||
|
||||
let mut work_senders = HashMap::new();
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
let mut work_senders: HashMap<TenantTimelineId, _> = HashMap::new();
|
||||
let mut tasks = Vec::new();
|
||||
for tl in &timelines {
|
||||
let (sender, receiver) = tokio::sync::mpsc::channel(10); // TODO: not sure what the implications of this are
|
||||
work_senders.insert(tl, sender);
|
||||
work_senders.insert(*tl, sender);
|
||||
tasks.push(tokio::spawn(client(
|
||||
args,
|
||||
*tl,
|
||||
Arc::clone(&start_work_barrier),
|
||||
receiver,
|
||||
Arc::clone(&all_work_done_barrier),
|
||||
Arc::clone(&live_stats),
|
||||
cancel.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
let work_sender: Pin<Box<dyn Send + Future<Output = ()>>> = match args.per_target_rate_limit {
|
||||
None => Box::pin(async move {
|
||||
let weights = rand::distributions::weighted::WeightedIndex::new(
|
||||
all_ranges.iter().map(|v| v.len()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
start_work_barrier.wait().await;
|
||||
|
||||
loop {
|
||||
let (timeline, req) = {
|
||||
let mut rng = rand::thread_rng();
|
||||
let r = &all_ranges[weights.sample(&mut rng)];
|
||||
let key: i128 = rng.gen_range(r.start..r.end);
|
||||
let key = repository::Key::from_i128(key);
|
||||
let (rel_tag, block_no) =
|
||||
key_to_rel_block(key).expect("we filter non-rel-block keys out above");
|
||||
(
|
||||
r.timeline,
|
||||
PagestreamGetPageRequest {
|
||||
latest: rng.gen_bool(args.req_latest_probability),
|
||||
lsn: r.timeline_lsn,
|
||||
rel: rel_tag,
|
||||
blkno: block_no,
|
||||
},
|
||||
)
|
||||
};
|
||||
let sender = work_senders.get(&timeline).unwrap();
|
||||
// TODO: what if this blocks?
|
||||
sender.send(req).await.ok().unwrap();
|
||||
}
|
||||
}),
|
||||
Some(rps_limit) => Box::pin(async move {
|
||||
let period = Duration::from_secs_f64(1.0 / (rps_limit as f64));
|
||||
|
||||
let make_timeline_task: &dyn Fn(
|
||||
TenantTimelineId,
|
||||
)
|
||||
-> Pin<Box<dyn Send + Future<Output = ()>>> = &|timeline| {
|
||||
let sender = work_senders.get(&timeline).unwrap();
|
||||
let ranges: Vec<KeyRange> = all_ranges
|
||||
.iter()
|
||||
.filter(|r| r.timeline == timeline)
|
||||
.cloned()
|
||||
.collect();
|
||||
let work_sender: Pin<Box<dyn Send + Future<Output = ()>>> = {
|
||||
let start_work_barrier = start_work_barrier.clone();
|
||||
let cancel = cancel.clone();
|
||||
match args.per_target_rate_limit {
|
||||
None => Box::pin(async move {
|
||||
let weights = rand::distributions::weighted::WeightedIndex::new(
|
||||
ranges.iter().map(|v| v.len()),
|
||||
all_ranges.iter().map(|v| v.len()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Box::pin(async move {
|
||||
let mut ticker = tokio::time::interval(period);
|
||||
ticker.set_missed_tick_behavior(
|
||||
/* TODO review this choice */
|
||||
tokio::time::MissedTickBehavior::Burst,
|
||||
);
|
||||
loop {
|
||||
ticker.tick().await;
|
||||
let req = {
|
||||
let mut rng = rand::thread_rng();
|
||||
let r = &ranges[weights.sample(&mut rng)];
|
||||
let key: i128 = rng.gen_range(r.start..r.end);
|
||||
let key = repository::Key::from_i128(key);
|
||||
let (rel_tag, block_no) = key_to_rel_block(key)
|
||||
.expect("we filter non-rel-block keys out above");
|
||||
start_work_barrier.wait().await;
|
||||
|
||||
while !cancel.is_cancelled() {
|
||||
let (timeline, req) = {
|
||||
let mut rng = rand::thread_rng();
|
||||
let r = &all_ranges[weights.sample(&mut rng)];
|
||||
let key: i128 = rng.gen_range(r.start..r.end);
|
||||
let key = repository::Key::from_i128(key);
|
||||
let (rel_tag, block_no) =
|
||||
key_to_rel_block(key).expect("we filter non-rel-block keys out above");
|
||||
(
|
||||
r.timeline,
|
||||
PagestreamGetPageRequest {
|
||||
latest: rng.gen_bool(args.req_latest_probability),
|
||||
lsn: r.timeline_lsn,
|
||||
rel: rel_tag,
|
||||
blkno: block_no,
|
||||
}
|
||||
};
|
||||
sender.send(req).await.ok().unwrap();
|
||||
},
|
||||
)
|
||||
};
|
||||
let sender = work_senders.get(&timeline).unwrap();
|
||||
// TODO: what if this blocks?
|
||||
if sender.send(req).await.is_err() {
|
||||
assert!(cancel.is_cancelled(), "client has gone away unexpectedly");
|
||||
}
|
||||
})
|
||||
};
|
||||
}
|
||||
}),
|
||||
Some(rps_limit) => Box::pin(async move {
|
||||
let period = Duration::from_secs_f64(1.0 / (rps_limit as f64));
|
||||
let make_timeline_task: &dyn Fn(
|
||||
TenantTimelineId,
|
||||
)
|
||||
-> Pin<Box<dyn Send + Future<Output = ()>>> = &|timeline| {
|
||||
let sender = work_senders.get(&timeline).unwrap();
|
||||
let ranges: Vec<KeyRange> = all_ranges
|
||||
.iter()
|
||||
.filter(|r| r.timeline == timeline)
|
||||
.cloned()
|
||||
.collect();
|
||||
let weights = rand::distributions::weighted::WeightedIndex::new(
|
||||
ranges.iter().map(|v| v.len()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tasks: Vec<_> = work_senders
|
||||
.keys()
|
||||
.map(|tl| make_timeline_task(**tl))
|
||||
.collect();
|
||||
let cancel = cancel.clone();
|
||||
Box::pin(async move {
|
||||
let mut ticker = tokio::time::interval(period);
|
||||
ticker.set_missed_tick_behavior(
|
||||
/* TODO review this choice */
|
||||
tokio::time::MissedTickBehavior::Burst,
|
||||
);
|
||||
while !cancel.is_cancelled() {
|
||||
ticker.tick().await;
|
||||
let req = {
|
||||
let mut rng = rand::thread_rng();
|
||||
let r = &ranges[weights.sample(&mut rng)];
|
||||
let key: i128 = rng.gen_range(r.start..r.end);
|
||||
let key = repository::Key::from_i128(key);
|
||||
assert!(is_rel_block_key(&key));
|
||||
let (rel_tag, block_no) = key_to_rel_block(key)
|
||||
.expect("we filter non-rel-block keys out above");
|
||||
PagestreamGetPageRequest {
|
||||
latest: rng.gen_bool(args.req_latest_probability),
|
||||
lsn: r.timeline_lsn,
|
||||
rel: rel_tag,
|
||||
blkno: block_no,
|
||||
}
|
||||
};
|
||||
if sender.send(req).await.is_err() {
|
||||
assert!(cancel.is_cancelled(), "client has gone away unexpectedly");
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
start_work_barrier.wait().await;
|
||||
let tasks: Vec<_> = work_senders
|
||||
.keys()
|
||||
.map(|tl| make_timeline_task(*tl))
|
||||
.collect();
|
||||
|
||||
join_all(tasks).await;
|
||||
}),
|
||||
start_work_barrier.wait().await;
|
||||
|
||||
join_all(tasks).await;
|
||||
}),
|
||||
}
|
||||
};
|
||||
|
||||
let work_sender_task = tokio::spawn(work_sender);
|
||||
|
||||
if let Some(runtime) = args.runtime {
|
||||
match tokio::time::timeout(runtime.into(), work_sender).await {
|
||||
Ok(()) => unreachable!("work sender never terminates"),
|
||||
Err(_timeout) => {
|
||||
// this implicitly drops the work_senders, making all the clients exit
|
||||
}
|
||||
}
|
||||
info!("waiting for everything to become ready");
|
||||
start_work_barrier.wait().await;
|
||||
info!("work started");
|
||||
tokio::time::sleep(runtime.into()).await;
|
||||
info!("runtime over, signalling cancellation");
|
||||
cancel.cancel();
|
||||
work_sender_task.await.unwrap();
|
||||
info!("work sender exited");
|
||||
} else {
|
||||
work_sender.await;
|
||||
work_sender_task.await.unwrap();
|
||||
unreachable!("work sender never terminates");
|
||||
}
|
||||
|
||||
info!("joining clients");
|
||||
for t in tasks {
|
||||
t.await.unwrap();
|
||||
}
|
||||
|
||||
info!("all clients stopped");
|
||||
|
||||
let output = Output {
|
||||
total: {
|
||||
let mut agg_stats = request_stats::Stats::new();
|
||||
@@ -320,11 +395,9 @@ async fn client(
|
||||
timeline: TenantTimelineId,
|
||||
start_work_barrier: Arc<Barrier>,
|
||||
mut work: tokio::sync::mpsc::Receiver<PagestreamGetPageRequest>,
|
||||
all_work_done_barrier: Arc<Barrier>,
|
||||
live_stats: Arc<LiveStats>,
|
||||
cancel: CancellationToken,
|
||||
) {
|
||||
start_work_barrier.wait().await;
|
||||
|
||||
let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -333,12 +406,18 @@ async fn client(
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
while let Some(req) = work.recv().await {
|
||||
start_work_barrier.wait().await;
|
||||
|
||||
while let Some(req) =
|
||||
tokio::select! { work = work.recv() => { work } , _ = cancel.cancelled() => { return; } }
|
||||
{
|
||||
let start = Instant::now();
|
||||
client
|
||||
.getpage(req)
|
||||
.await
|
||||
.with_context(|| format!("getpage for {timeline}"))
|
||||
|
||||
let res = tokio::select! {
|
||||
res = client.getpage(req) => { res },
|
||||
_ = cancel.cancelled() => { return; }
|
||||
};
|
||||
res.with_context(|| format!("getpage for {timeline}"))
|
||||
.unwrap();
|
||||
let elapsed = start.elapsed();
|
||||
live_stats.inc();
|
||||
@@ -346,6 +425,4 @@ async fn client(
|
||||
stats.borrow().lock().unwrap().observe(elapsed).unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
all_work_done_barrier.wait().await;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ use humantime::Duration;
|
||||
use tokio::task::JoinSet;
|
||||
use utils::id::TenantTimelineId;
|
||||
|
||||
use pageserver_client::mgmt_api::ForceAwaitLogicalSize;
|
||||
|
||||
#[derive(clap::Parser)]
|
||||
pub(crate) struct Args {
|
||||
#[clap(long, default_value = "http://localhost:9898")]
|
||||
@@ -56,14 +58,15 @@ async fn main_impl(args: Args) -> anyhow::Result<()> {
|
||||
for tl in timelines {
|
||||
let mgmt_api_client = Arc::clone(&mgmt_api_client);
|
||||
js.spawn(async move {
|
||||
// TODO: API to explicitly trigger initial logical size computation.
|
||||
// Should probably also avoid making it a side effect of timeline details to trigger initial logical size calculation.
|
||||
// => https://github.com/neondatabase/neon/issues/6168
|
||||
let info = mgmt_api_client
|
||||
.timeline_info(tl.tenant_id, tl.timeline_id)
|
||||
.timeline_info(tl.tenant_id, tl.timeline_id, ForceAwaitLogicalSize::Yes)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Polling should not be strictly required here since we await
|
||||
// for the initial logical size, however it's possible for the request
|
||||
// to land before the timeline is initialised. This results in an approximate
|
||||
// logical size.
|
||||
if let Some(period) = args.poll_for_completion {
|
||||
let mut ticker = tokio::time::interval(period.into());
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
|
||||
@@ -71,7 +74,7 @@ async fn main_impl(args: Args) -> anyhow::Result<()> {
|
||||
while !info.current_logical_size_is_accurate {
|
||||
ticker.tick().await;
|
||||
info = mgmt_api_client
|
||||
.timeline_info(tl.tenant_id, tl.timeline_id)
|
||||
.timeline_info(tl.tenant_id, tl.timeline_id, ForceAwaitLogicalSize::Yes)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use futures::Future;
|
||||
use pageserver_api::{
|
||||
control_api::{
|
||||
ReAttachRequest, ReAttachResponse, ValidateRequest, ValidateRequestTenant, ValidateResponse,
|
||||
@@ -28,13 +29,14 @@ pub enum RetryForeverError {
|
||||
ShuttingDown,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait ControlPlaneGenerationsApi {
|
||||
async fn re_attach(&self) -> Result<HashMap<TenantShardId, Generation>, RetryForeverError>;
|
||||
async fn validate(
|
||||
fn re_attach(
|
||||
&self,
|
||||
) -> impl Future<Output = Result<HashMap<TenantShardId, Generation>, RetryForeverError>> + Send;
|
||||
fn validate(
|
||||
&self,
|
||||
tenants: Vec<(TenantShardId, Generation)>,
|
||||
) -> Result<HashMap<TenantShardId, bool>, RetryForeverError>;
|
||||
) -> impl Future<Output = Result<HashMap<TenantShardId, bool>, RetryForeverError>> + Send;
|
||||
}
|
||||
|
||||
impl ControlPlaneClient {
|
||||
@@ -123,7 +125,6 @@ impl ControlPlaneClient {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ControlPlaneGenerationsApi for ControlPlaneClient {
|
||||
/// Block until we get a successful response, or error out if we are shut down
|
||||
async fn re_attach(&self) -> Result<HashMap<TenantShardId, Generation>, RetryForeverError> {
|
||||
|
||||
@@ -831,7 +831,6 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ControlPlaneGenerationsApi for MockControlPlane {
|
||||
#[allow(clippy::diverging_sub_expression)] // False positive via async_trait
|
||||
async fn re_attach(&self) -> Result<HashMap<TenantShardId, Generation>, RetryForeverError> {
|
||||
|
||||
@@ -15,6 +15,7 @@ use hyper::StatusCode;
|
||||
use hyper::{Body, Request, Response, Uri};
|
||||
use metrics::launch_timestamp::LaunchTimestamp;
|
||||
use pageserver_api::models::TenantDetails;
|
||||
use pageserver_api::models::TenantState;
|
||||
use pageserver_api::models::{
|
||||
DownloadRemoteLayersTaskSpawnRequest, LocationConfigMode, TenantAttachRequest,
|
||||
TenantLoadRequest, TenantLocationConfigRequest,
|
||||
@@ -37,6 +38,7 @@ use crate::pgdatadir_mapping::LsnForTimestamp;
|
||||
use crate::task_mgr::TaskKind;
|
||||
use crate::tenant::config::{LocationConf, TenantConfOpt};
|
||||
use crate::tenant::mgr::GetActiveTenantError;
|
||||
use crate::tenant::mgr::UpsertLocationError;
|
||||
use crate::tenant::mgr::{
|
||||
GetTenantError, SetNewTenantConfigError, TenantManager, TenantMapError, TenantMapInsertError,
|
||||
TenantSlotError, TenantSlotUpsertError, TenantStateError,
|
||||
@@ -46,7 +48,8 @@ use crate::tenant::size::ModelInputs;
|
||||
use crate::tenant::storage_layer::LayerAccessStatsReset;
|
||||
use crate::tenant::timeline::CompactFlags;
|
||||
use crate::tenant::timeline::Timeline;
|
||||
use crate::tenant::{LogicalSizeCalculationCause, PageReconstructError, TenantSharedResources};
|
||||
use crate::tenant::SpawnMode;
|
||||
use crate::tenant::{LogicalSizeCalculationCause, PageReconstructError};
|
||||
use crate::{config::PageServerConf, tenant::mgr};
|
||||
use crate::{disk_usage_eviction_task, tenant};
|
||||
use pageserver_api::models::{
|
||||
@@ -112,14 +115,6 @@ impl State {
|
||||
secondary_controller,
|
||||
})
|
||||
}
|
||||
|
||||
fn tenant_resources(&self) -> TenantSharedResources {
|
||||
TenantSharedResources {
|
||||
broker_client: self.broker_client.clone(),
|
||||
remote_storage: self.remote_storage.clone(),
|
||||
deletion_queue_client: self.deletion_queue_client.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@@ -175,7 +170,7 @@ impl From<TenantSlotError> for ApiError {
|
||||
NotFound(tenant_id) => {
|
||||
ApiError::NotFound(anyhow::anyhow!("NotFound: tenant {tenant_id}").into())
|
||||
}
|
||||
e @ (AlreadyExists(_, _) | Conflict(_)) => ApiError::Conflict(format!("{e}")),
|
||||
e @ AlreadyExists(_, _) => ApiError::Conflict(format!("{e}")),
|
||||
InProgress => {
|
||||
ApiError::ResourceUnavailable("Tenant is being modified concurrently".into())
|
||||
}
|
||||
@@ -194,6 +189,18 @@ impl From<TenantSlotUpsertError> for ApiError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UpsertLocationError> for ApiError {
|
||||
fn from(e: UpsertLocationError) -> ApiError {
|
||||
use UpsertLocationError::*;
|
||||
match e {
|
||||
BadRequest(e) => ApiError::BadRequest(e),
|
||||
Unavailable(_) => ApiError::ShuttingDown,
|
||||
e @ InProgress => ApiError::Conflict(format!("{e}")),
|
||||
Flush(e) | Other(e) => ApiError::InternalServerError(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<TenantMapError> for ApiError {
|
||||
fn from(e: TenantMapError) -> ApiError {
|
||||
use TenantMapError::*;
|
||||
@@ -316,11 +323,21 @@ impl From<crate::tenant::delete::DeleteTenantError> for ApiError {
|
||||
async fn build_timeline_info(
|
||||
timeline: &Arc<Timeline>,
|
||||
include_non_incremental_logical_size: bool,
|
||||
force_await_initial_logical_size: bool,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<TimelineInfo> {
|
||||
crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
|
||||
let mut info = build_timeline_info_common(timeline, ctx).await?;
|
||||
if force_await_initial_logical_size {
|
||||
timeline.clone().await_initial_logical_size().await
|
||||
}
|
||||
|
||||
let mut info = build_timeline_info_common(
|
||||
timeline,
|
||||
ctx,
|
||||
tenant::timeline::GetLogicalSizePriority::Background,
|
||||
)
|
||||
.await?;
|
||||
if include_non_incremental_logical_size {
|
||||
// XXX we should be using spawn_ondemand_logical_size_calculation here.
|
||||
// Otherwise, if someone deletes the timeline / detaches the tenant while
|
||||
@@ -337,6 +354,7 @@ async fn build_timeline_info(
|
||||
async fn build_timeline_info_common(
|
||||
timeline: &Arc<Timeline>,
|
||||
ctx: &RequestContext,
|
||||
logical_size_task_priority: tenant::timeline::GetLogicalSizePriority,
|
||||
) -> anyhow::Result<TimelineInfo> {
|
||||
crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
let initdb_lsn = timeline.initdb_lsn;
|
||||
@@ -359,8 +377,7 @@ async fn build_timeline_info_common(
|
||||
Lsn(0) => None,
|
||||
lsn @ Lsn(_) => Some(lsn),
|
||||
};
|
||||
let current_logical_size =
|
||||
timeline.get_current_logical_size(tenant::timeline::GetLogicalSizePriority::User, ctx);
|
||||
let current_logical_size = timeline.get_current_logical_size(logical_size_task_priority, ctx);
|
||||
let current_physical_size = Some(timeline.layer_size_sum().await);
|
||||
let state = timeline.current_state();
|
||||
let remote_consistent_lsn_projected = timeline
|
||||
@@ -471,7 +488,7 @@ async fn timeline_create_handler(
|
||||
.await {
|
||||
Ok(new_timeline) => {
|
||||
// Created. Construct a TimelineInfo for it.
|
||||
let timeline_info = build_timeline_info_common(&new_timeline, &ctx)
|
||||
let timeline_info = build_timeline_info_common(&new_timeline, &ctx, tenant::timeline::GetLogicalSizePriority::User)
|
||||
.await
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
json_response(StatusCode::CREATED, timeline_info)
|
||||
@@ -507,6 +524,8 @@ async fn timeline_list_handler(
|
||||
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
|
||||
let include_non_incremental_logical_size: Option<bool> =
|
||||
parse_query_param(&request, "include-non-incremental-logical-size")?;
|
||||
let force_await_initial_logical_size: Option<bool> =
|
||||
parse_query_param(&request, "force-await-initial-logical-size")?;
|
||||
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
|
||||
|
||||
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
|
||||
@@ -520,6 +539,7 @@ async fn timeline_list_handler(
|
||||
let timeline_info = build_timeline_info(
|
||||
&timeline,
|
||||
include_non_incremental_logical_size.unwrap_or(false),
|
||||
force_await_initial_logical_size.unwrap_or(false),
|
||||
&ctx,
|
||||
)
|
||||
.instrument(info_span!("build_timeline_info", timeline_id = %timeline.timeline_id))
|
||||
@@ -547,6 +567,8 @@ async fn timeline_detail_handler(
|
||||
let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?;
|
||||
let include_non_incremental_logical_size: Option<bool> =
|
||||
parse_query_param(&request, "include-non-incremental-logical-size")?;
|
||||
let force_await_initial_logical_size: Option<bool> =
|
||||
parse_query_param(&request, "force-await-initial-logical-size")?;
|
||||
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
|
||||
|
||||
// Logical size calculation needs downloading.
|
||||
@@ -562,6 +584,7 @@ async fn timeline_detail_handler(
|
||||
let timeline_info = build_timeline_info(
|
||||
&timeline,
|
||||
include_non_incremental_logical_size.unwrap_or(false),
|
||||
force_await_initial_logical_size.unwrap_or(false),
|
||||
&ctx,
|
||||
)
|
||||
.await
|
||||
@@ -680,16 +703,37 @@ async fn tenant_attach_handler(
|
||||
)));
|
||||
}
|
||||
|
||||
mgr::attach_tenant(
|
||||
state.conf,
|
||||
tenant_id,
|
||||
generation,
|
||||
tenant_conf,
|
||||
state.tenant_resources(),
|
||||
&ctx,
|
||||
)
|
||||
.instrument(info_span!("tenant_attach", %tenant_id))
|
||||
.await?;
|
||||
let tenant_shard_id = TenantShardId::unsharded(tenant_id);
|
||||
let location_conf = LocationConf::attached_single(tenant_conf, generation);
|
||||
let tenant = state
|
||||
.tenant_manager
|
||||
.upsert_location(
|
||||
tenant_shard_id,
|
||||
location_conf,
|
||||
None,
|
||||
SpawnMode::Normal,
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let Some(tenant) = tenant else {
|
||||
// This should never happen: indicates a bug in upsert_location
|
||||
return Err(ApiError::InternalServerError(anyhow::anyhow!(
|
||||
"Upsert succeeded but didn't return tenant!"
|
||||
)));
|
||||
};
|
||||
|
||||
// We might have successfully constructed a Tenant, but it could still
|
||||
// end up in a broken state:
|
||||
if let TenantState::Broken {
|
||||
reason,
|
||||
backtrace: _,
|
||||
} = tenant.current_state()
|
||||
{
|
||||
return Err(ApiError::InternalServerError(anyhow::anyhow!(
|
||||
"Tenant state is Broken: {reason}"
|
||||
)));
|
||||
}
|
||||
|
||||
json_response(StatusCode::ACCEPTED, ())
|
||||
}
|
||||
@@ -1148,16 +1192,25 @@ async fn tenant_create_handler(
|
||||
|
||||
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Warn);
|
||||
|
||||
let new_tenant = mgr::create_tenant(
|
||||
state.conf,
|
||||
tenant_conf,
|
||||
target_tenant_id,
|
||||
generation,
|
||||
state.tenant_resources(),
|
||||
&ctx,
|
||||
)
|
||||
.instrument(info_span!("tenant_create", tenant_id = %target_tenant_id))
|
||||
.await?;
|
||||
let location_conf = LocationConf::attached_single(tenant_conf, generation);
|
||||
|
||||
let new_tenant = state
|
||||
.tenant_manager
|
||||
.upsert_location(
|
||||
target_tenant_id,
|
||||
location_conf,
|
||||
None,
|
||||
SpawnMode::Create,
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let Some(new_tenant) = new_tenant else {
|
||||
// This should never happen: indicates a bug in upsert_location
|
||||
return Err(ApiError::InternalServerError(anyhow::anyhow!(
|
||||
"Upsert succeeded but didn't return tenant!"
|
||||
)));
|
||||
};
|
||||
|
||||
// We created the tenant. Existing API semantics are that the tenant
|
||||
// is Active when this function returns.
|
||||
@@ -1166,7 +1219,7 @@ async fn tenant_create_handler(
|
||||
.await
|
||||
{
|
||||
// This shouldn't happen because we just created the tenant directory
|
||||
// in tenant::mgr::create_tenant, and there aren't any remote timelines
|
||||
// in upsert_location, and there aren't any remote timelines
|
||||
// to load, so, nothing can really fail during load.
|
||||
// Don't do cleanup because we don't know how we got here.
|
||||
// The tenant will likely be in `Broken` state and subsequent
|
||||
@@ -1267,12 +1320,14 @@ async fn put_tenant_location_config_handler(
|
||||
|
||||
state
|
||||
.tenant_manager
|
||||
.upsert_location(tenant_shard_id, location_conf, flush, &ctx)
|
||||
.await
|
||||
// TODO: badrequest assumes the caller was asking for something unreasonable, but in
|
||||
// principle we might have hit something like concurrent API calls to the same tenant,
|
||||
// which is not a 400 but a 409.
|
||||
.map_err(ApiError::BadRequest)?;
|
||||
.upsert_location(
|
||||
tenant_shard_id,
|
||||
location_conf,
|
||||
flush,
|
||||
tenant::SpawnMode::Normal,
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Some(_flush_ms) = flush {
|
||||
match state
|
||||
|
||||
@@ -1500,7 +1500,8 @@ impl From<GetActiveTenantError> for QueryError {
|
||||
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
|
||||
ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
|
||||
),
|
||||
GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => {
|
||||
GetActiveTenantError::Cancelled
|
||||
| GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => {
|
||||
QueryError::Shutdown
|
||||
}
|
||||
e => QueryError::Other(anyhow::anyhow!(e)),
|
||||
|
||||
@@ -23,7 +23,7 @@ impl Statvfs {
|
||||
}
|
||||
|
||||
// NB: allow() because the block count type is u32 on macOS.
|
||||
#[allow(clippy::useless_conversion)]
|
||||
#[allow(clippy::useless_conversion, clippy::unnecessary_fallible_conversions)]
|
||||
pub fn blocks(&self) -> u64 {
|
||||
match self {
|
||||
Statvfs::Real(stat) => u64::try_from(stat.blocks()).unwrap(),
|
||||
@@ -32,7 +32,7 @@ impl Statvfs {
|
||||
}
|
||||
|
||||
// NB: allow() because the block count type is u32 on macOS.
|
||||
#[allow(clippy::useless_conversion)]
|
||||
#[allow(clippy::useless_conversion, clippy::unnecessary_fallible_conversions)]
|
||||
pub fn blocks_available(&self) -> u64 {
|
||||
match self {
|
||||
Statvfs::Real(stat) => u64::try_from(stat.blocks_available()).unwrap(),
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
//!
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use camino::Utf8Path;
|
||||
use enumset::EnumSet;
|
||||
use futures::stream::FuturesUnordered;
|
||||
use futures::FutureExt;
|
||||
@@ -130,6 +130,13 @@ macro_rules! pausable_failpoint {
|
||||
.expect("spawn_blocking");
|
||||
}
|
||||
};
|
||||
($name:literal, $cond:expr) => {
|
||||
if cfg!(feature = "testing") {
|
||||
if $cond {
|
||||
pausable_failpoint!($name)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub mod blob_io;
|
||||
@@ -1003,7 +1010,7 @@ impl Tenant {
|
||||
// IndexPart is the source of truth.
|
||||
self.clean_up_timelines(&existent_timelines)?;
|
||||
|
||||
failpoint_support::sleep_millis_async!("attach-before-activate");
|
||||
failpoint_support::sleep_millis_async!("attach-before-activate", &self.cancel);
|
||||
|
||||
info!("Done");
|
||||
|
||||
@@ -2036,6 +2043,13 @@ impl Tenant {
|
||||
// It's mesed up.
|
||||
// we just ignore the failure to stop
|
||||
|
||||
// If we're still attaching, fire the cancellation token early to drop out: this
|
||||
// will prevent us flushing, but ensures timely shutdown if some I/O during attach
|
||||
// is very slow.
|
||||
if matches!(self.current_state(), TenantState::Attaching) {
|
||||
self.cancel.cancel();
|
||||
}
|
||||
|
||||
match self.set_stopping(shutdown_progress, false, false).await {
|
||||
Ok(()) => {}
|
||||
Err(SetStoppingError::Broken) => {
|
||||
@@ -2734,6 +2748,10 @@ impl Tenant {
|
||||
"#
|
||||
.to_string();
|
||||
|
||||
fail::fail_point!("tenant-config-before-write", |_| {
|
||||
anyhow::bail!("tenant-config-before-write");
|
||||
});
|
||||
|
||||
// Convert the config to a toml file.
|
||||
conf_content += &toml_edit::ser::to_string_pretty(&location_conf)?;
|
||||
|
||||
@@ -3650,140 +3668,6 @@ fn remove_timeline_and_uninit_mark(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn create_tenant_files(
|
||||
conf: &'static PageServerConf,
|
||||
location_conf: &LocationConf,
|
||||
tenant_shard_id: &TenantShardId,
|
||||
) -> anyhow::Result<Utf8PathBuf> {
|
||||
let target_tenant_directory = conf.tenant_path(tenant_shard_id);
|
||||
anyhow::ensure!(
|
||||
!target_tenant_directory
|
||||
.try_exists()
|
||||
.context("check existence of tenant directory")?,
|
||||
"tenant directory already exists",
|
||||
);
|
||||
|
||||
let temporary_tenant_dir =
|
||||
path_with_suffix_extension(&target_tenant_directory, TEMP_FILE_SUFFIX);
|
||||
debug!("Creating temporary directory structure in {temporary_tenant_dir}");
|
||||
|
||||
// top-level dir may exist if we are creating it through CLI
|
||||
crashsafe::create_dir_all(&temporary_tenant_dir).with_context(|| {
|
||||
format!("could not create temporary tenant directory {temporary_tenant_dir}")
|
||||
})?;
|
||||
|
||||
let creation_result = try_create_target_tenant_dir(
|
||||
conf,
|
||||
location_conf,
|
||||
tenant_shard_id,
|
||||
&temporary_tenant_dir,
|
||||
&target_tenant_directory,
|
||||
)
|
||||
.await;
|
||||
|
||||
if creation_result.is_err() {
|
||||
error!(
|
||||
"Failed to create directory structure for tenant {tenant_shard_id}, cleaning tmp data"
|
||||
);
|
||||
if let Err(e) = fs::remove_dir_all(&temporary_tenant_dir) {
|
||||
error!("Failed to remove temporary tenant directory {temporary_tenant_dir:?}: {e}")
|
||||
} else if let Err(e) = crashsafe::fsync(&temporary_tenant_dir) {
|
||||
error!(
|
||||
"Failed to fsync removed temporary tenant directory {temporary_tenant_dir:?}: {e}"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
creation_result?;
|
||||
|
||||
Ok(target_tenant_directory)
|
||||
}
|
||||
|
||||
async fn try_create_target_tenant_dir(
|
||||
conf: &'static PageServerConf,
|
||||
location_conf: &LocationConf,
|
||||
tenant_shard_id: &TenantShardId,
|
||||
temporary_tenant_dir: &Utf8Path,
|
||||
target_tenant_directory: &Utf8Path,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let temporary_tenant_timelines_dir = rebase_directory(
|
||||
&conf.timelines_path(tenant_shard_id),
|
||||
target_tenant_directory,
|
||||
temporary_tenant_dir,
|
||||
)
|
||||
.with_context(|| format!("resolve tenant {tenant_shard_id} temporary timelines dir"))?;
|
||||
let temporary_legacy_tenant_config_path = rebase_directory(
|
||||
&conf.tenant_config_path(tenant_shard_id),
|
||||
target_tenant_directory,
|
||||
temporary_tenant_dir,
|
||||
)
|
||||
.with_context(|| format!("resolve tenant {tenant_shard_id} temporary config path"))?;
|
||||
let temporary_tenant_config_path = rebase_directory(
|
||||
&conf.tenant_location_config_path(tenant_shard_id),
|
||||
target_tenant_directory,
|
||||
temporary_tenant_dir,
|
||||
)
|
||||
.with_context(|| format!("resolve tenant {tenant_shard_id} temporary config path"))?;
|
||||
|
||||
Tenant::persist_tenant_config_at(
|
||||
tenant_shard_id,
|
||||
&temporary_tenant_config_path,
|
||||
&temporary_legacy_tenant_config_path,
|
||||
location_conf,
|
||||
)
|
||||
.await?;
|
||||
|
||||
crashsafe::create_dir(&temporary_tenant_timelines_dir).with_context(|| {
|
||||
format!(
|
||||
"create tenant {} temporary timelines directory {}",
|
||||
tenant_shard_id, temporary_tenant_timelines_dir,
|
||||
)
|
||||
})?;
|
||||
fail::fail_point!("tenant-creation-before-tmp-rename", |_| {
|
||||
anyhow::bail!("failpoint tenant-creation-before-tmp-rename");
|
||||
});
|
||||
|
||||
// Make sure the current tenant directory entries are durable before renaming.
|
||||
// Without this, a crash may reorder any of the directory entry creations above.
|
||||
crashsafe::fsync(temporary_tenant_dir)
|
||||
.with_context(|| format!("sync temporary tenant directory {temporary_tenant_dir:?}"))?;
|
||||
|
||||
fs::rename(temporary_tenant_dir, target_tenant_directory).with_context(|| {
|
||||
format!(
|
||||
"move tenant {} temporary directory {} into the permanent one {}",
|
||||
tenant_shard_id, temporary_tenant_dir, target_tenant_directory
|
||||
)
|
||||
})?;
|
||||
let target_dir_parent = target_tenant_directory.parent().with_context(|| {
|
||||
format!(
|
||||
"get tenant {} dir parent for {}",
|
||||
tenant_shard_id, target_tenant_directory,
|
||||
)
|
||||
})?;
|
||||
crashsafe::fsync(target_dir_parent).with_context(|| {
|
||||
format!(
|
||||
"fsync renamed directory's parent {} for tenant {}",
|
||||
target_dir_parent, tenant_shard_id,
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rebase_directory(
|
||||
original_path: &Utf8Path,
|
||||
base: &Utf8Path,
|
||||
new_base: &Utf8Path,
|
||||
) -> anyhow::Result<Utf8PathBuf> {
|
||||
let relative_path = original_path.strip_prefix(base).with_context(|| {
|
||||
format!(
|
||||
"Failed to strip base prefix '{}' off path '{}'",
|
||||
base, original_path
|
||||
)
|
||||
})?;
|
||||
Ok(new_base.join(relative_path))
|
||||
}
|
||||
|
||||
/// Create the cluster temporarily in 'initdbpath' directory inside the repository
|
||||
/// to get bootstrap data for timeline initialization.
|
||||
async fn run_initdb(
|
||||
@@ -3878,6 +3762,7 @@ pub async fn dump_layerfile_from_path(
|
||||
#[cfg(test)]
|
||||
pub(crate) mod harness {
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use camino::Utf8PathBuf;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pageserver_api::shard::ShardIndex;
|
||||
use std::fs;
|
||||
@@ -3945,8 +3830,6 @@ pub(crate) mod harness {
|
||||
pub struct TenantHarness {
|
||||
pub conf: &'static PageServerConf,
|
||||
pub tenant_conf: TenantConf,
|
||||
// TODO(sharding): remove duplicative `tenant_id` in favor of access to tenant_shard_id
|
||||
pub(crate) tenant_id: TenantId,
|
||||
pub tenant_shard_id: TenantShardId,
|
||||
pub generation: Generation,
|
||||
pub shard: ShardIndex,
|
||||
@@ -4008,7 +3891,6 @@ pub(crate) mod harness {
|
||||
Ok(Self {
|
||||
conf,
|
||||
tenant_conf,
|
||||
tenant_id,
|
||||
tenant_shard_id,
|
||||
generation: Generation::new(0xdeadbeef),
|
||||
shard: ShardIndex::unsharded(),
|
||||
|
||||
@@ -35,7 +35,7 @@ use crate::tenant::config::{
|
||||
};
|
||||
use crate::tenant::delete::DeleteTenantFlow;
|
||||
use crate::tenant::span::debug_assert_current_span_has_tenant_id;
|
||||
use crate::tenant::{create_tenant_files, AttachedTenantConf, SpawnMode, Tenant, TenantState};
|
||||
use crate::tenant::{AttachedTenantConf, SpawnMode, Tenant, TenantState};
|
||||
use crate::{InitializationOrder, IGNORED_TENANT_FILE_NAME, TEMP_FILE_SUFFIX};
|
||||
|
||||
use utils::crashsafe::path_with_suffix_extension;
|
||||
@@ -754,45 +754,6 @@ async fn shutdown_all_tenants0(tenants: &std::sync::RwLock<TenantsMap>) {
|
||||
// caller will log how long we took
|
||||
}
|
||||
|
||||
pub(crate) async fn create_tenant(
|
||||
conf: &'static PageServerConf,
|
||||
tenant_conf: TenantConfOpt,
|
||||
tenant_shard_id: TenantShardId,
|
||||
generation: Generation,
|
||||
resources: TenantSharedResources,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Arc<Tenant>, TenantMapInsertError> {
|
||||
let location_conf = LocationConf::attached_single(tenant_conf, generation);
|
||||
info!("Creating tenant at location {location_conf:?}");
|
||||
|
||||
let slot_guard =
|
||||
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustNotExist)?;
|
||||
let tenant_path = super::create_tenant_files(conf, &location_conf, &tenant_shard_id).await?;
|
||||
|
||||
let shard_identity = location_conf.shard;
|
||||
let created_tenant = tenant_spawn(
|
||||
conf,
|
||||
tenant_shard_id,
|
||||
&tenant_path,
|
||||
resources,
|
||||
AttachedTenantConf::try_from(location_conf)?,
|
||||
shard_identity,
|
||||
None,
|
||||
&TENANTS,
|
||||
SpawnMode::Create,
|
||||
ctx,
|
||||
)?;
|
||||
// TODO: tenant object & its background loops remain, untracked in tenant map, if we fail here.
|
||||
// See https://github.com/neondatabase/neon/issues/4233
|
||||
|
||||
let created_tenant_id = created_tenant.tenant_id();
|
||||
debug_assert_eq!(created_tenant_id, tenant_shard_id.tenant_id);
|
||||
|
||||
slot_guard.upsert(TenantSlot::Attached(created_tenant.clone()))?;
|
||||
|
||||
Ok(created_tenant)
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum SetNewTenantConfigError {
|
||||
#[error(transparent)]
|
||||
@@ -824,6 +785,24 @@ pub(crate) async fn set_new_tenant_config(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub(crate) enum UpsertLocationError {
|
||||
#[error("Bad config request: {0}")]
|
||||
BadRequest(anyhow::Error),
|
||||
|
||||
#[error("Cannot change config in this state: {0}")]
|
||||
Unavailable(#[from] TenantMapError),
|
||||
|
||||
#[error("Tenant is already being modified")]
|
||||
InProgress,
|
||||
|
||||
#[error("Failed to flush: {0}")]
|
||||
Flush(anyhow::Error),
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl TenantManager {
|
||||
/// Convenience function so that anyone with a TenantManager can get at the global configuration, without
|
||||
/// having to pass it around everywhere as a separate object.
|
||||
@@ -888,8 +867,9 @@ impl TenantManager {
|
||||
tenant_shard_id: TenantShardId,
|
||||
new_location_config: LocationConf,
|
||||
flush: Option<Duration>,
|
||||
spawn_mode: SpawnMode,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<Option<Arc<Tenant>>, UpsertLocationError> {
|
||||
debug_assert_current_span_has_tenant_id();
|
||||
info!("configuring tenant location to state {new_location_config:?}");
|
||||
|
||||
@@ -911,9 +891,10 @@ impl TenantManager {
|
||||
// A transition from Attached to Attached in the same generation, we may
|
||||
// take our fast path and just provide the updated configuration
|
||||
// to the tenant.
|
||||
tenant.set_new_location_config(AttachedTenantConf::try_from(
|
||||
new_location_config.clone(),
|
||||
)?);
|
||||
tenant.set_new_location_config(
|
||||
AttachedTenantConf::try_from(new_location_config.clone())
|
||||
.map_err(UpsertLocationError::BadRequest)?,
|
||||
);
|
||||
|
||||
Some(FastPathModified::Attached(tenant.clone()))
|
||||
} else {
|
||||
@@ -940,8 +921,7 @@ impl TenantManager {
|
||||
match fast_path_taken {
|
||||
Some(FastPathModified::Attached(tenant)) => {
|
||||
Tenant::persist_tenant_config(self.conf, &tenant_shard_id, &new_location_config)
|
||||
.await
|
||||
.map_err(SetNewTenantConfigError::Persist)?;
|
||||
.await?;
|
||||
|
||||
// Transition to AttachedStale means we may well hold a valid generation
|
||||
// still, and have been requested to go stale as part of a migration. If
|
||||
@@ -954,9 +934,9 @@ impl TenantManager {
|
||||
if let Some(flush_timeout) = flush {
|
||||
match tokio::time::timeout(flush_timeout, tenant.flush_remote()).await {
|
||||
Ok(Err(e)) => {
|
||||
return Err(e);
|
||||
return Err(UpsertLocationError::Flush(e));
|
||||
}
|
||||
Ok(Ok(_)) => return Ok(()),
|
||||
Ok(Ok(_)) => return Ok(Some(tenant)),
|
||||
Err(_) => {
|
||||
tracing::warn!(
|
||||
timeout_ms = flush_timeout.as_millis(),
|
||||
@@ -967,14 +947,13 @@ impl TenantManager {
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
return Ok(Some(tenant));
|
||||
}
|
||||
Some(FastPathModified::Secondary(_secondary_tenant)) => {
|
||||
Tenant::persist_tenant_config(self.conf, &tenant_shard_id, &new_location_config)
|
||||
.await
|
||||
.map_err(SetNewTenantConfigError::Persist)?;
|
||||
.await?;
|
||||
|
||||
return Ok(());
|
||||
return Ok(None);
|
||||
}
|
||||
None => {
|
||||
// Proceed with the general case procedure, where we will shutdown & remove any existing
|
||||
@@ -987,7 +966,14 @@ impl TenantManager {
|
||||
// the tenant is inaccessible to the outside world while we are doing this, but that is sensible:
|
||||
// the state is ill-defined while we're in transition. Transitions are async, but fast: we do
|
||||
// not do significant I/O, and shutdowns should be prompt via cancellation tokens.
|
||||
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
|
||||
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)
|
||||
.map_err(|e| match e {
|
||||
TenantSlotError::AlreadyExists(_, _) | TenantSlotError::NotFound(_) => {
|
||||
unreachable!("Called with mode Any")
|
||||
}
|
||||
TenantSlotError::InProgress => UpsertLocationError::InProgress,
|
||||
TenantSlotError::MapState(s) => UpsertLocationError::Unavailable(s),
|
||||
})?;
|
||||
|
||||
match slot_guard.get_old_value() {
|
||||
Some(TenantSlot::Attached(tenant)) => {
|
||||
@@ -1025,7 +1011,9 @@ impl TenantManager {
|
||||
Some(TenantSlot::InProgress(_)) => {
|
||||
// This should never happen: acquire_slot should error out
|
||||
// if the contents of a slot were InProgress.
|
||||
anyhow::bail!("Acquired an InProgress slot, this is a bug.")
|
||||
return Err(UpsertLocationError::Other(anyhow::anyhow!(
|
||||
"Acquired an InProgress slot, this is a bug."
|
||||
)));
|
||||
}
|
||||
None => {
|
||||
// Slot was vacant, nothing needs shutting down.
|
||||
@@ -1047,9 +1035,7 @@ impl TenantManager {
|
||||
// Before activating either secondary or attached mode, persist the
|
||||
// configuration, so that on restart we will re-attach (or re-start
|
||||
// secondary) on the tenant.
|
||||
Tenant::persist_tenant_config(self.conf, &tenant_shard_id, &new_location_config)
|
||||
.await
|
||||
.map_err(SetNewTenantConfigError::Persist)?;
|
||||
Tenant::persist_tenant_config(self.conf, &tenant_shard_id, &new_location_config).await?;
|
||||
|
||||
let new_slot = match &new_location_config.mode {
|
||||
LocationMode::Secondary(secondary_config) => {
|
||||
@@ -1066,7 +1052,7 @@ impl TenantManager {
|
||||
shard_identity,
|
||||
None,
|
||||
self.tenants,
|
||||
SpawnMode::Normal,
|
||||
spawn_mode,
|
||||
ctx,
|
||||
)?;
|
||||
|
||||
@@ -1074,9 +1060,20 @@ impl TenantManager {
|
||||
}
|
||||
};
|
||||
|
||||
slot_guard.upsert(new_slot)?;
|
||||
let attached_tenant = if let TenantSlot::Attached(tenant) = &new_slot {
|
||||
Some(tenant.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(())
|
||||
slot_guard.upsert(new_slot).map_err(|e| match e {
|
||||
TenantSlotUpsertError::InternalError(e) => {
|
||||
UpsertLocationError::Other(anyhow::anyhow!(e))
|
||||
}
|
||||
TenantSlotUpsertError::MapState(e) => UpsertLocationError::Unavailable(e),
|
||||
})?;
|
||||
|
||||
Ok(attached_tenant)
|
||||
}
|
||||
|
||||
/// Resetting a tenant is equivalent to detaching it, then attaching it again with the same
|
||||
@@ -1648,55 +1645,6 @@ pub(crate) async fn list_tenants() -> Result<Vec<(TenantShardId, TenantState)>,
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Execute Attach mgmt API command.
|
||||
///
|
||||
/// Downloading all the tenant data is performed in the background, this merely
|
||||
/// spawns the background task and returns quickly.
|
||||
pub(crate) async fn attach_tenant(
|
||||
conf: &'static PageServerConf,
|
||||
tenant_id: TenantId,
|
||||
generation: Generation,
|
||||
tenant_conf: TenantConfOpt,
|
||||
resources: TenantSharedResources,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<(), TenantMapInsertError> {
|
||||
// This is a legacy API (replaced by `/location_conf`). It does not support sharding
|
||||
let tenant_shard_id = TenantShardId::unsharded(tenant_id);
|
||||
|
||||
let slot_guard =
|
||||
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustNotExist)?;
|
||||
let location_conf = LocationConf::attached_single(tenant_conf, generation);
|
||||
let tenant_dir = create_tenant_files(conf, &location_conf, &tenant_shard_id).await?;
|
||||
// TODO: tenant directory remains on disk if we bail out from here on.
|
||||
// See https://github.com/neondatabase/neon/issues/4233
|
||||
|
||||
let shard_identity = location_conf.shard;
|
||||
let attached_tenant = tenant_spawn(
|
||||
conf,
|
||||
tenant_shard_id,
|
||||
&tenant_dir,
|
||||
resources,
|
||||
AttachedTenantConf::try_from(location_conf)?,
|
||||
shard_identity,
|
||||
None,
|
||||
&TENANTS,
|
||||
SpawnMode::Normal,
|
||||
ctx,
|
||||
)?;
|
||||
// TODO: tenant object & its background loops remain, untracked in tenant map, if we fail here.
|
||||
// See https://github.com/neondatabase/neon/issues/4233
|
||||
|
||||
let attached_tenant_id = attached_tenant.tenant_id();
|
||||
if tenant_id != attached_tenant_id {
|
||||
return Err(TenantMapInsertError::Other(anyhow::anyhow!(
|
||||
"loaded created tenant has unexpected tenant id (expect {tenant_id} != actual {attached_tenant_id})",
|
||||
)));
|
||||
}
|
||||
|
||||
slot_guard.upsert(TenantSlot::Attached(attached_tenant))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum TenantMapInsertError {
|
||||
#[error(transparent)]
|
||||
@@ -1710,7 +1658,7 @@ pub(crate) enum TenantMapInsertError {
|
||||
/// Superset of TenantMapError: issues that can occur when acquiring a slot
|
||||
/// for a particular tenant ID.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TenantSlotError {
|
||||
pub(crate) enum TenantSlotError {
|
||||
/// When acquiring a slot with the expectation that the tenant already exists.
|
||||
#[error("Tenant {0} not found")]
|
||||
NotFound(TenantShardId),
|
||||
@@ -1719,9 +1667,6 @@ pub enum TenantSlotError {
|
||||
#[error("tenant {0} already exists, state: {1:?}")]
|
||||
AlreadyExists(TenantShardId, TenantState),
|
||||
|
||||
#[error("tenant {0} already exists in but is not attached")]
|
||||
Conflict(TenantShardId),
|
||||
|
||||
// Tried to read a slot that is currently being mutated by another administrative
|
||||
// operation.
|
||||
#[error("tenant has a state change in progress, try again later")]
|
||||
|
||||
@@ -1903,7 +1903,7 @@ mod tests {
|
||||
fn span(&self) -> tracing::Span {
|
||||
tracing::info_span!(
|
||||
"test",
|
||||
tenant_id = %self.harness.tenant_id,
|
||||
tenant_id = %self.harness.tenant_shard_id.tenant_id,
|
||||
timeline_id = %TIMELINE_ID
|
||||
)
|
||||
}
|
||||
|
||||
@@ -186,7 +186,6 @@ type Scheduler = TenantBackgroundJobs<
|
||||
DownloadCommand,
|
||||
>;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl JobGenerator<PendingDownload, RunningDownload, CompleteDownload, DownloadCommand>
|
||||
for SecondaryDownloader
|
||||
{
|
||||
|
||||
@@ -134,7 +134,6 @@ type Scheduler = TenantBackgroundJobs<
|
||||
UploadCommand,
|
||||
>;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl JobGenerator<UploadPending, WriteInProgress, WriteComplete, UploadCommand>
|
||||
for HeatmapUploader
|
||||
{
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use async_trait;
|
||||
use futures::Future;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
@@ -65,7 +64,6 @@ where
|
||||
_phantom: PhantomData<(PJ, RJ, C, CMD)>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub(crate) trait JobGenerator<PJ, RJ, C, CMD>
|
||||
where
|
||||
C: Completion,
|
||||
|
||||
@@ -320,8 +320,8 @@ impl DeltaLayer {
|
||||
.metadata()
|
||||
.context("get file metadata to determine size")?;
|
||||
|
||||
// TODO(sharding): we must get the TenantShardId from the path instead of reading the Summary.
|
||||
// we should also validate the path against the Summary, as both should contain the same tenant, timeline, key, lsn.
|
||||
// This function is never used for constructing layers in a running pageserver,
|
||||
// so it does not need an accurate TenantShardId.
|
||||
let tenant_shard_id = TenantShardId::unsharded(summary.tenant_id);
|
||||
|
||||
Ok(DeltaLayer {
|
||||
|
||||
@@ -278,8 +278,8 @@ impl ImageLayer {
|
||||
.metadata()
|
||||
.context("get file metadata to determine size")?;
|
||||
|
||||
// TODO(sharding): we should get TenantShardId from path.
|
||||
// OR, not at all: any layer we load from disk should also get reconciled with remote IndexPart.
|
||||
// This function is never used for constructing layers in a running pageserver,
|
||||
// so it does not need an accurate TenantShardId.
|
||||
let tenant_shard_id = TenantShardId::unsharded(summary.tenant_id);
|
||||
|
||||
Ok(ImageLayer {
|
||||
|
||||
@@ -945,8 +945,18 @@ impl LayerInner {
|
||||
Ok((Err(e), _permit)) => {
|
||||
// sleep already happened in the spawned task, if it was not cancelled
|
||||
let consecutive_failures = self.consecutive_failures.load(Ordering::Relaxed);
|
||||
tracing::error!(consecutive_failures, "layer file download failed: {e:#}");
|
||||
Err(DownloadError::DownloadFailed)
|
||||
|
||||
match e.downcast_ref::<remote_storage::DownloadError>() {
|
||||
// If the download failed due to its cancellation token,
|
||||
// propagate the cancellation error upstream.
|
||||
Some(remote_storage::DownloadError::Cancelled) => {
|
||||
Err(DownloadError::DownloadCancelled)
|
||||
}
|
||||
_ => {
|
||||
tracing::error!(consecutive_failures, "layer file download failed: {e:#}");
|
||||
Err(DownloadError::DownloadFailed)
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_gone) => Err(DownloadError::DownloadCancelled),
|
||||
}
|
||||
|
||||
@@ -65,6 +65,11 @@ pub(crate) async fn concurrent_background_tasks_rate_limit_permit(
|
||||
.with_label_values(&[loop_kind.as_static_str()])
|
||||
.guard();
|
||||
|
||||
pausable_failpoint!(
|
||||
"initial-size-calculation-permit-pause",
|
||||
loop_kind == BackgroundLoopKind::InitialLogicalSizeCalculation
|
||||
);
|
||||
|
||||
match CONCURRENT_BACKGROUND_TASKS.acquire().await {
|
||||
Ok(permit) => permit,
|
||||
Err(_closed) => unreachable!("we never close the semaphore"),
|
||||
|
||||
@@ -3131,11 +3131,13 @@ impl Timeline {
|
||||
.await
|
||||
.context("fsync of newly created layer files")?;
|
||||
|
||||
par_fsync::par_fsync_async(&[self
|
||||
.conf
|
||||
.timeline_path(&self.tenant_shard_id, &self.timeline_id)])
|
||||
.await
|
||||
.context("fsync of timeline dir")?;
|
||||
if !all_paths.is_empty() {
|
||||
par_fsync::par_fsync_async(&[self
|
||||
.conf
|
||||
.timeline_path(&self.tenant_shard_id, &self.timeline_id)])
|
||||
.await
|
||||
.context("fsync of timeline dir")?;
|
||||
}
|
||||
|
||||
let mut guard = self.layers.write().await;
|
||||
|
||||
|
||||
@@ -1337,7 +1337,7 @@ mod tests {
|
||||
|
||||
ConnectionManagerState {
|
||||
id: TenantTimelineId {
|
||||
tenant_id: harness.tenant_id,
|
||||
tenant_id: harness.tenant_shard_id.tenant_id,
|
||||
timeline_id: TIMELINE_ID,
|
||||
},
|
||||
timeline,
|
||||
|
||||
@@ -18,7 +18,8 @@ use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Error, ErrorKind, Seek, SeekFrom};
|
||||
use std::os::unix::fs::FileExt;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::{RwLock, RwLockWriteGuard};
|
||||
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
use tokio::time::Instant;
|
||||
use utils::fs_ext;
|
||||
|
||||
///
|
||||
@@ -111,7 +112,7 @@ impl OpenFiles {
|
||||
///
|
||||
/// On return, we hold a lock on the slot, and its 'tag' has been updated
|
||||
/// recently_used has been set. It's all ready for reuse.
|
||||
fn find_victim_slot(&self) -> (SlotHandle, RwLockWriteGuard<SlotInner>) {
|
||||
async fn find_victim_slot(&self) -> (SlotHandle, RwLockWriteGuard<SlotInner>) {
|
||||
//
|
||||
// Run the clock algorithm to find a slot to replace.
|
||||
//
|
||||
@@ -143,7 +144,7 @@ impl OpenFiles {
|
||||
}
|
||||
retries += 1;
|
||||
} else {
|
||||
slot_guard = slot.inner.write().unwrap();
|
||||
slot_guard = slot.inner.write().await;
|
||||
index = next;
|
||||
break;
|
||||
}
|
||||
@@ -250,6 +251,29 @@ impl<T> MaybeFatalIo<T> for std::io::Result<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Observe duration for the given storage I/O operation
|
||||
///
|
||||
/// Unlike `observe_closure_duration`, this supports async,
|
||||
/// where "support" means that we measure wall clock time.
|
||||
macro_rules! observe_duration {
|
||||
($op:expr, $($body:tt)*) => {{
|
||||
let instant = Instant::now();
|
||||
let result = $($body)*;
|
||||
let elapsed = instant.elapsed().as_secs_f64();
|
||||
STORAGE_IO_TIME_METRIC
|
||||
.get($op)
|
||||
.observe(elapsed);
|
||||
result
|
||||
}}
|
||||
}
|
||||
|
||||
macro_rules! with_file {
|
||||
($this:expr, $op:expr, | $ident:ident | $($body:tt)*) => {{
|
||||
let $ident = $this.lock_file().await?;
|
||||
observe_duration!($op, $($body)*)
|
||||
}};
|
||||
}
|
||||
|
||||
impl VirtualFile {
|
||||
/// Open a file in read-only mode. Like File::open.
|
||||
pub async fn open(path: &Utf8Path) -> Result<VirtualFile, std::io::Error> {
|
||||
@@ -286,14 +310,12 @@ impl VirtualFile {
|
||||
tenant_id = "*".to_string();
|
||||
timeline_id = "*".to_string();
|
||||
}
|
||||
let (handle, mut slot_guard) = get_open_files().find_victim_slot();
|
||||
let (handle, mut slot_guard) = get_open_files().find_victim_slot().await;
|
||||
|
||||
// NB: there is also StorageIoOperation::OpenAfterReplace which is for the case
|
||||
// where our caller doesn't get to use the returned VirtualFile before its
|
||||
// slot gets re-used by someone else.
|
||||
let file = STORAGE_IO_TIME_METRIC
|
||||
.get(StorageIoOperation::Open)
|
||||
.observe_closure_duration(|| open_options.open(path))?;
|
||||
let file = observe_duration!(StorageIoOperation::Open, open_options.open(path))?;
|
||||
|
||||
// Strip all options other than read and write.
|
||||
//
|
||||
@@ -366,22 +388,24 @@ impl VirtualFile {
|
||||
|
||||
/// Call File::sync_all() on the underlying File.
|
||||
pub async fn sync_all(&self) -> Result<(), Error> {
|
||||
self.with_file(StorageIoOperation::Fsync, |file| file.sync_all())
|
||||
.await?
|
||||
with_file!(self, StorageIoOperation::Fsync, |file| file
|
||||
.as_ref()
|
||||
.sync_all())
|
||||
}
|
||||
|
||||
pub async fn metadata(&self) -> Result<fs::Metadata, Error> {
|
||||
self.with_file(StorageIoOperation::Metadata, |file| file.metadata())
|
||||
.await?
|
||||
with_file!(self, StorageIoOperation::Metadata, |file| file
|
||||
.as_ref()
|
||||
.metadata())
|
||||
}
|
||||
|
||||
/// Helper function that looks up the underlying File for this VirtualFile,
|
||||
/// opening it and evicting some other File if necessary. It calls 'func'
|
||||
/// with the physical File.
|
||||
async fn with_file<F, R>(&self, op: StorageIoOperation, mut func: F) -> Result<R, Error>
|
||||
where
|
||||
F: FnMut(&File) -> R,
|
||||
{
|
||||
/// Helper function internal to `VirtualFile` that looks up the underlying File,
|
||||
/// opens it and evicts some other File if necessary. The passed parameter is
|
||||
/// assumed to be a function available for the physical `File`.
|
||||
///
|
||||
/// We are doing it via a macro as Rust doesn't support async closures that
|
||||
/// take on parameters with lifetimes.
|
||||
async fn lock_file(&self) -> Result<FileGuard<'_>, Error> {
|
||||
let open_files = get_open_files();
|
||||
|
||||
let mut handle_guard = {
|
||||
@@ -391,27 +415,23 @@ impl VirtualFile {
|
||||
// We only need to hold the handle lock while we read the current handle. If
|
||||
// another thread closes the file and recycles the slot for a different file,
|
||||
// we will notice that the handle we read is no longer valid and retry.
|
||||
let mut handle = *self.handle.read().unwrap();
|
||||
let mut handle = *self.handle.read().await;
|
||||
loop {
|
||||
// Check if the slot contains our File
|
||||
{
|
||||
let slot = &open_files.slots[handle.index];
|
||||
let slot_guard = slot.inner.read().unwrap();
|
||||
if slot_guard.tag == handle.tag {
|
||||
if let Some(file) = &slot_guard.file {
|
||||
// Found a cached file descriptor.
|
||||
slot.recently_used.store(true, Ordering::Relaxed);
|
||||
return Ok(STORAGE_IO_TIME_METRIC
|
||||
.get(op)
|
||||
.observe_closure_duration(|| func(file)));
|
||||
}
|
||||
let slot_guard = slot.inner.read().await;
|
||||
if slot_guard.tag == handle.tag && slot_guard.file.is_some() {
|
||||
// Found a cached file descriptor.
|
||||
slot.recently_used.store(true, Ordering::Relaxed);
|
||||
return Ok(FileGuard { slot_guard });
|
||||
}
|
||||
}
|
||||
|
||||
// The slot didn't contain our File. We will have to open it ourselves,
|
||||
// but before that, grab a write lock on handle in the VirtualFile, so
|
||||
// that no other thread will try to concurrently open the same file.
|
||||
let handle_guard = self.handle.write().unwrap();
|
||||
let handle_guard = self.handle.write().await;
|
||||
|
||||
// If another thread changed the handle while we were not holding the lock,
|
||||
// then the handle might now be valid again. Loop back to retry.
|
||||
@@ -425,20 +445,16 @@ impl VirtualFile {
|
||||
|
||||
// We need to open the file ourselves. The handle in the VirtualFile is
|
||||
// now locked in write-mode. Find a free slot to put it in.
|
||||
let (handle, mut slot_guard) = open_files.find_victim_slot();
|
||||
let (handle, mut slot_guard) = open_files.find_victim_slot().await;
|
||||
|
||||
// Re-open the physical file.
|
||||
// NB: we use StorageIoOperation::OpenAferReplace for this to distinguish this
|
||||
// case from StorageIoOperation::Open. This helps with identifying thrashing
|
||||
// of the virtual file descriptor cache.
|
||||
let file = STORAGE_IO_TIME_METRIC
|
||||
.get(StorageIoOperation::OpenAfterReplace)
|
||||
.observe_closure_duration(|| self.open_options.open(&self.path))?;
|
||||
|
||||
// Perform the requested operation on it
|
||||
let result = STORAGE_IO_TIME_METRIC
|
||||
.get(op)
|
||||
.observe_closure_duration(|| func(&file));
|
||||
let file = observe_duration!(
|
||||
StorageIoOperation::OpenAfterReplace,
|
||||
self.open_options.open(&self.path)
|
||||
)?;
|
||||
|
||||
// Store the File in the slot and update the handle in the VirtualFile
|
||||
// to point to it.
|
||||
@@ -446,7 +462,9 @@ impl VirtualFile {
|
||||
|
||||
*handle_guard = handle;
|
||||
|
||||
Ok(result)
|
||||
return Ok(FileGuard {
|
||||
slot_guard: slot_guard.downgrade(),
|
||||
});
|
||||
}
|
||||
|
||||
pub fn remove(self) {
|
||||
@@ -461,11 +479,9 @@ impl VirtualFile {
|
||||
self.pos = offset;
|
||||
}
|
||||
SeekFrom::End(offset) => {
|
||||
self.pos = self
|
||||
.with_file(StorageIoOperation::Seek, |mut file| {
|
||||
file.seek(SeekFrom::End(offset))
|
||||
})
|
||||
.await??
|
||||
self.pos = with_file!(self, StorageIoOperation::Seek, |file| file
|
||||
.as_ref()
|
||||
.seek(SeekFrom::End(offset)))?
|
||||
}
|
||||
SeekFrom::Current(offset) => {
|
||||
let pos = self.pos as i128 + offset as i128;
|
||||
@@ -553,9 +569,9 @@ impl VirtualFile {
|
||||
}
|
||||
|
||||
pub async fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<usize, Error> {
|
||||
let result = self
|
||||
.with_file(StorageIoOperation::Read, |file| file.read_at(buf, offset))
|
||||
.await?;
|
||||
let result = with_file!(self, StorageIoOperation::Read, |file| file
|
||||
.as_ref()
|
||||
.read_at(buf, offset));
|
||||
if let Ok(size) = result {
|
||||
STORAGE_IO_SIZE
|
||||
.with_label_values(&["read", &self.tenant_id, &self.timeline_id])
|
||||
@@ -565,9 +581,9 @@ impl VirtualFile {
|
||||
}
|
||||
|
||||
async fn write_at(&self, buf: &[u8], offset: u64) -> Result<usize, Error> {
|
||||
let result = self
|
||||
.with_file(StorageIoOperation::Write, |file| file.write_at(buf, offset))
|
||||
.await?;
|
||||
let result = with_file!(self, StorageIoOperation::Write, |file| file
|
||||
.as_ref()
|
||||
.write_at(buf, offset));
|
||||
if let Ok(size) = result {
|
||||
STORAGE_IO_SIZE
|
||||
.with_label_values(&["write", &self.tenant_id, &self.timeline_id])
|
||||
@@ -577,6 +593,18 @@ impl VirtualFile {
|
||||
}
|
||||
}
|
||||
|
||||
struct FileGuard<'a> {
|
||||
slot_guard: RwLockReadGuard<'a, SlotInner>,
|
||||
}
|
||||
|
||||
impl<'a> AsRef<File> for FileGuard<'a> {
|
||||
fn as_ref(&self) -> &File {
|
||||
// This unwrap is safe because we only create `FileGuard`s
|
||||
// if we know that the file is Some.
|
||||
self.slot_guard.file.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl VirtualFile {
|
||||
pub(crate) async fn read_blk(
|
||||
@@ -609,22 +637,41 @@ impl VirtualFile {
|
||||
impl Drop for VirtualFile {
|
||||
/// If a VirtualFile is dropped, close the underlying file if it was open.
|
||||
fn drop(&mut self) {
|
||||
let handle = self.handle.get_mut().unwrap();
|
||||
let handle = self.handle.get_mut();
|
||||
|
||||
// We could check with a read-lock first, to avoid waiting on an
|
||||
// unrelated I/O.
|
||||
let slot = &get_open_files().slots[handle.index];
|
||||
let mut slot_guard = slot.inner.write().unwrap();
|
||||
if slot_guard.tag == handle.tag {
|
||||
slot.recently_used.store(false, Ordering::Relaxed);
|
||||
// there is also operation "close-by-replace" for closes done on eviction for
|
||||
// comparison.
|
||||
if let Some(fd) = slot_guard.file.take() {
|
||||
STORAGE_IO_TIME_METRIC
|
||||
.get(StorageIoOperation::Close)
|
||||
.observe_closure_duration(|| drop(fd));
|
||||
fn clean_slot(slot: &Slot, mut slot_guard: RwLockWriteGuard<'_, SlotInner>, tag: u64) {
|
||||
if slot_guard.tag == tag {
|
||||
slot.recently_used.store(false, Ordering::Relaxed);
|
||||
// there is also operation "close-by-replace" for closes done on eviction for
|
||||
// comparison.
|
||||
if let Some(fd) = slot_guard.file.take() {
|
||||
STORAGE_IO_TIME_METRIC
|
||||
.get(StorageIoOperation::Close)
|
||||
.observe_closure_duration(|| drop(fd));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We don't have async drop so we cannot directly await the lock here.
|
||||
// Instead, first do a best-effort attempt at closing the underlying
|
||||
// file descriptor by using `try_write`, and if that fails, spawn
|
||||
// a tokio task to do it asynchronously: we just want it to be
|
||||
// cleaned up eventually.
|
||||
// Most of the time, the `try_lock` should succeed though,
|
||||
// as we have `&mut self` access. In other words, if the slot
|
||||
// is still occupied by our file, there should be no access from
|
||||
// other I/O operations; the only other possible place to lock
|
||||
// the slot is the lock algorithm looking for free slots.
|
||||
let slot = &get_open_files().slots[handle.index];
|
||||
if let Ok(slot_guard) = slot.inner.try_write() {
|
||||
clean_slot(slot, slot_guard, handle.tag);
|
||||
} else {
|
||||
let tag = handle.tag;
|
||||
tokio::spawn(async move {
|
||||
let slot_guard = slot.inner.write().await;
|
||||
clean_slot(slot, slot_guard, tag);
|
||||
});
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
216
poetry.lock
generated
216
poetry.lock
generated
@@ -158,6 +158,28 @@ files = [
|
||||
attrs = ">=16.0.0"
|
||||
pluggy = ">=0.4.0"
|
||||
|
||||
[[package]]
|
||||
name = "anyio"
|
||||
version = "4.2.0"
|
||||
description = "High level compatibility layer for multiple asynchronous event loop implementations"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"},
|
||||
{file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""}
|
||||
idna = ">=2.8"
|
||||
sniffio = ">=1.1"
|
||||
typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"]
|
||||
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"]
|
||||
trio = ["trio (>=0.23)"]
|
||||
|
||||
[[package]]
|
||||
name = "async-timeout"
|
||||
version = "4.0.3"
|
||||
@@ -1064,6 +1086,100 @@ files = [
|
||||
{file = "graphql_core-3.2.1-py3-none-any.whl", hash = "sha256:f83c658e4968998eed1923a2e3e3eddd347e005ac0315fbb7ca4d70ea9156323"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h11"
|
||||
version = "0.14.0"
|
||||
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
|
||||
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "4.1.0"
|
||||
description = "HTTP/2 State-Machine based protocol implementation"
|
||||
optional = false
|
||||
python-versions = ">=3.6.1"
|
||||
files = [
|
||||
{file = "h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d"},
|
||||
{file = "h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
hpack = ">=4.0,<5"
|
||||
hyperframe = ">=6.0,<7"
|
||||
|
||||
[[package]]
|
||||
name = "hpack"
|
||||
version = "4.0.0"
|
||||
description = "Pure-Python HPACK header compression"
|
||||
optional = false
|
||||
python-versions = ">=3.6.1"
|
||||
files = [
|
||||
{file = "hpack-4.0.0-py3-none-any.whl", hash = "sha256:84a076fad3dc9a9f8063ccb8041ef100867b1878b25ef0ee63847a5d53818a6c"},
|
||||
{file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "httpcore"
|
||||
version = "1.0.2"
|
||||
description = "A minimal low-level HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"},
|
||||
{file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
certifi = "*"
|
||||
h11 = ">=0.13,<0.15"
|
||||
|
||||
[package.extras]
|
||||
asyncio = ["anyio (>=4.0,<5.0)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
trio = ["trio (>=0.22.0,<0.23.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx"
|
||||
version = "0.26.0"
|
||||
description = "The next generation HTTP client."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpx-0.26.0-py3-none-any.whl", hash = "sha256:8915f5a3627c4d47b73e8202457cb28f1266982d1159bd5779d86a80c0eab1cd"},
|
||||
{file = "httpx-0.26.0.tar.gz", hash = "sha256:451b55c30d5185ea6b23c2c793abf9bb237d2a7dfb901ced6ff69ad37ec1dfaf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = "*"
|
||||
certifi = "*"
|
||||
h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""}
|
||||
httpcore = "==1.*"
|
||||
idna = "*"
|
||||
sniffio = "*"
|
||||
|
||||
[package.extras]
|
||||
brotli = ["brotli", "brotlicffi"]
|
||||
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "hyperframe"
|
||||
version = "6.0.1"
|
||||
description = "HTTP/2 framing layer for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.6.1"
|
||||
files = [
|
||||
{file = "hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15"},
|
||||
{file = "hyperframe-6.0.1.tar.gz", hash = "sha256:ae510046231dc8e9ecb1a6586f63d2347bf4c8905914aa84ba585ae85f28a914"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "3.3"
|
||||
@@ -1118,13 +1234,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.2"
|
||||
version = "3.1.3"
|
||||
description = "A very fast and expressive template engine."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"},
|
||||
{file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
|
||||
{file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"},
|
||||
{file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2215,6 +2331,17 @@ files = [
|
||||
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.0"
|
||||
description = "Sniff out which async library your code is running under"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"},
|
||||
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sshpubkeys"
|
||||
version = "3.3.1"
|
||||
@@ -2378,6 +2505,87 @@ docs = ["Sphinx (>=3.4)", "sphinx-rtd-theme (>=0.5)"]
|
||||
optional = ["python-socks", "wsaccel"]
|
||||
test = ["websockets"]
|
||||
|
||||
[[package]]
|
||||
name = "websockets"
|
||||
version = "12.0"
|
||||
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "websockets-12.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d554236b2a2006e0ce16315c16eaa0d628dab009c33b63ea03f41c6107958374"},
|
||||
{file = "websockets-12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2d225bb6886591b1746b17c0573e29804619c8f755b5598d875bb4235ea639be"},
|
||||
{file = "websockets-12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:eb809e816916a3b210bed3c82fb88eaf16e8afcf9c115ebb2bacede1797d2547"},
|
||||
{file = "websockets-12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c588f6abc13f78a67044c6b1273a99e1cf31038ad51815b3b016ce699f0d75c2"},
|
||||
{file = "websockets-12.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5aa9348186d79a5f232115ed3fa9020eab66d6c3437d72f9d2c8ac0c6858c558"},
|
||||
{file = "websockets-12.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6350b14a40c95ddd53e775dbdbbbc59b124a5c8ecd6fbb09c2e52029f7a9f480"},
|
||||
{file = "websockets-12.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:70ec754cc2a769bcd218ed8d7209055667b30860ffecb8633a834dde27d6307c"},
|
||||
{file = "websockets-12.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6e96f5ed1b83a8ddb07909b45bd94833b0710f738115751cdaa9da1fb0cb66e8"},
|
||||
{file = "websockets-12.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4d87be612cbef86f994178d5186add3d94e9f31cc3cb499a0482b866ec477603"},
|
||||
{file = "websockets-12.0-cp310-cp310-win32.whl", hash = "sha256:befe90632d66caaf72e8b2ed4d7f02b348913813c8b0a32fae1cc5fe3730902f"},
|
||||
{file = "websockets-12.0-cp310-cp310-win_amd64.whl", hash = "sha256:363f57ca8bc8576195d0540c648aa58ac18cf85b76ad5202b9f976918f4219cf"},
|
||||
{file = "websockets-12.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5d873c7de42dea355d73f170be0f23788cf3fa9f7bed718fd2830eefedce01b4"},
|
||||
{file = "websockets-12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3f61726cae9f65b872502ff3c1496abc93ffbe31b278455c418492016e2afc8f"},
|
||||
{file = "websockets-12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed2fcf7a07334c77fc8a230755c2209223a7cc44fc27597729b8ef5425aa61a3"},
|
||||
{file = "websockets-12.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e332c210b14b57904869ca9f9bf4ca32f5427a03eeb625da9b616c85a3a506c"},
|
||||
{file = "websockets-12.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5693ef74233122f8ebab026817b1b37fe25c411ecfca084b29bc7d6efc548f45"},
|
||||
{file = "websockets-12.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e9e7db18b4539a29cc5ad8c8b252738a30e2b13f033c2d6e9d0549b45841c04"},
|
||||
{file = "websockets-12.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6e2df67b8014767d0f785baa98393725739287684b9f8d8a1001eb2839031447"},
|
||||
{file = "websockets-12.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:bea88d71630c5900690fcb03161ab18f8f244805c59e2e0dc4ffadae0a7ee0ca"},
|
||||
{file = "websockets-12.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dff6cdf35e31d1315790149fee351f9e52978130cef6c87c4b6c9b3baf78bc53"},
|
||||
{file = "websockets-12.0-cp311-cp311-win32.whl", hash = "sha256:3e3aa8c468af01d70332a382350ee95f6986db479ce7af14d5e81ec52aa2b402"},
|
||||
{file = "websockets-12.0-cp311-cp311-win_amd64.whl", hash = "sha256:25eb766c8ad27da0f79420b2af4b85d29914ba0edf69f547cc4f06ca6f1d403b"},
|
||||
{file = "websockets-12.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0e6e2711d5a8e6e482cacb927a49a3d432345dfe7dea8ace7b5790df5932e4df"},
|
||||
{file = "websockets-12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:dbcf72a37f0b3316e993e13ecf32f10c0e1259c28ffd0a85cee26e8549595fbc"},
|
||||
{file = "websockets-12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12743ab88ab2af1d17dd4acb4645677cb7063ef4db93abffbf164218a5d54c6b"},
|
||||
{file = "websockets-12.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b645f491f3c48d3f8a00d1fce07445fab7347fec54a3e65f0725d730d5b99cb"},
|
||||
{file = "websockets-12.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9893d1aa45a7f8b3bc4510f6ccf8db8c3b62120917af15e3de247f0780294b92"},
|
||||
{file = "websockets-12.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1f38a7b376117ef7aff996e737583172bdf535932c9ca021746573bce40165ed"},
|
||||
{file = "websockets-12.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f764ba54e33daf20e167915edc443b6f88956f37fb606449b4a5b10ba42235a5"},
|
||||
{file = "websockets-12.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:1e4b3f8ea6a9cfa8be8484c9221ec0257508e3a1ec43c36acdefb2a9c3b00aa2"},
|
||||
{file = "websockets-12.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9fdf06fd06c32205a07e47328ab49c40fc1407cdec801d698a7c41167ea45113"},
|
||||
{file = "websockets-12.0-cp312-cp312-win32.whl", hash = "sha256:baa386875b70cbd81798fa9f71be689c1bf484f65fd6fb08d051a0ee4e79924d"},
|
||||
{file = "websockets-12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ae0a5da8f35a5be197f328d4727dbcfafa53d1824fac3d96cdd3a642fe09394f"},
|
||||
{file = "websockets-12.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5f6ffe2c6598f7f7207eef9a1228b6f5c818f9f4d53ee920aacd35cec8110438"},
|
||||
{file = "websockets-12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9edf3fc590cc2ec20dc9d7a45108b5bbaf21c0d89f9fd3fd1685e223771dc0b2"},
|
||||
{file = "websockets-12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8572132c7be52632201a35f5e08348137f658e5ffd21f51f94572ca6c05ea81d"},
|
||||
{file = "websockets-12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:604428d1b87edbf02b233e2c207d7d528460fa978f9e391bd8aaf9c8311de137"},
|
||||
{file = "websockets-12.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1a9d160fd080c6285e202327aba140fc9a0d910b09e423afff4ae5cbbf1c7205"},
|
||||
{file = "websockets-12.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87b4aafed34653e465eb77b7c93ef058516cb5acf3eb21e42f33928616172def"},
|
||||
{file = "websockets-12.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b2ee7288b85959797970114deae81ab41b731f19ebcd3bd499ae9ca0e3f1d2c8"},
|
||||
{file = "websockets-12.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:7fa3d25e81bfe6a89718e9791128398a50dec6d57faf23770787ff441d851967"},
|
||||
{file = "websockets-12.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a571f035a47212288e3b3519944f6bf4ac7bc7553243e41eac50dd48552b6df7"},
|
||||
{file = "websockets-12.0-cp38-cp38-win32.whl", hash = "sha256:3c6cc1360c10c17463aadd29dd3af332d4a1adaa8796f6b0e9f9df1fdb0bad62"},
|
||||
{file = "websockets-12.0-cp38-cp38-win_amd64.whl", hash = "sha256:1bf386089178ea69d720f8db6199a0504a406209a0fc23e603b27b300fdd6892"},
|
||||
{file = "websockets-12.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ab3d732ad50a4fbd04a4490ef08acd0517b6ae6b77eb967251f4c263011a990d"},
|
||||
{file = "websockets-12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1d9697f3337a89691e3bd8dc56dea45a6f6d975f92e7d5f773bc715c15dde28"},
|
||||
{file = "websockets-12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1df2fbd2c8a98d38a66f5238484405b8d1d16f929bb7a33ed73e4801222a6f53"},
|
||||
{file = "websockets-12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23509452b3bc38e3a057382c2e941d5ac2e01e251acce7adc74011d7d8de434c"},
|
||||
{file = "websockets-12.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2e5fc14ec6ea568200ea4ef46545073da81900a2b67b3e666f04adf53ad452ec"},
|
||||
{file = "websockets-12.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46e71dbbd12850224243f5d2aeec90f0aaa0f2dde5aeeb8fc8df21e04d99eff9"},
|
||||
{file = "websockets-12.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b81f90dcc6c85a9b7f29873beb56c94c85d6f0dac2ea8b60d995bd18bf3e2aae"},
|
||||
{file = "websockets-12.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:a02413bc474feda2849c59ed2dfb2cddb4cd3d2f03a2fedec51d6e959d9b608b"},
|
||||
{file = "websockets-12.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:bbe6013f9f791944ed31ca08b077e26249309639313fff132bfbf3ba105673b9"},
|
||||
{file = "websockets-12.0-cp39-cp39-win32.whl", hash = "sha256:cbe83a6bbdf207ff0541de01e11904827540aa069293696dd528a6640bd6a5f6"},
|
||||
{file = "websockets-12.0-cp39-cp39-win_amd64.whl", hash = "sha256:fc4e7fa5414512b481a2483775a8e8be7803a35b30ca805afa4998a84f9fd9e8"},
|
||||
{file = "websockets-12.0-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:248d8e2446e13c1d4326e0a6a4e9629cb13a11195051a73acf414812700badbd"},
|
||||
{file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f44069528d45a933997a6fef143030d8ca8042f0dfaad753e2906398290e2870"},
|
||||
{file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c4e37d36f0d19f0a4413d3e18c0d03d0c268ada2061868c1e6f5ab1a6d575077"},
|
||||
{file = "websockets-12.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d829f975fc2e527a3ef2f9c8f25e553eb7bc779c6665e8e1d52aa22800bb38b"},
|
||||
{file = "websockets-12.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:2c71bd45a777433dd9113847af751aae36e448bc6b8c361a566cb043eda6ec30"},
|
||||
{file = "websockets-12.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0bee75f400895aef54157b36ed6d3b308fcab62e5260703add87f44cee9c82a6"},
|
||||
{file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:423fc1ed29f7512fceb727e2d2aecb952c46aa34895e9ed96071821309951123"},
|
||||
{file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:27a5e9964ef509016759f2ef3f2c1e13f403725a5e6a1775555994966a66e931"},
|
||||
{file = "websockets-12.0-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3181df4583c4d3994d31fb235dc681d2aaad744fbdbf94c4802485ececdecf2"},
|
||||
{file = "websockets-12.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:b067cb952ce8bf40115f6c19f478dc71c5e719b7fbaa511359795dfd9d1a6468"},
|
||||
{file = "websockets-12.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:00700340c6c7ab788f176d118775202aadea7602c5cc6be6ae127761c16d6b0b"},
|
||||
{file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e469d01137942849cff40517c97a30a93ae79917752b34029f0ec72df6b46399"},
|
||||
{file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffefa1374cd508d633646d51a8e9277763a9b78ae71324183693959cf94635a7"},
|
||||
{file = "websockets-12.0-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba0cab91b3956dfa9f512147860783a1829a8d905ee218a9837c18f683239611"},
|
||||
{file = "websockets-12.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2cb388a5bfb56df4d9a406783b7f9dbefb888c09b71629351cc6b036e9259370"},
|
||||
{file = "websockets-12.0-py3-none-any.whl", hash = "sha256:dc284bbc8d7c78a6c69e0c7325ab46ee5e40bb4d50e494d8131a07ef47500e9e"},
|
||||
{file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "werkzeug"
|
||||
version = "3.0.1"
|
||||
@@ -2658,4 +2866,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "35c237fe6a9278b2dc65b06ed96bde5afb9e393d52c01b00c59acf1df3a8d482"
|
||||
content-hash = "f750bd06f1937f0614204e0ffe9a293eb61a0d7d675a80d5849f40a22745b5f9"
|
||||
|
||||
@@ -5,7 +5,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = ["testing"]
|
||||
testing = []
|
||||
|
||||
[dependencies]
|
||||
@@ -14,6 +14,7 @@ async-trait.workspace = true
|
||||
base64.workspace = true
|
||||
bstr.workspace = true
|
||||
bytes = { workspace = true, features = ["serde"] }
|
||||
camino.workspace = true
|
||||
chrono.workspace = true
|
||||
clap.workspace = true
|
||||
consumption_metrics.workspace = true
|
||||
@@ -26,7 +27,6 @@ hex.workspace = true
|
||||
hmac.workspace = true
|
||||
hostname.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
ipnet.workspace = true
|
||||
itertools.workspace = true
|
||||
@@ -35,6 +35,8 @@ metrics.workspace = true
|
||||
once_cell.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
parking_lot.workspace = true
|
||||
parquet.workspace = true
|
||||
parquet_derive.workspace = true
|
||||
pbkdf2 = { workspace = true, features = ["simple", "std"] }
|
||||
pin-project-lite.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
@@ -42,6 +44,7 @@ pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
remote_storage = { version = "0.1", path = "../libs/remote_storage/" }
|
||||
reqwest = { workspace = true, features = ["json"] }
|
||||
reqwest-middleware.workspace = true
|
||||
reqwest-retry.workspace = true
|
||||
@@ -62,11 +65,13 @@ tls-listener.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tokio-tungstenite.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
tracing-opentelemetry.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
tracing.workspace = true
|
||||
tungstenite.workspace = true
|
||||
url.workspace = true
|
||||
utils.workspace = true
|
||||
uuid.workspace = true
|
||||
@@ -75,11 +80,13 @@ x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
smol_str.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
camino-tempfile.workspace = true
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
|
||||
@@ -4,7 +4,7 @@ pub mod backend;
|
||||
pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::{check_peer_addr_is_in_list, ClientCredentials};
|
||||
pub use credentials::{check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint};
|
||||
|
||||
mod password_hack;
|
||||
pub use password_hack::parse_endpoint_param;
|
||||
|
||||
@@ -8,26 +8,27 @@ use tokio_postgres::config::AuthKeys;
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::validate_password_and_exchange;
|
||||
use crate::cache::Cached;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::AuthSecret;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::proxy::connect_compute::handle_try_wake;
|
||||
use crate::proxy::retry::retry_after;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::scram;
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
auth::{self, ClientCredentials},
|
||||
auth::{self, ComputeUserInfoMaybeEndpoint},
|
||||
config::AuthenticationConfig,
|
||||
console::{
|
||||
self,
|
||||
provider::{CachedNodeInfo, ConsoleReqExtra},
|
||||
provider::{CachedAllowedIps, CachedNodeInfo},
|
||||
Api,
|
||||
},
|
||||
metrics::LatencyTimer,
|
||||
stream, url,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use std::borrow::Cow;
|
||||
use std::net::IpAddr;
|
||||
use std::ops::ControlFlow;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -38,7 +39,7 @@ use tracing::{error, info, warn};
|
||||
/// * When `T` is `()`, it's just a regular auth backend selector
|
||||
/// which we use in [`crate::config::ProxyConfig`].
|
||||
///
|
||||
/// * However, when we substitute `T` with [`ClientCredentials`],
|
||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum BackendType<'a, T> {
|
||||
@@ -56,7 +57,7 @@ pub enum BackendType<'a, T> {
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError>;
|
||||
fn get_allowed_ips(&self) -> Result<Vec<SmolStr>, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, ()> {
|
||||
@@ -127,15 +128,23 @@ pub struct ComputeCredentials<T> {
|
||||
pub keys: T,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfoNoEndpoint {
|
||||
pub user: SmolStr,
|
||||
pub peer_addr: IpAddr,
|
||||
pub cache_key: SmolStr,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfo {
|
||||
pub endpoint: SmolStr,
|
||||
pub inner: ComputeUserInfoNoEndpoint,
|
||||
pub user: SmolStr,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfo {
|
||||
pub fn endpoint_cache_key(&self) -> SmolStr {
|
||||
self.options.get_cache_key(&self.endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ComputeCredentialKeys {
|
||||
@@ -144,19 +153,21 @@ pub enum ComputeCredentialKeys {
|
||||
AuthKeys(AuthKeys),
|
||||
}
|
||||
|
||||
impl TryFrom<ClientCredentials> for ComputeUserInfo {
|
||||
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
|
||||
// user name
|
||||
type Error = ComputeUserInfoNoEndpoint;
|
||||
|
||||
fn try_from(creds: ClientCredentials) -> Result<Self, Self::Error> {
|
||||
let inner = ComputeUserInfoNoEndpoint {
|
||||
user: creds.user,
|
||||
peer_addr: creds.peer_addr,
|
||||
cache_key: creds.cache_key,
|
||||
};
|
||||
match creds.project {
|
||||
None => Err(inner),
|
||||
Some(endpoint) => Ok(ComputeUserInfo { endpoint, inner }),
|
||||
fn try_from(user_info: ComputeUserInfoMaybeEndpoint) -> Result<Self, Self::Error> {
|
||||
match user_info.project {
|
||||
None => Err(ComputeUserInfoNoEndpoint {
|
||||
user: user_info.user,
|
||||
options: user_info.options,
|
||||
}),
|
||||
Some(endpoint) => Ok(ComputeUserInfo {
|
||||
endpoint,
|
||||
user: user_info.user,
|
||||
options: user_info.options,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -166,49 +177,53 @@ impl TryFrom<ClientCredentials> for ComputeUserInfo {
|
||||
///
|
||||
/// All authentication flows will emit an AuthenticationOk message if successful.
|
||||
async fn auth_quirks(
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: ClientCredentials,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let (info, unauthenticated_password) = match creds.try_into() {
|
||||
let (info, unauthenticated_password) = match user_info.try_into() {
|
||||
Err(info) => {
|
||||
let res = hacks::password_hack_no_authentication(info, client, latency_timer).await?;
|
||||
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
|
||||
.await?;
|
||||
ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
|
||||
(res.info, Some(res.keys))
|
||||
}
|
||||
Ok(info) => (info, None),
|
||||
};
|
||||
|
||||
info!("fetching user's authentication info");
|
||||
let allowed_ips = api.get_allowed_ips(extra, &info).await?;
|
||||
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
|
||||
|
||||
// check allowed list
|
||||
if !check_peer_addr_is_in_list(&info.inner.peer_addr, &allowed_ips) {
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed());
|
||||
}
|
||||
let cached_secret = api.get_role_secret(extra, &info).await?;
|
||||
let maybe_secret = api.get_role_secret(ctx, &info).await?;
|
||||
|
||||
let secret = cached_secret.clone().unwrap_or_else(|| {
|
||||
let cached_secret = maybe_secret.unwrap_or_else(|| {
|
||||
// If we don't have an authentication secret, we mock one to
|
||||
// prevent malicious probing (possible due to missing protocol steps).
|
||||
// This mocked secret will never lead to successful authentication.
|
||||
info!("authentication info not found, mocking it");
|
||||
AuthSecret::Scram(scram::ServerSecret::mock(&info.inner.user, rand::random()))
|
||||
Cached::new_uncached(AuthSecret::Scram(scram::ServerSecret::mock(
|
||||
&info.user,
|
||||
rand::random(),
|
||||
)))
|
||||
});
|
||||
match authenticate_with_secret(
|
||||
secret,
|
||||
ctx,
|
||||
cached_secret.value.clone(),
|
||||
info,
|
||||
client,
|
||||
unauthenticated_password,
|
||||
allow_cleartext,
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -224,13 +239,13 @@ async fn auth_quirks(
|
||||
}
|
||||
|
||||
async fn authenticate_with_secret(
|
||||
ctx: &mut RequestMonitoring,
|
||||
secret: AuthSecret,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
unauthenticated_password: Option<Vec<u8>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let auth_outcome = validate_password_and_exchange(&password, secret)?;
|
||||
@@ -238,7 +253,7 @@ async fn authenticate_with_secret(
|
||||
crate::sasl::Outcome::Success(key) => key,
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*info.inner.user));
|
||||
return Err(auth::AuthError::auth_failed(&*info.user));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -253,38 +268,29 @@ async fn authenticate_with_secret(
|
||||
// Perform cleartext auth if we're allowed to do that.
|
||||
// Currently, we use it for websocket connections (latency).
|
||||
if allow_cleartext {
|
||||
return hacks::authenticate_cleartext(info, client, latency_timer, secret).await;
|
||||
return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await;
|
||||
}
|
||||
|
||||
// Finally, proceed with the main auth flow (SCRAM-based).
|
||||
classic::authenticate(info, client, config, latency_timer, secret).await
|
||||
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
|
||||
}
|
||||
|
||||
/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache)
|
||||
/// only if authentication was successfuly.
|
||||
async fn auth_and_wake_compute(
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &impl console::Api,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: ClientCredentials,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
|
||||
let compute_credentials = auth_quirks(
|
||||
api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await?;
|
||||
let compute_credentials =
|
||||
auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?;
|
||||
|
||||
let mut num_retries = 0;
|
||||
let mut node = loop {
|
||||
let wake_res = api.wake_compute(extra, &compute_credentials.info).await;
|
||||
let wake_res = api.wake_compute(ctx, &compute_credentials.info).await;
|
||||
match handle_try_wake(wake_res, num_retries) {
|
||||
Err(e) => {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||
@@ -301,6 +307,8 @@ async fn auth_and_wake_compute(
|
||||
tokio::time::sleep(wait_duration).await;
|
||||
};
|
||||
|
||||
ctx.set_project(node.aux.clone());
|
||||
|
||||
match compute_credentials.keys {
|
||||
#[cfg(feature = "testing")]
|
||||
ComputeCredentialKeys::Password(password) => node.config.password(password),
|
||||
@@ -310,15 +318,15 @@ async fn auth_and_wake_compute(
|
||||
Ok((node, compute_credentials.info))
|
||||
}
|
||||
|
||||
impl<'a> BackendType<'a, ClientCredentials> {
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<SmolStr> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, creds) => creds.project.clone(),
|
||||
Console(_, user_info) => user_info.project.clone(),
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(_, creds) => creds.project.clone(),
|
||||
Postgres(_, user_info) => user_info.project.clone(),
|
||||
Link(_) => Some("link".into()),
|
||||
#[cfg(test)]
|
||||
Test(_) => Some("test".into()),
|
||||
@@ -330,9 +338,9 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, creds) => &creds.user,
|
||||
Console(_, user_info) => &user_info.user,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(_, creds) => &creds.user,
|
||||
Postgres(_, user_info) => &user_info.user,
|
||||
Link(_) => "link",
|
||||
#[cfg(test)]
|
||||
Test(_) => "test",
|
||||
@@ -343,52 +351,37 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
Console(api, creds) => {
|
||||
Console(api, user_info) => {
|
||||
info!(
|
||||
user = &*creds.user,
|
||||
project = creds.project(),
|
||||
user = &*user_info.user,
|
||||
project = user_info.project(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let (cache_info, user_info) = auth_and_wake_compute(
|
||||
&*api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await?;
|
||||
let (cache_info, user_info) =
|
||||
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
|
||||
.await?;
|
||||
(cache_info, BackendType::Console(api, user_info))
|
||||
}
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => {
|
||||
Postgres(api, user_info) => {
|
||||
info!(
|
||||
user = &*creds.user,
|
||||
project = creds.project(),
|
||||
user = &*user_info.user,
|
||||
project = user_info.project(),
|
||||
"performing authentication using a local postgres instance"
|
||||
);
|
||||
|
||||
let (cache_info, user_info) = auth_and_wake_compute(
|
||||
&*api,
|
||||
extra,
|
||||
creds,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
latency_timer,
|
||||
)
|
||||
.await?;
|
||||
let (cache_info, user_info) =
|
||||
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
|
||||
.await?;
|
||||
(cache_info, BackendType::Postgres(api, user_info))
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
@@ -416,16 +409,16 @@ impl<'a> BackendType<'a, ClientCredentials> {
|
||||
impl BackendType<'_, ComputeUserInfo> {
|
||||
pub async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, creds) => api.get_allowed_ips(extra, creds).await,
|
||||
Console(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => api.get_allowed_ips(extra, creds).await,
|
||||
Link(_) => Ok(Arc::new(vec![])),
|
||||
Postgres(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
|
||||
Link(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
|
||||
#[cfg(test)]
|
||||
Test(x) => x.get_allowed_ips(),
|
||||
Test(x) => Ok(Cached::new_uncached(Arc::new(x.get_allowed_ips()?))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -433,14 +426,14 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
/// The link auth flow doesn't support this, so we return [`None`] in that case.
|
||||
pub async fn wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
|
||||
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
||||
#[cfg(feature = "testing")]
|
||||
Postgres(api, creds) => api.wake_compute(extra, creds).map_ok(Some).await,
|
||||
Postgres(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
||||
Link(_) => Ok(None),
|
||||
#[cfg(test)]
|
||||
Test(x) => x.wake_compute().map(Some),
|
||||
|
||||
@@ -54,7 +54,7 @@ pub(super) async fn authenticate(
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*creds.inner.user));
|
||||
return Err(auth::AuthError::auth_failed(&*creds.user));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ pub async fn authenticate_cleartext(
|
||||
sasl::Outcome::Success(key) => key,
|
||||
sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*info.inner.user));
|
||||
return Err(auth::AuthError::auth_failed(&*info.user));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -67,7 +67,8 @@ pub async fn password_hack_no_authentication(
|
||||
// Report tentative success; compute node will check the password anyway.
|
||||
Ok(ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
inner: info,
|
||||
user: info.user,
|
||||
options: info.options,
|
||||
endpoint: payload.endpoint,
|
||||
},
|
||||
keys: payload.password,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use crate::{
|
||||
auth::password_hack::parse_endpoint_param, error::UserFacingError,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::neon_options_str,
|
||||
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
|
||||
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
@@ -12,7 +12,7 @@ use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
||||
pub enum ClientCredsParseError {
|
||||
pub enum ComputeUserInfoParseError {
|
||||
#[error("Parameter '{0}' is missing in startup packet.")]
|
||||
MissingKey(&'static str),
|
||||
|
||||
@@ -33,39 +33,58 @@ pub enum ClientCredsParseError {
|
||||
MalformedProjectName(SmolStr),
|
||||
}
|
||||
|
||||
impl UserFacingError for ClientCredsParseError {}
|
||||
impl UserFacingError for ComputeUserInfoParseError {}
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub struct ComputeUserInfoMaybeEndpoint {
|
||||
pub user: SmolStr,
|
||||
// TODO: this is a severe misnomer! We should think of a new name ASAP.
|
||||
pub project: Option<SmolStr>,
|
||||
|
||||
pub cache_key: SmolStr,
|
||||
pub peer_addr: IpAddr,
|
||||
pub options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
#[inline]
|
||||
pub fn project(&self) -> Option<&str> {
|
||||
self.project.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
pub fn endpoint_sni<'a>(
|
||||
sni: &'a str,
|
||||
common_names: &HashSet<String>,
|
||||
) -> Result<&'a str, ComputeUserInfoParseError> {
|
||||
let Some((subdomain, common_name)) = sni.split_once('.') else {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName { cn: sni.into() });
|
||||
};
|
||||
if !common_names.contains(common_name) {
|
||||
return Err(ComputeUserInfoParseError::UnknownCommonName {
|
||||
cn: common_name.into(),
|
||||
});
|
||||
}
|
||||
Ok(subdomain)
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
pub fn parse(
|
||||
ctx: &mut RequestMonitoring,
|
||||
params: &StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
common_names: Option<HashSet<String>>,
|
||||
peer_addr: IpAddr,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
use ClientCredsParseError::*;
|
||||
common_names: Option<&HashSet<String>>,
|
||||
) -> Result<Self, ComputeUserInfoParseError> {
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
let user = get_param("user")?.into();
|
||||
let user: SmolStr = get_param("user")?.into();
|
||||
|
||||
// record the values if we have them
|
||||
ctx.set_application(params.get("application_name").map(SmolStr::from));
|
||||
ctx.set_user(user.clone());
|
||||
ctx.set_endpoint_id(sni.map(SmolStr::from));
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let project_option = params
|
||||
@@ -83,21 +102,7 @@ impl ClientCredentials {
|
||||
|
||||
let project_from_domain = if let Some(sni_str) = sni {
|
||||
if let Some(cn) = common_names {
|
||||
let common_name_from_sni = sni_str.split_once('.').map(|(_, domain)| domain);
|
||||
|
||||
let project = common_name_from_sni
|
||||
.and_then(|domain| {
|
||||
if cn.contains(domain) {
|
||||
subdomain_from_sni(sni_str, domain)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| UnknownCommonName {
|
||||
cn: common_name_from_sni.unwrap_or("").into(),
|
||||
})?;
|
||||
|
||||
Some(project)
|
||||
Some(SmolStr::from(endpoint_sni(sni_str, cn)?))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -136,23 +141,17 @@ impl ClientCredentials {
|
||||
info!("Connection with password hack");
|
||||
}
|
||||
|
||||
let cache_key = format!(
|
||||
"{}{}",
|
||||
project.as_deref().unwrap_or(""),
|
||||
neon_options_str(params)
|
||||
)
|
||||
.into();
|
||||
let options = NeonOptions::parse_params(params);
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
project,
|
||||
cache_key,
|
||||
peer_addr,
|
||||
options,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<String>) -> bool {
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &Vec<SmolStr>) -> bool {
|
||||
if ip_list.is_empty() {
|
||||
return true;
|
||||
}
|
||||
@@ -204,25 +203,19 @@ fn project_name_valid(name: &str) -> bool {
|
||||
name.chars().all(|c| c.is_alphanumeric() || c == '-')
|
||||
}
|
||||
|
||||
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<SmolStr> {
|
||||
sni.strip_suffix(common_name)?
|
||||
.strip_suffix('.')
|
||||
.map(SmolStr::from)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use ClientCredsParseError::*;
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
#[test]
|
||||
fn parse_bare_minimum() -> anyhow::Result<()> {
|
||||
// According to postgresql, only `user` should be required.
|
||||
let options = StartupMessageParams::new([("user", "john_doe")]);
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project, None);
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -234,10 +227,10 @@ mod tests {
|
||||
("database", "world"), // should be ignored
|
||||
("foo", "bar"), // should be ignored
|
||||
]);
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project, None);
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -249,11 +242,12 @@ mod tests {
|
||||
let sni = Some("foo.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("foo"));
|
||||
assert_eq!(creds.cache_key, "foo");
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("foo"));
|
||||
assert_eq!(user_info.options.get_cache_key("foo"), "foo");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -265,10 +259,10 @@ mod tests {
|
||||
("options", "-ckey=1 project=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("bar"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -280,10 +274,10 @@ mod tests {
|
||||
("options", "-ckey=1 endpoint=bar -c geqo=off"),
|
||||
]);
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("bar"));
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("bar"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -298,10 +292,10 @@ mod tests {
|
||||
),
|
||||
]);
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert!(creds.project.is_none());
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.project.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -313,10 +307,10 @@ mod tests {
|
||||
("options", "-ckey=1 endpoint=bar project=foo -c geqo=off"),
|
||||
]);
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, None, None, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert!(creds.project.is_none());
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info = ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, None, None)?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert!(user_info.project.is_none());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -328,10 +322,11 @@ mod tests {
|
||||
let sni = Some("baz.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.user, "john_doe");
|
||||
assert_eq!(creds.project.as_deref(), Some("baz"));
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.user, "john_doe");
|
||||
assert_eq!(user_info.project.as_deref(), Some("baz"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -342,15 +337,17 @@ mod tests {
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.a.com");
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.project.as_deref(), Some("p1"));
|
||||
|
||||
let common_names = Some(["a.com".into(), "b.com".into()].into());
|
||||
let sni = Some("p1.b.com");
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("p1"));
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.project.as_deref(), Some("p1"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -363,9 +360,10 @@ mod tests {
|
||||
let sni = Some("second.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let err = ClientCredentials::parse(&options, sni, common_names, peer_addr)
|
||||
.expect_err("should fail");
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let err =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
InconsistentProjectNames { domain, option } => {
|
||||
assert_eq!(option, "first");
|
||||
@@ -382,9 +380,10 @@ mod tests {
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["example.com".into()].into());
|
||||
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let err = ClientCredentials::parse(&options, sni, common_names, peer_addr)
|
||||
.expect_err("should fail");
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let err =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())
|
||||
.expect_err("should fail");
|
||||
match err {
|
||||
UnknownCommonName { cn } => {
|
||||
assert_eq!(cn, "localhost");
|
||||
@@ -402,10 +401,14 @@ mod tests {
|
||||
|
||||
let sni = Some("project.localhost");
|
||||
let common_names = Some(["localhost".into()].into());
|
||||
let peer_addr = IpAddr::from([127, 0, 0, 1]);
|
||||
let creds = ClientCredentials::parse(&options, sni, common_names, peer_addr)?;
|
||||
assert_eq!(creds.project.as_deref(), Some("project"));
|
||||
assert_eq!(creds.cache_key, "projectendpoint_type:read_write lsn:0/2");
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let user_info =
|
||||
ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &options, sni, common_names.as_ref())?;
|
||||
assert_eq!(user_info.project.as_deref(), Some("project"));
|
||||
assert_eq!(
|
||||
user_info.options.get_cache_key("project"),
|
||||
"project endpoint_type:read_write lsn:0/2"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::{net::SocketAddr, sync::Arc};
|
||||
use futures::future::Either;
|
||||
use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use proxy::context::RequestMonitoring;
|
||||
use proxy::proxy::run_until_cancelled;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
@@ -170,7 +171,16 @@ async fn task_main(
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
info!(%peer_addr, "serving");
|
||||
handle_client(dest_suffix, tls_config, tls_server_end_point, socket).await
|
||||
let mut ctx =
|
||||
RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
|
||||
handle_client(
|
||||
&mut ctx,
|
||||
dest_suffix,
|
||||
tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -236,6 +246,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
ctx: &mut RequestMonitoring,
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
@@ -261,5 +272,5 @@ async fn handle_client(
|
||||
let client = tokio::net::TcpStream::connect(destination).await?;
|
||||
|
||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||
proxy::proxy::proxy_pass(tls_stream, client, metrics_aux).await
|
||||
proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await
|
||||
}
|
||||
|
||||
@@ -3,14 +3,14 @@ use proxy::auth;
|
||||
use proxy::config::AuthenticationConfig;
|
||||
use proxy::config::CacheOptions;
|
||||
use proxy::config::HttpConfig;
|
||||
use proxy::config::ProjectInfoCacheOptions;
|
||||
use proxy::console;
|
||||
use proxy::console::provider::AllowedIpsCache;
|
||||
use proxy::console::provider::NodeInfoCache;
|
||||
use proxy::console::provider::RoleSecretCache;
|
||||
use proxy::context::parquet::ParquetUploadArgs;
|
||||
use proxy::http;
|
||||
use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
use proxy::rate_limiter::RateLimiterConfig;
|
||||
use proxy::redis::notifications;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::usage_metrics;
|
||||
|
||||
@@ -44,6 +44,9 @@ enum AuthBackend {
|
||||
#[derive(Parser)]
|
||||
#[command(version = GIT_VERSION, about)]
|
||||
struct ProxyCliArgs {
|
||||
/// Name of the region this proxy is deployed in
|
||||
#[clap(long, default_value_t = String::new())]
|
||||
region: String,
|
||||
/// listen for incoming client connections on ip:port
|
||||
#[clap(short, long, default_value = "127.0.0.1:4432")]
|
||||
proxy: String,
|
||||
@@ -133,6 +136,15 @@ struct ProxyCliArgs {
|
||||
/// disable ip check for http requests. If it is too time consuming, it could be turned off.
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
disable_ip_check_for_http: bool,
|
||||
/// redis url for notifications.
|
||||
#[clap(long)]
|
||||
redis_notifications: Option<String>,
|
||||
/// cache for `project_info` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
project_info_cache: String,
|
||||
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
}
|
||||
|
||||
#[derive(clap::Args, Clone, Copy, Debug)]
|
||||
@@ -221,6 +233,11 @@ async fn main() -> anyhow::Result<()> {
|
||||
));
|
||||
}
|
||||
|
||||
client_tasks.spawn(proxy::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
));
|
||||
|
||||
// maintenance tasks. these never return unless there's an error
|
||||
let mut maintenance_tasks = JoinSet::new();
|
||||
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token));
|
||||
@@ -231,6 +248,15 @@ async fn main() -> anyhow::Result<()> {
|
||||
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
|
||||
}
|
||||
|
||||
if let auth::BackendType::Console(api, _) = &config.auth_backend {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(url) = args.redis_notifications {
|
||||
info!("Starting redis notifications listener ({url})");
|
||||
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
|
||||
let maintenance = loop {
|
||||
// get one complete task
|
||||
match futures::future::select(
|
||||
@@ -296,32 +322,17 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let auth_backend = match &args.auth_backend {
|
||||
AuthBackend::Console => {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
let allowed_ips_cache_config: CacheOptions = args.allowed_ips_cache.parse()?;
|
||||
let role_secret_cache_config: CacheOptions = args.role_secret_cache.parse()?;
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
|
||||
info!("Using AllowedIpsCache (wake_compute) with options={allowed_ips_cache_config:?}");
|
||||
info!("Using RoleSecretCache (wake_compute) with options={role_secret_cache_config:?}");
|
||||
let caches = Box::leak(Box::new(console::caches::ApiCaches {
|
||||
node_info: NodeInfoCache::new(
|
||||
"node_info_cache",
|
||||
wake_compute_cache_config.size,
|
||||
wake_compute_cache_config.ttl,
|
||||
true,
|
||||
),
|
||||
allowed_ips: AllowedIpsCache::new(
|
||||
"allowed_ips_cache",
|
||||
allowed_ips_cache_config.size,
|
||||
allowed_ips_cache_config.ttl,
|
||||
false,
|
||||
),
|
||||
role_secret: RoleSecretCache::new(
|
||||
"role_secret_cache",
|
||||
role_secret_cache_config.size,
|
||||
role_secret_cache_config.ttl,
|
||||
false,
|
||||
),
|
||||
}));
|
||||
info!(
|
||||
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
|
||||
);
|
||||
let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
)));
|
||||
|
||||
let config::WakeComputeLockOptions {
|
||||
shards,
|
||||
@@ -380,6 +391,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
require_client_ip: args.require_client_ip,
|
||||
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
||||
endpoint_rps_limit,
|
||||
// TODO: add this argument
|
||||
region: args.region.clone(),
|
||||
}));
|
||||
|
||||
Ok(config)
|
||||
|
||||
@@ -1,311 +1,6 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
ops::{Deref, DerefMut},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
/// This is useful for [`timed_lru::Cached`].
|
||||
pub trait Cache {
|
||||
/// Entry's key.
|
||||
type Key;
|
||||
|
||||
/// Entry's value.
|
||||
type Value;
|
||||
|
||||
/// Used for entry invalidation.
|
||||
type LookupInfo<Key>;
|
||||
|
||||
/// Invalidate an entry using a lookup info.
|
||||
/// We don't have an empty default impl because it's error-prone.
|
||||
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
|
||||
}
|
||||
|
||||
impl<C: Cache> Cache for &C {
|
||||
type Key = C::Key;
|
||||
type Value = C::Value;
|
||||
type LookupInfo<Key> = C::LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
|
||||
C::invalidate(self, info)
|
||||
}
|
||||
}
|
||||
pub mod common;
|
||||
pub mod project_info;
|
||||
mod timed_lru;
|
||||
|
||||
pub use common::{Cache, Cached};
|
||||
pub use timed_lru::TimedLru;
|
||||
pub mod timed_lru {
|
||||
use super::*;
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
|
||||
update_ttl_on_retrieval: bool,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type Key = K;
|
||||
type Value = V;
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
capacity: usize,
|
||||
ttl: Duration,
|
||||
update_ttl_on_retrieval: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
update_ttl_on_retrieval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop an entry from the cache if it's outdated.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: &LookupInfo<K>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(raw_entry.key(), entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
if self.update_ttl_on_retrieval {
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
}
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
old_expires_at = format_args!("{expires_at:?}"),
|
||||
new_expires_at = format_args!("{deadline:?}"),
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(key, entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(created_at, old)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper.
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some((self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
|
||||
/// Wrapper for convenient entry invalidation.
|
||||
pub struct Cached<C: Cache> {
|
||||
/// Cache + lookup info.
|
||||
token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
|
||||
/// The value itself.
|
||||
value: C::Value,
|
||||
}
|
||||
|
||||
impl<C: Cache> Cached<C> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
pub fn new_uncached(value: C::Value) -> Self {
|
||||
Self { token: None, value }
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(self) -> C::Value {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
cache.invalidate(info);
|
||||
}
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Tell if this entry is actually cached.
|
||||
pub fn cached(&self) -> bool {
|
||||
self.token.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> Deref for Cached<C> {
|
||||
type Target = C::Value;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache> DerefMut for Cached<C> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
72
proxy/src/cache/common.rs
vendored
Normal file
72
proxy/src/cache/common.rs
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
/// This is useful for [`Cached`].
|
||||
pub trait Cache {
|
||||
/// Entry's key.
|
||||
type Key;
|
||||
|
||||
/// Entry's value.
|
||||
type Value;
|
||||
|
||||
/// Used for entry invalidation.
|
||||
type LookupInfo<Key>;
|
||||
|
||||
/// Invalidate an entry using a lookup info.
|
||||
/// We don't have an empty default impl because it's error-prone.
|
||||
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
|
||||
}
|
||||
|
||||
impl<C: Cache> Cache for &C {
|
||||
type Key = C::Key;
|
||||
type Value = C::Value;
|
||||
type LookupInfo<Key> = C::LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
|
||||
C::invalidate(self, info)
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper for convenient entry invalidation.
|
||||
pub struct Cached<C: Cache, V = <C as Cache>::Value> {
|
||||
/// Cache + lookup info.
|
||||
pub token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
|
||||
/// The value itself.
|
||||
pub value: V,
|
||||
}
|
||||
|
||||
impl<C: Cache, V> Cached<C, V> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
pub fn new_uncached(value: V) -> Self {
|
||||
Self { token: None, value }
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(self) -> V {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
cache.invalidate(info);
|
||||
}
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Tell if this entry is actually cached.
|
||||
pub fn cached(&self) -> bool {
|
||||
self.token.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache, V> Deref for Cached<C, V> {
|
||||
type Target = V;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Cache, V> DerefMut for Cached<C, V> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.value
|
||||
}
|
||||
}
|
||||
496
proxy/src/cache/project_info.rs
vendored
Normal file
496
proxy/src/cache/project_info.rs
vendored
Normal file
@@ -0,0 +1,496 @@
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
convert::Infallible,
|
||||
sync::{atomic::AtomicU64, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use rand::{thread_rng, Rng};
|
||||
use smol_str::SmolStr;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{config::ProjectInfoCacheOptions, console::AuthSecret};
|
||||
|
||||
use super::{Cache, Cached};
|
||||
|
||||
pub trait ProjectInfoCache {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr);
|
||||
fn enable_ttl(&self);
|
||||
fn disable_ttl(&self);
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<T> Entry<T> {
|
||||
pub fn new(value: T) -> Self {
|
||||
Self {
|
||||
created_at: Instant::now(),
|
||||
value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for Entry<T> {
|
||||
fn from(value: T) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct EndpointInfo {
|
||||
secret: std::collections::HashMap<SmolStr, Entry<AuthSecret>>,
|
||||
allowed_ips: Option<Entry<Arc<Vec<SmolStr>>>>,
|
||||
}
|
||||
|
||||
impl EndpointInfo {
|
||||
fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
|
||||
match ignore_cache_since {
|
||||
None => false,
|
||||
Some(t) => t < created_at,
|
||||
}
|
||||
}
|
||||
pub fn get_role_secret(
|
||||
&self,
|
||||
role_name: &SmolStr,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(AuthSecret, bool)> {
|
||||
if let Some(secret) = self.secret.get(role_name) {
|
||||
if valid_since < secret.created_at {
|
||||
return Some((
|
||||
secret.value.clone(),
|
||||
Self::check_ignore_cache(ignore_cache_since, secret.created_at),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
) -> Option<(Arc<Vec<SmolStr>>, bool)> {
|
||||
if let Some(allowed_ips) = &self.allowed_ips {
|
||||
if valid_since < allowed_ips.created_at {
|
||||
return Some((
|
||||
allowed_ips.value.clone(),
|
||||
Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
|
||||
));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
pub fn invalidate_allowed_ips(&mut self) {
|
||||
self.allowed_ips = None;
|
||||
}
|
||||
pub fn invalidate_role_secret(&mut self, role_name: &SmolStr) {
|
||||
self.secret.remove(role_name);
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache for project info.
|
||||
/// This is used to cache auth data for endpoints.
|
||||
/// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
|
||||
///
|
||||
/// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
|
||||
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
|
||||
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
|
||||
pub struct ProjectInfoCacheImpl {
|
||||
cache: DashMap<SmolStr, EndpointInfo>,
|
||||
|
||||
project2ep: DashMap<SmolStr, HashSet<SmolStr>>,
|
||||
config: ProjectInfoCacheOptions,
|
||||
|
||||
start_time: Instant,
|
||||
ttl_disabled_since_us: AtomicU64,
|
||||
}
|
||||
|
||||
impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: &SmolStr) {
|
||||
info!("invalidating allowed ips for project `{}`", project_id);
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
.get(project_id)
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_ips();
|
||||
}
|
||||
}
|
||||
}
|
||||
fn invalidate_role_secret_for_project(&self, project_id: &SmolStr, role_name: &SmolStr) {
|
||||
info!(
|
||||
"invalidating role secret for project_id `{}` and role_name `{}`",
|
||||
project_id, role_name
|
||||
);
|
||||
let endpoints = self
|
||||
.project2ep
|
||||
.get(project_id)
|
||||
.map(|kv| kv.value().clone())
|
||||
.unwrap_or_default();
|
||||
for endpoint_id in endpoints {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
|
||||
endpoint_info.invalidate_role_secret(role_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
fn enable_ttl(&self) {
|
||||
self.ttl_disabled_since_us
|
||||
.store(u64::MAX, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn disable_ttl(&self) {
|
||||
let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
|
||||
self.ttl_disabled_since_us
|
||||
.store(new_ttl, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheImpl {
|
||||
pub fn new(config: ProjectInfoCacheOptions) -> Self {
|
||||
Self {
|
||||
cache: DashMap::new(),
|
||||
project2ep: DashMap::new(),
|
||||
config,
|
||||
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
|
||||
start_time: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_role_secret(
|
||||
&self,
|
||||
endpoint_id: &SmolStr,
|
||||
role_name: &SmolStr,
|
||||
) -> Option<Cached<&Self, AuthSecret>> {
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(endpoint_id)?;
|
||||
let (value, ignore_cache) =
|
||||
endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((
|
||||
self,
|
||||
CachedLookupInfo::new_role_secret(endpoint_id.clone(), role_name.clone()),
|
||||
)),
|
||||
value,
|
||||
};
|
||||
return Some(cached);
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub fn get_allowed_ips(
|
||||
&self,
|
||||
endpoint_id: &SmolStr,
|
||||
) -> Option<Cached<&Self, Arc<Vec<SmolStr>>>> {
|
||||
let (valid_since, ignore_cache_since) = self.get_cache_times();
|
||||
let endpoint_info = self.cache.get(endpoint_id)?;
|
||||
let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
|
||||
let (value, ignore_cache) = value?;
|
||||
if !ignore_cache {
|
||||
let cached = Cached {
|
||||
token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id.clone()))),
|
||||
value,
|
||||
};
|
||||
return Some(cached);
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub fn insert_role_secret(
|
||||
&self,
|
||||
project_id: &SmolStr,
|
||||
endpoint_id: &SmolStr,
|
||||
role_name: &SmolStr,
|
||||
secret: AuthSecret,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
self.inser_project2endpoint(project_id, endpoint_id);
|
||||
let mut entry = self.cache.entry(endpoint_id.clone()).or_default();
|
||||
if entry.secret.len() < self.config.max_roles {
|
||||
entry.secret.insert(role_name.clone(), secret.into());
|
||||
}
|
||||
}
|
||||
pub fn insert_allowed_ips(
|
||||
&self,
|
||||
project_id: &SmolStr,
|
||||
endpoint_id: &SmolStr,
|
||||
allowed_ips: Arc<Vec<SmolStr>>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
self.inser_project2endpoint(project_id, endpoint_id);
|
||||
self.cache
|
||||
.entry(endpoint_id.clone())
|
||||
.or_default()
|
||||
.allowed_ips = Some(allowed_ips.into());
|
||||
}
|
||||
fn inser_project2endpoint(&self, project_id: &SmolStr, endpoint_id: &SmolStr) {
|
||||
if let Some(mut endpoints) = self.project2ep.get_mut(project_id) {
|
||||
endpoints.insert(endpoint_id.clone());
|
||||
} else {
|
||||
self.project2ep
|
||||
.insert(project_id.clone(), HashSet::from([endpoint_id.clone()]));
|
||||
}
|
||||
}
|
||||
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
|
||||
let mut valid_since = Instant::now() - self.config.ttl;
|
||||
// Only ignore cache if ttl is disabled.
|
||||
let ttl_disabled_since_us = self
|
||||
.ttl_disabled_since_us
|
||||
.load(std::sync::atomic::Ordering::Relaxed);
|
||||
let ignore_cache_since = if ttl_disabled_since_us != u64::MAX {
|
||||
let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
|
||||
// We are fine if entry is not older than ttl or was added before we are getting notifications.
|
||||
valid_since = valid_since.min(ignore_cache_since);
|
||||
Some(ignore_cache_since)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
(valid_since, ignore_cache_since)
|
||||
}
|
||||
|
||||
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
|
||||
let mut interval =
|
||||
tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if self.cache.len() <= self.config.size {
|
||||
// If there are not too many entries, wait until the next gc cycle.
|
||||
continue;
|
||||
}
|
||||
self.gc();
|
||||
}
|
||||
}
|
||||
|
||||
fn gc(&self) {
|
||||
let shard = thread_rng().gen_range(0..self.project2ep.shards().len());
|
||||
debug!(shard, "project_info_cache: performing epoch reclamation");
|
||||
|
||||
// acquire a random shard lock
|
||||
let mut removed = 0;
|
||||
let shard = self.project2ep.shards()[shard].write();
|
||||
for (_, endpoints) in shard.iter() {
|
||||
for endpoint in endpoints.get().iter() {
|
||||
self.cache.remove(endpoint);
|
||||
removed += 1;
|
||||
}
|
||||
}
|
||||
// We can drop this shard only after making sure that all endpoints are removed.
|
||||
drop(shard);
|
||||
info!("project_info_cache: removed {removed} endpoints");
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup info for project info cache.
|
||||
/// This is used to invalidate cache entries.
|
||||
pub struct CachedLookupInfo {
|
||||
/// Search by this key.
|
||||
endpoint_id: SmolStr,
|
||||
lookup_type: LookupType,
|
||||
}
|
||||
|
||||
impl CachedLookupInfo {
|
||||
pub(self) fn new_role_secret(endpoint_id: SmolStr, role_name: SmolStr) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::RoleSecret(role_name),
|
||||
}
|
||||
}
|
||||
pub(self) fn new_allowed_ips(endpoint_id: SmolStr) -> Self {
|
||||
Self {
|
||||
endpoint_id,
|
||||
lookup_type: LookupType::AllowedIps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum LookupType {
|
||||
RoleSecret(SmolStr),
|
||||
AllowedIps,
|
||||
}
|
||||
|
||||
impl Cache for ProjectInfoCacheImpl {
|
||||
type Key = SmolStr;
|
||||
// Value is not really used here, but we need to specify it.
|
||||
type Value = SmolStr;
|
||||
|
||||
type LookupInfo<Key> = CachedLookupInfo;
|
||||
|
||||
fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
|
||||
match &key.lookup_type {
|
||||
LookupType::RoleSecret(role_name) => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_role_secret(role_name);
|
||||
}
|
||||
}
|
||||
LookupType::AllowedIps => {
|
||||
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
|
||||
endpoint_info.invalidate_allowed_ips();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{console::AuthSecret, scram::ServerSecret};
|
||||
use smol_str::SmolStr;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_settings() {
|
||||
tokio::time::pause();
|
||||
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
});
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: SmolStr = "user1".into();
|
||||
let user2: SmolStr = "user2".into();
|
||||
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
|
||||
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert!(cached.cached());
|
||||
assert_eq!(cached.value, secret1);
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert!(cached.cached());
|
||||
assert_eq!(cached.value, secret2);
|
||||
|
||||
// Shouldn't add more than 2 roles.
|
||||
let user3: SmolStr = "user3".into();
|
||||
let secret3 = AuthSecret::Scram(ServerSecret::mock(user3.as_str(), [3; 32]));
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user3, secret3.clone());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(cached.cached());
|
||||
assert_eq!(cached.value, allowed_ips);
|
||||
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_allowed_ips(&endpoint_id);
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_invalidations() {
|
||||
tokio::time::pause();
|
||||
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
}));
|
||||
cache.clone().disable_ttl();
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: SmolStr = "user1".into();
|
||||
let user2: SmolStr = "user2".into();
|
||||
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
|
||||
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
// Nothing should be invalidated.
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
// TTL is disabled, so it should be impossible to invalidate this value.
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, secret1);
|
||||
|
||||
cached.invalidate(); // Shouldn't do anything.
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert_eq!(cached.value, secret1);
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, secret2);
|
||||
|
||||
// The only way to invalidate this value is to invalidate via the api.
|
||||
cache.invalidate_role_secret_for_project(&project_id, &user2);
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, allowed_ips);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disable_ttl_invalidate_added_before() {
|
||||
tokio::time::pause();
|
||||
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 2,
|
||||
max_roles: 2,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
}));
|
||||
|
||||
let project_id = "project".into();
|
||||
let endpoint_id = "endpoint".into();
|
||||
let user1: SmolStr = "user1".into();
|
||||
let user2: SmolStr = "user2".into();
|
||||
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
|
||||
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
|
||||
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
|
||||
cache.clone().disable_ttl();
|
||||
tokio::time::advance(Duration::from_millis(100)).await;
|
||||
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
|
||||
|
||||
// Added before ttl was disabled + ttl should be still cached.
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert!(cached.cached());
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert!(cached.cached());
|
||||
|
||||
tokio::time::advance(Duration::from_secs(1)).await;
|
||||
// Added before ttl was disabled + ttl should expire.
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
|
||||
// Added after ttl was disabled + ttl should not be cached.
|
||||
cache.insert_allowed_ips(&project_id, &endpoint_id, allowed_ips.clone());
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(!cached.cached());
|
||||
|
||||
tokio::time::advance(Duration::from_secs(1)).await;
|
||||
// Added before ttl was disabled + ttl still should expire.
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
|
||||
// Shouldn't be invalidated.
|
||||
|
||||
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
|
||||
assert!(!cached.cached());
|
||||
assert_eq!(cached.value, allowed_ips);
|
||||
}
|
||||
}
|
||||
258
proxy/src/cache/timed_lru.rs
vendored
Normal file
258
proxy/src/cache/timed_lru.rs
vendored
Normal file
@@ -0,0 +1,258 @@
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
hash::Hash,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::debug;
|
||||
|
||||
// This seems to make more sense than `lru` or `cached`:
|
||||
//
|
||||
// * `near/nearcore` ditched `cached` in favor of `lru`
|
||||
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
|
||||
//
|
||||
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
|
||||
// This severely hinders its usage both in terms of creating wrappers and supported key types.
|
||||
//
|
||||
// On the other hand, `hashlink` has good download stats and appears to be maintained.
|
||||
use hashlink::{linked_hash_map::RawEntryMut, LruCache};
|
||||
|
||||
use super::{common::Cached, *};
|
||||
|
||||
/// An implementation of timed LRU cache with fixed capacity.
|
||||
/// Key properties:
|
||||
///
|
||||
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
|
||||
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
|
||||
///
|
||||
/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp.
|
||||
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
|
||||
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
|
||||
/// its existence.
|
||||
///
|
||||
/// * There's an API for immediate invalidation (removal) of a cache entry;
|
||||
/// It's useful in case we know for sure that the entry is no longer correct.
|
||||
/// See [`timed_lru::LookupInfo`] & [`timed_lru::Cached`] for more information.
|
||||
///
|
||||
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
|
||||
/// or by a successful lookup (i.e. the entry hasn't expired yet).
|
||||
/// There is no background job to reap the expired records.
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
/// The underlying cache implementation.
|
||||
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
|
||||
|
||||
/// Default time-to-live of a single entry.
|
||||
ttl: Duration,
|
||||
|
||||
update_ttl_on_retrieval: bool,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type Key = K;
|
||||
type Value = V;
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info)
|
||||
}
|
||||
}
|
||||
|
||||
struct Entry<T> {
|
||||
created_at: Instant,
|
||||
expires_at: Instant,
|
||||
value: T,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
capacity: usize,
|
||||
ttl: Duration,
|
||||
update_ttl_on_retrieval: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
cache: LruCache::new(capacity).into(),
|
||||
ttl,
|
||||
update_ttl_on_retrieval,
|
||||
}
|
||||
}
|
||||
|
||||
/// Drop an entry from the cache if it's outdated.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn invalidate_raw(&self, info: &LookupInfo<K>) {
|
||||
let now = Instant::now();
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let raw_entry = match cache.raw_entry_mut().from_key(&info.key) {
|
||||
RawEntryMut::Vacant(_) => return,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Remove the entry if it was created prior to lookup timestamp.
|
||||
let entry = raw_entry.get();
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
let should_remove = created_at <= info.created_at || expires_at <= now;
|
||||
|
||||
if should_remove {
|
||||
raw_entry.remove();
|
||||
}
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
entry_removed = should_remove,
|
||||
"processed a cache entry invalidation event"
|
||||
);
|
||||
}
|
||||
|
||||
/// Try retrieving an entry by its key, then execute `extract` if it exists.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let now = Instant::now();
|
||||
let deadline = now.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let mut cache = self.cache.lock();
|
||||
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
|
||||
RawEntryMut::Vacant(_) => return None,
|
||||
RawEntryMut::Occupied(x) => x,
|
||||
};
|
||||
|
||||
// Immeditely drop the entry if it has expired.
|
||||
let entry = raw_entry.get();
|
||||
if entry.expires_at <= now {
|
||||
raw_entry.remove();
|
||||
return None;
|
||||
}
|
||||
|
||||
let value = extract(raw_entry.key(), entry);
|
||||
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
|
||||
|
||||
// Update the deadline and the entry's position in the LRU list.
|
||||
if self.update_ttl_on_retrieval {
|
||||
raw_entry.get_mut().expires_at = deadline;
|
||||
}
|
||||
raw_entry.to_back();
|
||||
|
||||
drop(cache); // drop lock before logging
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
old_expires_at = format_args!("{expires_at:?}"),
|
||||
new_expires_at = format_args!("{deadline:?}"),
|
||||
"accessed a cache entry"
|
||||
);
|
||||
|
||||
Some(value)
|
||||
}
|
||||
|
||||
/// Insert an entry to the cache. If an entry with the same key already
|
||||
/// existed, return the previous value and its creation timestamp.
|
||||
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
|
||||
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
|
||||
let created_at = Instant::now();
|
||||
let expires_at = created_at.checked_add(self.ttl).expect("time overflow");
|
||||
|
||||
let entry = Entry {
|
||||
created_at,
|
||||
expires_at,
|
||||
value,
|
||||
};
|
||||
|
||||
// Do costly things before taking the lock.
|
||||
let old = self
|
||||
.cache
|
||||
.lock()
|
||||
.insert(key, entry)
|
||||
.map(|entry| entry.value);
|
||||
|
||||
debug!(
|
||||
created_at = format_args!("{created_at:?}"),
|
||||
expires_at = format_args!("{expires_at:?}"),
|
||||
replaced = old.is_some(),
|
||||
"created a cache entry"
|
||||
);
|
||||
|
||||
(created_at, old)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper.
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
self.get_raw(key, |key, entry| {
|
||||
let info = LookupInfo {
|
||||
created_at: entry.created_at,
|
||||
key: key.clone(),
|
||||
};
|
||||
|
||||
Cached {
|
||||
token: Some((self, info)),
|
||||
value: entry.value.clone(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieve a cached entry in convenient wrapper, ignoring its TTL.
|
||||
pub fn get_ignoring_ttl<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let mut cache = self.cache.lock();
|
||||
cache
|
||||
.get(key)
|
||||
.map(|entry| Cached::new_uncached(entry.value.clone()))
|
||||
}
|
||||
|
||||
/// Remove an entry from the cache.
|
||||
pub fn remove<Q>(&self, key: &Q) -> Option<V>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let mut cache = self.cache.lock();
|
||||
cache.remove(key).map(|entry| entry.value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
/// task after we got the entry we're trying to invalidate now).
|
||||
created_at: Instant,
|
||||
|
||||
/// Search by this key.
|
||||
key: K,
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{
|
||||
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
|
||||
error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE, proxy::neon_option,
|
||||
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
proxy::neon_option,
|
||||
};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
@@ -38,7 +39,17 @@ impl UserFacingError for ConnectionError {
|
||||
// This helps us drop irrelevant library-specific prefixes.
|
||||
// TODO: propagate severity level and other parameters.
|
||||
Postgres(err) => match err.as_db_error() {
|
||||
Some(err) => err.message().to_owned(),
|
||||
Some(err) => {
|
||||
let msg = err.message();
|
||||
|
||||
if msg.starts_with("unsupported startup parameter: ")
|
||||
|| msg.starts_with("unsupported startup parameter in options: ")
|
||||
{
|
||||
format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter")
|
||||
} else {
|
||||
msg.to_owned()
|
||||
}
|
||||
}
|
||||
None => err.to_string(),
|
||||
},
|
||||
WakeComputeError(err) => err.to_string_client(),
|
||||
@@ -232,9 +243,9 @@ impl ConnCfg {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
allow_self_signed_compute: bool,
|
||||
timeout: Duration,
|
||||
proto: &'static str,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
|
||||
@@ -268,7 +279,9 @@ impl ConnCfg {
|
||||
stream,
|
||||
params,
|
||||
cancel_closure,
|
||||
_guage: NUM_DB_CONNECTIONS_GAUGE.with_label_values(&[proto]).guard(),
|
||||
_guage: NUM_DB_CONNECTIONS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard(),
|
||||
};
|
||||
|
||||
Ok(connection)
|
||||
|
||||
@@ -21,6 +21,7 @@ pub struct ProxyConfig {
|
||||
pub require_client_ip: bool,
|
||||
pub disable_ip_check_for_http: bool,
|
||||
pub endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
pub region: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -31,7 +32,7 @@ pub struct MetricCollectionConfig {
|
||||
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_names: Option<HashSet<String>>,
|
||||
pub common_names: HashSet<String>,
|
||||
pub cert_resolver: Arc<CertResolver>,
|
||||
}
|
||||
|
||||
@@ -96,7 +97,7 @@ pub fn configure_tls(
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
common_names: Some(common_names),
|
||||
common_names,
|
||||
cert_resolver,
|
||||
})
|
||||
}
|
||||
@@ -351,6 +352,69 @@ impl FromStr for CacheOptions {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[derive(Debug)]
|
||||
pub struct ProjectInfoCacheOptions {
|
||||
/// Max number of entries.
|
||||
pub size: usize,
|
||||
/// Entry's time-to-live.
|
||||
pub ttl: Duration,
|
||||
/// Max number of roles per endpoint.
|
||||
pub max_roles: usize,
|
||||
/// Gc interval.
|
||||
pub gc_interval: Duration,
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheOptions {
|
||||
/// Default options for [`crate::console::provider::NodeInfoCache`].
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"size=10000,ttl=4m,max_roles=10,gc_interval=60m";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut size = None;
|
||||
let mut ttl = None;
|
||||
let mut max_roles = None;
|
||||
let mut gc_interval = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"size" => size = Some(value.parse()?),
|
||||
"ttl" => ttl = Some(humantime::parse_duration(value)?),
|
||||
"max_roles" => max_roles = Some(value.parse()?),
|
||||
"gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// TTL doesn't matter if cache is always empty.
|
||||
if let Some(0) = size {
|
||||
ttl.get_or_insert(Duration::default());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
size: size.context("missing `size`")?,
|
||||
ttl: ttl.context("missing `ttl`")?,
|
||||
max_roles: max_roles.context("missing `max_roles`")?,
|
||||
gc_interval: gc_interval.context("missing `gc_interval`")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for ProjectInfoCacheOptions {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse cache options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
pub struct WakeComputeLockOptions {
|
||||
/// The number of shards the lock map should have
|
||||
|
||||
@@ -6,7 +6,7 @@ pub mod messages;
|
||||
|
||||
/// Wrappers for console APIs and their mocks.
|
||||
pub mod provider;
|
||||
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
|
||||
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod caches {
|
||||
|
||||
@@ -15,6 +15,7 @@ pub struct ConsoleError {
|
||||
pub struct GetRoleSecret {
|
||||
pub role_secret: Box<str>,
|
||||
pub allowed_ips: Option<Vec<Box<str>>>,
|
||||
pub project_id: Option<Box<str>>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
@@ -207,12 +208,17 @@ mod tests {
|
||||
"role_secret": "secret",
|
||||
});
|
||||
let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
|
||||
// Empty `allowed_ips` field.
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
});
|
||||
let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
"project_id": "project",
|
||||
});
|
||||
let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,17 +5,18 @@ pub mod neon;
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::backend::ComputeUserInfo,
|
||||
cache::{timed_lru, TimedLru},
|
||||
compute, scram,
|
||||
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||
context::RequestMonitoring,
|
||||
scram,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use smol_str::SmolStr;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::{
|
||||
sync::{OwnedSemaphorePermit, Semaphore},
|
||||
time::Instant,
|
||||
};
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
pub mod errors {
|
||||
@@ -196,28 +197,8 @@ pub mod errors {
|
||||
}
|
||||
}
|
||||
|
||||
/// Extra query params we'd like to pass to the console.
|
||||
pub struct ConsoleReqExtra {
|
||||
/// A unique identifier for a connection.
|
||||
pub session_id: uuid::Uuid,
|
||||
/// Name of client application, if set.
|
||||
pub application_name: String,
|
||||
pub options: Vec<(String, String)>,
|
||||
}
|
||||
|
||||
impl ConsoleReqExtra {
|
||||
// https://swagger.io/docs/specification/serialization/ DeepObject format
|
||||
// paramName[prop1]=value1¶mName[prop2]=value2&....
|
||||
pub fn options_as_deep_object(&self) -> Vec<(String, String)> {
|
||||
self.options
|
||||
.iter()
|
||||
.map(|(k, v)| (format!("options[{}]", k), v.to_string()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub enum AuthSecret {
|
||||
#[cfg(feature = "testing")]
|
||||
/// Md5 hash of user's password.
|
||||
@@ -231,7 +212,9 @@ pub enum AuthSecret {
|
||||
pub struct AuthInfo {
|
||||
pub secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub allowed_ips: Vec<String>,
|
||||
pub allowed_ips: Vec<SmolStr>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub project_id: Option<SmolStr>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
@@ -250,33 +233,34 @@ pub struct NodeInfo {
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
|
||||
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
|
||||
pub type AllowedIpsCache = TimedLru<SmolStr, Arc<Vec<String>>>;
|
||||
pub type RoleSecretCache = TimedLru<(SmolStr, SmolStr), Option<AuthSecret>>;
|
||||
pub type CachedRoleSecret = timed_lru::Cached<&'static RoleSecretCache>;
|
||||
pub type NodeInfoCache = TimedLru<SmolStr, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, AuthSecret>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<SmolStr>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
#[async_trait]
|
||||
pub trait Api {
|
||||
/// Get the client's auth secret for authentication.
|
||||
/// Returns option because user not found situation is special.
|
||||
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
|
||||
) -> Result<Option<CachedRoleSecret>, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
|
||||
) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
@@ -285,16 +269,31 @@ pub trait Api {
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub node_info: NodeInfoCache,
|
||||
/// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead.
|
||||
pub allowed_ips: AllowedIpsCache,
|
||||
/// Cache for the `get_role_secret`. TODO(anna): use notifications listener instead.
|
||||
pub role_secret: RoleSecretCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
}
|
||||
|
||||
impl ApiCaches {
|
||||
pub fn new(
|
||||
wake_compute_cache_config: CacheOptions,
|
||||
project_info_cache_config: ProjectInfoCacheOptions,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_info: NodeInfoCache::new(
|
||||
"node_info_cache",
|
||||
wake_compute_cache_config.size,
|
||||
wake_compute_cache_config.ttl,
|
||||
true,
|
||||
),
|
||||
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiLocks {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<Arc<str>, Arc<Semaphore>>,
|
||||
node_locks: DashMap<SmolStr, Arc<Semaphore>>,
|
||||
permits: usize,
|
||||
timeout: Duration,
|
||||
registered: prometheus::IntCounter,
|
||||
@@ -362,7 +361,7 @@ impl ApiLocks {
|
||||
|
||||
pub async fn get_wake_compute_permit(
|
||||
&self,
|
||||
key: &Arc<str>,
|
||||
key: &SmolStr,
|
||||
) -> Result<WakeComputePermit, errors::WakeComputeError> {
|
||||
if self.permits == 0 {
|
||||
return Ok(WakeComputePermit { permit: None });
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
//! Mock console backend which relies on a user-provided postgres instance.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||
};
|
||||
use crate::console::provider::CachedRoleSecret;
|
||||
use crate::cache::Cached;
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::{config::SslMode, Client};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
@@ -48,7 +50,7 @@ impl Api {
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
creds: &ComputeUserInfo,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let (secret, allowed_ips) = async {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
@@ -61,7 +63,7 @@ impl Api {
|
||||
let secret = match get_execute_postgres_query(
|
||||
&client,
|
||||
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
|
||||
&[&&*creds.inner.user],
|
||||
&[&&*user_info.user],
|
||||
"rolpassword",
|
||||
)
|
||||
.await?
|
||||
@@ -72,14 +74,14 @@ impl Api {
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
}
|
||||
None => {
|
||||
warn!("user '{}' does not exist", creds.inner.user);
|
||||
warn!("user '{}' does not exist", user_info.user);
|
||||
None
|
||||
}
|
||||
};
|
||||
let allowed_ips = match get_execute_postgres_query(
|
||||
&client,
|
||||
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
|
||||
&[&creds.endpoint.as_str()],
|
||||
&[&user_info.endpoint.as_str()],
|
||||
"allowed_ips",
|
||||
)
|
||||
.await?
|
||||
@@ -98,7 +100,8 @@ impl Api {
|
||||
.await?;
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
allowed_ips: allowed_ips.iter().map(SmolStr::from).collect(),
|
||||
project_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -145,27 +148,31 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(
|
||||
self.do_get_auth_info(creds).await?.secret,
|
||||
))
|
||||
_ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
|
||||
Ok(self
|
||||
.do_get_auth_info(user_info)
|
||||
.await?
|
||||
.secret
|
||||
.map(CachedRoleSecret::new_uncached))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
Ok(Arc::new(self.do_get_auth_info(creds).await?.allowed_ips))
|
||||
_ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
Ok(Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(user_info).await?.allowed_ips,
|
||||
)))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_extra: &ConsoleReqExtra,
|
||||
_creds: &ComputeUserInfo,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute()
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
|
||||
@@ -3,14 +3,19 @@
|
||||
use super::{
|
||||
super::messages::{ConsoleError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedNodeInfo, CachedRoleSecret, ConsoleReqExtra,
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
|
||||
NodeInfo,
|
||||
};
|
||||
use crate::metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, http, scram};
|
||||
use crate::{
|
||||
cache::Cached,
|
||||
context::RequestMonitoring,
|
||||
metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::TryFutureExt;
|
||||
use itertools::Itertools;
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
@@ -19,7 +24,7 @@ use tracing::{error, info, info_span, warn, Instrument};
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
pub caches: &'static ApiCaches,
|
||||
locks: &'static ApiLocks,
|
||||
jwt: String,
|
||||
}
|
||||
@@ -49,21 +54,22 @@ impl Api {
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get("proxy_get_role_secret")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[("session_id", ctx.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name.as_str()),
|
||||
("project", creds.endpoint.as_str()),
|
||||
("role", creds.inner.user.as_str()),
|
||||
("application_name", application_name.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
("role", user_info.user.as_str()),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
@@ -87,12 +93,13 @@ impl Api {
|
||||
.allowed_ips
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(String::from)
|
||||
.map(SmolStr::from)
|
||||
.collect_vec();
|
||||
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
|
||||
Ok(AuthInfo {
|
||||
secret: Some(secret),
|
||||
allowed_ips,
|
||||
project_id: body.project_id.map(SmolStr::from),
|
||||
})
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
@@ -102,27 +109,28 @@ impl Api {
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let mut request_builder = self
|
||||
.endpoint
|
||||
.get("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", extra.session_id)])
|
||||
.query(&[("session_id", ctx.session_id)])
|
||||
.query(&[
|
||||
("application_name", extra.application_name.as_str()),
|
||||
("project", creds.endpoint.as_str()),
|
||||
("application_name", application_name.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
]);
|
||||
|
||||
request_builder = if extra.options.is_empty() {
|
||||
request_builder
|
||||
} else {
|
||||
request_builder.query(&extra.options_as_deep_object())
|
||||
};
|
||||
let options = user_info.options.to_deep_object();
|
||||
if !options.is_empty() {
|
||||
request_builder = request_builder.query(&options);
|
||||
}
|
||||
|
||||
let request = request_builder.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
@@ -162,69 +170,77 @@ impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
let ep = creds.endpoint.clone();
|
||||
let user = creds.inner.user.clone();
|
||||
if let Some(role_secret) = self.caches.role_secret.get(&(ep.clone(), user.clone())) {
|
||||
return Ok(role_secret);
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
|
||||
let ep = &user_info.endpoint;
|
||||
let user = &user_info.user;
|
||||
if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
|
||||
return Ok(Some(role_secret));
|
||||
}
|
||||
let auth_info = self.do_get_auth_info(extra, creds).await?;
|
||||
let (_, secret) = self
|
||||
.caches
|
||||
.role_secret
|
||||
.insert((ep.clone(), user), auth_info.secret.clone());
|
||||
self.caches
|
||||
.allowed_ips
|
||||
.insert(ep, Arc::new(auth_info.allowed_ips));
|
||||
Ok(secret)
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let project_id = auth_info.project_id.unwrap_or(ep.clone());
|
||||
if let Some(secret) = &auth_info.secret {
|
||||
self.caches
|
||||
.project_info
|
||||
.insert_role_secret(&project_id, ep, user, secret.clone())
|
||||
}
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
&project_id,
|
||||
ep,
|
||||
Arc::new(auth_info.allowed_ips),
|
||||
);
|
||||
// When we just got a secret, we don't need to invalidate it.
|
||||
Ok(auth_info.secret.map(Cached::new_uncached))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
) -> Result<Arc<Vec<String>>, GetAuthInfoError> {
|
||||
if let Some(allowed_ips) = self.caches.allowed_ips.get(&creds.endpoint) {
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedAllowedIps, GetAuthInfoError> {
|
||||
let ep = &user_info.endpoint;
|
||||
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
.with_label_values(&["hit"])
|
||||
.inc();
|
||||
return Ok(Arc::new(allowed_ips.to_vec()));
|
||||
return Ok(allowed_ips);
|
||||
}
|
||||
ALLOWED_IPS_BY_CACHE_OUTCOME
|
||||
.with_label_values(&["miss"])
|
||||
.inc();
|
||||
let auth_info = self.do_get_auth_info(extra, creds).await?;
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let ep = creds.endpoint.clone();
|
||||
let user = creds.inner.user.clone();
|
||||
let user = &user_info.user;
|
||||
let project_id = auth_info.project_id.unwrap_or(ep.clone());
|
||||
if let Some(secret) = &auth_info.secret {
|
||||
self.caches
|
||||
.project_info
|
||||
.insert_role_secret(&project_id, ep, user, secret.clone())
|
||||
}
|
||||
self.caches
|
||||
.role_secret
|
||||
.insert((ep.clone(), user), auth_info.secret);
|
||||
self.caches.allowed_ips.insert(ep, allowed_ips.clone());
|
||||
Ok(allowed_ips)
|
||||
.project_info
|
||||
.insert_allowed_ips(&project_id, ep, allowed_ips.clone());
|
||||
Ok(Cached::new_uncached(allowed_ips))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
extra: &ConsoleReqExtra,
|
||||
creds: &ComputeUserInfo,
|
||||
ctx: &mut RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key: &str = &creds.inner.cache_key;
|
||||
let key = user_info.endpoint_cache_key();
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
if let Some(cached) = self.caches.node_info.get(key) {
|
||||
info!(key = key, "found cached compute node info");
|
||||
if let Some(cached) = self.caches.node_info.get(&*key) {
|
||||
info!(key = &*key, "found cached compute node info");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let key: Arc<str> = key.into();
|
||||
|
||||
let permit = self.locks.get_wake_compute_permit(&key).await?;
|
||||
|
||||
// after getting back a permit - it's possible the cache was filled
|
||||
@@ -236,7 +252,7 @@ impl super::Api for Api {
|
||||
}
|
||||
}
|
||||
|
||||
let node = self.do_wake_compute(extra, creds).await?;
|
||||
let node = self.do_wake_compute(ctx, user_info).await?;
|
||||
let (_, cached) = self.caches.node_info.insert(key.clone(), node);
|
||||
info!(key = &*key, "created a cache entry for compute node info");
|
||||
|
||||
|
||||
110
proxy/src/context.rs
Normal file
110
proxy/src/context.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
//! Connection request monitoring contexts
|
||||
|
||||
use chrono::Utc;
|
||||
use once_cell::sync::OnceCell;
|
||||
use smol_str::SmolStr;
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{console::messages::MetricsAuxInfo, error::ErrorKind, metrics::LatencyTimer};
|
||||
|
||||
pub mod parquet;
|
||||
|
||||
static LOG_CHAN: OnceCell<mpsc::WeakUnboundedSender<RequestMonitoring>> = OnceCell::new();
|
||||
|
||||
#[derive(Clone)]
|
||||
/// Context data for a single request to connect to a database.
|
||||
///
|
||||
/// This data should **not** be used for connection logic, only for observability and limiting purposes.
|
||||
/// All connection logic should instead use strongly typed state machines, not a bunch of Options.
|
||||
pub struct RequestMonitoring {
|
||||
pub peer_addr: IpAddr,
|
||||
pub session_id: Uuid,
|
||||
pub protocol: &'static str,
|
||||
first_packet: chrono::DateTime<Utc>,
|
||||
region: &'static str,
|
||||
|
||||
// filled in as they are discovered
|
||||
project: Option<SmolStr>,
|
||||
branch: Option<SmolStr>,
|
||||
endpoint_id: Option<SmolStr>,
|
||||
user: Option<SmolStr>,
|
||||
application: Option<SmolStr>,
|
||||
error_kind: Option<ErrorKind>,
|
||||
|
||||
// extra
|
||||
// This sender is here to keep the request monitoring channel open while requests are taking place.
|
||||
sender: Option<mpsc::UnboundedSender<RequestMonitoring>>,
|
||||
pub latency_timer: LatencyTimer,
|
||||
}
|
||||
|
||||
impl RequestMonitoring {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
protocol: &'static str,
|
||||
region: &'static str,
|
||||
) -> Self {
|
||||
Self {
|
||||
peer_addr,
|
||||
session_id,
|
||||
protocol,
|
||||
first_packet: Utc::now(),
|
||||
region,
|
||||
|
||||
project: None,
|
||||
branch: None,
|
||||
endpoint_id: None,
|
||||
user: None,
|
||||
application: None,
|
||||
error_kind: None,
|
||||
|
||||
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
||||
latency_timer: LatencyTimer::new(protocol),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
|
||||
}
|
||||
|
||||
pub fn console_application_name(&self) -> String {
|
||||
format!(
|
||||
"{}/{}",
|
||||
self.application.as_deref().unwrap_or_default(),
|
||||
self.protocol
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_project(&mut self, x: MetricsAuxInfo) {
|
||||
self.branch = Some(x.branch_id);
|
||||
self.endpoint_id = Some(x.endpoint_id);
|
||||
self.project = Some(x.project_id);
|
||||
}
|
||||
|
||||
pub fn set_endpoint_id(&mut self, endpoint_id: Option<SmolStr>) {
|
||||
self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone());
|
||||
}
|
||||
|
||||
pub fn set_application(&mut self, app: Option<SmolStr>) {
|
||||
self.application = app.or_else(|| self.application.clone());
|
||||
}
|
||||
|
||||
pub fn set_user(&mut self, user: SmolStr) {
|
||||
self.user = Some(user);
|
||||
}
|
||||
|
||||
pub fn log(&mut self) {
|
||||
if let Some(tx) = self.sender.take() {
|
||||
let _: Result<(), _> = tx.send(self.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RequestMonitoring {
|
||||
fn drop(&mut self) {
|
||||
self.log()
|
||||
}
|
||||
}
|
||||
641
proxy/src/context/parquet.rs
Normal file
641
proxy/src/context/parquet.rs
Normal file
@@ -0,0 +1,641 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::BytesMut;
|
||||
use futures::{Stream, StreamExt};
|
||||
use parquet::{
|
||||
basic::Compression,
|
||||
file::{
|
||||
metadata::RowGroupMetaDataPtr,
|
||||
properties::{WriterProperties, WriterPropertiesPtr, DEFAULT_PAGE_SIZE},
|
||||
writer::SerializedFileWriter,
|
||||
},
|
||||
record::RecordWriter,
|
||||
};
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig};
|
||||
use tokio::{sync::mpsc, time};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info, Span};
|
||||
use utils::backoff;
|
||||
|
||||
use super::{RequestMonitoring, LOG_CHAN};
|
||||
|
||||
#[derive(clap::Args, Clone, Debug)]
|
||||
pub struct ParquetUploadArgs {
|
||||
/// Storage location to upload the parquet files to.
|
||||
/// Encoded as toml (same format as pageservers), eg
|
||||
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
|
||||
#[clap(long, default_value = "{}", value_parser = remote_storage_from_toml)]
|
||||
parquet_upload_remote_storage: OptRemoteStorageConfig,
|
||||
|
||||
/// How many rows to include in a row group
|
||||
#[clap(long, default_value_t = 8192)]
|
||||
parquet_upload_row_group_size: usize,
|
||||
|
||||
/// How large each column page should be in bytes
|
||||
#[clap(long, default_value_t = DEFAULT_PAGE_SIZE)]
|
||||
parquet_upload_page_size: usize,
|
||||
|
||||
/// How large the total parquet file should be in bytes
|
||||
#[clap(long, default_value_t = 100_000_000)]
|
||||
parquet_upload_size: i64,
|
||||
|
||||
/// How long to wait before forcing a file upload
|
||||
#[clap(long, default_value = "20m", value_parser = humantime::parse_duration)]
|
||||
parquet_upload_maximum_duration: tokio::time::Duration,
|
||||
|
||||
/// What level of compression to use
|
||||
#[clap(long, default_value_t = Compression::UNCOMPRESSED)]
|
||||
parquet_upload_compression: Compression,
|
||||
}
|
||||
|
||||
/// Hack to avoid clap being smarter. If you don't use this type alias, clap assumes more about the optional state and you get
|
||||
/// runtime type errors from the value parser we use.
|
||||
type OptRemoteStorageConfig = Option<RemoteStorageConfig>;
|
||||
|
||||
fn remote_storage_from_toml(s: &str) -> anyhow::Result<OptRemoteStorageConfig> {
|
||||
RemoteStorageConfig::from_toml(&s.parse()?)
|
||||
}
|
||||
|
||||
// Occasional network issues and such can cause remote operations to fail, and
|
||||
// that's expected. If a upload fails, we log it at info-level, and retry.
|
||||
// But after FAILED_UPLOAD_WARN_THRESHOLD retries, we start to log it at WARN
|
||||
// level instead, as repeated failures can mean a more serious problem. If it
|
||||
// fails more than FAILED_UPLOAD_RETRIES times, we give up
|
||||
pub(crate) const FAILED_UPLOAD_WARN_THRESHOLD: u32 = 3;
|
||||
pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
|
||||
|
||||
// the parquet crate leaves a lot to be desired...
|
||||
// what follows is an attempt to write parquet files with minimal allocs.
|
||||
// complication: parquet is a columnar format, while we want to write in as rows.
|
||||
// design:
|
||||
// * we batch up to 1024 rows, then flush them into a 'row group'
|
||||
// * after each rowgroup write, we check the length of the file and upload to s3 if large enough
|
||||
|
||||
#[derive(parquet_derive::ParquetRecordWriter)]
|
||||
struct RequestData {
|
||||
region: &'static str,
|
||||
protocol: &'static str,
|
||||
/// Must be UTC. The derive macro doesn't like the timezones
|
||||
timestamp: chrono::NaiveDateTime,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: String,
|
||||
username: Option<String>,
|
||||
application_name: Option<String>,
|
||||
endpoint_id: Option<String>,
|
||||
project: Option<String>,
|
||||
branch: Option<String>,
|
||||
error: Option<&'static str>,
|
||||
}
|
||||
|
||||
impl From<RequestMonitoring> for RequestData {
|
||||
fn from(value: RequestMonitoring) -> Self {
|
||||
Self {
|
||||
session_id: value.session_id,
|
||||
peer_addr: value.peer_addr.to_string(),
|
||||
timestamp: value.first_packet.naive_utc(),
|
||||
username: value.user.as_deref().map(String::from),
|
||||
application_name: value.application.as_deref().map(String::from),
|
||||
endpoint_id: value.endpoint_id.as_deref().map(String::from),
|
||||
project: value.project.as_deref().map(String::from),
|
||||
branch: value.branch.as_deref().map(String::from),
|
||||
protocol: value.protocol,
|
||||
region: value.region,
|
||||
error: value.error_kind.as_ref().map(|e| e.to_str()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parquet request context worker
|
||||
///
|
||||
/// It listened on a channel for all completed requests, extracts the data and writes it into a parquet file,
|
||||
/// then uploads a completed batch to S3
|
||||
pub async fn worker(
|
||||
cancellation_token: CancellationToken,
|
||||
config: ParquetUploadArgs,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
|
||||
tracing::warn!("parquet request upload: no s3 bucket configured");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
LOG_CHAN.set(tx.downgrade()).unwrap();
|
||||
|
||||
// setup row stream that will close on cancellation
|
||||
tokio::spawn(async move {
|
||||
cancellation_token.cancelled().await;
|
||||
// dropping this sender will cause the channel to close only once
|
||||
// all the remaining inflight requests have been completed.
|
||||
drop(tx);
|
||||
});
|
||||
let rx = futures::stream::poll_fn(move |cx| rx.poll_recv(cx));
|
||||
let rx = rx.map(RequestData::from);
|
||||
|
||||
let storage =
|
||||
GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?;
|
||||
|
||||
let properties = WriterProperties::builder()
|
||||
.set_data_page_size_limit(config.parquet_upload_page_size)
|
||||
.set_compression(config.parquet_upload_compression);
|
||||
|
||||
let parquet_config = ParquetConfig {
|
||||
propeties: Arc::new(properties.build()),
|
||||
rows_per_group: config.parquet_upload_row_group_size,
|
||||
file_size: config.parquet_upload_size,
|
||||
max_duration: config.parquet_upload_maximum_duration,
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
worker_inner(storage, rx, parquet_config).await
|
||||
}
|
||||
|
||||
struct ParquetConfig {
|
||||
propeties: WriterPropertiesPtr,
|
||||
rows_per_group: usize,
|
||||
file_size: i64,
|
||||
|
||||
max_duration: tokio::time::Duration,
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
test_remote_failures: u64,
|
||||
}
|
||||
|
||||
async fn worker_inner(
|
||||
storage: GenericRemoteStorage,
|
||||
rx: impl Stream<Item = RequestData>,
|
||||
config: ParquetConfig,
|
||||
) -> anyhow::Result<()> {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
let storage = if config.test_remote_failures > 0 {
|
||||
GenericRemoteStorage::unreliable_wrapper(storage, config.test_remote_failures)
|
||||
} else {
|
||||
storage
|
||||
};
|
||||
|
||||
let mut rx = std::pin::pin!(rx);
|
||||
|
||||
let mut rows = Vec::with_capacity(config.rows_per_group);
|
||||
|
||||
let schema = rows.as_slice().schema()?;
|
||||
let file = BytesWriter::default();
|
||||
let mut w = SerializedFileWriter::new(file, schema.clone(), config.propeties.clone())?;
|
||||
|
||||
let mut last_upload = time::Instant::now();
|
||||
|
||||
let mut len = 0;
|
||||
while let Some(row) = rx.next().await {
|
||||
rows.push(row);
|
||||
let force = last_upload.elapsed() > config.max_duration;
|
||||
if rows.len() == config.rows_per_group || force {
|
||||
let rg_meta;
|
||||
(rows, w, rg_meta) = flush_rows(rows, w).await?;
|
||||
len += rg_meta.compressed_size();
|
||||
}
|
||||
if len > config.file_size || force {
|
||||
last_upload = time::Instant::now();
|
||||
let file = upload_parquet(w, len, &storage).await?;
|
||||
w = SerializedFileWriter::new(file, schema.clone(), config.propeties.clone())?;
|
||||
len = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if !rows.is_empty() {
|
||||
let rg_meta;
|
||||
(_, w, rg_meta) = flush_rows(rows, w).await?;
|
||||
len += rg_meta.compressed_size();
|
||||
}
|
||||
|
||||
if !w.flushed_row_groups().is_empty() {
|
||||
let _: BytesWriter = upload_parquet(w, len, &storage).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn flush_rows(
|
||||
rows: Vec<RequestData>,
|
||||
mut w: SerializedFileWriter<BytesWriter>,
|
||||
) -> anyhow::Result<(
|
||||
Vec<RequestData>,
|
||||
SerializedFileWriter<BytesWriter>,
|
||||
RowGroupMetaDataPtr,
|
||||
)> {
|
||||
let span = Span::current();
|
||||
let (mut rows, w, rg_meta) = tokio::task::spawn_blocking(move || {
|
||||
let _enter = span.enter();
|
||||
|
||||
let mut rg = w.next_row_group()?;
|
||||
rows.as_slice().write_to_row_group(&mut rg)?;
|
||||
let rg_meta = rg.close()?;
|
||||
|
||||
let size = rg_meta.compressed_size();
|
||||
let compression = rg_meta.compressed_size() as f64 / rg_meta.total_byte_size() as f64;
|
||||
|
||||
debug!(size, compression, "flushed row group to parquet file");
|
||||
|
||||
Ok::<_, parquet::errors::ParquetError>((rows, w, rg_meta))
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
rows.clear();
|
||||
Ok((rows, w, rg_meta))
|
||||
}
|
||||
|
||||
async fn upload_parquet(
|
||||
w: SerializedFileWriter<BytesWriter>,
|
||||
len: i64,
|
||||
storage: &GenericRemoteStorage,
|
||||
) -> anyhow::Result<BytesWriter> {
|
||||
let len_uncompressed = w
|
||||
.flushed_row_groups()
|
||||
.iter()
|
||||
.map(|rg| rg.total_byte_size())
|
||||
.sum::<i64>();
|
||||
|
||||
// I don't know how compute intensive this is, although it probably isn't much... better be safe than sorry.
|
||||
// finish method only available on the fork: https://github.com/apache/arrow-rs/issues/5253
|
||||
let (mut file, metadata) = tokio::task::spawn_blocking(move || w.finish())
|
||||
.await
|
||||
.unwrap()?;
|
||||
|
||||
let data = file.buf.split().freeze();
|
||||
|
||||
let compression = len as f64 / len_uncompressed as f64;
|
||||
let size = data.len();
|
||||
let id = uuid::Uuid::now_v7();
|
||||
|
||||
info!(
|
||||
%id,
|
||||
rows = metadata.num_rows,
|
||||
size, compression, "uploading request parquet file"
|
||||
);
|
||||
|
||||
let path = RemotePath::from_string(&format!("requests_{id}.parquet"))?;
|
||||
backoff::retry(
|
||||
|| async {
|
||||
let stream = futures::stream::once(futures::future::ready(Ok(data.clone())));
|
||||
storage.upload(stream, data.len(), &path, None).await
|
||||
},
|
||||
|_e| false,
|
||||
FAILED_UPLOAD_WARN_THRESHOLD,
|
||||
FAILED_UPLOAD_MAX_RETRIES,
|
||||
"request_data_upload",
|
||||
// we don't want cancellation to interrupt here, so we make a dummy cancel token
|
||||
backoff::Cancel::new(CancellationToken::new(), || anyhow::anyhow!("Cancelled")),
|
||||
)
|
||||
.await
|
||||
.context("request_data_upload")?;
|
||||
|
||||
Ok(file)
|
||||
}
|
||||
|
||||
// why doesn't BytesMut impl io::Write?
|
||||
#[derive(Default)]
|
||||
struct BytesWriter {
|
||||
buf: BytesMut,
|
||||
}
|
||||
|
||||
impl std::io::Write for BytesWriter {
|
||||
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
self.buf.extend_from_slice(buf);
|
||||
Ok(buf.len())
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{net::Ipv4Addr, num::NonZeroUsize, sync::Arc};
|
||||
|
||||
use camino::Utf8Path;
|
||||
use clap::Parser;
|
||||
use futures::{Stream, StreamExt};
|
||||
use itertools::Itertools;
|
||||
use parquet::{
|
||||
basic::{Compression, ZstdLevel},
|
||||
file::{
|
||||
properties::{WriterProperties, DEFAULT_PAGE_SIZE},
|
||||
reader::FileReader,
|
||||
serialized_reader::SerializedFileReader,
|
||||
},
|
||||
};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use remote_storage::{
|
||||
GenericRemoteStorage, RemoteStorageConfig, RemoteStorageKind, S3Config,
|
||||
DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT,
|
||||
};
|
||||
use tokio::{sync::mpsc, time};
|
||||
|
||||
use super::{worker_inner, ParquetConfig, ParquetUploadArgs, RequestData};
|
||||
|
||||
#[derive(Parser)]
|
||||
struct ProxyCliArgs {
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_parser() {
|
||||
let ProxyCliArgs { parquet_upload } = ProxyCliArgs::parse_from(["proxy"]);
|
||||
assert_eq!(parquet_upload.parquet_upload_remote_storage, None);
|
||||
assert_eq!(parquet_upload.parquet_upload_row_group_size, 8192);
|
||||
assert_eq!(parquet_upload.parquet_upload_page_size, DEFAULT_PAGE_SIZE);
|
||||
assert_eq!(parquet_upload.parquet_upload_size, 100_000_000);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_maximum_duration,
|
||||
time::Duration::from_secs(20 * 60)
|
||||
);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_compression,
|
||||
Compression::UNCOMPRESSED
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_parser() {
|
||||
let ProxyCliArgs { parquet_upload } = ProxyCliArgs::parse_from([
|
||||
"proxy",
|
||||
"--parquet-upload-remote-storage",
|
||||
"{bucket_name='default',prefix_in_bucket='proxy/',bucket_region='us-east-1',endpoint='http://minio:9000'}",
|
||||
"--parquet-upload-row-group-size",
|
||||
"100",
|
||||
"--parquet-upload-page-size",
|
||||
"10000",
|
||||
"--parquet-upload-size",
|
||||
"10000000",
|
||||
"--parquet-upload-maximum-duration",
|
||||
"10m",
|
||||
"--parquet-upload-compression",
|
||||
"zstd(5)",
|
||||
]);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_remote_storage,
|
||||
Some(RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: "default".into(),
|
||||
bucket_region: "us-east-1".into(),
|
||||
prefix_in_bucket: Some("proxy/".into()),
|
||||
endpoint: Some("http://minio:9000".into()),
|
||||
concurrency_limit: NonZeroUsize::new(
|
||||
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
|
||||
)
|
||||
.unwrap(),
|
||||
max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE,
|
||||
})
|
||||
})
|
||||
);
|
||||
assert_eq!(parquet_upload.parquet_upload_row_group_size, 100);
|
||||
assert_eq!(parquet_upload.parquet_upload_page_size, 10000);
|
||||
assert_eq!(parquet_upload.parquet_upload_size, 10_000_000);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_maximum_duration,
|
||||
time::Duration::from_secs(10 * 60)
|
||||
);
|
||||
assert_eq!(
|
||||
parquet_upload.parquet_upload_compression,
|
||||
Compression::ZSTD(ZstdLevel::try_new(5).unwrap())
|
||||
);
|
||||
}
|
||||
|
||||
fn generate_request_data(rng: &mut impl Rng) -> RequestData {
|
||||
RequestData {
|
||||
session_id: uuid::Builder::from_random_bytes(rng.gen()).into_uuid(),
|
||||
peer_addr: Ipv4Addr::from(rng.gen::<[u8; 4]>()).to_string(),
|
||||
timestamp: chrono::NaiveDateTime::from_timestamp_millis(
|
||||
rng.gen_range(1703862754..1803862754),
|
||||
)
|
||||
.unwrap(),
|
||||
application_name: Some("test".to_owned()),
|
||||
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
|
||||
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
||||
region: "us-east-1",
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn random_stream(len: usize) -> impl Stream<Item = RequestData> + Unpin {
|
||||
let mut rng = StdRng::from_seed([0x39; 32]);
|
||||
futures::stream::iter(
|
||||
std::iter::repeat_with(move || generate_request_data(&mut rng)).take(len),
|
||||
)
|
||||
}
|
||||
|
||||
async fn run_test(
|
||||
tmpdir: &Utf8Path,
|
||||
config: ParquetConfig,
|
||||
rx: impl Stream<Item = RequestData>,
|
||||
) -> Vec<(u64, usize, i64)> {
|
||||
let remote_storage_config = RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::LocalFs(tmpdir.to_path_buf()),
|
||||
};
|
||||
let storage = GenericRemoteStorage::from_config(&remote_storage_config).unwrap();
|
||||
|
||||
worker_inner(storage, rx, config).await.unwrap();
|
||||
|
||||
let mut files = std::fs::read_dir(tmpdir.as_std_path())
|
||||
.unwrap()
|
||||
.map(|entry| entry.unwrap().path())
|
||||
.collect_vec();
|
||||
files.sort();
|
||||
|
||||
files
|
||||
.into_iter()
|
||||
.map(|path| std::fs::File::open(tmpdir.as_std_path().join(path)).unwrap())
|
||||
.map(|file| {
|
||||
(
|
||||
file.metadata().unwrap(),
|
||||
SerializedFileReader::new(file).unwrap().metadata().clone(),
|
||||
)
|
||||
})
|
||||
.map(|(file_meta, parquet_meta)| {
|
||||
(
|
||||
file_meta.len(),
|
||||
parquet_meta.num_row_groups(),
|
||||
parquet_meta.file_metadata().num_rows(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_no_compression() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(WriterProperties::new()),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1029153, 3, 6000),
|
||||
(1029075, 3, 6000),
|
||||
(1029216, 3, 6000),
|
||||
(1029129, 3, 6000),
|
||||
(1029250, 3, 6000),
|
||||
(1029017, 3, 6000),
|
||||
(1029175, 3, 6000),
|
||||
(1029247, 3, 6000),
|
||||
(343124, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_min_compression() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(
|
||||
WriterProperties::builder()
|
||||
.set_compression(parquet::basic::Compression::ZSTD(ZstdLevel::default()))
|
||||
.build(),
|
||||
),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
// with compression, there are fewer files with more rows per file
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1166201, 6, 12000),
|
||||
(1163577, 6, 12000),
|
||||
(1164641, 6, 12000),
|
||||
(1168772, 6, 12000),
|
||||
(196761, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_strong_compression() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(
|
||||
WriterProperties::builder()
|
||||
.set_compression(parquet::basic::Compression::ZSTD(
|
||||
ZstdLevel::try_new(10).unwrap(),
|
||||
))
|
||||
.build(),
|
||||
),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 0,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
// with strong compression, the files are smaller
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1144934, 6, 12000),
|
||||
(1144941, 6, 12000),
|
||||
(1144735, 6, 12000),
|
||||
(1144936, 6, 12000),
|
||||
(191035, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn verify_parquet_unreliable_upload() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(WriterProperties::new()),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(20 * 60),
|
||||
test_remote_failures: 2,
|
||||
};
|
||||
|
||||
let rx = random_stream(50_000);
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1029153, 3, 6000),
|
||||
(1029075, 3, 6000),
|
||||
(1029216, 3, 6000),
|
||||
(1029129, 3, 6000),
|
||||
(1029250, 3, 6000),
|
||||
(1029017, 3, 6000),
|
||||
(1029175, 3, 6000),
|
||||
(1029247, 3, 6000),
|
||||
(343124, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn verify_parquet_regular_upload() {
|
||||
let tmpdir = camino_tempfile::tempdir().unwrap();
|
||||
|
||||
let config = ParquetConfig {
|
||||
propeties: Arc::new(WriterProperties::new()),
|
||||
rows_per_group: 2_000,
|
||||
file_size: 1_000_000,
|
||||
max_duration: time::Duration::from_secs(60),
|
||||
test_remote_failures: 2,
|
||||
};
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
for _ in 0..3 {
|
||||
let mut s = random_stream(3000);
|
||||
while let Some(r) = s.next().await {
|
||||
tx.send(r).unwrap();
|
||||
}
|
||||
time::sleep(time::Duration::from_secs(70)).await
|
||||
}
|
||||
});
|
||||
|
||||
let rx = futures::stream::poll_fn(move |cx| rx.poll_recv(cx));
|
||||
let file_stats = run_test(tmpdir.path(), config, rx).await;
|
||||
|
||||
// files are smaller than the size threshold, but they took too long to fill so were flushed early
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[(515807, 2, 3001), (515585, 2, 3000), (515425, 2, 2999)],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
}
|
||||
}
|
||||
@@ -28,3 +28,37 @@ pub trait UserFacingError: fmt::Display {
|
||||
self.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum ErrorKind {
|
||||
/// Wrong password, unknown endpoint, protocol violation, etc...
|
||||
User,
|
||||
|
||||
/// Network error between user and proxy. Not necessarily user error
|
||||
Disconnect,
|
||||
|
||||
/// Proxy self-imposed rate limits
|
||||
RateLimit,
|
||||
|
||||
/// internal errors
|
||||
Service,
|
||||
|
||||
/// Error communicating with control plane
|
||||
ControlPlane,
|
||||
|
||||
/// Error communicating with compute
|
||||
Compute,
|
||||
}
|
||||
|
||||
impl ErrorKind {
|
||||
pub fn to_str(&self) -> &'static str {
|
||||
match self {
|
||||
ErrorKind::User => "request failed due to user error",
|
||||
ErrorKind::Disconnect => "client disconnected",
|
||||
ErrorKind::RateLimit => "request cancelled due to rate limit",
|
||||
ErrorKind::Service => "internal service error",
|
||||
ErrorKind::ControlPlane => "non-retryable control plane error",
|
||||
ErrorKind::Compute => "non-retryable compute error (or exhausted retry capacity)",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ pub mod cancellation;
|
||||
pub mod compute;
|
||||
pub mod config;
|
||||
pub mod console;
|
||||
pub mod context;
|
||||
pub mod error;
|
||||
pub mod http;
|
||||
pub mod logging;
|
||||
@@ -21,6 +22,7 @@ pub mod parse;
|
||||
pub mod protocol2;
|
||||
pub mod proxy;
|
||||
pub mod rate_limiter;
|
||||
pub mod redis;
|
||||
pub mod sasl;
|
||||
pub mod scram;
|
||||
pub mod serverless;
|
||||
|
||||
@@ -115,11 +115,12 @@ pub static ALLOWED_IPS_NUMBER: Lazy<Histogram> = Lazy::new(|| {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LatencyTimer {
|
||||
// time since the stopwatch was started
|
||||
start: Option<time::Instant>,
|
||||
// accumulated time on the stopwatch
|
||||
accumulated: std::time::Duration,
|
||||
pub accumulated: std::time::Duration,
|
||||
// label data
|
||||
protocol: &'static str,
|
||||
cache_miss: bool,
|
||||
@@ -160,7 +161,12 @@ impl LatencyTimer {
|
||||
self.pool_miss = false;
|
||||
}
|
||||
|
||||
pub fn success(mut self) {
|
||||
pub fn success(&mut self) {
|
||||
// stop the stopwatch and record the time that we have accumulated
|
||||
let start = self.start.take().expect("latency timer should be started");
|
||||
self.accumulated += start.elapsed();
|
||||
|
||||
// success
|
||||
self.outcome = "success";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,9 +9,10 @@ use crate::{
|
||||
cancellation::{self, CancelMap},
|
||||
compute,
|
||||
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
|
||||
console::{self, messages::MetricsAuxInfo},
|
||||
console::messages::MetricsAuxInfo,
|
||||
context::RequestMonitoring,
|
||||
metrics::{
|
||||
LatencyTimer, NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
|
||||
NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
|
||||
NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE,
|
||||
},
|
||||
protocol2::WithClientIp,
|
||||
@@ -25,7 +26,8 @@ use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use std::{net::IpAddr, sync::Arc};
|
||||
use smol_str::SmolStr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
@@ -82,14 +84,16 @@ pub async fn task_main(
|
||||
info!("accepted postgres client connection");
|
||||
|
||||
let mut socket = WithClientIp::new(socket);
|
||||
let mut peer_addr = peer_addr;
|
||||
if let Some(ip) = socket.wait_for_addr().await? {
|
||||
peer_addr = ip;
|
||||
tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
|
||||
let mut peer_addr = peer_addr.ip();
|
||||
if let Some(addr) = socket.wait_for_addr().await? {
|
||||
peer_addr = addr.ip();
|
||||
tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
|
||||
} else if config.require_client_ip {
|
||||
bail!("missing required client IP");
|
||||
}
|
||||
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
|
||||
|
||||
socket
|
||||
.inner
|
||||
.set_nodelay(true)
|
||||
@@ -97,11 +101,10 @@ pub async fn task_main(
|
||||
|
||||
handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
&cancel_map,
|
||||
session_id,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
@@ -134,13 +137,6 @@ pub enum ClientMode {
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
fn protocol_label(&self) -> &'static str {
|
||||
match self {
|
||||
ClientMode::Tcp => "tcp",
|
||||
ClientMode::Websockets { .. } => "ws",
|
||||
}
|
||||
}
|
||||
|
||||
fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
@@ -173,19 +169,18 @@ impl ClientMode {
|
||||
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
info!(
|
||||
protocol = mode.protocol_label(),
|
||||
protocol = ctx.protocol,
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let proto = mode.protocol_label();
|
||||
let proto = ctx.protocol;
|
||||
let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE
|
||||
.with_label_values(&[proto])
|
||||
.guard();
|
||||
@@ -195,38 +190,46 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let tls = config.tls_config.as_ref();
|
||||
|
||||
let pause = ctx.latency_timer.pause();
|
||||
let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
|
||||
let (mut stream, params) = match do_handshake.await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let creds = {
|
||||
let user_info = {
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
let result = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names, peer_addr))
|
||||
.map(|_| {
|
||||
auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)
|
||||
})
|
||||
.transpose();
|
||||
|
||||
match result {
|
||||
Ok(creds) => creds,
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e).await?,
|
||||
}
|
||||
};
|
||||
|
||||
ctx.set_endpoint_id(user_info.get_endpoint());
|
||||
|
||||
let client = Client::new(
|
||||
stream,
|
||||
creds,
|
||||
user_info,
|
||||
¶ms,
|
||||
session_id,
|
||||
mode.allow_self_signed_compute(config),
|
||||
endpoint_rate_limiter,
|
||||
);
|
||||
cancel_map
|
||||
.with_session(|session| client.connect_to_db(session, mode, &config.authentication_config))
|
||||
.with_session(|session| {
|
||||
client.connect_to_db(ctx, session, mode, &config.authentication_config)
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -348,10 +351,13 @@ async fn prepare_client_connection(
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn proxy_pass(
|
||||
ctx: &mut RequestMonitoring,
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
ctx.log();
|
||||
|
||||
let usage = USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id.clone(),
|
||||
branch_id: aux.branch_id.clone(),
|
||||
@@ -394,11 +400,9 @@ struct Client<'a, S> {
|
||||
/// The underlying libpq protocol stream.
|
||||
stream: PqStream<Stream<S>>,
|
||||
/// Client credentials that we care about.
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
||||
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
params: &'a StartupMessageParams,
|
||||
/// Unique connection ID.
|
||||
session_id: uuid::Uuid,
|
||||
/// Allow self-signed certificates (for testing).
|
||||
allow_self_signed_compute: bool,
|
||||
/// Rate limiter for endpoints
|
||||
@@ -409,17 +413,15 @@ impl<'a, S> Client<'a, S> {
|
||||
/// Construct a new connection context.
|
||||
fn new(
|
||||
stream: PqStream<Stream<S>>,
|
||||
creds: auth::BackendType<'a, auth::ClientCredentials>,
|
||||
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
|
||||
params: &'a StartupMessageParams,
|
||||
session_id: uuid::Uuid,
|
||||
allow_self_signed_compute: bool,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
creds,
|
||||
user_info,
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
}
|
||||
@@ -430,24 +432,24 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
/// Let the client authenticate and connect to the designated compute node.
|
||||
// Instrumentation logs endpoint name everywhere. Doesn't work for link
|
||||
// auth; strictly speaking we don't know endpoint name in its case.
|
||||
#[tracing::instrument(name = "", fields(ep = %self.creds.get_endpoint().unwrap_or_default()), skip_all)]
|
||||
#[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)]
|
||||
async fn connect_to_db(
|
||||
self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
session: cancellation::Session<'_>,
|
||||
mode: ClientMode,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> anyhow::Result<()> {
|
||||
let Self {
|
||||
mut stream,
|
||||
creds,
|
||||
user_info,
|
||||
params,
|
||||
session_id,
|
||||
allow_self_signed_compute,
|
||||
endpoint_rate_limiter,
|
||||
} = self;
|
||||
|
||||
// check rate limit
|
||||
if let Some(ep) = creds.get_endpoint() {
|
||||
if let Some(ep) = user_info.get_endpoint() {
|
||||
if !endpoint_rate_limiter.check(ep) {
|
||||
return stream
|
||||
.throw_error(auth::AuthError::too_many_connections())
|
||||
@@ -455,27 +457,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
}
|
||||
}
|
||||
|
||||
let proto = mode.protocol_label();
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id, // aka this connection's id
|
||||
application_name: format!(
|
||||
"{}/{}",
|
||||
params.get("application_name").unwrap_or_default(),
|
||||
proto
|
||||
),
|
||||
options: neon_options(params),
|
||||
};
|
||||
let mut latency_timer = LatencyTimer::new(proto);
|
||||
|
||||
let user = creds.get_user().to_owned();
|
||||
let auth_result = match creds
|
||||
.authenticate(
|
||||
&extra,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
config,
|
||||
&mut latency_timer,
|
||||
)
|
||||
let user = user_info.get_user().to_owned();
|
||||
let auth_result = match user_info
|
||||
.authenticate(ctx, &mut stream, mode.allow_cleartext(), config)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
@@ -488,20 +472,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
}
|
||||
};
|
||||
|
||||
let (mut node_info, creds) = auth_result;
|
||||
let (mut node_info, user_info) = auth_result;
|
||||
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
|
||||
let aux = node_info.aux.clone();
|
||||
let mut node = connect_to_compute(
|
||||
&TcpMechanism { params, proto },
|
||||
node_info,
|
||||
&extra,
|
||||
&creds,
|
||||
latency_timer,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
prepare_client_connection(&node, session, &mut stream).await?;
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
@@ -510,33 +488,56 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
proxy_pass(stream, node.stream, aux).await
|
||||
proxy_pass(ctx, stream, node.stream, aux).await
|
||||
}
|
||||
}
|
||||
|
||||
pub fn neon_options(params: &StartupMessageParams) -> Vec<(String, String)> {
|
||||
#[allow(unstable_name_collisions)]
|
||||
match params.options_raw() {
|
||||
Some(options) => options.filter_map(neon_option).collect(),
|
||||
None => vec![],
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
|
||||
impl NeonOptions {
|
||||
pub fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
params
|
||||
.options_raw()
|
||||
.map(Self::parse_from_iter)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
pub fn parse_options_raw(options: &str) -> Self {
|
||||
Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
|
||||
}
|
||||
|
||||
fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
|
||||
let mut options = options
|
||||
.filter_map(neon_option)
|
||||
.map(|(k, v)| (k.into(), v.into()))
|
||||
.collect_vec();
|
||||
options.sort();
|
||||
Self(options)
|
||||
}
|
||||
|
||||
pub fn get_cache_key(&self, prefix: &str) -> SmolStr {
|
||||
// prefix + format!(" {k}:{v}")
|
||||
// kinda jank because SmolStr is immutable
|
||||
std::iter::once(prefix)
|
||||
.chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// <https://swagger.io/docs/specification/serialization/> DeepObject format
|
||||
/// `paramName[prop1]=value1¶mName[prop2]=value2&...`
|
||||
pub fn to_deep_object(&self) -> Vec<(String, SmolStr)> {
|
||||
self.0
|
||||
.iter()
|
||||
.map(|(k, v)| (format!("options[{}]", k), v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn neon_options_str(params: &StartupMessageParams) -> String {
|
||||
#[allow(unstable_name_collisions)]
|
||||
neon_options(params)
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{}:{}", k, v))
|
||||
.sorted() // we sort it to use as cache key
|
||||
.intersperse(" ".to_owned())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn neon_option(bytes: &str) -> Option<(String, String)> {
|
||||
pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
|
||||
static RE: OnceCell<Regex> = OnceCell::new();
|
||||
let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
|
||||
|
||||
let cap = re.captures(bytes)?;
|
||||
let (_, [k, v]) = cap.extract();
|
||||
Some((k.to_owned(), v.to_owned()))
|
||||
Some((k, v))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,8 @@ use crate::{
|
||||
auth,
|
||||
compute::{self, PostgresConnection},
|
||||
console::{self, errors::WakeComputeError, Api},
|
||||
metrics::{bool_to_str, LatencyTimer, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES},
|
||||
context::RequestMonitoring,
|
||||
metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES},
|
||||
proxy::retry::{retry_after, ShouldRetry},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
@@ -35,15 +36,15 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg
|
||||
/// Try to connect to the compute node once.
|
||||
#[tracing::instrument(name = "connect_once", fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn connect_to_compute_once(
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
proto: &'static str,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
let allow_self_signed_compute = node_info.allow_self_signed_compute;
|
||||
|
||||
node_info
|
||||
.config
|
||||
.connect(allow_self_signed_compute, timeout, proto)
|
||||
.connect(ctx, allow_self_signed_compute, timeout)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -54,6 +55,7 @@ pub trait ConnectMechanism {
|
||||
type Error: From<Self::ConnectError>;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
@@ -64,7 +66,6 @@ pub trait ConnectMechanism {
|
||||
pub struct TcpMechanism<'a> {
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub params: &'a StartupMessageParams,
|
||||
pub proto: &'static str,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -75,10 +76,11 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
connect_to_compute_once(node_info, timeout, self.proto).await
|
||||
connect_to_compute_once(ctx, node_info, timeout).await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
@@ -123,11 +125,10 @@ fn report_error(e: &WakeComputeError, retry: bool) {
|
||||
/// This function might update `node_info`, so we take it by `&mut`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn connect_to_compute<M: ConnectMechanism>(
|
||||
ctx: &mut RequestMonitoring,
|
||||
mechanism: &M,
|
||||
mut node_info: console::CachedNodeInfo,
|
||||
extra: &console::ConsoleReqExtra,
|
||||
creds: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
||||
mut latency_timer: LatencyTimer,
|
||||
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
||||
@@ -136,9 +137,12 @@ where
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
|
||||
let (config, err) = match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
latency_timer.success();
|
||||
ctx.latency_timer.success();
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -147,17 +151,17 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
latency_timer.cache_miss();
|
||||
ctx.latency_timer.cache_miss();
|
||||
|
||||
let mut num_retries = 1;
|
||||
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
let node_info = loop {
|
||||
let wake_res = match creds {
|
||||
auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
|
||||
let wake_res = match user_info {
|
||||
auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(feature = "testing")]
|
||||
auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
|
||||
auth::BackendType::Postgres(api, user_info) => api.wake_compute(ctx, user_info).await,
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => return Err(err.into()),
|
||||
// test backend
|
||||
@@ -195,9 +199,12 @@ where
|
||||
// * DNS connection settings haven't quite propagated yet
|
||||
info!("wake_compute success. attempting to connect");
|
||||
loop {
|
||||
match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
|
||||
match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
latency_timer.success();
|
||||
ctx.latency_timer.success();
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
@@ -7,11 +7,12 @@ use super::retry::ShouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, TestBackend};
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::{CachedNodeInfo, NodeInfo};
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
use rstest::rstest;
|
||||
use smol_str::SmolStr;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::{MakeTlsConnect, NoTls};
|
||||
use tokio_postgres_rustls::{MakeRustlsConnect, RustlsStream};
|
||||
@@ -82,7 +83,7 @@ fn generate_tls_config<'a>(
|
||||
let mut cert_resolver = CertResolver::new();
|
||||
cert_resolver.add_cert(key, vec![cert], true)?;
|
||||
|
||||
let common_names = Some(cert_resolver.get_common_names());
|
||||
let common_names = cert_resolver.get_common_names();
|
||||
|
||||
TlsConfig {
|
||||
config,
|
||||
@@ -425,6 +426,7 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_node_info: &console::CachedNodeInfo,
|
||||
_timeout: std::time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
@@ -469,7 +471,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_allowed_ips(&self) -> Result<Arc<Vec<String>>, console::errors::GetAuthInfoError> {
|
||||
fn get_allowed_ips(&self) -> Result<Vec<SmolStr>, console::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
}
|
||||
@@ -485,27 +487,19 @@ fn helper_create_cached_node_info() -> CachedNodeInfo {
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> (
|
||||
CachedNodeInfo,
|
||||
console::ConsoleReqExtra,
|
||||
auth::BackendType<'_, ComputeUserInfo>,
|
||||
) {
|
||||
) -> (CachedNodeInfo, auth::BackendType<'_, ComputeUserInfo>) {
|
||||
let cache = helper_create_cached_node_info();
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
application_name: "TEST".into(),
|
||||
options: vec![],
|
||||
};
|
||||
let creds = auth::BackendType::Test(mechanism);
|
||||
(cache, extra, creds)
|
||||
let user_info = auth::BackendType::Test(mechanism);
|
||||
(cache, user_info)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_success() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -514,9 +508,10 @@ async fn connect_to_compute_success() {
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_retry() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -526,9 +521,10 @@ async fn connect_to_compute_retry() {
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_1() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -538,9 +534,10 @@ async fn connect_to_compute_non_retry_1() {
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_2() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -551,12 +548,13 @@ async fn connect_to_compute_non_retry_2() {
|
||||
async fn connect_to_compute_non_retry_3() {
|
||||
assert_eq!(NUM_RETRIES_CONNECT, 16);
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![
|
||||
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -566,9 +564,10 @@ async fn connect_to_compute_non_retry_3() {
|
||||
#[tokio::test]
|
||||
async fn wake_retry() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -578,9 +577,10 @@ async fn wake_retry() {
|
||||
#[tokio::test]
|
||||
async fn wake_non_retry() {
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
|
||||
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
1
proxy/src/redis.rs
Normal file
1
proxy/src/redis.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod notifications;
|
||||
202
proxy/src/redis/notifications.rs
Normal file
202
proxy/src/redis/notifications.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use std::{convert::Infallible, sync::Arc};
|
||||
|
||||
use futures::StreamExt;
|
||||
use redis::aio::PubSub;
|
||||
use serde::Deserialize;
|
||||
use smol_str::SmolStr;
|
||||
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
|
||||
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
|
||||
struct ConsoleRedisClient {
|
||||
client: redis::Client,
|
||||
}
|
||||
|
||||
impl ConsoleRedisClient {
|
||||
pub fn new(url: &str) -> anyhow::Result<Self> {
|
||||
let client = redis::Client::open(url)?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
async fn try_connect(&self) -> anyhow::Result<PubSub> {
|
||||
let mut conn = self.client.get_async_connection().await?.into_pubsub();
|
||||
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
|
||||
conn.subscribe(CHANNEL_NAME).await?;
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[serde(tag = "topic", content = "data")]
|
||||
enum Notification {
|
||||
#[serde(
|
||||
rename = "/allowed_ips_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate,
|
||||
},
|
||||
#[serde(
|
||||
rename = "/password_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
PasswordUpdate { password_update: PasswordUpdate },
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct AllowedIpsUpdate {
|
||||
#[serde(rename = "project")]
|
||||
project_id: SmolStr,
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct PasswordUpdate {
|
||||
#[serde(rename = "project")]
|
||||
project_id: SmolStr,
|
||||
#[serde(rename = "role")]
|
||||
role_name: SmolStr,
|
||||
}
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: for<'de2> serde::Deserialize<'de2>,
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
|
||||
}
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
use Notification::*;
|
||||
match msg {
|
||||
AllowedIpsUpdate { allowed_ips_update } => {
|
||||
cache.invalidate_allowed_ips_for_project(&allowed_ips_update.project_id)
|
||||
}
|
||||
PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
|
||||
&password_update.project_id,
|
||||
&password_update.role_name,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(cache))]
|
||||
fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
let msg: Notification = match serde_json::from_str(&payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
tracing::error!("broken message: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
invalidate_cache(cache.clone(), msg.clone());
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg.clone());
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "console_notifications", skip_all)]
|
||||
pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
cache.enable_ttl();
|
||||
|
||||
loop {
|
||||
let redis = ConsoleRedisClient::new(&url)?;
|
||||
let conn = match redis.try_connect().await {
|
||||
Ok(conn) => {
|
||||
cache.disable_ttl();
|
||||
conn
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
|
||||
);
|
||||
tokio::time::sleep(RECONNECT_TIMEOUT).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let mut stream = conn.into_on_message();
|
||||
while let Some(msg) = stream.next().await {
|
||||
match handle_message(msg, cache.clone()) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to handle message: {e}, will try to reconnect");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
cache.enable_ttl();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_allowed_ips() -> anyhow::Result<()> {
|
||||
let project_id = "new_project".to_string();
|
||||
let data = format!("{{\"project\": \"{project_id}\"}}");
|
||||
let text = json!({
|
||||
"type": "message",
|
||||
"topic": "/allowed_ips_updated",
|
||||
"data": data,
|
||||
"extre_fields": "something"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::AllowedIpsUpdate {
|
||||
allowed_ips_update: AllowedIpsUpdate {
|
||||
project_id: project_id.into()
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_password_updated() -> anyhow::Result<()> {
|
||||
let project_id = "new_project".to_string();
|
||||
let role_name = "new_role".to_string();
|
||||
let data = format!("{{\"project\": \"{project_id}\", \"role\": \"{role_name}\"}}");
|
||||
let text = json!({
|
||||
"type": "message",
|
||||
"topic": "/password_updated",
|
||||
"data": data,
|
||||
"extre_fields": "something"
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(
|
||||
result,
|
||||
Notification::PasswordUpdate {
|
||||
password_update: PasswordUpdate {
|
||||
project_id: project_id.into(),
|
||||
role_name: role_name.into()
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ pub const SCRAM_KEY_LEN: usize = 32;
|
||||
/// One of the keys derived from the [password](super::password::SaltedPassword).
|
||||
/// We use the same structure for all keys, i.e.
|
||||
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
||||
#[derive(Clone, Default, PartialEq, Eq)]
|
||||
#[derive(Clone, Default, PartialEq, Eq, Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct ScramKey {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
|
||||
@@ -5,7 +5,7 @@ use super::key::ScramKey;
|
||||
|
||||
/// Server secret is produced from [password](super::password::SaltedPassword)
|
||||
/// and is used throughout the authentication process.
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub struct ServerSecret {
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
pub iterations: u32,
|
||||
|
||||
@@ -17,6 +17,8 @@ pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
@@ -68,15 +70,19 @@ pub async fn task_main(
|
||||
}
|
||||
});
|
||||
|
||||
let tls_config = config.tls_config.as_ref().map(|cfg| cfg.to_server_config());
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = match tls_config {
|
||||
Some(config) => config.into(),
|
||||
let tls_config = match config.tls_config.as_ref() {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
warn!("TLS config is missing, WebSocket Secure server will not be started");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
let mut tls_server_config = rustls::ServerConfig::clone(&tls_config.to_server_config());
|
||||
// prefer http2, but support http/1.1
|
||||
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
|
||||
|
||||
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
let _ = addr_incoming.set_nodelay(true);
|
||||
let addr_incoming = ProxyProtocolAccept {
|
||||
@@ -101,6 +107,9 @@ pub async fn task_main(
|
||||
let client_addr = io.client_addr();
|
||||
let remote_addr = io.inner.remote_addr();
|
||||
let sni_name = tls.server_name().map(|s| s.to_string());
|
||||
let protocol = tls
|
||||
.alpn_protocol()
|
||||
.map(|s| String::from_utf8_lossy(s).into_owned());
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
@@ -114,6 +123,7 @@ pub async fn task_main(
|
||||
Ok(MetricService::new(hyper::service::service_fn(
|
||||
move |req: Request<Body>| {
|
||||
let sni_name = sni_name.clone();
|
||||
let protocol = protocol.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
@@ -125,6 +135,7 @@ pub async fn task_main(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
tls_config,
|
||||
conn_pool,
|
||||
ws_connections,
|
||||
cancel_map,
|
||||
@@ -137,6 +148,7 @@ pub async fn task_main(
|
||||
"serverless",
|
||||
session = %session_id,
|
||||
%peer_addr,
|
||||
http_protocol = ?protocol,
|
||||
))
|
||||
.await
|
||||
}
|
||||
@@ -147,6 +159,7 @@ pub async fn task_main(
|
||||
);
|
||||
|
||||
hyper::Server::builder(accept::from_stream(tls_listener))
|
||||
.http2_enable_connect_protocol()
|
||||
.serve(make_svc)
|
||||
.with_graceful_shutdown(cancellation_token.cancelled())
|
||||
.await?;
|
||||
@@ -194,6 +207,7 @@ where
|
||||
async fn request_handler(
|
||||
mut request: Request<Body>,
|
||||
config: &'static ProxyConfig,
|
||||
tls: &'static TlsConfig,
|
||||
conn_pool: Arc<conn_pool::GlobalConnPool>,
|
||||
ws_connections: TaskTracker,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
@@ -209,22 +223,25 @@ async fn request_handler(
|
||||
.and_then(|h| h.split(':').next())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let ws_config = None;
|
||||
|
||||
// Check if the request is a websocket upgrade request.
|
||||
if hyper_tungstenite::is_upgrade_request(&request) {
|
||||
if websocket::is_upgrade_request(&request) {
|
||||
info!(session_id = ?session_id, "performing websocket upgrade");
|
||||
|
||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||
let (response, websocket) = websocket::upgrade(&mut request, ws_config)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
|
||||
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
websocket,
|
||||
config,
|
||||
&mut ctx,
|
||||
websocket,
|
||||
&cancel_map,
|
||||
session_id,
|
||||
host,
|
||||
peer_addr,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
@@ -235,21 +252,51 @@ async fn request_handler(
|
||||
.in_current_span(),
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if websocket::is_connect_request(&request) {
|
||||
info!(session_id = ?session_id, "performing http2 websocket upgrade");
|
||||
|
||||
let (response, websocket) = websocket::connect(&mut request, ws_config)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws2", &config.region);
|
||||
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
&mut ctx,
|
||||
websocket,
|
||||
&cancel_map,
|
||||
host,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!(session_id = ?session_id, "error in http2 websocket connection: {e:#}");
|
||||
}
|
||||
}
|
||||
.in_current_span(),
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
|
||||
sql_over_http::handle(
|
||||
tls,
|
||||
&config.http_config,
|
||||
&mut ctx,
|
||||
request,
|
||||
sni_hostname,
|
||||
conn_pool,
|
||||
session_id,
|
||||
peer_addr,
|
||||
&config.http_config,
|
||||
)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
.header("Allow", "OPTIONS, POST, CONNECT")
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
.header(
|
||||
"Access-Control-Allow-Headers",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use futures::{future::poll_fn, Future};
|
||||
@@ -9,11 +9,10 @@ use pbkdf2::{
|
||||
password_hash::{PasswordHashString, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Params, Pbkdf2,
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use prometheus::{exponential_buckets, register_histogram, Histogram};
|
||||
use rand::Rng;
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashMap, net::IpAddr, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||
use std::{
|
||||
fmt,
|
||||
task::{ready, Poll},
|
||||
@@ -28,8 +27,9 @@ use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
|
||||
use crate::{
|
||||
auth::{self, backend::ComputeUserInfo, check_peer_addr_is_in_list},
|
||||
console,
|
||||
metrics::{LatencyTimer, NUM_DB_CONNECTIONS_GAUGE},
|
||||
proxy::{connect_compute::ConnectMechanism, neon_options},
|
||||
context::RequestMonitoring,
|
||||
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
proxy::connect_compute::ConnectMechanism,
|
||||
usage_metrics::{Ids, MetricCounter, USAGE_METRICS},
|
||||
};
|
||||
use crate::{compute, config};
|
||||
@@ -37,28 +37,37 @@ use crate::{compute, config};
|
||||
use tracing::{debug, error, warn, Span};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
pub const APP_NAME: &str = "/sql_over_http";
|
||||
pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnInfo {
|
||||
pub username: SmolStr,
|
||||
pub user_info: ComputeUserInfo,
|
||||
pub dbname: SmolStr,
|
||||
pub hostname: SmolStr,
|
||||
pub password: SmolStr,
|
||||
pub options: Option<SmolStr>,
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
// hm, change to hasher to avoid cloning?
|
||||
pub fn db_and_user(&self) -> (SmolStr, SmolStr) {
|
||||
(self.dbname.clone(), self.username.clone())
|
||||
(self.dbname.clone(), self.user_info.user.clone())
|
||||
}
|
||||
|
||||
pub fn endpoint_cache_key(&self) -> SmolStr {
|
||||
self.user_info.endpoint_cache_key()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ConnInfo {
|
||||
// use custom display to avoid logging password
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}@{}/{}", self.username, self.hostname, self.dbname)
|
||||
write!(
|
||||
f,
|
||||
"{}@{}/{}?{}",
|
||||
self.user_info.user,
|
||||
self.user_info.endpoint,
|
||||
self.dbname,
|
||||
self.user_info.options.get_cache_key("")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -309,18 +318,16 @@ impl GlobalConnPool {
|
||||
|
||||
pub async fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
force_new: bool,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<Client> {
|
||||
let mut client: Option<ClientInner> = None;
|
||||
let mut latency_timer = LatencyTimer::new("http");
|
||||
|
||||
let mut hash_valid = false;
|
||||
let mut endpoint_pool = Weak::new();
|
||||
if !force_new {
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
|
||||
endpoint_pool = Arc::downgrade(&pool);
|
||||
let mut hash = None;
|
||||
|
||||
@@ -360,23 +367,21 @@ impl GlobalConnPool {
|
||||
info!(%conn_id, "pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
connect_to_compute(
|
||||
self.proxy_config,
|
||||
ctx,
|
||||
&conn_info,
|
||||
conn_id,
|
||||
session_id,
|
||||
latency_timer,
|
||||
peer_addr,
|
||||
endpoint_pool.clone(),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
info!("pool: reusing connection '{conn_info}'");
|
||||
client.session.send(session_id)?;
|
||||
client.session.send(ctx.session_id)?;
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
&tracing::field::display(client.inner.get_process_id()),
|
||||
);
|
||||
latency_timer.pool_hit();
|
||||
latency_timer.success();
|
||||
ctx.latency_timer.pool_hit();
|
||||
ctx.latency_timer.success();
|
||||
return Ok(Client::new(client, conn_info, endpoint_pool).await);
|
||||
}
|
||||
} else {
|
||||
@@ -384,11 +389,9 @@ impl GlobalConnPool {
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
connect_to_compute(
|
||||
self.proxy_config,
|
||||
ctx,
|
||||
&conn_info,
|
||||
conn_id,
|
||||
session_id,
|
||||
latency_timer,
|
||||
peer_addr,
|
||||
endpoint_pool.clone(),
|
||||
)
|
||||
.await
|
||||
@@ -406,7 +409,7 @@ impl GlobalConnPool {
|
||||
Err(err)
|
||||
if hash_valid && err.to_string().contains("password authentication failed") =>
|
||||
{
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
|
||||
let mut pool = pool.write();
|
||||
if let Some(entry) = pool.pools.get_mut(&conn_info.db_and_user()) {
|
||||
entry.password_hash = None;
|
||||
@@ -423,7 +426,7 @@ impl GlobalConnPool {
|
||||
})
|
||||
.await??;
|
||||
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.hostname);
|
||||
let pool = self.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key());
|
||||
let mut pool = pool.write();
|
||||
pool.pools
|
||||
.entry(conn_info.db_and_user())
|
||||
@@ -483,7 +486,6 @@ impl GlobalConnPool {
|
||||
struct TokioMechanism<'a> {
|
||||
pool: Weak<RwLock<EndpointConnPool>>,
|
||||
conn_info: &'a ConnInfo,
|
||||
session_id: uuid::Uuid,
|
||||
conn_id: uuid::Uuid,
|
||||
idle: Duration,
|
||||
}
|
||||
@@ -496,15 +498,16 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
connect_to_compute_once(
|
||||
ctx,
|
||||
node_info,
|
||||
self.conn_info,
|
||||
timeout,
|
||||
self.conn_id,
|
||||
self.session_id,
|
||||
self.pool.clone(),
|
||||
self.idle,
|
||||
)
|
||||
@@ -520,80 +523,58 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn connect_to_compute(
|
||||
config: &config::ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
session_id: uuid::Uuid,
|
||||
latency_timer: LatencyTimer,
|
||||
peer_addr: IpAddr,
|
||||
pool: Weak<RwLock<EndpointConnPool>>,
|
||||
) -> anyhow::Result<ClientInner> {
|
||||
let tls = config.tls_config.as_ref();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
ctx.set_application(Some(APP_NAME));
|
||||
let backend = config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| conn_info.user_info.clone());
|
||||
|
||||
let params = StartupMessageParams::new([
|
||||
("user", &conn_info.username),
|
||||
("database", &conn_info.dbname),
|
||||
("application_name", APP_NAME),
|
||||
("options", conn_info.options.as_deref().unwrap_or("")),
|
||||
]);
|
||||
let creds = auth::ClientCredentials::parse(
|
||||
¶ms,
|
||||
Some(&conn_info.hostname),
|
||||
common_names,
|
||||
peer_addr,
|
||||
)?;
|
||||
|
||||
let creds =
|
||||
ComputeUserInfo::try_from(creds).map_err(|_| anyhow!("missing endpoint identifier"))?;
|
||||
let backend = config.auth_backend.as_ref().map(|_| creds);
|
||||
|
||||
let console_options = neon_options(¶ms);
|
||||
|
||||
let extra = console::ConsoleReqExtra {
|
||||
session_id: uuid::Uuid::new_v4(),
|
||||
application_name: APP_NAME.to_string(),
|
||||
options: console_options,
|
||||
};
|
||||
if !config.disable_ip_check_for_http {
|
||||
let allowed_ips = backend.get_allowed_ips(&extra).await?;
|
||||
if !check_peer_addr_is_in_list(&peer_addr, &allowed_ips) {
|
||||
let allowed_ips = backend.get_allowed_ips(ctx).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed().into());
|
||||
}
|
||||
}
|
||||
let node_info = backend
|
||||
.wake_compute(&extra)
|
||||
.wake_compute(ctx)
|
||||
.await?
|
||||
.context("missing cache entry from wake_compute")?;
|
||||
|
||||
ctx.set_project(node_info.aux.clone());
|
||||
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
session_id,
|
||||
pool,
|
||||
idle: config.http_config.pool_options.idle_timeout,
|
||||
},
|
||||
node_info,
|
||||
&extra,
|
||||
&backend,
|
||||
latency_timer,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn connect_to_compute_once(
|
||||
ctx: &mut RequestMonitoring,
|
||||
node_info: &console::CachedNodeInfo,
|
||||
conn_info: &ConnInfo,
|
||||
timeout: time::Duration,
|
||||
conn_id: uuid::Uuid,
|
||||
mut session: uuid::Uuid,
|
||||
pool: Weak<RwLock<EndpointConnPool>>,
|
||||
idle: Duration,
|
||||
) -> Result<ClientInner, tokio_postgres::Error> {
|
||||
let mut config = (*node_info.config).clone();
|
||||
let mut session = ctx.session_id;
|
||||
|
||||
let (client, mut connection) = config
|
||||
.user(&conn_info.username)
|
||||
.user(&conn_info.user_info.user)
|
||||
.password(&*conn_info.password)
|
||||
.dbname(&conn_info.dbname)
|
||||
.connect_timeout(timeout)
|
||||
@@ -601,7 +582,7 @@ async fn connect_to_compute_once(
|
||||
.await?;
|
||||
|
||||
let conn_gauge = NUM_DB_CONNECTIONS_GAUGE
|
||||
.with_label_values(&["http"])
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard();
|
||||
|
||||
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
@@ -14,6 +13,7 @@ use hyper::{Body, HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use smol_str::SmolStr;
|
||||
use tokio_postgres::error::DbError;
|
||||
use tokio_postgres::types::Kind;
|
||||
use tokio_postgres::types::Type;
|
||||
@@ -28,8 +28,13 @@ use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
use utils::http::json::json_response;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::config::HttpConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
use crate::proxy::NeonOptions;
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
@@ -121,8 +126,10 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
ctx: &mut RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
sni_hostname: Option<String>,
|
||||
tls: &TlsConfig,
|
||||
) -> Result<ConnInfo, anyhow::Error> {
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
@@ -146,10 +153,11 @@ fn get_conn_info(
|
||||
.next()
|
||||
.ok_or(anyhow::anyhow!("invalid database name"))?;
|
||||
|
||||
let username = connection_url.username();
|
||||
let username = SmolStr::from(connection_url.username());
|
||||
if username.is_empty() {
|
||||
return Err(anyhow::anyhow!("missing username"));
|
||||
}
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let password = connection_url
|
||||
.password()
|
||||
@@ -176,45 +184,47 @@ fn get_conn_info(
|
||||
}
|
||||
}
|
||||
|
||||
let endpoint = endpoint_sni(hostname, &tls.common_names)?;
|
||||
|
||||
let endpoint: SmolStr = endpoint.into();
|
||||
ctx.set_endpoint_id(Some(endpoint.clone()));
|
||||
|
||||
let pairs = connection_url.query_pairs();
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
for (key, value) in pairs {
|
||||
if key == "options" {
|
||||
options = Some(value.into());
|
||||
options = Some(NeonOptions::parse_options_raw(&value));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let user_info = ComputeUserInfo {
|
||||
endpoint,
|
||||
user: username,
|
||||
options: options.unwrap_or_default(),
|
||||
};
|
||||
|
||||
Ok(ConnInfo {
|
||||
username: username.into(),
|
||||
user_info,
|
||||
dbname: dbname.into(),
|
||||
hostname: hostname.into(),
|
||||
password: password.into(),
|
||||
options,
|
||||
})
|
||||
}
|
||||
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
config: &'static HttpConfig,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.request_timeout,
|
||||
handle_inner(
|
||||
config,
|
||||
request,
|
||||
sni_hostname,
|
||||
conn_pool,
|
||||
session_id,
|
||||
peer_addr,
|
||||
),
|
||||
handle_inner(tls, config, ctx, request, sni_hostname, conn_pool),
|
||||
)
|
||||
.await;
|
||||
let mut response = match result {
|
||||
@@ -296,12 +306,12 @@ pub async fn handle(
|
||||
|
||||
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
|
||||
async fn handle_inner(
|
||||
tls: &'static TlsConfig,
|
||||
config: &'static HttpConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&["http"])
|
||||
@@ -311,7 +321,7 @@ async fn handle_inner(
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
let conn_info = get_conn_info(headers, sni_hostname)?;
|
||||
let conn_info = get_conn_info(ctx, headers, sni_hostname, tls)?;
|
||||
|
||||
// Determine the output options. Default behaviour is 'false'. Anything that is not
|
||||
// strictly 'true' assumed to be false.
|
||||
@@ -340,10 +350,12 @@ async fn handle_inner(
|
||||
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
||||
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
let paused = ctx.latency_timer.pause();
|
||||
let request_content_length = match request.body().size_hint().upper() {
|
||||
Some(v) => v,
|
||||
None => MAX_REQUEST_SIZE + 1,
|
||||
};
|
||||
drop(paused);
|
||||
|
||||
// we don't have a streaming request support yet so this is to prevent OOM
|
||||
// from a malicious user sending an extremely large request body
|
||||
@@ -359,9 +371,7 @@ async fn handle_inner(
|
||||
let body = hyper::body::to_bytes(request.into_body()).await?;
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
|
||||
let mut client = conn_pool
|
||||
.get(conn_info, !allow_pool, session_id, peer_addr)
|
||||
.await?;
|
||||
let mut client = conn_pool.get(ctx, conn_info, !allow_pool).await?;
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
@@ -449,6 +459,7 @@ async fn handle_inner(
|
||||
}
|
||||
};
|
||||
|
||||
ctx.log();
|
||||
let metrics = client.metrics();
|
||||
|
||||
// how could this possibly fail
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
use crate::{
|
||||
cancellation::CancelMap,
|
||||
config::ProxyConfig,
|
||||
context::RequestMonitoring,
|
||||
error::io_error,
|
||||
proxy::{handle_client, ClientMode},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
};
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
|
||||
use hyper::{ext::Protocol, upgrade::Upgraded, Body, Method, Request, Response};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tungstenite::{
|
||||
error::{Error as WSError, ProtocolError},
|
||||
handshake::derive_accept_key,
|
||||
protocol::{Role, WebSocketConfig},
|
||||
Message,
|
||||
};
|
||||
|
||||
use std::{
|
||||
net::IpAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{ready, Context, Poll},
|
||||
@@ -130,41 +136,222 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
}
|
||||
|
||||
pub async fn serve_websocket(
|
||||
websocket: HyperWebsocket,
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
websocket: HyperWebsocket,
|
||||
cancel_map: &CancelMap,
|
||||
session_id: uuid::Uuid,
|
||||
hostname: Option<String>,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
handle_client(
|
||||
config,
|
||||
ctx,
|
||||
cancel_map,
|
||||
session_id,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
peer_addr,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Try to upgrade a received `hyper::Request` to a websocket connection.
|
||||
///
|
||||
/// The function returns a HTTP response and a future that resolves to the websocket stream.
|
||||
/// The response body *MUST* be sent to the client before the future can be resolved.
|
||||
///
|
||||
/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers.
|
||||
/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
|
||||
/// You can inspect the headers manually before calling this function,
|
||||
/// and modify the response headers appropriately.
|
||||
///
|
||||
/// This function also does not look at the `Connection` or `Upgrade` headers.
|
||||
/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`].
|
||||
/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
|
||||
///
|
||||
pub fn upgrade<B>(
|
||||
mut request: impl std::borrow::BorrowMut<Request<B>>,
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Result<(Response<Body>, HyperWebsocket), ProtocolError> {
|
||||
let request = request.borrow_mut();
|
||||
|
||||
let key = request
|
||||
.headers()
|
||||
.get("Sec-WebSocket-Key")
|
||||
.ok_or(ProtocolError::MissingSecWebSocketKey)?;
|
||||
if request
|
||||
.headers()
|
||||
.get("Sec-WebSocket-Version")
|
||||
.map(|v| v.as_bytes())
|
||||
!= Some(b"13")
|
||||
{
|
||||
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
|
||||
}
|
||||
|
||||
let response = Response::builder()
|
||||
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(hyper::header::CONNECTION, "upgrade")
|
||||
.header(hyper::header::UPGRADE, "websocket")
|
||||
.header("Sec-WebSocket-Accept", &derive_accept_key(key.as_bytes()))
|
||||
.body(Body::from("switching to websocket protocol"))
|
||||
.expect("bug: failed to build response");
|
||||
|
||||
let stream = HyperWebsocket {
|
||||
inner: hyper::upgrade::on(request),
|
||||
config,
|
||||
};
|
||||
|
||||
Ok((response, stream))
|
||||
}
|
||||
|
||||
/// Check if a request is a websocket upgrade request.
|
||||
///
|
||||
/// If the `Upgrade` header lists multiple protocols,
|
||||
/// this function returns true if of them are `"websocket"`,
|
||||
/// If the server supports multiple upgrade protocols,
|
||||
/// it would be more appropriate to try each listed protocol in order.
|
||||
pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
|
||||
header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade")
|
||||
&& header_contains_value(request.headers(), hyper::header::UPGRADE, "websocket")
|
||||
}
|
||||
|
||||
/// Check if there is a header of the given name containing the wanted value.
|
||||
fn header_contains_value(
|
||||
headers: &hyper::HeaderMap,
|
||||
header: impl hyper::header::AsHeaderName,
|
||||
value: impl AsRef<[u8]>,
|
||||
) -> bool {
|
||||
let value = value.as_ref();
|
||||
for header in headers.get_all(header) {
|
||||
if header
|
||||
.as_bytes()
|
||||
.split(|&c| c == b',')
|
||||
.any(|x| trim(x).eq_ignore_ascii_case(value))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn trim(data: &[u8]) -> &[u8] {
|
||||
trim_end(trim_start(data))
|
||||
}
|
||||
|
||||
fn trim_start(data: &[u8]) -> &[u8] {
|
||||
if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) {
|
||||
&data[start..]
|
||||
} else {
|
||||
b""
|
||||
}
|
||||
}
|
||||
|
||||
fn trim_end(data: &[u8]) -> &[u8] {
|
||||
if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
|
||||
&data[..last + 1]
|
||||
} else {
|
||||
b""
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to upgrade a received `hyper::Request` to a websocket connection.
|
||||
///
|
||||
/// The function returns a HTTP response and a future that resolves to the websocket stream.
|
||||
/// The response body *MUST* be sent to the client before the future can be resolved.
|
||||
///
|
||||
/// This functions checks `Sec-WebSocket-Version` header.
|
||||
/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
|
||||
/// You can inspect the headers manually before calling this function,
|
||||
/// and modify the response headers appropriately.
|
||||
///
|
||||
/// This function also does not look at the `Connection` or `Upgrade` headers.
|
||||
/// To check if a request is a websocket connect request, you can use [`is_connect_request`].
|
||||
/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
|
||||
///
|
||||
pub fn connect<B>(
|
||||
mut request: impl std::borrow::BorrowMut<Request<B>>,
|
||||
config: Option<WebSocketConfig>,
|
||||
) -> Result<(Response<Body>, HyperWebsocket), ProtocolError> {
|
||||
let request = request.borrow_mut();
|
||||
|
||||
if request
|
||||
.headers()
|
||||
.get("Sec-WebSocket-Version")
|
||||
.map(|v| v.as_bytes())
|
||||
!= Some(b"13")
|
||||
{
|
||||
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
|
||||
}
|
||||
|
||||
let response = Response::builder()
|
||||
.status(hyper::StatusCode::OK)
|
||||
.body(Body::from("switching to websocket protocol"))
|
||||
.expect("bug: failed to build response");
|
||||
|
||||
let stream = HyperWebsocket {
|
||||
inner: hyper::upgrade::on(request),
|
||||
config,
|
||||
};
|
||||
|
||||
Ok((response, stream))
|
||||
}
|
||||
|
||||
/// Check if a request is a websocket connect request.
|
||||
pub fn is_connect_request<B>(request: &hyper::Request<B>) -> bool {
|
||||
request.method() == Method::CONNECT
|
||||
&& request
|
||||
.extensions()
|
||||
.get::<Protocol>()
|
||||
.is_some_and(|protocol| protocol.as_str() == "websocket")
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
/// A future that resolves to a websocket stream when the associated connection completes.
|
||||
#[derive(Debug)]
|
||||
pub struct HyperWebsocket {
|
||||
#[pin]
|
||||
inner: hyper::upgrade::OnUpgrade,
|
||||
config: Option<WebSocketConfig>
|
||||
}
|
||||
}
|
||||
|
||||
impl std::future::Future for HyperWebsocket {
|
||||
type Output = Result<WebSocketStream<hyper::upgrade::Upgraded>, WSError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
let upgraded = match this.inner.poll(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(x) => x,
|
||||
};
|
||||
|
||||
let upgraded =
|
||||
upgraded.map_err(|_| WSError::Protocol(ProtocolError::HandshakeIncomplete))?;
|
||||
|
||||
let stream = WebSocketStream::from_raw_socket(upgraded, Role::Server, None);
|
||||
tokio::pin!(stream);
|
||||
|
||||
// The future returned by `from_raw_socket` is always ready.
|
||||
// Not sure why it is a future in the first place.
|
||||
match stream.as_mut().poll(cx) {
|
||||
Poll::Pending => unreachable!("from_raw_socket should always be created ready"),
|
||||
Poll::Ready(x) => Poll::Ready(Ok(x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::pin;
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use hyper_tungstenite::{
|
||||
tungstenite::{protocol::Role, Message},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tokio::{
|
||||
io::{duplex, AsyncReadExt, AsyncWriteExt},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tungstenite::{protocol::Role, Message};
|
||||
|
||||
use super::WebSocketRw;
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ requests = "^2.31.0"
|
||||
pytest-xdist = "^3.3.1"
|
||||
asyncpg = "^0.29.0"
|
||||
aiopg = "^1.4.0"
|
||||
Jinja2 = "^3.0.2"
|
||||
Jinja2 = "^3.1.3"
|
||||
types-requests = "^2.31.0.0"
|
||||
types-psycopg2 = "^2.9.21.10"
|
||||
boto3 = "^1.34.11"
|
||||
@@ -38,6 +38,8 @@ pytest-rerunfailures = "^13.0"
|
||||
types-pytest-lazy-fixture = "^0.6.3.3"
|
||||
pytest-split = "^0.8.1"
|
||||
zstandard = "^0.21.0"
|
||||
websockets = "^12.0"
|
||||
httpx = {extras = ["http2"], version = "^0.26.0"}
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
mypy = "==1.3.0"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[toolchain]
|
||||
channel = "1.74.0"
|
||||
channel = "1.75.0"
|
||||
profile = "default"
|
||||
# The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy.
|
||||
# https://rust-lang.github.io/rustup/concepts/profiles.html
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
//! S3 objects which are either not referenced by any metadata, or are referenced by a
|
||||
//! control plane tenant/timeline in a deleted state.
|
||||
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use aws_sdk_s3::{
|
||||
@@ -118,6 +121,13 @@ const S3_CONCURRENCY: usize = 32;
|
||||
// How many concurrent API requests to make to the console API.
|
||||
const CONSOLE_CONCURRENCY: usize = 128;
|
||||
|
||||
struct ConsoleCache {
|
||||
/// Set of tenants found in the control plane API
|
||||
projects: HashMap<TenantId, ProjectData>,
|
||||
/// Set of tenants for which the control plane API returned 404
|
||||
not_found: HashSet<TenantId>,
|
||||
}
|
||||
|
||||
async fn find_garbage_inner(
|
||||
bucket_config: BucketConfig,
|
||||
console_config: ConsoleConfig,
|
||||
@@ -143,23 +153,49 @@ async fn find_garbage_inner(
|
||||
console_projects.len()
|
||||
);
|
||||
|
||||
// TODO(sharding): batch calls into Console so that we only call once for each TenantId,
|
||||
// rather than checking the same TenantId for multiple TenantShardId
|
||||
// Because many tenant shards may look up the same TenantId, we maintain a cache.
|
||||
let console_cache = Arc::new(std::sync::Mutex::new(ConsoleCache {
|
||||
projects: console_projects,
|
||||
not_found: HashSet::new(),
|
||||
}));
|
||||
|
||||
// Enumerate Tenants in S3, and check if each one exists in Console
|
||||
tracing::info!("Finding all tenants in bucket {}...", bucket_config.bucket);
|
||||
let tenants = stream_tenants(&s3_client, &target);
|
||||
let tenants_checked = tenants.map_ok(|t| {
|
||||
let api_client = cloud_admin_api_client.clone();
|
||||
let console_projects = &console_projects;
|
||||
let console_cache = console_cache.clone();
|
||||
async move {
|
||||
match console_projects.get(&t.tenant_id) {
|
||||
// Check cache before issuing API call
|
||||
let project_data = {
|
||||
let cache = console_cache.lock().unwrap();
|
||||
let result = cache.projects.get(&t.tenant_id).cloned();
|
||||
if result.is_none() && cache.not_found.contains(&t.tenant_id) {
|
||||
return Ok((t, None));
|
||||
}
|
||||
result
|
||||
};
|
||||
|
||||
match project_data {
|
||||
Some(project_data) => Ok((t, Some(project_data.clone()))),
|
||||
None => api_client
|
||||
.find_tenant_project(t.tenant_id)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e))
|
||||
.map(|r| (t, r)),
|
||||
None => {
|
||||
let project_data = api_client
|
||||
.find_tenant_project(t.tenant_id)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e));
|
||||
|
||||
// Populate cache with result of API call
|
||||
{
|
||||
let mut cache = console_cache.lock().unwrap();
|
||||
if let Ok(Some(project_data)) = &project_data {
|
||||
cache.projects.insert(t.tenant_id, project_data.clone());
|
||||
} else if let Ok(None) = &project_data {
|
||||
cache.not_found.insert(t.tenant_id);
|
||||
}
|
||||
}
|
||||
|
||||
project_data.map(|r| (t, r))
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -17,7 +17,9 @@ use utils::id::TenantId;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct MetadataSummary {
|
||||
count: usize,
|
||||
tenant_count: usize,
|
||||
timeline_count: usize,
|
||||
timeline_shard_count: usize,
|
||||
with_errors: HashSet<TenantShardTimelineId>,
|
||||
with_warnings: HashSet<TenantShardTimelineId>,
|
||||
with_orphans: HashSet<TenantShardTimelineId>,
|
||||
@@ -87,7 +89,9 @@ impl MinMaxHisto {
|
||||
impl MetadataSummary {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
count: 0,
|
||||
tenant_count: 0,
|
||||
timeline_count: 0,
|
||||
timeline_shard_count: 0,
|
||||
with_errors: HashSet::new(),
|
||||
with_warnings: HashSet::new(),
|
||||
with_orphans: HashSet::new(),
|
||||
@@ -112,7 +116,7 @@ impl MetadataSummary {
|
||||
}
|
||||
|
||||
fn update_data(&mut self, data: &S3TimelineBlobData) {
|
||||
self.count += 1;
|
||||
self.timeline_shard_count += 1;
|
||||
if let BlobDataParseResult::Parsed {
|
||||
index_part,
|
||||
index_part_generation: _,
|
||||
@@ -158,16 +162,20 @@ impl MetadataSummary {
|
||||
);
|
||||
|
||||
format!(
|
||||
"Timelines: {0}
|
||||
With errors: {1}
|
||||
With warnings: {2}
|
||||
With orphan layers: {3}
|
||||
"Tenants: {}
|
||||
Timelines: {}
|
||||
Timeline-shards: {}
|
||||
With errors: {}
|
||||
With warnings: {}
|
||||
With orphan layers: {}
|
||||
Index versions: {version_summary}
|
||||
Timeline size bytes: {4}
|
||||
Layer size bytes: {5}
|
||||
Timeline layer count: {6}
|
||||
Timeline size bytes: {}
|
||||
Layer size bytes: {}
|
||||
Timeline layer count: {}
|
||||
",
|
||||
self.count,
|
||||
self.tenant_count,
|
||||
self.timeline_count,
|
||||
self.timeline_shard_count,
|
||||
self.with_errors.len(),
|
||||
self.with_warnings.len(),
|
||||
self.with_orphans.len(),
|
||||
@@ -182,7 +190,7 @@ Timeline layer count: {6}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.count == 0
|
||||
self.timeline_shard_count == 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -233,8 +241,12 @@ pub async fn scan_metadata(
|
||||
mut tenant_objects: TenantObjectListing,
|
||||
timelines: Vec<(TenantShardTimelineId, S3TimelineBlobData)>,
|
||||
) {
|
||||
summary.tenant_count += 1;
|
||||
|
||||
let mut timeline_ids = HashSet::new();
|
||||
let mut timeline_generations = HashMap::new();
|
||||
for (ttid, data) in timelines {
|
||||
timeline_ids.insert(ttid.timeline_id);
|
||||
// Stash the generation of each timeline, for later use identifying orphan layers
|
||||
if let BlobDataParseResult::Parsed {
|
||||
index_part: _index_part,
|
||||
@@ -252,6 +264,8 @@ pub async fn scan_metadata(
|
||||
summary.update_analysis(&ttid, &analysis);
|
||||
}
|
||||
|
||||
summary.timeline_count += timeline_ids.len();
|
||||
|
||||
// Identifying orphan layers must be done on a tenant-wide basis, because individual
|
||||
// shards' layers may be referenced by other shards.
|
||||
//
|
||||
|
||||
@@ -13,13 +13,16 @@ use std::time::Instant;
|
||||
|
||||
use crate::control_file_upgrade::upgrade_control_file;
|
||||
use crate::metrics::PERSIST_CONTROL_FILE_SECONDS;
|
||||
use crate::safekeeper::{SafeKeeperState, SK_FORMAT_VERSION, SK_MAGIC};
|
||||
use crate::state::TimelinePersistentState;
|
||||
use utils::{bin_ser::LeSer, id::TenantTimelineId};
|
||||
|
||||
use crate::SafeKeeperConf;
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
pub const SK_MAGIC: u32 = 0xcafeceefu32;
|
||||
pub const SK_FORMAT_VERSION: u32 = 7;
|
||||
|
||||
// contains persistent metadata for safekeeper
|
||||
const CONTROL_FILE_NAME: &str = "safekeeper.control";
|
||||
// needed to atomically update the state using `rename`
|
||||
@@ -29,9 +32,9 @@ pub const CHECKSUM_SIZE: usize = std::mem::size_of::<u32>();
|
||||
/// Storage should keep actual state inside of it. It should implement Deref
|
||||
/// trait to access state fields and have persist method for updating that state.
|
||||
#[async_trait::async_trait]
|
||||
pub trait Storage: Deref<Target = SafeKeeperState> {
|
||||
pub trait Storage: Deref<Target = TimelinePersistentState> {
|
||||
/// Persist safekeeper state on disk and update internal state.
|
||||
async fn persist(&mut self, s: &SafeKeeperState) -> Result<()>;
|
||||
async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()>;
|
||||
|
||||
/// Timestamp of last persist.
|
||||
fn last_persist_at(&self) -> Instant;
|
||||
@@ -44,7 +47,7 @@ pub struct FileStorage {
|
||||
conf: SafeKeeperConf,
|
||||
|
||||
/// Last state persisted to disk.
|
||||
state: SafeKeeperState,
|
||||
state: TimelinePersistentState,
|
||||
/// Not preserved across restarts.
|
||||
last_persist_at: Instant,
|
||||
}
|
||||
@@ -68,7 +71,7 @@ impl FileStorage {
|
||||
pub fn create_new(
|
||||
timeline_dir: Utf8PathBuf,
|
||||
conf: &SafeKeeperConf,
|
||||
state: SafeKeeperState,
|
||||
state: TimelinePersistentState,
|
||||
) -> Result<FileStorage> {
|
||||
let store = FileStorage {
|
||||
timeline_dir,
|
||||
@@ -81,7 +84,7 @@ impl FileStorage {
|
||||
}
|
||||
|
||||
/// Check the magic/version in the on-disk data and deserialize it, if possible.
|
||||
fn deser_sk_state(buf: &mut &[u8]) -> Result<SafeKeeperState> {
|
||||
fn deser_sk_state(buf: &mut &[u8]) -> Result<TimelinePersistentState> {
|
||||
// Read the version independent part
|
||||
let magic = ReadBytesExt::read_u32::<LittleEndian>(buf)?;
|
||||
if magic != SK_MAGIC {
|
||||
@@ -93,7 +96,7 @@ impl FileStorage {
|
||||
}
|
||||
let version = ReadBytesExt::read_u32::<LittleEndian>(buf)?;
|
||||
if version == SK_FORMAT_VERSION {
|
||||
let res = SafeKeeperState::des(buf)?;
|
||||
let res = TimelinePersistentState::des(buf)?;
|
||||
return Ok(res);
|
||||
}
|
||||
// try to upgrade
|
||||
@@ -104,13 +107,15 @@ impl FileStorage {
|
||||
pub fn load_control_file_conf(
|
||||
conf: &SafeKeeperConf,
|
||||
ttid: &TenantTimelineId,
|
||||
) -> Result<SafeKeeperState> {
|
||||
) -> Result<TimelinePersistentState> {
|
||||
let path = conf.timeline_dir(ttid).join(CONTROL_FILE_NAME);
|
||||
Self::load_control_file(path)
|
||||
}
|
||||
|
||||
/// Read in the control file.
|
||||
pub fn load_control_file<P: AsRef<Path>>(control_file_path: P) -> Result<SafeKeeperState> {
|
||||
pub fn load_control_file<P: AsRef<Path>>(
|
||||
control_file_path: P,
|
||||
) -> Result<TimelinePersistentState> {
|
||||
let mut control_file = std::fs::OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
@@ -153,7 +158,7 @@ impl FileStorage {
|
||||
}
|
||||
|
||||
impl Deref for FileStorage {
|
||||
type Target = SafeKeeperState;
|
||||
type Target = TimelinePersistentState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.state
|
||||
@@ -165,7 +170,7 @@ impl Storage for FileStorage {
|
||||
/// Persists state durably to the underlying storage.
|
||||
///
|
||||
/// For a description, see <https://lwn.net/Articles/457667/>.
|
||||
async fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
|
||||
async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> {
|
||||
let _timer = PERSIST_CONTROL_FILE_SECONDS.start_timer();
|
||||
|
||||
// write data to safekeeper.control.partial
|
||||
@@ -242,7 +247,7 @@ impl Storage for FileStorage {
|
||||
mod test {
|
||||
use super::FileStorage;
|
||||
use super::*;
|
||||
use crate::{safekeeper::SafeKeeperState, SafeKeeperConf};
|
||||
use crate::SafeKeeperConf;
|
||||
use anyhow::Result;
|
||||
use utils::{id::TenantTimelineId, lsn::Lsn};
|
||||
|
||||
@@ -257,7 +262,7 @@ mod test {
|
||||
async fn load_from_control_file(
|
||||
conf: &SafeKeeperConf,
|
||||
ttid: &TenantTimelineId,
|
||||
) -> Result<(FileStorage, SafeKeeperState)> {
|
||||
) -> Result<(FileStorage, TimelinePersistentState)> {
|
||||
fs::create_dir_all(conf.timeline_dir(ttid))
|
||||
.await
|
||||
.expect("failed to create timeline dir");
|
||||
@@ -270,11 +275,11 @@ mod test {
|
||||
async fn create(
|
||||
conf: &SafeKeeperConf,
|
||||
ttid: &TenantTimelineId,
|
||||
) -> Result<(FileStorage, SafeKeeperState)> {
|
||||
) -> Result<(FileStorage, TimelinePersistentState)> {
|
||||
fs::create_dir_all(conf.timeline_dir(ttid))
|
||||
.await
|
||||
.expect("failed to create timeline dir");
|
||||
let state = SafeKeeperState::empty();
|
||||
let state = TimelinePersistentState::empty();
|
||||
let timeline_dir = conf.timeline_dir(ttid);
|
||||
let storage = FileStorage::create_new(timeline_dir, conf, state.clone())?;
|
||||
Ok((storage, state))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Code to deal with safekeeper control file upgrades
|
||||
use crate::safekeeper::{
|
||||
AcceptorState, PersistedPeers, PgUuid, SafeKeeperState, ServerInfo, Term, TermHistory, TermLsn,
|
||||
use crate::{
|
||||
safekeeper::{AcceptorState, PgUuid, ServerInfo, Term, TermHistory, TermLsn},
|
||||
state::{PersistedPeers, TimelinePersistentState},
|
||||
};
|
||||
use anyhow::{bail, Result};
|
||||
use pq_proto::SystemId;
|
||||
@@ -137,7 +138,7 @@ pub struct SafeKeeperStateV4 {
|
||||
pub peers: PersistedPeers,
|
||||
}
|
||||
|
||||
pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result<SafeKeeperState> {
|
||||
pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result<TimelinePersistentState> {
|
||||
// migrate to storing full term history
|
||||
if version == 1 {
|
||||
info!("reading safekeeper control file version {}", version);
|
||||
@@ -149,7 +150,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result<SafeKeeperState>
|
||||
lsn: Lsn(0),
|
||||
}]),
|
||||
};
|
||||
return Ok(SafeKeeperState {
|
||||
return Ok(TimelinePersistentState {
|
||||
tenant_id: oldstate.server.tenant_id,
|
||||
timeline_id: oldstate.server.timeline_id,
|
||||
acceptor_state: ac,
|
||||
@@ -176,7 +177,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result<SafeKeeperState>
|
||||
system_id: oldstate.server.system_id,
|
||||
wal_seg_size: oldstate.server.wal_seg_size,
|
||||
};
|
||||
return Ok(SafeKeeperState {
|
||||
return Ok(TimelinePersistentState {
|
||||
tenant_id: oldstate.server.tenant_id,
|
||||
timeline_id: oldstate.server.timeline_id,
|
||||
acceptor_state: oldstate.acceptor_state,
|
||||
@@ -199,7 +200,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result<SafeKeeperState>
|
||||
system_id: oldstate.server.system_id,
|
||||
wal_seg_size: oldstate.server.wal_seg_size,
|
||||
};
|
||||
return Ok(SafeKeeperState {
|
||||
return Ok(TimelinePersistentState {
|
||||
tenant_id: oldstate.server.tenant_id,
|
||||
timeline_id: oldstate.server.timeline_id,
|
||||
acceptor_state: oldstate.acceptor_state,
|
||||
@@ -222,7 +223,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result<SafeKeeperState>
|
||||
system_id: oldstate.server.system_id,
|
||||
wal_seg_size: oldstate.server.wal_seg_size,
|
||||
};
|
||||
return Ok(SafeKeeperState {
|
||||
return Ok(TimelinePersistentState {
|
||||
tenant_id: oldstate.tenant_id,
|
||||
timeline_id: oldstate.timeline_id,
|
||||
acceptor_state: oldstate.acceptor_state,
|
||||
@@ -238,7 +239,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result<SafeKeeperState>
|
||||
});
|
||||
} else if version == 5 {
|
||||
info!("reading safekeeper control file version {}", version);
|
||||
let mut oldstate = SafeKeeperState::des(&buf[..buf.len()])?;
|
||||
let mut oldstate = TimelinePersistentState::des(&buf[..buf.len()])?;
|
||||
if oldstate.timeline_start_lsn != Lsn(0) {
|
||||
return Ok(oldstate);
|
||||
}
|
||||
@@ -251,7 +252,7 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result<SafeKeeperState>
|
||||
return Ok(oldstate);
|
||||
} else if version == 6 {
|
||||
info!("reading safekeeper control file version {}", version);
|
||||
let mut oldstate = SafeKeeperState::des(&buf[..buf.len()])?;
|
||||
let mut oldstate = TimelinePersistentState::des(&buf[..buf.len()])?;
|
||||
if oldstate.server.pg_version != 0 {
|
||||
return Ok(oldstate);
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ use utils::{id::TenantTimelineId, lsn::Lsn};
|
||||
use crate::{
|
||||
control_file::{FileStorage, Storage},
|
||||
pull_timeline::{create_temp_timeline_dir, load_temp_timeline, validate_temp_timeline},
|
||||
safekeeper::SafeKeeperState,
|
||||
state::TimelinePersistentState,
|
||||
timeline::{Timeline, TimelineError},
|
||||
wal_backup::copy_s3_segments,
|
||||
wal_storage::{wal_file_paths, WalReader},
|
||||
@@ -137,7 +137,7 @@ pub async fn handle_request(request: Request) -> Result<()> {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut new_state = SafeKeeperState::new(
|
||||
let mut new_state = TimelinePersistentState::new(
|
||||
&request.destination_ttid,
|
||||
state.server.clone(),
|
||||
vec![],
|
||||
@@ -160,7 +160,7 @@ pub async fn handle_request(request: Request) -> Result<()> {
|
||||
|
||||
async fn copy_disk_segments(
|
||||
conf: &SafeKeeperConf,
|
||||
persisted_state: &SafeKeeperState,
|
||||
persisted_state: &TimelinePersistentState,
|
||||
wal_seg_size: usize,
|
||||
source_ttid: &TenantTimelineId,
|
||||
start_lsn: Lsn,
|
||||
|
||||
@@ -22,14 +22,13 @@ use utils::id::TenantTimelineId;
|
||||
use utils::id::{TenantId, TimelineId};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use crate::safekeeper::SafeKeeperState;
|
||||
use crate::safekeeper::SafekeeperMemState;
|
||||
use crate::safekeeper::TermHistory;
|
||||
use crate::SafeKeeperConf;
|
||||
|
||||
use crate::send_wal::WalSenderState;
|
||||
use crate::state::TimelineMemState;
|
||||
use crate::state::TimelinePersistentState;
|
||||
use crate::wal_storage::WalReader;
|
||||
use crate::GlobalTimelines;
|
||||
use crate::SafeKeeperConf;
|
||||
|
||||
/// Various filters that influence the resulting JSON output.
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
@@ -143,7 +142,7 @@ pub struct Config {
|
||||
pub struct Timeline {
|
||||
pub tenant_id: TenantId,
|
||||
pub timeline_id: TimelineId,
|
||||
pub control_file: Option<SafeKeeperState>,
|
||||
pub control_file: Option<TimelinePersistentState>,
|
||||
pub memory: Option<Memory>,
|
||||
pub disk_content: Option<DiskContent>,
|
||||
}
|
||||
@@ -158,7 +157,7 @@ pub struct Memory {
|
||||
pub num_computes: u32,
|
||||
pub last_removed_segno: XLogSegNo,
|
||||
pub epoch_start_lsn: Lsn,
|
||||
pub mem_state: SafekeeperMemState,
|
||||
pub mem_state: TimelineMemState,
|
||||
|
||||
// PhysicalStorage state.
|
||||
pub write_lsn: Lsn,
|
||||
|
||||
@@ -160,7 +160,7 @@ async fn timeline_status_handler(request: Request<Body>) -> Result<Response<Body
|
||||
commit_lsn: inmem.commit_lsn,
|
||||
backup_lsn: inmem.backup_lsn,
|
||||
peer_horizon_lsn: inmem.peer_horizon_lsn,
|
||||
remote_consistent_lsn: tli.get_walsenders().get_remote_consistent_lsn(),
|
||||
remote_consistent_lsn: inmem.remote_consistent_lsn,
|
||||
peers: tli.get_peers(conf).await,
|
||||
walsenders: tli.get_walsenders().get_all(),
|
||||
walreceivers: tli.get_walreceivers().get_all(),
|
||||
|
||||
@@ -21,7 +21,8 @@ use crate::safekeeper::{AcceptorProposerMessage, AppendResponse, ServerInfo};
|
||||
use crate::safekeeper::{
|
||||
AppendRequest, AppendRequestHeader, ProposerAcceptorMessage, ProposerElected,
|
||||
};
|
||||
use crate::safekeeper::{SafeKeeperState, Term, TermHistory, TermLsn};
|
||||
use crate::safekeeper::{Term, TermHistory, TermLsn};
|
||||
use crate::state::TimelinePersistentState;
|
||||
use crate::timeline::Timeline;
|
||||
use crate::GlobalTimelines;
|
||||
use postgres_backend::PostgresBackend;
|
||||
@@ -56,7 +57,7 @@ pub struct AppendLogicalMessage {
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AppendResult {
|
||||
// safekeeper state after append
|
||||
state: SafeKeeperState,
|
||||
state: TimelinePersistentState,
|
||||
// info about new record in the WAL
|
||||
inserted_wal: InsertedWAL,
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ pub mod recovery;
|
||||
pub mod remove_wal;
|
||||
pub mod safekeeper;
|
||||
pub mod send_wal;
|
||||
pub mod state;
|
||||
pub mod timeline;
|
||||
pub mod wal_backup;
|
||||
pub mod wal_service;
|
||||
|
||||
@@ -21,7 +21,7 @@ use utils::pageserver_feedback::PageserverFeedback;
|
||||
use utils::{id::TenantTimelineId, lsn::Lsn};
|
||||
|
||||
use crate::{
|
||||
safekeeper::{SafeKeeperState, SafekeeperMemState},
|
||||
state::{TimelineMemState, TimelinePersistentState},
|
||||
GlobalTimelines,
|
||||
};
|
||||
|
||||
@@ -308,11 +308,10 @@ pub struct FullTimelineInfo {
|
||||
pub last_removed_segno: XLogSegNo,
|
||||
|
||||
pub epoch_start_lsn: Lsn,
|
||||
pub mem_state: SafekeeperMemState,
|
||||
pub persisted_state: SafeKeeperState,
|
||||
pub mem_state: TimelineMemState,
|
||||
pub persisted_state: TimelinePersistentState,
|
||||
|
||||
pub flush_lsn: Lsn,
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
|
||||
pub wal_storage: WalStorageMetrics,
|
||||
}
|
||||
@@ -608,7 +607,7 @@ impl Collector for TimelineCollector {
|
||||
.set(tli.mem_state.peer_horizon_lsn.into());
|
||||
self.remote_consistent_lsn
|
||||
.with_label_values(labels)
|
||||
.set(tli.remote_consistent_lsn.into());
|
||||
.set(tli.mem_state.remote_consistent_lsn.into());
|
||||
self.timeline_active
|
||||
.with_label_values(labels)
|
||||
.set(tli.timeline_is_active as u64);
|
||||
|
||||
@@ -18,17 +18,16 @@ use tracing::*;
|
||||
use crate::control_file;
|
||||
use crate::send_wal::HotStandbyFeedback;
|
||||
|
||||
use crate::state::TimelineState;
|
||||
use crate::wal_storage;
|
||||
use pq_proto::SystemId;
|
||||
use utils::pageserver_feedback::PageserverFeedback;
|
||||
use utils::{
|
||||
bin_ser::LeSer,
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
};
|
||||
|
||||
pub const SK_MAGIC: u32 = 0xcafeceefu32;
|
||||
pub const SK_FORMAT_VERSION: u32 = 7;
|
||||
const SK_PROTOCOL_VERSION: u32 = 2;
|
||||
pub const UNKNOWN_SERVER_VERSION: u32 = 0;
|
||||
|
||||
@@ -222,7 +221,7 @@ pub struct PersistedPeerInfo {
|
||||
}
|
||||
|
||||
impl PersistedPeerInfo {
|
||||
fn new() -> Self {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
backup_lsn: Lsn::INVALID,
|
||||
term: INVALID_TERM,
|
||||
@@ -232,111 +231,10 @@ impl PersistedPeerInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct PersistedPeers(pub Vec<(NodeId, PersistedPeerInfo)>);
|
||||
|
||||
/// Persistent information stored on safekeeper node
|
||||
/// On disk data is prefixed by magic and format version and followed by checksum.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct SafeKeeperState {
|
||||
#[serde(with = "hex")]
|
||||
pub tenant_id: TenantId,
|
||||
#[serde(with = "hex")]
|
||||
pub timeline_id: TimelineId,
|
||||
/// persistent acceptor state
|
||||
pub acceptor_state: AcceptorState,
|
||||
/// information about server
|
||||
pub server: ServerInfo,
|
||||
/// Unique id of the last *elected* proposer we dealt with. Not needed
|
||||
/// for correctness, exists for monitoring purposes.
|
||||
#[serde(with = "hex")]
|
||||
pub proposer_uuid: PgUuid,
|
||||
/// Since which LSN this timeline generally starts. Safekeeper might have
|
||||
/// joined later.
|
||||
pub timeline_start_lsn: Lsn,
|
||||
/// Since which LSN safekeeper has (had) WAL for this timeline.
|
||||
/// All WAL segments next to one containing local_start_lsn are
|
||||
/// filled with data from the beginning.
|
||||
pub local_start_lsn: Lsn,
|
||||
/// Part of WAL acknowledged by quorum *and available locally*. Always points
|
||||
/// to record boundary.
|
||||
pub commit_lsn: Lsn,
|
||||
/// LSN that points to the end of the last backed up segment. Useful to
|
||||
/// persist to avoid finding out offloading progress on boot.
|
||||
pub backup_lsn: Lsn,
|
||||
/// Minimal LSN which may be needed for recovery of some safekeeper (end_lsn
|
||||
/// of last record streamed to everyone). Persisting it helps skipping
|
||||
/// recovery in walproposer, generally we compute it from peers. In
|
||||
/// walproposer proto called 'truncate_lsn'. Updates are currently drived
|
||||
/// only by walproposer.
|
||||
pub peer_horizon_lsn: Lsn,
|
||||
/// LSN of the oldest known checkpoint made by pageserver and successfully
|
||||
/// pushed to s3. We don't remove WAL beyond it. Persisted only for
|
||||
/// informational purposes, we receive it from pageserver (or broker).
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
// Peers and their state as we remember it. Knowing peers themselves is
|
||||
// fundamental; but state is saved here only for informational purposes and
|
||||
// obviously can be stale. (Currently not saved at all, but let's provision
|
||||
// place to have less file version upgrades).
|
||||
pub peers: PersistedPeers,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
// In memory safekeeper state. Fields mirror ones in `SafeKeeperState`; values
|
||||
// are not flushed yet.
|
||||
pub struct SafekeeperMemState {
|
||||
pub commit_lsn: Lsn,
|
||||
pub backup_lsn: Lsn,
|
||||
pub peer_horizon_lsn: Lsn,
|
||||
#[serde(with = "hex")]
|
||||
pub proposer_uuid: PgUuid,
|
||||
}
|
||||
|
||||
impl SafeKeeperState {
|
||||
pub fn new(
|
||||
ttid: &TenantTimelineId,
|
||||
server_info: ServerInfo,
|
||||
peers: Vec<NodeId>,
|
||||
commit_lsn: Lsn,
|
||||
local_start_lsn: Lsn,
|
||||
) -> SafeKeeperState {
|
||||
SafeKeeperState {
|
||||
tenant_id: ttid.tenant_id,
|
||||
timeline_id: ttid.timeline_id,
|
||||
acceptor_state: AcceptorState {
|
||||
term: 0,
|
||||
term_history: TermHistory::empty(),
|
||||
},
|
||||
server: server_info,
|
||||
proposer_uuid: [0; 16],
|
||||
timeline_start_lsn: Lsn(0),
|
||||
local_start_lsn,
|
||||
commit_lsn,
|
||||
backup_lsn: local_start_lsn,
|
||||
peer_horizon_lsn: local_start_lsn,
|
||||
remote_consistent_lsn: Lsn(0),
|
||||
peers: PersistedPeers(
|
||||
peers
|
||||
.iter()
|
||||
.map(|p| (*p, PersistedPeerInfo::new()))
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn empty() -> Self {
|
||||
SafeKeeperState::new(
|
||||
&TenantTimelineId::empty(),
|
||||
ServerInfo {
|
||||
pg_version: UNKNOWN_SERVER_VERSION, /* Postgres server version */
|
||||
system_id: 0, /* Postgres system identifier */
|
||||
wal_seg_size: 0,
|
||||
},
|
||||
vec![],
|
||||
Lsn::INVALID,
|
||||
Lsn::INVALID,
|
||||
)
|
||||
// make clippy happy
|
||||
impl Default for PersistedPeerInfo {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -583,9 +481,7 @@ pub struct SafeKeeper<CTRL: control_file::Storage, WAL: wal_storage::Storage> {
|
||||
/// determines epoch switch point.
|
||||
pub epoch_start_lsn: Lsn,
|
||||
|
||||
pub inmem: SafekeeperMemState, // in memory part
|
||||
pub state: CTRL, // persistent state storage
|
||||
|
||||
pub state: TimelineState<CTRL>, // persistent state storage
|
||||
pub wal_store: WAL,
|
||||
|
||||
node_id: NodeId, // safekeeper's node id
|
||||
@@ -612,13 +508,7 @@ where
|
||||
|
||||
Ok(SafeKeeper {
|
||||
epoch_start_lsn: Lsn(0),
|
||||
inmem: SafekeeperMemState {
|
||||
commit_lsn: state.commit_lsn,
|
||||
backup_lsn: state.backup_lsn,
|
||||
peer_horizon_lsn: state.peer_horizon_lsn,
|
||||
proposer_uuid: state.proposer_uuid,
|
||||
},
|
||||
state,
|
||||
state: TimelineState::new(state),
|
||||
wal_store,
|
||||
node_id,
|
||||
})
|
||||
@@ -726,12 +616,12 @@ where
|
||||
);
|
||||
}
|
||||
|
||||
let mut state = self.state.clone();
|
||||
let mut state = self.state.start_change();
|
||||
state.server.system_id = msg.system_id;
|
||||
if msg.pg_version != UNKNOWN_SERVER_VERSION {
|
||||
state.server.pg_version = msg.pg_version;
|
||||
}
|
||||
self.state.persist(&state).await?;
|
||||
self.state.finish_change(&state).await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -766,15 +656,15 @@ where
|
||||
term: self.state.acceptor_state.term,
|
||||
vote_given: false as u64,
|
||||
flush_lsn: self.flush_lsn(),
|
||||
truncate_lsn: self.inmem.peer_horizon_lsn,
|
||||
truncate_lsn: self.state.inmem.peer_horizon_lsn,
|
||||
term_history: self.get_term_history(),
|
||||
timeline_start_lsn: self.state.timeline_start_lsn,
|
||||
};
|
||||
if self.state.acceptor_state.term < msg.term {
|
||||
let mut state = self.state.clone();
|
||||
let mut state = self.state.start_change();
|
||||
state.acceptor_state.term = msg.term;
|
||||
// persist vote before sending it out
|
||||
self.state.persist(&state).await?;
|
||||
self.state.finish_change(&state).await?;
|
||||
|
||||
resp.term = self.state.acceptor_state.term;
|
||||
resp.vote_given = true as u64;
|
||||
@@ -803,9 +693,9 @@ where
|
||||
) -> Result<Option<AcceptorProposerMessage>> {
|
||||
info!("received ProposerElected {:?}", msg);
|
||||
if self.state.acceptor_state.term < msg.term {
|
||||
let mut state = self.state.clone();
|
||||
let mut state = self.state.start_change();
|
||||
state.acceptor_state.term = msg.term;
|
||||
self.state.persist(&state).await?;
|
||||
self.state.finish_change(&state).await?;
|
||||
}
|
||||
|
||||
// If our term is higher, ignore the message (next feedback will inform the compute)
|
||||
@@ -825,10 +715,10 @@ where
|
||||
}
|
||||
// Otherwise we must never attempt to truncate committed data.
|
||||
assert!(
|
||||
msg.start_streaming_at >= self.inmem.commit_lsn,
|
||||
msg.start_streaming_at >= self.state.inmem.commit_lsn,
|
||||
"attempt to truncate committed data: start_streaming_at={}, commit_lsn={}",
|
||||
msg.start_streaming_at,
|
||||
self.inmem.commit_lsn
|
||||
self.state.inmem.commit_lsn
|
||||
);
|
||||
|
||||
// TODO: cross check divergence point, check if msg.start_streaming_at corresponds to
|
||||
@@ -839,7 +729,7 @@ where
|
||||
|
||||
// and now adopt term history from proposer
|
||||
{
|
||||
let mut state = self.state.clone();
|
||||
let mut state = self.state.start_change();
|
||||
|
||||
// Here we learn initial LSN for the first time, set fields
|
||||
// interested in that.
|
||||
@@ -863,13 +753,13 @@ where
|
||||
// NB: on new clusters, this happens at the same time as
|
||||
// timeline_start_lsn initialization, it is taken outside to provide
|
||||
// upgrade.
|
||||
self.inmem.commit_lsn = max(self.inmem.commit_lsn, state.timeline_start_lsn);
|
||||
state.commit_lsn = max(state.commit_lsn, state.timeline_start_lsn);
|
||||
|
||||
// Initializing backup_lsn is useful to avoid making backup think it should upload 0 segment.
|
||||
self.inmem.backup_lsn = max(self.inmem.backup_lsn, state.timeline_start_lsn);
|
||||
state.backup_lsn = max(state.backup_lsn, state.timeline_start_lsn);
|
||||
|
||||
state.acceptor_state.term_history = msg.term_history.clone();
|
||||
self.persist_control_file(state).await?;
|
||||
self.state.finish_change(&state).await?;
|
||||
}
|
||||
|
||||
info!("start receiving WAL since {:?}", msg.start_streaming_at);
|
||||
@@ -892,63 +782,41 @@ where
|
||||
async fn update_commit_lsn(&mut self, mut candidate: Lsn) -> Result<()> {
|
||||
// Both peers and walproposer communicate this value, we might already
|
||||
// have a fresher (higher) version.
|
||||
candidate = max(candidate, self.inmem.commit_lsn);
|
||||
candidate = max(candidate, self.state.inmem.commit_lsn);
|
||||
let commit_lsn = min(candidate, self.flush_lsn());
|
||||
assert!(
|
||||
commit_lsn >= self.inmem.commit_lsn,
|
||||
commit_lsn >= self.state.inmem.commit_lsn,
|
||||
"commit_lsn monotonicity violated: old={} new={}",
|
||||
self.inmem.commit_lsn,
|
||||
self.state.inmem.commit_lsn,
|
||||
commit_lsn
|
||||
);
|
||||
|
||||
self.inmem.commit_lsn = commit_lsn;
|
||||
self.state.inmem.commit_lsn = commit_lsn;
|
||||
|
||||
// If new commit_lsn reached epoch switch, force sync of control
|
||||
// file: walproposer in sync mode is very interested when this
|
||||
// happens. Note: this is for sync-safekeepers mode only, as
|
||||
// otherwise commit_lsn might jump over epoch_start_lsn.
|
||||
if commit_lsn >= self.epoch_start_lsn && self.state.commit_lsn < self.epoch_start_lsn {
|
||||
self.persist_control_file(self.state.clone()).await?;
|
||||
self.state.flush().await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist in-memory state of control file to disk.
|
||||
//
|
||||
// TODO: passing inmem_remote_consistent_lsn everywhere is ugly, better
|
||||
// separate state completely and give Arc to all those who need it.
|
||||
pub async fn persist_inmem(&mut self, inmem_remote_consistent_lsn: Lsn) -> Result<()> {
|
||||
let mut state = self.state.clone();
|
||||
state.remote_consistent_lsn = inmem_remote_consistent_lsn;
|
||||
self.persist_control_file(state).await
|
||||
}
|
||||
|
||||
/// Persist in-memory state to the disk, taking other data from state.
|
||||
async fn persist_control_file(&mut self, mut state: SafeKeeperState) -> Result<()> {
|
||||
state.commit_lsn = self.inmem.commit_lsn;
|
||||
state.backup_lsn = self.inmem.backup_lsn;
|
||||
state.peer_horizon_lsn = self.inmem.peer_horizon_lsn;
|
||||
state.proposer_uuid = self.inmem.proposer_uuid;
|
||||
self.state.persist(&state).await
|
||||
}
|
||||
|
||||
/// Persist control file if there is something to save and enough time
|
||||
/// passed after the last save.
|
||||
pub async fn maybe_persist_inmem_control_file(
|
||||
&mut self,
|
||||
inmem_remote_consistent_lsn: Lsn,
|
||||
) -> Result<()> {
|
||||
pub async fn maybe_persist_inmem_control_file(&mut self) -> Result<()> {
|
||||
const CF_SAVE_INTERVAL: Duration = Duration::from_secs(300);
|
||||
if self.state.last_persist_at().elapsed() < CF_SAVE_INTERVAL {
|
||||
if self.state.pers.last_persist_at().elapsed() < CF_SAVE_INTERVAL {
|
||||
return Ok(());
|
||||
}
|
||||
let need_persist = self.inmem.commit_lsn > self.state.commit_lsn
|
||||
|| self.inmem.backup_lsn > self.state.backup_lsn
|
||||
|| self.inmem.peer_horizon_lsn > self.state.peer_horizon_lsn
|
||||
|| inmem_remote_consistent_lsn > self.state.remote_consistent_lsn;
|
||||
let need_persist = self.state.inmem.commit_lsn > self.state.commit_lsn
|
||||
|| self.state.inmem.backup_lsn > self.state.backup_lsn
|
||||
|| self.state.inmem.peer_horizon_lsn > self.state.peer_horizon_lsn
|
||||
|| self.state.inmem.remote_consistent_lsn > self.state.remote_consistent_lsn;
|
||||
if need_persist {
|
||||
self.persist_inmem(inmem_remote_consistent_lsn).await?;
|
||||
self.state.flush().await?;
|
||||
trace!("saved control file: {CF_SAVE_INTERVAL:?} passed");
|
||||
}
|
||||
Ok(())
|
||||
@@ -974,7 +842,7 @@ where
|
||||
// Now we know that we are in the same term as the proposer,
|
||||
// processing the message.
|
||||
|
||||
self.inmem.proposer_uuid = msg.h.proposer_uuid;
|
||||
self.state.inmem.proposer_uuid = msg.h.proposer_uuid;
|
||||
|
||||
// do the job
|
||||
if !msg.wal_data.is_empty() {
|
||||
@@ -998,15 +866,16 @@ where
|
||||
// - if we make safekeepers always send persistent value,
|
||||
// any compute restart would pull it down.
|
||||
// Thus, take max before adopting.
|
||||
self.inmem.peer_horizon_lsn = max(self.inmem.peer_horizon_lsn, msg.h.truncate_lsn);
|
||||
self.state.inmem.peer_horizon_lsn =
|
||||
max(self.state.inmem.peer_horizon_lsn, msg.h.truncate_lsn);
|
||||
|
||||
// Update truncate and commit LSN in control file.
|
||||
// To avoid negative impact on performance of extra fsync, do it only
|
||||
// when truncate_lsn delta exceeds WAL segment size.
|
||||
if self.state.peer_horizon_lsn + (self.state.server.wal_seg_size as u64)
|
||||
< self.inmem.peer_horizon_lsn
|
||||
// when commit_lsn delta exceeds WAL segment size.
|
||||
if self.state.commit_lsn + (self.state.server.wal_seg_size as u64)
|
||||
< self.state.inmem.commit_lsn
|
||||
{
|
||||
self.persist_control_file(self.state.clone()).await?;
|
||||
self.state.flush().await?;
|
||||
}
|
||||
|
||||
trace!(
|
||||
@@ -1048,27 +917,27 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
let new_backup_lsn = max(Lsn(sk_info.backup_lsn), self.inmem.backup_lsn);
|
||||
sync_control_file |=
|
||||
self.state.backup_lsn + (self.state.server.wal_seg_size as u64) < new_backup_lsn;
|
||||
self.inmem.backup_lsn = new_backup_lsn;
|
||||
self.state.inmem.backup_lsn = max(Lsn(sk_info.backup_lsn), self.state.inmem.backup_lsn);
|
||||
sync_control_file |= self.state.backup_lsn + (self.state.server.wal_seg_size as u64)
|
||||
< self.state.inmem.backup_lsn;
|
||||
|
||||
// value in sk_info should be maximized over our local in memory value.
|
||||
let new_remote_consistent_lsn = Lsn(sk_info.remote_consistent_lsn);
|
||||
assert!(self.state.remote_consistent_lsn <= new_remote_consistent_lsn);
|
||||
self.state.inmem.remote_consistent_lsn = max(
|
||||
Lsn(sk_info.remote_consistent_lsn),
|
||||
self.state.inmem.remote_consistent_lsn,
|
||||
);
|
||||
sync_control_file |= self.state.remote_consistent_lsn
|
||||
+ (self.state.server.wal_seg_size as u64)
|
||||
< new_remote_consistent_lsn;
|
||||
< self.state.inmem.remote_consistent_lsn;
|
||||
|
||||
let new_peer_horizon_lsn = max(Lsn(sk_info.peer_horizon_lsn), self.inmem.peer_horizon_lsn);
|
||||
self.state.inmem.peer_horizon_lsn = max(
|
||||
Lsn(sk_info.peer_horizon_lsn),
|
||||
self.state.inmem.peer_horizon_lsn,
|
||||
);
|
||||
sync_control_file |= self.state.peer_horizon_lsn + (self.state.server.wal_seg_size as u64)
|
||||
< new_peer_horizon_lsn;
|
||||
self.inmem.peer_horizon_lsn = new_peer_horizon_lsn;
|
||||
< self.state.inmem.peer_horizon_lsn;
|
||||
|
||||
if sync_control_file {
|
||||
let mut state = self.state.clone();
|
||||
state.remote_consistent_lsn = new_remote_consistent_lsn;
|
||||
self.persist_control_file(state).await?;
|
||||
self.state.flush().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -1096,17 +965,20 @@ mod tests {
|
||||
use postgres_ffi::WAL_SEGMENT_SIZE;
|
||||
|
||||
use super::*;
|
||||
use crate::wal_storage::Storage;
|
||||
use crate::{
|
||||
state::{PersistedPeers, TimelinePersistentState},
|
||||
wal_storage::Storage,
|
||||
};
|
||||
use std::{ops::Deref, str::FromStr, time::Instant};
|
||||
|
||||
// fake storage for tests
|
||||
struct InMemoryState {
|
||||
persisted_state: SafeKeeperState,
|
||||
persisted_state: TimelinePersistentState,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl control_file::Storage for InMemoryState {
|
||||
async fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
|
||||
async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> {
|
||||
self.persisted_state = s.clone();
|
||||
Ok(())
|
||||
}
|
||||
@@ -1117,15 +989,15 @@ mod tests {
|
||||
}
|
||||
|
||||
impl Deref for InMemoryState {
|
||||
type Target = SafeKeeperState;
|
||||
type Target = TimelinePersistentState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.persisted_state
|
||||
}
|
||||
}
|
||||
|
||||
fn test_sk_state() -> SafeKeeperState {
|
||||
let mut state = SafeKeeperState::empty();
|
||||
fn test_sk_state() -> TimelinePersistentState {
|
||||
let mut state = TimelinePersistentState::empty();
|
||||
state.server.wal_seg_size = WAL_SEGMENT_SIZE as u32;
|
||||
state.tenant_id = TenantId::from([1u8; 16]);
|
||||
state.timeline_id = TimelineId::from([1u8; 16]);
|
||||
@@ -1182,7 +1054,7 @@ mod tests {
|
||||
}
|
||||
|
||||
// reboot...
|
||||
let state = sk.state.persisted_state.clone();
|
||||
let state = sk.state.deref().clone();
|
||||
let storage = InMemoryState {
|
||||
persisted_state: state,
|
||||
};
|
||||
@@ -1321,7 +1193,7 @@ mod tests {
|
||||
use utils::Hex;
|
||||
let tenant_id = TenantId::from_str("cf0480929707ee75372337efaa5ecf96").unwrap();
|
||||
let timeline_id = TimelineId::from_str("112ded66422aa5e953e5440fa5427ac4").unwrap();
|
||||
let state = SafeKeeperState {
|
||||
let state = TimelinePersistentState {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
acceptor_state: AcceptorState {
|
||||
@@ -1405,7 +1277,7 @@ mod tests {
|
||||
|
||||
assert_eq!(Hex(&ser), Hex(&expected));
|
||||
|
||||
let deser = SafeKeeperState::des(&ser).unwrap();
|
||||
let deser = TimelinePersistentState::des(&ser).unwrap();
|
||||
|
||||
assert_eq!(deser, state);
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::failpoint_support;
|
||||
use utils::id::TenantTimelineId;
|
||||
use utils::lsn::AtomicLsn;
|
||||
use utils::pageserver_feedback::PageserverFeedback;
|
||||
|
||||
use std::cmp::{max, min};
|
||||
@@ -90,16 +89,12 @@ pub struct StandbyFeedback {
|
||||
|
||||
/// WalSenders registry. Timeline holds it (wrapped in Arc).
|
||||
pub struct WalSenders {
|
||||
/// Lsn maximized over all walsenders *and* peer data, so might be higher
|
||||
/// than what we receive from replicas.
|
||||
remote_consistent_lsn: AtomicLsn,
|
||||
mutex: Mutex<WalSendersShared>,
|
||||
}
|
||||
|
||||
impl WalSenders {
|
||||
pub fn new(remote_consistent_lsn: Lsn) -> Arc<WalSenders> {
|
||||
pub fn new() -> Arc<WalSenders> {
|
||||
Arc::new(WalSenders {
|
||||
remote_consistent_lsn: AtomicLsn::from(remote_consistent_lsn),
|
||||
mutex: Mutex::new(WalSendersShared::new()),
|
||||
})
|
||||
}
|
||||
@@ -157,7 +152,6 @@ impl WalSenders {
|
||||
let mut shared = self.mutex.lock();
|
||||
shared.get_slot_mut(id).feedback = ReplicationFeedback::Pageserver(*feedback);
|
||||
shared.update_ps_feedback();
|
||||
self.update_remote_consistent_lsn(shared.agg_ps_feedback.remote_consistent_lsn);
|
||||
}
|
||||
|
||||
/// Record standby reply.
|
||||
@@ -202,18 +196,6 @@ impl WalSenders {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get remote_consistent_lsn maximized across all walsenders and peers.
|
||||
pub fn get_remote_consistent_lsn(self: &Arc<WalSenders>) -> Lsn {
|
||||
self.remote_consistent_lsn.load()
|
||||
}
|
||||
|
||||
/// Update maximized remote_consistent_lsn, return new (potentially) value.
|
||||
pub fn update_remote_consistent_lsn(self: &Arc<WalSenders>, candidate: Lsn) -> Lsn {
|
||||
self.remote_consistent_lsn
|
||||
.fetch_max(candidate)
|
||||
.max(candidate)
|
||||
}
|
||||
|
||||
/// Unregister walsender.
|
||||
fn unregister(self: &Arc<WalSenders>, id: WalSenderId) {
|
||||
let mut shared = self.mutex.lock();
|
||||
@@ -444,7 +426,11 @@ impl SafekeeperPostgresHandler {
|
||||
wal_reader,
|
||||
send_buf: [0; MAX_SEND_SIZE],
|
||||
};
|
||||
let mut reply_reader = ReplyReader { reader, ws_guard };
|
||||
let mut reply_reader = ReplyReader {
|
||||
reader,
|
||||
ws_guard,
|
||||
tli,
|
||||
};
|
||||
|
||||
let res = tokio::select! {
|
||||
// todo: add read|write .context to these errors
|
||||
@@ -638,17 +624,18 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> WalSender<'_, IO> {
|
||||
struct ReplyReader<IO> {
|
||||
reader: PostgresBackendReader<IO>,
|
||||
ws_guard: Arc<WalSenderGuard>,
|
||||
tli: Arc<Timeline>,
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> ReplyReader<IO> {
|
||||
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
|
||||
loop {
|
||||
let msg = self.reader.read_copy_message().await?;
|
||||
self.handle_feedback(&msg)?
|
||||
self.handle_feedback(&msg).await?
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_feedback(&mut self, msg: &Bytes) -> anyhow::Result<()> {
|
||||
async fn handle_feedback(&mut self, msg: &Bytes) -> anyhow::Result<()> {
|
||||
match msg.first().cloned() {
|
||||
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[1..] because we skip the tag byte.
|
||||
@@ -675,6 +662,9 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> ReplyReader<IO> {
|
||||
self.ws_guard
|
||||
.walsenders
|
||||
.record_ps_feedback(self.ws_guard.id, &ps_feedback);
|
||||
self.tli
|
||||
.update_remote_consistent_lsn(ps_feedback.remote_consistent_lsn)
|
||||
.await;
|
||||
// in principle new remote_consistent_lsn could allow to
|
||||
// deactivate the timeline, but we check that regularly through
|
||||
// broker updated, not need to do it here
|
||||
|
||||
197
safekeeper/src/state.rs
Normal file
197
safekeeper/src/state.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
//! Defines per timeline data stored persistently (SafeKeeperPersistentState)
|
||||
//! and its wrapper with in memory layer (SafekeeperState).
|
||||
|
||||
use std::ops::Deref;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utils::{
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
lsn::Lsn,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
control_file,
|
||||
safekeeper::{AcceptorState, PersistedPeerInfo, PgUuid, ServerInfo, TermHistory},
|
||||
};
|
||||
|
||||
/// Persistent information stored on safekeeper node about timeline.
|
||||
/// On disk data is prefixed by magic and format version and followed by checksum.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct TimelinePersistentState {
|
||||
#[serde(with = "hex")]
|
||||
pub tenant_id: TenantId,
|
||||
#[serde(with = "hex")]
|
||||
pub timeline_id: TimelineId,
|
||||
/// persistent acceptor state
|
||||
pub acceptor_state: AcceptorState,
|
||||
/// information about server
|
||||
pub server: ServerInfo,
|
||||
/// Unique id of the last *elected* proposer we dealt with. Not needed
|
||||
/// for correctness, exists for monitoring purposes.
|
||||
#[serde(with = "hex")]
|
||||
pub proposer_uuid: PgUuid,
|
||||
/// Since which LSN this timeline generally starts. Safekeeper might have
|
||||
/// joined later.
|
||||
pub timeline_start_lsn: Lsn,
|
||||
/// Since which LSN safekeeper has (had) WAL for this timeline.
|
||||
/// All WAL segments next to one containing local_start_lsn are
|
||||
/// filled with data from the beginning.
|
||||
pub local_start_lsn: Lsn,
|
||||
/// Part of WAL acknowledged by quorum *and available locally*. Always points
|
||||
/// to record boundary.
|
||||
pub commit_lsn: Lsn,
|
||||
/// LSN that points to the end of the last backed up segment. Useful to
|
||||
/// persist to avoid finding out offloading progress on boot.
|
||||
pub backup_lsn: Lsn,
|
||||
/// Minimal LSN which may be needed for recovery of some safekeeper (end_lsn
|
||||
/// of last record streamed to everyone). Persisting it helps skipping
|
||||
/// recovery in walproposer, generally we compute it from peers. In
|
||||
/// walproposer proto called 'truncate_lsn'. Updates are currently drived
|
||||
/// only by walproposer.
|
||||
pub peer_horizon_lsn: Lsn,
|
||||
/// LSN of the oldest known checkpoint made by pageserver and successfully
|
||||
/// pushed to s3. We don't remove WAL beyond it. Persisted only for
|
||||
/// informational purposes, we receive it from pageserver (or broker).
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
// Peers and their state as we remember it. Knowing peers themselves is
|
||||
// fundamental; but state is saved here only for informational purposes and
|
||||
// obviously can be stale. (Currently not saved at all, but let's provision
|
||||
// place to have less file version upgrades).
|
||||
pub peers: PersistedPeers,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct PersistedPeers(pub Vec<(NodeId, PersistedPeerInfo)>);
|
||||
|
||||
impl TimelinePersistentState {
|
||||
pub fn new(
|
||||
ttid: &TenantTimelineId,
|
||||
server_info: ServerInfo,
|
||||
peers: Vec<NodeId>,
|
||||
commit_lsn: Lsn,
|
||||
local_start_lsn: Lsn,
|
||||
) -> TimelinePersistentState {
|
||||
TimelinePersistentState {
|
||||
tenant_id: ttid.tenant_id,
|
||||
timeline_id: ttid.timeline_id,
|
||||
acceptor_state: AcceptorState {
|
||||
term: 0,
|
||||
term_history: TermHistory::empty(),
|
||||
},
|
||||
server: server_info,
|
||||
proposer_uuid: [0; 16],
|
||||
timeline_start_lsn: Lsn(0),
|
||||
local_start_lsn,
|
||||
commit_lsn,
|
||||
backup_lsn: local_start_lsn,
|
||||
peer_horizon_lsn: local_start_lsn,
|
||||
remote_consistent_lsn: Lsn(0),
|
||||
peers: PersistedPeers(
|
||||
peers
|
||||
.iter()
|
||||
.map(|p| (*p, PersistedPeerInfo::new()))
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn empty() -> Self {
|
||||
use crate::safekeeper::UNKNOWN_SERVER_VERSION;
|
||||
|
||||
TimelinePersistentState::new(
|
||||
&TenantTimelineId::empty(),
|
||||
ServerInfo {
|
||||
pg_version: UNKNOWN_SERVER_VERSION, /* Postgres server version */
|
||||
system_id: 0, /* Postgres system identifier */
|
||||
wal_seg_size: 0,
|
||||
},
|
||||
vec![],
|
||||
Lsn::INVALID,
|
||||
Lsn::INVALID,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
// In memory safekeeper state. Fields mirror ones in `SafeKeeperPersistentState`; values
|
||||
// are not flushed yet.
|
||||
pub struct TimelineMemState {
|
||||
pub commit_lsn: Lsn,
|
||||
pub backup_lsn: Lsn,
|
||||
pub peer_horizon_lsn: Lsn,
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
#[serde(with = "hex")]
|
||||
pub proposer_uuid: PgUuid,
|
||||
}
|
||||
|
||||
/// Safekeeper persistent state plus in memory layer, to avoid frequent fsyncs
|
||||
/// when we update fields like commit_lsn which don't need immediate
|
||||
/// persistence. Provides transactional like API to atomically update the state.
|
||||
///
|
||||
/// Implements Deref into *persistent* part.
|
||||
pub struct TimelineState<CTRL: control_file::Storage> {
|
||||
pub inmem: TimelineMemState,
|
||||
pub pers: CTRL, // persistent
|
||||
}
|
||||
|
||||
impl<CTRL> TimelineState<CTRL>
|
||||
where
|
||||
CTRL: control_file::Storage,
|
||||
{
|
||||
pub fn new(state: CTRL) -> Self {
|
||||
TimelineState {
|
||||
inmem: TimelineMemState {
|
||||
commit_lsn: state.commit_lsn,
|
||||
backup_lsn: state.backup_lsn,
|
||||
peer_horizon_lsn: state.peer_horizon_lsn,
|
||||
remote_consistent_lsn: state.remote_consistent_lsn,
|
||||
proposer_uuid: state.proposer_uuid,
|
||||
},
|
||||
pers: state,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start atomic change. Returns SafeKeeperPersistentState with in memory
|
||||
/// values applied; the protocol is to 1) change returned struct as desired
|
||||
/// 2) atomically persist it with finish_change.
|
||||
pub fn start_change(&self) -> TimelinePersistentState {
|
||||
let mut s = self.pers.clone();
|
||||
s.commit_lsn = self.inmem.commit_lsn;
|
||||
s.backup_lsn = self.inmem.backup_lsn;
|
||||
s.peer_horizon_lsn = self.inmem.peer_horizon_lsn;
|
||||
s.remote_consistent_lsn = self.inmem.remote_consistent_lsn;
|
||||
s.proposer_uuid = self.inmem.proposer_uuid;
|
||||
s
|
||||
}
|
||||
|
||||
/// Persist given state. c.f. start_change.
|
||||
pub async fn finish_change(&mut self, s: &TimelinePersistentState) -> Result<()> {
|
||||
self.pers.persist(s).await?;
|
||||
// keep in memory values up to date
|
||||
self.inmem.commit_lsn = s.commit_lsn;
|
||||
self.inmem.backup_lsn = s.backup_lsn;
|
||||
self.inmem.peer_horizon_lsn = s.peer_horizon_lsn;
|
||||
self.inmem.remote_consistent_lsn = s.remote_consistent_lsn;
|
||||
self.inmem.proposer_uuid = s.proposer_uuid;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Flush in memory values.
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
let s = self.start_change();
|
||||
self.finish_change(&s).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<CTRL> Deref for TimelineState<CTRL>
|
||||
where
|
||||
CTRL: control_file::Storage,
|
||||
{
|
||||
type Target = TimelinePersistentState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.pers
|
||||
}
|
||||
}
|
||||
@@ -28,10 +28,11 @@ use storage_broker::proto::TenantTimelineId as ProtoTenantTimelineId;
|
||||
use crate::receive_wal::WalReceivers;
|
||||
use crate::recovery::{recovery_main, Donor, RecoveryNeededInfo};
|
||||
use crate::safekeeper::{
|
||||
AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState,
|
||||
SafekeeperMemState, ServerInfo, Term, TermLsn, INVALID_TERM,
|
||||
AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, ServerInfo, Term, TermLsn,
|
||||
INVALID_TERM,
|
||||
};
|
||||
use crate::send_wal::WalSenders;
|
||||
use crate::state::{TimelineMemState, TimelinePersistentState};
|
||||
use crate::{control_file, safekeeper::UNKNOWN_SERVER_VERSION};
|
||||
|
||||
use crate::metrics::FullTimelineInfo;
|
||||
@@ -121,7 +122,7 @@ impl SharedState {
|
||||
fn create_new(
|
||||
conf: &SafeKeeperConf,
|
||||
ttid: &TenantTimelineId,
|
||||
state: SafeKeeperState,
|
||||
state: TimelinePersistentState,
|
||||
) -> Result<Self> {
|
||||
if state.server.wal_seg_size == 0 {
|
||||
bail!(TimelineError::UninitializedWalSegSize(*ttid));
|
||||
@@ -175,30 +176,28 @@ impl SharedState {
|
||||
})
|
||||
}
|
||||
|
||||
fn is_active(&self, num_computes: usize, remote_consistent_lsn: Lsn) -> bool {
|
||||
fn is_active(&self, num_computes: usize) -> bool {
|
||||
self.is_wal_backup_required(num_computes)
|
||||
// FIXME: add tracking of relevant pageservers and check them here individually,
|
||||
// otherwise migration won't work (we suspend too early).
|
||||
|| remote_consistent_lsn < self.sk.inmem.commit_lsn
|
||||
|| self.sk.state.inmem.remote_consistent_lsn < self.sk.state.inmem.commit_lsn
|
||||
}
|
||||
|
||||
/// Mark timeline active/inactive and return whether s3 offloading requires
|
||||
/// start/stop action. If timeline is deactivated, control file is persisted
|
||||
/// as maintenance task does that only for active timelines.
|
||||
async fn update_status(
|
||||
&mut self,
|
||||
num_computes: usize,
|
||||
remote_consistent_lsn: Lsn,
|
||||
ttid: TenantTimelineId,
|
||||
) -> bool {
|
||||
let is_active = self.is_active(num_computes, remote_consistent_lsn);
|
||||
async fn update_status(&mut self, num_computes: usize, ttid: TenantTimelineId) -> bool {
|
||||
let is_active = self.is_active(num_computes);
|
||||
if self.active != is_active {
|
||||
info!(
|
||||
"timeline {} active={} now, remote_consistent_lsn={}, commit_lsn={}",
|
||||
ttid, is_active, remote_consistent_lsn, self.sk.inmem.commit_lsn
|
||||
ttid,
|
||||
is_active,
|
||||
self.sk.state.inmem.remote_consistent_lsn,
|
||||
self.sk.state.inmem.commit_lsn
|
||||
);
|
||||
if !is_active {
|
||||
if let Err(e) = self.sk.persist_inmem(remote_consistent_lsn).await {
|
||||
if let Err(e) = self.sk.state.flush().await {
|
||||
warn!("control file save in update_status failed: {:?}", e);
|
||||
}
|
||||
}
|
||||
@@ -212,8 +211,8 @@ impl SharedState {
|
||||
let seg_size = self.get_wal_seg_size();
|
||||
num_computes > 0 ||
|
||||
// Currently only the whole segment is offloaded, so compare segment numbers.
|
||||
(self.sk.inmem.commit_lsn.segment_number(seg_size) >
|
||||
self.sk.inmem.backup_lsn.segment_number(seg_size))
|
||||
(self.sk.state.inmem.commit_lsn.segment_number(seg_size) >
|
||||
self.sk.state.inmem.backup_lsn.segment_number(seg_size))
|
||||
}
|
||||
|
||||
/// Is current state of s3 offloading is not what it ought to be?
|
||||
@@ -227,7 +226,7 @@ impl SharedState {
|
||||
};
|
||||
trace!(
|
||||
"timeline {} s3 offloading action {} pending: num_computes={}, commit_lsn={}, backup_lsn={}",
|
||||
self.sk.state.timeline_id, action_pending, num_computes, self.sk.inmem.commit_lsn, self.sk.inmem.backup_lsn
|
||||
self.sk.state.timeline_id, action_pending, num_computes, self.sk.state.inmem.commit_lsn, self.sk.state.inmem.backup_lsn
|
||||
);
|
||||
}
|
||||
res
|
||||
@@ -248,7 +247,6 @@ impl SharedState {
|
||||
&self,
|
||||
ttid: &TenantTimelineId,
|
||||
conf: &SafeKeeperConf,
|
||||
remote_consistent_lsn: Lsn,
|
||||
) -> SafekeeperTimelineInfo {
|
||||
SafekeeperTimelineInfo {
|
||||
safekeeper_id: conf.my_id.0,
|
||||
@@ -260,15 +258,15 @@ impl SharedState {
|
||||
last_log_term: self.sk.get_epoch(),
|
||||
flush_lsn: self.sk.flush_lsn().0,
|
||||
// note: this value is not flushed to control file yet and can be lost
|
||||
commit_lsn: self.sk.inmem.commit_lsn.0,
|
||||
remote_consistent_lsn: remote_consistent_lsn.0,
|
||||
peer_horizon_lsn: self.sk.inmem.peer_horizon_lsn.0,
|
||||
commit_lsn: self.sk.state.inmem.commit_lsn.0,
|
||||
remote_consistent_lsn: self.sk.state.inmem.remote_consistent_lsn.0,
|
||||
peer_horizon_lsn: self.sk.state.inmem.peer_horizon_lsn.0,
|
||||
safekeeper_connstr: conf
|
||||
.advertise_pg_addr
|
||||
.to_owned()
|
||||
.unwrap_or(conf.listen_pg_addr.clone()),
|
||||
http_connstr: conf.listen_http_addr.to_owned(),
|
||||
backup_lsn: self.sk.inmem.backup_lsn.0,
|
||||
backup_lsn: self.sk.state.inmem.backup_lsn.0,
|
||||
local_start_lsn: self.sk.state.local_start_lsn.0,
|
||||
availability_zone: conf.availability_zone.clone(),
|
||||
}
|
||||
@@ -366,7 +364,6 @@ impl Timeline {
|
||||
let _enter = info_span!("load_timeline", timeline = %ttid.timeline_id).entered();
|
||||
|
||||
let shared_state = SharedState::restore(conf, &ttid)?;
|
||||
let rcl = shared_state.sk.state.remote_consistent_lsn;
|
||||
let (commit_lsn_watch_tx, commit_lsn_watch_rx) =
|
||||
watch::channel(shared_state.sk.state.commit_lsn);
|
||||
let (term_flush_lsn_watch_tx, term_flush_lsn_watch_rx) = watch::channel(TermLsn::from((
|
||||
@@ -383,7 +380,7 @@ impl Timeline {
|
||||
term_flush_lsn_watch_tx,
|
||||
term_flush_lsn_watch_rx,
|
||||
mutex: Mutex::new(shared_state),
|
||||
walsenders: WalSenders::new(rcl),
|
||||
walsenders: WalSenders::new(),
|
||||
walreceivers: WalReceivers::new(),
|
||||
cancellation_rx,
|
||||
cancellation_tx,
|
||||
@@ -404,7 +401,8 @@ impl Timeline {
|
||||
let (term_flush_lsn_watch_tx, term_flush_lsn_watch_rx) =
|
||||
watch::channel(TermLsn::from((INVALID_TERM, Lsn::INVALID)));
|
||||
let (cancellation_tx, cancellation_rx) = watch::channel(false);
|
||||
let state = SafeKeeperState::new(&ttid, server_info, vec![], commit_lsn, local_start_lsn);
|
||||
let state =
|
||||
TimelinePersistentState::new(&ttid, server_info, vec![], commit_lsn, local_start_lsn);
|
||||
|
||||
Ok(Timeline {
|
||||
ttid,
|
||||
@@ -414,7 +412,7 @@ impl Timeline {
|
||||
term_flush_lsn_watch_tx,
|
||||
term_flush_lsn_watch_rx,
|
||||
mutex: Mutex::new(SharedState::create_new(conf, &ttid, state)?),
|
||||
walsenders: WalSenders::new(Lsn(0)),
|
||||
walsenders: WalSenders::new(),
|
||||
walreceivers: WalReceivers::new(),
|
||||
cancellation_rx,
|
||||
cancellation_tx,
|
||||
@@ -448,7 +446,7 @@ impl Timeline {
|
||||
fs::create_dir_all(&self.timeline_dir).await?;
|
||||
|
||||
// Write timeline to disk and start background tasks.
|
||||
if let Err(e) = shared_state.sk.persist_inmem(Lsn::INVALID).await {
|
||||
if let Err(e) = shared_state.sk.state.flush().await {
|
||||
// Bootstrap failed, cancel timeline and remove timeline directory.
|
||||
self.cancel(shared_state);
|
||||
|
||||
@@ -523,11 +521,7 @@ impl Timeline {
|
||||
|
||||
async fn update_status(&self, shared_state: &mut SharedState) -> bool {
|
||||
shared_state
|
||||
.update_status(
|
||||
self.walreceivers.get_num(),
|
||||
self.get_walsenders().get_remote_consistent_lsn(),
|
||||
self.ttid,
|
||||
)
|
||||
.update_status(self.walreceivers.get_num(), self.ttid)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -558,8 +552,8 @@ impl Timeline {
|
||||
}
|
||||
let shared_state = self.write_shared_state().await;
|
||||
if self.walreceivers.get_num() == 0 {
|
||||
return shared_state.sk.inmem.commit_lsn == Lsn(0) || // no data at all yet
|
||||
reported_remote_consistent_lsn >= shared_state.sk.inmem.commit_lsn;
|
||||
return shared_state.sk.state.inmem.commit_lsn == Lsn(0) || // no data at all yet
|
||||
reported_remote_consistent_lsn >= shared_state.sk.state.inmem.commit_lsn;
|
||||
}
|
||||
false
|
||||
}
|
||||
@@ -623,7 +617,7 @@ impl Timeline {
|
||||
resp.pageserver_feedback = ps_feedback;
|
||||
}
|
||||
|
||||
commit_lsn = shared_state.sk.inmem.commit_lsn;
|
||||
commit_lsn = shared_state.sk.state.inmem.commit_lsn;
|
||||
term_flush_lsn =
|
||||
TermLsn::from((shared_state.sk.get_term(), shared_state.sk.flush_lsn()));
|
||||
}
|
||||
@@ -647,14 +641,14 @@ impl Timeline {
|
||||
}
|
||||
|
||||
/// Returns state of the timeline.
|
||||
pub async fn get_state(&self) -> (SafekeeperMemState, SafeKeeperState) {
|
||||
pub async fn get_state(&self) -> (TimelineMemState, TimelinePersistentState) {
|
||||
let state = self.write_shared_state().await;
|
||||
(state.sk.inmem.clone(), state.sk.state.clone())
|
||||
(state.sk.state.inmem.clone(), state.sk.state.clone())
|
||||
}
|
||||
|
||||
/// Returns latest backup_lsn.
|
||||
pub async fn get_wal_backup_lsn(&self) -> Lsn {
|
||||
self.write_shared_state().await.sk.inmem.backup_lsn
|
||||
self.write_shared_state().await.sk.state.inmem.backup_lsn
|
||||
}
|
||||
|
||||
/// Sets backup_lsn to the given value.
|
||||
@@ -664,7 +658,7 @@ impl Timeline {
|
||||
}
|
||||
|
||||
let mut state = self.write_shared_state().await;
|
||||
state.sk.inmem.backup_lsn = max(state.sk.inmem.backup_lsn, backup_lsn);
|
||||
state.sk.state.inmem.backup_lsn = max(state.sk.state.inmem.backup_lsn, backup_lsn);
|
||||
// we should check whether to shut down offloader, but this will be done
|
||||
// soon by peer communication anyway.
|
||||
Ok(())
|
||||
@@ -673,21 +667,11 @@ impl Timeline {
|
||||
/// Get safekeeper info for broadcasting to broker and other peers.
|
||||
pub async fn get_safekeeper_info(&self, conf: &SafeKeeperConf) -> SafekeeperTimelineInfo {
|
||||
let shared_state = self.write_shared_state().await;
|
||||
shared_state.get_safekeeper_info(
|
||||
&self.ttid,
|
||||
conf,
|
||||
self.walsenders.get_remote_consistent_lsn(),
|
||||
)
|
||||
shared_state.get_safekeeper_info(&self.ttid, conf)
|
||||
}
|
||||
|
||||
/// Update timeline state with peer safekeeper data.
|
||||
pub async fn record_safekeeper_info(&self, mut sk_info: SafekeeperTimelineInfo) -> Result<()> {
|
||||
// Update local remote_consistent_lsn in memory (in .walsenders) and in
|
||||
// sk_info to pass it down to control file.
|
||||
sk_info.remote_consistent_lsn = self
|
||||
.walsenders
|
||||
.update_remote_consistent_lsn(Lsn(sk_info.remote_consistent_lsn))
|
||||
.0;
|
||||
pub async fn record_safekeeper_info(&self, sk_info: SafekeeperTimelineInfo) -> Result<()> {
|
||||
let is_wal_backup_action_pending: bool;
|
||||
let commit_lsn: Lsn;
|
||||
{
|
||||
@@ -696,7 +680,7 @@ impl Timeline {
|
||||
let peer_info = PeerInfo::from_sk_info(&sk_info, Instant::now());
|
||||
shared_state.peers_info.upsert(&peer_info);
|
||||
is_wal_backup_action_pending = self.update_status(&mut shared_state).await;
|
||||
commit_lsn = shared_state.sk.inmem.commit_lsn;
|
||||
commit_lsn = shared_state.sk.state.inmem.commit_lsn;
|
||||
}
|
||||
self.commit_lsn_watch_tx.send(commit_lsn)?;
|
||||
// Wake up wal backup launcher, if it is time to stop the offloading.
|
||||
@@ -706,6 +690,13 @@ impl Timeline {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update in memory remote consistent lsn.
|
||||
pub async fn update_remote_consistent_lsn(&self, candidate: Lsn) {
|
||||
let mut shared_state = self.write_shared_state().await;
|
||||
shared_state.sk.state.inmem.remote_consistent_lsn =
|
||||
max(shared_state.sk.state.inmem.remote_consistent_lsn, candidate);
|
||||
}
|
||||
|
||||
pub async fn get_peers(&self, conf: &SafeKeeperConf) -> Vec<PeerInfo> {
|
||||
let shared_state = self.write_shared_state().await;
|
||||
shared_state.get_peers(conf.heartbeat_timeout)
|
||||
@@ -836,11 +827,10 @@ impl Timeline {
|
||||
/// to date so that storage nodes restart doesn't cause many pageserver ->
|
||||
/// safekeeper reconnections.
|
||||
pub async fn maybe_persist_control_file(&self) -> Result<()> {
|
||||
let remote_consistent_lsn = self.walsenders.get_remote_consistent_lsn();
|
||||
self.write_shared_state()
|
||||
.await
|
||||
.sk
|
||||
.maybe_persist_inmem_control_file(remote_consistent_lsn)
|
||||
.maybe_persist_inmem_control_file()
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -862,10 +852,9 @@ impl Timeline {
|
||||
num_computes: self.walreceivers.get_num() as u32,
|
||||
last_removed_segno: state.last_removed_segno,
|
||||
epoch_start_lsn: state.sk.epoch_start_lsn,
|
||||
mem_state: state.sk.inmem.clone(),
|
||||
mem_state: state.sk.state.inmem.clone(),
|
||||
persisted_state: state.sk.state.clone(),
|
||||
flush_lsn: state.sk.wal_store.flush_lsn(),
|
||||
remote_consistent_lsn: self.get_walsenders().get_remote_consistent_lsn(),
|
||||
wal_storage: state.sk.wal_store.get_metrics(),
|
||||
})
|
||||
} else {
|
||||
@@ -889,7 +878,7 @@ impl Timeline {
|
||||
num_computes: self.walreceivers.get_num() as u32,
|
||||
last_removed_segno: state.last_removed_segno,
|
||||
epoch_start_lsn: state.sk.epoch_start_lsn,
|
||||
mem_state: state.sk.inmem.clone(),
|
||||
mem_state: state.sk.state.inmem.clone(),
|
||||
write_lsn,
|
||||
write_record_lsn,
|
||||
flush_lsn,
|
||||
|
||||
@@ -23,7 +23,7 @@ use tokio::io::{AsyncReadExt, AsyncSeekExt};
|
||||
use tracing::*;
|
||||
|
||||
use crate::metrics::{time_io_closure, WalStorageMetrics, REMOVED_WAL_SEGMENTS};
|
||||
use crate::safekeeper::SafeKeeperState;
|
||||
use crate::state::TimelinePersistentState;
|
||||
use crate::wal_backup::read_object;
|
||||
use crate::SafeKeeperConf;
|
||||
use postgres_ffi::waldecoder::WalStreamDecoder;
|
||||
@@ -125,7 +125,7 @@ impl PhysicalStorage {
|
||||
ttid: &TenantTimelineId,
|
||||
timeline_dir: Utf8PathBuf,
|
||||
conf: &SafeKeeperConf,
|
||||
state: &SafeKeeperState,
|
||||
state: &TimelinePersistentState,
|
||||
) -> Result<PhysicalStorage> {
|
||||
let wal_seg_size = state.server.wal_seg_size as usize;
|
||||
|
||||
@@ -525,7 +525,7 @@ impl WalReader {
|
||||
pub fn new(
|
||||
workdir: Utf8PathBuf,
|
||||
timeline_dir: Utf8PathBuf,
|
||||
state: &SafeKeeperState,
|
||||
state: &TimelinePersistentState,
|
||||
start_pos: Lsn,
|
||||
enable_remote_read: bool,
|
||||
) -> Result<Self> {
|
||||
|
||||
@@ -54,7 +54,10 @@ class NeonBroker:
|
||||
else:
|
||||
break # success
|
||||
|
||||
def stop(self):
|
||||
def stop(self, immediate: bool = False):
|
||||
if self.handle is not None:
|
||||
self.handle.terminate()
|
||||
if immediate:
|
||||
self.handle.kill()
|
||||
else:
|
||||
self.handle.terminate()
|
||||
self.handle.wait()
|
||||
|
||||
@@ -24,6 +24,7 @@ from urllib.parse import urlparse
|
||||
|
||||
import asyncpg
|
||||
import backoff
|
||||
import httpx
|
||||
import jwt
|
||||
import psycopg2
|
||||
import pytest
|
||||
@@ -40,6 +41,7 @@ from psycopg2.extensions import make_dsn, parse_dsn
|
||||
from typing_extensions import Literal
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from fixtures import overlayfs
|
||||
from fixtures.broker import NeonBroker
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.pageserver.allowed_errors import (
|
||||
@@ -424,6 +426,7 @@ class NeonEnvBuilder:
|
||||
pg_version: PgVersion,
|
||||
test_name: str,
|
||||
test_output_dir: Path,
|
||||
test_overlay_dir: Optional[Path] = None,
|
||||
pageserver_remote_storage: Optional[RemoteStorage] = None,
|
||||
pageserver_config_override: Optional[str] = None,
|
||||
num_safekeepers: int = 1,
|
||||
@@ -468,6 +471,9 @@ class NeonEnvBuilder:
|
||||
self.initial_timeline = initial_timeline or TimelineId.generate()
|
||||
self.scrub_on_exit = False
|
||||
self.test_output_dir = test_output_dir
|
||||
self.test_overlay_dir = test_overlay_dir
|
||||
self.overlay_mounts_created_by_us: List[Tuple[str, Path]] = []
|
||||
self.config_init_force: Optional[str] = None
|
||||
|
||||
assert test_name.startswith(
|
||||
"test_"
|
||||
@@ -547,7 +553,10 @@ class NeonEnvBuilder:
|
||||
tenants_to_dir = self.repo_dir / ps_dir.name / "tenants"
|
||||
|
||||
log.info(f"Copying pageserver tenants directory {tenants_from_dir} to {tenants_to_dir}")
|
||||
shutil.copytree(tenants_from_dir, tenants_to_dir)
|
||||
if self.test_overlay_dir is None:
|
||||
shutil.copytree(tenants_from_dir, tenants_to_dir)
|
||||
else:
|
||||
self.overlay_mount(f"{ps_dir.name}:tenants", tenants_from_dir, tenants_to_dir)
|
||||
|
||||
for sk_from_dir in (repo_dir / "safekeepers").glob("sk*"):
|
||||
sk_to_dir = self.repo_dir / "safekeepers" / sk_from_dir.name
|
||||
@@ -556,9 +565,16 @@ class NeonEnvBuilder:
|
||||
shutil.copytree(sk_from_dir, sk_to_dir, ignore=shutil.ignore_patterns("*.log", "*.pid"))
|
||||
|
||||
shutil.rmtree(self.repo_dir / "local_fs_remote_storage", ignore_errors=True)
|
||||
shutil.copytree(
|
||||
repo_dir / "local_fs_remote_storage", self.repo_dir / "local_fs_remote_storage"
|
||||
)
|
||||
if self.test_overlay_dir is None:
|
||||
shutil.copytree(
|
||||
repo_dir / "local_fs_remote_storage", self.repo_dir / "local_fs_remote_storage"
|
||||
)
|
||||
else:
|
||||
self.overlay_mount(
|
||||
"local_fs_remote_storage",
|
||||
repo_dir / "local_fs_remote_storage",
|
||||
self.repo_dir / "local_fs_remote_storage",
|
||||
)
|
||||
|
||||
if (attachments_json := Path(repo_dir / "attachments.json")).exists():
|
||||
shutil.copyfile(attachments_json, self.repo_dir / attachments_json.name)
|
||||
@@ -575,6 +591,69 @@ class NeonEnvBuilder:
|
||||
|
||||
return self.env
|
||||
|
||||
def overlay_mount(self, ident: str, srcdir: Path, dstdir: Path):
|
||||
"""
|
||||
Mount `srcdir` as an overlayfs mount at `dstdir`.
|
||||
The overlayfs `upperdir` and `workdir` will be placed in test_overlay_dir.
|
||||
"""
|
||||
assert self.test_overlay_dir
|
||||
assert (
|
||||
self.test_output_dir in dstdir.parents
|
||||
) # so that teardown & test_overlay_dir fixture work
|
||||
assert srcdir.is_dir()
|
||||
dstdir.mkdir(exist_ok=False, parents=False)
|
||||
ident_state_dir = self.test_overlay_dir / ident
|
||||
upper = ident_state_dir / "upper"
|
||||
work = ident_state_dir / "work"
|
||||
ident_state_dir.mkdir(
|
||||
exist_ok=False, parents=False
|
||||
) # exists_ok=False also checks uniqueness in self.overlay_mounts
|
||||
upper.mkdir()
|
||||
work.mkdir()
|
||||
cmd = [
|
||||
"sudo",
|
||||
"mount",
|
||||
"-t",
|
||||
"overlay",
|
||||
"overlay",
|
||||
"-o",
|
||||
f"lowerdir={srcdir},upperdir={upper},workdir={work}",
|
||||
str(dstdir),
|
||||
]
|
||||
log.info(f"Mounting overlayfs srcdir={srcdir} dstdir={dstdir}: {cmd}")
|
||||
subprocess_capture(
|
||||
self.test_output_dir, cmd, check=True, echo_stderr=True, echo_stdout=True
|
||||
)
|
||||
self.overlay_mounts_created_by_us.append((ident, dstdir))
|
||||
|
||||
def overlay_cleanup_teardown(self):
|
||||
"""
|
||||
Unmount the overlayfs mounts created by `self.overlay_mount()`.
|
||||
Supposed to be called during env teardown.
|
||||
"""
|
||||
if self.test_overlay_dir is None:
|
||||
return
|
||||
while len(self.overlay_mounts_created_by_us) > 0:
|
||||
(ident, mountpoint) = self.overlay_mounts_created_by_us.pop()
|
||||
ident_state_dir = self.test_overlay_dir / ident
|
||||
cmd = ["sudo", "umount", str(mountpoint)]
|
||||
log.info(
|
||||
f"Unmounting overlayfs mount created during setup for ident {ident} at {mountpoint}: {cmd}"
|
||||
)
|
||||
subprocess_capture(
|
||||
self.test_output_dir, cmd, check=True, echo_stderr=True, echo_stdout=True
|
||||
)
|
||||
log.info(
|
||||
f"Cleaning up overlayfs state dir (owned by root user) for ident {ident} at {ident_state_dir}"
|
||||
)
|
||||
cmd = ["sudo", "rm", "-rf", str(ident_state_dir)]
|
||||
subprocess_capture(
|
||||
self.test_output_dir, cmd, check=True, echo_stderr=True, echo_stdout=True
|
||||
)
|
||||
|
||||
# assert all overlayfs mounts in our test directory are gone
|
||||
assert [] == list(overlayfs.iter_mounts_beneath(self.test_overlay_dir))
|
||||
|
||||
def enable_scrub_on_exit(self):
|
||||
"""
|
||||
Call this if you would like the fixture to automatically run
|
||||
@@ -676,17 +755,12 @@ class NeonEnvBuilder:
|
||||
# Stop all the nodes.
|
||||
if self.env:
|
||||
log.info("Cleaning up all storage and compute nodes")
|
||||
self.env.endpoints.stop_all()
|
||||
for sk in self.env.safekeepers:
|
||||
sk.stop(immediate=True)
|
||||
|
||||
for pageserver in self.env.pageservers:
|
||||
pageserver.assert_no_metric_errors()
|
||||
|
||||
pageserver.stop(immediate=True)
|
||||
|
||||
self.env.attachment_service.stop(immediate=True)
|
||||
|
||||
self.env.stop(
|
||||
immediate=True,
|
||||
# if the test threw an exception, don't check for errors
|
||||
# as a failing assertion would cause the cleanup below to fail
|
||||
ps_assert_metric_no_errors=(exc_type is None),
|
||||
)
|
||||
cleanup_error = None
|
||||
|
||||
if self.scrub_on_exit:
|
||||
@@ -696,6 +770,13 @@ class NeonEnvBuilder:
|
||||
log.error(f"Error during remote storage scrub: {e}")
|
||||
cleanup_error = e
|
||||
|
||||
try:
|
||||
self.overlay_cleanup_teardown()
|
||||
except Exception as e:
|
||||
log.error(f"Error cleaning up overlay state: {e}")
|
||||
if cleanup_error is not None:
|
||||
cleanup_error = e
|
||||
|
||||
try:
|
||||
self.cleanup_remote_storage()
|
||||
except Exception as e:
|
||||
@@ -848,7 +929,7 @@ class NeonEnv:
|
||||
cfg["safekeepers"].append(sk_cfg)
|
||||
|
||||
log.info(f"Config: {cfg}")
|
||||
self.neon_cli.init(cfg)
|
||||
self.neon_cli.init(cfg, force=config.config_init_force)
|
||||
|
||||
def start(self):
|
||||
# Start up broker, pageserver and all safekeepers
|
||||
@@ -862,6 +943,20 @@ class NeonEnv:
|
||||
for safekeeper in self.safekeepers:
|
||||
safekeeper.start()
|
||||
|
||||
def stop(self, immediate=False, ps_assert_metric_no_errors=False):
|
||||
"""
|
||||
After this method returns, there should be no child processes running.
|
||||
"""
|
||||
self.endpoints.stop_all()
|
||||
for sk in self.safekeepers:
|
||||
sk.stop(immediate=immediate)
|
||||
for pageserver in self.pageservers:
|
||||
if ps_assert_metric_no_errors:
|
||||
pageserver.assert_no_metric_errors()
|
||||
pageserver.stop(immediate=immediate)
|
||||
self.attachment_service.stop(immediate=immediate)
|
||||
self.broker.stop(immediate=immediate)
|
||||
|
||||
@property
|
||||
def pageserver(self) -> NeonPageserver:
|
||||
"""
|
||||
@@ -1017,6 +1112,7 @@ def neon_env_builder(
|
||||
default_broker: NeonBroker,
|
||||
run_id: uuid.UUID,
|
||||
request: FixtureRequest,
|
||||
test_overlay_dir: Path,
|
||||
) -> Iterator[NeonEnvBuilder]:
|
||||
"""
|
||||
Fixture to create a Neon environment for test.
|
||||
@@ -1047,6 +1143,7 @@ def neon_env_builder(
|
||||
preserve_database_files=pytestconfig.getoption("--preserve-database-files"),
|
||||
test_name=request.node.name,
|
||||
test_output_dir=test_output_dir,
|
||||
test_overlay_dir=test_overlay_dir,
|
||||
) as builder:
|
||||
yield builder
|
||||
|
||||
@@ -1334,6 +1431,7 @@ class NeonCli(AbstractNeonCli):
|
||||
def init(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
force: Optional[str] = None,
|
||||
) -> "subprocess.CompletedProcess[str]":
|
||||
with tempfile.NamedTemporaryFile(mode="w+") as tmp:
|
||||
tmp.write(toml.dumps(config))
|
||||
@@ -1341,6 +1439,9 @@ class NeonCli(AbstractNeonCli):
|
||||
|
||||
cmd = ["init", f"--config={tmp.name}", "--pg-version", self.env.pg_version]
|
||||
|
||||
if force is not None:
|
||||
cmd.extend(["--force", force])
|
||||
|
||||
storage = self.env.pageserver_remote_storage
|
||||
|
||||
append_pageserver_param_overrides(
|
||||
@@ -1828,18 +1929,24 @@ class NeonPageserver(PgProtocol):
|
||||
return None
|
||||
|
||||
def tenant_attach(
|
||||
self, tenant_id: TenantId, config: None | Dict[str, Any] = None, config_null: bool = False
|
||||
self,
|
||||
tenant_id: TenantId,
|
||||
config: None | Dict[str, Any] = None,
|
||||
config_null: bool = False,
|
||||
generation: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Tenant attachment passes through here to acquire a generation number before proceeding
|
||||
to call into the pageserver HTTP client.
|
||||
"""
|
||||
client = self.http_client()
|
||||
if generation is None:
|
||||
generation = self.env.attachment_service.attach_hook_issue(tenant_id, self.id)
|
||||
return client.tenant_attach(
|
||||
tenant_id,
|
||||
config,
|
||||
config_null,
|
||||
generation=self.env.attachment_service.attach_hook_issue(tenant_id, self.id),
|
||||
generation=generation,
|
||||
)
|
||||
|
||||
def tenant_detach(self, tenant_id: TenantId):
|
||||
@@ -2371,6 +2478,33 @@ class NeonProxy(PgProtocol):
|
||||
assert response.status_code == kwargs["expected_code"], f"response: {response.json()}"
|
||||
return response.json()
|
||||
|
||||
async def http2_query(self, query, args, **kwargs):
|
||||
# TODO maybe use default values if not provided
|
||||
user = kwargs["user"]
|
||||
password = kwargs["password"]
|
||||
expected_code = kwargs.get("expected_code")
|
||||
|
||||
connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres"
|
||||
async with httpx.AsyncClient(
|
||||
http2=True, verify=str(self.test_output_dir / "proxy.crt")
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f"https://{self.domain}:{self.external_http_port}/sql",
|
||||
json={"query": query, "params": args},
|
||||
headers={
|
||||
"Content-Type": "application/sql",
|
||||
"Neon-Connection-String": connstr,
|
||||
"Neon-Pool-Opt-In": "true",
|
||||
},
|
||||
)
|
||||
assert response.http_version == "HTTP/2"
|
||||
|
||||
if expected_code is not None:
|
||||
assert (
|
||||
response.status_code == kwargs["expected_code"]
|
||||
), f"response: {response.json()}"
|
||||
return response.json()
|
||||
|
||||
def get_metrics(self) -> str:
|
||||
request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics")
|
||||
request_result.raise_for_status()
|
||||
@@ -3194,10 +3328,10 @@ class S3Scrubber:
|
||||
raise
|
||||
|
||||
|
||||
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
"""Compute the working directory for an individual test."""
|
||||
def _get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str) -> Path:
|
||||
"""Compute the path to a working directory for an individual test."""
|
||||
test_name = request.node.name
|
||||
test_dir = top_output_dir / test_name.replace("/", "-")
|
||||
test_dir = top_output_dir / f"{prefix}{test_name.replace('/', '-')}"
|
||||
|
||||
# We rerun flaky tests multiple times, use a separate directory for each run.
|
||||
if (suffix := getattr(request.node, "execution_count", None)) is not None:
|
||||
@@ -3209,6 +3343,21 @@ def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
return test_dir
|
||||
|
||||
|
||||
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
"""
|
||||
The working directory for a test.
|
||||
"""
|
||||
return _get_test_dir(request, top_output_dir, "")
|
||||
|
||||
|
||||
def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
"""
|
||||
Directory that contains `upperdir` and `workdir` for overlayfs mounts
|
||||
that a test creates. See `NeonEnvBuilder.overlay_mount`.
|
||||
"""
|
||||
return _get_test_dir(request, top_output_dir, "overlay-")
|
||||
|
||||
|
||||
def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
return get_test_output_dir(request, top_output_dir) / "repo"
|
||||
|
||||
@@ -3236,8 +3385,12 @@ SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
# scope. So it uses the get_test_output_dir() function to get the path, and
|
||||
# this fixture ensures that the directory exists. That works because
|
||||
# 'autouse' fixtures are run before other fixtures.
|
||||
#
|
||||
# NB: we request the overlay dir fixture so the fixture does its cleanups
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Iterator[Path]:
|
||||
def test_output_dir(
|
||||
request: FixtureRequest, top_output_dir: Path, test_overlay_dir: Path
|
||||
) -> Iterator[Path]:
|
||||
"""Create the working directory for an individual test."""
|
||||
|
||||
# one directory per test
|
||||
@@ -3251,6 +3404,43 @@ def test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Iterator[P
|
||||
allure_attach_from_dir(test_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]:
|
||||
"""
|
||||
Idempotently create a test's overlayfs mount state directory.
|
||||
If the functionality isn't enabled via env var, returns None.
|
||||
|
||||
The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc).
|
||||
"""
|
||||
|
||||
if os.getenv("NEON_ENV_BUILDER_FROM_REPO_DIR_USE_OVERLAYFS") is None:
|
||||
return None
|
||||
|
||||
overlay_dir = get_test_overlay_dir(request, top_output_dir)
|
||||
log.info(f"test_overlay_dir is {overlay_dir}")
|
||||
|
||||
overlay_dir.mkdir(exist_ok=True)
|
||||
# unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir`
|
||||
for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)):
|
||||
cmd = ["sudo", "umount", str(mountpoint)]
|
||||
log.info(
|
||||
f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}"
|
||||
)
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
# the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work.
|
||||
cmd = ["sudo", "rm", "-rf", str(overlay_dir)]
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
|
||||
overlay_dir.mkdir()
|
||||
|
||||
return overlay_dir
|
||||
|
||||
# no need to clean up anything: on clean shutdown,
|
||||
# NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup
|
||||
# and on unclean shutdown, this function will take care of it
|
||||
# on the next test run
|
||||
|
||||
|
||||
SKIP_DIRS = frozenset(
|
||||
(
|
||||
"pg_wal",
|
||||
|
||||
16
test_runner/fixtures/overlayfs.py
Normal file
16
test_runner/fixtures/overlayfs.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
def iter_mounts_beneath(topdir: Path) -> Iterator[Path]:
|
||||
"""
|
||||
Iterate over the overlayfs mounts beneath the specififed `topdir`.
|
||||
The `topdir` itself isn't considered.
|
||||
"""
|
||||
for part in psutil.disk_partitions(all=True):
|
||||
if part.fstype == "overlay":
|
||||
mountpoint = Path(part.mountpoint)
|
||||
if topdir in mountpoint.parents:
|
||||
yield mountpoint
|
||||
@@ -441,6 +441,7 @@ class PageserverHttpClient(requests.Session):
|
||||
timeline_id: TimelineId,
|
||||
include_non_incremental_logical_size: bool = False,
|
||||
include_timeline_dir_layer_file_size_sum: bool = False,
|
||||
force_await_initial_logical_size: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict[Any, Any]:
|
||||
params = {}
|
||||
@@ -448,6 +449,8 @@ class PageserverHttpClient(requests.Session):
|
||||
params["include-non-incremental-logical-size"] = "true"
|
||||
if include_timeline_dir_layer_file_size_sum:
|
||||
params["include-timeline-dir-layer-file-size-sum"] = "true"
|
||||
if force_await_initial_logical_size:
|
||||
params["force-await-initial-logical-size"] = "true"
|
||||
|
||||
res = self.get(
|
||||
f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}",
|
||||
|
||||
@@ -254,7 +254,9 @@ def test_generations_upgrade(neon_env_builder: NeonEnvBuilder):
|
||||
metadata_summary = S3Scrubber(
|
||||
neon_env_builder.test_output_dir, neon_env_builder
|
||||
).scan_metadata()
|
||||
assert metadata_summary["count"] == 1 # Scrubber should have seen our timeline
|
||||
assert metadata_summary["tenant_count"] == 1 # Scrubber should have seen our timeline
|
||||
assert metadata_summary["timeline_count"] == 1
|
||||
assert metadata_summary["timeline_shard_count"] == 1
|
||||
assert not metadata_summary["with_errors"]
|
||||
assert not metadata_summary["with_warnings"]
|
||||
|
||||
|
||||
@@ -500,3 +500,13 @@ def test_sql_over_http_pool_custom_types(static_proxy: NeonProxy):
|
||||
"select array['foo'::foo, 'bar'::foo, 'baz'::foo] as data",
|
||||
)
|
||||
assert response["rows"][0]["data"] == ["foo", "bar", "baz"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_over_http2(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create role http with login password 'http' superuser")
|
||||
|
||||
resp = await static_proxy.http2_query(
|
||||
"select 42 as answer", [], user="http", password="http", expected_code=200
|
||||
)
|
||||
assert resp["rows"] == [{"answer": 42}]
|
||||
|
||||
62
test_runner/regress/test_proxy_websockets.py
Normal file
62
test_runner/regress/test_proxy_websockets.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import ssl
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
from fixtures.neon_fixtures import NeonProxy
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websockets(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create user ws_auth with password 'ws' superuser")
|
||||
|
||||
user = "ws_auth"
|
||||
password = "ws"
|
||||
|
||||
version = b"\x00\x03\x00\x00"
|
||||
params = {
|
||||
"user": user,
|
||||
"database": "postgres",
|
||||
"client_encoding": "UTF8",
|
||||
}
|
||||
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ssl_context.load_verify_locations(str(static_proxy.test_output_dir / "proxy.crt"))
|
||||
|
||||
async with websockets.connect(
|
||||
f"wss://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
||||
ssl=ssl_context,
|
||||
) as websocket:
|
||||
startup_message = bytearray(version)
|
||||
for key, value in params.items():
|
||||
startup_message.extend(key.encode("ascii"))
|
||||
startup_message.extend(b"\0")
|
||||
startup_message.extend(value.encode("ascii"))
|
||||
startup_message.extend(b"\0")
|
||||
startup_message.extend(b"\0")
|
||||
length = (4 + len(startup_message)).to_bytes(4, byteorder="big")
|
||||
|
||||
await websocket.send([length, startup_message])
|
||||
|
||||
startup_response = await websocket.recv()
|
||||
assert startup_response[0:1] == b"R", "should be authentication message"
|
||||
assert startup_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
|
||||
assert startup_response[5:9] == b"\x00\x00\x00\x03", "should be cleartext"
|
||||
|
||||
auth_message = password.encode("utf-8") + b"\0"
|
||||
length = (4 + len(auth_message)).to_bytes(4, byteorder="big")
|
||||
await websocket.send([b"p", length, auth_message])
|
||||
|
||||
auth_response = await websocket.recv()
|
||||
assert auth_response[0:1] == b"R", "should be authentication message"
|
||||
assert auth_response[1:5] == b"\x00\x00\x00\x08", "should be 8 bytes long message"
|
||||
assert auth_response[5:9] == b"\x00\x00\x00\x00", "should be authenticated"
|
||||
|
||||
query_message = "SELECT 1".encode("utf-8") + b"\0"
|
||||
length = (4 + len(query_message)).to_bytes(4, byteorder="big")
|
||||
await websocket.send([b"Q", length, query_message])
|
||||
|
||||
_query_response = await websocket.recv()
|
||||
|
||||
# close
|
||||
await websocket.send(b"X\x00\x00\x00\x04")
|
||||
await websocket.wait_closed()
|
||||
@@ -144,8 +144,11 @@ def test_remote_storage_backup_and_restore(
|
||||
# Introduce failpoint in list remote timelines code path to make tenant_attach fail.
|
||||
# This is before the failures injected by test_remote_failures, so it's a permanent error.
|
||||
pageserver_http.configure_failpoints(("storage-sync-list-remote-timelines", "return"))
|
||||
env.pageserver.allowed_errors.append(
|
||||
".*attach failed.*: storage-sync-list-remote-timelines",
|
||||
env.pageserver.allowed_errors.extend(
|
||||
[
|
||||
".*attach failed.*: storage-sync-list-remote-timelines",
|
||||
".*Tenant state is Broken: storage-sync-list-remote-timelines.*",
|
||||
]
|
||||
)
|
||||
# Attach it. This HTTP request will succeed and launch a
|
||||
# background task to load the tenant. In that background task,
|
||||
@@ -159,9 +162,13 @@ def test_remote_storage_backup_and_restore(
|
||||
"data": {"reason": "storage-sync-list-remote-timelines"},
|
||||
}
|
||||
|
||||
# Ensure that even though the tenant is broken, we can't attach it again.
|
||||
with pytest.raises(Exception, match=f"tenant {tenant_id} already exists, state: Broken"):
|
||||
env.pageserver.tenant_attach(tenant_id)
|
||||
# Ensure that even though the tenant is broken, retrying the attachment fails
|
||||
with pytest.raises(Exception, match="Tenant state is Broken"):
|
||||
# Use same generation as in previous attempt
|
||||
gen_state = env.attachment_service.inspect(tenant_id)
|
||||
assert gen_state is not None
|
||||
generation = gen_state[0]
|
||||
env.pageserver.tenant_attach(tenant_id, generation=generation)
|
||||
|
||||
# Restart again, this implicitly clears the failpoint.
|
||||
# test_remote_failures=1 remains active, though, as it's in the pageserver config.
|
||||
@@ -176,10 +183,8 @@ def test_remote_storage_backup_and_restore(
|
||||
), "we shouldn't have tried any layer downloads yet since list remote timelines has a failpoint"
|
||||
env.pageserver.start()
|
||||
|
||||
# Ensure that the pageserver remembers that the tenant was attaching, by
|
||||
# trying to attach it again. It should fail.
|
||||
with pytest.raises(Exception, match=f"tenant {tenant_id} already exists, state:"):
|
||||
env.pageserver.tenant_attach(tenant_id)
|
||||
# The attach should have got far enough that it recovers on restart (i.e. tenant's
|
||||
# config was written to local storage).
|
||||
log.info("waiting for tenant to become active. this should be quick with on-demand download")
|
||||
|
||||
wait_until_tenant_active(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user