diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 2bcda7cc8e..d27713f083 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -1121,10 +1121,16 @@ jobs: run: | if [[ "$GITHUB_REF_NAME" == "main" ]]; then gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=false - - # TODO: move deployPreprodRegion to release (`"$GITHUB_REF_NAME" == "release"` block), once Staging support different compute tag prefixes for different regions - gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=true elif [[ "$GITHUB_REF_NAME" == "release" ]]; then + gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \ + -f deployPgSniRouter=false \ + -f deployProxy=false \ + -f deployStorage=true \ + -f deployStorageBroker=true \ + -f branch=main \ + -f dockerTag=${{needs.tag.outputs.build-tag}} \ + -f deployPreprodRegion=true + gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main \ -f deployPgSniRouter=false \ -f deployProxy=false \ @@ -1133,6 +1139,15 @@ jobs: -f branch=main \ -f dockerTag=${{needs.tag.outputs.build-tag}} elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then + gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \ + -f deployPgSniRouter=true \ + -f deployProxy=true \ + -f deployStorage=false \ + -f deployStorageBroker=false \ + -f branch=main \ + -f dockerTag=${{needs.tag.outputs.build-tag}} \ + -f deployPreprodRegion=true + gh workflow --repo neondatabase/aws run deploy-proxy-prod.yml --ref main \ -f deployPgSniRouter=true \ -f deployProxy=true \ diff --git a/Cargo.lock b/Cargo.lock index 9930a6c323..92befd09ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -276,7 +276,6 @@ version = "0.1.0" dependencies = [ "anyhow", "aws-config", - "aws-sdk-secretsmanager", "bytes", "camino", "clap", @@ -347,9 +346,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33cc49dcdd31c8b6e79850a179af4c367669150c7ac0135f176c61bec81a70f7" +checksum = "fa8587ae17c8e967e4b05a62d495be2fb7701bec52a97f7acfe8a29f938384c8" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -359,9 +358,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb031bff99877c26c28895766f7bb8484a05e24547e370768d6cc9db514662aa" +checksum = "b13dc54b4b49f8288532334bba8f87386a40571c47c37b1304979b556dc613c8" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -381,6 +380,29 @@ dependencies = [ "uuid", ] +[[package]] +name = "aws-sdk-iam" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8ae76026bfb1b80a6aed0bb400c1139cd9c0563e26bce1986cd021c6a968c7b" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "http 0.2.9", + "once_cell", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-s3" version = "1.14.0" @@ -410,29 +432,6 @@ dependencies = [ "url", ] -[[package]] -name = "aws-sdk-secretsmanager" -version = "1.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a0b64e61e7d632d9df90a2e0f32630c68c24960cab1d27d848718180af883d3" -dependencies = [ - "aws-credential-types", - "aws-runtime", - "aws-smithy-async", - "aws-smithy-http", - "aws-smithy-json", - "aws-smithy-runtime", - "aws-smithy-runtime-api", - "aws-smithy-types", - "aws-types", - "bytes", - "fastrand 2.0.0", - "http 0.2.9", - "once_cell", - "regex-lite", - "tracing", -] - [[package]] name = "aws-sdk-sso" version = "1.12.0" @@ -502,9 +501,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.1.4" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c371c6b0ac54d4605eb6f016624fb5c7c2925d315fdf600ac1bf21b19d5f1742" +checksum = "11d6f29688a4be9895c0ba8bef861ad0c0dac5c15e9618b9b7a6c233990fc263" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", @@ -517,7 +516,7 @@ dependencies = [ "hex", "hmac", "http 0.2.9", - "http 1.0.0", + "http 1.1.0", "once_cell", "p256", "percent-encoding", @@ -531,9 +530,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72ee2d09cce0ef3ae526679b522835d63e75fb427aca5413cd371e490d52dcc6" +checksum = "d26ea8fa03025b2face2b3038a63525a10891e3d8829901d502e5384a0d8cd46" dependencies = [ "futures-util", "pin-project-lite", @@ -574,9 +573,9 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.4" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dab56aea3cd9e1101a0a999447fb346afb680ab1406cebc44b32346e25b4117d" +checksum = "3f10fa66956f01540051b0aa7ad54574640f748f9839e843442d99b970d3aff9" dependencies = [ "aws-smithy-eventstream", "aws-smithy-runtime-api", @@ -595,18 +594,18 @@ dependencies = [ [[package]] name = "aws-smithy-json" -version = "0.60.4" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd3898ca6518f9215f62678870064398f00031912390efd03f1f6ef56d83aa8e" +checksum = "4683df9469ef09468dad3473d129960119a0d3593617542b7d52086c8486f2d6" dependencies = [ "aws-smithy-types", ] [[package]] name = "aws-smithy-query" -version = "0.60.4" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda4b1dfc9810e35fba8a620e900522cd1bd4f9578c446e82f49d1ce41d2e9f9" +checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" dependencies = [ "aws-smithy-types", "urlencoding", @@ -614,9 +613,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fafdab38f40ad7816e7da5dec279400dd505160780083759f01441af1bbb10ea" +checksum = "ec81002d883e5a7fd2bb063d6fb51c4999eb55d404f4fff3dd878bf4733b9f01" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -639,14 +638,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.1.4" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c18276dd28852f34b3bf501f4f3719781f4999a51c7bff1a5c6dc8c4529adc29" +checksum = "9acb931e0adaf5132de878f1398d83f8677f90ba70f01f65ff87f6d7244be1c5" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.9", + "http 1.1.0", "pin-project-lite", "tokio", "tracing", @@ -655,9 +655,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb3e134004170d3303718baa2a4eb4ca64ee0a1c0a7041dca31b38be0fb414f3" +checksum = "abe14dceea1e70101d38fbf2a99e6a34159477c0fb95e68e05c66bd7ae4c3729" dependencies = [ "base64-simd", "bytes", @@ -678,18 +678,18 @@ dependencies = [ [[package]] name = "aws-smithy-xml" -version = "0.60.4" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8604a11b25e9ecaf32f9aa56b9fe253c5e2f606a3477f0071e96d3155a5ed218" +checksum = "872c68cf019c0e4afc5de7753c4f7288ce4b71663212771bf5e4542eb9346ca9" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "1.1.4" +version = "1.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "789bbe008e65636fe1b6dbbb374c40c8960d1232b96af5ff4aec349f9c4accf4" +checksum = "0dbf2f3da841a8930f159163175cf6a3d16ddde517c1b0fba7aa776822800f40" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -2396,9 +2396,9 @@ dependencies = [ [[package]] name = "http" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" dependencies = [ "bytes", "fnv", @@ -2498,7 +2498,7 @@ dependencies = [ "hyper", "log", "rustls 0.21.9", - "rustls-native-certs", + "rustls-native-certs 0.6.2", "tokio", "tokio-rustls 0.24.0", ] @@ -4199,6 +4199,10 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "aws-config", + "aws-sdk-iam", + "aws-sigv4", + "aws-types", "base64 0.13.1", "bstr", "bytes", @@ -4209,6 +4213,7 @@ dependencies = [ "consumption_metrics", "dashmap", "env_logger", + "fallible-iterator", "futures", "git-version", "hashbrown 0.13.2", @@ -4216,6 +4221,7 @@ dependencies = [ "hex", "hmac", "hostname", + "http 1.1.0", "humantime", "hyper", "hyper-tungstenite", @@ -4431,9 +4437,9 @@ dependencies = [ [[package]] name = "redis" -version = "0.24.0" +version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c580d9cbbe1d1b479e8d67cf9daf6a62c957e6846048408b80b43ac3f6af84cd" +checksum = "71d64e978fd98a0e6b105d066ba4889a7301fca65aeac850a877d8797343feeb" dependencies = [ "async-trait", "bytes", @@ -4442,15 +4448,15 @@ dependencies = [ "itoa", "percent-encoding", "pin-project-lite", - "rustls 0.21.9", - "rustls-native-certs", - "rustls-pemfile 1.0.2", - "rustls-webpki 0.101.7", + "rustls 0.22.2", + "rustls-native-certs 0.7.0", + "rustls-pemfile 2.1.1", + "rustls-pki-types", "ryu", "sha1_smol", - "socket2 0.4.9", + "socket2 0.5.5", "tokio", - "tokio-rustls 0.24.0", + "tokio-rustls 0.25.0", "tokio-util", "url", ] @@ -4879,6 +4885,19 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-native-certs" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.1", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.2" @@ -6159,7 +6178,7 @@ dependencies = [ "percent-encoding", "pin-project", "prost", - "rustls-native-certs", + "rustls-native-certs 0.6.2", "rustls-pemfile 1.0.2", "tokio", "tokio-rustls 0.24.0", @@ -7045,7 +7064,6 @@ dependencies = [ "aws-sigv4", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-runtime-api", "aws-smithy-types", "axum", "base64 0.21.1", diff --git a/Cargo.toml b/Cargo.toml index 65308db3c1..e9ba0c9004 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,10 +52,12 @@ async-stream = "0.3" async-trait = "0.1" aws-config = { version = "1.1.4", default-features = false, features=["rustls"] } aws-sdk-s3 = "1.14" -aws-sdk-secretsmanager = { version = "1.14.0" } +aws-sdk-iam = "1.15.0" aws-smithy-async = { version = "1.1.4", default-features = false, features=["rt-tokio"] } aws-smithy-types = "1.1.4" aws-credential-types = "1.1.4" +aws-sigv4 = { version = "1.2.0", features = ["sign-http"] } +aws-types = "1.1.7" axum = { version = "0.6.20", features = ["ws"] } base64 = "0.13.0" bincode = "1.3" @@ -76,6 +78,7 @@ either = "1.8" enum-map = "2.4.2" enumset = "1.0.12" fail = "0.5.0" +fallible-iterator = "0.2" fs2 = "0.4.3" futures = "0.3" futures-core = "0.3" @@ -88,6 +91,7 @@ hex = "0.4" hex-literal = "0.4" hmac = "0.12.1" hostname = "0.3.1" +http = {version = "1.1.0", features = ["std"]} http-types = { version = "2", default-features = false } humantime = "2.1" humantime-serde = "1.1.1" @@ -121,7 +125,7 @@ procfs = "0.14" prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency prost = "0.11" rand = "0.8" -redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] } +redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] } regex = "1.10.2" reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] } reqwest-tracing = { version = "0.4.7", features = ["opentelemetry_0_20"] } diff --git a/Dockerfile.build-tools b/Dockerfile.build-tools index 3a452fec32..1ed6f87473 100644 --- a/Dockerfile.build-tools +++ b/Dockerfile.build-tools @@ -135,7 +135,7 @@ WORKDIR /home/nonroot # Rust # Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`) -ENV RUSTC_VERSION=1.76.0 +ENV RUSTC_VERSION=1.77.0 ENV RUSTUP_HOME="/home/nonroot/.rustup" ENV PATH="/home/nonroot/.cargo/bin:${PATH}" RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \ @@ -149,7 +149,7 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux cargo install --git https://github.com/paritytech/cachepot && \ cargo install rustfilt && \ cargo install cargo-hakari && \ - cargo install cargo-deny && \ + cargo install cargo-deny --locked && \ cargo install cargo-hack && \ cargo install cargo-nextest && \ rm -rf /home/nonroot/.cargo/registry && \ diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index 42b8480211..f1fd8637f5 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -17,6 +17,7 @@ pub fn line_in_file(path: &Path, line: &str) -> Result { .write(true) .create(true) .append(false) + .truncate(false) .open(path)?; let buf = io::BufReader::new(&file); let mut count: usize = 0; diff --git a/control_plane/attachment_service/Cargo.toml b/control_plane/attachment_service/Cargo.toml index 34882659e3..0201e0ed86 100644 --- a/control_plane/attachment_service/Cargo.toml +++ b/control_plane/attachment_service/Cargo.toml @@ -16,7 +16,6 @@ testing = [] [dependencies] anyhow.workspace = true aws-config.workspace = true -aws-sdk-secretsmanager.workspace = true bytes.workspace = true camino.workspace = true clap.workspace = true diff --git a/control_plane/attachment_service/src/heartbeater.rs b/control_plane/attachment_service/src/heartbeater.rs index e15de28920..7669680eb6 100644 --- a/control_plane/attachment_service/src/heartbeater.rs +++ b/control_plane/attachment_service/src/heartbeater.rs @@ -139,7 +139,7 @@ impl HeartbeaterTask { .with_client_retries( |client| async move { client.get_utilization().await }, &jwt_token, - 2, + 3, 3, Duration::from_secs(1), &cancel, diff --git a/control_plane/attachment_service/src/main.rs b/control_plane/attachment_service/src/main.rs index 0a925a63f6..bd8d7f5c59 100644 --- a/control_plane/attachment_service/src/main.rs +++ b/control_plane/attachment_service/src/main.rs @@ -3,7 +3,6 @@ use attachment_service::http::make_router; use attachment_service::metrics::preinitialize_metrics; use attachment_service::persistence::Persistence; use attachment_service::service::{Config, Service, MAX_UNAVAILABLE_INTERVAL_DEFAULT}; -use aws_config::{BehaviorVersion, Region}; use camino::Utf8PathBuf; use clap::Parser; use diesel::Connection; @@ -55,11 +54,31 @@ struct Cli { #[arg(long)] database_url: Option, + /// Flag to enable dev mode, which permits running without auth + #[arg(long, default_value = "false")] + dev: bool, + /// Grace period before marking unresponsive pageserver offline #[arg(long)] max_unavailable_interval: Option, } +enum StrictMode { + /// In strict mode, we will require that all secrets are loaded, i.e. security features + /// may not be implicitly turned off by omitting secrets in the environment. + Strict, + /// In dev mode, secrets are optional, and omitting a particular secret will implicitly + /// disable the auth related to it (e.g. no pageserver jwt key -> send unauthenticated + /// requests, no public key -> don't authenticate incoming requests). + Dev, +} + +impl Default for StrictMode { + fn default() -> Self { + Self::Strict + } +} + /// Secrets may either be provided on the command line (for testing), or loaded from AWS SecretManager: this /// type encapsulates the logic to decide which and do the loading. struct Secrets { @@ -70,13 +89,6 @@ struct Secrets { } impl Secrets { - const DATABASE_URL_SECRET: &'static str = "rds-neon-storage-controller-url"; - const PAGESERVER_JWT_TOKEN_SECRET: &'static str = - "neon-storage-controller-pageserver-jwt-token"; - const CONTROL_PLANE_JWT_TOKEN_SECRET: &'static str = - "neon-storage-controller-control-plane-jwt-token"; - const PUBLIC_KEY_SECRET: &'static str = "neon-storage-controller-public-key"; - const DATABASE_URL_ENV: &'static str = "DATABASE_URL"; const PAGESERVER_JWT_TOKEN_ENV: &'static str = "PAGESERVER_JWT_TOKEN"; const CONTROL_PLANE_JWT_TOKEN_ENV: &'static str = "CONTROL_PLANE_JWT_TOKEN"; @@ -87,111 +99,41 @@ impl Secrets { /// - Environment variables if DATABASE_URL is set. /// - AWS Secrets Manager secrets async fn load(args: &Cli) -> anyhow::Result { - match &args.database_url { - Some(url) => Self::load_cli(url, args), - None => match std::env::var(Self::DATABASE_URL_ENV) { - Ok(database_url) => Self::load_env(database_url), - Err(_) => Self::load_aws_sm().await, - }, - } - } - - fn load_env(database_url: String) -> anyhow::Result { - let public_key = match std::env::var(Self::PUBLIC_KEY_ENV) { - Ok(public_key) => Some(JwtAuth::from_key(public_key).context("Loading public key")?), - Err(_) => None, - }; - Ok(Self { - database_url, - public_key, - jwt_token: std::env::var(Self::PAGESERVER_JWT_TOKEN_ENV).ok(), - control_plane_jwt_token: std::env::var(Self::CONTROL_PLANE_JWT_TOKEN_ENV).ok(), - }) - } - - async fn load_aws_sm() -> anyhow::Result { - let Ok(region) = std::env::var("AWS_REGION") else { - anyhow::bail!("AWS_REGION is not set, cannot load secrets automatically: either set this, or use CLI args to supply secrets"); - }; - let config = aws_config::defaults(BehaviorVersion::v2023_11_09()) - .region(Region::new(region.clone())) - .load() - .await; - - let asm = aws_sdk_secretsmanager::Client::new(&config); - - let Some(database_url) = asm - .get_secret_value() - .secret_id(Self::DATABASE_URL_SECRET) - .send() - .await? - .secret_string() - .map(str::to_string) + let Some(database_url) = + Self::load_secret(&args.database_url, Self::DATABASE_URL_ENV).await else { anyhow::bail!( - "Database URL secret not found at {region}/{}", - Self::DATABASE_URL_SECRET + "Database URL is not set (set `--database-url`, or `DATABASE_URL` environment)" ) }; - let jwt_token = asm - .get_secret_value() - .secret_id(Self::PAGESERVER_JWT_TOKEN_SECRET) - .send() - .await? - .secret_string() - .map(str::to_string); - if jwt_token.is_none() { - tracing::warn!("No pageserver JWT token set: this will only work if authentication is disabled on the pageserver"); - } - - let control_plane_jwt_token = asm - .get_secret_value() - .secret_id(Self::CONTROL_PLANE_JWT_TOKEN_SECRET) - .send() - .await? - .secret_string() - .map(str::to_string); - if jwt_token.is_none() { - tracing::warn!("No control plane JWT token set: this will only work if authentication is disabled on the pageserver"); - } - - let public_key = asm - .get_secret_value() - .secret_id(Self::PUBLIC_KEY_SECRET) - .send() - .await? - .secret_string() - .map(str::to_string); - let public_key = match public_key { - Some(key) => Some(JwtAuth::from_key(key)?), - None => { - tracing::warn!( - "No public key set: inccoming HTTP requests will not be authenticated" - ); - None - } + let public_key = match Self::load_secret(&args.public_key, Self::PUBLIC_KEY_ENV).await { + Some(v) => Some(JwtAuth::from_key(v).context("Loading public key")?), + None => None, }; - Ok(Self { + let this = Self { database_url, public_key, - jwt_token, - control_plane_jwt_token, - }) + jwt_token: Self::load_secret(&args.jwt_token, Self::PAGESERVER_JWT_TOKEN_ENV).await, + control_plane_jwt_token: Self::load_secret( + &args.control_plane_jwt_token, + Self::CONTROL_PLANE_JWT_TOKEN_ENV, + ) + .await, + }; + + Ok(this) } - fn load_cli(database_url: &str, args: &Cli) -> anyhow::Result { - let public_key = match &args.public_key { - None => None, - Some(key) => Some(JwtAuth::from_key(key.clone()).context("Loading public key")?), - }; - Ok(Self { - database_url: database_url.to_owned(), - public_key, - jwt_token: args.jwt_token.clone(), - control_plane_jwt_token: args.control_plane_jwt_token.clone(), - }) + async fn load_secret(cli: &Option, env_name: &str) -> Option { + if let Some(v) = cli { + Some(v.clone()) + } else if let Ok(v) = std::env::var(env_name) { + Some(v) + } else { + None + } } } @@ -247,8 +189,42 @@ async fn async_main() -> anyhow::Result<()> { args.listen ); + let strict_mode = if args.dev { + StrictMode::Dev + } else { + StrictMode::Strict + }; + let secrets = Secrets::load(&args).await?; + // Validate required secrets and arguments are provided in strict mode + match strict_mode { + StrictMode::Strict + if (secrets.public_key.is_none() + || secrets.jwt_token.is_none() + || secrets.control_plane_jwt_token.is_none()) => + { + // Production systems should always have secrets configured: if public_key was not set + // then we would implicitly disable auth. + anyhow::bail!( + "Insecure config! One or more secrets is not set. This is only permitted in `--dev` mode" + ); + } + StrictMode::Strict if args.compute_hook_url.is_none() => { + // Production systems should always have a compute hook set, to prevent falling + // back to trying to use neon_local. + anyhow::bail!( + "`--compute-hook-url` is not set: this is only permitted in `--dev` mode" + ); + } + StrictMode::Strict => { + tracing::info!("Starting in strict mode: configuration is OK.") + } + StrictMode::Dev => { + tracing::warn!("Starting in dev mode: this may be an insecure configuration.") + } + } + let config = Config { jwt_token: secrets.jwt_token, control_plane_jwt_token: secrets.control_plane_jwt_token, diff --git a/control_plane/src/storage_controller.rs b/control_plane/src/storage_controller.rs index e7697ecac8..7f2b973391 100644 --- a/control_plane/src/storage_controller.rs +++ b/control_plane/src/storage_controller.rs @@ -279,6 +279,7 @@ impl StorageController { &self.listen, "-p", self.path.as_ref(), + "--dev", "--database-url", &database_url, "--max-unavailable-interval", diff --git a/libs/postgres_ffi/wal_craft/src/bin/wal_craft.rs b/libs/postgres_ffi/wal_craft/src/bin/wal_craft.rs index e87ca27e90..41afcea6c2 100644 --- a/libs/postgres_ffi/wal_craft/src/bin/wal_craft.rs +++ b/libs/postgres_ffi/wal_craft/src/bin/wal_craft.rs @@ -1,5 +1,6 @@ use anyhow::*; use clap::{value_parser, Arg, ArgMatches, Command}; +use postgres::Client; use std::{path::PathBuf, str::FromStr}; use wal_craft::*; @@ -8,8 +9,8 @@ fn main() -> Result<()> { .init(); let arg_matches = cli().get_matches(); - let wal_craft = |arg_matches: &ArgMatches, client| { - let (intermediate_lsns, end_of_wal_lsn) = match arg_matches + let wal_craft = |arg_matches: &ArgMatches, client: &mut Client| { + let intermediate_lsns = match arg_matches .get_one::("type") .map(|s| s.as_str()) .context("'type' is required")? @@ -25,6 +26,7 @@ fn main() -> Result<()> { LastWalRecordCrossingSegment::NAME => LastWalRecordCrossingSegment::craft(client)?, a => panic!("Unknown --type argument: {a}"), }; + let end_of_wal_lsn = client.pg_current_wal_insert_lsn()?; for lsn in intermediate_lsns { println!("intermediate_lsn = {lsn}"); } diff --git a/libs/postgres_ffi/wal_craft/src/lib.rs b/libs/postgres_ffi/wal_craft/src/lib.rs index 281a180e3b..23786e3b08 100644 --- a/libs/postgres_ffi/wal_craft/src/lib.rs +++ b/libs/postgres_ffi/wal_craft/src/lib.rs @@ -5,7 +5,6 @@ use postgres::types::PgLsn; use postgres::Client; use postgres_ffi::{WAL_SEGMENT_SIZE, XLOG_BLCKSZ}; use postgres_ffi::{XLOG_SIZE_OF_XLOG_RECORD, XLOG_SIZE_OF_XLOG_SHORT_PHD}; -use std::cmp::Ordering; use std::path::{Path, PathBuf}; use std::process::Command; use std::time::{Duration, Instant}; @@ -232,59 +231,52 @@ pub fn ensure_server_config(client: &mut impl postgres::GenericClient) -> anyhow pub trait Crafter { const NAME: &'static str; - /// Generates WAL using the client `client`. Returns a pair of: - /// * A vector of some valid "interesting" intermediate LSNs which one may start reading from. - /// May include or exclude Lsn(0) and the end-of-wal. - /// * The expected end-of-wal LSN. - fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec, PgLsn)>; + /// Generates WAL using the client `client`. Returns a vector of some valid + /// "interesting" intermediate LSNs which one may start reading from. + /// test_end_of_wal uses this to check various starting points. + /// + /// Note that postgres is generally keen about writing some WAL. While we + /// try to disable it (autovacuum, big wal_writer_delay, etc) it is always + /// possible, e.g. xl_running_xacts are dumped each 15s. So checks about + /// stable WAL end would be flaky unless postgres is shut down. For this + /// reason returning potential end of WAL here is pointless. Most of the + /// time this doesn't happen though, so it is reasonable to create needed + /// WAL structure and immediately kill postgres like test_end_of_wal does. + fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result>; } +/// Wraps some WAL craft function, providing current LSN to it before the +/// insertion and flushing WAL afterwards. Also pushes initial LSN to the +/// result. fn craft_internal( client: &mut C, - f: impl Fn(&mut C, PgLsn) -> anyhow::Result<(Vec, Option)>, -) -> anyhow::Result<(Vec, PgLsn)> { + f: impl Fn(&mut C, PgLsn) -> anyhow::Result>, +) -> anyhow::Result> { ensure_server_config(client)?; let initial_lsn = client.pg_current_wal_insert_lsn()?; info!("LSN initial = {}", initial_lsn); - let (mut intermediate_lsns, last_lsn) = f(client, initial_lsn)?; - let last_lsn = match last_lsn { - None => client.pg_current_wal_insert_lsn()?, - Some(last_lsn) => { - let insert_lsn = client.pg_current_wal_insert_lsn()?; - match last_lsn.cmp(&insert_lsn) { - Ordering::Less => bail!( - "Some records were inserted after the crafted WAL: {} vs {}", - last_lsn, - insert_lsn - ), - Ordering::Equal => last_lsn, - Ordering::Greater => bail!("Reported LSN is greater than insert_lsn"), - } - } - }; + let mut intermediate_lsns = f(client, initial_lsn)?; if !intermediate_lsns.starts_with(&[initial_lsn]) { intermediate_lsns.insert(0, initial_lsn); } // Some records may be not flushed, e.g. non-transactional logical messages. + // + // Note: this is broken if pg_current_wal_insert_lsn is at page boundary + // because pg_current_wal_insert_lsn skips page headers. client.execute("select neon_xlogflush(pg_current_wal_insert_lsn())", &[])?; - match last_lsn.cmp(&client.pg_current_wal_flush_lsn()?) { - Ordering::Less => bail!("Some records were flushed after the crafted WAL"), - Ordering::Equal => {} - Ordering::Greater => bail!("Reported LSN is greater than flush_lsn"), - } - Ok((intermediate_lsns, last_lsn)) + Ok(intermediate_lsns) } pub struct Simple; impl Crafter for Simple { const NAME: &'static str = "simple"; - fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec, PgLsn)> { + fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result> { craft_internal(client, |client, _| { client.execute("CREATE table t(x int)", &[])?; - Ok((Vec::new(), None)) + Ok(Vec::new()) }) } } @@ -292,29 +284,36 @@ impl Crafter for Simple { pub struct LastWalRecordXlogSwitch; impl Crafter for LastWalRecordXlogSwitch { const NAME: &'static str = "last_wal_record_xlog_switch"; - fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec, PgLsn)> { - // Do not use generate_internal because here we end up with flush_lsn exactly on + fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result> { + // Do not use craft_internal because here we end up with flush_lsn exactly on // the segment boundary and insert_lsn after the initial page header, which is unusual. ensure_server_config(client)?; client.execute("CREATE table t(x int)", &[])?; let before_xlog_switch = client.pg_current_wal_insert_lsn()?; - let after_xlog_switch: PgLsn = client.query_one("SELECT pg_switch_wal()", &[])?.get(0); - let next_segment = PgLsn::from(0x0200_0000); + // pg_switch_wal returns end of last record of the switched segment, + // i.e. end of SWITCH itself. + let xlog_switch_record_end: PgLsn = client.query_one("SELECT pg_switch_wal()", &[])?.get(0); + let before_xlog_switch_u64 = u64::from(before_xlog_switch); + let next_segment = PgLsn::from( + before_xlog_switch_u64 - (before_xlog_switch_u64 % WAL_SEGMENT_SIZE as u64) + + WAL_SEGMENT_SIZE as u64, + ); ensure!( - after_xlog_switch <= next_segment, - "XLOG_SWITCH message ended after the expected segment boundary: {} > {}", - after_xlog_switch, + xlog_switch_record_end <= next_segment, + "XLOG_SWITCH record ended after the expected segment boundary: {} > {}", + xlog_switch_record_end, next_segment ); - Ok((vec![before_xlog_switch, after_xlog_switch], next_segment)) + Ok(vec![before_xlog_switch, xlog_switch_record_end]) } } pub struct LastWalRecordXlogSwitchEndsOnPageBoundary; +/// Craft xlog SWITCH record ending at page boundary. impl Crafter for LastWalRecordXlogSwitchEndsOnPageBoundary { const NAME: &'static str = "last_wal_record_xlog_switch_ends_on_page_boundary"; - fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec, PgLsn)> { + fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result> { // Do not use generate_internal because here we end up with flush_lsn exactly on // the segment boundary and insert_lsn after the initial page header, which is unusual. ensure_server_config(client)?; @@ -361,28 +360,29 @@ impl Crafter for LastWalRecordXlogSwitchEndsOnPageBoundary { // Emit the XLOG_SWITCH let before_xlog_switch = client.pg_current_wal_insert_lsn()?; - let after_xlog_switch: PgLsn = client.query_one("SELECT pg_switch_wal()", &[])?.get(0); + let xlog_switch_record_end: PgLsn = client.query_one("SELECT pg_switch_wal()", &[])?.get(0); let next_segment = PgLsn::from(0x0200_0000); ensure!( - after_xlog_switch < next_segment, - "XLOG_SWITCH message ended on or after the expected segment boundary: {} > {}", - after_xlog_switch, + xlog_switch_record_end < next_segment, + "XLOG_SWITCH record ended on or after the expected segment boundary: {} > {}", + xlog_switch_record_end, next_segment ); ensure!( - u64::from(after_xlog_switch) as usize % XLOG_BLCKSZ == XLOG_SIZE_OF_XLOG_SHORT_PHD, + u64::from(xlog_switch_record_end) as usize % XLOG_BLCKSZ == XLOG_SIZE_OF_XLOG_SHORT_PHD, "XLOG_SWITCH message ended not on page boundary: {}, offset = {}", - after_xlog_switch, - u64::from(after_xlog_switch) as usize % XLOG_BLCKSZ + xlog_switch_record_end, + u64::from(xlog_switch_record_end) as usize % XLOG_BLCKSZ ); - Ok((vec![before_xlog_switch, after_xlog_switch], next_segment)) + Ok(vec![before_xlog_switch, xlog_switch_record_end]) } } -fn craft_single_logical_message( +/// Write ~16MB logical message; it should cross WAL segment. +fn craft_seg_size_logical_message( client: &mut impl postgres::GenericClient, transactional: bool, -) -> anyhow::Result<(Vec, PgLsn)> { +) -> anyhow::Result> { craft_internal(client, |client, initial_lsn| { ensure!( initial_lsn < PgLsn::from(0x0200_0000 - 1024 * 1024), @@ -405,34 +405,24 @@ fn craft_single_logical_message( "Logical message crossed two segments" ); - if transactional { - // Transactional logical messages are part of a transaction, so the one above is - // followed by a small COMMIT record. - - let after_message_lsn = client.pg_current_wal_insert_lsn()?; - ensure!( - message_lsn < after_message_lsn, - "No record found after the emitted message" - ); - Ok((vec![message_lsn], Some(after_message_lsn))) - } else { - Ok((Vec::new(), Some(message_lsn))) - } + Ok(vec![message_lsn]) }) } pub struct WalRecordCrossingSegmentFollowedBySmallOne; impl Crafter for WalRecordCrossingSegmentFollowedBySmallOne { const NAME: &'static str = "wal_record_crossing_segment_followed_by_small_one"; - fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec, PgLsn)> { - craft_single_logical_message(client, true) + fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result> { + // Transactional message crossing WAL segment will be followed by small + // commit record. + craft_seg_size_logical_message(client, true) } } pub struct LastWalRecordCrossingSegment; impl Crafter for LastWalRecordCrossingSegment { const NAME: &'static str = "last_wal_record_crossing_segment"; - fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result<(Vec, PgLsn)> { - craft_single_logical_message(client, false) + fn craft(client: &mut impl postgres::GenericClient) -> anyhow::Result> { + craft_seg_size_logical_message(client, false) } } diff --git a/libs/postgres_ffi/wal_craft/src/xlog_utils_test.rs b/libs/postgres_ffi/wal_craft/src/xlog_utils_test.rs index 6ff4c563b2..496458b2e4 100644 --- a/libs/postgres_ffi/wal_craft/src/xlog_utils_test.rs +++ b/libs/postgres_ffi/wal_craft/src/xlog_utils_test.rs @@ -11,13 +11,15 @@ use utils::const_assert; use utils::lsn::Lsn; fn init_logging() { - let _ = env_logger::Builder::from_env(env_logger::Env::default().default_filter_or( - format!("crate=info,postgres_ffi::{PG_MAJORVERSION}::xlog_utils=trace"), - )) + let _ = env_logger::Builder::from_env(env_logger::Env::default().default_filter_or(format!( + "crate=info,postgres_ffi::{PG_MAJORVERSION}::xlog_utils=trace" + ))) .is_test(true) .try_init(); } +/// Test that find_end_of_wal returns the same results as pg_dump on various +/// WALs created by Crafter. fn test_end_of_wal(test_name: &str) { use crate::*; @@ -38,13 +40,13 @@ fn test_end_of_wal(test_name: &str) { } cfg.initdb().unwrap(); let srv = cfg.start_server().unwrap(); - let (intermediate_lsns, expected_end_of_wal_partial) = - C::craft(&mut srv.connect_with_timeout().unwrap()).unwrap(); + let intermediate_lsns = C::craft(&mut srv.connect_with_timeout().unwrap()).unwrap(); let intermediate_lsns: Vec = intermediate_lsns .iter() .map(|&lsn| u64::from(lsn).into()) .collect(); - let expected_end_of_wal: Lsn = u64::from(expected_end_of_wal_partial).into(); + // Kill postgres. Note that it might have inserted to WAL something after + // 'craft' did its job. srv.kill(); // Check find_end_of_wal on the initial WAL @@ -56,7 +58,7 @@ fn test_end_of_wal(test_name: &str) { .filter(|fname| IsXLogFileName(fname)) .max() .unwrap(); - check_pg_waldump_end_of_wal(&cfg, &last_segment, expected_end_of_wal); + let expected_end_of_wal = find_pg_waldump_end_of_wal(&cfg, &last_segment); for start_lsn in intermediate_lsns .iter() .chain(std::iter::once(&expected_end_of_wal)) @@ -91,11 +93,7 @@ fn test_end_of_wal(test_name: &str) { } } -fn check_pg_waldump_end_of_wal( - cfg: &crate::Conf, - last_segment: &str, - expected_end_of_wal: Lsn, -) { +fn find_pg_waldump_end_of_wal(cfg: &crate::Conf, last_segment: &str) -> Lsn { // Get the actual end of WAL by pg_waldump let waldump_output = cfg .pg_waldump("000000010000000000000001", last_segment) @@ -113,11 +111,8 @@ fn check_pg_waldump_end_of_wal( } }; let waldump_wal_end = Lsn::from_str(caps.get(1).unwrap().as_str()).unwrap(); - info!( - "waldump erred on {}, expected wal end at {}", - waldump_wal_end, expected_end_of_wal - ); - assert_eq!(waldump_wal_end, expected_end_of_wal); + info!("waldump erred on {}", waldump_wal_end); + waldump_wal_end } fn check_end_of_wal( @@ -210,9 +205,9 @@ pub fn test_update_next_xid() { #[test] pub fn test_encode_logical_message() { let expected = [ - 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 170, 34, 166, 227, 255, - 38, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 112, 114, - 101, 102, 105, 120, 0, 109, 101, 115, 115, 97, 103, 101, + 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 21, 0, 0, 170, 34, 166, 227, 255, 38, + 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 112, 114, 101, 102, + 105, 120, 0, 109, 101, 115, 115, 97, 103, 101, ]; let actual = encode_logical_message("prefix", "message"); assert_eq!(expected, actual[..]); diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index 313d8226b1..8cad863731 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -198,6 +198,7 @@ impl LocalFs { fs::OpenOptions::new() .write(true) .create(true) + .truncate(true) .open(&temp_file_path) .await .with_context(|| { diff --git a/libs/utils/src/lock_file.rs b/libs/utils/src/lock_file.rs index 987b9d9ad2..59c66ca757 100644 --- a/libs/utils/src/lock_file.rs +++ b/libs/utils/src/lock_file.rs @@ -63,6 +63,7 @@ impl UnwrittenLockFile { pub fn create_exclusive(lock_file_path: &Utf8Path) -> anyhow::Result { let lock_file = fs::OpenOptions::new() .create(true) // O_CREAT + .truncate(true) .write(true) .open(lock_file_path) .context("open lock file")?; diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 1fd7c775d5..f4a231f217 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -15,9 +15,9 @@ use metrics::launch_timestamp::{set_launch_timestamp_metric, LaunchTimestamp}; use pageserver::control_plane_client::ControlPlaneClient; use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task}; use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING}; -use pageserver::task_mgr::WALRECEIVER_RUNTIME; use pageserver::tenant::{secondary, TenantSharedResources}; use remote_storage::GenericRemoteStorage; +use tokio::signal::unix::SignalKind; use tokio::time::Instant; use tracing::*; @@ -28,7 +28,7 @@ use pageserver::{ deletion_queue::DeletionQueue, http, page_cache, page_service, task_mgr, task_mgr::TaskKind, - task_mgr::{BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME}, + task_mgr::THE_RUNTIME, tenant::mgr, virtual_file, }; @@ -323,7 +323,7 @@ fn start_pageserver( // Launch broker client // The storage_broker::connect call needs to happen inside a tokio runtime thread. - let broker_client = WALRECEIVER_RUNTIME + let broker_client = THE_RUNTIME .block_on(async { // Note: we do not attempt connecting here (but validate endpoints sanity). storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval) @@ -391,7 +391,7 @@ fn start_pageserver( conf, ); if let Some(deletion_workers) = deletion_workers { - deletion_workers.spawn_with(BACKGROUND_RUNTIME.handle()); + deletion_workers.spawn_with(THE_RUNTIME.handle()); } // Up to this point no significant I/O has been done: this should have been fast. Record @@ -423,7 +423,7 @@ fn start_pageserver( // Scan the local 'tenants/' directory and start loading the tenants let deletion_queue_client = deletion_queue.new_client(); - let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr( + let tenant_manager = THE_RUNTIME.block_on(mgr::init_tenant_mgr( conf, TenantSharedResources { broker_client: broker_client.clone(), @@ -435,7 +435,7 @@ fn start_pageserver( ))?; let tenant_manager = Arc::new(tenant_manager); - BACKGROUND_RUNTIME.spawn({ + THE_RUNTIME.spawn({ let shutdown_pageserver = shutdown_pageserver.clone(); let drive_init = async move { // NOTE: unlike many futures in pageserver, this one is cancellation-safe @@ -545,7 +545,7 @@ fn start_pageserver( // Start up the service to handle HTTP mgmt API request. We created the // listener earlier already. { - let _rt_guard = MGMT_REQUEST_RUNTIME.enter(); + let _rt_guard = THE_RUNTIME.enter(); let router_state = Arc::new( http::routes::State::new( @@ -569,7 +569,6 @@ fn start_pageserver( .with_graceful_shutdown(task_mgr::shutdown_watcher()); task_mgr::spawn( - MGMT_REQUEST_RUNTIME.handle(), TaskKind::HttpEndpointListener, None, None, @@ -594,7 +593,6 @@ fn start_pageserver( let local_disk_storage = conf.workdir.join("last_consumption_metrics.json"); task_mgr::spawn( - crate::BACKGROUND_RUNTIME.handle(), TaskKind::MetricsCollection, None, None, @@ -615,6 +613,7 @@ fn start_pageserver( pageserver::consumption_metrics::collect_metrics( metric_collection_endpoint, + &conf.metric_collection_bucket, conf.metric_collection_interval, conf.cached_metric_collection_interval, conf.synthetic_size_calculation_interval, @@ -642,7 +641,6 @@ fn start_pageserver( DownloadBehavior::Error, ); task_mgr::spawn( - COMPUTE_REQUEST_RUNTIME.handle(), TaskKind::LibpqEndpointListener, None, None, @@ -666,42 +664,37 @@ fn start_pageserver( let mut shutdown_pageserver = Some(shutdown_pageserver.drop_guard()); // All started up! Now just sit and wait for shutdown signal. - { - use signal_hook::consts::*; - let signal_handler = BACKGROUND_RUNTIME.spawn_blocking(move || { - let mut signals = - signal_hook::iterator::Signals::new([SIGINT, SIGTERM, SIGQUIT]).unwrap(); - return signals - .forever() - .next() - .expect("forever() never returns None unless explicitly closed"); - }); - let signal = BACKGROUND_RUNTIME - .block_on(signal_handler) - .expect("join error"); - match signal { - SIGQUIT => { - info!("Got signal {signal}. Terminating in immediate shutdown mode",); - std::process::exit(111); - } - SIGINT | SIGTERM => { - info!("Got signal {signal}. Terminating gracefully in fast shutdown mode",); - // This cancels the `shutdown_pageserver` cancellation tree. - // Right now that tree doesn't reach very far, and `task_mgr` is used instead. - // The plan is to change that over time. - shutdown_pageserver.take(); - let bg_remote_storage = remote_storage.clone(); - let bg_deletion_queue = deletion_queue.clone(); - BACKGROUND_RUNTIME.block_on(pageserver::shutdown_pageserver( - &tenant_manager, - bg_remote_storage.map(|_| bg_deletion_queue), - 0, - )); - unreachable!() - } - _ => unreachable!(), - } + { + THE_RUNTIME.block_on(async move { + let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt()).unwrap(); + let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate()).unwrap(); + let mut sigquit = tokio::signal::unix::signal(SignalKind::quit()).unwrap(); + let signal = tokio::select! { + _ = sigquit.recv() => { + info!("Got signal SIGQUIT. Terminating in immediate shutdown mode",); + std::process::exit(111); + } + _ = sigint.recv() => { "SIGINT" }, + _ = sigterm.recv() => { "SIGTERM" }, + }; + + info!("Got signal {signal}. Terminating gracefully in fast shutdown mode",); + + // This cancels the `shutdown_pageserver` cancellation tree. + // Right now that tree doesn't reach very far, and `task_mgr` is used instead. + // The plan is to change that over time. + shutdown_pageserver.take(); + let bg_remote_storage = remote_storage.clone(); + let bg_deletion_queue = deletion_queue.clone(); + pageserver::shutdown_pageserver( + &tenant_manager, + bg_remote_storage.map(|_| bg_deletion_queue), + 0, + ) + .await; + unreachable!() + }) } } diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 89fda958ff..8b4e08e0d1 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -230,6 +230,7 @@ pub struct PageServerConf { // How often to send unchanged cached metrics to the metrics endpoint. pub cached_metric_collection_interval: Duration, pub metric_collection_endpoint: Option, + pub metric_collection_bucket: Option, pub synthetic_size_calculation_interval: Duration, pub disk_usage_based_eviction: Option, @@ -368,6 +369,7 @@ struct PageServerConfigBuilder { cached_metric_collection_interval: BuilderValue, metric_collection_endpoint: BuilderValue>, synthetic_size_calculation_interval: BuilderValue, + metric_collection_bucket: BuilderValue>, disk_usage_based_eviction: BuilderValue>, @@ -448,6 +450,8 @@ impl PageServerConfigBuilder { .expect("cannot parse default synthetic size calculation interval")), metric_collection_endpoint: Set(DEFAULT_METRIC_COLLECTION_ENDPOINT), + metric_collection_bucket: Set(None), + disk_usage_based_eviction: Set(None), test_remote_failures: Set(0), @@ -575,6 +579,13 @@ impl PageServerConfigBuilder { self.metric_collection_endpoint = BuilderValue::Set(metric_collection_endpoint) } + pub fn metric_collection_bucket( + &mut self, + metric_collection_bucket: Option, + ) { + self.metric_collection_bucket = BuilderValue::Set(metric_collection_bucket) + } + pub fn synthetic_size_calculation_interval( &mut self, synthetic_size_calculation_interval: Duration, @@ -682,6 +693,7 @@ impl PageServerConfigBuilder { metric_collection_interval, cached_metric_collection_interval, metric_collection_endpoint, + metric_collection_bucket, synthetic_size_calculation_interval, disk_usage_based_eviction, test_remote_failures, @@ -929,6 +941,9 @@ impl PageServerConf { let endpoint = parse_toml_string(key, item)?.parse().context("failed to parse metric_collection_endpoint")?; builder.metric_collection_endpoint(Some(endpoint)); }, + "metric_collection_bucket" => { + builder.metric_collection_bucket(RemoteStorageConfig::from_toml(item)?) + } "synthetic_size_calculation_interval" => builder.synthetic_size_calculation_interval(parse_toml_duration(key, item)?), "test_remote_failures" => builder.test_remote_failures(parse_toml_u64(key, item)?), @@ -1043,6 +1058,7 @@ impl PageServerConf { metric_collection_interval: Duration::from_secs(60), cached_metric_collection_interval: Duration::from_secs(60 * 60), metric_collection_endpoint: defaults::DEFAULT_METRIC_COLLECTION_ENDPOINT, + metric_collection_bucket: None, synthetic_size_calculation_interval: Duration::from_secs(60), disk_usage_based_eviction: None, test_remote_failures: 0, @@ -1273,6 +1289,7 @@ background_task_maximum_delay = '334 s' defaults::DEFAULT_CACHED_METRIC_COLLECTION_INTERVAL )?, metric_collection_endpoint: defaults::DEFAULT_METRIC_COLLECTION_ENDPOINT, + metric_collection_bucket: None, synthetic_size_calculation_interval: humantime::parse_duration( defaults::DEFAULT_SYNTHETIC_SIZE_CALCULATION_INTERVAL )?, @@ -1346,6 +1363,7 @@ background_task_maximum_delay = '334 s' metric_collection_interval: Duration::from_secs(222), cached_metric_collection_interval: Duration::from_secs(22200), metric_collection_endpoint: Some(Url::parse("http://localhost:80/metrics")?), + metric_collection_bucket: None, synthetic_size_calculation_interval: Duration::from_secs(333), disk_usage_based_eviction: None, test_remote_failures: 0, diff --git a/pageserver/src/consumption_metrics.rs b/pageserver/src/consumption_metrics.rs index c7f9d596c6..c82be8c581 100644 --- a/pageserver/src/consumption_metrics.rs +++ b/pageserver/src/consumption_metrics.rs @@ -1,12 +1,13 @@ //! Periodically collect consumption metrics for all active tenants //! and push them to a HTTP endpoint. use crate::context::{DownloadBehavior, RequestContext}; -use crate::task_mgr::{self, TaskKind, BACKGROUND_RUNTIME}; +use crate::task_mgr::{self, TaskKind}; use crate::tenant::tasks::BackgroundLoopKind; use crate::tenant::{mgr, LogicalSizeCalculationCause, PageReconstructError, Tenant}; use camino::Utf8PathBuf; use consumption_metrics::EventType; use pageserver_api::models::TenantState; +use remote_storage::{GenericRemoteStorage, RemoteStorageConfig}; use reqwest::Url; use std::collections::HashMap; use std::sync::Arc; @@ -41,6 +42,7 @@ type Cache = HashMap; #[allow(clippy::too_many_arguments)] pub async fn collect_metrics( metric_collection_endpoint: &Url, + metric_collection_bucket: &Option, metric_collection_interval: Duration, _cached_metric_collection_interval: Duration, synthetic_size_calculation_interval: Duration, @@ -59,7 +61,6 @@ pub async fn collect_metrics( let worker_ctx = ctx.detached_child(TaskKind::CalculateSyntheticSize, DownloadBehavior::Download); task_mgr::spawn( - BACKGROUND_RUNTIME.handle(), TaskKind::CalculateSyntheticSize, None, None, @@ -94,6 +95,20 @@ pub async fn collect_metrics( .build() .expect("Failed to create http client with timeout"); + let bucket_client = if let Some(bucket_config) = metric_collection_bucket { + match GenericRemoteStorage::from_config(bucket_config) { + Ok(client) => Some(client), + Err(e) => { + // Non-fatal error: if we were given an invalid config, we will proceed + // with sending metrics over the network, but not to S3. + tracing::warn!("Invalid configuration for metric_collection_bucket: {e}"); + None + } + } + } else { + None + }; + let node_id = node_id.to_string(); loop { @@ -118,10 +133,18 @@ pub async fn collect_metrics( tracing::error!("failed to persist metrics to {path:?}: {e:#}"); } } + + if let Some(bucket_client) = &bucket_client { + let res = + upload::upload_metrics_bucket(bucket_client, &cancel, &node_id, &metrics).await; + if let Err(e) = res { + tracing::error!("failed to upload to S3: {e:#}"); + } + } }; let upload = async { - let res = upload::upload_metrics( + let res = upload::upload_metrics_http( &client, metric_collection_endpoint, &cancel, @@ -132,7 +155,7 @@ pub async fn collect_metrics( .await; if let Err(e) = res { // serialization error which should never happen - tracing::error!("failed to upload due to {e:#}"); + tracing::error!("failed to upload via HTTP due to {e:#}"); } }; diff --git a/pageserver/src/consumption_metrics/upload.rs b/pageserver/src/consumption_metrics/upload.rs index 6b840a3136..4e8283c3e4 100644 --- a/pageserver/src/consumption_metrics/upload.rs +++ b/pageserver/src/consumption_metrics/upload.rs @@ -1,4 +1,9 @@ +use std::time::SystemTime; + +use chrono::{DateTime, Utc}; use consumption_metrics::{Event, EventChunk, IdempotencyKey, CHUNK_SIZE}; +use remote_storage::{GenericRemoteStorage, RemotePath}; +use tokio::io::AsyncWriteExt; use tokio_util::sync::CancellationToken; use tracing::Instrument; @@ -13,8 +18,9 @@ struct Ids { pub(super) timeline_id: Option, } +/// Serialize and write metrics to an HTTP endpoint #[tracing::instrument(skip_all, fields(metrics_total = %metrics.len()))] -pub(super) async fn upload_metrics( +pub(super) async fn upload_metrics_http( client: &reqwest::Client, metric_collection_endpoint: &reqwest::Url, cancel: &CancellationToken, @@ -74,6 +80,60 @@ pub(super) async fn upload_metrics( Ok(()) } +/// Serialize and write metrics to a remote storage object +#[tracing::instrument(skip_all, fields(metrics_total = %metrics.len()))] +pub(super) async fn upload_metrics_bucket( + client: &GenericRemoteStorage, + cancel: &CancellationToken, + node_id: &str, + metrics: &[RawMetric], +) -> anyhow::Result<()> { + if metrics.is_empty() { + // Skip uploads if we have no metrics, so that readers don't have to handle the edge case + // of an empty object. + return Ok(()); + } + + // Compose object path + let datetime: DateTime = SystemTime::now().into(); + let ts_prefix = datetime.format("year=%Y/month=%m/day=%d/%H:%M:%SZ"); + let path = RemotePath::from_string(&format!("{ts_prefix}_{node_id}.ndjson.gz"))?; + + // Set up a gzip writer into a buffer + let mut compressed_bytes: Vec = Vec::new(); + let compressed_writer = std::io::Cursor::new(&mut compressed_bytes); + let mut gzip_writer = async_compression::tokio::write::GzipEncoder::new(compressed_writer); + + // Serialize and write into compressed buffer + let started_at = std::time::Instant::now(); + for res in serialize_in_chunks(CHUNK_SIZE, metrics, node_id) { + let (_chunk, body) = res?; + gzip_writer.write_all(&body).await?; + } + gzip_writer.flush().await?; + gzip_writer.shutdown().await?; + let compressed_length = compressed_bytes.len(); + + // Write to remote storage + client + .upload_storage_object( + futures::stream::once(futures::future::ready(Ok(compressed_bytes.into()))), + compressed_length, + &path, + cancel, + ) + .await?; + let elapsed = started_at.elapsed(); + + tracing::info!( + compressed_length, + elapsed_ms = elapsed.as_millis(), + "write metrics bucket at {path}", + ); + + Ok(()) +} + // The return type is quite ugly, but we gain testability in isolation fn serialize_in_chunks<'a, F>( chunk_size: usize, diff --git a/pageserver/src/control_plane_client.rs b/pageserver/src/control_plane_client.rs index 42c800822b..55d80c2966 100644 --- a/pageserver/src/control_plane_client.rs +++ b/pageserver/src/control_plane_client.rs @@ -173,8 +173,6 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient { register, }; - fail::fail_point!("control-plane-client-re-attach"); - let response: ReAttachResponse = self.retry_http_forever(&re_attach_path, request).await?; tracing::info!( "Received re-attach response with {} tenants", @@ -210,7 +208,7 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient { .collect(), }; - fail::fail_point!("control-plane-client-validate"); + crate::tenant::pausable_failpoint!("control-plane-client-validate"); let response: ValidateResponse = self.retry_http_forever(&re_attach_path, request).await?; diff --git a/pageserver/src/disk_usage_eviction_task.rs b/pageserver/src/disk_usage_eviction_task.rs index 92c1475aef..6b68acd1c7 100644 --- a/pageserver/src/disk_usage_eviction_task.rs +++ b/pageserver/src/disk_usage_eviction_task.rs @@ -59,7 +59,7 @@ use utils::{completion, id::TimelineId}; use crate::{ config::PageServerConf, metrics::disk_usage_based_eviction::METRICS, - task_mgr::{self, TaskKind, BACKGROUND_RUNTIME}, + task_mgr::{self, TaskKind}, tenant::{ self, mgr::TenantManager, @@ -202,7 +202,6 @@ pub fn launch_disk_usage_global_eviction_task( info!("launching disk usage based eviction task"); task_mgr::spawn( - BACKGROUND_RUNTIME.handle(), TaskKind::DiskUsageEviction, None, None, diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index bcee5613b6..36c3513933 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -699,6 +699,14 @@ pub static STARTUP_IS_LOADING: Lazy = Lazy::new(|| { .expect("Failed to register pageserver_startup_is_loading") }); +pub(crate) static TIMELINE_EPHEMERAL_BYTES: Lazy = Lazy::new(|| { + register_uint_gauge!( + "pageserver_timeline_ephemeral_bytes", + "Total number of bytes in ephemeral layers, summed for all timelines. Approximate, lazily updated." + ) + .expect("Failed to register metric") +}); + /// Metrics related to the lifecycle of a [`crate::tenant::Tenant`] object: things /// like how long it took to load. /// diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index f3ceb7d3e6..fa1a0f535b 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -180,7 +180,6 @@ pub async fn libpq_listener_main( // only deal with a particular timeline, but we don't know which one // yet. task_mgr::spawn( - &tokio::runtime::Handle::current(), TaskKind::PageRequestHandler, None, None, diff --git a/pageserver/src/task_mgr.rs b/pageserver/src/task_mgr.rs index 69e163effa..2d97389982 100644 --- a/pageserver/src/task_mgr.rs +++ b/pageserver/src/task_mgr.rs @@ -98,42 +98,22 @@ use utils::id::TimelineId; // other operations, if the upload tasks e.g. get blocked on locks. It shouldn't // happen, but still. // -pub static COMPUTE_REQUEST_RUNTIME: Lazy = Lazy::new(|| { - tokio::runtime::Builder::new_multi_thread() - .thread_name("compute request worker") - .enable_all() - .build() - .expect("Failed to create compute request runtime") -}); -pub static MGMT_REQUEST_RUNTIME: Lazy = Lazy::new(|| { +/// The single tokio runtime used by all pageserver code. +/// In the past, we had multiple runtimes, and in the future we should weed out +/// remaining references to this global field and rely on ambient runtime instead, +/// i.e., use `tokio::spawn` instead of `THE_RUNTIME.spawn()`, etc. +pub static THE_RUNTIME: Lazy = Lazy::new(|| { tokio::runtime::Builder::new_multi_thread() - .thread_name("mgmt request worker") - .enable_all() - .build() - .expect("Failed to create mgmt request runtime") -}); - -pub static WALRECEIVER_RUNTIME: Lazy = Lazy::new(|| { - tokio::runtime::Builder::new_multi_thread() - .thread_name("walreceiver worker") - .enable_all() - .build() - .expect("Failed to create walreceiver runtime") -}); - -pub static BACKGROUND_RUNTIME: Lazy = Lazy::new(|| { - tokio::runtime::Builder::new_multi_thread() - .thread_name("background op worker") // if you change the number of worker threads please change the constant below .enable_all() .build() .expect("Failed to create background op runtime") }); -pub(crate) static BACKGROUND_RUNTIME_WORKER_THREADS: Lazy = Lazy::new(|| { +pub(crate) static THE_RUNTIME_WORKER_THREADS: Lazy = Lazy::new(|| { // force init and thus panics - let _ = BACKGROUND_RUNTIME.handle(); + let _ = THE_RUNTIME.handle(); // replicates tokio-1.28.1::loom::sys::num_cpus which is not available publicly // tokio would had already panicked for parsing errors or NotUnicode // @@ -325,7 +305,6 @@ struct PageServerTask { /// Note: if shutdown_process_on_error is set to true failure /// of the task will lead to shutdown of entire process pub fn spawn( - runtime: &tokio::runtime::Handle, kind: TaskKind, tenant_shard_id: Option, timeline_id: Option, @@ -354,7 +333,7 @@ where let task_name = name.to_string(); let task_cloned = Arc::clone(&task); - let join_handle = runtime.spawn(task_wrapper( + let join_handle = THE_RUNTIME.spawn(task_wrapper( task_name, task_id, task_cloned, diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index fe48741a89..7bd85b6fd5 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -144,6 +144,7 @@ macro_rules! pausable_failpoint { } }; } +pub(crate) use pausable_failpoint; pub mod blob_io; pub mod block_io; @@ -661,7 +662,6 @@ impl Tenant { let tenant_clone = Arc::clone(&tenant); let ctx = ctx.detached_child(TaskKind::Attach, DownloadBehavior::Warn); task_mgr::spawn( - &tokio::runtime::Handle::current(), TaskKind::Attach, Some(tenant_shard_id), None, diff --git a/pageserver/src/tenant/delete.rs b/pageserver/src/tenant/delete.rs index cab60c3111..3866136dbd 100644 --- a/pageserver/src/tenant/delete.rs +++ b/pageserver/src/tenant/delete.rs @@ -111,6 +111,7 @@ async fn create_local_delete_mark( let _ = std::fs::OpenOptions::new() .write(true) .create(true) + .truncate(true) .open(&marker_path) .with_context(|| format!("could not create delete marker file {marker_path:?}"))?; @@ -481,7 +482,6 @@ impl DeleteTenantFlow { let tenant_shard_id = tenant.tenant_shard_id; task_mgr::spawn( - task_mgr::BACKGROUND_RUNTIME.handle(), TaskKind::TimelineDeletionWorker, Some(tenant_shard_id), None, diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 97a505ded9..34ca43a173 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -1850,7 +1850,6 @@ impl TenantManager { let task_tenant_id = None; task_mgr::spawn( - task_mgr::BACKGROUND_RUNTIME.handle(), TaskKind::MgmtRequest, task_tenant_id, None, @@ -2816,15 +2815,12 @@ pub(crate) fn immediate_gc( // TODO: spawning is redundant now, need to hold the gate task_mgr::spawn( - &tokio::runtime::Handle::current(), TaskKind::GarbageCollector, Some(tenant_shard_id), Some(timeline_id), &format!("timeline_gc_handler garbage collection run for tenant {tenant_shard_id} timeline {timeline_id}"), false, async move { - fail::fail_point!("immediate_gc_task_pre"); - #[allow(unused_mut)] let mut result = tenant .gc_iteration(Some(timeline_id), gc_horizon, pitr, &cancel, &ctx) diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index 40be2ca8f3..c0a150eb0d 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -223,7 +223,6 @@ use crate::{ config::PageServerConf, task_mgr, task_mgr::TaskKind, - task_mgr::BACKGROUND_RUNTIME, tenant::metadata::TimelineMetadata, tenant::upload_queue::{ UploadOp, UploadQueue, UploadQueueInitialized, UploadQueueStopped, UploadTask, @@ -307,8 +306,6 @@ pub enum PersistIndexPartWithDeletedFlagError { pub struct RemoteTimelineClient { conf: &'static PageServerConf, - runtime: tokio::runtime::Handle, - tenant_shard_id: TenantShardId, timeline_id: TimelineId, generation: Generation, @@ -341,12 +338,6 @@ impl RemoteTimelineClient { ) -> RemoteTimelineClient { RemoteTimelineClient { conf, - runtime: if cfg!(test) { - // remote_timeline_client.rs tests rely on current-thread runtime - tokio::runtime::Handle::current() - } else { - BACKGROUND_RUNTIME.handle().clone() - }, tenant_shard_id, timeline_id, generation, @@ -1281,7 +1272,6 @@ impl RemoteTimelineClient { let tenant_shard_id = self.tenant_shard_id; let timeline_id = self.timeline_id; task_mgr::spawn( - &self.runtime, TaskKind::RemoteUploadTask, Some(self.tenant_shard_id), Some(self.timeline_id), @@ -1876,7 +1866,6 @@ mod tests { fn build_client(&self, generation: Generation) -> Arc { Arc::new(RemoteTimelineClient { conf: self.harness.conf, - runtime: tokio::runtime::Handle::current(), tenant_shard_id: self.harness.tenant_shard_id, timeline_id: TIMELINE_ID, generation, diff --git a/pageserver/src/tenant/secondary.rs b/pageserver/src/tenant/secondary.rs index 19f36c722e..b0babb1308 100644 --- a/pageserver/src/tenant/secondary.rs +++ b/pageserver/src/tenant/secondary.rs @@ -8,7 +8,7 @@ use std::{sync::Arc, time::SystemTime}; use crate::{ config::PageServerConf, disk_usage_eviction_task::DiskUsageEvictionInfo, - task_mgr::{self, TaskKind, BACKGROUND_RUNTIME}, + task_mgr::{self, TaskKind}, virtual_file::MaybeFatalIo, }; @@ -317,7 +317,6 @@ pub fn spawn_tasks( tokio::sync::mpsc::channel::>(16); task_mgr::spawn( - BACKGROUND_RUNTIME.handle(), TaskKind::SecondaryDownloads, None, None, @@ -338,7 +337,6 @@ pub fn spawn_tasks( ); task_mgr::spawn( - BACKGROUND_RUNTIME.handle(), TaskKind::SecondaryUploads, None, None, diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index 40f19e3b05..8782a9f04e 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -11,11 +11,11 @@ use crate::{ disk_usage_eviction_task::{ finite_f32, DiskUsageEvictionInfo, EvictionCandidate, EvictionLayer, EvictionSecondaryLayer, }, - is_temporary, metrics::SECONDARY_MODE, tenant::{ config::SecondaryLocationConfig, debug_assert_current_span_has_tenant_and_timeline_id, + ephemeral_file::is_ephemeral_file, remote_timeline_client::{ index::LayerFileMetadata, is_temp_download_file, FAILED_DOWNLOAD_WARN_THRESHOLD, FAILED_REMOTE_OP_RETRIES, @@ -964,7 +964,7 @@ async fn init_timeline_state( continue; } else if crate::is_temporary(&file_path) || is_temp_download_file(&file_path) - || is_temporary(&file_path) + || is_ephemeral_file(file_name) { // Temporary files are frequently left behind from restarting during downloads tracing::info!("Cleaning up temporary file {file_path}"); diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer.rs b/pageserver/src/tenant/storage_layer/inmemory_layer.rs index 5f1db21d49..869d175d8d 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer.rs @@ -23,8 +23,12 @@ use tracing::*; use utils::{bin_ser::BeSer, id::TimelineId, lsn::Lsn, vec_map::VecMap}; // avoid binding to Write (conflicts with std::io::Write) // while being able to use std::fmt::Write's methods +use crate::metrics::TIMELINE_EPHEMERAL_BYTES; +use std::cmp::Ordering; use std::fmt::Write as _; use std::ops::Range; +use std::sync::atomic::Ordering as AtomicOrdering; +use std::sync::atomic::{AtomicU64, AtomicUsize}; use tokio::sync::{RwLock, RwLockWriteGuard}; use super::{ @@ -70,6 +74,8 @@ pub struct InMemoryLayerInner { /// Each serialized Value is preceded by a 'u32' length field. /// PerSeg::page_versions map stores offsets into this file. file: EphemeralFile, + + resource_units: GlobalResourceUnits, } impl std::fmt::Debug for InMemoryLayerInner { @@ -78,6 +84,101 @@ impl std::fmt::Debug for InMemoryLayerInner { } } +/// State shared by all in-memory (ephemeral) layers. Updated infrequently during background ticks in Timeline, +/// to minimize contention. +/// +/// This global state is used to implement behaviors that require a global view of the system, e.g. +/// rolling layers proactively to limit the total amount of dirty data. +struct GlobalResources { + // How many bytes are in all EphemeralFile objects + dirty_bytes: AtomicU64, + // How many layers are contributing to dirty_bytes + dirty_layers: AtomicUsize, +} + +// Per-timeline RAII struct for its contribution to [`GlobalResources`] +struct GlobalResourceUnits { + // How many dirty bytes have I added to the global dirty_bytes: this guard object is responsible + // for decrementing the global counter by this many bytes when dropped. + dirty_bytes: u64, +} + +impl GlobalResourceUnits { + // Hint for the layer append path to update us when the layer size differs from the last + // call to update_size by this much. If we don't reach this threshold, we'll still get + // updated when the Timeline "ticks" in the background. + const MAX_SIZE_DRIFT: u64 = 10 * 1024 * 1024; + + fn new() -> Self { + GLOBAL_RESOURCES + .dirty_layers + .fetch_add(1, AtomicOrdering::Relaxed); + Self { dirty_bytes: 0 } + } + + /// Do not call this frequently: all timelines will write to these same global atomics, + /// so this is a relatively expensive operation. Wait at least a few seconds between calls. + fn publish_size(&mut self, size: u64) { + let new_global_dirty_bytes = match size.cmp(&self.dirty_bytes) { + Ordering::Equal => { + return; + } + Ordering::Greater => { + let delta = size - self.dirty_bytes; + let old = GLOBAL_RESOURCES + .dirty_bytes + .fetch_add(delta, AtomicOrdering::Relaxed); + old + delta + } + Ordering::Less => { + let delta = self.dirty_bytes - size; + let old = GLOBAL_RESOURCES + .dirty_bytes + .fetch_sub(delta, AtomicOrdering::Relaxed); + old - delta + } + }; + + // This is a sloppy update: concurrent updates to the counter will race, and the exact + // value of the metric might not be the exact latest value of GLOBAL_RESOURCES::dirty_bytes. + // That's okay: as long as the metric contains some recent value, it doesn't have to always + // be literally the last update. + TIMELINE_EPHEMERAL_BYTES.set(new_global_dirty_bytes); + + self.dirty_bytes = size; + } + + // Call publish_size if the input size differs from last published size by more than + // the drift limit + fn maybe_publish_size(&mut self, size: u64) { + let publish = match size.cmp(&self.dirty_bytes) { + Ordering::Equal => false, + Ordering::Greater => size - self.dirty_bytes > Self::MAX_SIZE_DRIFT, + Ordering::Less => self.dirty_bytes - size > Self::MAX_SIZE_DRIFT, + }; + + if publish { + self.publish_size(size); + } + } +} + +impl Drop for GlobalResourceUnits { + fn drop(&mut self) { + GLOBAL_RESOURCES + .dirty_layers + .fetch_sub(1, AtomicOrdering::Relaxed); + + // Subtract our contribution to the global total dirty bytes + self.publish_size(0); + } +} + +static GLOBAL_RESOURCES: GlobalResources = GlobalResources { + dirty_bytes: AtomicU64::new(0), + dirty_layers: AtomicUsize::new(0), +}; + impl InMemoryLayer { pub(crate) fn get_timeline_id(&self) -> TimelineId { self.timeline_id @@ -328,6 +429,7 @@ impl InMemoryLayer { inner: RwLock::new(InMemoryLayerInner { index: HashMap::new(), file, + resource_units: GlobalResourceUnits::new(), }), }) } @@ -378,9 +480,18 @@ impl InMemoryLayer { warn!("Key {} at {} already exists", key, lsn); } + let size = locked_inner.file.len(); + locked_inner.resource_units.maybe_publish_size(size); + Ok(()) } + pub(crate) async fn tick(&self) { + let mut inner = self.inner.write().await; + let size = inner.file.len(); + inner.resource_units.publish_size(size); + } + pub(crate) async fn put_tombstones(&self, _key_ranges: &[(Range, Lsn)]) -> Result<()> { // TODO: Currently, we just leak the storage for any deleted keys Ok(()) diff --git a/pageserver/src/tenant/storage_layer/layer.rs b/pageserver/src/tenant/storage_layer/layer.rs index 8ba37b5a86..e101a40da4 100644 --- a/pageserver/src/tenant/storage_layer/layer.rs +++ b/pageserver/src/tenant/storage_layer/layer.rs @@ -1447,7 +1447,7 @@ impl LayerInner { #[cfg(test)] tokio::task::spawn(fut); #[cfg(not(test))] - crate::task_mgr::BACKGROUND_RUNTIME.spawn(fut); + crate::task_mgr::THE_RUNTIME.spawn(fut); } /// Needed to use entered runtime in tests, but otherwise use BACKGROUND_RUNTIME. @@ -1458,7 +1458,7 @@ impl LayerInner { #[cfg(test)] tokio::task::spawn_blocking(f); #[cfg(not(test))] - crate::task_mgr::BACKGROUND_RUNTIME.spawn_blocking(f); + crate::task_mgr::THE_RUNTIME.spawn_blocking(f); } } diff --git a/pageserver/src/tenant/tasks.rs b/pageserver/src/tenant/tasks.rs index e4f5f75132..db32223a60 100644 --- a/pageserver/src/tenant/tasks.rs +++ b/pageserver/src/tenant/tasks.rs @@ -8,7 +8,7 @@ use std::time::{Duration, Instant}; use crate::context::{DownloadBehavior, RequestContext}; use crate::metrics::TENANT_TASK_EVENTS; use crate::task_mgr; -use crate::task_mgr::{TaskKind, BACKGROUND_RUNTIME}; +use crate::task_mgr::TaskKind; use crate::tenant::throttle::Stats; use crate::tenant::timeline::CompactionError; use crate::tenant::{Tenant, TenantState}; @@ -18,7 +18,7 @@ use utils::{backoff, completion}; static CONCURRENT_BACKGROUND_TASKS: once_cell::sync::Lazy = once_cell::sync::Lazy::new(|| { - let total_threads = *task_mgr::BACKGROUND_RUNTIME_WORKER_THREADS; + let total_threads = *crate::task_mgr::THE_RUNTIME_WORKER_THREADS; let permits = usize::max( 1, // while a lot of the work is done on spawn_blocking, we still do @@ -85,7 +85,6 @@ pub fn start_background_loops( ) { let tenant_shard_id = tenant.tenant_shard_id; task_mgr::spawn( - BACKGROUND_RUNTIME.handle(), TaskKind::Compaction, Some(tenant_shard_id), None, @@ -109,7 +108,6 @@ pub fn start_background_loops( }, ); task_mgr::spawn( - BACKGROUND_RUNTIME.handle(), TaskKind::GarbageCollector, Some(tenant_shard_id), None, diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 7523130f23..15ffa72aaa 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -1723,7 +1723,6 @@ impl Timeline { initdb_optimization_count: 0, }; task_mgr::spawn( - task_mgr::BACKGROUND_RUNTIME.handle(), task_mgr::TaskKind::LayerFlushTask, Some(self.tenant_shard_id), Some(self.timeline_id), @@ -2086,7 +2085,6 @@ impl Timeline { DownloadBehavior::Download, ); task_mgr::spawn( - task_mgr::BACKGROUND_RUNTIME.handle(), task_mgr::TaskKind::InitialLogicalSizeCalculation, Some(self.tenant_shard_id), Some(self.timeline_id), @@ -2264,7 +2262,6 @@ impl Timeline { DownloadBehavior::Download, ); task_mgr::spawn( - task_mgr::BACKGROUND_RUNTIME.handle(), task_mgr::TaskKind::OndemandLogicalSizeCalculation, Some(self.tenant_shard_id), Some(self.timeline_id), @@ -3840,7 +3837,7 @@ impl Timeline { }; let timer = self.metrics.garbage_collect_histo.start_timer(); - fail_point!("before-timeline-gc"); + pausable_failpoint!("before-timeline-gc"); // Is the timeline being deleted? if self.is_stopping() { @@ -4151,7 +4148,6 @@ impl Timeline { let self_clone = Arc::clone(&self); let task_id = task_mgr::spawn( - task_mgr::BACKGROUND_RUNTIME.handle(), task_mgr::TaskKind::DownloadAllRemoteLayers, Some(self.tenant_shard_id), Some(self.timeline_id), @@ -4469,6 +4465,9 @@ impl<'a> TimelineWriter<'a> { let action = self.get_open_layer_action(last_record_lsn, 0); if action == OpenLayerAction::Roll { self.roll_layer(last_record_lsn).await?; + } else if let Some(writer_state) = &mut *self.write_guard { + // Periodic update of statistics + writer_state.open_layer.tick().await; } Ok(()) diff --git a/pageserver/src/tenant/timeline/delete.rs b/pageserver/src/tenant/timeline/delete.rs index a0c9d99196..d2272fc75f 100644 --- a/pageserver/src/tenant/timeline/delete.rs +++ b/pageserver/src/tenant/timeline/delete.rs @@ -443,7 +443,6 @@ impl DeleteTimelineFlow { let timeline_id = timeline.timeline_id; task_mgr::spawn( - task_mgr::BACKGROUND_RUNTIME.handle(), TaskKind::TimelineDeletionWorker, Some(tenant_shard_id), Some(timeline_id), diff --git a/pageserver/src/tenant/timeline/eviction_task.rs b/pageserver/src/tenant/timeline/eviction_task.rs index dd769d4121..f84a4b0dac 100644 --- a/pageserver/src/tenant/timeline/eviction_task.rs +++ b/pageserver/src/tenant/timeline/eviction_task.rs @@ -28,7 +28,7 @@ use tracing::{debug, error, info, info_span, instrument, warn, Instrument}; use crate::{ context::{DownloadBehavior, RequestContext}, pgdatadir_mapping::CollectKeySpaceError, - task_mgr::{self, TaskKind, BACKGROUND_RUNTIME}, + task_mgr::{self, TaskKind}, tenant::{ tasks::BackgroundLoopKind, timeline::EvictionError, LogicalSizeCalculationCause, Tenant, }, @@ -56,7 +56,6 @@ impl Timeline { let self_clone = Arc::clone(self); let background_tasks_can_start = background_tasks_can_start.cloned(); task_mgr::spawn( - BACKGROUND_RUNTIME.handle(), TaskKind::Eviction, Some(self.tenant_shard_id), Some(self.timeline_id), diff --git a/pageserver/src/tenant/timeline/walreceiver.rs b/pageserver/src/tenant/timeline/walreceiver.rs index 2fab6722b8..3592dda8d7 100644 --- a/pageserver/src/tenant/timeline/walreceiver.rs +++ b/pageserver/src/tenant/timeline/walreceiver.rs @@ -24,7 +24,7 @@ mod connection_manager; mod walreceiver_connection; use crate::context::{DownloadBehavior, RequestContext}; -use crate::task_mgr::{self, TaskKind, WALRECEIVER_RUNTIME}; +use crate::task_mgr::{self, TaskKind}; use crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id; use crate::tenant::timeline::walreceiver::connection_manager::{ connection_manager_loop_step, ConnectionManagerState, @@ -82,7 +82,6 @@ impl WalReceiver { let loop_status = Arc::new(std::sync::RwLock::new(None)); let manager_status = Arc::clone(&loop_status); task_mgr::spawn( - WALRECEIVER_RUNTIME.handle(), TaskKind::WalReceiverManager, Some(timeline.tenant_shard_id), Some(timeline_id), @@ -181,7 +180,7 @@ impl TaskHandle { let (events_sender, events_receiver) = watch::channel(TaskStateUpdate::Started); let cancellation_clone = cancellation.clone(); - let join_handle = WALRECEIVER_RUNTIME.spawn(async move { + let join_handle = tokio::spawn(async move { events_sender.send(TaskStateUpdate::Started).ok(); task(events_sender, cancellation_clone).await // events_sender is dropped at some point during the .await above. diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index d9f780cfd1..cf87cc6ce0 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -11,7 +11,6 @@ use std::{ use anyhow::{anyhow, Context}; use bytes::BytesMut; use chrono::{NaiveDateTime, Utc}; -use fail::fail_point; use futures::StreamExt; use postgres::{error::SqlState, SimpleQueryMessage, SimpleQueryRow}; use postgres_ffi::WAL_SEGMENT_SIZE; @@ -27,9 +26,7 @@ use super::TaskStateUpdate; use crate::{ context::RequestContext, metrics::{LIVE_CONNECTIONS_COUNT, WALRECEIVER_STARTED_CONNECTIONS, WAL_INGEST}, - task_mgr, - task_mgr::TaskKind, - task_mgr::WALRECEIVER_RUNTIME, + task_mgr::{self, TaskKind}, tenant::{debug_assert_current_span_has_tenant_and_timeline_id, Timeline, WalReceiverInfo}, walingest::WalIngest, walrecord::DecodedWALRecord, @@ -163,7 +160,6 @@ pub(super) async fn handle_walreceiver_connection( ); let connection_cancellation = cancellation.clone(); task_mgr::spawn( - WALRECEIVER_RUNTIME.handle(), TaskKind::WalReceiverConnectionPoller, Some(timeline.tenant_shard_id), Some(timeline.timeline_id), @@ -329,7 +325,17 @@ pub(super) async fn handle_walreceiver_connection( filtered_records += 1; } - fail_point!("walreceiver-after-ingest"); + // don't simply use pausable_failpoint here because its spawn_blocking slows + // slows down the tests too much. + fail::fail_point!("walreceiver-after-ingest-blocking"); + if let Err(()) = (|| { + fail::fail_point!("walreceiver-after-ingest-pause-activate", |_| { + Err(()) + }); + Ok(()) + })() { + pausable_failpoint!("walreceiver-after-ingest-pause"); + } last_rec_lsn = lsn; diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index 6ede78a576..8d236144b5 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -312,7 +312,7 @@ pg_cluster_size(PG_FUNCTION_ARGS) { int64 size; - size = GetZenithCurrentClusterSize(); + size = GetNeonCurrentClusterSize(); if (size == 0) PG_RETURN_NULL(); diff --git a/pgxn/neon/neon.h b/pgxn/neon/neon.h index a0f8c97497..5c653fc6c6 100644 --- a/pgxn/neon/neon.h +++ b/pgxn/neon/neon.h @@ -26,6 +26,8 @@ extern void pg_init_libpagestore(void); extern void pg_init_walproposer(void); extern uint64 BackpressureThrottlingTime(void); +extern void SetNeonCurrentClusterSize(uint64 size); +extern uint64 GetNeonCurrentClusterSize(void); extern void replication_feedback_get_lsns(XLogRecPtr *writeLsn, XLogRecPtr *flushLsn, XLogRecPtr *applyLsn); extern void PGDLLEXPORT WalProposerSync(int argc, char *argv[]); diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index 0256de2b9a..2d222e3c7c 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -1831,7 +1831,7 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, reln->smgr_relpersistence == RELPERSISTENCE_PERMANENT && !IsAutoVacuumWorkerProcess()) { - uint64 current_size = GetZenithCurrentClusterSize(); + uint64 current_size = GetNeonCurrentClusterSize(); if (current_size >= ((uint64) max_cluster_size) * 1024 * 1024) ereport(ERROR, @@ -1912,7 +1912,7 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum, reln->smgr_relpersistence == RELPERSISTENCE_PERMANENT && !IsAutoVacuumWorkerProcess()) { - uint64 current_size = GetZenithCurrentClusterSize(); + uint64 current_size = GetNeonCurrentClusterSize(); if (current_size >= ((uint64) max_cluster_size) * 1024 * 1024) ereport(ERROR, diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 28585eb4e7..69a557fdf2 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -287,6 +287,7 @@ typedef struct WalproposerShmemState slock_t mutex; term_t mineLastElectedTerm; pg_atomic_uint64 backpressureThrottlingTime; + pg_atomic_uint64 currentClusterSize; /* last feedback from each shard */ PageserverFeedback shard_ps_feedback[MAX_SHARDS]; diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 002bf4e2ce..7debb6325e 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -282,6 +282,7 @@ WalproposerShmemInit(void) memset(walprop_shared, 0, WalproposerShmemSize()); SpinLockInit(&walprop_shared->mutex); pg_atomic_init_u64(&walprop_shared->backpressureThrottlingTime, 0); + pg_atomic_init_u64(&walprop_shared->currentClusterSize, 0); } LWLockRelease(AddinShmemInitLock); @@ -1972,7 +1973,7 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, Safekeeper *sk) /* Only one main shard sends non-zero currentClusterSize */ if (sk->appendResponse.ps_feedback.currentClusterSize > 0) - SetZenithCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize); + SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize); if (min_feedback.disk_consistent_lsn != standby_apply_lsn) { @@ -2094,6 +2095,18 @@ GetLogRepRestartLSN(WalProposer *wp) return lrRestartLsn; } +void SetNeonCurrentClusterSize(uint64 size) +{ + pg_atomic_write_u64(&walprop_shared->currentClusterSize, size); +} + +uint64 GetNeonCurrentClusterSize(void) +{ + return pg_atomic_read_u64(&walprop_shared->currentClusterSize); +} +uint64 GetNeonCurrentClusterSize(void); + + static const walproposer_api walprop_pg = { .get_shmem_state = walprop_pg_get_shmem_state, .start_streaming = walprop_pg_start_streaming, diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 93a1fe85db..57a2736d5b 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -11,6 +11,10 @@ testing = [] [dependencies] anyhow.workspace = true async-trait.workspace = true +aws-config.workspace = true +aws-sdk-iam.workspace = true +aws-sigv4.workspace = true +aws-types.workspace = true base64.workspace = true bstr.workspace = true bytes = { workspace = true, features = ["serde"] } @@ -27,6 +31,7 @@ hashlink.workspace = true hex.workspace = true hmac.workspace = true hostname.workspace = true +http.workspace = true humantime.workspace = true hyper-tungstenite.workspace = true hyper.workspace = true @@ -92,6 +97,7 @@ workspace_hack.workspace = true [dev-dependencies] camino-tempfile.workspace = true +fallible-iterator.workspace = true rcgen.workspace = true rstest.workspace = true tokio-postgres-rustls.workspace = true diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index bc307230dd..04fe83d8eb 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -408,3 +408,228 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use bytes::BytesMut; + use fallible_iterator::FallibleIterator; + use postgres_protocol::{ + authentication::sasl::{ChannelBinding, ScramSha256}, + message::{backend::Message as PgMessage, frontend}, + }; + use provider::AuthSecret; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + + use crate::{ + auth::{ComputeUserInfoMaybeEndpoint, IpPattern}, + config::AuthenticationConfig, + console::{ + self, + provider::{self, CachedAllowedIps, CachedRoleSecret}, + CachedNodeInfo, + }, + context::RequestMonitoring, + proxy::NeonOptions, + scram::ServerSecret, + stream::{PqStream, Stream}, + }; + + use super::auth_quirks; + + struct Auth { + ips: Vec, + secret: AuthSecret, + } + + impl console::Api for Auth { + async fn get_role_secret( + &self, + _ctx: &mut RequestMonitoring, + _user_info: &super::ComputeUserInfo, + ) -> Result { + Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone()))) + } + + async fn get_allowed_ips_and_secret( + &self, + _ctx: &mut RequestMonitoring, + _user_info: &super::ComputeUserInfo, + ) -> Result<(CachedAllowedIps, Option), console::errors::GetAuthInfoError> + { + Ok(( + CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())), + Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))), + )) + } + + async fn wake_compute( + &self, + _ctx: &mut RequestMonitoring, + _user_info: &super::ComputeUserInfo, + ) -> Result { + unimplemented!() + } + } + + static CONFIG: &AuthenticationConfig = &AuthenticationConfig { + scram_protocol_timeout: std::time::Duration::from_secs(5), + }; + + async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage { + loop { + r.read_buf(&mut *b).await.unwrap(); + if let Some(m) = PgMessage::parse(&mut *b).unwrap() { + break m; + } + } + } + + #[tokio::test] + async fn auth_quirks_scram() { + let (mut client, server) = tokio::io::duplex(1024); + let mut stream = PqStream::new(Stream::from_raw(server)); + + let mut ctx = RequestMonitoring::test(); + let api = Auth { + ips: vec![], + secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), + }; + + let user_info = ComputeUserInfoMaybeEndpoint { + user: "conrad".into(), + endpoint_id: Some("endpoint".into()), + options: NeonOptions::default(), + }; + + let handle = tokio::spawn(async move { + let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported()); + + let mut read = BytesMut::new(); + + // server should offer scram + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationSasl(a) => { + let options: Vec<&str> = a.mechanisms().collect().unwrap(); + assert_eq!(options, ["SCRAM-SHA-256"]); + } + _ => panic!("wrong message"), + } + + // client sends client-first-message + let mut write = BytesMut::new(); + frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap(); + client.write_all(&write).await.unwrap(); + + // server response with server-first-message + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationSaslContinue(a) => { + scram.update(a.data()).await.unwrap(); + } + _ => panic!("wrong message"), + } + + // client response with client-final-message + write.clear(); + frontend::sasl_response(scram.message(), &mut write).unwrap(); + client.write_all(&write).await.unwrap(); + + // server response with server-final-message + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationSaslFinal(a) => { + scram.finish(a.data()).unwrap(); + } + _ => panic!("wrong message"), + } + }); + + let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, CONFIG) + .await + .unwrap(); + + handle.await.unwrap(); + } + + #[tokio::test] + async fn auth_quirks_cleartext() { + let (mut client, server) = tokio::io::duplex(1024); + let mut stream = PqStream::new(Stream::from_raw(server)); + + let mut ctx = RequestMonitoring::test(); + let api = Auth { + ips: vec![], + secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), + }; + + let user_info = ComputeUserInfoMaybeEndpoint { + user: "conrad".into(), + endpoint_id: Some("endpoint".into()), + options: NeonOptions::default(), + }; + + let handle = tokio::spawn(async move { + let mut read = BytesMut::new(); + let mut write = BytesMut::new(); + + // server should offer cleartext + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationCleartextPassword => {} + _ => panic!("wrong message"), + } + + // client responds with password + write.clear(); + frontend::password_message(b"my-secret-password", &mut write).unwrap(); + client.write_all(&write).await.unwrap(); + }); + + let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG) + .await + .unwrap(); + + handle.await.unwrap(); + } + + #[tokio::test] + async fn auth_quirks_password_hack() { + let (mut client, server) = tokio::io::duplex(1024); + let mut stream = PqStream::new(Stream::from_raw(server)); + + let mut ctx = RequestMonitoring::test(); + let api = Auth { + ips: vec![], + secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()), + }; + + let user_info = ComputeUserInfoMaybeEndpoint { + user: "conrad".into(), + endpoint_id: None, + options: NeonOptions::default(), + }; + + let handle = tokio::spawn(async move { + let mut read = BytesMut::new(); + + // server should offer cleartext + match read_message(&mut client, &mut read).await { + PgMessage::AuthenticationCleartextPassword => {} + _ => panic!("wrong message"), + } + + // client responds with password + let mut write = BytesMut::new(); + frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write) + .unwrap(); + client.write_all(&write).await.unwrap(); + }); + + let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG) + .await + .unwrap(); + + assert_eq!(creds.info.endpoint, "my-endpoint"); + + handle.await.unwrap(); + } +} diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index b3d4fc0411..d38439c2a0 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,3 +1,10 @@ +use aws_config::environment::EnvironmentVariableCredentialsProvider; +use aws_config::imds::credentials::ImdsCredentialsProvider; +use aws_config::meta::credentials::CredentialsProviderChain; +use aws_config::meta::region::RegionProviderChain; +use aws_config::profile::ProfileFileCredentialsProvider; +use aws_config::provider_config::ProviderConfig; +use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; use futures::future::Either; use proxy::auth; use proxy::auth::backend::MaybeOwned; @@ -10,11 +17,14 @@ use proxy::config::ProjectInfoCacheOptions; use proxy::console; use proxy::context::parquet::ParquetUploadArgs; use proxy::http; +use proxy::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT; use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateLimiterConfig; +use proxy::redis::cancellation_publisher::RedisPublisherClient; +use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; +use proxy::redis::elasticache; use proxy::redis::notifications; -use proxy::redis::publisher::RedisPublisherClient; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -150,9 +160,24 @@ struct ProxyCliArgs { /// disable ip check for http requests. If it is too time consuming, it could be turned off. #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] disable_ip_check_for_http: bool, - /// redis url for notifications. + /// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections) #[clap(long)] redis_notifications: Option, + /// redis host for streaming connections (might be different from the notifications host) + #[clap(long)] + redis_host: Option, + /// redis port for streaming connections (might be different from the notifications host) + #[clap(long)] + redis_port: Option, + /// redis cluster name, used in aws elasticache + #[clap(long)] + redis_cluster_name: Option, + /// redis user_id, used in aws elasticache + #[clap(long)] + redis_user_id: Option, + /// aws region to retrieve credentials + #[clap(long, default_value_t = String::new())] + aws_region: String, /// cache for `project_info` (use `size=0` to disable) #[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)] project_info_cache: String, @@ -216,6 +241,61 @@ async fn main() -> anyhow::Result<()> { let config = build_config(&args)?; info!("Authentication backend: {}", config.auth_backend); + info!("Using region: {}", config.aws_region); + + let region_provider = RegionProviderChain::default_provider().or_else(&*config.aws_region); // Replace with your Redis region if needed + let provider_conf = + ProviderConfig::without_region().with_region(region_provider.region().await); + let aws_credentials_provider = { + // uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY" + CredentialsProviderChain::first_try("env", EnvironmentVariableCredentialsProvider::new()) + // uses "AWS_PROFILE" / `aws sso login --profile ` + .or_else( + "profile-sso", + ProfileFileCredentialsProvider::builder() + .configure(&provider_conf) + .build(), + ) + // uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME" + // needed to access remote extensions bucket + .or_else( + "token", + WebIdentityTokenCredentialsProvider::builder() + .configure(&provider_conf) + .build(), + ) + // uses imds v2 + .or_else("imds", ImdsCredentialsProvider::builder().build()) + }; + let elasticache_credentials_provider = Arc::new(elasticache::CredentialsProvider::new( + elasticache::AWSIRSAConfig::new( + config.aws_region.clone(), + args.redis_cluster_name, + args.redis_user_id, + ), + aws_credentials_provider, + )); + let redis_notifications_client = + match (args.redis_notifications, (args.redis_host, args.redis_port)) { + (Some(url), _) => { + info!("Starting redis notifications listener ({url})"); + Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url)) + } + (None, (Some(host), Some(port))) => Some( + ConnectionWithCredentialsProvider::new_with_credentials_provider( + host, + port, + elasticache_credentials_provider.clone(), + ), + ), + (None, (None, None)) => { + warn!("Redis is disabled"); + None + } + _ => { + bail!("redis-host and redis-port must be specified together"); + } + }; // Check that we can bind to address before further initialization let http_address: SocketAddr = args.http.parse()?; @@ -233,17 +313,22 @@ async fn main() -> anyhow::Result<()> { let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit)); let cancel_map = CancelMap::default(); - let redis_publisher = match &args.redis_notifications { - Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( - url, + + // let redis_notifications_client = redis_notifications_client.map(|x| Box::leak(Box::new(x))); + let redis_publisher = match &redis_notifications_client { + Some(redis_publisher) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( + redis_publisher.clone(), args.region.clone(), &config.redis_rps_limit, )?))), None => None, }; - let cancellation_handler = Arc::new(CancellationHandler::new( + let cancellation_handler = Arc::new(CancellationHandler::< + Option>>, + >::new( cancel_map.clone(), redis_publisher, + NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT, )); // client facing tasks. these will exit on error or on cancellation @@ -290,17 +375,16 @@ async fn main() -> anyhow::Result<()> { if let auth::BackendType::Console(api, _) = &config.auth_backend { if let proxy::console::provider::ConsoleBackend::Console(api) = &**api { - let cache = api.caches.project_info.clone(); - if let Some(url) = args.redis_notifications { - info!("Starting redis notifications listener ({url})"); + if let Some(redis_notifications_client) = redis_notifications_client { + let cache = api.caches.project_info.clone(); maintenance_tasks.spawn(notifications::task_main( - url.to_owned(), + redis_notifications_client.clone(), cache.clone(), cancel_map.clone(), args.region.clone(), )); + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } } @@ -445,8 +529,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { endpoint_rps_limit, redis_rps_limit, handshake_timeout: args.handshake_timeout, - // TODO: add this argument region: args.region.clone(), + aws_region: args.aws_region.clone(), })); Ok(config) diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index c9607909b3..6151513614 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use dashmap::DashMap; use pq_proto::CancelKeyData; use std::{net::SocketAddr, sync::Arc}; @@ -10,18 +9,26 @@ use tracing::info; use uuid::Uuid; use crate::{ - error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS, - redis::publisher::RedisPublisherClient, + error::ReportableError, + metrics::NUM_CANCELLATION_REQUESTS, + redis::cancellation_publisher::{ + CancellationPublisher, CancellationPublisherMut, RedisPublisherClient, + }, }; pub type CancelMap = Arc>>; +pub type CancellationHandlerMain = CancellationHandler>>>; +pub type CancellationHandlerMainInternal = Option>>; /// Enables serving `CancelRequest`s. /// -/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances. -pub struct CancellationHandler { +/// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances. +pub struct CancellationHandler

{ map: CancelMap, - redis_client: Option>>, + client: P, + /// This field used for the monitoring purposes. + /// Represents the source of the cancellation request. + from: &'static str, } #[derive(Debug, Error)] @@ -44,49 +51,9 @@ impl ReportableError for CancelError { } } -impl CancellationHandler { - pub fn new(map: CancelMap, redis_client: Option>>) -> Self { - Self { map, redis_client } - } - /// Cancel a running query for the corresponding connection. - pub async fn cancel_session( - &self, - key: CancelKeyData, - session_id: Uuid, - ) -> Result<(), CancelError> { - let from = "from_client"; - // NB: we should immediately release the lock after cloning the token. - let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { - tracing::warn!("query cancellation key not found: {key}"); - if let Some(redis_client) = &self.redis_client { - NUM_CANCELLATION_REQUESTS - .with_label_values(&[from, "not_found"]) - .inc(); - info!("publishing cancellation key to Redis"); - match redis_client.lock().await.try_publish(key, session_id).await { - Ok(()) => { - info!("cancellation key successfuly published to Redis"); - } - Err(e) => { - tracing::error!("failed to publish a message: {e}"); - return Err(CancelError::IO(std::io::Error::new( - std::io::ErrorKind::Other, - e.to_string(), - ))); - } - } - } - return Ok(()); - }; - NUM_CANCELLATION_REQUESTS - .with_label_values(&[from, "found"]) - .inc(); - info!("cancelling query per user's request using key {key}"); - cancel_closure.try_cancel_query().await - } - +impl CancellationHandler

{ /// Run async action within an ephemeral session identified by [`CancelKeyData`]. - pub fn get_session(self: Arc) -> Session { + pub fn get_session(self: Arc) -> Session

{ // HACK: We'd rather get the real backend_pid but tokio_postgres doesn't // expose it and we don't want to do another roundtrip to query // for it. The client will be able to notice that this is not the @@ -112,9 +79,39 @@ impl CancellationHandler { cancellation_handler: self, } } + /// Try to cancel a running query for the corresponding connection. + /// If the cancellation key is not found, it will be published to Redis. + pub async fn cancel_session( + &self, + key: CancelKeyData, + session_id: Uuid, + ) -> Result<(), CancelError> { + // NB: we should immediately release the lock after cloning the token. + let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { + tracing::warn!("query cancellation key not found: {key}"); + NUM_CANCELLATION_REQUESTS + .with_label_values(&[self.from, "not_found"]) + .inc(); + match self.client.try_publish(key, session_id).await { + Ok(()) => {} // do nothing + Err(e) => { + return Err(CancelError::IO(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + ))); + } + } + return Ok(()); + }; + NUM_CANCELLATION_REQUESTS + .with_label_values(&[self.from, "found"]) + .inc(); + info!("cancelling query per user's request using key {key}"); + cancel_closure.try_cancel_query().await + } #[cfg(test)] - fn contains(&self, session: &Session) -> bool { + fn contains(&self, session: &Session

) -> bool { self.map.contains_key(&session.key) } @@ -124,31 +121,19 @@ impl CancellationHandler { } } -#[async_trait] -pub trait NotificationsCancellationHandler { - async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>; +impl CancellationHandler<()> { + pub fn new(map: CancelMap, from: &'static str) -> Self { + Self { + map, + client: (), + from, + } + } } -#[async_trait] -impl NotificationsCancellationHandler for CancellationHandler { - async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> { - let from = "from_redis"; - let cancel_closure = self.map.get(&key).and_then(|x| x.clone()); - match cancel_closure { - Some(cancel_closure) => { - NUM_CANCELLATION_REQUESTS - .with_label_values(&[from, "found"]) - .inc(); - cancel_closure.try_cancel_query().await - } - None => { - NUM_CANCELLATION_REQUESTS - .with_label_values(&[from, "not_found"]) - .inc(); - tracing::warn!("query cancellation key not found: {key}"); - Ok(()) - } - } +impl CancellationHandler>>> { + pub fn new(map: CancelMap, client: Option>>, from: &'static str) -> Self { + Self { map, client, from } } } @@ -178,14 +163,14 @@ impl CancelClosure { } /// Helper for registering query cancellation tokens. -pub struct Session { +pub struct Session

{ /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. - cancellation_handler: Arc, + cancellation_handler: Arc>, } -impl Session { +impl

Session

{ /// Store the cancel token for the given session. /// This enables query cancellation in `crate::proxy::prepare_client_connection`. pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { @@ -198,7 +183,7 @@ impl Session { } } -impl Drop for Session { +impl

Drop for Session

{ fn drop(&mut self) { self.cancellation_handler.map.remove(&self.key); info!("dropped query cancellation key {}", &self.key); @@ -207,14 +192,16 @@ impl Drop for Session { #[cfg(test)] mod tests { + use crate::metrics::NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS; + use super::*; #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { - let cancellation_handler = Arc::new(CancellationHandler { - map: CancelMap::default(), - redis_client: None, - }); + let cancellation_handler = Arc::new(CancellationHandler::<()>::new( + CancelMap::default(), + NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, + )); let session = cancellation_handler.clone().get_session(); assert!(cancellation_handler.contains(&session)); @@ -224,4 +211,19 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn cancel_session_noop_regression() { + let handler = CancellationHandler::<()>::new(Default::default(), "local"); + handler + .cancel_session( + CancelKeyData { + backend_pid: 0, + cancel_key: 0, + }, + Uuid::new_v4(), + ) + .await + .unwrap(); + } } diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index b61c1fb9ef..65153babcb 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -82,14 +82,13 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; /// A config for establishing a connection to compute node. /// Eventually, `tokio_postgres` will be replaced with something better. /// Newtype allows us to implement methods on top of it. -#[derive(Clone)] -#[repr(transparent)] +#[derive(Clone, Default)] pub struct ConnCfg(Box); /// Creation and initialization routines. impl ConnCfg { pub fn new() -> Self { - Self(Default::default()) + Self::default() } /// Reuse password or auth keys from the other config. @@ -165,12 +164,6 @@ impl std::ops::DerefMut for ConnCfg { } } -impl Default for ConnCfg { - fn default() -> Self { - Self::new() - } -} - impl ConnCfg { /// Establish a raw TCP connection to the compute node. async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> { diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 437ec9f401..45f8d76144 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -28,6 +28,7 @@ pub struct ProxyConfig { pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, + pub aws_region: String, } #[derive(Debug)] diff --git a/proxy/src/console.rs b/proxy/src/console.rs index fd3c46b946..ea95e83437 100644 --- a/proxy/src/console.rs +++ b/proxy/src/console.rs @@ -6,7 +6,7 @@ pub mod messages; /// Wrappers for console APIs and their mocks. pub mod provider; -pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo}; +pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo}; /// Various cache-related types. pub mod caches { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index 8609606273..69bfd6b045 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -14,7 +14,6 @@ use crate::{ context::RequestMonitoring, scram, EndpointCacheKey, ProjectId, }; -use async_trait::async_trait; use dashmap::DashMap; use std::{sync::Arc, time::Duration}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; @@ -326,8 +325,7 @@ pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc), } -#[async_trait] impl Api for ConsoleBackend { async fn get_role_secret( &self, diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 0579ef6fc4..b759c81373 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -8,7 +8,6 @@ use crate::console::provider::{CachedAllowedIps, CachedRoleSecret}; use crate::context::RequestMonitoring; use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl}; use crate::{auth::IpPattern, cache::Cached}; -use async_trait::async_trait; use futures::TryFutureExt; use std::{str::FromStr, sync::Arc}; use thiserror::Error; @@ -144,7 +143,6 @@ async fn get_execute_postgres_query( Ok(Some(entry)) } -#[async_trait] impl super::Api for Api { #[tracing::instrument(skip_all)] async fn get_role_secret( diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index b36663518d..89ebfa57f1 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -14,7 +14,6 @@ use crate::{ context::RequestMonitoring, metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER}, }; -use async_trait::async_trait; use futures::TryFutureExt; use std::sync::Arc; use tokio::time::Instant; @@ -168,7 +167,6 @@ impl Api { } } -#[async_trait] impl super::Api for Api { #[tracing::instrument(skip_all)] async fn get_role_secret( diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 02ebcd6aaa..eed45e421b 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -161,6 +161,9 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { .unwrap() }); +pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client"; +pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis"; + pub enum Waiting { Cplane, Client, diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index ab5bf5d494..843bfc08cf 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -10,7 +10,7 @@ pub mod wake_compute; use crate::{ auth, - cancellation::{self, CancellationHandler}, + cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal}, compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, @@ -62,7 +62,7 @@ pub async fn task_main( listener: tokio::net::TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, - cancellation_handler: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); @@ -233,12 +233,12 @@ impl ReportableError for ClientRequestError { pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - cancellation_handler: Arc, + cancellation_handler: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: IntCounterPairGuard, -) -> Result>, ClientRequestError> { +) -> Result>, ClientRequestError> { info!("handling interactive connection from client"); let proto = ctx.protocol; @@ -338,9 +338,9 @@ pub async fn handle_client( /// Finish client connection initialization: confirm auth success, send params, etc. #[tracing::instrument(skip_all)] -async fn prepare_client_connection( +async fn prepare_client_connection

( node: &compute::PostgresConnection, - session: &cancellation::Session, + session: &cancellation::Session

, stream: &mut PqStream, ) -> Result<(), std::io::Error> { // Register compute's query cancellation token and produce a new, unique one. diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index b2f682fd2f..f6d4314391 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -55,17 +55,17 @@ pub async fn proxy_pass( Ok(()) } -pub struct ProxyPassthrough { +pub struct ProxyPassthrough { pub client: Stream, pub compute: PostgresConnection, pub aux: MetricsAuxInfo, pub req: IntCounterPairGuard, pub conn: IntCounterPairGuard, - pub cancel: cancellation::Session, + pub cancel: cancellation::Session

, } -impl ProxyPassthrough { +impl ProxyPassthrough { pub async fn proxy_pass(self) -> anyhow::Result<()> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; self.compute.cancel_closure.try_cancel_query().await?; diff --git a/proxy/src/redis.rs b/proxy/src/redis.rs index 35d6db074e..a322f0368c 100644 --- a/proxy/src/redis.rs +++ b/proxy/src/redis.rs @@ -1,2 +1,4 @@ +pub mod cancellation_publisher; +pub mod connection_with_credentials_provider; +pub mod elasticache; pub mod notifications; -pub mod publisher; diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs new file mode 100644 index 0000000000..422789813c --- /dev/null +++ b/proxy/src/redis/cancellation_publisher.rs @@ -0,0 +1,161 @@ +use std::sync::Arc; + +use pq_proto::CancelKeyData; +use redis::AsyncCommands; +use tokio::sync::Mutex; +use uuid::Uuid; + +use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; + +use super::{ + connection_with_credentials_provider::ConnectionWithCredentialsProvider, + notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}, +}; + +pub trait CancellationPublisherMut: Send + Sync + 'static { + #[allow(async_fn_in_trait)] + async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()>; +} + +pub trait CancellationPublisher: Send + Sync + 'static { + #[allow(async_fn_in_trait)] + async fn try_publish( + &self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()>; +} + +impl CancellationPublisher for () { + async fn try_publish( + &self, + _cancel_key_data: CancelKeyData, + _session_id: Uuid, + ) -> anyhow::Result<()> { + Ok(()) + } +} + +impl CancellationPublisherMut for P { + async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { +

::try_publish(self, cancel_key_data, session_id).await + } +} + +impl CancellationPublisher for Option

{ + async fn try_publish( + &self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + if let Some(p) = self { + p.try_publish(cancel_key_data, session_id).await + } else { + Ok(()) + } + } +} + +impl CancellationPublisher for Arc> { + async fn try_publish( + &self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + self.lock() + .await + .try_publish(cancel_key_data, session_id) + .await + } +} + +pub struct RedisPublisherClient { + client: ConnectionWithCredentialsProvider, + region_id: String, + limiter: RedisRateLimiter, +} + +impl RedisPublisherClient { + pub fn new( + client: ConnectionWithCredentialsProvider, + region_id: String, + info: &'static [RateBucketInfo], + ) -> anyhow::Result { + Ok(Self { + client, + region_id, + limiter: RedisRateLimiter::new(info), + }) + } + + async fn publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + let payload = serde_json::to_string(&Notification::Cancel(CancelSession { + region_id: Some(self.region_id.clone()), + cancel_key_data, + session_id, + }))?; + self.client.publish(PROXY_CHANNEL_NAME, payload).await?; + Ok(()) + } + pub async fn try_connect(&mut self) -> anyhow::Result<()> { + match self.client.connect().await { + Ok(()) => {} + Err(e) => { + tracing::error!("failed to connect to redis: {e}"); + return Err(e); + } + } + Ok(()) + } + async fn try_publish_internal( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping cancellation message"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + match self.publish(cancel_key_data, session_id).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + } + } + tracing::info!("Publisher is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.publish(cancel_key_data, session_id).await + } +} + +impl CancellationPublisherMut for RedisPublisherClient { + async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + tracing::info!("publishing cancellation key to Redis"); + match self.try_publish_internal(cancel_key_data, session_id).await { + Ok(()) => { + tracing::info!("cancellation key successfuly published to Redis"); + Ok(()) + } + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + Err(e) + } + } + } +} diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs new file mode 100644 index 0000000000..d183abb53a --- /dev/null +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -0,0 +1,225 @@ +use std::{sync::Arc, time::Duration}; + +use futures::FutureExt; +use redis::{ + aio::{ConnectionLike, MultiplexedConnection}, + ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult, +}; +use tokio::task::JoinHandle; +use tracing::{error, info}; + +use super::elasticache::CredentialsProvider; + +enum Credentials { + Static(ConnectionInfo), + Dynamic(Arc, redis::ConnectionAddr), +} + +impl Clone for Credentials { + fn clone(&self) -> Self { + match self { + Credentials::Static(info) => Credentials::Static(info.clone()), + Credentials::Dynamic(provider, addr) => { + Credentials::Dynamic(Arc::clone(provider), addr.clone()) + } + } + } +} + +/// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token. +/// Provides PubSub connection without credentials refresh. +pub struct ConnectionWithCredentialsProvider { + credentials: Credentials, + con: Option, + refresh_token_task: Option>, + mutex: tokio::sync::Mutex<()>, +} + +impl Clone for ConnectionWithCredentialsProvider { + fn clone(&self) -> Self { + Self { + credentials: self.credentials.clone(), + con: None, + refresh_token_task: None, + mutex: tokio::sync::Mutex::new(()), + } + } +} + +impl ConnectionWithCredentialsProvider { + pub fn new_with_credentials_provider( + host: String, + port: u16, + credentials_provider: Arc, + ) -> Self { + Self { + credentials: Credentials::Dynamic( + credentials_provider, + redis::ConnectionAddr::TcpTls { + host, + port, + insecure: false, + tls_params: None, + }, + ), + con: None, + refresh_token_task: None, + mutex: tokio::sync::Mutex::new(()), + } + } + + pub fn new_with_static_credentials(params: T) -> Self { + Self { + credentials: Credentials::Static(params.into_connection_info().unwrap()), + con: None, + refresh_token_task: None, + mutex: tokio::sync::Mutex::new(()), + } + } + + pub async fn connect(&mut self) -> anyhow::Result<()> { + let _guard = self.mutex.lock().await; + if let Some(con) = self.con.as_mut() { + match redis::cmd("PING").query_async(con).await { + Ok(()) => { + return Ok(()); + } + Err(e) => { + error!("Error during PING: {e:?}"); + } + } + } else { + info!("Connection is not established"); + } + info!("Establishing a new connection..."); + self.con = None; + if let Some(f) = self.refresh_token_task.take() { + f.abort() + } + let con = self + .get_client() + .await? + .get_multiplexed_tokio_connection() + .await?; + if let Credentials::Dynamic(credentials_provider, _) = &self.credentials { + let credentials_provider = credentials_provider.clone(); + let con2 = con.clone(); + let f = tokio::spawn(async move { + let _ = Self::keep_connection(con2, credentials_provider).await; + }); + self.refresh_token_task = Some(f); + } + self.con = Some(con); + Ok(()) + } + + async fn get_connection_info(&self) -> anyhow::Result { + match &self.credentials { + Credentials::Static(info) => Ok(info.clone()), + Credentials::Dynamic(provider, addr) => { + let (username, password) = provider.provide_credentials().await?; + Ok(ConnectionInfo { + addr: addr.clone(), + redis: RedisConnectionInfo { + db: 0, + username: Some(username), + password: Some(password.clone()), + }, + }) + } + } + } + + async fn get_client(&self) -> anyhow::Result { + let client = redis::Client::open(self.get_connection_info().await?)?; + Ok(client) + } + + // PubSub does not support credentials refresh. + // Requires manual reconnection every 12h. + pub async fn get_async_pubsub(&self) -> anyhow::Result { + Ok(self.get_client().await?.get_async_pubsub().await?) + } + + // The connection lives for 12h. + // It can be prolonged with sending `AUTH` commands with the refreshed token. + // https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits + async fn keep_connection( + mut con: MultiplexedConnection, + credentials_provider: Arc, + ) -> anyhow::Result<()> { + loop { + // The connection lives for 12h, for the sanity check we refresh it every hour. + tokio::time::sleep(Duration::from_secs(60 * 60)).await; + match Self::refresh_token(&mut con, credentials_provider.clone()).await { + Ok(()) => { + info!("Token refreshed"); + } + Err(e) => { + error!("Error during token refresh: {e:?}"); + } + } + } + } + async fn refresh_token( + con: &mut MultiplexedConnection, + credentials_provider: Arc, + ) -> anyhow::Result<()> { + let (user, password) = credentials_provider.provide_credentials().await?; + redis::cmd("AUTH") + .arg(user) + .arg(password) + .query_async(con) + .await?; + Ok(()) + } + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult { + // Clone connection to avoid having to lock the ArcSwap in write mode + let con = self.con.as_mut().ok_or(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Connection not established", + )))?; + con.send_packed_command(cmd).await + } + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &redis::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + // Clone shared connection future to avoid having to lock the ArcSwap in write mode + let con = self.con.as_mut().ok_or(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Connection not established", + )))?; + con.send_packed_commands(cmd, offset, count).await + } +} + +impl ConnectionLike for ConnectionWithCredentialsProvider { + fn req_packed_command<'a>( + &'a mut self, + cmd: &'a redis::Cmd, + ) -> redis::RedisFuture<'a, redis::Value> { + (async move { self.send_packed_command(cmd).await }).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> redis::RedisFuture<'a, Vec> { + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() + } + + fn get_db(&self) -> i64 { + 0 + } +} diff --git a/proxy/src/redis/elasticache.rs b/proxy/src/redis/elasticache.rs new file mode 100644 index 0000000000..eded8250af --- /dev/null +++ b/proxy/src/redis/elasticache.rs @@ -0,0 +1,110 @@ +use std::time::{Duration, SystemTime}; + +use aws_config::meta::credentials::CredentialsProviderChain; +use aws_sdk_iam::config::ProvideCredentials; +use aws_sigv4::http_request::{ + self, SignableBody, SignableRequest, SignatureLocation, SigningSettings, +}; +use tracing::info; + +#[derive(Debug)] +pub struct AWSIRSAConfig { + region: String, + service_name: String, + cluster_name: String, + user_id: String, + token_ttl: Duration, + action: String, +} + +impl AWSIRSAConfig { + pub fn new(region: String, cluster_name: Option, user_id: Option) -> Self { + AWSIRSAConfig { + region, + service_name: "elasticache".to_string(), + cluster_name: cluster_name.unwrap_or_default(), + user_id: user_id.unwrap_or_default(), + // "The IAM authentication token is valid for 15 minutes" + // https://docs.aws.amazon.com/memorydb/latest/devguide/auth-iam.html#auth-iam-limits + token_ttl: Duration::from_secs(15 * 60), + action: "connect".to_string(), + } + } +} + +/// Credentials provider for AWS elasticache authentication. +/// +/// Official documentation: +/// +/// +/// Useful resources: +/// +pub struct CredentialsProvider { + config: AWSIRSAConfig, + credentials_provider: CredentialsProviderChain, +} + +impl CredentialsProvider { + pub fn new(config: AWSIRSAConfig, credentials_provider: CredentialsProviderChain) -> Self { + CredentialsProvider { + config, + credentials_provider, + } + } + pub async fn provide_credentials(&self) -> anyhow::Result<(String, String)> { + let aws_credentials = self + .credentials_provider + .provide_credentials() + .await? + .into(); + info!("AWS credentials successfully obtained"); + info!("Connecting to Redis with configuration: {:?}", self.config); + let mut settings = SigningSettings::default(); + settings.signature_location = SignatureLocation::QueryParams; + settings.expires_in = Some(self.config.token_ttl); + let signing_params = aws_sigv4::sign::v4::SigningParams::builder() + .identity(&aws_credentials) + .region(&self.config.region) + .name(&self.config.service_name) + .time(SystemTime::now()) + .settings(settings) + .build()? + .into(); + let auth_params = [ + ("Action", &self.config.action), + ("User", &self.config.user_id), + ]; + let auth_params = url::form_urlencoded::Serializer::new(String::new()) + .extend_pairs(auth_params) + .finish(); + let auth_uri = http::Uri::builder() + .scheme("http") + .authority(self.config.cluster_name.as_bytes()) + .path_and_query(format!("/?{auth_params}")) + .build()?; + info!("{}", auth_uri); + + // Convert the HTTP request into a signable request + let signable_request = SignableRequest::new( + "GET", + auth_uri.to_string(), + std::iter::empty(), + SignableBody::Bytes(&[]), + )?; + + // Sign and then apply the signature to the request + let (si, _) = http_request::sign(signable_request, &signing_params)?.into_parts(); + let mut signable_request = http::Request::builder() + .method("GET") + .uri(auth_uri) + .body(())?; + si.apply_to_request_http1x(&mut signable_request); + Ok(( + self.config.user_id.clone(), + signable_request + .uri() + .to_string() + .replacen("http://", "", 1), + )) + } +} diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 6ae848c0d2..8b7e3e3419 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -6,11 +6,12 @@ use redis::aio::PubSub; use serde::{Deserialize, Serialize}; use uuid::Uuid; +use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::{ cache::project_info::ProjectInfoCache, - cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler}, + cancellation::{CancelMap, CancellationHandler}, intern::{ProjectIdInt, RoleNameInt}, - metrics::REDIS_BROKEN_MESSAGES, + metrics::{NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, REDIS_BROKEN_MESSAGES}, }; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; @@ -18,23 +19,13 @@ pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20); -struct RedisConsumerClient { - client: redis::Client, -} - -impl RedisConsumerClient { - pub fn new(url: &str) -> anyhow::Result { - let client = redis::Client::open(url)?; - Ok(Self { client }) - } - async fn try_connect(&self) -> anyhow::Result { - let mut conn = self.client.get_async_connection().await?.into_pubsub(); - tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); - conn.subscribe(CPLANE_CHANNEL_NAME).await?; - tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); - conn.subscribe(PROXY_CHANNEL_NAME).await?; - Ok(conn) - } +async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result { + let mut conn = client.get_async_pubsub().await?; + tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); + conn.subscribe(CPLANE_CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); + conn.subscribe(PROXY_CHANNEL_NAME).await?; + Ok(conn) } #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -80,21 +71,18 @@ where serde_json::from_str(&s).map_err(::custom) } -struct MessageHandler< - C: ProjectInfoCache + Send + Sync + 'static, - H: NotificationsCancellationHandler + Send + Sync + 'static, -> { +struct MessageHandler { cache: Arc, - cancellation_handler: Arc, + cancellation_handler: Arc>, region_id: String, } -impl< - C: ProjectInfoCache + Send + Sync + 'static, - H: NotificationsCancellationHandler + Send + Sync + 'static, - > MessageHandler -{ - pub fn new(cache: Arc, cancellation_handler: Arc, region_id: String) -> Self { +impl MessageHandler { + pub fn new( + cache: Arc, + cancellation_handler: Arc>, + region_id: String, + ) -> Self { Self { cache, cancellation_handler, @@ -139,7 +127,7 @@ impl< // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message. match self .cancellation_handler - .cancel_session_no_publish(cancel_session.cancel_key_data) + .cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil()) .await { Ok(()) => {} @@ -182,7 +170,7 @@ fn invalidate_cache(cache: Arc, msg: Notification) { /// Handle console's invalidation messages. #[tracing::instrument(name = "console_notifications", skip_all)] pub async fn task_main( - url: String, + redis: ConnectionWithCredentialsProvider, cache: Arc, cancel_map: CancelMap, region_id: String, @@ -193,13 +181,15 @@ where cache.enable_ttl(); let handler = MessageHandler::new( cache, - Arc::new(CancellationHandler::new(cancel_map, None)), + Arc::new(CancellationHandler::<()>::new( + cancel_map, + NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, + )), region_id, ); loop { - let redis = RedisConsumerClient::new(&url)?; - let conn = match redis.try_connect().await { + let mut conn = match try_connect(&redis).await { Ok(conn) => { handler.disable_ttl(); conn @@ -212,7 +202,7 @@ where continue; } }; - let mut stream = conn.into_on_message(); + let mut stream = conn.on_message(); while let Some(msg) = stream.next().await { match handler.handle_message(msg).await { Ok(()) => {} diff --git a/proxy/src/redis/publisher.rs b/proxy/src/redis/publisher.rs deleted file mode 100644 index f85593afdd..0000000000 --- a/proxy/src/redis/publisher.rs +++ /dev/null @@ -1,80 +0,0 @@ -use pq_proto::CancelKeyData; -use redis::AsyncCommands; -use uuid::Uuid; - -use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; - -use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}; - -pub struct RedisPublisherClient { - client: redis::Client, - publisher: Option, - region_id: String, - limiter: RedisRateLimiter, -} - -impl RedisPublisherClient { - pub fn new( - url: &str, - region_id: String, - info: &'static [RateBucketInfo], - ) -> anyhow::Result { - let client = redis::Client::open(url)?; - Ok(Self { - client, - publisher: None, - region_id, - limiter: RedisRateLimiter::new(info), - }) - } - pub async fn try_publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()> { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping cancellation message"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - match self.publish(cancel_key_data, session_id).await { - Ok(()) => return Ok(()), - Err(e) => { - tracing::error!("failed to publish a message: {e}"); - self.publisher = None; - } - } - tracing::info!("Publisher is disconnected. Reconnectiong..."); - self.try_connect().await?; - self.publish(cancel_key_data, session_id).await - } - - async fn publish( - &mut self, - cancel_key_data: CancelKeyData, - session_id: Uuid, - ) -> anyhow::Result<()> { - let conn = self - .publisher - .as_mut() - .ok_or_else(|| anyhow::anyhow!("not connected"))?; - let payload = serde_json::to_string(&Notification::Cancel(CancelSession { - region_id: Some(self.region_id.clone()), - cancel_key_data, - session_id, - }))?; - conn.publish(PROXY_CHANNEL_NAME, payload).await?; - Ok(()) - } - pub async fn try_connect(&mut self) -> anyhow::Result<()> { - match self.client.get_async_connection().await { - Ok(conn) => { - self.publisher = Some(conn); - } - Err(e) => { - tracing::error!("failed to connect to redis: {e}"); - return Err(e.into()); - } - } - Ok(()) - } -} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index 682cbe795f..89dd33e59f 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -3,9 +3,7 @@ use std::convert::Infallible; use hmac::{Hmac, Mac}; -use sha2::digest::FixedOutput; -use sha2::{Digest, Sha256}; -use subtle::{Choice, ConstantTimeEq}; +use sha2::Sha256; use tokio::task::yield_now; use super::messages::{ @@ -13,6 +11,7 @@ use super::messages::{ }; use super::secret::ServerSecret; use super::signature::SignatureBuilder; +use super::ScramKey; use crate::config; use crate::sasl::{self, ChannelBinding, Error as SaslError}; @@ -104,7 +103,7 @@ async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] { } // copied from -async fn derive_keys(password: &[u8], salt: &[u8], iterations: u32) -> ([u8; 32], [u8; 32]) { +async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey { let salted_password = pbkdf2(password, salt, iterations).await; let make_key = |name| { @@ -116,7 +115,7 @@ async fn derive_keys(password: &[u8], salt: &[u8], iterations: u32) -> ([u8; 32] <[u8; 32]>::from(key.into_bytes()) }; - (make_key(b"Client Key"), make_key(b"Server Key")) + make_key(b"Client Key").into() } pub async fn exchange( @@ -124,21 +123,12 @@ pub async fn exchange( password: &[u8], ) -> sasl::Result> { let salt = base64::decode(&secret.salt_base64)?; - let (client_key, server_key) = derive_keys(password, &salt, secret.iterations).await; - let stored_key: [u8; 32] = Sha256::default() - .chain_update(client_key) - .finalize_fixed() - .into(); + let client_key = derive_client_key(password, &salt, secret.iterations).await; - // constant time to not leak partial key match - let valid = stored_key.ct_eq(&secret.stored_key.as_bytes()) - | server_key.ct_eq(&secret.server_key.as_bytes()) - | Choice::from(secret.doomed as u8); - - if valid.into() { - Ok(sasl::Outcome::Success(super::ScramKey::from(client_key))) - } else { + if secret.is_password_invalid(&client_key).into() { Ok(sasl::Outcome::Failure("password doesn't match")) + } else { + Ok(sasl::Outcome::Success(client_key)) } } @@ -220,7 +210,7 @@ impl SaslSentInner { .derive_client_key(&client_final_message.proof); // Auth fails either if keys don't match or it's pre-determined to fail. - if client_key.sha256() != secret.stored_key || secret.doomed { + if secret.is_password_invalid(&client_key).into() { return Ok(sasl::Step::Failure("password doesn't match")); } diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index 973126e729..32a3dbd203 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -1,17 +1,31 @@ //! Tools for client/server/stored key management. +use subtle::ConstantTimeEq; + /// Faithfully taken from PostgreSQL. pub const SCRAM_KEY_LEN: usize = 32; /// One of the keys derived from the user's password. /// We use the same structure for all keys, i.e. /// `ClientKey`, `StoredKey`, and `ServerKey`. -#[derive(Clone, Default, PartialEq, Eq, Debug)] +#[derive(Clone, Default, Eq, Debug)] #[repr(transparent)] pub struct ScramKey { bytes: [u8; SCRAM_KEY_LEN], } +impl PartialEq for ScramKey { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl ConstantTimeEq for ScramKey { + fn ct_eq(&self, other: &Self) -> subtle::Choice { + self.bytes.ct_eq(&other.bytes) + } +} + impl ScramKey { pub fn sha256(&self) -> Self { super::sha256([self.as_ref()]).into() diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs index b59baec508..f9372540ca 100644 --- a/proxy/src/scram/messages.rs +++ b/proxy/src/scram/messages.rs @@ -206,6 +206,28 @@ mod tests { } } + #[test] + fn parse_client_first_message_with_invalid_gs2_authz() { + assert!(ClientFirstMessage::parse("n,authzid,n=user,r=nonce").is_none()) + } + + #[test] + fn parse_client_first_message_with_extra_params() { + let msg = ClientFirstMessage::parse("n,,n=user,r=nonce,a=foo,b=bar,c=baz").unwrap(); + assert_eq!(msg.bare, "n=user,r=nonce,a=foo,b=bar,c=baz"); + assert_eq!(msg.username, "user"); + assert_eq!(msg.nonce, "nonce"); + assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient); + } + + #[test] + fn parse_client_first_message_with_extra_params_invalid() { + // must be of the form `=<...>` + assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,abc=foo").is_none()); + assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,1=foo").is_none()); + assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,a").is_none()); + } + #[test] fn parse_client_final_message() { let input = [ diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index b46d8c3ab5..f3414cb8ec 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -1,5 +1,7 @@ //! Tools for SCRAM server secret management. +use subtle::{Choice, ConstantTimeEq}; + use super::base64_decode_array; use super::key::ScramKey; @@ -40,6 +42,11 @@ impl ServerSecret { Some(secret) } + pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice { + // constant time to not leak partial key match + client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8) + } + /// To avoid revealing information to an attacker, we use a /// mocked server secret even if the user doesn't exist. /// See `auth-scram.c : mock_scram_secret` for details. diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index be9f90acde..a2010fd613 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -21,11 +21,12 @@ pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use tokio_util::task::TaskTracker; use tracing::instrument::Instrumented; +use crate::cancellation::CancellationHandlerMain; +use crate::config::ProxyConfig; use crate::context::RequestMonitoring; use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard}; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; -use crate::{cancellation::CancellationHandler, config::ProxyConfig}; use hyper::{ server::conn::{AddrIncoming, AddrStream}, Body, Method, Request, Response, @@ -47,7 +48,7 @@ pub async fn task_main( ws_listener: TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, - cancellation_handler: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("websocket server has shut down"); @@ -237,7 +238,7 @@ async fn request_handler( config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, - cancellation_handler: Arc, + cancellation_handler: Arc, peer_addr: IpAddr, endpoint_rate_limiter: Arc, // used to cancel in-flight HTTP requests. not used to cancel websockets diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index a72ede6d0a..ada6c974f4 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -1,5 +1,5 @@ use crate::{ - cancellation::CancellationHandler, + cancellation::CancellationHandlerMain, config::ProxyConfig, context::RequestMonitoring, error::{io_error, ReportableError}, @@ -134,7 +134,7 @@ pub async fn serve_websocket( config: &'static ProxyConfig, mut ctx: RequestMonitoring, websocket: HyperWebsocket, - cancellation_handler: Arc, + cancellation_handler: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { diff --git a/rust-toolchain.toml b/rust-toolchain.toml index b0949c32b1..50a5a4185b 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.76.0" +channel = "1.77.0" profile = "default" # The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy. # https://rust-lang.github.io/rustup/concepts/profiles.html diff --git a/safekeeper/src/copy_timeline.rs b/safekeeper/src/copy_timeline.rs index 5bc877adbd..3023d4e2cb 100644 --- a/safekeeper/src/copy_timeline.rs +++ b/safekeeper/src/copy_timeline.rs @@ -225,6 +225,7 @@ async fn write_segment( assert!(from <= to); assert!(to <= wal_seg_size); + #[allow(clippy::suspicious_open_options)] let mut file = OpenOptions::new() .create(true) .write(true) diff --git a/safekeeper/src/wal_storage.rs b/safekeeper/src/wal_storage.rs index 8bbd95e9e8..147f318b9f 100644 --- a/safekeeper/src/wal_storage.rs +++ b/safekeeper/src/wal_storage.rs @@ -221,6 +221,7 @@ impl PhysicalStorage { // half initialized segment, first bake it under tmp filename and // then rename. let tmp_path = self.timeline_dir.join("waltmp"); + #[allow(clippy::suspicious_open_options)] let mut file = OpenOptions::new() .create(true) .write(true) diff --git a/safekeeper/tests/walproposer_sim/walproposer_api.rs b/safekeeper/tests/walproposer_sim/walproposer_api.rs index 42340ba1df..c49495a4f3 100644 --- a/safekeeper/tests/walproposer_sim/walproposer_api.rs +++ b/safekeeper/tests/walproposer_sim/walproposer_api.rs @@ -244,6 +244,7 @@ impl SimulationApi { mutex: 0, mineLastElectedTerm: 0, backpressureThrottlingTime: pg_atomic_uint64 { value: 0 }, + currentClusterSize: pg_atomic_uint64 { value: 0 }, shard_ps_feedback: [empty_feedback; 128], num_shards: 0, min_ps_feedback: empty_feedback, diff --git a/test_runner/fixtures/pageserver/allowed_errors.py b/test_runner/fixtures/pageserver/allowed_errors.py index ec0f81b380..d7f682dad3 100755 --- a/test_runner/fixtures/pageserver/allowed_errors.py +++ b/test_runner/fixtures/pageserver/allowed_errors.py @@ -96,6 +96,8 @@ DEFAULT_STORAGE_CONTROLLER_ALLOWED_ERRORS = [ ".*Call to node.*management API.*failed.*ReceiveBody.*", # Many tests will start up with a node offline ".*startup_reconcile: Could not scan node.*", + # Tests run in dev mode + ".*Starting in dev mode.*", ] diff --git a/test_runner/fixtures/pageserver/utils.py b/test_runner/fixtures/pageserver/utils.py index cf64c86821..693771dd3d 100644 --- a/test_runner/fixtures/pageserver/utils.py +++ b/test_runner/fixtures/pageserver/utils.py @@ -62,9 +62,7 @@ def wait_for_upload( ) time.sleep(1) raise Exception( - "timed out while waiting for remote_consistent_lsn to reach {}, was {}".format( - lsn, current_lsn - ) + f"timed out while waiting for {tenant}/{timeline} remote_consistent_lsn to reach {lsn}, was {current_lsn}" ) diff --git a/test_runner/regress/test_backpressure.py b/test_runner/regress/test_backpressure.py index 819912dd05..af17a2e89d 100644 --- a/test_runner/regress/test_backpressure.py +++ b/test_runner/regress/test_backpressure.py @@ -116,7 +116,7 @@ def test_backpressure_received_lsn_lag(neon_env_builder: NeonEnvBuilder): # Configure failpoint to slow down walreceiver ingest with closing(env.pageserver.connect()) as psconn: with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur: - pscur.execute("failpoints walreceiver-after-ingest=sleep(20)") + pscur.execute("failpoints walreceiver-after-ingest-blocking=sleep(20)") # FIXME # Wait for the check thread to start diff --git a/test_runner/regress/test_hot_standby.py b/test_runner/regress/test_hot_standby.py index 0497e1965c..ac3315b86f 100644 --- a/test_runner/regress/test_hot_standby.py +++ b/test_runner/regress/test_hot_standby.py @@ -84,3 +84,21 @@ def test_hot_standby(neon_simple_env: NeonEnv): # clean up if slow_down_send: sk_http.configure_failpoints(("sk-send-wal-replica-sleep", "off")) + + +def test_2_replicas_start(neon_simple_env: NeonEnv): + env = neon_simple_env + + with env.endpoints.create_start( + branch_name="main", + endpoint_id="primary", + ) as primary: + time.sleep(1) + with env.endpoints.new_replica_start( + origin=primary, endpoint_id="secondary1" + ) as secondary1: + with env.endpoints.new_replica_start( + origin=primary, endpoint_id="secondary2" + ) as secondary2: + wait_replica_caughtup(primary, secondary1) + wait_replica_caughtup(primary, secondary2) diff --git a/test_runner/regress/test_pageserver_metric_collection.py b/test_runner/regress/test_pageserver_metric_collection.py index 042961baa5..c34ef46d07 100644 --- a/test_runner/regress/test_pageserver_metric_collection.py +++ b/test_runner/regress/test_pageserver_metric_collection.py @@ -1,4 +1,6 @@ +import gzip import json +import os import time from dataclasses import dataclass from pathlib import Path @@ -10,7 +12,11 @@ from fixtures.neon_fixtures import ( NeonEnvBuilder, wait_for_last_flush_lsn, ) -from fixtures.remote_storage import RemoteStorageKind +from fixtures.remote_storage import ( + LocalFsStorage, + RemoteStorageKind, + remote_storage_to_toml_inline_table, +) from fixtures.types import TenantId, TimelineId from pytest_httpserver import HTTPServer from werkzeug.wrappers.request import Request @@ -40,6 +46,9 @@ def test_metric_collection( uploads.put((events, is_last == "true")) return Response(status=200) + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) + assert neon_env_builder.pageserver_remote_storage is not None + # Require collecting metrics frequently, since we change # the timeline and want something to be logged about it. # @@ -48,12 +57,11 @@ def test_metric_collection( neon_env_builder.pageserver_config_override = f""" metric_collection_interval="1s" metric_collection_endpoint="{metric_collection_endpoint}" + metric_collection_bucket={remote_storage_to_toml_inline_table(neon_env_builder.pageserver_remote_storage)} cached_metric_collection_interval="0s" synthetic_size_calculation_interval="3s" """ - neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) - log.info(f"test_metric_collection endpoint is {metric_collection_endpoint}") # mock http server that returns OK for the metrics @@ -70,6 +78,7 @@ def test_metric_collection( # we have a fast rate of calculation, these can happen at shutdown ".*synthetic_size_worker:calculate_synthetic_size.*:gather_size_inputs.*: failed to calculate logical size at .*: cancelled.*", ".*synthetic_size_worker: failed to calculate synthetic size for tenant .*: failed to calculate some logical_sizes", + ".*metrics_collection: failed to upload to S3: Failed to upload data of length .* to storage path.*", ] ) @@ -166,6 +175,20 @@ def test_metric_collection( httpserver.check() + # Check that at least one bucket output object is present, and that all + # can be decompressed and decoded. + bucket_dumps = {} + assert isinstance(env.pageserver_remote_storage, LocalFsStorage) + for dirpath, _dirs, files in os.walk(env.pageserver_remote_storage.root): + for file in files: + file_path = os.path.join(dirpath, file) + log.info(file_path) + if file.endswith(".gz"): + bucket_dumps[file_path] = json.load(gzip.open(file_path)) + + assert len(bucket_dumps) >= 1 + assert all("events" in data for data in bucket_dumps.values()) + def test_metric_collection_cleans_up_tempfile( httpserver: HTTPServer, diff --git a/test_runner/regress/test_pageserver_small_inmemory_layers.py b/test_runner/regress/test_pageserver_small_inmemory_layers.py index 5d55020e3c..714d1c1229 100644 --- a/test_runner/regress/test_pageserver_small_inmemory_layers.py +++ b/test_runner/regress/test_pageserver_small_inmemory_layers.py @@ -1,5 +1,4 @@ import asyncio -import time from typing import Tuple import pytest @@ -10,7 +9,7 @@ from fixtures.neon_fixtures import ( tenant_get_shards, ) from fixtures.pageserver.http import PageserverHttpClient -from fixtures.pageserver.utils import wait_for_last_record_lsn +from fixtures.pageserver.utils import wait_for_last_record_lsn, wait_for_upload from fixtures.types import Lsn, TenantId, TimelineId from fixtures.utils import wait_until @@ -61,6 +60,15 @@ def wait_until_pageserver_is_caught_up( assert waited >= last_flush_lsn +def wait_until_pageserver_has_uploaded( + env: NeonEnv, last_flush_lsns: list[Tuple[TenantId, TimelineId, Lsn]] +): + for tenant, timeline, last_flush_lsn in last_flush_lsns: + shards = tenant_get_shards(env, tenant) + for tenant_shard_id, pageserver in shards: + wait_for_upload(pageserver.http_client(), tenant_shard_id, timeline, last_flush_lsn) + + def wait_for_wal_ingest_metric(pageserver_http: PageserverHttpClient) -> float: def query(): value = pageserver_http.get_metric_value("pageserver_wal_ingest_records_received_total") @@ -86,25 +94,50 @@ def test_pageserver_small_inmemory_layers( The workload creates a number of timelines and writes some data to each, but not enough to trigger flushes via the `checkpoint_distance` config. """ + + def get_dirty_bytes(): + v = ( + env.pageserver.http_client().get_metric_value("pageserver_timeline_ephemeral_bytes") + or 0 + ) + log.info(f"dirty_bytes: {v}") + return v + + def assert_dirty_bytes(v): + assert get_dirty_bytes() == v + env = neon_env_builder.init_configs() env.start() last_flush_lsns = asyncio.run(workload(env, TIMELINE_COUNT, ENTRIES_PER_TIMELINE)) wait_until_pageserver_is_caught_up(env, last_flush_lsns) + # We didn't write enough data to trigger a size-based checkpoint + assert get_dirty_bytes() > 0 + ps_http_client = env.pageserver.http_client() total_wal_ingested_before_restart = wait_for_wal_ingest_metric(ps_http_client) - log.info("Sleeping for checkpoint timeout ...") - time.sleep(CHECKPOINT_TIMEOUT_SECONDS + 5) + # Within ~ the checkpoint interval, all the ephemeral layers should be frozen and flushed, + # such that there are zero bytes of ephemeral layer left on the pageserver + log.info("Waiting for background checkpoints...") + wait_until(CHECKPOINT_TIMEOUT_SECONDS * 2, 1, lambda: assert_dirty_bytes(0)) # type: ignore + + # Zero ephemeral layer bytes does not imply that all the frozen layers were uploaded: they + # must be uploaded to remain visible to the pageserver after restart. + wait_until_pageserver_has_uploaded(env, last_flush_lsns) env.pageserver.restart(immediate=immediate_shutdown) wait_until_pageserver_is_caught_up(env, last_flush_lsns) + # Catching up with WAL ingest should have resulted in zero bytes of ephemeral layers, since + # we froze, flushed and uploaded everything before restarting. There can be no more WAL writes + # because we shut down compute endpoints before flushing. + assert get_dirty_bytes() == 0 + total_wal_ingested_after_restart = wait_for_wal_ingest_metric(ps_http_client) log.info(f"WAL ingested before restart: {total_wal_ingested_before_restart}") log.info(f"WAL ingested after restart: {total_wal_ingested_after_restart}") - leeway = total_wal_ingested_before_restart * 5 / 100 - assert total_wal_ingested_after_restart <= leeway + assert total_wal_ingested_after_restart == 0 diff --git a/test_runner/regress/test_recovery.py b/test_runner/regress/test_recovery.py index 6aac1e1d84..ab5c8be256 100644 --- a/test_runner/regress/test_recovery.py +++ b/test_runner/regress/test_recovery.py @@ -15,6 +15,13 @@ def test_pageserver_recovery(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() env.pageserver.is_testing_enabled_or_skip() + # We expect the pageserver to exit, which will cause storage storage controller + # requests to fail and warn. + env.storage_controller.allowed_errors.append(".*management API still failed.*") + env.storage_controller.allowed_errors.append( + ".*Reconcile error.*error sending request for url.*" + ) + # Create a branch for us env.neon_cli.create_branch("test_pageserver_recovery", "main") diff --git a/test_runner/regress/test_replication_start.py b/test_runner/regress/test_replication_start.py index b4699c7be8..2360745990 100644 --- a/test_runner/regress/test_replication_start.py +++ b/test_runner/regress/test_replication_start.py @@ -1,7 +1,9 @@ +import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv, wait_replica_caughtup +@pytest.mark.xfail def test_replication_start(neon_simple_env: NeonEnv): env = neon_simple_env diff --git a/test_runner/regress/test_timeline_size.py b/test_runner/regress/test_timeline_size.py index 628c484fbd..efd257900d 100644 --- a/test_runner/regress/test_timeline_size.py +++ b/test_runner/regress/test_timeline_size.py @@ -931,7 +931,7 @@ def test_timeline_logical_size_task_priority(neon_env_builder: NeonEnvBuilder): env.pageserver.stop() env.pageserver.start( extra_env_vars={ - "FAILPOINTS": "initial-size-calculation-permit-pause=pause;walreceiver-after-ingest=pause" + "FAILPOINTS": "initial-size-calculation-permit-pause=pause;walreceiver-after-ingest-pause-activate=return(1);walreceiver-after-ingest-pause=pause" } ) @@ -953,7 +953,11 @@ def test_timeline_logical_size_task_priority(neon_env_builder: NeonEnvBuilder): assert details["current_logical_size_is_accurate"] is True client.configure_failpoints( - [("initial-size-calculation-permit-pause", "off"), ("walreceiver-after-ingest", "off")] + [ + ("initial-size-calculation-permit-pause", "off"), + ("walreceiver-after-ingest-pause-activate", "off"), + ("walreceiver-after-ingest-pause", "off"), + ] ) @@ -983,7 +987,7 @@ def test_eager_attach_does_not_queue_up(neon_env_builder: NeonEnvBuilder): # pause at logical size calculation, also pause before walreceiver can give feedback so it will give priority to logical size calculation env.pageserver.start( extra_env_vars={ - "FAILPOINTS": "timeline-calculate-logical-size-pause=pause;walreceiver-after-ingest=pause" + "FAILPOINTS": "timeline-calculate-logical-size-pause=pause;walreceiver-after-ingest-pause-activate=return(1);walreceiver-after-ingest-pause=pause" } ) @@ -1029,7 +1033,11 @@ def test_eager_attach_does_not_queue_up(neon_env_builder: NeonEnvBuilder): other_is_attaching() client.configure_failpoints( - [("timeline-calculate-logical-size-pause", "off"), ("walreceiver-after-ingest", "off")] + [ + ("timeline-calculate-logical-size-pause", "off"), + ("walreceiver-after-ingest-pause-activate", "off"), + ("walreceiver-after-ingest-pause", "off"), + ] ) @@ -1059,7 +1067,7 @@ def test_lazy_attach_activation(neon_env_builder: NeonEnvBuilder, activation_met # pause at logical size calculation, also pause before walreceiver can give feedback so it will give priority to logical size calculation env.pageserver.start( extra_env_vars={ - "FAILPOINTS": "timeline-calculate-logical-size-pause=pause;walreceiver-after-ingest=pause" + "FAILPOINTS": "timeline-calculate-logical-size-pause=pause;walreceiver-after-ingest-pause-activate=return(1);walreceiver-after-ingest-pause=pause" } ) @@ -1111,3 +1119,11 @@ def test_lazy_attach_activation(neon_env_builder: NeonEnvBuilder, activation_met delete_lazy_activating(lazy_tenant, env.pageserver, expect_attaching=True) else: raise RuntimeError(activation_method) + + client.configure_failpoints( + [ + ("timeline-calculate-logical-size-pause", "off"), + ("walreceiver-after-ingest-pause-activate", "off"), + ("walreceiver-after-ingest-pause", "off"), + ] + ) diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 3b09894ddb..748643b468 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 3b09894ddb8825b50c963942059eab1a2a0b0a89 +Subproject commit 748643b4683e9fe3b105011a6ba8a687d032cd65 diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index 80cef885ad..e7651e79c0 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit 80cef885add1af6741aa31944c7d2c84d8f9098f +Subproject commit e7651e79c0c27fbddc3c724f5b9553222c28e395 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 9007894722..3946b2e2ea 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 90078947229aa7f9ac5f7ed4527b2c7386d5332b +Subproject commit 3946b2e2ea71d07af092099cb5bcae76a69b90d6 diff --git a/vendor/revisions.json b/vendor/revisions.json index ae524d70b1..3c1b866137 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,5 +1,5 @@ { - "postgres-v16": "90078947229aa7f9ac5f7ed4527b2c7386d5332b", - "postgres-v15": "80cef885add1af6741aa31944c7d2c84d8f9098f", - "postgres-v14": "3b09894ddb8825b50c963942059eab1a2a0b0a89" + "postgres-v16": "3946b2e2ea71d07af092099cb5bcae76a69b90d6", + "postgres-v15": "e7651e79c0c27fbddc3c724f5b9553222c28e395", + "postgres-v14": "748643b4683e9fe3b105011a6ba8a687d032cd65" } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 152c452dd4..7b8228a082 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -19,8 +19,7 @@ aws-runtime = { version = "1", default-features = false, features = ["event-stre aws-sigv4 = { version = "1", features = ["http0-compat", "sign-eventstream", "sigv4a"] } aws-smithy-async = { version = "1", default-features = false, features = ["rt-tokio"] } aws-smithy-http = { version = "0.60", default-features = false, features = ["event-stream"] } -aws-smithy-runtime-api = { version = "1", features = ["client", "http-02x", "http-auth"] } -aws-smithy-types = { version = "1", default-features = false, features = ["byte-stream-poll-next", "http-body-0-4-x", "rt-tokio"] } +aws-smithy-types = { version = "1", default-features = false, features = ["byte-stream-poll-next", "http-body-0-4-x", "rt-tokio", "test-util"] } axum = { version = "0.6", features = ["ws"] } base64 = { version = "0.21", features = ["alloc"] } base64ct = { version = "1", default-features = false, features = ["std"] }