Compare commits

..

19 Commits

Author SHA1 Message Date
Conrad Ludgate
fe8b93ab9d fix Payload deser 2024-10-14 14:02:46 +01:00
Conrad Ludgate
7e3e7f1cca turns out we don't actually need to deser everything 2024-10-14 11:58:53 +01:00
Conrad Ludgate
0b0ed662d9 proxy: use RawValue to lazily process inputs 2024-10-14 11:48:58 +01:00
Conrad Ludgate
50bd65769f a seemingly random change... 2024-10-14 11:44:18 +01:00
Conrad Ludgate
90534b1745 remove iterator join 2024-10-14 11:42:20 +01:00
Conrad Ludgate
99d52df475 proxy: slight refactor to json parsing 2024-10-14 11:38:18 +01:00
Conrad Ludgate
ab5bbb445b proxy: refactor auth backends (#9271)
preliminary for #9270 

The auth::Backend didn't need to be in the mega ProxyConfig object, so I
split it off and passed it manually in the few places it was necessary.

I've also refined some of the uses of config I saw while doing this
small refactor.

I've also followed the trend and make the console redirect backend it's
own struct, same as LocalBackend and ControlPlaneBackend.
2024-10-11 20:14:52 +01:00
Alexander Bayandin
5ef805e12c CI(run-python-test-set): allow to skip missing compatibility snapshot (#9365)
## Problem
Action `run-python-test-set` fails if it is not used for `regress_tests`
on release PR, because it expects
`test_compatibility.py::test_create_snapshot` to generate a snapshot,
and the test exists only in `regress_tests` suite.
For example, in https://github.com/neondatabase/neon/pull/9291
[`test-postgres-client-libs`](https://github.com/neondatabase/neon/actions/runs/11209615321/job/31155111544)
job failed.

## Summary of changes
- Add `skip-if-does-not-exist` input to `.github/actions/upload` action
(the same way we do for `.github/actions/download`)
- Set `skip-if-does-not-exist=true` for "Upload compatibility snapshot"
step in `run-python-test-set` action
2024-10-11 16:58:41 +01:00
a-masterov
091a175a3e Test versions mismatch (#9167)
## Problem
We faced the problem of incompatibility of the different components of
different versions.
This should be detected automatically to prevent production bugs.
## Summary of changes
The test for this situation was implemented

Co-authored-by: Alexander Bayandin <alexander@neon.tech>
2024-10-11 15:29:54 +02:00
Fedor Dikarev
326cd80f0d ci: gh-workflow-stats-action v0.1.4: remove debug output and proper pagination (#9356)
## Problem
In previous version pagination didn't work so we collect information
only for first 30 jobs in WorkflowRun
2024-10-11 14:46:45 +02:00
Folke Behrens
6baf1aae33 proxy: Demote some errors to warnings in logs (#9354) 2024-10-11 11:29:08 +02:00
John Spray
184935619e tests: stabilize test_storage_controller_heartbeats (#9347)
## Problem

This could fail with `reconciliation in progress` if running on a slow
test node such that background reconciliation happens at the same time
as we call consistency_check.

Example:
https://neon-github-public-dev.s3.amazonaws.com/reports/main/11258171952/index.html#/testresult/54889c9469afb232

## Summary of changes

- Call reconcile_until_idle before calling consistency check once,
rather than calling consistency check until it passes
2024-10-11 09:41:08 +01:00
Ivan Efremov
b2ecbf3e80 Introduce "quota" ErrorKind (#9300)
## Problem
Fixes #8340
## Summary of changes
Introduced ErrorKind::quota to handle quota-related errors
## Checklist before requesting a review

- [x] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

## Checklist before merging

- [ ] Do not forget to reformat commit message to not include the above
checklist
2024-10-11 10:45:55 +03:00
Tristan Partin
53147b51f9 Use valid type hints for Python 3.9
I have no idea how this made it past the linters.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-10-10 13:00:25 -05:00
Tristan Partin
006d9dfb6b Add compute_config_dir fixture
Allows easy access to various compute config files.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-10-10 12:43:40 -05:00
Tristan Partin
1f7904c917 Enable cargo caching in check-codestyle-rust
This job takes an extraordinary amount of time for what I understand it
to do. The obvious win is caching dependencies.

Rory disabled caching in cd5732d9d8.
I assume this was to get gen3 runners up and running.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-10-10 12:40:30 -05:00
John Spray
07c714343f tests: allow a log warning in test_cli_start_stop_multi (#9320)
## Problem

This test restarts services in an undefined order (whatever neon_local
does), which means we should be tolerant of warnings that come from
restarting the storage controller while a pageserver is running.

We can see failures with warnings from dropped requests, e.g.
https://neon-github-public-dev.s3.amazonaws.com/reports/pr-9307/11229000712/index.html#/testresult/d33d5cb206331e28
```
 WARN request{method=GET path=/v1/location_config request_id=b7dbda15-6efb-4610-8b19-a3772b65455f}: request was dropped before completing\n')
```

## Summary of changes

- allow-list the `request was dropped before completing` message on
pageservers before restarting services
2024-10-10 17:06:42 +01:00
Tristan Partin
264c34dfb7 Move path-related fixtures into their own module (#9304)
neon_fixtures.py has grown into quite a beast.

Signed-off-by: Tristan Partin <tristan@neon.tech>
2024-10-10 10:26:23 -05:00
Erik Grinaker
9dd80b9b4c storage_scrubber: fix faulty assertion when no timelines (#9345)
When there are no timelines in remote storage, the storage scrubber
would incorrectly trip an assertion with "Must be set if results are
present", referring to the last processed tenant ID. When there are no
timelines we don't expect there to be a tenant ID either.

The assertion was introduced in 37aa6fd.

Only apply the assertion when any timelines are present.
2024-10-10 09:09:53 -04:00
36 changed files with 1442 additions and 672 deletions

View File

@@ -218,6 +218,9 @@ runs:
name: compatibility-snapshot-${{ runner.arch }}-${{ inputs.build_type }}-pg${{ inputs.pg_version }}
# Directory is created by test_compatibility.py::test_create_snapshot, keep the path in sync with the test
path: /tmp/test_output/compatibility_snapshot_pg${{ inputs.pg_version }}/
# The lack of compatibility snapshot shouldn't fail the job
# (for example if we didn't run the test for non build-and-test workflow)
skip-if-does-not-exist: true
- name: Upload test results
if: ${{ !cancelled() }}

View File

@@ -7,6 +7,10 @@ inputs:
path:
description: "A directory or file to upload"
required: true
skip-if-does-not-exist:
description: "Allow to skip if path doesn't exist, fail otherwise"
default: false
required: false
prefix:
description: "S3 prefix. Default is '${GITHUB_SHA}/${GITHUB_RUN_ID}/${GITHUB_RUN_ATTEMPT}'"
required: false
@@ -15,10 +19,12 @@ runs:
using: "composite"
steps:
- name: Prepare artifact
id: prepare-artifact
shell: bash -euxo pipefail {0}
env:
SOURCE: ${{ inputs.path }}
ARCHIVE: /tmp/uploads/${{ inputs.name }}.tar.zst
SKIP_IF_DOES_NOT_EXIST: ${{ inputs.skip-if-does-not-exist }}
run: |
mkdir -p $(dirname $ARCHIVE)
@@ -33,14 +39,22 @@ runs:
elif [ -f ${SOURCE} ]; then
time tar -cf ${ARCHIVE} --zstd ${SOURCE}
elif ! ls ${SOURCE} > /dev/null 2>&1; then
echo >&2 "${SOURCE} does not exist"
exit 2
if [ "${SKIP_IF_DOES_NOT_EXIST}" = "true" ]; then
echo 'SKIPPED=true' >> $GITHUB_OUTPUT
exit 0
else
echo >&2 "${SOURCE} does not exist"
exit 2
fi
else
echo >&2 "${SOURCE} is neither a directory nor a file, do not know how to handle it"
exit 3
fi
echo 'SKIPPED=false' >> $GITHUB_OUTPUT
- name: Upload artifact
if: ${{ steps.prepare-artifact.outputs.SKIPPED == 'false' }}
shell: bash -euxo pipefail {0}
env:
SOURCE: ${{ inputs.path }}

View File

@@ -193,16 +193,15 @@ jobs:
with:
submodules: true
# Disabled for now
# - name: Restore cargo deps cache
# id: cache_cargo
# uses: actions/cache@v4
# with:
# path: |
# !~/.cargo/registry/src
# ~/.cargo/git/
# target/
# key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-clippy-${{ hashFiles('rust-toolchain.toml') }}-${{ hashFiles('Cargo.lock') }}
- name: Cache cargo deps
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
!~/.cargo/registry/src
~/.cargo/git
target
key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-${{ hashFiles('./Cargo.lock') }}-${{ hashFiles('./rust-toolchain.toml') }}-rust
# Some of our rust modules use FFI and need those to be checked
- name: Get postgres headers

View File

@@ -33,7 +33,7 @@ jobs:
actions: read
steps:
- name: Export GH Workflow Stats
uses: fedordikarev/gh-workflow-stats-action@v0.1.2
uses: neondatabase/gh-workflow-stats-action@v0.1.4
with:
DB_URI: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }}
DB_TABLE: "gh_workflow_stats_neon"

View File

@@ -25,6 +25,10 @@ pub(crate) enum WebAuthError {
Io(#[from] std::io::Error),
}
pub struct ConsoleRedirectBackend {
console_uri: reqwest::Url,
}
impl UserFacingError for WebAuthError {
fn to_string_client(&self) -> String {
"Internal error".to_string()
@@ -57,7 +61,26 @@ pub(crate) fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
pub(super) async fn authenticate(
impl ConsoleRedirectBackend {
pub fn new(console_uri: reqwest::Url) -> Self {
Self { console_uri }
}
pub(super) fn url(&self) -> &reqwest::Url {
&self.console_uri
}
pub(crate) async fn authenticate(
&self,
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
authenticate(ctx, auth_config, &self.console_uri, client).await
}
}
async fn authenticate(
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,

View File

@@ -8,6 +8,7 @@ use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::WebAuthError;
use ipnet::{Ipv4Net, Ipv6Net};
use local::LocalBackend;
@@ -36,7 +37,7 @@ use crate::{
provider::{CachedAllowedIps, CachedNodeInfo},
Api,
},
stream, url,
stream,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
@@ -69,7 +70,7 @@ pub enum Backend<'a, T, D> {
/// Cloud API (V2).
ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
/// Authentication via a web browser.
ConsoleRedirect(MaybeOwned<'a, url::ApiUrl>, D),
ConsoleRedirect(MaybeOwned<'a, ConsoleRedirectBackend>, D),
/// Local proxy uses configured auth credentials and does not wake compute
Local(MaybeOwned<'a, LocalBackend>),
}
@@ -106,9 +107,9 @@ impl std::fmt::Display for Backend<'_, (), ()> {
#[cfg(test)]
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
},
Self::ConsoleRedirect(url, ()) => fmt
Self::ConsoleRedirect(backend, ()) => fmt
.debug_tuple("ConsoleRedirect")
.field(&url.as_str())
.field(&backend.url().as_str())
.finish(),
Self::Local(_) => fmt.debug_tuple("Local").finish(),
}
@@ -241,7 +242,6 @@ impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
@@ -265,7 +265,7 @@ impl AuthenticationConfig {
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
),
password_weight,
);
@@ -339,7 +339,6 @@ async fn auth_quirks(
let secret = if let Some(secret) = secret {
config.check_rate_limit(
ctx,
config,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
@@ -456,12 +455,12 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
Backend::ControlPlane(api, credentials)
}
// NOTE: this auth backend doesn't use client credentials.
Self::ConsoleRedirect(url, ()) => {
Self::ConsoleRedirect(backend, ()) => {
info!("performing web authentication");
let info = console_redirect::authenticate(ctx, config, &url, client).await?;
let info = backend.authenticate(ctx, config, client).await?;
Backend::ConsoleRedirect(url, info)
Backend::ConsoleRedirect(backend, info)
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))

View File

@@ -6,9 +6,12 @@ use compute_api::spec::LocalProxySpec;
use dashmap::DashMap;
use futures::future::Either;
use proxy::{
auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
auth::{
self,
backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
@@ -132,6 +135,7 @@ async fn main() -> anyhow::Result<()> {
let args = LocalProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
// before we bind to any ports, write the process ID to a file
// so that compute-ctl can find our process later
@@ -193,6 +197,7 @@ async fn main() -> anyhow::Result<()> {
let task = serverless::task_main(
config,
auth_backend,
http_listener,
shutdown.clone(),
Arc::new(CancellationHandlerMain::new(
@@ -257,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
Ok(Box::leak(Box::new(ProxyConfig {
tls_config: None,
auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.compute),
)),
metric_collection: None,
allow_self_signed_compute: false,
http_config,
@@ -286,6 +288,17 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
})))
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &LocalProxyCliArgs,
) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> {
let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
LocalBackend::new(args.compute),
));
Ok(Box::leak(Box::new(auth_backend)))
}
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
loop {
rx.notified().await;

View File

@@ -10,6 +10,7 @@ use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::ConsoleRedirectBackend;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
use proxy::cancellation::CancellationHandler;
@@ -311,8 +312,9 @@ async fn main() -> anyhow::Result<()> {
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
info!("Authentication backend: {}", config.auth_backend);
info!("Authentication backend: {}", auth_backend);
info!("Using region: {}", args.aws_region);
let region_provider =
@@ -462,6 +464,7 @@ async fn main() -> anyhow::Result<()> {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
@@ -472,6 +475,7 @@ async fn main() -> anyhow::Result<()> {
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
@@ -506,7 +510,7 @@ async fn main() -> anyhow::Result<()> {
));
}
if let auth::Backend::ControlPlane(api, _) = &config.auth_backend {
if let auth::Backend::ControlPlane(api, _) = auth_backend {
if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
@@ -610,6 +614,80 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
bail!("dynamic rate limiter should be disabled");
}
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.connect_compute_lock.parse()?;
info!(
?limiter,
shards,
?epoch,
"Using NodeLocks (connect_compute)"
);
let connect_compute_locks = control_plane::locks::ApiLocks::new(
"connect_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().proxy.connect_compute_lock,
)?;
let http_config = HttpConfig {
accept_websockets: !args.is_auth_broker,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
},
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
};
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
metric_collection,
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
}));
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
Ok(config)
}
/// auth::Backend is created at proxy startup, and lives forever.
fn build_auth_backend(
args: &ProxyCliArgs,
) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> {
let auth_backend = match &args.auth_backend {
AuthBackendType::Console => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
@@ -665,7 +743,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
AuthBackendType::Web => {
let url = args.uri.parse()?;
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(url), ())
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(ConsoleRedirectBackend::new(url)), ())
}
#[cfg(feature = "testing")]
@@ -677,75 +755,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
}
};
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.connect_compute_lock.parse()?;
info!(
?limiter,
shards,
?epoch,
"Using NodeLocks (connect_compute)"
);
let connect_compute_locks = control_plane::locks::ApiLocks::new(
"connect_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().proxy.connect_compute_lock,
)?;
let http_config = HttpConfig {
accept_websockets: !args.is_auth_broker,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
},
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
};
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute_retry_config: config::RetryConfig::parse(
&args.connect_to_compute_retry,
)?,
}));
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
Ok(config)
Ok(Box::leak(Box::new(auth_backend)))
}
#[cfg(test)]

View File

@@ -1,8 +1,5 @@
use crate::{
auth::{
self,
backend::{jwt::JwkCache, AuthRateLimiter},
},
auth::backend::{jwt::JwkCache, AuthRateLimiter},
control_plane::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
@@ -29,7 +26,6 @@ use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: auth::Backend<'static, (), ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,

View File

@@ -81,12 +81,12 @@ pub(crate) mod errors {
Reason::EndpointNotFound => ErrorKind::User,
Reason::BranchNotFound => ErrorKind::User,
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User,
Reason::ActiveTimeQuotaExceeded => ErrorKind::User,
Reason::ComputeTimeQuotaExceeded => ErrorKind::User,
Reason::WrittenDataQuotaExceeded => ErrorKind::User,
Reason::DataTransferQuotaExceeded => ErrorKind::User,
Reason::LogicalSizeQuotaExceeded => ErrorKind::User,
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::Quota,
Reason::ActiveTimeQuotaExceeded => ErrorKind::Quota,
Reason::ComputeTimeQuotaExceeded => ErrorKind::Quota,
Reason::WrittenDataQuotaExceeded => ErrorKind::Quota,
Reason::DataTransferQuotaExceeded => ErrorKind::Quota,
Reason::LogicalSizeQuotaExceeded => ErrorKind::Quota,
Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
Reason::RunningOperations => ErrorKind::ControlPlane,
@@ -103,7 +103,7 @@ pub(crate) mod errors {
} if error
.contains("compute time quota of non-primary branches is exceeded") =>
{
crate::error::ErrorKind::User
crate::error::ErrorKind::Quota
}
ControlPlaneError {
http_status_code: http::StatusCode::LOCKED,
@@ -112,7 +112,7 @@ pub(crate) mod errors {
} if error.contains("quota exceeded")
|| error.contains("the limit for current plan reached") =>
{
crate::error::ErrorKind::User
crate::error::ErrorKind::Quota
}
ControlPlaneError {
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,

View File

@@ -22,7 +22,7 @@ use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, error, info, info_span, warn, Instrument};
use tracing::{debug, info, info_span, warn, Instrument};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
@@ -456,7 +456,7 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
});
body.http_status_code = status;
error!("console responded with an error ({status}): {body:?}");
warn!("console responded with an error ({status}): {body:?}");
Err(ApiError::ControlPlane(body))
}

View File

@@ -49,6 +49,10 @@ pub enum ErrorKind {
#[label(rename = "serviceratelimit")]
ServiceRateLimit,
/// Proxy quota limit violation
#[label(rename = "quota")]
Quota,
/// internal errors
Service,
@@ -70,6 +74,7 @@ impl ErrorKind {
ErrorKind::ClientDisconnect => "clientdisconnect",
ErrorKind::RateLimit => "ratelimit",
ErrorKind::ServiceRateLimit => "serviceratelimit",
ErrorKind::Quota => "quota",
ErrorKind::Service => "service",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "postgres",

View File

@@ -35,7 +35,7 @@ use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, Instrument};
use tracing::{error, info, warn, Instrument};
use self::{
connect_compute::{connect_to_compute, TcpMechanism},
@@ -61,6 +61,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, (), ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -95,15 +96,15 @@ pub async fn task_main(
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
warn!("per-client task finished with an error: {e:#}");
return;
}
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
error!("missing required proxy protocol header");
warn!("missing required proxy protocol header");
return;
}
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
error!("proxy protocol header not supported");
warn!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
@@ -129,6 +130,7 @@ pub async fn task_main(
let startup = Box::pin(
handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
@@ -144,7 +146,7 @@ pub async fn task_main(
Err(e) => {
// todo: log and push to ctx the error kind
ctx.set_error_kind(e.get_error_kind());
error!(parent: &span, "per-client task finished with an error: {e:#}");
warn!(parent: &span, "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
@@ -155,7 +157,7 @@ pub async fn task_main(
match p.proxy_pass().instrument(span.clone()).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
}
Err(ErrorSource::Compute(e)) => {
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
@@ -243,8 +245,10 @@ impl ReportableError for ClientRequestError {
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, (), ()>,
ctx: &RequestMonitoring,
cancellation_handler: Arc<CancellationHandlerMain>,
stream: S,
@@ -285,8 +289,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let common_names = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = config
.auth_backend
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();

View File

@@ -71,7 +71,7 @@ impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
tracing::error!(?err, "could not cancel the query in the database");
tracing::warn!(?err, "could not cancel the query in the database");
}
res
}

View File

@@ -6,7 +6,7 @@ use redis::{
ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult,
};
use tokio::task::JoinHandle;
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};
use super::elasticache::CredentialsProvider;
@@ -89,7 +89,7 @@ impl ConnectionWithCredentialsProvider {
return Ok(());
}
Err(e) => {
error!("Error during PING: {e:?}");
warn!("Error during PING: {e:?}");
}
}
} else {
@@ -121,7 +121,7 @@ impl ConnectionWithCredentialsProvider {
info!("Connection succesfully established");
}
Err(e) => {
error!("Connection is broken. Error during PING: {e:?}");
warn!("Connection is broken. Error during PING: {e:?}");
}
}
self.con = Some(con);

View File

@@ -146,7 +146,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
{
Ok(()) => {}
Err(e) => {
tracing::error!("failed to cancel session: {e}");
tracing::warn!("failed to cancel session: {e}");
}
}
}

View File

@@ -13,7 +13,7 @@ use crate::{
check_peer_addr_is_in_list, AuthError,
},
compute,
config::{AuthenticationConfig, ProxyConfig},
config::ProxyConfig,
context::RequestMonitoring,
control_plane::{
errors::{GetAuthInfoError, WakeComputeError},
@@ -32,16 +32,17 @@ use crate::{
};
use super::{
conn_pool::{poll_client, Client, ConnInfo, ConnPool, EndpointConnPool},
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
local_conn_pool::{self, LocalClient},
local_conn_pool::{self, LocalClient, LocalConnPool},
};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) local_pool: Arc<ConnPool<tokio_postgres::Client>>,
pub(crate) pool: Arc<ConnPool<tokio_postgres::Client>>,
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) auth_backend: &'static crate::auth::Backend<'static, (), ()>,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
@@ -49,18 +50,13 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_password(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
password: &[u8],
) -> Result<ComputeCredentials, AuthError> {
let user_info = user_info.clone();
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| user_info.clone());
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
if config.ip_allowlist_check_enabled
if self.config.authentication_config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
@@ -79,7 +75,6 @@ impl PoolingBackend {
let secret = match cached_secret.value.clone() {
Some(secret) => self.config.authentication_config.check_rate_limit(
ctx,
config,
secret,
&user_info.endpoint,
true,
@@ -91,9 +86,13 @@ impl PoolingBackend {
}
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome =
crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
.await?;
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
ep,
password,
secret,
)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");
@@ -113,13 +112,13 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<ComputeCredentials, AuthError> {
match &self.config.auth_backend {
match &self.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
config
self.config
.authentication_config
.jwks_cache
.check_jwt(
ctx,
@@ -140,7 +139,9 @@ impl PoolingBackend {
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(_) => {
let keys = config
let keys = self
.config
.authentication_config
.jwks_cache
.check_jwt(
ctx,
@@ -185,7 +186,7 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.config.auth_backend.as_ref().map(|()| keys);
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
@@ -217,21 +218,14 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}-local-proxy",
conn_info.user_info.endpoint
)),
options: conn_info.user_info.options.clone(),
},
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)),
options: conn_info.user_info.options.clone(),
},
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
@@ -269,7 +263,7 @@ impl PoolingBackend {
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let mut node_info = match &self.config.auth_backend {
let mut node_info = match &self.auth_backend {
auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => {
unreachable!("only local_proxy can connect to local postgres")
}
@@ -439,7 +433,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError {
}
struct TokioMechanism {
pool: Arc<ConnPool<tokio_postgres::Client>>,
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,

View File

@@ -77,7 +77,7 @@ impl fmt::Display for ConnInfo {
}
}
pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
struct ConnPoolEntry<C: ClientInnerExt> {
conn: ClientInner<C>,
_last_access: std::time::Instant,
}
@@ -87,11 +87,10 @@ pub(crate) struct ConnPoolEntry<C: ClientInnerExt> {
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
max_conns: usize, // max conns per endpoint
max_conns: usize,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
global_pool_size_max_conns: usize,
pool_name: String, // used for logging
}
impl<C: ClientInnerExt> EndpointConnPool<C> {
@@ -134,23 +133,21 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
}
}
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
let conn_id = client.conn_id;
let p_name = pool.read().pool_name.clone();
if client.is_closed() {
info!(%conn_id, "{p_name}: throwing away connection '{conn_info}' because connection is closed");
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool
.read()
.global_connections_count
.load(atomic::Ordering::Relaxed)
>= global_max_conn
{
info!(%conn_id, "{p_name}: throwing away connection '{conn_info}' because pool is full");
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full");
return;
}
@@ -185,11 +182,9 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
// do logging outside of the mutex
if returned {
info!(%conn_id, "{p_name}: returning connection '{conn_info}' back to the pool,
total_conns={total_conns}, for this (db, user)={per_db_size}");
info!(%conn_id, "pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!(%conn_id, "{p_name}: throwing away connection '{conn_info}' because pool is full,
total_conns={total_conns}");
info!(%conn_id, "pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
}
}
@@ -219,7 +214,7 @@ impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
}
impl<C: ClientInnerExt> DbUserConnPool<C> {
pub(crate) fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
let old_len = self.conns.len();
self.conns.retain(|conn| !conn.conn.is_closed());
@@ -230,7 +225,7 @@ impl<C: ClientInnerExt> DbUserConnPool<C> {
removed
}
pub(crate) fn get_conn_entry(
fn get_conn_entry(
&mut self,
conns: &mut usize,
global_connections_count: Arc<AtomicUsize>,
@@ -251,12 +246,12 @@ impl<C: ClientInnerExt> DbUserConnPool<C> {
}
}
pub(crate) struct ConnPool<C: ClientInnerExt> {
pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
pub(crate) global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool<C>>>>,
/// Number of endpoint-connection pools
///
@@ -291,7 +286,7 @@ pub struct GlobalConnPoolOptions {
pub max_total_conns: usize,
}
impl<C: ClientInnerExt> ConnPool<C> {
impl<C: ClientInnerExt> GlobalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
@@ -433,7 +428,7 @@ impl<C: ClientInnerExt> ConnPool<C> {
Ok(None)
}
pub(crate) fn get_or_create_endpoint_pool(
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool<C>>> {
@@ -450,7 +445,6 @@ impl<C: ClientInnerExt> ConnPool<C> {
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
global_pool_size_max_conns: self.config.pool_options.max_total_conns,
pool_name: String::from("global_pool"),
}));
// find or create a pool for this endpoint
@@ -480,7 +474,7 @@ impl<C: ClientInnerExt> ConnPool<C> {
}
pub(crate) fn poll_client<C: ClientInnerExt>(
global_pool: Arc<ConnPool<C>>,
global_pool: Arc<GlobalConnPool<C>>,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
client: C,
@@ -600,12 +594,6 @@ struct ClientInner<C: ClientInnerExt> {
conn_id: uuid::Uuid,
}
impl<C: ClientInnerExt> ClientInner<C> {
pub(crate) fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<C: ClientInnerExt> Drop for ClientInner<C> {
fn drop(&mut self) {
// on client drop, tell the conn to shut down
@@ -627,6 +615,22 @@ impl ClientInnerExt for tokio_postgres::Client {
}
}
impl<C: ClientInnerExt> ClientInner<C> {
pub(crate) fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<C: ClientInnerExt> Client<C> {
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
}
}
pub(crate) struct Client<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInner<C>>,
@@ -634,6 +638,11 @@ pub(crate) struct Client<C: ClientInnerExt> {
pool: Weak<RwLock<EndpointConnPool<C>>>,
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
}
impl<C: ClientInnerExt> Client<C> {
pub(self) fn new(
inner: ClientInner<C>,
@@ -647,7 +656,6 @@ impl<C: ClientInnerExt> Client<C> {
pool,
}
}
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
@@ -658,15 +666,36 @@ impl<C: ClientInnerExt> Client<C> {
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { conn_info, pool })
}
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
})
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
impl<C: ClientInnerExt> Deref for Client<C> {
type Target = C;
fn deref(&self) -> &Self::Target {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
}
impl<C: ClientInnerExt> Client<C> {
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
@@ -685,18 +714,6 @@ impl<C: ClientInnerExt> Client<C> {
}
}
impl<C: ClientInnerExt> Deref for Client<C> {
type Target = C;
fn deref(&self) -> &Self::Target {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
}
impl<C: ClientInnerExt> Drop for Client<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
@@ -705,26 +722,6 @@ impl<C: ClientInnerExt> Drop for Client<C> {
}
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
}
}
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
#[cfg(test)]
mod tests {
use std::{mem, sync::atomic::AtomicBool};
@@ -787,7 +784,7 @@ mod tests {
max_request_size_bytes: u64::MAX,
max_response_size_bytes: usize::MAX,
}));
let pool = ConnPool::new(config);
let pool = GlobalConnPool::new(config);
let conn_info = ConnInfo {
user_info: ComputeUserInfo {
user: "user".into(),

View File

@@ -1,30 +1,40 @@
use itertools::Itertools;
use serde_json::value::RawValue;
use serde_json::Map;
use serde_json::Value;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::Row;
use typed_json::json;
use super::json_raw_value::LazyValue;
//
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
json.iter().map(json_value_to_pg_text).collect()
pub(crate) fn json_to_pg_text(
json: &[&RawValue],
) -> Result<Vec<Option<String>>, serde_json::Error> {
json.iter().copied().map(json_value_to_pg_text).try_collect()
}
fn json_value_to_pg_text(value: &Value) -> Option<String> {
match value {
fn json_value_to_pg_text(value: &RawValue) -> Result<Option<String>, serde_json::Error> {
let lazy_value = serde_json::from_str(value.get())?;
match lazy_value {
// special care for nulls
Value::Null => None,
LazyValue::Null => Ok(None),
// convert to text with escaping
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
LazyValue::Bool | LazyValue::Number | LazyValue::Object => {
Ok(Some(value.get().to_string()))
}
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.to_string()),
LazyValue::String(s) => Ok(Some(s.into_owned())),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
LazyValue::Array(arr) => Ok(Some(json_array_to_pg_array(arr)?)),
}
}
@@ -36,27 +46,42 @@ fn json_value_to_pg_text(value: &Value) -> Option<String> {
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(value: &Value) -> Option<String> {
match value {
fn json_array_to_pg_array(arr: Vec<&RawValue>) -> Result<String, serde_json::Error> {
let mut output = String::new();
let mut first = true;
output.push('{');
for value in arr {
if !first {
output.push(',');
}
first = false;
let value = json_array_to_pg_array_inner(value)?;
output.push_str(value.as_deref().unwrap_or("NULL"));
}
output.push('}');
Ok(output)
}
fn json_array_to_pg_array_inner(value: &RawValue) -> Result<Option<String>, serde_json::Error> {
let lazy_value = serde_json::from_str(value.get())?;
match lazy_value {
// special care for nulls
Value::Null => None,
LazyValue::Null => Ok(None),
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
LazyValue::Bool | LazyValue::Number | LazyValue::String(_) => {
Ok(Some(value.get().to_string()))
}
LazyValue::Object => Ok(Some(json!(value.get().to_string()).to_string())),
// recurse into array
Value::Array(arr) => {
let vals = arr
.iter()
.map(json_array_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.collect::<Vec<_>>()
.join(",");
Some(format!("{{{vals}}}"))
}
LazyValue::Array(arr) => Ok(Some(json_array_to_pg_array(arr)?)),
}
}
@@ -259,25 +284,31 @@ mod tests {
use super::*;
use serde_json::json;
fn json_to_pg_text_test(json: Vec<serde_json::Value>) -> Vec<Option<String>> {
let json = serde_json::Value::Array(json).to_string();
let json: Vec<&RawValue> = serde_json::from_str(&json).unwrap();
json_to_pg_text(&json).unwrap()
}
#[test]
fn test_atomic_types_to_pg_params() {
let json = vec![Value::Bool(true), Value::Bool(false)];
let pg_params = json_to_pg_text(json);
let pg_params = json_to_pg_text_test(json);
assert_eq!(
pg_params,
vec![Some("true".to_owned()), Some("false".to_owned())]
);
let json = vec![Value::Number(serde_json::Number::from(42))];
let pg_params = json_to_pg_text(json);
let pg_params = json_to_pg_text_test(json);
assert_eq!(pg_params, vec![Some("42".to_owned())]);
let json = vec![Value::String("foo\"".to_string())];
let pg_params = json_to_pg_text(json);
let pg_params = json_to_pg_text_test(json);
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
let json = vec![Value::Null];
let pg_params = json_to_pg_text(json);
let pg_params = json_to_pg_text_test(json);
assert_eq!(pg_params, vec![None]);
}
@@ -286,7 +317,7 @@ mod tests {
// atoms and escaping
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
let pg_params = json_to_pg_text_test(vec![json]);
assert_eq!(
pg_params,
vec![Some(
@@ -297,7 +328,7 @@ mod tests {
// nested arrays
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
let pg_params = json_to_pg_text_test(vec![json]);
assert_eq!(
pg_params,
vec![Some(
@@ -307,7 +338,7 @@ mod tests {
// array of objects
let json = r#"[{"foo": 1},{"bar": 2}]"#;
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
let pg_params = json_to_pg_text_test(vec![json]);
assert_eq!(
pg_params,
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]

View File

@@ -0,0 +1,193 @@
//! [`serde_json::Value`] but uses RawValue internally
//!
//! This code forks from the serde_json code, but replaces internal Value with RawValue where possible.
//!
//! Taken from <https://github.com/serde-rs/json/blob/faab2e8d2fcf781a3f77f329df836ffb3aaacfba/src/value/de.rs>
//! Licensed from serde-rs under MIT or APACHE-2.0, with modifications by Conrad Ludgate
use core::fmt;
use std::borrow::Cow;
use serde::{
de::{IgnoredAny, MapAccess, SeqAccess, Visitor},
Deserialize,
};
use serde_json::value::RawValue;
pub enum LazyValue<'de> {
Null,
Bool,
Number,
String(Cow<'de, str>),
Array(Vec<&'de RawValue>),
Object,
}
impl<'de> Deserialize<'de> for LazyValue<'de> {
#[inline]
fn deserialize<D>(deserializer: D) -> Result<LazyValue<'de>, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ValueVisitor;
impl<'de> Visitor<'de> for ValueVisitor {
type Value = LazyValue<'de>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("any valid JSON value")
}
#[inline]
fn visit_bool<E>(self, _value: bool) -> Result<LazyValue<'de>, E> {
Ok(LazyValue::Bool)
}
#[inline]
fn visit_i64<E>(self, _value: i64) -> Result<LazyValue<'de>, E> {
Ok(LazyValue::Number)
}
#[inline]
fn visit_u64<E>(self, _value: u64) -> Result<LazyValue<'de>, E> {
Ok(LazyValue::Number)
}
#[inline]
fn visit_f64<E>(self, _value: f64) -> Result<LazyValue<'de>, E> {
Ok(LazyValue::Number)
}
#[inline]
fn visit_str<E>(self, value: &str) -> Result<LazyValue<'de>, E>
where
E: serde::de::Error,
{
self.visit_string(String::from(value))
}
#[inline]
fn visit_borrowed_str<E>(self, value: &'de str) -> Result<LazyValue<'de>, E>
where
E: serde::de::Error,
{
Ok(LazyValue::String(Cow::Borrowed(value)))
}
#[inline]
fn visit_string<E>(self, value: String) -> Result<LazyValue<'de>, E> {
Ok(LazyValue::String(Cow::Owned(value)))
}
#[inline]
fn visit_none<E>(self) -> Result<LazyValue<'de>, E> {
Ok(LazyValue::Null)
}
#[inline]
fn visit_some<D>(self, deserializer: D) -> Result<LazyValue<'de>, D::Error>
where
D: serde::Deserializer<'de>,
{
Deserialize::deserialize(deserializer)
}
#[inline]
fn visit_unit<E>(self) -> Result<LazyValue<'de>, E> {
Ok(LazyValue::Null)
}
#[inline]
fn visit_seq<V>(self, mut visitor: V) -> Result<LazyValue<'de>, V::Error>
where
V: SeqAccess<'de>,
{
let mut vec = Vec::new();
while let Some(elem) = visitor.next_element()? {
vec.push(elem);
}
Ok(LazyValue::Array(vec))
}
fn visit_map<V>(self, mut visitor: V) -> Result<LazyValue<'de>, V::Error>
where
V: MapAccess<'de>,
{
while visitor.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
Ok(LazyValue::Object)
}
}
deserializer.deserialize_any(ValueVisitor)
}
}
#[cfg(test)]
mod tests {
use std::borrow::Cow;
use typed_json::json;
use super::LazyValue;
#[test]
fn object() {
let json = json! {{
"foo": {
"bar": 1
},
"baz": [2, 3],
}}
.to_string();
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
let LazyValue::Object = lazy else {
panic!("expected object")
};
}
#[test]
fn array() {
let json = json! {[
{
"bar": 1
},
[2, 3],
]}
.to_string();
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
let LazyValue::Array(array) = lazy else {
panic!("expected array")
};
assert_eq!(array.len(), 2);
assert_eq!(array[0].get(), r#"{"bar":1}"#);
assert_eq!(array[1].get(), r#"[2,3]"#);
}
#[test]
fn string() {
let json = json! { "hello world" }.to_string();
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
let LazyValue::String(Cow::Borrowed(string)) = lazy else {
panic!("expected borrowed string")
};
assert_eq!(string, "hello world");
let json = json! { "hello \n world" }.to_string();
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
let LazyValue::String(Cow::Owned(string)) = lazy else {
panic!("expected owned string")
};
assert_eq!(string, "hello \n world");
}
}

View File

@@ -6,15 +6,7 @@ use rand::rngs::OsRng;
use serde_json::Value;
use signature::Signer;
use std::task::{ready, Poll};
use std::{
collections::HashMap,
pin::pin,
sync::atomic::{self, AtomicUsize},
sync::Arc,
sync::Weak,
time::Duration,
};
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
use tokio::time::Instant;
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::types::ToSql;
@@ -23,7 +15,7 @@ use tokio_util::sync::CancellationToken;
use typed_json::json;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::metrics::Metrics;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, DbName, RoleName};
@@ -31,10 +23,230 @@ use tracing::{debug, error, warn, Span};
use tracing::{info, info_span, Instrument};
use super::backend::HttpConnError;
use super::conn_pool::{ClientInnerExt, ConnInfo, ConnPool, EndpointConnPool};
use super::conn_pool::{ClientInnerExt, ConnInfo};
pub(crate) fn poll_client<C: ClientInnerExt>(
local_pool: Arc<ConnPool<C>>,
struct ConnPoolEntry<C: ClientInnerExt> {
conn: ClientInner<C>,
_last_access: std::time::Instant,
}
// /// key id for the pg_session_jwt state
// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1);
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
total_conns: usize,
max_conns: usize,
global_pool_size_max_conns: usize,
}
impl<C: ClientInnerExt> EndpointConnPool<C> {
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
let Self {
pools, total_conns, ..
} = self;
pools
.get_mut(&db_user)
.and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
}
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
let Self {
pools, total_conns, ..
} = self;
if let Some(pool) = pools.get_mut(&db_user) {
let old_len = pool.conns.len();
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
let new_len = pool.conns.len();
let removed = old_len - new_len;
if removed > 0 {
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
*total_conns -= removed;
removed > 0
} else {
false
}
}
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
let conn_id = client.conn_id;
if client.is_closed() {
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed");
return;
}
let global_max_conn = pool.read().global_pool_size_max_conns;
if pool.read().total_conns >= global_max_conn {
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full");
return;
}
// return connection to the pool
let mut returned = false;
let mut per_db_size = 0;
let total_conns = {
let mut pool = pool.write();
if pool.total_conns < pool.max_conns {
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
pool_entries.conns.push(ConnPoolEntry {
conn: client,
_last_access: std::time::Instant::now(),
});
returned = true;
per_db_size = pool_entries.conns.len();
pool.total_conns += 1;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.inc();
}
pool.total_conns
};
// do logging outside of the mutex
if returned {
info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
}
}
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
fn drop(&mut self) {
if self.total_conns > 0 {
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.total_conns as i64);
}
}
}
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
conns: Vec<ConnPoolEntry<C>>,
}
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
fn default() -> Self {
Self { conns: Vec::new() }
}
}
impl<C: ClientInnerExt> DbUserConnPool<C> {
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
let old_len = self.conns.len();
self.conns.retain(|conn| !conn.conn.is_closed());
let new_len = self.conns.len();
let removed = old_len - new_len;
*conns -= removed;
removed
}
fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry<C>> {
let mut removed = self.clear_closed_clients(conns);
let conn = self.conns.pop();
if conn.is_some() {
*conns -= 1;
removed += 1;
}
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
conn
}
}
pub(crate) struct LocalConnPool<C: ClientInnerExt> {
global_pool: RwLock<EndpointConnPool<C>>,
config: &'static crate::config::HttpConfig,
}
impl<C: ClientInnerExt> LocalConnPool<C> {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
Arc::new(Self {
global_pool: RwLock::new(EndpointConnPool {
pools: HashMap::new(),
total_conns: 0,
max_conns: config.pool_options.max_conns_per_endpoint,
global_pool_size_max_conns: config.pool_options.max_total_conns,
}),
config,
})
}
pub(crate) fn get_idle_timeout(&self) -> Duration {
self.config.pool_options.idle_timeout
}
// pub(crate) fn shutdown(&self) {
// let mut pool = self.global_pool.write();
// pool.pools.clear();
// pool.total_conns = 0;
// }
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Result<Option<LocalClient<C>>, HttpConnError> {
let mut client: Option<ClientInner<C>> = None;
if let Some(entry) = self
.global_pool
.write()
.get_conn_entry(conn_info.db_and_user())
{
client = Some(entry.conn);
}
// ok return cached connection if found and establish a new one otherwise
if let Some(client) = client {
if client.is_closed() {
info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"local_pool: reusing connection '{conn_info}'"
);
client.session.send(ctx.session_id())?;
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
return Ok(Some(LocalClient::new(
client,
conn_info.clone(),
Arc::downgrade(self),
)));
}
Ok(None)
}
}
pub(crate) fn poll_client(
global_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
client: tokio_postgres::Client,
@@ -51,11 +263,11 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = Arc::downgrade(&local_pool);
let pool = Arc::downgrade(&global_pool);
let pool_clone = pool.clone();
let db_user = conn_info.db_and_user();
let idle = local_pool.get_idle_timeout();
let idle = global_pool.get_idle_timeout();
let cancel = CancellationToken::new();
let cancelled = cancel.clone().cancelled_owned();
@@ -123,7 +335,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_client(db_user.clone(), conn_id) {
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
@@ -160,23 +372,20 @@ struct ClientInner<C: ClientInnerExt> {
jti: u64,
}
pub(crate) struct LocalClient<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInner<C>>,
conn_info: ConnInfo,
pool: Weak<ConnPool<C>>,
impl<C: ClientInnerExt> Drop for ClientInner<C> {
fn drop(&mut self) {
// on client drop, tell the conn to shut down
self.cancel.cancel();
}
}
impl<C: ClientInnerExt> ClientInner<C> {
pub(crate) fn is_closed(&self) -> bool {
self.inner.is_closed()
}
}
impl<C: ClientInnerExt> LocalClient<C> {
pub(self) fn new(inner: ClientInner<C>, conn_info: ConnInfo, pool: Weak<ConnPool<C>>) -> Self {
Self {
inner: Some(inner),
span: Span::current(),
conn_info,
pool,
}
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
let aux = &self.inner.as_ref().unwrap().aux;
USAGE_METRICS.register(Ids {
@@ -184,7 +393,33 @@ impl<C: ClientInnerExt> LocalClient<C> {
branch_id: aux.branch_id,
})
}
}
pub(crate) struct LocalClient<C: ClientInnerExt> {
span: Span,
inner: Option<ClientInner<C>>,
conn_info: ConnInfo,
pool: Weak<LocalConnPool<C>>,
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<LocalConnPool<C>>,
}
impl<C: ClientInnerExt> LocalClient<C> {
pub(self) fn new(
inner: ClientInner<C>,
conn_info: ConnInfo,
pool: Weak<LocalConnPool<C>>,
) -> Self {
Self {
inner: Some(inner),
span: Span::current(),
conn_info,
pool,
}
}
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
let Self {
inner,
@@ -195,7 +430,6 @@ impl<C: ClientInnerExt> LocalClient<C> {
let inner = inner.as_mut().expect("client inner should not be removed");
(&mut inner.inner, Discard { conn_info, pool })
}
pub(crate) fn key(&self) -> &SigningKey {
let inner = &self
.inner
@@ -203,31 +437,6 @@ impl<C: ClientInnerExt> LocalClient<C> {
.expect("client inner should not be removed");
&inner.key
}
pub fn get_client(&self) -> &C {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool.local_pool, &conn_info, client);
});
}
None
}
}
impl LocalClient<tokio_postgres::Client> {
@@ -282,11 +491,6 @@ fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String {
format!("{message}.{base64_sig}")
}
pub(crate) struct Discard<'a, C: ClientInnerExt> {
conn_info: &'a ConnInfo,
pool: &'a mut Weak<ConnPool<C>>,
}
impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
let conn_info = &self.conn_info;
@@ -299,14 +503,38 @@ impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!(
"local_pool: throwing away connection '{conn_info}'
because connection is potentially in a broken state"
);
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
}
}
}
impl<C: ClientInnerExt> LocalClient<C> {
pub fn get_client(&self) -> &C {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
fn do_drop(&mut self) -> Option<impl FnOnce()> {
let conn_info = self.conn_info.clone();
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client);
});
}
None
}
}
impl<C: ClientInnerExt> Drop for LocalClient<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {

View File

@@ -8,6 +8,7 @@ mod conn_pool;
mod http_conn_pool;
mod http_util;
mod json;
mod json_raw_value;
mod local_conn_pool;
mod sql_over_http;
mod websocket;
@@ -48,13 +49,14 @@ use std::pin::{pin, Pin};
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};
use tracing::{info, warn, Instrument};
use utils::http::error::ApiError;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, (), ()>,
ws_listener: TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -64,8 +66,8 @@ pub async fn task_main(
info!("websocket server has shut down");
}
let local_pool = conn_pool::ConnPool::new(&config.http_config);
let conn_pool = conn_pool::ConnPool::new(&config.http_config);
let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config);
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
{
let conn_pool = Arc::clone(&conn_pool);
tokio::spawn(async move {
@@ -110,6 +112,7 @@ pub async fn task_main(
local_pool,
pool: Arc::clone(&conn_pool),
config,
auth_backend,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
});
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
@@ -241,7 +244,7 @@ async fn connection_startup(
let (conn, peer) = match read_proxy_protocol(conn).await {
Ok(c) => c,
Err(e) => {
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
return None;
}
};
@@ -397,6 +400,7 @@ async fn request_handler(
async move {
if let Err(e) = websocket::serve_websocket(
config,
backend.auth_backend,
ctx,
websocket,
cancellation_handler,
@@ -405,7 +409,7 @@ async fn request_handler(
)
.await
{
error!("error in websocket connection: {e:#}");
warn!("error in websocket connection: {e:#}");
}
}
.instrument(span),

View File

@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::pin::pin;
use std::sync::Arc;
@@ -22,7 +23,7 @@ use hyper::StatusCode;
use hyper::{HeaderMap, Request};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
use serde_json::value::RawValue;
use tokio::time;
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
@@ -45,6 +46,7 @@ use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::HttpConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
@@ -75,24 +77,28 @@ use super::local_conn_pool;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
params: Vec<Option<String>>,
#[serde(bound = "'de: 'a")]
struct QueryData<'a> {
#[serde(borrow)]
query: Cow<'a, str>,
#[serde(borrow)]
params: Vec<&'a RawValue>,
#[serde(default)]
array_mode: Option<bool>,
}
#[derive(serde::Deserialize)]
struct BatchQueryData {
queries: Vec<QueryData>,
#[serde(rename_all = "camelCase")]
#[serde(bound = "'de: 'a")]
struct BatchQueryData<'a> {
queries: Vec<QueryData<'a>>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Payload {
Single(QueryData),
Batch(BatchQueryData),
enum Payload<'a> {
Batch(BatchQueryData<'a>),
Single(QueryData<'a>),
}
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
@@ -105,13 +111,18 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
Ok(json_to_pg_text(json))
fn parse_pg_params(params: &[&RawValue]) -> Result<Vec<Option<String>>, ReadPayloadError> {
json_to_pg_text(params).map_err(ReadPayloadError::Parse)
}
fn parse_payload(body: &[u8]) -> Result<Payload<'_>, ReadPayloadError> {
// RawValue doesn't work via untagged enums
// so instead we try parse each individually
if let Ok(batch) = serde_json::from_slice(body) {
Ok(Payload::Batch(batch))
} else {
Ok(Payload::Single(serde_json::from_slice(body)?))
}
}
#[derive(Debug, thiserror::Error)]
@@ -554,7 +565,7 @@ async fn handle_inner(
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
@@ -614,36 +625,24 @@ async fn handle_db_inner(
async {
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
Ok::<Bytes, ReadPayloadError>(body)
}
.map_err(SqlOverHttpError::from),
);
let authenticate_and_connect = Box::pin(
async {
let is_local_proxy =
matches!(backend.config.auth_backend, crate::auth::Backend::Local(_));
let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_));
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
&pw,
)
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await?
}
};
@@ -670,7 +669,7 @@ async fn handle_db_inner(
.map_err(SqlOverHttpError::from),
);
let (payload, mut client) = match run_until_cancelled(
let (body, mut client) = match run_until_cancelled(
// Run both operations in parallel
try_join(
pin!(fetch_and_process_request),
@@ -684,6 +683,8 @@ async fn handle_db_inner(
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
};
let payload = parse_payload(&body)?;
let mut response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json");
@@ -691,7 +692,7 @@ async fn handle_db_inner(
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => {
stmt.process(config, cancel, &mut client, parsed_headers)
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
}
Payload::Batch(statements) => {
@@ -709,7 +710,7 @@ async fn handle_db_inner(
}
statements
.process(config, cancel, &mut client, parsed_headers)
.process(&config.http_config, cancel, &mut client, parsed_headers)
.await?
}
};
@@ -749,7 +750,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
];
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
@@ -757,12 +757,7 @@ async fn handle_auth_broker_inner(
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
.await
.map_err(HttpConnError::from)?;
@@ -797,10 +792,10 @@ async fn handle_auth_broker_inner(
.map(|b| b.boxed()))
}
impl QueryData {
impl QueryData<'_> {
async fn process(
self,
config: &'static ProxyConfig,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client,
parsed_headers: HttpHeaders,
@@ -831,7 +826,7 @@ impl QueryData {
Either::Right((_cancelled, query)) => {
tracing::info!("cancelling query");
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
tracing::warn!(?err, "could not cancel query");
}
// wait for the query cancellation
match time::timeout(time::Duration::from_millis(100), query).await {
@@ -871,10 +866,10 @@ impl QueryData {
}
}
impl BatchQueryData {
impl BatchQueryData<'_> {
async fn process(
self,
config: &'static ProxyConfig,
config: &'static HttpConfig,
cancel: CancellationToken,
client: &mut Client,
parsed_headers: HttpHeaders,
@@ -920,7 +915,7 @@ impl BatchQueryData {
}
Err(SqlOverHttpError::Cancelled(_)) => {
if let Err(err) = cancel_token.cancel_query(NoTls).await {
tracing::error!(?err, "could not cancel query");
tracing::warn!(?err, "could not cancel query");
}
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
discard.discard();
@@ -944,10 +939,10 @@ impl BatchQueryData {
}
async fn query_batch(
config: &'static ProxyConfig,
config: &'static HttpConfig,
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
queries: BatchQueryData<'_>,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let mut results = Vec::with_capacity(queries.queries.len());
@@ -983,14 +978,14 @@ async fn query_batch(
}
async fn query_to_json<T: GenericClient>(
config: &'static ProxyConfig,
config: &'static HttpConfig,
client: &T,
data: QueryData,
data: QueryData<'_>,
current_size: &mut usize,
parsed_headers: HttpHeaders,
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
info!("executing query");
let query_params = data.params;
let query_params = parse_pg_params(&data.params)?;
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
info!("finished executing query");
@@ -1004,9 +999,9 @@ async fn query_to_json<T: GenericClient>(
rows.push(row);
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
if *current_size > config.http_config.max_response_size_bytes {
if *current_size > config.max_response_size_bytes {
return Err(SqlOverHttpError::ResponseTooLarge(
config.http_config.max_response_size_bytes,
config.max_response_size_bytes,
));
}
}
@@ -1116,3 +1111,41 @@ impl Discard<'_> {
}
}
}
#[cfg(test)]
mod tests {
use typed_json::json;
use super::parse_payload;
use super::Payload;
#[test]
fn raw_single_payload() {
let body = json! {
{"query":"select $1","params":["1"]}
}
.to_string();
let Payload::Single(query) = parse_payload(body.as_bytes()).unwrap() else {
panic!("expected single")
};
assert_eq!(&*query.query, "select $1");
assert_eq!(query.params[0].get(), "\"1\"");
}
#[test]
fn raw_batch_payload() {
let body = json! {{
"queries": [
{"query":"select $1","params":["1"]},
{"query":"select $1","params":["2"]},
]
}}
.to_string();
let Payload::Batch(query) = parse_payload(body.as_bytes()).unwrap() else {
panic!("expected batch")
};
assert_eq!(query.queries.len(), 2);
}
}

View File

@@ -129,6 +129,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, (), ()>,
ctx: RequestMonitoring,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandlerMain>,
@@ -145,6 +146,7 @@ pub(crate) async fn serve_websocket(
let res = Box::pin(handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
WebSocketRw::new(websocket),

View File

@@ -27,7 +27,7 @@ use std::{
};
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, instrument, trace};
use tracing::{error, info, instrument, trace, warn};
use utils::backoff;
use uuid::{NoContext, Timestamp};
@@ -346,7 +346,7 @@ async fn collect_metrics_iteration(
error!("metrics endpoint refused the sent metrics: {:?}", res);
for metric in chunk.events.iter().filter(|e| e.value > (1u64 << 40)) {
// Report if the metric value is suspiciously large
error!("potentially abnormal metric value: {:?}", metric);
warn!("potentially abnormal metric value: {:?}", metric);
}
}
}

View File

@@ -317,9 +317,8 @@ pub async fn scan_pageserver_metadata(
tenant_timeline_results.push((ttid, data));
}
let tenant_id = tenant_id.expect("Must be set if results are present");
if !tenant_timeline_results.is_empty() {
let tenant_id = tenant_id.expect("Must be set if results are present");
analyze_tenant(
&remote_client,
tenant_id,

View File

@@ -64,10 +64,12 @@ By default performance tests are excluded. To run them explicitly pass performan
Useful environment variables:
`NEON_BIN`: The directory where neon binaries can be found.
`COMPATIBILITY_NEON_BIN`: The directory where the previous version of Neon binaries can be found
`POSTGRES_DISTRIB_DIR`: The directory where postgres distribution can be found.
Since pageserver supports several postgres versions, `POSTGRES_DISTRIB_DIR` must contain
a subdirectory for each version with naming convention `v{PG_VERSION}/`.
Inside that dir, a `bin/postgres` binary should be present.
`COMPATIBILITY_POSTGRES_DISTRIB_DIR`: The directory where the prevoius version of postgres distribution can be found.
`DEFAULT_PG_VERSION`: The version of Postgres to use,
This is used to construct full path to the postgres binaries.
Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION=16`
@@ -294,6 +296,16 @@ def test_foobar2(neon_env_builder: NeonEnvBuilder):
client.timeline_detail(tenant_id=tenant_id, timeline_id=timeline_id)
```
All the test which rely on NeonEnvBuilder, can check the various version combinations of the components.
To do this yuo may want to add the parametrize decorator with the function fixtures.utils.allpairs_versions()
E.g.
```python
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
def test_something(
...
```
For more information about pytest fixtures, see https://docs.pytest.org/en/stable/fixture.html
At the end of a test, all the nodes in the environment are automatically stopped, so you

View File

@@ -6,6 +6,7 @@ pytest_plugins = (
"fixtures.httpserver",
"fixtures.compute_reconfigure",
"fixtures.storage_controller_proxy",
"fixtures.paths",
"fixtures.neon_fixtures",
"fixtures.benchmark_fixture",
"fixtures.pg_stats",

View File

@@ -185,8 +185,8 @@ class NeonAPI:
def get_connection_uri(
self,
project_id: str,
branch_id: str | None = None,
endpoint_id: str | None = None,
branch_id: Optional[str] = None,
endpoint_id: Optional[str] = None,
database_name: str = "neondb",
role_name: str = "neondb_owner",
pooled: bool = True,
@@ -262,7 +262,7 @@ class NeonAPI:
class NeonApiEndpoint:
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None):
def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: Optional[str]):
self.neon_api = neon_api
if project_id is None:
project = neon_api.create_project(pg_version)

View File

@@ -18,7 +18,6 @@ from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from fcntl import LOCK_EX, LOCK_UN, flock
from functools import cached_property
from pathlib import Path
from types import TracebackType
@@ -59,6 +58,7 @@ from fixtures.pageserver.http import PageserverHttpClient
from fixtures.pageserver.utils import (
wait_for_last_record_lsn,
)
from fixtures.paths import get_test_repo_dir, shared_snapshot_dir
from fixtures.pg_version import PgVersion
from fixtures.port_distributor import PortDistributor
from fixtures.remote_storage import (
@@ -75,8 +75,8 @@ from fixtures.safekeeper.http import SafekeeperHttpClient
from fixtures.safekeeper.utils import wait_walreceivers_absent
from fixtures.utils import (
ATTACHMENT_NAME_REGEX,
COMPONENT_BINARIES,
allure_add_grafana_links,
allure_attach_from_dir,
assert_no_errors,
get_dir_size,
print_gc_result,
@@ -96,6 +96,8 @@ if TYPE_CHECKING:
Union,
)
from fixtures.paths import SnapshotDirLocked
T = TypeVar("T")
@@ -118,65 +120,11 @@ put directly-importable functions into utils.py or another separate file.
Env = dict[str, str]
DEFAULT_OUTPUT_DIR: str = "test_output"
DEFAULT_BRANCH_NAME: str = "main"
BASE_PORT: int = 15000
@pytest.fixture(scope="session")
def base_dir() -> Iterator[Path]:
# find the base directory (currently this is the git root)
base_dir = Path(__file__).parents[2]
log.info(f"base_dir is {base_dir}")
yield base_dir
@pytest.fixture(scope="function")
def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]:
if os.getenv("REMOTE_ENV"):
# we are in remote env and do not have neon binaries locally
# this is the case for benchmarks run on self-hosted runner
return
# Find the neon binaries.
if env_neon_bin := os.environ.get("NEON_BIN"):
binpath = Path(env_neon_bin)
else:
binpath = base_dir / "target" / build_type
log.info(f"neon_binpath is {binpath}")
if not (binpath / "pageserver").exists():
raise Exception(f"neon binaries not found at '{binpath}'")
yield binpath
@pytest.fixture(scope="session")
def pg_distrib_dir(base_dir: Path) -> Iterator[Path]:
if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"):
distrib_dir = Path(env_postgres_bin).resolve()
else:
distrib_dir = base_dir / "pg_install"
log.info(f"pg_distrib_dir is {distrib_dir}")
yield distrib_dir
@pytest.fixture(scope="session")
def top_output_dir(base_dir: Path) -> Iterator[Path]:
# Compute the top-level directory for all tests.
if env_test_output := os.environ.get("TEST_OUTPUT"):
output_dir = Path(env_test_output).resolve()
else:
output_dir = base_dir / DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
log.info(f"top_output_dir is {output_dir}")
yield output_dir
@pytest.fixture(scope="session")
def neon_api_key() -> str:
api_key = os.getenv("NEON_API_KEY")
@@ -369,11 +317,14 @@ class NeonEnvBuilder:
run_id: uuid.UUID,
mock_s3_server: MockS3Server,
neon_binpath: Path,
compatibility_neon_binpath: Path,
pg_distrib_dir: Path,
compatibility_pg_distrib_dir: Path,
pg_version: PgVersion,
test_name: str,
top_output_dir: Path,
test_output_dir: Path,
combination,
test_overlay_dir: Optional[Path] = None,
pageserver_remote_storage: Optional[RemoteStorage] = None,
# toml that will be decomposed into `--config-override` flags during `pageserver --init`
@@ -455,6 +406,19 @@ class NeonEnvBuilder:
"test_"
), "Unexpectedly instantiated from outside a test function"
self.test_name = test_name
self.compatibility_neon_binpath = compatibility_neon_binpath
self.compatibility_pg_distrib_dir = compatibility_pg_distrib_dir
self.version_combination = combination
self.mixdir = self.test_output_dir / "mixdir_neon"
if self.version_combination is not None:
assert (
self.compatibility_neon_binpath is not None
), "the environment variable COMPATIBILITY_NEON_BIN is required when using mixed versions"
assert (
self.compatibility_pg_distrib_dir is not None
), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required when using mixed versions"
self.mixdir.mkdir(mode=0o755, exist_ok=True)
self._mix_versions()
def init_configs(self, default_remote_storage_if_missing: bool = True) -> NeonEnv:
# Cannot create more than one environment from one builder
@@ -655,6 +619,21 @@ class NeonEnvBuilder:
return self.env
def _mix_versions(self):
assert self.version_combination is not None, "version combination must be set"
for component, paths in COMPONENT_BINARIES.items():
directory = (
self.neon_binpath
if self.version_combination[component] == "new"
else self.compatibility_neon_binpath
)
for filename in paths:
destination = self.mixdir / filename
destination.symlink_to(directory / filename)
if self.version_combination["compute"] == "old":
self.pg_distrib_dir = self.compatibility_pg_distrib_dir
self.neon_binpath = self.mixdir
def overlay_mount(self, ident: str, srcdir: Path, dstdir: Path):
"""
Mount `srcdir` as an overlayfs mount at `dstdir`.
@@ -1403,7 +1382,9 @@ def neon_simple_env(
top_output_dir: Path,
test_output_dir: Path,
neon_binpath: Path,
compatibility_neon_binpath: Path,
pg_distrib_dir: Path,
compatibility_pg_distrib_dir: Path,
pg_version: PgVersion,
pageserver_virtual_file_io_engine: str,
pageserver_aux_file_policy: Optional[AuxFileStore],
@@ -1418,6 +1399,11 @@ def neon_simple_env(
# Create the environment in the per-test output directory
repo_dir = get_test_repo_dir(request, top_output_dir)
combination = (
request._pyfuncitem.callspec.params["combination"]
if "combination" in request._pyfuncitem.callspec.params
else None
)
with NeonEnvBuilder(
top_output_dir=top_output_dir,
@@ -1425,7 +1411,9 @@ def neon_simple_env(
port_distributor=port_distributor,
mock_s3_server=mock_s3_server,
neon_binpath=neon_binpath,
compatibility_neon_binpath=compatibility_neon_binpath,
pg_distrib_dir=pg_distrib_dir,
compatibility_pg_distrib_dir=compatibility_pg_distrib_dir,
pg_version=pg_version,
run_id=run_id,
preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")),
@@ -1435,6 +1423,7 @@ def neon_simple_env(
pageserver_aux_file_policy=pageserver_aux_file_policy,
pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm,
pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode,
combination=combination,
) as builder:
env = builder.init_start()
@@ -1448,7 +1437,9 @@ def neon_env_builder(
port_distributor: PortDistributor,
mock_s3_server: MockS3Server,
neon_binpath: Path,
compatibility_neon_binpath: Path,
pg_distrib_dir: Path,
compatibility_pg_distrib_dir: Path,
pg_version: PgVersion,
run_id: uuid.UUID,
request: FixtureRequest,
@@ -1475,6 +1466,11 @@ def neon_env_builder(
# Create the environment in the test-specific output dir
repo_dir = os.path.join(test_output_dir, "repo")
combination = (
request._pyfuncitem.callspec.params["combination"]
if "combination" in request._pyfuncitem.callspec.params
else None
)
# Return the builder to the caller
with NeonEnvBuilder(
@@ -1483,7 +1479,10 @@ def neon_env_builder(
port_distributor=port_distributor,
mock_s3_server=mock_s3_server,
neon_binpath=neon_binpath,
compatibility_neon_binpath=compatibility_neon_binpath,
pg_distrib_dir=pg_distrib_dir,
compatibility_pg_distrib_dir=compatibility_pg_distrib_dir,
combination=combination,
pg_version=pg_version,
run_id=run_id,
preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")),
@@ -4246,44 +4245,6 @@ class StorageScrubber:
raise
def _get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str) -> Path:
"""Compute the path to a working directory for an individual test."""
test_name = request.node.name
test_dir = top_output_dir / f"{prefix}{test_name.replace('/', '-')}"
# We rerun flaky tests multiple times, use a separate directory for each run.
if (suffix := getattr(request.node, "execution_count", None)) is not None:
test_dir = test_dir.parent / f"{test_dir.name}-{suffix}"
log.info(f"get_test_output_dir is {test_dir}")
# make mypy happy
assert isinstance(test_dir, Path)
return test_dir
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
"""
The working directory for a test.
"""
return _get_test_dir(request, top_output_dir, "")
def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
"""
Directory that contains `upperdir` and `workdir` for overlayfs mounts
that a test creates. See `NeonEnvBuilder.overlay_mount`.
"""
return _get_test_dir(request, top_output_dir, "overlay-")
def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path:
return top_output_dir / "shared-snapshots" / snapshot_name
def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
return get_test_output_dir(request, top_output_dir) / "repo"
def pytest_addoption(parser: Parser):
parser.addoption(
"--preserve-database-files",
@@ -4298,149 +4259,6 @@ SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile(
)
# This is autouse, so the test output directory always gets created, even
# if a test doesn't put anything there.
#
# NB: we request the overlay dir fixture so the fixture does its cleanups
@pytest.fixture(scope="function", autouse=True)
def test_output_dir(
request: FixtureRequest, top_output_dir: Path, test_overlay_dir: Path
) -> Iterator[Path]:
"""Create the working directory for an individual test."""
# one directory per test
test_dir = get_test_output_dir(request, top_output_dir)
log.info(f"test_output_dir is {test_dir}")
shutil.rmtree(test_dir, ignore_errors=True)
test_dir.mkdir()
yield test_dir
# Allure artifacts creation might involve the creation of `.tar.zst` archives,
# which aren't going to be used if Allure results collection is not enabled
# (i.e. --alluredir is not set).
# Skip `allure_attach_from_dir` in this case
if not request.config.getoption("--alluredir"):
return
preserve_database_files = False
for k, v in request.node.user_properties:
# NB: the neon_env_builder fixture uses this fixture (test_output_dir).
# So, neon_env_builder's cleanup runs before here.
# The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property.
if k == "preserve_database_files":
assert isinstance(v, bool)
preserve_database_files = v
allure_attach_from_dir(test_dir, preserve_database_files)
class FileAndThreadLock:
def __init__(self, path: Path):
self.path = path
self.thread_lock = threading.Lock()
self.fd: Optional[int] = None
def __enter__(self):
self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY)
# lock thread lock before file lock so that there's no race
# around flocking / funlocking the file lock
self.thread_lock.acquire()
flock(self.fd, LOCK_EX)
def __exit__(self, exc_type, exc_value, exc_traceback):
assert self.fd is not None
assert self.thread_lock.locked() # ... by us
flock(self.fd, LOCK_UN)
self.thread_lock.release()
os.close(self.fd)
self.fd = None
class SnapshotDirLocked:
def __init__(self, parent: SnapshotDir):
self._parent = parent
def is_initialized(self):
# TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized.
# Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed.
return self._parent._marker_file_path.exists()
def set_initialized(self):
self._parent._marker_file_path.write_text("")
@property
def path(self) -> Path:
return self._parent._path / "snapshot"
class SnapshotDir:
_path: Path
def __init__(self, path: Path):
self._path = path
assert self._path.is_dir()
self._lock = FileAndThreadLock(self._lock_file_path)
@property
def _lock_file_path(self) -> Path:
return self._path / "initializing.flock"
@property
def _marker_file_path(self) -> Path:
return self._path / "initialized.marker"
def __enter__(self) -> SnapshotDirLocked:
self._lock.__enter__()
return SnapshotDirLocked(self)
def __exit__(self, exc_type, exc_value, exc_traceback):
self._lock.__exit__(exc_type, exc_value, exc_traceback)
def shared_snapshot_dir(top_output_dir, ident: str) -> SnapshotDir:
snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident)
snapshot_dir_path.mkdir(exist_ok=True, parents=True)
return SnapshotDir(snapshot_dir_path)
@pytest.fixture(scope="function")
def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]:
"""
Idempotently create a test's overlayfs mount state directory.
If the functionality isn't enabled via env var, returns None.
The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc).
"""
if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None:
return None
overlay_dir = get_test_overlay_dir(request, top_output_dir)
log.info(f"test_overlay_dir is {overlay_dir}")
overlay_dir.mkdir(exist_ok=True)
# unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir`
for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)):
cmd = ["sudo", "umount", str(mountpoint)]
log.info(
f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}"
)
subprocess.run(cmd, capture_output=True, check=True)
# the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work.
cmd = ["sudo", "rm", "-rf", str(overlay_dir)]
subprocess.run(cmd, capture_output=True, check=True)
overlay_dir.mkdir()
return overlay_dir
# no need to clean up anything: on clean shutdown,
# NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup
# and on unclean shutdown, this function will take care of it
# on the next test run
SKIP_DIRS = frozenset(
(
"pg_wal",

View File

@@ -886,7 +886,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
self,
tenant_id: Union[TenantId, TenantShardId],
timeline_id: TimelineId,
batch_size: int | None = None,
batch_size: Optional[int] = None,
**kwargs,
) -> set[TimelineId]:
params = {}

View File

@@ -0,0 +1,312 @@
from __future__ import annotations
import os
import shutil
import subprocess
import threading
from fcntl import LOCK_EX, LOCK_UN, flock
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING
import pytest
from pytest import FixtureRequest
from fixtures import overlayfs
from fixtures.log_helper import log
from fixtures.utils import allure_attach_from_dir
if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Optional
DEFAULT_OUTPUT_DIR: str = "test_output"
def get_test_dir(
request: FixtureRequest, top_output_dir: Path, prefix: Optional[str] = None
) -> Path:
"""Compute the path to a working directory for an individual test."""
test_name = request.node.name
test_dir = top_output_dir / f"{prefix or ''}{test_name.replace('/', '-')}"
# We rerun flaky tests multiple times, use a separate directory for each run.
if (suffix := getattr(request.node, "execution_count", None)) is not None:
test_dir = test_dir.parent / f"{test_dir.name}-{suffix}"
return test_dir
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
"""
The working directory for a test.
"""
return get_test_dir(request, top_output_dir)
def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
"""
Directory that contains `upperdir` and `workdir` for overlayfs mounts
that a test creates. See `NeonEnvBuilder.overlay_mount`.
"""
return get_test_dir(request, top_output_dir, "overlay-")
def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path:
return top_output_dir / "shared-snapshots" / snapshot_name
def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
return get_test_output_dir(request, top_output_dir) / "repo"
@pytest.fixture(scope="session")
def base_dir() -> Iterator[Path]:
# find the base directory (currently this is the git root)
base_dir = Path(__file__).parents[2]
log.info(f"base_dir is {base_dir}")
yield base_dir
@pytest.fixture(scope="session")
def compute_config_dir(base_dir: Path) -> Iterator[Path]:
"""
Retrieve the path to the compute configuration directory.
"""
yield base_dir / "compute" / "etc"
@pytest.fixture(scope="function")
def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]:
if os.getenv("REMOTE_ENV"):
# we are in remote env and do not have neon binaries locally
# this is the case for benchmarks run on self-hosted runner
return
# Find the neon binaries.
if env_neon_bin := os.environ.get("NEON_BIN"):
binpath = Path(env_neon_bin)
else:
binpath = base_dir / "target" / build_type
log.info(f"neon_binpath is {binpath}")
if not (binpath / "pageserver").exists():
raise Exception(f"neon binaries not found at '{binpath}'")
yield binpath.absolute()
@pytest.fixture(scope="session")
def compatibility_snapshot_dir() -> Iterator[Path]:
if os.getenv("REMOTE_ENV"):
return
compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR")
assert (
compatibility_snapshot_dir_env is not None
), "COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg(PG_VERSION)` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)"
compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve()
yield compatibility_snapshot_dir
@pytest.fixture(scope="session")
def compatibility_neon_binpath() -> Optional[Iterator[Path]]:
if os.getenv("REMOTE_ENV"):
return
comp_binpath = None
if env_compatibility_neon_binpath := os.environ.get("COMPATIBILITY_NEON_BIN"):
comp_binpath = Path(env_compatibility_neon_binpath).resolve().absolute()
yield comp_binpath
@pytest.fixture(scope="session")
def pg_distrib_dir(base_dir: Path) -> Iterator[Path]:
if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"):
distrib_dir = Path(env_postgres_bin).resolve()
else:
distrib_dir = base_dir / "pg_install"
log.info(f"pg_distrib_dir is {distrib_dir}")
yield distrib_dir
@pytest.fixture(scope="session")
def compatibility_pg_distrib_dir() -> Optional[Iterator[Path]]:
compat_distrib_dir = None
if env_compat_postgres_bin := os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR"):
compat_distrib_dir = Path(env_compat_postgres_bin).resolve()
if not compat_distrib_dir.exists():
raise Exception(f"compatibility postgres directory not found at {compat_distrib_dir}")
if compat_distrib_dir:
log.info(f"compatibility_pg_distrib_dir is {compat_distrib_dir}")
yield compat_distrib_dir
@pytest.fixture(scope="session")
def top_output_dir(base_dir: Path) -> Iterator[Path]:
# Compute the top-level directory for all tests.
if env_test_output := os.environ.get("TEST_OUTPUT"):
output_dir = Path(env_test_output).resolve()
else:
output_dir = base_dir / DEFAULT_OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
log.info(f"top_output_dir is {output_dir}")
yield output_dir
# This is autouse, so the test output directory always gets created, even
# if a test doesn't put anything there.
#
# NB: we request the overlay dir fixture so the fixture does its cleanups
@pytest.fixture(scope="function", autouse=True)
def test_output_dir(request: pytest.FixtureRequest, top_output_dir: Path) -> Iterator[Path]:
"""Create the working directory for an individual test."""
# one directory per test
test_dir = get_test_output_dir(request, top_output_dir)
log.info(f"test_output_dir is {test_dir}")
shutil.rmtree(test_dir, ignore_errors=True)
test_dir.mkdir()
yield test_dir
# Allure artifacts creation might involve the creation of `.tar.zst` archives,
# which aren't going to be used if Allure results collection is not enabled
# (i.e. --alluredir is not set).
# Skip `allure_attach_from_dir` in this case
if not request.config.getoption("--alluredir"):
return
preserve_database_files = False
for k, v in request.node.user_properties:
# NB: the neon_env_builder fixture uses this fixture (test_output_dir).
# So, neon_env_builder's cleanup runs before here.
# The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property.
if k == "preserve_database_files":
assert isinstance(v, bool)
preserve_database_files = v
allure_attach_from_dir(test_dir, preserve_database_files)
class FileAndThreadLock:
def __init__(self, path: Path):
self.path = path
self.thread_lock = threading.Lock()
self.fd: Optional[int] = None
def __enter__(self):
self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY)
# lock thread lock before file lock so that there's no race
# around flocking / funlocking the file lock
self.thread_lock.acquire()
flock(self.fd, LOCK_EX)
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
exc_traceback: Optional[TracebackType],
):
assert self.fd is not None
assert self.thread_lock.locked() # ... by us
flock(self.fd, LOCK_UN)
self.thread_lock.release()
os.close(self.fd)
self.fd = None
class SnapshotDirLocked:
def __init__(self, parent: SnapshotDir):
self._parent = parent
def is_initialized(self):
# TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized.
# Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed.
return self._parent.marker_file_path.exists()
def set_initialized(self):
self._parent.marker_file_path.write_text("")
@property
def path(self) -> Path:
return self._parent.path / "snapshot"
class SnapshotDir:
_path: Path
def __init__(self, path: Path):
self._path = path
assert self._path.is_dir()
self._lock = FileAndThreadLock(self.lock_file_path)
@property
def path(self) -> Path:
return self._path
@property
def lock_file_path(self) -> Path:
return self._path / "initializing.flock"
@property
def marker_file_path(self) -> Path:
return self._path / "initialized.marker"
def __enter__(self) -> SnapshotDirLocked:
self._lock.__enter__()
return SnapshotDirLocked(self)
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
exc_traceback: Optional[TracebackType],
):
self._lock.__exit__(exc_type, exc_value, exc_traceback)
def shared_snapshot_dir(top_output_dir: Path, ident: str) -> SnapshotDir:
snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident)
snapshot_dir_path.mkdir(exist_ok=True, parents=True)
return SnapshotDir(snapshot_dir_path)
@pytest.fixture(scope="function")
def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]:
"""
Idempotently create a test's overlayfs mount state directory.
If the functionality isn't enabled via env var, returns None.
The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc).
"""
if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None:
return None
overlay_dir = get_test_overlay_dir(request, top_output_dir)
log.info(f"test_overlay_dir is {overlay_dir}")
overlay_dir.mkdir(exist_ok=True)
# unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir`
for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)):
cmd = ["sudo", "umount", str(mountpoint)]
log.info(
f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}"
)
subprocess.run(cmd, capture_output=True, check=True)
# the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work.
cmd = ["sudo", "rm", "-rf", str(overlay_dir)]
subprocess.run(cmd, capture_output=True, check=True)
overlay_dir.mkdir()
return overlay_dir
# no need to clean up anything: on clean shutdown,
# NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup
# and on unclean shutdown, this function will take care of it
# on the next test run

View File

@@ -37,6 +37,23 @@ if TYPE_CHECKING:
Fn = TypeVar("Fn", bound=Callable[..., Any])
COMPONENT_BINARIES = {
"storage_controller": ("storage_controller",),
"storage_broker": ("storage_broker",),
"compute": ("compute_ctl",),
"safekeeper": ("safekeeper",),
"pageserver": ("pageserver", "pagectl"),
}
# Disable auto-formatting for better readability
# fmt: off
VERSIONS_COMBINATIONS = (
{"storage_controller": "new", "storage_broker": "new", "compute": "new", "safekeeper": "new", "pageserver": "new"},
{"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "old"},
{"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "new"},
{"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "new", "pageserver": "new"},
{"storage_controller": "old", "storage_broker": "old", "compute": "new", "safekeeper": "new", "pageserver": "new"},
)
# fmt: on
def subprocess_capture(
@@ -607,3 +624,19 @@ def human_bytes(amt: float) -> str:
amt = amt / 1024
raise RuntimeError("unreachable")
def allpairs_versions():
"""
Returns a dictionary with arguments for pytest parametrize
to test the compatibility with the previous version of Neon components
combinations were pre-computed to test all the pairs of the components with
the different versions.
"""
ids = []
for pair in VERSIONS_COMBINATIONS:
cur_id = []
for component in sorted(pair.keys()):
cur_id.append(pair[component][0])
ids.append(f"combination_{''.join(cur_id)}")
return {"argnames": "combination", "argvalues": VERSIONS_COMBINATIONS, "ids": ids}

View File

@@ -9,6 +9,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import fixtures.utils
import pytest
import toml
from fixtures.common_types import TenantId, TimelineId
@@ -93,6 +94,34 @@ if TYPE_CHECKING:
# # Run forward compatibility test
# ./scripts/pytest -k test_forward_compatibility
#
#
# How to run `test_version_mismatch` locally:
#
# export DEFAULT_PG_VERSION=16
# export BUILD_TYPE=release
# export CHECK_ONDISK_DATA_COMPATIBILITY=true
# export COMPATIBILITY_NEON_BIN=neon_previous/target/${BUILD_TYPE}
# export COMPATIBILITY_POSTGRES_DISTRIB_DIR=neon_previous/pg_install
# export NEON_BIN=target/release
# export POSTGRES_DISTRIB_DIR=pg_install
#
# # Build previous version of binaries and store them somewhere:
# rm -rf pg_install target
# git checkout <previous version>
# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc`
# mkdir -p neon_previous/target
# cp -a target/${BUILD_TYPE} ./neon_previous/target/${BUILD_TYPE}
# cp -a pg_install ./neon_previous/pg_install
#
# # Build current version of binaries and create a data snapshot:
# rm -rf pg_install target
# git checkout <current version>
# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc`
# ./scripts/pytest -k test_create_snapshot
#
# # Run the version mismatch test
# ./scripts/pytest -k test_version_mismatch
check_ondisk_data_compatibility_if_enabled = pytest.mark.skipif(
os.environ.get("CHECK_ONDISK_DATA_COMPATIBILITY") is None,
@@ -166,16 +195,11 @@ def test_backward_compatibility(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
pg_version: PgVersion,
compatibility_snapshot_dir: Path,
):
"""
Test that the new binaries can read old data
"""
compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR")
assert (
compatibility_snapshot_dir_env is not None
), f"COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg{pg_version.v_prefixed}` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)"
compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve()
breaking_changes_allowed = (
os.environ.get("ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true"
)
@@ -214,27 +238,11 @@ def test_forward_compatibility(
test_output_dir: Path,
top_output_dir: Path,
pg_version: PgVersion,
compatibility_snapshot_dir: Path,
):
"""
Test that the old binaries can read new data
"""
compatibility_neon_bin_env = os.environ.get("COMPATIBILITY_NEON_BIN")
assert compatibility_neon_bin_env is not None, (
"COMPATIBILITY_NEON_BIN is not set. It should be set to a path with Neon binaries "
"(ideally generated by the previous version of Neon)"
)
compatibility_neon_bin = Path(compatibility_neon_bin_env).resolve()
compatibility_postgres_distrib_dir_env = os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR")
assert (
compatibility_postgres_distrib_dir_env is not None
), "COMPATIBILITY_POSTGRES_DISTRIB_DIR is not set. It should be set to a pg_install directrory (ideally generated by the previous version of Neon)"
compatibility_postgres_distrib_dir = Path(compatibility_postgres_distrib_dir_env).resolve()
compatibility_snapshot_dir = (
top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}"
)
breaking_changes_allowed = (
os.environ.get("ALLOW_FORWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true"
)
@@ -245,9 +253,14 @@ def test_forward_compatibility(
# Use previous version's production binaries (pageserver, safekeeper, pg_distrib_dir, etc.).
# But always use the current version's neon_local binary.
# This is because we want to test the compatibility of the data format, not the compatibility of the neon_local CLI.
neon_env_builder.neon_binpath = compatibility_neon_bin
neon_env_builder.pg_distrib_dir = compatibility_postgres_distrib_dir
neon_env_builder.neon_local_binpath = neon_env_builder.neon_local_binpath
assert (
neon_env_builder.compatibility_neon_binpath is not None
), "the environment variable COMPATIBILITY_NEON_BIN is required"
assert (
neon_env_builder.compatibility_pg_distrib_dir is not None
), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required"
neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath
neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir
env = neon_env_builder.from_repo_dir(
compatibility_snapshot_dir / "repo",
@@ -558,3 +571,29 @@ def test_historic_storage_formats(
env.pageserver.http_client().timeline_compact(
dataset.tenant_id, existing_timeline_id, force_image_layer_creation=True
)
@check_ondisk_data_compatibility_if_enabled
@pytest.mark.xdist_group("compatibility")
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
def test_versions_mismatch(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
pg_version: PgVersion,
compatibility_snapshot_dir,
combination,
):
"""
Checks compatibility of different combinations of versions of the components
"""
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.from_repo_dir(
compatibility_snapshot_dir / "repo",
)
env.pageserver.allowed_errors.extend(
[".*ingesting record with timestamp lagging more than wait_lsn_timeout.+"]
)
env.start()
check_neon_works(
env, test_output_dir, compatibility_snapshot_dir / "dump.sql", test_output_dir / "repo"
)

View File

@@ -162,6 +162,11 @@ def test_cli_start_stop_multi(neon_env_builder: NeonEnvBuilder):
env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID)
env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID + 1)
# We will stop the storage controller while it may have requests in
# flight, and the pageserver complains when requests are abandoned.
for ps in env.pageservers:
ps.allowed_errors.append(".*request was dropped before completing.*")
# Keep NeonEnv state up to date, it usually owns starting/stopping services
env.pageservers[0].running = False
env.pageservers[1].running = False

View File

@@ -9,6 +9,7 @@ from datetime import datetime, timezone
from enum import Enum
from typing import TYPE_CHECKING
import fixtures.utils
import pytest
from fixtures.auth_tokens import TokenScope
from fixtures.common_types import TenantId, TenantShardId, TimelineId
@@ -38,7 +39,11 @@ from fixtures.pg_version import PgVersion, run_only_on_default_postgres
from fixtures.port_distributor import PortDistributor
from fixtures.remote_storage import RemoteStorageKind, s3_storage
from fixtures.storage_controller_proxy import StorageControllerProxy
from fixtures.utils import run_pg_bench_small, subprocess_capture, wait_until
from fixtures.utils import (
run_pg_bench_small,
subprocess_capture,
wait_until,
)
from fixtures.workload import Workload
from mypy_boto3_s3.type_defs import (
ObjectTypeDef,
@@ -60,9 +65,8 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids):
return counts
def test_storage_controller_smoke(
neon_env_builder: NeonEnvBuilder,
):
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
def test_storage_controller_smoke(neon_env_builder: NeonEnvBuilder, combination):
"""
Test the basic lifecycle of a storage controller:
- Restarting
@@ -1300,11 +1304,11 @@ def test_storage_controller_heartbeats(
node_to_tenants = build_node_to_tenants_map(env)
log.info(f"Back online: {node_to_tenants=}")
# ... expecting the storage controller to reach a consistent state
def storage_controller_consistent():
env.storage_controller.consistency_check()
# ... background reconciliation may need to run to clean up the location on the node that was offline
env.storage_controller.reconcile_until_idle()
wait_until(30, 1, storage_controller_consistent)
# ... expecting the storage controller to reach a consistent state
env.storage_controller.consistency_check()
def test_storage_controller_re_attach(neon_env_builder: NeonEnvBuilder):