Compare commits

...

23 Commits

Author SHA1 Message Date
Em Sharnoff
e614a95853 Merge pull request #5610 from neondatabase/sharnoff/rc-2023-10-20-vm-monitor-fixes
Release 2023-10-20: vm-monitor memory.high throttling fixes
2023-10-20 00:11:06 -07:00
Em Sharnoff
850db4cc13 vm-monitor: Deny not fail downscale if no memory stats yet (#5606)
Fixes an issue we observed on staging that happens when the
autoscaler-agent attempts to immediately downscale the VM after binding,
which is typical for pooled computes.

The issue was occurring because the autoscaler-agent was requesting
downscaling before the vm-monitor had gathered sufficient cgroup memory
stats to be confident in approving it. When the vm-monitor returned an
internal error instead of denying downscaling, the autoscaler-agent
retried the connection and immediately hit the same issue (in part
because cgroup stats are collected per-connection, rather than
globally).
2023-10-19 21:56:55 -07:00
Em Sharnoff
8a316b1277 vm-monitor: Log full error on message handling failure (#5604)
There's currently an issue with the vm-monitor on staging that's not
really feasible to debug because the current display impl gives no
context to the errors (just says "failed to downscale").

Logging the full error should help.

For communications with the autoscaler-agent, it's ok to only provide
the outermost cause, because we can cross-reference with the VM logs.
At some point in the future, we may want to change that.
2023-10-19 21:56:50 -07:00
Em Sharnoff
4d13bae449 vm-monitor: Switch from memory.high to polling memory.stat (#5524)
tl;dr it's really hard to avoid throttling from memory.high, and it
counts tmpfs & page cache usage, so it's also hard to make sense of.

In the interest of fixing things quickly with something that should be
*good enough*, this PR switches to instead periodically fetch memory
statistics from the cgroup's memory.stat and use that data to determine
if and when we should upscale.

This PR fixes #5444, which has a lot more detail on the difficulties
we've hit with memory.high. This PR also supersedes #5488.
2023-10-19 21:56:36 -07:00
Vadim Kharitonov
49377abd98 Merge pull request #5577 from neondatabase/releases/2023-10-17
Release 2023-10-17
2023-10-17 12:21:20 +02:00
Christian Schwarz
a6b2f4e54e limit imitate accesses concurrency, using same semaphore as compactions (#5578)
Before this PR, when we restarted pageserver, we'd see a rush of
`$number_of_tenants` concurrent eviction tasks starting to do imitate
accesses building up in the period of `[init_order allows activations,
$random_access_delay + EvictionPolicyLayerAccessThreshold::period]`.

We simply cannot handle that degree of concurrent IO.

We already solved the problem for compactions by adding a semaphore.
So, this PR shares that semaphore for use by evictions.

Part of https://github.com/neondatabase/neon/issues/5479

Which is again part of https://github.com/neondatabase/neon/issues/4743

Risks / Changes In System Behavior
==================================

* we don't do evictions as timely as we currently do
* we log a bunch of warnings about eviction taking too long
* imitate accesses and compactions compete for the same concurrency
limit, so, they'll slow each other down through this shares semaphore

Changes
=======

- Move the `CONCURRENT_COMPACTIONS` semaphore into `tasks.rs`
- Rename it to `CONCURRENT_BACKGROUND_TASKS`
- Use it also for the eviction imitate accesses:
    - Imitate acceses are both per-TIMELINE and per-TENANT
    - The per-TENANT is done through coalescing all the per-TIMELINE
      tasks via a tokio mutex `eviction_task_tenant_state`.
    - We acquire the CONCURRENT_BACKGROUND_TASKS permit early, at the
      beginning of the eviction iteration, much before the imitate
      acesses start (and they may not even start at all in the given
      iteration, as they happen only every $threshold).
    - Acquiring early is **sub-optimal** because when the per-timline
      tasks coalesce on the `eviction_task_tenant_state` mutex,
      they are already holding a CONCURRENT_BACKGROUND_TASKS permit.
    - It's also unfair because tenants with many timelines win
      the CONCURRENT_BACKGROUND_TASKS more often.
    - I don't think there's another way though, without refactoring
      more of the imitate accesses logic, e.g, making it all per-tenant.
- Add metrics for queue depth behind the semaphore.
I found these very useful to understand what work is queued in the
system.

    - The metrics are tagged by the new `BackgroundLoopKind`.
    - On a green slate, I would have used `TaskKind`, but we already had
      pre-existing labels whose names didn't map exactly to task kind.
      Also the task kind is kind of a lower-level detail, so, I think
it's fine to have a separate enum to identify background work kinds.

Future Work
===========

I guess I could move the eviction tasks from a ticker to "sleep for
$period".
The benefit would be that the semaphore automatically "smears" the
eviction task scheduling over time, so, we only have the rush on restart
but a smeared-out rush afterward.

The downside is that this perverts the meaning of "$period", as we'd
actually not run the eviction at a fixed period. It also means the the
"took to long" warning & metric becomes meaningless.

Then again, that is already the case for the compaction and gc tasks,
which do sleep for `$period` instead of using a ticker.

(cherry picked from commit 9256788273)
2023-10-17 12:16:26 +02:00
Arpad Müller
3666df6342 azure_blob.rs: use division instead of left shift (#5572)
Should have been a right shift but I did a left shift. It's constant
folded anyways so we just use a shift.
2023-10-16 19:52:07 +01:00
Alexey Kondratov
0ca342260c [compute_ctl+pgxn] Handle invalid databases after failed drop (#5561)
## Problem

In 89275f6c1e we fixed an issue, when we were dropping db in Postgres
even though cplane request failed. Yet, it introduced a new problem that
we now de-register db in cplane even if we didn't actually drop it in
Postgres.

## Summary of changes

Here we revert extension change, so we now again may leave db in invalid
state after failed drop. Instead, `compute_ctl` is now responsible for
cleaning up invalid databases during full configuration. Thus, there are
two ways of recovering from failed DROP DATABASE:
1. User can just repeat DROP DATABASE, same as in Vanilla Postgres.
2. If they didn't, then on next full configuration (dbs / roles changes
   in the API; password reset; or data availability check) invalid db
   will be cleaned up in the Postgres and re-created by `compute_ctl`. So
   again it follows pretty much the same semantics as Vanilla Postgres --
   you need to drop it again after failed drop.

That way, we have a recovery trajectory for both problems.

See this commit for info about `invalid` db state:
  a4b4cc1d60

According to it:
> An invalid database cannot be connected to anymore, but can still be
dropped.

While on it, this commit also fixes another issue, when `compute_ctl`
was trying to connect to databases with `ALLOW CONNECTIONS false`. Now
it will just skip them.

Fixes #5435
2023-10-16 20:46:45 +02:00
John Spray
ded7f48565 pageserver: measure startup duration spent fetching remote indices (#5564)
## Problem

Currently it's unclear how much of the `initial_tenant_load` period is
in S3 objects, and therefore how impactful it is to make changes to
remote operations during startup.

## Summary of changes

- `Tenant::load` is refactored to load remote indices in parallel and to
wait for all these remote downloads to finish before it proceeds to
construct any `Timeline` objects.
- `pageserver_startup_duration_seconds` gets a new `phase` value of
`initial_tenant_load_remote` which counts the time from startup to when
the last tenant finishes loading remote content.
- `test_pageserver_restart` is extended to validate this phase. The
previous version of the test was relying on order of dict entries, which
stopped working when adding a phase, so this is refactored a bit.
- `test_pageserver_restart` used to explicitly create a branch, now it
uses the default initial_timeline. This avoids startup getting held up
waiting for logical sizes, when one of the branches is not in use.
2023-10-16 18:21:37 +01:00
Arpad Müller
e09d5ada6a Azure blob storage support (#5546)
Adds prototype-level support for [Azure blob storage](https://azure.microsoft.com/en-us/products/storage/blobs). Some corners were cut, see the TODOs and the followup issue #5567 for details.

Steps to try it out:

* Create a storage account with block blobs (this is a per-storage
account setting).
* Create a container inside that storage account.
* Set the appropriate env vars: `AZURE_STORAGE_ACCOUNT,
AZURE_STORAGE_ACCESS_KEY, REMOTE_STORAGE_AZURE_CONTAINER,
REMOTE_STORAGE_AZURE_REGION`
* Set the env var `ENABLE_REAL_AZURE_REMOTE_STORAGE=y` and run `cargo
test -p remote_storage azure`

Fixes  #5562
2023-10-16 17:37:09 +02:00
Conrad Ludgate
8c522ea034 proxy: count cache-miss for compute latency (#5539)
## Problem

Would be good to view latency for hot-path vs cold-path

## Summary of changes

add some labels to latency metrics
2023-10-16 16:31:04 +01:00
John Spray
44b1c4c456 pageserver: fix eviction across generations (#5538)
## Problem

Bug was introduced by me in 83ae2bd82c

When eviction constructs a RemoteLayer to replace the layer it just
evicted, it is building a LayerFileMetadata using its _current_
generation, rather than the generation of the layer.

## Summary of changes

- Retrieve Generation from RemoteTimelineClient when evicting. This will
no longer be necessary when #4938 lands.
- Add a test for the scenario in question (this fails without the fix).
2023-10-15 20:23:18 +01:00
Christian Schwarz
99c15907c1 walredo: trim public interfaces (#5556)
Stacked atop https://github.com/neondatabase/neon/pull/5554.
2023-10-13 19:35:53 +01:00
Christian Schwarz
c3626e3432 walredo: remove legacy wal-redo-datadir cleanup code (#5554)
It says it in the comment.
2023-10-13 19:16:15 +01:00
Christian Schwarz
dd6990567f walredo: apply_batch_postgres: get a backtrace whenever it encounters an error (#5541)
For 2 weeks we've seen rare, spurious, not-reproducible page
reconstruction
failures with PG16 in prod.

One of the commits we deployed this week was

Commit

    commit fc467941f9
    Author: Joonas Koivunen <joonas@neon.tech>
    Date:   Wed Oct 4 16:19:19 2023 +0300

        walredo: log retryed error (#546)

With the logs from that commit, we learned that some read() or write()
system call that walredo does fails with `EAGAIN`, aka
`Resource temporarily unavailable (os error 11)`.

But we have no idea where exactly in the code we get back that error.

So, use anyhow instead of fake std::io::Error's as an easy way to get
a backtrace when the error happens, and change the logging to print
that backtrace (i.e., use `{:?}` instead of
`utils::error::report_compact_sources(e)`).

The `WalRedoError` type had to go because we add additional `.context()`
further up the call chain before we `{:?}`-print it. That additional
`.context()` further up doesn't see that there's already an
anyhow::Error
inside the `WalRedoError::ApplyWalRecords` variant, and hence captures
another backtrace and prints that one on `{:?}`-print instead of the
original one inside `WalRedoError::ApplyWalRecords`.

If we ever switch back to `report_compact_sources`, we should make sure
we have some other way to uniquely identify the places where we return
an error in the error message.
2023-10-13 14:08:23 +00:00
khanova
21deb81acb Fix case for array of jsons (#5523)
## Problem

Currently proxy doesn't handle array of json parameters correctly.

## Summary of changes

Added one more level of quotes escaping for the array of jsons case.
Resolves: https://github.com/neondatabase/neon/issues/5515
2023-10-12 14:32:49 +02:00
khanova
dbb21d6592 Make http timeout configurable (#5532)
## Problem

Currently http timeout is hardcoded to 15 seconds.

## Summary of changes

Added an option to configure it via cli args.

Context: https://neondb.slack.com/archives/C04DGM6SMTM/p1696941726151899
2023-10-12 11:41:07 +02:00
Joonas Koivunen
ddceb9e6cd fix(branching): read last record lsn only after Tenant::gc_cs (#5535)
Fixes #5531, at least the latest error of not being able to create a
branch from the head under write and gc pressure.
2023-10-11 16:24:36 +01:00
John Spray
0fc3708de2 pageserver: use a backoff::retry in Deleter (#5534)
## Problem

The `Deleter` currently doesn't use a backoff::retry because it doesn't
need to: it is already inside a loop when doing the deletion, so can
just let the loop go around.

However, this is a problem for logging, because we log on errors, which
includes things like 503/429 cases that would usually be swallowed by a
backoff::retry in most places we use the RemoteStorage interface.

The underlying problem is that RemoteStorage doesn't have a proper error
type, and an anyhow::Error can't easily be interrogated for its original
S3 SdkError because downcast_ref requires a concrete type, but SdkError
is parametrized on response type.

## Summary of changes

Wrap remote deletions in Deleter in a backoff::retry to avoid logging
warnings on transient 429/503 conditions, and for symmetry with how
RemoteStorage is used in other places.
2023-10-11 15:25:08 +01:00
John Spray
e0c8ad48d4 remote_storage: log detail errors in delete_objects (#5530)
## Problem

When we got an error in the payload of a DeleteObjects response, we only
logged how many errors, not what they were.

## Summary of changes

Log up to 10 specific errors. We do not log all of them because that
would be up to 1000 log lines per request.
2023-10-11 13:22:00 +01:00
John Spray
39e144696f pageserver: clean up mgr.rs types that needn't be public (#5529)
## Problem

These types/functions are public and it prevents clippy from catching
unused things.

## Summary of changes

Move to `pub(crate)` and remove the error enum that becomes clearly
unused as a result.
2023-10-11 11:50:16 +00:00
Alexander Bayandin
653044f754 test_runners: increase some timeouts to make tests less flaky (#5521)
## Problem
- `test_heavy_write_workload` is flaky, and fails because of to
statement timeout
- `test_wal_lagging` is flaky and fails because of the default pytest
timeout (see https://github.com/neondatabase/neon/issues/5305)

## Summary of changes
- `test_heavy_write_workload`: increase statement timeout to 5 minutes
(from default 2 minutes)
- `test_wal_lagging`: increase pytest timeout to 600s (from default
300s)
2023-10-11 10:49:15 +01:00
Vadim Kharitonov
80dcdfa8bf Update pgvector to 0.5.1 (#5525) 2023-10-11 09:47:19 +01:00
51 changed files with 3372 additions and 1217 deletions

946
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -36,6 +36,9 @@ license = "Apache-2.0"
[workspace.dependencies]
anyhow = { version = "1.0", features = ["backtrace"] }
async-compression = { version = "0.4.0", features = ["tokio", "gzip"] }
azure_core = "0.16.0"
azure_storage = "0.16"
azure_storage_blobs = "0.16.0"
flate2 = "1.0.26"
async-stream = "0.3"
async-trait = "0.1"
@@ -76,6 +79,7 @@ hex = "0.4"
hex-literal = "0.4"
hmac = "0.12.1"
hostname = "0.3.1"
http-types = "2"
humantime = "2.1"
humantime-serde = "1.1.1"
hyper = "0.14"

View File

@@ -224,8 +224,8 @@ RUN wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.7.tar.gz -
FROM build-deps AS vector-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.5.0.tar.gz -O pgvector.tar.gz && \
echo "d8aa3504b215467ca528525a6de12c3f85f9891b091ce0e5864dd8a9b757f77b pgvector.tar.gz" | sha256sum --check && \
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.5.1.tar.gz -O pgvector.tar.gz && \
echo "cc7a8e034a96e30a819911ac79d32f6bc47bdd1aa2de4d7d4904e26b83209dc8 pgvector.tar.gz" | sha256sum --check && \
mkdir pgvector-src && cd pgvector-src && tar xvzf ../pgvector.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \

View File

@@ -692,10 +692,11 @@ impl ComputeNode {
// Proceed with post-startup configuration. Note, that order of operations is important.
let spec = &compute_state.pspec.as_ref().expect("spec must be set").spec;
create_neon_superuser(spec, &mut client)?;
cleanup_instance(&mut client)?;
handle_roles(spec, &mut client)?;
handle_databases(spec, &mut client)?;
handle_role_deletions(spec, self.connstr.as_str(), &mut client)?;
handle_grants(spec, self.connstr.as_str())?;
handle_grants(spec, &mut client, self.connstr.as_str())?;
handle_extensions(spec, &mut client)?;
create_availability_check_data(&mut client)?;
@@ -731,10 +732,11 @@ impl ComputeNode {
// Disable DDL forwarding because control plane already knows about these roles/databases.
if spec.mode == ComputeMode::Primary {
client.simple_query("SET neon.forward_ddl = false")?;
cleanup_instance(&mut client)?;
handle_roles(&spec, &mut client)?;
handle_databases(&spec, &mut client)?;
handle_role_deletions(&spec, self.connstr.as_str(), &mut client)?;
handle_grants(&spec, self.connstr.as_str())?;
handle_grants(&spec, &mut client, self.connstr.as_str())?;
handle_extensions(&spec, &mut client)?;
}

View File

@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::fmt::Write;
use std::fs;
use std::fs::File;
@@ -205,22 +206,37 @@ pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
}
/// Build a list of existing Postgres databases
pub fn get_existing_dbs(client: &mut Client) -> Result<Vec<Database>> {
let postgres_dbs = client
pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
// `pg_database.datconnlimit = -2` means that the database is in the
// invalid state. See:
// https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
let postgres_dbs: Vec<Database> = client
.query(
"SELECT datname, datdba::regrole::text as owner
FROM pg_catalog.pg_database;",
"SELECT
datname AS name,
datdba::regrole::text AS owner,
NOT datallowconn AS restrict_conn,
datconnlimit = - 2 AS invalid
FROM
pg_catalog.pg_database;",
&[],
)?
.iter()
.map(|row| Database {
name: row.get("datname"),
name: row.get("name"),
owner: row.get("owner"),
restrict_conn: row.get("restrict_conn"),
invalid: row.get("invalid"),
options: None,
})
.collect();
Ok(postgres_dbs)
let dbs_map = postgres_dbs
.iter()
.map(|db| (db.name.clone(), db.clone()))
.collect::<HashMap<_, _>>();
Ok(dbs_map)
}
/// Wait for Postgres to become ready to accept connections. It's ready to

View File

@@ -13,7 +13,7 @@ use crate::params::PG_HBA_ALL_MD5;
use crate::pg_helpers::*;
use compute_api::responses::{ControlPlaneComputeStatus, ControlPlaneSpecResponse};
use compute_api::spec::{ComputeSpec, Database, PgIdent, Role};
use compute_api::spec::{ComputeSpec, PgIdent, Role};
// Do control plane request and return response if any. In case of error it
// returns a bool flag indicating whether it makes sense to retry the request
@@ -161,6 +161,38 @@ pub fn add_standby_signal(pgdata_path: &Path) -> Result<()> {
Ok(())
}
/// Compute could be unexpectedly shut down, for example, during the
/// database dropping. This leaves the database in the invalid state,
/// which prevents new db creation with the same name. This function
/// will clean it up before proceeding with catalog updates. All
/// possible future cleanup operations may go here too.
#[instrument(skip_all)]
pub fn cleanup_instance(client: &mut Client) -> Result<()> {
let existing_dbs = get_existing_dbs(client)?;
for (_, db) in existing_dbs {
if db.invalid {
// After recent commit in Postgres, interrupted DROP DATABASE
// leaves the database in the invalid state. According to the
// commit message, the only option for user is to drop it again.
// See:
// https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
//
// Postgres Neon extension is done the way, that db is de-registered
// in the control plane metadata only after it is dropped. So there is
// a chance that it still thinks that db should exist. This means
// that it will be re-created by `handle_databases()`. Yet, it's fine
// as user can just repeat drop (in vanilla Postgres they would need
// to do the same, btw).
let query = format!("DROP DATABASE IF EXISTS {}", db.name.pg_quote());
info!("dropping invalid database {}", db.name);
client.execute(query.as_str(), &[])?;
}
}
Ok(())
}
/// Given a cluster spec json and open transaction it handles roles creation,
/// deletion and update.
#[instrument(skip_all)]
@@ -379,13 +411,13 @@ fn reassign_owned_objects(spec: &ComputeSpec, connstr: &str, role_name: &PgIdent
/// which together provide us idempotency.
#[instrument(skip_all)]
pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
let existing_dbs: Vec<Database> = get_existing_dbs(client)?;
let existing_dbs = get_existing_dbs(client)?;
// Print a list of existing Postgres databases (only in debug mode)
if span_enabled!(Level::INFO) {
info!("postgres databases:");
for r in &existing_dbs {
info!(" {}:{}", r.name, r.owner);
for (dbname, db) in &existing_dbs {
info!(" {}:{}", dbname, db.owner);
}
}
@@ -439,8 +471,7 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
"rename_db" => {
let new_name = op.new_name.as_ref().unwrap();
// XXX: with a limited number of roles it is fine, but consider making it a HashMap
if existing_dbs.iter().any(|r| r.name == op.name) {
if existing_dbs.get(&op.name).is_some() {
let query: String = format!(
"ALTER DATABASE {} RENAME TO {}",
op.name.pg_quote(),
@@ -457,14 +488,12 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
}
// Refresh Postgres databases info to handle possible renames
let existing_dbs: Vec<Database> = get_existing_dbs(client)?;
let existing_dbs = get_existing_dbs(client)?;
info!("cluster spec databases:");
for db in &spec.cluster.databases {
let name = &db.name;
// XXX: with a limited number of databases it is fine, but consider making it a HashMap
let pg_db = existing_dbs.iter().find(|r| r.name == *name);
let pg_db = existing_dbs.get(name);
enum DatabaseAction {
None,
@@ -530,13 +559,32 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
/// Grant CREATE ON DATABASE to the database owner and do some other alters and grants
/// to allow users creating trusted extensions and re-creating `public` schema, for example.
#[instrument(skip_all)]
pub fn handle_grants(spec: &ComputeSpec, connstr: &str) -> Result<()> {
info!("cluster spec grants:");
pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) -> Result<()> {
info!("modifying database permissions");
let existing_dbs = get_existing_dbs(client)?;
// Do some per-database access adjustments. We'd better do this at db creation time,
// but CREATE DATABASE isn't transactional. So we cannot create db + do some grants
// atomically.
for db in &spec.cluster.databases {
match existing_dbs.get(&db.name) {
Some(pg_db) => {
if pg_db.restrict_conn || pg_db.invalid {
info!(
"skipping grants for db {} (invalid: {}, connections not allowed: {})",
db.name, pg_db.invalid, pg_db.restrict_conn
);
continue;
}
}
None => {
bail!(
"database {} doesn't exist in Postgres after handle_databases()",
db.name
);
}
}
let mut conf = Config::from_str(connstr)?;
conf.dbname(&db.name);
@@ -575,6 +623,11 @@ pub fn handle_grants(spec: &ComputeSpec, connstr: &str) -> Result<()> {
// Explicitly grant CREATE ON SCHEMA PUBLIC to the web_access user.
// This is needed because since postgres 15 this privilege is removed by default.
// TODO: web_access isn't created for almost 1 year. It could be that we have
// active users of 1 year old projects, but hopefully not, so check it and
// remove this code if possible. The worst thing that could happen is that
// user won't be able to use public schema in NEW databases created in the
// very OLD project.
let grant_query = "DO $$\n\
BEGIN\n\
IF EXISTS(\n\

View File

@@ -86,7 +86,7 @@ where
.stdout(process_log_file)
.stderr(same_file_for_stderr)
.args(args);
let filled_cmd = fill_aws_secrets_vars(fill_rust_env_vars(background_command));
let filled_cmd = fill_remote_storage_secrets_vars(fill_rust_env_vars(background_command));
filled_cmd.envs(envs);
let pid_file_to_check = match initial_pid_file {
@@ -238,11 +238,13 @@ fn fill_rust_env_vars(cmd: &mut Command) -> &mut Command {
filled_cmd
}
fn fill_aws_secrets_vars(mut cmd: &mut Command) -> &mut Command {
fn fill_remote_storage_secrets_vars(mut cmd: &mut Command) -> &mut Command {
for env_key in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AZURE_STORAGE_ACCOUNT",
"AZURE_STORAGE_ACCESS_KEY",
] {
if let Ok(value) = std::env::var(env_key) {
cmd = cmd.env(env_key, value);

View File

@@ -96,6 +96,16 @@ prefix_in_bucket = '/test_prefix/'
`AWS_SECRET_ACCESS_KEY` and `AWS_ACCESS_KEY_ID` env variables can be used to specify the S3 credentials if needed.
or
```toml
[remote_storage]
container_name = 'some-container-name'
container_region = 'us-east'
prefix_in_container = '/test-prefix/'
```
`AZURE_STORAGE_ACCOUNT` and `AZURE_STORAGE_ACCESS_KEY` env variables can be used to specify the azure credentials if needed.
## Repository background tasks

View File

@@ -200,6 +200,12 @@ pub struct Database {
pub name: PgIdent,
pub owner: PgIdent,
pub options: GenericOptions,
// These are derived flags, not present in the spec file.
// They are never set by the control plane.
#[serde(skip_deserializing, default)]
pub restrict_conn: bool,
#[serde(skip_deserializing, default)]
pub invalid: bool,
}
/// Common type representing both SQL statement params with or without value,

View File

@@ -13,6 +13,7 @@ aws-types.workspace = true
aws-config.workspace = true
aws-sdk-s3.workspace = true
aws-credential-types.workspace = true
bytes.workspace = true
camino.workspace = true
hyper = { workspace = true, features = ["stream"] }
serde.workspace = true
@@ -26,6 +27,12 @@ metrics.workspace = true
utils.workspace = true
pin-project-lite.workspace = true
workspace_hack.workspace = true
azure_core.workspace = true
azure_storage.workspace = true
azure_storage_blobs.workspace = true
futures-util.workspace = true
http-types.workspace = true
itertools.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true

View File

@@ -0,0 +1,381 @@
//! Azure Blob Storage wrapper
use std::num::NonZeroU32;
use std::{borrow::Cow, collections::HashMap, io::Cursor};
use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
use anyhow::Result;
use azure_core::request_options::{MaxResults, Metadata, Range};
use azure_core::Header;
use azure_storage::StorageCredentials;
use azure_storage_blobs::prelude::ClientBuilder;
use azure_storage_blobs::{
blob::operations::GetBlobBuilder,
prelude::{BlobClient, ContainerClient},
};
use futures_util::StreamExt;
use http_types::StatusCode;
use tokio::io::AsyncRead;
use tracing::debug;
use crate::s3_bucket::RequestKind;
use crate::{
AzureConfig, ConcurrencyLimiter, Download, DownloadError, RemotePath, RemoteStorage,
StorageMetadata,
};
pub struct AzureBlobStorage {
client: ContainerClient,
prefix_in_container: Option<String>,
max_keys_per_list_response: Option<NonZeroU32>,
concurrency_limiter: ConcurrencyLimiter,
}
impl AzureBlobStorage {
pub fn new(azure_config: &AzureConfig) -> Result<Self> {
debug!(
"Creating azure remote storage for azure container {}",
azure_config.container_name
);
let account =
std::env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT");
let access_key =
std::env::var("AZURE_STORAGE_ACCESS_KEY").expect("missing AZURE_STORAGE_ACCESS_KEY");
let credentials = StorageCredentials::access_key(account.clone(), access_key);
let builder = ClientBuilder::new(account, credentials);
let client = builder.container_client(azure_config.container_name.to_owned());
let max_keys_per_list_response =
if let Some(limit) = azure_config.max_keys_per_list_response {
Some(
NonZeroU32::new(limit as u32)
.ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
)
} else {
None
};
Ok(AzureBlobStorage {
client,
prefix_in_container: azure_config.prefix_in_container.to_owned(),
max_keys_per_list_response,
concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
})
}
pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
let path_string = path
.get_path()
.as_str()
.trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
match &self.prefix_in_container {
Some(prefix) => {
if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
prefix.clone() + path_string
} else {
format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
}
}
None => path_string.to_string(),
}
}
fn name_to_relative_path(&self, key: &str) -> RemotePath {
let relative_path =
match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
Some(stripped) => stripped,
// we rely on Azure to return properly prefixed paths
// for requests with a certain prefix
None => panic!(
"Key {key} does not start with container prefix {:?}",
self.prefix_in_container
),
};
RemotePath(
relative_path
.split(REMOTE_STORAGE_PREFIX_SEPARATOR)
.collect(),
)
}
async fn download_for_builder(
&self,
metadata: StorageMetadata,
builder: GetBlobBuilder,
) -> Result<Download, DownloadError> {
let mut response = builder.into_stream();
// TODO give proper streaming response instead of buffering into RAM
// https://github.com/neondatabase/neon/issues/5563
let mut buf = Vec::new();
while let Some(part) = response.next().await {
let part = match part {
Ok(l) => l,
Err(e) => {
return Err(if let Some(http_err) = e.as_http_error() {
match http_err.status() {
StatusCode::NotFound => DownloadError::NotFound,
StatusCode::BadRequest => {
DownloadError::BadInput(anyhow::Error::new(e))
}
_ => DownloadError::Other(anyhow::Error::new(e)),
}
} else {
DownloadError::Other(e.into())
});
}
};
let data = part
.data
.collect()
.await
.map_err(|e| DownloadError::Other(e.into()))?;
buf.extend_from_slice(&data.slice(..));
}
Ok(Download {
download_stream: Box::pin(Cursor::new(buf)),
metadata: Some(metadata),
})
}
// TODO get rid of this function once we have metadata included in the response
// https://github.com/Azure/azure-sdk-for-rust/issues/1439
async fn get_metadata(
&self,
blob_client: &BlobClient,
) -> Result<StorageMetadata, DownloadError> {
let builder = blob_client.get_metadata();
match builder.into_future().await {
Ok(r) => {
let mut map = HashMap::new();
for md in r.metadata.iter() {
map.insert(
md.name().as_str().to_string(),
md.value().as_str().to_string(),
);
}
Ok(StorageMetadata(map))
}
Err(e) => {
return Err(if let Some(http_err) = e.as_http_error() {
match http_err.status() {
StatusCode::NotFound => DownloadError::NotFound,
StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(e)),
_ => DownloadError::Other(anyhow::Error::new(e)),
}
} else {
DownloadError::Other(e.into())
});
}
}
}
async fn permit(&self, kind: RequestKind) -> tokio::sync::SemaphorePermit<'_> {
self.concurrency_limiter
.acquire(kind)
.await
.expect("semaphore is never closed")
}
}
fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
let mut res = Metadata::new();
for (k, v) in metadata.0.into_iter() {
res.insert(k, v);
}
res
}
#[async_trait::async_trait]
impl RemoteStorage for AzureBlobStorage {
async fn list_prefixes(
&self,
prefix: Option<&RemotePath>,
) -> Result<Vec<RemotePath>, DownloadError> {
// get the passed prefix or if it is not set use prefix_in_bucket value
let list_prefix = prefix
.map(|p| self.relative_path_to_name(p))
.or_else(|| self.prefix_in_container.clone())
.map(|mut p| {
// required to end with a separator
// otherwise request will return only the entry of a prefix
if !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
}
p
});
let mut builder = self
.client
.list_blobs()
.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
if let Some(prefix) = list_prefix {
builder = builder.prefix(Cow::from(prefix.to_owned()));
}
if let Some(limit) = self.max_keys_per_list_response {
builder = builder.max_results(MaxResults::new(limit));
}
let mut response = builder.into_stream();
let mut res = Vec::new();
while let Some(l) = response.next().await {
let entry = match l {
Ok(l) => l,
Err(e) => {
return Err(if let Some(http_err) = e.as_http_error() {
match http_err.status() {
StatusCode::NotFound => DownloadError::NotFound,
StatusCode::BadRequest => {
DownloadError::BadInput(anyhow::Error::new(e))
}
_ => DownloadError::Other(anyhow::Error::new(e)),
}
} else {
DownloadError::Other(e.into())
});
}
};
let name_iter = entry
.blobs
.prefixes()
.map(|prefix| self.name_to_relative_path(&prefix.name));
res.extend(name_iter);
}
Ok(res)
}
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
let folder_name = folder
.map(|p| self.relative_path_to_name(p))
.or_else(|| self.prefix_in_container.clone());
let mut builder = self.client.list_blobs();
if let Some(folder_name) = folder_name {
builder = builder.prefix(Cow::from(folder_name.to_owned()));
}
if let Some(limit) = self.max_keys_per_list_response {
builder = builder.max_results(MaxResults::new(limit));
}
let mut response = builder.into_stream();
let mut res = Vec::new();
while let Some(l) = response.next().await {
let entry = l.map_err(anyhow::Error::new)?;
let name_iter = entry
.blobs
.blobs()
.map(|bl| self.name_to_relative_path(&bl.name));
res.extend(name_iter);
}
Ok(res)
}
async fn upload(
&self,
mut from: impl AsyncRead + Unpin + Send + Sync + 'static,
data_size_bytes: usize,
to: &RemotePath,
metadata: Option<StorageMetadata>,
) -> anyhow::Result<()> {
let _permit = self.permit(RequestKind::Put).await;
let blob_client = self.client.blob_client(self.relative_path_to_name(to));
// TODO FIX THIS UGLY HACK and don't buffer the entire object
// into RAM here, but use the streaming interface. For that,
// we'd have to change the interface though...
// https://github.com/neondatabase/neon/issues/5563
let mut buf = Vec::with_capacity(data_size_bytes);
tokio::io::copy(&mut from, &mut buf).await?;
let body = azure_core::Body::Bytes(buf.into());
let mut builder = blob_client.put_block_blob(body);
if let Some(metadata) = metadata {
builder = builder.metadata(to_azure_metadata(metadata));
}
let _response = builder.into_future().await?;
Ok(())
}
async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
let _permit = self.permit(RequestKind::Get).await;
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
let metadata = self.get_metadata(&blob_client).await?;
let builder = blob_client.get();
self.download_for_builder(metadata, builder).await
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
) -> Result<Download, DownloadError> {
let _permit = self.permit(RequestKind::Get).await;
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
let metadata = self.get_metadata(&blob_client).await?;
let mut builder = blob_client.get();
if let Some(end_exclusive) = end_exclusive {
builder = builder.range(Range::new(start_inclusive, end_exclusive));
} else {
// Open ranges are not supported by the SDK so we work around
// by setting the upper limit extremely high (but high enough
// to still be representable by signed 64 bit integers).
// TODO remove workaround once the SDK adds open range support
// https://github.com/Azure/azure-sdk-for-rust/issues/1438
let end_exclusive = u64::MAX / 4;
builder = builder.range(Range::new(start_inclusive, end_exclusive));
}
self.download_for_builder(metadata, builder).await
}
async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> {
let _permit = self.permit(RequestKind::Delete).await;
let blob_client = self.client.blob_client(self.relative_path_to_name(path));
let builder = blob_client.delete();
match builder.into_future().await {
Ok(_response) => Ok(()),
Err(e) => {
if let Some(http_err) = e.as_http_error() {
if http_err.status() == StatusCode::NotFound {
return Ok(());
}
}
Err(anyhow::Error::new(e))
}
}
}
async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> {
// Permit is already obtained by inner delete function
// TODO batch requests are also not supported by the SDK
// https://github.com/Azure/azure-sdk-for-rust/issues/1068
// https://github.com/Azure/azure-sdk-for-rust/issues/1249
for path in paths {
self.delete(path).await?;
}
Ok(())
}
}

View File

@@ -4,7 +4,10 @@
//! [`RemoteStorage`] trait a CRUD-like generic abstraction to use for adapting external storages with a few implementations:
//! * [`local_fs`] allows to use local file system as an external storage
//! * [`s3_bucket`] uses AWS S3 bucket as an external storage
//! * [`azure_blob`] allows to use Azure Blob storage as an external storage
//!
mod azure_blob;
mod local_fs;
mod s3_bucket;
mod simulate_failures;
@@ -21,11 +24,15 @@ use anyhow::{bail, Context};
use camino::{Utf8Path, Utf8PathBuf};
use serde::{Deserialize, Serialize};
use tokio::io;
use tokio::{io, sync::Semaphore};
use toml_edit::Item;
use tracing::info;
pub use self::{local_fs::LocalFs, s3_bucket::S3Bucket, simulate_failures::UnreliableWrapper};
pub use self::{
azure_blob::AzureBlobStorage, local_fs::LocalFs, s3_bucket::S3Bucket,
simulate_failures::UnreliableWrapper,
};
use s3_bucket::RequestKind;
/// How many different timelines can be processed simultaneously when synchronizing layers with the remote storage.
/// During regular work, pageserver produces one layer file per timeline checkpoint, with bursts of concurrency
@@ -39,6 +46,11 @@ pub const DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS: u32 = 10;
/// ~3500 PUT/COPY/POST/DELETE or 5500 GET/HEAD S3 requests
/// <https://aws.amazon.com/premiumsupport/knowledge-center/s3-request-limit-avoid-throttling/>
pub const DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT: usize = 100;
/// We set this a little bit low as we currently buffer the entire file into RAM
///
/// Here, a limit of max 20k concurrent connections was noted.
/// <https://learn.microsoft.com/en-us/answers/questions/1301863/is-there-any-limitation-to-concurrent-connections>
pub const DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT: usize = 30;
/// No limits on the client side, which currenltly means 1000 for AWS S3.
/// <https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html#API_ListObjectsV2_RequestSyntax>
pub const DEFAULT_MAX_KEYS_PER_LIST_RESPONSE: Option<i32> = None;
@@ -217,6 +229,7 @@ impl std::error::Error for DownloadError {}
pub enum GenericRemoteStorage {
LocalFs(LocalFs),
AwsS3(Arc<S3Bucket>),
AzureBlob(Arc<AzureBlobStorage>),
Unreliable(Arc<UnreliableWrapper>),
}
@@ -228,6 +241,7 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.list_files(folder).await,
Self::AwsS3(s) => s.list_files(folder).await,
Self::AzureBlob(s) => s.list_files(folder).await,
Self::Unreliable(s) => s.list_files(folder).await,
}
}
@@ -242,6 +256,7 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.list_prefixes(prefix).await,
Self::AwsS3(s) => s.list_prefixes(prefix).await,
Self::AzureBlob(s) => s.list_prefixes(prefix).await,
Self::Unreliable(s) => s.list_prefixes(prefix).await,
}
}
@@ -256,6 +271,7 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.upload(from, data_size_bytes, to, metadata).await,
Self::AwsS3(s) => s.upload(from, data_size_bytes, to, metadata).await,
Self::AzureBlob(s) => s.upload(from, data_size_bytes, to, metadata).await,
Self::Unreliable(s) => s.upload(from, data_size_bytes, to, metadata).await,
}
}
@@ -264,6 +280,7 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.download(from).await,
Self::AwsS3(s) => s.download(from).await,
Self::AzureBlob(s) => s.download(from).await,
Self::Unreliable(s) => s.download(from).await,
}
}
@@ -283,6 +300,10 @@ impl GenericRemoteStorage {
s.download_byte_range(from, start_inclusive, end_exclusive)
.await
}
Self::AzureBlob(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive)
.await
}
Self::Unreliable(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive)
.await
@@ -294,6 +315,7 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.delete(path).await,
Self::AwsS3(s) => s.delete(path).await,
Self::AzureBlob(s) => s.delete(path).await,
Self::Unreliable(s) => s.delete(path).await,
}
}
@@ -302,6 +324,7 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.delete_objects(paths).await,
Self::AwsS3(s) => s.delete_objects(paths).await,
Self::AzureBlob(s) => s.delete_objects(paths).await,
Self::Unreliable(s) => s.delete_objects(paths).await,
}
}
@@ -319,6 +342,11 @@ impl GenericRemoteStorage {
s3_config.bucket_name, s3_config.bucket_region, s3_config.prefix_in_bucket, s3_config.endpoint);
Self::AwsS3(Arc::new(S3Bucket::new(s3_config)?))
}
RemoteStorageKind::AzureContainer(azure_config) => {
info!("Using azure container '{}' in region '{}' as a remote storage, prefix in container: '{:?}'",
azure_config.container_name, azure_config.container_region, azure_config.prefix_in_container);
Self::AzureBlob(Arc::new(AzureBlobStorage::new(azure_config)?))
}
})
}
@@ -383,6 +411,9 @@ pub enum RemoteStorageKind {
/// AWS S3 based storage, storing all files in the S3 bucket
/// specified by the config
AwsS3(S3Config),
/// Azure Blob based storage, storing all files in the container
/// specified by the config
AzureContainer(AzureConfig),
}
/// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
@@ -422,11 +453,45 @@ impl Debug for S3Config {
}
}
/// Azure bucket coordinates and access credentials to manage the bucket contents (read and write).
#[derive(Clone, PartialEq, Eq)]
pub struct AzureConfig {
/// Name of the container to connect to.
pub container_name: String,
/// The region where the bucket is located at.
pub container_region: String,
/// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
pub prefix_in_container: Option<String>,
/// Azure has various limits on its API calls, we need not to exceed those.
/// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
pub concurrency_limit: NonZeroUsize,
pub max_keys_per_list_response: Option<i32>,
}
impl Debug for AzureConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AzureConfig")
.field("bucket_name", &self.container_name)
.field("bucket_region", &self.container_region)
.field("prefix_in_bucket", &self.prefix_in_container)
.field("concurrency_limit", &self.concurrency_limit)
.field(
"max_keys_per_list_response",
&self.max_keys_per_list_response,
)
.finish()
}
}
impl RemoteStorageConfig {
pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<Option<RemoteStorageConfig>> {
let local_path = toml.get("local_path");
let bucket_name = toml.get("bucket_name");
let bucket_region = toml.get("bucket_region");
let container_name = toml.get("container_name");
let container_region = toml.get("container_region");
let use_azure = container_name.is_some() && container_region.is_some();
let max_concurrent_syncs = NonZeroUsize::new(
parse_optional_integer("max_concurrent_syncs", toml)?
@@ -440,9 +505,13 @@ impl RemoteStorageConfig {
)
.context("Failed to parse 'max_sync_errors' as a positive integer")?;
let default_concurrency_limit = if use_azure {
DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT
} else {
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
};
let concurrency_limit = NonZeroUsize::new(
parse_optional_integer("concurrency_limit", toml)?
.unwrap_or(DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT),
parse_optional_integer("concurrency_limit", toml)?.unwrap_or(default_concurrency_limit),
)
.context("Failed to parse 'concurrency_limit' as a positive integer")?;
@@ -451,33 +520,70 @@ impl RemoteStorageConfig {
.context("Failed to parse 'max_keys_per_list_response' as a positive integer")?
.or(DEFAULT_MAX_KEYS_PER_LIST_RESPONSE);
let storage = match (local_path, bucket_name, bucket_region) {
let endpoint = toml
.get("endpoint")
.map(|endpoint| parse_toml_string("endpoint", endpoint))
.transpose()?;
let storage = match (
local_path,
bucket_name,
bucket_region,
container_name,
container_region,
) {
// no 'local_path' nor 'bucket_name' options are provided, consider this remote storage disabled
(None, None, None) => return Ok(None),
(_, Some(_), None) => {
(None, None, None, None, None) => return Ok(None),
(_, Some(_), None, ..) => {
bail!("'bucket_region' option is mandatory if 'bucket_name' is given ")
}
(_, None, Some(_)) => {
(_, None, Some(_), ..) => {
bail!("'bucket_name' option is mandatory if 'bucket_region' is given ")
}
(None, Some(bucket_name), Some(bucket_region)) => RemoteStorageKind::AwsS3(S3Config {
bucket_name: parse_toml_string("bucket_name", bucket_name)?,
bucket_region: parse_toml_string("bucket_region", bucket_region)?,
prefix_in_bucket: toml
.get("prefix_in_bucket")
.map(|prefix_in_bucket| parse_toml_string("prefix_in_bucket", prefix_in_bucket))
.transpose()?,
endpoint: toml
.get("endpoint")
.map(|endpoint| parse_toml_string("endpoint", endpoint))
.transpose()?,
concurrency_limit,
max_keys_per_list_response,
}),
(Some(local_path), None, None) => RemoteStorageKind::LocalFs(Utf8PathBuf::from(
parse_toml_string("local_path", local_path)?,
)),
(Some(_), Some(_), _) => bail!("local_path and bucket_name are mutually exclusive"),
(None, Some(bucket_name), Some(bucket_region), ..) => {
RemoteStorageKind::AwsS3(S3Config {
bucket_name: parse_toml_string("bucket_name", bucket_name)?,
bucket_region: parse_toml_string("bucket_region", bucket_region)?,
prefix_in_bucket: toml
.get("prefix_in_bucket")
.map(|prefix_in_bucket| {
parse_toml_string("prefix_in_bucket", prefix_in_bucket)
})
.transpose()?,
endpoint,
concurrency_limit,
max_keys_per_list_response,
})
}
(_, _, _, Some(_), None) => {
bail!("'container_name' option is mandatory if 'container_region' is given ")
}
(_, _, _, None, Some(_)) => {
bail!("'container_name' option is mandatory if 'container_region' is given ")
}
(None, None, None, Some(container_name), Some(container_region)) => {
RemoteStorageKind::AzureContainer(AzureConfig {
container_name: parse_toml_string("container_name", container_name)?,
container_region: parse_toml_string("container_region", container_region)?,
prefix_in_container: toml
.get("prefix_in_container")
.map(|prefix_in_container| {
parse_toml_string("prefix_in_container", prefix_in_container)
})
.transpose()?,
concurrency_limit,
max_keys_per_list_response,
})
}
(Some(local_path), None, None, None, None) => RemoteStorageKind::LocalFs(
Utf8PathBuf::from(parse_toml_string("local_path", local_path)?),
),
(Some(_), Some(_), ..) => {
bail!("'local_path' and 'bucket_name' are mutually exclusive")
}
(Some(_), _, _, Some(_), Some(_)) => {
bail!("local_path and 'container_name' are mutually exclusive")
}
};
Ok(Some(RemoteStorageConfig {
@@ -513,6 +619,46 @@ fn parse_toml_string(name: &str, item: &Item) -> anyhow::Result<String> {
Ok(s.to_string())
}
struct ConcurrencyLimiter {
// Every request to S3 can be throttled or cancelled, if a certain number of requests per second is exceeded.
// Same goes to IAM, which is queried before every S3 request, if enabled. IAM has even lower RPS threshold.
// The helps to ensure we don't exceed the thresholds.
write: Arc<Semaphore>,
read: Arc<Semaphore>,
}
impl ConcurrencyLimiter {
fn for_kind(&self, kind: RequestKind) -> &Arc<Semaphore> {
match kind {
RequestKind::Get => &self.read,
RequestKind::Put => &self.write,
RequestKind::List => &self.read,
RequestKind::Delete => &self.write,
}
}
async fn acquire(
&self,
kind: RequestKind,
) -> Result<tokio::sync::SemaphorePermit<'_>, tokio::sync::AcquireError> {
self.for_kind(kind).acquire().await
}
async fn acquire_owned(
&self,
kind: RequestKind,
) -> Result<tokio::sync::OwnedSemaphorePermit, tokio::sync::AcquireError> {
Arc::clone(self.for_kind(kind)).acquire_owned().await
}
fn new(limit: usize) -> ConcurrencyLimiter {
Self {
read: Arc::new(Semaphore::new(limit)),
write: Arc::new(Semaphore::new(limit)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -4,7 +4,7 @@
//! allowing multiple api users to independently work with the same S3 bucket, if
//! their bucket prefixes are both specified and different.
use std::sync::Arc;
use std::borrow::Cow;
use anyhow::Context;
use aws_config::{
@@ -24,22 +24,20 @@ use aws_sdk_s3::{
use aws_smithy_http::body::SdkBody;
use hyper::Body;
use scopeguard::ScopeGuard;
use tokio::{
io::{self, AsyncRead},
sync::Semaphore,
};
use tokio::io::{self, AsyncRead};
use tokio_util::io::ReaderStream;
use tracing::debug;
use super::StorageMetadata;
use crate::{
Download, DownloadError, RemotePath, RemoteStorage, S3Config, MAX_KEYS_PER_DELETE,
REMOTE_STORAGE_PREFIX_SEPARATOR,
ConcurrencyLimiter, Download, DownloadError, RemotePath, RemoteStorage, S3Config,
MAX_KEYS_PER_DELETE, REMOTE_STORAGE_PREFIX_SEPARATOR,
};
pub(super) mod metrics;
use self::metrics::{AttemptOutcome, RequestKind};
use self::metrics::AttemptOutcome;
pub(super) use self::metrics::RequestKind;
/// AWS S3 storage.
pub struct S3Bucket {
@@ -50,46 +48,6 @@ pub struct S3Bucket {
concurrency_limiter: ConcurrencyLimiter,
}
struct ConcurrencyLimiter {
// Every request to S3 can be throttled or cancelled, if a certain number of requests per second is exceeded.
// Same goes to IAM, which is queried before every S3 request, if enabled. IAM has even lower RPS threshold.
// The helps to ensure we don't exceed the thresholds.
write: Arc<Semaphore>,
read: Arc<Semaphore>,
}
impl ConcurrencyLimiter {
fn for_kind(&self, kind: RequestKind) -> &Arc<Semaphore> {
match kind {
RequestKind::Get => &self.read,
RequestKind::Put => &self.write,
RequestKind::List => &self.read,
RequestKind::Delete => &self.write,
}
}
async fn acquire(
&self,
kind: RequestKind,
) -> Result<tokio::sync::SemaphorePermit<'_>, tokio::sync::AcquireError> {
self.for_kind(kind).acquire().await
}
async fn acquire_owned(
&self,
kind: RequestKind,
) -> Result<tokio::sync::OwnedSemaphorePermit, tokio::sync::AcquireError> {
Arc::clone(self.for_kind(kind)).acquire_owned().await
}
fn new(limit: usize) -> ConcurrencyLimiter {
Self {
read: Arc::new(Semaphore::new(limit)),
write: Arc::new(Semaphore::new(limit)),
}
}
}
#[derive(Default)]
struct GetObjectRequest {
bucket: String,
@@ -556,6 +514,20 @@ impl RemoteStorage for S3Bucket {
.deleted_objects_total
.inc_by(chunk.len() as u64);
if let Some(errors) = resp.errors {
// Log a bounded number of the errors within the response:
// these requests can carry 1000 keys so logging each one
// would be too verbose, especially as errors may lead us
// to retry repeatedly.
const LOG_UP_TO_N_ERRORS: usize = 10;
for e in errors.iter().take(LOG_UP_TO_N_ERRORS) {
tracing::warn!(
"DeleteObjects key {} failed: {}: {}",
e.key.as_ref().map(Cow::from).unwrap_or("".into()),
e.code.as_ref().map(Cow::from).unwrap_or("".into()),
e.message.as_ref().map(Cow::from).unwrap_or("".into())
);
}
return Err(anyhow::format_err!(
"Failed to delete {} objects",
errors.len()

View File

@@ -6,7 +6,7 @@ use once_cell::sync::Lazy;
pub(super) static BUCKET_METRICS: Lazy<BucketMetrics> = Lazy::new(Default::default);
#[derive(Clone, Copy, Debug)]
pub(super) enum RequestKind {
pub(crate) enum RequestKind {
Get = 0,
Put = 1,
Delete = 2,

View File

@@ -0,0 +1,619 @@
use std::collections::HashSet;
use std::env;
use std::num::{NonZeroU32, NonZeroUsize};
use std::ops::ControlFlow;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::UNIX_EPOCH;
use anyhow::Context;
use camino::Utf8Path;
use once_cell::sync::OnceCell;
use remote_storage::{
AzureConfig, Download, GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind,
};
use test_context::{test_context, AsyncTestContext};
use tokio::task::JoinSet;
use tracing::{debug, error, info};
static LOGGING_DONE: OnceCell<()> = OnceCell::new();
const ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME: &str = "ENABLE_REAL_AZURE_REMOTE_STORAGE";
const BASE_PREFIX: &str = "test";
/// Tests that the Azure client can list all prefixes, even if the response comes paginated and requires multiple HTTP queries.
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified.
/// See the client creation in [`create_azure_client`] for details on the required env vars.
/// If real Azure tests are disabled, the test passes, skipping any real test run: currently, there's no way to mark the test ignored in runtime with the
/// deafult test framework, see https://github.com/rust-lang/rust/issues/68007 for details.
///
/// First, the test creates a set of Azure blobs with keys `/${random_prefix_part}/${base_prefix_str}/sub_prefix_${i}/blob_${i}` in [`upload_azure_data`]
/// where
/// * `random_prefix_part` is set for the entire Azure client during the Azure client creation in [`create_azure_client`], to avoid multiple test runs interference
/// * `base_prefix_str` is a common prefix to use in the client requests: we would want to ensure that the client is able to list nested prefixes inside the bucket
///
/// Then, verifies that the client does return correct prefixes when queried:
/// * with no prefix, it lists everything after its `${random_prefix_part}/` — that should be `${base_prefix_str}` value only
/// * with `${base_prefix_str}/` prefix, it lists every `sub_prefix_${i}`
///
/// With the real Azure enabled and `#[cfg(test)]` Rust configuration used, the Azure client test adds a `max-keys` param to limit the response keys.
/// This way, we are able to test the pagination implicitly, by ensuring all results are returned from the remote storage and avoid uploading too many blobs to Azure.
///
/// Lastly, the test attempts to clean up and remove all uploaded Azure files.
/// If any errors appear during the clean up, they get logged, but the test is not failed or stopped until clean up is finished.
#[test_context(MaybeEnabledAzureWithTestBlobs)]
#[tokio::test]
async fn azure_pagination_should_work(
ctx: &mut MaybeEnabledAzureWithTestBlobs,
) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzureWithTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledAzureWithTestBlobs::Disabled => return Ok(()),
MaybeEnabledAzureWithTestBlobs::UploadsFailed(e, _) => {
anyhow::bail!("Azure init failed: {e:?}")
}
};
let test_client = Arc::clone(&ctx.enabled.client);
let expected_remote_prefixes = ctx.remote_prefixes.clone();
let base_prefix = RemotePath::new(Utf8Path::new(ctx.enabled.base_prefix))
.context("common_prefix construction")?;
let root_remote_prefixes = test_client
.list_prefixes(None)
.await
.context("client list root prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_remote_prefixes, HashSet::from([base_prefix.clone()]),
"remote storage root prefixes list mismatches with the uploads. Returned prefixes: {root_remote_prefixes:?}"
);
let nested_remote_prefixes = test_client
.list_prefixes(Some(&base_prefix))
.await
.context("client list nested prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
let remote_only_prefixes = nested_remote_prefixes
.difference(&expected_remote_prefixes)
.collect::<HashSet<_>>();
let missing_uploaded_prefixes = expected_remote_prefixes
.difference(&nested_remote_prefixes)
.collect::<HashSet<_>>();
assert_eq!(
remote_only_prefixes.len() + missing_uploaded_prefixes.len(), 0,
"remote storage nested prefixes list mismatches with the uploads. Remote only prefixes: {remote_only_prefixes:?}, missing uploaded prefixes: {missing_uploaded_prefixes:?}",
);
Ok(())
}
/// Tests that Azure client can list all files in a folder, even if the response comes paginated and requirees multiple Azure queries.
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified. Test will skip real code and pass if env vars not set.
/// See `Azure_pagination_should_work` for more information.
///
/// First, create a set of Azure objects with keys `random_prefix/folder{j}/blob_{i}.txt` in [`upload_azure_data`]
/// Then performs the following queries:
/// 1. `list_files(None)`. This should return all files `random_prefix/folder{j}/blob_{i}.txt`
/// 2. `list_files("folder1")`. This should return all files `random_prefix/folder1/blob_{i}.txt`
#[test_context(MaybeEnabledAzureWithSimpleTestBlobs)]
#[tokio::test]
async fn azure_list_files_works(
ctx: &mut MaybeEnabledAzureWithSimpleTestBlobs,
) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzureWithSimpleTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledAzureWithSimpleTestBlobs::Disabled => return Ok(()),
MaybeEnabledAzureWithSimpleTestBlobs::UploadsFailed(e, _) => {
anyhow::bail!("Azure init failed: {e:?}")
}
};
let test_client = Arc::clone(&ctx.enabled.client);
let base_prefix =
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
let root_files = test_client
.list_files(None)
.await
.context("client list root files failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_files,
ctx.remote_blobs.clone(),
"remote storage list_files on root mismatches with the uploads."
);
let nested_remote_files = test_client
.list_files(Some(&base_prefix))
.await
.context("client list nested files failure")?
.into_iter()
.collect::<HashSet<_>>();
let trim_remote_blobs: HashSet<_> = ctx
.remote_blobs
.iter()
.map(|x| x.get_path())
.filter(|x| x.starts_with("folder1"))
.map(|x| RemotePath::new(x).expect("must be valid path"))
.collect();
assert_eq!(
nested_remote_files, trim_remote_blobs,
"remote storage list_files on subdirrectory mismatches with the uploads."
);
Ok(())
}
#[test_context(MaybeEnabledAzure)]
#[tokio::test]
async fn azure_delete_non_exising_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzure::Enabled(ctx) => ctx,
MaybeEnabledAzure::Disabled => return Ok(()),
};
let path = RemotePath::new(Utf8Path::new(
format!("{}/for_sure_there_is_nothing_there_really", ctx.base_prefix).as_str(),
))
.with_context(|| "RemotePath conversion")?;
ctx.client.delete(&path).await.expect("should succeed");
Ok(())
}
#[test_context(MaybeEnabledAzure)]
#[tokio::test]
async fn azure_delete_objects_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzure::Enabled(ctx) => ctx,
MaybeEnabledAzure::Disabled => return Ok(()),
};
let path1 = RemotePath::new(Utf8Path::new(format!("{}/path1", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path2 = RemotePath::new(Utf8Path::new(format!("{}/path2", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let data1 = "remote blob data1".as_bytes();
let data1_len = data1.len();
let data2 = "remote blob data2".as_bytes();
let data2_len = data2.len();
let data3 = "remote blob data3".as_bytes();
let data3_len = data3.len();
ctx.client
.upload(std::io::Cursor::new(data1), data1_len, &path1, None)
.await?;
ctx.client
.upload(std::io::Cursor::new(data2), data2_len, &path2, None)
.await?;
ctx.client
.upload(std::io::Cursor::new(data3), data3_len, &path3, None)
.await?;
ctx.client.delete_objects(&[path1, path2]).await?;
let prefixes = ctx.client.list_prefixes(None).await?;
assert_eq!(prefixes.len(), 1);
ctx.client.delete_objects(&[path3]).await?;
Ok(())
}
#[test_context(MaybeEnabledAzure)]
#[tokio::test]
async fn azure_upload_download_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
let MaybeEnabledAzure::Enabled(ctx) = ctx else {
return Ok(());
};
let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let data = "remote blob data here".as_bytes();
let data_len = data.len() as u64;
ctx.client
.upload(std::io::Cursor::new(data), data.len(), &path, None)
.await?;
async fn download_and_compare(mut dl: Download) -> anyhow::Result<Vec<u8>> {
let mut buf = Vec::new();
tokio::io::copy(&mut dl.download_stream, &mut buf).await?;
Ok(buf)
}
// Normal download request
let dl = ctx.client.download(&path).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data);
// Full range (end specified)
let dl = ctx
.client
.download_byte_range(&path, 0, Some(data_len))
.await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data);
// partial range (end specified)
let dl = ctx.client.download_byte_range(&path, 4, Some(10)).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data[4..10]);
// partial range (end beyond real end)
let dl = ctx
.client
.download_byte_range(&path, 8, Some(data_len * 100))
.await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data[8..]);
// Partial range (end unspecified)
let dl = ctx.client.download_byte_range(&path, 4, None).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data[4..]);
// Full range (end unspecified)
let dl = ctx.client.download_byte_range(&path, 0, None).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data);
Ok(())
}
fn ensure_logging_ready() {
LOGGING_DONE.get_or_init(|| {
utils::logging::init(
utils::logging::LogFormat::Test,
utils::logging::TracingErrorLayerEnablement::Disabled,
)
.expect("logging init failed");
});
}
struct EnabledAzure {
client: Arc<GenericRemoteStorage>,
base_prefix: &'static str,
}
impl EnabledAzure {
async fn setup(max_keys_in_list_response: Option<i32>) -> Self {
let client = create_azure_client(max_keys_in_list_response)
.context("Azure client creation")
.expect("Azure client creation failed");
EnabledAzure {
client,
base_prefix: BASE_PREFIX,
}
}
}
enum MaybeEnabledAzure {
Enabled(EnabledAzure),
Disabled,
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledAzure {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
info!(
"`{}` env variable is not set, skipping the test",
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
);
return Self::Disabled;
}
Self::Enabled(EnabledAzure::setup(None).await)
}
}
enum MaybeEnabledAzureWithTestBlobs {
Enabled(AzureWithTestBlobs),
Disabled,
UploadsFailed(anyhow::Error, AzureWithTestBlobs),
}
struct AzureWithTestBlobs {
enabled: EnabledAzure,
remote_prefixes: HashSet<RemotePath>,
remote_blobs: HashSet<RemotePath>,
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledAzureWithTestBlobs {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
info!(
"`{}` env variable is not set, skipping the test",
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
);
return Self::Disabled;
}
let max_keys_in_list_response = 10;
let upload_tasks_count = 1 + (2 * usize::try_from(max_keys_in_list_response).unwrap());
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
match upload_azure_data(&enabled.client, enabled.base_prefix, upload_tasks_count).await {
ControlFlow::Continue(uploads) => {
info!("Remote objects created successfully");
Self::Enabled(AzureWithTestBlobs {
enabled,
remote_prefixes: uploads.prefixes,
remote_blobs: uploads.blobs,
})
}
ControlFlow::Break(uploads) => Self::UploadsFailed(
anyhow::anyhow!("One or multiple blobs failed to upload to Azure"),
AzureWithTestBlobs {
enabled,
remote_prefixes: uploads.prefixes,
remote_blobs: uploads.blobs,
},
),
}
}
async fn teardown(self) {
match self {
Self::Disabled => {}
Self::Enabled(ctx) | Self::UploadsFailed(_, ctx) => {
cleanup(&ctx.enabled.client, ctx.remote_blobs).await;
}
}
}
}
// NOTE: the setups for the list_prefixes test and the list_files test are very similar
// However, they are not idential. The list_prefixes function is concerned with listing prefixes,
// whereas the list_files function is concerned with listing files.
// See `RemoteStorage::list_files` documentation for more details
enum MaybeEnabledAzureWithSimpleTestBlobs {
Enabled(AzureWithSimpleTestBlobs),
Disabled,
UploadsFailed(anyhow::Error, AzureWithSimpleTestBlobs),
}
struct AzureWithSimpleTestBlobs {
enabled: EnabledAzure,
remote_blobs: HashSet<RemotePath>,
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledAzureWithSimpleTestBlobs {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
info!(
"`{}` env variable is not set, skipping the test",
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
);
return Self::Disabled;
}
let max_keys_in_list_response = 10;
let upload_tasks_count = 1 + (2 * usize::try_from(max_keys_in_list_response).unwrap());
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
match upload_simple_azure_data(&enabled.client, upload_tasks_count).await {
ControlFlow::Continue(uploads) => {
info!("Remote objects created successfully");
Self::Enabled(AzureWithSimpleTestBlobs {
enabled,
remote_blobs: uploads,
})
}
ControlFlow::Break(uploads) => Self::UploadsFailed(
anyhow::anyhow!("One or multiple blobs failed to upload to Azure"),
AzureWithSimpleTestBlobs {
enabled,
remote_blobs: uploads,
},
),
}
}
async fn teardown(self) {
match self {
Self::Disabled => {}
Self::Enabled(ctx) | Self::UploadsFailed(_, ctx) => {
cleanup(&ctx.enabled.client, ctx.remote_blobs).await;
}
}
}
}
fn create_azure_client(
max_keys_per_list_response: Option<i32>,
) -> anyhow::Result<Arc<GenericRemoteStorage>> {
use rand::Rng;
let remote_storage_azure_container = env::var("REMOTE_STORAGE_AZURE_CONTAINER").context(
"`REMOTE_STORAGE_AZURE_CONTAINER` env var is not set, but real Azure tests are enabled",
)?;
let remote_storage_azure_region = env::var("REMOTE_STORAGE_AZURE_REGION").context(
"`REMOTE_STORAGE_AZURE_REGION` env var is not set, but real Azure tests are enabled",
)?;
// due to how time works, we've had test runners use the same nanos as bucket prefixes.
// millis is just a debugging aid for easier finding the prefix later.
let millis = std::time::SystemTime::now()
.duration_since(UNIX_EPOCH)
.context("random Azure test prefix part calculation")?
.as_millis();
// because nanos can be the same for two threads so can millis, add randomness
let random = rand::thread_rng().gen::<u32>();
let remote_storage_config = RemoteStorageConfig {
max_concurrent_syncs: NonZeroUsize::new(100).unwrap(),
max_sync_errors: NonZeroU32::new(5).unwrap(),
storage: RemoteStorageKind::AzureContainer(AzureConfig {
container_name: remote_storage_azure_container,
container_region: remote_storage_azure_region,
prefix_in_container: Some(format!("test_{millis}_{random:08x}/")),
concurrency_limit: NonZeroUsize::new(100).unwrap(),
max_keys_per_list_response,
}),
};
Ok(Arc::new(
GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?,
))
}
struct Uploads {
prefixes: HashSet<RemotePath>,
blobs: HashSet<RemotePath>,
}
async fn upload_azure_data(
client: &Arc<GenericRemoteStorage>,
base_prefix_str: &'static str,
upload_tasks_count: usize,
) -> ControlFlow<Uploads, Uploads> {
info!("Creating {upload_tasks_count} Azure files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let prefix = format!("{base_prefix_str}/sub_prefix_{i}/");
let blob_prefix = RemotePath::new(Utf8Path::new(&prefix))
.with_context(|| format!("{prefix:?} to RemotePath conversion"))?;
let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}")));
debug!("Creating remote item {i} at path {blob_path:?}");
let data = format!("remote blob data {i}").into_bytes();
let data_len = data.len();
task_client
.upload(std::io::Cursor::new(data), data_len, &blob_path, None)
.await?;
Ok::<_, anyhow::Error>((blob_prefix, blob_path))
});
}
let mut upload_tasks_failed = false;
let mut uploaded_prefixes = HashSet::with_capacity(upload_tasks_count);
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok((upload_prefix, upload_path)) => {
uploaded_prefixes.insert(upload_prefix);
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
let uploads = Uploads {
prefixes: uploaded_prefixes,
blobs: uploaded_blobs,
};
if upload_tasks_failed {
ControlFlow::Break(uploads)
} else {
ControlFlow::Continue(uploads)
}
}
async fn cleanup(client: &Arc<GenericRemoteStorage>, objects_to_delete: HashSet<RemotePath>) {
info!(
"Removing {} objects from the remote storage during cleanup",
objects_to_delete.len()
);
let mut delete_tasks = JoinSet::new();
for object_to_delete in objects_to_delete {
let task_client = Arc::clone(client);
delete_tasks.spawn(async move {
debug!("Deleting remote item at path {object_to_delete:?}");
task_client
.delete(&object_to_delete)
.await
.with_context(|| format!("{object_to_delete:?} removal"))
});
}
while let Some(task_run_result) = delete_tasks.join_next().await {
match task_run_result {
Ok(task_result) => match task_result {
Ok(()) => {}
Err(e) => error!("Delete task failed: {e:?}"),
},
Err(join_err) => error!("Delete task did not finish correctly: {join_err}"),
}
}
}
// Uploads files `folder{j}/blob{i}.txt`. See test description for more details.
async fn upload_simple_azure_data(
client: &Arc<GenericRemoteStorage>,
upload_tasks_count: usize,
) -> ControlFlow<HashSet<RemotePath>, HashSet<RemotePath>> {
info!("Creating {upload_tasks_count} Azure files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let blob_path = PathBuf::from(format!("folder{}/blob_{}.txt", i / 7, i));
let blob_path = RemotePath::new(
Utf8Path::from_path(blob_path.as_path()).expect("must be valid blob path"),
)
.with_context(|| format!("{blob_path:?} to RemotePath conversion"))?;
debug!("Creating remote item {i} at path {blob_path:?}");
let data = format!("remote blob data {i}").into_bytes();
let data_len = data.len();
task_client
.upload(std::io::Cursor::new(data), data_len, &blob_path, None)
.await?;
Ok::<_, anyhow::Error>(blob_path)
});
}
let mut upload_tasks_failed = false;
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok(upload_path) => {
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
if upload_tasks_failed {
ControlFlow::Break(uploaded_blobs)
} else {
ControlFlow::Continue(uploaded_blobs)
}
}

View File

@@ -27,8 +27,8 @@ and old one if it exists.
* the filecache: a struct that allows communication with the Postgres file cache.
On startup, we connect to the filecache and hold on to the connection for the
entire monitor lifetime.
* the cgroup watcher: the `CgroupWatcher` manages the `neon-postgres` cgroup by
listening for `memory.high` events and setting its `memory.{high,max}` values.
* the cgroup watcher: the `CgroupWatcher` polls the `neon-postgres` cgroup's memory
usage and sends rolling aggregates to the runner.
* the runner: the runner marries the filecache and cgroup watcher together,
communicating with the agent throught the `Dispatcher`, and then calling filecache
and cgroup watcher functions as needed to upscale and downscale

View File

@@ -1,161 +1,38 @@
use std::{
fmt::{Debug, Display},
fs,
pin::pin,
sync::atomic::{AtomicU64, Ordering},
};
use std::fmt::{self, Debug, Formatter};
use std::time::{Duration, Instant};
use anyhow::{anyhow, bail, Context};
use anyhow::{anyhow, Context};
use cgroups_rs::{
freezer::FreezerController,
hierarchies::{self, is_cgroup2_unified_mode, UNIFIED_MOUNTPOINT},
hierarchies::{self, is_cgroup2_unified_mode},
memory::MemController,
MaxValue,
Subsystem::{Freezer, Mem},
Subsystem,
};
use inotify::{EventStream, Inotify, WatchMask};
use tokio::sync::mpsc::{self, error::TryRecvError};
use tokio::time::{Duration, Instant};
use tokio_stream::{Stream, StreamExt};
use tokio::sync::watch;
use tracing::{info, warn};
use crate::protocol::Resources;
use crate::MiB;
/// Monotonically increasing counter of the number of memory.high events
/// the cgroup has experienced.
///
/// We use this to determine if a modification to the `memory.events` file actually
/// changed the `high` field. If not, we don't care about the change. When we
/// read the file, we check the `high` field in the file against `MEMORY_EVENT_COUNT`
/// to see if it changed since last time.
pub static MEMORY_EVENT_COUNT: AtomicU64 = AtomicU64::new(0);
/// Monotonically increasing counter that gives each cgroup event a unique id.
///
/// This allows us to answer questions like "did this upscale arrive before this
/// memory.high?". This static is also used by the `Sequenced` type to "tag" values
/// with a sequence number. As such, prefer to used the `Sequenced` type rather
/// than this static directly.
static EVENT_SEQUENCE_NUMBER: AtomicU64 = AtomicU64::new(0);
/// A memory event type reported in memory.events.
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum MemoryEvent {
Low,
High,
Max,
Oom,
OomKill,
OomGroupKill,
}
impl MemoryEvent {
fn as_str(&self) -> &str {
match self {
MemoryEvent::Low => "low",
MemoryEvent::High => "high",
MemoryEvent::Max => "max",
MemoryEvent::Oom => "oom",
MemoryEvent::OomKill => "oom_kill",
MemoryEvent::OomGroupKill => "oom_group_kill",
}
}
}
impl Display for MemoryEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
/// Configuration for a `CgroupWatcher`
#[derive(Debug, Clone)]
pub struct Config {
// The target difference between the total memory reserved for the cgroup
// and the value of the cgroup's memory.high.
//
// In other words, memory.high + oom_buffer_bytes will equal the total memory that the cgroup may
// use (equal to system memory, minus whatever's taken out for the file cache).
oom_buffer_bytes: u64,
/// Interval at which we should be fetching memory statistics
memory_poll_interval: Duration,
// The amount of memory, in bytes, below a proposed new value for
// memory.high that the cgroup's memory usage must be for us to downscale
//
// In other words, we can downscale only when:
//
// memory.current + memory_high_buffer_bytes < (proposed) memory.high
//
// TODO: there's some minor issues with this approach -- in particular, that we might have
// memory in use by the kernel's page cache that we're actually ok with getting rid of.
pub(crate) memory_high_buffer_bytes: u64,
// The maximum duration, in milliseconds, that we're allowed to pause
// the cgroup for while waiting for the autoscaler-agent to upscale us
max_upscale_wait: Duration,
// The required minimum time, in milliseconds, that we must wait before re-freezing
// the cgroup while waiting for the autoscaler-agent to upscale us.
do_not_freeze_more_often_than: Duration,
// The amount of memory, in bytes, that we should periodically increase memory.high
// by while waiting for the autoscaler-agent to upscale us.
//
// This exists to avoid the excessive throttling that happens when a cgroup is above its
// memory.high for too long. See more here:
// https://github.com/neondatabase/autoscaling/issues/44#issuecomment-1522487217
memory_high_increase_by_bytes: u64,
// The period, in milliseconds, at which we should repeatedly increase the value
// of the cgroup's memory.high while we're waiting on upscaling and memory.high
// is still being hit.
//
// Technically speaking, this actually serves as a rate limit to moderate responding to
// memory.high events, but these are roughly equivalent if the process is still allocating
// memory.
memory_high_increase_every: Duration,
}
impl Config {
/// Calculate the new value for the cgroups memory.high based on system memory
pub fn calculate_memory_high_value(&self, total_system_mem: u64) -> u64 {
total_system_mem.saturating_sub(self.oom_buffer_bytes)
}
/// The number of samples used in constructing aggregated memory statistics
memory_history_len: usize,
/// The number of most recent samples that will be periodically logged.
///
/// Each sample is logged exactly once. Increasing this value means that recent samples will be
/// logged less frequently, and vice versa.
///
/// For simplicity, this value must be greater than or equal to `memory_history_len`.
memory_history_log_interval: usize,
}
impl Default for Config {
fn default() -> Self {
Self {
oom_buffer_bytes: 100 * MiB,
memory_high_buffer_bytes: 100 * MiB,
// while waiting for upscale, don't freeze for more than 20ms every 1s
max_upscale_wait: Duration::from_millis(20),
do_not_freeze_more_often_than: Duration::from_millis(1000),
// while waiting for upscale, increase memory.high by 10MiB every 25ms
memory_high_increase_by_bytes: 10 * MiB,
memory_high_increase_every: Duration::from_millis(25),
}
}
}
/// Used to represent data that is associated with a certain point in time, such
/// as an upscale request or memory.high event.
///
/// Internally, creating a `Sequenced` uses a static atomic counter to obtain
/// a unique sequence number. Sequence numbers are monotonically increasing,
/// allowing us to answer questions like "did this upscale happen after this
/// memory.high event?" by comparing the sequence numbers of the two events.
#[derive(Debug, Clone)]
pub struct Sequenced<T> {
seqnum: u64,
data: T,
}
impl<T> Sequenced<T> {
pub fn new(data: T) -> Self {
Self {
seqnum: EVENT_SEQUENCE_NUMBER.fetch_add(1, Ordering::AcqRel),
data,
memory_poll_interval: Duration::from_millis(100),
memory_history_len: 5, // use 500ms of history for decision-making
memory_history_log_interval: 20, // but only log every ~2s (otherwise it's spammy)
}
}
}
@@ -170,74 +47,14 @@ impl<T> Sequenced<T> {
pub struct CgroupWatcher {
pub config: Config,
/// The sequence number of the last upscale.
///
/// If we receive a memory.high event that has a _lower_ sequence number than
/// `last_upscale_seqnum`, then we know it occured before the upscale, and we
/// can safely ignore it.
///
/// Note: Like the `events` field, this doesn't _need_ interior mutability but we
/// use it anyways so that methods take `&self`, not `&mut self`.
last_upscale_seqnum: AtomicU64,
/// A channel on which we send messages to request upscale from the dispatcher.
upscale_requester: mpsc::Sender<()>,
/// The actual cgroup we are watching and managing.
cgroup: cgroups_rs::Cgroup,
}
/// Read memory.events for the desired event type.
///
/// `path` specifies the path to the desired `memory.events` file.
/// For more info, see the `memory.events` section of the [kernel docs]
/// <https://docs.kernel.org/admin-guide/cgroup-v2.html#memory-interface-files>
fn get_event_count(path: &str, event: MemoryEvent) -> anyhow::Result<u64> {
let contents = fs::read_to_string(path)
.with_context(|| format!("failed to read memory.events from {path}"))?;
// Then contents of the file look like:
// low 42
// high 101
// ...
contents
.lines()
.filter_map(|s| s.split_once(' '))
.find(|(e, _)| *e == event.as_str())
.ok_or_else(|| anyhow!("failed to find entry for memory.{event} events in {path}"))
.and_then(|(_, count)| {
count
.parse::<u64>()
.with_context(|| format!("failed to parse memory.{event} as u64"))
})
}
/// Create an event stream that produces events whenever the file at the provided
/// path is modified.
fn create_file_watcher(path: &str) -> anyhow::Result<EventStream<[u8; 1024]>> {
info!("creating file watcher for {path}");
let inotify = Inotify::init().context("failed to initialize file watcher")?;
inotify
.watches()
.add(path, WatchMask::MODIFY)
.with_context(|| format!("failed to start watching {path}"))?;
inotify
// The inotify docs use [0u8; 1024] so we'll just copy them. We only need
// to store one event at a time - if the event gets written over, that's
// ok. We still see that there is an event. For more information, see:
// https://man7.org/linux/man-pages/man7/inotify.7.html
.into_event_stream([0u8; 1024])
.context("failed to start inotify event stream")
}
impl CgroupWatcher {
/// Create a new `CgroupWatcher`.
#[tracing::instrument(skip_all, fields(%name))]
pub fn new(
name: String,
// A channel on which to send upscale requests
upscale_requester: mpsc::Sender<()>,
) -> anyhow::Result<(Self, impl Stream<Item = Sequenced<u64>>)> {
pub fn new(name: String) -> anyhow::Result<Self> {
// TODO: clarify exactly why we need v2
// Make sure cgroups v2 (aka unified) are supported
if !is_cgroup2_unified_mode() {
@@ -245,410 +62,203 @@ impl CgroupWatcher {
}
let cgroup = cgroups_rs::Cgroup::load(hierarchies::auto(), &name);
// Start monitoring the cgroup for memory events. In general, for
// cgroups v2 (aka unified), metrics are reported in files like
// > `/sys/fs/cgroup/{name}/{metric}`
// We are looking for `memory.high` events, which are stored in the
// file `memory.events`. For more info, see the `memory.events` section
// of https://docs.kernel.org/admin-guide/cgroup-v2.html#memory-interface-files
let path = format!("{}/{}/memory.events", UNIFIED_MOUNTPOINT, &name);
let memory_events = create_file_watcher(&path)
.with_context(|| format!("failed to create event watcher for {path}"))?
// This would be nice with with .inspect_err followed by .ok
.filter_map(move |_| match get_event_count(&path, MemoryEvent::High) {
Ok(high) => Some(high),
Err(error) => {
// TODO: Might want to just panic here
warn!(?error, "failed to read high events count from {}", &path);
None
}
})
// Only report the event if the memory.high count increased
.filter_map(|high| {
if MEMORY_EVENT_COUNT.fetch_max(high, Ordering::AcqRel) < high {
Some(high)
} else {
None
}
})
.map(Sequenced::new);
let initial_count = get_event_count(
&format!("{}/{}/memory.events", UNIFIED_MOUNTPOINT, &name),
MemoryEvent::High,
)?;
info!(initial_count, "initial memory.high event count");
// Hard update `MEMORY_EVENT_COUNT` since there could have been processes
// running in the cgroup before that caused it to be non-zero.
MEMORY_EVENT_COUNT.fetch_max(initial_count, Ordering::AcqRel);
Ok((
Self {
cgroup,
upscale_requester,
last_upscale_seqnum: AtomicU64::new(0),
config: Default::default(),
},
memory_events,
))
Ok(Self {
cgroup,
config: Default::default(),
})
}
/// The entrypoint for the `CgroupWatcher`.
#[tracing::instrument(skip_all)]
pub async fn watch<E>(
pub async fn watch(
&self,
// These are ~dependency injected~ (fancy, I know) because this function
// should never return.
// -> therefore: when we tokio::spawn it, we don't await the JoinHandle.
// -> therefore: if we want to stick it in an Arc so many threads can access
// it, methods can never take mutable access.
// - note: we use the Arc strategy so that a) we can call this function
// right here and b) the runner can call the set/get_memory methods
// -> since calling recv() on a tokio::sync::mpsc::Receiver takes &mut self,
// we just pass them in here instead of holding them in fields, as that
// would require this method to take &mut self.
mut upscales: mpsc::Receiver<Sequenced<Resources>>,
events: E,
) -> anyhow::Result<()>
where
E: Stream<Item = Sequenced<u64>>,
{
let mut wait_to_freeze = pin!(tokio::time::sleep(Duration::ZERO));
let mut last_memory_high_increase_at: Option<Instant> = None;
let mut events = pin!(events);
// Are we waiting to be upscaled? Could be true if we request upscale due
// to a memory.high event and it does not arrive in time.
let mut waiting_on_upscale = false;
loop {
tokio::select! {
upscale = upscales.recv() => {
let Sequenced { seqnum, data } = upscale
.context("failed to listen on upscale notification channel")?;
waiting_on_upscale = false;
last_memory_high_increase_at = None;
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
info!(cpu = data.cpu, mem_bytes = data.mem, "received upscale");
}
event = events.next() => {
let Some(Sequenced { seqnum, .. }) = event else {
bail!("failed to listen for memory.high events")
};
// The memory.high came before our last upscale, so we consider
// it resolved
if self.last_upscale_seqnum.fetch_max(seqnum, Ordering::AcqRel) > seqnum {
info!(
"received memory.high event, but it came before our last upscale -> ignoring it"
);
continue;
}
// The memory.high came after our latest upscale. We don't
// want to do anything yet, so peek the next event in hopes
// that it's an upscale.
if let Some(upscale_num) = self
.upscaled(&mut upscales)
.context("failed to check if we were upscaled")?
{
if upscale_num > seqnum {
info!(
"received memory.high event, but it came before our last upscale -> ignoring it"
);
continue;
}
}
// If it's been long enough since we last froze, freeze the
// cgroup and request upscale
if wait_to_freeze.is_elapsed() {
info!("received memory.high event -> requesting upscale");
waiting_on_upscale = self
.handle_memory_high_event(&mut upscales)
.await
.context("failed to handle upscale")?;
wait_to_freeze
.as_mut()
.reset(Instant::now() + self.config.do_not_freeze_more_often_than);
continue;
}
// Ok, we can't freeze, just request upscale
if !waiting_on_upscale {
info!("received memory.high event, but too soon to refreeze -> requesting upscale");
// Make check to make sure we haven't been upscaled in the
// meantine (can happen if the agent independently decides
// to upscale us again)
if self
.upscaled(&mut upscales)
.context("failed to check if we were upscaled")?
.is_some()
{
info!("no need to request upscaling because we got upscaled");
continue;
}
self.upscale_requester
.send(())
.await
.context("failed to request upscale")?;
waiting_on_upscale = true;
continue;
}
// Shoot, we can't freeze or and we're still waiting on upscale,
// increase memory.high to reduce throttling
let can_increase_memory_high = match last_memory_high_increase_at {
None => true,
Some(t) => t.elapsed() > self.config.memory_high_increase_every,
};
if can_increase_memory_high {
info!(
"received memory.high event, \
but too soon to refreeze and already requested upscale \
-> increasing memory.high"
);
// Make check to make sure we haven't been upscaled in the
// meantine (can happen if the agent independently decides
// to upscale us again)
if self
.upscaled(&mut upscales)
.context("failed to check if we were upscaled")?
.is_some()
{
info!("no need to increase memory.high because got upscaled");
continue;
}
// Request upscale anyways (the agent will handle deduplicating
// requests)
self.upscale_requester
.send(())
.await
.context("failed to request upscale")?;
let memory_high =
self.get_memory_high_bytes().context("failed to get memory.high")?;
let new_high = memory_high + self.config.memory_high_increase_by_bytes;
info!(
current_high_bytes = memory_high,
new_high_bytes = new_high,
"updating memory.high"
);
self.set_memory_high_bytes(new_high)
.context("failed to set memory.high")?;
last_memory_high_increase_at = Some(Instant::now());
continue;
}
info!("received memory.high event, but can't do anything");
}
};
}
}
/// Handle a `memory.high`, returning whether we are still waiting on upscale
/// by the time the function returns.
///
/// The general plan for handling a `memory.high` event is as follows:
/// 1. Freeze the cgroup
/// 2. Start a timer for `self.config.max_upscale_wait`
/// 3. Request upscale
/// 4. After the timer elapses or we receive upscale, thaw the cgroup.
/// 5. Return whether or not we are still waiting for upscale. If we are,
/// we'll increase the cgroups memory.high to avoid getting oom killed
#[tracing::instrument(skip_all)]
async fn handle_memory_high_event(
&self,
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
) -> anyhow::Result<bool> {
// Immediately freeze the cgroup before doing anything else.
info!("received memory.high event -> freezing cgroup");
self.freeze().context("failed to freeze cgroup")?;
// We'll use this for logging durations
let start_time = Instant::now();
// Await the upscale until we have to unfreeze
let timed =
tokio::time::timeout(self.config.max_upscale_wait, self.await_upscale(upscales));
// Request the upscale
info!(
wait = ?self.config.max_upscale_wait,
"sending request for immediate upscaling",
);
self.upscale_requester
.send(())
.await
.context("failed to request upscale")?;
let waiting_on_upscale = match timed.await {
Ok(Ok(())) => {
info!(elapsed = ?start_time.elapsed(), "received upscale in time");
false
}
// **important**: unfreeze the cgroup before ?-reporting the error
Ok(Err(e)) => {
info!("error waiting for upscale -> thawing cgroup");
self.thaw()
.context("failed to thaw cgroup after errored waiting for upscale")?;
Err(e.context("failed to await upscale"))?
}
Err(_) => {
info!(elapsed = ?self.config.max_upscale_wait, "timed out waiting for upscale");
true
}
};
info!("thawing cgroup");
self.thaw().context("failed to thaw cgroup")?;
Ok(waiting_on_upscale)
}
/// Checks whether we were just upscaled, returning the upscale's sequence
/// number if so.
#[tracing::instrument(skip_all)]
fn upscaled(
&self,
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
) -> anyhow::Result<Option<u64>> {
let Sequenced { seqnum, data } = match upscales.try_recv() {
Ok(upscale) => upscale,
Err(TryRecvError::Empty) => return Ok(None),
Err(TryRecvError::Disconnected) => {
bail!("upscale notification channel was disconnected")
}
};
// Make sure to update the last upscale sequence number
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
info!(cpu = data.cpu, mem_bytes = data.mem, "received upscale");
Ok(Some(seqnum))
}
/// Await an upscale event, discarding any `memory.high` events received in
/// the process.
///
/// This is used in `handle_memory_high_event`, where we need to listen
/// for upscales in particular so we know if we can thaw the cgroup early.
#[tracing::instrument(skip_all)]
async fn await_upscale(
&self,
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
updates: watch::Sender<(Instant, MemoryHistory)>,
) -> anyhow::Result<()> {
let Sequenced { seqnum, .. } = upscales
.recv()
.await
.context("error listening for upscales")?;
// this requirement makes the code a bit easier to work with; see the config for more.
assert!(self.config.memory_history_len <= self.config.memory_history_log_interval);
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
Ok(())
}
let mut ticker = tokio::time::interval(self.config.memory_poll_interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
// ticker.reset_immediately(); // FIXME: enable this once updating to tokio >= 1.30.0
/// Get the cgroup's name.
pub fn path(&self) -> &str {
self.cgroup.path()
}
}
let mem_controller = self.memory()?;
// Methods for manipulating the actual cgroup
impl CgroupWatcher {
/// Get a handle on the freezer subsystem.
fn freezer(&self) -> anyhow::Result<&FreezerController> {
if let Some(Freezer(freezer)) = self
.cgroup
.subsystems()
.iter()
.find(|sub| matches!(sub, Freezer(_)))
{
Ok(freezer)
} else {
anyhow::bail!("could not find freezer subsystem")
// buffer for samples that will be logged. once full, it remains so.
let history_log_len = self.config.memory_history_log_interval;
let mut history_log_buf = vec![MemoryStatus::zeroed(); history_log_len];
for t in 0_u64.. {
ticker.tick().await;
let now = Instant::now();
let mem = Self::memory_usage(mem_controller);
let i = t as usize % history_log_len;
history_log_buf[i] = mem;
// We're taking *at most* memory_history_len values; we may be bounded by the total
// number of samples that have come in so far.
let samples_count = (t + 1).min(self.config.memory_history_len as u64) as usize;
// NB: in `ring_buf_recent_values_iter`, `i` is *inclusive*, which matches the fact
// that we just inserted a value there, so the end of the iterator will *include* the
// value at i, rather than stopping just short of it.
let samples = ring_buf_recent_values_iter(&history_log_buf, i, samples_count);
let summary = MemoryHistory {
avg_non_reclaimable: samples.map(|h| h.non_reclaimable).sum::<u64>()
/ samples_count as u64,
samples_count,
samples_span: self.config.memory_poll_interval * (samples_count - 1) as u32,
};
// Log the current history if it's time to do so. Because `history_log_buf` has length
// equal to the logging interval, we can just log the entire buffer every time we set
// the last entry, which also means that for this log line, we can ignore that it's a
// ring buffer (because all the entries are in order of increasing time).
if i == history_log_len - 1 {
info!(
history = ?MemoryStatus::debug_slice(&history_log_buf),
summary = ?summary,
"Recent cgroup memory statistics history"
);
}
updates
.send((now, summary))
.context("failed to send MemoryHistory")?;
}
}
/// Attempt to freeze the cgroup.
pub fn freeze(&self) -> anyhow::Result<()> {
self.freezer()
.context("failed to get freezer subsystem")?
.freeze()
.context("failed to freeze")
}
/// Attempt to thaw the cgroup.
pub fn thaw(&self) -> anyhow::Result<()> {
self.freezer()
.context("failed to get freezer subsystem")?
.thaw()
.context("failed to thaw")
unreachable!()
}
/// Get a handle on the memory subsystem.
///
/// Note: this method does not require `self.memory_update_lock` because
/// getting a handle to the subsystem does not access any of the files we
/// care about, such as memory.high and memory.events
fn memory(&self) -> anyhow::Result<&MemController> {
if let Some(Mem(memory)) = self
.cgroup
self.cgroup
.subsystems()
.iter()
.find(|sub| matches!(sub, Mem(_)))
{
Ok(memory)
} else {
anyhow::bail!("could not find memory subsystem")
}
}
/// Get cgroup current memory usage.
pub fn current_memory_usage(&self) -> anyhow::Result<u64> {
Ok(self
.memory()
.context("failed to get memory subsystem")?
.memory_stat()
.usage_in_bytes)
}
/// Set cgroup memory.high threshold.
pub fn set_memory_high_bytes(&self, bytes: u64) -> anyhow::Result<()> {
self.set_memory_high_internal(MaxValue::Value(u64::min(bytes, i64::MAX as u64) as i64))
}
/// Set the cgroup's memory.high to 'max', disabling it.
pub fn unset_memory_high(&self) -> anyhow::Result<()> {
self.set_memory_high_internal(MaxValue::Max)
}
fn set_memory_high_internal(&self, value: MaxValue) -> anyhow::Result<()> {
self.memory()
.context("failed to get memory subsystem")?
.set_mem(cgroups_rs::memory::SetMemory {
low: None,
high: Some(value),
min: None,
max: None,
.find_map(|sub| match sub {
Subsystem::Mem(c) => Some(c),
_ => None,
})
.map_err(anyhow::Error::from)
.ok_or_else(|| anyhow!("could not find memory subsystem"))
}
/// Get memory.high threshold.
pub fn get_memory_high_bytes(&self) -> anyhow::Result<u64> {
let high = self
.memory()
.context("failed to get memory subsystem while getting memory statistics")?
.get_mem()
.map(|mem| mem.high)
.context("failed to get memory statistics from subsystem")?;
match high {
Some(MaxValue::Max) => Ok(i64::MAX as u64),
Some(MaxValue::Value(high)) => Ok(high as u64),
None => anyhow::bail!("failed to read memory.high from memory subsystem"),
/// Given a handle on the memory subsystem, returns the current memory information
fn memory_usage(mem_controller: &MemController) -> MemoryStatus {
let stat = mem_controller.memory_stat().stat;
MemoryStatus {
non_reclaimable: stat.active_anon + stat.inactive_anon,
}
}
}
// Helper function for `CgroupWatcher::watch`
fn ring_buf_recent_values_iter<T>(
buf: &[T],
last_value_idx: usize,
count: usize,
) -> impl '_ + Iterator<Item = &T> {
// Assertion carried over from `CgroupWatcher::watch`, to make the logic in this function
// easier (we only have to add `buf.len()` once, rather than a dynamic number of times).
assert!(count <= buf.len());
buf.iter()
// 'cycle' because the values could wrap around
.cycle()
// with 'cycle', this skip is more like 'offset', and functionally this is
// offsettting by 'last_value_idx - count (mod buf.len())', but we have to be
// careful to avoid underflow, so we pre-add buf.len().
// The '+ 1' is because `last_value_idx` is inclusive, rather than exclusive.
.skip((buf.len() + last_value_idx + 1 - count) % buf.len())
.take(count)
}
/// Summary of recent memory usage
#[derive(Debug, Copy, Clone)]
pub struct MemoryHistory {
/// Rolling average of non-reclaimable memory usage samples over the last `history_period`
pub avg_non_reclaimable: u64,
/// The number of samples used to construct this summary
pub samples_count: usize,
/// Total timespan between the first and last sample used for this summary
pub samples_span: Duration,
}
#[derive(Debug, Copy, Clone)]
pub struct MemoryStatus {
non_reclaimable: u64,
}
impl MemoryStatus {
fn zeroed() -> Self {
MemoryStatus { non_reclaimable: 0 }
}
fn debug_slice(slice: &[Self]) -> impl '_ + Debug {
struct DS<'a>(&'a [MemoryStatus]);
impl<'a> Debug for DS<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("[MemoryStatus]")
.field(
"non_reclaimable[..]",
&Fields(self.0, |stat: &MemoryStatus| {
BytesToGB(stat.non_reclaimable)
}),
)
.finish()
}
}
struct Fields<'a, F>(&'a [MemoryStatus], F);
impl<'a, F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'a, F> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_list().entries(self.0.iter().map(&self.1)).finish()
}
}
struct BytesToGB(u64);
impl Debug for BytesToGB {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.write_fmt(format_args!(
"{:.3}Gi",
self.0 as f64 / (1_u64 << 30) as f64
))
}
}
DS(slice)
}
}
#[cfg(test)]
mod tests {
#[test]
fn ring_buf_iter() {
let buf = vec![0_i32, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let values = |offset, count| {
super::ring_buf_recent_values_iter(&buf, offset, count)
.copied()
.collect::<Vec<i32>>()
};
// Boundary conditions: start, end, and entire thing:
assert_eq!(values(0, 1), [0]);
assert_eq!(values(3, 4), [0, 1, 2, 3]);
assert_eq!(values(9, 4), [6, 7, 8, 9]);
assert_eq!(values(9, 10), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
// "normal" operation: no wraparound
assert_eq!(values(7, 4), [4, 5, 6, 7]);
// wraparound:
assert_eq!(values(0, 4), [7, 8, 9, 0]);
assert_eq!(values(1, 4), [8, 9, 0, 1]);
assert_eq!(values(2, 4), [9, 0, 1, 2]);
assert_eq!(values(2, 10), [3, 4, 5, 6, 7, 8, 9, 0, 1, 2]);
}
}

View File

@@ -12,12 +12,10 @@ use futures::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use tokio::sync::mpsc;
use tracing::info;
use crate::cgroup::Sequenced;
use crate::protocol::{
OutboundMsg, ProtocolRange, ProtocolResponse, ProtocolVersion, Resources, PROTOCOL_MAX_VERSION,
OutboundMsg, ProtocolRange, ProtocolResponse, ProtocolVersion, PROTOCOL_MAX_VERSION,
PROTOCOL_MIN_VERSION,
};
@@ -36,13 +34,6 @@ pub struct Dispatcher {
/// We send messages to the agent through `sink`
sink: SplitSink<WebSocket, Message>,
/// Used to notify the cgroup when we are upscaled.
pub(crate) notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
/// When the cgroup requests upscale it will send on this channel. In response
/// we send an `UpscaleRequst` to the agent.
pub(crate) request_upscale_events: mpsc::Receiver<()>,
/// The protocol version we have agreed to use with the agent. This is negotiated
/// during the creation of the dispatcher, and should be the highest shared protocol
/// version.
@@ -61,11 +52,7 @@ impl Dispatcher {
/// 1. Wait for the agent to sent the range of protocols it supports.
/// 2. Send a protocol version that works for us as well, or an error if there
/// is no compatible version.
pub async fn new(
stream: WebSocket,
notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
request_upscale_events: mpsc::Receiver<()>,
) -> anyhow::Result<Self> {
pub async fn new(stream: WebSocket) -> anyhow::Result<Self> {
let (mut sink, mut source) = stream.split();
// Figure out the highest protocol version we both support
@@ -119,22 +106,10 @@ impl Dispatcher {
Ok(Self {
sink,
source,
notify_upscale_events,
request_upscale_events,
proto_version: highest_shared_version,
})
}
/// Notify the cgroup manager that we have received upscale and wait for
/// the acknowledgement.
#[tracing::instrument(skip_all, fields(?resources))]
pub async fn notify_upscale(&self, resources: Sequenced<Resources>) -> anyhow::Result<()> {
self.notify_upscale_events
.send(resources)
.await
.context("failed to send resources and oneshot sender across channel")
}
/// Send a message to the agent.
///
/// Although this function is small, it has one major benefit: it is the only

View File

@@ -5,18 +5,16 @@
//! all functionality.
use std::fmt::Debug;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{bail, Context};
use axum::extract::ws::{Message, WebSocket};
use futures::StreamExt;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::{broadcast, watch};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use crate::cgroup::{CgroupWatcher, Sequenced};
use crate::cgroup::{self, CgroupWatcher};
use crate::dispatcher::Dispatcher;
use crate::filecache::{FileCacheConfig, FileCacheState};
use crate::protocol::{InboundMsg, InboundMsgKind, OutboundMsg, OutboundMsgKind, Resources};
@@ -28,7 +26,7 @@ use crate::{bytes_to_mebibytes, get_total_system_memory, spawn_with_cancel, Args
pub struct Runner {
config: Config,
filecache: Option<FileCacheState>,
cgroup: Option<Arc<CgroupWatcher>>,
cgroup: Option<CgroupState>,
dispatcher: Dispatcher,
/// We "mint" new message ids by incrementing this counter and taking the value.
@@ -45,6 +43,14 @@ pub struct Runner {
kill: broadcast::Receiver<()>,
}
#[derive(Debug)]
struct CgroupState {
watcher: watch::Receiver<(Instant, cgroup::MemoryHistory)>,
/// If [`cgroup::MemoryHistory::avg_non_reclaimable`] exceeds `threshold`, we send upscale
/// requests.
threshold: u64,
}
/// Configuration for a `Runner`
#[derive(Debug)]
pub struct Config {
@@ -62,16 +68,56 @@ pub struct Config {
/// upscale resource amounts (because we might not *actually* have been upscaled yet). This field
/// should be removed once we have a better solution there.
sys_buffer_bytes: u64,
/// Minimum fraction of total system memory reserved *before* the the cgroup threshold; in
/// other words, providing a ceiling for the highest value of the threshold by enforcing that
/// there's at least `cgroup_min_overhead_fraction` of the total memory remaining beyond the
/// threshold.
///
/// For example, a value of `0.1` means that 10% of total memory must remain after exceeding
/// the threshold, so the value of the cgroup threshold would always be capped at 90% of total
/// memory.
///
/// The default value of `0.15` means that we *guarantee* sending upscale requests if the
/// cgroup is using more than 85% of total memory (even if we're *not* separately reserving
/// memory for the file cache).
cgroup_min_overhead_fraction: f64,
cgroup_downscale_threshold_buffer_bytes: u64,
}
impl Default for Config {
fn default() -> Self {
Self {
sys_buffer_bytes: 100 * MiB,
cgroup_min_overhead_fraction: 0.15,
cgroup_downscale_threshold_buffer_bytes: 100 * MiB,
}
}
}
impl Config {
fn cgroup_threshold(&self, total_mem: u64, file_cache_disk_size: u64) -> u64 {
// If the file cache is in tmpfs, then it will count towards shmem usage of the cgroup,
// and thus be non-reclaimable, so we should allow for additional memory usage.
//
// If the file cache sits on disk, our desired stable system state is for it to be fully
// page cached (its contents should only be paged to/from disk in situations where we can't
// upscale fast enough). Page-cached memory is reclaimable, so we need to lower the
// threshold for non-reclaimable memory so we scale up *before* the kernel starts paging
// out the file cache.
let memory_remaining_for_cgroup = total_mem.saturating_sub(file_cache_disk_size);
// Even if we're not separately making room for the file cache (if it's in tmpfs), we still
// want our threshold to be met gracefully instead of letting postgres get OOM-killed.
// So we guarantee that there's at least `cgroup_min_overhead_fraction` of total memory
// remaining above the threshold.
let max_threshold = (total_mem as f64 * (1.0 - self.cgroup_min_overhead_fraction)) as u64;
memory_remaining_for_cgroup.min(max_threshold)
}
}
impl Runner {
/// Create a new monitor.
#[tracing::instrument(skip_all, fields(?config, ?args))]
@@ -87,12 +133,7 @@ impl Runner {
"invalid monitor Config: sys_buffer_bytes cannot be 0"
);
// *NOTE*: the dispatcher and cgroup manager talk through these channels
// so make sure they each get the correct half, nothing is droppped, etc.
let (notified_send, notified_recv) = mpsc::channel(1);
let (requesting_send, requesting_recv) = mpsc::channel(1);
let dispatcher = Dispatcher::new(ws, notified_send, requesting_recv)
let dispatcher = Dispatcher::new(ws)
.await
.context("error creating new dispatcher")?;
@@ -106,46 +147,10 @@ impl Runner {
kill,
};
// If we have both the cgroup and file cache integrations enabled, it's possible for
// temporary failures to result in cgroup throttling (from memory.high), that in turn makes
// it near-impossible to connect to the file cache (because it times out). Unfortunately,
// we *do* still want to determine the file cache size before setting the cgroup's
// memory.high, so it's not as simple as just swapping the order.
//
// Instead, the resolution here is that on vm-monitor startup (note: happens on each
// connection from autoscaler-agent, possibly multiple times per compute_ctl lifecycle), we
// temporarily unset memory.high, to allow any existing throttling to dissipate. It's a bit
// of a hacky solution, but helps with reliability.
if let Some(name) = &args.cgroup {
// Best not to set up cgroup stuff more than once, so we'll initialize cgroup state
// now, and then set limits later.
info!("initializing cgroup");
let (cgroup, cgroup_event_stream) = CgroupWatcher::new(name.clone(), requesting_send)
.context("failed to create cgroup manager")?;
info!("temporarily unsetting memory.high");
// Temporarily un-set cgroup memory.high; see above.
cgroup
.unset_memory_high()
.context("failed to unset memory.high")?;
let cgroup = Arc::new(cgroup);
let cgroup_clone = Arc::clone(&cgroup);
spawn_with_cancel(
token.clone(),
|_| error!("cgroup watcher terminated"),
async move { cgroup_clone.watch(notified_recv, cgroup_event_stream).await },
);
state.cgroup = Some(cgroup);
}
let mut file_cache_reserved_bytes = 0;
let mem = get_total_system_memory();
let mut file_cache_disk_size = 0;
// We need to process file cache initialization before cgroup initialization, so that the memory
// allocated to the file cache is appropriately taken into account when we decide the cgroup's
// memory limits.
@@ -156,7 +161,7 @@ impl Runner {
false => FileCacheConfig::default_in_memory(),
};
let mut file_cache = FileCacheState::new(connstr, config, token)
let mut file_cache = FileCacheState::new(connstr, config, token.clone())
.await
.context("failed to create file cache")?;
@@ -181,23 +186,40 @@ impl Runner {
if actual_size != new_size {
info!("file cache size actually got set to {actual_size}")
}
// Mark the resources given to the file cache as reserved, but only if it's in memory.
if !args.file_cache_on_disk {
file_cache_reserved_bytes = actual_size;
if args.file_cache_on_disk {
file_cache_disk_size = actual_size;
}
state.filecache = Some(file_cache);
}
if let Some(cgroup) = &state.cgroup {
let available = mem - file_cache_reserved_bytes;
let value = cgroup.config.calculate_memory_high_value(available);
if let Some(name) = &args.cgroup {
// Best not to set up cgroup stuff more than once, so we'll initialize cgroup state
// now, and then set limits later.
info!("initializing cgroup");
info!(value, "setting memory.high");
let cgroup =
CgroupWatcher::new(name.clone()).context("failed to create cgroup manager")?;
cgroup
.set_memory_high_bytes(value)
.context("failed to set cgroup memory.high")?;
let init_value = cgroup::MemoryHistory {
avg_non_reclaimable: 0,
samples_count: 0,
samples_span: Duration::ZERO,
};
let (hist_tx, hist_rx) = watch::channel((Instant::now(), init_value));
spawn_with_cancel(token, |_| error!("cgroup watcher terminated"), async move {
cgroup.watch(hist_tx).await
});
let threshold = state.config.cgroup_threshold(mem, file_cache_disk_size);
info!(threshold, "set initial cgroup threshold",);
state.cgroup = Some(CgroupState {
watcher: hist_rx,
threshold,
});
}
Ok(state)
@@ -217,28 +239,51 @@ impl Runner {
let requested_mem = target.mem;
let usable_system_memory = requested_mem.saturating_sub(self.config.sys_buffer_bytes);
let expected_file_cache_mem_usage = self
let (expected_file_cache_size, expected_file_cache_disk_size) = self
.filecache
.as_ref()
.map(|file_cache| file_cache.config.calculate_cache_size(usable_system_memory))
.unwrap_or(0);
let mut new_cgroup_mem_high = 0;
.map(|file_cache| {
let size = file_cache.config.calculate_cache_size(usable_system_memory);
match file_cache.config.in_memory {
true => (size, 0),
false => (size, size),
}
})
.unwrap_or((0, 0));
if let Some(cgroup) = &self.cgroup {
new_cgroup_mem_high = cgroup
let (last_time, last_history) = *cgroup.watcher.borrow();
// NB: The ordering of these conditions is intentional. During startup, we should deny
// downscaling until we have enough information to determine that it's safe to do so
// (i.e. enough samples have come in). But if it's been a while and we *still* haven't
// received any information, we should *fail* instead of just denying downscaling.
//
// `last_time` is set to `Instant::now()` on startup, so checking `last_time.elapsed()`
// serves double-duty: it trips if we haven't received *any* metrics for long enough,
// OR if we haven't received metrics *recently enough*.
//
// TODO: make the duration here configurable.
if last_time.elapsed() > Duration::from_secs(5) {
bail!("haven't gotten cgroup memory stats recently enough to determine downscaling information");
} else if last_history.samples_count <= 1 {
let status = "haven't received enough cgroup memory stats yet";
info!(status, "discontinuing downscale");
return Ok((false, status.to_owned()));
}
let new_threshold = self
.config
.calculate_memory_high_value(usable_system_memory - expected_file_cache_mem_usage);
.cgroup_threshold(usable_system_memory, expected_file_cache_disk_size);
let current = cgroup
.current_memory_usage()
.context("failed to fetch cgroup memory")?;
let current = last_history.avg_non_reclaimable;
if new_cgroup_mem_high < current + cgroup.config.memory_high_buffer_bytes {
if new_threshold < current + self.config.cgroup_downscale_threshold_buffer_bytes {
let status = format!(
"{}: {} MiB (new high) < {} (current usage) + {} (buffer)",
"calculated memory.high too low",
bytes_to_mebibytes(new_cgroup_mem_high),
"{}: {} MiB (new threshold) < {} (current usage) + {} (downscale buffer)",
"calculated memory threshold too low",
bytes_to_mebibytes(new_threshold),
bytes_to_mebibytes(current),
bytes_to_mebibytes(cgroup.config.memory_high_buffer_bytes)
bytes_to_mebibytes(self.config.cgroup_downscale_threshold_buffer_bytes)
);
info!(status, "discontinuing downscale");
@@ -249,14 +294,14 @@ impl Runner {
// The downscaling has been approved. Downscale the file cache, then the cgroup.
let mut status = vec![];
let mut file_cache_mem_usage = 0;
let mut file_cache_disk_size = 0;
if let Some(file_cache) = &mut self.filecache {
let actual_usage = file_cache
.set_file_cache_size(expected_file_cache_mem_usage)
.set_file_cache_size(expected_file_cache_size)
.await
.context("failed to set file cache size")?;
if file_cache.config.in_memory {
file_cache_mem_usage = actual_usage;
if !file_cache.config.in_memory {
file_cache_disk_size = actual_usage;
}
let message = format!(
"set file cache size to {} MiB (in memory = {})",
@@ -267,24 +312,18 @@ impl Runner {
status.push(message);
}
if let Some(cgroup) = &self.cgroup {
let available_memory = usable_system_memory - file_cache_mem_usage;
if file_cache_mem_usage != expected_file_cache_mem_usage {
new_cgroup_mem_high = cgroup.config.calculate_memory_high_value(available_memory);
}
// new_cgroup_mem_high is initialized to 0 but it is guaranteed to not be here
// since it is properly initialized in the previous cgroup if let block
cgroup
.set_memory_high_bytes(new_cgroup_mem_high)
.context("failed to set cgroup memory.high")?;
if let Some(cgroup) = &mut self.cgroup {
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
let message = format!(
"set cgroup memory.high to {} MiB, of new max {} MiB",
bytes_to_mebibytes(new_cgroup_mem_high),
bytes_to_mebibytes(available_memory)
"set cgroup memory threshold from {} MiB to {} MiB, of new total {} MiB",
bytes_to_mebibytes(cgroup.threshold),
bytes_to_mebibytes(new_threshold),
bytes_to_mebibytes(usable_system_memory)
);
cgroup.threshold = new_threshold;
info!("downscale: {message}");
status.push(message);
}
@@ -305,8 +344,7 @@ impl Runner {
let new_mem = resources.mem;
let usable_system_memory = new_mem.saturating_sub(self.config.sys_buffer_bytes);
// Get the file cache's expected contribution to the memory usage
let mut file_cache_mem_usage = 0;
let mut file_cache_disk_size = 0;
if let Some(file_cache) = &mut self.filecache {
let expected_usage = file_cache.config.calculate_cache_size(usable_system_memory);
info!(
@@ -319,8 +357,8 @@ impl Runner {
.set_file_cache_size(expected_usage)
.await
.context("failed to set file cache size")?;
if file_cache.config.in_memory {
file_cache_mem_usage = actual_usage;
if !file_cache.config.in_memory {
file_cache_disk_size = actual_usage;
}
if actual_usage != expected_usage {
@@ -332,18 +370,18 @@ impl Runner {
}
}
if let Some(cgroup) = &self.cgroup {
let available_memory = usable_system_memory - file_cache_mem_usage;
let new_cgroup_mem_high = cgroup.config.calculate_memory_high_value(available_memory);
if let Some(cgroup) = &mut self.cgroup {
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
info!(
target = bytes_to_mebibytes(new_cgroup_mem_high),
total = bytes_to_mebibytes(new_mem),
name = cgroup.path(),
"updating cgroup memory.high",
"set cgroup memory threshold from {} MiB to {} MiB of new total {} MiB",
bytes_to_mebibytes(cgroup.threshold),
bytes_to_mebibytes(new_threshold),
bytes_to_mebibytes(usable_system_memory)
);
cgroup
.set_memory_high_bytes(new_cgroup_mem_high)
.context("failed to set cgroup memory.high")?;
cgroup.threshold = new_threshold;
}
Ok(())
@@ -361,10 +399,6 @@ impl Runner {
self.handle_upscale(granted)
.await
.context("failed to handle upscale")?;
self.dispatcher
.notify_upscale(Sequenced::new(granted))
.await
.context("failed to notify notify cgroup of upscale")?;
Ok(Some(OutboundMsg::new(
OutboundMsgKind::UpscaleConfirmation {},
id,
@@ -408,33 +442,53 @@ impl Runner {
Err(e) => bail!("failed to receive kill signal: {e}")
}
}
// we need to propagate an upscale request
request = self.dispatcher.request_upscale_events.recv(), if self.cgroup.is_some() => {
if request.is_none() {
bail!("failed to listen for upscale event from cgroup")
// New memory stats from the cgroup, *may* need to request upscaling, if we've
// exceeded the threshold
result = self.cgroup.as_mut().unwrap().watcher.changed(), if self.cgroup.is_some() => {
result.context("failed to receive from cgroup memory stats watcher")?;
let cgroup = self.cgroup.as_ref().unwrap();
let (_time, cgroup_mem_stat) = *cgroup.watcher.borrow();
// If we haven't exceeded the threshold, then we're all ok
if cgroup_mem_stat.avg_non_reclaimable < cgroup.threshold {
continue;
}
// If it's been less than 1 second since the last time we requested upscaling,
// ignore the event, to avoid spamming the agent (otherwise, this can happen
// ~1k times per second).
// Otherwise, we generally want upscaling. But, if it's been less than 1 second
// since the last time we requested upscaling, ignore the event, to avoid
// spamming the agent.
if let Some(t) = self.last_upscale_request_at {
let elapsed = t.elapsed();
if elapsed < Duration::from_secs(1) {
info!(elapsed_millis = elapsed.as_millis(), "cgroup asked for upscale but too soon to forward the request, ignoring");
info!(
elapsed_millis = elapsed.as_millis(),
avg_non_reclaimable = bytes_to_mebibytes(cgroup_mem_stat.avg_non_reclaimable),
threshold = bytes_to_mebibytes(cgroup.threshold),
"cgroup memory stats are high enough to upscale but too soon to forward the request, ignoring",
);
continue;
}
}
self.last_upscale_request_at = Some(Instant::now());
info!("cgroup asking for upscale; forwarding request");
info!(
avg_non_reclaimable = bytes_to_mebibytes(cgroup_mem_stat.avg_non_reclaimable),
threshold = bytes_to_mebibytes(cgroup.threshold),
"cgroup memory stats are high enough to upscale, requesting upscale",
);
self.counter += 2; // Increment, preserving parity (i.e. keep the
// counter odd). See the field comment for more.
self.dispatcher
.send(OutboundMsg::new(OutboundMsgKind::UpscaleRequest {}, self.counter))
.await
.context("failed to send message")?;
}
},
// there is a message from the agent
msg = self.dispatcher.source.next() => {
if let Some(msg) = msg {
@@ -462,11 +516,14 @@ impl Runner {
Ok(Some(out)) => out,
Ok(None) => continue,
Err(e) => {
let error = e.to_string();
warn!(?error, "error handling message");
// use {:#} for our logging because the display impl only
// gives the outermost cause, and the debug impl
// pretty-prints the error, whereas {:#} contains all the
// causes, but is compact (no newlines).
warn!(error = format!("{e:#}"), "error handling message");
OutboundMsg::new(
OutboundMsgKind::InternalError {
error
error: e.to_string(),
},
message.id
)

View File

@@ -11,10 +11,7 @@ use std::sync::{Arc, Barrier};
use bytes::{Buf, Bytes};
use pageserver::{
config::PageServerConf,
repository::Key,
walrecord::NeonWalRecord,
walredo::{PostgresRedoManager, WalRedoError},
config::PageServerConf, repository::Key, walrecord::NeonWalRecord, walredo::PostgresRedoManager,
};
use utils::{id::TenantId, lsn::Lsn};
@@ -152,7 +149,7 @@ impl Drop for JoinOnDrop {
}
}
fn execute_all<I>(input: I, manager: &PostgresRedoManager) -> Result<(), WalRedoError>
fn execute_all<I>(input: I, manager: &PostgresRedoManager) -> anyhow::Result<()>
where
I: IntoIterator<Item = Request>,
{
@@ -160,7 +157,7 @@ where
input.into_iter().try_for_each(|req| {
let page = req.execute(manager)?;
assert_eq!(page.remaining(), 8192);
Ok::<_, WalRedoError>(())
anyhow::Ok(())
})
}
@@ -473,7 +470,7 @@ struct Request {
}
impl Request {
fn execute(self, manager: &PostgresRedoManager) -> Result<Bytes, WalRedoError> {
fn execute(self, manager: &PostgresRedoManager) -> anyhow::Result<Bytes> {
use pageserver::walredo::WalRedoManager;
let Request {

View File

@@ -355,6 +355,7 @@ fn start_pageserver(
// consumer side) will be dropped once we can start the background jobs. Currently it is behind
// completing all initial logical size calculations (init_logical_size_done_rx) and a timeout
// (background_task_maximum_delay).
let (init_remote_done_tx, init_remote_done_rx) = utils::completion::channel();
let (init_done_tx, init_done_rx) = utils::completion::channel();
let (init_logical_size_done_tx, init_logical_size_done_rx) = utils::completion::channel();
@@ -362,7 +363,8 @@ fn start_pageserver(
let (background_jobs_can_start, background_jobs_barrier) = utils::completion::channel();
let order = pageserver::InitializationOrder {
initial_tenant_load: Some(init_done_tx),
initial_tenant_load_remote: Some(init_done_tx),
initial_tenant_load: Some(init_remote_done_tx),
initial_logical_size_can_start: init_done_rx.clone(),
initial_logical_size_attempt: Some(init_logical_size_done_tx),
background_jobs_can_start: background_jobs_barrier.clone(),
@@ -388,6 +390,9 @@ fn start_pageserver(
// NOTE: unlike many futures in pageserver, this one is cancellation-safe
let guard = scopeguard::guard_on_success((), |_| tracing::info!("Cancelled before initial load completed"));
init_remote_done_rx.wait().await;
startup_checkpoint("initial_tenant_load_remote", "Remote part of initial load completed");
init_done_rx.wait().await;
startup_checkpoint("initial_tenant_load", "Initial load completed");
STARTUP_IS_LOADING.set(0);

View File

@@ -2,6 +2,7 @@
//! and push them to a HTTP endpoint.
use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::{self, TaskKind, BACKGROUND_RUNTIME};
use crate::tenant::tasks::BackgroundLoopKind;
use crate::tenant::{mgr, LogicalSizeCalculationCause};
use camino::Utf8PathBuf;
use consumption_metrics::EventType;
@@ -143,7 +144,7 @@ pub async fn collect_metrics(
crate::tenant::tasks::warn_when_period_overrun(
tick_at.elapsed(),
metric_collection_interval,
"consumption_metrics_collect_metrics",
BackgroundLoopKind::ConsumptionMetricsCollectMetrics,
);
}
}
@@ -268,6 +269,11 @@ async fn calculate_synthetic_size_worker(
}
if let Ok(tenant) = mgr::get_tenant(tenant_id, true).await {
// TODO should we use concurrent_background_tasks_rate_limit() here, like the other background tasks?
// We can put in some prioritization for consumption metrics.
// Same for the loop that fetches computed metrics.
// By using the same limiter, we centralize metrics collection for "start" and "finished" counters,
// which turns out is really handy to understand the system.
if let Err(e) = tenant.calculate_synthetic_size(cause, ctx).await {
error!("failed to calculate synthetic size for tenant {tenant_id}: {e:#}");
}
@@ -277,7 +283,7 @@ async fn calculate_synthetic_size_worker(
crate::tenant::tasks::warn_when_period_overrun(
tick_at.elapsed(),
synthetic_size_calculation_interval,
"consumption_metrics_synthetic_size_worker",
BackgroundLoopKind::ConsumptionMetricsSyntheticSizeWorker,
);
}
}

View File

@@ -13,6 +13,7 @@ use std::time::Duration;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::warn;
use utils::backoff;
use crate::metrics;
@@ -63,7 +64,19 @@ impl Deleter {
Err(anyhow::anyhow!("failpoint hit"))
});
self.remote_storage.delete_objects(&self.accumulator).await
// A backoff::retry is used here for two reasons:
// - To provide a backoff rather than busy-polling the API on errors
// - To absorb transient 429/503 conditions without hitting our error
// logging path for issues deleting objects.
backoff::retry(
|| async { self.remote_storage.delete_objects(&self.accumulator).await },
|_| false,
3,
10,
"executing deletion batch",
backoff::Cancel::new(self.cancel.clone(), || anyhow::anyhow!("Shutting down")),
)
.await
}
/// Block until everything in accumulator has been executed
@@ -88,7 +101,10 @@ impl Deleter {
self.accumulator.clear();
}
Err(e) => {
warn!("DeleteObjects request failed: {e:#}, will retry");
if self.cancel.is_cancelled() {
return Err(DeletionQueueError::ShuttingDown);
}
warn!("DeleteObjects request failed: {e:#}, will continue trying");
metrics::DELETION_QUEUE
.remote_errors
.with_label_values(&["execute"])

View File

@@ -411,6 +411,11 @@ pub async fn disk_usage_eviction_task_iteration_impl<U: Usage>(
evictions_failed.file_sizes += file_size;
evictions_failed.count += 1;
}
Some(Err(EvictionError::MetadataInconsistency(detail))) => {
warn!(%layer, "failed to evict layer: {detail}");
evictions_failed.file_sizes += file_size;
evictions_failed.count += 1;
}
None => {
assert!(cancel.is_cancelled());
return;

View File

@@ -136,9 +136,7 @@ impl From<PageReconstructError> for ApiError {
PageReconstructError::AncestorStopping(_) => {
ApiError::ResourceUnavailable(format!("{pre}").into())
}
PageReconstructError::WalRedo(pre) => {
ApiError::InternalServerError(anyhow::Error::new(pre))
}
PageReconstructError::WalRedo(pre) => ApiError::InternalServerError(pre),
}
}
}
@@ -164,9 +162,6 @@ impl From<TenantStateError> for ApiError {
fn from(tse: TenantStateError) -> ApiError {
match tse {
TenantStateError::NotFound(tid) => ApiError::NotFound(anyhow!("tenant {}", tid).into()),
TenantStateError::NotActive(_) => {
ApiError::ResourceUnavailable("Tenant not yet active".into())
}
TenantStateError::IsStopping(_) => {
ApiError::ResourceUnavailable("Tenant is stopping".into())
}

View File

@@ -173,6 +173,9 @@ fn is_walkdir_io_not_found(e: &walkdir::Error) -> bool {
/// delaying is needed.
#[derive(Clone)]
pub struct InitializationOrder {
/// Each initial tenant load task carries this until it is done loading timelines from remote storage
pub initial_tenant_load_remote: Option<utils::completion::Completion>,
/// Each initial tenant load task carries this until completion.
pub initial_tenant_load: Option<utils::completion::Completion>,

View File

@@ -1067,6 +1067,26 @@ pub(crate) static TENANT_TASK_EVENTS: Lazy<IntCounterVec> = Lazy::new(|| {
.expect("Failed to register tenant_task_events metric")
});
pub(crate) static BACKGROUND_LOOP_SEMAPHORE_WAIT_START_COUNT: Lazy<IntCounterVec> =
Lazy::new(|| {
register_int_counter_vec!(
"pageserver_background_loop_semaphore_wait_start_count",
"Counter for background loop concurrency-limiting semaphore acquire calls started",
&["task"],
)
.unwrap()
});
pub(crate) static BACKGROUND_LOOP_SEMAPHORE_WAIT_FINISH_COUNT: Lazy<IntCounterVec> =
Lazy::new(|| {
register_int_counter_vec!(
"pageserver_background_loop_semaphore_wait_finish_count",
"Counter for background loop concurrency-limiting semaphore acquire calls finished",
&["task"],
)
.unwrap()
});
pub(crate) static BACKGROUND_LOOP_PERIOD_OVERRUN_COUNT: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_background_loop_period_overrun_count",

View File

@@ -23,12 +23,14 @@ use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::completion;
use utils::completion::Completion;
use utils::crashsafe::path_with_suffix_extension;
use std::cmp::min;
use std::collections::hash_map::Entry;
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::collections::HashSet;
use std::fmt::Debug;
use std::fmt::Display;
use std::fs;
@@ -185,6 +187,11 @@ impl AttachedTenantConf {
}
}
}
struct TimelinePreload {
timeline_id: TimelineId,
client: RemoteTimelineClient,
index_part: Result<MaybeDeletedIndexPart, DownloadError>,
}
///
/// Tenant consists of multiple timelines. Keep them in a hash table.
@@ -209,7 +216,7 @@ pub struct Tenant {
/// The remote storage generation, used to protect S3 objects from split-brain.
/// Does not change over the lifetime of the [`Tenant`] object.
///
///
/// This duplicates the generation stored in LocationConf, but that structure is mutable:
/// this copy enforces the invariant that generatio doesn't change during a Tenant's lifetime.
generation: Generation,
@@ -962,6 +969,9 @@ impl Tenant {
let _completion = init_order
.as_mut()
.and_then(|x| x.initial_tenant_load.take());
let remote_load_completion = init_order
.as_mut()
.and_then(|x| x.initial_tenant_load_remote.take());
// Dont block pageserver startup on figuring out deletion status
let pending_deletion = {
@@ -986,6 +996,7 @@ impl Tenant {
// as we are no longer loading, signal completion by dropping
// the completion while we resume deletion
drop(_completion);
drop(remote_load_completion);
// do not hold to initial_logical_size_attempt as it will prevent loading from proceeding without timeout
let _ = init_order
.as_mut()
@@ -1011,7 +1022,10 @@ impl Tenant {
let background_jobs_can_start =
init_order.as_ref().map(|x| &x.background_jobs_can_start);
match tenant_clone.load(init_order.as_ref(), &ctx).await {
match tenant_clone
.load(init_order.as_ref(), remote_load_completion, &ctx)
.await
{
Ok(()) => {
debug!("load finished");
@@ -1175,6 +1189,52 @@ impl Tenant {
})
}
async fn load_timeline_metadata(
self: &Arc<Tenant>,
timeline_ids: HashSet<TimelineId>,
remote_storage: &GenericRemoteStorage,
) -> anyhow::Result<HashMap<TimelineId, TimelinePreload>> {
let mut part_downloads = JoinSet::new();
for timeline_id in timeline_ids {
let client = RemoteTimelineClient::new(
remote_storage.clone(),
self.deletion_queue_client.clone(),
self.conf,
self.tenant_id,
timeline_id,
self.generation,
);
part_downloads.spawn(
async move {
debug!("starting index part download");
let index_part = client.download_index_file().await;
debug!("finished index part download");
Result::<_, anyhow::Error>::Ok(TimelinePreload {
client,
timeline_id,
index_part,
})
}
.map(move |res| {
res.with_context(|| format!("download index part for timeline {timeline_id}"))
})
.instrument(info_span!("download_index_part", %timeline_id)),
);
}
let mut timeline_preloads: HashMap<TimelineId, TimelinePreload> = HashMap::new();
while let Some(result) = part_downloads.join_next().await {
let preload_result = result.context("join preload task")?;
let preload = preload_result?;
timeline_preloads.insert(preload.timeline_id, preload);
}
Ok(timeline_preloads)
}
///
/// Background task to load in-memory data structures for this tenant, from
/// files on disk. Used at pageserver startup.
@@ -1183,14 +1243,13 @@ impl Tenant {
async fn load(
self: &Arc<Tenant>,
init_order: Option<&InitializationOrder>,
remote_completion: Option<Completion>,
ctx: &RequestContext,
) -> anyhow::Result<()> {
span::debug_assert_current_span_has_tenant_id();
debug!("loading tenant task");
crate::failpoint_support::sleep_millis_async!("before-loading-tenant");
// Load in-memory state to reflect the local files on disk
//
// Scan the directory, peek into the metadata file of each timeline, and
@@ -1209,10 +1268,38 @@ impl Tenant {
// FIXME original collect_timeline_files contained one more check:
// 1. "Timeline has no ancestor and no layer files"
// Load remote content for timelines in this tenant
let all_timeline_ids = scan
.sorted_timelines_to_load
.iter()
.map(|i| i.0)
.chain(scan.timelines_to_resume_deletion.iter().map(|i| i.0))
.collect();
let mut preload = if let Some(remote_storage) = &self.remote_storage {
Some(
self.load_timeline_metadata(all_timeline_ids, remote_storage)
.await?,
)
} else {
None
};
drop(remote_completion);
crate::failpoint_support::sleep_millis_async!("before-loading-tenant");
// Process loadable timelines first
for (timeline_id, local_metadata) in scan.sorted_timelines_to_load {
let timeline_preload = preload.as_mut().map(|p| p.remove(&timeline_id).unwrap());
if let Err(e) = self
.load_local_timeline(timeline_id, local_metadata, init_order, ctx, false)
.load_local_timeline(
timeline_id,
local_metadata,
timeline_preload,
init_order,
ctx,
false,
)
.await
{
match e {
@@ -1245,8 +1332,17 @@ impl Tenant {
}
}
Some(local_metadata) => {
let timeline_preload =
preload.as_mut().map(|p| p.remove(&timeline_id).unwrap());
if let Err(e) = self
.load_local_timeline(timeline_id, local_metadata, init_order, ctx, true)
.load_local_timeline(
timeline_id,
local_metadata,
timeline_preload,
init_order,
ctx,
true,
)
.await
{
match e {
@@ -1274,11 +1370,12 @@ impl Tenant {
/// Subroutine of `load_tenant`, to load an individual timeline
///
/// NB: The parent is assumed to be already loaded!
#[instrument(skip(self, local_metadata, init_order, ctx))]
#[instrument(skip(self, local_metadata, init_order, preload, ctx))]
async fn load_local_timeline(
self: &Arc<Self>,
timeline_id: TimelineId,
local_metadata: TimelineMetadata,
preload: Option<TimelinePreload>,
init_order: Option<&InitializationOrder>,
ctx: &RequestContext,
found_delete_mark: bool,
@@ -1287,74 +1384,81 @@ impl Tenant {
let mut resources = self.build_timeline_resources(timeline_id);
let (remote_startup_data, remote_client) = match resources.remote_client {
Some(remote_client) => match remote_client.download_index_file().await {
Ok(index_part) => {
let index_part = match index_part {
MaybeDeletedIndexPart::IndexPart(index_part) => index_part,
MaybeDeletedIndexPart::Deleted(index_part) => {
// TODO: we won't reach here if remote storage gets de-configured after start of the deletion operation.
// Example:
// start deletion operation
// finishes upload of index part
// pageserver crashes
// remote storage gets de-configured
// pageserver starts
//
// We don't really anticipate remote storage to be de-configured, so, for now, this is fine.
// Also, maybe we'll remove that option entirely in the future, see https://github.com/neondatabase/neon/issues/4099.
info!("is_deleted is set on remote, resuming removal of timeline data originally done by timeline deletion handler");
let (remote_startup_data, remote_client) = match preload {
Some(preload) => {
let TimelinePreload {
index_part,
client: remote_client,
timeline_id: _timeline_id,
} = preload;
match index_part {
Ok(index_part) => {
let index_part = match index_part {
MaybeDeletedIndexPart::IndexPart(index_part) => index_part,
MaybeDeletedIndexPart::Deleted(index_part) => {
// TODO: we won't reach here if remote storage gets de-configured after start of the deletion operation.
// Example:
// start deletion operation
// finishes upload of index part
// pageserver crashes
// remote storage gets de-configured
// pageserver starts
//
// We don't really anticipate remote storage to be de-configured, so, for now, this is fine.
// Also, maybe we'll remove that option entirely in the future, see https://github.com/neondatabase/neon/issues/4099.
info!("is_deleted is set on remote, resuming removal of timeline data originally done by timeline deletion handler");
remote_client
.init_upload_queue_stopped_to_continue_deletion(&index_part)
.context("init queue stopped")
remote_client
.init_upload_queue_stopped_to_continue_deletion(&index_part)
.context("init queue stopped")
.map_err(LoadLocalTimelineError::ResumeDeletion)?;
DeleteTimelineFlow::resume_deletion(
Arc::clone(self),
timeline_id,
&local_metadata,
Some(remote_client),
self.deletion_queue_client.clone(),
init_order,
)
.await
.context("resume deletion")
.map_err(LoadLocalTimelineError::ResumeDeletion)?;
DeleteTimelineFlow::resume_deletion(
Arc::clone(self),
return Ok(());
}
};
let remote_metadata = index_part.metadata.clone();
(
Some(RemoteStartupData {
index_part,
remote_metadata,
}),
Some(remote_client),
)
}
Err(DownloadError::NotFound) => {
info!("no index file was found on the remote, found_delete_mark: {found_delete_mark}");
if found_delete_mark {
// We could've resumed at a point where remote index was deleted, but metadata file wasnt.
// Cleanup:
return DeleteTimelineFlow::cleanup_remaining_timeline_fs_traces(
self,
timeline_id,
&local_metadata,
Some(remote_client),
self.deletion_queue_client.clone(),
init_order,
)
.await
.context("resume deletion")
.map_err(LoadLocalTimelineError::ResumeDeletion)?;
return Ok(());
.context("cleanup_remaining_timeline_fs_traces")
.map_err(LoadLocalTimelineError::ResumeDeletion);
}
};
let remote_metadata = index_part.metadata.clone();
(
Some(RemoteStartupData {
index_part,
remote_metadata,
}),
Some(remote_client),
)
}
Err(DownloadError::NotFound) => {
info!("no index file was found on the remote, found_delete_mark: {found_delete_mark}");
if found_delete_mark {
// We could've resumed at a point where remote index was deleted, but metadata file wasnt.
// Cleanup:
return DeleteTimelineFlow::cleanup_remaining_timeline_fs_traces(
self,
timeline_id,
)
.await
.context("cleanup_remaining_timeline_fs_traces")
.map_err(LoadLocalTimelineError::ResumeDeletion);
// We're loading fresh timeline that didnt yet make it into remote.
(None, Some(remote_client))
}
// We're loading fresh timeline that didnt yet make it into remote.
(None, Some(remote_client))
Err(e) => return Err(LoadLocalTimelineError::Load(anyhow::Error::new(e))),
}
Err(e) => return Err(LoadLocalTimelineError::Load(anyhow::Error::new(e))),
},
}
None => {
// No remote client
if found_delete_mark {
@@ -2755,6 +2859,11 @@ impl Tenant {
) -> Result<Arc<Timeline>, CreateTimelineError> {
let src_id = src_timeline.timeline_id;
// First acquire the GC lock so that another task cannot advance the GC
// cutoff in 'gc_info', and make 'start_lsn' invalid, while we are
// creating the branch.
let _gc_cs = self.gc_cs.lock().await;
// If no start LSN is specified, we branch the new timeline from the source timeline's last record LSN
let start_lsn = start_lsn.unwrap_or_else(|| {
let lsn = src_timeline.get_last_record_lsn();
@@ -2762,11 +2871,6 @@ impl Tenant {
lsn
});
// First acquire the GC lock so that another task cannot advance the GC
// cutoff in 'gc_info', and make 'start_lsn' invalid, while we are
// creating the branch.
let _gc_cs = self.gc_cs.lock().await;
// Create a placeholder for the new branch. This will error
// out if the new timeline ID is already in use.
let timeline_uninit_mark = {
@@ -3448,11 +3552,8 @@ pub mod harness {
use crate::deletion_queue::mock::MockDeletionQueue;
use crate::{
config::PageServerConf,
repository::Key,
tenant::Tenant,
walrecord::NeonWalRecord,
walredo::{WalRedoError, WalRedoManager},
config::PageServerConf, repository::Key, tenant::Tenant, walrecord::NeonWalRecord,
walredo::WalRedoManager,
};
use super::*;
@@ -3597,7 +3698,7 @@ pub mod harness {
self.deletion_queue.new_client(),
));
tenant
.load(None, ctx)
.load(None, None, ctx)
.instrument(info_span!("try_load", tenant_id=%self.tenant_id))
.await?;
@@ -3625,7 +3726,7 @@ pub mod harness {
base_img: Option<(Lsn, Bytes)>,
records: Vec<(Lsn, NeonWalRecord)>,
_pg_version: u32,
) -> Result<Bytes, WalRedoError> {
) -> anyhow::Result<Bytes> {
let s = format!(
"redo for {} to get to {}, with {} and {} records",
key,

View File

@@ -31,7 +31,7 @@ use super::{
const SHOULD_RESUME_DELETION_FETCH_MARK_ATTEMPTS: u32 = 3;
#[derive(Debug, thiserror::Error)]
pub enum DeleteTenantError {
pub(crate) enum DeleteTenantError {
#[error("GetTenant {0}")]
Get(#[from] GetTenantError),
@@ -376,7 +376,7 @@ impl DeleteTenantFlow {
Ok(())
}
pub async fn should_resume_deletion(
pub(crate) async fn should_resume_deletion(
conf: &'static PageServerConf,
remote_storage: Option<&GenericRemoteStorage>,
tenant: &Tenant,
@@ -432,7 +432,7 @@ impl DeleteTenantFlow {
// Tenant may not be loadable if we fail late in cleanup_remaining_fs_traces (e g remove timelines dir)
let timelines_path = tenant.conf.timelines_path(&tenant.tenant_id);
if timelines_path.exists() {
tenant.load(init_order, ctx).await.context("load")?;
tenant.load(init_order, None, ctx).await.context("load")?;
}
Self::background(

View File

@@ -50,7 +50,7 @@ use super::TenantSharedResources;
/// its lifetime, and we can preserve some important safety invariants like `Tenant` always
/// having a properly acquired generation (Secondary doesn't need a generation)
#[derive(Clone)]
pub enum TenantSlot {
pub(crate) enum TenantSlot {
Attached(Arc<Tenant>),
Secondary,
}
@@ -481,7 +481,7 @@ pub(crate) fn schedule_local_tenant_processing(
/// management API. For example, it could attach the tenant on a different pageserver.
/// We would then be in split-brain once this pageserver restarts.
#[instrument(skip_all)]
pub async fn shutdown_all_tenants() {
pub(crate) async fn shutdown_all_tenants() {
shutdown_all_tenants0(&TENANTS).await
}
@@ -593,7 +593,7 @@ async fn shutdown_all_tenants0(tenants: &tokio::sync::RwLock<TenantsMap>) {
// caller will log how long we took
}
pub async fn create_tenant(
pub(crate) async fn create_tenant(
conf: &'static PageServerConf,
tenant_conf: TenantConfOpt,
tenant_id: TenantId,
@@ -628,14 +628,14 @@ pub async fn create_tenant(
}
#[derive(Debug, thiserror::Error)]
pub enum SetNewTenantConfigError {
pub(crate) enum SetNewTenantConfigError {
#[error(transparent)]
GetTenant(#[from] GetTenantError),
#[error(transparent)]
Persist(anyhow::Error),
}
pub async fn set_new_tenant_config(
pub(crate) async fn set_new_tenant_config(
conf: &'static PageServerConf,
new_tenant_conf: TenantConfOpt,
tenant_id: TenantId,
@@ -776,7 +776,7 @@ pub(crate) async fn upsert_location(
}
#[derive(Debug, thiserror::Error)]
pub enum GetTenantError {
pub(crate) enum GetTenantError {
#[error("Tenant {0} not found")]
NotFound(TenantId),
#[error("Tenant {0} is not active")]
@@ -792,7 +792,7 @@ pub enum GetTenantError {
/// `active_only = true` allows to query only tenants that are ready for operations, erroring on other kinds of tenants.
///
/// This method is cancel-safe.
pub async fn get_tenant(
pub(crate) async fn get_tenant(
tenant_id: TenantId,
active_only: bool,
) -> Result<Arc<Tenant>, GetTenantError> {
@@ -817,7 +817,7 @@ pub async fn get_tenant(
}
}
pub async fn delete_tenant(
pub(crate) async fn delete_tenant(
conf: &'static PageServerConf,
remote_storage: Option<GenericRemoteStorage>,
tenant_id: TenantId,
@@ -826,7 +826,7 @@ pub async fn delete_tenant(
}
#[derive(Debug, thiserror::Error)]
pub enum DeleteTimelineError {
pub(crate) enum DeleteTimelineError {
#[error("Tenant {0}")]
Tenant(#[from] GetTenantError),
@@ -834,7 +834,7 @@ pub enum DeleteTimelineError {
Timeline(#[from] crate::tenant::DeleteTimelineError),
}
pub async fn delete_timeline(
pub(crate) async fn delete_timeline(
tenant_id: TenantId,
timeline_id: TimelineId,
_ctx: &RequestContext,
@@ -845,18 +845,16 @@ pub async fn delete_timeline(
}
#[derive(Debug, thiserror::Error)]
pub enum TenantStateError {
pub(crate) enum TenantStateError {
#[error("Tenant {0} not found")]
NotFound(TenantId),
#[error("Tenant {0} is stopping")]
IsStopping(TenantId),
#[error("Tenant {0} is not active")]
NotActive(TenantId),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub async fn detach_tenant(
pub(crate) async fn detach_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
detach_ignored: bool,
@@ -926,7 +924,7 @@ async fn detach_tenant0(
removal_result
}
pub async fn load_tenant(
pub(crate) async fn load_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
generation: Generation,
@@ -963,7 +961,7 @@ pub async fn load_tenant(
Ok(())
}
pub async fn ignore_tenant(
pub(crate) async fn ignore_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
) -> Result<(), TenantStateError> {
@@ -991,7 +989,7 @@ async fn ignore_tenant0(
}
#[derive(Debug, thiserror::Error)]
pub enum TenantMapListError {
pub(crate) enum TenantMapListError {
#[error("tenant map is still initiailizing")]
Initializing,
}
@@ -999,7 +997,7 @@ pub enum TenantMapListError {
///
/// Get list of tenants, for the mgmt API
///
pub async fn list_tenants() -> Result<Vec<(TenantId, TenantState)>, TenantMapListError> {
pub(crate) async fn list_tenants() -> Result<Vec<(TenantId, TenantState)>, TenantMapListError> {
let tenants = TENANTS.read().await;
let m = match &*tenants {
TenantsMap::Initializing => return Err(TenantMapListError::Initializing),
@@ -1017,7 +1015,7 @@ pub async fn list_tenants() -> Result<Vec<(TenantId, TenantState)>, TenantMapLis
///
/// Downloading all the tenant data is performed in the background, this merely
/// spawns the background task and returns quickly.
pub async fn attach_tenant(
pub(crate) async fn attach_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
generation: Generation,
@@ -1054,7 +1052,7 @@ pub async fn attach_tenant(
}
#[derive(Debug, thiserror::Error)]
pub enum TenantMapInsertError {
pub(crate) enum TenantMapInsertError {
#[error("tenant map is still initializing")]
StillInitializing,
#[error("tenant map is shutting down")]
@@ -1217,7 +1215,7 @@ use {
utils::http::error::ApiError,
};
pub async fn immediate_gc(
pub(crate) async fn immediate_gc(
tenant_id: TenantId,
timeline_id: TimelineId,
gc_req: TimelineGcRequest,

View File

@@ -1419,6 +1419,13 @@ impl RemoteTimelineClient {
}
}
}
pub(crate) fn get_layer_metadata(
&self,
name: &LayerFileName,
) -> anyhow::Result<Option<LayerFileMetadata>> {
self.upload_queue.lock().unwrap().get_layer_metadata(name)
}
}
pub fn remote_timelines_path(tenant_id: &TenantId) -> RemotePath {

View File

@@ -14,6 +14,73 @@ use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::completion;
static CONCURRENT_BACKGROUND_TASKS: once_cell::sync::Lazy<tokio::sync::Semaphore> =
once_cell::sync::Lazy::new(|| {
let total_threads = *task_mgr::BACKGROUND_RUNTIME_WORKER_THREADS;
let permits = usize::max(
1,
// while a lot of the work is done on spawn_blocking, we still do
// repartitioning in the async context. this should give leave us some workers
// unblocked to be blocked on other work, hopefully easing any outside visible
// effects of restarts.
//
// 6/8 is a guess; previously we ran with unlimited 8 and more from
// spawn_blocking.
(total_threads * 3).checked_div(4).unwrap_or(0),
);
assert_ne!(permits, 0, "we will not be adding in permits later");
assert!(
permits < total_threads,
"need threads avail for shorter work"
);
tokio::sync::Semaphore::new(permits)
});
#[derive(Debug, PartialEq, Eq, Clone, Copy, strum_macros::IntoStaticStr)]
#[strum(serialize_all = "snake_case")]
pub(crate) enum BackgroundLoopKind {
Compaction,
Gc,
Eviction,
ConsumptionMetricsCollectMetrics,
ConsumptionMetricsSyntheticSizeWorker,
}
impl BackgroundLoopKind {
fn as_static_str(&self) -> &'static str {
let s: &'static str = self.into();
s
}
}
pub(crate) enum RateLimitError {
Cancelled,
}
pub(crate) async fn concurrent_background_tasks_rate_limit(
loop_kind: BackgroundLoopKind,
_ctx: &RequestContext,
cancel: &CancellationToken,
) -> Result<impl Drop, RateLimitError> {
crate::metrics::BACKGROUND_LOOP_SEMAPHORE_WAIT_START_COUNT
.with_label_values(&[loop_kind.as_static_str()])
.inc();
scopeguard::defer!(
crate::metrics::BACKGROUND_LOOP_SEMAPHORE_WAIT_FINISH_COUNT.with_label_values(&[loop_kind.as_static_str()]).inc();
);
tokio::select! {
permit = CONCURRENT_BACKGROUND_TASKS.acquire() => {
match permit {
Ok(permit) => Ok(permit),
Err(_closed) => unreachable!("we never close the semaphore"),
}
},
_ = cancel.cancelled() => {
Err(RateLimitError::Cancelled)
}
}
}
/// Start per tenant background loops: compaction and gc.
pub fn start_background_loops(
tenant: &Arc<Tenant>,
@@ -116,7 +183,7 @@ async fn compaction_loop(tenant: Arc<Tenant>, cancel: CancellationToken) {
}
};
warn_when_period_overrun(started_at.elapsed(), period, "compaction");
warn_when_period_overrun(started_at.elapsed(), period, BackgroundLoopKind::Compaction);
// Sleep
if tokio::time::timeout(sleep_duration, cancel.cancelled())
@@ -184,7 +251,7 @@ async fn gc_loop(tenant: Arc<Tenant>, cancel: CancellationToken) {
}
};
warn_when_period_overrun(started_at.elapsed(), period, "gc");
warn_when_period_overrun(started_at.elapsed(), period, BackgroundLoopKind::Gc);
// Sleep
if tokio::time::timeout(sleep_duration, cancel.cancelled())
@@ -258,7 +325,11 @@ pub(crate) async fn random_init_delay(
}
/// Attention: the `task` and `period` beocme labels of a pageserver-wide prometheus metric.
pub(crate) fn warn_when_period_overrun(elapsed: Duration, period: Duration, task: &str) {
pub(crate) fn warn_when_period_overrun(
elapsed: Duration,
period: Duration,
task: BackgroundLoopKind,
) {
// Duration::ZERO will happen because it's the "disable [bgtask]" value.
if elapsed >= period && period != Duration::ZERO {
// humantime does no significant digits clamping whereas Duration's debug is a bit more
@@ -267,11 +338,11 @@ pub(crate) fn warn_when_period_overrun(elapsed: Duration, period: Duration, task
warn!(
?elapsed,
period = %humantime::format_duration(period),
task,
?task,
"task iteration took longer than the configured period"
);
crate::metrics::BACKGROUND_LOOP_PERIOD_OVERRUN_COUNT
.with_label_values(&[task, &format!("{}", period.as_secs())])
.with_label_values(&[task.as_static_str(), &format!("{}", period.as_secs())])
.inc();
}
}

View File

@@ -44,6 +44,7 @@ use crate::tenant::storage_layer::delta_layer::DeltaEntry;
use crate::tenant::storage_layer::{
DeltaLayerWriter, ImageLayerWriter, InMemoryLayer, LayerAccessStats, LayerFileName, RemoteLayer,
};
use crate::tenant::tasks::{BackgroundLoopKind, RateLimitError};
use crate::tenant::timeline::logical_size::CurrentLogicalSize;
use crate::tenant::{
layer_map::{LayerMap, SearchResult},
@@ -370,7 +371,7 @@ pub enum PageReconstructError {
/// An error happened replaying WAL records
#[error(transparent)]
WalRedo(#[from] crate::walredo::WalRedoError),
WalRedo(anyhow::Error),
}
impl std::fmt::Debug for PageReconstructError {
@@ -684,37 +685,17 @@ impl Timeline {
) -> anyhow::Result<()> {
const ROUNDS: usize = 2;
static CONCURRENT_COMPACTIONS: once_cell::sync::Lazy<tokio::sync::Semaphore> =
once_cell::sync::Lazy::new(|| {
let total_threads = *task_mgr::BACKGROUND_RUNTIME_WORKER_THREADS;
let permits = usize::max(
1,
// while a lot of the work is done on spawn_blocking, we still do
// repartitioning in the async context. this should give leave us some workers
// unblocked to be blocked on other work, hopefully easing any outside visible
// effects of restarts.
//
// 6/8 is a guess; previously we ran with unlimited 8 and more from
// spawn_blocking.
(total_threads * 3).checked_div(4).unwrap_or(0),
);
assert_ne!(permits, 0, "we will not be adding in permits later");
assert!(
permits < total_threads,
"need threads avail for shorter work"
);
tokio::sync::Semaphore::new(permits)
});
// this wait probably never needs any "long time spent" logging, because we already nag if
// compaction task goes over it's period (20s) which is quite often in production.
let _permit = tokio::select! {
permit = CONCURRENT_COMPACTIONS.acquire() => {
permit
},
_ = cancel.cancelled() => {
return Ok(());
}
let _permit = match super::tasks::concurrent_background_tasks_rate_limit(
BackgroundLoopKind::Compaction,
ctx,
cancel,
)
.await
{
Ok(permit) => permit,
Err(RateLimitError::Cancelled) => return Ok(()),
};
let last_record_lsn = self.get_last_record_lsn();
@@ -1294,7 +1275,23 @@ impl Timeline {
Ok(delta) => Some(delta),
};
let layer_metadata = LayerFileMetadata::new(layer_file_size, self.generation);
// RemoteTimelineClient holds the metadata on layers' remote generations, so
// query it to construct a RemoteLayer.
let layer_metadata = self
.remote_client
.as_ref()
.expect("Eviction is not called without remote storage")
.get_layer_metadata(&local_layer.filename())
.map_err(EvictionError::LayerNotFound)?
.ok_or_else(|| {
EvictionError::LayerNotFound(anyhow::anyhow!("Layer not in remote metadata"))
})?;
if layer_metadata.file_size() != layer_file_size {
return Err(EvictionError::MetadataInconsistency(format!(
"Layer size {layer_file_size} doesn't match remote metadata file size {}",
layer_metadata.file_size()
)));
}
let new_remote_layer = Arc::new(match local_layer.filename() {
LayerFileName::Image(image_name) => RemoteLayer::new_img(
@@ -1373,6 +1370,10 @@ pub(crate) enum EvictionError {
/// different objects in memory.
#[error("layer was no longer part of LayerMap")]
LayerNotFound(#[source] anyhow::Error),
/// This should never happen
#[error("Metadata inconsistency")]
MetadataInconsistency(String),
}
/// Number of times we will compute partition within a checkpoint distance.

View File

@@ -30,6 +30,7 @@ use crate::{
tenant::{
config::{EvictionPolicy, EvictionPolicyLayerAccessThreshold},
storage_layer::PersistentLayer,
tasks::{BackgroundLoopKind, RateLimitError},
timeline::EvictionError,
LogicalSizeCalculationCause, Tenant,
},
@@ -129,7 +130,11 @@ impl Timeline {
ControlFlow::Continue(()) => (),
}
let elapsed = start.elapsed();
crate::tenant::tasks::warn_when_period_overrun(elapsed, p.period, "eviction");
crate::tenant::tasks::warn_when_period_overrun(
elapsed,
p.period,
BackgroundLoopKind::Eviction,
);
crate::metrics::EVICTION_ITERATION_DURATION
.get_metric_with_label_values(&[
&format!("{}", p.period.as_secs()),
@@ -150,6 +155,17 @@ impl Timeline {
) -> ControlFlow<()> {
let now = SystemTime::now();
let _permit = match crate::tenant::tasks::concurrent_background_tasks_rate_limit(
BackgroundLoopKind::Eviction,
ctx,
cancel,
)
.await
{
Ok(permit) => permit,
Err(RateLimitError::Cancelled) => return ControlFlow::Break(()),
};
// If we evict layers but keep cached values derived from those layers, then
// we face a storm of on-demand downloads after pageserver restart.
// The reason is that the restart empties the caches, and so, the values
@@ -285,6 +301,10 @@ impl Timeline {
warn!(layer = %l, "failed to evict layer: {e}");
stats.not_evictable += 1;
}
Some(Err(EvictionError::MetadataInconsistency(detail))) => {
warn!(layer = %l, "failed to evict layer: {detail}");
stats.not_evictable += 1;
}
}
}
if stats.candidates == stats.not_evictable {

View File

@@ -203,6 +203,18 @@ impl UploadQueue {
UploadQueue::Stopped(stopped) => Ok(stopped),
}
}
pub(crate) fn get_layer_metadata(
&self,
name: &LayerFileName,
) -> anyhow::Result<Option<LayerFileMetadata>> {
match self {
UploadQueue::Stopped(_) | UploadQueue::Uninitialized => {
anyhow::bail!("queue is in state {}", self.as_str())
}
UploadQueue::Initialized(inner) => Ok(inner.latest_files.get(name).cloned()),
}
}
}
/// An in-progress upload or delete task.

View File

@@ -18,11 +18,13 @@
//! any WAL records, so that even if an attacker hijacks the Postgres
//! process, he cannot escape out of it.
//!
use anyhow::Context;
use byteorder::{ByteOrder, LittleEndian};
use bytes::{BufMut, Bytes, BytesMut};
use nix::poll::*;
use serde::Serialize;
use std::collections::VecDeque;
use std::io;
use std::io::prelude::*;
use std::io::{Error, ErrorKind};
use std::ops::{Deref, DerefMut};
@@ -33,14 +35,13 @@ use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
use std::sync::{Mutex, MutexGuard};
use std::time::Duration;
use std::time::Instant;
use std::{fs, io};
use tracing::*;
use utils::crashsafe::path_with_suffix_extension;
use utils::{bin_ser::BeSer, id::TenantId, lsn::Lsn, nonblock::set_nonblock};
#[cfg(feature = "testing")]
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::config::PageServerConf;
use crate::metrics::{
WAL_REDO_BYTES_HISTOGRAM, WAL_REDO_RECORDS_HISTOGRAM, WAL_REDO_RECORD_COUNTER, WAL_REDO_TIME,
WAL_REDO_WAIT_TIME,
@@ -49,7 +50,6 @@ use crate::pgdatadir_mapping::{key_to_rel_block, key_to_slru_block};
use crate::repository::Key;
use crate::task_mgr::BACKGROUND_RUNTIME;
use crate::walrecord::NeonWalRecord;
use crate::{config::PageServerConf, TEMP_FILE_SUFFIX};
use pageserver_api::reltag::{RelTag, SlruKind};
use postgres_ffi::pg_constants;
use postgres_ffi::relfile_utils::VISIBILITYMAP_FORKNUM;
@@ -66,7 +66,7 @@ use postgres_ffi::BLCKSZ;
/// [See more related comments here](https://github.com/postgres/postgres/blob/99c5852e20a0987eca1c38ba0c09329d4076b6a0/src/include/storage/buf_internals.h#L91).
///
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize)]
pub struct BufferTag {
pub(crate) struct BufferTag {
pub rel: RelTag,
pub blknum: u32,
}
@@ -89,7 +89,7 @@ pub trait WalRedoManager: Send + Sync {
base_img: Option<(Lsn, Bytes)>,
records: Vec<(Lsn, NeonWalRecord)>,
pg_version: u32,
) -> Result<Bytes, WalRedoError>;
) -> anyhow::Result<Bytes>;
}
struct ProcessInput {
@@ -140,20 +140,6 @@ fn can_apply_in_neon(rec: &NeonWalRecord) -> bool {
}
}
/// An error happened in WAL redo
#[derive(Debug, thiserror::Error)]
pub enum WalRedoError {
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("cannot perform WAL redo now")]
InvalidState,
#[error("cannot perform WAL redo for this request")]
InvalidRequest,
#[error("cannot perform WAL redo for this record")]
InvalidRecord,
}
///
/// Public interface of WAL redo manager
///
@@ -171,10 +157,9 @@ impl WalRedoManager for PostgresRedoManager {
base_img: Option<(Lsn, Bytes)>,
records: Vec<(Lsn, NeonWalRecord)>,
pg_version: u32,
) -> Result<Bytes, WalRedoError> {
) -> anyhow::Result<Bytes> {
if records.is_empty() {
error!("invalid WAL redo request with no records");
return Err(WalRedoError::InvalidRequest);
anyhow::bail!("invalid WAL redo request with no records");
}
let base_img_lsn = base_img.as_ref().map(|p| p.0).unwrap_or(Lsn::INVALID);
@@ -238,15 +223,6 @@ impl PostgresRedoManager {
}
}
/// Launch process pre-emptively. Should not be needed except for benchmarking.
pub fn launch_process(&self, pg_version: u32) -> anyhow::Result<()> {
let mut proc = self.stdin.lock().unwrap();
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
Ok(())
}
///
/// Process one request for WAL redo using wal-redo postgres
///
@@ -260,8 +236,8 @@ impl PostgresRedoManager {
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
pg_version: u32,
) -> Result<Bytes, WalRedoError> {
let (rel, blknum) = key_to_rel_block(key).or(Err(WalRedoError::InvalidRecord))?;
) -> anyhow::Result<Bytes> {
let (rel, blknum) = key_to_rel_block(key).context("invalid record")?;
const MAX_RETRY_ATTEMPTS: u32 = 1;
let start_time = Instant::now();
let mut n_attempts = 0u32;
@@ -271,7 +247,8 @@ impl PostgresRedoManager {
// launch the WAL redo process on first use
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
self.launch(&mut proc, pg_version)
.context("launch process")?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
@@ -279,7 +256,7 @@ impl PostgresRedoManager {
let buf_tag = BufferTag { rel, blknum };
let result = self
.apply_wal_records(proc, buf_tag, &base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
.context("apply_wal_records");
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
@@ -309,15 +286,15 @@ impl PostgresRedoManager {
// next request will launch a new one.
if let Err(e) = result.as_ref() {
error!(
n_attempts,
"error applying {} WAL records {}..{} ({} bytes) to base image with LSN {} to reconstruct page image at LSN {}: {}",
"error applying {} WAL records {}..{} ({} bytes) to base image with LSN {} to reconstruct page image at LSN {} n_attempts={}: {:?}",
records.len(),
records.first().map(|p| p.0).unwrap_or(Lsn(0)),
records.last().map(|p| p.0).unwrap_or(Lsn(0)),
nbytes,
base_img_lsn,
lsn,
utils::error::report_compact_sources(e),
n_attempts,
e,
);
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
@@ -354,7 +331,7 @@ impl PostgresRedoManager {
lsn: Lsn,
base_img: Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
) -> Result<Bytes, WalRedoError> {
) -> anyhow::Result<Bytes> {
let start_time = Instant::now();
let mut page = BytesMut::new();
@@ -363,8 +340,7 @@ impl PostgresRedoManager {
page.extend_from_slice(&fpi[..]);
} else {
// All the current WAL record types that we can handle require a base image.
error!("invalid neon WAL redo request with no base image");
return Err(WalRedoError::InvalidRequest);
anyhow::bail!("invalid neon WAL redo request with no base image");
}
// Apply all the WAL records in the batch
@@ -392,14 +368,13 @@ impl PostgresRedoManager {
page: &mut BytesMut,
_record_lsn: Lsn,
record: &NeonWalRecord,
) -> Result<(), WalRedoError> {
) -> anyhow::Result<()> {
match record {
NeonWalRecord::Postgres {
will_init: _,
rec: _,
} => {
error!("tried to pass postgres wal record to neon WAL redo");
return Err(WalRedoError::InvalidRequest);
anyhow::bail!("tried to pass postgres wal record to neon WAL redo");
}
NeonWalRecord::ClearVisibilityMapFlags {
new_heap_blkno,
@@ -407,7 +382,7 @@ impl PostgresRedoManager {
flags,
} => {
// sanity check that this is modifying the correct relation
let (rel, blknum) = key_to_rel_block(key).or(Err(WalRedoError::InvalidRecord))?;
let (rel, blknum) = key_to_rel_block(key).context("invalid record")?;
assert!(
rel.forknum == VISIBILITYMAP_FORKNUM,
"ClearVisibilityMapFlags record on unexpected rel {}",
@@ -445,7 +420,7 @@ impl PostgresRedoManager {
// same effects as the corresponding Postgres WAL redo function.
NeonWalRecord::ClogSetCommitted { xids, timestamp } => {
let (slru_kind, segno, blknum) =
key_to_slru_block(key).or(Err(WalRedoError::InvalidRecord))?;
key_to_slru_block(key).context("invalid record")?;
assert_eq!(
slru_kind,
SlruKind::Clog,
@@ -495,7 +470,7 @@ impl PostgresRedoManager {
}
NeonWalRecord::ClogSetAborted { xids } => {
let (slru_kind, segno, blknum) =
key_to_slru_block(key).or(Err(WalRedoError::InvalidRecord))?;
key_to_slru_block(key).context("invalid record")?;
assert_eq!(
slru_kind,
SlruKind::Clog,
@@ -526,7 +501,7 @@ impl PostgresRedoManager {
}
NeonWalRecord::MultixactOffsetCreate { mid, moff } => {
let (slru_kind, segno, blknum) =
key_to_slru_block(key).or(Err(WalRedoError::InvalidRecord))?;
key_to_slru_block(key).context("invalid record")?;
assert_eq!(
slru_kind,
SlruKind::MultiXactOffsets,
@@ -559,7 +534,7 @@ impl PostgresRedoManager {
}
NeonWalRecord::MultixactMembersCreate { moff, members } => {
let (slru_kind, segno, blknum) =
key_to_slru_block(key).or(Err(WalRedoError::InvalidRecord))?;
key_to_slru_block(key).context("invalid record")?;
assert_eq!(
slru_kind,
SlruKind::MultiXactMembers,
@@ -649,26 +624,6 @@ impl PostgresRedoManager {
input: &mut MutexGuard<Option<ProcessInput>>,
pg_version: u32,
) -> Result<(), Error> {
// Previous versions of wal-redo required data directory and that directories
// occupied some space on disk. Remove it if we face it.
//
// This code could be dropped after one release cycle.
let legacy_datadir = path_with_suffix_extension(
self.conf
.tenant_path(&self.tenant_id)
.join("wal-redo-datadir"),
TEMP_FILE_SUFFIX,
);
if legacy_datadir.exists() {
info!("legacy wal-redo datadir {legacy_datadir:?} exists, removing");
fs::remove_dir_all(&legacy_datadir).map_err(|e| {
Error::new(
e.kind(),
format!("legacy wal-redo datadir {legacy_datadir:?} removal failure: {e}"),
)
})?;
}
let pg_bin_dir_path = self
.conf
.pg_bin_dir(pg_version)
@@ -759,7 +714,7 @@ impl PostgresRedoManager {
base_img: &Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
) -> Result<Bytes, std::io::Error> {
) -> anyhow::Result<Bytes> {
// Serialize all the messages to send the WAL redo process first.
//
// This could be problematic if there are millions of records to replay,
@@ -782,10 +737,7 @@ impl PostgresRedoManager {
{
build_apply_record_msg(*lsn, postgres_rec, &mut writebuf);
} else {
return Err(Error::new(
ErrorKind::Other,
"tried to pass neon wal record to postgres WAL redo",
));
anyhow::bail!("tried to pass neon wal record to postgres WAL redo");
}
}
build_get_page_msg(tag, &mut writebuf);
@@ -807,7 +759,7 @@ impl PostgresRedoManager {
writebuf: &[u8],
mut input: MutexGuard<Option<ProcessInput>>,
wal_redo_timeout: Duration,
) -> Result<Bytes, std::io::Error> {
) -> anyhow::Result<Bytes> {
let proc = input.as_mut().unwrap();
let mut nwrite = 0usize;
let stdout_fd = proc.stdout_fd;
@@ -831,7 +783,7 @@ impl PostgresRedoManager {
}?;
if n == 0 {
return Err(Error::new(ErrorKind::Other, "WAL redo timed out"));
anyhow::bail!("WAL redo timed out");
}
// If we have some messages in stderr, forward them to the log.
@@ -855,10 +807,7 @@ impl PostgresRedoManager {
continue;
}
} else if err_revents.contains(PollFlags::POLLHUP) {
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stderr unexpectedly",
));
anyhow::bail!("WAL redo process closed its stderr unexpectedly");
}
// If 'stdin' is writeable, do write.
@@ -867,10 +816,7 @@ impl PostgresRedoManager {
nwrite += proc.stdin.write(&writebuf[nwrite..])?;
} else if in_revents.contains(PollFlags::POLLHUP) {
// We still have more data to write, but the process closed the pipe.
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stdin unexpectedly",
));
anyhow::bail!("WAL redo process closed its stdin unexpectedly");
}
}
let request_no = proc.n_requests;
@@ -901,10 +847,7 @@ impl PostgresRedoManager {
//
// Cross-read this with the comment in apply_batch_postgres if result.is_err().
// That's where we kill the child process.
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stdout unexpectedly",
));
anyhow::bail!("WAL redo process closed its stdout unexpectedly");
}
let n_processed_responses = output.n_processed_responses;
while n_processed_responses + output.pending_responses.len() <= request_no {
@@ -923,7 +866,7 @@ impl PostgresRedoManager {
}?;
if n == 0 {
return Err(Error::new(ErrorKind::Other, "WAL redo timed out"));
anyhow::bail!("WAL redo timed out");
}
// If we have some messages in stderr, forward them to the log.
@@ -947,10 +890,7 @@ impl PostgresRedoManager {
continue;
}
} else if err_revents.contains(PollFlags::POLLHUP) {
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stderr unexpectedly",
));
anyhow::bail!("WAL redo process closed its stderr unexpectedly");
}
// If we have some data in stdout, read it to the result buffer.
@@ -958,10 +898,7 @@ impl PostgresRedoManager {
if out_revents & (PollFlags::POLLERR | PollFlags::POLLIN) != PollFlags::empty() {
nresult += output.stdout.read(&mut resultbuf[nresult..])?;
} else if out_revents.contains(PollFlags::POLLHUP) {
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stdout unexpectedly",
));
anyhow::bail!("WAL redo process closed its stdout unexpectedly");
}
}
output

View File

@@ -741,13 +741,6 @@ NeonProcessUtility(
break;
case T_DropdbStmt:
HandleDropDb(castNode(DropdbStmt, parseTree));
/*
* We do this here to hack around the fact that Postgres performs the drop
* INSIDE of standard_ProcessUtility, which means that if we try to
* abort the drop normally it'll be too late. DROP DATABASE can't be inside
* of a transaction block anyway, so this should be fine to do.
*/
NeonXactCallback(XACT_EVENT_PRE_COMMIT, NULL);
break;
case T_CreateRoleStmt:
HandleCreateRole(castNode(CreateRoleStmt, parseTree));

View File

@@ -1,5 +1,6 @@
use futures::future::Either;
use proxy::auth;
use proxy::config::HttpConfig;
use proxy::console;
use proxy::http;
use proxy::metrics;
@@ -79,6 +80,9 @@ struct ProxyCliArgs {
/// Allow self-signed certificates for compute nodes (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
allow_self_signed_compute: bool,
/// timeout for http connections
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
sql_over_http_timeout: tokio::time::Duration,
}
#[tokio::main]
@@ -220,12 +224,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
auth::BackendType::Link(Cow::Owned(url))
}
};
let http_config = HttpConfig {
sql_over_http_timeout: args.sql_over_http_timeout,
};
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
}));
Ok(config)

View File

@@ -13,6 +13,7 @@ pub struct ProxyConfig {
pub auth_backend: auth::BackendType<'static, ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,
}
#[derive(Debug)]
@@ -26,6 +27,10 @@ pub struct TlsConfig {
pub common_names: Option<HashSet<String>>,
}
pub struct HttpConfig {
pub sql_over_http_timeout: tokio::time::Duration,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()

View File

@@ -20,7 +20,7 @@ use tokio_postgres::AsyncMessage;
use crate::{
auth, console,
metrics::{Ids, MetricCounter, USAGE_METRICS},
proxy::{NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER},
proxy::{LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER},
};
use crate::{compute, config};
@@ -139,6 +139,7 @@ impl GlobalConnPool {
session_id: uuid::Uuid,
) -> anyhow::Result<Client> {
let mut client: Option<Client> = None;
let mut latency_timer = LatencyTimer::new("http");
let mut hash_valid = false;
if !force_new {
@@ -182,15 +183,16 @@ impl GlobalConnPool {
let new_client = if let Some(client) = client {
if client.inner.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
connect_to_compute(self.proxy_config, conn_info, session_id).await
connect_to_compute(self.proxy_config, conn_info, session_id, latency_timer).await
} else {
latency_timer.pool_hit();
info!("pool: reusing connection '{conn_info}'");
client.session.send(session_id)?;
return Ok(client);
}
} else {
info!("pool: opening a new connection '{conn_info}'");
connect_to_compute(self.proxy_config, conn_info, session_id).await
connect_to_compute(self.proxy_config, conn_info, session_id, latency_timer).await
};
match &new_client {
@@ -347,6 +349,7 @@ async fn connect_to_compute(
config: &config::ProxyConfig,
conn_info: &ConnInfo,
session_id: uuid::Uuid,
latency_timer: LatencyTimer,
) -> anyhow::Result<Client> {
let tls = config.tls_config.as_ref();
let common_names = tls.and_then(|tls| tls.common_names.clone());
@@ -386,6 +389,7 @@ async fn connect_to_compute(
node_info,
&extra,
&creds,
latency_timer,
)
.await
}

View File

@@ -24,6 +24,7 @@ use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::config::HttpConfig;
use crate::proxy::{NUM_CONNECTIONS_ACCEPTED_COUNTER, NUM_CONNECTIONS_CLOSED_COUNTER};
use super::conn_pool::ConnInfo;
@@ -49,7 +50,6 @@ enum Payload {
const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MiB
const MAX_REQUEST_SIZE: u64 = 10 * 1024 * 1024; // 10 MiB
const HTTP_CONNECTION_TIMEOUT: tokio::time::Duration = tokio::time::Duration::from_secs(15);
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
@@ -102,9 +102,9 @@ fn json_array_to_pg_array(value: &Value) -> Result<Option<String>, serde_json::E
// convert to text with escaping
Value::Bool(_) => serde_json::to_string(value).map(Some),
Value::Number(_) => serde_json::to_string(value).map(Some),
Value::Object(_) => serde_json::to_string(value).map(Some),
// here string needs to be escaped, as it is part of the array
Value::Object(_) => json_array_to_pg_array(&Value::String(serde_json::to_string(value)?)),
Value::String(_) => serde_json::to_string(value).map(Some),
// recurse into array
@@ -191,9 +191,10 @@ pub async fn handle(
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
session_id: uuid::Uuid,
config: &'static HttpConfig,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
HTTP_CONNECTION_TIMEOUT,
config.sql_over_http_timeout,
handle_inner(request, sni_hostname, conn_pool, session_id),
)
.await;
@@ -224,7 +225,7 @@ pub async fn handle(
Err(_) => {
let message = format!(
"HTTP-Connection timed out, execution time exeeded {} seconds",
HTTP_CONNECTION_TIMEOUT.as_secs()
config.sql_over_http_timeout.as_secs()
);
error!(message);
json_response(
@@ -612,7 +613,7 @@ fn _pg_array_parse(
}
}
}
'}' => {
'}' if !quote => {
level -= 1;
if level == 0 {
push_checked(&mut entry, &mut entries, elem_type)?;
@@ -696,6 +697,14 @@ mod tests {
"{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
)]
);
// array of objects
let json = r#"[{"foo": 1},{"bar": 2}]"#;
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]).unwrap();
assert_eq!(
pg_params,
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
);
}
#[test]
@@ -823,4 +832,23 @@ mod tests {
json!([[[1, 2, 3], [4, 5, 6]]])
);
}
#[test]
fn test_pg_array_parse_json() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::JSONB).unwrap()
}
assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
assert_eq!(
pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
json!([{"foo": 1, "bar": 2}])
);
assert_eq!(
pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
json!([{"foo": 1}, {"bar": 2}])
);
assert_eq!(
pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
json!([[{"foo": 1}, {"bar": 2}]])
);
}
}

View File

@@ -205,7 +205,14 @@ async fn ws_handler(
// TODO: that deserves a refactor as now this function also handles http json client besides websockets.
// Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead.
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
sql_over_http::handle(request, sni_hostname, conn_pool, session_id).await
sql_over_http::handle(
request,
sni_hostname,
conn_pool,
session_id,
&config.http_config,
)
.await
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")

View File

@@ -15,12 +15,11 @@ use crate::{
use anyhow::{bail, Context};
use async_trait::async_trait;
use futures::TryFutureExt;
use metrics::{
exponential_buckets, register_histogram, register_int_counter_vec, Histogram, IntCounterVec,
};
use metrics::{exponential_buckets, register_int_counter_vec, IntCounterVec};
use once_cell::sync::Lazy;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::{error::Error, io, ops::ControlFlow, sync::Arc};
use prometheus::{register_histogram_vec, HistogramVec};
use std::{error::Error, io, ops::ControlFlow, sync::Arc, time::Instant};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
time,
@@ -93,16 +92,57 @@ pub static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
.unwrap()
});
static COMPUTE_CONNECTION_LATENCY: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
static COMPUTE_CONNECTION_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"proxy_compute_connection_latency_seconds",
"Time it took for proxy to establish a connection to the compute endpoint",
&["protocol", "cache_miss", "pool_miss"],
// largest bucket = 2^16 * 0.5ms = 32s
exponential_buckets(0.0005, 2.0, 16).unwrap(),
)
.unwrap()
});
pub struct LatencyTimer {
start: Instant,
pool_miss: bool,
cache_miss: bool,
protocol: &'static str,
}
impl LatencyTimer {
pub fn new(protocol: &'static str) -> Self {
Self {
start: Instant::now(),
cache_miss: false,
// by default we don't do pooling
pool_miss: true,
protocol,
}
}
pub fn cache_miss(&mut self) {
self.cache_miss = true;
}
pub fn pool_hit(&mut self) {
self.pool_miss = false;
}
}
impl Drop for LatencyTimer {
fn drop(&mut self) {
let duration = self.start.elapsed().as_secs_f64();
COMPUTE_CONNECTION_LATENCY
.with_label_values(&[
self.protocol,
bool_to_str(self.cache_miss),
bool_to_str(self.pool_miss),
])
.observe(duration)
}
}
static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_connection_failures_total",
@@ -495,13 +535,12 @@ pub async fn connect_to_compute<M: ConnectMechanism>(
mut node_info: console::CachedNodeInfo,
extra: &console::ConsoleReqExtra<'_>,
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
mut latency_timer: LatencyTimer,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: ShouldRetry + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
let _timer = COMPUTE_CONNECTION_LATENCY.start_timer();
mechanism.update_connect_config(&mut node_info.config);
// try once
@@ -513,6 +552,8 @@ where
}
};
latency_timer.cache_miss();
let mut num_retries = 1;
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
@@ -789,6 +830,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
application_name: params.get("application_name"),
};
let latency_timer = LatencyTimer::new(mode.protocol_label());
let auth_result = match creds
.authenticate(&extra, &mut stream, mode.allow_cleartext())
.await
@@ -812,9 +855,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
node_info.allow_self_signed_compute = allow_self_signed_compute;
let aux = node_info.aux.clone();
let mut node = connect_to_compute(&TcpMechanism { params }, node_info, &extra, &creds)
.or_else(|e| stream.throw_error(e))
.await?;
let mut node = connect_to_compute(
&TcpMechanism { params },
node_info,
&extra,
&creds,
latency_timer,
)
.or_else(|e| stream.throw_error(e))
.await?;
let proto = mode.protocol_label();
NUM_DB_CONNECTIONS_OPENED_COUNTER

View File

@@ -450,7 +450,7 @@ async fn connect_to_compute_success() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds)
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
.await
.unwrap();
mechanism.verify();
@@ -461,7 +461,7 @@ async fn connect_to_compute_retry() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds)
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
.await
.unwrap();
mechanism.verify();
@@ -473,7 +473,7 @@ async fn connect_to_compute_non_retry_1() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds)
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
.await
.unwrap_err();
mechanism.verify();
@@ -485,7 +485,7 @@ async fn connect_to_compute_non_retry_2() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds)
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
.await
.unwrap();
mechanism.verify();
@@ -501,7 +501,7 @@ async fn connect_to_compute_non_retry_3() {
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds)
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
.await
.unwrap_err();
mechanism.verify();
@@ -513,7 +513,7 @@ async fn wake_retry() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds)
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
.await
.unwrap();
mechanism.verify();
@@ -525,7 +525,7 @@ async fn wake_non_retry() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds)
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
.await
.unwrap_err();
mechanism.verify();

View File

@@ -65,7 +65,7 @@ def start_heavy_write_workload(env: PgCompare, n_tables: int, scale: int, num_it
def start_single_table_workload(table_id: int):
for _ in range(num_iters):
with env.pg.connect().cursor() as cur:
with env.pg.connect(options="-cstatement_timeout=300s").cursor() as cur:
cur.execute(
f"INSERT INTO t{table_id} SELECT FROM generate_series(1,{new_rows_each_update})"
)

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import psycopg2
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import VanillaPostgres
from fixtures.neon_fixtures import NeonEnv, VanillaPostgres
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
from werkzeug.wrappers.response import Response
@@ -205,6 +205,10 @@ def test_ddl_forwarding(ddl: DdlForwardingContext):
ddl.wait()
assert ddl.dbs == {"stork": "cork"}
cur.execute("DROP DATABASE stork")
ddl.wait()
assert ddl.dbs == {}
with pytest.raises(psycopg2.InternalError):
ddl.failures(True)
cur.execute("CREATE DATABASE failure WITH OWNER=cork")
@@ -217,6 +221,94 @@ def test_ddl_forwarding(ddl: DdlForwardingContext):
ddl.failures(True)
cur.execute("DROP DATABASE failure")
ddl.wait()
ddl.pg.connect(dbname="failure") # Ensure we can connect after a failed drop
assert ddl.dbs == {"failure": "cork"}
ddl.failures(False)
# Check that db is still in the Postgres after failure
cur.execute("SELECT datconnlimit FROM pg_database WHERE datname = 'failure'")
result = cur.fetchone()
if not result:
raise AssertionError("Database 'failure' not found")
# -2 means invalid database
# It should be invalid because cplane request failed
assert result[0] == -2, "Database 'failure' is not invalid"
# Check that repeated drop succeeds
cur.execute("DROP DATABASE failure")
ddl.wait()
assert ddl.dbs == {}
# DB should be absent in the Postgres
cur.execute("SELECT count(*) FROM pg_database WHERE datname = 'failure'")
result = cur.fetchone()
if not result:
raise AssertionError("Could not count databases")
assert result[0] == 0, "Database 'failure' still exists after drop"
conn.close()
# Assert that specified database has a specific connlimit, throwing an AssertionError otherwise
# -2 means invalid database
# -1 means no specific per-db limit (default)
def assert_db_connlimit(endpoint: Any, db_name: str, connlimit: int, msg: str):
with endpoint.cursor() as cur:
cur.execute("SELECT datconnlimit FROM pg_database WHERE datname = %s", (db_name,))
result = cur.fetchone()
if not result:
raise AssertionError(f"Database '{db_name}' not found")
assert result[0] == connlimit, msg
# Test that compute_ctl can deal with invalid databases (drop them).
# If Postgres extension cannot reach cplane, then DROP will be aborted
# and database will be marked as invalid. Then there are two recovery
# flows:
# 1. User can just repeat DROP DATABASE command until it succeeds
# 2. User can ignore, then compute_ctl will drop invalid databases
# automatically during full configuration
# Here we test the latter. The first one is tested in test_ddl_forwarding
def test_ddl_forwarding_invalid_db(neon_simple_env: NeonEnv):
env = neon_simple_env
env.neon_cli.create_branch("test_ddl_forwarding_invalid_db", "empty")
endpoint = env.endpoints.create_start(
"test_ddl_forwarding_invalid_db",
# Some non-existent url
config_lines=["neon.console_url=http://localhost:9999/unknown/api/v0/roles_and_databases"],
)
log.info("postgres is running on 'test_ddl_forwarding_invalid_db' branch")
with endpoint.cursor() as cur:
cur.execute("SET neon.forward_ddl = false")
cur.execute("CREATE DATABASE failure")
cur.execute("COMMIT")
assert_db_connlimit(
endpoint, "failure", -1, "Database 'failure' doesn't have a valid connlimit"
)
with pytest.raises(psycopg2.InternalError):
with endpoint.cursor() as cur:
cur.execute("DROP DATABASE failure")
cur.execute("COMMIT")
# Should be invalid after failed drop
assert_db_connlimit(endpoint, "failure", -2, "Database 'failure' ins't invalid")
endpoint.stop()
endpoint.start()
# Still invalid after restart without full configuration
assert_db_connlimit(endpoint, "failure", -2, "Database 'failure' ins't invalid")
endpoint.stop()
endpoint.respec(skip_pg_catalog_updates=False)
endpoint.start()
# Should be cleaned up by compute_ctl during full configuration
with endpoint.cursor() as cur:
cur.execute("SELECT count(*) FROM pg_database WHERE datname = 'failure'")
result = cur.fetchone()
if not result:
raise AssertionError("Could not count databases")
assert result[0] == 0, "Database 'failure' still exists after restart"

View File

@@ -104,6 +104,22 @@ def generate_uploads_and_deletions(
assert gc_result["layers_removed"] > 0
def read_all(
env: NeonEnv, tenant_id: Optional[TenantId] = None, timeline_id: Optional[TimelineId] = None
):
if tenant_id is None:
tenant_id = env.initial_tenant
assert tenant_id is not None
if timeline_id is None:
timeline_id = env.initial_timeline
assert timeline_id is not None
env.pageserver.http_client()
with env.endpoints.create_start("main", tenant_id=tenant_id) as endpoint:
endpoint.safe_psql("SELECT SUM(LENGTH(val)) FROM foo;")
def get_metric_or_0(ps_http, metric: str) -> int:
v = ps_http.get_metric_value(metric)
return 0 if v is None else int(v)
@@ -467,3 +483,48 @@ def test_emergency_mode(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin):
assert get_deletion_queue_depth(ps_http) == 0
assert get_deletion_queue_validated(ps_http) > 0
assert get_deletion_queue_executed(ps_http) > 0
def evict_all_layers(env: NeonEnv, tenant_id: TenantId, timeline_id: TimelineId):
timeline_path = env.pageserver.timeline_dir(tenant_id, timeline_id)
initial_local_layers = sorted(
list(filter(lambda path: path.name != "metadata", timeline_path.glob("*")))
)
client = env.pageserver.http_client()
for layer in initial_local_layers:
if "ephemeral" in layer.name:
continue
log.info(f"Evicting layer {tenant_id}/{timeline_id} {layer.name}")
client.evict_layer(tenant_id=tenant_id, timeline_id=timeline_id, layer_name=layer.name)
def test_eviction_across_generations(neon_env_builder: NeonEnvBuilder):
"""
Eviction and on-demand downloads exercise a particular code path where RemoteLayer is constructed
and must be constructed using the proper generation for the layer, which may not be the same generation
that the tenant is running in.
"""
neon_env_builder.enable_generations = True
neon_env_builder.enable_pageserver_remote_storage(
RemoteStorageKind.MOCK_S3,
)
env = neon_env_builder.init_start(initial_tenant_conf=TENANT_CONF)
env.pageserver.http_client()
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
generate_uploads_and_deletions(env)
read_all(env, tenant_id, timeline_id)
evict_all_layers(env, tenant_id, timeline_id)
read_all(env, tenant_id, timeline_id)
# This will cause the generation to increment
env.pageserver.stop()
env.pageserver.start()
# Now we are running as generation 2, but must still correctly remember that the layers
# we are evicting and downloading are from generation 1.
read_all(env, tenant_id, timeline_id)
evict_all_layers(env, tenant_id, timeline_id)
read_all(env, tenant_id, timeline_id)

View File

@@ -4,6 +4,7 @@ import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnvBuilder
from fixtures.remote_storage import s3_storage
from fixtures.utils import wait_until
# Test restarting page server, while safekeeper and compute node keep
@@ -16,8 +17,7 @@ def test_pageserver_restart(neon_env_builder: NeonEnvBuilder, generations: bool)
env = neon_env_builder.init_start()
env.neon_cli.create_branch("test_pageserver_restart")
endpoint = env.endpoints.create_start("test_pageserver_restart")
endpoint = env.endpoints.create_start("main")
pageserver_http = env.pageserver.http_client()
pg_conn = endpoint.connect()
@@ -75,27 +75,52 @@ def test_pageserver_restart(neon_env_builder: NeonEnvBuilder, generations: bool)
cur.execute("SELECT count(*) FROM foo")
assert cur.fetchone() == (100000,)
# Validate startup time metrics
metrics = pageserver_http.get_metrics()
# Wait for metrics to indicate startup complete, so that we can know all
# startup phases will be reflected in the subsequent checks
def assert_complete():
for sample in pageserver_http.get_metrics().query_all(
"pageserver_startup_duration_seconds"
):
labels = dict(sample.labels)
log.info(f"metric {labels['phase']}={sample.value}")
if labels["phase"] == "complete" and sample.value > 0:
return
raise AssertionError("No 'complete' metric yet")
wait_until(30, 1.0, assert_complete)
# Expectation callbacks: arg t is sample value, arg p is the previous phase's sample value
expectations = {
"initial": lambda t, p: True, # make no assumptions about the initial time point, it could be 0 in theory
expectations = [
(
"initial",
lambda t, p: True,
), # make no assumptions about the initial time point, it could be 0 in theory
# Remote phase of initial_tenant_load should happen before overall phase is complete
("initial_tenant_load_remote", lambda t, p: t >= 0.0 and t >= p),
# Initial tenant load should reflect the delay we injected
"initial_tenant_load": lambda t, p: t >= (tenant_load_delay_ms / 1000.0) and t >= p,
("initial_tenant_load", lambda t, p: t >= (tenant_load_delay_ms / 1000.0) and t >= p),
# Subsequent steps should occur in expected order
"initial_logical_sizes": lambda t, p: t > 0 and t >= p,
"background_jobs_can_start": lambda t, p: t > 0 and t >= p,
"complete": lambda t, p: t > 0 and t >= p,
}
("initial_logical_sizes", lambda t, p: t > 0 and t >= p),
("background_jobs_can_start", lambda t, p: t > 0 and t >= p),
("complete", lambda t, p: t > 0 and t >= p),
]
# Accumulate the runtime of each startup phase
values = {}
metrics = pageserver_http.get_metrics()
prev_value = None
for sample in metrics.query_all("pageserver_startup_duration_seconds"):
labels = dict(sample.labels)
phase = labels["phase"]
phase = sample.labels["phase"]
log.info(f"metric {phase}={sample.value}")
assert phase in expectations, f"Unexpected phase {phase}"
assert expectations[phase](
assert phase in [e[0] for e in expectations], f"Unexpected phase {phase}"
values[phase] = sample
# Apply expectations to the metrics retrieved
for phase, expectation in expectations:
assert phase in values, f"No data for phase {phase}"
sample = values[phase]
assert expectation(
sample.value, prev_value
), f"Unexpected value for {phase}: {sample.value}"
prev_value = sample.value

View File

@@ -188,7 +188,7 @@ def test_sql_over_http(static_proxy: NeonProxy):
headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
verify=str(static_proxy.test_output_dir / "proxy.crt"),
)
assert response.status_code == 200
assert response.status_code == 200, response.text
return response.json()
rows = q("select 42 as answer")["rows"]
@@ -206,6 +206,12 @@ def test_sql_over_http(static_proxy: NeonProxy):
rows = q("select $1::json->'a' as answer", [{"a": {"b": 42}}])["rows"]
assert rows == [{"answer": {"b": 42}}]
rows = q("select $1::jsonb[] as answer", [[{}]])["rows"]
assert rows == [{"answer": [{}]}]
rows = q("select $1::jsonb[] as answer", [[{"foo": 1}, {"bar": 2}]])["rows"]
assert rows == [{"answer": [{"foo": 1}, {"bar": 2}]}]
rows = q("select * from pg_class limit 1")["rows"]
assert len(rows) == 1

View File

@@ -6,6 +6,7 @@ from pathlib import Path
from typing import List, Optional
import asyncpg
import pytest
import toml
from fixtures.log_helper import getLogger
from fixtures.neon_fixtures import Endpoint, NeonEnv, NeonEnvBuilder, Safekeeper
@@ -597,7 +598,10 @@ async def run_wal_lagging(env: NeonEnv, endpoint: Endpoint, test_output_dir: Pat
assert res == expected_sum
# do inserts while restarting postgres and messing with safekeeper addresses
# Do inserts while restarting postgres and messing with safekeeper addresses.
# The test takes more than default 5 minutes on Postgres 16,
# see https://github.com/neondatabase/neon/issues/5305
@pytest.mark.timeout(600)
def test_wal_lagging(neon_env_builder: NeonEnvBuilder, test_output_dir: Path):
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.init_start()

View File

@@ -31,13 +31,14 @@ futures = { version = "0.3" }
futures-channel = { version = "0.3", features = ["sink"] }
futures-core = { version = "0.3" }
futures-executor = { version = "0.3" }
futures-io = { version = "0.3" }
futures-sink = { version = "0.3" }
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
hex = { version = "0.4", features = ["serde"] }
hyper = { version = "0.14", features = ["full"] }
itertools = { version = "0.10" }
libc = { version = "0.2", features = ["extra_traits"] }
log = { version = "0.4", default-features = false, features = ["std"] }
log = { version = "0.4", default-features = false, features = ["kv_unstable", "std"] }
memchr = { version = "2" }
nom = { version = "7" }
num-bigint = { version = "0.4" }
@@ -47,7 +48,7 @@ prost = { version = "0.11" }
rand = { version = "0.8", features = ["small_rng"] }
regex = { version = "1" }
regex-syntax = { version = "0.7" }
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "multipart", "rustls-tls"] }
reqwest = { version = "0.11", default-features = false, features = ["blocking", "default-tls", "json", "multipart", "rustls-tls", "stream"] }
ring = { version = "0.16", features = ["std"] }
rustls = { version = "0.21", features = ["dangerous_configuration"] }
scopeguard = { version = "1" }
@@ -55,7 +56,8 @@ serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["raw_value"] }
smallvec = { version = "1", default-features = false, features = ["write"] }
socket2 = { version = "0.4", default-features = false, features = ["all"] }
time = { version = "0.3", features = ["formatting", "macros", "parsing"] }
standback = { version = "0.2", default-features = false, features = ["std"] }
time = { version = "0.3", features = ["macros", "serde-well-known"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
tokio-rustls = { version = "0.24" }
tokio-util = { version = "0.7", features = ["codec", "io"] }
@@ -75,14 +77,16 @@ cc = { version = "1", default-features = false, features = ["parallel"] }
either = { version = "1" }
itertools = { version = "0.10" }
libc = { version = "0.2", features = ["extra_traits"] }
log = { version = "0.4", default-features = false, features = ["std"] }
log = { version = "0.4", default-features = false, features = ["kv_unstable", "std"] }
memchr = { version = "2" }
nom = { version = "7" }
prost = { version = "0.11" }
regex = { version = "1" }
regex-syntax = { version = "0.7" }
serde = { version = "1", features = ["alloc", "derive"] }
standback = { version = "0.2", default-features = false, features = ["std"] }
syn-dff4ba8e3ae991db = { package = "syn", version = "1", features = ["extra-traits", "full", "visit"] }
syn-f595c2ba2a3f28df = { package = "syn", version = "2", features = ["extra-traits", "full", "visit", "visit-mut"] }
time-macros = { version = "0.2", default-features = false, features = ["formatting", "parsing", "serde"] }
### END HAKARI SECTION