diff --git a/.github/actions/run-python-test-set/action.yml b/.github/actions/run-python-test-set/action.yml index 4008cd0d36..330e875d56 100644 --- a/.github/actions/run-python-test-set/action.yml +++ b/.github/actions/run-python-test-set/action.yml @@ -218,6 +218,9 @@ runs: name: compatibility-snapshot-${{ runner.arch }}-${{ inputs.build_type }}-pg${{ inputs.pg_version }} # Directory is created by test_compatibility.py::test_create_snapshot, keep the path in sync with the test path: /tmp/test_output/compatibility_snapshot_pg${{ inputs.pg_version }}/ + # The lack of compatibility snapshot shouldn't fail the job + # (for example if we didn't run the test for non build-and-test workflow) + skip-if-does-not-exist: true - name: Upload test results if: ${{ !cancelled() }} diff --git a/.github/actions/upload/action.yml b/.github/actions/upload/action.yml index edcece7d2b..8a4cfe2eff 100644 --- a/.github/actions/upload/action.yml +++ b/.github/actions/upload/action.yml @@ -7,6 +7,10 @@ inputs: path: description: "A directory or file to upload" required: true + skip-if-does-not-exist: + description: "Allow to skip if path doesn't exist, fail otherwise" + default: false + required: false prefix: description: "S3 prefix. Default is '${GITHUB_SHA}/${GITHUB_RUN_ID}/${GITHUB_RUN_ATTEMPT}'" required: false @@ -15,10 +19,12 @@ runs: using: "composite" steps: - name: Prepare artifact + id: prepare-artifact shell: bash -euxo pipefail {0} env: SOURCE: ${{ inputs.path }} ARCHIVE: /tmp/uploads/${{ inputs.name }}.tar.zst + SKIP_IF_DOES_NOT_EXIST: ${{ inputs.skip-if-does-not-exist }} run: | mkdir -p $(dirname $ARCHIVE) @@ -33,14 +39,22 @@ runs: elif [ -f ${SOURCE} ]; then time tar -cf ${ARCHIVE} --zstd ${SOURCE} elif ! ls ${SOURCE} > /dev/null 2>&1; then - echo >&2 "${SOURCE} does not exist" - exit 2 + if [ "${SKIP_IF_DOES_NOT_EXIST}" = "true" ]; then + echo 'SKIPPED=true' >> $GITHUB_OUTPUT + exit 0 + else + echo >&2 "${SOURCE} does not exist" + exit 2 + fi else echo >&2 "${SOURCE} is neither a directory nor a file, do not know how to handle it" exit 3 fi + echo 'SKIPPED=false' >> $GITHUB_OUTPUT + - name: Upload artifact + if: ${{ steps.prepare-artifact.outputs.SKIPPED == 'false' }} shell: bash -euxo pipefail {0} env: SOURCE: ${{ inputs.path }} diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index a759efb56c..e7193cfe19 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -193,16 +193,15 @@ jobs: with: submodules: true -# Disabled for now -# - name: Restore cargo deps cache -# id: cache_cargo -# uses: actions/cache@v4 -# with: -# path: | -# !~/.cargo/registry/src -# ~/.cargo/git/ -# target/ -# key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-clippy-${{ hashFiles('rust-toolchain.toml') }}-${{ hashFiles('Cargo.lock') }} + - name: Cache cargo deps + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + !~/.cargo/registry/src + ~/.cargo/git + target + key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-${{ hashFiles('./Cargo.lock') }}-${{ hashFiles('./rust-toolchain.toml') }}-rust # Some of our rust modules use FFI and need those to be checked - name: Get postgres headers diff --git a/.github/workflows/report-workflow-stats.yml b/.github/workflows/report-workflow-stats.yml index 1afe896600..6abeff7695 100644 --- a/.github/workflows/report-workflow-stats.yml +++ b/.github/workflows/report-workflow-stats.yml @@ -33,7 +33,7 @@ jobs: actions: read steps: - name: Export GH Workflow Stats - uses: fedordikarev/gh-workflow-stats-action@v0.1.2 + uses: neondatabase/gh-workflow-stats-action@v0.1.4 with: DB_URI: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }} DB_TABLE: "gh_workflow_stats_neon" diff --git a/CODEOWNERS b/CODEOWNERS index 606dbb4e22..f8ed4be816 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,5 +1,6 @@ /compute_tools/ @neondatabase/control-plane @neondatabase/compute /storage_controller @neondatabase/storage +/storage_scrubber @neondatabase/storage /libs/pageserver_api/ @neondatabase/storage /libs/postgres_ffi/ @neondatabase/compute @neondatabase/storage /libs/remote_storage/ @neondatabase/storage diff --git a/libs/utils/src/auth.rs b/libs/utils/src/auth.rs index 7b735875b7..5bd6f4bedc 100644 --- a/libs/utils/src/auth.rs +++ b/libs/utils/src/auth.rs @@ -31,9 +31,12 @@ pub enum Scope { /// The scope used by pageservers in upcalls to storage controller and cloud control plane #[serde(rename = "generations_api")] GenerationsApi, - /// Allows access to control plane managment API and some storage controller endpoints. + /// Allows access to control plane managment API and all storage controller endpoints. Admin, + /// Allows access to control plane & storage controller endpoints used in infrastructure automation (e.g. node registration) + Infra, + /// Allows access to storage controller APIs used by the scrubber, to interrogate the state /// of a tenant & post scrub results. Scrubber, diff --git a/pageserver/src/auth.rs b/pageserver/src/auth.rs index 9e3dedb75a..5c931fcfdb 100644 --- a/pageserver/src/auth.rs +++ b/pageserver/src/auth.rs @@ -14,14 +14,19 @@ pub fn check_permission(claims: &Claims, tenant_id: Option) -> Result< } (Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope (Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope - (Scope::Admin | Scope::SafekeeperData | Scope::GenerationsApi | Scope::Scrubber, _) => { - Err(AuthError( - format!( - "JWT scope '{:?}' is ineligible for Pageserver auth", - claims.scope - ) - .into(), - )) - } + ( + Scope::Admin + | Scope::SafekeeperData + | Scope::GenerationsApi + | Scope::Infra + | Scope::Scrubber, + _, + ) => Err(AuthError( + format!( + "JWT scope '{:?}' is ineligible for Pageserver auth", + claims.scope + ) + .into(), + )), } } diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index a7cc678187..127be545e1 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -25,6 +25,10 @@ pub(crate) enum WebAuthError { Io(#[from] std::io::Error), } +pub struct ConsoleRedirectBackend { + console_uri: reqwest::Url, +} + impl UserFacingError for WebAuthError { fn to_string_client(&self) -> String { "Internal error".to_string() @@ -57,7 +61,26 @@ pub(crate) fn new_psql_session_id() -> String { hex::encode(rand::random::<[u8; 8]>()) } -pub(super) async fn authenticate( +impl ConsoleRedirectBackend { + pub fn new(console_uri: reqwest::Url) -> Self { + Self { console_uri } + } + + pub(super) fn url(&self) -> &reqwest::Url { + &self.console_uri + } + + pub(crate) async fn authenticate( + &self, + ctx: &RequestMonitoring, + auth_config: &'static AuthenticationConfig, + client: &mut PqStream, + ) -> auth::Result { + authenticate(ctx, auth_config, &self.console_uri, client).await + } +} + +async fn authenticate( ctx: &RequestMonitoring, auth_config: &'static AuthenticationConfig, link_uri: &reqwest::Url, diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index c9aa5b7e61..27c9f1876e 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -8,6 +8,7 @@ use std::net::IpAddr; use std::sync::Arc; use std::time::Duration; +pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::WebAuthError; use ipnet::{Ipv4Net, Ipv6Net}; use local::LocalBackend; @@ -36,7 +37,7 @@ use crate::{ provider::{CachedAllowedIps, CachedNodeInfo}, Api, }, - stream, url, + stream, }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; @@ -69,7 +70,7 @@ pub enum Backend<'a, T, D> { /// Cloud API (V2). ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T), /// Authentication via a web browser. - ConsoleRedirect(MaybeOwned<'a, url::ApiUrl>, D), + ConsoleRedirect(MaybeOwned<'a, ConsoleRedirectBackend>, D), /// Local proxy uses configured auth credentials and does not wake compute Local(MaybeOwned<'a, LocalBackend>), } @@ -106,9 +107,9 @@ impl std::fmt::Display for Backend<'_, (), ()> { #[cfg(test)] ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(), }, - Self::ConsoleRedirect(url, ()) => fmt + Self::ConsoleRedirect(backend, ()) => fmt .debug_tuple("ConsoleRedirect") - .field(&url.as_str()) + .field(&backend.url().as_str()) .finish(), Self::Local(_) => fmt.debug_tuple("Local").finish(), } @@ -241,7 +242,6 @@ impl AuthenticationConfig { pub(crate) fn check_rate_limit( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, secret: AuthSecret, endpoint: &EndpointId, is_cleartext: bool, @@ -265,7 +265,7 @@ impl AuthenticationConfig { let limit_not_exceeded = self.rate_limiter.check( ( endpoint_int, - MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet), + MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet), ), password_weight, ); @@ -339,7 +339,6 @@ async fn auth_quirks( let secret = if let Some(secret) = secret { config.check_rate_limit( ctx, - config, secret, &info.endpoint, unauthenticated_password.is_some() || allow_cleartext, @@ -456,12 +455,12 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> { Backend::ControlPlane(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Self::ConsoleRedirect(url, ()) => { + Self::ConsoleRedirect(backend, ()) => { info!("performing web authentication"); - let info = console_redirect::authenticate(ctx, config, &url, client).await?; + let info = backend.authenticate(ctx, config, client).await?; - Backend::ConsoleRedirect(url, info) + Backend::ConsoleRedirect(backend, info) } Self::Local(_) => { return Err(auth::AuthError::bad_auth_method("invalid for local proxy")) diff --git a/proxy/src/bin/local_proxy.rs b/proxy/src/bin/local_proxy.rs index ae8a7f0841..c781af846a 100644 --- a/proxy/src/bin/local_proxy.rs +++ b/proxy/src/bin/local_proxy.rs @@ -6,9 +6,12 @@ use compute_api::spec::LocalProxySpec; use dashmap::DashMap; use futures::future::Either; use proxy::{ - auth::backend::{ - jwt::JwkCache, - local::{LocalBackend, JWKS_ROLE_MAP}, + auth::{ + self, + backend::{ + jwt::JwkCache, + local::{LocalBackend, JWKS_ROLE_MAP}, + }, }, cancellation::CancellationHandlerMain, config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig}, @@ -132,6 +135,7 @@ async fn main() -> anyhow::Result<()> { let args = LocalProxyCliArgs::parse(); let config = build_config(&args)?; + let auth_backend = build_auth_backend(&args)?; // before we bind to any ports, write the process ID to a file // so that compute-ctl can find our process later @@ -193,6 +197,7 @@ async fn main() -> anyhow::Result<()> { let task = serverless::task_main( config, + auth_backend, http_listener, shutdown.clone(), Arc::new(CancellationHandlerMain::new( @@ -257,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig Ok(Box::leak(Box::new(ProxyConfig { tls_config: None, - auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( - LocalBackend::new(args.compute), - )), metric_collection: None, allow_self_signed_compute: false, http_config, @@ -286,6 +288,17 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig }))) } +/// auth::Backend is created at proxy startup, and lives forever. +fn build_auth_backend( + args: &LocalProxyCliArgs, +) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> { + let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned( + LocalBackend::new(args.compute), + )); + + Ok(Box::leak(Box::new(auth_backend))) +} + async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc) { loop { rx.notified().await; diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 7488cce3c4..3f4c2df809 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -10,6 +10,7 @@ use futures::future::Either; use proxy::auth; use proxy::auth::backend::jwt::JwkCache; use proxy::auth::backend::AuthRateLimiter; +use proxy::auth::backend::ConsoleRedirectBackend; use proxy::auth::backend::MaybeOwned; use proxy::cancellation::CancelMap; use proxy::cancellation::CancellationHandler; @@ -311,8 +312,9 @@ async fn main() -> anyhow::Result<()> { let args = ProxyCliArgs::parse(); let config = build_config(&args)?; + let auth_backend = build_auth_backend(&args)?; - info!("Authentication backend: {}", config.auth_backend); + info!("Authentication backend: {}", auth_backend); info!("Using region: {}", args.aws_region); let region_provider = @@ -462,6 +464,7 @@ async fn main() -> anyhow::Result<()> { if let Some(proxy_listener) = proxy_listener { client_tasks.spawn(proxy::proxy::task_main( config, + auth_backend, proxy_listener, cancellation_token.clone(), cancellation_handler.clone(), @@ -472,6 +475,7 @@ async fn main() -> anyhow::Result<()> { if let Some(serverless_listener) = serverless_listener { client_tasks.spawn(serverless::task_main( config, + auth_backend, serverless_listener, cancellation_token.clone(), cancellation_handler.clone(), @@ -506,7 +510,7 @@ async fn main() -> anyhow::Result<()> { )); } - if let auth::Backend::ControlPlane(api, _) = &config.auth_backend { + if let auth::Backend::ControlPlane(api, _) = auth_backend { if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api { match (redis_notifications_client, regional_redis_client.clone()) { (None, None) => {} @@ -610,6 +614,80 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { bail!("dynamic rate limiter should be disabled"); } + let config::ConcurrencyLockOptions { + shards, + limiter, + epoch, + timeout, + } = args.connect_compute_lock.parse()?; + info!( + ?limiter, + shards, + ?epoch, + "Using NodeLocks (connect_compute)" + ); + let connect_compute_locks = control_plane::locks::ApiLocks::new( + "connect_compute_lock", + limiter, + shards, + timeout, + epoch, + &Metrics::get().proxy.connect_compute_lock, + )?; + + let http_config = HttpConfig { + accept_websockets: !args.is_auth_broker, + pool_options: GlobalConnPoolOptions { + max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, + gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, + pool_shards: args.sql_over_http.sql_over_http_pool_shards, + idle_timeout: args.sql_over_http.sql_over_http_idle_timeout, + opt_in: args.sql_over_http.sql_over_http_pool_opt_in, + max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, + }, + cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), + client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, + max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes, + max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, + }; + let authentication_config = AuthenticationConfig { + jwks_cache: JwkCache::default(), + thread_pool, + scram_protocol_timeout: args.scram_protocol_timeout, + rate_limiter_enabled: args.auth_rate_limit_enabled, + rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), + rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, + ip_allowlist_check_enabled: !args.is_private_access_proxy, + is_auth_broker: args.is_auth_broker, + accept_jwts: args.is_auth_broker, + webauth_confirmation_timeout: args.webauth_confirmation_timeout, + }; + + let config = Box::leak(Box::new(ProxyConfig { + tls_config, + metric_collection, + allow_self_signed_compute: args.allow_self_signed_compute, + http_config, + authentication_config, + proxy_protocol_v2: args.proxy_protocol_v2, + handshake_timeout: args.handshake_timeout, + region: args.region.clone(), + wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, + connect_compute_locks, + connect_to_compute_retry_config: config::RetryConfig::parse( + &args.connect_to_compute_retry, + )?, + })); + + tokio::spawn(config.connect_compute_locks.garbage_collect_worker()); + + Ok(config) +} + +/// auth::Backend is created at proxy startup, and lives forever. +fn build_auth_backend( + args: &ProxyCliArgs, +) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> { let auth_backend = match &args.auth_backend { AuthBackendType::Console => { let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?; @@ -665,7 +743,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { AuthBackendType::Web => { let url = args.uri.parse()?; - auth::Backend::ConsoleRedirect(MaybeOwned::Owned(url), ()) + auth::Backend::ConsoleRedirect(MaybeOwned::Owned(ConsoleRedirectBackend::new(url)), ()) } #[cfg(feature = "testing")] @@ -677,75 +755,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } }; - let config::ConcurrencyLockOptions { - shards, - limiter, - epoch, - timeout, - } = args.connect_compute_lock.parse()?; - info!( - ?limiter, - shards, - ?epoch, - "Using NodeLocks (connect_compute)" - ); - let connect_compute_locks = control_plane::locks::ApiLocks::new( - "connect_compute_lock", - limiter, - shards, - timeout, - epoch, - &Metrics::get().proxy.connect_compute_lock, - )?; - - let http_config = HttpConfig { - accept_websockets: !args.is_auth_broker, - pool_options: GlobalConnPoolOptions { - max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint, - gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch, - pool_shards: args.sql_over_http.sql_over_http_pool_shards, - idle_timeout: args.sql_over_http.sql_over_http_idle_timeout, - opt_in: args.sql_over_http.sql_over_http_pool_opt_in, - max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns, - }, - cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards), - client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold, - max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes, - max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes, - }; - let authentication_config = AuthenticationConfig { - jwks_cache: JwkCache::default(), - thread_pool, - scram_protocol_timeout: args.scram_protocol_timeout, - rate_limiter_enabled: args.auth_rate_limit_enabled, - rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()), - rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet, - ip_allowlist_check_enabled: !args.is_private_access_proxy, - is_auth_broker: args.is_auth_broker, - accept_jwts: args.is_auth_broker, - webauth_confirmation_timeout: args.webauth_confirmation_timeout, - }; - - let config = Box::leak(Box::new(ProxyConfig { - tls_config, - auth_backend, - metric_collection, - allow_self_signed_compute: args.allow_self_signed_compute, - http_config, - authentication_config, - proxy_protocol_v2: args.proxy_protocol_v2, - handshake_timeout: args.handshake_timeout, - region: args.region.clone(), - wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, - connect_compute_locks, - connect_to_compute_retry_config: config::RetryConfig::parse( - &args.connect_to_compute_retry, - )?, - })); - - tokio::spawn(config.connect_compute_locks.garbage_collect_worker()); - - Ok(config) + Ok(Box::leak(Box::new(auth_backend))) } #[cfg(test)] diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 55d0b6374c..c068fc50fb 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,8 +1,5 @@ use crate::{ - auth::{ - self, - backend::{jwt::JwkCache, AuthRateLimiter}, - }, + auth::backend::{jwt::JwkCache, AuthRateLimiter}, control_plane::locks::ApiLocks, rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig}, scram::threadpool::ThreadPool, @@ -29,7 +26,6 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::Backend<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, diff --git a/proxy/src/control_plane/provider/mod.rs b/proxy/src/control_plane/provider/mod.rs index 01d93dee43..6cc525a324 100644 --- a/proxy/src/control_plane/provider/mod.rs +++ b/proxy/src/control_plane/provider/mod.rs @@ -81,12 +81,12 @@ pub(crate) mod errors { Reason::EndpointNotFound => ErrorKind::User, Reason::BranchNotFound => ErrorKind::User, Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit, - Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User, - Reason::ActiveTimeQuotaExceeded => ErrorKind::User, - Reason::ComputeTimeQuotaExceeded => ErrorKind::User, - Reason::WrittenDataQuotaExceeded => ErrorKind::User, - Reason::DataTransferQuotaExceeded => ErrorKind::User, - Reason::LogicalSizeQuotaExceeded => ErrorKind::User, + Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::Quota, + Reason::ActiveTimeQuotaExceeded => ErrorKind::Quota, + Reason::ComputeTimeQuotaExceeded => ErrorKind::Quota, + Reason::WrittenDataQuotaExceeded => ErrorKind::Quota, + Reason::DataTransferQuotaExceeded => ErrorKind::Quota, + Reason::LogicalSizeQuotaExceeded => ErrorKind::Quota, Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane, Reason::LockAlreadyTaken => ErrorKind::ControlPlane, Reason::RunningOperations => ErrorKind::ControlPlane, @@ -103,7 +103,7 @@ pub(crate) mod errors { } if error .contains("compute time quota of non-primary branches is exceeded") => { - crate::error::ErrorKind::User + crate::error::ErrorKind::Quota } ControlPlaneError { http_status_code: http::StatusCode::LOCKED, @@ -112,7 +112,7 @@ pub(crate) mod errors { } if error.contains("quota exceeded") || error.contains("the limit for current plan reached") => { - crate::error::ErrorKind::User + crate::error::ErrorKind::Quota } ControlPlaneError { http_status_code: http::StatusCode::TOO_MANY_REQUESTS, diff --git a/proxy/src/control_plane/provider/neon.rs b/proxy/src/control_plane/provider/neon.rs index e5f8b5c741..d01878741c 100644 --- a/proxy/src/control_plane/provider/neon.rs +++ b/proxy/src/control_plane/provider/neon.rs @@ -22,7 +22,7 @@ use futures::TryFutureExt; use std::{sync::Arc, time::Duration}; use tokio::time::Instant; use tokio_postgres::config::SslMode; -use tracing::{debug, error, info, info_span, warn, Instrument}; +use tracing::{debug, info, info_span, warn, Instrument}; const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); @@ -456,7 +456,7 @@ async fn parse_body serde::Deserialize<'a>>( }); body.http_status_code = status; - error!("console responded with an error ({status}): {body:?}"); + warn!("console responded with an error ({status}): {body:?}"); Err(ApiError::ControlPlane(body)) } diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 53f9f75c5b..1cd4dc2c22 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -49,6 +49,10 @@ pub enum ErrorKind { #[label(rename = "serviceratelimit")] ServiceRateLimit, + /// Proxy quota limit violation + #[label(rename = "quota")] + Quota, + /// internal errors Service, @@ -70,6 +74,7 @@ impl ErrorKind { ErrorKind::ClientDisconnect => "clientdisconnect", ErrorKind::RateLimit => "ratelimit", ErrorKind::ServiceRateLimit => "serviceratelimit", + ErrorKind::Quota => "quota", ErrorKind::Service => "service", ErrorKind::ControlPlane => "controlplane", ErrorKind::Postgres => "postgres", diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 7003af2aba..3a43ccb74a 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -35,7 +35,7 @@ use std::sync::Arc; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, Instrument}; +use tracing::{error, info, warn, Instrument}; use self::{ connect_compute::{connect_to_compute, TcpMechanism}, @@ -61,6 +61,7 @@ pub async fn run_until_cancelled( pub async fn task_main( config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, (), ()>, listener: tokio::net::TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -95,15 +96,15 @@ pub async fn task_main( connections.spawn(async move { let (socket, peer_addr) = match read_proxy_protocol(socket).await { Err(e) => { - error!("per-client task finished with an error: {e:#}"); + warn!("per-client task finished with an error: {e:#}"); return; } Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => { - error!("missing required proxy protocol header"); + warn!("missing required proxy protocol header"); return; } Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => { - error!("proxy protocol header not supported"); + warn!("proxy protocol header not supported"); return; } Ok((socket, Some(addr))) => (socket, addr.ip()), @@ -129,6 +130,7 @@ pub async fn task_main( let startup = Box::pin( handle_client( config, + auth_backend, &ctx, cancellation_handler, socket, @@ -144,7 +146,7 @@ pub async fn task_main( Err(e) => { // todo: log and push to ctx the error kind ctx.set_error_kind(e.get_error_kind()); - error!(parent: &span, "per-client task finished with an error: {e:#}"); + warn!(parent: &span, "per-client task finished with an error: {e:#}"); } Ok(None) => { ctx.set_success(); @@ -155,7 +157,7 @@ pub async fn task_main( match p.proxy_pass().instrument(span.clone()).await { Ok(()) => {} Err(ErrorSource::Client(e)) => { - error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); + warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); } Err(ErrorSource::Compute(e)) => { error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}"); @@ -243,8 +245,10 @@ impl ReportableError for ClientRequestError { } } +#[allow(clippy::too_many_arguments)] pub(crate) async fn handle_client( config: &'static ProxyConfig, + auth_backend: &'static auth::Backend<'static, (), ()>, ctx: &RequestMonitoring, cancellation_handler: Arc, stream: S, @@ -285,8 +289,7 @@ pub(crate) async fn handle_client( let common_names = tls.map(|tls| &tls.common_names); // Extract credentials which we're going to use for auth. - let result = config - .auth_backend + let result = auth_backend .as_ref() .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names)) .transpose(); diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index bbea47f8af..497cf4bfd5 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -71,7 +71,7 @@ impl ProxyPassthrough { pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { - tracing::error!(?err, "could not cancel the query in the database"); + tracing::warn!(?err, "could not cancel the query in the database"); } res } diff --git a/proxy/src/redis/connection_with_credentials_provider.rs b/proxy/src/redis/connection_with_credentials_provider.rs index 2de66b58b1..ccd48f1481 100644 --- a/proxy/src/redis/connection_with_credentials_provider.rs +++ b/proxy/src/redis/connection_with_credentials_provider.rs @@ -6,7 +6,7 @@ use redis::{ ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult, }; use tokio::task::JoinHandle; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use super::elasticache::CredentialsProvider; @@ -89,7 +89,7 @@ impl ConnectionWithCredentialsProvider { return Ok(()); } Err(e) => { - error!("Error during PING: {e:?}"); + warn!("Error during PING: {e:?}"); } } } else { @@ -121,7 +121,7 @@ impl ConnectionWithCredentialsProvider { info!("Connection succesfully established"); } Err(e) => { - error!("Connection is broken. Error during PING: {e:?}"); + warn!("Connection is broken. Error during PING: {e:?}"); } } self.con = Some(con); diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 36a3443603..c3af6740cb 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -146,7 +146,7 @@ impl MessageHandler { { Ok(()) => {} Err(e) => { - tracing::error!("failed to cancel session: {e}"); + tracing::warn!("failed to cancel session: {e}"); } } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 8a8f38d181..9e49478cf3 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -13,7 +13,7 @@ use crate::{ check_peer_addr_is_in_list, AuthError, }, compute, - config::{AuthenticationConfig, ProxyConfig}, + config::ProxyConfig, context::RequestMonitoring, control_plane::{ errors::{GetAuthInfoError, WakeComputeError}, @@ -28,7 +28,7 @@ use crate::{ retry::{CouldRetry, ShouldRetryWakeCompute}, }, rate_limiter::EndpointRateLimiter, - Host, + EndpointId, Host, }; use super::{ @@ -42,6 +42,7 @@ pub(crate) struct PoolingBackend { pub(crate) local_pool: Arc>, pub(crate) pool: Arc>, pub(crate) config: &'static ProxyConfig, + pub(crate) auth_backend: &'static crate::auth::Backend<'static, (), ()>, pub(crate) endpoint_rate_limiter: Arc, } @@ -49,18 +50,13 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_password( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, user_info: &ComputeUserInfo, password: &[u8], ) -> Result { let user_info = user_info.clone(); - let backend = self - .config - .auth_backend - .as_ref() - .map(|()| user_info.clone()); + let backend = self.auth_backend.as_ref().map(|()| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; - if config.ip_allowlist_check_enabled + if self.config.authentication_config.ip_allowlist_check_enabled && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) { return Err(AuthError::ip_address_not_allowed(ctx.peer_addr())); @@ -79,7 +75,6 @@ impl PoolingBackend { let secret = match cached_secret.value.clone() { Some(secret) => self.config.authentication_config.check_rate_limit( ctx, - config, secret, &user_info.endpoint, true, @@ -91,9 +86,13 @@ impl PoolingBackend { } }; let ep = EndpointIdInt::from(&user_info.endpoint); - let auth_outcome = - crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret) - .await?; + let auth_outcome = crate::auth::validate_password_and_exchange( + &self.config.authentication_config.thread_pool, + ep, + password, + secret, + ) + .await?; let res = match auth_outcome { crate::sasl::Outcome::Success(key) => { info!("user successfully authenticated"); @@ -113,13 +112,13 @@ impl PoolingBackend { pub(crate) async fn authenticate_with_jwt( &self, ctx: &RequestMonitoring, - config: &AuthenticationConfig, user_info: &ComputeUserInfo, jwt: String, ) -> Result { - match &self.config.auth_backend { + match &self.auth_backend { crate::auth::Backend::ControlPlane(console, ()) => { - config + self.config + .authentication_config .jwks_cache .check_jwt( ctx, @@ -140,7 +139,9 @@ impl PoolingBackend { "JWT login over web auth proxy is not supported", )), crate::auth::Backend::Local(_) => { - let keys = config + let keys = self + .config + .authentication_config .jwks_cache .check_jwt( ctx, @@ -185,7 +186,7 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self.config.auth_backend.as_ref().map(|()| keys); + let backend = self.auth_backend.as_ref().map(|()| keys); crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { @@ -217,14 +218,14 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self - .config - .auth_backend - .as_ref() - .map(|()| ComputeCredentials { - info: conn_info.user_info.clone(), - keys: crate::auth::backend::ComputeCredentialKeys::None, - }); + let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials { + info: ComputeUserInfo { + user: conn_info.user_info.user.clone(), + endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)), + options: conn_info.user_info.options.clone(), + }, + keys: crate::auth::backend::ComputeCredentialKeys::None, + }); crate::proxy::connect_compute::connect_to_compute( ctx, &HyperMechanism { @@ -262,7 +263,7 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); - let mut node_info = match &self.config.auth_backend { + let mut node_info = match &self.auth_backend { auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => { unreachable!("only local_proxy can connect to local postgres") } @@ -507,8 +508,12 @@ impl ConnectMechanism for HyperMechanism { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - // let port = node_info.config.get_ports().first().unwrap_or_else(10432); - let res = connect_http2(&host, 10432, timeout).await; + let port = *node_info.config.get_ports().first().ok_or_else(|| { + HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress( + "local-proxy port missing on compute address".into(), + )) + })?; + let res = connect_http2(&host, port, timeout).await; drop(pause); let (client, connection) = permit.release_result(res)?; diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 9be6b592bd..95f64e972c 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -48,13 +48,14 @@ use std::pin::{pin, Pin}; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn, Instrument}; +use tracing::{info, warn, Instrument}; use utils::http::error::ApiError; pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub async fn task_main( config: &'static ProxyConfig, + auth_backend: &'static crate::auth::Backend<'static, (), ()>, ws_listener: TcpListener, cancellation_token: CancellationToken, cancellation_handler: Arc, @@ -110,6 +111,7 @@ pub async fn task_main( local_pool, pool: Arc::clone(&conn_pool), config, + auth_backend, endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter), }); let tls_acceptor: Arc = match config.tls_config.as_ref() { @@ -241,7 +243,7 @@ async fn connection_startup( let (conn, peer) = match read_proxy_protocol(conn).await { Ok(c) => c, Err(e) => { - tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); + tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"); return None; } }; @@ -397,6 +399,7 @@ async fn request_handler( async move { if let Err(e) = websocket::serve_websocket( config, + backend.auth_backend, ctx, websocket, cancellation_handler, @@ -405,7 +408,7 @@ async fn request_handler( ) .await { - error!("error in websocket connection: {e:#}"); + warn!("error in websocket connection: {e:#}"); } } .instrument(span), diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index f7c3b26917..cf3324926c 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -45,6 +45,7 @@ use crate::auth::backend::ComputeUserInfo; use crate::auth::endpoint_sni; use crate::auth::ComputeUserInfoParseError; use crate::config::AuthenticationConfig; +use crate::config::HttpConfig; use crate::config::ProxyConfig; use crate::config::TlsConfig; use crate::context::RequestMonitoring; @@ -554,7 +555,7 @@ async fn handle_inner( match conn_info.auth { AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => { - handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await + handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await } auth => { handle_db_inner( @@ -622,28 +623,17 @@ async fn handle_db_inner( let authenticate_and_connect = Box::pin( async { - let is_local_proxy = - matches!(backend.config.auth_backend, crate::auth::Backend::Local(_)); + let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_)); let keys = match auth { AuthData::Password(pw) => { backend - .authenticate_with_password( - ctx, - &config.authentication_config, - &conn_info.user_info, - &pw, - ) + .authenticate_with_password(ctx, &conn_info.user_info, &pw) .await? } AuthData::Jwt(jwt) => { backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.user_info, - jwt, - ) + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) .await? } }; @@ -691,7 +681,7 @@ async fn handle_db_inner( // Now execute the query and return the result. let json_output = match payload { Payload::Single(stmt) => { - stmt.process(config, cancel, &mut client, parsed_headers) + stmt.process(&config.http_config, cancel, &mut client, parsed_headers) .await? } Payload::Batch(statements) => { @@ -709,7 +699,7 @@ async fn handle_db_inner( } statements - .process(config, cancel, &mut client, parsed_headers) + .process(&config.http_config, cancel, &mut client, parsed_headers) .await? } }; @@ -749,7 +739,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[ ]; async fn handle_auth_broker_inner( - config: &'static ProxyConfig, ctx: &RequestMonitoring, request: Request, conn_info: ConnInfo, @@ -757,12 +746,7 @@ async fn handle_auth_broker_inner( backend: Arc, ) -> Result>, SqlOverHttpError> { backend - .authenticate_with_jwt( - ctx, - &config.authentication_config, - &conn_info.user_info, - jwt, - ) + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) .await .map_err(HttpConnError::from)?; @@ -800,7 +784,7 @@ async fn handle_auth_broker_inner( impl QueryData { async fn process( self, - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -831,7 +815,7 @@ impl QueryData { Either::Right((_cancelled, query)) => { tracing::info!("cancelling query"); if let Err(err) = cancel_token.cancel_query(NoTls).await { - tracing::error!(?err, "could not cancel query"); + tracing::warn!(?err, "could not cancel query"); } // wait for the query cancellation match time::timeout(time::Duration::from_millis(100), query).await { @@ -874,7 +858,7 @@ impl QueryData { impl BatchQueryData { async fn process( self, - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, client: &mut Client, parsed_headers: HttpHeaders, @@ -920,7 +904,7 @@ impl BatchQueryData { } Err(SqlOverHttpError::Cancelled(_)) => { if let Err(err) = cancel_token.cancel_query(NoTls).await { - tracing::error!(?err, "could not cancel query"); + tracing::warn!(?err, "could not cancel query"); } // TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe. discard.discard(); @@ -944,7 +928,7 @@ impl BatchQueryData { } async fn query_batch( - config: &'static ProxyConfig, + config: &'static HttpConfig, cancel: CancellationToken, transaction: &Transaction<'_>, queries: BatchQueryData, @@ -983,7 +967,7 @@ async fn query_batch( } async fn query_to_json( - config: &'static ProxyConfig, + config: &'static HttpConfig, client: &T, data: QueryData, current_size: &mut usize, @@ -1004,9 +988,9 @@ async fn query_to_json( rows.push(row); // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) - if *current_size > config.http_config.max_response_size_bytes { + if *current_size > config.max_response_size_bytes { return Err(SqlOverHttpError::ResponseTooLarge( - config.http_config.max_response_size_bytes, + config.max_response_size_bytes, )); } } diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 08d5da9bef..fd0f0cac7f 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -129,6 +129,7 @@ impl AsyncBufRead for WebSocketRw { pub(crate) async fn serve_websocket( config: &'static ProxyConfig, + auth_backend: &'static crate::auth::Backend<'static, (), ()>, ctx: RequestMonitoring, websocket: OnUpgrade, cancellation_handler: Arc, @@ -145,6 +146,7 @@ pub(crate) async fn serve_websocket( let res = Box::pin(handle_client( config, + auth_backend, &ctx, cancellation_handler, WebSocketRw::new(websocket), diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index bd3e62bc12..ee36ed462d 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -27,7 +27,7 @@ use std::{ }; use tokio::io::AsyncWriteExt; use tokio_util::sync::CancellationToken; -use tracing::{error, info, instrument, trace}; +use tracing::{error, info, instrument, trace, warn}; use utils::backoff; use uuid::{NoContext, Timestamp}; @@ -346,7 +346,7 @@ async fn collect_metrics_iteration( error!("metrics endpoint refused the sent metrics: {:?}", res); for metric in chunk.events.iter().filter(|e| e.value > (1u64 << 40)) { // Report if the metric value is suspiciously large - error!("potentially abnormal metric value: {:?}", metric); + warn!("potentially abnormal metric value: {:?}", metric); } } } diff --git a/safekeeper/src/auth.rs b/safekeeper/src/auth.rs index c5c9393c00..fdd0830b02 100644 --- a/safekeeper/src/auth.rs +++ b/safekeeper/src/auth.rs @@ -15,15 +15,20 @@ pub fn check_permission(claims: &Claims, tenant_id: Option) -> Result< } Ok(()) } - (Scope::Admin | Scope::PageServerApi | Scope::GenerationsApi | Scope::Scrubber, _) => { - Err(AuthError( - format!( - "JWT scope '{:?}' is ineligible for Safekeeper auth", - claims.scope - ) - .into(), - )) - } + ( + Scope::Admin + | Scope::PageServerApi + | Scope::GenerationsApi + | Scope::Infra + | Scope::Scrubber, + _, + ) => Err(AuthError( + format!( + "JWT scope '{:?}' is ineligible for Safekeeper auth", + claims.scope + ) + .into(), + )), (Scope::SafekeeperData, _) => Ok(()), } } diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index 4dd8badd03..46b6f4f2bf 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -636,7 +636,7 @@ async fn handle_tenant_list( } async fn handle_node_register(req: Request) -> Result, ApiError> { - check_permissions(&req, Scope::Admin)?; + check_permissions(&req, Scope::Infra)?; let mut req = match maybe_forward(req).await { ForwardOutcome::Forwarded(res) => { @@ -1182,7 +1182,7 @@ async fn handle_get_safekeeper(req: Request) -> Result, Api /// Assumes information is only relayed to storage controller after first selecting an unique id on /// control plane database, which means we have an id field in the request and payload. async fn handle_upsert_safekeeper(mut req: Request) -> Result, ApiError> { - check_permissions(&req, Scope::Admin)?; + check_permissions(&req, Scope::Infra)?; let body = json_request::(&mut req).await?; let id = parse_request_param::(&req, "id")?; diff --git a/storage_scrubber/src/scan_pageserver_metadata.rs b/storage_scrubber/src/scan_pageserver_metadata.rs index c1ea589f7f..cb3299d413 100644 --- a/storage_scrubber/src/scan_pageserver_metadata.rs +++ b/storage_scrubber/src/scan_pageserver_metadata.rs @@ -317,9 +317,8 @@ pub async fn scan_pageserver_metadata( tenant_timeline_results.push((ttid, data)); } - let tenant_id = tenant_id.expect("Must be set if results are present"); - if !tenant_timeline_results.is_empty() { + let tenant_id = tenant_id.expect("Must be set if results are present"); analyze_tenant( &remote_client, tenant_id, diff --git a/test_runner/README.md b/test_runner/README.md index d754e60d17..e087241c1f 100644 --- a/test_runner/README.md +++ b/test_runner/README.md @@ -64,10 +64,12 @@ By default performance tests are excluded. To run them explicitly pass performan Useful environment variables: `NEON_BIN`: The directory where neon binaries can be found. +`COMPATIBILITY_NEON_BIN`: The directory where the previous version of Neon binaries can be found `POSTGRES_DISTRIB_DIR`: The directory where postgres distribution can be found. Since pageserver supports several postgres versions, `POSTGRES_DISTRIB_DIR` must contain a subdirectory for each version with naming convention `v{PG_VERSION}/`. Inside that dir, a `bin/postgres` binary should be present. +`COMPATIBILITY_POSTGRES_DISTRIB_DIR`: The directory where the prevoius version of postgres distribution can be found. `DEFAULT_PG_VERSION`: The version of Postgres to use, This is used to construct full path to the postgres binaries. Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION=16` @@ -294,6 +296,16 @@ def test_foobar2(neon_env_builder: NeonEnvBuilder): client.timeline_detail(tenant_id=tenant_id, timeline_id=timeline_id) ``` +All the test which rely on NeonEnvBuilder, can check the various version combinations of the components. +To do this yuo may want to add the parametrize decorator with the function fixtures.utils.allpairs_versions() +E.g. + +```python +@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) +def test_something( +... +``` + For more information about pytest fixtures, see https://docs.pytest.org/en/stable/fixture.html At the end of a test, all the nodes in the environment are automatically stopped, so you diff --git a/test_runner/conftest.py b/test_runner/conftest.py index d6e7fcf7ca..4a3194c691 100644 --- a/test_runner/conftest.py +++ b/test_runner/conftest.py @@ -6,6 +6,7 @@ pytest_plugins = ( "fixtures.httpserver", "fixtures.compute_reconfigure", "fixtures.storage_controller_proxy", + "fixtures.paths", "fixtures.neon_fixtures", "fixtures.benchmark_fixture", "fixtures.pg_stats", diff --git a/test_runner/fixtures/benchmark_fixture.py b/test_runner/fixtures/benchmark_fixture.py index 88f9ec1cd0..74fe39ef53 100644 --- a/test_runner/fixtures/benchmark_fixture.py +++ b/test_runner/fixtures/benchmark_fixture.py @@ -7,7 +7,6 @@ import json import os import re import timeit -from collections.abc import Iterator from contextlib import contextmanager from datetime import datetime from pathlib import Path @@ -25,7 +24,8 @@ from fixtures.log_helper import log from fixtures.neon_fixtures import NeonPageserver if TYPE_CHECKING: - from typing import Callable, ClassVar, Optional + from collections.abc import Iterator, Mapping + from typing import Callable, Optional """ @@ -141,6 +141,28 @@ class PgBenchRunResult: ) +# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171 +# +# This used to be a class variable on PgBenchInitResult. However later versions +# of Python complain: +# +# ValueError: mutable default for field EXTRACTORS is not allowed: use default_factory +# +# When you do what the error tells you to do, it seems to fail our Python 3.9 +# test environment. So let's just move it to a private module constant, and move +# on. +_PGBENCH_INIT_EXTRACTORS: Mapping[str, re.Pattern[str]] = { + "drop_tables": re.compile(r"drop tables (\d+\.\d+) s"), + "create_tables": re.compile(r"create tables (\d+\.\d+) s"), + "client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"), + "server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"), + "vacuum": re.compile(r"vacuum (\d+\.\d+) s"), + "primary_keys": re.compile(r"primary keys (\d+\.\d+) s"), + "foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"), + "total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench +} + + @dataclasses.dataclass class PgBenchInitResult: total: Optional[float] @@ -155,20 +177,6 @@ class PgBenchInitResult: start_timestamp: int end_timestamp: int - # Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171 - EXTRACTORS: ClassVar[dict[str, re.Pattern[str]]] = dataclasses.field( - default_factory=lambda: { - "drop_tables": re.compile(r"drop tables (\d+\.\d+) s"), - "create_tables": re.compile(r"create tables (\d+\.\d+) s"), - "client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"), - "server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"), - "vacuum": re.compile(r"vacuum (\d+\.\d+) s"), - "primary_keys": re.compile(r"primary keys (\d+\.\d+) s"), - "foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"), - "total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench - } - ) - @classmethod def parse_from_stderr( cls, @@ -185,7 +193,7 @@ class PgBenchInitResult: timings: dict[str, Optional[float]] = {} last_line_items = re.split(r"\(|\)|,", last_line) for item in last_line_items: - for key, regex in cls.EXTRACTORS.items(): + for key, regex in _PGBENCH_INIT_EXTRACTORS.items(): if (m := regex.match(item.strip())) is not None: if key in timings: raise RuntimeError( diff --git a/test_runner/fixtures/common_types.py b/test_runner/fixtures/common_types.py index 3022c0279f..0ea7148f50 100644 --- a/test_runner/fixtures/common_types.py +++ b/test_runner/fixtures/common_types.py @@ -6,6 +6,8 @@ from enum import Enum from functools import total_ordering from typing import TYPE_CHECKING, TypeVar +from typing_extensions import override + if TYPE_CHECKING: from typing import Any, Union @@ -31,33 +33,36 @@ class Lsn: self.lsn_int = (int(left, 16) << 32) + int(right, 16) assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF + @override def __str__(self) -> str: """Convert lsn from int to standard hex notation.""" return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}" + @override def __repr__(self) -> str: return f'Lsn("{str(self)}")' def __int__(self) -> int: return self.lsn_int - def __lt__(self, other: Any) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int < other.lsn_int - def __gt__(self, other: Any) -> bool: + def __gt__(self, other: object) -> bool: if not isinstance(other, Lsn): raise NotImplementedError return self.lsn_int > other.lsn_int - def __eq__(self, other: Any) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int == other.lsn_int # Returns the difference between two Lsns, in bytes - def __sub__(self, other: Any) -> int: + def __sub__(self, other: object) -> int: if not isinstance(other, Lsn): return NotImplemented return self.lsn_int - other.lsn_int @@ -70,6 +75,7 @@ class Lsn: else: raise NotImplementedError + @override def __hash__(self) -> int: return hash(self.lsn_int) @@ -116,19 +122,22 @@ class Id: self.id = bytearray.fromhex(x) assert len(self.id) == 16 + @override def __str__(self) -> str: return self.id.hex() - def __lt__(self, other) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self.id < other.id - def __eq__(self, other) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self.id == other.id + @override def __hash__(self) -> int: return hash(str(self.id)) @@ -139,25 +148,31 @@ class Id: class TenantId(Id): + @override def __repr__(self) -> str: return f'`TenantId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() class NodeId(Id): + @override def __repr__(self) -> str: return f'`NodeId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() class TimelineId(Id): + @override def __repr__(self) -> str: return f'TimelineId("{self.id.hex()}")' + @override def __str__(self) -> str: return self.id.hex() @@ -187,7 +202,7 @@ class TenantShardId: assert self.shard_number < self.shard_count or self.shard_count == 0 @classmethod - def parse(cls: type[TTenantShardId], input) -> TTenantShardId: + def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId: if len(input) == 32: return cls( tenant_id=TenantId(input), @@ -203,6 +218,7 @@ class TenantShardId: else: raise ValueError(f"Invalid TenantShardId '{input}'") + @override def __str__(self): if self.shard_count > 0: return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}" @@ -210,22 +226,25 @@ class TenantShardId: # Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id) return str(self.tenant_id) + @override def __repr__(self): return self.__str__() def _tuple(self) -> tuple[TenantId, int, int]: return (self.tenant_id, self.shard_number, self.shard_count) - def __lt__(self, other) -> bool: + def __lt__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._tuple() < other._tuple() - def __eq__(self, other) -> bool: + @override + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return NotImplemented return self._tuple() == other._tuple() + @override def __hash__(self) -> int: return hash(self._tuple()) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index ce191ac91c..2195ae8225 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -8,9 +8,11 @@ from contextlib import _GeneratorContextManager, contextmanager # Type-related stuff from pathlib import Path +from typing import TYPE_CHECKING import pytest from _pytest.fixtures import FixtureRequest +from typing_extensions import override from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.log_helper import log @@ -24,6 +26,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pg_stats import PgStatTable +if TYPE_CHECKING: + from collections.abc import Iterator + class PgCompare(ABC): """Common interface of all postgres implementations, useful for benchmarks. @@ -65,12 +70,12 @@ class PgCompare(ABC): @contextmanager @abstractmethod - def record_pageserver_writes(self, out_name): + def record_pageserver_writes(self, out_name: str): pass @contextmanager @abstractmethod - def record_duration(self, out_name): + def record_duration(self, out_name: str): pass @contextmanager @@ -122,28 +127,34 @@ class NeonCompare(PgCompare): self._pg = self.env.endpoints.create_start("main", "main", self.tenant) @property + @override def pg(self) -> PgProtocol: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg_bin + @override def flush(self, compact: bool = True, gc: bool = True): wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline) self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact) if gc: self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0) + @override def compact(self): self.pageserver_http_client.timeline_compact( self.tenant, self.timeline, wait_until_uploaded=True ) + @override def report_peak_memory_use(self): self.zenbenchmark.record( "peak_mem", @@ -152,6 +163,7 @@ class NeonCompare(PgCompare): report=MetricReport.LOWER_IS_BETTER, ) + @override def report_size(self): timeline_size = self.zenbenchmark.get_timeline_size( self.env.repo_dir, self.tenant, self.timeline @@ -185,9 +197,11 @@ class NeonCompare(PgCompare): "num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER ) + @override def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name) + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) @@ -211,26 +225,33 @@ class VanillaCompare(PgCompare): self.cur = self.conn.cursor() @property + @override def pg(self) -> VanillaPostgres: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg.pg_bin + @override def flush(self, compact: bool = False, gc: bool = False): self.cur.execute("checkpoint") + @override def compact(self): pass + @override def report_peak_memory_use(self): pass # TODO find something + @override def report_size(self): data_size = self.pg.get_subdir_size(Path("base")) self.zenbenchmark.record( @@ -245,6 +266,7 @@ class VanillaCompare(PgCompare): def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) @@ -261,28 +283,35 @@ class RemoteCompare(PgCompare): self.cur = self.conn.cursor() @property + @override def pg(self) -> PgProtocol: return self._pg @property + @override def zenbenchmark(self) -> NeonBenchmarker: return self._zenbenchmark @property + @override def pg_bin(self) -> PgBin: return self._pg.pg_bin - def flush(self): + @override + def flush(self, compact: bool = False, gc: bool = False): # TODO: flush the remote pageserver pass + @override def compact(self): pass + @override def report_peak_memory_use(self): # TODO: get memory usage from remote pageserver pass + @override def report_size(self): # TODO: get storage size from remote pageserver pass @@ -291,6 +320,7 @@ class RemoteCompare(PgCompare): def record_pageserver_writes(self, out_name: str) -> Iterator[None]: yield # Do nothing + @override def record_duration(self, out_name: str) -> _GeneratorContextManager[None]: return self.zenbenchmark.record_duration(out_name) diff --git a/test_runner/fixtures/compute_reconfigure.py b/test_runner/fixtures/compute_reconfigure.py index d2305ea431..6354b7f833 100644 --- a/test_runner/fixtures/compute_reconfigure.py +++ b/test_runner/fixtures/compute_reconfigure.py @@ -1,27 +1,31 @@ from __future__ import annotations import concurrent.futures -from typing import Any +from typing import TYPE_CHECKING import pytest +from pytest_httpserver import HTTPServer from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response from fixtures.common_types import TenantId from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Any, Callable, Optional + class ComputeReconfigure: - def __init__(self, server): + def __init__(self, server: HTTPServer): self.server = server self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach" - self.workloads = {} - self.on_notify = None + self.workloads: dict[TenantId, Any] = {} + self.on_notify: Optional[Callable[[Any], None]] = None - def register_workload(self, workload): + def register_workload(self, workload: Any): self.workloads[workload.tenant_id] = workload - def register_on_notify(self, fn): + def register_on_notify(self, fn: Optional[Callable[[Any], None]]): """ Add some extra work during a notification, like sleeping to slow things down, or logging what was notified. @@ -30,7 +34,7 @@ class ComputeReconfigure: @pytest.fixture(scope="function") -def compute_reconfigure_listener(make_httpserver): +def compute_reconfigure_listener(make_httpserver: HTTPServer): """ This fixture exposes an HTTP listener for the storage controller to submit compute notifications to us, instead of updating neon_local endpoints itself. @@ -48,7 +52,7 @@ def compute_reconfigure_listener(make_httpserver): # accept a healthy rate of calls into notify-attach. reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1) - def handler(request: Request): + def handler(request: Request) -> Response: assert request.json is not None body: dict[str, Any] = request.json log.info(f"notify-attach request: {body}") diff --git a/test_runner/fixtures/flaky.py b/test_runner/fixtures/flaky.py index 4ca87520a0..01634a29c5 100644 --- a/test_runner/fixtures/flaky.py +++ b/test_runner/fixtures/flaky.py @@ -14,8 +14,10 @@ from allure_pytest.utils import allure_name, allure_suite_labels from fixtures.log_helper import log if TYPE_CHECKING: + from collections.abc import MutableMapping from typing import Any + """ The plugin reruns flaky tests. It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py` diff --git a/test_runner/fixtures/httpserver.py b/test_runner/fixtures/httpserver.py index 9d5b5d6422..f653fd804c 100644 --- a/test_runner/fixtures/httpserver.py +++ b/test_runner/fixtures/httpserver.py @@ -1,8 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pytest from pytest_httpserver import HTTPServer +if TYPE_CHECKING: + from collections.abc import Iterator + + from fixtures.port_distributor import PortDistributor + # TODO: mypy fails with: # Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined] # from fixtures.neon_fixtures import PortDistributor @@ -17,7 +24,7 @@ def httpserver_ssl_context(): @pytest.fixture(scope="function") -def make_httpserver(httpserver_listen_address, httpserver_ssl_context): +def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]: host, port = httpserver_listen_address if not host: host = HTTPServer.DEFAULT_LISTEN_HOST @@ -33,13 +40,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context): @pytest.fixture(scope="function") -def httpserver(make_httpserver): +def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]: server = make_httpserver yield server server.clear() @pytest.fixture(scope="function") -def httpserver_listen_address(port_distributor) -> tuple[str, int]: +def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]: port = port_distributor.get_port() return ("localhost", port) diff --git a/test_runner/fixtures/log_helper.py b/test_runner/fixtures/log_helper.py index 70d76a39c4..ebf5c8d803 100644 --- a/test_runner/fixtures/log_helper.py +++ b/test_runner/fixtures/log_helper.py @@ -31,7 +31,7 @@ LOGGING = { } -def getLogger(name="root") -> logging.Logger: +def getLogger(name: str = "root") -> logging.Logger: """Method to get logger for tests. Should be used to get correctly initialized logger.""" diff --git a/test_runner/fixtures/metrics.py b/test_runner/fixtures/metrics.py index adc90a41d0..e056ea77d4 100644 --- a/test_runner/fixtures/metrics.py +++ b/test_runner/fixtures/metrics.py @@ -22,7 +22,7 @@ class Metrics: def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]: filter = filter or {} - res = [] + res: list[Sample] = [] for sample in self.metrics[name]: try: @@ -59,7 +59,7 @@ class MetricsGetter: return results[0].value def get_metrics_values( - self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok=False + self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False ) -> dict[str, float]: """ When fetching multiple named metrics, it is more efficient to use this diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index 846a790f1f..5934baccff 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast import requests if TYPE_CHECKING: - from typing import Any, Literal, Optional, Union + from typing import Any, Literal, Optional from fixtures.pg_version import PgVersion @@ -25,9 +25,7 @@ class NeonAPI: self.__neon_api_key = neon_api_key self.__neon_api_base_url = neon_api_base_url.strip("/") - def __request( - self, method: Union[str, bytes], endpoint: str, **kwargs: Any - ) -> requests.Response: + def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response: if "headers" not in kwargs: kwargs["headers"] = {} kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}" diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 5cb9821476..7789855fe4 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -18,7 +18,6 @@ from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import datetime from enum import Enum -from fcntl import LOCK_EX, LOCK_UN, flock from functools import cached_property from pathlib import Path from types import TracebackType @@ -59,6 +58,7 @@ from fixtures.pageserver.http import PageserverHttpClient from fixtures.pageserver.utils import ( wait_for_last_record_lsn, ) +from fixtures.paths import get_test_repo_dir, shared_snapshot_dir from fixtures.pg_version import PgVersion from fixtures.port_distributor import PortDistributor from fixtures.remote_storage import ( @@ -75,8 +75,8 @@ from fixtures.safekeeper.http import SafekeeperHttpClient from fixtures.safekeeper.utils import wait_walreceivers_absent from fixtures.utils import ( ATTACHMENT_NAME_REGEX, + COMPONENT_BINARIES, allure_add_grafana_links, - allure_attach_from_dir, assert_no_errors, get_dir_size, print_gc_result, @@ -96,6 +96,8 @@ if TYPE_CHECKING: Union, ) + from fixtures.paths import SnapshotDirLocked + T = TypeVar("T") @@ -118,65 +120,11 @@ put directly-importable functions into utils.py or another separate file. Env = dict[str, str] -DEFAULT_OUTPUT_DIR: str = "test_output" DEFAULT_BRANCH_NAME: str = "main" BASE_PORT: int = 15000 -@pytest.fixture(scope="session") -def base_dir() -> Iterator[Path]: - # find the base directory (currently this is the git root) - base_dir = Path(__file__).parents[2] - log.info(f"base_dir is {base_dir}") - - yield base_dir - - -@pytest.fixture(scope="function") -def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]: - if os.getenv("REMOTE_ENV"): - # we are in remote env and do not have neon binaries locally - # this is the case for benchmarks run on self-hosted runner - return - - # Find the neon binaries. - if env_neon_bin := os.environ.get("NEON_BIN"): - binpath = Path(env_neon_bin) - else: - binpath = base_dir / "target" / build_type - log.info(f"neon_binpath is {binpath}") - - if not (binpath / "pageserver").exists(): - raise Exception(f"neon binaries not found at '{binpath}'") - - yield binpath - - -@pytest.fixture(scope="session") -def pg_distrib_dir(base_dir: Path) -> Iterator[Path]: - if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"): - distrib_dir = Path(env_postgres_bin).resolve() - else: - distrib_dir = base_dir / "pg_install" - - log.info(f"pg_distrib_dir is {distrib_dir}") - yield distrib_dir - - -@pytest.fixture(scope="session") -def top_output_dir(base_dir: Path) -> Iterator[Path]: - # Compute the top-level directory for all tests. - if env_test_output := os.environ.get("TEST_OUTPUT"): - output_dir = Path(env_test_output).resolve() - else: - output_dir = base_dir / DEFAULT_OUTPUT_DIR - output_dir.mkdir(exist_ok=True) - - log.info(f"top_output_dir is {output_dir}") - yield output_dir - - @pytest.fixture(scope="session") def neon_api_key() -> str: api_key = os.getenv("NEON_API_KEY") @@ -369,11 +317,14 @@ class NeonEnvBuilder: run_id: uuid.UUID, mock_s3_server: MockS3Server, neon_binpath: Path, + compatibility_neon_binpath: Path, pg_distrib_dir: Path, + compatibility_pg_distrib_dir: Path, pg_version: PgVersion, test_name: str, top_output_dir: Path, test_output_dir: Path, + combination, test_overlay_dir: Optional[Path] = None, pageserver_remote_storage: Optional[RemoteStorage] = None, # toml that will be decomposed into `--config-override` flags during `pageserver --init` @@ -455,6 +406,19 @@ class NeonEnvBuilder: "test_" ), "Unexpectedly instantiated from outside a test function" self.test_name = test_name + self.compatibility_neon_binpath = compatibility_neon_binpath + self.compatibility_pg_distrib_dir = compatibility_pg_distrib_dir + self.version_combination = combination + self.mixdir = self.test_output_dir / "mixdir_neon" + if self.version_combination is not None: + assert ( + self.compatibility_neon_binpath is not None + ), "the environment variable COMPATIBILITY_NEON_BIN is required when using mixed versions" + assert ( + self.compatibility_pg_distrib_dir is not None + ), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required when using mixed versions" + self.mixdir.mkdir(mode=0o755, exist_ok=True) + self._mix_versions() def init_configs(self, default_remote_storage_if_missing: bool = True) -> NeonEnv: # Cannot create more than one environment from one builder @@ -655,6 +619,21 @@ class NeonEnvBuilder: return self.env + def _mix_versions(self): + assert self.version_combination is not None, "version combination must be set" + for component, paths in COMPONENT_BINARIES.items(): + directory = ( + self.neon_binpath + if self.version_combination[component] == "new" + else self.compatibility_neon_binpath + ) + for filename in paths: + destination = self.mixdir / filename + destination.symlink_to(directory / filename) + if self.version_combination["compute"] == "old": + self.pg_distrib_dir = self.compatibility_pg_distrib_dir + self.neon_binpath = self.mixdir + def overlay_mount(self, ident: str, srcdir: Path, dstdir: Path): """ Mount `srcdir` as an overlayfs mount at `dstdir`. @@ -1403,7 +1382,9 @@ def neon_simple_env( top_output_dir: Path, test_output_dir: Path, neon_binpath: Path, + compatibility_neon_binpath: Path, pg_distrib_dir: Path, + compatibility_pg_distrib_dir: Path, pg_version: PgVersion, pageserver_virtual_file_io_engine: str, pageserver_aux_file_policy: Optional[AuxFileStore], @@ -1418,6 +1399,11 @@ def neon_simple_env( # Create the environment in the per-test output directory repo_dir = get_test_repo_dir(request, top_output_dir) + combination = ( + request._pyfuncitem.callspec.params["combination"] + if "combination" in request._pyfuncitem.callspec.params + else None + ) with NeonEnvBuilder( top_output_dir=top_output_dir, @@ -1425,7 +1411,9 @@ def neon_simple_env( port_distributor=port_distributor, mock_s3_server=mock_s3_server, neon_binpath=neon_binpath, + compatibility_neon_binpath=compatibility_neon_binpath, pg_distrib_dir=pg_distrib_dir, + compatibility_pg_distrib_dir=compatibility_pg_distrib_dir, pg_version=pg_version, run_id=run_id, preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")), @@ -1435,6 +1423,7 @@ def neon_simple_env( pageserver_aux_file_policy=pageserver_aux_file_policy, pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm, pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode, + combination=combination, ) as builder: env = builder.init_start() @@ -1448,7 +1437,9 @@ def neon_env_builder( port_distributor: PortDistributor, mock_s3_server: MockS3Server, neon_binpath: Path, + compatibility_neon_binpath: Path, pg_distrib_dir: Path, + compatibility_pg_distrib_dir: Path, pg_version: PgVersion, run_id: uuid.UUID, request: FixtureRequest, @@ -1475,6 +1466,11 @@ def neon_env_builder( # Create the environment in the test-specific output dir repo_dir = os.path.join(test_output_dir, "repo") + combination = ( + request._pyfuncitem.callspec.params["combination"] + if "combination" in request._pyfuncitem.callspec.params + else None + ) # Return the builder to the caller with NeonEnvBuilder( @@ -1483,7 +1479,10 @@ def neon_env_builder( port_distributor=port_distributor, mock_s3_server=mock_s3_server, neon_binpath=neon_binpath, + compatibility_neon_binpath=compatibility_neon_binpath, pg_distrib_dir=pg_distrib_dir, + compatibility_pg_distrib_dir=compatibility_pg_distrib_dir, + combination=combination, pg_version=pg_version, run_id=run_id, preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")), @@ -3657,7 +3656,7 @@ class Endpoint(PgProtocol, LogUtils): config_lines: Optional[list[str]] = None, remote_ext_config: Optional[str] = None, pageserver_id: Optional[int] = None, - allow_multiple=False, + allow_multiple: bool = False, basebackup_request_tries: Optional[int] = None, ) -> Endpoint: """ @@ -3998,7 +3997,7 @@ class Safekeeper(LogUtils): def timeline_dir(self, tenant_id, timeline_id) -> Path: return self.data_dir / str(tenant_id) / str(timeline_id) - # List partial uploaded segments of this safekeeper. Works only for + # list partial uploaded segments of this safekeeper. Works only for # RemoteStorageKind.LOCAL_FS. def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId): tline_path = ( @@ -4246,44 +4245,6 @@ class StorageScrubber: raise -def _get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str) -> Path: - """Compute the path to a working directory for an individual test.""" - test_name = request.node.name - test_dir = top_output_dir / f"{prefix}{test_name.replace('/', '-')}" - - # We rerun flaky tests multiple times, use a separate directory for each run. - if (suffix := getattr(request.node, "execution_count", None)) is not None: - test_dir = test_dir.parent / f"{test_dir.name}-{suffix}" - - log.info(f"get_test_output_dir is {test_dir}") - # make mypy happy - assert isinstance(test_dir, Path) - return test_dir - - -def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path: - """ - The working directory for a test. - """ - return _get_test_dir(request, top_output_dir, "") - - -def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path: - """ - Directory that contains `upperdir` and `workdir` for overlayfs mounts - that a test creates. See `NeonEnvBuilder.overlay_mount`. - """ - return _get_test_dir(request, top_output_dir, "overlay-") - - -def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path: - return top_output_dir / "shared-snapshots" / snapshot_name - - -def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path: - return get_test_output_dir(request, top_output_dir) / "repo" - - def pytest_addoption(parser: Parser): parser.addoption( "--preserve-database-files", @@ -4293,154 +4254,11 @@ def pytest_addoption(parser: Parser): ) -SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg] +SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile( r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)" ) -# This is autouse, so the test output directory always gets created, even -# if a test doesn't put anything there. -# -# NB: we request the overlay dir fixture so the fixture does its cleanups -@pytest.fixture(scope="function", autouse=True) -def test_output_dir( - request: FixtureRequest, top_output_dir: Path, test_overlay_dir: Path -) -> Iterator[Path]: - """Create the working directory for an individual test.""" - - # one directory per test - test_dir = get_test_output_dir(request, top_output_dir) - log.info(f"test_output_dir is {test_dir}") - shutil.rmtree(test_dir, ignore_errors=True) - test_dir.mkdir() - - yield test_dir - - # Allure artifacts creation might involve the creation of `.tar.zst` archives, - # which aren't going to be used if Allure results collection is not enabled - # (i.e. --alluredir is not set). - # Skip `allure_attach_from_dir` in this case - if not request.config.getoption("--alluredir"): - return - - preserve_database_files = False - for k, v in request.node.user_properties: - # NB: the neon_env_builder fixture uses this fixture (test_output_dir). - # So, neon_env_builder's cleanup runs before here. - # The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property. - if k == "preserve_database_files": - assert isinstance(v, bool) - preserve_database_files = v - - allure_attach_from_dir(test_dir, preserve_database_files) - - -class FileAndThreadLock: - def __init__(self, path: Path): - self.path = path - self.thread_lock = threading.Lock() - self.fd: Optional[int] = None - - def __enter__(self): - self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY) - # lock thread lock before file lock so that there's no race - # around flocking / funlocking the file lock - self.thread_lock.acquire() - flock(self.fd, LOCK_EX) - - def __exit__(self, exc_type, exc_value, exc_traceback): - assert self.fd is not None - assert self.thread_lock.locked() # ... by us - flock(self.fd, LOCK_UN) - self.thread_lock.release() - os.close(self.fd) - self.fd = None - - -class SnapshotDirLocked: - def __init__(self, parent: SnapshotDir): - self._parent = parent - - def is_initialized(self): - # TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized. - # Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed. - return self._parent._marker_file_path.exists() - - def set_initialized(self): - self._parent._marker_file_path.write_text("") - - @property - def path(self) -> Path: - return self._parent._path / "snapshot" - - -class SnapshotDir: - _path: Path - - def __init__(self, path: Path): - self._path = path - assert self._path.is_dir() - self._lock = FileAndThreadLock(self._lock_file_path) - - @property - def _lock_file_path(self) -> Path: - return self._path / "initializing.flock" - - @property - def _marker_file_path(self) -> Path: - return self._path / "initialized.marker" - - def __enter__(self) -> SnapshotDirLocked: - self._lock.__enter__() - return SnapshotDirLocked(self) - - def __exit__(self, exc_type, exc_value, exc_traceback): - self._lock.__exit__(exc_type, exc_value, exc_traceback) - - -def shared_snapshot_dir(top_output_dir, ident: str) -> SnapshotDir: - snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident) - snapshot_dir_path.mkdir(exist_ok=True, parents=True) - return SnapshotDir(snapshot_dir_path) - - -@pytest.fixture(scope="function") -def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]: - """ - Idempotently create a test's overlayfs mount state directory. - If the functionality isn't enabled via env var, returns None. - - The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc). - """ - - if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None: - return None - - overlay_dir = get_test_overlay_dir(request, top_output_dir) - log.info(f"test_overlay_dir is {overlay_dir}") - - overlay_dir.mkdir(exist_ok=True) - # unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir` - for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)): - cmd = ["sudo", "umount", str(mountpoint)] - log.info( - f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}" - ) - subprocess.run(cmd, capture_output=True, check=True) - # the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work. - cmd = ["sudo", "rm", "-rf", str(overlay_dir)] - subprocess.run(cmd, capture_output=True, check=True) - - overlay_dir.mkdir() - - return overlay_dir - - # no need to clean up anything: on clean shutdown, - # NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup - # and on unclean shutdown, this function will take care of it - # on the next test run - - SKIP_DIRS = frozenset( ( "pg_wal", diff --git a/test_runner/fixtures/overlayfs.py b/test_runner/fixtures/overlayfs.py index e0ebfeb8f4..ea11cd272c 100644 --- a/test_runner/fixtures/overlayfs.py +++ b/test_runner/fixtures/overlayfs.py @@ -1,10 +1,13 @@ from __future__ import annotations -from collections.abc import Iterator from pathlib import Path +from typing import TYPE_CHECKING import psutil +if TYPE_CHECKING: + from collections.abc import Iterator + def iter_mounts_beneath(topdir: Path) -> Iterator[Path]: """ diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index 84a7e5f0a2..aa4435af4e 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -886,7 +886,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, - batch_size: int | None = None, + batch_size: Optional[int] = None, **kwargs, ) -> set[TimelineId]: params = {} diff --git a/test_runner/fixtures/parametrize.py b/test_runner/fixtures/parametrize.py index 3bbac4b8ee..4114c2fcb3 100644 --- a/test_runner/fixtures/parametrize.py +++ b/test_runner/fixtures/parametrize.py @@ -9,7 +9,12 @@ import toml from _pytest.python import Metafunc from fixtures.pg_version import PgVersion -from fixtures.utils import AuxFileStore + +if TYPE_CHECKING: + from typing import Any, Optional + + from fixtures.utils import AuxFileStore + if TYPE_CHECKING: from typing import Any, Optional diff --git a/test_runner/fixtures/paths.py b/test_runner/fixtures/paths.py new file mode 100644 index 0000000000..65f8e432b0 --- /dev/null +++ b/test_runner/fixtures/paths.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import os +import shutil +import subprocess +import threading +from fcntl import LOCK_EX, LOCK_UN, flock +from pathlib import Path +from types import TracebackType +from typing import TYPE_CHECKING + +import pytest +from pytest import FixtureRequest + +from fixtures import overlayfs +from fixtures.log_helper import log +from fixtures.utils import allure_attach_from_dir + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Optional + + +DEFAULT_OUTPUT_DIR: str = "test_output" + + +def get_test_dir( + request: FixtureRequest, top_output_dir: Path, prefix: Optional[str] = None +) -> Path: + """Compute the path to a working directory for an individual test.""" + test_name = request.node.name + test_dir = top_output_dir / f"{prefix or ''}{test_name.replace('/', '-')}" + + # We rerun flaky tests multiple times, use a separate directory for each run. + if (suffix := getattr(request.node, "execution_count", None)) is not None: + test_dir = test_dir.parent / f"{test_dir.name}-{suffix}" + + return test_dir + + +def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path: + """ + The working directory for a test. + """ + return get_test_dir(request, top_output_dir) + + +def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path: + """ + Directory that contains `upperdir` and `workdir` for overlayfs mounts + that a test creates. See `NeonEnvBuilder.overlay_mount`. + """ + return get_test_dir(request, top_output_dir, "overlay-") + + +def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path: + return top_output_dir / "shared-snapshots" / snapshot_name + + +def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path: + return get_test_output_dir(request, top_output_dir) / "repo" + + +@pytest.fixture(scope="session") +def base_dir() -> Iterator[Path]: + # find the base directory (currently this is the git root) + base_dir = Path(__file__).parents[2] + log.info(f"base_dir is {base_dir}") + + yield base_dir + + +@pytest.fixture(scope="session") +def compute_config_dir(base_dir: Path) -> Iterator[Path]: + """ + Retrieve the path to the compute configuration directory. + """ + yield base_dir / "compute" / "etc" + + +@pytest.fixture(scope="function") +def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]: + if os.getenv("REMOTE_ENV"): + # we are in remote env and do not have neon binaries locally + # this is the case for benchmarks run on self-hosted runner + return + + # Find the neon binaries. + if env_neon_bin := os.environ.get("NEON_BIN"): + binpath = Path(env_neon_bin) + else: + binpath = base_dir / "target" / build_type + log.info(f"neon_binpath is {binpath}") + + if not (binpath / "pageserver").exists(): + raise Exception(f"neon binaries not found at '{binpath}'") + + yield binpath.absolute() + + +@pytest.fixture(scope="session") +def compatibility_snapshot_dir() -> Iterator[Path]: + if os.getenv("REMOTE_ENV"): + return + compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR") + assert ( + compatibility_snapshot_dir_env is not None + ), "COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg(PG_VERSION)` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)" + compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve() + yield compatibility_snapshot_dir + + +@pytest.fixture(scope="session") +def compatibility_neon_binpath() -> Optional[Iterator[Path]]: + if os.getenv("REMOTE_ENV"): + return + comp_binpath = None + if env_compatibility_neon_binpath := os.environ.get("COMPATIBILITY_NEON_BIN"): + comp_binpath = Path(env_compatibility_neon_binpath).resolve().absolute() + yield comp_binpath + + +@pytest.fixture(scope="session") +def pg_distrib_dir(base_dir: Path) -> Iterator[Path]: + if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"): + distrib_dir = Path(env_postgres_bin).resolve() + else: + distrib_dir = base_dir / "pg_install" + + log.info(f"pg_distrib_dir is {distrib_dir}") + yield distrib_dir + + +@pytest.fixture(scope="session") +def compatibility_pg_distrib_dir() -> Optional[Iterator[Path]]: + compat_distrib_dir = None + if env_compat_postgres_bin := os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR"): + compat_distrib_dir = Path(env_compat_postgres_bin).resolve() + if not compat_distrib_dir.exists(): + raise Exception(f"compatibility postgres directory not found at {compat_distrib_dir}") + + if compat_distrib_dir: + log.info(f"compatibility_pg_distrib_dir is {compat_distrib_dir}") + yield compat_distrib_dir + + +@pytest.fixture(scope="session") +def top_output_dir(base_dir: Path) -> Iterator[Path]: + # Compute the top-level directory for all tests. + if env_test_output := os.environ.get("TEST_OUTPUT"): + output_dir = Path(env_test_output).resolve() + else: + output_dir = base_dir / DEFAULT_OUTPUT_DIR + output_dir.mkdir(exist_ok=True) + + log.info(f"top_output_dir is {output_dir}") + yield output_dir + + +# This is autouse, so the test output directory always gets created, even +# if a test doesn't put anything there. +# +# NB: we request the overlay dir fixture so the fixture does its cleanups +@pytest.fixture(scope="function", autouse=True) +def test_output_dir(request: pytest.FixtureRequest, top_output_dir: Path) -> Iterator[Path]: + """Create the working directory for an individual test.""" + + # one directory per test + test_dir = get_test_output_dir(request, top_output_dir) + log.info(f"test_output_dir is {test_dir}") + shutil.rmtree(test_dir, ignore_errors=True) + test_dir.mkdir() + + yield test_dir + + # Allure artifacts creation might involve the creation of `.tar.zst` archives, + # which aren't going to be used if Allure results collection is not enabled + # (i.e. --alluredir is not set). + # Skip `allure_attach_from_dir` in this case + if not request.config.getoption("--alluredir"): + return + + preserve_database_files = False + for k, v in request.node.user_properties: + # NB: the neon_env_builder fixture uses this fixture (test_output_dir). + # So, neon_env_builder's cleanup runs before here. + # The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property. + if k == "preserve_database_files": + assert isinstance(v, bool) + preserve_database_files = v + + allure_attach_from_dir(test_dir, preserve_database_files) + + +class FileAndThreadLock: + def __init__(self, path: Path): + self.path = path + self.thread_lock = threading.Lock() + self.fd: Optional[int] = None + + def __enter__(self): + self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY) + # lock thread lock before file lock so that there's no race + # around flocking / funlocking the file lock + self.thread_lock.acquire() + flock(self.fd, LOCK_EX) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ): + assert self.fd is not None + assert self.thread_lock.locked() # ... by us + flock(self.fd, LOCK_UN) + self.thread_lock.release() + os.close(self.fd) + self.fd = None + + +class SnapshotDirLocked: + def __init__(self, parent: SnapshotDir): + self._parent = parent + + def is_initialized(self): + # TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized. + # Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed. + return self._parent.marker_file_path.exists() + + def set_initialized(self): + self._parent.marker_file_path.write_text("") + + @property + def path(self) -> Path: + return self._parent.path / "snapshot" + + +class SnapshotDir: + _path: Path + + def __init__(self, path: Path): + self._path = path + assert self._path.is_dir() + self._lock = FileAndThreadLock(self.lock_file_path) + + @property + def path(self) -> Path: + return self._path + + @property + def lock_file_path(self) -> Path: + return self._path / "initializing.flock" + + @property + def marker_file_path(self) -> Path: + return self._path / "initialized.marker" + + def __enter__(self) -> SnapshotDirLocked: + self._lock.__enter__() + return SnapshotDirLocked(self) + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ): + self._lock.__exit__(exc_type, exc_value, exc_traceback) + + +def shared_snapshot_dir(top_output_dir: Path, ident: str) -> SnapshotDir: + snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident) + snapshot_dir_path.mkdir(exist_ok=True, parents=True) + return SnapshotDir(snapshot_dir_path) + + +@pytest.fixture(scope="function") +def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]: + """ + Idempotently create a test's overlayfs mount state directory. + If the functionality isn't enabled via env var, returns None. + + The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc). + """ + + if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None: + return None + + overlay_dir = get_test_overlay_dir(request, top_output_dir) + log.info(f"test_overlay_dir is {overlay_dir}") + + overlay_dir.mkdir(exist_ok=True) + # unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir` + for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)): + cmd = ["sudo", "umount", str(mountpoint)] + log.info( + f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}" + ) + subprocess.run(cmd, capture_output=True, check=True) + # the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work. + cmd = ["sudo", "rm", "-rf", str(overlay_dir)] + subprocess.run(cmd, capture_output=True, check=True) + + overlay_dir.mkdir() + + return overlay_dir + + # no need to clean up anything: on clean shutdown, + # NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup + # and on unclean shutdown, this function will take care of it + # on the next test run diff --git a/test_runner/fixtures/pg_version.py b/test_runner/fixtures/pg_version.py index 5820b50a46..01f0245665 100644 --- a/test_runner/fixtures/pg_version.py +++ b/test_runner/fixtures/pg_version.py @@ -2,9 +2,14 @@ from __future__ import annotations import enum import os -from typing import Optional +from typing import TYPE_CHECKING import pytest +from typing_extensions import override + +if TYPE_CHECKING: + from typing import Optional + """ This fixture is used to determine which version of Postgres to use for tests. @@ -24,10 +29,12 @@ class PgVersion(str, enum.Enum): NOT_SET = "<-POSTRGRES VERSION IS NOT SET->" # Make it less confusing in logs + @override def __repr__(self) -> str: return f"'{self.value}'" # Make this explicit for Python 3.11 compatibility, which changes the behavior of enums + @override def __str__(self) -> str: return self.value @@ -38,7 +45,8 @@ class PgVersion(str, enum.Enum): return f"v{self.value}" @classmethod - def _missing_(cls, value) -> Optional[PgVersion]: + @override + def _missing_(cls, value: object) -> Optional[PgVersion]: known_values = {v.value for _, v in cls.__members__.items()} # Allow passing version as a string with "v" prefix (e.g. "v14") diff --git a/test_runner/fixtures/port_distributor.py b/test_runner/fixtures/port_distributor.py index 435f452a02..df0eb2a809 100644 --- a/test_runner/fixtures/port_distributor.py +++ b/test_runner/fixtures/port_distributor.py @@ -59,10 +59,7 @@ class PortDistributor: if isinstance(value, int): return self._replace_port_int(value) - if isinstance(value, str): - return self._replace_port_str(value) - - raise TypeError(f"unsupported type {type(value)} of {value=}") + return self._replace_port_str(value) def _replace_port_int(self, value: int) -> int: known_port = self.port_map.get(value) @@ -75,7 +72,7 @@ class PortDistributor: # Use regex to find port in a string # urllib.parse.urlparse produces inconvenient results for cases without scheme like "localhost:5432" # See https://bugs.python.org/issue27657 - ports = re.findall(r":(\d+)(?:/|$)", value) + ports: list[str] = re.findall(r":(\d+)(?:/|$)", value) assert len(ports) == 1, f"can't find port in {value}" port_int = int(ports[0]) diff --git a/test_runner/fixtures/remote_storage.py b/test_runner/fixtures/remote_storage.py index 20e6bd9318..7024953661 100644 --- a/test_runner/fixtures/remote_storage.py +++ b/test_runner/fixtures/remote_storage.py @@ -13,6 +13,7 @@ import boto3 import toml from moto.server import ThreadedMotoServer from mypy_boto3_s3 import S3Client +from typing_extensions import override from fixtures.common_types import TenantId, TenantShardId, TimelineId from fixtures.log_helper import log @@ -36,6 +37,7 @@ class RemoteStorageUser(str, enum.Enum): EXTENSIONS = "ext" SAFEKEEPER = "safekeeper" + @override def __str__(self) -> str: return self.value @@ -81,11 +83,13 @@ class LocalFsStorage: def timeline_path(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path: return self.tenant_path(tenant_id) / "timelines" / str(timeline_id) - def timeline_latest_generation(self, tenant_id, timeline_id): + def timeline_latest_generation( + self, tenant_id: TenantId, timeline_id: TimelineId + ) -> Optional[int]: timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id)) index_parts = [f for f in timeline_files if f.startswith("index_part")] - def parse_gen(filename): + def parse_gen(filename: str) -> Optional[int]: log.info(f"parsing index_part '{filename}'") parts = filename.split("-") if len(parts) == 2: @@ -93,7 +97,7 @@ class LocalFsStorage: else: return None - generations = sorted([parse_gen(f) for f in index_parts]) + generations = sorted([parse_gen(f) for f in index_parts]) # type: ignore if len(generations) == 0: raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}") return generations[-1] @@ -122,14 +126,14 @@ class LocalFsStorage: filename = f"{local_name}-{generation:08x}" return self.timeline_path(tenant_id, timeline_id) / filename - def index_content(self, tenant_id: TenantId, timeline_id: TimelineId): + def index_content(self, tenant_id: TenantId, timeline_id: TimelineId) -> Any: with self.index_path(tenant_id, timeline_id).open("r") as f: return json.load(f) def heatmap_path(self, tenant_id: TenantId) -> Path: return self.tenant_path(tenant_id) / TENANT_HEATMAP_FILE_NAME - def heatmap_content(self, tenant_id): + def heatmap_content(self, tenant_id: TenantId) -> Any: with self.heatmap_path(tenant_id).open("r") as f: return json.load(f) @@ -297,7 +301,7 @@ class S3Storage: def heatmap_key(self, tenant_id: TenantId) -> str: return f"{self.tenant_path(tenant_id)}/{TENANT_HEATMAP_FILE_NAME}" - def heatmap_content(self, tenant_id: TenantId): + def heatmap_content(self, tenant_id: TenantId) -> Any: r = self.client.get_object(Bucket=self.bucket_name, Key=self.heatmap_key(tenant_id)) return json.loads(r["Body"].read().decode("utf-8")) @@ -317,7 +321,7 @@ class RemoteStorageKind(str, enum.Enum): def configure( self, repo_dir: Path, - mock_s3_server, + mock_s3_server: MockS3Server, run_id: str, test_name: str, user: RemoteStorageUser, @@ -451,15 +455,9 @@ def default_remote_storage() -> RemoteStorageKind: def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> dict[str, Any]: - if not isinstance(remote_storage, (LocalFsStorage, S3Storage)): - raise Exception("invalid remote storage type") - return remote_storage.to_toml_dict() # serialize as toml inline table def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str: - if not isinstance(remote_storage, (LocalFsStorage, S3Storage)): - raise Exception("invalid remote storage type") - return remote_storage.to_toml_inline_table() diff --git a/test_runner/fixtures/storage_controller_proxy.py b/test_runner/fixtures/storage_controller_proxy.py index 02cf6fc33f..c174358ef5 100644 --- a/test_runner/fixtures/storage_controller_proxy.py +++ b/test_runner/fixtures/storage_controller_proxy.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Any, Optional +from typing import TYPE_CHECKING import pytest import requests @@ -12,6 +12,9 @@ from werkzeug.wrappers.response import Response from fixtures.log_helper import log +if TYPE_CHECKING: + from typing import Any, Optional + class StorageControllerProxy: def __init__(self, server: HTTPServer): @@ -34,7 +37,7 @@ def proxy_request(method: str, url: str, **kwargs) -> requests.Response: @pytest.fixture(scope="function") -def storage_controller_proxy(make_httpserver): +def storage_controller_proxy(make_httpserver: HTTPServer): """ Proxies requests into the storage controller to the currently selected storage controller instance via `StorageControllerProxy.route_to`. @@ -48,7 +51,7 @@ def storage_controller_proxy(make_httpserver): log.info(f"Storage controller proxy listening on {self.listen}") - def handler(request: Request): + def handler(request: Request) -> Response: if self.route_to is None: log.info(f"Storage controller proxy has no routing configured for {request.url}") return Response("Routing not configured", status=503) diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index 23381e258a..76575d330c 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -18,6 +18,7 @@ from urllib.parse import urlencode import allure import zstandard from psycopg2.extensions import cursor +from typing_extensions import override from fixtures.log_helper import log from fixtures.pageserver.common_types import ( @@ -26,28 +27,45 @@ from fixtures.pageserver.common_types import ( ) if TYPE_CHECKING: - from typing import ( - IO, - Optional, - Union, - ) + from collections.abc import Iterable + from typing import IO, Optional + from fixtures.common_types import TimelineId from fixtures.neon_fixtures import PgBin -from fixtures.common_types import TimelineId + + WaitUntilRet = TypeVar("WaitUntilRet") + Fn = TypeVar("Fn", bound=Callable[..., Any]) +COMPONENT_BINARIES = { + "storage_controller": ("storage_controller",), + "storage_broker": ("storage_broker",), + "compute": ("compute_ctl",), + "safekeeper": ("safekeeper",), + "pageserver": ("pageserver", "pagectl"), +} +# Disable auto-formatting for better readability +# fmt: off +VERSIONS_COMBINATIONS = ( + {"storage_controller": "new", "storage_broker": "new", "compute": "new", "safekeeper": "new", "pageserver": "new"}, + {"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "old"}, + {"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "new"}, + {"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "new", "pageserver": "new"}, + {"storage_controller": "old", "storage_broker": "old", "compute": "new", "safekeeper": "new", "pageserver": "new"}, +) +# fmt: on def subprocess_capture( capture_dir: Path, cmd: list[str], *, - check=False, - echo_stderr=False, - echo_stdout=False, - capture_stdout=False, - timeout=None, - with_command_header=True, + check: bool = False, + echo_stderr: bool = False, + echo_stdout: bool = False, + capture_stdout: bool = False, + timeout: Optional[float] = None, + with_command_header: bool = True, **popen_kwargs: Any, ) -> tuple[str, Optional[str], int]: """Run a process and bifurcate its output to files and the `log` logger @@ -84,6 +102,7 @@ def subprocess_capture( self.capture = capture self.captured = "" + @override def run(self): first = with_command_header for line in self.in_file: @@ -165,10 +184,10 @@ def global_counter() -> int: def print_gc_result(row: dict[str, Any]): log.info("GC duration {elapsed} ms".format_map(row)) log.info( - " total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}" - " needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map( - row - ) + ( + " total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}" + " needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}" + ).format_map(row) ) @@ -226,7 +245,7 @@ def get_scale_for_db(size_mb: int) -> int: return round(0.06689 * size_mb - 0.5) -ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg] +ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile( r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)" ) @@ -289,7 +308,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz" def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int): """Add links to server logs in Grafana to Allure report""" - links = {} + links: dict[str, str] = {} # We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build endpoint_id, region_id, _ = host.split(".", 2) @@ -341,7 +360,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, def start_in_background( - command: list[str], cwd: Path, log_file_name: str, is_started: Fn + command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet] ) -> subprocess.Popen[bytes]: """Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors.""" @@ -376,14 +395,11 @@ def start_in_background( return spawned_process -WaitUntilRet = TypeVar("WaitUntilRet") - - def wait_until( number_of_iterations: int, interval: float, func: Callable[[], WaitUntilRet], - show_intermediate_error=False, + show_intermediate_error: bool = False, ) -> WaitUntilRet: """ Wait until 'func' returns successfully, without exception. Returns the @@ -464,7 +480,7 @@ def humantime_to_ms(humantime: str) -> float: def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]: # FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py error_or_warn = re.compile(r"\s(ERROR|WARN)") - errors = [] + errors: list[tuple[int, str]] = [] for lineno, line in enumerate(input, start=1): if len(line) == 0: continue @@ -484,7 +500,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list return errors -def assert_no_errors(log_file, service, allowed_errors): +def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]): if not log_file.exists(): log.warning(f"Skipping {service} log check: {log_file} does not exist") return @@ -504,9 +520,11 @@ class AuxFileStore(str, enum.Enum): V2 = "v2" CrossValidation = "cross-validation" + @override def __repr__(self) -> str: return f"'aux-{self.value}'" + @override def __str__(self) -> str: return f"'aux-{self.value}'" @@ -525,7 +543,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str """ started_at = time.time() - def hash_extracted(reader: Union[IO[bytes], None]) -> bytes: + def hash_extracted(reader: Optional[IO[bytes]]) -> bytes: assert reader is not None digest = sha256(usedforsecurity=False) while True: @@ -550,7 +568,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str right_list ), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}" - mismatching = set() + mismatching: set[str] = set() for left_tuple, right_tuple in zip(left_list, right_list): left_path, left_hash = left_tuple @@ -575,6 +593,7 @@ class PropagatingThread(threading.Thread): Simple Thread wrapper with join() propagating the possible exception in the thread. """ + @override def run(self): self.exc = None try: @@ -582,7 +601,8 @@ class PropagatingThread(threading.Thread): except BaseException as e: self.exc = e - def join(self, timeout=None): + @override + def join(self, timeout: Optional[float] = None) -> Any: super().join(timeout) if self.exc: raise self.exc @@ -604,3 +624,19 @@ def human_bytes(amt: float) -> str: amt = amt / 1024 raise RuntimeError("unreachable") + + +def allpairs_versions(): + """ + Returns a dictionary with arguments for pytest parametrize + to test the compatibility with the previous version of Neon components + combinations were pre-computed to test all the pairs of the components with + the different versions. + """ + ids = [] + for pair in VERSIONS_COMBINATIONS: + cur_id = [] + for component in sorted(pair.keys()): + cur_id.append(pair[component][0]) + ids.append(f"combination_{''.join(cur_id)}") + return {"argnames": "combination", "argvalues": VERSIONS_COMBINATIONS, "ids": ids} diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index 4f9c1125bf..e869c43185 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -1,7 +1,7 @@ from __future__ import annotations import threading -from typing import Any, Optional +from typing import TYPE_CHECKING from fixtures.common_types import TenantId, TimelineId from fixtures.log_helper import log @@ -14,6 +14,9 @@ from fixtures.neon_fixtures import ( ) from fixtures.pageserver.utils import wait_for_last_record_lsn +if TYPE_CHECKING: + from typing import Any, Optional + # neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex # to ensure we don't do that: this enables running lots of Workloads in parallel safely. ENDPOINT_LOCK = threading.Lock() @@ -100,7 +103,7 @@ class Workload: self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id ) - def write_rows(self, n, pageserver_id: Optional[int] = None, upload: bool = True): + def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True): endpoint = self.endpoint(pageserver_id) start = self.expect_rows end = start + n - 1 @@ -121,7 +124,9 @@ class Workload: else: return False - def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True): + def churn_rows( + self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True + ): assert self.expect_rows >= n max_iters = 10 diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 39d4a3a6d7..420055ac3a 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -4,7 +4,7 @@ import enum import json import os import time -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.log_helper import log @@ -16,6 +16,10 @@ from fixtures.pageserver.http import PageserverApiException from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + + AGGRESIVE_COMPACTION_TENANT_CONF = { # Disable gc and compaction. The test runs compaction manually. "gc_period": "0s", diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index 791e38383e..96ba3dd5a4 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING +import fixtures.utils import pytest import toml from fixtures.common_types import TenantId, TimelineId @@ -93,6 +94,34 @@ if TYPE_CHECKING: # # Run forward compatibility test # ./scripts/pytest -k test_forward_compatibility # +# +# How to run `test_version_mismatch` locally: +# +# export DEFAULT_PG_VERSION=16 +# export BUILD_TYPE=release +# export CHECK_ONDISK_DATA_COMPATIBILITY=true +# export COMPATIBILITY_NEON_BIN=neon_previous/target/${BUILD_TYPE} +# export COMPATIBILITY_POSTGRES_DISTRIB_DIR=neon_previous/pg_install +# export NEON_BIN=target/release +# export POSTGRES_DISTRIB_DIR=pg_install +# +# # Build previous version of binaries and store them somewhere: +# rm -rf pg_install target +# git checkout +# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc` +# mkdir -p neon_previous/target +# cp -a target/${BUILD_TYPE} ./neon_previous/target/${BUILD_TYPE} +# cp -a pg_install ./neon_previous/pg_install +# +# # Build current version of binaries and create a data snapshot: +# rm -rf pg_install target +# git checkout +# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc` +# ./scripts/pytest -k test_create_snapshot +# +# # Run the version mismatch test +# ./scripts/pytest -k test_version_mismatch + check_ondisk_data_compatibility_if_enabled = pytest.mark.skipif( os.environ.get("CHECK_ONDISK_DATA_COMPATIBILITY") is None, @@ -166,16 +195,11 @@ def test_backward_compatibility( neon_env_builder: NeonEnvBuilder, test_output_dir: Path, pg_version: PgVersion, + compatibility_snapshot_dir: Path, ): """ Test that the new binaries can read old data """ - compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR") - assert ( - compatibility_snapshot_dir_env is not None - ), f"COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg{pg_version.v_prefixed}` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)" - compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve() - breaking_changes_allowed = ( os.environ.get("ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true" ) @@ -214,27 +238,11 @@ def test_forward_compatibility( test_output_dir: Path, top_output_dir: Path, pg_version: PgVersion, + compatibility_snapshot_dir: Path, ): """ Test that the old binaries can read new data """ - compatibility_neon_bin_env = os.environ.get("COMPATIBILITY_NEON_BIN") - assert compatibility_neon_bin_env is not None, ( - "COMPATIBILITY_NEON_BIN is not set. It should be set to a path with Neon binaries " - "(ideally generated by the previous version of Neon)" - ) - compatibility_neon_bin = Path(compatibility_neon_bin_env).resolve() - - compatibility_postgres_distrib_dir_env = os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR") - assert ( - compatibility_postgres_distrib_dir_env is not None - ), "COMPATIBILITY_POSTGRES_DISTRIB_DIR is not set. It should be set to a pg_install directrory (ideally generated by the previous version of Neon)" - compatibility_postgres_distrib_dir = Path(compatibility_postgres_distrib_dir_env).resolve() - - compatibility_snapshot_dir = ( - top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}" - ) - breaking_changes_allowed = ( os.environ.get("ALLOW_FORWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true" ) @@ -245,9 +253,14 @@ def test_forward_compatibility( # Use previous version's production binaries (pageserver, safekeeper, pg_distrib_dir, etc.). # But always use the current version's neon_local binary. # This is because we want to test the compatibility of the data format, not the compatibility of the neon_local CLI. - neon_env_builder.neon_binpath = compatibility_neon_bin - neon_env_builder.pg_distrib_dir = compatibility_postgres_distrib_dir - neon_env_builder.neon_local_binpath = neon_env_builder.neon_local_binpath + assert ( + neon_env_builder.compatibility_neon_binpath is not None + ), "the environment variable COMPATIBILITY_NEON_BIN is required" + assert ( + neon_env_builder.compatibility_pg_distrib_dir is not None + ), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required" + neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath + neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir env = neon_env_builder.from_repo_dir( compatibility_snapshot_dir / "repo", @@ -558,3 +571,29 @@ def test_historic_storage_formats( env.pageserver.http_client().timeline_compact( dataset.tenant_id, existing_timeline_id, force_image_layer_creation=True ) + + +@check_ondisk_data_compatibility_if_enabled +@pytest.mark.xdist_group("compatibility") +@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) +def test_versions_mismatch( + neon_env_builder: NeonEnvBuilder, + test_output_dir: Path, + pg_version: PgVersion, + compatibility_snapshot_dir, + combination, +): + """ + Checks compatibility of different combinations of versions of the components + """ + neon_env_builder.num_safekeepers = 3 + env = neon_env_builder.from_repo_dir( + compatibility_snapshot_dir / "repo", + ) + env.pageserver.allowed_errors.extend( + [".*ingesting record with timestamp lagging more than wait_lsn_timeout.+"] + ) + env.start() + check_neon_works( + env, test_output_dir, compatibility_snapshot_dir / "dump.sql", test_output_dir / "repo" + ) diff --git a/test_runner/regress/test_neon_cli.py b/test_runner/regress/test_neon_cli.py index 3a0a4b10bf..783fb813cf 100644 --- a/test_runner/regress/test_neon_cli.py +++ b/test_runner/regress/test_neon_cli.py @@ -162,6 +162,11 @@ def test_cli_start_stop_multi(neon_env_builder: NeonEnvBuilder): env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID) env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID + 1) + # We will stop the storage controller while it may have requests in + # flight, and the pageserver complains when requests are abandoned. + for ps in env.pageservers: + ps.allowed_errors.append(".*request was dropped before completing.*") + # Keep NeonEnv state up to date, it usually owns starting/stopping services env.pageservers[0].running = False env.pageservers[1].running = False diff --git a/test_runner/regress/test_pageserver_generations.py b/test_runner/regress/test_pageserver_generations.py index 577a3a25ca..11ebb81023 100644 --- a/test_runner/regress/test_pageserver_generations.py +++ b/test_runner/regress/test_pageserver_generations.py @@ -15,7 +15,7 @@ import enum import os import re import time -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TimelineId @@ -40,6 +40,10 @@ from fixtures.remote_storage import ( from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + + # A tenant configuration that is convenient for generating uploads and deletions # without a large amount of postgres traffic. TENANT_CONF = { diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index d1d6b3af75..b1abcaa763 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -23,6 +23,7 @@ from fixtures.remote_storage import s3_storage from fixtures.utils import wait_until from fixtures.workload import Workload from pytest_httpserver import HTTPServer +from typing_extensions import override from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response @@ -954,6 +955,7 @@ class PageserverFailpoint(Failure): self.pageserver_id = pageserver_id self._mitigate = mitigate + @override def apply(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.allowed_errors.extend( @@ -961,19 +963,23 @@ class PageserverFailpoint(Failure): ) pageserver.http_client().configure_failpoints((self.failpoint, "return(1)")) + @override def clear(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.http_client().configure_failpoints((self.failpoint, "off")) if self._mitigate: env.storage_controller.node_configure(self.pageserver_id, {"availability": "Active"}) + @override def expect_available(self): return True + @override def can_mitigate(self): return self._mitigate - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"}) @@ -983,9 +989,11 @@ class StorageControllerFailpoint(Failure): self.pageserver_id = None self.action = action + @override def apply(self, env: NeonEnv): env.storage_controller.configure_failpoints((self.failpoint, self.action)) + @override def clear(self, env: NeonEnv): if "panic" in self.action: log.info("Restarting storage controller after panic") @@ -994,16 +1002,19 @@ class StorageControllerFailpoint(Failure): else: env.storage_controller.configure_failpoints((self.failpoint, "off")) + @override def expect_available(self): # Controller panics _do_ leave pageservers available, but our test code relies # on using the locate API to update configurations in Workload, so we must skip # these actions when the controller has been panicked. return "panic" not in self.action + @override def can_mitigate(self): return False - def fails_forward(self, env): + @override + def fails_forward(self, env: NeonEnv): # Edge case: the very last failpoint that simulates a DB connection error, where # the abort path will fail-forward and result in a complete split. fail_forward = self.failpoint == "shard-split-post-complete" @@ -1017,6 +1028,7 @@ class StorageControllerFailpoint(Failure): return fail_forward + @override def expect_exception(self): if "panic" in self.action: return requests.exceptions.ConnectionError @@ -1029,18 +1041,22 @@ class NodeKill(Failure): self.pageserver_id = pageserver_id self._mitigate = mitigate + @override def apply(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.stop(immediate=True) + @override def clear(self, env: NeonEnv): pageserver = env.get_pageserver(self.pageserver_id) pageserver.start() + @override def expect_available(self): return False - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"}) @@ -1059,21 +1075,26 @@ class CompositeFailure(Failure): self.pageserver_id = f.pageserver_id break + @override def apply(self, env: NeonEnv): for f in self.failures: f.apply(env) - def clear(self, env): + @override + def clear(self, env: NeonEnv): for f in self.failures: f.clear(env) + @override def expect_available(self): return all(f.expect_available() for f in self.failures) - def mitigate(self, env): + @override + def mitigate(self, env: NeonEnv): for f in self.failures: f.mitigate(env) + @override def expect_exception(self): expect = set(f.expect_exception() for f in self.failures) @@ -1211,7 +1232,7 @@ def test_sharding_split_failures( assert attached_count == initial_shard_count - def assert_split_done(exclude_ps_id=None) -> None: + def assert_split_done(exclude_ps_id: Optional[int] = None) -> None: secondary_count = 0 attached_count = 0 for ps in env.pageservers: diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 202634477c..1dcc37c407 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -9,6 +9,7 @@ from datetime import datetime, timezone from enum import Enum from typing import TYPE_CHECKING +import fixtures.utils import pytest from fixtures.auth_tokens import TokenScope from fixtures.common_types import TenantId, TenantShardId, TimelineId @@ -38,7 +39,11 @@ from fixtures.pg_version import PgVersion, run_only_on_default_postgres from fixtures.port_distributor import PortDistributor from fixtures.remote_storage import RemoteStorageKind, s3_storage from fixtures.storage_controller_proxy import StorageControllerProxy -from fixtures.utils import run_pg_bench_small, subprocess_capture, wait_until +from fixtures.utils import ( + run_pg_bench_small, + subprocess_capture, + wait_until, +) from fixtures.workload import Workload from mypy_boto3_s3.type_defs import ( ObjectTypeDef, @@ -60,9 +65,8 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids): return counts -def test_storage_controller_smoke( - neon_env_builder: NeonEnvBuilder, -): +@pytest.mark.parametrize(**fixtures.utils.allpairs_versions()) +def test_storage_controller_smoke(neon_env_builder: NeonEnvBuilder, combination): """ Test the basic lifecycle of a storage controller: - Restarting @@ -1038,7 +1042,7 @@ def test_storage_controller_tenant_deletion( ) # Break the compute hook: we are checking that deletion does not depend on the compute hook being available - def break_hook(): + def break_hook(_body: Any): raise RuntimeError("Unexpected call to compute hook") compute_reconfigure_listener.register_on_notify(break_hook) @@ -1300,11 +1304,11 @@ def test_storage_controller_heartbeats( node_to_tenants = build_node_to_tenants_map(env) log.info(f"Back online: {node_to_tenants=}") - # ... expecting the storage controller to reach a consistent state - def storage_controller_consistent(): - env.storage_controller.consistency_check() + # ... background reconciliation may need to run to clean up the location on the node that was offline + env.storage_controller.reconcile_until_idle() - wait_until(30, 1, storage_controller_consistent) + # ... expecting the storage controller to reach a consistent state + env.storage_controller.consistency_check() def test_storage_controller_re_attach(neon_env_builder: NeonEnvBuilder): diff --git a/test_runner/regress/test_storage_scrubber.py b/test_runner/regress/test_storage_scrubber.py index f999edc067..05db0fe977 100644 --- a/test_runner/regress/test_storage_scrubber.py +++ b/test_runner/regress/test_storage_scrubber.py @@ -6,7 +6,7 @@ import shutil import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Optional +from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TenantShardId, TimelineId @@ -20,6 +20,9 @@ from fixtures.remote_storage import S3Storage, s3_storage from fixtures.utils import wait_until from fixtures.workload import Workload +if TYPE_CHECKING: + from typing import Optional + @pytest.mark.parametrize("shard_count", [None, 4]) def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]):