Merge branch 'main' into yuchen/direct-io-for-read

This commit is contained in:
Yuchen Liang
2024-10-14 00:41:57 -04:00
committed by GitHub
56 changed files with 1025 additions and 607 deletions

View File

@@ -218,6 +218,9 @@ runs:
name: compatibility-snapshot-${{ runner.arch }}-${{ inputs.build_type }}-pg${{ inputs.pg_version }}
# Directory is created by test_compatibility.py::test_create_snapshot, keep the path in sync with the test
path: /tmp/test_output/compatibility_snapshot_pg${{ inputs.pg_version }}/
# The lack of compatibility snapshot shouldn't fail the job
# (for example if we didn't run the test for non build-and-test workflow)
skip-if-does-not-exist: true
- name: Upload test results
if: ${{ !cancelled() }}

View File

@@ -7,6 +7,10 @@ inputs:
path:
description: "A directory or file to upload"
required: true
skip-if-does-not-exist:
description: "Allow to skip if path doesn't exist, fail otherwise"
default: false
required: false
prefix:
description: "S3 prefix. Default is '${GITHUB_SHA}/${GITHUB_RUN_ID}/${GITHUB_RUN_ATTEMPT}'"
required: false
@@ -15,10 +19,12 @@ runs:
using: "composite"
steps:
- name: Prepare artifact
id: prepare-artifact
shell: bash -euxo pipefail {0}
env:
SOURCE: ${{ inputs.path }}
ARCHIVE: /tmp/uploads/${{ inputs.name }}.tar.zst
SKIP_IF_DOES_NOT_EXIST: ${{ inputs.skip-if-does-not-exist }}
run: |
mkdir -p $(dirname $ARCHIVE)
@@ -33,14 +39,22 @@ runs:
elif [ -f ${SOURCE} ]; then
time tar -cf ${ARCHIVE} --zstd ${SOURCE}
elif ! ls ${SOURCE} > /dev/null 2>&1; then
echo >&2 "${SOURCE} does not exist"
exit 2
if [ "${SKIP_IF_DOES_NOT_EXIST}" = "true" ]; then
echo 'SKIPPED=true' >> $GITHUB_OUTPUT
exit 0
else
echo >&2 "${SOURCE} does not exist"
exit 2
fi
else
echo >&2 "${SOURCE} is neither a directory nor a file, do not know how to handle it"
exit 3
fi
echo 'SKIPPED=false' >> $GITHUB_OUTPUT
- name: Upload artifact
if: ${{ steps.prepare-artifact.outputs.SKIPPED == 'false' }}
shell: bash -euxo pipefail {0}
env:
SOURCE: ${{ inputs.path }}

View File

@@ -193,16 +193,15 @@ jobs:
with:
submodules: true
# Disabled for now
# - name: Restore cargo deps cache
# id: cache_cargo
# uses: actions/cache@v4
# with:
# path: |
# !~/.cargo/registry/src
# ~/.cargo/git/
# target/
# key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-clippy-${{ hashFiles('rust-toolchain.toml') }}-${{ hashFiles('Cargo.lock') }}
- name: Cache cargo deps
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
!~/.cargo/registry/src
~/.cargo/git
target
key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-${{ hashFiles('./Cargo.lock') }}-${{ hashFiles('./rust-toolchain.toml') }}-rust
# Some of our rust modules use FFI and need those to be checked
- name: Get postgres headers

View File

@@ -33,7 +33,7 @@ jobs:
actions: read
steps:
- name: Export GH Workflow Stats
uses: fedordikarev/gh-workflow-stats-action@v0.1.2
uses: neondatabase/gh-workflow-stats-action@v0.1.4
with:
DB_URI: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }}
DB_TABLE: "gh_workflow_stats_neon"

View File

@@ -1,5 +1,6 @@
/compute_tools/ @neondatabase/control-plane @neondatabase/compute
/storage_controller @neondatabase/storage
/storage_scrubber @neondatabase/storage
/libs/pageserver_api/ @neondatabase/storage
/libs/postgres_ffi/ @neondatabase/compute @neondatabase/storage
/libs/remote_storage/ @neondatabase/storage

View File

@@ -31,9 +31,12 @@ pub enum Scope {
/// The scope used by pageservers in upcalls to storage controller and cloud control plane
#[serde(rename = "generations_api")]
GenerationsApi,
/// Allows access to control plane managment API and some storage controller endpoints.
/// Allows access to control plane managment API and all storage controller endpoints.
Admin,
/// Allows access to control plane & storage controller endpoints used in infrastructure automation (e.g. node registration)
Infra,
/// Allows access to storage controller APIs used by the scrubber, to interrogate the state
/// of a tenant & post scrub results.
Scrubber,

View File

@@ -14,14 +14,19 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
}
(Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope
(Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope
(Scope::Admin | Scope::SafekeeperData | Scope::GenerationsApi | Scope::Scrubber, _) => {
Err(AuthError(
format!(
"JWT scope '{:?}' is ineligible for Pageserver auth",
claims.scope
)
.into(),
))
}
(
Scope::Admin
| Scope::SafekeeperData
| Scope::GenerationsApi
| Scope::Infra
| Scope::Scrubber,
_,
) => Err(AuthError(
format!(
"JWT scope '{:?}' is ineligible for Pageserver auth",
claims.scope
)
.into(),
)),
}
}

View File

@@ -25,6 +25,10 @@ pub(crate) enum WebAuthError {
Io(#[from] std::io::Error),
}
pub struct ConsoleRedirectBackend {
console_uri: reqwest::Url,
}
impl UserFacingError for WebAuthError {
fn to_string_client(&self) -> String {
"Internal error".to_string()
@@ -57,7 +61,26 @@ pub(crate) fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
pub(super) async fn authenticate(
impl ConsoleRedirectBackend {
pub fn new(console_uri: reqwest::Url) -> Self {
Self { console_uri }
}
pub(super) fn url(&self) -> &reqwest::Url {
&self.console_uri
}
pub(crate) async fn authenticate(
&self,
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
authenticate(ctx, auth_config, &self.console_uri, client).await
}
}
async fn authenticate(
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,

View File

@@ -8,6 +8,7 @@ use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::WebAuthError;
use ipnet::{Ipv4Net, Ipv6Net};
use local::LocalBackend;
@@ -36,7 +37,7 @@ use crate::{
provider::{CachedAllowedIps, CachedNodeInfo},
Api,
},
stream, url,
stream,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
@@ -69,7 +70,7 @@ pub enum Backend<'a, T, D> {
/// Cloud API (V2).
ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
/// Authentication via a web browser.
ConsoleRedirect(MaybeOwned<'a, url::ApiUrl>, D),
ConsoleRedirect(MaybeOwned<'a, ConsoleRedirectBackend>, D),
/// Local proxy uses configured auth credentials and does not wake compute
Local(MaybeOwned<'a, LocalBackend>),
}
@@ -106,9 +107,9 @@ impl std::fmt::Display for Backend<'_, (), ()> {
#[cfg(test)]
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
},
Self::ConsoleRedirect(url, ()) => fmt
Self::ConsoleRedirect(backend, ()) => fmt
.debug_tuple("ConsoleRedirect")
.field(&url.as_str())
.field(&backend.url().as_str())
.finish(),
Self::Local(_) => fmt.debug_tuple("Local").finish(),
}
@@ -241,7 +242,6 @@ impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
@@ -265,7 +265,7 @@ impl AuthenticationConfig {
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
),
password_weight,
);
@@ -339,7 +339,6 @@ async fn auth_quirks(
let secret = if let Some(secret) = secret {
config.check_rate_limit(
ctx,
config,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
@@ -456,12 +455,12 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
Backend::ControlPlane(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
Self::ConsoleRedirect(url, ()) => {
Self::ConsoleRedirect(backend, ()) => {
info!("performing web authentication");
let info = console_redirect::authenticate(ctx, config, &url, client).await?;
let info = backend.authenticate(ctx, config, client).await?;
Backend::ConsoleRedirect(url, info)
Backend::ConsoleRedirect(backend, info)
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))

View File

@@ -6,9 +6,12 @@ use compute_api::spec::LocalProxySpec;
use dashmap::DashMap;
use futures::future::Either;
use proxy::{
auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
auth::{
self,
backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
@@ -132,6 +135,7 @@ async fn main() -> anyhow::Result<()> {
let args = LocalProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
// before we bind to any ports, write the process ID to a file
// so that compute-ctl can find our process later
@@ -193,6 +197,7 @@ async fn main() -> anyhow::Result<()> {
let task = serverless::task_main(
config,
auth_backend,
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
@@ -257,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.compute),
)),
metric_collection: None,
allow_self_signed_compute: false,
http_config,
@@ -286,6 +288,17 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
})))
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &LocalProxyCliArgs,
) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> {
let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.compute),
));
Ok(Box::leak(Box::new(auth_backend)))
}
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
loop {
rx.notified().await;

View File

@@ -10,6 +10,7 @@ use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::ConsoleRedirectBackend;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
@@ -311,8 +312,9 @@ async fn main() -> anyhow::Result<()> {
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
info!("Authentication backend: {}", config.auth_backend);
info!("Authentication backend: {}", auth_backend);
info!("Using region: {}", args.aws_region);
let region_provider =
@@ -462,6 +464,7 @@ async fn main() -> anyhow::Result<()> {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
@@ -472,6 +475,7 @@ async fn main() -> anyhow::Result<()> {
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
@@ -506,7 +510,7 @@ async fn main() -> anyhow::Result<()> {
));
}
if let auth::Backend::ControlPlane(api, _) = &config.auth_backend {
if let auth::Backend::ControlPlane(api, _) = auth_backend {
if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
@@ -610,6 +614,80 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
bail!("dynamic rate limiter should be disabled");
}
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.connect_compute_lock.parse()?;
info!(
?limiter,
shards,
?epoch,
"Using NodeLocks (connect_compute)"
);
let connect_compute_locks = control_plane::locks::ApiLocks::new(
"connect_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().proxy.connect_compute_lock,
)?;
let http_config = HttpConfig {
accept_websockets: !args.is_auth_broker,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
},
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
};
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
metric_collection,
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
}));
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
Ok(config)
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,
) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> {
let auth_backend = match &args.auth_backend {
AuthBackendType::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
@@ -665,7 +743,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
AuthBackendType::Web => {
let url = args.uri.parse()?;
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(url), ())
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(ConsoleRedirectBackend::new(url)), ())
}
#[cfg(feature = "testing")]
@@ -677,75 +755,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
}
};
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.connect_compute_lock.parse()?;
info!(
?limiter,
shards,
?epoch,
"Using NodeLocks (connect_compute)"
);
let connect_compute_locks = control_plane::locks::ApiLocks::new(
"connect_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().proxy.connect_compute_lock,
)?;
let http_config = HttpConfig {
accept_websockets: !args.is_auth_broker,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
},
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
};
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
}));
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
Ok(config)
Ok(Box::leak(Box::new(auth_backend)))
}
#[cfg(test)]

View File

@@ -1,8 +1,5 @@
use crate::{
auth::{
self,
backend::{jwt::JwkCache, AuthRateLimiter},
},
auth::backend::{jwt::JwkCache, AuthRateLimiter},
control_plane::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
@@ -29,7 +26,6 @@ use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::Backend<'static, (), ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,

View File

@@ -81,12 +81,12 @@ pub(crate) mod errors {
Reason::EndpointNotFound => ErrorKind::User,
Reason::BranchNotFound => ErrorKind::User,
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User,
Reason::ActiveTimeQuotaExceeded => ErrorKind::User,
Reason::ComputeTimeQuotaExceeded => ErrorKind::User,
Reason::WrittenDataQuotaExceeded => ErrorKind::User,
Reason::DataTransferQuotaExceeded => ErrorKind::User,
Reason::LogicalSizeQuotaExceeded => ErrorKind::User,
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::Quota,
Reason::ActiveTimeQuotaExceeded => ErrorKind::Quota,
Reason::ComputeTimeQuotaExceeded => ErrorKind::Quota,
Reason::WrittenDataQuotaExceeded => ErrorKind::Quota,
Reason::DataTransferQuotaExceeded => ErrorKind::Quota,
Reason::LogicalSizeQuotaExceeded => ErrorKind::Quota,
Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
Reason::RunningOperations => ErrorKind::ControlPlane,
@@ -103,7 +103,7 @@ pub(crate) mod errors {
} if error
.contains("compute time quota of non-primary branches is exceeded") =>
{
crate::error::ErrorKind::User
crate::error::ErrorKind::Quota
}
ControlPlaneError {
http_status_code: http::StatusCode::LOCKED,
@@ -112,7 +112,7 @@ pub(crate) mod errors {
} if error.contains("quota exceeded")
|| error.contains("the limit for current plan reached") =>
{
crate::error::ErrorKind::User
crate::error::ErrorKind::Quota
}
ControlPlaneError {
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,

View File

@@ -22,7 +22,7 @@ use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, error, info, info_span, warn, Instrument};
use tracing::{debug, info, info_span, warn, Instrument};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
@@ -456,7 +456,7 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
});
body.http_status_code = status;
error!("console responded with an error ({status}): {body:?}");
warn!("console responded with an error ({status}): {body:?}");
Err(ApiError::ControlPlane(body))
}

View File

@@ -49,6 +49,10 @@ pub enum ErrorKind {
#[label(rename = "serviceratelimit")]
ServiceRateLimit,
/// Proxy quota limit violation
#[label(rename = "quota")]
Quota,
/// internal errors
Service,
@@ -70,6 +74,7 @@ impl ErrorKind {
ErrorKind::ClientDisconnect => "clientdisconnect",
ErrorKind::RateLimit => "ratelimit",
ErrorKind::ServiceRateLimit => "serviceratelimit",
ErrorKind::Quota => "quota",
ErrorKind::Service => "service",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "postgres",

View File

@@ -35,7 +35,7 @@ use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, Instrument};
use tracing::{error, info, warn, Instrument};
use self::{
connect_compute::{connect_to_compute, TcpMechanism},
@@ -61,6 +61,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, (), ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -95,15 +96,15 @@ pub async fn task_main(
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
warn!("per-client task finished with an error: {e:#}");
return;
}
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
error!("missing required proxy protocol header");
warn!("missing required proxy protocol header");
return;
}
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
error!("proxy protocol header not supported");
warn!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
@@ -129,6 +130,7 @@ pub async fn task_main(
let startup = Box::pin(
handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
@@ -144,7 +146,7 @@ pub async fn task_main(
Err(e) => {
// todo: log and push to ctx the error kind
ctx.set_error_kind(e.get_error_kind());
error!(parent: &span, "per-client task finished with an error: {e:#}");
warn!(parent: &span, "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
@@ -155,7 +157,7 @@ pub async fn task_main(
match p.proxy_pass().instrument(span.clone()).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
}
Err(ErrorSource::Compute(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
@@ -243,8 +245,10 @@ impl ReportableError for ClientRequestError {
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, (), ()>,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
@@ -285,8 +289,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = config
.auth_backend
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();

View File

@@ -71,7 +71,7 @@ impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
tracing::error!(?err, "could not cancel the query in the database");
tracing::warn!(?err, "could not cancel the query in the database");
}
res
}

View File

@@ -6,7 +6,7 @@ use redis::{
ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult,
};
use tokio::task::JoinHandle;
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};
use super::elasticache::CredentialsProvider;
@@ -89,7 +89,7 @@ impl ConnectionWithCredentialsProvider {
return Ok(());
}
Err(e) => {
error!("Error during PING: {e:?}");
warn!("Error during PING: {e:?}");
}
}
} else {
@@ -121,7 +121,7 @@ impl ConnectionWithCredentialsProvider {
info!("Connection succesfully established");
}
Err(e) => {
error!("Connection is broken. Error during PING: {e:?}");
warn!("Connection is broken. Error during PING: {e:?}");
}
}
self.con = Some(con);

View File

@@ -146,7 +146,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
{
Ok(()) => {}
Err(e) => {
tracing::error!("failed to cancel session: {e}");
tracing::warn!("failed to cancel session: {e}");
}
}
}

View File

@@ -13,7 +13,7 @@ use crate::{
check_peer_addr_is_in_list, AuthError,
},
compute,
config::{AuthenticationConfig, ProxyConfig},
config::ProxyConfig,
context::RequestMonitoring,
control_plane::{
errors::{GetAuthInfoError, WakeComputeError},
@@ -28,7 +28,7 @@ use crate::{
retry::{CouldRetry, ShouldRetryWakeCompute},
},
rate_limiter::EndpointRateLimiter,
Host,
EndpointId, Host,
};
use super::{
@@ -42,6 +42,7 @@ pub(crate) struct PoolingBackend {
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: &'static crate::auth::Backend<'static, (), ()>,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
@@ -49,18 +50,13 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_password(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
password: &[u8],
) -> Result<ComputeCredentials, AuthError> {
let user_info = user_info.clone();
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| user_info.clone());
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if config.ip_allowlist_check_enabled
if self.config.authentication_config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
@@ -79,7 +75,6 @@ impl PoolingBackend {
let secret = match cached_secret.value.clone() {
Some(secret) => self.config.authentication_config.check_rate_limit(
ctx,
config,
secret,
&user_info.endpoint,
true,
@@ -91,9 +86,13 @@ impl PoolingBackend {
}
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome =
crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
.await?;
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
ep,
password,
secret,
)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");
@@ -113,13 +112,13 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<ComputeCredentials, AuthError> {
match &self.config.auth_backend {
match &self.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
config
self.config
.authentication_config
.jwks_cache
.check_jwt(
ctx,
@@ -140,7 +139,9 @@ impl PoolingBackend {
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(_) => {
let keys = config
let keys = self
.config
.authentication_config
.jwks_cache
.check_jwt(
ctx,
@@ -185,7 +186,7 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.config.auth_backend.as_ref().map(|()| keys);
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
@@ -217,14 +218,14 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)),
options: conn_info.user_info.options.clone(),
},
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
@@ -262,7 +263,7 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = match &self.config.auth_backend {
let mut node_info = match &self.auth_backend {
auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => {
unreachable!("only local_proxy can connect to local postgres")
}
@@ -507,8 +508,12 @@ impl ConnectMechanism for HyperMechanism {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
let res = connect_http2(&host, 10432, timeout).await;
let port = *node_info.config.get_ports().first().ok_or_else(|| {
HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress(
"local-proxy port missing on compute address".into(),
))
})?;
let res = connect_http2(&host, port, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;

View File

@@ -48,13 +48,14 @@ use std::pin::{pin, Pin};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};
use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, (), ()>,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -110,6 +111,7 @@ pub async fn task_main(
local_pool,
pool: Arc::clone(&conn_pool),
config,
auth_backend,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
@@ -241,7 +243,7 @@ async fn connection_startup(
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
Err(e) => {
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return None;
}
};
@@ -397,6 +399,7 @@ async fn request_handler(
async move {
if let Err(e) = websocket::serve_websocket(
config,
backend.auth_backend,
ctx,
websocket,
cancellation_handler,
@@ -405,7 +408,7 @@ async fn request_handler(
)
.await
{
error!("error in websocket connection: {e:#}");
warn!("error in websocket connection: {e:#}");
}
}
.instrument(span),

View File

@@ -45,6 +45,7 @@ use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::HttpConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
@@ -554,7 +555,7 @@ async fn handle_inner(
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
@@ -622,28 +623,17 @@ async fn handle_db_inner(
let authenticate_and_connect = Box::pin(
async {
let is_local_proxy =
matches!(backend.config.auth_backend, crate::auth::Backend::Local(_));
let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_));
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
&pw,
)
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await?
}
};
@@ -691,7 +681,7 @@ async fn handle_db_inner(
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => {
stmt.process(config, cancel, &mut client, parsed_headers)
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
}
Payload::Batch(statements) => {
@@ -709,7 +699,7 @@ async fn handle_db_inner(
}
statements
.process(config, cancel, &mut client, parsed_headers)
.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
}
};
@@ -749,7 +739,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
];
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
@@ -757,12 +746,7 @@ async fn handle_auth_broker_inner(
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::from)?;
@@ -800,7 +784,7 @@ async fn handle_auth_broker_inner(
impl QueryData {
async fn process(
self,
config: &'static ProxyConfig,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client,
parsed_headers: HttpHeaders,
@@ -831,7 +815,7 @@ impl QueryData {
Either::Right((_cancelled, query)) => {
tracing::info!("cancelling query");
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
tracing::warn!(?err, "could not cancel query");
}
// wait for the query cancellation
match time::timeout(time::Duration::from_millis(100), query).await {
@@ -874,7 +858,7 @@ impl QueryData {
impl BatchQueryData {
async fn process(
self,
config: &'static ProxyConfig,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client,
parsed_headers: HttpHeaders,
@@ -920,7 +904,7 @@ impl BatchQueryData {
}
Err(SqlOverHttpError::Cancelled(_)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
tracing::warn!(?err, "could not cancel query");
}
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
discard.discard();
@@ -944,7 +928,7 @@ impl BatchQueryData {
}
async fn query_batch(
config: &'static ProxyConfig,
config: &'static HttpConfig,
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
@@ -983,7 +967,7 @@ async fn query_batch(
}
async fn query_to_json<T: GenericClient>(
config: &'static ProxyConfig,
config: &'static HttpConfig,
client: &T,
data: QueryData,
current_size: &mut usize,
@@ -1004,9 +988,9 @@ async fn query_to_json<T: GenericClient>(
rows.push(row);
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
if *current_size > config.http_config.max_response_size_bytes {
if *current_size > config.max_response_size_bytes {
return Err(SqlOverHttpError::ResponseTooLarge(
config.http_config.max_response_size_bytes,
config.max_response_size_bytes,
));
}
}

View File

@@ -129,6 +129,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, (), ()>,
ctx: RequestMonitoring,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -145,6 +146,7 @@ pub(crate) async fn serve_websocket(
let res = Box::pin(handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
WebSocketRw::new(websocket),

View File

@@ -27,7 +27,7 @@ use std::{
};
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, instrument, trace};
use tracing::{error, info, instrument, trace, warn};
use utils::backoff;
use uuid::{NoContext, Timestamp};
@@ -346,7 +346,7 @@ async fn collect_metrics_iteration(
error!("metrics endpoint refused the sent metrics: {:?}", res);
for metric in chunk.events.iter().filter(|e| e.value > (1u64 << 40)) {
// Report if the metric value is suspiciously large
error!("potentially abnormal metric value: {:?}", metric);
warn!("potentially abnormal metric value: {:?}", metric);
}
}
}

View File

@@ -15,15 +15,20 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
}
Ok(())
}
(Scope::Admin | Scope::PageServerApi | Scope::GenerationsApi | Scope::Scrubber, _) => {
Err(AuthError(
format!(
"JWT scope '{:?}' is ineligible for Safekeeper auth",
claims.scope
)
.into(),
))
}
(
Scope::Admin
| Scope::PageServerApi
| Scope::GenerationsApi
| Scope::Infra
| Scope::Scrubber,
_,
) => Err(AuthError(
format!(
"JWT scope '{:?}' is ineligible for Safekeeper auth",
claims.scope
)
.into(),
)),
(Scope::SafekeeperData, _) => Ok(()),
}
}

View File

@@ -636,7 +636,7 @@ async fn handle_tenant_list(
}
async fn handle_node_register(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
check_permissions(&req, Scope::Infra)?;
let mut req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
@@ -1182,7 +1182,7 @@ async fn handle_get_safekeeper(req: Request<Body>) -> Result<Response<Body>, Api
/// Assumes information is only relayed to storage controller after first selecting an unique id on
/// control plane database, which means we have an id field in the request and payload.
async fn handle_upsert_safekeeper(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::Admin)?;
check_permissions(&req, Scope::Infra)?;
let body = json_request::<SafekeeperPersistence>(&mut req).await?;
let id = parse_request_param::<i64>(&req, "id")?;

View File

@@ -317,9 +317,8 @@ pub async fn scan_pageserver_metadata(
tenant_timeline_results.push((ttid, data));
}
let tenant_id = tenant_id.expect("Must be set if results are present");
if !tenant_timeline_results.is_empty() {
let tenant_id = tenant_id.expect("Must be set if results are present");
analyze_tenant(
&remote_client,
tenant_id,

View File

@@ -64,10 +64,12 @@ By default performance tests are excluded. To run them explicitly pass performan
Useful environment variables:
`NEON_BIN`: The directory where neon binaries can be found.
`COMPATIBILITY_NEON_BIN`: The directory where the previous version of Neon binaries can be found
`POSTGRES_DISTRIB_DIR`: The directory where postgres distribution can be found.
Since pageserver supports several postgres versions, `POSTGRES_DISTRIB_DIR` must contain
a subdirectory for each version with naming convention `v{PG_VERSION}/`.
Inside that dir, a `bin/postgres` binary should be present.
`COMPATIBILITY_POSTGRES_DISTRIB_DIR`: The directory where the prevoius version of postgres distribution can be found.
`DEFAULT_PG_VERSION`: The version of Postgres to use,
This is used to construct full path to the postgres binaries.
Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION=16`
@@ -294,6 +296,16 @@ def test_foobar2(neon_env_builder: NeonEnvBuilder):
client.timeline_detail(tenant_id=tenant_id, timeline_id=timeline_id)
```
All the test which rely on NeonEnvBuilder, can check the various version combinations of the components.
To do this yuo may want to add the parametrize decorator with the function fixtures.utils.allpairs_versions()
E.g.
```python
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
def test_something(
...
```
For more information about pytest fixtures, see https://docs.pytest.org/en/stable/fixture.html
At the end of a test, all the nodes in the environment are automatically stopped, so you

View File

@@ -6,6 +6,7 @@ pytest_plugins = (
"fixtures.httpserver",
"fixtures.compute_reconfigure",
"fixtures.storage_controller_proxy",
"fixtures.paths",
"fixtures.neon_fixtures",
"fixtures.benchmark_fixture",
"fixtures.pg_stats",

View File

@@ -7,7 +7,6 @@ import json
import os
import re
import timeit
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
@@ -25,7 +24,8 @@ from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonPageserver
if TYPE_CHECKING:
from typing import Callable, ClassVar, Optional
from collections.abc import Iterator, Mapping
from typing import Callable, Optional
"""
@@ -141,6 +141,28 @@ class PgBenchRunResult:
)
# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171
#
# This used to be a class variable on PgBenchInitResult. However later versions
# of Python complain:
#
# ValueError: mutable default <class 'dict'> for field EXTRACTORS is not allowed: use default_factory
#
# When you do what the error tells you to do, it seems to fail our Python 3.9
# test environment. So let's just move it to a private module constant, and move
# on.
_PGBENCH_INIT_EXTRACTORS: Mapping[str, re.Pattern[str]] = {
"drop_tables": re.compile(r"drop tables (\d+\.\d+) s"),
"create_tables": re.compile(r"create tables (\d+\.\d+) s"),
"client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"),
"server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"),
"vacuum": re.compile(r"vacuum (\d+\.\d+) s"),
"primary_keys": re.compile(r"primary keys (\d+\.\d+) s"),
"foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"),
"total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench
}
@dataclasses.dataclass
class PgBenchInitResult:
total: Optional[float]
@@ -155,20 +177,6 @@ class PgBenchInitResult:
start_timestamp: int
end_timestamp: int
# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171
EXTRACTORS: ClassVar[dict[str, re.Pattern[str]]] = dataclasses.field(
default_factory=lambda: {
"drop_tables": re.compile(r"drop tables (\d+\.\d+) s"),
"create_tables": re.compile(r"create tables (\d+\.\d+) s"),
"client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"),
"server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"),
"vacuum": re.compile(r"vacuum (\d+\.\d+) s"),
"primary_keys": re.compile(r"primary keys (\d+\.\d+) s"),
"foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"),
"total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench
}
)
@classmethod
def parse_from_stderr(
cls,
@@ -185,7 +193,7 @@ class PgBenchInitResult:
timings: dict[str, Optional[float]] = {}
last_line_items = re.split(r"\(|\)|,", last_line)
for item in last_line_items:
for key, regex in cls.EXTRACTORS.items():
for key, regex in _PGBENCH_INIT_EXTRACTORS.items():
if (m := regex.match(item.strip())) is not None:
if key in timings:
raise RuntimeError(

View File

@@ -6,6 +6,8 @@ from enum import Enum
from functools import total_ordering
from typing import TYPE_CHECKING, TypeVar
from typing_extensions import override
if TYPE_CHECKING:
from typing import Any, Union
@@ -31,33 +33,36 @@ class Lsn:
self.lsn_int = (int(left, 16) << 32) + int(right, 16)
assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF
@override
def __str__(self) -> str:
"""Convert lsn from int to standard hex notation."""
return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}"
@override
def __repr__(self) -> str:
return f'Lsn("{str(self)}")'
def __int__(self) -> int:
return self.lsn_int
def __lt__(self, other: Any) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int < other.lsn_int
def __gt__(self, other: Any) -> bool:
def __gt__(self, other: object) -> bool:
if not isinstance(other, Lsn):
raise NotImplementedError
return self.lsn_int > other.lsn_int
def __eq__(self, other: Any) -> bool:
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int == other.lsn_int
# Returns the difference between two Lsns, in bytes
def __sub__(self, other: Any) -> int:
def __sub__(self, other: object) -> int:
if not isinstance(other, Lsn):
return NotImplemented
return self.lsn_int - other.lsn_int
@@ -70,6 +75,7 @@ class Lsn:
else:
raise NotImplementedError
@override
def __hash__(self) -> int:
return hash(self.lsn_int)
@@ -116,19 +122,22 @@ class Id:
self.id = bytearray.fromhex(x)
assert len(self.id) == 16
@override
def __str__(self) -> str:
return self.id.hex()
def __lt__(self, other) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self.id < other.id
def __eq__(self, other) -> bool:
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self.id == other.id
@override
def __hash__(self) -> int:
return hash(str(self.id))
@@ -139,25 +148,31 @@ class Id:
class TenantId(Id):
@override
def __repr__(self) -> str:
return f'`TenantId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
class NodeId(Id):
@override
def __repr__(self) -> str:
return f'`NodeId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
class TimelineId(Id):
@override
def __repr__(self) -> str:
return f'TimelineId("{self.id.hex()}")'
@override
def __str__(self) -> str:
return self.id.hex()
@@ -187,7 +202,7 @@ class TenantShardId:
assert self.shard_number < self.shard_count or self.shard_count == 0
@classmethod
def parse(cls: type[TTenantShardId], input) -> TTenantShardId:
def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId:
if len(input) == 32:
return cls(
tenant_id=TenantId(input),
@@ -203,6 +218,7 @@ class TenantShardId:
else:
raise ValueError(f"Invalid TenantShardId '{input}'")
@override
def __str__(self):
if self.shard_count > 0:
return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}"
@@ -210,22 +226,25 @@ class TenantShardId:
# Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id)
return str(self.tenant_id)
@override
def __repr__(self):
return self.__str__()
def _tuple(self) -> tuple[TenantId, int, int]:
return (self.tenant_id, self.shard_number, self.shard_count)
def __lt__(self, other) -> bool:
def __lt__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self._tuple() < other._tuple()
def __eq__(self, other) -> bool:
@override
def __eq__(self, other: object) -> bool:
if not isinstance(other, type(self)):
return NotImplemented
return self._tuple() == other._tuple()
@override
def __hash__(self) -> int:
return hash(self._tuple())

View File

@@ -8,9 +8,11 @@ from contextlib import _GeneratorContextManager, contextmanager
# Type-related stuff
from pathlib import Path
from typing import TYPE_CHECKING
import pytest
from _pytest.fixtures import FixtureRequest
from typing_extensions import override
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
from fixtures.log_helper import log
@@ -24,6 +26,9 @@ from fixtures.neon_fixtures import (
)
from fixtures.pg_stats import PgStatTable
if TYPE_CHECKING:
from collections.abc import Iterator
class PgCompare(ABC):
"""Common interface of all postgres implementations, useful for benchmarks.
@@ -65,12 +70,12 @@ class PgCompare(ABC):
@contextmanager
@abstractmethod
def record_pageserver_writes(self, out_name):
def record_pageserver_writes(self, out_name: str):
pass
@contextmanager
@abstractmethod
def record_duration(self, out_name):
def record_duration(self, out_name: str):
pass
@contextmanager
@@ -122,28 +127,34 @@ class NeonCompare(PgCompare):
self._pg = self.env.endpoints.create_start("main", "main", self.tenant)
@property
@override
def pg(self) -> PgProtocol:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg_bin
@override
def flush(self, compact: bool = True, gc: bool = True):
wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline)
self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact)
if gc:
self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0)
@override
def compact(self):
self.pageserver_http_client.timeline_compact(
self.tenant, self.timeline, wait_until_uploaded=True
)
@override
def report_peak_memory_use(self):
self.zenbenchmark.record(
"peak_mem",
@@ -152,6 +163,7 @@ class NeonCompare(PgCompare):
report=MetricReport.LOWER_IS_BETTER,
)
@override
def report_size(self):
timeline_size = self.zenbenchmark.get_timeline_size(
self.env.repo_dir, self.tenant, self.timeline
@@ -185,9 +197,11 @@ class NeonCompare(PgCompare):
"num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER
)
@override
def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
@@ -211,26 +225,33 @@ class VanillaCompare(PgCompare):
self.cur = self.conn.cursor()
@property
@override
def pg(self) -> VanillaPostgres:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
@override
def flush(self, compact: bool = False, gc: bool = False):
self.cur.execute("checkpoint")
@override
def compact(self):
pass
@override
def report_peak_memory_use(self):
pass # TODO find something
@override
def report_size(self):
data_size = self.pg.get_subdir_size(Path("base"))
self.zenbenchmark.record(
@@ -245,6 +266,7 @@ class VanillaCompare(PgCompare):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)
@@ -261,28 +283,35 @@ class RemoteCompare(PgCompare):
self.cur = self.conn.cursor()
@property
@override
def pg(self) -> PgProtocol:
return self._pg
@property
@override
def zenbenchmark(self) -> NeonBenchmarker:
return self._zenbenchmark
@property
@override
def pg_bin(self) -> PgBin:
return self._pg.pg_bin
def flush(self):
@override
def flush(self, compact: bool = False, gc: bool = False):
# TODO: flush the remote pageserver
pass
@override
def compact(self):
pass
@override
def report_peak_memory_use(self):
# TODO: get memory usage from remote pageserver
pass
@override
def report_size(self):
# TODO: get storage size from remote pageserver
pass
@@ -291,6 +320,7 @@ class RemoteCompare(PgCompare):
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
yield # Do nothing
@override
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
return self.zenbenchmark.record_duration(out_name)

View File

@@ -1,27 +1,31 @@
from __future__ import annotations
import concurrent.futures
from typing import Any
from typing import TYPE_CHECKING
import pytest
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
from fixtures.common_types import TenantId
from fixtures.log_helper import log
if TYPE_CHECKING:
from typing import Any, Callable, Optional
class ComputeReconfigure:
def __init__(self, server):
def __init__(self, server: HTTPServer):
self.server = server
self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach"
self.workloads = {}
self.on_notify = None
self.workloads: dict[TenantId, Any] = {}
self.on_notify: Optional[Callable[[Any], None]] = None
def register_workload(self, workload):
def register_workload(self, workload: Any):
self.workloads[workload.tenant_id] = workload
def register_on_notify(self, fn):
def register_on_notify(self, fn: Optional[Callable[[Any], None]]):
"""
Add some extra work during a notification, like sleeping to slow things down, or
logging what was notified.
@@ -30,7 +34,7 @@ class ComputeReconfigure:
@pytest.fixture(scope="function")
def compute_reconfigure_listener(make_httpserver):
def compute_reconfigure_listener(make_httpserver: HTTPServer):
"""
This fixture exposes an HTTP listener for the storage controller to submit
compute notifications to us, instead of updating neon_local endpoints itself.
@@ -48,7 +52,7 @@ def compute_reconfigure_listener(make_httpserver):
# accept a healthy rate of calls into notify-attach.
reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def handler(request: Request):
def handler(request: Request) -> Response:
assert request.json is not None
body: dict[str, Any] = request.json
log.info(f"notify-attach request: {body}")

View File

@@ -14,8 +14,10 @@ from allure_pytest.utils import allure_name, allure_suite_labels
from fixtures.log_helper import log
if TYPE_CHECKING:
from collections.abc import MutableMapping
from typing import Any
"""
The plugin reruns flaky tests.
It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py`

View File

@@ -1,8 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from pytest_httpserver import HTTPServer
if TYPE_CHECKING:
from collections.abc import Iterator
from fixtures.port_distributor import PortDistributor
# TODO: mypy fails with:
# Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined]
# from fixtures.neon_fixtures import PortDistributor
@@ -17,7 +24,7 @@ def httpserver_ssl_context():
@pytest.fixture(scope="function")
def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]:
host, port = httpserver_listen_address
if not host:
host = HTTPServer.DEFAULT_LISTEN_HOST
@@ -33,13 +40,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
@pytest.fixture(scope="function")
def httpserver(make_httpserver):
def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]:
server = make_httpserver
yield server
server.clear()
@pytest.fixture(scope="function")
def httpserver_listen_address(port_distributor) -> tuple[str, int]:
def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]:
port = port_distributor.get_port()
return ("localhost", port)

View File

@@ -31,7 +31,7 @@ LOGGING = {
}
def getLogger(name="root") -> logging.Logger:
def getLogger(name: str = "root") -> logging.Logger:
"""Method to get logger for tests.
Should be used to get correctly initialized logger."""

View File

@@ -22,7 +22,7 @@ class Metrics:
def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]:
filter = filter or {}
res = []
res: list[Sample] = []
for sample in self.metrics[name]:
try:
@@ -59,7 +59,7 @@ class MetricsGetter:
return results[0].value
def get_metrics_values(
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok=False
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False
) -> dict[str, float]:
"""
When fetching multiple named metrics, it is more efficient to use this

View File

@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast
import requests
if TYPE_CHECKING:
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Optional
from fixtures.pg_version import PgVersion
@@ -25,9 +25,7 @@ class NeonAPI:
self.__neon_api_key = neon_api_key
self.__neon_api_base_url = neon_api_base_url.strip("/")
def __request(
self, method: Union[str, bytes], endpoint: str, **kwargs: Any
) -> requests.Response:
def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response:
if "headers" not in kwargs:
kwargs["headers"] = {}
kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}"

View File

@@ -18,7 +18,6 @@ from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from fcntl import LOCK_EX, LOCK_UN, flock
from functools import cached_property
from pathlib import Path
from types import TracebackType
@@ -59,6 +58,7 @@ from fixtures.pageserver.http import PageserverHttpClient
from fixtures.pageserver.utils import (
wait_for_last_record_lsn,
)
from fixtures.paths import get_test_repo_dir, shared_snapshot_dir
from fixtures.pg_version import PgVersion
from fixtures.port_distributor import PortDistributor
from fixtures.remote_storage import (
@@ -75,8 +75,8 @@ from fixtures.safekeeper.http import SafekeeperHttpClient
from fixtures.safekeeper.utils import wait_walreceivers_absent
from fixtures.utils import (
ATTACHMENT_NAME_REGEX,
COMPONENT_BINARIES,
allure_add_grafana_links,
allure_attach_from_dir,
assert_no_errors,
get_dir_size,
print_gc_result,
@@ -96,6 +96,8 @@ if TYPE_CHECKING:
Union,
)
from fixtures.paths import SnapshotDirLocked
T = TypeVar("T")
@@ -118,65 +120,11 @@ put directly-importable functions into utils.py or another separate file.
Env = dict[str, str]
DEFAULT_OUTPUT_DIR: str = "test_output"
DEFAULT_BRANCH_NAME: str = "main"
BASE_PORT: int = 15000
@pytest.fixture(scope="session")
def base_dir() -> Iterator[Path]:
# find the base directory (currently this is the git root)
base_dir = Path(__file__).parents[2]
log.info(f"base_dir is {base_dir}")
yield base_dir
@pytest.fixture(scope="function")
def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]:
if os.getenv("REMOTE_ENV"):
# we are in remote env and do not have neon binaries locally
# this is the case for benchmarks run on self-hosted runner
return
# Find the neon binaries.
if env_neon_bin := os.environ.get("NEON_BIN"):
binpath = Path(env_neon_bin)
else:
binpath = base_dir / "target" / build_type
log.info(f"neon_binpath is {binpath}")
if not (binpath / "pageserver").exists():
raise Exception(f"neon binaries not found at '{binpath}'")
yield binpath
@pytest.fixture(scope="session")
def pg_distrib_dir(base_dir: Path) -> Iterator[Path]:
if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"):
distrib_dir = Path(env_postgres_bin).resolve()
else:
distrib_dir = base_dir / "pg_install"
log.info(f"pg_distrib_dir is {distrib_dir}")
yield distrib_dir
@pytest.fixture(scope="session")
def top_output_dir(base_dir: Path) -> Iterator[Path]:
# Compute the top-level directory for all tests.
if env_test_output := os.environ.get("TEST_OUTPUT"):
output_dir = Path(env_test_output).resolve()
else:
output_dir = base_dir / DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
log.info(f"top_output_dir is {output_dir}")
yield output_dir
@pytest.fixture(scope="session")
def neon_api_key() -> str:
api_key = os.getenv("NEON_API_KEY")
@@ -369,11 +317,14 @@ class NeonEnvBuilder:
run_id: uuid.UUID,
mock_s3_server: MockS3Server,
neon_binpath: Path,
compatibility_neon_binpath: Path,
pg_distrib_dir: Path,
compatibility_pg_distrib_dir: Path,
pg_version: PgVersion,
test_name: str,
top_output_dir: Path,
test_output_dir: Path,
combination,
test_overlay_dir: Optional[Path] = None,
pageserver_remote_storage: Optional[RemoteStorage] = None,
# toml that will be decomposed into `--config-override` flags during `pageserver --init`
@@ -455,6 +406,19 @@ class NeonEnvBuilder:
"test_"
), "Unexpectedly instantiated from outside a test function"
self.test_name = test_name
self.compatibility_neon_binpath = compatibility_neon_binpath
self.compatibility_pg_distrib_dir = compatibility_pg_distrib_dir
self.version_combination = combination
self.mixdir = self.test_output_dir / "mixdir_neon"
if self.version_combination is not None:
assert (
self.compatibility_neon_binpath is not None
), "the environment variable COMPATIBILITY_NEON_BIN is required when using mixed versions"
assert (
self.compatibility_pg_distrib_dir is not None
), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required when using mixed versions"
self.mixdir.mkdir(mode=0o755, exist_ok=True)
self._mix_versions()
def init_configs(self, default_remote_storage_if_missing: bool = True) -> NeonEnv:
# Cannot create more than one environment from one builder
@@ -655,6 +619,21 @@ class NeonEnvBuilder:
return self.env
def _mix_versions(self):
assert self.version_combination is not None, "version combination must be set"
for component, paths in COMPONENT_BINARIES.items():
directory = (
self.neon_binpath
if self.version_combination[component] == "new"
else self.compatibility_neon_binpath
)
for filename in paths:
destination = self.mixdir / filename
destination.symlink_to(directory / filename)
if self.version_combination["compute"] == "old":
self.pg_distrib_dir = self.compatibility_pg_distrib_dir
self.neon_binpath = self.mixdir
def overlay_mount(self, ident: str, srcdir: Path, dstdir: Path):
"""
Mount `srcdir` as an overlayfs mount at `dstdir`.
@@ -1403,7 +1382,9 @@ def neon_simple_env(
top_output_dir: Path,
test_output_dir: Path,
neon_binpath: Path,
compatibility_neon_binpath: Path,
pg_distrib_dir: Path,
compatibility_pg_distrib_dir: Path,
pg_version: PgVersion,
pageserver_virtual_file_io_engine: str,
pageserver_aux_file_policy: Optional[AuxFileStore],
@@ -1418,6 +1399,11 @@ def neon_simple_env(
# Create the environment in the per-test output directory
repo_dir = get_test_repo_dir(request, top_output_dir)
combination = (
request._pyfuncitem.callspec.params["combination"]
if "combination" in request._pyfuncitem.callspec.params
else None
)
with NeonEnvBuilder(
top_output_dir=top_output_dir,
@@ -1425,7 +1411,9 @@ def neon_simple_env(
port_distributor=port_distributor,
mock_s3_server=mock_s3_server,
neon_binpath=neon_binpath,
compatibility_neon_binpath=compatibility_neon_binpath,
pg_distrib_dir=pg_distrib_dir,
compatibility_pg_distrib_dir=compatibility_pg_distrib_dir,
pg_version=pg_version,
run_id=run_id,
preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")),
@@ -1435,6 +1423,7 @@ def neon_simple_env(
pageserver_aux_file_policy=pageserver_aux_file_policy,
pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm,
pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode,
combination=combination,
) as builder:
env = builder.init_start()
@@ -1448,7 +1437,9 @@ def neon_env_builder(
port_distributor: PortDistributor,
mock_s3_server: MockS3Server,
neon_binpath: Path,
compatibility_neon_binpath: Path,
pg_distrib_dir: Path,
compatibility_pg_distrib_dir: Path,
pg_version: PgVersion,
run_id: uuid.UUID,
request: FixtureRequest,
@@ -1475,6 +1466,11 @@ def neon_env_builder(
# Create the environment in the test-specific output dir
repo_dir = os.path.join(test_output_dir, "repo")
combination = (
request._pyfuncitem.callspec.params["combination"]
if "combination" in request._pyfuncitem.callspec.params
else None
)
# Return the builder to the caller
with NeonEnvBuilder(
@@ -1483,7 +1479,10 @@ def neon_env_builder(
port_distributor=port_distributor,
mock_s3_server=mock_s3_server,
neon_binpath=neon_binpath,
compatibility_neon_binpath=compatibility_neon_binpath,
pg_distrib_dir=pg_distrib_dir,
compatibility_pg_distrib_dir=compatibility_pg_distrib_dir,
combination=combination,
pg_version=pg_version,
run_id=run_id,
preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")),
@@ -3657,7 +3656,7 @@ class Endpoint(PgProtocol, LogUtils):
config_lines: Optional[list[str]] = None,
remote_ext_config: Optional[str] = None,
pageserver_id: Optional[int] = None,
allow_multiple=False,
allow_multiple: bool = False,
basebackup_request_tries: Optional[int] = None,
) -> Endpoint:
"""
@@ -3998,7 +3997,7 @@ class Safekeeper(LogUtils):
def timeline_dir(self, tenant_id, timeline_id) -> Path:
return self.data_dir / str(tenant_id) / str(timeline_id)
# List partial uploaded segments of this safekeeper. Works only for
# list partial uploaded segments of this safekeeper. Works only for
# RemoteStorageKind.LOCAL_FS.
def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId):
tline_path = (
@@ -4246,44 +4245,6 @@ class StorageScrubber:
raise
def _get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str) -> Path:
"""Compute the path to a working directory for an individual test."""
test_name = request.node.name
test_dir = top_output_dir / f"{prefix}{test_name.replace('/', '-')}"
# We rerun flaky tests multiple times, use a separate directory for each run.
if (suffix := getattr(request.node, "execution_count", None)) is not None:
test_dir = test_dir.parent / f"{test_dir.name}-{suffix}"
log.info(f"get_test_output_dir is {test_dir}")
# make mypy happy
assert isinstance(test_dir, Path)
return test_dir
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
"""
The working directory for a test.
"""
return _get_test_dir(request, top_output_dir, "")
def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
"""
Directory that contains `upperdir` and `workdir` for overlayfs mounts
that a test creates. See `NeonEnvBuilder.overlay_mount`.
"""
return _get_test_dir(request, top_output_dir, "overlay-")
def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path:
return top_output_dir / "shared-snapshots" / snapshot_name
def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
return get_test_output_dir(request, top_output_dir) / "repo"
def pytest_addoption(parser: Parser):
parser.addoption(
"--preserve-database-files",
@@ -4293,154 +4254,11 @@ def pytest_addoption(parser: Parser):
)
SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile(
r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)"
)
# This is autouse, so the test output directory always gets created, even
# if a test doesn't put anything there.
#
# NB: we request the overlay dir fixture so the fixture does its cleanups
@pytest.fixture(scope="function", autouse=True)
def test_output_dir(
request: FixtureRequest, top_output_dir: Path, test_overlay_dir: Path
) -> Iterator[Path]:
"""Create the working directory for an individual test."""
# one directory per test
test_dir = get_test_output_dir(request, top_output_dir)
log.info(f"test_output_dir is {test_dir}")
shutil.rmtree(test_dir, ignore_errors=True)
test_dir.mkdir()
yield test_dir
# Allure artifacts creation might involve the creation of `.tar.zst` archives,
# which aren't going to be used if Allure results collection is not enabled
# (i.e. --alluredir is not set).
# Skip `allure_attach_from_dir` in this case
if not request.config.getoption("--alluredir"):
return
preserve_database_files = False
for k, v in request.node.user_properties:
# NB: the neon_env_builder fixture uses this fixture (test_output_dir).
# So, neon_env_builder's cleanup runs before here.
# The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property.
if k == "preserve_database_files":
assert isinstance(v, bool)
preserve_database_files = v
allure_attach_from_dir(test_dir, preserve_database_files)
class FileAndThreadLock:
def __init__(self, path: Path):
self.path = path
self.thread_lock = threading.Lock()
self.fd: Optional[int] = None
def __enter__(self):
self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY)
# lock thread lock before file lock so that there's no race
# around flocking / funlocking the file lock
self.thread_lock.acquire()
flock(self.fd, LOCK_EX)
def __exit__(self, exc_type, exc_value, exc_traceback):
assert self.fd is not None
assert self.thread_lock.locked() # ... by us
flock(self.fd, LOCK_UN)
self.thread_lock.release()
os.close(self.fd)
self.fd = None
class SnapshotDirLocked:
def __init__(self, parent: SnapshotDir):
self._parent = parent
def is_initialized(self):
# TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized.
# Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed.
return self._parent._marker_file_path.exists()
def set_initialized(self):
self._parent._marker_file_path.write_text("")
@property
def path(self) -> Path:
return self._parent._path / "snapshot"
class SnapshotDir:
_path: Path
def __init__(self, path: Path):
self._path = path
assert self._path.is_dir()
self._lock = FileAndThreadLock(self._lock_file_path)
@property
def _lock_file_path(self) -> Path:
return self._path / "initializing.flock"
@property
def _marker_file_path(self) -> Path:
return self._path / "initialized.marker"
def __enter__(self) -> SnapshotDirLocked:
self._lock.__enter__()
return SnapshotDirLocked(self)
def __exit__(self, exc_type, exc_value, exc_traceback):
self._lock.__exit__(exc_type, exc_value, exc_traceback)
def shared_snapshot_dir(top_output_dir, ident: str) -> SnapshotDir:
snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident)
snapshot_dir_path.mkdir(exist_ok=True, parents=True)
return SnapshotDir(snapshot_dir_path)
@pytest.fixture(scope="function")
def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]:
"""
Idempotently create a test's overlayfs mount state directory.
If the functionality isn't enabled via env var, returns None.
The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc).
"""
if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None:
return None
overlay_dir = get_test_overlay_dir(request, top_output_dir)
log.info(f"test_overlay_dir is {overlay_dir}")
overlay_dir.mkdir(exist_ok=True)
# unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir`
for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)):
cmd = ["sudo", "umount", str(mountpoint)]
log.info(
f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}"
)
subprocess.run(cmd, capture_output=True, check=True)
# the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work.
cmd = ["sudo", "rm", "-rf", str(overlay_dir)]
subprocess.run(cmd, capture_output=True, check=True)
overlay_dir.mkdir()
return overlay_dir
# no need to clean up anything: on clean shutdown,
# NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup
# and on unclean shutdown, this function will take care of it
# on the next test run
SKIP_DIRS = frozenset(
(
"pg_wal",

View File

@@ -1,10 +1,13 @@
from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING
import psutil
if TYPE_CHECKING:
from collections.abc import Iterator
def iter_mounts_beneath(topdir: Path) -> Iterator[Path]:
"""

View File

@@ -886,7 +886,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
self,
tenant_id: Union[TenantId, TenantShardId],
timeline_id: TimelineId,
batch_size: int | None = None,
batch_size: Optional[int] = None,
**kwargs,
) -> set[TimelineId]:
params = {}

View File

@@ -9,7 +9,12 @@ import toml
from _pytest.python import Metafunc
from fixtures.pg_version import PgVersion
from fixtures.utils import AuxFileStore
if TYPE_CHECKING:
from typing import Any, Optional
from fixtures.utils import AuxFileStore
if TYPE_CHECKING:
from typing import Any, Optional

View File

@@ -0,0 +1,312 @@
from __future__ import annotations
import os
import shutil
import subprocess
import threading
from fcntl import LOCK_EX, LOCK_UN, flock
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING
import pytest
from pytest import FixtureRequest
from fixtures import overlayfs
from fixtures.log_helper import log
from fixtures.utils import allure_attach_from_dir
if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Optional
DEFAULT_OUTPUT_DIR: str = "test_output"
def get_test_dir(
request: FixtureRequest, top_output_dir: Path, prefix: Optional[str] = None
) -> Path:
"""Compute the path to a working directory for an individual test."""
test_name = request.node.name
test_dir = top_output_dir / f"{prefix or ''}{test_name.replace('/', '-')}"
# We rerun flaky tests multiple times, use a separate directory for each run.
if (suffix := getattr(request.node, "execution_count", None)) is not None:
test_dir = test_dir.parent / f"{test_dir.name}-{suffix}"
return test_dir
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
"""
The working directory for a test.
"""
return get_test_dir(request, top_output_dir)
def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
"""
Directory that contains `upperdir` and `workdir` for overlayfs mounts
that a test creates. See `NeonEnvBuilder.overlay_mount`.
"""
return get_test_dir(request, top_output_dir, "overlay-")
def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path:
return top_output_dir / "shared-snapshots" / snapshot_name
def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
return get_test_output_dir(request, top_output_dir) / "repo"
@pytest.fixture(scope="session")
def base_dir() -> Iterator[Path]:
# find the base directory (currently this is the git root)
base_dir = Path(__file__).parents[2]
log.info(f"base_dir is {base_dir}")
yield base_dir
@pytest.fixture(scope="session")
def compute_config_dir(base_dir: Path) -> Iterator[Path]:
"""
Retrieve the path to the compute configuration directory.
"""
yield base_dir / "compute" / "etc"
@pytest.fixture(scope="function")
def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]:
if os.getenv("REMOTE_ENV"):
# we are in remote env and do not have neon binaries locally
# this is the case for benchmarks run on self-hosted runner
return
# Find the neon binaries.
if env_neon_bin := os.environ.get("NEON_BIN"):
binpath = Path(env_neon_bin)
else:
binpath = base_dir / "target" / build_type
log.info(f"neon_binpath is {binpath}")
if not (binpath / "pageserver").exists():
raise Exception(f"neon binaries not found at '{binpath}'")
yield binpath.absolute()
@pytest.fixture(scope="session")
def compatibility_snapshot_dir() -> Iterator[Path]:
if os.getenv("REMOTE_ENV"):
return
compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR")
assert (
compatibility_snapshot_dir_env is not None
), "COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg(PG_VERSION)` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)"
compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve()
yield compatibility_snapshot_dir
@pytest.fixture(scope="session")
def compatibility_neon_binpath() -> Optional[Iterator[Path]]:
if os.getenv("REMOTE_ENV"):
return
comp_binpath = None
if env_compatibility_neon_binpath := os.environ.get("COMPATIBILITY_NEON_BIN"):
comp_binpath = Path(env_compatibility_neon_binpath).resolve().absolute()
yield comp_binpath
@pytest.fixture(scope="session")
def pg_distrib_dir(base_dir: Path) -> Iterator[Path]:
if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"):
distrib_dir = Path(env_postgres_bin).resolve()
else:
distrib_dir = base_dir / "pg_install"
log.info(f"pg_distrib_dir is {distrib_dir}")
yield distrib_dir
@pytest.fixture(scope="session")
def compatibility_pg_distrib_dir() -> Optional[Iterator[Path]]:
compat_distrib_dir = None
if env_compat_postgres_bin := os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR"):
compat_distrib_dir = Path(env_compat_postgres_bin).resolve()
if not compat_distrib_dir.exists():
raise Exception(f"compatibility postgres directory not found at {compat_distrib_dir}")
if compat_distrib_dir:
log.info(f"compatibility_pg_distrib_dir is {compat_distrib_dir}")
yield compat_distrib_dir
@pytest.fixture(scope="session")
def top_output_dir(base_dir: Path) -> Iterator[Path]:
# Compute the top-level directory for all tests.
if env_test_output := os.environ.get("TEST_OUTPUT"):
output_dir = Path(env_test_output).resolve()
else:
output_dir = base_dir / DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
log.info(f"top_output_dir is {output_dir}")
yield output_dir
# This is autouse, so the test output directory always gets created, even
# if a test doesn't put anything there.
#
# NB: we request the overlay dir fixture so the fixture does its cleanups
@pytest.fixture(scope="function", autouse=True)
def test_output_dir(request: pytest.FixtureRequest, top_output_dir: Path) -> Iterator[Path]:
"""Create the working directory for an individual test."""
# one directory per test
test_dir = get_test_output_dir(request, top_output_dir)
log.info(f"test_output_dir is {test_dir}")
shutil.rmtree(test_dir, ignore_errors=True)
test_dir.mkdir()
yield test_dir
# Allure artifacts creation might involve the creation of `.tar.zst` archives,
# which aren't going to be used if Allure results collection is not enabled
# (i.e. --alluredir is not set).
# Skip `allure_attach_from_dir` in this case
if not request.config.getoption("--alluredir"):
return
preserve_database_files = False
for k, v in request.node.user_properties:
# NB: the neon_env_builder fixture uses this fixture (test_output_dir).
# So, neon_env_builder's cleanup runs before here.
# The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property.
if k == "preserve_database_files":
assert isinstance(v, bool)
preserve_database_files = v
allure_attach_from_dir(test_dir, preserve_database_files)
class FileAndThreadLock:
def __init__(self, path: Path):
self.path = path
self.thread_lock = threading.Lock()
self.fd: Optional[int] = None
def __enter__(self):
self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY)
# lock thread lock before file lock so that there's no race
# around flocking / funlocking the file lock
self.thread_lock.acquire()
flock(self.fd, LOCK_EX)
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
exc_traceback: Optional[TracebackType],
):
assert self.fd is not None
assert self.thread_lock.locked() # ... by us
flock(self.fd, LOCK_UN)
self.thread_lock.release()
os.close(self.fd)
self.fd = None
class SnapshotDirLocked:
def __init__(self, parent: SnapshotDir):
self._parent = parent
def is_initialized(self):
# TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized.
# Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed.
return self._parent.marker_file_path.exists()
def set_initialized(self):
self._parent.marker_file_path.write_text("")
@property
def path(self) -> Path:
return self._parent.path / "snapshot"
class SnapshotDir:
_path: Path
def __init__(self, path: Path):
self._path = path
assert self._path.is_dir()
self._lock = FileAndThreadLock(self.lock_file_path)
@property
def path(self) -> Path:
return self._path
@property
def lock_file_path(self) -> Path:
return self._path / "initializing.flock"
@property
def marker_file_path(self) -> Path:
return self._path / "initialized.marker"
def __enter__(self) -> SnapshotDirLocked:
self._lock.__enter__()
return SnapshotDirLocked(self)
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
exc_traceback: Optional[TracebackType],
):
self._lock.__exit__(exc_type, exc_value, exc_traceback)
def shared_snapshot_dir(top_output_dir: Path, ident: str) -> SnapshotDir:
snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident)
snapshot_dir_path.mkdir(exist_ok=True, parents=True)
return SnapshotDir(snapshot_dir_path)
@pytest.fixture(scope="function")
def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]:
"""
Idempotently create a test's overlayfs mount state directory.
If the functionality isn't enabled via env var, returns None.
The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc).
"""
if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None:
return None
overlay_dir = get_test_overlay_dir(request, top_output_dir)
log.info(f"test_overlay_dir is {overlay_dir}")
overlay_dir.mkdir(exist_ok=True)
# unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir`
for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)):
cmd = ["sudo", "umount", str(mountpoint)]
log.info(
f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}"
)
subprocess.run(cmd, capture_output=True, check=True)
# the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work.
cmd = ["sudo", "rm", "-rf", str(overlay_dir)]
subprocess.run(cmd, capture_output=True, check=True)
overlay_dir.mkdir()
return overlay_dir
# no need to clean up anything: on clean shutdown,
# NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup
# and on unclean shutdown, this function will take care of it
# on the next test run

View File

@@ -2,9 +2,14 @@ from __future__ import annotations
import enum
import os
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from typing_extensions import override
if TYPE_CHECKING:
from typing import Optional
"""
This fixture is used to determine which version of Postgres to use for tests.
@@ -24,10 +29,12 @@ class PgVersion(str, enum.Enum):
NOT_SET = "<-POSTRGRES VERSION IS NOT SET->"
# Make it less confusing in logs
@override
def __repr__(self) -> str:
return f"'{self.value}'"
# Make this explicit for Python 3.11 compatibility, which changes the behavior of enums
@override
def __str__(self) -> str:
return self.value
@@ -38,7 +45,8 @@ class PgVersion(str, enum.Enum):
return f"v{self.value}"
@classmethod
def _missing_(cls, value) -> Optional[PgVersion]:
@override
def _missing_(cls, value: object) -> Optional[PgVersion]:
known_values = {v.value for _, v in cls.__members__.items()}
# Allow passing version as a string with "v" prefix (e.g. "v14")

View File

@@ -59,10 +59,7 @@ class PortDistributor:
if isinstance(value, int):
return self._replace_port_int(value)
if isinstance(value, str):
return self._replace_port_str(value)
raise TypeError(f"unsupported type {type(value)} of {value=}")
return self._replace_port_str(value)
def _replace_port_int(self, value: int) -> int:
known_port = self.port_map.get(value)
@@ -75,7 +72,7 @@ class PortDistributor:
# Use regex to find port in a string
# urllib.parse.urlparse produces inconvenient results for cases without scheme like "localhost:5432"
# See https://bugs.python.org/issue27657
ports = re.findall(r":(\d+)(?:/|$)", value)
ports: list[str] = re.findall(r":(\d+)(?:/|$)", value)
assert len(ports) == 1, f"can't find port in {value}"
port_int = int(ports[0])

View File

@@ -13,6 +13,7 @@ import boto3
import toml
from moto.server import ThreadedMotoServer
from mypy_boto3_s3 import S3Client
from typing_extensions import override
from fixtures.common_types import TenantId, TenantShardId, TimelineId
from fixtures.log_helper import log
@@ -36,6 +37,7 @@ class RemoteStorageUser(str, enum.Enum):
EXTENSIONS = "ext"
SAFEKEEPER = "safekeeper"
@override
def __str__(self) -> str:
return self.value
@@ -81,11 +83,13 @@ class LocalFsStorage:
def timeline_path(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path:
return self.tenant_path(tenant_id) / "timelines" / str(timeline_id)
def timeline_latest_generation(self, tenant_id, timeline_id):
def timeline_latest_generation(
self, tenant_id: TenantId, timeline_id: TimelineId
) -> Optional[int]:
timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id))
index_parts = [f for f in timeline_files if f.startswith("index_part")]
def parse_gen(filename):
def parse_gen(filename: str) -> Optional[int]:
log.info(f"parsing index_part '{filename}'")
parts = filename.split("-")
if len(parts) == 2:
@@ -93,7 +97,7 @@ class LocalFsStorage:
else:
return None
generations = sorted([parse_gen(f) for f in index_parts])
generations = sorted([parse_gen(f) for f in index_parts]) # type: ignore
if len(generations) == 0:
raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}")
return generations[-1]
@@ -122,14 +126,14 @@ class LocalFsStorage:
filename = f"{local_name}-{generation:08x}"
return self.timeline_path(tenant_id, timeline_id) / filename
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId):
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId) -> Any:
with self.index_path(tenant_id, timeline_id).open("r") as f:
return json.load(f)
def heatmap_path(self, tenant_id: TenantId) -> Path:
return self.tenant_path(tenant_id) / TENANT_HEATMAP_FILE_NAME
def heatmap_content(self, tenant_id):
def heatmap_content(self, tenant_id: TenantId) -> Any:
with self.heatmap_path(tenant_id).open("r") as f:
return json.load(f)
@@ -297,7 +301,7 @@ class S3Storage:
def heatmap_key(self, tenant_id: TenantId) -> str:
return f"{self.tenant_path(tenant_id)}/{TENANT_HEATMAP_FILE_NAME}"
def heatmap_content(self, tenant_id: TenantId):
def heatmap_content(self, tenant_id: TenantId) -> Any:
r = self.client.get_object(Bucket=self.bucket_name, Key=self.heatmap_key(tenant_id))
return json.loads(r["Body"].read().decode("utf-8"))
@@ -317,7 +321,7 @@ class RemoteStorageKind(str, enum.Enum):
def configure(
self,
repo_dir: Path,
mock_s3_server,
mock_s3_server: MockS3Server,
run_id: str,
test_name: str,
user: RemoteStorageUser,
@@ -451,15 +455,9 @@ def default_remote_storage() -> RemoteStorageKind:
def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> dict[str, Any]:
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
raise Exception("invalid remote storage type")
return remote_storage.to_toml_dict()
# serialize as toml inline table
def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str:
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
raise Exception("invalid remote storage type")
return remote_storage.to_toml_inline_table()

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import re
from typing import Any, Optional
from typing import TYPE_CHECKING
import pytest
import requests
@@ -12,6 +12,9 @@ from werkzeug.wrappers.response import Response
from fixtures.log_helper import log
if TYPE_CHECKING:
from typing import Any, Optional
class StorageControllerProxy:
def __init__(self, server: HTTPServer):
@@ -34,7 +37,7 @@ def proxy_request(method: str, url: str, **kwargs) -> requests.Response:
@pytest.fixture(scope="function")
def storage_controller_proxy(make_httpserver):
def storage_controller_proxy(make_httpserver: HTTPServer):
"""
Proxies requests into the storage controller to the currently
selected storage controller instance via `StorageControllerProxy.route_to`.
@@ -48,7 +51,7 @@ def storage_controller_proxy(make_httpserver):
log.info(f"Storage controller proxy listening on {self.listen}")
def handler(request: Request):
def handler(request: Request) -> Response:
if self.route_to is None:
log.info(f"Storage controller proxy has no routing configured for {request.url}")
return Response("Routing not configured", status=503)

View File

@@ -18,6 +18,7 @@ from urllib.parse import urlencode
import allure
import zstandard
from psycopg2.extensions import cursor
from typing_extensions import override
from fixtures.log_helper import log
from fixtures.pageserver.common_types import (
@@ -26,28 +27,45 @@ from fixtures.pageserver.common_types import (
)
if TYPE_CHECKING:
from typing import (
IO,
Optional,
Union,
)
from collections.abc import Iterable
from typing import IO, Optional
from fixtures.common_types import TimelineId
from fixtures.neon_fixtures import PgBin
from fixtures.common_types import TimelineId
WaitUntilRet = TypeVar("WaitUntilRet")
Fn = TypeVar("Fn", bound=Callable[..., Any])
COMPONENT_BINARIES = {
"storage_controller": ("storage_controller",),
"storage_broker": ("storage_broker",),
"compute": ("compute_ctl",),
"safekeeper": ("safekeeper",),
"pageserver": ("pageserver", "pagectl"),
}
# Disable auto-formatting for better readability
# fmt: off
VERSIONS_COMBINATIONS = (
{"storage_controller": "new", "storage_broker": "new", "compute": "new", "safekeeper": "new", "pageserver": "new"},
{"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "old"},
{"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "new"},
{"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "new", "pageserver": "new"},
{"storage_controller": "old", "storage_broker": "old", "compute": "new", "safekeeper": "new", "pageserver": "new"},
)
# fmt: on
def subprocess_capture(
capture_dir: Path,
cmd: list[str],
*,
check=False,
echo_stderr=False,
echo_stdout=False,
capture_stdout=False,
timeout=None,
with_command_header=True,
check: bool = False,
echo_stderr: bool = False,
echo_stdout: bool = False,
capture_stdout: bool = False,
timeout: Optional[float] = None,
with_command_header: bool = True,
**popen_kwargs: Any,
) -> tuple[str, Optional[str], int]:
"""Run a process and bifurcate its output to files and the `log` logger
@@ -84,6 +102,7 @@ def subprocess_capture(
self.capture = capture
self.captured = ""
@override
def run(self):
first = with_command_header
for line in self.in_file:
@@ -165,10 +184,10 @@ def global_counter() -> int:
def print_gc_result(row: dict[str, Any]):
log.info("GC duration {elapsed} ms".format_map(row))
log.info(
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map(
row
)
(
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}"
).format_map(row)
)
@@ -226,7 +245,7 @@ def get_scale_for_db(size_mb: int) -> int:
return round(0.06689 * size_mb - 0.5)
ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile(
r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)"
)
@@ -289,7 +308,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz"
def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int):
"""Add links to server logs in Grafana to Allure report"""
links = {}
links: dict[str, str] = {}
# We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build
endpoint_id, region_id, _ = host.split(".", 2)
@@ -341,7 +360,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int,
def start_in_background(
command: list[str], cwd: Path, log_file_name: str, is_started: Fn
command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet]
) -> subprocess.Popen[bytes]:
"""Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors."""
@@ -376,14 +395,11 @@ def start_in_background(
return spawned_process
WaitUntilRet = TypeVar("WaitUntilRet")
def wait_until(
number_of_iterations: int,
interval: float,
func: Callable[[], WaitUntilRet],
show_intermediate_error=False,
show_intermediate_error: bool = False,
) -> WaitUntilRet:
"""
Wait until 'func' returns successfully, without exception. Returns the
@@ -464,7 +480,7 @@ def humantime_to_ms(humantime: str) -> float:
def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]:
# FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py
error_or_warn = re.compile(r"\s(ERROR|WARN)")
errors = []
errors: list[tuple[int, str]] = []
for lineno, line in enumerate(input, start=1):
if len(line) == 0:
continue
@@ -484,7 +500,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list
return errors
def assert_no_errors(log_file, service, allowed_errors):
def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]):
if not log_file.exists():
log.warning(f"Skipping {service} log check: {log_file} does not exist")
return
@@ -504,9 +520,11 @@ class AuxFileStore(str, enum.Enum):
V2 = "v2"
CrossValidation = "cross-validation"
@override
def __repr__(self) -> str:
return f"'aux-{self.value}'"
@override
def __str__(self) -> str:
return f"'aux-{self.value}'"
@@ -525,7 +543,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
"""
started_at = time.time()
def hash_extracted(reader: Union[IO[bytes], None]) -> bytes:
def hash_extracted(reader: Optional[IO[bytes]]) -> bytes:
assert reader is not None
digest = sha256(usedforsecurity=False)
while True:
@@ -550,7 +568,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
right_list
), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}"
mismatching = set()
mismatching: set[str] = set()
for left_tuple, right_tuple in zip(left_list, right_list):
left_path, left_hash = left_tuple
@@ -575,6 +593,7 @@ class PropagatingThread(threading.Thread):
Simple Thread wrapper with join() propagating the possible exception in the thread.
"""
@override
def run(self):
self.exc = None
try:
@@ -582,7 +601,8 @@ class PropagatingThread(threading.Thread):
except BaseException as e:
self.exc = e
def join(self, timeout=None):
@override
def join(self, timeout: Optional[float] = None) -> Any:
super().join(timeout)
if self.exc:
raise self.exc
@@ -604,3 +624,19 @@ def human_bytes(amt: float) -> str:
amt = amt / 1024
raise RuntimeError("unreachable")
def allpairs_versions():
"""
Returns a dictionary with arguments for pytest parametrize
to test the compatibility with the previous version of Neon components
combinations were pre-computed to test all the pairs of the components with
the different versions.
"""
ids = []
for pair in VERSIONS_COMBINATIONS:
cur_id = []
for component in sorted(pair.keys()):
cur_id.append(pair[component][0])
ids.append(f"combination_{''.join(cur_id)}")
return {"argnames": "combination", "argvalues": VERSIONS_COMBINATIONS, "ids": ids}

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import threading
from typing import Any, Optional
from typing import TYPE_CHECKING
from fixtures.common_types import TenantId, TimelineId
from fixtures.log_helper import log
@@ -14,6 +14,9 @@ from fixtures.neon_fixtures import (
)
from fixtures.pageserver.utils import wait_for_last_record_lsn
if TYPE_CHECKING:
from typing import Any, Optional
# neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex
# to ensure we don't do that: this enables running lots of Workloads in parallel safely.
ENDPOINT_LOCK = threading.Lock()
@@ -100,7 +103,7 @@ class Workload:
self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id
)
def write_rows(self, n, pageserver_id: Optional[int] = None, upload: bool = True):
def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True):
endpoint = self.endpoint(pageserver_id)
start = self.expect_rows
end = start + n - 1
@@ -121,7 +124,9 @@ class Workload:
else:
return False
def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True):
def churn_rows(
self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True
):
assert self.expect_rows >= n
max_iters = 10

View File

@@ -4,7 +4,7 @@ import enum
import json
import os
import time
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from fixtures.log_helper import log
@@ -16,6 +16,10 @@ from fixtures.pageserver.http import PageserverApiException
from fixtures.utils import wait_until
from fixtures.workload import Workload
if TYPE_CHECKING:
from typing import Optional
AGGRESIVE_COMPACTION_TENANT_CONF = {
# Disable gc and compaction. The test runs compaction manually.
"gc_period": "0s",

View File

@@ -9,6 +9,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import fixtures.utils
import pytest
import toml
from fixtures.common_types import TenantId, TimelineId
@@ -93,6 +94,34 @@ if TYPE_CHECKING:
# # Run forward compatibility test
# ./scripts/pytest -k test_forward_compatibility
#
#
# How to run `test_version_mismatch` locally:
#
# export DEFAULT_PG_VERSION=16
# export BUILD_TYPE=release
# export CHECK_ONDISK_DATA_COMPATIBILITY=true
# export COMPATIBILITY_NEON_BIN=neon_previous/target/${BUILD_TYPE}
# export COMPATIBILITY_POSTGRES_DISTRIB_DIR=neon_previous/pg_install
# export NEON_BIN=target/release
# export POSTGRES_DISTRIB_DIR=pg_install
#
# # Build previous version of binaries and store them somewhere:
# rm -rf pg_install target
# git checkout <previous version>
# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc`
# mkdir -p neon_previous/target
# cp -a target/${BUILD_TYPE} ./neon_previous/target/${BUILD_TYPE}
# cp -a pg_install ./neon_previous/pg_install
#
# # Build current version of binaries and create a data snapshot:
# rm -rf pg_install target
# git checkout <current version>
# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc`
# ./scripts/pytest -k test_create_snapshot
#
# # Run the version mismatch test
# ./scripts/pytest -k test_version_mismatch
check_ondisk_data_compatibility_if_enabled = pytest.mark.skipif(
os.environ.get("CHECK_ONDISK_DATA_COMPATIBILITY") is None,
@@ -166,16 +195,11 @@ def test_backward_compatibility(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
pg_version: PgVersion,
compatibility_snapshot_dir: Path,
):
"""
Test that the new binaries can read old data
"""
compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR")
assert (
compatibility_snapshot_dir_env is not None
), f"COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg{pg_version.v_prefixed}` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)"
compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve()
breaking_changes_allowed = (
os.environ.get("ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true"
)
@@ -214,27 +238,11 @@ def test_forward_compatibility(
test_output_dir: Path,
top_output_dir: Path,
pg_version: PgVersion,
compatibility_snapshot_dir: Path,
):
"""
Test that the old binaries can read new data
"""
compatibility_neon_bin_env = os.environ.get("COMPATIBILITY_NEON_BIN")
assert compatibility_neon_bin_env is not None, (
"COMPATIBILITY_NEON_BIN is not set. It should be set to a path with Neon binaries "
"(ideally generated by the previous version of Neon)"
)
compatibility_neon_bin = Path(compatibility_neon_bin_env).resolve()
compatibility_postgres_distrib_dir_env = os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR")
assert (
compatibility_postgres_distrib_dir_env is not None
), "COMPATIBILITY_POSTGRES_DISTRIB_DIR is not set. It should be set to a pg_install directrory (ideally generated by the previous version of Neon)"
compatibility_postgres_distrib_dir = Path(compatibility_postgres_distrib_dir_env).resolve()
compatibility_snapshot_dir = (
top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}"
)
breaking_changes_allowed = (
os.environ.get("ALLOW_FORWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true"
)
@@ -245,9 +253,14 @@ def test_forward_compatibility(
# Use previous version's production binaries (pageserver, safekeeper, pg_distrib_dir, etc.).
# But always use the current version's neon_local binary.
# This is because we want to test the compatibility of the data format, not the compatibility of the neon_local CLI.
neon_env_builder.neon_binpath = compatibility_neon_bin
neon_env_builder.pg_distrib_dir = compatibility_postgres_distrib_dir
neon_env_builder.neon_local_binpath = neon_env_builder.neon_local_binpath
assert (
neon_env_builder.compatibility_neon_binpath is not None
), "the environment variable COMPATIBILITY_NEON_BIN is required"
assert (
neon_env_builder.compatibility_pg_distrib_dir is not None
), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required"
neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath
neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir
env = neon_env_builder.from_repo_dir(
compatibility_snapshot_dir / "repo",
@@ -558,3 +571,29 @@ def test_historic_storage_formats(
env.pageserver.http_client().timeline_compact(
dataset.tenant_id, existing_timeline_id, force_image_layer_creation=True
)
@check_ondisk_data_compatibility_if_enabled
@pytest.mark.xdist_group("compatibility")
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
def test_versions_mismatch(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
pg_version: PgVersion,
compatibility_snapshot_dir,
combination,
):
"""
Checks compatibility of different combinations of versions of the components
"""
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.from_repo_dir(
compatibility_snapshot_dir / "repo",
)
env.pageserver.allowed_errors.extend(
[".*ingesting record with timestamp lagging more than wait_lsn_timeout.+"]
)
env.start()
check_neon_works(
env, test_output_dir, compatibility_snapshot_dir / "dump.sql", test_output_dir / "repo"
)

View File

@@ -162,6 +162,11 @@ def test_cli_start_stop_multi(neon_env_builder: NeonEnvBuilder):
env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID)
env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID + 1)
# We will stop the storage controller while it may have requests in
# flight, and the pageserver complains when requests are abandoned.
for ps in env.pageservers:
ps.allowed_errors.append(".*request was dropped before completing.*")
# Keep NeonEnv state up to date, it usually owns starting/stopping services
env.pageservers[0].running = False
env.pageservers[1].running = False

View File

@@ -15,7 +15,7 @@ import enum
import os
import re
import time
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from fixtures.common_types import TenantId, TimelineId
@@ -40,6 +40,10 @@ from fixtures.remote_storage import (
from fixtures.utils import wait_until
from fixtures.workload import Workload
if TYPE_CHECKING:
from typing import Optional
# A tenant configuration that is convenient for generating uploads and deletions
# without a large amount of postgres traffic.
TENANT_CONF = {

View File

@@ -23,6 +23,7 @@ from fixtures.remote_storage import s3_storage
from fixtures.utils import wait_until
from fixtures.workload import Workload
from pytest_httpserver import HTTPServer
from typing_extensions import override
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
@@ -954,6 +955,7 @@ class PageserverFailpoint(Failure):
self.pageserver_id = pageserver_id
self._mitigate = mitigate
@override
def apply(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.allowed_errors.extend(
@@ -961,19 +963,23 @@ class PageserverFailpoint(Failure):
)
pageserver.http_client().configure_failpoints((self.failpoint, "return(1)"))
@override
def clear(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.http_client().configure_failpoints((self.failpoint, "off"))
if self._mitigate:
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Active"})
@override
def expect_available(self):
return True
@override
def can_mitigate(self):
return self._mitigate
def mitigate(self, env):
@override
def mitigate(self, env: NeonEnv):
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
@@ -983,9 +989,11 @@ class StorageControllerFailpoint(Failure):
self.pageserver_id = None
self.action = action
@override
def apply(self, env: NeonEnv):
env.storage_controller.configure_failpoints((self.failpoint, self.action))
@override
def clear(self, env: NeonEnv):
if "panic" in self.action:
log.info("Restarting storage controller after panic")
@@ -994,16 +1002,19 @@ class StorageControllerFailpoint(Failure):
else:
env.storage_controller.configure_failpoints((self.failpoint, "off"))
@override
def expect_available(self):
# Controller panics _do_ leave pageservers available, but our test code relies
# on using the locate API to update configurations in Workload, so we must skip
# these actions when the controller has been panicked.
return "panic" not in self.action
@override
def can_mitigate(self):
return False
def fails_forward(self, env):
@override
def fails_forward(self, env: NeonEnv):
# Edge case: the very last failpoint that simulates a DB connection error, where
# the abort path will fail-forward and result in a complete split.
fail_forward = self.failpoint == "shard-split-post-complete"
@@ -1017,6 +1028,7 @@ class StorageControllerFailpoint(Failure):
return fail_forward
@override
def expect_exception(self):
if "panic" in self.action:
return requests.exceptions.ConnectionError
@@ -1029,18 +1041,22 @@ class NodeKill(Failure):
self.pageserver_id = pageserver_id
self._mitigate = mitigate
@override
def apply(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.stop(immediate=True)
@override
def clear(self, env: NeonEnv):
pageserver = env.get_pageserver(self.pageserver_id)
pageserver.start()
@override
def expect_available(self):
return False
def mitigate(self, env):
@override
def mitigate(self, env: NeonEnv):
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
@@ -1059,21 +1075,26 @@ class CompositeFailure(Failure):
self.pageserver_id = f.pageserver_id
break
@override
def apply(self, env: NeonEnv):
for f in self.failures:
f.apply(env)
def clear(self, env):
@override
def clear(self, env: NeonEnv):
for f in self.failures:
f.clear(env)
@override
def expect_available(self):
return all(f.expect_available() for f in self.failures)
def mitigate(self, env):
@override
def mitigate(self, env: NeonEnv):
for f in self.failures:
f.mitigate(env)
@override
def expect_exception(self):
expect = set(f.expect_exception() for f in self.failures)
@@ -1211,7 +1232,7 @@ def test_sharding_split_failures(
assert attached_count == initial_shard_count
def assert_split_done(exclude_ps_id=None) -> None:
def assert_split_done(exclude_ps_id: Optional[int] = None) -> None:
secondary_count = 0
attached_count = 0
for ps in env.pageservers:

View File

@@ -9,6 +9,7 @@ from datetime import datetime, timezone
from enum import Enum
from typing import TYPE_CHECKING
import fixtures.utils
import pytest
from fixtures.auth_tokens import TokenScope
from fixtures.common_types import TenantId, TenantShardId, TimelineId
@@ -38,7 +39,11 @@ from fixtures.pg_version import PgVersion, run_only_on_default_postgres
from fixtures.port_distributor import PortDistributor
from fixtures.remote_storage import RemoteStorageKind, s3_storage
from fixtures.storage_controller_proxy import StorageControllerProxy
from fixtures.utils import run_pg_bench_small, subprocess_capture, wait_until
from fixtures.utils import (
run_pg_bench_small,
subprocess_capture,
wait_until,
)
from fixtures.workload import Workload
from mypy_boto3_s3.type_defs import (
ObjectTypeDef,
@@ -60,9 +65,8 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids):
return counts
def test_storage_controller_smoke(
neon_env_builder: NeonEnvBuilder,
):
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
def test_storage_controller_smoke(neon_env_builder: NeonEnvBuilder, combination):
"""
Test the basic lifecycle of a storage controller:
- Restarting
@@ -1038,7 +1042,7 @@ def test_storage_controller_tenant_deletion(
)
# Break the compute hook: we are checking that deletion does not depend on the compute hook being available
def break_hook():
def break_hook(_body: Any):
raise RuntimeError("Unexpected call to compute hook")
compute_reconfigure_listener.register_on_notify(break_hook)
@@ -1300,11 +1304,11 @@ def test_storage_controller_heartbeats(
node_to_tenants = build_node_to_tenants_map(env)
log.info(f"Back online: {node_to_tenants=}")
# ... expecting the storage controller to reach a consistent state
def storage_controller_consistent():
env.storage_controller.consistency_check()
# ... background reconciliation may need to run to clean up the location on the node that was offline
env.storage_controller.reconcile_until_idle()
wait_until(30, 1, storage_controller_consistent)
# ... expecting the storage controller to reach a consistent state
env.storage_controller.consistency_check()
def test_storage_controller_re_attach(neon_env_builder: NeonEnvBuilder):

View File

@@ -6,7 +6,7 @@ import shutil
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from typing import TYPE_CHECKING
import pytest
from fixtures.common_types import TenantId, TenantShardId, TimelineId
@@ -20,6 +20,9 @@ from fixtures.remote_storage import S3Storage, s3_storage
from fixtures.utils import wait_until
from fixtures.workload import Workload
if TYPE_CHECKING:
from typing import Optional
@pytest.mark.parametrize("shard_count", [None, 4])
def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]):