Compare commits

..

2 Commits

Author SHA1 Message Date
Konstantin Knizhnik
08cf2749ca Reduce number of iteration in test_physical_replication to reduce test time 2024-05-22 11:52:31 +03:00
Konstantin Knizhnik
c85fd74d34 Fix test_physical_replication test taken in acount autocommit behaviour of psycopg 2024-04-08 17:37:48 +03:00
61 changed files with 857 additions and 1705 deletions

View File

@@ -10,7 +10,7 @@ inputs:
required: true
api_host:
desctiption: 'Neon API host'
default: console-stage.neon.build
default: console.stage.neon.tech
outputs:
dsn:
description: 'Created Branch DSN (for main database)'

View File

@@ -13,7 +13,7 @@ inputs:
required: true
api_host:
desctiption: 'Neon API host'
default: console-stage.neon.build
default: console.stage.neon.tech
runs:
using: "composite"

View File

@@ -13,7 +13,7 @@ inputs:
default: 15
api_host:
desctiption: 'Neon API host'
default: console-stage.neon.build
default: console.stage.neon.tech
provisioner:
desctiption: 'k8s-pod or k8s-neonvm'
default: 'k8s-pod'

View File

@@ -10,7 +10,7 @@ inputs:
required: true
api_host:
desctiption: 'Neon API host'
default: console-stage.neon.build
default: console.stage.neon.tech
runs:
using: "composite"

396
Cargo.lock generated
View File

@@ -270,12 +270,6 @@ dependencies = [
"critical-section",
]
[[package]]
name = "atomic-take"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8ab6b55fe97976e46f91ddbed8d147d966475dc29b2032757ba47e02376fbc3"
[[package]]
name = "autocfg"
version = "1.1.0"
@@ -304,7 +298,7 @@ dependencies = [
"fastrand 2.0.0",
"hex",
"http 0.2.9",
"hyper 0.14.26",
"hyper",
"ring 0.17.6",
"time",
"tokio",
@@ -341,7 +335,7 @@ dependencies = [
"bytes",
"fastrand 2.0.0",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"percent-encoding",
"pin-project-lite",
"tracing",
@@ -392,7 +386,7 @@ dependencies = [
"aws-types",
"bytes",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"once_cell",
"percent-encoding",
"regex-lite",
@@ -520,7 +514,7 @@ dependencies = [
"crc32fast",
"hex",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"md-5",
"pin-project-lite",
"sha1",
@@ -552,7 +546,7 @@ dependencies = [
"bytes-utils",
"futures-core",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"once_cell",
"percent-encoding",
"pin-project-lite",
@@ -591,10 +585,10 @@ dependencies = [
"aws-smithy-types",
"bytes",
"fastrand 2.0.0",
"h2 0.3.26",
"h2",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.26",
"http-body",
"hyper",
"hyper-rustls",
"once_cell",
"pin-project-lite",
@@ -632,7 +626,7 @@ dependencies = [
"bytes-utils",
"futures-core",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"itoa",
"num-integer",
"pin-project-lite",
@@ -681,8 +675,8 @@ dependencies = [
"bytes",
"futures-util",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.26",
"http-body",
"hyper",
"itoa",
"matchit",
"memchr",
@@ -697,7 +691,7 @@ dependencies = [
"sha1",
"sync_wrapper",
"tokio",
"tokio-tungstenite 0.20.0",
"tokio-tungstenite",
"tower",
"tower-layer",
"tower-service",
@@ -713,7 +707,7 @@ dependencies = [
"bytes",
"futures-util",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"mime",
"rustversion",
"tower-layer",
@@ -1130,7 +1124,7 @@ version = "4.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "191d9573962933b4027f932c600cd252ce27a8ad5979418fe78e43c07996f27b"
dependencies = [
"heck 0.4.1",
"heck",
"proc-macro2",
"quote",
"syn 2.0.52",
@@ -1202,7 +1196,7 @@ dependencies = [
"compute_api",
"flate2",
"futures",
"hyper 0.14.26",
"hyper",
"nix 0.27.1",
"notify",
"num_cpus",
@@ -1319,7 +1313,7 @@ dependencies = [
"git-version",
"hex",
"humantime",
"hyper 0.14.26",
"hyper",
"nix 0.27.1",
"once_cell",
"pageserver_api",
@@ -1468,9 +1462,12 @@ dependencies = [
[[package]]
name = "crossbeam-utils"
version = "0.8.19"
version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b"
dependencies = [
"cfg-if",
]
[[package]]
name = "crossterm"
@@ -1843,12 +1840,23 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "errno"
version = "0.3.8"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245"
checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a"
dependencies = [
"errno-dragonfly",
"libc",
"windows-sys 0.48.0",
]
[[package]]
name = "errno-dragonfly"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf"
dependencies = [
"cc",
"libc",
"windows-sys 0.52.0",
]
[[package]]
@@ -2205,25 +2213,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "h2"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "816ec7294445779408f36fe57bc5b7fc1cf59664059096c65f905c1c61f58069"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http 1.1.0",
"indexmap 2.0.1",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "half"
version = "1.8.2"
@@ -2305,12 +2294,6 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.3.3"
@@ -2395,29 +2378,6 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
dependencies = [
"bytes",
"http 1.1.0",
]
[[package]]
name = "http-body-util"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840"
dependencies = [
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.0",
"pin-project-lite",
]
[[package]]
name = "http-types"
version = "2.12.0"
@@ -2476,9 +2436,9 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2 0.3.26",
"h2",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"httparse",
"httpdate",
"itoa",
@@ -2490,26 +2450,6 @@ dependencies = [
"want",
]
[[package]]
name = "hyper"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.4",
"http 1.1.0",
"http-body 1.0.0",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
"tokio",
]
[[package]]
name = "hyper-rustls"
version = "0.24.0"
@@ -2517,7 +2457,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7"
dependencies = [
"http 0.2.9",
"hyper 0.14.26",
"hyper",
"log",
"rustls 0.21.9",
"rustls-native-certs 0.6.2",
@@ -2531,7 +2471,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
dependencies = [
"hyper 0.14.26",
"hyper",
"pin-project-lite",
"tokio",
"tokio-io-timeout",
@@ -2544,7 +2484,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
dependencies = [
"bytes",
"hyper 0.14.26",
"hyper",
"native-tls",
"tokio",
"tokio-native-tls",
@@ -2552,33 +2492,15 @@ dependencies = [
[[package]]
name = "hyper-tungstenite"
version = "0.13.0"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad"
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
dependencies = [
"http-body-util",
"hyper 1.2.0",
"hyper-util",
"hyper",
"pin-project-lite",
"tokio",
"tokio-tungstenite 0.21.0",
"tungstenite 0.21.0",
]
[[package]]
name = "hyper-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
dependencies = [
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.0",
"hyper 1.2.0",
"pin-project-lite",
"socket2 0.5.5",
"tokio",
"tokio-tungstenite",
"tungstenite",
]
[[package]]
@@ -2872,12 +2794,6 @@ version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519"
[[package]]
name = "linux-raw-sys"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "lock_api"
version = "0.4.10"
@@ -2932,12 +2848,11 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "measured"
version = "0.0.20"
version = "0.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cbf033874bea03565f2449572c8640ca37ec26300455faf36001f24755da452"
checksum = "f246648d027839a34b420e27c7de1165ace96e19ef894985d0a6ff89a7840a9f"
dependencies = [
"bytes",
"crossbeam-utils",
"hashbrown 0.14.0",
"itoa",
"lasso",
@@ -2950,27 +2865,16 @@ dependencies = [
[[package]]
name = "measured-derive"
version = "0.0.20"
version = "0.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be9e29b682b38f8af2a89f960455054ab1a9f5a06822f6f3500637ad9fa57def"
checksum = "edaa5cc22d99d5d6d7d99c3b5b5f7e7f8034c22f1b5d62a1adecd2ed005d9b80"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "measured-process"
version = "0.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a20849acdd04c5d6a88f565559044546904648a1842a2937cfff0b48b4ca7ef2"
dependencies = [
"libc",
"measured",
"procfs 0.16.0",
]
[[package]]
name = "memchr"
version = "2.6.4"
@@ -3010,10 +2914,8 @@ version = "0.1.0"
dependencies = [
"chrono",
"libc",
"measured",
"measured-process",
"once_cell",
"procfs 0.14.2",
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr",
@@ -3563,17 +3465,12 @@ dependencies = [
"camino",
"clap",
"git-version",
"humantime",
"pageserver",
"pageserver_api",
"postgres_ffi",
"remote_storage",
"serde",
"serde_json",
"svg_fmt",
"tokio",
"tokio-util",
"toml_edit",
"utils",
"workspace_hack",
]
@@ -3609,7 +3506,7 @@ dependencies = [
"hex-literal",
"humantime",
"humantime-serde",
"hyper 0.14.26",
"hyper",
"itertools",
"leaky-bucket",
"md5",
@@ -3628,7 +3525,7 @@ dependencies = [
"postgres_connection",
"postgres_ffi",
"pq_proto",
"procfs 0.14.2",
"procfs",
"rand 0.8.5",
"regex",
"remote_storage",
@@ -3719,6 +3616,7 @@ dependencies = [
"anyhow",
"async-compression",
"async-stream",
"async-trait",
"byteorder",
"bytes",
"chrono",
@@ -4188,29 +4086,6 @@ dependencies = [
"rustix 0.36.16",
]
[[package]]
name = "procfs"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "731e0d9356b0c25f16f33b5be79b1c57b562f141ebfcdb0ad8ac2c13a24293b4"
dependencies = [
"bitflags 2.4.1",
"hex",
"lazy_static",
"procfs-core",
"rustix 0.38.28",
]
[[package]]
name = "procfs-core"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d3554923a69f4ce04c4a754260c338f505ce22642d3830e049a399fc2059a29"
dependencies = [
"bitflags 2.4.1",
"hex",
]
[[package]]
name = "prometheus"
version = "0.13.3"
@@ -4223,7 +4098,7 @@ dependencies = [
"libc",
"memchr",
"parking_lot 0.12.1",
"procfs 0.14.2",
"procfs",
"thiserror",
]
@@ -4244,7 +4119,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270"
dependencies = [
"bytes",
"heck 0.4.1",
"heck",
"itertools",
"lazy_static",
"log",
@@ -4288,7 +4163,6 @@ dependencies = [
"anyhow",
"async-compression",
"async-trait",
"atomic-take",
"aws-config",
"aws-sdk-iam",
"aws-sigv4",
@@ -4312,12 +4186,9 @@ dependencies = [
"hmac",
"hostname",
"http 1.1.0",
"http-body-util",
"humantime",
"hyper 0.14.26",
"hyper 1.2.0",
"hyper",
"hyper-tungstenite",
"hyper-util",
"ipnet",
"itertools",
"lasso",
@@ -4650,7 +4521,7 @@ dependencies = [
"futures-util",
"http-types",
"humantime",
"hyper 0.14.26",
"hyper",
"itertools",
"metrics",
"once_cell",
@@ -4680,10 +4551,10 @@ dependencies = [
"encoding_rs",
"futures-core",
"futures-util",
"h2 0.3.26",
"h2",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.26",
"http-body",
"hyper",
"hyper-rustls",
"hyper-tls",
"ipnet",
@@ -4741,7 +4612,7 @@ dependencies = [
"futures",
"getrandom 0.2.11",
"http 0.2.9",
"hyper 0.14.26",
"hyper",
"parking_lot 0.11.2",
"reqwest",
"reqwest-middleware",
@@ -4828,7 +4699,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945"
dependencies = [
"http 0.2.9",
"hyper 0.14.26",
"hyper",
"lazy_static",
"percent-encoding",
"regex",
@@ -4940,19 +4811,6 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "rustix"
version = "0.38.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316"
dependencies = [
"bitflags 2.4.1",
"errno",
"libc",
"linux-raw-sys 0.4.13",
"windows-sys 0.52.0",
]
[[package]]
name = "rustls"
version = "0.21.9"
@@ -5133,7 +4991,7 @@ dependencies = [
"git-version",
"hex",
"humantime",
"hyper 0.14.26",
"hyper",
"metrics",
"once_cell",
"parking_lot 0.12.1",
@@ -5618,9 +5476,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.13.1"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
[[package]]
name = "smol_str"
@@ -5712,7 +5570,7 @@ dependencies = [
"futures-util",
"git-version",
"humantime",
"hyper 0.14.26",
"hyper",
"metrics",
"once_cell",
"parking_lot 0.12.1",
@@ -5743,7 +5601,7 @@ dependencies = [
"git-version",
"hex",
"humantime",
"hyper 0.14.26",
"hyper",
"itertools",
"lasso",
"measured",
@@ -5772,7 +5630,7 @@ dependencies = [
"anyhow",
"clap",
"comfy-table",
"hyper 0.14.26",
"hyper",
"pageserver_api",
"pageserver_client",
"reqwest",
@@ -5813,7 +5671,7 @@ version = "0.24.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59"
dependencies = [
"heck 0.4.1",
"heck",
"proc-macro2",
"quote",
"rustversion",
@@ -6255,19 +6113,7 @@ dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.20.1",
]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.21.0",
"tungstenite",
]
[[package]]
@@ -6334,10 +6180,10 @@ dependencies = [
"bytes",
"futures-core",
"futures-util",
"h2 0.3.26",
"h2",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.26",
"http-body",
"hyper",
"hyper-timeout",
"percent-encoding",
"pin-project",
@@ -6523,7 +6369,7 @@ dependencies = [
name = "tracing-utils"
version = "0.1.0"
dependencies = [
"hyper 0.14.26",
"hyper",
"opentelemetry",
"opentelemetry-otlp",
"opentelemetry-semantic-conventions",
@@ -6560,25 +6406,6 @@ dependencies = [
"utf-8",
]
[[package]]
name = "tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http 1.1.0",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]]
name = "twox-hash"
version = "1.6.3"
@@ -6743,8 +6570,7 @@ dependencies = [
"heapless",
"hex",
"hex-literal",
"humantime",
"hyper 0.14.26",
"hyper",
"jsonwebtoken",
"leaky-bucket",
"metrics",
@@ -7104,15 +6930,6 @@ dependencies = [
"windows-targets 0.48.0",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets 0.52.4",
]
[[package]]
name = "windows-targets"
version = "0.42.2"
@@ -7143,21 +6960,6 @@ dependencies = [
"windows_x86_64_msvc 0.48.0",
]
[[package]]
name = "windows-targets"
version = "0.52.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dd37b7e5ab9018759f893a1952c9420d060016fc19a472b4bb20d1bdd694d1b"
dependencies = [
"windows_aarch64_gnullvm 0.52.4",
"windows_aarch64_msvc 0.52.4",
"windows_i686_gnu 0.52.4",
"windows_i686_msvc 0.52.4",
"windows_x86_64_gnu 0.52.4",
"windows_x86_64_gnullvm 0.52.4",
"windows_x86_64_msvc 0.52.4",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
@@ -7170,12 +6972,6 @@ version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bcf46cf4c365c6f2d1cc93ce535f2c8b244591df96ceee75d8e83deb70a9cac9"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
@@ -7188,12 +6984,6 @@ version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da9f259dd3bcf6990b55bffd094c4f7235817ba4ceebde8e6d11cd0c5633b675"
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
@@ -7206,12 +6996,6 @@ version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241"
[[package]]
name = "windows_i686_gnu"
version = "0.52.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b474d8268f99e0995f25b9f095bc7434632601028cf86590aea5c8a5cb7801d3"
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
@@ -7224,12 +7008,6 @@ version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00"
[[package]]
name = "windows_i686_msvc"
version = "0.52.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1515e9a29e5bed743cb4415a9ecf5dfca648ce85ee42e15873c3cd8610ff8e02"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
@@ -7242,12 +7020,6 @@ version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5eee091590e89cc02ad514ffe3ead9eb6b660aedca2183455434b93546371a03"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
@@ -7260,12 +7032,6 @@ version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77ca79f2451b49fa9e2af39f0747fe999fcda4f5e241b2898624dca97a1f2177"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
@@ -7278,12 +7044,6 @@ version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8"
[[package]]
name = "winnow"
version = "0.4.6"
@@ -7332,10 +7092,11 @@ dependencies = [
"futures-sink",
"futures-util",
"getrandom 0.2.11",
"hashbrown 0.13.2",
"hashbrown 0.14.0",
"hex",
"hmac",
"hyper 0.14.26",
"hyper",
"indexmap 1.9.3",
"itertools",
"libc",
@@ -7373,6 +7134,7 @@ dependencies = [
"tower",
"tracing",
"tracing-core",
"tungstenite",
"url",
"uuid",
"zeroize",

View File

@@ -44,7 +44,6 @@ license = "Apache-2.0"
anyhow = { version = "1.0", features = ["backtrace"] }
arc-swap = "1.6"
async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] }
atomic-take = "1.1.0"
azure_core = "0.18"
azure_identity = "0.18"
azure_storage = "0.18"
@@ -98,7 +97,7 @@ http-types = { version = "2", default-features = false }
humantime = "2.1"
humantime-serde = "1.1.1"
hyper = "0.14"
hyper-tungstenite = "0.13.0"
hyper-tungstenite = "0.11"
inotify = "0.10.2"
ipnet = "2.9.0"
itertools = "0.10"
@@ -107,8 +106,7 @@ lasso = "0.7"
leaky-bucket = "1.0.1"
libc = "0.2"
md5 = "0.7.0"
measured = { version = "0.0.20", features=["lasso"] }
measured-process = { version = "0.0.20" }
measured = { version = "0.0.13", features=["default", "lasso"] }
memoffset = "0.8"
native-tls = "0.2"
nix = { version = "0.27", features = ["fs", "process", "socket", "signal", "poll"] }

View File

@@ -86,10 +86,7 @@ where
.stdout(process_log_file)
.stderr(same_file_for_stderr)
.args(args);
let filled_cmd = fill_env_vars_prefixed_neon(fill_remote_storage_secrets_vars(
fill_rust_env_vars(background_command),
));
let filled_cmd = fill_remote_storage_secrets_vars(fill_rust_env_vars(background_command));
filled_cmd.envs(envs);
let pid_file_to_check = match &initial_pid_file {
@@ -271,15 +268,6 @@ fn fill_remote_storage_secrets_vars(mut cmd: &mut Command) -> &mut Command {
cmd
}
fn fill_env_vars_prefixed_neon(mut cmd: &mut Command) -> &mut Command {
for (var, val) in std::env::vars() {
if var.starts_with("NEON_PAGESERVER_") {
cmd = cmd.env(var, val);
}
}
cmd
}
/// Add a `pre_exec` to the cmd that, inbetween fork() and exec(),
/// 1. Claims a pidfile with a fcntl lock on it and
/// 2. Sets up the pidfile's file descriptor so that it (and the lock)

View File

@@ -10,13 +10,11 @@ libc.workspace = true
once_cell.workspace = true
chrono.workspace = true
twox-hash.workspace = true
measured.workspace = true
workspace_hack.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]
procfs.workspace = true
measured-process.workspace = true
[dev-dependencies]
rand = "0.8"

View File

@@ -4,17 +4,6 @@
//! a default registry.
#![deny(clippy::undocumented_unsafe_blocks)]
use measured::{
label::{LabelGroupVisitor, LabelName, NoLabels},
metric::{
counter::CounterState,
gauge::GaugeState,
group::{Encoding, MetricValue},
name::{MetricName, MetricNameEncoder},
MetricEncoding, MetricFamilyEncoding,
},
FixedCardinalityLabel, LabelGroup, MetricGroup,
};
use once_cell::sync::Lazy;
use prometheus::core::{
Atomic, AtomicU64, Collector, GenericCounter, GenericCounterVec, GenericGauge, GenericGaugeVec,
@@ -22,7 +11,6 @@ use prometheus::core::{
pub use prometheus::opts;
pub use prometheus::register;
pub use prometheus::Error;
use prometheus::Registry;
pub use prometheus::{core, default_registry, proto};
pub use prometheus::{exponential_buckets, linear_buckets};
pub use prometheus::{register_counter_vec, Counter, CounterVec};
@@ -35,6 +23,7 @@ pub use prometheus::{register_int_counter_vec, IntCounterVec};
pub use prometheus::{register_int_gauge, IntGauge};
pub use prometheus::{register_int_gauge_vec, IntGaugeVec};
pub use prometheus::{Encoder, TextEncoder};
use prometheus::{Registry, Result};
pub mod launch_timestamp;
mod wrappers;
@@ -70,7 +59,7 @@ static INTERNAL_REGISTRY: Lazy<Registry> = Lazy::new(Registry::new);
/// Register a collector in the internal registry. MUST be called before the first call to `gather()`.
/// Otherwise, we can have a deadlock in the `gather()` call, trying to register a new collector
/// while holding the lock.
pub fn register_internal(c: Box<dyn Collector>) -> prometheus::Result<()> {
pub fn register_internal(c: Box<dyn Collector>) -> Result<()> {
INTERNAL_REGISTRY.register(c)
}
@@ -107,127 +96,6 @@ pub const DISK_WRITE_SECONDS_BUCKETS: &[f64] = &[
0.000_050, 0.000_100, 0.000_500, 0.001, 0.003, 0.005, 0.01, 0.05, 0.1, 0.3, 0.5,
];
pub struct BuildInfo {
pub revision: &'static str,
pub build_tag: &'static str,
}
// todo: allow label group without the set
impl LabelGroup for BuildInfo {
fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
const REVISION: &LabelName = LabelName::from_str("revision");
v.write_value(REVISION, &self.revision);
const BUILD_TAG: &LabelName = LabelName::from_str("build_tag");
v.write_value(BUILD_TAG, &self.build_tag);
}
}
impl<T: Encoding> MetricFamilyEncoding<T> for BuildInfo
where
GaugeState: MetricEncoding<T>,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut T,
) -> Result<(), T::Err> {
enc.write_help(&name, "Build/version information")?;
GaugeState::write_type(&name, enc)?;
GaugeState {
count: std::sync::atomic::AtomicI64::new(1),
}
.collect_into(&(), self, name, enc)
}
}
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct NeonMetrics {
#[cfg(target_os = "linux")]
#[metric(namespace = "process")]
#[metric(init = measured_process::ProcessCollector::for_self())]
process: measured_process::ProcessCollector,
#[metric(namespace = "libmetrics")]
#[metric(init = LibMetrics::new(build_info))]
libmetrics: LibMetrics,
}
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct LibMetrics {
#[metric(init = build_info)]
build_info: BuildInfo,
#[metric(flatten)]
rusage: Rusage,
serve_count: CollectionCounter,
}
fn write_gauge<Enc: Encoding>(
x: i64,
labels: impl LabelGroup,
name: impl MetricNameEncoder,
enc: &mut Enc,
) -> Result<(), Enc::Err> {
enc.write_metric_value(name, labels, MetricValue::Int(x))
}
#[derive(Default)]
struct Rusage;
#[derive(FixedCardinalityLabel, Clone, Copy)]
#[label(singleton = "io_operation")]
enum IoOp {
Read,
Write,
}
impl<T: Encoding> MetricGroup<T> for Rusage
where
GaugeState: MetricEncoding<T>,
{
fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> {
const DISK_IO: &MetricName = MetricName::from_str("disk_io_bytes_total");
const MAXRSS: &MetricName = MetricName::from_str("maxrss_kb");
let ru = get_rusage_stats();
enc.write_help(
DISK_IO,
"Bytes written and read from disk, grouped by the operation (read|write)",
)?;
GaugeState::write_type(DISK_IO, enc)?;
write_gauge(ru.ru_inblock * BYTES_IN_BLOCK, IoOp::Read, DISK_IO, enc)?;
write_gauge(ru.ru_oublock * BYTES_IN_BLOCK, IoOp::Write, DISK_IO, enc)?;
enc.write_help(MAXRSS, "Memory usage (Maximum Resident Set Size)")?;
GaugeState::write_type(MAXRSS, enc)?;
write_gauge(ru.ru_maxrss, IoOp::Read, MAXRSS, enc)?;
Ok(())
}
}
#[derive(Default)]
struct CollectionCounter(CounterState);
impl<T: Encoding> MetricFamilyEncoding<T> for CollectionCounter
where
CounterState: MetricEncoding<T>,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut T,
) -> Result<(), T::Err> {
self.0.inc();
enc.write_help(&name, "Number of metric requests made")?;
self.0.collect_into(&(), NoLabels, name, enc)
}
}
pub fn set_build_info_metric(revision: &str, build_tag: &str) {
let metric = register_int_gauge_vec!(
"libmetrics_build_info",
@@ -237,7 +105,6 @@ pub fn set_build_info_metric(revision: &str, build_tag: &str) {
.expect("Failed to register build info metric");
metric.with_label_values(&[revision, build_tag]).set(1);
}
const BYTES_IN_BLOCK: i64 = 512;
// Records I/O stats in a "cross-platform" way.
// Compiles both on macOS and Linux, but current macOS implementation always returns 0 as values for I/O stats.
@@ -250,6 +117,7 @@ const BYTES_IN_BLOCK: i64 = 512;
fn update_rusage_metrics() {
let rusage_stats = get_rusage_stats();
const BYTES_IN_BLOCK: i64 = 512;
DISK_IO_BYTES
.with_label_values(&["read"])
.set(rusage_stats.ru_inblock * BYTES_IN_BLOCK);
@@ -283,7 +151,6 @@ macro_rules! register_int_counter_pair_vec {
}
}};
}
/// Create an [`IntCounterPair`] and registers to default registry.
#[macro_export(local_inner_macros)]
macro_rules! register_int_counter_pair {
@@ -321,10 +188,7 @@ impl<P: Atomic> GenericCounterPairVec<P> {
///
/// An error is returned if the number of label values is not the same as the
/// number of VariableLabels in Desc.
pub fn get_metric_with_label_values(
&self,
vals: &[&str],
) -> prometheus::Result<GenericCounterPair<P>> {
pub fn get_metric_with_label_values(&self, vals: &[&str]) -> Result<GenericCounterPair<P>> {
Ok(GenericCounterPair {
inc: self.inc.get_metric_with_label_values(vals)?,
dec: self.dec.get_metric_with_label_values(vals)?,
@@ -337,7 +201,7 @@ impl<P: Atomic> GenericCounterPairVec<P> {
self.get_metric_with_label_values(vals).unwrap()
}
pub fn remove_label_values(&self, res: &mut [prometheus::Result<()>; 2], vals: &[&str]) {
pub fn remove_label_values(&self, res: &mut [Result<()>; 2], vals: &[&str]) {
res[0] = self.inc.remove_label_values(vals);
res[1] = self.dec.remove_label_values(vals);
}

View File

@@ -20,7 +20,6 @@ use utils::{
history_buffer::HistoryBufferWithDropCounter,
id::{NodeId, TenantId, TimelineId},
lsn::Lsn,
serde_system_time,
};
use crate::controller_api::PlacementPolicy;
@@ -759,7 +758,11 @@ pub struct WalRedoManagerStatus {
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
pub struct SecondaryProgress {
/// The remote storage LastModified time of the heatmap object we last downloaded.
pub heatmap_mtime: Option<serde_system_time::SystemTime>,
#[serde(
serialize_with = "opt_ser_rfc3339_millis",
deserialize_with = "opt_deser_rfc3339_millis"
)]
pub heatmap_mtime: Option<SystemTime>,
/// The number of layers currently on-disk
pub layers_downloaded: usize,
@@ -772,6 +775,29 @@ pub struct SecondaryProgress {
pub bytes_total: u64,
}
fn opt_ser_rfc3339_millis<S: serde::Serializer>(
ts: &Option<SystemTime>,
serializer: S,
) -> Result<S::Ok, S::Error> {
match ts {
Some(ts) => serializer.collect_str(&humantime::format_rfc3339_millis(*ts)),
None => serializer.serialize_none(),
}
}
fn opt_deser_rfc3339_millis<'de, D>(deserializer: D) -> Result<Option<SystemTime>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let s: Option<String> = serde::de::Deserialize::deserialize(deserializer)?;
match s {
None => Ok(None),
Some(s) => humantime::parse_rfc3339(&s)
.map_err(serde::de::Error::custom)
.map(Some),
}
}
pub mod virtual_file {
#[derive(
Copy,

View File

@@ -1,4 +1,4 @@
use utils::serde_system_time::SystemTime;
use std::time::SystemTime;
/// Pageserver current utilization and scoring for how good candidate the pageserver would be for
/// the next tenant.
@@ -21,9 +21,28 @@ pub struct PageserverUtilization {
/// When was this snapshot captured, pageserver local time.
///
/// Use millis to give confidence that the value is regenerated often enough.
#[serde(
serialize_with = "ser_rfc3339_millis",
deserialize_with = "deser_rfc3339_millis"
)]
pub captured_at: SystemTime,
}
fn ser_rfc3339_millis<S: serde::Serializer>(
ts: &SystemTime,
serializer: S,
) -> Result<S::Ok, S::Error> {
serializer.collect_str(&humantime::format_rfc3339_millis(*ts))
}
fn deser_rfc3339_millis<'de, D>(deserializer: D) -> Result<SystemTime, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let s: String = serde::de::Deserialize::deserialize(deserializer)?;
humantime::parse_rfc3339(&s).map_err(serde::de::Error::custom)
}
/// openapi knows only `format: int64`, so avoid outputting a non-parseable value by generated clients.
///
/// Instead of newtype, use this because a newtype would get require handling deserializing values
@@ -50,9 +69,7 @@ mod tests {
disk_usage_bytes: u64::MAX,
free_space_bytes: 0,
utilization_score: u64::MAX,
captured_at: SystemTime(
std::time::SystemTime::UNIX_EPOCH + Duration::from_secs(1708509779),
),
captured_at: SystemTime::UNIX_EPOCH + Duration::from_secs(1708509779),
};
let s = serde_json::to_string(&doc).unwrap();

View File

@@ -22,7 +22,6 @@ camino.workspace = true
chrono.workspace = true
heapless.workspace = true
hex = { workspace = true, features = ["serde"] }
humantime.workspace = true
hyper = { workspace = true, features = ["full"] }
fail.workspace = true
futures = { workspace = true}

View File

@@ -1,21 +0,0 @@
//! Wrapper around `std::env::var` for parsing environment variables.
use std::{fmt::Display, str::FromStr};
pub fn var<V, E>(varname: &str) -> Option<V>
where
V: FromStr<Err = E>,
E: Display,
{
match std::env::var(varname) {
Ok(s) => Some(
s.parse()
.map_err(|e| format!("failed to parse env var {varname}: {e:#}"))
.unwrap(),
),
Err(std::env::VarError::NotPresent) => None,
Err(std::env::VarError::NotUnicode(_)) => {
panic!("env var {varname} is not unicode")
}
}
}

View File

@@ -63,7 +63,6 @@ pub mod measured_stream;
pub mod serde_percent;
pub mod serde_regex;
pub mod serde_system_time;
pub mod pageserver_feedback;
@@ -90,8 +89,6 @@ pub mod yielding_loop;
pub mod zstd;
pub mod env;
/// This is a shortcut to embed git sha into binaries and avoid copying the same build script to all packages
///
/// we have several cases:

View File

@@ -1,55 +0,0 @@
//! A `serde::{Deserialize,Serialize}` type for SystemTime with RFC3339 format and millisecond precision.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
#[serde(transparent)]
pub struct SystemTime(
#[serde(
deserialize_with = "deser_rfc3339_millis",
serialize_with = "ser_rfc3339_millis"
)]
pub std::time::SystemTime,
);
fn ser_rfc3339_millis<S: serde::ser::Serializer>(
ts: &std::time::SystemTime,
serializer: S,
) -> Result<S::Ok, S::Error> {
serializer.collect_str(&humantime::format_rfc3339_millis(*ts))
}
fn deser_rfc3339_millis<'de, D>(deserializer: D) -> Result<std::time::SystemTime, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let s: String = serde::de::Deserialize::deserialize(deserializer)?;
humantime::parse_rfc3339(&s).map_err(serde::de::Error::custom)
}
#[cfg(test)]
mod tests {
use super::*;
/// Helper function to make a SystemTime have millisecond precision by truncating additional nanoseconds.
fn to_millisecond_precision(time: SystemTime) -> SystemTime {
match time.0.duration_since(std::time::SystemTime::UNIX_EPOCH) {
Ok(duration) => {
let total_millis = duration.as_secs() * 1_000 + u64::from(duration.subsec_millis());
SystemTime(
std::time::SystemTime::UNIX_EPOCH
+ std::time::Duration::from_millis(total_millis),
)
}
Err(_) => time,
}
}
#[test]
fn test_serialize_deserialize() {
let input = SystemTime(std::time::SystemTime::now());
let expected_serialized = format!("\"{}\"", humantime::format_rfc3339_millis(input.0));
let serialized = serde_json::to_string(&input).unwrap();
assert_eq!(expected_serialized, serialized);
let deserialized: SystemTime = serde_json::from_str(&expected_serialized).unwrap();
assert_eq!(to_millisecond_precision(input), deserialized);
}
}

View File

@@ -11,6 +11,7 @@ default = []
anyhow.workspace = true
async-compression.workspace = true
async-stream.workspace = true
async-trait.workspace = true
byteorder.workspace = true
bytes.workspace = true
chrono = { workspace = true, features = ["serde"] }

View File

@@ -180,7 +180,7 @@ where
match top.deref_mut() {
LazyLoadLayer::Unloaded(ref mut l) => {
let fut = l.load_keys(this.ctx);
this.load_future.set(Some(Box::pin(fut)));
this.load_future.set(Some(fut));
continue;
}
LazyLoadLayer::Loaded(ref mut entries) => {

View File

@@ -3,6 +3,7 @@
//!
//! All the heavy lifting is done by the create_image and create_delta
//! functions that the implementor provides.
use async_trait::async_trait;
use futures::Future;
use pageserver_api::{key::Key, keyspace::key_range_size};
use std::ops::Range;
@@ -140,16 +141,18 @@ pub trait CompactionLayer<K: CompactionKey + ?Sized> {
fn is_delta(&self) -> bool;
}
#[async_trait]
pub trait CompactionDeltaLayer<E: CompactionJobExecutor + ?Sized>: CompactionLayer<E::Key> {
type DeltaEntry<'a>: CompactionDeltaEntry<'a, E::Key>
where
Self: 'a;
/// Return all keys in this delta layer.
fn load_keys<'a>(
async fn load_keys<'a>(
&self,
ctx: &E::RequestContext,
) -> impl Future<Output = anyhow::Result<Vec<Self::DeltaEntry<'_>>>> + Send;
) -> anyhow::Result<Vec<Self::DeltaEntry<'_>>>;
}
pub trait CompactionImageLayer<E: CompactionJobExecutor + ?Sized>: CompactionLayer<E::Key> {}

View File

@@ -2,6 +2,7 @@ mod draw;
use draw::{LayerTraceEvent, LayerTraceFile, LayerTraceOp};
use async_trait::async_trait;
use futures::StreamExt;
use rand::Rng;
use tracing::info;
@@ -138,6 +139,7 @@ impl interface::CompactionLayer<Key> for Arc<MockDeltaLayer> {
}
}
#[async_trait]
impl interface::CompactionDeltaLayer<MockTimeline> for Arc<MockDeltaLayer> {
type DeltaEntry<'a> = MockRecord;

View File

@@ -12,14 +12,9 @@ bytes.workspace = true
camino.workspace = true
clap = { workspace = true, features = ["string"] }
git-version.workspace = true
humantime.workspace = true
pageserver = { path = ".." }
pageserver_api.workspace = true
remote_storage = { path = "../../libs/remote_storage" }
postgres_ffi.workspace = true
tokio.workspace = true
tokio-util.workspace = true
toml_edit.workspace = true
utils.workspace = true
svg_fmt.workspace = true
workspace_hack.workspace = true

View File

@@ -9,11 +9,6 @@ mod index_part;
mod layer_map_analyzer;
mod layers;
use std::{
str::FromStr,
time::{Duration, SystemTime},
};
use camino::{Utf8Path, Utf8PathBuf};
use clap::{Parser, Subcommand};
use index_part::IndexPartCmd;
@@ -25,16 +20,8 @@ use pageserver::{
tenant::{dump_layerfile_from_path, metadata::TimelineMetadata},
virtual_file,
};
use pageserver_api::shard::TenantShardId;
use postgres_ffi::ControlFileData;
use remote_storage::{RemotePath, RemoteStorageConfig};
use tokio_util::sync::CancellationToken;
use utils::{
id::TimelineId,
logging::{self, LogFormat, TracingErrorLayerEnablement},
lsn::Lsn,
project_git_version,
};
use utils::{lsn::Lsn, project_git_version};
project_git_version!(GIT_VERSION);
@@ -56,7 +43,6 @@ enum Commands {
#[command(subcommand)]
IndexPart(IndexPartCmd),
PrintLayerFile(PrintLayerFileCmd),
TimeTravelRemotePrefix(TimeTravelRemotePrefixCmd),
DrawTimeline {},
AnalyzeLayerMap(AnalyzeLayerMapCmd),
#[command(subcommand)]
@@ -82,26 +68,6 @@ struct PrintLayerFileCmd {
path: Utf8PathBuf,
}
/// Roll back the time for the specified prefix using S3 history.
///
/// The command is fairly low level and powerful. Validation is only very light,
/// so it is more powerful, and thus potentially more dangerous.
#[derive(Parser)]
struct TimeTravelRemotePrefixCmd {
/// A configuration string for the remote_storage configuration.
///
/// Example: `remote_storage = { bucket_name = "aws-storage-bucket-name", bucket_region = "us-east-2" }`
config_toml_str: String,
/// remote prefix to time travel recover. For safety reasons, we require it to contain
/// a timeline or tenant ID in the prefix.
prefix: String,
/// Timestamp to travel to. Given in format like `2024-01-20T10:45:45Z`. Assumes UTC and second accuracy.
travel_to: String,
/// Timestamp of the start of the operation, must be after any changes we want to roll back and after.
/// You can use a few seconds before invoking the command. Same format as `travel_to`.
done_if_after: Option<String>,
}
#[derive(Parser)]
struct AnalyzeLayerMapCmd {
/// Pageserver data path
@@ -112,14 +78,6 @@ struct AnalyzeLayerMapCmd {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
logging::init(
LogFormat::Plain,
TracingErrorLayerEnablement::EnableWithRustLogFilter,
logging::Output::Stdout,
)?;
logging::replace_panic_hook_with_tracing_panic_hook().forget();
let cli = CliOpts::parse();
match cli.command {
@@ -147,42 +105,6 @@ async fn main() -> anyhow::Result<()> {
print_layerfile(&cmd.path).await?;
}
}
Commands::TimeTravelRemotePrefix(cmd) => {
let timestamp = humantime::parse_rfc3339(&cmd.travel_to)
.map_err(|_e| anyhow::anyhow!("Invalid time for travel_to: '{}'", cmd.travel_to))?;
let done_if_after = if let Some(done_if_after) = &cmd.done_if_after {
humantime::parse_rfc3339(done_if_after).map_err(|_e| {
anyhow::anyhow!("Invalid time for done_if_after: '{}'", done_if_after)
})?
} else {
const SAFETY_MARGIN: Duration = Duration::from_secs(3);
tokio::time::sleep(SAFETY_MARGIN).await;
// Convert to string representation and back to get rid of sub-second values
let done_if_after = SystemTime::now();
tokio::time::sleep(SAFETY_MARGIN).await;
done_if_after
};
let timestamp = strip_subsecond(timestamp);
let done_if_after = strip_subsecond(done_if_after);
let Some(prefix) = validate_prefix(&cmd.prefix) else {
println!("specified prefix '{}' failed validation", cmd.prefix);
return Ok(());
};
let toml_document = toml_edit::Document::from_str(&cmd.config_toml_str)?;
let toml_item = toml_document
.get("remote_storage")
.expect("need remote_storage");
let config = RemoteStorageConfig::from_toml(toml_item)?.expect("incomplete config");
let storage = remote_storage::GenericRemoteStorage::from_config(&config);
let cancel = CancellationToken::new();
storage
.unwrap()
.time_travel_recover(Some(&prefix), timestamp, done_if_after, &cancel)
.await?;
}
};
Ok(())
}
@@ -263,89 +185,3 @@ fn handle_metadata(
Ok(())
}
/// Ensures that the given S3 prefix is sufficiently constrained.
/// The command is very risky already and we don't want to expose something
/// that allows usually unintentional and quite catastrophic time travel of
/// an entire bucket, which would be a major catastrophy and away
/// by only one character change (similar to "rm -r /home /username/foobar").
fn validate_prefix(prefix: &str) -> Option<RemotePath> {
if prefix.is_empty() {
// Empty prefix means we want to specify the *whole* bucket
return None;
}
let components = prefix.split('/').collect::<Vec<_>>();
let (last, components) = {
let last = components.last()?;
if last.is_empty() {
(
components.iter().nth_back(1)?,
&components[..(components.len() - 1)],
)
} else {
(last, &components[..])
}
};
'valid: {
if let Ok(_timeline_id) = TimelineId::from_str(last) {
// Ends in either a tenant or timeline ID
break 'valid;
}
if *last == "timelines" {
if let Some(before_last) = components.iter().nth_back(1) {
if let Ok(_tenant_id) = TenantShardId::from_str(before_last) {
// Has a valid tenant id
break 'valid;
}
}
}
return None;
}
RemotePath::from_string(prefix).ok()
}
fn strip_subsecond(timestamp: SystemTime) -> SystemTime {
let ts_str = humantime::format_rfc3339_seconds(timestamp).to_string();
humantime::parse_rfc3339(&ts_str).expect("can't parse just created timestamp")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_prefix() {
assert_eq!(validate_prefix(""), None);
assert_eq!(validate_prefix("/"), None);
#[track_caller]
fn assert_valid(prefix: &str) {
let remote_path = RemotePath::from_string(prefix).unwrap();
assert_eq!(validate_prefix(prefix), Some(remote_path));
}
assert_valid("wal/3aa8fcc61f6d357410b7de754b1d9001/641e5342083b2235ee3deb8066819683/");
// Path is not relative but absolute
assert_eq!(
validate_prefix(
"/wal/3aa8fcc61f6d357410b7de754b1d9001/641e5342083b2235ee3deb8066819683/"
),
None
);
assert_valid("wal/3aa8fcc61f6d357410b7de754b1d9001/");
// Partial tenant IDs should be invalid, S3 will match all tenants with the specific ID prefix
assert_eq!(validate_prefix("wal/3aa8fcc61f6d357410b7d"), None);
assert_eq!(validate_prefix("wal"), None);
assert_eq!(validate_prefix("/wal/"), None);
assert_valid("pageserver/v1/tenants/3aa8fcc61f6d357410b7de754b1d9001");
// Partial tenant ID
assert_eq!(
validate_prefix("pageserver/v1/tenants/3aa8fcc61f6d357410b"),
None
);
assert_valid("pageserver/v1/tenants/3aa8fcc61f6d357410b7de754b1d9001/timelines");
assert_valid("pageserver/v1/tenants/3aa8fcc61f6d357410b7de754b1d9001-0004/timelines");
assert_valid("pageserver/v1/tenants/3aa8fcc61f6d357410b7de754b1d9001/timelines/");
assert_valid("pageserver/v1/tenants/3aa8fcc61f6d357410b7de754b1d9001/timelines/641e5342083b2235ee3deb8066819683");
assert_eq!(validate_prefix("pageserver/v1/tenants/"), None);
}
}

View File

@@ -2100,7 +2100,6 @@ pub(crate) fn remove_tenant_metrics(tenant_shard_id: &TenantShardId) {
use futures::Future;
use pin_project_lite::pin_project;
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
@@ -2670,26 +2669,6 @@ pub(crate) mod disk_usage_based_eviction {
pub(crate) static METRICS: Lazy<Metrics> = Lazy::new(Metrics::default);
}
static TOKIO_EXECUTOR_THREAD_COUNT: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_tokio_executor_thread_configured_count",
"Total number of configued tokio executor threads in the process.
The `setup` label denotes whether we're running with multiple runtimes or a single runtime.",
&["setup"],
)
.unwrap()
});
pub(crate) fn set_tokio_runtime_setup(setup: &str, num_threads: NonZeroUsize) {
static SERIALIZE: std::sync::Mutex<()> = std::sync::Mutex::new(());
let _guard = SERIALIZE.lock().unwrap();
TOKIO_EXECUTOR_THREAD_COUNT.reset();
TOKIO_EXECUTOR_THREAD_COUNT
.get_metric_with_label_values(&[setup])
.unwrap()
.set(u64::try_from(num_threads.get()).unwrap());
}
pub fn preinitialize_metrics() {
// Python tests need these and on some we do alerting.
//

View File

@@ -33,14 +33,13 @@
use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::num::NonZeroUsize;
use std::panic::AssertUnwindSafe;
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use futures::FutureExt;
use pageserver_api::shard::TenantShardId;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tokio::task_local;
use tokio_util::sync::CancellationToken;
@@ -49,11 +48,8 @@ use tracing::{debug, error, info, warn};
use once_cell::sync::Lazy;
use utils::env;
use utils::id::TimelineId;
use crate::metrics::set_tokio_runtime_setup;
//
// There are four runtimes:
//
@@ -102,119 +98,52 @@ use crate::metrics::set_tokio_runtime_setup;
// other operations, if the upload tasks e.g. get blocked on locks. It shouldn't
// happen, but still.
//
pub static COMPUTE_REQUEST_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("compute request worker")
.enable_all()
.build()
.expect("Failed to create compute request runtime")
});
pub(crate) static TOKIO_WORKER_THREADS: Lazy<NonZeroUsize> = Lazy::new(|| {
pub static MGMT_REQUEST_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("mgmt request worker")
.enable_all()
.build()
.expect("Failed to create mgmt request runtime")
});
pub static WALRECEIVER_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("walreceiver worker")
.enable_all()
.build()
.expect("Failed to create walreceiver runtime")
});
pub static BACKGROUND_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("background op worker")
// if you change the number of worker threads please change the constant below
.enable_all()
.build()
.expect("Failed to create background op runtime")
});
pub(crate) static BACKGROUND_RUNTIME_WORKER_THREADS: Lazy<usize> = Lazy::new(|| {
// force init and thus panics
let _ = BACKGROUND_RUNTIME.handle();
// replicates tokio-1.28.1::loom::sys::num_cpus which is not available publicly
// tokio would had already panicked for parsing errors or NotUnicode
//
// this will be wrong if any of the runtimes gets their worker threads configured to something
// else, but that has not been needed in a long time.
NonZeroUsize::new(
std::env::var("TOKIO_WORKER_THREADS")
.map(|s| s.parse::<usize>().unwrap())
.unwrap_or_else(|_e| usize::max(2, num_cpus::get())),
)
.expect("the max() ensures that this is not zero")
std::env::var("TOKIO_WORKER_THREADS")
.map(|s| s.parse::<usize>().unwrap())
.unwrap_or_else(|_e| usize::max(2, num_cpus::get()))
});
enum TokioRuntimeMode {
SingleThreaded,
MultiThreaded { num_workers: NonZeroUsize },
}
impl FromStr for TokioRuntimeMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"current_thread" => Ok(TokioRuntimeMode::SingleThreaded),
s => match s.strip_prefix("multi_thread:") {
Some("default") => Ok(TokioRuntimeMode::MultiThreaded {
num_workers: *TOKIO_WORKER_THREADS,
}),
Some(suffix) => {
let num_workers = suffix.parse::<NonZeroUsize>().map_err(|e| {
format!(
"invalid number of multi-threaded runtime workers ({suffix:?}): {e}",
)
})?;
Ok(TokioRuntimeMode::MultiThreaded { num_workers })
}
None => Err(format!("invalid runtime config: {s:?}")),
},
}
}
}
static ONE_RUNTIME: Lazy<Option<tokio::runtime::Runtime>> = Lazy::new(|| {
let thread_name = "pageserver-tokio";
let Some(mode) = env::var("NEON_PAGESERVER_USE_ONE_RUNTIME") else {
// If the env var is not set, leave this static as None.
set_tokio_runtime_setup(
"multiple-runtimes",
NUM_MULTIPLE_RUNTIMES
.checked_mul(*TOKIO_WORKER_THREADS)
.unwrap(),
);
return None;
};
Some(match mode {
TokioRuntimeMode::SingleThreaded => {
set_tokio_runtime_setup("one-runtime-single-threaded", NonZeroUsize::new(1).unwrap());
tokio::runtime::Builder::new_current_thread()
.thread_name(thread_name)
.enable_all()
.build()
.expect("failed to create one single runtime")
}
TokioRuntimeMode::MultiThreaded { num_workers } => {
set_tokio_runtime_setup("one-runtime-multi-threaded", num_workers);
tokio::runtime::Builder::new_multi_thread()
.thread_name(thread_name)
.enable_all()
.worker_threads(num_workers.get())
.build()
.expect("failed to create one multi-threaded runtime")
}
})
});
/// Declare a lazy static variable named `$varname` that will resolve
/// to a tokio runtime handle. If the env var `NEON_PAGESERVER_USE_ONE_RUNTIME`
/// is set, this will resolve to `ONE_RUNTIME`. Otherwise, the macro invocation
/// declares a separate runtime and the lazy static variable `$varname`
/// will resolve to that separate runtime.
///
/// The result is is that `$varname.spawn()` will use `ONE_RUNTIME` if
/// `NEON_PAGESERVER_USE_ONE_RUNTIME` is set, and will use the separate runtime
/// otherwise.
macro_rules! pageserver_runtime {
($varname:ident, $name:literal) => {
pub static $varname: Lazy<&'static tokio::runtime::Runtime> = Lazy::new(|| {
if let Some(runtime) = &*ONE_RUNTIME {
return runtime;
}
static RUNTIME: Lazy<tokio::runtime::Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name($name)
.worker_threads(TOKIO_WORKER_THREADS.get())
.enable_all()
.build()
.expect(std::concat!("Failed to create runtime ", $name))
});
&*RUNTIME
});
};
}
pageserver_runtime!(COMPUTE_REQUEST_RUNTIME, "compute request worker");
pageserver_runtime!(MGMT_REQUEST_RUNTIME, "mgmt request worker");
pageserver_runtime!(WALRECEIVER_RUNTIME, "walreceiver worker");
pageserver_runtime!(BACKGROUND_RUNTIME, "background op worker");
// Bump this number when adding a new pageserver_runtime!
// SAFETY: it's obviously correct
const NUM_MULTIPLE_RUNTIMES: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(4) };
#[derive(Debug, Clone, Copy)]
pub struct PageserverTaskId(u64);

View File

@@ -51,7 +51,7 @@ use tokio_util::sync::CancellationToken;
use tracing::{info_span, instrument, warn, Instrument};
use utils::{
backoff, completion::Barrier, crashsafe::path_with_suffix_extension, failpoint_support, fs_ext,
id::TimelineId, serde_system_time,
id::TimelineId,
};
use super::{
@@ -591,7 +591,7 @@ impl<'a> TenantDownloader<'a> {
let mut progress = SecondaryProgress {
layers_total: heatmap_stats.layers,
bytes_total: heatmap_stats.bytes,
heatmap_mtime: Some(serde_system_time::SystemTime(heatmap_mtime)),
heatmap_mtime: Some(heatmap_mtime),
layers_downloaded: 0,
bytes_downloaded: 0,
};

View File

@@ -19,7 +19,6 @@ use pageserver_api::models::InMemoryLayerInfo;
use pageserver_api::shard::TenantShardId;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::{Arc, OnceLock};
use std::time::Instant;
use tracing::*;
use utils::{bin_ser::BeSer, id::TimelineId, lsn::Lsn, vec_map::VecMap};
// avoid binding to Write (conflicts with std::io::Write)
@@ -54,8 +53,6 @@ pub struct InMemoryLayer {
/// Writes are only allowed when this is `None`.
end_lsn: OnceLock<Lsn>,
opened_at: Instant,
/// The above fields never change, except for `end_lsn`, which is only set once.
/// All other changing parts are in `inner`, and protected by a mutex.
inner: RwLock<InMemoryLayerInner>,
@@ -463,7 +460,6 @@ impl InMemoryLayer {
tenant_shard_id,
start_lsn,
end_lsn: OnceLock::new(),
opened_at: Instant::now(),
inner: RwLock::new(InMemoryLayerInner {
index: HashMap::new(),
file,
@@ -524,10 +520,6 @@ impl InMemoryLayer {
Ok(())
}
pub(crate) fn get_opened_at(&self) -> Instant {
self.opened_at
}
pub(crate) async fn tick(&self) -> Option<u64> {
let mut inner = self.inner.write().await;
let size = inner.file.len();

View File

@@ -18,7 +18,7 @@ use utils::{backoff, completion};
static CONCURRENT_BACKGROUND_TASKS: once_cell::sync::Lazy<tokio::sync::Semaphore> =
once_cell::sync::Lazy::new(|| {
let total_threads = task_mgr::TOKIO_WORKER_THREADS.get();
let total_threads = *task_mgr::BACKGROUND_RUNTIME_WORKER_THREADS;
let permits = usize::max(
1,
// while a lot of the work is done on spawn_blocking, we still do
@@ -72,7 +72,6 @@ pub(crate) async fn concurrent_background_tasks_rate_limit_permit(
loop_kind == BackgroundLoopKind::InitialLogicalSizeCalculation
);
// TODO: assert that we run on BACKGROUND_RUNTIME; requires tokio_unstable Handle::id();
match CONCURRENT_BACKGROUND_TASKS.acquire().await {
Ok(permit) => permit,
Err(_closed) => unreachable!("we never close the semaphore"),

View File

@@ -1257,7 +1257,7 @@ impl Timeline {
checkpoint_distance,
self.get_last_record_lsn(),
self.last_freeze_at.load(),
open_layer.get_opened_at(),
*self.last_freeze_ts.read().unwrap(),
) {
match open_layer.info() {
InMemoryLayerInfo::Frozen { lsn_start, lsn_end } => {
@@ -1622,7 +1622,7 @@ impl Timeline {
checkpoint_distance: u64,
projected_lsn: Lsn,
last_freeze_at: Lsn,
opened_at: Instant,
last_freeze_ts: Instant,
) -> bool {
let distance = projected_lsn.widening_sub(last_freeze_at);
@@ -1648,13 +1648,13 @@ impl Timeline {
);
true
} else if distance > 0 && opened_at.elapsed() >= self.get_checkpoint_timeout() {
} else if distance > 0 && last_freeze_ts.elapsed() >= self.get_checkpoint_timeout() {
info!(
"Will roll layer at {} with layer size {} due to time since first write to the layer ({:?})",
projected_lsn,
layer_size,
opened_at.elapsed()
);
"Will roll layer at {} with layer size {} due to time since last flush ({:?})",
projected_lsn,
layer_size,
last_freeze_ts.elapsed()
);
true
} else {
@@ -4703,16 +4703,23 @@ struct TimelineWriterState {
max_lsn: Option<Lsn>,
// Cached details of the last freeze. Avoids going trough the atomic/lock on every put.
cached_last_freeze_at: Lsn,
cached_last_freeze_ts: Instant,
}
impl TimelineWriterState {
fn new(open_layer: Arc<InMemoryLayer>, current_size: u64, last_freeze_at: Lsn) -> Self {
fn new(
open_layer: Arc<InMemoryLayer>,
current_size: u64,
last_freeze_at: Lsn,
last_freeze_ts: Instant,
) -> Self {
Self {
open_layer,
current_size,
prev_lsn: None,
max_lsn: None,
cached_last_freeze_at: last_freeze_at,
cached_last_freeze_ts: last_freeze_ts,
}
}
}
@@ -4811,10 +4818,12 @@ impl<'a> TimelineWriter<'a> {
let initial_size = layer.size().await?;
let last_freeze_at = self.last_freeze_at.load();
let last_freeze_ts = *self.last_freeze_ts.read().unwrap();
self.write_guard.replace(TimelineWriterState::new(
layer,
initial_size,
last_freeze_at,
last_freeze_ts,
));
Ok(())
@@ -4861,7 +4870,7 @@ impl<'a> TimelineWriter<'a> {
self.get_checkpoint_distance(),
lsn,
state.cached_last_freeze_at,
state.open_layer.get_opened_at(),
state.cached_last_freeze_ts,
) {
OpenLayerAction::Roll
} else {

View File

@@ -12,6 +12,7 @@ use super::layer_manager::LayerManager;
use super::{CompactFlags, DurationRecorder, RecordedDuration, Timeline};
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use enumset::EnumSet;
use fail::fail_point;
use itertools::Itertools;
@@ -1121,6 +1122,7 @@ impl CompactionLayer<Key> for ResidentDeltaLayer {
}
}
#[async_trait]
impl CompactionDeltaLayer<TimelineAdaptor> for ResidentDeltaLayer {
type DeltaEntry<'a> = DeltaEntry<'a>;

View File

@@ -41,7 +41,7 @@ pub(crate) fn regenerate(tenants_path: &Path) -> anyhow::Result<PageserverUtiliz
//
// note that u64::MAX will be output as i64::MAX as u64, but that should not matter
utilization_score: u64::MAX,
captured_at: utils::serde_system_time::SystemTime(captured_at),
captured_at,
};
// TODO: make utilization_score into a metric

View File

@@ -12,7 +12,6 @@ testing = []
anyhow.workspace = true
async-compression.workspace = true
async-trait.workspace = true
atomic-take.workspace = true
aws-config.workspace = true
aws-sdk-iam.workspace = true
aws-sigv4.workspace = true
@@ -37,9 +36,6 @@ http.workspace = true
humantime.workspace = true
hyper-tungstenite.workspace = true
hyper.workspace = true
hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] }
http-body-util = { version = "0.1" }
ipnet.workspace = true
itertools.workspace = true
lasso = { workspace = true, features = ["multi-threaded"] }

View File

@@ -27,7 +27,7 @@ use crate::{
},
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, Normalize, RoleName};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -186,7 +186,7 @@ impl AuthenticationConfig {
is_cleartext: bool,
) -> auth::Result<AuthSecret> {
// we have validated the endpoint exists, so let's intern it.
let endpoint_int = EndpointIdInt::from(endpoint.normalize());
let endpoint_int = EndpointIdInt::from(endpoint);
// only count the full hash count if password hack or websocket flow.
// in other words, if proxy needs to run the hashing

View File

@@ -189,9 +189,7 @@ struct ProxyCliArgs {
/// cache for `project_info` (use `size=0` to disable)
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
project_info_cache: String,
/// cache for all valid endpoints
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
endpoint_cache_config: String,
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
@@ -412,9 +410,6 @@ async fn main() -> anyhow::Result<()> {
args.region.clone(),
));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
let cache = api.caches.endpoints_cache.clone();
let con = redis_notifications_client.clone();
maintenance_tasks.spawn(async move { cache.do_read(con).await });
}
}
}
@@ -494,18 +489,14 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::WakeComputeLockOptions {
@@ -516,10 +507,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
} = args.wake_compute_lock.parse()?;
info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(
console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout, epoch)
console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout)
.unwrap(),
));
tokio::spawn(locks.garbage_collect_worker());
tokio::spawn(locks.garbage_collect_worker(epoch));
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));

View File

@@ -1,5 +1,4 @@
pub mod common;
pub mod endpoints;
pub mod project_info;
mod timed_lru;

View File

@@ -1,190 +0,0 @@
use std::{
convert::Infallible,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use dashmap::DashSet;
use redis::{
streams::{StreamReadOptions, StreamReadReply},
AsyncCommands, FromRedisValue, Value,
};
use serde::Deserialize;
use tokio::sync::Mutex;
use crate::{
config::EndpointCacheConfig,
context::RequestMonitoring,
intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
metrics::REDIS_BROKEN_MESSAGES,
rate_limiter::GlobalRateLimiter,
redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
EndpointId,
};
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all(deserialize = "snake_case"))]
pub enum ControlPlaneEventKey {
EndpointCreated,
BranchCreated,
ProjectCreated,
}
pub struct EndpointsCache {
config: EndpointCacheConfig,
endpoints: DashSet<EndpointIdInt>,
branches: DashSet<BranchIdInt>,
projects: DashSet<ProjectIdInt>,
ready: AtomicBool,
limiter: Arc<Mutex<GlobalRateLimiter>>,
}
impl EndpointsCache {
pub fn new(config: EndpointCacheConfig) -> Self {
Self {
limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
config.limiter_info.clone(),
))),
config,
endpoints: DashSet::new(),
branches: DashSet::new(),
projects: DashSet::new(),
ready: AtomicBool::new(false),
}
}
pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool {
if !self.ready.load(Ordering::Acquire) {
return true;
}
// If cache is disabled, just collect the metrics and return.
if self.config.disable_cache {
ctx.set_rejected(self.should_reject(endpoint));
return true;
}
// If the limiter allows, we don't need to check the cache.
if self.limiter.lock().await.check() {
return true;
}
let rejected = self.should_reject(endpoint);
ctx.set_rejected(rejected);
!rejected
}
fn should_reject(&self, endpoint: &EndpointId) -> bool {
if endpoint.is_endpoint() {
!self.endpoints.contains(&EndpointIdInt::from(endpoint))
} else if endpoint.is_branch() {
!self
.branches
.contains(&BranchIdInt::from(&endpoint.as_branch()))
} else {
!self
.projects
.contains(&ProjectIdInt::from(&endpoint.as_project()))
}
}
fn insert_event(&self, key: ControlPlaneEventKey, value: String) {
// Do not do normalization here, we expect the events to be normalized.
match key {
ControlPlaneEventKey::EndpointCreated => {
self.endpoints.insert(EndpointIdInt::from(&value.into()));
}
ControlPlaneEventKey::BranchCreated => {
self.branches.insert(BranchIdInt::from(&value.into()));
}
ControlPlaneEventKey::ProjectCreated => {
self.projects.insert(ProjectIdInt::from(&value.into()));
}
}
}
pub async fn do_read(
&self,
mut con: ConnectionWithCredentialsProvider,
) -> anyhow::Result<Infallible> {
let mut last_id = "0-0".to_string();
loop {
self.ready.store(false, Ordering::Release);
if let Err(e) = con.connect().await {
tracing::error!("error connecting to redis: {:?}", e);
continue;
}
if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
tracing::error!("error reading from redis: {:?}", e);
}
}
}
async fn read_from_stream(
&self,
con: &mut ConnectionWithCredentialsProvider,
last_id: &mut String,
) -> anyhow::Result<()> {
tracing::info!("reading endpoints/branches/projects from redis");
self.batch_read(
con,
StreamReadOptions::default().count(self.config.initial_batch_size),
last_id,
true,
)
.await?;
tracing::info!("ready to filter user requests");
self.ready.store(true, Ordering::Release);
self.batch_read(
con,
StreamReadOptions::default()
.count(self.config.initial_batch_size)
.block(self.config.xread_timeout.as_millis() as usize),
last_id,
false,
)
.await
}
fn parse_key_value(key: &str, value: &Value) -> anyhow::Result<(ControlPlaneEventKey, String)> {
Ok((serde_json::from_str(key)?, String::from_redis_value(value)?))
}
async fn batch_read(
&self,
conn: &mut ConnectionWithCredentialsProvider,
opts: StreamReadOptions,
last_id: &mut String,
return_when_finish: bool,
) -> anyhow::Result<()> {
let mut total: usize = 0;
loop {
let mut res: StreamReadReply = conn
.xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
.await?;
if res.keys.len() != 1 {
anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
}
let res = res.keys.pop().expect("Checked length above");
if return_when_finish && res.ids.len() <= self.config.default_batch_size {
break;
}
for x in res.ids {
total += 1;
for (k, v) in x.map {
let (key, value) = match Self::parse_key_value(&k, &v) {
Ok(x) => x,
Err(e) => {
REDIS_BROKEN_MESSAGES
.with_label_values(&[&self.config.stream_name])
.inc();
tracing::error!("error parsing key-value {k}-{v:?}: {e:?}");
continue;
}
};
self.insert_event(key, value);
}
if total.is_power_of_two() {
tracing::debug!("endpoints read {}", total);
}
*last_id = x.id;
}
}
tracing::info!("read {} endpoints/branches/projects from redis", total);
Ok(())
}
}

View File

@@ -313,75 +313,6 @@ impl CertResolver {
}
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.
pub initial_batch_size: usize,
/// Batch size to receive endpoints.
pub default_batch_size: usize,
/// Timeouts for the stream read operation.
pub xread_timeout: Duration,
/// Stream name to read from.
pub stream_name: String,
/// Limiter info (to distinguish when to enable cache).
pub limiter_info: Vec<RateBucketInfo>,
/// Disable cache.
/// If true, cache is ignored, but reports all statistics.
pub disable_cache: bool,
}
impl EndpointCacheConfig {
/// Default options for [`crate::console::provider::NodeInfoCache`].
/// Notice that by default the limiter is empty, which means that cache is disabled.
pub const CACHE_DEFAULT_OPTIONS: &'static str =
"initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s";
/// Parse cache options passed via cmdline.
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
fn parse(options: &str) -> anyhow::Result<Self> {
let mut initial_batch_size = None;
let mut default_batch_size = None;
let mut xread_timeout = None;
let mut stream_name = None;
let mut limiter_info = vec![];
let mut disable_cache = false;
for option in options.split(',') {
let (key, value) = option
.split_once('=')
.with_context(|| format!("bad key-value pair: {option}"))?;
match key {
"initial_batch_size" => initial_batch_size = Some(value.parse()?),
"default_batch_size" => default_batch_size = Some(value.parse()?),
"xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
"stream_name" => stream_name = Some(value.to_string()),
"limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
"disable_cache" => disable_cache = value.parse()?,
unknown => bail!("unknown key: {unknown}"),
}
}
RateBucketInfo::validate(&mut limiter_info)?;
Ok(Self {
initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
stream_name: stream_name.context("missing `stream_name`")?,
disable_cache,
limiter_info,
})
}
}
impl FromStr for EndpointCacheConfig {
type Err = anyhow::Error;
fn from_str(options: &str) -> Result<Self, Self::Err> {
let error = || format!("failed to parse endpoint cache options '{options}'");
Self::parse(options).with_context(error)
}
}
#[derive(Debug)]
pub struct MetricBackupCollectionConfig {
pub interval: Duration,

View File

@@ -8,15 +8,15 @@ use crate::{
backend::{ComputeCredentialKeys, ComputeUserInfo},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
compute,
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
config::{CacheOptions, ProjectInfoCacheOptions},
context::RequestMonitoring,
intern::ProjectIdInt,
scram, EndpointCacheKey,
};
use dashmap::DashMap;
use std::{convert::Infallible, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio::time::Instant;
use tracing::info;
@@ -416,15 +416,12 @@ pub struct ApiCaches {
pub node_info: NodeInfoCache,
/// Cache which stores project_id -> endpoint_ids mapping.
pub project_info: Arc<ProjectInfoCacheImpl>,
/// List of all valid endpoints.
pub endpoints_cache: Arc<EndpointsCache>,
}
impl ApiCaches {
pub fn new(
wake_compute_cache_config: CacheOptions,
project_info_cache_config: ProjectInfoCacheOptions,
endpoint_cache_config: EndpointCacheConfig,
) -> Self {
Self {
node_info: NodeInfoCache::new(
@@ -434,7 +431,6 @@ impl ApiCaches {
true,
),
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
}
}
}
@@ -445,7 +441,6 @@ pub struct ApiLocks {
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
permits: usize,
timeout: Duration,
epoch: std::time::Duration,
registered: prometheus::IntCounter,
unregistered: prometheus::IntCounter,
reclamation_lag: prometheus::Histogram,
@@ -458,7 +453,6 @@ impl ApiLocks {
permits: usize,
shards: usize,
timeout: Duration,
epoch: std::time::Duration,
) -> prometheus::Result<Self> {
let registered = prometheus::IntCounter::with_opts(
prometheus::Opts::new(
@@ -503,7 +497,6 @@ impl ApiLocks {
node_locks: DashMap::with_shard_amount(shards),
permits,
timeout,
epoch,
lock_acquire_lag,
registered,
unregistered,
@@ -543,12 +536,12 @@ impl ApiLocks {
})
}
pub async fn garbage_collect_worker(&self) {
pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
if self.permits == 0 {
return;
}
let mut interval =
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
loop {
for (i, shard) in self.node_locks.shards().iter().enumerate() {
interval.tick().await;

View File

@@ -8,7 +8,6 @@ use super::{
};
use crate::{
auth::backend::ComputeUserInfo, compute, console::messages::ColdStartInfo, http, scram,
Normalize,
};
use crate::{
cache::Cached,
@@ -24,7 +23,7 @@ use tracing::{error, info, info_span, warn, Instrument};
pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub locks: &'static ApiLocks,
locks: &'static ApiLocks,
jwt: String,
}
@@ -56,15 +55,6 @@ impl Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &user_info.endpoint.normalize())
.await
{
info!("endpoint is not valid, skipping the request");
return Ok(AuthInfo::default());
}
let request_id = ctx.session_id.to_string();
let application_name = ctx.console_application_name();
async {
@@ -91,9 +81,7 @@ impl Api {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
Err(e) => match e.http_status_code() {
Some(http::StatusCode::NOT_FOUND) => {
return Ok(AuthInfo::default());
}
Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
_otherwise => return Err(e.into()),
},
};
@@ -186,27 +174,23 @@ impl super::Api for Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
let ep = &user_info.endpoint;
let user = &user_info.user;
if let Some(role_secret) = self
.caches
.project_info
.get_role_secret(normalized_ep, user)
{
if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
return Ok(role_secret);
}
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
let ep_int = ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
ep_int,
Arc::new(auth_info.allowed_ips),
);
ctx.set_project_id(project_id);
@@ -220,8 +204,8 @@ impl super::Api for Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
let ep = &user_info.endpoint;
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["hit"])
.inc();
@@ -234,18 +218,16 @@ impl super::Api for Api {
let allowed_ips = Arc::new(auth_info.allowed_ips);
let user = &user_info.user;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
let ep_int = ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
self.caches
.project_info
.insert_allowed_ips(project_id, ep_int, allowed_ips.clone());
ctx.set_project_id(project_id);
}
Ok((

View File

@@ -12,9 +12,7 @@ use crate::{
console::messages::{ColdStartInfo, MetricsAuxInfo},
error::ErrorKind,
intern::{BranchIdInt, ProjectIdInt},
metrics::{
bool_to_str, LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND, NUM_INVALID_ENDPOINTS,
},
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
DbName, EndpointId, RoleName,
};
@@ -52,8 +50,6 @@ pub struct RequestMonitoring {
// This sender is here to keep the request monitoring channel open while requests are taking place.
sender: Option<mpsc::UnboundedSender<RequestData>>,
pub latency_timer: LatencyTimer,
// Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
rejected: bool,
}
#[derive(Clone, Debug)]
@@ -97,7 +93,6 @@ impl RequestMonitoring {
error_kind: None,
auth_method: None,
success: false,
rejected: false,
cold_start_info: ColdStartInfo::Unknown,
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
@@ -118,10 +113,6 @@ impl RequestMonitoring {
)
}
pub fn set_rejected(&mut self, rejected: bool) {
self.rejected = rejected;
}
pub fn set_cold_start_info(&mut self, info: ColdStartInfo) {
self.cold_start_info = info;
self.latency_timer.cold_start_info(info);
@@ -187,10 +178,6 @@ impl RequestMonitoring {
impl Drop for RequestMonitoring {
fn drop(&mut self) {
let outcome = if self.success { "success" } else { "failure" };
NUM_INVALID_ENDPOINTS
.with_label_values(&[self.protocol, bool_to_str(self.rejected), outcome])
.inc();
if let Some(tx) = self.sender.take() {
let _: Result<(), _> = tx.send(RequestData::from(&*self));
}

View File

@@ -160,11 +160,6 @@ impl From<&EndpointId> for EndpointIdInt {
EndpointIdTag::get_interner().get_or_intern(value)
}
}
impl From<EndpointId> for EndpointIdInt {
fn from(value: EndpointId) -> Self {
EndpointIdTag::get_interner().get_or_intern(&value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct BranchIdTag;
@@ -180,11 +175,6 @@ impl From<&BranchId> for BranchIdInt {
BranchIdTag::get_interner().get_or_intern(value)
}
}
impl From<BranchId> for BranchIdInt {
fn from(value: BranchId) -> Self {
BranchIdTag::get_interner().get_or_intern(&value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct ProjectIdTag;
@@ -200,11 +190,6 @@ impl From<&ProjectId> for ProjectIdInt {
ProjectIdTag::get_interner().get_or_intern(value)
}
}
impl From<ProjectId> for ProjectIdInt {
fn from(value: ProjectId) -> Self {
ProjectIdTag::get_interner().get_or_intern(&value)
}
}
#[cfg(test)]
mod tests {

View File

@@ -127,24 +127,6 @@ macro_rules! smol_str_wrapper {
};
}
const POOLER_SUFFIX: &str = "-pooler";
pub trait Normalize {
fn normalize(&self) -> Self;
}
impl<S: Clone + AsRef<str> + From<String>> Normalize for S {
fn normalize(&self) -> Self {
if self.as_ref().ends_with(POOLER_SUFFIX) {
let mut s = self.as_ref().to_string();
s.truncate(s.len() - POOLER_SUFFIX.len());
s.into()
} else {
self.clone()
}
}
}
// 90% of role name strings are 20 characters or less.
smol_str_wrapper!(RoleName);
// 50% of endpoint strings are 23 characters or less.
@@ -158,22 +140,3 @@ smol_str_wrapper!(ProjectId);
smol_str_wrapper!(EndpointCacheKey);
smol_str_wrapper!(DbName);
// Endpoints are a bit tricky. Rare they might be branches or projects.
impl EndpointId {
pub fn is_endpoint(&self) -> bool {
self.0.starts_with("ep-")
}
pub fn is_branch(&self) -> bool {
self.0.starts_with("br-")
}
pub fn is_project(&self) -> bool {
!self.is_endpoint() && !self.is_branch()
}
pub fn as_branch(&self) -> BranchId {
BranchId(self.0.clone())
}
pub fn as_project(&self) -> ProjectId {
ProjectId(self.0.clone())
}
}

View File

@@ -169,18 +169,6 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
.unwrap()
});
pub static NUM_INVALID_ENDPOINTS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_invalid_endpoints_total",
"Number of invalid endpoints (per protocol, per rejected).",
// http/ws/tcp, true/false, success/failure
// TODO(anna): the last dimension is just a proxy to what we actually want to measure.
// We need to measure whether the endpoint was found by cplane or not.
&["protocol", "rejected", "outcome"],
)
.unwrap()
});
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client";
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis";

View File

@@ -5,13 +5,19 @@ use std::{
io,
net::SocketAddr,
pin::{pin, Pin},
sync::Mutex,
task::{ready, Context, Poll},
};
use bytes::{Buf, BytesMut};
use hyper::server::conn::AddrIncoming;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use metrics::IntCounterPairGuard;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use uuid::Uuid;
use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
pub struct ProxyProtocolAccept {
pub incoming: AddrIncoming,
@@ -325,6 +331,103 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
}
}
impl Accept for ProxyProtocolAccept {
type Conn = WithConnectionGuard<WithClientIp<AddrStream>>;
type Error = io::Error;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
let conn_id = uuid::Uuid::new_v4();
let span = tracing::info_span!("http_conn", ?conn_id);
{
let _enter = span.enter();
tracing::info!("accepted new TCP connection");
}
let Some(conn) = conn else {
return Poll::Ready(None);
};
Poll::Ready(Some(Ok(WithConnectionGuard {
inner: WithClientIp::new(conn),
connection_id: Uuid::new_v4(),
gauge: Mutex::new(Some(
NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&[self.protocol])
.guard(),
)),
span,
})))
}
}
pin_project! {
pub struct WithConnectionGuard<T> {
#[pin]
pub inner: T,
pub connection_id: Uuid,
pub gauge: Mutex<Option<IntCounterPairGuard>>,
pub span: tracing::Span,
}
impl<S> PinnedDrop for WithConnectionGuard<S> {
fn drop(this: Pin<&mut Self>) {
let _enter = this.span.enter();
tracing::info!("HTTP connection closed")
}
}
}
impl<T: AsyncWrite> AsyncWrite for WithConnectionGuard<T> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_shutdown(cx)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
impl<T: AsyncRead> AsyncRead for WithConnectionGuard<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}
#[cfg(test)]
mod tests {
use std::pin::pin;

View File

@@ -20,7 +20,7 @@ use crate::{
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey, Normalize,
EndpointCacheKey,
};
use futures::TryFutureExt;
use itertools::Itertools;
@@ -280,7 +280,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
// check rate limit
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep.normalize(), 1) {
if !endpoint_rate_limiter.check(ep, 1) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await?;

View File

@@ -4,4 +4,4 @@ mod limiter;
pub use aimd::Aimd;
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
pub use limiter::Limiter;
pub use limiter::{AuthRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo};
pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};

View File

@@ -24,13 +24,13 @@ use super::{
RateLimiterConfig,
};
pub struct GlobalRateLimiter {
pub struct RedisRateLimiter {
data: Vec<RateBucket>,
info: Vec<RateBucketInfo>,
info: &'static [RateBucketInfo],
}
impl GlobalRateLimiter {
pub fn new(info: Vec<RateBucketInfo>) -> Self {
impl RedisRateLimiter {
pub fn new(info: &'static [RateBucketInfo]) -> Self {
Self {
data: vec![
RateBucket {
@@ -50,7 +50,7 @@ impl GlobalRateLimiter {
let should_allow_request = self
.data
.iter_mut()
.zip(&self.info)
.zip(self.info)
.all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
if should_allow_request {

View File

@@ -5,7 +5,7 @@ use redis::AsyncCommands;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
use super::{
connection_with_credentials_provider::ConnectionWithCredentialsProvider,
@@ -80,7 +80,7 @@ impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
pub struct RedisPublisherClient {
client: ConnectionWithCredentialsProvider,
region_id: String,
limiter: GlobalRateLimiter,
limiter: RedisRateLimiter,
}
impl RedisPublisherClient {
@@ -92,7 +92,7 @@ impl RedisPublisherClient {
Ok(Self {
client,
region_id,
limiter: GlobalRateLimiter::new(info.into()),
limiter: RedisRateLimiter::new(info),
})
}

View File

@@ -4,48 +4,42 @@
mod backend;
mod conn_pool;
mod http_util;
mod json;
mod sql_over_http;
pub mod tls_listener;
mod websocket;
use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::Full;
use hyper1::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use anyhow::bail;
use hyper::StatusCode;
use metrics::IntCounterPairGuard;
use rand::rngs::StdRng;
use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::task::TaskTracker;
use tracing::instrument::Instrumented;
use crate::cancellation::CancellationHandlerMain;
use crate::config::ProxyConfig;
use crate::context::RequestMonitoring;
use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES};
use crate::protocol2::WithClientIp;
use crate::proxy::run_until_cancelled;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use hyper::{
server::conn::{AddrIncoming, AddrStream},
Body, Method, Request, Response,
};
use std::net::{IpAddr, SocketAddr};
use std::pin::pin;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use std::task::Poll;
use tls_listener::TlsListener;
use tokio::net::TcpListener;
use tokio_util::sync::{CancellationToken, DropGuard};
use tracing::{error, info, warn, Instrument};
use utils::http::error::ApiError;
use utils::http::{error::ApiError, json::json_response};
pub const SERVERLESS_DRIVER_SNI: &str = "api";
@@ -97,174 +91,161 @@ pub async fn task_main(
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
let _ = addr_incoming.set_nodelay(true);
let addr_incoming = ProxyProtocolAccept {
incoming: addr_incoming,
protocol: "http",
};
let server = Builder::new(hyper_util::rt::TokioExecutor::new());
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete`
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
tracing::error!("could not set nodelay: {e}");
continue;
}
let conn_id = uuid::Uuid::new_v4();
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming, config.handshake_timeout);
connections.spawn(
connection_handler(
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
cancellation_token.clone(),
server.clone(),
tls_acceptor.clone(),
conn,
peer_addr,
)
.instrument(http_conn_span),
);
}
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<
WithConnectionGuard<WithClientIp<AddrStream>>,
>| {
let (conn, _) = stream.get_ref();
connections.wait().await;
// this is jank. should dissapear with hyper 1.0 migration.
let gauge = conn
.gauge
.lock()
.expect("lock should not be poisoned")
.take()
.expect("gauge should be set on connection start");
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new();
let cancel_connection = http_cancellation_token.clone().drop_guard();
let span = conn.span.clone();
let client_addr = conn.inner.client_addr();
let remote_addr = conn.inner.inner.remote_addr();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
let peer_addr = match client_addr {
Some(addr) => addr,
None if config.require_client_ip => bail!("missing required client ip"),
None => remote_addr,
};
Ok(MetricService::new(
hyper::service::service_fn(move |req: Request<Body>| {
let backend = backend.clone();
let ws_connections2 = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
let http_cancellation_token = http_cancellation_token.child_token();
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
ws_connections.spawn(
async move {
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let _cancel_session = http_cancellation_token.clone().drop_guard();
let res = request_handler(
req,
config,
backend,
ws_connections2,
cancellation_handler,
peer_addr.ip(),
endpoint_rate_limiter,
http_cancellation_token,
)
.await
.map_or_else(|e| e.into_response(), |r| r);
_cancel_session.disarm();
res
}
.in_current_span(),
)
}),
gauge,
cancel_connection,
span,
))
}
},
);
hyper::Server::builder(tls_listener)
.serve(make_svc)
.with_graceful_shutdown(cancellation_token.cancelled())
.await?;
// await websocket connections
ws_connections.wait().await;
Ok(())
}
/// Handles the TCP lifecycle.
///
/// 1. Parses PROXY protocol V2
/// 2. Handles TLS handshake
/// 3. Handles HTTP connection
/// 1. With graceful shutdowns
/// 2. With graceful request cancellation with connection failure
/// 3. With websocket upgrade support.
#[allow(clippy::too_many_arguments)]
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
server: Builder<TokioExecutor>,
tls_acceptor: TlsAcceptor,
conn: TcpStream,
peer_addr: SocketAddr,
) {
let session_id = uuid::Uuid::new_v4();
struct MetricService<S> {
inner: S,
_gauge: IntCounterPairGuard,
_cancel: DropGuard,
span: tracing::Span,
}
let _gauge = NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&["http"])
.guard();
// handle PROXY protocol
let mut conn = WithClientIp::new(conn);
let peer = match conn.wait_for_addr().await {
Ok(peer) => peer,
Err(e) => {
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return;
impl<S> MetricService<S> {
fn new(
inner: S,
_gauge: IntCounterPairGuard,
_cancel: DropGuard,
span: tracing::Span,
) -> MetricService<S> {
MetricService {
inner,
_gauge,
_cancel,
span,
}
};
}
}
let peer_addr = peer.unwrap_or(peer_addr).ip();
info!(?session_id, %peer_addr, "accepted new TCP connection");
impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
where
S: hyper::service::Service<Request<ReqBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = Instrumented<S::Future>;
// try upgrade to TLS, but with a timeout.
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
Ok(Ok(conn)) => {
info!(?session_id, %peer_addr, "accepted new TLS connection");
conn
}
// The handshake failed
Ok(Err(e)) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
}
// The handshake timed out
Err(e) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(?session_id, %peer_addr, "failed to accept TLS connection: {e:?}");
return;
}
};
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
let session_id = AtomicTake::new(session_id);
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
// First HTTP request shares the same session ID
let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let http_request_token = http_cancellation_token.child_token();
let cancel_request = http_request_token.clone().drop_guard();
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let handler = connections.spawn(
request_handler(
req,
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
session_id,
peer_addr,
endpoint_rate_limiter.clone(),
http_request_token,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
);
async move {
let res = handler.await;
cancel_request.disarm();
res
}
}),
);
// On cancellation, trigger the HTTP connection handler to shut down.
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
Either::Left((_cancelled, mut conn)) => {
conn.as_mut().graceful_shutdown();
conn.await
}
Either::Right((res, _)) => res,
};
match res {
Ok(()) => tracing::info!(%peer_addr, "HTTP connection closed"),
Err(e) => tracing::warn!(%peer_addr, "HTTP connection error {e}"),
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
self.span
.in_scope(|| self.inner.call(req))
.instrument(self.span.clone())
}
}
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: hyper1::Request<Incoming>,
mut request: Request<Body>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandlerMain>,
session_id: uuid::Uuid,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<Body>, ApiError> {
let session_id = uuid::Uuid::new_v4();
let host = request
.headers()
.get("host")
@@ -301,14 +282,14 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
let span = ctx.span.clone();
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.await
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")
.header("Access-Control-Allow-Origin", "*")
@@ -318,7 +299,7 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Full::new(Bytes::new()))
.body(Body::empty())
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")

View File

@@ -1,92 +0,0 @@
//! Things stolen from `libs/utils/src/http` to add hyper 1.0 compatibility
//! Will merge back in at some point in the future.
use bytes::Bytes;
use anyhow::Context;
use http::{Response, StatusCode};
use http_body_util::Full;
use serde::Serialize;
use utils::http::error::ApiError;
/// Like [`ApiError::into_response`]
pub fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
StatusCode::BAD_REQUEST,
),
ApiError::Forbidden(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
}
ApiError::Unauthorized(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
}
ApiError::NotFound(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
}
ApiError::Conflict(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
}
ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
this.to_string(),
StatusCode::PRECONDITION_FAILED,
),
ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
"Shutting down".to_string(),
StatusCode::SERVICE_UNAVAILABLE,
),
ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::SERVICE_UNAVAILABLE,
),
ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::REQUEST_TIMEOUT,
),
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::INTERNAL_SERVER_ERROR,
),
}
}
/// Same as [`utils::http::error::HttpErrorBody`]
#[derive(Serialize)]
struct HttpErrorBody {
pub msg: String,
}
impl HttpErrorBody {
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
HttpErrorBody { msg }.to_response(status)
}
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
.unwrap()
}
}
/// Same as [`utils::http::json::json_response`]
pub fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<Full<Bytes>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;
let response = Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json)))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}

View File

@@ -1,22 +1,18 @@
use std::pin::pin;
use std::sync::Arc;
use bytes::Bytes;
use futures::future::select;
use futures::future::try_join;
use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
use hyper1::body::Incoming;
use hyper1::header;
use hyper1::http::HeaderName;
use hyper1::http::HeaderValue;
use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use hyper::body::HttpBody;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use serde_json::json;
use serde_json::Value;
use tokio::time;
@@ -33,6 +29,7 @@ use tracing::error;
use tracing::info;
use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
@@ -55,7 +52,6 @@ use crate::RoleName;
use super::backend::PoolingBackend;
use super::conn_pool::Client;
use super::conn_pool::ConnInfo;
use super::http_util::json_response;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
@@ -222,10 +218,10 @@ fn get_conn_info(
pub async fn handle(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
request: Request<Incoming>,
request: Request<Body>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<Body>, ApiError> {
let result = handle_inner(cancel, config, &mut ctx, request, backend).await;
let mut response = match result {
@@ -336,9 +332,10 @@ pub async fn handle(
}
};
response
.headers_mut()
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
response.headers_mut().insert(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
);
Ok(response)
}
@@ -399,7 +396,7 @@ impl UserFacingError for SqlOverHttpError {
#[derive(Debug, thiserror::Error)]
pub enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper1::Error),
Read(#[from] hyper::Error),
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
@@ -440,7 +437,7 @@ struct HttpHeaders {
}
impl HttpHeaders {
fn try_parse(headers: &hyper1::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
fn try_parse(headers: &hyper::http::HeaderMap) -> Result<Self, SqlOverHttpError> {
// Determine the output options. Default behaviour is 'false'. Anything that is not
// strictly 'true' assumed to be false.
let raw_output = headers.get(&RAW_TEXT_OUTPUT) == Some(&HEADER_VALUE_TRUE);
@@ -491,9 +488,9 @@ async fn handle_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Incoming>,
request: Request<Body>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
) -> Result<Response<Body>, SqlOverHttpError> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&[ctx.protocol])
.guard();
@@ -531,7 +528,7 @@ async fn handle_inner(
}
let fetch_and_process_request = async {
let body = request.into_body().collect().await?.to_bytes();
let body = hyper::body::to_bytes(request.into_body()).await?;
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
@@ -599,7 +596,7 @@ async fn handle_inner(
let body = serde_json::to_string(&result).expect("json serialization should not fail");
let len = body.len();
let response = response
.body(Full::new(Bytes::from(body)))
.body(Body::from(body))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
@@ -642,7 +639,6 @@ impl QueryData {
}
// The query was cancelled.
Either::Right((_cancelled, query)) => {
tracing::info!("cancelling query");
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
}

View File

@@ -0,0 +1,123 @@
use std::{
convert::Infallible,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use hyper::server::{accept::Accept, conn::AddrStream};
use pin_project_lite::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite},
task::JoinSet,
time::timeout,
};
use tokio_rustls::{server::TlsStream, TlsAcceptor};
use tracing::{info, warn, Instrument};
use crate::{
metrics::TLS_HANDSHAKE_FAILURES,
protocol2::{WithClientIp, WithConnectionGuard},
};
pin_project! {
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
/// encrypted using TLS.
pub(crate) struct TlsListener<A: Accept> {
#[pin]
listener: A,
tls: TlsAcceptor,
waiting: JoinSet<Option<TlsStream<A::Conn>>>,
timeout: Duration,
}
}
impl<A: Accept> TlsListener<A> {
/// Create a `TlsListener` with default options.
pub(crate) fn new(tls: TlsAcceptor, listener: A, timeout: Duration) -> Self {
TlsListener {
listener,
tls,
waiting: JoinSet::new(),
timeout,
}
}
}
impl<A> Accept for TlsListener<A>
where
A: Accept<Conn = WithConnectionGuard<WithClientIp<AddrStream>>>,
A::Error: std::error::Error,
A::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Conn = TlsStream<A::Conn>;
type Error = Infallible;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let mut this = self.project();
loop {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Some(Ok(mut conn))) => {
let t = *this.timeout;
let tls = this.tls.clone();
let span = conn.span.clone();
this.waiting.spawn(async move {
let peer_addr = match conn.inner.wait_for_addr().await {
Ok(Some(addr)) => addr,
Err(e) => {
tracing::error!("failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return None;
}
Ok(None) => conn.inner.inner.remote_addr()
};
let accept = tls.accept(conn);
match timeout(t, accept).await {
Ok(Ok(conn)) => {
info!(%peer_addr, "accepted new TLS connection");
Some(conn)
},
// The handshake failed, try getting another connection from the queue
Ok(Err(e)) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, "failed to accept TLS connection: {e:?}");
None
}
// The handshake timed out, try getting another connection from the queue
Err(_) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, "failed to accept TLS connection: timeout");
None
}
}
}.instrument(span));
}
Poll::Ready(Some(Err(e))) => {
tracing::error!("error accepting TCP connection: {e}");
continue;
}
Poll::Ready(None) => return Poll::Ready(None),
}
}
loop {
return match this.waiting.poll_join_next(cx) {
Poll::Ready(Some(Ok(Some(conn)))) => Poll::Ready(Some(Ok(conn))),
// The handshake failed to complete, try getting another connection from the queue
Poll::Ready(Some(Ok(None))) => continue,
// The handshake panicked or was cancelled. ignore and get another connection
Poll::Ready(Some(Err(e))) => {
tracing::warn!("handshake aborted: {e}");
continue;
}
_ => Poll::Pending,
};
}
}
}

View File

@@ -15,7 +15,8 @@ FLAKY_TESTS_QUERY = """
DISTINCT parent_suite, suite, name
FROM results
WHERE
started_at > CURRENT_DATE - INTERVAL '%s' day
started_at > CURRENT_DATE - INTERVAL '10' day
AND started_at > '2024-03-11 14:50:11.845+00' -- we switched the default PAGESERVER_VIRTUAL_FILE_IO_ENGINE to `tokio-epoll-uring` from `std-fs` on this date, we want to ignore the flaky tests for `std-fs`
AND (
(status IN ('failed', 'broken') AND reference = 'refs/heads/main')
OR flaky

View File

@@ -22,7 +22,7 @@ parser.add_argument("--safekeeper-host", required=True, type=str)
args = parser.parse_args()
access_key = os.getenv("CONSOLE_API_TOKEN")
endpoint: str = "https://console-stage.neon.build/api"
endpoint: str = "https://console.stage.neon.tech/api"
trash_dir: Path = args.trash_dir
dry_run: bool = args.dry_run

View File

@@ -3,7 +3,7 @@
3. Issue admin token (add/remove .stage from url for staging/prod and setting proper API key):
```
# staging:
AUTH_TOKEN=$(curl https://console-stage.neon.build/regions/console/api/v1/admin/issue_token -H "Accept: application/json" -H "Content-Type: application/json" -H "Authorization: Bearer $NEON_STAGING_KEY" -X POST -d '{"ttl_seconds": 43200, "scope": "safekeeperdata"}' 2>/dev/null | jq --raw-output '.jwt')
AUTH_TOKEN=$(curl https://console.stage.neon.tech/regions/console/api/v1/admin/issue_token -H "Accept: application/json" -H "Content-Type: application/json" -H "Authorization: Bearer $NEON_STAGING_KEY" -X POST -d '{"ttl_seconds": 43200, "scope": "safekeeperdata"}' 2>/dev/null | jq --raw-output '.jwt')
# prod:
AUTH_TOKEN=$(curl https://console.neon.tech/regions/console/api/v1/admin/issue_token -H "Accept: application/json" -H "Content-Type: application/json" -H "Authorization: Bearer $NEON_PROD_KEY" -X POST -d '{"ttl_seconds": 43200, "scope": "safekeeperdata"}' 2>/dev/null | jq --raw-output '.jwt')
# check

View File

@@ -8,7 +8,6 @@ use futures::Future;
use hyper::header::CONTENT_TYPE;
use hyper::{Body, Request, Response};
use hyper::{StatusCode, Uri};
use metrics::{BuildInfo, NeonMetrics};
use pageserver_api::models::{
TenantConfigRequest, TenantCreateRequest, TenantLocationConfigRequest, TenantShardSplitRequest,
TenantTimeTravelRequest, TimelineCreateRequest,
@@ -45,19 +44,15 @@ use control_plane::storage_controller::{AttachHookRequest, InspectRequest};
use routerify::Middleware;
/// State available to HTTP request handlers
#[derive(Clone)]
pub struct HttpState {
service: Arc<crate::service::Service>,
auth: Option<Arc<SwappableJwtAuth>>,
neon_metrics: NeonMetrics,
allowlist_routes: Vec<Uri>,
}
impl HttpState {
pub fn new(
service: Arc<crate::service::Service>,
auth: Option<Arc<SwappableJwtAuth>>,
build_info: BuildInfo,
) -> Self {
pub fn new(service: Arc<crate::service::Service>, auth: Option<Arc<SwappableJwtAuth>>) -> Self {
let allowlist_routes = ["/status", "/ready", "/metrics"]
.iter()
.map(|v| v.parse().unwrap())
@@ -65,7 +60,6 @@ impl HttpState {
Self {
service,
auth,
neon_metrics: NeonMetrics::new(build_info),
allowlist_routes,
}
}
@@ -678,11 +672,10 @@ fn epilogue_metrics_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>
})
}
pub async fn measured_metrics_handler(req: Request<Body>) -> Result<Response<Body>, ApiError> {
pub async fn measured_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
pub const TEXT_FORMAT: &str = "text/plain; version=0.0.4";
let state = get_state(&req);
let payload = crate::metrics::METRICS_REGISTRY.encode(&state.neon_metrics);
let payload = crate::metrics::METRICS_REGISTRY.encode();
let response = Response::builder()
.status(200)
.header(CONTENT_TYPE, TEXT_FORMAT)
@@ -711,7 +704,6 @@ where
pub fn make_router(
service: Arc<Service>,
auth: Option<Arc<SwappableJwtAuth>>,
build_info: BuildInfo,
) -> RouterBuilder<hyper::Body, ApiError> {
let mut router = endpoint::make_router()
.middleware(prologue_metrics_middleware())
@@ -728,7 +720,7 @@ pub fn make_router(
}
router
.data(Arc::new(HttpState::new(service, auth, build_info)))
.data(Arc::new(HttpState::new(service, auth)))
.get("/metrics", |r| {
named_request_span(r, measured_metrics_handler, RequestName("metrics"))
})

View File

@@ -3,7 +3,6 @@ use camino::Utf8PathBuf;
use clap::Parser;
use diesel::Connection;
use metrics::launch_timestamp::LaunchTimestamp;
use metrics::BuildInfo;
use std::sync::Arc;
use storage_controller::http::make_router;
use storage_controller::metrics::preinitialize_metrics;
@@ -193,11 +192,6 @@ async fn async_main() -> anyhow::Result<()> {
args.listen
);
let build_info = BuildInfo {
revision: GIT_VERSION,
build_tag: BUILD_TAG,
};
let strict_mode = if args.dev {
StrictMode::Dev
} else {
@@ -259,7 +253,7 @@ async fn async_main() -> anyhow::Result<()> {
let auth = secrets
.public_key
.map(|jwt_auth| Arc::new(SwappableJwtAuth::new(jwt_auth)));
let router = make_router(service.clone(), auth, build_info)
let router = make_router(service.clone(), auth)
.build()
.map_err(|err| anyhow!(err))?;
let router_service = utils::http::RouterService::new(router).unwrap();

View File

@@ -8,8 +8,10 @@
//! The rest of the code defines label group types and deals with converting outer types to labels.
//!
use bytes::Bytes;
use measured::{label::LabelValue, metric::histogram, FixedCardinalityLabel, MetricGroup};
use metrics::NeonMetrics;
use measured::{
label::{LabelValue, StaticLabelSet},
FixedCardinalityLabel, MetricGroup,
};
use once_cell::sync::Lazy;
use std::sync::Mutex;
@@ -24,15 +26,13 @@ pub fn preinitialize_metrics() {
pub(crate) struct StorageControllerMetrics {
pub(crate) metrics_group: StorageControllerMetricGroup,
encoder: Mutex<measured::text::BufferedTextEncoder>,
encoder: Mutex<measured::text::TextEncoder>,
}
#[derive(measured::MetricGroup)]
#[metric(new())]
pub(crate) struct StorageControllerMetricGroup {
/// Count of how many times we spawn a reconcile task
pub(crate) storage_controller_reconcile_spawn: measured::Counter,
/// Reconciler tasks completed, broken down by success/failure/cancelled
pub(crate) storage_controller_reconcile_complete:
measured::CounterVec<ReconcileCompleteLabelGroupSet>,
@@ -43,9 +43,7 @@ pub(crate) struct StorageControllerMetricGroup {
/// HTTP request status counters for handled requests
pub(crate) storage_controller_http_request_status:
measured::CounterVec<HttpRequestStatusLabelGroupSet>,
/// HTTP request handler latency across all status codes
#[metric(metadata = histogram::Thresholds::exponential_buckets(0.1, 2.0))]
pub(crate) storage_controller_http_request_latency:
measured::HistogramVec<HttpRequestLatencyLabelGroupSet, 5>,
@@ -57,7 +55,6 @@ pub(crate) struct StorageControllerMetricGroup {
/// Latency of HTTP requests to the pageserver, broken down by pageserver
/// node id, request name and method. This include both successful and unsuccessful
/// requests.
#[metric(metadata = histogram::Thresholds::exponential_buckets(0.1, 2.0))]
pub(crate) storage_controller_pageserver_request_latency:
measured::HistogramVec<PageserverRequestLabelGroupSet, 5>,
@@ -69,7 +66,6 @@ pub(crate) struct StorageControllerMetricGroup {
/// Latency of pass-through HTTP requests to the pageserver, broken down by pageserver
/// node id, request name and method. This include both successful and unsuccessful
/// requests.
#[metric(metadata = histogram::Thresholds::exponential_buckets(0.1, 2.0))]
pub(crate) storage_controller_passthrough_request_latency:
measured::HistogramVec<PageserverRequestLabelGroupSet, 5>,
@@ -78,34 +74,76 @@ pub(crate) struct StorageControllerMetricGroup {
measured::CounterVec<DatabaseQueryErrorLabelGroupSet>,
/// Latency of database queries, broken down by operation.
#[metric(metadata = histogram::Thresholds::exponential_buckets(0.1, 2.0))]
pub(crate) storage_controller_database_query_latency:
measured::HistogramVec<DatabaseQueryLatencyLabelGroupSet, 5>,
}
impl StorageControllerMetrics {
pub(crate) fn encode(&self, neon_metrics: &NeonMetrics) -> Bytes {
pub(crate) fn encode(&self) -> Bytes {
let mut encoder = self.encoder.lock().unwrap();
neon_metrics
.collect_group_into(&mut *encoder)
.unwrap_or_else(|infallible| match infallible {});
self.metrics_group
.collect_group_into(&mut *encoder)
.unwrap_or_else(|infallible| match infallible {});
self.metrics_group.collect_into(&mut *encoder);
encoder.finish()
}
}
impl Default for StorageControllerMetrics {
fn default() -> Self {
let mut metrics_group = StorageControllerMetricGroup::new();
metrics_group
.storage_controller_reconcile_complete
.init_all_dense();
Self {
metrics_group,
encoder: Mutex::new(measured::text::BufferedTextEncoder::new()),
metrics_group: StorageControllerMetricGroup::new(),
encoder: Mutex::new(measured::text::TextEncoder::new()),
}
}
}
impl StorageControllerMetricGroup {
pub(crate) fn new() -> Self {
Self {
storage_controller_reconcile_spawn: measured::Counter::new(),
storage_controller_reconcile_complete: measured::CounterVec::new(
ReconcileCompleteLabelGroupSet {
status: StaticLabelSet::new(),
},
),
storage_controller_schedule_optimization: measured::Counter::new(),
storage_controller_http_request_status: measured::CounterVec::new(
HttpRequestStatusLabelGroupSet {
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
status: StaticLabelSet::new(),
},
),
storage_controller_http_request_latency: measured::HistogramVec::new(
measured::metric::histogram::Thresholds::exponential_buckets(0.1, 2.0),
),
storage_controller_pageserver_request_error: measured::CounterVec::new(
PageserverRequestLabelGroupSet {
pageserver_id: lasso::ThreadedRodeo::new(),
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
},
),
storage_controller_pageserver_request_latency: measured::HistogramVec::new(
measured::metric::histogram::Thresholds::exponential_buckets(0.1, 2.0),
),
storage_controller_passthrough_request_error: measured::CounterVec::new(
PageserverRequestLabelGroupSet {
pageserver_id: lasso::ThreadedRodeo::new(),
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
},
),
storage_controller_passthrough_request_latency: measured::HistogramVec::new(
measured::metric::histogram::Thresholds::exponential_buckets(0.1, 2.0),
),
storage_controller_database_query_error: measured::CounterVec::new(
DatabaseQueryErrorLabelGroupSet {
operation: StaticLabelSet::new(),
error_type: StaticLabelSet::new(),
},
),
storage_controller_database_query_latency: measured::HistogramVec::new(
measured::metric::histogram::Thresholds::exponential_buckets(0.1, 2.0),
),
}
}
}
@@ -119,7 +157,7 @@ pub(crate) struct ReconcileCompleteLabelGroup {
#[derive(measured::LabelGroup)]
#[label(set = HttpRequestStatusLabelGroupSet)]
pub(crate) struct HttpRequestStatusLabelGroup<'a> {
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
#[label(dynamic_with = lasso::ThreadedRodeo)]
pub(crate) path: &'a str,
pub(crate) method: Method,
pub(crate) status: StatusCode,
@@ -128,21 +166,40 @@ pub(crate) struct HttpRequestStatusLabelGroup<'a> {
#[derive(measured::LabelGroup)]
#[label(set = HttpRequestLatencyLabelGroupSet)]
pub(crate) struct HttpRequestLatencyLabelGroup<'a> {
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
#[label(dynamic_with = lasso::ThreadedRodeo)]
pub(crate) path: &'a str,
pub(crate) method: Method,
}
impl Default for HttpRequestLatencyLabelGroupSet {
fn default() -> Self {
Self {
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
}
}
}
#[derive(measured::LabelGroup, Clone)]
#[label(set = PageserverRequestLabelGroupSet)]
pub(crate) struct PageserverRequestLabelGroup<'a> {
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
#[label(dynamic_with = lasso::ThreadedRodeo)]
pub(crate) pageserver_id: &'a str,
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
#[label(dynamic_with = lasso::ThreadedRodeo)]
pub(crate) path: &'a str,
pub(crate) method: Method,
}
impl Default for PageserverRequestLabelGroupSet {
fn default() -> Self {
Self {
pageserver_id: lasso::ThreadedRodeo::new(),
path: lasso::ThreadedRodeo::new(),
method: StaticLabelSet::new(),
}
}
}
#[derive(measured::LabelGroup)]
#[label(set = DatabaseQueryErrorLabelGroupSet)]
pub(crate) struct DatabaseQueryErrorLabelGroup {
@@ -156,7 +213,7 @@ pub(crate) struct DatabaseQueryLatencyLabelGroup {
pub(crate) operation: DatabaseOperation,
}
#[derive(FixedCardinalityLabel, Clone, Copy)]
#[derive(FixedCardinalityLabel)]
pub(crate) enum ReconcileOutcome {
#[label(rename = "ok")]
Success,
@@ -164,7 +221,7 @@ pub(crate) enum ReconcileOutcome {
Cancel,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[derive(FixedCardinalityLabel, Clone)]
pub(crate) enum Method {
Get,
Put,
@@ -189,12 +246,11 @@ impl From<hyper::Method> for Method {
}
}
#[derive(Clone, Copy)]
pub(crate) struct StatusCode(pub(crate) hyper::http::StatusCode);
impl LabelValue for StatusCode {
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
v.write_int(self.0.as_u16() as i64)
v.write_int(self.0.as_u16() as u64)
}
}
@@ -212,7 +268,7 @@ impl FixedCardinalityLabel for StatusCode {
}
}
#[derive(FixedCardinalityLabel, Clone, Copy)]
#[derive(FixedCardinalityLabel)]
pub(crate) enum DatabaseErrorLabel {
Query,
Connection,

View File

@@ -79,7 +79,7 @@ pub(crate) enum DatabaseError {
Logical(String),
}
#[derive(measured::FixedCardinalityLabel, Copy, Clone)]
#[derive(measured::FixedCardinalityLabel, Clone)]
pub(crate) enum DatabaseOperation {
InsertNode,
UpdateNode,
@@ -153,7 +153,9 @@ impl Persistence {
let latency = &METRICS_REGISTRY
.metrics_group
.storage_controller_database_query_latency;
let _timer = latency.start_timer(DatabaseQueryLatencyLabelGroup { operation: op });
let _timer = latency.start_timer(DatabaseQueryLatencyLabelGroup {
operation: op.clone(),
});
let res = self.with_conn(func).await;

View File

@@ -1,7 +1,6 @@
import asyncio
import os
import time
from typing import Optional, Tuple
from typing import Tuple
import psutil
import pytest
@@ -21,30 +20,20 @@ ENTRIES_PER_TIMELINE = 10_000
CHECKPOINT_TIMEOUT_SECONDS = 60
async def run_worker_for_tenant(
env: NeonEnv, entries: int, tenant: TenantId, offset: Optional[int] = None
) -> Lsn:
if offset is None:
offset = 0
async def run_worker(env: NeonEnv, tenant_conf, entries: int) -> Tuple[TenantId, TimelineId, Lsn]:
tenant, timeline = env.neon_cli.create_tenant(conf=tenant_conf)
with env.endpoints.create_start("main", tenant_id=tenant) as ep:
conn = await ep.connect_async()
try:
await conn.execute("CREATE TABLE IF NOT EXISTS t(key serial primary key, value text)")
await conn.execute(
f"INSERT INTO t SELECT i, CONCAT('payload_', i) FROM generate_series({offset},{entries}) as i"
f"INSERT INTO t SELECT i, CONCAT('payload_', i) FROM generate_series(0,{entries}) as i"
)
finally:
await conn.close(timeout=10)
last_flush_lsn = Lsn(ep.safe_psql("SELECT pg_current_wal_flush_lsn()")[0][0])
return last_flush_lsn
async def run_worker(env: NeonEnv, tenant_conf, entries: int) -> Tuple[TenantId, TimelineId, Lsn]:
tenant, timeline = env.neon_cli.create_tenant(conf=tenant_conf)
last_flush_lsn = await run_worker_for_tenant(env, entries, tenant)
return tenant, timeline, last_flush_lsn
return tenant, timeline, last_flush_lsn
async def workload(
@@ -100,9 +89,7 @@ def assert_dirty_bytes(env, v):
def assert_dirty_bytes_nonzero(env):
dirty_bytes = get_dirty_bytes(env)
assert dirty_bytes > 0
return dirty_bytes
assert get_dirty_bytes(env) > 0
@pytest.mark.parametrize("immediate_shutdown", [True, False])
@@ -195,31 +182,6 @@ def test_idle_checkpoints(neon_env_builder: NeonEnvBuilder):
log.info("Waiting for background checkpoints...")
wait_until(CHECKPOINT_TIMEOUT_SECONDS * 2, 1, lambda: assert_dirty_bytes(env, 0)) # type: ignore
# The code below verifies that we do not flush on the first write
# after an idle period longer than the checkpoint timeout.
# Sit quietly for longer than the checkpoint timeout
time.sleep(CHECKPOINT_TIMEOUT_SECONDS + CHECKPOINT_TIMEOUT_SECONDS / 2)
# Restart the safekeepers and write a bit of extra data into one tenant
for sk in env.safekeepers:
sk.start()
tenant_with_extra_writes = last_flush_lsns[0][0]
asyncio.run(
run_worker_for_tenant(env, 5, tenant_with_extra_writes, offset=ENTRIES_PER_TIMELINE)
)
dirty_after_write = wait_until(10, 1, lambda: assert_dirty_bytes_nonzero(env)) # type: ignore
# We shouldn't flush since we've just opened a new layer
waited_for = 0
while waited_for < CHECKPOINT_TIMEOUT_SECONDS // 4:
time.sleep(5)
waited_for += 5
assert get_dirty_bytes(env) >= dirty_after_write
@pytest.mark.skipif(
# We have to use at least ~100MB of data to hit the lowest limit we can configure, which is

View File

@@ -6,7 +6,7 @@ from fixtures.neon_fixtures import NeonEnv
def test_physical_replication(neon_simple_env: NeonEnv):
env = neon_simple_env
n_records = 100000
n_records = 10000
with env.endpoints.create_start(
branch_name="main",
endpoint_id="primary",
@@ -18,12 +18,15 @@ def test_physical_replication(neon_simple_env: NeonEnv):
)
time.sleep(1)
with env.endpoints.new_replica_start(origin=primary, endpoint_id="secondary") as secondary:
with primary.connect() as p_con:
with p_con.cursor() as p_cur:
with secondary.connect() as s_con:
with s_con.cursor() as s_cur:
for pk in range(n_records):
p_cur.execute("insert into t (pk) values (%s)", (pk,))
s_cur.execute(
"select * from t where pk=%s", (random.randrange(1, n_records),)
)
p_con = primary.connect()
s_con = secondary.connect()
count = 0
with p_con.cursor() as p_cur:
with s_con.cursor() as s_cur:
for pk in range(n_records):
p_cur.execute("insert into t (pk) values (%s)", (pk,))
s_cur.execute(
"select count(*) from t where pk=%s", (random.randrange(1, n_records),)
)
count += s_cur.fetchall()[0][0]
assert count > 0

View File

@@ -0,0 +1,84 @@
import asyncio
import time
from pathlib import Path
from typing import Iterator
import pytest
from fixtures.neon_fixtures import (
PSQL,
NeonProxy,
)
from fixtures.port_distributor import PortDistributor
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.response import Response
def waiting_handler(status_code: int) -> Response:
# wait more than timeout to make sure that both (two) connections are open.
# It would be better to use a barrier here, but I don't know how to do that together with pytest-httpserver.
time.sleep(2)
return Response(status=status_code)
@pytest.fixture(scope="function")
def proxy_with_rate_limit(
port_distributor: PortDistributor,
neon_binpath: Path,
httpserver_listen_address,
test_output_dir: Path,
) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()
external_http_port = port_distributor.get_port()
(host, port) = httpserver_listen_address
endpoint = f"http://{host}:{port}/billing/api/v1/usage_events"
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.Console(endpoint, fixed_rate_limit=5),
) as proxy:
proxy.start()
yield proxy
@pytest.mark.asyncio
async def test_proxy_rate_limit(
httpserver: HTTPServer,
proxy_with_rate_limit: NeonProxy,
):
uri = "/billing/api/v1/usage_events/proxy_get_role_secret"
# mock control plane service
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
lambda _: Response(status=200)
)
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
lambda _: waiting_handler(429)
)
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
lambda _: waiting_handler(500)
)
psql = PSQL(host=proxy_with_rate_limit.host, port=proxy_with_rate_limit.proxy_port)
f = await psql.run("select 42;")
await proxy_with_rate_limit.find_auth_link(uri, f)
# Limit should be 2.
# Run two queries in parallel.
f1, f2 = await asyncio.gather(psql.run("select 42;"), psql.run("select 42;"))
await proxy_with_rate_limit.find_auth_link(uri, f1)
await proxy_with_rate_limit.find_auth_link(uri, f2)
# Now limit should be 0.
f = await psql.run("select 42;")
await proxy_with_rate_limit.find_auth_link(uri, f)
# There last query shouldn't reach the http-server.
assert httpserver.assertions == []

View File

@@ -37,7 +37,8 @@ futures-io = { version = "0.3" }
futures-sink = { version = "0.3" }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
hashbrown = { version = "0.14", features = ["raw"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] }
hex = { version = "0.4", features = ["serde"] }
hmac = { version = "0.12", default-features = false, features = ["reset"] }
hyper = { version = "0.14", features = ["full"] }
@@ -63,7 +64,7 @@ scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["raw_value"] }
sha2 = { version = "0.10", features = ["asm"] }
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
smallvec = { version = "1", default-features = false, features = ["write"] }
subtle = { version = "2" }
time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
@@ -75,6 +76,7 @@ tonic = { version = "0.9", features = ["tls-roots"] }
tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }
tracing = { version = "0.1", features = ["log"] }
tracing-core = { version = "0.1" }
tungstenite = { version = "0.20" }
url = { version = "2", features = ["serde"] }
uuid = { version = "1", features = ["serde", "v4", "v7"] }
zeroize = { version = "1", features = ["derive"] }
@@ -89,7 +91,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] }
either = { version = "1" }
getrandom = { version = "0.2", default-features = false, features = ["std"] }
hashbrown = { version = "0.14", features = ["raw"] }
hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] }
indexmap = { version = "1", default-features = false, features = ["std"] }
itertools = { version = "0.10" }
libc = { version = "0.2", features = ["extra_traits", "use_std"] }