Compare commits

..

8 Commits

Author SHA1 Message Date
Konstantin Knizhnik
773a8836cc Update pgxn/neon/file_cache.c
Co-authored-by: Anastasia Lubennikova <anastasia@neon.tech>
2023-10-30 19:14:00 +02:00
Konstantin Knizhnik
da34a9ed89 Add neon_stats view 2023-10-12 09:44:15 +03:00
Konstantin Knizhnik
e4b74e375c Remove wrong comment 2023-10-11 08:48:22 +03:00
Konstantin Knizhnik
c3d1e26361 Add missed neon--1.0--1.1.sql file 2023-10-11 08:40:53 +03:00
Konstantin Knizhnik
665eff5e08 Show information about local file cache in EXPLAIN ANALYZE 2023-10-10 10:11:30 +03:00
Konstantin Knizhnik
7aeec2bc9d Apply review suggestions 2023-10-09 22:23:28 +03:00
Konstantin Knizhnik
bc3a5d571f Use popcount32 to calculate number of used pages 2023-10-09 22:07:10 +03:00
Konstantin Knizhnik
2324e5a3d3 Rebased version of LFC fixes 2023-10-08 15:02:42 +03:00
142 changed files with 3440 additions and 9860 deletions

View File

@@ -320,9 +320,6 @@ jobs:
- name: Build neon extensions
run: mold -run make neon-pg-ext -j$(nproc)
- name: Build walproposer-lib
run: mold -run make walproposer-lib -j$(nproc)
- name: Run cargo build
run: |
${cov_prefix} mold -run cargo build $CARGO_FLAGS $CARGO_FEATURES --bins --tests
@@ -338,16 +335,6 @@ jobs:
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
${cov_prefix} cargo test $CARGO_FLAGS --package remote_storage --test test_real_s3
# Run separate tests for real Azure Blob Storage
# XXX: replace region with `eu-central-1`-like region
export AZURE_STORAGE_ACCOUNT="${{ secrets.AZURE_STORAGE_ACCOUNT_DEV }}"
export AZURE_STORAGE_ACCESS_KEY="${{ secrets.AZURE_STORAGE_ACCESS_KEY_DEV }}"
export ENABLE_REAL_AZURE_REMOTE_STORAGE=y
export REMOTE_STORAGE_AZURE_CONTAINER=neon-github-sandbox
export REMOTE_STORAGE_AZURE_REGION=eastus2
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
${cov_prefix} cargo test $CARGO_FLAGS --package remote_storage --test test_real_azure
- name: Install rust binaries
run: |
# Install target binaries
@@ -847,7 +834,7 @@ jobs:
run:
shell: sh -eu {0}
env:
VM_BUILDER_VERSION: v0.18.2
VM_BUILDER_VERSION: v0.17.12
steps:
- name: Checkout

View File

@@ -32,7 +32,7 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 1
@@ -90,21 +90,18 @@ jobs:
- name: Build postgres v14
if: steps.cache_pg_14.outputs.cache-hit != 'true'
run: make postgres-v14 -j$(sysctl -n hw.ncpu)
run: make postgres-v14 -j$(nproc)
- name: Build postgres v15
if: steps.cache_pg_15.outputs.cache-hit != 'true'
run: make postgres-v15 -j$(sysctl -n hw.ncpu)
run: make postgres-v15 -j$(nproc)
- name: Build postgres v16
if: steps.cache_pg_16.outputs.cache-hit != 'true'
run: make postgres-v16 -j$(sysctl -n hw.ncpu)
run: make postgres-v16 -j$(nproc)
- name: Build neon extensions
run: make neon-pg-ext -j$(sysctl -n hw.ncpu)
- name: Build walproposer-lib
run: make walproposer-lib -j$(sysctl -n hw.ncpu)
run: make neon-pg-ext -j$(nproc)
- name: Run cargo build
run: cargo build --all --release
@@ -129,7 +126,7 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 1
@@ -138,9 +135,6 @@ jobs:
- name: Get postgres headers
run: make postgres-headers -j$(nproc)
- name: Build walproposer-lib
run: make walproposer-lib -j$(nproc)
- name: Produce the build stats
run: cargo build --all --release --timings

1047
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -26,7 +26,6 @@ members = [
"libs/tracing-utils",
"libs/postgres_ffi/wal_craft",
"libs/vm_monitor",
"libs/walproposer",
]
[workspace.package]
@@ -37,10 +36,6 @@ license = "Apache-2.0"
[workspace.dependencies]
anyhow = { version = "1.0", features = ["backtrace"] }
async-compression = { version = "0.4.0", features = ["tokio", "gzip"] }
azure_core = "0.16"
azure_identity = "0.16"
azure_storage = "0.16"
azure_storage_blobs = "0.16"
flate2 = "1.0.26"
async-stream = "0.3"
async-trait = "0.1"
@@ -81,7 +76,6 @@ hex = "0.4"
hex-literal = "0.4"
hmac = "0.12.1"
hostname = "0.3.1"
http-types = "2"
humantime = "2.1"
humantime-serde = "1.1.1"
hyper = "0.14"
@@ -161,11 +155,11 @@ env_logger = "0.10"
log = "0.4"
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="7434d9388965a17a6d113e5dfc0e65666a03b4c2" }
postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="7434d9388965a17a6d113e5dfc0e65666a03b4c2" }
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="7434d9388965a17a6d113e5dfc0e65666a03b4c2" }
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="7434d9388965a17a6d113e5dfc0e65666a03b4c2" }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="7434d9388965a17a6d113e5dfc0e65666a03b4c2" }
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
## Other git libraries
heapless = { default-features=false, features=[], git = "https://github.com/japaric/heapless.git", rev = "644653bf3b831c6bb4963be2de24804acf5e5001" } # upstream release pending
@@ -186,7 +180,6 @@ tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
utils = { version = "0.1", path = "./libs/utils/" }
vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" }
walproposer = { version = "0.1", path = "./libs/walproposer/" }
## Common library dependency
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
@@ -202,7 +195,7 @@ tonic-build = "0.9"
# This is only needed for proxy's tests.
# TODO: we should probably fork `tokio-postgres-rustls` instead.
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="7434d9388965a17a6d113e5dfc0e65666a03b4c2" }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
################# Binary contents sections

View File

@@ -224,8 +224,8 @@ RUN wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.7.tar.gz -
FROM build-deps AS vector-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.5.1.tar.gz -O pgvector.tar.gz && \
echo "cc7a8e034a96e30a819911ac79d32f6bc47bdd1aa2de4d7d4904e26b83209dc8 pgvector.tar.gz" | sha256sum --check && \
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.5.0.tar.gz -O pgvector.tar.gz && \
echo "d8aa3504b215467ca528525a6de12c3f85f9891b091ce0e5864dd8a9b757f77b pgvector.tar.gz" | sha256sum --check && \
mkdir pgvector-src && cd pgvector-src && tar xvzf ../pgvector.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -368,8 +368,8 @@ RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar
FROM build-deps AS plpgsql-check-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.5.3.tar.gz -O plpgsql_check.tar.gz && \
echo "6631ec3e7fb3769eaaf56e3dfedb829aa761abf163d13dba354b4c218508e1c0 plpgsql_check.tar.gz" | sha256sum --check && \
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.4.0.tar.gz -O plpgsql_check.tar.gz && \
echo "9ba58387a279b35a3bfa39ee611e5684e6cddb2ba046ddb2c5190b3bd2ca254a plpgsql_check.tar.gz" | sha256sum --check && \
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xvzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \

View File

@@ -62,7 +62,7 @@ all: neon postgres neon-pg-ext
#
# The 'postgres_ffi' depends on the Postgres headers.
.PHONY: neon
neon: postgres-headers walproposer-lib
neon: postgres-headers
+@echo "Compiling Neon"
$(CARGO_CMD_PREFIX) cargo build $(CARGO_BUILD_FLAGS)
@@ -168,42 +168,6 @@ neon-pg-ext-clean-%:
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean
# Build walproposer as a static library. walproposer source code is located
# in the pgxn/neon directory.
#
# We also need to include libpgport.a and libpgcommon.a, because walproposer
# uses some functions from those libraries.
#
# Some object files are removed from libpgport.a and libpgcommon.a because
# they depend on openssl and other libraries that are not included in our
# Rust build.
.PHONY: walproposer-lib
walproposer-lib: neon-pg-ext-v16
+@echo "Compiling walproposer-lib"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v16/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile walproposer-lib
cp $(POSTGRES_INSTALL_DIR)/v16/lib/libpgport.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
cp $(POSTGRES_INSTALL_DIR)/v16/lib/libpgcommon.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
ifeq ($(UNAME_S),Linux)
$(AR) d $(POSTGRES_INSTALL_DIR)/build/walproposer-lib/libpgport.a \
pg_strong_random.o
$(AR) d $(POSTGRES_INSTALL_DIR)/build/walproposer-lib/libpgcommon.a \
pg_crc32c.o \
hmac_openssl.o \
cryptohash_openssl.o \
scram-common.o \
md5_common.o \
checksum_helper.o
endif
.PHONY: walproposer-lib-clean
walproposer-lib-clean:
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v16/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile clean
.PHONY: neon-pg-ext
neon-pg-ext: \
neon-pg-ext-v14 \

4
NOTICE
View File

@@ -1,5 +1,5 @@
Neon
Copyright 2022 Neon Inc.
The PostgreSQL submodules in vendor/ are licensed under the PostgreSQL license.
See vendor/postgres-vX/COPYRIGHT for details.
The PostgreSQL submodules in vendor/postgres-v14 and vendor/postgres-v15 are licensed under the
PostgreSQL license. See vendor/postgres-v14/COPYRIGHT and vendor/postgres-v15/COPYRIGHT.

View File

@@ -252,7 +252,7 @@ fn create_neon_superuser(spec: &ComputeSpec, client: &mut Client) -> Result<()>
IF NOT EXISTS (
SELECT FROM pg_catalog.pg_roles WHERE rolname = 'neon_superuser')
THEN
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN REPLICATION IN ROLE pg_read_all_data, pg_write_all_data;
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN IN ROLE pg_read_all_data, pg_write_all_data;
IF array_length(roles, 1) IS NOT NULL THEN
EXECUTE format('GRANT neon_superuser TO %s',
array_to_string(ARRAY(SELECT quote_ident(x) FROM unnest(roles) as x), ', '));
@@ -692,11 +692,10 @@ impl ComputeNode {
// Proceed with post-startup configuration. Note, that order of operations is important.
let spec = &compute_state.pspec.as_ref().expect("spec must be set").spec;
create_neon_superuser(spec, &mut client)?;
cleanup_instance(&mut client)?;
handle_roles(spec, &mut client)?;
handle_databases(spec, &mut client)?;
handle_role_deletions(spec, self.connstr.as_str(), &mut client)?;
handle_grants(spec, &mut client, self.connstr.as_str())?;
handle_grants(spec, self.connstr.as_str())?;
handle_extensions(spec, &mut client)?;
create_availability_check_data(&mut client)?;
@@ -732,11 +731,10 @@ impl ComputeNode {
// Disable DDL forwarding because control plane already knows about these roles/databases.
if spec.mode == ComputeMode::Primary {
client.simple_query("SET neon.forward_ddl = false")?;
cleanup_instance(&mut client)?;
handle_roles(&spec, &mut client)?;
handle_databases(&spec, &mut client)?;
handle_role_deletions(&spec, self.connstr.as_str(), &mut client)?;
handle_grants(&spec, &mut client, self.connstr.as_str())?;
handle_grants(&spec, self.connstr.as_str())?;
handle_extensions(&spec, &mut client)?;
}

View File

@@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::fmt::Write;
use std::fs;
use std::fs::File;
@@ -206,37 +205,22 @@ pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
}
/// Build a list of existing Postgres databases
pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
// `pg_database.datconnlimit = -2` means that the database is in the
// invalid state. See:
// https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
let postgres_dbs: Vec<Database> = client
pub fn get_existing_dbs(client: &mut Client) -> Result<Vec<Database>> {
let postgres_dbs = client
.query(
"SELECT
datname AS name,
datdba::regrole::text AS owner,
NOT datallowconn AS restrict_conn,
datconnlimit = - 2 AS invalid
FROM
pg_catalog.pg_database;",
"SELECT datname, datdba::regrole::text as owner
FROM pg_catalog.pg_database;",
&[],
)?
.iter()
.map(|row| Database {
name: row.get("name"),
name: row.get("datname"),
owner: row.get("owner"),
restrict_conn: row.get("restrict_conn"),
invalid: row.get("invalid"),
options: None,
})
.collect();
let dbs_map = postgres_dbs
.iter()
.map(|db| (db.name.clone(), db.clone()))
.collect::<HashMap<_, _>>();
Ok(dbs_map)
Ok(postgres_dbs)
}
/// Wait for Postgres to become ready to accept connections. It's ready to

View File

@@ -13,7 +13,7 @@ use crate::params::PG_HBA_ALL_MD5;
use crate::pg_helpers::*;
use compute_api::responses::{ControlPlaneComputeStatus, ControlPlaneSpecResponse};
use compute_api::spec::{ComputeSpec, PgIdent, Role};
use compute_api::spec::{ComputeSpec, Database, PgIdent, Role};
// Do control plane request and return response if any. In case of error it
// returns a bool flag indicating whether it makes sense to retry the request
@@ -161,38 +161,6 @@ pub fn add_standby_signal(pgdata_path: &Path) -> Result<()> {
Ok(())
}
/// Compute could be unexpectedly shut down, for example, during the
/// database dropping. This leaves the database in the invalid state,
/// which prevents new db creation with the same name. This function
/// will clean it up before proceeding with catalog updates. All
/// possible future cleanup operations may go here too.
#[instrument(skip_all)]
pub fn cleanup_instance(client: &mut Client) -> Result<()> {
let existing_dbs = get_existing_dbs(client)?;
for (_, db) in existing_dbs {
if db.invalid {
// After recent commit in Postgres, interrupted DROP DATABASE
// leaves the database in the invalid state. According to the
// commit message, the only option for user is to drop it again.
// See:
// https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
//
// Postgres Neon extension is done the way, that db is de-registered
// in the control plane metadata only after it is dropped. So there is
// a chance that it still thinks that db should exist. This means
// that it will be re-created by `handle_databases()`. Yet, it's fine
// as user can just repeat drop (in vanilla Postgres they would need
// to do the same, btw).
let query = format!("DROP DATABASE IF EXISTS {}", db.name.pg_quote());
info!("dropping invalid database {}", db.name);
client.execute(query.as_str(), &[])?;
}
}
Ok(())
}
/// Given a cluster spec json and open transaction it handles roles creation,
/// deletion and update.
#[instrument(skip_all)]
@@ -302,7 +270,7 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
}
RoleAction::Create => {
let mut query: String = format!(
"CREATE ROLE {} CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser",
"CREATE ROLE {} CREATEROLE CREATEDB BYPASSRLS IN ROLE neon_superuser",
name.pg_quote()
);
info!("role create query: '{}'", &query);
@@ -411,13 +379,13 @@ fn reassign_owned_objects(spec: &ComputeSpec, connstr: &str, role_name: &PgIdent
/// which together provide us idempotency.
#[instrument(skip_all)]
pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
let existing_dbs = get_existing_dbs(client)?;
let existing_dbs: Vec<Database> = get_existing_dbs(client)?;
// Print a list of existing Postgres databases (only in debug mode)
if span_enabled!(Level::INFO) {
info!("postgres databases:");
for (dbname, db) in &existing_dbs {
info!(" {}:{}", dbname, db.owner);
for r in &existing_dbs {
info!(" {}:{}", r.name, r.owner);
}
}
@@ -471,7 +439,8 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
"rename_db" => {
let new_name = op.new_name.as_ref().unwrap();
if existing_dbs.get(&op.name).is_some() {
// XXX: with a limited number of roles it is fine, but consider making it a HashMap
if existing_dbs.iter().any(|r| r.name == op.name) {
let query: String = format!(
"ALTER DATABASE {} RENAME TO {}",
op.name.pg_quote(),
@@ -488,12 +457,14 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
}
// Refresh Postgres databases info to handle possible renames
let existing_dbs = get_existing_dbs(client)?;
let existing_dbs: Vec<Database> = get_existing_dbs(client)?;
info!("cluster spec databases:");
for db in &spec.cluster.databases {
let name = &db.name;
let pg_db = existing_dbs.get(name);
// XXX: with a limited number of databases it is fine, but consider making it a HashMap
let pg_db = existing_dbs.iter().find(|r| r.name == *name);
enum DatabaseAction {
None,
@@ -559,32 +530,13 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
/// Grant CREATE ON DATABASE to the database owner and do some other alters and grants
/// to allow users creating trusted extensions and re-creating `public` schema, for example.
#[instrument(skip_all)]
pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) -> Result<()> {
info!("modifying database permissions");
let existing_dbs = get_existing_dbs(client)?;
pub fn handle_grants(spec: &ComputeSpec, connstr: &str) -> Result<()> {
info!("cluster spec grants:");
// Do some per-database access adjustments. We'd better do this at db creation time,
// but CREATE DATABASE isn't transactional. So we cannot create db + do some grants
// atomically.
for db in &spec.cluster.databases {
match existing_dbs.get(&db.name) {
Some(pg_db) => {
if pg_db.restrict_conn || pg_db.invalid {
info!(
"skipping grants for db {} (invalid: {}, connections not allowed: {})",
db.name, pg_db.invalid, pg_db.restrict_conn
);
continue;
}
}
None => {
bail!(
"database {} doesn't exist in Postgres after handle_databases()",
db.name
);
}
}
let mut conf = Config::from_str(connstr)?;
conf.dbname(&db.name);
@@ -623,11 +575,6 @@ pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) ->
// Explicitly grant CREATE ON SCHEMA PUBLIC to the web_access user.
// This is needed because since postgres 15 this privilege is removed by default.
// TODO: web_access isn't created for almost 1 year. It could be that we have
// active users of 1 year old projects, but hopefully not, so check it and
// remove this code if possible. The worst thing that could happen is that
// user won't be able to use public schema in NEW databases created in the
// very OLD project.
let grant_query = "DO $$\n\
BEGIN\n\
IF EXISTS(\n\

View File

@@ -28,7 +28,7 @@ mod pg_helpers_tests {
assert_eq!(
spec.cluster.settings.as_pg_settings(),
r#"fsync = off
wal_level = logical
wal_level = replica
hot_standby = on
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on

View File

@@ -86,7 +86,7 @@ where
.stdout(process_log_file)
.stderr(same_file_for_stderr)
.args(args);
let filled_cmd = fill_remote_storage_secrets_vars(fill_rust_env_vars(background_command));
let filled_cmd = fill_aws_secrets_vars(fill_rust_env_vars(background_command));
filled_cmd.envs(envs);
let pid_file_to_check = match initial_pid_file {
@@ -238,13 +238,11 @@ fn fill_rust_env_vars(cmd: &mut Command) -> &mut Command {
filled_cmd
}
fn fill_remote_storage_secrets_vars(mut cmd: &mut Command) -> &mut Command {
fn fill_aws_secrets_vars(mut cmd: &mut Command) -> &mut Command {
for env_key in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AZURE_STORAGE_ACCOUNT",
"AZURE_STORAGE_ACCESS_KEY",
] {
if let Ok(value) = std::env::var(env_key) {
cmd = cmd.env(env_key, value);

View File

@@ -13,7 +13,6 @@ use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::{collections::HashMap, sync::Arc};
use utils::logging::{self, LogFormat};
use utils::signals::{ShutdownSignals, Signal};
use utils::{
http::{
@@ -269,16 +268,7 @@ async fn main() -> anyhow::Result<()> {
let server = hyper::Server::from_tcp(http_listener)?.serve(service);
tracing::info!("Serving on {0}", args.listen);
tokio::task::spawn(server);
ShutdownSignals::handle(|signal| match signal {
Signal::Interrupt | Signal::Terminate | Signal::Quit => {
tracing::info!("Got {}. Terminating", signal.name());
// We're just a test helper: no graceful shutdown.
std::process::exit(0);
}
})?;
server.await?;
Ok(())
}

View File

@@ -116,7 +116,6 @@ fn main() -> Result<()> {
"attachment_service" => handle_attachment_service(sub_args, &env),
"safekeeper" => handle_safekeeper(sub_args, &env),
"endpoint" => handle_endpoint(sub_args, &env),
"mappings" => handle_mappings(sub_args, &mut env),
"pg" => bail!("'pg' subcommand has been renamed to 'endpoint'"),
_ => bail!("unexpected subcommand {sub_name}"),
};
@@ -817,38 +816,6 @@ fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<(
Ok(())
}
fn handle_mappings(sub_match: &ArgMatches, env: &mut local_env::LocalEnv) -> Result<()> {
let (sub_name, sub_args) = match sub_match.subcommand() {
Some(ep_subcommand_data) => ep_subcommand_data,
None => bail!("no mappings subcommand provided"),
};
match sub_name {
"map" => {
let branch_name = sub_args
.get_one::<String>("branch-name")
.expect("branch-name argument missing");
let tenant_id = sub_args
.get_one::<String>("tenant-id")
.map(|x| TenantId::from_str(x))
.expect("tenant-id argument missing")
.expect("malformed tenant-id arg");
let timeline_id = sub_args
.get_one::<String>("timeline-id")
.map(|x| TimelineId::from_str(x))
.expect("timeline-id argument missing")
.expect("malformed timeline-id arg");
env.register_branch_mapping(branch_name.to_owned(), tenant_id, timeline_id)?;
Ok(())
}
other => unimplemented!("mappings subcommand {other}"),
}
}
fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
fn get_pageserver(env: &local_env::LocalEnv, args: &ArgMatches) -> Result<PageServerNode> {
let node_id = if let Some(id_str) = args.get_one::<String>("pageserver-id") {
@@ -1117,7 +1084,6 @@ fn cli() -> Command {
// --id, when using a pageserver command
let pageserver_id_arg = Arg::new("pageserver-id")
.long("id")
.global(true)
.help("pageserver id")
.required(false);
// --pageserver-id when using a non-pageserver command
@@ -1288,20 +1254,17 @@ fn cli() -> Command {
Command::new("pageserver")
.arg_required_else_help(true)
.about("Manage pageserver")
.arg(pageserver_id_arg)
.subcommand(Command::new("status"))
.subcommand(Command::new("start")
.about("Start local pageserver")
.arg(pageserver_config_args.clone())
)
.subcommand(Command::new("stop")
.about("Stop local pageserver")
.arg(stop_mode_arg.clone())
)
.subcommand(Command::new("restart")
.about("Restart local pageserver")
.arg(pageserver_config_args.clone())
)
.arg(pageserver_id_arg.clone())
.subcommand(Command::new("start").about("Start local pageserver")
.arg(pageserver_id_arg.clone())
.arg(pageserver_config_args.clone()))
.subcommand(Command::new("stop").about("Stop local pageserver")
.arg(pageserver_id_arg.clone())
.arg(stop_mode_arg.clone()))
.subcommand(Command::new("restart").about("Restart local pageserver")
.arg(pageserver_id_arg.clone())
.arg(pageserver_config_args.clone()))
)
.subcommand(
Command::new("attachment_service")
@@ -1358,8 +1321,8 @@ fn cli() -> Command {
.about("Start postgres.\n If the endpoint doesn't exist yet, it is created.")
.arg(endpoint_id_arg.clone())
.arg(tenant_id_arg.clone())
.arg(branch_name_arg.clone())
.arg(timeline_id_arg.clone())
.arg(branch_name_arg)
.arg(timeline_id_arg)
.arg(lsn_arg)
.arg(pg_port_arg)
.arg(http_port_arg)
@@ -1372,7 +1335,7 @@ fn cli() -> Command {
.subcommand(
Command::new("stop")
.arg(endpoint_id_arg)
.arg(tenant_id_arg.clone())
.arg(tenant_id_arg)
.arg(
Arg::new("destroy")
.help("Also delete data directory (now optional, should be default in future)")
@@ -1383,18 +1346,6 @@ fn cli() -> Command {
)
)
.subcommand(
Command::new("mappings")
.arg_required_else_help(true)
.about("Manage neon_local branch name mappings")
.subcommand(
Command::new("map")
.about("Create new mapping which cannot exist already")
.arg(branch_name_arg.clone())
.arg(tenant_id_arg.clone())
.arg(timeline_id_arg.clone())
)
)
// Obsolete old name for 'endpoint'. We now just print an error if it's used.
.subcommand(
Command::new("pg")

View File

@@ -253,7 +253,7 @@ impl Endpoint {
conf.append("shared_buffers", "1MB");
conf.append("fsync", "off");
conf.append("max_connections", "100");
conf.append("wal_level", "logical");
conf.append("wal_level", "replica");
// wal_sender_timeout is the maximum time to wait for WAL replication.
// It also defines how often the walreciever will send a feedback message to the wal sender.
conf.append("wal_sender_timeout", "5s");

View File

@@ -25,7 +25,7 @@
},
{
"name": "wal_level",
"value": "logical",
"value": "replica",
"vartype": "enum"
},
{

View File

@@ -188,60 +188,11 @@ that.
## Error message style
### PostgreSQL extensions
PostgreSQL has a style guide for writing error messages:
https://www.postgresql.org/docs/current/error-style-guide.html
Follow that guide when writing error messages in the PostgreSQL
extensions.
### Neon Rust code
#### Anyhow Context
When adding anyhow `context()`, use form `present-tense-verb+action`.
Example:
- Bad: `file.metadata().context("could not get file metadata")?;`
- Good: `file.metadata().context("get file metadata")?;`
#### Logging Errors
When logging any error `e`, use `could not {e:#}` or `failed to {e:#}`.
If `e` is an `anyhow` error and you want to log the backtrace that it contains,
use `{e:?}` instead of `{e:#}`.
#### Rationale
The `{:#}` ("alternate Display") of an `anyhow` error chain is concatenation fo the contexts, using `: `.
For example, the following Rust code will result in output
```
ERROR failed to list users: load users from server: parse response: invalid json
```
This is more concise / less noisy than what happens if you do `.context("could not ...")?` at each level, i.e.:
```
ERROR could not list users: could not load users from server: could not parse response: invalid json
```
```rust
fn main() {
match list_users().context("list users") else {
Ok(_) => ...,
Err(e) => tracing::error!("failed to {e:#}"),
}
}
fn list_users() {
http_get_users().context("load users from server")?;
}
fn http_get_users() {
let response = client....?;
response.parse().context("parse response")?; // fails with serde error "invalid json"
}
```
extension. We don't follow it strictly in the pageserver and
safekeeper, but the advice in the PostgreSQL style guide is generally
good, and you can't go wrong by following it.

View File

@@ -96,16 +96,6 @@ prefix_in_bucket = '/test_prefix/'
`AWS_SECRET_ACCESS_KEY` and `AWS_ACCESS_KEY_ID` env variables can be used to specify the S3 credentials if needed.
or
```toml
[remote_storage]
container_name = 'some-container-name'
container_region = 'us-east'
prefix_in_container = '/test-prefix/'
```
`AZURE_STORAGE_ACCOUNT` and `AZURE_STORAGE_ACCESS_KEY` env variables can be used to specify the azure credentials if needed.
## Repository background tasks

View File

@@ -200,12 +200,6 @@ pub struct Database {
pub name: PgIdent,
pub owner: PgIdent,
pub options: GenericOptions,
// These are derived flags, not present in the spec file.
// They are never set by the control plane.
#[serde(skip_deserializing, default)]
pub restrict_conn: bool,
#[serde(skip_deserializing, default)]
pub invalid: bool,
}
/// Common type representing both SQL statement params with or without value,

View File

@@ -76,7 +76,7 @@
},
{
"name": "wal_level",
"value": "logical",
"value": "replica",
"vartype": "enum"
},
{

View File

@@ -1,6 +1,6 @@
use std::io::{Read, Result, Write};
/// A wrapper for an object implementing [Read]
/// A wrapper for an object implementing [Read](std::io::Read)
/// which allows a closure to observe the amount of bytes read.
/// This is useful in conjunction with metrics (e.g. [IntCounter](crate::IntCounter)).
///
@@ -51,17 +51,17 @@ impl<'a, T> CountedReader<'a, T> {
}
}
/// Get an immutable reference to the underlying [Read] implementor
/// Get an immutable reference to the underlying [Read](std::io::Read) implementor
pub fn inner(&self) -> &T {
&self.reader
}
/// Get a mutable reference to the underlying [Read] implementor
/// Get a mutable reference to the underlying [Read](std::io::Read) implementor
pub fn inner_mut(&mut self) -> &mut T {
&mut self.reader
}
/// Consume the wrapper and return the underlying [Read] implementor
/// Consume the wrapper and return the underlying [Read](std::io::Read) implementor
pub fn into_inner(self) -> T {
self.reader
}
@@ -75,7 +75,7 @@ impl<T: Read> Read for CountedReader<'_, T> {
}
}
/// A wrapper for an object implementing [Write]
/// A wrapper for an object implementing [Write](std::io::Write)
/// which allows a closure to observe the amount of bytes written.
/// This is useful in conjunction with metrics (e.g. [IntCounter](crate::IntCounter)).
///
@@ -122,17 +122,17 @@ impl<'a, T> CountedWriter<'a, T> {
}
}
/// Get an immutable reference to the underlying [Write] implementor
/// Get an immutable reference to the underlying [Write](std::io::Write) implementor
pub fn inner(&self) -> &T {
&self.writer
}
/// Get a mutable reference to the underlying [Write] implementor
/// Get a mutable reference to the underlying [Write](std::io::Write) implementor
pub fn inner_mut(&mut self) -> &mut T {
&mut self.writer
}
/// Consume the wrapper and return the underlying [Write] implementor
/// Consume the wrapper and return the underlying [Write](std::io::Write) implementor
pub fn into_inner(self) -> T {
self.writer
}

View File

@@ -22,9 +22,9 @@ use postgres_ffi::Oid;
/// [See more related comments here](https:///github.com/postgres/postgres/blob/99c5852e20a0987eca1c38ba0c09329d4076b6a0/src/include/storage/relfilenode.h#L57).
///
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
// Then we could replace the custom Ord and PartialOrd implementations below with
// deriving them. This will require changes in walredoproc.c.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize)]
// Then we could replace the custo Ord and PartialOrd implementations below with
// deriving them.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct RelTag {
pub forknum: u8,
pub spcnode: Oid,
@@ -40,9 +40,21 @@ impl PartialOrd for RelTag {
impl Ord for RelTag {
fn cmp(&self, other: &Self) -> Ordering {
// Custom ordering where we put forknum to the end of the list
let other_tup = (other.spcnode, other.dbnode, other.relnode, other.forknum);
(self.spcnode, self.dbnode, self.relnode, self.forknum).cmp(&other_tup)
let mut cmp = self.spcnode.cmp(&other.spcnode);
if cmp != Ordering::Equal {
return cmp;
}
cmp = self.dbnode.cmp(&other.dbnode);
if cmp != Ordering::Equal {
return cmp;
}
cmp = self.relnode.cmp(&other.relnode);
if cmp != Ordering::Equal {
return cmp;
}
cmp = self.forknum.cmp(&other.forknum);
cmp
}
}

View File

@@ -19,8 +19,8 @@ use tracing::{debug, error, info, trace};
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
use pq_proto::{
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
SQLSTATE_SUCCESSFUL_COMPLETION,
};
/// An error, occurred during query processing:
@@ -30,9 +30,6 @@ pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// We were instructed to shutdown while processing the query
#[error("Shutting down")]
Shutdown,
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
@@ -47,8 +44,7 @@ impl From<io::Error> for QueryError {
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
@@ -400,20 +396,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
// socket might be already closed, e.g. if previously received error,
// so ignore result.
self.framed.shutdown().await.ok();
match ret {
Ok(()) => Ok(()),
Err(QueryError::Shutdown) => {
info!("Stopped due to shutdown");
Ok(())
}
Err(QueryError::Disconnected(e)) => {
info!("Disconnected ({e:#})");
// Disconnection is not an error: we just use it that way internally to drop
// out of loops.
Ok(())
}
e => e,
}
ret
}
async fn run_message_loop<F, S>(
@@ -433,11 +416,15 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Err(QueryError::Shutdown)
return Ok(())
},
handshake_r = self.handshake(handler) => {
handshake_r?;
result = self.handshake(handler) => {
// Handshake complete.
result?;
if self.state == ProtoState::Closed {
return Ok(()); // EOF during handshake
}
}
);
@@ -448,34 +435,17 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
return Err(QueryError::Shutdown)
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during response flush");
// If we exited process_message with a shutdown error, there may be
// some valid response content on in our transmit buffer: permit sending
// this within a short timeout. This is a best effort thing so we don't
// care about the result.
tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
return Err(QueryError::Shutdown)
},
flush_r = self.flush() => {
flush_r?;
}
);
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
@@ -580,9 +550,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
self.peer_addr
);
self.state = ProtoState::Closed;
return Err(QueryError::Disconnected(ConnectionError::Protocol(
ProtocolError::Protocol("EOF during handshake".to_string()),
)));
return Ok(());
}
}
}
@@ -621,9 +589,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
self.peer_addr
);
self.state = ProtoState::Closed;
return Err(QueryError::Disconnected(ConnectionError::Protocol(
ProtocolError::Protocol("EOF during auth".to_string()),
)));
return Ok(());
}
}
}
@@ -947,7 +913,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, I
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Shutdown => "shutdown".to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
@@ -964,9 +929,6 @@ fn log_query_error(query: &str, e: &QueryError) {
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Shutdown => {
info!("query handler for '{query}' cancelled during tenant shutdown")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}

View File

@@ -131,7 +131,6 @@ pub const MAX_SEND_SIZE: usize = XLOG_BLCKSZ * 16;
// Export some version independent functions that are used outside of this mod
pub use v14::xlog_utils::encode_logical_message;
pub use v14::xlog_utils::from_pg_timestamp;
pub use v14::xlog_utils::get_current_timestamp;
pub use v14::xlog_utils::to_pg_timestamp;
pub use v14::xlog_utils::XLogFileName;

View File

@@ -220,10 +220,6 @@ pub const XLOG_CHECKPOINT_ONLINE: u8 = 0x10;
pub const XLP_FIRST_IS_CONTRECORD: u16 = 0x0001;
pub const XLP_LONG_HEADER: u16 = 0x0002;
/* From replication/slot.h */
pub const REPL_SLOT_ON_DISK_OFFSETOF_RESTART_LSN: usize = 4*4 /* offset of `slotdata` in ReplicationSlotOnDisk */
+ 64 /* NameData */ + 4*4;
/* From fsm_internals.h */
const FSM_NODES_PER_PAGE: usize = BLCKSZ as usize - SIZEOF_PAGE_HEADER_DATA - 4;
const FSM_NON_LEAF_NODES_PER_PAGE: usize = BLCKSZ as usize / 2 - 1;

View File

@@ -136,42 +136,21 @@ pub fn get_current_timestamp() -> TimestampTz {
to_pg_timestamp(SystemTime::now())
}
// Module to reduce the scope of the constants
mod timestamp_conversions {
use std::time::Duration;
use super::*;
const UNIX_EPOCH_JDATE: u64 = 2440588; // == date2j(1970, 1, 1)
const POSTGRES_EPOCH_JDATE: u64 = 2451545; // == date2j(2000, 1, 1)
pub fn to_pg_timestamp(time: SystemTime) -> TimestampTz {
const UNIX_EPOCH_JDATE: u64 = 2440588; /* == date2j(1970, 1, 1) */
const POSTGRES_EPOCH_JDATE: u64 = 2451545; /* == date2j(2000, 1, 1) */
const SECS_PER_DAY: u64 = 86400;
const USECS_PER_SEC: u64 = 1000000;
const SECS_DIFF_UNIX_TO_POSTGRES_EPOCH: u64 =
(POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * SECS_PER_DAY;
pub fn to_pg_timestamp(time: SystemTime) -> TimestampTz {
match time.duration_since(SystemTime::UNIX_EPOCH) {
Ok(n) => {
((n.as_secs() - SECS_DIFF_UNIX_TO_POSTGRES_EPOCH) * USECS_PER_SEC
+ n.subsec_micros() as u64) as i64
}
Err(_) => panic!("SystemTime before UNIX EPOCH!"),
match time.duration_since(SystemTime::UNIX_EPOCH) {
Ok(n) => {
((n.as_secs() - ((POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * SECS_PER_DAY))
* USECS_PER_SEC
+ n.subsec_micros() as u64) as i64
}
}
pub fn from_pg_timestamp(time: TimestampTz) -> SystemTime {
let time: u64 = time
.try_into()
.expect("timestamp before millenium (postgres epoch)");
let since_unix_epoch = time + SECS_DIFF_UNIX_TO_POSTGRES_EPOCH * USECS_PER_SEC;
SystemTime::UNIX_EPOCH
.checked_add(Duration::from_micros(since_unix_epoch))
.expect("SystemTime overflow")
Err(_) => panic!("SystemTime before UNIX EPOCH!"),
}
}
pub use timestamp_conversions::{from_pg_timestamp, to_pg_timestamp};
// Returns (aligned) end_lsn of the last record in data_dir with WAL segments.
// start_lsn must point to some previously known record boundary (beginning of
// the next record). If no valid record after is found, start_lsn is returned
@@ -502,24 +481,4 @@ pub fn encode_logical_message(prefix: &str, message: &str) -> Vec<u8> {
wal
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ts_conversion() {
let now = SystemTime::now();
let round_trip = from_pg_timestamp(to_pg_timestamp(now));
let now_since = now.duration_since(SystemTime::UNIX_EPOCH).unwrap();
let round_trip_since = round_trip.duration_since(SystemTime::UNIX_EPOCH).unwrap();
assert_eq!(now_since.as_micros(), round_trip_since.as_micros());
let now_pg = get_current_timestamp();
let round_trip_pg = to_pg_timestamp(from_pg_timestamp(now_pg));
assert_eq!(now_pg, round_trip_pg);
}
// If you need to craft WAL and write tests for this module, put it at wal_craft crate.
}
// If you need to craft WAL and write tests for this module, put it at wal_craft crate.

View File

@@ -670,7 +670,6 @@ pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
}
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
impl<'a> BeMessage<'a> {

View File

@@ -13,7 +13,6 @@ aws-types.workspace = true
aws-config.workspace = true
aws-sdk-s3.workspace = true
aws-credential-types.workspace = true
bytes.workspace = true
camino.workspace = true
hyper = { workspace = true, features = ["stream"] }
serde.workspace = true
@@ -27,13 +26,6 @@ metrics.workspace = true
utils.workspace = true
pin-project-lite.workspace = true
workspace_hack.workspace = true
azure_core.workspace = true
azure_identity.workspace = true
azure_storage.workspace = true
azure_storage_blobs.workspace = true
futures-util.workspace = true
http-types.workspace = true
itertools.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true

View File

@@ -1,356 +0,0 @@
//! Azure Blob Storage wrapper
use std::env;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::{borrow::Cow, collections::HashMap, io::Cursor};
use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
use anyhow::Result;
use azure_core::request_options::{MaxResults, Metadata, Range};
use azure_core::Header;
use azure_identity::DefaultAzureCredential;
use azure_storage::StorageCredentials;
use azure_storage_blobs::prelude::ClientBuilder;
use azure_storage_blobs::{
blob::operations::GetBlobBuilder,
prelude::{BlobClient, ContainerClient},
};
use futures_util::StreamExt;
use http_types::StatusCode;
use tokio::io::AsyncRead;
use tracing::debug;
use crate::s3_bucket::RequestKind;
use crate::{
AzureConfig, ConcurrencyLimiter, Download, DownloadError, RemotePath, RemoteStorage,
StorageMetadata,
};
pub struct AzureBlobStorage {
client: ContainerClient,
prefix_in_container: Option<String>,
max_keys_per_list_response: Option<NonZeroU32>,
concurrency_limiter: ConcurrencyLimiter,
}
impl AzureBlobStorage {
pub fn new(azure_config: &AzureConfig) -> Result<Self> {
debug!(
"Creating azure remote storage for azure container {}",
azure_config.container_name
);
let account = env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT");
// If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
// otherwise try the token based credentials.
let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
StorageCredentials::access_key(account.clone(), access_key)
} else {
let token_credential = DefaultAzureCredential::default();
StorageCredentials::token_credential(Arc::new(token_credential))
};
let builder = ClientBuilder::new(account, credentials);
let client = builder.container_client(azure_config.container_name.to_owned());
let max_keys_per_list_response =
if let Some(limit) = azure_config.max_keys_per_list_response {
Some(
NonZeroU32::new(limit as u32)
.ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
)
} else {
None
};
Ok(AzureBlobStorage {
client,
prefix_in_container: azure_config.prefix_in_container.to_owned(),
max_keys_per_list_response,
concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
})
}
pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
let path_string = path
.get_path()
.as_str()
.trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
match &self.prefix_in_container {
Some(prefix) => {
if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
prefix.clone() + path_string
} else {
format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
}
}
None => path_string.to_string(),
}
}
fn name_to_relative_path(&self, key: &str) -> RemotePath {
let relative_path =
match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
Some(stripped) => stripped,
// we rely on Azure to return properly prefixed paths
// for requests with a certain prefix
None => panic!(
"Key {key} does not start with container prefix {:?}",
self.prefix_in_container
),
};
RemotePath(
relative_path
.split(REMOTE_STORAGE_PREFIX_SEPARATOR)
.collect(),
)
}
async fn download_for_builder(
&self,
metadata: StorageMetadata,
builder: GetBlobBuilder,
) -> Result<Download, DownloadError> {
let mut response = builder.into_stream();
// TODO give proper streaming response instead of buffering into RAM
// https://github.com/neondatabase/neon/issues/5563
let mut buf = Vec::new();
while let Some(part) = response.next().await {
let part = part.map_err(to_download_error)?;
let data = part
.data
.collect()
.await
.map_err(|e| DownloadError::Other(e.into()))?;
buf.extend_from_slice(&data.slice(..));
}
Ok(Download {
download_stream: Box::pin(Cursor::new(buf)),
metadata: Some(metadata),
})
}
// TODO get rid of this function once we have metadata included in the response
// https://github.com/Azure/azure-sdk-for-rust/issues/1439
async fn get_metadata(
&self,
blob_client: &BlobClient,
) -> Result<StorageMetadata, DownloadError> {
let builder = blob_client.get_metadata();
let response = builder.into_future().await.map_err(to_download_error)?;
let mut map = HashMap::new();
for md in response.metadata.iter() {
map.insert(
md.name().as_str().to_string(),
md.value().as_str().to_string(),
);
}
Ok(StorageMetadata(map))
}
async fn permit(&self, kind: RequestKind) -> tokio::sync::SemaphorePermit<'_> {
self.concurrency_limiter
.acquire(kind)
.await
.expect("semaphore is never closed")
}
}
fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
let mut res = Metadata::new();
for (k, v) in metadata.0.into_iter() {
res.insert(k, v);
}
res
}
fn to_download_error(error: azure_core::Error) -> DownloadError {
if let Some(http_err) = error.as_http_error() {
match http_err.status() {
StatusCode::NotFound => DownloadError::NotFound,
StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
_ => DownloadError::Other(anyhow::Error::new(error)),
}
} else {
DownloadError::Other(error.into())
}
}
#[async_trait::async_trait]
impl RemoteStorage for AzureBlobStorage {
async fn list_prefixes(
&self,
prefix: Option<&RemotePath>,
) -> Result<Vec<RemotePath>, DownloadError> {
// get the passed prefix or if it is not set use prefix_in_bucket value
let list_prefix = prefix
.map(|p| self.relative_path_to_name(p))
.or_else(|| self.prefix_in_container.clone())
.map(|mut p| {
// required to end with a separator
// otherwise request will return only the entry of a prefix
if !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
}
p
});
let mut builder = self
.client
.list_blobs()
.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
if let Some(prefix) = list_prefix {
builder = builder.prefix(Cow::from(prefix.to_owned()));
}
if let Some(limit) = self.max_keys_per_list_response {
builder = builder.max_results(MaxResults::new(limit));
}
let mut response = builder.into_stream();
let mut res = Vec::new();
while let Some(entry) = response.next().await {
let entry = entry.map_err(to_download_error)?;
let name_iter = entry
.blobs
.prefixes()
.map(|prefix| self.name_to_relative_path(&prefix.name));
res.extend(name_iter);
}
Ok(res)
}
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
let folder_name = folder
.map(|p| self.relative_path_to_name(p))
.or_else(|| self.prefix_in_container.clone());
let mut builder = self.client.list_blobs();
if let Some(folder_name) = folder_name {
builder = builder.prefix(Cow::from(folder_name.to_owned()));
}
if let Some(limit) = self.max_keys_per_list_response {
builder = builder.max_results(MaxResults::new(limit));
}
let mut response = builder.into_stream();
let mut res = Vec::new();
while let Some(l) = response.next().await {
let entry = l.map_err(anyhow::Error::new)?;
let name_iter = entry
.blobs
.blobs()
.map(|bl| self.name_to_relative_path(&bl.name));
res.extend(name_iter);
}
Ok(res)
}
async fn upload(
&self,
mut from: impl AsyncRead + Unpin + Send + Sync + 'static,
data_size_bytes: usize,
to: &RemotePath,
metadata: Option<StorageMetadata>,
) -> anyhow::Result<()> {
let _permit = self.permit(RequestKind::Put).await;
let blob_client = self.client.blob_client(self.relative_path_to_name(to));
// TODO FIX THIS UGLY HACK and don't buffer the entire object
// into RAM here, but use the streaming interface. For that,
// we'd have to change the interface though...
// https://github.com/neondatabase/neon/issues/5563
let mut buf = Vec::with_capacity(data_size_bytes);
tokio::io::copy(&mut from, &mut buf).await?;
let body = azure_core::Body::Bytes(buf.into());
let mut builder = blob_client.put_block_blob(body);
if let Some(metadata) = metadata {
builder = builder.metadata(to_azure_metadata(metadata));
}
let _response = builder.into_future().await?;
Ok(())
}
async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
let _permit = self.permit(RequestKind::Get).await;
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
let metadata = self.get_metadata(&blob_client).await?;
let builder = blob_client.get();
self.download_for_builder(metadata, builder).await
}
async fn download_byte_range(
&self,
from: &RemotePath,
start_inclusive: u64,
end_exclusive: Option<u64>,
) -> Result<Download, DownloadError> {
let _permit = self.permit(RequestKind::Get).await;
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
let metadata = self.get_metadata(&blob_client).await?;
let mut builder = blob_client.get();
if let Some(end_exclusive) = end_exclusive {
builder = builder.range(Range::new(start_inclusive, end_exclusive));
} else {
// Open ranges are not supported by the SDK so we work around
// by setting the upper limit extremely high (but high enough
// to still be representable by signed 64 bit integers).
// TODO remove workaround once the SDK adds open range support
// https://github.com/Azure/azure-sdk-for-rust/issues/1438
let end_exclusive = u64::MAX / 4;
builder = builder.range(Range::new(start_inclusive, end_exclusive));
}
self.download_for_builder(metadata, builder).await
}
async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> {
let _permit = self.permit(RequestKind::Delete).await;
let blob_client = self.client.blob_client(self.relative_path_to_name(path));
let builder = blob_client.delete();
match builder.into_future().await {
Ok(_response) => Ok(()),
Err(e) => {
if let Some(http_err) = e.as_http_error() {
if http_err.status() == StatusCode::NotFound {
return Ok(());
}
}
Err(anyhow::Error::new(e))
}
}
}
async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> {
// Permit is already obtained by inner delete function
// TODO batch requests are also not supported by the SDK
// https://github.com/Azure/azure-sdk-for-rust/issues/1068
// https://github.com/Azure/azure-sdk-for-rust/issues/1249
for path in paths {
self.delete(path).await?;
}
Ok(())
}
}

View File

@@ -4,10 +4,7 @@
//! [`RemoteStorage`] trait a CRUD-like generic abstraction to use for adapting external storages with a few implementations:
//! * [`local_fs`] allows to use local file system as an external storage
//! * [`s3_bucket`] uses AWS S3 bucket as an external storage
//! * [`azure_blob`] allows to use Azure Blob storage as an external storage
//!
mod azure_blob;
mod local_fs;
mod s3_bucket;
mod simulate_failures;
@@ -24,15 +21,11 @@ use anyhow::{bail, Context};
use camino::{Utf8Path, Utf8PathBuf};
use serde::{Deserialize, Serialize};
use tokio::{io, sync::Semaphore};
use tokio::io;
use toml_edit::Item;
use tracing::info;
pub use self::{
azure_blob::AzureBlobStorage, local_fs::LocalFs, s3_bucket::S3Bucket,
simulate_failures::UnreliableWrapper,
};
use s3_bucket::RequestKind;
pub use self::{local_fs::LocalFs, s3_bucket::S3Bucket, simulate_failures::UnreliableWrapper};
/// How many different timelines can be processed simultaneously when synchronizing layers with the remote storage.
/// During regular work, pageserver produces one layer file per timeline checkpoint, with bursts of concurrency
@@ -46,11 +39,6 @@ pub const DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS: u32 = 10;
/// ~3500 PUT/COPY/POST/DELETE or 5500 GET/HEAD S3 requests
/// <https://aws.amazon.com/premiumsupport/knowledge-center/s3-request-limit-avoid-throttling/>
pub const DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT: usize = 100;
/// We set this a little bit low as we currently buffer the entire file into RAM
///
/// Here, a limit of max 20k concurrent connections was noted.
/// <https://learn.microsoft.com/en-us/answers/questions/1301863/is-there-any-limitation-to-concurrent-connections>
pub const DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT: usize = 30;
/// No limits on the client side, which currenltly means 1000 for AWS S3.
/// <https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html#API_ListObjectsV2_RequestSyntax>
pub const DEFAULT_MAX_KEYS_PER_LIST_RESPONSE: Option<i32> = None;
@@ -229,7 +217,6 @@ impl std::error::Error for DownloadError {}
pub enum GenericRemoteStorage {
LocalFs(LocalFs),
AwsS3(Arc<S3Bucket>),
AzureBlob(Arc<AzureBlobStorage>),
Unreliable(Arc<UnreliableWrapper>),
}
@@ -241,7 +228,6 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.list_files(folder).await,
Self::AwsS3(s) => s.list_files(folder).await,
Self::AzureBlob(s) => s.list_files(folder).await,
Self::Unreliable(s) => s.list_files(folder).await,
}
}
@@ -256,7 +242,6 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.list_prefixes(prefix).await,
Self::AwsS3(s) => s.list_prefixes(prefix).await,
Self::AzureBlob(s) => s.list_prefixes(prefix).await,
Self::Unreliable(s) => s.list_prefixes(prefix).await,
}
}
@@ -271,7 +256,6 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.upload(from, data_size_bytes, to, metadata).await,
Self::AwsS3(s) => s.upload(from, data_size_bytes, to, metadata).await,
Self::AzureBlob(s) => s.upload(from, data_size_bytes, to, metadata).await,
Self::Unreliable(s) => s.upload(from, data_size_bytes, to, metadata).await,
}
}
@@ -280,7 +264,6 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.download(from).await,
Self::AwsS3(s) => s.download(from).await,
Self::AzureBlob(s) => s.download(from).await,
Self::Unreliable(s) => s.download(from).await,
}
}
@@ -300,10 +283,6 @@ impl GenericRemoteStorage {
s.download_byte_range(from, start_inclusive, end_exclusive)
.await
}
Self::AzureBlob(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive)
.await
}
Self::Unreliable(s) => {
s.download_byte_range(from, start_inclusive, end_exclusive)
.await
@@ -315,7 +294,6 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.delete(path).await,
Self::AwsS3(s) => s.delete(path).await,
Self::AzureBlob(s) => s.delete(path).await,
Self::Unreliable(s) => s.delete(path).await,
}
}
@@ -324,7 +302,6 @@ impl GenericRemoteStorage {
match self {
Self::LocalFs(s) => s.delete_objects(paths).await,
Self::AwsS3(s) => s.delete_objects(paths).await,
Self::AzureBlob(s) => s.delete_objects(paths).await,
Self::Unreliable(s) => s.delete_objects(paths).await,
}
}
@@ -342,11 +319,6 @@ impl GenericRemoteStorage {
s3_config.bucket_name, s3_config.bucket_region, s3_config.prefix_in_bucket, s3_config.endpoint);
Self::AwsS3(Arc::new(S3Bucket::new(s3_config)?))
}
RemoteStorageKind::AzureContainer(azure_config) => {
info!("Using azure container '{}' in region '{}' as a remote storage, prefix in container: '{:?}'",
azure_config.container_name, azure_config.container_region, azure_config.prefix_in_container);
Self::AzureBlob(Arc::new(AzureBlobStorage::new(azure_config)?))
}
})
}
@@ -411,9 +383,6 @@ pub enum RemoteStorageKind {
/// AWS S3 based storage, storing all files in the S3 bucket
/// specified by the config
AwsS3(S3Config),
/// Azure Blob based storage, storing all files in the container
/// specified by the config
AzureContainer(AzureConfig),
}
/// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
@@ -453,45 +422,11 @@ impl Debug for S3Config {
}
}
/// Azure bucket coordinates and access credentials to manage the bucket contents (read and write).
#[derive(Clone, PartialEq, Eq)]
pub struct AzureConfig {
/// Name of the container to connect to.
pub container_name: String,
/// The region where the bucket is located at.
pub container_region: String,
/// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
pub prefix_in_container: Option<String>,
/// Azure has various limits on its API calls, we need not to exceed those.
/// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
pub concurrency_limit: NonZeroUsize,
pub max_keys_per_list_response: Option<i32>,
}
impl Debug for AzureConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AzureConfig")
.field("bucket_name", &self.container_name)
.field("bucket_region", &self.container_region)
.field("prefix_in_bucket", &self.prefix_in_container)
.field("concurrency_limit", &self.concurrency_limit)
.field(
"max_keys_per_list_response",
&self.max_keys_per_list_response,
)
.finish()
}
}
impl RemoteStorageConfig {
pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<Option<RemoteStorageConfig>> {
let local_path = toml.get("local_path");
let bucket_name = toml.get("bucket_name");
let bucket_region = toml.get("bucket_region");
let container_name = toml.get("container_name");
let container_region = toml.get("container_region");
let use_azure = container_name.is_some() && container_region.is_some();
let max_concurrent_syncs = NonZeroUsize::new(
parse_optional_integer("max_concurrent_syncs", toml)?
@@ -505,13 +440,9 @@ impl RemoteStorageConfig {
)
.context("Failed to parse 'max_sync_errors' as a positive integer")?;
let default_concurrency_limit = if use_azure {
DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT
} else {
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
};
let concurrency_limit = NonZeroUsize::new(
parse_optional_integer("concurrency_limit", toml)?.unwrap_or(default_concurrency_limit),
parse_optional_integer("concurrency_limit", toml)?
.unwrap_or(DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT),
)
.context("Failed to parse 'concurrency_limit' as a positive integer")?;
@@ -520,70 +451,33 @@ impl RemoteStorageConfig {
.context("Failed to parse 'max_keys_per_list_response' as a positive integer")?
.or(DEFAULT_MAX_KEYS_PER_LIST_RESPONSE);
let endpoint = toml
.get("endpoint")
.map(|endpoint| parse_toml_string("endpoint", endpoint))
.transpose()?;
let storage = match (
local_path,
bucket_name,
bucket_region,
container_name,
container_region,
) {
let storage = match (local_path, bucket_name, bucket_region) {
// no 'local_path' nor 'bucket_name' options are provided, consider this remote storage disabled
(None, None, None, None, None) => return Ok(None),
(_, Some(_), None, ..) => {
(None, None, None) => return Ok(None),
(_, Some(_), None) => {
bail!("'bucket_region' option is mandatory if 'bucket_name' is given ")
}
(_, None, Some(_), ..) => {
(_, None, Some(_)) => {
bail!("'bucket_name' option is mandatory if 'bucket_region' is given ")
}
(None, Some(bucket_name), Some(bucket_region), ..) => {
RemoteStorageKind::AwsS3(S3Config {
bucket_name: parse_toml_string("bucket_name", bucket_name)?,
bucket_region: parse_toml_string("bucket_region", bucket_region)?,
prefix_in_bucket: toml
.get("prefix_in_bucket")
.map(|prefix_in_bucket| {
parse_toml_string("prefix_in_bucket", prefix_in_bucket)
})
.transpose()?,
endpoint,
concurrency_limit,
max_keys_per_list_response,
})
}
(_, _, _, Some(_), None) => {
bail!("'container_name' option is mandatory if 'container_region' is given ")
}
(_, _, _, None, Some(_)) => {
bail!("'container_name' option is mandatory if 'container_region' is given ")
}
(None, None, None, Some(container_name), Some(container_region)) => {
RemoteStorageKind::AzureContainer(AzureConfig {
container_name: parse_toml_string("container_name", container_name)?,
container_region: parse_toml_string("container_region", container_region)?,
prefix_in_container: toml
.get("prefix_in_container")
.map(|prefix_in_container| {
parse_toml_string("prefix_in_container", prefix_in_container)
})
.transpose()?,
concurrency_limit,
max_keys_per_list_response,
})
}
(Some(local_path), None, None, None, None) => RemoteStorageKind::LocalFs(
Utf8PathBuf::from(parse_toml_string("local_path", local_path)?),
),
(Some(_), Some(_), ..) => {
bail!("'local_path' and 'bucket_name' are mutually exclusive")
}
(Some(_), _, _, Some(_), Some(_)) => {
bail!("local_path and 'container_name' are mutually exclusive")
}
(None, Some(bucket_name), Some(bucket_region)) => RemoteStorageKind::AwsS3(S3Config {
bucket_name: parse_toml_string("bucket_name", bucket_name)?,
bucket_region: parse_toml_string("bucket_region", bucket_region)?,
prefix_in_bucket: toml
.get("prefix_in_bucket")
.map(|prefix_in_bucket| parse_toml_string("prefix_in_bucket", prefix_in_bucket))
.transpose()?,
endpoint: toml
.get("endpoint")
.map(|endpoint| parse_toml_string("endpoint", endpoint))
.transpose()?,
concurrency_limit,
max_keys_per_list_response,
}),
(Some(local_path), None, None) => RemoteStorageKind::LocalFs(Utf8PathBuf::from(
parse_toml_string("local_path", local_path)?,
)),
(Some(_), Some(_), _) => bail!("local_path and bucket_name are mutually exclusive"),
};
Ok(Some(RemoteStorageConfig {
@@ -619,46 +513,6 @@ fn parse_toml_string(name: &str, item: &Item) -> anyhow::Result<String> {
Ok(s.to_string())
}
struct ConcurrencyLimiter {
// Every request to S3 can be throttled or cancelled, if a certain number of requests per second is exceeded.
// Same goes to IAM, which is queried before every S3 request, if enabled. IAM has even lower RPS threshold.
// The helps to ensure we don't exceed the thresholds.
write: Arc<Semaphore>,
read: Arc<Semaphore>,
}
impl ConcurrencyLimiter {
fn for_kind(&self, kind: RequestKind) -> &Arc<Semaphore> {
match kind {
RequestKind::Get => &self.read,
RequestKind::Put => &self.write,
RequestKind::List => &self.read,
RequestKind::Delete => &self.write,
}
}
async fn acquire(
&self,
kind: RequestKind,
) -> Result<tokio::sync::SemaphorePermit<'_>, tokio::sync::AcquireError> {
self.for_kind(kind).acquire().await
}
async fn acquire_owned(
&self,
kind: RequestKind,
) -> Result<tokio::sync::OwnedSemaphorePermit, tokio::sync::AcquireError> {
Arc::clone(self.for_kind(kind)).acquire_owned().await
}
fn new(limit: usize) -> ConcurrencyLimiter {
Self {
read: Arc::new(Semaphore::new(limit)),
write: Arc::new(Semaphore::new(limit)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -4,7 +4,7 @@
//! allowing multiple api users to independently work with the same S3 bucket, if
//! their bucket prefixes are both specified and different.
use std::borrow::Cow;
use std::sync::Arc;
use anyhow::Context;
use aws_config::{
@@ -24,20 +24,22 @@ use aws_sdk_s3::{
use aws_smithy_http::body::SdkBody;
use hyper::Body;
use scopeguard::ScopeGuard;
use tokio::io::{self, AsyncRead};
use tokio::{
io::{self, AsyncRead},
sync::Semaphore,
};
use tokio_util::io::ReaderStream;
use tracing::debug;
use super::StorageMetadata;
use crate::{
ConcurrencyLimiter, Download, DownloadError, RemotePath, RemoteStorage, S3Config,
MAX_KEYS_PER_DELETE, REMOTE_STORAGE_PREFIX_SEPARATOR,
Download, DownloadError, RemotePath, RemoteStorage, S3Config, MAX_KEYS_PER_DELETE,
REMOTE_STORAGE_PREFIX_SEPARATOR,
};
pub(super) mod metrics;
use self::metrics::AttemptOutcome;
pub(super) use self::metrics::RequestKind;
use self::metrics::{AttemptOutcome, RequestKind};
/// AWS S3 storage.
pub struct S3Bucket {
@@ -48,6 +50,46 @@ pub struct S3Bucket {
concurrency_limiter: ConcurrencyLimiter,
}
struct ConcurrencyLimiter {
// Every request to S3 can be throttled or cancelled, if a certain number of requests per second is exceeded.
// Same goes to IAM, which is queried before every S3 request, if enabled. IAM has even lower RPS threshold.
// The helps to ensure we don't exceed the thresholds.
write: Arc<Semaphore>,
read: Arc<Semaphore>,
}
impl ConcurrencyLimiter {
fn for_kind(&self, kind: RequestKind) -> &Arc<Semaphore> {
match kind {
RequestKind::Get => &self.read,
RequestKind::Put => &self.write,
RequestKind::List => &self.read,
RequestKind::Delete => &self.write,
}
}
async fn acquire(
&self,
kind: RequestKind,
) -> Result<tokio::sync::SemaphorePermit<'_>, tokio::sync::AcquireError> {
self.for_kind(kind).acquire().await
}
async fn acquire_owned(
&self,
kind: RequestKind,
) -> Result<tokio::sync::OwnedSemaphorePermit, tokio::sync::AcquireError> {
Arc::clone(self.for_kind(kind)).acquire_owned().await
}
fn new(limit: usize) -> ConcurrencyLimiter {
Self {
read: Arc::new(Semaphore::new(limit)),
write: Arc::new(Semaphore::new(limit)),
}
}
}
#[derive(Default)]
struct GetObjectRequest {
bucket: String,
@@ -514,20 +556,6 @@ impl RemoteStorage for S3Bucket {
.deleted_objects_total
.inc_by(chunk.len() as u64);
if let Some(errors) = resp.errors {
// Log a bounded number of the errors within the response:
// these requests can carry 1000 keys so logging each one
// would be too verbose, especially as errors may lead us
// to retry repeatedly.
const LOG_UP_TO_N_ERRORS: usize = 10;
for e in errors.iter().take(LOG_UP_TO_N_ERRORS) {
tracing::warn!(
"DeleteObjects key {} failed: {}: {}",
e.key.as_ref().map(Cow::from).unwrap_or("".into()),
e.code.as_ref().map(Cow::from).unwrap_or("".into()),
e.message.as_ref().map(Cow::from).unwrap_or("".into())
);
}
return Err(anyhow::format_err!(
"Failed to delete {} objects",
errors.len()

View File

@@ -6,7 +6,7 @@ use once_cell::sync::Lazy;
pub(super) static BUCKET_METRICS: Lazy<BucketMetrics> = Lazy::new(Default::default);
#[derive(Clone, Copy, Debug)]
pub(crate) enum RequestKind {
pub(super) enum RequestKind {
Get = 0,
Put = 1,
Delete = 2,

View File

@@ -1,625 +0,0 @@
use std::collections::HashSet;
use std::env;
use std::num::{NonZeroU32, NonZeroUsize};
use std::ops::ControlFlow;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::UNIX_EPOCH;
use anyhow::Context;
use camino::Utf8Path;
use once_cell::sync::OnceCell;
use remote_storage::{
AzureConfig, Download, GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind,
};
use test_context::{test_context, AsyncTestContext};
use tokio::task::JoinSet;
use tracing::{debug, error, info};
static LOGGING_DONE: OnceCell<()> = OnceCell::new();
const ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME: &str = "ENABLE_REAL_AZURE_REMOTE_STORAGE";
const BASE_PREFIX: &str = "test";
/// Tests that the Azure client can list all prefixes, even if the response comes paginated and requires multiple HTTP queries.
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified.
/// See the client creation in [`create_azure_client`] for details on the required env vars.
/// If real Azure tests are disabled, the test passes, skipping any real test run: currently, there's no way to mark the test ignored in runtime with the
/// deafult test framework, see https://github.com/rust-lang/rust/issues/68007 for details.
///
/// First, the test creates a set of Azure blobs with keys `/${random_prefix_part}/${base_prefix_str}/sub_prefix_${i}/blob_${i}` in [`upload_azure_data`]
/// where
/// * `random_prefix_part` is set for the entire Azure client during the Azure client creation in [`create_azure_client`], to avoid multiple test runs interference
/// * `base_prefix_str` is a common prefix to use in the client requests: we would want to ensure that the client is able to list nested prefixes inside the bucket
///
/// Then, verifies that the client does return correct prefixes when queried:
/// * with no prefix, it lists everything after its `${random_prefix_part}/` — that should be `${base_prefix_str}` value only
/// * with `${base_prefix_str}/` prefix, it lists every `sub_prefix_${i}`
///
/// With the real Azure enabled and `#[cfg(test)]` Rust configuration used, the Azure client test adds a `max-keys` param to limit the response keys.
/// This way, we are able to test the pagination implicitly, by ensuring all results are returned from the remote storage and avoid uploading too many blobs to Azure.
///
/// Lastly, the test attempts to clean up and remove all uploaded Azure files.
/// If any errors appear during the clean up, they get logged, but the test is not failed or stopped until clean up is finished.
#[test_context(MaybeEnabledAzureWithTestBlobs)]
#[tokio::test]
async fn azure_pagination_should_work(
ctx: &mut MaybeEnabledAzureWithTestBlobs,
) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzureWithTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledAzureWithTestBlobs::Disabled => return Ok(()),
MaybeEnabledAzureWithTestBlobs::UploadsFailed(e, _) => {
anyhow::bail!("Azure init failed: {e:?}")
}
};
let test_client = Arc::clone(&ctx.enabled.client);
let expected_remote_prefixes = ctx.remote_prefixes.clone();
let base_prefix = RemotePath::new(Utf8Path::new(ctx.enabled.base_prefix))
.context("common_prefix construction")?;
let root_remote_prefixes = test_client
.list_prefixes(None)
.await
.context("client list root prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_remote_prefixes, HashSet::from([base_prefix.clone()]),
"remote storage root prefixes list mismatches with the uploads. Returned prefixes: {root_remote_prefixes:?}"
);
let nested_remote_prefixes = test_client
.list_prefixes(Some(&base_prefix))
.await
.context("client list nested prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
let remote_only_prefixes = nested_remote_prefixes
.difference(&expected_remote_prefixes)
.collect::<HashSet<_>>();
let missing_uploaded_prefixes = expected_remote_prefixes
.difference(&nested_remote_prefixes)
.collect::<HashSet<_>>();
assert_eq!(
remote_only_prefixes.len() + missing_uploaded_prefixes.len(), 0,
"remote storage nested prefixes list mismatches with the uploads. Remote only prefixes: {remote_only_prefixes:?}, missing uploaded prefixes: {missing_uploaded_prefixes:?}",
);
Ok(())
}
/// Tests that Azure client can list all files in a folder, even if the response comes paginated and requirees multiple Azure queries.
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified. Test will skip real code and pass if env vars not set.
/// See `Azure_pagination_should_work` for more information.
///
/// First, create a set of Azure objects with keys `random_prefix/folder{j}/blob_{i}.txt` in [`upload_azure_data`]
/// Then performs the following queries:
/// 1. `list_files(None)`. This should return all files `random_prefix/folder{j}/blob_{i}.txt`
/// 2. `list_files("folder1")`. This should return all files `random_prefix/folder1/blob_{i}.txt`
#[test_context(MaybeEnabledAzureWithSimpleTestBlobs)]
#[tokio::test]
async fn azure_list_files_works(
ctx: &mut MaybeEnabledAzureWithSimpleTestBlobs,
) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzureWithSimpleTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledAzureWithSimpleTestBlobs::Disabled => return Ok(()),
MaybeEnabledAzureWithSimpleTestBlobs::UploadsFailed(e, _) => {
anyhow::bail!("Azure init failed: {e:?}")
}
};
let test_client = Arc::clone(&ctx.enabled.client);
let base_prefix =
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
let root_files = test_client
.list_files(None)
.await
.context("client list root files failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_files,
ctx.remote_blobs.clone(),
"remote storage list_files on root mismatches with the uploads."
);
let nested_remote_files = test_client
.list_files(Some(&base_prefix))
.await
.context("client list nested files failure")?
.into_iter()
.collect::<HashSet<_>>();
let trim_remote_blobs: HashSet<_> = ctx
.remote_blobs
.iter()
.map(|x| x.get_path())
.filter(|x| x.starts_with("folder1"))
.map(|x| RemotePath::new(x).expect("must be valid path"))
.collect();
assert_eq!(
nested_remote_files, trim_remote_blobs,
"remote storage list_files on subdirrectory mismatches with the uploads."
);
Ok(())
}
#[test_context(MaybeEnabledAzure)]
#[tokio::test]
async fn azure_delete_non_exising_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzure::Enabled(ctx) => ctx,
MaybeEnabledAzure::Disabled => return Ok(()),
};
let path = RemotePath::new(Utf8Path::new(
format!("{}/for_sure_there_is_nothing_there_really", ctx.base_prefix).as_str(),
))
.with_context(|| "RemotePath conversion")?;
ctx.client.delete(&path).await.expect("should succeed");
Ok(())
}
#[test_context(MaybeEnabledAzure)]
#[tokio::test]
async fn azure_delete_objects_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzure::Enabled(ctx) => ctx,
MaybeEnabledAzure::Disabled => return Ok(()),
};
let path1 = RemotePath::new(Utf8Path::new(format!("{}/path1", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path2 = RemotePath::new(Utf8Path::new(format!("{}/path2", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let data1 = "remote blob data1".as_bytes();
let data1_len = data1.len();
let data2 = "remote blob data2".as_bytes();
let data2_len = data2.len();
let data3 = "remote blob data3".as_bytes();
let data3_len = data3.len();
ctx.client
.upload(std::io::Cursor::new(data1), data1_len, &path1, None)
.await?;
ctx.client
.upload(std::io::Cursor::new(data2), data2_len, &path2, None)
.await?;
ctx.client
.upload(std::io::Cursor::new(data3), data3_len, &path3, None)
.await?;
ctx.client.delete_objects(&[path1, path2]).await?;
let prefixes = ctx.client.list_prefixes(None).await?;
assert_eq!(prefixes.len(), 1);
ctx.client.delete_objects(&[path3]).await?;
Ok(())
}
#[test_context(MaybeEnabledAzure)]
#[tokio::test]
async fn azure_upload_download_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
let MaybeEnabledAzure::Enabled(ctx) = ctx else {
return Ok(());
};
let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let data = "remote blob data here".as_bytes();
let data_len = data.len() as u64;
ctx.client
.upload(std::io::Cursor::new(data), data.len(), &path, None)
.await?;
async fn download_and_compare(mut dl: Download) -> anyhow::Result<Vec<u8>> {
let mut buf = Vec::new();
tokio::io::copy(&mut dl.download_stream, &mut buf).await?;
Ok(buf)
}
// Normal download request
let dl = ctx.client.download(&path).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data);
// Full range (end specified)
let dl = ctx
.client
.download_byte_range(&path, 0, Some(data_len))
.await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data);
// partial range (end specified)
let dl = ctx.client.download_byte_range(&path, 4, Some(10)).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data[4..10]);
// partial range (end beyond real end)
let dl = ctx
.client
.download_byte_range(&path, 8, Some(data_len * 100))
.await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data[8..]);
// Partial range (end unspecified)
let dl = ctx.client.download_byte_range(&path, 4, None).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data[4..]);
// Full range (end unspecified)
let dl = ctx.client.download_byte_range(&path, 0, None).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(buf, data);
debug!("Cleanup: deleting file at path {path:?}");
ctx.client
.delete(&path)
.await
.with_context(|| format!("{path:?} removal"))?;
Ok(())
}
fn ensure_logging_ready() {
LOGGING_DONE.get_or_init(|| {
utils::logging::init(
utils::logging::LogFormat::Test,
utils::logging::TracingErrorLayerEnablement::Disabled,
)
.expect("logging init failed");
});
}
struct EnabledAzure {
client: Arc<GenericRemoteStorage>,
base_prefix: &'static str,
}
impl EnabledAzure {
async fn setup(max_keys_in_list_response: Option<i32>) -> Self {
let client = create_azure_client(max_keys_in_list_response)
.context("Azure client creation")
.expect("Azure client creation failed");
EnabledAzure {
client,
base_prefix: BASE_PREFIX,
}
}
}
enum MaybeEnabledAzure {
Enabled(EnabledAzure),
Disabled,
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledAzure {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
info!(
"`{}` env variable is not set, skipping the test",
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
);
return Self::Disabled;
}
Self::Enabled(EnabledAzure::setup(None).await)
}
}
enum MaybeEnabledAzureWithTestBlobs {
Enabled(AzureWithTestBlobs),
Disabled,
UploadsFailed(anyhow::Error, AzureWithTestBlobs),
}
struct AzureWithTestBlobs {
enabled: EnabledAzure,
remote_prefixes: HashSet<RemotePath>,
remote_blobs: HashSet<RemotePath>,
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledAzureWithTestBlobs {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
info!(
"`{}` env variable is not set, skipping the test",
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
);
return Self::Disabled;
}
let max_keys_in_list_response = 10;
let upload_tasks_count = 1 + (2 * usize::try_from(max_keys_in_list_response).unwrap());
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
match upload_azure_data(&enabled.client, enabled.base_prefix, upload_tasks_count).await {
ControlFlow::Continue(uploads) => {
info!("Remote objects created successfully");
Self::Enabled(AzureWithTestBlobs {
enabled,
remote_prefixes: uploads.prefixes,
remote_blobs: uploads.blobs,
})
}
ControlFlow::Break(uploads) => Self::UploadsFailed(
anyhow::anyhow!("One or multiple blobs failed to upload to Azure"),
AzureWithTestBlobs {
enabled,
remote_prefixes: uploads.prefixes,
remote_blobs: uploads.blobs,
},
),
}
}
async fn teardown(self) {
match self {
Self::Disabled => {}
Self::Enabled(ctx) | Self::UploadsFailed(_, ctx) => {
cleanup(&ctx.enabled.client, ctx.remote_blobs).await;
}
}
}
}
// NOTE: the setups for the list_prefixes test and the list_files test are very similar
// However, they are not idential. The list_prefixes function is concerned with listing prefixes,
// whereas the list_files function is concerned with listing files.
// See `RemoteStorage::list_files` documentation for more details
enum MaybeEnabledAzureWithSimpleTestBlobs {
Enabled(AzureWithSimpleTestBlobs),
Disabled,
UploadsFailed(anyhow::Error, AzureWithSimpleTestBlobs),
}
struct AzureWithSimpleTestBlobs {
enabled: EnabledAzure,
remote_blobs: HashSet<RemotePath>,
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledAzureWithSimpleTestBlobs {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
info!(
"`{}` env variable is not set, skipping the test",
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
);
return Self::Disabled;
}
let max_keys_in_list_response = 10;
let upload_tasks_count = 1 + (2 * usize::try_from(max_keys_in_list_response).unwrap());
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
match upload_simple_azure_data(&enabled.client, upload_tasks_count).await {
ControlFlow::Continue(uploads) => {
info!("Remote objects created successfully");
Self::Enabled(AzureWithSimpleTestBlobs {
enabled,
remote_blobs: uploads,
})
}
ControlFlow::Break(uploads) => Self::UploadsFailed(
anyhow::anyhow!("One or multiple blobs failed to upload to Azure"),
AzureWithSimpleTestBlobs {
enabled,
remote_blobs: uploads,
},
),
}
}
async fn teardown(self) {
match self {
Self::Disabled => {}
Self::Enabled(ctx) | Self::UploadsFailed(_, ctx) => {
cleanup(&ctx.enabled.client, ctx.remote_blobs).await;
}
}
}
}
fn create_azure_client(
max_keys_per_list_response: Option<i32>,
) -> anyhow::Result<Arc<GenericRemoteStorage>> {
use rand::Rng;
let remote_storage_azure_container = env::var("REMOTE_STORAGE_AZURE_CONTAINER").context(
"`REMOTE_STORAGE_AZURE_CONTAINER` env var is not set, but real Azure tests are enabled",
)?;
let remote_storage_azure_region = env::var("REMOTE_STORAGE_AZURE_REGION").context(
"`REMOTE_STORAGE_AZURE_REGION` env var is not set, but real Azure tests are enabled",
)?;
// due to how time works, we've had test runners use the same nanos as bucket prefixes.
// millis is just a debugging aid for easier finding the prefix later.
let millis = std::time::SystemTime::now()
.duration_since(UNIX_EPOCH)
.context("random Azure test prefix part calculation")?
.as_millis();
// because nanos can be the same for two threads so can millis, add randomness
let random = rand::thread_rng().gen::<u32>();
let remote_storage_config = RemoteStorageConfig {
max_concurrent_syncs: NonZeroUsize::new(100).unwrap(),
max_sync_errors: NonZeroU32::new(5).unwrap(),
storage: RemoteStorageKind::AzureContainer(AzureConfig {
container_name: remote_storage_azure_container,
container_region: remote_storage_azure_region,
prefix_in_container: Some(format!("test_{millis}_{random:08x}/")),
concurrency_limit: NonZeroUsize::new(100).unwrap(),
max_keys_per_list_response,
}),
};
Ok(Arc::new(
GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?,
))
}
struct Uploads {
prefixes: HashSet<RemotePath>,
blobs: HashSet<RemotePath>,
}
async fn upload_azure_data(
client: &Arc<GenericRemoteStorage>,
base_prefix_str: &'static str,
upload_tasks_count: usize,
) -> ControlFlow<Uploads, Uploads> {
info!("Creating {upload_tasks_count} Azure files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let prefix = format!("{base_prefix_str}/sub_prefix_{i}/");
let blob_prefix = RemotePath::new(Utf8Path::new(&prefix))
.with_context(|| format!("{prefix:?} to RemotePath conversion"))?;
let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}")));
debug!("Creating remote item {i} at path {blob_path:?}");
let data = format!("remote blob data {i}").into_bytes();
let data_len = data.len();
task_client
.upload(std::io::Cursor::new(data), data_len, &blob_path, None)
.await?;
Ok::<_, anyhow::Error>((blob_prefix, blob_path))
});
}
let mut upload_tasks_failed = false;
let mut uploaded_prefixes = HashSet::with_capacity(upload_tasks_count);
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok((upload_prefix, upload_path)) => {
uploaded_prefixes.insert(upload_prefix);
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
let uploads = Uploads {
prefixes: uploaded_prefixes,
blobs: uploaded_blobs,
};
if upload_tasks_failed {
ControlFlow::Break(uploads)
} else {
ControlFlow::Continue(uploads)
}
}
async fn cleanup(client: &Arc<GenericRemoteStorage>, objects_to_delete: HashSet<RemotePath>) {
info!(
"Removing {} objects from the remote storage during cleanup",
objects_to_delete.len()
);
let mut delete_tasks = JoinSet::new();
for object_to_delete in objects_to_delete {
let task_client = Arc::clone(client);
delete_tasks.spawn(async move {
debug!("Deleting remote item at path {object_to_delete:?}");
task_client
.delete(&object_to_delete)
.await
.with_context(|| format!("{object_to_delete:?} removal"))
});
}
while let Some(task_run_result) = delete_tasks.join_next().await {
match task_run_result {
Ok(task_result) => match task_result {
Ok(()) => {}
Err(e) => error!("Delete task failed: {e:?}"),
},
Err(join_err) => error!("Delete task did not finish correctly: {join_err}"),
}
}
}
// Uploads files `folder{j}/blob{i}.txt`. See test description for more details.
async fn upload_simple_azure_data(
client: &Arc<GenericRemoteStorage>,
upload_tasks_count: usize,
) -> ControlFlow<HashSet<RemotePath>, HashSet<RemotePath>> {
info!("Creating {upload_tasks_count} Azure files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let blob_path = PathBuf::from(format!("folder{}/blob_{}.txt", i / 7, i));
let blob_path = RemotePath::new(
Utf8Path::from_path(blob_path.as_path()).expect("must be valid blob path"),
)
.with_context(|| format!("{blob_path:?} to RemotePath conversion"))?;
debug!("Creating remote item {i} at path {blob_path:?}");
let data = format!("remote blob data {i}").into_bytes();
let data_len = data.len();
task_client
.upload(std::io::Cursor::new(data), data_len, &blob_path, None)
.await?;
Ok::<_, anyhow::Error>(blob_path)
});
}
let mut upload_tasks_failed = false;
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok(upload_path) => {
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
if upload_tasks_failed {
ControlFlow::Break(uploaded_blobs)
} else {
ControlFlow::Continue(uploaded_blobs)
}
}

View File

@@ -119,7 +119,6 @@ pub fn api_error_handler(api_error: ApiError) -> Response<Body> {
match api_error {
ApiError::ResourceUnavailable(_) => info!("Error processing HTTP request: {api_error:#}"),
ApiError::NotFound(_) => info!("Error processing HTTP request: {api_error:#}"),
ApiError::InternalServerError(_) => error!("Error processing HTTP request: {api_error:?}"),
_ => error!("Error processing HTTP request: {api_error:#}"),
}

View File

@@ -27,8 +27,8 @@ and old one if it exists.
* the filecache: a struct that allows communication with the Postgres file cache.
On startup, we connect to the filecache and hold on to the connection for the
entire monitor lifetime.
* the cgroup watcher: the `CgroupWatcher` polls the `neon-postgres` cgroup's memory
usage and sends rolling aggregates to the runner.
* the cgroup watcher: the `CgroupWatcher` manages the `neon-postgres` cgroup by
listening for `memory.high` events and setting its `memory.{high,max}` values.
* the runner: the runner marries the filecache and cgroup watcher together,
communicating with the agent throught the `Dispatcher`, and then calling filecache
and cgroup watcher functions as needed to upscale and downscale

View File

@@ -1,38 +1,161 @@
use std::fmt::{self, Debug, Formatter};
use std::time::{Duration, Instant};
use anyhow::{anyhow, Context};
use cgroups_rs::{
hierarchies::{self, is_cgroup2_unified_mode},
memory::MemController,
Subsystem,
use std::{
fmt::{Debug, Display},
fs,
pin::pin,
sync::atomic::{AtomicU64, Ordering},
};
use tokio::sync::watch;
use anyhow::{anyhow, bail, Context};
use cgroups_rs::{
freezer::FreezerController,
hierarchies::{self, is_cgroup2_unified_mode, UNIFIED_MOUNTPOINT},
memory::MemController,
MaxValue,
Subsystem::{Freezer, Mem},
};
use inotify::{EventStream, Inotify, WatchMask};
use tokio::sync::mpsc::{self, error::TryRecvError};
use tokio::time::{Duration, Instant};
use tokio_stream::{Stream, StreamExt};
use tracing::{info, warn};
use crate::protocol::Resources;
use crate::MiB;
/// Monotonically increasing counter of the number of memory.high events
/// the cgroup has experienced.
///
/// We use this to determine if a modification to the `memory.events` file actually
/// changed the `high` field. If not, we don't care about the change. When we
/// read the file, we check the `high` field in the file against `MEMORY_EVENT_COUNT`
/// to see if it changed since last time.
pub static MEMORY_EVENT_COUNT: AtomicU64 = AtomicU64::new(0);
/// Monotonically increasing counter that gives each cgroup event a unique id.
///
/// This allows us to answer questions like "did this upscale arrive before this
/// memory.high?". This static is also used by the `Sequenced` type to "tag" values
/// with a sequence number. As such, prefer to used the `Sequenced` type rather
/// than this static directly.
static EVENT_SEQUENCE_NUMBER: AtomicU64 = AtomicU64::new(0);
/// A memory event type reported in memory.events.
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum MemoryEvent {
Low,
High,
Max,
Oom,
OomKill,
OomGroupKill,
}
impl MemoryEvent {
fn as_str(&self) -> &str {
match self {
MemoryEvent::Low => "low",
MemoryEvent::High => "high",
MemoryEvent::Max => "max",
MemoryEvent::Oom => "oom",
MemoryEvent::OomKill => "oom_kill",
MemoryEvent::OomGroupKill => "oom_group_kill",
}
}
}
impl Display for MemoryEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
/// Configuration for a `CgroupWatcher`
#[derive(Debug, Clone)]
pub struct Config {
/// Interval at which we should be fetching memory statistics
memory_poll_interval: Duration,
// The target difference between the total memory reserved for the cgroup
// and the value of the cgroup's memory.high.
//
// In other words, memory.high + oom_buffer_bytes will equal the total memory that the cgroup may
// use (equal to system memory, minus whatever's taken out for the file cache).
oom_buffer_bytes: u64,
/// The number of samples used in constructing aggregated memory statistics
memory_history_len: usize,
/// The number of most recent samples that will be periodically logged.
///
/// Each sample is logged exactly once. Increasing this value means that recent samples will be
/// logged less frequently, and vice versa.
///
/// For simplicity, this value must be greater than or equal to `memory_history_len`.
memory_history_log_interval: usize,
// The amount of memory, in bytes, below a proposed new value for
// memory.high that the cgroup's memory usage must be for us to downscale
//
// In other words, we can downscale only when:
//
// memory.current + memory_high_buffer_bytes < (proposed) memory.high
//
// TODO: there's some minor issues with this approach -- in particular, that we might have
// memory in use by the kernel's page cache that we're actually ok with getting rid of.
pub(crate) memory_high_buffer_bytes: u64,
// The maximum duration, in milliseconds, that we're allowed to pause
// the cgroup for while waiting for the autoscaler-agent to upscale us
max_upscale_wait: Duration,
// The required minimum time, in milliseconds, that we must wait before re-freezing
// the cgroup while waiting for the autoscaler-agent to upscale us.
do_not_freeze_more_often_than: Duration,
// The amount of memory, in bytes, that we should periodically increase memory.high
// by while waiting for the autoscaler-agent to upscale us.
//
// This exists to avoid the excessive throttling that happens when a cgroup is above its
// memory.high for too long. See more here:
// https://github.com/neondatabase/autoscaling/issues/44#issuecomment-1522487217
memory_high_increase_by_bytes: u64,
// The period, in milliseconds, at which we should repeatedly increase the value
// of the cgroup's memory.high while we're waiting on upscaling and memory.high
// is still being hit.
//
// Technically speaking, this actually serves as a rate limit to moderate responding to
// memory.high events, but these are roughly equivalent if the process is still allocating
// memory.
memory_high_increase_every: Duration,
}
impl Config {
/// Calculate the new value for the cgroups memory.high based on system memory
pub fn calculate_memory_high_value(&self, total_system_mem: u64) -> u64 {
total_system_mem.saturating_sub(self.oom_buffer_bytes)
}
}
impl Default for Config {
fn default() -> Self {
Self {
memory_poll_interval: Duration::from_millis(100),
memory_history_len: 5, // use 500ms of history for decision-making
memory_history_log_interval: 20, // but only log every ~2s (otherwise it's spammy)
oom_buffer_bytes: 100 * MiB,
memory_high_buffer_bytes: 100 * MiB,
// while waiting for upscale, don't freeze for more than 20ms every 1s
max_upscale_wait: Duration::from_millis(20),
do_not_freeze_more_often_than: Duration::from_millis(1000),
// while waiting for upscale, increase memory.high by 10MiB every 25ms
memory_high_increase_by_bytes: 10 * MiB,
memory_high_increase_every: Duration::from_millis(25),
}
}
}
/// Used to represent data that is associated with a certain point in time, such
/// as an upscale request or memory.high event.
///
/// Internally, creating a `Sequenced` uses a static atomic counter to obtain
/// a unique sequence number. Sequence numbers are monotonically increasing,
/// allowing us to answer questions like "did this upscale happen after this
/// memory.high event?" by comparing the sequence numbers of the two events.
#[derive(Debug, Clone)]
pub struct Sequenced<T> {
seqnum: u64,
data: T,
}
impl<T> Sequenced<T> {
pub fn new(data: T) -> Self {
Self {
seqnum: EVENT_SEQUENCE_NUMBER.fetch_add(1, Ordering::AcqRel),
data,
}
}
}
@@ -47,14 +170,74 @@ impl Default for Config {
pub struct CgroupWatcher {
pub config: Config,
/// The sequence number of the last upscale.
///
/// If we receive a memory.high event that has a _lower_ sequence number than
/// `last_upscale_seqnum`, then we know it occured before the upscale, and we
/// can safely ignore it.
///
/// Note: Like the `events` field, this doesn't _need_ interior mutability but we
/// use it anyways so that methods take `&self`, not `&mut self`.
last_upscale_seqnum: AtomicU64,
/// A channel on which we send messages to request upscale from the dispatcher.
upscale_requester: mpsc::Sender<()>,
/// The actual cgroup we are watching and managing.
cgroup: cgroups_rs::Cgroup,
}
/// Read memory.events for the desired event type.
///
/// `path` specifies the path to the desired `memory.events` file.
/// For more info, see the `memory.events` section of the [kernel docs]
/// <https://docs.kernel.org/admin-guide/cgroup-v2.html#memory-interface-files>
fn get_event_count(path: &str, event: MemoryEvent) -> anyhow::Result<u64> {
let contents = fs::read_to_string(path)
.with_context(|| format!("failed to read memory.events from {path}"))?;
// Then contents of the file look like:
// low 42
// high 101
// ...
contents
.lines()
.filter_map(|s| s.split_once(' '))
.find(|(e, _)| *e == event.as_str())
.ok_or_else(|| anyhow!("failed to find entry for memory.{event} events in {path}"))
.and_then(|(_, count)| {
count
.parse::<u64>()
.with_context(|| format!("failed to parse memory.{event} as u64"))
})
}
/// Create an event stream that produces events whenever the file at the provided
/// path is modified.
fn create_file_watcher(path: &str) -> anyhow::Result<EventStream<[u8; 1024]>> {
info!("creating file watcher for {path}");
let inotify = Inotify::init().context("failed to initialize file watcher")?;
inotify
.watches()
.add(path, WatchMask::MODIFY)
.with_context(|| format!("failed to start watching {path}"))?;
inotify
// The inotify docs use [0u8; 1024] so we'll just copy them. We only need
// to store one event at a time - if the event gets written over, that's
// ok. We still see that there is an event. For more information, see:
// https://man7.org/linux/man-pages/man7/inotify.7.html
.into_event_stream([0u8; 1024])
.context("failed to start inotify event stream")
}
impl CgroupWatcher {
/// Create a new `CgroupWatcher`.
#[tracing::instrument(skip_all, fields(%name))]
pub fn new(name: String) -> anyhow::Result<Self> {
pub fn new(
name: String,
// A channel on which to send upscale requests
upscale_requester: mpsc::Sender<()>,
) -> anyhow::Result<(Self, impl Stream<Item = Sequenced<u64>>)> {
// TODO: clarify exactly why we need v2
// Make sure cgroups v2 (aka unified) are supported
if !is_cgroup2_unified_mode() {
@@ -62,203 +245,410 @@ impl CgroupWatcher {
}
let cgroup = cgroups_rs::Cgroup::load(hierarchies::auto(), &name);
Ok(Self {
cgroup,
config: Default::default(),
})
// Start monitoring the cgroup for memory events. In general, for
// cgroups v2 (aka unified), metrics are reported in files like
// > `/sys/fs/cgroup/{name}/{metric}`
// We are looking for `memory.high` events, which are stored in the
// file `memory.events`. For more info, see the `memory.events` section
// of https://docs.kernel.org/admin-guide/cgroup-v2.html#memory-interface-files
let path = format!("{}/{}/memory.events", UNIFIED_MOUNTPOINT, &name);
let memory_events = create_file_watcher(&path)
.with_context(|| format!("failed to create event watcher for {path}"))?
// This would be nice with with .inspect_err followed by .ok
.filter_map(move |_| match get_event_count(&path, MemoryEvent::High) {
Ok(high) => Some(high),
Err(error) => {
// TODO: Might want to just panic here
warn!(?error, "failed to read high events count from {}", &path);
None
}
})
// Only report the event if the memory.high count increased
.filter_map(|high| {
if MEMORY_EVENT_COUNT.fetch_max(high, Ordering::AcqRel) < high {
Some(high)
} else {
None
}
})
.map(Sequenced::new);
let initial_count = get_event_count(
&format!("{}/{}/memory.events", UNIFIED_MOUNTPOINT, &name),
MemoryEvent::High,
)?;
info!(initial_count, "initial memory.high event count");
// Hard update `MEMORY_EVENT_COUNT` since there could have been processes
// running in the cgroup before that caused it to be non-zero.
MEMORY_EVENT_COUNT.fetch_max(initial_count, Ordering::AcqRel);
Ok((
Self {
cgroup,
upscale_requester,
last_upscale_seqnum: AtomicU64::new(0),
config: Default::default(),
},
memory_events,
))
}
/// The entrypoint for the `CgroupWatcher`.
#[tracing::instrument(skip_all)]
pub async fn watch(
pub async fn watch<E>(
&self,
updates: watch::Sender<(Instant, MemoryHistory)>,
) -> anyhow::Result<()> {
// this requirement makes the code a bit easier to work with; see the config for more.
assert!(self.config.memory_history_len <= self.config.memory_history_log_interval);
// These are ~dependency injected~ (fancy, I know) because this function
// should never return.
// -> therefore: when we tokio::spawn it, we don't await the JoinHandle.
// -> therefore: if we want to stick it in an Arc so many threads can access
// it, methods can never take mutable access.
// - note: we use the Arc strategy so that a) we can call this function
// right here and b) the runner can call the set/get_memory methods
// -> since calling recv() on a tokio::sync::mpsc::Receiver takes &mut self,
// we just pass them in here instead of holding them in fields, as that
// would require this method to take &mut self.
mut upscales: mpsc::Receiver<Sequenced<Resources>>,
events: E,
) -> anyhow::Result<()>
where
E: Stream<Item = Sequenced<u64>>,
{
let mut wait_to_freeze = pin!(tokio::time::sleep(Duration::ZERO));
let mut last_memory_high_increase_at: Option<Instant> = None;
let mut events = pin!(events);
let mut ticker = tokio::time::interval(self.config.memory_poll_interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
// ticker.reset_immediately(); // FIXME: enable this once updating to tokio >= 1.30.0
// Are we waiting to be upscaled? Could be true if we request upscale due
// to a memory.high event and it does not arrive in time.
let mut waiting_on_upscale = false;
let mem_controller = self.memory()?;
loop {
tokio::select! {
upscale = upscales.recv() => {
let Sequenced { seqnum, data } = upscale
.context("failed to listen on upscale notification channel")?;
waiting_on_upscale = false;
last_memory_high_increase_at = None;
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
info!(cpu = data.cpu, mem_bytes = data.mem, "received upscale");
}
event = events.next() => {
let Some(Sequenced { seqnum, .. }) = event else {
bail!("failed to listen for memory.high events")
};
// The memory.high came before our last upscale, so we consider
// it resolved
if self.last_upscale_seqnum.fetch_max(seqnum, Ordering::AcqRel) > seqnum {
info!(
"received memory.high event, but it came before our last upscale -> ignoring it"
);
continue;
}
// buffer for samples that will be logged. once full, it remains so.
let history_log_len = self.config.memory_history_log_interval;
let mut history_log_buf = vec![MemoryStatus::zeroed(); history_log_len];
// The memory.high came after our latest upscale. We don't
// want to do anything yet, so peek the next event in hopes
// that it's an upscale.
if let Some(upscale_num) = self
.upscaled(&mut upscales)
.context("failed to check if we were upscaled")?
{
if upscale_num > seqnum {
info!(
"received memory.high event, but it came before our last upscale -> ignoring it"
);
continue;
}
}
for t in 0_u64.. {
ticker.tick().await;
// If it's been long enough since we last froze, freeze the
// cgroup and request upscale
if wait_to_freeze.is_elapsed() {
info!("received memory.high event -> requesting upscale");
waiting_on_upscale = self
.handle_memory_high_event(&mut upscales)
.await
.context("failed to handle upscale")?;
wait_to_freeze
.as_mut()
.reset(Instant::now() + self.config.do_not_freeze_more_often_than);
continue;
}
let now = Instant::now();
let mem = Self::memory_usage(mem_controller);
// Ok, we can't freeze, just request upscale
if !waiting_on_upscale {
info!("received memory.high event, but too soon to refreeze -> requesting upscale");
let i = t as usize % history_log_len;
history_log_buf[i] = mem;
// Make check to make sure we haven't been upscaled in the
// meantine (can happen if the agent independently decides
// to upscale us again)
if self
.upscaled(&mut upscales)
.context("failed to check if we were upscaled")?
.is_some()
{
info!("no need to request upscaling because we got upscaled");
continue;
}
self.upscale_requester
.send(())
.await
.context("failed to request upscale")?;
waiting_on_upscale = true;
continue;
}
// We're taking *at most* memory_history_len values; we may be bounded by the total
// number of samples that have come in so far.
let samples_count = (t + 1).min(self.config.memory_history_len as u64) as usize;
// NB: in `ring_buf_recent_values_iter`, `i` is *inclusive*, which matches the fact
// that we just inserted a value there, so the end of the iterator will *include* the
// value at i, rather than stopping just short of it.
let samples = ring_buf_recent_values_iter(&history_log_buf, i, samples_count);
// Shoot, we can't freeze or and we're still waiting on upscale,
// increase memory.high to reduce throttling
let can_increase_memory_high = match last_memory_high_increase_at {
None => true,
Some(t) => t.elapsed() > self.config.memory_high_increase_every,
};
if can_increase_memory_high {
info!(
"received memory.high event, \
but too soon to refreeze and already requested upscale \
-> increasing memory.high"
);
let summary = MemoryHistory {
avg_non_reclaimable: samples.map(|h| h.non_reclaimable).sum::<u64>()
/ samples_count as u64,
samples_count,
samples_span: self.config.memory_poll_interval * (samples_count - 1) as u32,
// Make check to make sure we haven't been upscaled in the
// meantine (can happen if the agent independently decides
// to upscale us again)
if self
.upscaled(&mut upscales)
.context("failed to check if we were upscaled")?
.is_some()
{
info!("no need to increase memory.high because got upscaled");
continue;
}
// Request upscale anyways (the agent will handle deduplicating
// requests)
self.upscale_requester
.send(())
.await
.context("failed to request upscale")?;
let memory_high =
self.get_memory_high_bytes().context("failed to get memory.high")?;
let new_high = memory_high + self.config.memory_high_increase_by_bytes;
info!(
current_high_bytes = memory_high,
new_high_bytes = new_high,
"updating memory.high"
);
self.set_memory_high_bytes(new_high)
.context("failed to set memory.high")?;
last_memory_high_increase_at = Some(Instant::now());
continue;
}
info!("received memory.high event, but can't do anything");
}
};
// Log the current history if it's time to do so. Because `history_log_buf` has length
// equal to the logging interval, we can just log the entire buffer every time we set
// the last entry, which also means that for this log line, we can ignore that it's a
// ring buffer (because all the entries are in order of increasing time).
if i == history_log_len - 1 {
info!(
history = ?MemoryStatus::debug_slice(&history_log_buf),
summary = ?summary,
"Recent cgroup memory statistics history"
);
}
updates
.send((now, summary))
.context("failed to send MemoryHistory")?;
}
}
unreachable!()
/// Handle a `memory.high`, returning whether we are still waiting on upscale
/// by the time the function returns.
///
/// The general plan for handling a `memory.high` event is as follows:
/// 1. Freeze the cgroup
/// 2. Start a timer for `self.config.max_upscale_wait`
/// 3. Request upscale
/// 4. After the timer elapses or we receive upscale, thaw the cgroup.
/// 5. Return whether or not we are still waiting for upscale. If we are,
/// we'll increase the cgroups memory.high to avoid getting oom killed
#[tracing::instrument(skip_all)]
async fn handle_memory_high_event(
&self,
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
) -> anyhow::Result<bool> {
// Immediately freeze the cgroup before doing anything else.
info!("received memory.high event -> freezing cgroup");
self.freeze().context("failed to freeze cgroup")?;
// We'll use this for logging durations
let start_time = Instant::now();
// Await the upscale until we have to unfreeze
let timed =
tokio::time::timeout(self.config.max_upscale_wait, self.await_upscale(upscales));
// Request the upscale
info!(
wait = ?self.config.max_upscale_wait,
"sending request for immediate upscaling",
);
self.upscale_requester
.send(())
.await
.context("failed to request upscale")?;
let waiting_on_upscale = match timed.await {
Ok(Ok(())) => {
info!(elapsed = ?start_time.elapsed(), "received upscale in time");
false
}
// **important**: unfreeze the cgroup before ?-reporting the error
Ok(Err(e)) => {
info!("error waiting for upscale -> thawing cgroup");
self.thaw()
.context("failed to thaw cgroup after errored waiting for upscale")?;
Err(e.context("failed to await upscale"))?
}
Err(_) => {
info!(elapsed = ?self.config.max_upscale_wait, "timed out waiting for upscale");
true
}
};
info!("thawing cgroup");
self.thaw().context("failed to thaw cgroup")?;
Ok(waiting_on_upscale)
}
/// Checks whether we were just upscaled, returning the upscale's sequence
/// number if so.
#[tracing::instrument(skip_all)]
fn upscaled(
&self,
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
) -> anyhow::Result<Option<u64>> {
let Sequenced { seqnum, data } = match upscales.try_recv() {
Ok(upscale) => upscale,
Err(TryRecvError::Empty) => return Ok(None),
Err(TryRecvError::Disconnected) => {
bail!("upscale notification channel was disconnected")
}
};
// Make sure to update the last upscale sequence number
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
info!(cpu = data.cpu, mem_bytes = data.mem, "received upscale");
Ok(Some(seqnum))
}
/// Await an upscale event, discarding any `memory.high` events received in
/// the process.
///
/// This is used in `handle_memory_high_event`, where we need to listen
/// for upscales in particular so we know if we can thaw the cgroup early.
#[tracing::instrument(skip_all)]
async fn await_upscale(
&self,
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
) -> anyhow::Result<()> {
let Sequenced { seqnum, .. } = upscales
.recv()
.await
.context("error listening for upscales")?;
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
Ok(())
}
/// Get the cgroup's name.
pub fn path(&self) -> &str {
self.cgroup.path()
}
}
// Methods for manipulating the actual cgroup
impl CgroupWatcher {
/// Get a handle on the freezer subsystem.
fn freezer(&self) -> anyhow::Result<&FreezerController> {
if let Some(Freezer(freezer)) = self
.cgroup
.subsystems()
.iter()
.find(|sub| matches!(sub, Freezer(_)))
{
Ok(freezer)
} else {
anyhow::bail!("could not find freezer subsystem")
}
}
/// Attempt to freeze the cgroup.
pub fn freeze(&self) -> anyhow::Result<()> {
self.freezer()
.context("failed to get freezer subsystem")?
.freeze()
.context("failed to freeze")
}
/// Attempt to thaw the cgroup.
pub fn thaw(&self) -> anyhow::Result<()> {
self.freezer()
.context("failed to get freezer subsystem")?
.thaw()
.context("failed to thaw")
}
/// Get a handle on the memory subsystem.
///
/// Note: this method does not require `self.memory_update_lock` because
/// getting a handle to the subsystem does not access any of the files we
/// care about, such as memory.high and memory.events
fn memory(&self) -> anyhow::Result<&MemController> {
self.cgroup
if let Some(Mem(memory)) = self
.cgroup
.subsystems()
.iter()
.find_map(|sub| match sub {
Subsystem::Mem(c) => Some(c),
_ => None,
.find(|sub| matches!(sub, Mem(_)))
{
Ok(memory)
} else {
anyhow::bail!("could not find memory subsystem")
}
}
/// Get cgroup current memory usage.
pub fn current_memory_usage(&self) -> anyhow::Result<u64> {
Ok(self
.memory()
.context("failed to get memory subsystem")?
.memory_stat()
.usage_in_bytes)
}
/// Set cgroup memory.high threshold.
pub fn set_memory_high_bytes(&self, bytes: u64) -> anyhow::Result<()> {
self.set_memory_high_internal(MaxValue::Value(u64::min(bytes, i64::MAX as u64) as i64))
}
/// Set the cgroup's memory.high to 'max', disabling it.
pub fn unset_memory_high(&self) -> anyhow::Result<()> {
self.set_memory_high_internal(MaxValue::Max)
}
fn set_memory_high_internal(&self, value: MaxValue) -> anyhow::Result<()> {
self.memory()
.context("failed to get memory subsystem")?
.set_mem(cgroups_rs::memory::SetMemory {
low: None,
high: Some(value),
min: None,
max: None,
})
.ok_or_else(|| anyhow!("could not find memory subsystem"))
.map_err(anyhow::Error::from)
}
/// Given a handle on the memory subsystem, returns the current memory information
fn memory_usage(mem_controller: &MemController) -> MemoryStatus {
let stat = mem_controller.memory_stat().stat;
MemoryStatus {
non_reclaimable: stat.active_anon + stat.inactive_anon,
/// Get memory.high threshold.
pub fn get_memory_high_bytes(&self) -> anyhow::Result<u64> {
let high = self
.memory()
.context("failed to get memory subsystem while getting memory statistics")?
.get_mem()
.map(|mem| mem.high)
.context("failed to get memory statistics from subsystem")?;
match high {
Some(MaxValue::Max) => Ok(i64::MAX as u64),
Some(MaxValue::Value(high)) => Ok(high as u64),
None => anyhow::bail!("failed to read memory.high from memory subsystem"),
}
}
}
// Helper function for `CgroupWatcher::watch`
fn ring_buf_recent_values_iter<T>(
buf: &[T],
last_value_idx: usize,
count: usize,
) -> impl '_ + Iterator<Item = &T> {
// Assertion carried over from `CgroupWatcher::watch`, to make the logic in this function
// easier (we only have to add `buf.len()` once, rather than a dynamic number of times).
assert!(count <= buf.len());
buf.iter()
// 'cycle' because the values could wrap around
.cycle()
// with 'cycle', this skip is more like 'offset', and functionally this is
// offsettting by 'last_value_idx - count (mod buf.len())', but we have to be
// careful to avoid underflow, so we pre-add buf.len().
// The '+ 1' is because `last_value_idx` is inclusive, rather than exclusive.
.skip((buf.len() + last_value_idx + 1 - count) % buf.len())
.take(count)
}
/// Summary of recent memory usage
#[derive(Debug, Copy, Clone)]
pub struct MemoryHistory {
/// Rolling average of non-reclaimable memory usage samples over the last `history_period`
pub avg_non_reclaimable: u64,
/// The number of samples used to construct this summary
pub samples_count: usize,
/// Total timespan between the first and last sample used for this summary
pub samples_span: Duration,
}
#[derive(Debug, Copy, Clone)]
pub struct MemoryStatus {
non_reclaimable: u64,
}
impl MemoryStatus {
fn zeroed() -> Self {
MemoryStatus { non_reclaimable: 0 }
}
fn debug_slice(slice: &[Self]) -> impl '_ + Debug {
struct DS<'a>(&'a [MemoryStatus]);
impl<'a> Debug for DS<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("[MemoryStatus]")
.field(
"non_reclaimable[..]",
&Fields(self.0, |stat: &MemoryStatus| {
BytesToGB(stat.non_reclaimable)
}),
)
.finish()
}
}
struct Fields<'a, F>(&'a [MemoryStatus], F);
impl<'a, F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'a, F> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_list().entries(self.0.iter().map(&self.1)).finish()
}
}
struct BytesToGB(u64);
impl Debug for BytesToGB {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.write_fmt(format_args!(
"{:.3}Gi",
self.0 as f64 / (1_u64 << 30) as f64
))
}
}
DS(slice)
}
}
#[cfg(test)]
mod tests {
#[test]
fn ring_buf_iter() {
let buf = vec![0_i32, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let values = |offset, count| {
super::ring_buf_recent_values_iter(&buf, offset, count)
.copied()
.collect::<Vec<i32>>()
};
// Boundary conditions: start, end, and entire thing:
assert_eq!(values(0, 1), [0]);
assert_eq!(values(3, 4), [0, 1, 2, 3]);
assert_eq!(values(9, 4), [6, 7, 8, 9]);
assert_eq!(values(9, 10), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
// "normal" operation: no wraparound
assert_eq!(values(7, 4), [4, 5, 6, 7]);
// wraparound:
assert_eq!(values(0, 4), [7, 8, 9, 0]);
assert_eq!(values(1, 4), [8, 9, 0, 1]);
assert_eq!(values(2, 4), [9, 0, 1, 2]);
assert_eq!(values(2, 10), [3, 4, 5, 6, 7, 8, 9, 0, 1, 2]);
}
}

View File

@@ -12,10 +12,12 @@ use futures::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use tokio::sync::mpsc;
use tracing::info;
use crate::cgroup::Sequenced;
use crate::protocol::{
OutboundMsg, ProtocolRange, ProtocolResponse, ProtocolVersion, PROTOCOL_MAX_VERSION,
OutboundMsg, ProtocolRange, ProtocolResponse, ProtocolVersion, Resources, PROTOCOL_MAX_VERSION,
PROTOCOL_MIN_VERSION,
};
@@ -34,6 +36,13 @@ pub struct Dispatcher {
/// We send messages to the agent through `sink`
sink: SplitSink<WebSocket, Message>,
/// Used to notify the cgroup when we are upscaled.
pub(crate) notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
/// When the cgroup requests upscale it will send on this channel. In response
/// we send an `UpscaleRequst` to the agent.
pub(crate) request_upscale_events: mpsc::Receiver<()>,
/// The protocol version we have agreed to use with the agent. This is negotiated
/// during the creation of the dispatcher, and should be the highest shared protocol
/// version.
@@ -52,7 +61,11 @@ impl Dispatcher {
/// 1. Wait for the agent to sent the range of protocols it supports.
/// 2. Send a protocol version that works for us as well, or an error if there
/// is no compatible version.
pub async fn new(stream: WebSocket) -> anyhow::Result<Self> {
pub async fn new(
stream: WebSocket,
notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
request_upscale_events: mpsc::Receiver<()>,
) -> anyhow::Result<Self> {
let (mut sink, mut source) = stream.split();
// Figure out the highest protocol version we both support
@@ -106,10 +119,22 @@ impl Dispatcher {
Ok(Self {
sink,
source,
notify_upscale_events,
request_upscale_events,
proto_version: highest_shared_version,
})
}
/// Notify the cgroup manager that we have received upscale and wait for
/// the acknowledgement.
#[tracing::instrument(skip_all, fields(?resources))]
pub async fn notify_upscale(&self, resources: Sequenced<Resources>) -> anyhow::Result<()> {
self.notify_upscale_events
.send(resources)
.await
.context("failed to send resources and oneshot sender across channel")
}
/// Send a message to the agent.
///
/// Although this function is small, it has one major benefit: it is the only

View File

@@ -5,16 +5,18 @@
//! all functionality.
use std::fmt::Debug;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{bail, Context};
use axum::extract::ws::{Message, WebSocket};
use futures::StreamExt;
use tokio::sync::{broadcast, watch};
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use crate::cgroup::{self, CgroupWatcher};
use crate::cgroup::{CgroupWatcher, Sequenced};
use crate::dispatcher::Dispatcher;
use crate::filecache::{FileCacheConfig, FileCacheState};
use crate::protocol::{InboundMsg, InboundMsgKind, OutboundMsg, OutboundMsgKind, Resources};
@@ -26,7 +28,7 @@ use crate::{bytes_to_mebibytes, get_total_system_memory, spawn_with_cancel, Args
pub struct Runner {
config: Config,
filecache: Option<FileCacheState>,
cgroup: Option<CgroupState>,
cgroup: Option<Arc<CgroupWatcher>>,
dispatcher: Dispatcher,
/// We "mint" new message ids by incrementing this counter and taking the value.
@@ -43,14 +45,6 @@ pub struct Runner {
kill: broadcast::Receiver<()>,
}
#[derive(Debug)]
struct CgroupState {
watcher: watch::Receiver<(Instant, cgroup::MemoryHistory)>,
/// If [`cgroup::MemoryHistory::avg_non_reclaimable`] exceeds `threshold`, we send upscale
/// requests.
threshold: u64,
}
/// Configuration for a `Runner`
#[derive(Debug)]
pub struct Config {
@@ -68,56 +62,16 @@ pub struct Config {
/// upscale resource amounts (because we might not *actually* have been upscaled yet). This field
/// should be removed once we have a better solution there.
sys_buffer_bytes: u64,
/// Minimum fraction of total system memory reserved *before* the the cgroup threshold; in
/// other words, providing a ceiling for the highest value of the threshold by enforcing that
/// there's at least `cgroup_min_overhead_fraction` of the total memory remaining beyond the
/// threshold.
///
/// For example, a value of `0.1` means that 10% of total memory must remain after exceeding
/// the threshold, so the value of the cgroup threshold would always be capped at 90% of total
/// memory.
///
/// The default value of `0.15` means that we *guarantee* sending upscale requests if the
/// cgroup is using more than 85% of total memory (even if we're *not* separately reserving
/// memory for the file cache).
cgroup_min_overhead_fraction: f64,
cgroup_downscale_threshold_buffer_bytes: u64,
}
impl Default for Config {
fn default() -> Self {
Self {
sys_buffer_bytes: 100 * MiB,
cgroup_min_overhead_fraction: 0.15,
cgroup_downscale_threshold_buffer_bytes: 100 * MiB,
}
}
}
impl Config {
fn cgroup_threshold(&self, total_mem: u64, file_cache_disk_size: u64) -> u64 {
// If the file cache is in tmpfs, then it will count towards shmem usage of the cgroup,
// and thus be non-reclaimable, so we should allow for additional memory usage.
//
// If the file cache sits on disk, our desired stable system state is for it to be fully
// page cached (its contents should only be paged to/from disk in situations where we can't
// upscale fast enough). Page-cached memory is reclaimable, so we need to lower the
// threshold for non-reclaimable memory so we scale up *before* the kernel starts paging
// out the file cache.
let memory_remaining_for_cgroup = total_mem.saturating_sub(file_cache_disk_size);
// Even if we're not separately making room for the file cache (if it's in tmpfs), we still
// want our threshold to be met gracefully instead of letting postgres get OOM-killed.
// So we guarantee that there's at least `cgroup_min_overhead_fraction` of total memory
// remaining above the threshold.
let max_threshold = (total_mem as f64 * (1.0 - self.cgroup_min_overhead_fraction)) as u64;
memory_remaining_for_cgroup.min(max_threshold)
}
}
impl Runner {
/// Create a new monitor.
#[tracing::instrument(skip_all, fields(?config, ?args))]
@@ -133,7 +87,12 @@ impl Runner {
"invalid monitor Config: sys_buffer_bytes cannot be 0"
);
let dispatcher = Dispatcher::new(ws)
// *NOTE*: the dispatcher and cgroup manager talk through these channels
// so make sure they each get the correct half, nothing is droppped, etc.
let (notified_send, notified_recv) = mpsc::channel(1);
let (requesting_send, requesting_recv) = mpsc::channel(1);
let dispatcher = Dispatcher::new(ws, notified_send, requesting_recv)
.await
.context("error creating new dispatcher")?;
@@ -147,9 +106,45 @@ impl Runner {
kill,
};
let mem = get_total_system_memory();
// If we have both the cgroup and file cache integrations enabled, it's possible for
// temporary failures to result in cgroup throttling (from memory.high), that in turn makes
// it near-impossible to connect to the file cache (because it times out). Unfortunately,
// we *do* still want to determine the file cache size before setting the cgroup's
// memory.high, so it's not as simple as just swapping the order.
//
// Instead, the resolution here is that on vm-monitor startup (note: happens on each
// connection from autoscaler-agent, possibly multiple times per compute_ctl lifecycle), we
// temporarily unset memory.high, to allow any existing throttling to dissipate. It's a bit
// of a hacky solution, but helps with reliability.
if let Some(name) = &args.cgroup {
// Best not to set up cgroup stuff more than once, so we'll initialize cgroup state
// now, and then set limits later.
info!("initializing cgroup");
let mut file_cache_disk_size = 0;
let (cgroup, cgroup_event_stream) = CgroupWatcher::new(name.clone(), requesting_send)
.context("failed to create cgroup manager")?;
info!("temporarily unsetting memory.high");
// Temporarily un-set cgroup memory.high; see above.
cgroup
.unset_memory_high()
.context("failed to unset memory.high")?;
let cgroup = Arc::new(cgroup);
let cgroup_clone = Arc::clone(&cgroup);
spawn_with_cancel(
token.clone(),
|_| error!("cgroup watcher terminated"),
async move { cgroup_clone.watch(notified_recv, cgroup_event_stream).await },
);
state.cgroup = Some(cgroup);
}
let mut file_cache_reserved_bytes = 0;
let mem = get_total_system_memory();
// We need to process file cache initialization before cgroup initialization, so that the memory
// allocated to the file cache is appropriately taken into account when we decide the cgroup's
@@ -161,7 +156,7 @@ impl Runner {
false => FileCacheConfig::default_in_memory(),
};
let mut file_cache = FileCacheState::new(connstr, config, token.clone())
let mut file_cache = FileCacheState::new(connstr, config, token)
.await
.context("failed to create file cache")?;
@@ -186,40 +181,23 @@ impl Runner {
if actual_size != new_size {
info!("file cache size actually got set to {actual_size}")
}
if args.file_cache_on_disk {
file_cache_disk_size = actual_size;
// Mark the resources given to the file cache as reserved, but only if it's in memory.
if !args.file_cache_on_disk {
file_cache_reserved_bytes = actual_size;
}
state.filecache = Some(file_cache);
}
if let Some(name) = &args.cgroup {
// Best not to set up cgroup stuff more than once, so we'll initialize cgroup state
// now, and then set limits later.
info!("initializing cgroup");
if let Some(cgroup) = &state.cgroup {
let available = mem - file_cache_reserved_bytes;
let value = cgroup.config.calculate_memory_high_value(available);
let cgroup =
CgroupWatcher::new(name.clone()).context("failed to create cgroup manager")?;
info!(value, "setting memory.high");
let init_value = cgroup::MemoryHistory {
avg_non_reclaimable: 0,
samples_count: 0,
samples_span: Duration::ZERO,
};
let (hist_tx, hist_rx) = watch::channel((Instant::now(), init_value));
spawn_with_cancel(token, |_| error!("cgroup watcher terminated"), async move {
cgroup.watch(hist_tx).await
});
let threshold = state.config.cgroup_threshold(mem, file_cache_disk_size);
info!(threshold, "set initial cgroup threshold",);
state.cgroup = Some(CgroupState {
watcher: hist_rx,
threshold,
});
cgroup
.set_memory_high_bytes(value)
.context("failed to set cgroup memory.high")?;
}
Ok(state)
@@ -239,51 +217,28 @@ impl Runner {
let requested_mem = target.mem;
let usable_system_memory = requested_mem.saturating_sub(self.config.sys_buffer_bytes);
let (expected_file_cache_size, expected_file_cache_disk_size) = self
let expected_file_cache_mem_usage = self
.filecache
.as_ref()
.map(|file_cache| {
let size = file_cache.config.calculate_cache_size(usable_system_memory);
match file_cache.config.in_memory {
true => (size, 0),
false => (size, size),
}
})
.unwrap_or((0, 0));
.map(|file_cache| file_cache.config.calculate_cache_size(usable_system_memory))
.unwrap_or(0);
let mut new_cgroup_mem_high = 0;
if let Some(cgroup) = &self.cgroup {
let (last_time, last_history) = *cgroup.watcher.borrow();
// NB: The ordering of these conditions is intentional. During startup, we should deny
// downscaling until we have enough information to determine that it's safe to do so
// (i.e. enough samples have come in). But if it's been a while and we *still* haven't
// received any information, we should *fail* instead of just denying downscaling.
//
// `last_time` is set to `Instant::now()` on startup, so checking `last_time.elapsed()`
// serves double-duty: it trips if we haven't received *any* metrics for long enough,
// OR if we haven't received metrics *recently enough*.
//
// TODO: make the duration here configurable.
if last_time.elapsed() > Duration::from_secs(5) {
bail!("haven't gotten cgroup memory stats recently enough to determine downscaling information");
} else if last_history.samples_count <= 1 {
let status = "haven't received enough cgroup memory stats yet";
info!(status, "discontinuing downscale");
return Ok((false, status.to_owned()));
}
let new_threshold = self
new_cgroup_mem_high = cgroup
.config
.cgroup_threshold(usable_system_memory, expected_file_cache_disk_size);
.calculate_memory_high_value(usable_system_memory - expected_file_cache_mem_usage);
let current = last_history.avg_non_reclaimable;
let current = cgroup
.current_memory_usage()
.context("failed to fetch cgroup memory")?;
if new_threshold < current + self.config.cgroup_downscale_threshold_buffer_bytes {
if new_cgroup_mem_high < current + cgroup.config.memory_high_buffer_bytes {
let status = format!(
"{}: {} MiB (new threshold) < {} (current usage) + {} (downscale buffer)",
"calculated memory threshold too low",
bytes_to_mebibytes(new_threshold),
"{}: {} MiB (new high) < {} (current usage) + {} (buffer)",
"calculated memory.high too low",
bytes_to_mebibytes(new_cgroup_mem_high),
bytes_to_mebibytes(current),
bytes_to_mebibytes(self.config.cgroup_downscale_threshold_buffer_bytes)
bytes_to_mebibytes(cgroup.config.memory_high_buffer_bytes)
);
info!(status, "discontinuing downscale");
@@ -294,14 +249,14 @@ impl Runner {
// The downscaling has been approved. Downscale the file cache, then the cgroup.
let mut status = vec![];
let mut file_cache_disk_size = 0;
let mut file_cache_mem_usage = 0;
if let Some(file_cache) = &mut self.filecache {
let actual_usage = file_cache
.set_file_cache_size(expected_file_cache_size)
.set_file_cache_size(expected_file_cache_mem_usage)
.await
.context("failed to set file cache size")?;
if !file_cache.config.in_memory {
file_cache_disk_size = actual_usage;
if file_cache.config.in_memory {
file_cache_mem_usage = actual_usage;
}
let message = format!(
"set file cache size to {} MiB (in memory = {})",
@@ -312,18 +267,24 @@ impl Runner {
status.push(message);
}
if let Some(cgroup) = &mut self.cgroup {
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
if let Some(cgroup) = &self.cgroup {
let available_memory = usable_system_memory - file_cache_mem_usage;
if file_cache_mem_usage != expected_file_cache_mem_usage {
new_cgroup_mem_high = cgroup.config.calculate_memory_high_value(available_memory);
}
// new_cgroup_mem_high is initialized to 0 but it is guaranteed to not be here
// since it is properly initialized in the previous cgroup if let block
cgroup
.set_memory_high_bytes(new_cgroup_mem_high)
.context("failed to set cgroup memory.high")?;
let message = format!(
"set cgroup memory threshold from {} MiB to {} MiB, of new total {} MiB",
bytes_to_mebibytes(cgroup.threshold),
bytes_to_mebibytes(new_threshold),
bytes_to_mebibytes(usable_system_memory)
"set cgroup memory.high to {} MiB, of new max {} MiB",
bytes_to_mebibytes(new_cgroup_mem_high),
bytes_to_mebibytes(available_memory)
);
cgroup.threshold = new_threshold;
info!("downscale: {message}");
status.push(message);
}
@@ -344,7 +305,8 @@ impl Runner {
let new_mem = resources.mem;
let usable_system_memory = new_mem.saturating_sub(self.config.sys_buffer_bytes);
let mut file_cache_disk_size = 0;
// Get the file cache's expected contribution to the memory usage
let mut file_cache_mem_usage = 0;
if let Some(file_cache) = &mut self.filecache {
let expected_usage = file_cache.config.calculate_cache_size(usable_system_memory);
info!(
@@ -357,8 +319,8 @@ impl Runner {
.set_file_cache_size(expected_usage)
.await
.context("failed to set file cache size")?;
if !file_cache.config.in_memory {
file_cache_disk_size = actual_usage;
if file_cache.config.in_memory {
file_cache_mem_usage = actual_usage;
}
if actual_usage != expected_usage {
@@ -370,18 +332,18 @@ impl Runner {
}
}
if let Some(cgroup) = &mut self.cgroup {
let new_threshold = self
.config
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
if let Some(cgroup) = &self.cgroup {
let available_memory = usable_system_memory - file_cache_mem_usage;
let new_cgroup_mem_high = cgroup.config.calculate_memory_high_value(available_memory);
info!(
"set cgroup memory threshold from {} MiB to {} MiB of new total {} MiB",
bytes_to_mebibytes(cgroup.threshold),
bytes_to_mebibytes(new_threshold),
bytes_to_mebibytes(usable_system_memory)
target = bytes_to_mebibytes(new_cgroup_mem_high),
total = bytes_to_mebibytes(new_mem),
name = cgroup.path(),
"updating cgroup memory.high",
);
cgroup.threshold = new_threshold;
cgroup
.set_memory_high_bytes(new_cgroup_mem_high)
.context("failed to set cgroup memory.high")?;
}
Ok(())
@@ -399,6 +361,10 @@ impl Runner {
self.handle_upscale(granted)
.await
.context("failed to handle upscale")?;
self.dispatcher
.notify_upscale(Sequenced::new(granted))
.await
.context("failed to notify notify cgroup of upscale")?;
Ok(Some(OutboundMsg::new(
OutboundMsgKind::UpscaleConfirmation {},
id,
@@ -442,53 +408,33 @@ impl Runner {
Err(e) => bail!("failed to receive kill signal: {e}")
}
}
// New memory stats from the cgroup, *may* need to request upscaling, if we've
// exceeded the threshold
result = self.cgroup.as_mut().unwrap().watcher.changed(), if self.cgroup.is_some() => {
result.context("failed to receive from cgroup memory stats watcher")?;
let cgroup = self.cgroup.as_ref().unwrap();
let (_time, cgroup_mem_stat) = *cgroup.watcher.borrow();
// If we haven't exceeded the threshold, then we're all ok
if cgroup_mem_stat.avg_non_reclaimable < cgroup.threshold {
continue;
// we need to propagate an upscale request
request = self.dispatcher.request_upscale_events.recv(), if self.cgroup.is_some() => {
if request.is_none() {
bail!("failed to listen for upscale event from cgroup")
}
// Otherwise, we generally want upscaling. But, if it's been less than 1 second
// since the last time we requested upscaling, ignore the event, to avoid
// spamming the agent.
// If it's been less than 1 second since the last time we requested upscaling,
// ignore the event, to avoid spamming the agent (otherwise, this can happen
// ~1k times per second).
if let Some(t) = self.last_upscale_request_at {
let elapsed = t.elapsed();
if elapsed < Duration::from_secs(1) {
info!(
elapsed_millis = elapsed.as_millis(),
avg_non_reclaimable = bytes_to_mebibytes(cgroup_mem_stat.avg_non_reclaimable),
threshold = bytes_to_mebibytes(cgroup.threshold),
"cgroup memory stats are high enough to upscale but too soon to forward the request, ignoring",
);
info!(elapsed_millis = elapsed.as_millis(), "cgroup asked for upscale but too soon to forward the request, ignoring");
continue;
}
}
self.last_upscale_request_at = Some(Instant::now());
info!(
avg_non_reclaimable = bytes_to_mebibytes(cgroup_mem_stat.avg_non_reclaimable),
threshold = bytes_to_mebibytes(cgroup.threshold),
"cgroup memory stats are high enough to upscale, requesting upscale",
);
info!("cgroup asking for upscale; forwarding request");
self.counter += 2; // Increment, preserving parity (i.e. keep the
// counter odd). See the field comment for more.
self.dispatcher
.send(OutboundMsg::new(OutboundMsgKind::UpscaleRequest {}, self.counter))
.await
.context("failed to send message")?;
},
}
// there is a message from the agent
msg = self.dispatcher.source.next() => {
if let Some(msg) = msg {
@@ -516,14 +462,11 @@ impl Runner {
Ok(Some(out)) => out,
Ok(None) => continue,
Err(e) => {
// use {:#} for our logging because the display impl only
// gives the outermost cause, and the debug impl
// pretty-prints the error, whereas {:#} contains all the
// causes, but is compact (no newlines).
warn!(error = format!("{e:#}"), "error handling message");
let error = e.to_string();
warn!(?error, "error handling message");
OutboundMsg::new(
OutboundMsgKind::InternalError {
error: e.to_string(),
error
},
message.id
)

View File

@@ -1,16 +0,0 @@
[package]
name = "walproposer"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
utils.workspace = true
postgres_ffi.workspace = true
workspace_hack.workspace = true
[build-dependencies]
anyhow.workspace = true
bindgen.workspace = true

View File

@@ -1 +0,0 @@
#include "walproposer.h"

View File

@@ -1,113 +0,0 @@
use std::{env, path::PathBuf, process::Command};
use anyhow::{anyhow, Context};
use bindgen::CargoCallbacks;
fn main() -> anyhow::Result<()> {
// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=bindgen_deps.h");
// Finding the location of built libraries and Postgres C headers:
// - if POSTGRES_INSTALL_DIR is set look into it, otherwise look into `<project_root>/pg_install`
// - if there's a `bin/pg_config` file use it for getting include server, otherwise use `<project_root>/pg_install/{PG_MAJORVERSION}/include/postgresql/server`
let pg_install_dir = if let Some(postgres_install_dir) = env::var_os("POSTGRES_INSTALL_DIR") {
postgres_install_dir.into()
} else {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../pg_install")
};
let pg_install_abs = std::fs::canonicalize(pg_install_dir)?;
let walproposer_lib_dir = pg_install_abs.join("build/walproposer-lib");
let walproposer_lib_search_str = walproposer_lib_dir
.to_str()
.ok_or(anyhow!("Bad non-UTF path"))?;
let pgxn_neon = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../pgxn/neon");
let pgxn_neon = std::fs::canonicalize(pgxn_neon)?;
let pgxn_neon = pgxn_neon.to_str().ok_or(anyhow!("Bad non-UTF path"))?;
println!("cargo:rustc-link-lib=static=pgport");
println!("cargo:rustc-link-lib=static=pgcommon");
println!("cargo:rustc-link-lib=static=walproposer");
println!("cargo:rustc-link-search={walproposer_lib_search_str}");
let pg_config_bin = pg_install_abs.join("v16").join("bin").join("pg_config");
let inc_server_path: String = if pg_config_bin.exists() {
let output = Command::new(pg_config_bin)
.arg("--includedir-server")
.output()
.context("failed to execute `pg_config --includedir-server`")?;
if !output.status.success() {
panic!("`pg_config --includedir-server` failed")
}
String::from_utf8(output.stdout)
.context("pg_config output is not UTF-8")?
.trim_end()
.into()
} else {
let server_path = pg_install_abs
.join("v16")
.join("include")
.join("postgresql")
.join("server")
.into_os_string();
server_path
.into_string()
.map_err(|s| anyhow!("Bad postgres server path {s:?}"))?
};
// The bindgen::Builder is the main entry point
// to bindgen, and lets you build up options for
// the resulting bindings.
let bindings = bindgen::Builder::default()
// The input header we would like to generate
// bindings for.
.header("bindgen_deps.h")
// Tell cargo to invalidate the built crate whenever any of the
// included header files changed.
.parse_callbacks(Box::new(CargoCallbacks))
.allowlist_type("WalProposer")
.allowlist_type("WalProposerConfig")
.allowlist_type("walproposer_api")
.allowlist_function("WalProposerCreate")
.allowlist_function("WalProposerStart")
.allowlist_function("WalProposerBroadcast")
.allowlist_function("WalProposerPoll")
.allowlist_function("WalProposerFree")
.allowlist_var("DEBUG5")
.allowlist_var("DEBUG4")
.allowlist_var("DEBUG3")
.allowlist_var("DEBUG2")
.allowlist_var("DEBUG1")
.allowlist_var("LOG")
.allowlist_var("INFO")
.allowlist_var("NOTICE")
.allowlist_var("WARNING")
.allowlist_var("ERROR")
.allowlist_var("FATAL")
.allowlist_var("PANIC")
.allowlist_var("WPEVENT")
.allowlist_var("WL_LATCH_SET")
.allowlist_var("WL_SOCKET_READABLE")
.allowlist_var("WL_SOCKET_WRITEABLE")
.allowlist_var("WL_TIMEOUT")
.allowlist_var("WL_SOCKET_CLOSED")
.allowlist_var("WL_SOCKET_MASK")
.clang_arg("-DWALPROPOSER_LIB")
.clang_arg(format!("-I{pgxn_neon}"))
.clang_arg(format!("-I{inc_server_path}"))
// Finish the builder and generate the bindings.
.generate()
// Unwrap the Result and panic on failure.
.expect("Unable to generate bindings");
// Write the bindings to the $OUT_DIR/bindings.rs file.
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs");
bindings
.write_to_file(out_path)
.expect("Couldn't write bindings!");
Ok(())
}

View File

@@ -1,455 +0,0 @@
#![allow(dead_code)]
use std::ffi::CStr;
use std::ffi::CString;
use crate::bindings::uint32;
use crate::bindings::walproposer_api;
use crate::bindings::PGAsyncReadResult;
use crate::bindings::PGAsyncWriteResult;
use crate::bindings::Safekeeper;
use crate::bindings::Size;
use crate::bindings::StringInfoData;
use crate::bindings::TimeLineID;
use crate::bindings::TimestampTz;
use crate::bindings::WalProposer;
use crate::bindings::WalProposerConnStatusType;
use crate::bindings::WalProposerConnectPollStatusType;
use crate::bindings::WalProposerExecStatusType;
use crate::bindings::WalproposerShmemState;
use crate::bindings::XLogRecPtr;
use crate::walproposer::ApiImpl;
use crate::walproposer::WaitResult;
extern "C" fn get_shmem_state(wp: *mut WalProposer) -> *mut WalproposerShmemState {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).get_shmem_state()
}
}
extern "C" fn start_streaming(wp: *mut WalProposer, startpos: XLogRecPtr) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).start_streaming(startpos)
}
}
extern "C" fn get_flush_rec_ptr(wp: *mut WalProposer) -> XLogRecPtr {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).get_flush_rec_ptr()
}
}
extern "C" fn get_current_timestamp(wp: *mut WalProposer) -> TimestampTz {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).get_current_timestamp()
}
}
extern "C" fn conn_error_message(sk: *mut Safekeeper) -> *mut ::std::os::raw::c_char {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
let msg = (*api).conn_error_message(&mut (*sk));
let msg = CString::new(msg).unwrap();
// TODO: fix leaking error message
msg.into_raw()
}
}
extern "C" fn conn_status(sk: *mut Safekeeper) -> WalProposerConnStatusType {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).conn_status(&mut (*sk))
}
}
extern "C" fn conn_connect_start(sk: *mut Safekeeper) {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).conn_connect_start(&mut (*sk))
}
}
extern "C" fn conn_connect_poll(sk: *mut Safekeeper) -> WalProposerConnectPollStatusType {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).conn_connect_poll(&mut (*sk))
}
}
extern "C" fn conn_send_query(sk: *mut Safekeeper, query: *mut ::std::os::raw::c_char) -> bool {
let query = unsafe { CStr::from_ptr(query) };
let query = query.to_str().unwrap();
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).conn_send_query(&mut (*sk), query)
}
}
extern "C" fn conn_get_query_result(sk: *mut Safekeeper) -> WalProposerExecStatusType {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).conn_get_query_result(&mut (*sk))
}
}
extern "C" fn conn_flush(sk: *mut Safekeeper) -> ::std::os::raw::c_int {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).conn_flush(&mut (*sk))
}
}
extern "C" fn conn_finish(sk: *mut Safekeeper) {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).conn_finish(&mut (*sk))
}
}
extern "C" fn conn_async_read(
sk: *mut Safekeeper,
buf: *mut *mut ::std::os::raw::c_char,
amount: *mut ::std::os::raw::c_int,
) -> PGAsyncReadResult {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
let (res, result) = (*api).conn_async_read(&mut (*sk));
// This function has guarantee that returned buf will be valid until
// the next call. So we can store a Vec in each Safekeeper and reuse
// it on the next call.
let mut inbuf = take_vec_u8(&mut (*sk).inbuf).unwrap_or_default();
inbuf.clear();
inbuf.extend_from_slice(res);
// Put a Vec back to sk->inbuf and return data ptr.
*buf = store_vec_u8(&mut (*sk).inbuf, inbuf);
*amount = res.len() as i32;
result
}
}
extern "C" fn conn_async_write(
sk: *mut Safekeeper,
buf: *const ::std::os::raw::c_void,
size: usize,
) -> PGAsyncWriteResult {
unsafe {
let buf = std::slice::from_raw_parts(buf as *const u8, size);
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).conn_async_write(&mut (*sk), buf)
}
}
extern "C" fn conn_blocking_write(
sk: *mut Safekeeper,
buf: *const ::std::os::raw::c_void,
size: usize,
) -> bool {
unsafe {
let buf = std::slice::from_raw_parts(buf as *const u8, size);
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).conn_blocking_write(&mut (*sk), buf)
}
}
extern "C" fn recovery_download(
sk: *mut Safekeeper,
_timeline: TimeLineID,
startpos: XLogRecPtr,
endpos: XLogRecPtr,
) -> bool {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).recovery_download(&mut (*sk), startpos, endpos)
}
}
extern "C" fn wal_read(
sk: *mut Safekeeper,
buf: *mut ::std::os::raw::c_char,
startptr: XLogRecPtr,
count: Size,
) {
unsafe {
let buf = std::slice::from_raw_parts_mut(buf as *mut u8, count);
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).wal_read(&mut (*sk), buf, startptr)
}
}
extern "C" fn wal_reader_allocate(sk: *mut Safekeeper) {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).wal_reader_allocate(&mut (*sk));
}
}
extern "C" fn free_event_set(wp: *mut WalProposer) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).free_event_set(&mut (*wp));
}
}
extern "C" fn init_event_set(wp: *mut WalProposer) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).init_event_set(&mut (*wp));
}
}
extern "C" fn update_event_set(sk: *mut Safekeeper, events: uint32) {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).update_event_set(&mut (*sk), events);
}
}
extern "C" fn add_safekeeper_event_set(sk: *mut Safekeeper, events: uint32) {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).add_safekeeper_event_set(&mut (*sk), events);
}
}
extern "C" fn wait_event_set(
wp: *mut WalProposer,
timeout: ::std::os::raw::c_long,
event_sk: *mut *mut Safekeeper,
events: *mut uint32,
) -> ::std::os::raw::c_int {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
let result = (*api).wait_event_set(&mut (*wp), timeout);
match result {
WaitResult::Latch => {
*event_sk = std::ptr::null_mut();
*events = crate::bindings::WL_LATCH_SET;
1
}
WaitResult::Timeout => {
*event_sk = std::ptr::null_mut();
*events = crate::bindings::WL_TIMEOUT;
0
}
WaitResult::Network(sk, event_mask) => {
*event_sk = sk;
*events = event_mask;
1
}
}
}
}
extern "C" fn strong_random(
wp: *mut WalProposer,
buf: *mut ::std::os::raw::c_void,
len: usize,
) -> bool {
unsafe {
let buf = std::slice::from_raw_parts_mut(buf as *mut u8, len);
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).strong_random(buf)
}
}
extern "C" fn get_redo_start_lsn(wp: *mut WalProposer) -> XLogRecPtr {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).get_redo_start_lsn()
}
}
extern "C" fn finish_sync_safekeepers(wp: *mut WalProposer, lsn: XLogRecPtr) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).finish_sync_safekeepers(lsn)
}
}
extern "C" fn process_safekeeper_feedback(wp: *mut WalProposer, commit_lsn: XLogRecPtr) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).process_safekeeper_feedback(&mut (*wp), commit_lsn)
}
}
extern "C" fn confirm_wal_streamed(wp: *mut WalProposer, lsn: XLogRecPtr) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).confirm_wal_streamed(&mut (*wp), lsn)
}
}
extern "C" fn log_internal(
wp: *mut WalProposer,
level: ::std::os::raw::c_int,
line: *const ::std::os::raw::c_char,
) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
let line = CStr::from_ptr(line);
let line = line.to_str().unwrap();
(*api).log_internal(&mut (*wp), Level::from(level as u32), line)
}
}
extern "C" fn after_election(wp: *mut WalProposer) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).after_election(&mut (*wp))
}
}
#[derive(Debug)]
pub enum Level {
Debug5,
Debug4,
Debug3,
Debug2,
Debug1,
Log,
Info,
Notice,
Warning,
Error,
Fatal,
Panic,
WPEvent,
}
impl Level {
pub fn from(elevel: u32) -> Level {
use crate::bindings::*;
match elevel {
DEBUG5 => Level::Debug5,
DEBUG4 => Level::Debug4,
DEBUG3 => Level::Debug3,
DEBUG2 => Level::Debug2,
DEBUG1 => Level::Debug1,
LOG => Level::Log,
INFO => Level::Info,
NOTICE => Level::Notice,
WARNING => Level::Warning,
ERROR => Level::Error,
FATAL => Level::Fatal,
PANIC => Level::Panic,
WPEVENT => Level::WPEvent,
_ => panic!("unknown log level {}", elevel),
}
}
}
pub(crate) fn create_api() -> walproposer_api {
walproposer_api {
get_shmem_state: Some(get_shmem_state),
start_streaming: Some(start_streaming),
get_flush_rec_ptr: Some(get_flush_rec_ptr),
get_current_timestamp: Some(get_current_timestamp),
conn_error_message: Some(conn_error_message),
conn_status: Some(conn_status),
conn_connect_start: Some(conn_connect_start),
conn_connect_poll: Some(conn_connect_poll),
conn_send_query: Some(conn_send_query),
conn_get_query_result: Some(conn_get_query_result),
conn_flush: Some(conn_flush),
conn_finish: Some(conn_finish),
conn_async_read: Some(conn_async_read),
conn_async_write: Some(conn_async_write),
conn_blocking_write: Some(conn_blocking_write),
recovery_download: Some(recovery_download),
wal_read: Some(wal_read),
wal_reader_allocate: Some(wal_reader_allocate),
free_event_set: Some(free_event_set),
init_event_set: Some(init_event_set),
update_event_set: Some(update_event_set),
add_safekeeper_event_set: Some(add_safekeeper_event_set),
wait_event_set: Some(wait_event_set),
strong_random: Some(strong_random),
get_redo_start_lsn: Some(get_redo_start_lsn),
finish_sync_safekeepers: Some(finish_sync_safekeepers),
process_safekeeper_feedback: Some(process_safekeeper_feedback),
confirm_wal_streamed: Some(confirm_wal_streamed),
log_internal: Some(log_internal),
after_election: Some(after_election),
}
}
impl std::fmt::Display for Level {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
/// Take ownership of `Vec<u8>` from StringInfoData.
pub(crate) fn take_vec_u8(pg: &mut StringInfoData) -> Option<Vec<u8>> {
if pg.data.is_null() {
return None;
}
let ptr = pg.data as *mut u8;
let length = pg.len as usize;
let capacity = pg.maxlen as usize;
pg.data = std::ptr::null_mut();
pg.len = 0;
pg.maxlen = 0;
unsafe { Some(Vec::from_raw_parts(ptr, length, capacity)) }
}
/// Store `Vec<u8>` in StringInfoData.
fn store_vec_u8(pg: &mut StringInfoData, vec: Vec<u8>) -> *mut ::std::os::raw::c_char {
let ptr = vec.as_ptr() as *mut ::std::os::raw::c_char;
let length = vec.len();
let capacity = vec.capacity();
assert!(pg.data.is_null());
pg.data = ptr;
pg.len = length as i32;
pg.maxlen = capacity as i32;
std::mem::forget(vec);
ptr
}

View File

@@ -1,14 +0,0 @@
pub mod bindings {
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
// bindgen creates some unsafe code with no doc comments.
#![allow(clippy::missing_safety_doc)]
// noted at 1.63 that in many cases there's a u32 -> u32 transmutes in bindgen code.
#![allow(clippy::useless_transmute)]
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
pub mod api_bindings;
pub mod walproposer;

View File

@@ -1,485 +0,0 @@
use std::ffi::CString;
use postgres_ffi::WAL_SEGMENT_SIZE;
use utils::id::TenantTimelineId;
use crate::{
api_bindings::{create_api, take_vec_u8, Level},
bindings::{
Safekeeper, WalProposer, WalProposerConfig, WalProposerCreate, WalProposerFree,
WalProposerStart,
},
};
/// Rust high-level wrapper for C walproposer API. Many methods are not required
/// for simple cases, hence todo!() in default implementations.
///
/// Refer to `pgxn/neon/walproposer.h` for documentation.
pub trait ApiImpl {
fn get_shmem_state(&self) -> &mut crate::bindings::WalproposerShmemState {
todo!()
}
fn start_streaming(&self, _startpos: u64) {
todo!()
}
fn get_flush_rec_ptr(&self) -> u64 {
todo!()
}
fn get_current_timestamp(&self) -> i64 {
todo!()
}
fn conn_error_message(&self, _sk: &mut Safekeeper) -> String {
todo!()
}
fn conn_status(&self, _sk: &mut Safekeeper) -> crate::bindings::WalProposerConnStatusType {
todo!()
}
fn conn_connect_start(&self, _sk: &mut Safekeeper) {
todo!()
}
fn conn_connect_poll(
&self,
_sk: &mut Safekeeper,
) -> crate::bindings::WalProposerConnectPollStatusType {
todo!()
}
fn conn_send_query(&self, _sk: &mut Safekeeper, _query: &str) -> bool {
todo!()
}
fn conn_get_query_result(
&self,
_sk: &mut Safekeeper,
) -> crate::bindings::WalProposerExecStatusType {
todo!()
}
fn conn_flush(&self, _sk: &mut Safekeeper) -> i32 {
todo!()
}
fn conn_finish(&self, _sk: &mut Safekeeper) {
todo!()
}
fn conn_async_read(&self, _sk: &mut Safekeeper) -> (&[u8], crate::bindings::PGAsyncReadResult) {
todo!()
}
fn conn_async_write(
&self,
_sk: &mut Safekeeper,
_buf: &[u8],
) -> crate::bindings::PGAsyncWriteResult {
todo!()
}
fn conn_blocking_write(&self, _sk: &mut Safekeeper, _buf: &[u8]) -> bool {
todo!()
}
fn recovery_download(&self, _sk: &mut Safekeeper, _startpos: u64, _endpos: u64) -> bool {
todo!()
}
fn wal_read(&self, _sk: &mut Safekeeper, _buf: &mut [u8], _startpos: u64) {
todo!()
}
fn wal_reader_allocate(&self, _sk: &mut Safekeeper) {
todo!()
}
fn free_event_set(&self, _wp: &mut WalProposer) {
todo!()
}
fn init_event_set(&self, _wp: &mut WalProposer) {
todo!()
}
fn update_event_set(&self, _sk: &mut Safekeeper, _events_mask: u32) {
todo!()
}
fn add_safekeeper_event_set(&self, _sk: &mut Safekeeper, _events_mask: u32) {
todo!()
}
fn wait_event_set(&self, _wp: &mut WalProposer, _timeout_millis: i64) -> WaitResult {
todo!()
}
fn strong_random(&self, _buf: &mut [u8]) -> bool {
todo!()
}
fn get_redo_start_lsn(&self) -> u64 {
todo!()
}
fn finish_sync_safekeepers(&self, _lsn: u64) {
todo!()
}
fn process_safekeeper_feedback(&self, _wp: &mut WalProposer, _commit_lsn: u64) {
todo!()
}
fn confirm_wal_streamed(&self, _wp: &mut WalProposer, _lsn: u64) {
todo!()
}
fn log_internal(&self, _wp: &mut WalProposer, _level: Level, _msg: &str) {
todo!()
}
fn after_election(&self, _wp: &mut WalProposer) {
todo!()
}
}
pub enum WaitResult {
Latch,
Timeout,
Network(*mut Safekeeper, u32),
}
pub struct Config {
/// Tenant and timeline id
pub ttid: TenantTimelineId,
/// List of safekeepers in format `host:port`
pub safekeepers_list: Vec<String>,
/// Safekeeper reconnect timeout in milliseconds
pub safekeeper_reconnect_timeout: i32,
/// Safekeeper connection timeout in milliseconds
pub safekeeper_connection_timeout: i32,
/// walproposer mode, finish when all safekeepers are synced or subscribe
/// to WAL streaming
pub sync_safekeepers: bool,
}
/// WalProposer main struct. C methods are reexported as Rust functions.
pub struct Wrapper {
wp: *mut WalProposer,
_safekeepers_list_vec: Vec<u8>,
}
impl Wrapper {
pub fn new(api: Box<dyn ApiImpl>, config: Config) -> Wrapper {
let neon_tenant = CString::new(config.ttid.tenant_id.to_string())
.unwrap()
.into_raw();
let neon_timeline = CString::new(config.ttid.timeline_id.to_string())
.unwrap()
.into_raw();
let mut safekeepers_list_vec = CString::new(config.safekeepers_list.join(","))
.unwrap()
.into_bytes_with_nul();
assert!(safekeepers_list_vec.len() == safekeepers_list_vec.capacity());
let safekeepers_list = safekeepers_list_vec.as_mut_ptr() as *mut i8;
let callback_data = Box::into_raw(Box::new(api)) as *mut ::std::os::raw::c_void;
let c_config = WalProposerConfig {
neon_tenant,
neon_timeline,
safekeepers_list,
safekeeper_reconnect_timeout: config.safekeeper_reconnect_timeout,
safekeeper_connection_timeout: config.safekeeper_connection_timeout,
wal_segment_size: WAL_SEGMENT_SIZE as i32, // default 16MB
syncSafekeepers: config.sync_safekeepers,
systemId: 0,
pgTimeline: 1,
callback_data,
};
let c_config = Box::into_raw(Box::new(c_config));
let api = create_api();
let wp = unsafe { WalProposerCreate(c_config, api) };
Wrapper {
wp,
_safekeepers_list_vec: safekeepers_list_vec,
}
}
pub fn start(&self) {
unsafe { WalProposerStart(self.wp) }
}
}
impl Drop for Wrapper {
fn drop(&mut self) {
unsafe {
let config = (*self.wp).config;
drop(Box::from_raw(
(*config).callback_data as *mut Box<dyn ApiImpl>,
));
drop(CString::from_raw((*config).neon_tenant));
drop(CString::from_raw((*config).neon_timeline));
drop(Box::from_raw(config));
for i in 0..(*self.wp).n_safekeepers {
let sk = &mut (*self.wp).safekeeper[i as usize];
take_vec_u8(&mut sk.inbuf);
}
WalProposerFree(self.wp);
}
}
}
#[cfg(test)]
mod tests {
use std::{
cell::Cell,
sync::{atomic::AtomicUsize, mpsc::sync_channel},
};
use utils::id::TenantTimelineId;
use crate::{api_bindings::Level, walproposer::Wrapper};
use super::ApiImpl;
#[derive(Clone, Copy, Debug)]
struct WaitEventsData {
sk: *mut crate::bindings::Safekeeper,
event_mask: u32,
}
struct MockImpl {
// data to return from wait_event_set
wait_events: Cell<WaitEventsData>,
// walproposer->safekeeper messages
expected_messages: Vec<Vec<u8>>,
expected_ptr: AtomicUsize,
// safekeeper->walproposer messages
safekeeper_replies: Vec<Vec<u8>>,
replies_ptr: AtomicUsize,
// channel to send LSN to the main thread
sync_channel: std::sync::mpsc::SyncSender<u64>,
}
impl MockImpl {
fn check_walproposer_msg(&self, msg: &[u8]) {
let ptr = self
.expected_ptr
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if ptr >= self.expected_messages.len() {
panic!("unexpected message from walproposer");
}
let expected_msg = &self.expected_messages[ptr];
assert_eq!(msg, expected_msg.as_slice());
}
fn next_safekeeper_reply(&self) -> &[u8] {
let ptr = self
.replies_ptr
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if ptr >= self.safekeeper_replies.len() {
panic!("no more safekeeper replies");
}
&self.safekeeper_replies[ptr]
}
}
impl ApiImpl for MockImpl {
fn get_current_timestamp(&self) -> i64 {
println!("get_current_timestamp");
0
}
fn conn_status(
&self,
_: &mut crate::bindings::Safekeeper,
) -> crate::bindings::WalProposerConnStatusType {
println!("conn_status");
crate::bindings::WalProposerConnStatusType_WP_CONNECTION_OK
}
fn conn_connect_start(&self, _: &mut crate::bindings::Safekeeper) {
println!("conn_connect_start");
}
fn conn_connect_poll(
&self,
_: &mut crate::bindings::Safekeeper,
) -> crate::bindings::WalProposerConnectPollStatusType {
println!("conn_connect_poll");
crate::bindings::WalProposerConnectPollStatusType_WP_CONN_POLLING_OK
}
fn conn_send_query(&self, _: &mut crate::bindings::Safekeeper, query: &str) -> bool {
println!("conn_send_query: {}", query);
true
}
fn conn_get_query_result(
&self,
_: &mut crate::bindings::Safekeeper,
) -> crate::bindings::WalProposerExecStatusType {
println!("conn_get_query_result");
crate::bindings::WalProposerExecStatusType_WP_EXEC_SUCCESS_COPYBOTH
}
fn conn_async_read(
&self,
_: &mut crate::bindings::Safekeeper,
) -> (&[u8], crate::bindings::PGAsyncReadResult) {
println!("conn_async_read");
let reply = self.next_safekeeper_reply();
println!("conn_async_read result: {:?}", reply);
(
reply,
crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS,
)
}
fn conn_blocking_write(&self, _: &mut crate::bindings::Safekeeper, buf: &[u8]) -> bool {
println!("conn_blocking_write: {:?}", buf);
self.check_walproposer_msg(buf);
true
}
fn wal_reader_allocate(&self, _: &mut crate::bindings::Safekeeper) {
println!("wal_reader_allocate")
}
fn free_event_set(&self, _: &mut crate::bindings::WalProposer) {
println!("free_event_set")
}
fn init_event_set(&self, _: &mut crate::bindings::WalProposer) {
println!("init_event_set")
}
fn update_event_set(&self, sk: &mut crate::bindings::Safekeeper, event_mask: u32) {
println!(
"update_event_set, sk={:?}, events_mask={:#b}",
sk as *mut crate::bindings::Safekeeper, event_mask
);
self.wait_events.set(WaitEventsData { sk, event_mask });
}
fn add_safekeeper_event_set(&self, sk: &mut crate::bindings::Safekeeper, event_mask: u32) {
println!(
"add_safekeeper_event_set, sk={:?}, events_mask={:#b}",
sk as *mut crate::bindings::Safekeeper, event_mask
);
self.wait_events.set(WaitEventsData { sk, event_mask });
}
fn wait_event_set(
&self,
_: &mut crate::bindings::WalProposer,
timeout_millis: i64,
) -> super::WaitResult {
let data = self.wait_events.get();
println!(
"wait_event_set, timeout_millis={}, res={:?}",
timeout_millis, data
);
super::WaitResult::Network(data.sk, data.event_mask)
}
fn strong_random(&self, buf: &mut [u8]) -> bool {
println!("strong_random");
buf.fill(0);
true
}
fn finish_sync_safekeepers(&self, lsn: u64) {
self.sync_channel.send(lsn).unwrap();
panic!("sync safekeepers finished at lsn={}", lsn);
}
fn log_internal(&self, _wp: &mut crate::bindings::WalProposer, level: Level, msg: &str) {
println!("walprop_log[{}] {}", level, msg);
}
fn after_election(&self, _wp: &mut crate::bindings::WalProposer) {
println!("after_election");
}
}
/// Test that walproposer can successfully connect to safekeeper and finish
/// sync_safekeepers. API is mocked in MockImpl.
///
/// Run this test with valgrind to detect leaks:
/// `valgrind --leak-check=full target/debug/deps/walproposer-<build>`
#[test]
fn test_simple_sync_safekeepers() -> anyhow::Result<()> {
let ttid = TenantTimelineId::new(
"9e4c8f36063c6c6e93bc20d65a820f3d".parse()?,
"9e4c8f36063c6c6e93bc20d65a820f3d".parse()?,
);
let (sender, receiver) = sync_channel(1);
let my_impl: Box<dyn ApiImpl> = Box::new(MockImpl {
wait_events: Cell::new(WaitEventsData {
sk: std::ptr::null_mut(),
event_mask: 0,
}),
expected_messages: vec![
// Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160000, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 })
vec![
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 76, 143, 54, 6, 60, 108, 110,
147, 188, 32, 214, 90, 130, 15, 61, 158, 76, 143, 54, 6, 60, 108, 110, 147,
188, 32, 214, 90, 130, 15, 61, 1, 0, 0, 0, 0, 0, 0, 1,
],
// VoteRequest(VoteRequest { term: 3 })
vec![
118, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0,
],
],
expected_ptr: AtomicUsize::new(0),
safekeeper_replies: vec![
// Greeting(AcceptorGreeting { term: 2, node_id: NodeId(1) })
vec![
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
],
// VoteResponse(VoteResponse { term: 3, vote_given: 1, flush_lsn: 0/539, truncate_lsn: 0/539, term_history: [(2, 0/539)], timeline_start_lsn: 0/539 })
vec![
118, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 57,
5, 0, 0, 0, 0, 0, 0, 57, 5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,
0, 57, 5, 0, 0, 0, 0, 0, 0, 57, 5, 0, 0, 0, 0, 0, 0,
],
],
replies_ptr: AtomicUsize::new(0),
sync_channel: sender,
});
let config = crate::walproposer::Config {
ttid,
safekeepers_list: vec!["localhost:5000".to_string()],
safekeeper_reconnect_timeout: 1000,
safekeeper_connection_timeout: 10000,
sync_safekeepers: true,
};
let wp = Wrapper::new(my_impl, config);
// walproposer will panic when it finishes sync_safekeepers
std::panic::catch_unwind(|| wp.start()).unwrap_err();
// validate the resulting LSN
assert_eq!(receiver.recv()?, 1337);
Ok(())
// drop() will free up resources here
}
}

View File

@@ -11,7 +11,10 @@ use std::sync::{Arc, Barrier};
use bytes::{Buf, Bytes};
use pageserver::{
config::PageServerConf, repository::Key, walrecord::NeonWalRecord, walredo::PostgresRedoManager,
config::PageServerConf,
repository::Key,
walrecord::NeonWalRecord,
walredo::{PostgresRedoManager, WalRedoError},
};
use utils::{id::TenantId, lsn::Lsn};
@@ -32,15 +35,9 @@ fn redo_scenarios(c: &mut Criterion) {
let manager = Arc::new(manager);
{
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
tracing::info!("executing first");
short().execute(rt.handle(), &manager).unwrap();
tracing::info!("first executed");
}
tracing::info!("executing first");
short().execute(&manager).unwrap();
tracing::info!("first executed");
let thread_counts = [1, 2, 4, 8, 16];
@@ -83,14 +80,9 @@ fn add_multithreaded_walredo_requesters(
assert_ne!(threads, 0);
if threads == 1 {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let handle = rt.handle();
b.iter_batched_ref(
|| Some(input_factory()),
|input| execute_all(input.take(), handle, manager),
|input| execute_all(input.take(), manager),
criterion::BatchSize::PerIteration,
);
} else {
@@ -106,26 +98,19 @@ fn add_multithreaded_walredo_requesters(
let manager = manager.clone();
let barrier = barrier.clone();
let work_rx = work_rx.clone();
move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let handle = rt.handle();
loop {
// queue up and wait if we want to go another round
if work_rx.lock().unwrap().recv().is_err() {
break;
}
let input = Some(input_factory());
barrier.wait();
execute_all(input, handle, &manager).unwrap();
barrier.wait();
move || loop {
// queue up and wait if we want to go another round
if work_rx.lock().unwrap().recv().is_err() {
break;
}
let input = Some(input_factory());
barrier.wait();
execute_all(input, &manager).unwrap();
barrier.wait();
}
})
})
@@ -167,19 +152,15 @@ impl Drop for JoinOnDrop {
}
}
fn execute_all<I>(
input: I,
handle: &tokio::runtime::Handle,
manager: &PostgresRedoManager,
) -> anyhow::Result<()>
fn execute_all<I>(input: I, manager: &PostgresRedoManager) -> Result<(), WalRedoError>
where
I: IntoIterator<Item = Request>,
{
// just fire all requests as fast as possible
input.into_iter().try_for_each(|req| {
let page = req.execute(handle, manager)?;
let page = req.execute(manager)?;
assert_eq!(page.remaining(), 8192);
anyhow::Ok(())
Ok::<_, WalRedoError>(())
})
}
@@ -492,11 +473,9 @@ struct Request {
}
impl Request {
fn execute(
self,
rt: &tokio::runtime::Handle,
manager: &PostgresRedoManager,
) -> anyhow::Result<Bytes> {
fn execute(self, manager: &PostgresRedoManager) -> Result<Bytes, WalRedoError> {
use pageserver::walredo::WalRedoManager;
let Request {
key,
lsn,
@@ -505,6 +484,6 @@ impl Request {
pg_version,
} = self;
rt.block_on(manager.request_redo(key, lsn, base_img, records, pg_version))
manager.request_redo(key, lsn, base_img, records, pg_version)
}
}

View File

@@ -13,7 +13,6 @@
use anyhow::{anyhow, bail, ensure, Context};
use bytes::{BufMut, BytesMut};
use fail::fail_point;
use postgres_ffi::pg_constants;
use std::fmt::Write as FmtWrite;
use std::time::SystemTime;
use tokio::io;
@@ -181,7 +180,6 @@ where
}
}
let mut min_restart_lsn: Lsn = Lsn::MAX;
// Create tablespace directories
for ((spcnode, dbnode), has_relmap_file) in
self.timeline.list_dbdirs(self.lsn, self.ctx).await?
@@ -215,34 +213,6 @@ where
self.add_rel(rel, rel).await?;
}
}
for (path, content) in self.timeline.list_aux_files(self.lsn, self.ctx).await? {
if path.starts_with("pg_replslot") {
let offs = pg_constants::REPL_SLOT_ON_DISK_OFFSETOF_RESTART_LSN;
let restart_lsn = Lsn(u64::from_le_bytes(
content[offs..offs + 8].try_into().unwrap(),
));
info!("Replication slot {} restart LSN={}", path, restart_lsn);
min_restart_lsn = Lsn::min(min_restart_lsn, restart_lsn);
}
let header = new_tar_header(&path, content.len() as u64)?;
self.ar
.append(&header, &*content)
.await
.context("could not add aux file to basebackup tarball")?;
}
}
if min_restart_lsn != Lsn::MAX {
info!(
"Min restart LSN for logical replication is {}",
min_restart_lsn
);
let data = min_restart_lsn.0.to_le_bytes();
let header = new_tar_header("restart.lsn", data.len() as u64)?;
self.ar
.append(&header, &data[..])
.await
.context("could not add restart.lsn file to basebackup tarball")?;
}
for xid in self
.timeline

View File

@@ -2,7 +2,6 @@
use std::env::{var, VarError};
use std::sync::Arc;
use std::time::Duration;
use std::{env, ops::ControlFlow, str::FromStr};
use anyhow::{anyhow, Context};
@@ -201,51 +200,6 @@ fn initialize_config(
})
}
struct WaitForPhaseResult<F: std::future::Future + Unpin> {
timeout_remaining: Duration,
skipped: Option<F>,
}
/// During startup, we apply a timeout to our waits for readiness, to avoid
/// stalling the whole service if one Tenant experiences some problem. Each
/// phase may consume some of the timeout: this function returns the updated
/// timeout for use in the next call.
async fn wait_for_phase<F>(phase: &str, mut fut: F, timeout: Duration) -> WaitForPhaseResult<F>
where
F: std::future::Future + Unpin,
{
let initial_t = Instant::now();
let skipped = match tokio::time::timeout(timeout, &mut fut).await {
Ok(_) => None,
Err(_) => {
tracing::info!(
timeout_millis = timeout.as_millis(),
%phase,
"Startup phase timed out, proceeding anyway"
);
Some(fut)
}
};
WaitForPhaseResult {
timeout_remaining: timeout
.checked_sub(Instant::now().duration_since(initial_t))
.unwrap_or(Duration::ZERO),
skipped,
}
}
fn startup_checkpoint(started_at: Instant, phase: &str, human_phase: &str) {
let elapsed = started_at.elapsed();
let secs = elapsed.as_secs_f64();
STARTUP_DURATION.with_label_values(&[phase]).set(secs);
info!(
elapsed_ms = elapsed.as_millis(),
"{human_phase} ({secs:.3}s since start)"
)
}
fn start_pageserver(
launch_ts: &'static LaunchTimestamp,
conf: &'static PageServerConf,
@@ -253,6 +207,16 @@ fn start_pageserver(
// Monotonic time for later calculating startup duration
let started_startup_at = Instant::now();
let startup_checkpoint = move |phase: &str, human_phase: &str| {
let elapsed = started_startup_at.elapsed();
let secs = elapsed.as_secs_f64();
STARTUP_DURATION.with_label_values(&[phase]).set(secs);
info!(
elapsed_ms = elapsed.as_millis(),
"{human_phase} ({secs:.3}s since start)"
)
};
// Print version and launch timestamp to the log,
// and expose them as prometheus metrics.
// A changed version string indicates changed software.
@@ -377,7 +341,7 @@ fn start_pageserver(
// Up to this point no significant I/O has been done: this should have been fast. Record
// duration prior to starting I/O intensive phase of startup.
startup_checkpoint(started_startup_at, "initial", "Starting loading tenants");
startup_checkpoint("initial", "Starting loading tenants");
STARTUP_IS_LOADING.set(1);
// Startup staging or optimizing:
@@ -391,7 +355,6 @@ fn start_pageserver(
// consumer side) will be dropped once we can start the background jobs. Currently it is behind
// completing all initial logical size calculations (init_logical_size_done_rx) and a timeout
// (background_task_maximum_delay).
let (init_remote_done_tx, init_remote_done_rx) = utils::completion::channel();
let (init_done_tx, init_done_rx) = utils::completion::channel();
let (init_logical_size_done_tx, init_logical_size_done_rx) = utils::completion::channel();
@@ -399,8 +362,7 @@ fn start_pageserver(
let (background_jobs_can_start, background_jobs_barrier) = utils::completion::channel();
let order = pageserver::InitializationOrder {
initial_tenant_load_remote: Some(init_done_tx),
initial_tenant_load: Some(init_remote_done_tx),
initial_tenant_load: Some(init_done_tx),
initial_logical_size_can_start: init_done_rx.clone(),
initial_logical_size_attempt: Some(init_logical_size_done_tx),
background_jobs_can_start: background_jobs_barrier.clone(),
@@ -424,93 +386,55 @@ fn start_pageserver(
let shutdown_pageserver = shutdown_pageserver.clone();
let drive_init = async move {
// NOTE: unlike many futures in pageserver, this one is cancellation-safe
let guard = scopeguard::guard_on_success((), |_| {
tracing::info!("Cancelled before initial load completed")
});
let guard = scopeguard::guard_on_success((), |_| tracing::info!("Cancelled before initial load completed"));
let timeout = conf.background_task_maximum_delay;
let init_remote_done = std::pin::pin!(async {
init_remote_done_rx.wait().await;
startup_checkpoint(
started_startup_at,
"initial_tenant_load_remote",
"Remote part of initial load completed",
);
});
let WaitForPhaseResult {
timeout_remaining: timeout,
skipped: init_remote_skipped,
} = wait_for_phase("initial_tenant_load_remote", init_remote_done, timeout).await;
let init_load_done = std::pin::pin!(async {
init_done_rx.wait().await;
startup_checkpoint(
started_startup_at,
"initial_tenant_load",
"Initial load completed",
);
STARTUP_IS_LOADING.set(0);
});
let WaitForPhaseResult {
timeout_remaining: timeout,
skipped: init_load_skipped,
} = wait_for_phase("initial_tenant_load", init_load_done, timeout).await;
init_done_rx.wait().await;
startup_checkpoint("initial_tenant_load", "Initial load completed");
STARTUP_IS_LOADING.set(0);
// initial logical sizes can now start, as they were waiting on init_done_rx.
scopeguard::ScopeGuard::into_inner(guard);
let guard = scopeguard::guard_on_success((), |_| {
tracing::info!("Cancelled before initial logical sizes completed")
});
let mut init_sizes_done = std::pin::pin!(init_logical_size_done_rx.wait());
let logical_sizes_done = std::pin::pin!(async {
init_logical_size_done_rx.wait().await;
startup_checkpoint(
started_startup_at,
"initial_logical_sizes",
"Initial logical sizes completed",
);
});
let timeout = conf.background_task_maximum_delay;
let WaitForPhaseResult {
timeout_remaining: _,
skipped: logical_sizes_skipped,
} = wait_for_phase("initial_logical_sizes", logical_sizes_done, timeout).await;
let guard = scopeguard::guard_on_success((), |_| tracing::info!("Cancelled before initial logical sizes completed"));
let init_sizes_done = match tokio::time::timeout(timeout, &mut init_sizes_done).await {
Ok(_) => {
startup_checkpoint("initial_logical_sizes", "Initial logical sizes completed");
None
}
Err(_) => {
tracing::info!(
timeout_millis = timeout.as_millis(),
"Initial logical size timeout elapsed; starting background jobs"
);
Some(init_sizes_done)
}
};
scopeguard::ScopeGuard::into_inner(guard);
// allow background jobs to start: we either completed prior stages, or they reached timeout
// and were skipped. It is important that we do not let them block background jobs indefinitely,
// because things like consumption metrics for billing are blocked by this barrier.
// allow background jobs to start
drop(background_jobs_can_start);
startup_checkpoint(
started_startup_at,
"background_jobs_can_start",
"Starting background jobs",
);
startup_checkpoint("background_jobs_can_start", "Starting background jobs");
// We are done. If we skipped any phases due to timeout, run them to completion here so that
// they will eventually update their startup_checkpoint, and so that we do not declare the
// 'complete' stage until all the other stages are really done.
let guard = scopeguard::guard_on_success((), |_| {
tracing::info!("Cancelled before waiting for skipped phases done")
});
if let Some(f) = init_remote_skipped {
f.await;
}
if let Some(f) = init_load_skipped {
f.await;
}
if let Some(f) = logical_sizes_skipped {
f.await;
}
scopeguard::ScopeGuard::into_inner(guard);
if let Some(init_sizes_done) = init_sizes_done {
// ending up here is not a bug; at the latest logical sizes will be queried by
// consumption metrics.
let guard = scopeguard::guard_on_success((), |_| tracing::info!("Cancelled before initial logical sizes completed"));
init_sizes_done.await;
startup_checkpoint(started_startup_at, "complete", "Startup complete");
scopeguard::ScopeGuard::into_inner(guard);
startup_checkpoint("initial_logical_sizes", "Initial logical sizes completed after timeout (background jobs already started)");
}
startup_checkpoint("complete", "Startup complete");
};
async move {
@@ -650,7 +574,6 @@ fn start_pageserver(
pageserver_listener,
conf.pg_auth_type,
libpq_ctx,
task_mgr::shutdown_token(),
)
.await
},

View File

@@ -1479,8 +1479,6 @@ threshold = "20m"
Some(DiskUsageEvictionTaskConfig {
max_usage_pct: Percent::new(80).unwrap(),
min_avail_bytes: 0,
target_avail_bytes: None,
target_usage_pct: None,
period: Duration::from_secs(10),
#[cfg(feature = "testing")]
mock_statvfs: None,

View File

@@ -2,7 +2,6 @@
//! and push them to a HTTP endpoint.
use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::{self, TaskKind, BACKGROUND_RUNTIME};
use crate::tenant::tasks::BackgroundLoopKind;
use crate::tenant::{mgr, LogicalSizeCalculationCause};
use camino::Utf8PathBuf;
use consumption_metrics::EventType;
@@ -11,7 +10,6 @@ use reqwest::Url;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::time::Instant;
use tracing::*;
use utils::id::NodeId;
@@ -89,12 +87,22 @@ pub async fn collect_metrics(
let node_id = node_id.to_string();
// reminder: ticker is ready immediatedly
let mut ticker = tokio::time::interval(metric_collection_interval);
loop {
let started_at = Instant::now();
let tick_at = tokio::select! {
_ = cancel.cancelled() => return Ok(()),
tick_at = ticker.tick() => tick_at,
};
// these are point in time, with variable "now"
let metrics = metrics::collect_all_metrics(&cached_metrics, &ctx).await;
if metrics.is_empty() {
continue;
}
let metrics = Arc::new(metrics);
// why not race cancellation here? because we are one of the last tasks, and if we are
@@ -133,19 +141,10 @@ pub async fn collect_metrics(
let (_, _) = tokio::join!(flush, upload);
crate::tenant::tasks::warn_when_period_overrun(
started_at.elapsed(),
tick_at.elapsed(),
metric_collection_interval,
BackgroundLoopKind::ConsumptionMetricsCollectMetrics,
"consumption_metrics_collect_metrics",
);
let res = tokio::time::timeout_at(
started_at + metric_collection_interval,
task_mgr::shutdown_token().cancelled(),
)
.await;
if res.is_ok() {
return Ok(());
}
}
}
@@ -244,14 +243,16 @@ async fn calculate_synthetic_size_worker(
ctx: &RequestContext,
) -> anyhow::Result<()> {
info!("starting calculate_synthetic_size_worker");
scopeguard::defer! {
info!("calculate_synthetic_size_worker stopped");
};
// reminder: ticker is ready immediatedly
let mut ticker = tokio::time::interval(synthetic_size_calculation_interval);
let cause = LogicalSizeCalculationCause::ConsumptionMetricsSyntheticSize;
loop {
let started_at = Instant::now();
let tick_at = tokio::select! {
_ = task_mgr::shutdown_watcher() => return Ok(()),
tick_at = ticker.tick() => tick_at,
};
let tenants = match mgr::list_tenants().await {
Ok(tenants) => tenants,
@@ -267,11 +268,6 @@ async fn calculate_synthetic_size_worker(
}
if let Ok(tenant) = mgr::get_tenant(tenant_id, true).await {
// TODO should we use concurrent_background_tasks_rate_limit() here, like the other background tasks?
// We can put in some prioritization for consumption metrics.
// Same for the loop that fetches computed metrics.
// By using the same limiter, we centralize metrics collection for "start" and "finished" counters,
// which turns out is really handy to understand the system.
if let Err(e) = tenant.calculate_synthetic_size(cause, ctx).await {
error!("failed to calculate synthetic size for tenant {tenant_id}: {e:#}");
}
@@ -279,18 +275,9 @@ async fn calculate_synthetic_size_worker(
}
crate::tenant::tasks::warn_when_period_overrun(
started_at.elapsed(),
tick_at.elapsed(),
synthetic_size_calculation_interval,
BackgroundLoopKind::ConsumptionMetricsSyntheticSizeWorker,
"consumption_metrics_synthetic_size_worker",
);
let res = tokio::time::timeout_at(
started_at + synthetic_size_calculation_interval,
task_mgr::shutdown_token().cancelled(),
)
.await;
if res.is_ok() {
return Ok(());
}
}
}

View File

@@ -153,7 +153,7 @@ impl FlushOp {
#[derive(Clone, Debug)]
pub struct DeletionQueueClient {
tx: tokio::sync::mpsc::UnboundedSender<ListWriterQueueMessage>,
tx: tokio::sync::mpsc::Sender<ListWriterQueueMessage>,
executor_tx: tokio::sync::mpsc::Sender<DeleterMessage>,
lsn_table: Arc<std::sync::RwLock<VisibleLsnUpdates>>,
@@ -212,7 +212,7 @@ where
/// Files ending with this suffix will be ignored and erased
/// during recovery as startup.
const TEMP_SUFFIX: &str = "tmp";
const TEMP_SUFFIX: &str = ".tmp";
#[serde_as]
#[derive(Debug, Serialize, Deserialize)]
@@ -416,7 +416,7 @@ pub enum DeletionQueueError {
impl DeletionQueueClient {
pub(crate) fn broken() -> Self {
// Channels whose receivers are immediately dropped.
let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
let (tx, _rx) = tokio::sync::mpsc::channel(1);
let (executor_tx, _executor_rx) = tokio::sync::mpsc::channel(1);
Self {
tx,
@@ -428,12 +428,12 @@ impl DeletionQueueClient {
/// This is cancel-safe. If you drop the future before it completes, the message
/// is not pushed, although in the context of the deletion queue it doesn't matter: once
/// we decide to do a deletion the decision is always final.
fn do_push<T>(
async fn do_push<T>(
&self,
queue: &tokio::sync::mpsc::UnboundedSender<T>,
queue: &tokio::sync::mpsc::Sender<T>,
msg: T,
) -> Result<(), DeletionQueueError> {
match queue.send(msg) {
match queue.send(msg).await {
Ok(_) => Ok(()),
Err(e) => {
// This shouldn't happen, we should shut down all tenants before
@@ -445,7 +445,7 @@ impl DeletionQueueClient {
}
}
pub(crate) fn recover(
pub(crate) async fn recover(
&self,
attached_tenants: HashMap<TenantId, Generation>,
) -> Result<(), DeletionQueueError> {
@@ -453,6 +453,7 @@ impl DeletionQueueClient {
&self.tx,
ListWriterQueueMessage::Recover(RecoverOp { attached_tenants }),
)
.await
}
/// When a Timeline wishes to update the remote_consistent_lsn that it exposes to the outside
@@ -525,21 +526,6 @@ impl DeletionQueueClient {
return self.flush_immediate().await;
}
self.push_layers_sync(tenant_id, timeline_id, current_generation, layers)
}
/// When a Tenant has a generation, push_layers is always synchronous because
/// the ListValidator channel is an unbounded channel.
///
/// This can be merged into push_layers when we remove the Generation-less mode
/// support (`<https://github.com/neondatabase/neon/issues/5395>`)
pub(crate) fn push_layers_sync(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
current_generation: Generation,
layers: Vec<(LayerFileName, Generation)>,
) -> Result<(), DeletionQueueError> {
metrics::DELETION_QUEUE
.keys_submitted
.inc_by(layers.len() as u64);
@@ -553,16 +539,17 @@ impl DeletionQueueClient {
objects: Vec::new(),
}),
)
.await
}
/// This is cancel-safe. If you drop the future the flush may still happen in the background.
async fn do_flush<T>(
&self,
queue: &tokio::sync::mpsc::UnboundedSender<T>,
queue: &tokio::sync::mpsc::Sender<T>,
msg: T,
rx: tokio::sync::oneshot::Receiver<()>,
) -> Result<(), DeletionQueueError> {
self.do_push(queue, msg)?;
self.do_push(queue, msg).await?;
if rx.await.is_err() {
// This shouldn't happen if tenants are shut down before deletion queue. If we
// encounter a bug like this, then a flusher will incorrectly believe it has flushed
@@ -583,18 +570,6 @@ impl DeletionQueueClient {
.await
}
/// Issue a flush without waiting for it to complete. This is useful on advisory flushes where
/// the caller wants to avoid the risk of waiting for lots of enqueued work, such as on tenant
/// detach where flushing is nice but not necessary.
///
/// This function provides no guarantees of work being done.
pub fn flush_advisory(&self) {
let (flush_op, _) = FlushOp::new();
// Transmit the flush message, ignoring any result (such as a closed channel during shutdown).
drop(self.tx.send(ListWriterQueueMessage::FlushExecute(flush_op)));
}
// Wait until all previous deletions are executed
pub(crate) async fn flush_execute(&self) -> Result<(), DeletionQueueError> {
debug!("flush_execute: flushing to deletion lists...");
@@ -611,7 +586,9 @@ impl DeletionQueueClient {
// Flush any immediate-mode deletions (the above backend flush will only flush
// the executor if deletions had flowed through the backend)
debug!("flush_execute: flushing execution...");
self.flush_immediate().await?;
let (flush_op, rx) = FlushOp::new();
self.do_flush(&self.executor_tx, DeleterMessage::Flush(flush_op), rx)
.await?;
debug!("flush_execute: finished flushing execution...");
Ok(())
}
@@ -666,10 +643,8 @@ impl DeletionQueue {
where
C: ControlPlaneGenerationsApi + Send + Sync,
{
// Unbounded channel: enables non-async functions to submit deletions. The actual length is
// constrained by how promptly the ListWriter wakes up and drains it, which should be frequent
// enough to avoid this taking pathologically large amount of memory.
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
// Deep channel: it consumes deletions from all timelines and we do not want to block them
let (tx, rx) = tokio::sync::mpsc::channel(16384);
// Shallow channel: it carries DeletionLists which each contain up to thousands of deletions
let (backend_tx, backend_rx) = tokio::sync::mpsc::channel(16);
@@ -982,7 +957,7 @@ mod test {
// Basic test that the deletion queue processes the deletions we pass into it
let ctx = setup("deletion_queue_smoke").expect("Failed test setup");
let client = ctx.deletion_queue.new_client();
client.recover(HashMap::new())?;
client.recover(HashMap::new()).await?;
let layer_file_name_1: LayerFileName = "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap();
let tenant_id = ctx.harness.tenant_id;
@@ -1050,7 +1025,7 @@ mod test {
async fn deletion_queue_validation() -> anyhow::Result<()> {
let ctx = setup("deletion_queue_validation").expect("Failed test setup");
let client = ctx.deletion_queue.new_client();
client.recover(HashMap::new())?;
client.recover(HashMap::new()).await?;
// Generation that the control plane thinks is current
let latest_generation = Generation::new(0xdeadbeef);
@@ -1107,7 +1082,7 @@ mod test {
// Basic test that the deletion queue processes the deletions we pass into it
let mut ctx = setup("deletion_queue_recovery").expect("Failed test setup");
let client = ctx.deletion_queue.new_client();
client.recover(HashMap::new())?;
client.recover(HashMap::new()).await?;
let tenant_id = ctx.harness.tenant_id;
@@ -1170,7 +1145,9 @@ mod test {
drop(client);
ctx.restart().await;
let client = ctx.deletion_queue.new_client();
client.recover(HashMap::from([(tenant_id, now_generation)]))?;
client
.recover(HashMap::from([(tenant_id, now_generation)]))
.await?;
info!("Flush-executing");
client.flush_execute().await?;
@@ -1196,7 +1173,7 @@ pub(crate) mod mock {
};
pub struct ConsumerState {
rx: tokio::sync::mpsc::UnboundedReceiver<ListWriterQueueMessage>,
rx: tokio::sync::mpsc::Receiver<ListWriterQueueMessage>,
executor_rx: tokio::sync::mpsc::Receiver<DeleterMessage>,
}
@@ -1273,7 +1250,7 @@ pub(crate) mod mock {
}
pub struct MockDeletionQueue {
tx: tokio::sync::mpsc::UnboundedSender<ListWriterQueueMessage>,
tx: tokio::sync::mpsc::Sender<ListWriterQueueMessage>,
executor_tx: tokio::sync::mpsc::Sender<DeleterMessage>,
executed: Arc<AtomicUsize>,
remote_storage: Option<GenericRemoteStorage>,
@@ -1283,7 +1260,7 @@ pub(crate) mod mock {
impl MockDeletionQueue {
pub fn new(remote_storage: Option<GenericRemoteStorage>) -> Self {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let (tx, rx) = tokio::sync::mpsc::channel(16384);
let (executor_tx, executor_rx) = tokio::sync::mpsc::channel(16384);
let executed = Arc::new(AtomicUsize::new(0));
@@ -1298,6 +1275,10 @@ pub(crate) mod mock {
}
}
pub fn get_executed(&self) -> usize {
self.executed.load(Ordering::Relaxed)
}
#[allow(clippy::await_holding_lock)]
pub async fn pump(&self) {
if let Some(remote_storage) = &self.remote_storage {

View File

@@ -13,7 +13,6 @@ use std::time::Duration;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::warn;
use utils::backoff;
use crate::metrics;
@@ -64,19 +63,7 @@ impl Deleter {
Err(anyhow::anyhow!("failpoint hit"))
});
// A backoff::retry is used here for two reasons:
// - To provide a backoff rather than busy-polling the API on errors
// - To absorb transient 429/503 conditions without hitting our error
// logging path for issues deleting objects.
backoff::retry(
|| async { self.remote_storage.delete_objects(&self.accumulator).await },
|_| false,
3,
10,
"executing deletion batch",
backoff::Cancel::new(self.cancel.clone(), || anyhow::anyhow!("Shutting down")),
)
.await
self.remote_storage.delete_objects(&self.accumulator).await
}
/// Block until everything in accumulator has been executed
@@ -101,10 +88,7 @@ impl Deleter {
self.accumulator.clear();
}
Err(e) => {
if self.cancel.is_cancelled() {
return Err(DeletionQueueError::ShuttingDown);
}
warn!("DeleteObjects request failed: {e:#}, will continue trying");
warn!("DeleteObjects request failed: {e:#}, will retry");
metrics::DELETION_QUEUE
.remote_errors
.with_label_values(&["execute"])

View File

@@ -85,7 +85,7 @@ pub(super) struct ListWriter {
conf: &'static PageServerConf,
// Incoming frontend requests to delete some keys
rx: tokio::sync::mpsc::UnboundedReceiver<ListWriterQueueMessage>,
rx: tokio::sync::mpsc::Receiver<ListWriterQueueMessage>,
// Outbound requests to the backend to execute deletion lists we have composed.
tx: tokio::sync::mpsc::Sender<ValidatorQueueMessage>,
@@ -111,7 +111,7 @@ impl ListWriter {
pub(super) fn new(
conf: &'static PageServerConf,
rx: tokio::sync::mpsc::UnboundedReceiver<ListWriterQueueMessage>,
rx: tokio::sync::mpsc::Receiver<ListWriterQueueMessage>,
tx: tokio::sync::mpsc::Sender<ValidatorQueueMessage>,
cancel: CancellationToken,
) -> Self {
@@ -230,7 +230,6 @@ impl ListWriter {
let list_name_pattern =
Regex::new("(?<sequence>[a-zA-Z0-9]{16})-(?<version>[a-zA-Z0-9]{2}).list").unwrap();
let temp_extension = format!(".{TEMP_SUFFIX}");
let header_path = self.conf.deletion_header_path();
let mut seqs: Vec<u64> = Vec::new();
while let Some(dentry) = dir.next_entry().await? {
@@ -242,7 +241,7 @@ impl ListWriter {
continue;
}
if dentry_str.ends_with(&temp_extension) {
if dentry_str.ends_with(TEMP_SUFFIX) {
info!("Cleaning up temporary file {dentry_str}");
let absolute_path =
deletion_directory.join(dentry.file_name().to_str().expect("non-Unicode path"));

View File

@@ -67,40 +67,16 @@ use crate::{
pub struct DiskUsageEvictionTaskConfig {
pub max_usage_pct: Percent,
pub min_avail_bytes: u64,
// Control how far we will go when evicting: when usage exceeds max_usage_pct or min_avail_bytes,
// we will keep evicting layers until we reach the target. The resulting disk usage should look
// like a sawtooth bouncing between the upper max/min line and the lower target line.
#[serde(default)]
pub target_usage_pct: Option<Percent>,
#[serde(default)]
pub target_avail_bytes: Option<u64>,
#[serde(with = "humantime_serde")]
pub period: Duration,
#[cfg(feature = "testing")]
pub mock_statvfs: Option<crate::statvfs::mock::Behavior>,
}
#[derive(Default)]
enum Status {
/// We are within disk limits, and not currently doing any eviction
#[default]
Idle,
/// Disk limits have been exceeded: we will evict soon
UnderPressure,
/// We are currently doing an eviction pass.
Evicting,
}
#[derive(Default)]
pub struct State {
/// Exclude http requests and background task from running at the same time.
mutex: tokio::sync::Mutex<()>,
/// Publish the current status of eviction work, for visibility to other subsystems
/// that modify their behavior if disk pressure is high or if eviction is going on.
status: std::sync::RwLock<Status>,
}
pub fn launch_disk_usage_global_eviction_task(
@@ -200,9 +176,7 @@ async fn disk_usage_eviction_task(
}
pub trait Usage: Clone + Copy + std::fmt::Debug {
fn pressure(&self) -> f64;
fn over_pressure(&self) -> bool;
fn no_pressure(&self) -> bool;
fn has_pressure(&self) -> bool;
fn add_available_bytes(&mut self, bytes: u64);
}
@@ -215,19 +189,13 @@ async fn disk_usage_eviction_task_iteration(
) -> anyhow::Result<()> {
let usage_pre = filesystem_level_usage::get(tenants_dir, task_config)
.context("get filesystem-level disk usage before evictions")?;
if usage_pre.over_pressure() {
*state.status.write().unwrap() = Status::Evicting;
}
let res = disk_usage_eviction_task_iteration_impl(state, storage, usage_pre, cancel).await;
match res {
Ok(outcome) => {
debug!(?outcome, "disk_usage_eviction_iteration finished");
let new_status = match outcome {
match outcome {
IterationOutcome::NoPressure | IterationOutcome::Cancelled => {
// nothing to do, select statement below will handle things
Status::Idle
}
IterationOutcome::Finished(outcome) => {
// Verify with statvfs whether we made any real progress
@@ -237,30 +205,21 @@ async fn disk_usage_eviction_task_iteration(
debug!(?after, "disk usage");
if after.over_pressure() {
if after.has_pressure() {
// Don't bother doing an out-of-order iteration here now.
// In practice, the task period is set to a value in the tens-of-seconds range,
// which will cause another iteration to happen soon enough.
// TODO: deltas between the three different usages would be helpful,
// consider MiB, GiB, TiB
warn!(?outcome, ?after, "disk usage still high");
Status::UnderPressure
} else {
info!(?outcome, ?after, "disk usage pressure relieved");
Status::Idle
}
}
};
*state.status.write().unwrap() = new_status;
}
}
Err(e) => {
error!("disk_usage_eviction_iteration failed: {:#}", e);
*state.status.write().unwrap() = if usage_pre.over_pressure() {
Status::UnderPressure
} else {
Status::Idle
};
}
}
@@ -326,10 +285,8 @@ pub async fn disk_usage_eviction_task_iteration_impl<U: Usage>(
debug!(?usage_pre, "disk usage");
if !usage_pre.over_pressure() {
if !usage_pre.has_pressure() {
return Ok(IterationOutcome::NoPressure);
} else {
*state.status.write().unwrap() = Status::Evicting;
}
warn!(
@@ -377,7 +334,7 @@ pub async fn disk_usage_eviction_task_iteration_impl<U: Usage>(
let mut warned = None;
let mut usage_planned = usage_pre;
for (i, (partition, candidate)) in candidates.into_iter().enumerate() {
if usage_planned.no_pressure() {
if !usage_planned.has_pressure() {
debug!(
no_candidates_evicted = i,
"took enough candidates for pressure to be relieved"
@@ -454,11 +411,6 @@ pub async fn disk_usage_eviction_task_iteration_impl<U: Usage>(
evictions_failed.file_sizes += file_size;
evictions_failed.count += 1;
}
Some(Err(EvictionError::MetadataInconsistency(detail))) => {
warn!(%layer, "failed to evict layer: {detail}");
evictions_failed.file_sizes += file_size;
evictions_failed.count += 1;
}
None => {
assert!(cancel.is_cancelled());
return;
@@ -687,57 +639,22 @@ mod filesystem_level_usage {
}
impl super::Usage for Usage<'_> {
/// Does the pressure exceed 1.0, i.e. has the disk usage exceeded upper bounds?
///
/// This is the condition for starting eviction.
fn over_pressure(&self) -> bool {
self.pressure() >= 1.0
}
fn has_pressure(&self) -> bool {
let usage_pct =
(100.0 * (1.0 - ((self.avail_bytes as f64) / (self.total_bytes as f64)))) as u64;
/// Is the pressure <0, ie.. has disk usage gone below the target bound?
///
/// This is the condition for dropping out of eviction.
fn no_pressure(&self) -> bool {
self.pressure() <= 0.0
}
let pressures = [
(
"min_avail_bytes",
self.avail_bytes < self.config.min_avail_bytes,
),
(
"max_usage_pct",
usage_pct >= self.config.max_usage_pct.get() as u64,
),
];
fn pressure(&self) -> f64 {
let max_usage = std::cmp::min(
self.total_bytes - self.config.min_avail_bytes,
(self.total_bytes as f64 * (self.config.max_usage_pct.get() as f64 / 100.0)) as u64,
);
let mut target_usage = max_usage;
if let Some(target_avail_bytes) = self.config.target_avail_bytes {
target_usage = std::cmp::min(target_usage, self.total_bytes - target_avail_bytes);
}
if let Some(target_usage_pct) = self.config.target_usage_pct {
target_usage = std::cmp::min(
target_usage,
(self.total_bytes as f64 * (target_usage_pct.get() as f64 / 100.0)) as u64,
);
};
let usage = self.total_bytes - self.avail_bytes;
eprintln!(
"pressure: {} {}, current {}",
target_usage, max_usage, usage
);
if target_usage == max_usage {
// We are configured with a zero sized range: treat anything at+beyond limit as pressure 1.0, else 0.0
if usage >= max_usage {
1.0
} else {
0.0
}
} else if usage <= target_usage {
// No pressure.
0.0
} else {
// We are above target: pressure is the ratio of how much we exceed target to the size of the gap
let range_size = (max_usage - target_usage) as f64;
(usage - target_usage) as f64 / range_size
}
pressures.into_iter().any(|(_, has_pressure)| has_pressure)
}
fn add_available_bytes(&mut self, bytes: u64) {
@@ -791,8 +708,6 @@ mod filesystem_level_usage {
config: &DiskUsageEvictionTaskConfig {
max_usage_pct: Percent::new(85).unwrap(),
min_avail_bytes: 0,
target_avail_bytes: None,
target_usage_pct: None,
period: Duration::MAX,
#[cfg(feature = "testing")]
mock_statvfs: None,
@@ -801,24 +716,24 @@ mod filesystem_level_usage {
avail_bytes: 0,
};
assert!(usage.over_pressure(), "expected pressure at 100%");
assert!(usage.has_pressure(), "expected pressure at 100%");
usage.add_available_bytes(14_000);
assert!(usage.over_pressure(), "expected pressure at 86%");
assert!(usage.has_pressure(), "expected pressure at 86%");
usage.add_available_bytes(999);
assert!(usage.over_pressure(), "expected pressure at 85.001%");
assert!(usage.has_pressure(), "expected pressure at 85.001%");
usage.add_available_bytes(1);
assert!(usage.over_pressure(), "expected pressure at precisely 85%");
assert!(usage.has_pressure(), "expected pressure at precisely 85%");
usage.add_available_bytes(1);
assert!(!usage.over_pressure(), "no pressure at 84.999%");
assert!(!usage.has_pressure(), "no pressure at 84.999%");
usage.add_available_bytes(999);
assert!(!usage.over_pressure(), "no pressure at 84%");
assert!(!usage.has_pressure(), "no pressure at 84%");
usage.add_available_bytes(16_000);
assert!(!usage.over_pressure());
assert!(!usage.has_pressure());
}
}

View File

@@ -306,67 +306,6 @@ paths:
schema:
$ref: "#/components/schemas/ServiceUnavailableError"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/get_timestamp_of_lsn:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
get:
description: Get timestamp for a given LSN
parameters:
- name: lsn
in: query
required: true
schema:
type: integer
description: A LSN to get the timestamp
responses:
"200":
description: OK
content:
application/json:
schema:
type: string
format: date-time
"400":
description: Error when no tenant id found in path, no timeline id or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"404":
description: Timeline not found, or there is no timestamp information for the given lsn
content:
application/json:
schema:
$ref: "#/components/schemas/NotFoundError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/get_lsn_by_timestamp:
parameters:

View File

@@ -2,12 +2,10 @@
//! Management HTTP API
//!
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use anyhow::{anyhow, Context, Result};
use futures::TryFutureExt;
use humantime::format_rfc3339;
use hyper::header::CONTENT_TYPE;
use hyper::StatusCode;
use hyper::{Body, Request, Response, Uri};
@@ -79,7 +77,7 @@ impl State {
disk_usage_eviction_state: Arc<disk_usage_eviction_task::State>,
deletion_queue_client: DeletionQueueClient,
) -> anyhow::Result<Self> {
let allowlist_routes = ["/v1/status", "/v1/doc", "/swagger.yml", "/metrics"]
let allowlist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"]
.iter()
.map(|v| v.parse().unwrap())
.collect::<Vec<_>>();
@@ -138,7 +136,9 @@ impl From<PageReconstructError> for ApiError {
PageReconstructError::AncestorStopping(_) => {
ApiError::ResourceUnavailable(format!("{pre}").into())
}
PageReconstructError::WalRedo(pre) => ApiError::InternalServerError(pre),
PageReconstructError::WalRedo(pre) => {
ApiError::InternalServerError(anyhow::Error::new(pre))
}
}
}
}
@@ -164,6 +164,9 @@ impl From<TenantStateError> for ApiError {
fn from(tse: TenantStateError) -> ApiError {
match tse {
TenantStateError::NotFound(tid) => ApiError::NotFound(anyhow!("tenant {}", tid).into()),
TenantStateError::NotActive(_) => {
ApiError::ResourceUnavailable("Tenant not yet active".into())
}
TenantStateError::IsStopping(_) => {
ApiError::ResourceUnavailable("Tenant is stopping".into())
}
@@ -393,9 +396,6 @@ async fn timeline_create_handler(
format!("{err:#}")
))
}
Err(e @ tenant::CreateTimelineError::AncestorNotActive) => {
json_response(StatusCode::SERVICE_UNAVAILABLE, HttpErrorBody::from_msg(e.to_string()))
}
Err(tenant::CreateTimelineError::Other(err)) => Err(ApiError::InternalServerError(err)),
}
}
@@ -504,33 +504,6 @@ async fn get_lsn_by_timestamp_handler(
json_response(StatusCode::OK, result)
}
async fn get_timestamp_of_lsn_handler(
request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
check_permission(&request, Some(tenant_id))?;
let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?;
let lsn_str = must_get_query_param(&request, "lsn")?;
let lsn = Lsn::from_str(&lsn_str)
.with_context(|| format!("Invalid LSN: {lsn_str:?}"))
.map_err(ApiError::BadRequest)?;
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
let result = timeline.get_timestamp_for_lsn(lsn, &ctx).await?;
match result {
Some(time) => {
let time = format_rfc3339(postgres_ffi::from_pg_timestamp(time)).to_string();
json_response(StatusCode::OK, time)
}
None => json_response(StatusCode::NOT_FOUND, ()),
}
}
async fn tenant_attach_handler(
mut request: Request<Body>,
_cancel: CancellationToken,
@@ -599,14 +572,9 @@ async fn tenant_detach_handler(
let state = get_state(&request);
let conf = state.conf;
mgr::detach_tenant(
conf,
tenant_id,
detach_ignored.unwrap_or(false),
&state.deletion_queue_client,
)
.instrument(info_span!("tenant_detach", %tenant_id))
.await?;
mgr::detach_tenant(conf, tenant_id, detach_ignored.unwrap_or(false))
.instrument(info_span!("tenant_detach", %tenant_id))
.await?;
json_response(StatusCode::OK, ())
}
@@ -1063,17 +1031,9 @@ async fn put_tenant_location_config_handler(
// The `Detached` state is special, it doesn't upsert a tenant, it removes
// its local disk content and drops it from memory.
if let LocationConfigMode::Detached = request_data.config.mode {
if let Err(e) = mgr::detach_tenant(conf, tenant_id, true, &state.deletion_queue_client)
mgr::detach_tenant(conf, tenant_id, true)
.instrument(info_span!("tenant_detach", %tenant_id))
.await
{
match e {
TenantStateError::NotFound(_) => {
// This API is idempotent: a NotFound on a detach is fine.
}
_ => return Err(e.into()),
}
}
.await?;
return json_response(StatusCode::OK, ());
}
@@ -1452,22 +1412,10 @@ async fn disk_usage_eviction_run(
}
impl crate::disk_usage_eviction_task::Usage for Usage {
fn over_pressure(&self) -> bool {
fn has_pressure(&self) -> bool {
self.config.evict_bytes > self.freed_bytes
}
fn no_pressure(&self) -> bool {
!self.over_pressure()
}
fn pressure(&self) -> f64 {
if self.over_pressure() {
1.0
} else {
0.0
}
}
fn add_available_bytes(&mut self, bytes: u64) {
self.freed_bytes += bytes;
}
@@ -1721,10 +1669,6 @@ pub fn make_router(
"/v1/tenant/:tenant_id/timeline/:timeline_id/get_lsn_by_timestamp",
|r| api_handler(r, get_lsn_by_timestamp_handler),
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/get_timestamp_of_lsn",
|r| api_handler(r, get_timestamp_of_lsn_handler),
)
.put("/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc", |r| {
api_handler(r, timeline_gc_handler)
})

View File

@@ -173,9 +173,6 @@ fn is_walkdir_io_not_found(e: &walkdir::Error) -> bool {
/// delaying is needed.
#[derive(Clone)]
pub struct InitializationOrder {
/// Each initial tenant load task carries this until it is done loading timelines from remote storage
pub initial_tenant_load_remote: Option<utils::completion::Completion>,
/// Each initial tenant load task carries this until completion.
pub initial_tenant_load: Option<utils::completion::Completion>,

View File

@@ -1067,26 +1067,6 @@ pub(crate) static TENANT_TASK_EVENTS: Lazy<IntCounterVec> = Lazy::new(|| {
.expect("Failed to register tenant_task_events metric")
});
pub(crate) static BACKGROUND_LOOP_SEMAPHORE_WAIT_START_COUNT: Lazy<IntCounterVec> =
Lazy::new(|| {
register_int_counter_vec!(
"pageserver_background_loop_semaphore_wait_start_count",
"Counter for background loop concurrency-limiting semaphore acquire calls started",
&["task"],
)
.unwrap()
});
pub(crate) static BACKGROUND_LOOP_SEMAPHORE_WAIT_FINISH_COUNT: Lazy<IntCounterVec> =
Lazy::new(|| {
register_int_counter_vec!(
"pageserver_background_loop_semaphore_wait_finish_count",
"Counter for background loop concurrency-limiting semaphore acquire calls finished",
&["task"],
)
.unwrap()
});
pub(crate) static BACKGROUND_LOOP_PERIOD_OVERRUN_COUNT: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_background_loop_period_overrun_count",

View File

@@ -318,6 +318,15 @@ impl std::ops::Deref for PageWriteGuard<'_> {
}
}
impl AsMut<[u8; PAGE_SZ]> for PageWriteGuard<'_> {
fn as_mut(&mut self) -> &mut [u8; PAGE_SZ] {
match &mut self.state {
PageWriteGuardState::Invalid { inner, _permit } => inner.buf,
PageWriteGuardState::Downgraded => unreachable!(),
}
}
}
impl<'a> PageWriteGuard<'a> {
/// Mark that the buffer contents are now valid.
#[must_use]

View File

@@ -35,7 +35,6 @@ use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::io::StreamReader;
use tokio_util::sync::CancellationToken;
use tracing::field;
use tracing::*;
use utils::id::ConnectionId;
@@ -65,6 +64,69 @@ use crate::trace::Tracer;
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
use postgres_ffi::BLCKSZ;
fn copyin_stream<IO>(pgb: &mut PostgresBackend<IO>) -> impl Stream<Item = io::Result<Bytes>> + '_
where
IO: AsyncRead + AsyncWrite + Unpin,
{
async_stream::try_stream! {
loop {
let msg = tokio::select! {
biased;
_ = task_mgr::shutdown_watcher() => {
// We were requested to shut down.
let msg = "pageserver is shutting down";
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, None));
Err(QueryError::Other(anyhow::anyhow!(msg)))
}
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
};
match msg {
Ok(Some(message)) => {
let copy_data_bytes = match message {
FeMessage::CopyData(bytes) => bytes,
FeMessage::CopyDone => { break },
FeMessage::Sync => continue,
FeMessage::Terminate => {
let msg = "client terminated connection with Terminate message during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
break;
}
m => {
let msg = format!("unexpected message {m:?}");
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
Err(io::Error::new(io::ErrorKind::Other, msg))?;
break;
}
};
yield copy_data_bytes;
}
Ok(None) => {
let msg = "client closed connection during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
pgb.flush().await?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(io_error)?;
}
Err(other) => {
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
}
};
}
}
}
/// Read the end of a tar archive.
///
/// A tar archive normally ends with two consecutive blocks of zeros, 512 bytes each.
@@ -122,7 +184,6 @@ pub async fn libpq_listener_main(
listener: TcpListener,
auth_type: AuthType,
listener_ctx: RequestContext,
cancel: CancellationToken,
) -> anyhow::Result<()> {
listener.set_nonblocking(true)?;
let tokio_listener = tokio::net::TcpListener::from_std(listener)?;
@@ -131,7 +192,7 @@ pub async fn libpq_listener_main(
while let Some(res) = tokio::select! {
biased;
_ = cancel.cancelled() => {
_ = task_mgr::shutdown_watcher() => {
// We were requested to shut down.
None
}
@@ -223,13 +284,7 @@ async fn page_service_conn_main(
// and create a child per-query context when it invokes process_query.
// But it's in a shared crate, so, we store connection_ctx inside PageServerHandler
// and create the per-query context in process_query ourselves.
let mut conn_handler = PageServerHandler::new(
conf,
broker_client,
auth,
connection_ctx,
task_mgr::shutdown_token(),
);
let mut conn_handler = PageServerHandler::new(conf, broker_client, auth, connection_ctx);
let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?;
match pgbackend
@@ -263,10 +318,6 @@ struct PageServerHandler {
/// For each query received over the connection,
/// `process_query` creates a child context from this one.
connection_ctx: RequestContext,
/// A token that should fire when the tenant transitions from
/// attached state, or when the pageserver is shutting down.
cancel: CancellationToken,
}
impl PageServerHandler {
@@ -275,7 +326,6 @@ impl PageServerHandler {
broker_client: storage_broker::BrokerClientChannel,
auth: Option<Arc<JwtAuth>>,
connection_ctx: RequestContext,
cancel: CancellationToken,
) -> Self {
PageServerHandler {
_conf: conf,
@@ -283,91 +333,6 @@ impl PageServerHandler {
auth,
claims: None,
connection_ctx,
cancel,
}
}
/// Wrap PostgresBackend::flush to respect our CancellationToken: it is important to use
/// this rather than naked flush() in order to shut down promptly. Without this, we would
/// block shutdown of a tenant if a postgres client was failing to consume bytes we send
/// in the flush.
async fn flush_cancellable<IO>(&self, pgb: &mut PostgresBackend<IO>) -> Result<(), QueryError>
where
IO: AsyncRead + AsyncWrite + Send + Sync + Unpin,
{
tokio::select!(
flush_r = pgb.flush() => {
Ok(flush_r?)
},
_ = self.cancel.cancelled() => {
Err(QueryError::Shutdown)
}
)
}
fn copyin_stream<'a, IO>(
&'a self,
pgb: &'a mut PostgresBackend<IO>,
) -> impl Stream<Item = io::Result<Bytes>> + 'a
where
IO: AsyncRead + AsyncWrite + Send + Sync + Unpin,
{
async_stream::try_stream! {
loop {
let msg = tokio::select! {
biased;
_ = self.cancel.cancelled() => {
// We were requested to shut down.
let msg = "pageserver is shutting down";
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, None));
Err(QueryError::Shutdown)
}
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
};
match msg {
Ok(Some(message)) => {
let copy_data_bytes = match message {
FeMessage::CopyData(bytes) => bytes,
FeMessage::CopyDone => { break },
FeMessage::Sync => continue,
FeMessage::Terminate => {
let msg = "client terminated connection with Terminate message during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
break;
}
m => {
let msg = format!("unexpected message {m:?}");
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
Err(io::Error::new(io::ErrorKind::Other, msg))?;
break;
}
};
yield copy_data_bytes;
}
Ok(None) => {
let msg = "client closed connection during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
self.flush_cancellable(pgb).await.map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(io_error)?;
}
Err(other) => {
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
}
};
}
}
}
@@ -407,7 +372,7 @@ impl PageServerHandler {
// switch client to COPYBOTH
pgb.write_message_noflush(&BeMessage::CopyBothResponse)?;
self.flush_cancellable(pgb).await?;
pgb.flush().await?;
let metrics = metrics::SmgrQueryTimePerTimeline::new(&tenant_id, &timeline_id);
@@ -415,10 +380,10 @@ impl PageServerHandler {
let msg = tokio::select! {
biased;
_ = self.cancel.cancelled() => {
_ = task_mgr::shutdown_watcher() => {
// We were requested to shut down.
info!("shutdown request received in page handler");
return Err(QueryError::Shutdown)
break;
}
msg = pgb.read_message() => { msg }
@@ -500,7 +465,7 @@ impl PageServerHandler {
});
pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?;
self.flush_cancellable(pgb).await?;
pgb.flush().await?;
}
Ok(())
}
@@ -543,9 +508,9 @@ impl PageServerHandler {
// Import basebackup provided via CopyData
info!("importing basebackup");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
self.flush_cancellable(pgb).await?;
pgb.flush().await?;
let mut copyin_reader = pin!(StreamReader::new(self.copyin_stream(pgb)));
let mut copyin_reader = pin!(StreamReader::new(copyin_stream(pgb)));
timeline
.import_basebackup_from_tar(
&mut copyin_reader,
@@ -598,8 +563,8 @@ impl PageServerHandler {
// Import wal provided via CopyData
info!("importing wal");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
self.flush_cancellable(pgb).await?;
let mut copyin_reader = pin!(StreamReader::new(self.copyin_stream(pgb)));
pgb.flush().await?;
let mut copyin_reader = pin!(StreamReader::new(copyin_stream(pgb)));
import_wal_from_tar(&timeline, &mut copyin_reader, start_lsn, end_lsn, &ctx).await?;
info!("wal import complete");
@@ -807,7 +772,7 @@ impl PageServerHandler {
// switch client to COPYOUT
pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
self.flush_cancellable(pgb).await?;
pgb.flush().await?;
// Send a tarball of the latest layer on the timeline. Compress if not
// fullbackup. TODO Compress in that case too (tests need to be updated)
@@ -859,7 +824,7 @@ impl PageServerHandler {
}
pgb.write_message_noflush(&BeMessage::CopyDone)?;
self.flush_cancellable(pgb).await?;
pgb.flush().await?;
let basebackup_after = started
.elapsed()

View File

@@ -19,7 +19,6 @@ use postgres_ffi::BLCKSZ;
use postgres_ffi::{Oid, TimestampTz, TransactionId};
use serde::{Deserialize, Serialize};
use std::collections::{hash_map, HashMap, HashSet};
use std::ops::ControlFlow;
use std::ops::Range;
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn};
@@ -371,6 +370,7 @@ impl Timeline {
}
}
///
/// Subroutine of find_lsn_for_timestamp(). Returns true, if there are any
/// commits that committed after 'search_timestamp', at LSN 'probe_lsn'.
///
@@ -385,50 +385,6 @@ impl Timeline {
found_larger: &mut bool,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
self.map_all_timestamps(probe_lsn, ctx, |timestamp| {
if timestamp >= search_timestamp {
*found_larger = true;
return ControlFlow::Break(true);
} else {
*found_smaller = true;
}
ControlFlow::Continue(())
})
.await
}
/// Obtain the possible timestamp range for the given lsn.
///
/// If the lsn has no timestamps, returns None. returns `(min, max, median)` if it has timestamps.
pub async fn get_timestamp_for_lsn(
&self,
probe_lsn: Lsn,
ctx: &RequestContext,
) -> Result<Option<TimestampTz>, PageReconstructError> {
let mut max: Option<TimestampTz> = None;
self.map_all_timestamps(probe_lsn, ctx, |timestamp| {
if let Some(max_prev) = max {
max = Some(max_prev.max(timestamp));
} else {
max = Some(timestamp);
}
ControlFlow::Continue(())
})
.await?;
Ok(max)
}
/// Runs the given function on all the timestamps for a given lsn
///
/// The return value is either given by the closure, or set to the `Default`
/// impl's output.
async fn map_all_timestamps<T: Default>(
&self,
probe_lsn: Lsn,
ctx: &RequestContext,
mut f: impl FnMut(TimestampTz) -> ControlFlow<T>,
) -> Result<T, PageReconstructError> {
for segno in self
.list_slru_segments(SlruKind::Clog, probe_lsn, ctx)
.await?
@@ -446,14 +402,16 @@ impl Timeline {
timestamp_bytes.copy_from_slice(&clog_page[BLCKSZ as usize..]);
let timestamp = TimestampTz::from_be_bytes(timestamp_bytes);
match f(timestamp) {
ControlFlow::Break(b) => return Ok(b),
ControlFlow::Continue(()) => (),
if timestamp >= search_timestamp {
*found_larger = true;
return Ok(true);
} else {
*found_smaller = true;
}
}
}
}
Ok(Default::default())
Ok(false)
}
/// Get a list of SLRU segments
@@ -541,23 +499,6 @@ impl Timeline {
self.get(CHECKPOINT_KEY, lsn, ctx).await
}
pub async fn list_aux_files(
&self,
lsn: Lsn,
ctx: &RequestContext,
) -> Result<HashMap<String, Bytes>, PageReconstructError> {
match self.get(AUX_FILES_KEY, lsn, ctx).await {
Ok(buf) => match AuxFilesDirectory::des(&buf).context("deserialization failure") {
Ok(dir) => Ok(dir.files),
Err(e) => Err(PageReconstructError::from(e)),
},
Err(e) => {
warn!("Failed to get info about AUX files: {}", e);
Ok(HashMap::new())
}
}
}
/// Does the same as get_current_logical_size but counted on demand.
/// Used to initialize the logical size tracking on startup.
///
@@ -675,7 +616,6 @@ impl Timeline {
result.add_key(CONTROLFILE_KEY);
result.add_key(CHECKPOINT_KEY);
result.add_key(AUX_FILES_KEY);
Ok(result.to_keyspace())
}
@@ -752,12 +692,6 @@ impl<'a> DatadirModification<'a> {
})?;
self.put(DBDIR_KEY, Value::Image(buf.into()));
// Create AuxFilesDirectory
let buf = AuxFilesDirectory::ser(&AuxFilesDirectory {
files: HashMap::new(),
})?;
self.put(AUX_FILES_KEY, Value::Image(Bytes::from(buf)));
let buf = TwoPhaseDirectory::ser(&TwoPhaseDirectory {
xids: HashSet::new(),
})?;
@@ -862,12 +796,6 @@ impl<'a> DatadirModification<'a> {
// 'true', now write the updated 'dbdirs' map back.
let buf = DbDirectory::ser(&dbdir)?;
self.put(DBDIR_KEY, Value::Image(buf.into()));
// Create AuxFilesDirectory as well
let buf = AuxFilesDirectory::ser(&AuxFilesDirectory {
files: HashMap::new(),
})?;
self.put(AUX_FILES_KEY, Value::Image(Bytes::from(buf)));
}
if r.is_none() {
// Create RelDirectory
@@ -1192,36 +1120,6 @@ impl<'a> DatadirModification<'a> {
Ok(())
}
pub async fn put_file(
&mut self,
path: &str,
content: &[u8],
ctx: &RequestContext,
) -> anyhow::Result<()> {
let mut dir = match self.get(AUX_FILES_KEY, ctx).await {
Ok(buf) => AuxFilesDirectory::des(&buf)?,
Err(e) => {
warn!("Failed to get info about AUX files: {}", e);
AuxFilesDirectory {
files: HashMap::new(),
}
}
};
let path = path.to_string();
if content.is_empty() {
dir.files.remove(&path);
} else {
dir.files.insert(path, Bytes::copy_from_slice(content));
}
self.put(
AUX_FILES_KEY,
Value::Image(Bytes::from(
AuxFilesDirectory::ser(&dir).context("serialize")?,
)),
);
Ok(())
}
///
/// Flush changes accumulated so far to the underlying repository.
///
@@ -1357,11 +1255,6 @@ struct RelDirectory {
rels: HashSet<(Oid, u8)>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
struct AuxFilesDirectory {
files: HashMap<String, Bytes>,
}
#[derive(Debug, Serialize, Deserialize)]
struct RelSizeEntry {
nblocks: u32,
@@ -1410,12 +1303,10 @@ static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
// 02 pg_twophase
//
// 03 misc
// Controlfile
// controlfile
// checkpoint
// pg_version
//
// 04 aux files
//
// Below is a full list of the keyspace allocation:
//
// DbDir:
@@ -1453,11 +1344,6 @@ static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
//
// Checkpoint:
// 03 00000000 00000000 00000000 00 00000001
//
// AuxFiles:
// 03 00000000 00000000 00000000 00 00000002
//
//-- Section 01: relation data and metadata
const DBDIR_KEY: Key = Key {
@@ -1681,15 +1567,6 @@ const CHECKPOINT_KEY: Key = Key {
field6: 1,
};
const AUX_FILES_KEY: Key = Key {
field1: 0x03,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 2,
};
// Reverse mappings for a few Keys.
// These are needed by WAL redo manager.

File diff suppressed because it is too large Load Diff

View File

@@ -31,7 +31,7 @@ use super::{
const SHOULD_RESUME_DELETION_FETCH_MARK_ATTEMPTS: u32 = 3;
#[derive(Debug, thiserror::Error)]
pub(crate) enum DeleteTenantError {
pub enum DeleteTenantError {
#[error("GetTenant {0}")]
Get(#[from] GetTenantError),
@@ -376,7 +376,7 @@ impl DeleteTenantFlow {
Ok(())
}
pub(crate) async fn should_resume_deletion(
pub async fn should_resume_deletion(
conf: &'static PageServerConf,
remote_storage: Option<&GenericRemoteStorage>,
tenant: &Tenant,
@@ -432,7 +432,7 @@ impl DeleteTenantFlow {
// Tenant may not be loadable if we fail late in cleanup_remaining_fs_traces (e g remove timelines dir)
let timelines_path = tenant.conf.timelines_path(&tenant.tenant_id);
if timelines_path.exists() {
tenant.load(init_order, None, ctx).await.context("load")?;
tenant.load(init_order, ctx).await.context("load")?;
}
Self::background(
@@ -458,10 +458,7 @@ impl DeleteTenantFlow {
.await
.expect("cant be stopping or broken");
tenant
.attach(ctx, super::AttachMarkerMode::Expect)
.await
.context("attach")?;
tenant.attach(ctx).await.context("attach")?;
Self::background(
guard,

View File

@@ -354,7 +354,8 @@ mod tests {
}
// Test a large blob that spans multiple pages
let mut large_data = vec![0; 20000];
let mut large_data = Vec::new();
large_data.resize(20000, 0);
thread_rng().fill_bytes(&mut large_data);
let pos_large = file.write_blob(&large_data, &ctx).await?;
let result = file.block_cursor().read_blob(pos_large, &ctx).await?;

View File

@@ -1,7 +1,7 @@
//! This module acts as a switchboard to access different repositories managed by this
//! page server.
use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf};
use camino::{Utf8Path, Utf8PathBuf};
use rand::{distributions::Alphanumeric, Rng};
use std::collections::{hash_map, HashMap};
use std::sync::Arc;
@@ -24,11 +24,10 @@ use crate::control_plane_client::{
};
use crate::deletion_queue::DeletionQueueClient;
use crate::task_mgr::{self, TaskKind};
use crate::tenant::config::{AttachmentMode, LocationConf, LocationMode, TenantConfOpt};
use crate::tenant::config::{LocationConf, LocationMode, TenantConfOpt};
use crate::tenant::delete::DeleteTenantFlow;
use crate::tenant::{
create_tenant_files, AttachMarkerMode, AttachedTenantConf, CreateTenantFilesMode, Tenant,
TenantState,
create_tenant_files, AttachedTenantConf, CreateTenantFilesMode, Tenant, TenantState,
};
use crate::{InitializationOrder, IGNORED_TENANT_FILE_NAME, TEMP_FILE_SUFFIX};
@@ -51,7 +50,7 @@ use super::TenantSharedResources;
/// its lifetime, and we can preserve some important safety invariants like `Tenant` always
/// having a properly acquired generation (Secondary doesn't need a generation)
#[derive(Clone)]
pub(crate) enum TenantSlot {
pub enum TenantSlot {
Attached(Arc<Tenant>),
Secondary,
}
@@ -152,49 +151,6 @@ async fn safe_rename_tenant_dir(path: impl AsRef<Utf8Path>) -> std::io::Result<U
static TENANTS: Lazy<RwLock<TenantsMap>> = Lazy::new(|| RwLock::new(TenantsMap::Initializing));
/// Create a directory, including parents. This does no fsyncs and makes
/// no guarantees about the persistence of the resulting metadata: for
/// use when creating dirs for use as cache.
async fn unsafe_create_dir_all(path: &Utf8PathBuf) -> std::io::Result<()> {
let mut dirs_to_create = Vec::new();
let mut path: &Utf8Path = path.as_ref();
// Figure out which directories we need to create.
loop {
let meta = tokio::fs::metadata(path).await;
match meta {
Ok(metadata) if metadata.is_dir() => break,
Ok(_) => {
return Err(std::io::Error::new(
std::io::ErrorKind::AlreadyExists,
format!("non-directory found in path: {path}"),
));
}
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => return Err(e),
}
dirs_to_create.push(path);
match path.parent() {
Some(parent) => path = parent,
None => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("can't find parent of path '{path}'"),
));
}
}
}
// Create directories from parent to child.
for &path in dirs_to_create.iter().rev() {
tokio::fs::create_dir(path).await?;
}
Ok(())
}
fn emergency_generations(
tenant_confs: &HashMap<TenantId, anyhow::Result<LocationConf>>,
) -> HashMap<TenantId, Generation> {
@@ -250,105 +206,90 @@ async fn init_load_generations(
if resources.remote_storage.is_some() {
resources
.deletion_queue_client
.recover(generations.clone())?;
.recover(generations.clone())
.await?;
}
Ok(Some(generations))
}
/// Given a directory discovered in the pageserver's tenants/ directory, attempt
/// to load a tenant config from it.
///
/// If file is missing, return Ok(None)
fn load_tenant_config(
conf: &'static PageServerConf,
dentry: Utf8DirEntry,
) -> anyhow::Result<Option<(TenantId, anyhow::Result<LocationConf>)>> {
let tenant_dir_path = dentry.path().to_path_buf();
if crate::is_temporary(&tenant_dir_path) {
info!("Found temporary tenant directory, removing: {tenant_dir_path}");
// No need to use safe_remove_tenant_dir_all because this is already
// a temporary path
if let Err(e) = std::fs::remove_dir_all(&tenant_dir_path) {
error!(
"Failed to remove temporary directory '{}': {:?}",
tenant_dir_path, e
);
}
return Ok(None);
}
// This case happens if we crash during attachment before writing a config into the dir
let is_empty = tenant_dir_path
.is_empty_dir()
.with_context(|| format!("Failed to check whether {tenant_dir_path:?} is an empty dir"))?;
if is_empty {
info!("removing empty tenant directory {tenant_dir_path:?}");
if let Err(e) = std::fs::remove_dir(&tenant_dir_path) {
error!(
"Failed to remove empty tenant directory '{}': {e:#}",
tenant_dir_path
)
}
return Ok(None);
}
let tenant_ignore_mark_file = tenant_dir_path.join(IGNORED_TENANT_FILE_NAME);
if tenant_ignore_mark_file.exists() {
info!("Found an ignore mark file {tenant_ignore_mark_file:?}, skipping the tenant");
return Ok(None);
}
let tenant_id = match tenant_dir_path
.file_name()
.unwrap_or_default()
.parse::<TenantId>()
{
Ok(id) => id,
Err(_) => {
warn!("Invalid tenant path (garbage in our repo directory?): {tenant_dir_path}",);
return Ok(None);
}
};
Ok(Some((
tenant_id,
Tenant::load_tenant_config(conf, &tenant_id),
)))
}
/// Initial stage of load: walk the local tenants directory, clean up any temp files,
/// and load configurations for the tenants we found.
///
/// Do this in parallel, because we expect 10k+ tenants, so serial execution can take
/// seconds even on reasonably fast drives.
async fn init_load_tenant_configs(
conf: &'static PageServerConf,
) -> anyhow::Result<HashMap<TenantId, anyhow::Result<LocationConf>>> {
let tenants_dir = conf.tenants_path();
let dentries = tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<Utf8DirEntry>> {
let dir_entries = tenants_dir
.read_dir_utf8()
.with_context(|| format!("Failed to list tenants dir {tenants_dir:?}"))?;
Ok(dir_entries.collect::<Result<Vec<_>, std::io::Error>>()?)
})
.await??;
let mut dir_entries = tenants_dir
.read_dir_utf8()
.with_context(|| format!("Failed to list tenants dir {tenants_dir:?}"))?;
let mut configs = HashMap::new();
let mut join_set = JoinSet::new();
for dentry in dentries {
join_set.spawn_blocking(move || load_tenant_config(conf, dentry));
}
loop {
match dir_entries.next() {
None => break,
Some(Ok(dentry)) => {
let tenant_dir_path = dentry.path().to_path_buf();
if crate::is_temporary(&tenant_dir_path) {
info!("Found temporary tenant directory, removing: {tenant_dir_path}");
// No need to use safe_remove_tenant_dir_all because this is already
// a temporary path
if let Err(e) = fs::remove_dir_all(&tenant_dir_path).await {
error!(
"Failed to remove temporary directory '{}': {:?}",
tenant_dir_path, e
);
}
continue;
}
while let Some(r) = join_set.join_next().await {
if let Some((tenant_id, tenant_config)) = r?? {
configs.insert(tenant_id, tenant_config);
// This case happens if we:
// * crash during attach before creating the attach marker file
// * crash during tenant delete before removing tenant directory
let is_empty = tenant_dir_path.is_empty_dir().with_context(|| {
format!("Failed to check whether {tenant_dir_path:?} is an empty dir")
})?;
if is_empty {
info!("removing empty tenant directory {tenant_dir_path:?}");
if let Err(e) = fs::remove_dir(&tenant_dir_path).await {
error!(
"Failed to remove empty tenant directory '{}': {e:#}",
tenant_dir_path
)
}
continue;
}
let tenant_ignore_mark_file = tenant_dir_path.join(IGNORED_TENANT_FILE_NAME);
if tenant_ignore_mark_file.exists() {
info!("Found an ignore mark file {tenant_ignore_mark_file:?}, skipping the tenant");
continue;
}
let tenant_id = match tenant_dir_path
.file_name()
.unwrap_or_default()
.parse::<TenantId>()
{
Ok(id) => id,
Err(_) => {
warn!(
"Invalid tenant path (garbage in our repo directory?): {tenant_dir_path}",
);
continue;
}
};
configs.insert(tenant_id, Tenant::load_tenant_config(conf, &tenant_id));
}
Some(Err(e)) => {
// An error listing the top level directory indicates serious problem
// with local filesystem: we will fail to load, and fail to start.
anyhow::bail!(e);
}
}
}
Ok(configs)
}
@@ -506,15 +447,7 @@ pub(crate) fn schedule_local_tenant_processing(
"attaching mark file present but no remote storage configured".to_string(),
)
} else {
match Tenant::spawn_attach(
conf,
tenant_id,
resources,
location_conf,
tenants,
AttachMarkerMode::Expect,
ctx,
) {
match Tenant::spawn_attach(conf, tenant_id, resources, location_conf, tenants, ctx) {
Ok(tenant) => tenant,
Err(e) => {
error!("Failed to spawn_attach tenant {tenant_id}, reason: {e:#}");
@@ -549,7 +482,7 @@ pub(crate) fn schedule_local_tenant_processing(
/// management API. For example, it could attach the tenant on a different pageserver.
/// We would then be in split-brain once this pageserver restarts.
#[instrument(skip_all)]
pub(crate) async fn shutdown_all_tenants() {
pub async fn shutdown_all_tenants() {
shutdown_all_tenants0(&TENANTS).await
}
@@ -661,7 +594,7 @@ async fn shutdown_all_tenants0(tenants: &tokio::sync::RwLock<TenantsMap>) {
// caller will log how long we took
}
pub(crate) async fn create_tenant(
pub async fn create_tenant(
conf: &'static PageServerConf,
tenant_conf: TenantConfOpt,
tenant_id: TenantId,
@@ -696,14 +629,14 @@ pub(crate) async fn create_tenant(
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SetNewTenantConfigError {
pub enum SetNewTenantConfigError {
#[error(transparent)]
GetTenant(#[from] GetTenantError),
#[error(transparent)]
Persist(anyhow::Error),
}
pub(crate) async fn set_new_tenant_config(
pub async fn set_new_tenant_config(
conf: &'static PageServerConf,
new_tenant_conf: TenantConfOpt,
tenant_id: TenantId,
@@ -723,7 +656,7 @@ pub(crate) async fn set_new_tenant_config(
Ok(())
}
#[instrument(skip_all, fields(%tenant_id))]
#[instrument(skip_all, fields(tenant_id, new_location_config))]
pub(crate) async fn upsert_location(
conf: &'static PageServerConf,
tenant_id: TenantId,
@@ -762,18 +695,6 @@ pub(crate) async fn upsert_location(
if let Some(tenant) = shutdown_tenant {
let (_guard, progress) = utils::completion::channel();
match tenant.get_attach_mode() {
AttachmentMode::Single | AttachmentMode::Multi => {
// Before we leave our state as the presumed holder of the latest generation,
// flush any outstanding deletions to reduce the risk of leaking objects.
deletion_queue_client.flush_advisory()
}
AttachmentMode::Stale => {
// If we're stale there's not point trying to flush deletions
}
};
info!("Shutting down attached tenant");
match tenant.shutdown(progress, false).await {
Ok(()) => {}
@@ -802,61 +723,36 @@ pub(crate) async fn upsert_location(
}
let new_slot = match &new_location_config.mode {
LocationMode::Secondary(_) => {
let tenant_path = conf.tenant_path(&tenant_id);
// Directory doesn't need to be fsync'd because if we crash it can
// safely be recreated next time this tenant location is configured.
unsafe_create_dir_all(&tenant_path)
.await
.with_context(|| format!("Creating {tenant_path}"))?;
Tenant::persist_tenant_config(conf, &tenant_id, &new_location_config)
.await
.map_err(SetNewTenantConfigError::Persist)?;
TenantSlot::Secondary
}
LocationMode::Secondary(_) => TenantSlot::Secondary,
LocationMode::Attached(_attach_config) => {
// Do a schedule_local_tenant_processing
// FIXME: should avoid doing this disk I/O inside the TenantsMap lock,
// we have the same problem in load_tenant/attach_tenant. Probably
// need a lock in TenantSlot to fix this.
let timelines_path = conf.timelines_path(&tenant_id);
// Directory doesn't need to be fsync'd because we do not depend on
// it to exist after crashes: it may be recreated when tenant is
// re-attached, see https://github.com/neondatabase/neon/issues/5550
unsafe_create_dir_all(&timelines_path)
.await
.with_context(|| format!("Creating {timelines_path}"))?;
Tenant::persist_tenant_config(conf, &tenant_id, &new_location_config)
.await
.map_err(SetNewTenantConfigError::Persist)?;
let tenant = match Tenant::spawn_attach(
let tenant_path = conf.tenant_path(&tenant_id);
let resources = TenantSharedResources {
broker_client,
remote_storage,
deletion_queue_client,
};
let new_tenant = schedule_local_tenant_processing(
conf,
tenant_id,
TenantSharedResources {
broker_client,
remote_storage,
deletion_queue_client,
},
&tenant_path,
AttachedTenantConf::try_from(new_location_config)?,
resources,
None,
&TENANTS,
// The LocationConf API does not use marker files, because we have Secondary
// locations where the directory's existence is not a signal that it contains
// all timelines. See https://github.com/neondatabase/neon/issues/5550
AttachMarkerMode::Ignore,
ctx,
) {
Ok(tenant) => tenant,
Err(e) => {
error!("Failed to spawn_attach tenant {tenant_id}, reason: {e:#}");
Tenant::create_broken_tenant(conf, tenant_id, format!("{e:#}"))
}
};
)
.with_context(|| {
format!("Failed to schedule tenant processing in path {tenant_path:?}")
})?;
TenantSlot::Attached(tenant)
TenantSlot::Attached(new_tenant)
}
};
@@ -864,11 +760,12 @@ pub(crate) async fn upsert_location(
})
.await?;
}
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum GetTenantError {
pub enum GetTenantError {
#[error("Tenant {0} not found")]
NotFound(TenantId),
#[error("Tenant {0} is not active")]
@@ -884,7 +781,7 @@ pub(crate) enum GetTenantError {
/// `active_only = true` allows to query only tenants that are ready for operations, erroring on other kinds of tenants.
///
/// This method is cancel-safe.
pub(crate) async fn get_tenant(
pub async fn get_tenant(
tenant_id: TenantId,
active_only: bool,
) -> Result<Arc<Tenant>, GetTenantError> {
@@ -909,7 +806,7 @@ pub(crate) async fn get_tenant(
}
}
pub(crate) async fn delete_tenant(
pub async fn delete_tenant(
conf: &'static PageServerConf,
remote_storage: Option<GenericRemoteStorage>,
tenant_id: TenantId,
@@ -918,7 +815,7 @@ pub(crate) async fn delete_tenant(
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum DeleteTimelineError {
pub enum DeleteTimelineError {
#[error("Tenant {0}")]
Tenant(#[from] GetTenantError),
@@ -926,7 +823,7 @@ pub(crate) enum DeleteTimelineError {
Timeline(#[from] crate::tenant::DeleteTimelineError),
}
pub(crate) async fn delete_timeline(
pub async fn delete_timeline(
tenant_id: TenantId,
timeline_id: TimelineId,
_ctx: &RequestContext,
@@ -937,29 +834,23 @@ pub(crate) async fn delete_timeline(
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum TenantStateError {
pub enum TenantStateError {
#[error("Tenant {0} not found")]
NotFound(TenantId),
#[error("Tenant {0} is stopping")]
IsStopping(TenantId),
#[error("Tenant {0} is not active")]
NotActive(TenantId),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub(crate) async fn detach_tenant(
pub async fn detach_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
detach_ignored: bool,
deletion_queue_client: &DeletionQueueClient,
) -> Result<(), TenantStateError> {
let tmp_path = detach_tenant0(
conf,
&TENANTS,
tenant_id,
detach_ignored,
deletion_queue_client,
)
.await?;
let tmp_path = detach_tenant0(conf, &TENANTS, tenant_id, detach_ignored).await?;
// Although we are cleaning up the tenant, this task is not meant to be bound by the lifetime of the tenant in memory.
// After a tenant is detached, there are no more task_mgr tasks for that tenant_id.
let task_tenant_id = None;
@@ -984,7 +875,6 @@ async fn detach_tenant0(
tenants: &tokio::sync::RwLock<TenantsMap>,
tenant_id: TenantId,
detach_ignored: bool,
deletion_queue_client: &DeletionQueueClient,
) -> Result<Utf8PathBuf, TenantStateError> {
let tenant_dir_rename_operation = |tenant_id_to_clean| async move {
let local_tenant_directory = conf.tenant_path(&tenant_id_to_clean);
@@ -996,10 +886,6 @@ async fn detach_tenant0(
let removal_result =
remove_tenant_from_memory(tenants, tenant_id, tenant_dir_rename_operation(tenant_id)).await;
// Flush pending deletions, so that they have a good chance of passing validation
// before this tenant is potentially re-attached elsewhere.
deletion_queue_client.flush_advisory();
// Ignored tenants are not present in memory and will bail the removal from memory operation.
// Before returning the error, check for ignored tenant removal case — we only need to clean its local files then.
if detach_ignored && matches!(removal_result, Err(TenantStateError::NotFound(_))) {
@@ -1016,7 +902,7 @@ async fn detach_tenant0(
removal_result
}
pub(crate) async fn load_tenant(
pub async fn load_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
generation: Generation,
@@ -1053,7 +939,7 @@ pub(crate) async fn load_tenant(
Ok(())
}
pub(crate) async fn ignore_tenant(
pub async fn ignore_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
) -> Result<(), TenantStateError> {
@@ -1081,7 +967,7 @@ async fn ignore_tenant0(
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum TenantMapListError {
pub enum TenantMapListError {
#[error("tenant map is still initiailizing")]
Initializing,
}
@@ -1089,7 +975,7 @@ pub(crate) enum TenantMapListError {
///
/// Get list of tenants, for the mgmt API
///
pub(crate) async fn list_tenants() -> Result<Vec<(TenantId, TenantState)>, TenantMapListError> {
pub async fn list_tenants() -> Result<Vec<(TenantId, TenantState)>, TenantMapListError> {
let tenants = TENANTS.read().await;
let m = match &*tenants {
TenantsMap::Initializing => return Err(TenantMapListError::Initializing),
@@ -1107,7 +993,7 @@ pub(crate) async fn list_tenants() -> Result<Vec<(TenantId, TenantState)>, Tenan
///
/// Downloading all the tenant data is performed in the background, this merely
/// spawns the background task and returns quickly.
pub(crate) async fn attach_tenant(
pub async fn attach_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
generation: Generation,
@@ -1144,7 +1030,7 @@ pub(crate) async fn attach_tenant(
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum TenantMapInsertError {
pub enum TenantMapInsertError {
#[error("tenant map is still initializing")]
StillInitializing,
#[error("tenant map is shutting down")]
@@ -1307,7 +1193,7 @@ use {
utils::http::error::ApiError,
};
pub(crate) async fn immediate_gc(
pub async fn immediate_gc(
tenant_id: TenantId,
timeline_id: TimelineId,
gc_req: TimelineGcRequest,

View File

@@ -1419,13 +1419,6 @@ impl RemoteTimelineClient {
}
}
}
pub(crate) fn get_layer_metadata(
&self,
name: &LayerFileName,
) -> anyhow::Result<Option<LayerFileMetadata>> {
self.upload_queue.lock().unwrap().get_layer_metadata(name)
}
}
pub fn remote_timelines_path(tenant_id: &TenantId) -> RemotePath {

View File

@@ -18,7 +18,7 @@ use crate::config::PageServerConf;
use crate::tenant::remote_timeline_client::{remote_layer_path, remote_timelines_path};
use crate::tenant::storage_layer::LayerFileName;
use crate::tenant::timeline::span::debug_assert_current_span_has_tenant_and_timeline_id;
use crate::tenant::{Generation, TENANT_DELETED_MARKER_FILE_NAME};
use crate::tenant::Generation;
use remote_storage::{DownloadError, GenericRemoteStorage};
use utils::crashsafe::path_with_suffix_extension;
use utils::id::{TenantId, TimelineId};
@@ -190,12 +190,6 @@ pub async fn list_remote_timelines(
let mut timeline_ids = HashSet::new();
for timeline_remote_storage_key in timelines {
if timeline_remote_storage_key.object_name() == Some(TENANT_DELETED_MARKER_FILE_NAME) {
// A `deleted` key within `timelines/` is a marker file, not a timeline. Ignore it.
// This code will be removed in https://github.com/neondatabase/neon/pull/5580
continue;
}
let object_name = timeline_remote_storage_key.object_name().ok_or_else(|| {
anyhow::anyhow!("failed to get timeline id for remote tenant {tenant_id}")
})?;

View File

@@ -31,7 +31,6 @@ pub(super) async fn upload_index_part<'a>(
fail_point!("before-upload-index", |_| {
bail!("failpoint before-upload-index")
});
pausable_failpoint!("before-upload-index-pausable");
let index_part_bytes =
serde_json::to_vec(&index_part).context("serialize index part file into bytes")?;
@@ -60,8 +59,6 @@ pub(super) async fn upload_timeline_layer<'a>(
bail!("failpoint before-upload-layer")
});
pausable_failpoint!("before-upload-layer-pausable");
let storage_path = remote_path(conf, source_path, generation)?;
let source_file_res = fs::File::open(&source_path).await;
let source_file = match source_file_res {

View File

@@ -226,14 +226,6 @@ impl LayerFileName {
_ => false,
}
}
pub(crate) fn kind(&self) -> &'static str {
use LayerFileName::*;
match self {
Delta(_) => "delta",
Image(_) => "image",
}
}
}
impl fmt::Display for LayerFileName {

View File

@@ -25,7 +25,7 @@ use super::{
};
/// RemoteLayer is a not yet downloaded [`ImageLayer`] or
/// [`DeltaLayer`].
/// [`DeltaLayer`](super::DeltaLayer).
///
/// RemoteLayer might be downloaded on-demand during operations which are
/// allowed download remote layers and during which, it gets replaced with a

View File

@@ -14,73 +14,6 @@ use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::completion;
static CONCURRENT_BACKGROUND_TASKS: once_cell::sync::Lazy<tokio::sync::Semaphore> =
once_cell::sync::Lazy::new(|| {
let total_threads = *task_mgr::BACKGROUND_RUNTIME_WORKER_THREADS;
let permits = usize::max(
1,
// while a lot of the work is done on spawn_blocking, we still do
// repartitioning in the async context. this should give leave us some workers
// unblocked to be blocked on other work, hopefully easing any outside visible
// effects of restarts.
//
// 6/8 is a guess; previously we ran with unlimited 8 and more from
// spawn_blocking.
(total_threads * 3).checked_div(4).unwrap_or(0),
);
assert_ne!(permits, 0, "we will not be adding in permits later");
assert!(
permits < total_threads,
"need threads avail for shorter work"
);
tokio::sync::Semaphore::new(permits)
});
#[derive(Debug, PartialEq, Eq, Clone, Copy, strum_macros::IntoStaticStr)]
#[strum(serialize_all = "snake_case")]
pub(crate) enum BackgroundLoopKind {
Compaction,
Gc,
Eviction,
ConsumptionMetricsCollectMetrics,
ConsumptionMetricsSyntheticSizeWorker,
}
impl BackgroundLoopKind {
fn as_static_str(&self) -> &'static str {
let s: &'static str = self.into();
s
}
}
pub(crate) enum RateLimitError {
Cancelled,
}
pub(crate) async fn concurrent_background_tasks_rate_limit(
loop_kind: BackgroundLoopKind,
_ctx: &RequestContext,
cancel: &CancellationToken,
) -> Result<impl Drop, RateLimitError> {
crate::metrics::BACKGROUND_LOOP_SEMAPHORE_WAIT_START_COUNT
.with_label_values(&[loop_kind.as_static_str()])
.inc();
scopeguard::defer!(
crate::metrics::BACKGROUND_LOOP_SEMAPHORE_WAIT_FINISH_COUNT.with_label_values(&[loop_kind.as_static_str()]).inc();
);
tokio::select! {
permit = CONCURRENT_BACKGROUND_TASKS.acquire() => {
match permit {
Ok(permit) => Ok(permit),
Err(_closed) => unreachable!("we never close the semaphore"),
}
},
_ = cancel.cancelled() => {
Err(RateLimitError::Cancelled)
}
}
}
/// Start per tenant background loops: compaction and gc.
pub fn start_background_loops(
tenant: &Arc<Tenant>,
@@ -183,7 +116,7 @@ async fn compaction_loop(tenant: Arc<Tenant>, cancel: CancellationToken) {
}
};
warn_when_period_overrun(started_at.elapsed(), period, BackgroundLoopKind::Compaction);
warn_when_period_overrun(started_at.elapsed(), period, "compaction");
// Sleep
if tokio::time::timeout(sleep_duration, cancel.cancelled())
@@ -251,7 +184,7 @@ async fn gc_loop(tenant: Arc<Tenant>, cancel: CancellationToken) {
}
};
warn_when_period_overrun(started_at.elapsed(), period, BackgroundLoopKind::Gc);
warn_when_period_overrun(started_at.elapsed(), period, "gc");
// Sleep
if tokio::time::timeout(sleep_duration, cancel.cancelled())
@@ -325,11 +258,7 @@ pub(crate) async fn random_init_delay(
}
/// Attention: the `task` and `period` beocme labels of a pageserver-wide prometheus metric.
pub(crate) fn warn_when_period_overrun(
elapsed: Duration,
period: Duration,
task: BackgroundLoopKind,
) {
pub(crate) fn warn_when_period_overrun(elapsed: Duration, period: Duration, task: &str) {
// Duration::ZERO will happen because it's the "disable [bgtask]" value.
if elapsed >= period && period != Duration::ZERO {
// humantime does no significant digits clamping whereas Duration's debug is a bit more
@@ -338,11 +267,11 @@ pub(crate) fn warn_when_period_overrun(
warn!(
?elapsed,
period = %humantime::format_duration(period),
?task,
task,
"task iteration took longer than the configured period"
);
crate::metrics::BACKGROUND_LOOP_PERIOD_OVERRUN_COUNT
.with_label_values(&[task.as_static_str(), &format!("{}", period.as_secs())])
.with_label_values(&[task, &format!("{}", period.as_secs())])
.inc();
}
}

View File

@@ -44,7 +44,6 @@ use crate::tenant::storage_layer::delta_layer::DeltaEntry;
use crate::tenant::storage_layer::{
DeltaLayerWriter, ImageLayerWriter, InMemoryLayer, LayerAccessStats, LayerFileName, RemoteLayer,
};
use crate::tenant::tasks::{BackgroundLoopKind, RateLimitError};
use crate::tenant::timeline::logical_size::CurrentLogicalSize;
use crate::tenant::{
layer_map::{LayerMap, SearchResult},
@@ -81,6 +80,7 @@ use crate::repository::GcResult;
use crate::repository::{Key, Value};
use crate::task_mgr;
use crate::task_mgr::TaskKind;
use crate::walredo::WalRedoManager;
use crate::ZERO_PAGE;
use self::delete::DeleteTimelineFlow;
@@ -200,7 +200,7 @@ pub struct Timeline {
last_freeze_ts: RwLock<Instant>,
// WAL redo manager
walredo_mgr: Arc<super::WalRedoManager>,
walredo_mgr: Arc<dyn WalRedoManager + Sync + Send>,
/// Remote storage client.
/// See [`remote_timeline_client`](super::remote_timeline_client) module comment for details.
@@ -370,7 +370,7 @@ pub enum PageReconstructError {
/// An error happened replaying WAL records
#[error(transparent)]
WalRedo(anyhow::Error),
WalRedo(#[from] crate::walredo::WalRedoError),
}
impl std::fmt::Debug for PageReconstructError {
@@ -684,17 +684,37 @@ impl Timeline {
) -> anyhow::Result<()> {
const ROUNDS: usize = 2;
static CONCURRENT_COMPACTIONS: once_cell::sync::Lazy<tokio::sync::Semaphore> =
once_cell::sync::Lazy::new(|| {
let total_threads = *task_mgr::BACKGROUND_RUNTIME_WORKER_THREADS;
let permits = usize::max(
1,
// while a lot of the work is done on spawn_blocking, we still do
// repartitioning in the async context. this should give leave us some workers
// unblocked to be blocked on other work, hopefully easing any outside visible
// effects of restarts.
//
// 6/8 is a guess; previously we ran with unlimited 8 and more from
// spawn_blocking.
(total_threads * 3).checked_div(4).unwrap_or(0),
);
assert_ne!(permits, 0, "we will not be adding in permits later");
assert!(
permits < total_threads,
"need threads avail for shorter work"
);
tokio::sync::Semaphore::new(permits)
});
// this wait probably never needs any "long time spent" logging, because we already nag if
// compaction task goes over it's period (20s) which is quite often in production.
let _permit = match super::tasks::concurrent_background_tasks_rate_limit(
BackgroundLoopKind::Compaction,
ctx,
cancel,
)
.await
{
Ok(permit) => permit,
Err(RateLimitError::Cancelled) => return Ok(()),
let _permit = tokio::select! {
permit = CONCURRENT_COMPACTIONS.acquire() => {
permit
},
_ = cancel.cancelled() => {
return Ok(());
}
};
let last_record_lsn = self.get_last_record_lsn();
@@ -1274,23 +1294,7 @@ impl Timeline {
Ok(delta) => Some(delta),
};
// RemoteTimelineClient holds the metadata on layers' remote generations, so
// query it to construct a RemoteLayer.
let layer_metadata = self
.remote_client
.as_ref()
.expect("Eviction is not called without remote storage")
.get_layer_metadata(&local_layer.filename())
.map_err(EvictionError::LayerNotFound)?
.ok_or_else(|| {
EvictionError::LayerNotFound(anyhow::anyhow!("Layer not in remote metadata"))
})?;
if layer_metadata.file_size() != layer_file_size {
return Err(EvictionError::MetadataInconsistency(format!(
"Layer size {layer_file_size} doesn't match remote metadata file size {}",
layer_metadata.file_size()
)));
}
let layer_metadata = LayerFileMetadata::new(layer_file_size, self.generation);
let new_remote_layer = Arc::new(match local_layer.filename() {
LayerFileName::Image(image_name) => RemoteLayer::new_img(
@@ -1369,10 +1373,6 @@ pub(crate) enum EvictionError {
/// different objects in memory.
#[error("layer was no longer part of LayerMap")]
LayerNotFound(#[source] anyhow::Error),
/// This should never happen
#[error("Metadata inconsistency")]
MetadataInconsistency(String),
}
/// Number of times we will compute partition within a checkpoint distance.
@@ -1470,7 +1470,7 @@ impl Timeline {
timeline_id: TimelineId,
tenant_id: TenantId,
generation: Generation,
walredo_mgr: Arc<super::WalRedoManager>,
walredo_mgr: Arc<dyn WalRedoManager + Send + Sync>,
resources: TimelineResources,
pg_version: u32,
initial_logical_size_can_start: Option<completion::Barrier>,
@@ -1699,7 +1699,7 @@ impl Timeline {
disk_consistent_lsn: Lsn,
index_part: Option<IndexPart>,
) -> anyhow::Result<()> {
use init::{Decision::*, Discovered, DismissedLayer};
use init::{Decision::*, Discovered, FutureLayer};
use LayerFileName::*;
let mut guard = self.layers.write().await;
@@ -1715,7 +1715,7 @@ impl Timeline {
// Copy to move into the task we're about to spawn
let generation = self.generation;
let (loaded_layers, needs_cleanup, total_physical_size) = tokio::task::spawn_blocking({
let (loaded_layers, to_sync, total_physical_size) = tokio::task::spawn_blocking({
move || {
let _g = span.entered();
let discovered = init::scan_timeline_dir(&timeline_path)?;
@@ -1764,6 +1764,7 @@ impl Timeline {
);
let mut loaded_layers = Vec::new();
let mut needs_upload = Vec::new();
let mut needs_cleanup = Vec::new();
let mut total_physical_size = 0;
@@ -1784,7 +1785,7 @@ impl Timeline {
}
}
Ok(decision) => decision,
Err(DismissedLayer::Future { local }) => {
Err(FutureLayer { local }) => {
if local.is_some() {
path.push(name.file_name());
init::cleanup_future_layer(&path, &name, disk_consistent_lsn)?;
@@ -1793,13 +1794,6 @@ impl Timeline {
needs_cleanup.push(name);
continue;
}
Err(DismissedLayer::LocalOnly(local)) => {
path.push(name.file_name());
init::cleanup_local_only_file(&path, &name, &local)?;
path.pop();
// this file never existed remotely, we will have to do rework
continue;
}
};
match &name {
@@ -1808,16 +1802,14 @@ impl Timeline {
}
let status = match &decision {
UseLocal(_) => LayerResidenceStatus::Resident,
UseLocal(_) | NeedsUpload(_) => LayerResidenceStatus::Resident,
Evicted(_) | UseRemote { .. } => LayerResidenceStatus::Evicted,
};
tracing::debug!(layer=%name, ?decision, ?status, "applied");
let stats = LayerAccessStats::for_loading_layer(status);
let layer: Arc<dyn PersistentLayer> = match (name, &decision) {
(Delta(d), UseLocal(m)) => {
(Delta(d), UseLocal(m) | NeedsUpload(m)) => {
total_physical_size += m.file_size();
Arc::new(DeltaLayer::new(
conf,
@@ -1828,7 +1820,7 @@ impl Timeline {
stats,
))
}
(Image(i), UseLocal(m)) => {
(Image(i), UseLocal(m) | NeedsUpload(m)) => {
total_physical_size += m.file_size();
Arc::new(ImageLayer::new(
conf,
@@ -1847,9 +1839,17 @@ impl Timeline {
),
};
if let NeedsUpload(m) = decision {
needs_upload.push((layer.clone(), m));
}
loaded_layers.push(layer);
}
Ok((loaded_layers, needs_cleanup, total_physical_size))
Ok((
loaded_layers,
(needs_upload, needs_cleanup),
total_physical_size,
))
}
})
.await
@@ -1861,6 +1861,10 @@ impl Timeline {
guard.initialize_local_layers(loaded_layers, disk_consistent_lsn + 1);
if let Some(rtc) = self.remote_client.as_ref() {
let (needs_upload, needs_cleanup) = to_sync;
for (layer, m) in needs_upload {
rtc.schedule_layer_file_upload(&layer.layer_desc().filename(), &m)?;
}
rtc.schedule_layer_file_deletion(needs_cleanup)?;
rtc.schedule_index_upload_for_file_changes()?;
// Tenant::create_timeline will wait for these uploads to happen before returning, or
@@ -2793,13 +2797,10 @@ impl Timeline {
)
};
let disk_consistent_lsn = Lsn(lsn_range.end.0 - 1);
let old_disk_consistent_lsn = self.disk_consistent_lsn.load();
// The new on-disk layers are now in the layer map. We can remove the
// in-memory layer from the map now. The flushed layer is stored in
// the mapping in `create_delta_layer`.
let metadata = {
{
let mut guard = self.layers.write().await;
if let Some(ref l) = delta_layer_to_add {
@@ -2815,17 +2816,8 @@ impl Timeline {
}
guard.finish_flush_l0_layer(delta_layer_to_add, &frozen_layer);
if disk_consistent_lsn != old_disk_consistent_lsn {
assert!(disk_consistent_lsn > old_disk_consistent_lsn);
self.disk_consistent_lsn.store(disk_consistent_lsn);
// Schedule remote uploads that will reflect our new disk_consistent_lsn
Some(self.schedule_uploads(disk_consistent_lsn, layer_paths_to_upload)?)
} else {
None
}
// release lock on 'layers'
};
}
// FIXME: between create_delta_layer and the scheduling of the upload in `update_metadata_file`,
// a compaction can delete the file and then it won't be available for uploads any more.
@@ -2841,22 +2833,28 @@ impl Timeline {
//
// TODO: This perhaps should be done in 'flush_frozen_layers', after flushing
// *all* the layers, to avoid fsyncing the file multiple times.
let disk_consistent_lsn = Lsn(lsn_range.end.0 - 1);
let old_disk_consistent_lsn = self.disk_consistent_lsn.load();
// If we updated our disk_consistent_lsn, persist the updated metadata to local disk.
if let Some(metadata) = metadata {
save_metadata(self.conf, &self.tenant_id, &self.timeline_id, &metadata)
// If we were able to advance 'disk_consistent_lsn', save it the metadata file.
// After crash, we will restart WAL streaming and processing from that point.
if disk_consistent_lsn != old_disk_consistent_lsn {
assert!(disk_consistent_lsn > old_disk_consistent_lsn);
self.update_metadata_file(disk_consistent_lsn, layer_paths_to_upload)
.await
.context("save_metadata")?;
.context("update_metadata_file")?;
// Also update the in-memory copy
self.disk_consistent_lsn.store(disk_consistent_lsn);
}
Ok(())
}
/// Update metadata file
fn schedule_uploads(
async fn update_metadata_file(
&self,
disk_consistent_lsn: Lsn,
layer_paths_to_upload: HashMap<LayerFileName, LayerFileMetadata>,
) -> anyhow::Result<TimelineMetadata> {
) -> anyhow::Result<()> {
// We can only save a valid 'prev_record_lsn' value on disk if we
// flushed *all* in-memory changes to disk. We only track
// 'prev_record_lsn' in memory for the latest processed record, so we
@@ -2893,6 +2891,10 @@ impl Timeline {
x.unwrap()
));
save_metadata(self.conf, &self.tenant_id, &self.timeline_id, &metadata)
.await
.context("save_metadata")?;
if let Some(remote_client) = &self.remote_client {
for (path, layer_metadata) in layer_paths_to_upload {
remote_client.schedule_layer_file_upload(&path, &layer_metadata)?;
@@ -2900,20 +2902,6 @@ impl Timeline {
remote_client.schedule_index_upload_for_metadata_update(&metadata)?;
}
Ok(metadata)
}
async fn update_metadata_file(
&self,
disk_consistent_lsn: Lsn,
layer_paths_to_upload: HashMap<LayerFileName, LayerFileMetadata>,
) -> anyhow::Result<()> {
let metadata = self.schedule_uploads(disk_consistent_lsn, layer_paths_to_upload)?;
save_metadata(self.conf, &self.tenant_id, &self.timeline_id, &metadata)
.await
.context("save_metadata")?;
Ok(())
}
@@ -4339,7 +4327,6 @@ impl Timeline {
let img = match self
.walredo_mgr
.request_redo(key, request_lsn, data.img, data.records, self.pg_version)
.await
.context("Failed to reconstruct a page image:")
{
Ok(img) => img,

View File

@@ -30,7 +30,6 @@ use crate::{
tenant::{
config::{EvictionPolicy, EvictionPolicyLayerAccessThreshold},
storage_layer::PersistentLayer,
tasks::{BackgroundLoopKind, RateLimitError},
timeline::EvictionError,
LogicalSizeCalculationCause, Tenant,
},
@@ -130,11 +129,7 @@ impl Timeline {
ControlFlow::Continue(()) => (),
}
let elapsed = start.elapsed();
crate::tenant::tasks::warn_when_period_overrun(
elapsed,
p.period,
BackgroundLoopKind::Eviction,
);
crate::tenant::tasks::warn_when_period_overrun(elapsed, p.period, "eviction");
crate::metrics::EVICTION_ITERATION_DURATION
.get_metric_with_label_values(&[
&format!("{}", p.period.as_secs()),
@@ -155,17 +150,6 @@ impl Timeline {
) -> ControlFlow<()> {
let now = SystemTime::now();
let _permit = match crate::tenant::tasks::concurrent_background_tasks_rate_limit(
BackgroundLoopKind::Eviction,
ctx,
cancel,
)
.await
{
Ok(permit) => permit,
Err(RateLimitError::Cancelled) => return ControlFlow::Break(()),
};
// If we evict layers but keep cached values derived from those layers, then
// we face a storm of on-demand downloads after pageserver restart.
// The reason is that the restart empties the caches, and so, the values
@@ -301,10 +285,6 @@ impl Timeline {
warn!(layer = %l, "failed to evict layer: {e}");
stats.not_evictable += 1;
}
Some(Err(EvictionError::MetadataInconsistency(detail))) => {
warn!(layer = %l, "failed to evict layer: {detail}");
stats.not_evictable += 1;
}
}
}
if stats.candidates == stats.not_evictable {

View File

@@ -72,7 +72,7 @@ pub(super) fn scan_timeline_dir(path: &Utf8Path) -> anyhow::Result<Vec<Discovere
}
/// Decision on what to do with a layer file after considering its local and remote metadata.
#[derive(Clone, Debug)]
#[derive(Clone)]
pub(super) enum Decision {
/// The layer is not present locally.
Evicted(LayerFileMetadata),
@@ -84,30 +84,27 @@ pub(super) enum Decision {
},
/// The layer is present locally, and metadata matches.
UseLocal(LayerFileMetadata),
/// The layer is only known locally, it needs to be uploaded.
NeedsUpload(LayerFileMetadata),
}
/// A layer needs to be left out of the layer map.
/// The related layer is is in future compared to disk_consistent_lsn, it must not be loaded.
#[derive(Debug)]
pub(super) enum DismissedLayer {
/// The related layer is is in future compared to disk_consistent_lsn, it must not be loaded.
Future {
/// The local metadata. `None` if the layer is only known through [`IndexPart`].
local: Option<LayerFileMetadata>,
},
/// The layer only exists locally.
///
/// In order to make crash safe updates to layer map, we must dismiss layers which are only
/// found locally or not yet included in the remote `index_part.json`.
LocalOnly(LayerFileMetadata),
pub(super) struct FutureLayer {
/// The local metadata. `None` if the layer is only known through [`IndexPart`].
pub(super) local: Option<LayerFileMetadata>,
}
/// Merges local discoveries and remote [`IndexPart`] to a collection of decisions.
///
/// This function should not gain additional reasons to fail than [`FutureLayer`], consider adding
/// the checks earlier to [`scan_timeline_dir`].
pub(super) fn reconcile(
discovered: Vec<(LayerFileName, u64)>,
index_part: Option<&IndexPart>,
disk_consistent_lsn: Lsn,
generation: Generation,
) -> Vec<(LayerFileName, Result<Decision, DismissedLayer>)> {
) -> Vec<(LayerFileName, Result<Decision, FutureLayer>)> {
use Decision::*;
// name => (local, remote)
@@ -145,19 +142,17 @@ pub(super) fn reconcile(
.into_iter()
.map(|(name, (local, remote))| {
let decision = if name.is_in_future(disk_consistent_lsn) {
Err(DismissedLayer::Future { local })
Err(FutureLayer { local })
} else {
match (local, remote) {
(Some(local), Some(remote)) if local != remote => {
Ok(UseRemote { local, remote })
}
(Some(x), Some(_)) => Ok(UseLocal(x)),
(None, Some(x)) => Ok(Evicted(x)),
(Some(x), None) => Err(DismissedLayer::LocalOnly(x)),
Ok(match (local, remote) {
(Some(local), Some(remote)) if local != remote => UseRemote { local, remote },
(Some(x), Some(_)) => UseLocal(x),
(None, Some(x)) => Evicted(x),
(Some(x), None) => NeedsUpload(x),
(None, None) => {
unreachable!("there must not be any non-local non-remote files")
}
}
})
};
(name, decision)
@@ -197,21 +192,14 @@ pub(super) fn cleanup_future_layer(
name: &LayerFileName,
disk_consistent_lsn: Lsn,
) -> anyhow::Result<()> {
use LayerFileName::*;
let kind = match name {
Delta(_) => "delta",
Image(_) => "image",
};
// future image layers are allowed to be produced always for not yet flushed to disk
// lsns stored in InMemoryLayer.
let kind = name.kind();
tracing::info!("found future {kind} layer {name} disk_consistent_lsn is {disk_consistent_lsn}");
std::fs::remove_file(path)?;
Ok(())
}
pub(super) fn cleanup_local_only_file(
path: &Utf8Path,
name: &LayerFileName,
local: &LayerFileMetadata,
) -> anyhow::Result<()> {
let kind = name.kind();
tracing::info!("found local-only {kind} layer {name}, metadata {local:?}");
std::fs::remove_file(path)?;
crate::tenant::timeline::rename_to_backup(path)?;
Ok(())
}

View File

@@ -203,18 +203,6 @@ impl UploadQueue {
UploadQueue::Stopped(stopped) => Ok(stopped),
}
}
pub(crate) fn get_layer_metadata(
&self,
name: &LayerFileName,
) -> anyhow::Result<Option<LayerFileMetadata>> {
match self {
UploadQueue::Stopped(_) | UploadQueue::Uninitialized => {
anyhow::bail!("queue is in state {}", self.as_str())
}
UploadQueue::Initialized(inner) => Ok(inner.latest_files.get(name).cloned()),
}
}
}
/// An in-progress upload or delete task.

View File

@@ -338,20 +338,11 @@ impl<'a> WalIngest<'a> {
} else if decoded.xl_rmid == pg_constants::RM_LOGICALMSG_ID {
let info = decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK;
if info == pg_constants::XLOG_LOGICAL_MESSAGE {
let xlrec = XlLogicalMessage::decode(&mut buf);
let prefix = std::str::from_utf8(&buf[0..xlrec.prefix_size - 1])?;
let message = &buf[xlrec.prefix_size..xlrec.prefix_size + xlrec.message_size];
if prefix == "neon-test" {
// This is a convenient way to make the WAL ingestion pause at
// particular point in the WAL. For more fine-grained control,
// we could peek into the message and only pause if it contains
// a particular string, for example, but this is enough for now.
crate::failpoint_support::sleep_millis_async!(
"wal-ingest-logical-message-sleep"
);
} else if let Some(path) = prefix.strip_prefix("neon-file:") {
modification.put_file(path, message, ctx).await?;
}
// This is a convenient way to make the WAL ingestion pause at
// particular point in the WAL. For more fine-grained control,
// we could peek into the message and only pause if it contains
// a particular string, for example, but this is enough for now.
crate::failpoint_support::sleep_millis_async!("wal-ingest-logical-message-sleep");
}
}
@@ -468,6 +459,7 @@ impl<'a> WalIngest<'a> {
}
} else if info == pg_constants::XLOG_HEAP_DELETE {
let xlrec = v14::XlHeapDelete::decode(buf);
assert_eq!(0, buf.remaining());
if (xlrec.flags & pg_constants::XLH_DELETE_ALL_VISIBLE_CLEARED) != 0 {
new_heap_blkno = Some(decoded.blocks[0].blkno);
}
@@ -535,6 +527,7 @@ impl<'a> WalIngest<'a> {
}
} else if info == pg_constants::XLOG_HEAP_DELETE {
let xlrec = v15::XlHeapDelete::decode(buf);
assert_eq!(0, buf.remaining());
if (xlrec.flags & pg_constants::XLH_DELETE_ALL_VISIBLE_CLEARED) != 0 {
new_heap_blkno = Some(decoded.blocks[0].blkno);
}
@@ -602,6 +595,7 @@ impl<'a> WalIngest<'a> {
}
} else if info == pg_constants::XLOG_HEAP_DELETE {
let xlrec = v16::XlHeapDelete::decode(buf);
assert_eq!(0, buf.remaining());
if (xlrec.flags & pg_constants::XLH_DELETE_ALL_VISIBLE_CLEARED) != 0 {
new_heap_blkno = Some(decoded.blocks[0].blkno);
}
@@ -777,6 +771,7 @@ impl<'a> WalIngest<'a> {
}
pg_constants::XLOG_NEON_HEAP_DELETE => {
let xlrec = v16::rm_neon::XlNeonHeapDelete::decode(buf);
assert_eq!(0, buf.remaining());
if (xlrec.flags & pg_constants::XLH_DELETE_ALL_VISIBLE_CLEARED) != 0 {
new_heap_blkno = Some(decoded.blocks[0].blkno);
}

View File

@@ -748,26 +748,6 @@ impl XlMultiXactTruncate {
}
}
#[repr(C)]
#[derive(Debug)]
pub struct XlLogicalMessage {
pub db_id: Oid,
pub transactional: bool,
pub prefix_size: usize,
pub message_size: usize,
}
impl XlLogicalMessage {
pub fn decode(buf: &mut Bytes) -> XlLogicalMessage {
XlLogicalMessage {
db_id: buf.get_u32_le(),
transactional: buf.get_u32_le() != 0, // 4-bytes alignment
prefix_size: buf.get_u64_le() as usize,
message_size: buf.get_u64_le() as usize,
}
}
}
/// Main routine to decode a WAL record and figure out which blocks are modified
//
// See xlogrecord.h for details

View File

@@ -18,37 +18,38 @@
//! any WAL records, so that even if an attacker hijacks the Postgres
//! process, he cannot escape out of it.
//!
use anyhow::Context;
use byteorder::{ByteOrder, LittleEndian};
use bytes::{BufMut, Bytes, BytesMut};
use nix::poll::*;
use serde::Serialize;
use std::collections::VecDeque;
use std::io;
use std::io::prelude::*;
use std::io::{Error, ErrorKind};
use std::ops::{Deref, DerefMut};
use std::os::unix::io::AsRawFd;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::prelude::CommandExt;
use std::process::Stdio;
use std::process::{Child, ChildStdin, ChildStdout, Command};
use std::sync::{Arc, Mutex, MutexGuard, RwLock};
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
use std::sync::{Mutex, MutexGuard};
use std::time::Duration;
use std::time::Instant;
use tokio_util::sync::CancellationToken;
use std::{fs, io};
use tracing::*;
use utils::crashsafe::path_with_suffix_extension;
use utils::{bin_ser::BeSer, id::TenantId, lsn::Lsn, nonblock::set_nonblock};
#[cfg(feature = "testing")]
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::config::PageServerConf;
use crate::metrics::{
WAL_REDO_BYTES_HISTOGRAM, WAL_REDO_RECORDS_HISTOGRAM, WAL_REDO_RECORD_COUNTER, WAL_REDO_TIME,
WAL_REDO_WAIT_TIME,
};
use crate::pgdatadir_mapping::{key_to_rel_block, key_to_slru_block};
use crate::repository::Key;
use crate::task_mgr::BACKGROUND_RUNTIME;
use crate::walrecord::NeonWalRecord;
use crate::{config::PageServerConf, TEMP_FILE_SUFFIX};
use pageserver_api::reltag::{RelTag, SlruKind};
use postgres_ffi::pg_constants;
use postgres_ffi::relfile_utils::VISIBILITYMAP_FORKNUM;
@@ -65,13 +66,37 @@ use postgres_ffi::BLCKSZ;
/// [See more related comments here](https://github.com/postgres/postgres/blob/99c5852e20a0987eca1c38ba0c09329d4076b6a0/src/include/storage/buf_internals.h#L91).
///
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize)]
pub(crate) struct BufferTag {
pub struct BufferTag {
pub rel: RelTag,
pub blknum: u32,
}
///
/// WAL Redo Manager is responsible for replaying WAL records.
///
/// Callers use the WAL redo manager through this abstract interface,
/// which makes it easy to mock it in tests.
pub trait WalRedoManager: Send + Sync {
/// Apply some WAL records.
///
/// The caller passes an old page image, and WAL records that should be
/// applied over it. The return value is a new page image, after applying
/// the reords.
fn request_redo(
&self,
key: Key,
lsn: Lsn,
base_img: Option<(Lsn, Bytes)>,
records: Vec<(Lsn, NeonWalRecord)>,
pg_version: u32,
) -> Result<Bytes, WalRedoError>;
}
struct ProcessInput {
child: NoLeakChild,
stdin: ChildStdin,
stderr_fd: RawFd,
stdout_fd: RawFd,
n_requests: usize,
}
@@ -91,7 +116,13 @@ struct ProcessOutput {
pub struct PostgresRedoManager {
tenant_id: TenantId,
conf: &'static PageServerConf,
redo_process: RwLock<Option<Arc<WalRedoProcess>>>,
/// Counter to separate same sized walredo inputs failing at the same millisecond.
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize,
stdout: Mutex<Option<ProcessOutput>>,
stdin: Mutex<Option<ProcessInput>>,
stderr: Mutex<Option<ChildStderr>>,
}
/// Can this request be served by neon redo functions
@@ -109,27 +140,41 @@ fn can_apply_in_neon(rec: &NeonWalRecord) -> bool {
}
}
/// An error happened in WAL redo
#[derive(Debug, thiserror::Error)]
pub enum WalRedoError {
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("cannot perform WAL redo now")]
InvalidState,
#[error("cannot perform WAL redo for this request")]
InvalidRequest,
#[error("cannot perform WAL redo for this record")]
InvalidRecord,
}
///
/// Public interface of WAL redo manager
///
impl PostgresRedoManager {
impl WalRedoManager for PostgresRedoManager {
///
/// Request the WAL redo manager to apply some WAL records
///
/// The WAL redo is handled by a separate thread, so this just sends a request
/// to the thread and waits for response.
///
/// CANCEL SAFETY: NOT CANCEL SAFE.
pub async fn request_redo(
fn request_redo(
&self,
key: Key,
lsn: Lsn,
base_img: Option<(Lsn, Bytes)>,
records: Vec<(Lsn, NeonWalRecord)>,
pg_version: u32,
) -> anyhow::Result<Bytes> {
) -> Result<Bytes, WalRedoError> {
if records.is_empty() {
anyhow::bail!("invalid WAL redo request with no records");
error!("invalid WAL redo request with no records");
return Err(WalRedoError::InvalidRequest);
}
let base_img_lsn = base_img.as_ref().map(|p| p.0).unwrap_or(Lsn::INVALID);
@@ -152,7 +197,6 @@ impl PostgresRedoManager {
self.conf.wal_redo_timeout,
pg_version,
)
.await
};
img = Some(result?);
@@ -173,7 +217,6 @@ impl PostgresRedoManager {
self.conf.wal_redo_timeout,
pg_version,
)
.await
}
}
}
@@ -187,15 +230,28 @@ impl PostgresRedoManager {
PostgresRedoManager {
tenant_id,
conf,
redo_process: RwLock::new(None),
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize::default(),
stdin: Mutex::new(None),
stdout: Mutex::new(None),
stderr: Mutex::new(None),
}
}
/// Launch process pre-emptively. Should not be needed except for benchmarking.
pub fn launch_process(&self, pg_version: u32) -> anyhow::Result<()> {
let mut proc = self.stdin.lock().unwrap();
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
Ok(())
}
///
/// Process one request for WAL redo using wal-redo postgres
///
#[allow(clippy::too_many_arguments)]
async fn apply_batch_postgres(
fn apply_batch_postgres(
&self,
key: Key,
lsn: Lsn,
@@ -204,45 +260,26 @@ impl PostgresRedoManager {
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
pg_version: u32,
) -> anyhow::Result<Bytes> {
let (rel, blknum) = key_to_rel_block(key).context("invalid record")?;
) -> Result<Bytes, WalRedoError> {
let (rel, blknum) = key_to_rel_block(key).or(Err(WalRedoError::InvalidRecord))?;
const MAX_RETRY_ATTEMPTS: u32 = 1;
let start_time = Instant::now();
let mut n_attempts = 0u32;
loop {
let mut proc = self.stdin.lock().unwrap();
let lock_time = Instant::now();
// launch the WAL redo process on first use
let proc: Arc<WalRedoProcess> = {
let proc_guard = self.redo_process.read().unwrap();
match &*proc_guard {
None => {
// "upgrade" to write lock to launch the process
drop(proc_guard);
let mut proc_guard = self.redo_process.write().unwrap();
match &*proc_guard {
None => {
let proc = Arc::new(
WalRedoProcess::launch(self.conf, self.tenant_id, pg_version)
.context("launch walredo process")?,
);
*proc_guard = Some(Arc::clone(&proc));
proc
}
Some(proc) => Arc::clone(proc),
}
}
Some(proc) => Arc::clone(proc),
}
};
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
let result = proc
.apply_wal_records(buf_tag, &base_img, records, wal_redo_timeout)
.context("apply_wal_records");
let result = self
.apply_wal_records(proc, buf_tag, &base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
@@ -272,50 +309,32 @@ impl PostgresRedoManager {
// next request will launch a new one.
if let Err(e) = result.as_ref() {
error!(
"error applying {} WAL records {}..{} ({} bytes) to base image with LSN {} to reconstruct page image at LSN {} n_attempts={}: {:?}",
n_attempts,
"error applying {} WAL records {}..{} ({} bytes) to base image with LSN {} to reconstruct page image at LSN {}: {}",
records.len(),
records.first().map(|p| p.0).unwrap_or(Lsn(0)),
records.last().map(|p| p.0).unwrap_or(Lsn(0)),
nbytes,
base_img_lsn,
lsn,
n_attempts,
e,
utils::error::report_compact_sources(e),
);
// Avoid concurrent callers hitting the same issue.
// We can't prevent it from happening because we want to enable parallelism.
{
let mut guard = self.redo_process.write().unwrap();
match &*guard {
Some(current_field_value) => {
if Arc::ptr_eq(current_field_value, &proc) {
// We're the first to observe an error from `proc`, it's our job to take it out of rotation.
*guard = None;
}
}
None => {
// Another thread was faster to observe the error, and already took the process out of rotation.
}
}
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
// The owning objects (ChildStdout and ChildStderr) are stored in
// self.stdout and self.stderr, respsectively.
// We intentionally keep them open here to avoid a race between
// currently running `apply_wal_records()` and a `launch()` call
// after we return here.
// The currently running `apply_wal_records()` must not read from
// the newly launched process.
// By keeping self.stdout and self.stderr open here, `launch()` will
// get other file descriptors for the new child's stdout and stderr,
// and hence the current `apply_wal_records()` calls will observe
// `output.stdout.as_raw_fd() != stdout_fd` .
if let Some(proc) = self.stdin.lock().unwrap().take() {
proc.child.kill_and_wait();
}
// NB: there may still be other concurrent threads using `proc`.
// The last one will send SIGKILL when the underlying Arc reaches refcount 0.
// NB: it's important to drop(proc) after drop(guard). Otherwise we'd keep
// holding the lock while waiting for the process to exit.
// NB: the drop impl blocks the current threads with a wait() system call for
// the child process. We dropped the `guard` above so that other threads aren't
// affected. But, it's good that the current thread _does_ block to wait.
// If we instead deferred the waiting into the background / to tokio, it could
// happen that if walredo always fails immediately, we spawn processes faster
// than we can SIGKILL & `wait` for them to exit. By doing it the way we do here,
// we limit this risk of run-away to at most $num_runtimes * $num_executor_threads.
// This probably needs revisiting at some later point.
let mut wait_done = proc.stderr_logger_task_done.clone();
drop(proc);
wait_done
.wait_for(|v| *v)
.await
.expect("we use scopeguard to ensure we always send `true` to the channel before dropping the sender");
} else if n_attempts != 0 {
info!(n_attempts, "retried walredo succeeded");
}
@@ -335,7 +354,7 @@ impl PostgresRedoManager {
lsn: Lsn,
base_img: Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
) -> anyhow::Result<Bytes> {
) -> Result<Bytes, WalRedoError> {
let start_time = Instant::now();
let mut page = BytesMut::new();
@@ -344,7 +363,8 @@ impl PostgresRedoManager {
page.extend_from_slice(&fpi[..]);
} else {
// All the current WAL record types that we can handle require a base image.
anyhow::bail!("invalid neon WAL redo request with no base image");
error!("invalid neon WAL redo request with no base image");
return Err(WalRedoError::InvalidRequest);
}
// Apply all the WAL records in the batch
@@ -372,13 +392,14 @@ impl PostgresRedoManager {
page: &mut BytesMut,
_record_lsn: Lsn,
record: &NeonWalRecord,
) -> anyhow::Result<()> {
) -> Result<(), WalRedoError> {
match record {
NeonWalRecord::Postgres {
will_init: _,
rec: _,
} => {
anyhow::bail!("tried to pass postgres wal record to neon WAL redo");
error!("tried to pass postgres wal record to neon WAL redo");
return Err(WalRedoError::InvalidRequest);
}
NeonWalRecord::ClearVisibilityMapFlags {
new_heap_blkno,
@@ -386,7 +407,7 @@ impl PostgresRedoManager {
flags,
} => {
// sanity check that this is modifying the correct relation
let (rel, blknum) = key_to_rel_block(key).context("invalid record")?;
let (rel, blknum) = key_to_rel_block(key).or(Err(WalRedoError::InvalidRecord))?;
assert!(
rel.forknum == VISIBILITYMAP_FORKNUM,
"ClearVisibilityMapFlags record on unexpected rel {}",
@@ -424,7 +445,7 @@ impl PostgresRedoManager {
// same effects as the corresponding Postgres WAL redo function.
NeonWalRecord::ClogSetCommitted { xids, timestamp } => {
let (slru_kind, segno, blknum) =
key_to_slru_block(key).context("invalid record")?;
key_to_slru_block(key).or(Err(WalRedoError::InvalidRecord))?;
assert_eq!(
slru_kind,
SlruKind::Clog,
@@ -474,7 +495,7 @@ impl PostgresRedoManager {
}
NeonWalRecord::ClogSetAborted { xids } => {
let (slru_kind, segno, blknum) =
key_to_slru_block(key).context("invalid record")?;
key_to_slru_block(key).or(Err(WalRedoError::InvalidRecord))?;
assert_eq!(
slru_kind,
SlruKind::Clog,
@@ -505,7 +526,7 @@ impl PostgresRedoManager {
}
NeonWalRecord::MultixactOffsetCreate { mid, moff } => {
let (slru_kind, segno, blknum) =
key_to_slru_block(key).context("invalid record")?;
key_to_slru_block(key).or(Err(WalRedoError::InvalidRecord))?;
assert_eq!(
slru_kind,
SlruKind::MultiXactOffsets,
@@ -538,7 +559,7 @@ impl PostgresRedoManager {
}
NeonWalRecord::MultixactMembersCreate { moff, members } => {
let (slru_kind, segno, blknum) =
key_to_slru_block(key).context("invalid record")?;
key_to_slru_block(key).or(Err(WalRedoError::InvalidRecord))?;
assert_eq!(
slru_kind,
SlruKind::MultiXactMembers,
@@ -618,33 +639,44 @@ impl<C: CommandExt> CloseFileDescriptors for C {
}
}
struct WalRedoProcess {
#[allow(dead_code)]
conf: &'static PageServerConf,
tenant_id: TenantId,
// Some() on construction, only becomes None on Drop.
child: Option<NoLeakChild>,
stdout: Mutex<ProcessOutput>,
stdin: Mutex<ProcessInput>,
stderr_logger_cancel: CancellationToken,
stderr_logger_task_done: tokio::sync::watch::Receiver<bool>,
/// Counter to separate same sized walredo inputs failing at the same millisecond.
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize,
}
impl WalRedoProcess {
impl PostgresRedoManager {
//
// Start postgres binary in special WAL redo mode.
//
#[instrument(skip_all,fields(tenant_id=%tenant_id, pg_version=pg_version))]
#[instrument(skip_all,fields(tenant_id=%self.tenant_id, pg_version=pg_version))]
fn launch(
conf: &'static PageServerConf,
tenant_id: TenantId,
&self,
input: &mut MutexGuard<Option<ProcessInput>>,
pg_version: u32,
) -> anyhow::Result<Self> {
let pg_bin_dir_path = conf.pg_bin_dir(pg_version).context("pg_bin_dir")?; // TODO these should be infallible.
let pg_lib_dir_path = conf.pg_lib_dir(pg_version).context("pg_lib_dir")?;
) -> Result<(), Error> {
// Previous versions of wal-redo required data directory and that directories
// occupied some space on disk. Remove it if we face it.
//
// This code could be dropped after one release cycle.
let legacy_datadir = path_with_suffix_extension(
self.conf
.tenant_path(&self.tenant_id)
.join("wal-redo-datadir"),
TEMP_FILE_SUFFIX,
);
if legacy_datadir.exists() {
info!("legacy wal-redo datadir {legacy_datadir:?} exists, removing");
fs::remove_dir_all(&legacy_datadir).map_err(|e| {
Error::new(
e.kind(),
format!("legacy wal-redo datadir {legacy_datadir:?} removal failure: {e}"),
)
})?;
}
let pg_bin_dir_path = self
.conf
.pg_bin_dir(pg_version)
.map_err(|e| Error::new(ErrorKind::Other, format!("incorrect pg_bin_dir path: {e}")))?;
let pg_lib_dir_path = self
.conf
.pg_lib_dir(pg_version)
.map_err(|e| Error::new(ErrorKind::Other, format!("incorrect pg_lib_dir path: {e}")))?;
// Start postgres itself
let child = Command::new(pg_bin_dir_path.join("postgres"))
@@ -665,8 +697,13 @@ impl WalRedoProcess {
// as close-on-exec by default, but that's not enough, since we use
// libraries that directly call libc open without setting that flag.
.close_fds()
.spawn_no_leak_child(tenant_id)
.context("spawn process")?;
.spawn_no_leak_child(self.tenant_id)
.map_err(|e| {
Error::new(
e.kind(),
format!("postgres --wal-redo command failed to start: {}", e),
)
})?;
let mut child = scopeguard::guard(child, |child| {
error!("killing wal-redo-postgres process due to a problem during launch");
@@ -676,6 +713,7 @@ impl WalRedoProcess {
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let stderr = child.stderr.take().unwrap();
macro_rules! set_nonblock_or_log_err {
($file:ident) => {{
let res = set_nonblock($file.as_raw_fd());
@@ -689,108 +727,39 @@ impl WalRedoProcess {
set_nonblock_or_log_err!(stdout)?;
set_nonblock_or_log_err!(stderr)?;
let mut stderr = tokio::io::unix::AsyncFd::new(stderr).context("AsyncFd::with_interest")?;
// all fallible operations post-spawn are complete, so get rid of the guard
let child = scopeguard::ScopeGuard::into_inner(child);
let stderr_logger_cancel = CancellationToken::new();
let (stderr_logger_task_done_tx, stderr_logger_task_done_rx) =
tokio::sync::watch::channel(false);
tokio::spawn({
let stderr_logger_cancel = stderr_logger_cancel.clone();
async move {
scopeguard::defer! {
debug!("wal-redo-postgres stderr_logger_task finished");
let _ = stderr_logger_task_done_tx.send(true);
}
debug!("wal-redo-postgres stderr_logger_task started");
loop {
// NB: we purposefully don't do a select! for the cancellation here.
// The cancellation would likely cause us to miss stderr messages.
// We can rely on this to return from .await because when we SIGKILL
// the child, the writing end of the stderr pipe gets closed.
match stderr.readable_mut().await {
Ok(mut guard) => {
let mut errbuf = [0; 16384];
let res = guard.try_io(|fd| {
use std::io::Read;
fd.get_mut().read(&mut errbuf)
});
match res {
Ok(Ok(0)) => {
// it closed the stderr pipe
break;
}
Ok(Ok(n)) => {
// The message might not be split correctly into lines here. But this is
// good enough, the important thing is to get the message to the log.
let output = String::from_utf8_lossy(&errbuf[0..n]).to_string();
error!(output, "received output");
},
Ok(Err(e)) => {
error!(error = ?e, "read() error, waiting for cancellation");
stderr_logger_cancel.cancelled().await;
error!(error = ?e, "read() error, cancellation complete");
break;
}
Err(e) => {
let _e: tokio::io::unix::TryIoError = e;
// the read() returned WouldBlock, that's expected
}
}
}
Err(e) => {
error!(error = ?e, "read() error, waiting for cancellation");
stderr_logger_cancel.cancelled().await;
error!(error = ?e, "read() error, cancellation complete");
break;
}
}
}
}.instrument(tracing::info_span!(parent: None, "wal-redo-postgres-stderr", pid = child.id(), tenant_id = %tenant_id, %pg_version))
**input = Some(ProcessInput {
child,
stdout_fd: stdout.as_raw_fd(),
stderr_fd: stderr.as_raw_fd(),
stdin,
n_requests: 0,
});
Ok(Self {
conf,
tenant_id,
child: Some(child),
stdin: Mutex::new(ProcessInput {
stdin,
n_requests: 0,
}),
stdout: Mutex::new(ProcessOutput {
stdout,
pending_responses: VecDeque::new(),
n_processed_responses: 0,
}),
stderr_logger_cancel,
stderr_logger_task_done: stderr_logger_task_done_rx,
#[cfg(feature = "testing")]
dump_sequence: AtomicUsize::default(),
})
}
*self.stdout.lock().unwrap() = Some(ProcessOutput {
stdout,
pending_responses: VecDeque::new(),
n_processed_responses: 0,
});
*self.stderr.lock().unwrap() = Some(stderr);
fn id(&self) -> u32 {
self.child
.as_ref()
.expect("must not call this during Drop")
.id()
Ok(())
}
// Apply given WAL records ('records') over an old page image. Returns
// new page image.
//
#[instrument(skip_all, fields(tenant_id=%self.tenant_id, pid=%self.id()))]
#[instrument(skip_all, fields(tenant_id=%self.tenant_id, pid=%input.as_ref().unwrap().child.id()))]
fn apply_wal_records(
&self,
input: MutexGuard<Option<ProcessInput>>,
tag: BufferTag,
base_img: &Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
) -> anyhow::Result<Bytes> {
let input = self.stdin.lock().unwrap();
) -> Result<Bytes, std::io::Error> {
// Serialize all the messages to send the WAL redo process first.
//
// This could be problematic if there are millions of records to replay,
@@ -813,7 +782,10 @@ impl WalRedoProcess {
{
build_apply_record_msg(*lsn, postgres_rec, &mut writebuf);
} else {
anyhow::bail!("tried to pass neon wal record to postgres WAL redo");
return Err(Error::new(
ErrorKind::Other,
"tried to pass neon wal record to postgres WAL redo",
));
}
}
build_get_page_msg(tag, &mut writebuf);
@@ -833,38 +805,77 @@ impl WalRedoProcess {
fn apply_wal_records0(
&self,
writebuf: &[u8],
input: MutexGuard<ProcessInput>,
mut input: MutexGuard<Option<ProcessInput>>,
wal_redo_timeout: Duration,
) -> anyhow::Result<Bytes> {
let mut proc = { input }; // TODO: remove this legacy rename, but this keep the patch small.
) -> Result<Bytes, std::io::Error> {
let proc = input.as_mut().unwrap();
let mut nwrite = 0usize;
let stdout_fd = proc.stdout_fd;
let mut stdin_pollfds = [PollFd::new(proc.stdin.as_raw_fd(), PollFlags::POLLOUT)];
// Prepare for calling poll()
let mut pollfds = [
PollFd::new(proc.stdin.as_raw_fd(), PollFlags::POLLOUT),
PollFd::new(proc.stderr_fd, PollFlags::POLLIN),
PollFd::new(stdout_fd, PollFlags::POLLIN),
];
// We do two things simultaneously: send the old base image and WAL records to
// the child process's stdin and forward any logging
// information that the child writes to its stderr to the page server's log.
while nwrite < writebuf.len() {
let n = loop {
match nix::poll::poll(&mut stdin_pollfds[..], wal_redo_timeout.as_millis() as i32) {
match nix::poll::poll(&mut pollfds[0..2], wal_redo_timeout.as_millis() as i32) {
Err(nix::errno::Errno::EINTR) => continue,
res => break res,
}
}?;
if n == 0 {
anyhow::bail!("WAL redo timed out");
return Err(Error::new(ErrorKind::Other, "WAL redo timed out"));
}
// If we have some messages in stderr, forward them to the log.
let err_revents = pollfds[1].revents().unwrap();
if err_revents & (PollFlags::POLLERR | PollFlags::POLLIN) != PollFlags::empty() {
let mut errbuf: [u8; 16384] = [0; 16384];
let mut stderr_guard = self.stderr.lock().unwrap();
let stderr = stderr_guard.as_mut().unwrap();
let len = stderr.read(&mut errbuf)?;
// The message might not be split correctly into lines here. But this is
// good enough, the important thing is to get the message to the log.
if len > 0 {
error!(
"wal-redo-postgres: {}",
String::from_utf8_lossy(&errbuf[0..len])
);
// To make sure we capture all log from the process if it fails, keep
// reading from the stderr, before checking the stdout.
continue;
}
} else if err_revents.contains(PollFlags::POLLHUP) {
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stderr unexpectedly",
));
}
// If 'stdin' is writeable, do write.
let in_revents = stdin_pollfds[0].revents().unwrap();
let in_revents = pollfds[0].revents().unwrap();
if in_revents & (PollFlags::POLLERR | PollFlags::POLLOUT) != PollFlags::empty() {
nwrite += proc.stdin.write(&writebuf[nwrite..])?;
} else if in_revents.contains(PollFlags::POLLHUP) {
// We still have more data to write, but the process closed the pipe.
anyhow::bail!("WAL redo process closed its stdin unexpectedly");
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stdin unexpectedly",
));
}
}
let request_no = proc.n_requests;
proc.n_requests += 1;
drop(proc);
drop(input);
// To improve walredo performance we separate sending requests and receiving
// responses. Them are protected by different mutexes (output and input).
@@ -878,8 +889,23 @@ impl WalRedoProcess {
// pending responses ring buffer and truncate all empty elements from the front,
// advancing processed responses number.
let mut output = self.stdout.lock().unwrap();
let mut stdout_pollfds = [PollFd::new(output.stdout.as_raw_fd(), PollFlags::POLLIN)];
let mut output_guard = self.stdout.lock().unwrap();
let output = output_guard.as_mut().unwrap();
if output.stdout.as_raw_fd() != stdout_fd {
// If stdout file descriptor is changed then it means that walredo process is crashed and restarted.
// As far as ProcessInput and ProcessOutout are protected by different mutexes,
// it can happen that we send request to one process and waiting response from another.
// To prevent such situation we compare stdout file descriptors.
// As far as old stdout pipe is destroyed only after new one is created,
// it can not reuse the same file descriptor, so this check is safe.
//
// Cross-read this with the comment in apply_batch_postgres if result.is_err().
// That's where we kill the child process.
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stdout unexpectedly",
));
}
let n_processed_responses = output.n_processed_responses;
while n_processed_responses + output.pending_responses.len() <= request_no {
// We expect the WAL redo process to respond with an 8k page image. We read it
@@ -890,25 +916,52 @@ impl WalRedoProcess {
// We do two things simultaneously: reading response from stdout
// and forward any logging information that the child writes to its stderr to the page server's log.
let n = loop {
match nix::poll::poll(
&mut stdout_pollfds[..],
wal_redo_timeout.as_millis() as i32,
) {
match nix::poll::poll(&mut pollfds[1..3], wal_redo_timeout.as_millis() as i32) {
Err(nix::errno::Errno::EINTR) => continue,
res => break res,
}
}?;
if n == 0 {
anyhow::bail!("WAL redo timed out");
return Err(Error::new(ErrorKind::Other, "WAL redo timed out"));
}
// If we have some messages in stderr, forward them to the log.
let err_revents = pollfds[1].revents().unwrap();
if err_revents & (PollFlags::POLLERR | PollFlags::POLLIN) != PollFlags::empty() {
let mut errbuf: [u8; 16384] = [0; 16384];
let mut stderr_guard = self.stderr.lock().unwrap();
let stderr = stderr_guard.as_mut().unwrap();
let len = stderr.read(&mut errbuf)?;
// The message might not be split correctly into lines here. But this is
// good enough, the important thing is to get the message to the log.
if len > 0 {
error!(
"wal-redo-postgres: {}",
String::from_utf8_lossy(&errbuf[0..len])
);
// To make sure we capture all log from the process if it fails, keep
// reading from the stderr, before checking the stdout.
continue;
}
} else if err_revents.contains(PollFlags::POLLHUP) {
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stderr unexpectedly",
));
}
// If we have some data in stdout, read it to the result buffer.
let out_revents = stdout_pollfds[0].revents().unwrap();
let out_revents = pollfds[2].revents().unwrap();
if out_revents & (PollFlags::POLLERR | PollFlags::POLLIN) != PollFlags::empty() {
nresult += output.stdout.read(&mut resultbuf[nresult..])?;
} else if out_revents.contains(PollFlags::POLLHUP) {
anyhow::bail!("WAL redo process closed its stdout unexpectedly");
return Err(Error::new(
ErrorKind::BrokenPipe,
"WAL redo process closed its stdout unexpectedly",
));
}
}
output
@@ -994,17 +1047,6 @@ impl WalRedoProcess {
fn record_and_log(&self, _: &[u8]) {}
}
impl Drop for WalRedoProcess {
fn drop(&mut self) {
self.child
.take()
.expect("we only do this once")
.kill_and_wait();
self.stderr_logger_cancel.cancel();
// no way to wait for stderr_logger_task from Drop because that is async only
}
}
/// Wrapper type around `std::process::Child` which guarantees that the child
/// will be killed and waited-for by this process before being dropped.
struct NoLeakChild {
@@ -1083,7 +1125,7 @@ impl Drop for NoLeakChild {
// Offload the kill+wait of the child process into the background.
// If someone stops the runtime, we'll leak the child process.
// We can ignore that case because we only stop the runtime on pageserver exit.
tokio::runtime::Handle::current().spawn(async move {
BACKGROUND_RUNTIME.spawn(async move {
tokio::task::spawn_blocking(move || {
// Intentionally don't inherit the tracing context from whoever is dropping us.
// This thread here is going to outlive of our dropper.
@@ -1152,15 +1194,15 @@ fn build_get_page_msg(tag: BufferTag, buf: &mut Vec<u8>) {
#[cfg(test)]
mod tests {
use super::PostgresRedoManager;
use super::{PostgresRedoManager, WalRedoManager};
use crate::repository::Key;
use crate::{config::PageServerConf, walrecord::NeonWalRecord};
use bytes::Bytes;
use std::str::FromStr;
use utils::{id::TenantId, lsn::Lsn};
#[tokio::test]
async fn short_v14_redo() {
#[test]
fn short_v14_redo() {
let expected = std::fs::read("fixtures/short_v14_redo.page").unwrap();
let h = RedoHarness::new().unwrap();
@@ -1181,14 +1223,13 @@ mod tests {
short_records(),
14,
)
.await
.unwrap();
assert_eq!(&expected, &*page);
}
#[tokio::test]
async fn short_v14_fails_for_wrong_key_but_returns_zero_page() {
#[test]
fn short_v14_fails_for_wrong_key_but_returns_zero_page() {
let h = RedoHarness::new().unwrap();
let page = h
@@ -1208,7 +1249,6 @@ mod tests {
short_records(),
14,
)
.await
.unwrap();
// TODO: there will be some stderr printout, which is forwarded to tracing that could
@@ -1216,22 +1256,6 @@ mod tests {
assert_eq!(page, crate::ZERO_PAGE);
}
#[tokio::test]
async fn test_stderr() {
let h = RedoHarness::new().unwrap();
h
.manager
.request_redo(
Key::from_i128(0),
Lsn::INVALID,
None,
short_records(),
16, /* 16 currently produces stderr output on startup, which adds a nice extra edge */
)
.await
.unwrap_err();
}
#[allow(clippy::octal_escapes)]
fn short_records() -> Vec<(Lsn, NeonWalRecord)> {
vec![
@@ -1260,8 +1284,6 @@ mod tests {
impl RedoHarness {
fn new() -> anyhow::Result<Self> {
crate::tenant::harness::setup_logging();
let repo_dir = camino_tempfile::tempdir()?;
let conf = PageServerConf::dummy_conf(repo_dir.path().to_path_buf());
let conf = Box::leak(Box::new(conf));

View File

@@ -20,26 +20,9 @@ SHLIB_LINK_INTERNAL = $(libpq)
SHLIB_LINK = -lcurl
EXTENSION = neon
DATA = neon--1.0.sql
DATA = neon--1.0.sql neon--1.0--1.1.sql
PGFILEDESC = "neon - cloud storage for PostgreSQL"
EXTRA_CLEAN = \
libwalproposer.a
WALPROP_OBJS = \
$(WIN32RES) \
walproposer.o \
neon_utils.o \
walproposer_compat.o
.PHONY: walproposer-lib
walproposer-lib: CPPFLAGS += -DWALPROPOSER_LIB
walproposer-lib: libwalproposer.a;
.PHONY: libwalproposer.a
libwalproposer.a: $(WALPROP_OBJS)
rm -f $@
$(AR) $(AROPT) $@ $^
PG_CONFIG = pg_config
PGXS := $(shell $(PG_CONFIG) --pgxs)

View File

@@ -741,6 +741,13 @@ NeonProcessUtility(
break;
case T_DropdbStmt:
HandleDropDb(castNode(DropdbStmt, parseTree));
/*
* We do this here to hack around the fact that Postgres performs the drop
* INSIDE of standard_ProcessUtility, which means that if we try to
* abort the drop normally it'll be too late. DROP DATABASE can't be inside
* of a transaction block anyway, so this should be fine to do.
*/
NeonXactCallback(XACT_EVENT_PRE_COMMIT, NULL);
break;
case T_CreateRoleStmt:
HandleCreateRole(castNode(CreateRoleStmt, parseTree));

View File

@@ -32,11 +32,13 @@
#include "storage/latch.h"
#include "storage/ipc.h"
#include "storage/lwlock.h"
#include "utils/builtins.h"
#include "utils/dynahash.h"
#include "utils/guc.h"
#include "storage/fd.h"
#include "storage/pg_shmem.h"
#include "storage/buf_internals.h"
#include "pgstat.h"
/*
* Local file cache is used to temporary store relations pages in local file system.
@@ -65,6 +67,7 @@
typedef struct FileCacheEntry
{
BufferTag key;
uint32 hash;
uint32 offset;
uint32 access_count;
uint32 bitmap[BLOCKS_PER_CHUNK/32];
@@ -76,6 +79,9 @@ typedef struct FileCacheControl
uint64 generation; /* generation is needed to handle correct hash reenabling */
uint32 size; /* size of cache file in chunks */
uint32 used; /* number of used chunks */
uint32 limit; /* shared copy of lfc_size_limit */
uint64 hits;
uint64 misses;
dlist_head lru; /* double linked list for LRU replacement algorithm */
} FileCacheControl;
@@ -91,10 +97,12 @@ static shmem_startup_hook_type prev_shmem_startup_hook;
static shmem_request_hook_type prev_shmem_request_hook;
#endif
void FileCacheMonitorMain(Datum main_arg);
#define LFC_ENABLED() (lfc_ctl->limit != 0)
void PGDLLEXPORT FileCacheMonitorMain(Datum main_arg);
/*
* Local file cache is mandatory and Neon can work without it.
* Local file cache is optional and Neon can work without it.
* In case of any any errors with this cache, we should disable it but to not throw error.
* Also we should allow re-enable it if source of failure (lack of disk space, permissions,...) is fixed.
* All cache content should be invalidated to avoid reading of stale or corrupted data
@@ -102,49 +110,68 @@ void FileCacheMonitorMain(Datum main_arg);
static void
lfc_disable(char const* op)
{
HASH_SEQ_STATUS status;
FileCacheEntry* entry;
elog(WARNING, "Failed to %s local file cache at %s: %m, disabling local file cache", op, lfc_path);
/* Invalidate hash */
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (LFC_ENABLED())
{
HASH_SEQ_STATUS status;
FileCacheEntry* entry;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
hash_search_with_hash_value(lfc_hash, &entry->key, entry->hash, HASH_REMOVE, NULL);
}
lfc_ctl->generation += 1;
lfc_ctl->size = 0;
lfc_ctl->used = 0;
lfc_ctl->limit = 0;
dlist_init(&lfc_ctl->lru);
if (lfc_desc > 0)
{
/* If the reason of error is ENOSPC, then truncation of file may help to reclaim some space */
int rc = ftruncate(lfc_desc, 0);
if (rc < 0)
elog(WARNING, "Failed to truncate local file cache %s: %m", lfc_path);
}
}
LWLockRelease(lfc_lock);
if (lfc_desc > 0)
close(lfc_desc);
lfc_desc = -1;
lfc_size_limit = 0;
}
/* Invalidate hash */
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
hash_search(lfc_hash, &entry->key, HASH_REMOVE, NULL);
memset(entry->bitmap, 0, sizeof entry->bitmap);
}
hash_seq_term(&status);
lfc_ctl->generation += 1;
lfc_ctl->size = 0;
lfc_ctl->used = 0;
dlist_init(&lfc_ctl->lru);
LWLockRelease(lfc_lock);
/*
* This check is done without obtaining lfc_lock, so it is unreliable
*/
static bool
lfc_maybe_disabled(void)
{
return !lfc_ctl || !LFC_ENABLED();
}
static bool
lfc_ensure_opened(void)
{
bool enabled = !lfc_maybe_disabled();
/* Open cache file if not done yet */
if (lfc_desc <= 0)
if (lfc_desc <= 0 && enabled)
{
lfc_desc = BasicOpenFile(lfc_path, O_RDWR|O_CREAT);
lfc_desc = BasicOpenFile(lfc_path, O_RDWR);
if (lfc_desc < 0) {
lfc_disable("open");
return false;
}
}
return true;
return enabled;
}
static void
@@ -163,6 +190,7 @@ lfc_shmem_startup(void)
lfc_ctl = (FileCacheControl*)ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
if (!found)
{
int fd;
uint32 lfc_size = SIZE_MB_TO_CHUNKS(lfc_max_size);
lfc_lock = (LWLockId)GetNamedLWLockTranche("lfc_lock");
info.keysize = sizeof(BufferTag);
@@ -175,10 +203,22 @@ lfc_shmem_startup(void)
lfc_ctl->generation = 0;
lfc_ctl->size = 0;
lfc_ctl->used = 0;
lfc_ctl->hits = 0;
lfc_ctl->misses = 0;
dlist_init(&lfc_ctl->lru);
/* Remove file cache on restart */
(void)unlink(lfc_path);
/* Recreate file cache on restart */
fd = BasicOpenFile(lfc_path, O_RDWR|O_CREAT|O_TRUNC);
if (fd < 0)
{
elog(WARNING, "Failed to create local file cache %s: %m", lfc_path);
lfc_ctl->limit = 0;
}
else
{
close(fd);
lfc_ctl->limit = SIZE_MB_TO_CHUNKS(lfc_size_limit);
}
}
LWLockRelease(AddinShmemInitLock);
}
@@ -195,6 +235,17 @@ lfc_shmem_request(void)
RequestNamedLWLockTranche("lfc_lock", 1);
}
static bool
is_normal_backend(void)
{
/*
* Stats collector detach shared memory, so we should not try to access shared memory here.
* Parallel workers first assign default value (0), so not perform truncation in parallel workers.
* The Postmaster can handle SIGHUP and it has access to shared memory (UsedShmemSegAddr != NULL), but has no PGPROC.
*/
return lfc_ctl && MyProc && UsedShmemSegAddr && !IsParallelWorker();
}
static bool
lfc_check_limit_hook(int *newval, void **extra, GucSource source)
{
@@ -210,25 +261,15 @@ static void
lfc_change_limit_hook(int newval, void *extra)
{
uint32 new_size = SIZE_MB_TO_CHUNKS(newval);
/*
* Stats collector detach shared memory, so we should not try to access shared memory here.
* Parallel workers first assign default value (0), so not perform truncation in parallel workers.
* The Postmaster can handle SIGHUP and it has access to shared memory (UsedShmemSegAddr != NULL), but has no PGPROC.
*/
if (!lfc_ctl || !MyProc || !UsedShmemSegAddr || IsParallelWorker())
if (!is_normal_backend())
return;
if (!lfc_ensure_opened())
return;
/* Open cache file if not done yet */
if (lfc_desc <= 0)
{
lfc_desc = BasicOpenFile(lfc_path, O_RDWR|O_CREAT);
if (lfc_desc < 0) {
elog(WARNING, "Failed to open file cache %s: %m, disabling file cache", lfc_path);
lfc_size_limit = 0; /* disable file cache */
return;
}
}
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
while (new_size < lfc_ctl->used && !dlist_is_empty(&lfc_ctl->lru))
{
/* Shrink cache by throwing away least recently accessed chunks and returning their space to file system */
@@ -238,10 +279,12 @@ lfc_change_limit_hook(int newval, void *extra)
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE|FALLOC_FL_KEEP_SIZE, (off_t)victim->offset*BLOCKS_PER_CHUNK*BLCKSZ, BLOCKS_PER_CHUNK*BLCKSZ) < 0)
elog(LOG, "Failed to punch hole in file: %m");
#endif
hash_search(lfc_hash, &victim->key, HASH_REMOVE, NULL);
hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL);
lfc_ctl->used -= 1;
}
lfc_ctl->limit = new_size;
elog(DEBUG1, "set local file cache limit to %d", new_size);
LWLockRelease(lfc_lock);
}
@@ -255,6 +298,7 @@ lfc_init(void)
if (!process_shared_preload_libraries_in_progress)
elog(ERROR, "Neon module should be loaded via shared_preload_libraries");
DefineCustomIntVariable("neon.max_file_cache_size",
"Maximal size of Neon local file cache",
NULL,
@@ -315,10 +359,10 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
BufferTag tag;
FileCacheEntry* entry;
int chunk_offs = blkno & (BLOCKS_PER_CHUNK-1);
bool found;
bool found = false;
uint32 hash;
if (lfc_size_limit == 0) /* fast exit if file cache is disabled */
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return false;
CopyNRelFileInfoToBufTag(tag, rinfo);
@@ -327,8 +371,11 @@ lfc_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
hash = get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_SHARED);
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
found = entry != NULL && (entry->bitmap[chunk_offs >> 5] & (1 << (chunk_offs & 31))) != 0;
if (LFC_ENABLED())
{
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
found = entry != NULL && (entry->bitmap[chunk_offs >> 5] & (1 << (chunk_offs & 31))) != 0;
}
LWLockRelease(lfc_lock);
return found;
}
@@ -345,7 +392,7 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
int chunk_offs = blkno & (BLOCKS_PER_CHUNK-1);
uint32 hash;
if (lfc_size_limit == 0) /* fast exit if file cache is disabled */
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return;
CopyNRelFileInfoToBufTag(tag, rinfo);
@@ -355,6 +402,13 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
hash = get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED())
{
LWLockRelease(lfc_lock);
return;
}
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, &found);
if (!found)
@@ -405,7 +459,7 @@ lfc_evict(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno)
/*
* Try to read page from local cache.
* Returns true if page is found in local cache.
* In case of error lfc_size_limit is set to zero to disable any further opera-tins with cache.
* In case of error local file cache is disabled (lfc->limit is set to zero).
*/
bool
lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
@@ -420,7 +474,7 @@ lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
uint64 generation;
uint32 entry_offset;
if (lfc_size_limit == 0) /* fast exit if file cache is disabled */
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return false;
if (!lfc_ensure_opened())
@@ -432,10 +486,19 @@ lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
hash = get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED())
{
LWLockRelease(lfc_lock);
return false;
}
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_FIND, NULL);
if (entry == NULL || (entry->bitmap[chunk_offs >> 5] & (1 << (chunk_offs & 31))) == 0)
{
/* Page is not cached */
lfc_ctl->misses += 1;
pgBufferUsage.file_cache.misses += 1;
LWLockRelease(lfc_lock);
return false;
}
@@ -456,8 +519,12 @@ lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
/* Place entry to the head of LRU list */
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (lfc_ctl->generation == generation)
{
Assert(LFC_ENABLED());
lfc_ctl->hits += 1;
pgBufferUsage.file_cache.hits += 1;
Assert(entry->access_count > 0);
if (--entry->access_count == 0)
dlist_push_tail(&lfc_ctl->lru, &entry->lru_node);
@@ -489,7 +556,7 @@ lfc_write(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
int chunk_offs = blkno & (BLOCKS_PER_CHUNK-1);
uint32 hash;
if (lfc_size_limit == 0) /* fast exit if file cache is disabled */
if (lfc_maybe_disabled()) /* fast exit if file cache is disabled */
return;
if (!lfc_ensure_opened())
@@ -497,12 +564,17 @@ lfc_write(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
tag.forkNum = forkNum;
tag.blockNum = blkno & ~(BLOCKS_PER_CHUNK-1);
CopyNRelFileInfoToBufTag(tag, rinfo);
hash = get_hash_value(lfc_hash, &tag);
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED())
{
LWLockRelease(lfc_lock);
return;
}
entry = hash_search_with_hash_value(lfc_hash, &tag, hash, HASH_ENTER, &found);
if (found)
@@ -521,13 +593,13 @@ lfc_write(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
* there are should be very large number of concurrent IO operations and them are limited by max_connections,
* we prefer not to complicate code and use second approach.
*/
if (lfc_ctl->used >= SIZE_MB_TO_CHUNKS(lfc_size_limit) && !dlist_is_empty(&lfc_ctl->lru))
if (lfc_ctl->used >= lfc_ctl->limit && !dlist_is_empty(&lfc_ctl->lru))
{
/* Cache overflow: evict least recently used chunk */
FileCacheEntry* victim = dlist_container(FileCacheEntry, lru_node, dlist_pop_head_node(&lfc_ctl->lru));
Assert(victim->access_count == 0);
entry->offset = victim->offset; /* grab victim's chunk */
hash_search(lfc_hash, &victim->key, HASH_REMOVE, NULL);
hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL);
elog(DEBUG2, "Swap file cache page");
}
else
@@ -536,6 +608,7 @@ lfc_write(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
entry->offset = lfc_ctl->size++; /* allocate new chunk at end of file */
}
entry->access_count = 1;
entry->hash = hash;
memset(entry->bitmap, 0, sizeof entry->bitmap);
}
@@ -557,6 +630,104 @@ lfc_write(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
}
typedef struct
{
TupleDesc tupdesc;
} NeonGetStatsCtx;
#define NUM_NEON_GET_STATS_COLS 2
#define NUM_NEON_GET_STATS_ROWS 3
PG_FUNCTION_INFO_V1(neon_get_stats);
Datum
neon_get_stats(PG_FUNCTION_ARGS)
{
FuncCallContext *funcctx;
NeonGetStatsCtx* fctx;
MemoryContext oldcontext;
TupleDesc tupledesc;
Datum result;
HeapTuple tuple;
char const* key;
uint64 value;
Datum values[NUM_NEON_GET_STATS_COLS];
bool nulls[NUM_NEON_GET_STATS_COLS];
if (SRF_IS_FIRSTCALL())
{
funcctx = SRF_FIRSTCALL_INIT();
/* Switch context when allocating stuff to be used in later calls */
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
/* Create a user function context for cross-call persistence */
fctx = (NeonGetStatsCtx*) palloc(sizeof(NeonGetStatsCtx));
/* Construct a tuple descriptor for the result rows. */
tupledesc = CreateTemplateTupleDesc(NUM_NEON_GET_STATS_COLS);
TupleDescInitEntry(tupledesc, (AttrNumber) 1, "ns_key",
TEXTOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "ns_value",
TEXTOID, -1, 0);
fctx->tupdesc = BlessTupleDesc(tupledesc);
funcctx->max_calls = NUM_NEON_GET_STATS_ROWS;
funcctx->user_fctx = fctx;
/* Return to original context when allocating transient memory */
MemoryContextSwitchTo(oldcontext);
}
funcctx = SRF_PERCALL_SETUP();
/* Get the saved state */
fctx = (NeonGetStatsCtx*) funcctx->user_fctx;
switch (funcctx->call_cntr)
{
case 0:
key = "file_cache_misses";
if (lfc_ctl)
value = lfc_ctl->misses;
break;
case 1:
key = "file_cache_hits";
if (lfc_ctl)
value = lfc_ctl->hits;
break;
case 2:
key = "file_cache_used";
if (lfc_ctl)
value = lfc_ctl->used;
break;
default:
SRF_RETURN_DONE(funcctx);
}
values[0] = PointerGetDatum(cstring_to_text(key));
nulls[0] = false;
if (lfc_ctl)
{
char buf[64];
snprintf(buf, sizeof buf, "%llu", (long long)value);
nulls[1] = false;
values[1] = PointerGetDatum(cstring_to_text(buf));
}
else
nulls[1] = true;
tuple = heap_form_tuple(fctx->tupdesc, values, nulls);
result = HeapTupleGetDatum(tuple);
SRF_RETURN_NEXT(funcctx, result);
}
/*
* Function returning data from the local file cache
* relation node/tablespace/database/blocknum and access_counter
*/
PG_FUNCTION_INFO_V1(local_cache_pages);
/*
* Record structure holding the to be exposed cache data.
*/
@@ -580,11 +751,6 @@ typedef struct
LocalCachePagesRec *record;
} LocalCachePagesContext;
/*
* Function returning data from the local file cache
* relation node/tablespace/database/blocknum and access_counter
*/
PG_FUNCTION_INFO_V1(local_cache_pages);
#define NUM_LOCALCACHE_PAGES_ELEM 7
@@ -653,13 +819,15 @@ local_cache_pages(PG_FUNCTION_ARGS)
LWLockAcquire(lfc_lock, LW_SHARED);
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
if (LFC_ENABLED())
{
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
n_pages += (entry->bitmap[i >> 5] & (1 << (i & 31))) != 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
for (int i = 0; i < BLOCKS_PER_CHUNK/32; i++)
n_pages += pg_popcount32(entry->bitmap[i]);
}
}
hash_seq_term(&status);
fctx->record = (LocalCachePagesRec *)
MemoryContextAllocHuge(CurrentMemoryContext,
sizeof(LocalCachePagesRec) * n_pages);
@@ -671,35 +839,33 @@ local_cache_pages(PG_FUNCTION_ARGS)
/* Return to original context when allocating transient memory */
MemoryContextSwitchTo(oldcontext);
/*
* Scan through all the buffers, saving the relevant fields in the
* fctx->record structure.
*
* We don't hold the partition locks, so we don't get a consistent
* snapshot across all buffers, but we do grab the buffer header
* locks, so the information of each buffer is self-consistent.
*/
n_pages = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
if (n_pages != 0)
{
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
/*
* Scan through all the cache entries, saving the relevant fields in the
* fctx->record structure.
*/
uint32 n = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
if (entry->bitmap[i >> 5] & (1 << (i & 31)))
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
fctx->record[n_pages].pageoffs = entry->offset*BLOCKS_PER_CHUNK + i;
fctx->record[n_pages].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n_pages].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n_pages].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n_pages].forknum = entry->key.forkNum;
fctx->record[n_pages].blocknum = entry->key.blockNum + i;
fctx->record[n_pages].accesscount = entry->access_count;
n_pages += 1;
if (entry->bitmap[i >> 5] & (1 << (i & 31)))
{
fctx->record[n].pageoffs = entry->offset*BLOCKS_PER_CHUNK + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
}
}
Assert(n_pages == n);
}
hash_seq_term(&status);
Assert(n_pages == funcctx->max_calls);
LWLockRelease(lfc_lock);
}

View File

@@ -0,0 +1,10 @@
\echo Use "ALTER EXTENSION neon UPDATE TO '1.1'" to load this file. \quit
CREATE FUNCTION neon_get_stats()
RETURNS SETOF RECORD
AS 'MODULE_PATHNAME', 'neon_get_stats'
LANGUAGE C PARALLEL SAFE;
-- Create a view for convenient access.
CREATE VIEW neon_stats AS
SELECT P.* FROM neon_get_stats() AS P (ns_key text, ns_value text);

View File

@@ -1,4 +1,4 @@
# neon extension
comment = 'cloud storage for PostgreSQL'
default_version = '1.0'
default_version = '1.1'
module_pathname = '$libdir/neon'

View File

@@ -63,6 +63,7 @@
#include "storage/md.h"
#include "pgstat.h"
#if PG_VERSION_NUM >= 150000
#include "access/xlogutils.h"
#include "access/xlogrecovery.h"
@@ -720,7 +721,7 @@ prefetch_register_buffer(BufferTag tag, bool *force_latest, XLogRecPtr *force_ls
/* use an intermediate PrefetchRequest struct to ensure correct alignment */
req.buftag = tag;
Retry:
entry = prfh_lookup(MyPState->prf_hash, (PrefetchRequest *) &req);
if (entry != NULL)
@@ -857,11 +858,7 @@ prefetch_register_buffer(BufferTag tag, bool *force_latest, XLogRecPtr *force_ls
if (flush_every_n_requests > 0 &&
MyPState->ring_unused - MyPState->ring_flush >= flush_every_n_requests)
{
if (!page_server->flush())
{
/* Prefetch set is reset in case of error, so we should try to register our request once again */
goto Retry;
}
page_server->flush();
MyPState->ring_flush = MyPState->ring_unused;
}
@@ -1394,6 +1391,12 @@ neon_get_request_lsn(bool *latest, NRelFileInfo rinfo, ForkNumber forknum, Block
elog(DEBUG1, "neon_get_request_lsn GetXLogReplayRecPtr %X/%X request lsn 0 ",
(uint32) ((lsn) >> 32), (uint32) (lsn));
}
else if (am_walsender)
{
*latest = true;
lsn = InvalidXLogRecPtr;
elog(DEBUG1, "am walsender neon_get_request_lsn lsn 0 ");
}
else
{
XLogRecPtr flushlsn;

View File

@@ -79,7 +79,7 @@ static int CompareLsn(const void *a, const void *b);
static char *FormatSafekeeperState(SafekeeperState state);
static void AssertEventsOkForState(uint32 events, Safekeeper *sk);
static uint32 SafekeeperStateDesiredEvents(SafekeeperState state);
static char *FormatEvents(WalProposer *wp, uint32 events);
static char *FormatEvents(uint32 events);
WalProposer *
WalProposerCreate(WalProposerConfig *config, walproposer_api api)
@@ -98,7 +98,7 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
port = strchr(host, ':');
if (port == NULL)
{
walprop_log(FATAL, "port is not specified");
elog(FATAL, "port is not specified");
}
*port++ = '\0';
sep = strchr(port, ',');
@@ -106,11 +106,12 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
*sep++ = '\0';
if (wp->n_safekeepers + 1 >= MAX_SAFEKEEPERS)
{
walprop_log(FATAL, "Too many safekeepers");
elog(FATAL, "Too many safekeepers");
}
wp->safekeeper[wp->n_safekeepers].host = host;
wp->safekeeper[wp->n_safekeepers].port = port;
wp->safekeeper[wp->n_safekeepers].state = SS_OFFLINE;
wp->safekeeper[wp->n_safekeepers].conn = NULL;
wp->safekeeper[wp->n_safekeepers].wp = wp;
{
@@ -121,11 +122,13 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
"host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'",
sk->host, sk->port, wp->config->neon_timeline, wp->config->neon_tenant);
if (written > MAXCONNINFO || written < 0)
walprop_log(FATAL, "could not create connection string for safekeeper %s:%s", sk->host, sk->port);
elog(FATAL, "could not create connection string for safekeeper %s:%s", sk->host, sk->port);
}
initStringInfo(&wp->safekeeper[wp->n_safekeepers].outbuf);
wp->api.wal_reader_allocate(&wp->safekeeper[wp->n_safekeepers]);
wp->safekeeper[wp->n_safekeepers].xlogreader = wp->api.wal_reader_allocate();
if (wp->safekeeper[wp->n_safekeepers].xlogreader == NULL)
elog(FATAL, "Failed to allocate xlog reader");
wp->safekeeper[wp->n_safekeepers].flushWrite = false;
wp->safekeeper[wp->n_safekeepers].startStreamingAt = InvalidXLogRecPtr;
wp->safekeeper[wp->n_safekeepers].streamingAt = InvalidXLogRecPtr;
@@ -133,7 +136,7 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
}
if (wp->n_safekeepers < 1)
{
walprop_log(FATAL, "Safekeepers addresses are not specified");
elog(FATAL, "Safekeepers addresses are not specified");
}
wp->quorum = wp->n_safekeepers / 2 + 1;
@@ -141,47 +144,27 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
wp->greetRequest.tag = 'g';
wp->greetRequest.protocolVersion = SK_PROTOCOL_VERSION;
wp->greetRequest.pgVersion = PG_VERSION_NUM;
wp->api.strong_random(wp, &wp->greetRequest.proposerId, sizeof(wp->greetRequest.proposerId));
wp->api.strong_random(&wp->greetRequest.proposerId, sizeof(wp->greetRequest.proposerId));
wp->greetRequest.systemId = wp->config->systemId;
if (!wp->config->neon_timeline)
walprop_log(FATAL, "neon.timeline_id is not provided");
elog(FATAL, "neon.timeline_id is not provided");
if (*wp->config->neon_timeline != '\0' &&
!HexDecodeString(wp->greetRequest.timeline_id, wp->config->neon_timeline, 16))
walprop_log(FATAL, "Could not parse neon.timeline_id, %s", wp->config->neon_timeline);
elog(FATAL, "Could not parse neon.timeline_id, %s", wp->config->neon_timeline);
if (!wp->config->neon_tenant)
walprop_log(FATAL, "neon.tenant_id is not provided");
elog(FATAL, "neon.tenant_id is not provided");
if (*wp->config->neon_tenant != '\0' &&
!HexDecodeString(wp->greetRequest.tenant_id, wp->config->neon_tenant, 16))
walprop_log(FATAL, "Could not parse neon.tenant_id, %s", wp->config->neon_tenant);
elog(FATAL, "Could not parse neon.tenant_id, %s", wp->config->neon_tenant);
wp->greetRequest.timeline = wp->config->pgTimeline;
wp->greetRequest.timeline = wp->api.get_timeline_id();
wp->greetRequest.walSegSize = wp->config->wal_segment_size;
wp->api.init_event_set(wp);
wp->api.init_event_set(wp->n_safekeepers);
return wp;
}
void
WalProposerFree(WalProposer *wp)
{
for (int i = 0; i < wp->n_safekeepers; i++)
{
Safekeeper *sk = &wp->safekeeper[i];
Assert(sk->outbuf.data != NULL);
pfree(sk->outbuf.data);
if (sk->voteResponse.termHistory.entries)
pfree(sk->voteResponse.termHistory.entries);
sk->voteResponse.termHistory.entries = NULL;
}
if (wp->propTermHistory.entries != NULL)
pfree(wp->propTermHistory.entries);
wp->propTermHistory.entries = NULL;
pfree(wp);
}
/*
* Create new AppendRequest message and start sending it. This function is
* called from walsender every time the new WAL is available.
@@ -207,10 +190,10 @@ WalProposerPoll(WalProposer *wp)
Safekeeper *sk = NULL;
int rc = 0;
uint32 events = 0;
TimestampTz now = wp->api.get_current_timestamp(wp);
TimestampTz now = wp->api.get_current_timestamp();
long timeout = TimeToReconnect(wp, now);
rc = wp->api.wait_event_set(wp, timeout, &sk, &events);
rc = wp->api.wait_event_set(timeout, &sk, &events);
/* Exit loop if latch is set (we got new WAL) */
if ((rc == 1 && events & WL_LATCH_SET))
@@ -241,14 +224,14 @@ WalProposerPoll(WalProposer *wp)
*/
if (!wp->config->syncSafekeepers)
{
XLogRecPtr flushed = wp->api.get_flush_rec_ptr(wp);
XLogRecPtr flushed = wp->api.get_flush_rec_ptr();
if (flushed > wp->availableLsn)
break;
}
}
now = wp->api.get_current_timestamp(wp);
now = wp->api.get_current_timestamp();
/* timeout expired: poll state */
if (rc == 0 || TimeToReconnect(wp, now) <= 0)
{
@@ -266,7 +249,7 @@ WalProposerPoll(WalProposer *wp)
/*
* Abandon connection attempts which take too long.
*/
now = wp->api.get_current_timestamp(wp);
now = wp->api.get_current_timestamp();
for (int i = 0; i < wp->n_safekeepers; i++)
{
Safekeeper *sk = &wp->safekeeper[i];
@@ -274,7 +257,7 @@ WalProposerPoll(WalProposer *wp)
if (TimestampDifferenceExceeds(sk->latestMsgReceivedAt, now,
wp->config->safekeeper_connection_timeout))
{
walprop_log(WARNING, "terminating connection to safekeeper '%s:%s' in '%s' state: no messages received during the last %dms or connection attempt took longer than that",
elog(WARNING, "terminating connection to safekeeper '%s:%s' in '%s' state: no messages received during the last %dms or connection attempt took longer than that",
sk->host, sk->port, FormatSafekeeperState(sk->state), wp->config->safekeeper_connection_timeout);
ShutdownConnection(sk);
}
@@ -313,10 +296,10 @@ HackyRemoveWalProposerEvent(Safekeeper *to_remove)
{
WalProposer *wp = to_remove->wp;
/* Remove the existing event set, assign sk->eventPos = -1 */
wp->api.free_event_set(wp);
/* Remove the existing event set */
wp->api.free_event_set();
/* Re-initialize it without adding any safekeeper events */
wp->api.init_event_set(wp);
wp->api.init_event_set(wp->n_safekeepers);
/*
* loop through the existing safekeepers. If they aren't the one we're
@@ -328,11 +311,13 @@ HackyRemoveWalProposerEvent(Safekeeper *to_remove)
uint32 desired_events = WL_NO_EVENTS;
Safekeeper *sk = &wp->safekeeper[i];
sk->eventPos = -1;
if (sk == to_remove)
continue;
/* If this safekeeper isn't offline, add an event for it! */
if (sk->state != SS_OFFLINE)
if (sk->conn != NULL)
{
desired_events = SafekeeperStateDesiredEvents(sk->state);
/* will set sk->eventPos */
@@ -345,7 +330,9 @@ HackyRemoveWalProposerEvent(Safekeeper *to_remove)
static void
ShutdownConnection(Safekeeper *sk)
{
sk->wp->api.conn_finish(sk);
if (sk->conn)
sk->wp->api.conn_finish(sk->conn);
sk->conn = NULL;
sk->state = SS_OFFLINE;
sk->flushWrite = false;
sk->streamingAt = InvalidXLogRecPtr;
@@ -374,16 +361,23 @@ ResetConnection(Safekeeper *sk)
}
/*
* Try to establish new connection, it will update sk->conn.
* Try to establish new connection
*/
wp->api.conn_connect_start(sk);
sk->conn = wp->api.conn_connect_start((char *) &sk->conninfo);
/*
* "If the result is null, then libpq has been unable to allocate a new
* PGconn structure"
*/
if (!sk->conn)
elog(FATAL, "failed to allocate new PGconn object");
/*
* PQconnectStart won't actually start connecting until we run
* PQconnectPoll. Before we do that though, we need to check that it
* didn't immediately fail.
*/
if (wp->api.conn_status(sk) == WP_CONNECTION_BAD)
if (wp->api.conn_status(sk->conn) == WP_CONNECTION_BAD)
{
/*---
* According to libpq docs:
@@ -394,14 +388,15 @@ ResetConnection(Safekeeper *sk)
*
* https://www.postgresql.org/docs/devel/libpq-connect.html#LIBPQ-PQCONNECTSTARTPARAMS
*/
walprop_log(WARNING, "Immediate failure to connect with node '%s:%s':\n\terror: %s",
sk->host, sk->port, wp->api.conn_error_message(sk));
elog(WARNING, "Immediate failure to connect with node '%s:%s':\n\terror: %s",
sk->host, sk->port, wp->api.conn_error_message(sk->conn));
/*
* Even though the connection failed, we still need to clean up the
* object
*/
wp->api.conn_finish(sk);
wp->api.conn_finish(sk->conn);
sk->conn = NULL;
return;
}
@@ -418,10 +413,10 @@ ResetConnection(Safekeeper *sk)
* (see libpqrcv_connect, defined in
* src/backend/replication/libpqwalreceiver/libpqwalreceiver.c)
*/
walprop_log(LOG, "connecting with node %s:%s", sk->host, sk->port);
elog(LOG, "connecting with node %s:%s", sk->host, sk->port);
sk->state = SS_CONNECTING_WRITE;
sk->latestMsgReceivedAt = wp->api.get_current_timestamp(wp);
sk->latestMsgReceivedAt = wp->api.get_current_timestamp();
wp->api.add_safekeeper_event_set(sk, WL_SOCKET_WRITEABLE);
return;
@@ -452,7 +447,7 @@ TimeToReconnect(WalProposer *wp, TimestampTz now)
static void
ReconnectSafekeepers(WalProposer *wp)
{
TimestampTz now = wp->api.get_current_timestamp(wp);
TimestampTz now = wp->api.get_current_timestamp();
if (TimeToReconnect(wp, now) == 0)
{
@@ -472,8 +467,6 @@ ReconnectSafekeepers(WalProposer *wp)
static void
AdvancePollState(Safekeeper *sk, uint32 events)
{
WalProposer *wp = sk->wp;
/*
* Sanity check. We assume further down that the operations don't block
* because the socket is ready.
@@ -488,7 +481,7 @@ AdvancePollState(Safekeeper *sk, uint32 events)
* ResetConnection
*/
case SS_OFFLINE:
walprop_log(FATAL, "Unexpected safekeeper %s:%s state advancement: is offline",
elog(FATAL, "Unexpected safekeeper %s:%s state advancement: is offline",
sk->host, sk->port);
break; /* actually unreachable, but prevents
* -Wimplicit-fallthrough */
@@ -524,7 +517,7 @@ AdvancePollState(Safekeeper *sk, uint32 events)
* requests.
*/
case SS_VOTING:
walprop_log(WARNING, "EOF from node %s:%s in %s state", sk->host,
elog(WARNING, "EOF from node %s:%s in %s state", sk->host,
sk->port, FormatSafekeeperState(sk->state));
ResetConnection(sk);
return;
@@ -553,7 +546,7 @@ AdvancePollState(Safekeeper *sk, uint32 events)
* Idle state for waiting votes from quorum.
*/
case SS_IDLE:
walprop_log(WARNING, "EOF from node %s:%s in %s state", sk->host,
elog(WARNING, "EOF from node %s:%s in %s state", sk->host,
sk->port, FormatSafekeeperState(sk->state));
ResetConnection(sk);
return;
@@ -571,7 +564,7 @@ static void
HandleConnectionEvent(Safekeeper *sk)
{
WalProposer *wp = sk->wp;
WalProposerConnectPollStatusType result = wp->api.conn_connect_poll(sk);
WalProposerConnectPollStatusType result = wp->api.conn_connect_poll(sk->conn);
/* The new set of events we'll wait on, after updating */
uint32 new_events = WL_NO_EVENTS;
@@ -579,9 +572,9 @@ HandleConnectionEvent(Safekeeper *sk)
switch (result)
{
case WP_CONN_POLLING_OK:
walprop_log(LOG, "connected with node %s:%s", sk->host,
elog(LOG, "connected with node %s:%s", sk->host,
sk->port);
sk->latestMsgReceivedAt = wp->api.get_current_timestamp(wp);
sk->latestMsgReceivedAt = wp->api.get_current_timestamp();
/*
* We have to pick some event to update event set. We'll
@@ -603,8 +596,8 @@ HandleConnectionEvent(Safekeeper *sk)
break;
case WP_CONN_POLLING_FAILED:
walprop_log(WARNING, "failed to connect to node '%s:%s': %s",
sk->host, sk->port, wp->api.conn_error_message(sk));
elog(WARNING, "failed to connect to node '%s:%s': %s",
sk->host, sk->port, wp->api.conn_error_message(sk->conn));
/*
* If connecting failed, we don't want to restart the connection
@@ -638,10 +631,10 @@ SendStartWALPush(Safekeeper *sk)
{
WalProposer *wp = sk->wp;
if (!wp->api.conn_send_query(sk, "START_WAL_PUSH"))
if (!wp->api.conn_send_query(sk->conn, "START_WAL_PUSH"))
{
walprop_log(WARNING, "Failed to send 'START_WAL_PUSH' query to safekeeper %s:%s: %s",
sk->host, sk->port, wp->api.conn_error_message(sk));
elog(WARNING, "Failed to send 'START_WAL_PUSH' query to safekeeper %s:%s: %s",
sk->host, sk->port, wp->api.conn_error_message(sk->conn));
ShutdownConnection(sk);
return;
}
@@ -654,7 +647,7 @@ RecvStartWALPushResult(Safekeeper *sk)
{
WalProposer *wp = sk->wp;
switch (wp->api.conn_get_query_result(sk))
switch (wp->api.conn_get_query_result(sk->conn))
{
/*
* Successful result, move on to starting the handshake
@@ -677,8 +670,8 @@ RecvStartWALPushResult(Safekeeper *sk)
break;
case WP_EXEC_FAILED:
walprop_log(WARNING, "Failed to send query to safekeeper %s:%s: %s",
sk->host, sk->port, wp->api.conn_error_message(sk));
elog(WARNING, "Failed to send query to safekeeper %s:%s: %s",
sk->host, sk->port, wp->api.conn_error_message(sk->conn));
ShutdownConnection(sk);
return;
@@ -688,7 +681,7 @@ RecvStartWALPushResult(Safekeeper *sk)
* wrong"
*/
case WP_EXEC_UNEXPECTED_SUCCESS:
walprop_log(WARNING, "Received bad response from safekeeper %s:%s query execution",
elog(WARNING, "Received bad response from safekeeper %s:%s query execution",
sk->host, sk->port);
ShutdownConnection(sk);
return;
@@ -724,7 +717,7 @@ RecvAcceptorGreeting(Safekeeper *sk)
if (!AsyncReadMessage(sk, (AcceptorProposerMessage *) &sk->greetResponse))
return;
walprop_log(LOG, "received AcceptorGreeting from safekeeper %s:%s", sk->host, sk->port);
elog(LOG, "received AcceptorGreeting from safekeeper %s:%s", sk->host, sk->port);
/* Protocol is all good, move to voting. */
sk->state = SS_VOTING;
@@ -744,7 +737,7 @@ RecvAcceptorGreeting(Safekeeper *sk)
if (wp->n_connected == wp->quorum)
{
wp->propTerm++;
walprop_log(LOG, "proposer connected to quorum (%d) safekeepers, propTerm=" INT64_FORMAT, wp->quorum, wp->propTerm);
elog(LOG, "proposer connected to quorum (%d) safekeepers, propTerm=" INT64_FORMAT, wp->quorum, wp->propTerm);
wp->voteRequest = (VoteRequest)
{
@@ -757,7 +750,7 @@ RecvAcceptorGreeting(Safekeeper *sk)
else if (sk->greetResponse.term > wp->propTerm)
{
/* Another compute with higher term is running. */
walprop_log(FATAL, "WAL acceptor %s:%s with term " INT64_FORMAT " rejects our connection request with term " INT64_FORMAT "",
elog(FATAL, "WAL acceptor %s:%s with term " INT64_FORMAT " rejects our connection request with term " INT64_FORMAT "",
sk->host, sk->port,
sk->greetResponse.term, wp->propTerm);
}
@@ -799,7 +792,7 @@ SendVoteRequest(Safekeeper *sk)
WalProposer *wp = sk->wp;
/* We have quorum for voting, send our vote request */
walprop_log(LOG, "requesting vote from %s:%s for term " UINT64_FORMAT, sk->host, sk->port, wp->voteRequest.term);
elog(LOG, "requesting vote from %s:%s for term " UINT64_FORMAT, sk->host, sk->port, wp->voteRequest.term);
/* On failure, logging & resetting is handled */
if (!BlockingWrite(sk, &wp->voteRequest, sizeof(wp->voteRequest), SS_WAIT_VERDICT))
return;
@@ -816,7 +809,7 @@ RecvVoteResponse(Safekeeper *sk)
if (!AsyncReadMessage(sk, (AcceptorProposerMessage *) &sk->voteResponse))
return;
walprop_log(LOG,
elog(LOG,
"got VoteResponse from acceptor %s:%s, voteGiven=" UINT64_FORMAT ", epoch=" UINT64_FORMAT ", flushLsn=%X/%X, truncateLsn=%X/%X, timelineStartLsn=%X/%X",
sk->host, sk->port, sk->voteResponse.voteGiven, GetHighestTerm(&sk->voteResponse.termHistory),
LSN_FORMAT_ARGS(sk->voteResponse.flushLsn),
@@ -831,7 +824,7 @@ RecvVoteResponse(Safekeeper *sk)
if ((!sk->voteResponse.voteGiven) &&
(sk->voteResponse.term > wp->propTerm || wp->n_votes < wp->quorum))
{
walprop_log(FATAL, "WAL acceptor %s:%s with term " INT64_FORMAT " rejects our connection request with term " INT64_FORMAT "",
elog(FATAL, "WAL acceptor %s:%s with term " INT64_FORMAT " rejects our connection request with term " INT64_FORMAT "",
sk->host, sk->port,
sk->voteResponse.term, wp->propTerm);
}
@@ -876,19 +869,19 @@ HandleElectedProposer(WalProposer *wp)
*/
if (wp->truncateLsn < wp->propEpochStartLsn)
{
walprop_log(LOG,
elog(LOG,
"start recovery because truncateLsn=%X/%X is not "
"equal to epochStartLsn=%X/%X",
LSN_FORMAT_ARGS(wp->truncateLsn),
LSN_FORMAT_ARGS(wp->propEpochStartLsn));
/* Perform recovery */
if (!wp->api.recovery_download(&wp->safekeeper[wp->donor], wp->greetRequest.timeline, wp->truncateLsn, wp->propEpochStartLsn))
walprop_log(FATAL, "Failed to recover state");
elog(FATAL, "Failed to recover state");
}
else if (wp->config->syncSafekeepers)
{
/* Sync is not needed: just exit */
wp->api.finish_sync_safekeepers(wp, wp->propEpochStartLsn);
wp->api.finish_sync_safekeepers(wp->propEpochStartLsn);
/* unreachable */
}
@@ -989,7 +982,7 @@ DetermineEpochStartLsn(WalProposer *wp)
if (wp->timelineStartLsn != InvalidXLogRecPtr &&
wp->timelineStartLsn != wp->safekeeper[i].voteResponse.timelineStartLsn)
{
walprop_log(WARNING,
elog(WARNING,
"inconsistent timelineStartLsn: current %X/%X, received %X/%X",
LSN_FORMAT_ARGS(wp->timelineStartLsn),
LSN_FORMAT_ARGS(wp->safekeeper[i].voteResponse.timelineStartLsn));
@@ -1005,12 +998,12 @@ DetermineEpochStartLsn(WalProposer *wp)
*/
if (wp->propEpochStartLsn == InvalidXLogRecPtr && !wp->config->syncSafekeepers)
{
wp->propEpochStartLsn = wp->truncateLsn = wp->api.get_redo_start_lsn(wp);
wp->propEpochStartLsn = wp->truncateLsn = wp->api.get_redo_start_lsn();
if (wp->timelineStartLsn == InvalidXLogRecPtr)
{
wp->timelineStartLsn = wp->api.get_redo_start_lsn(wp);
wp->timelineStartLsn = wp->api.get_redo_start_lsn();
}
walprop_log(LOG, "bumped epochStartLsn to the first record %X/%X", LSN_FORMAT_ARGS(wp->propEpochStartLsn));
elog(LOG, "bumped epochStartLsn to the first record %X/%X", LSN_FORMAT_ARGS(wp->propEpochStartLsn));
}
/*
@@ -1037,7 +1030,7 @@ DetermineEpochStartLsn(WalProposer *wp)
wp->propTermHistory.entries[wp->propTermHistory.n_entries - 1].term = wp->propTerm;
wp->propTermHistory.entries[wp->propTermHistory.n_entries - 1].lsn = wp->propEpochStartLsn;
walprop_log(LOG, "got votes from majority (%d) of nodes, term " UINT64_FORMAT ", epochStartLsn %X/%X, donor %s:%s, truncate_lsn %X/%X",
elog(LOG, "got votes from majority (%d) of nodes, term " UINT64_FORMAT ", epochStartLsn %X/%X, donor %s:%s, truncate_lsn %X/%X",
wp->quorum,
wp->propTerm,
LSN_FORMAT_ARGS(wp->propEpochStartLsn),
@@ -1051,7 +1044,7 @@ DetermineEpochStartLsn(WalProposer *wp)
*/
if (!wp->config->syncSafekeepers)
{
WalproposerShmemState *walprop_shared = wp->api.get_shmem_state(wp);
WalproposerShmemState *walprop_shared = wp->api.get_shmem_state();
/*
* Basebackup LSN always points to the beginning of the record (not
@@ -1059,7 +1052,7 @@ DetermineEpochStartLsn(WalProposer *wp)
* Safekeepers don't skip header as they need continious stream of
* data, so correct LSN for comparison.
*/
if (SkipXLogPageHeader(wp, wp->propEpochStartLsn) != wp->api.get_redo_start_lsn(wp))
if (SkipXLogPageHeader(wp, wp->propEpochStartLsn) != wp->api.get_redo_start_lsn())
{
/*
* However, allow to proceed if previously elected leader was me;
@@ -1069,21 +1062,14 @@ DetermineEpochStartLsn(WalProposer *wp)
if (!((dth->n_entries >= 1) && (dth->entries[dth->n_entries - 1].term ==
walprop_shared->mineLastElectedTerm)))
{
walprop_log(PANIC,
elog(PANIC,
"collected propEpochStartLsn %X/%X, but basebackup LSN %X/%X",
LSN_FORMAT_ARGS(wp->propEpochStartLsn),
LSN_FORMAT_ARGS(wp->api.get_redo_start_lsn(wp)));
LSN_FORMAT_ARGS(wp->api.get_redo_start_lsn()));
}
}
walprop_shared->mineLastElectedTerm = wp->propTerm;
}
/*
* WalProposer has just elected itself and initialized history, so
* we can call election callback. Usually it updates truncateLsn to
* fetch WAL for logical replication.
*/
wp->api.after_election(wp);
}
/*
@@ -1154,7 +1140,7 @@ SendProposerElected(Safekeeper *sk)
*/
sk->startStreamingAt = wp->truncateLsn;
walprop_log(WARNING, "empty safekeeper joined cluster as %s:%s, historyStart=%X/%X, sk->startStreamingAt=%X/%X",
elog(WARNING, "empty safekeeper joined cluster as %s:%s, historyStart=%X/%X, sk->startStreamingAt=%X/%X",
sk->host, sk->port, LSN_FORMAT_ARGS(wp->propTermHistory.entries[0].lsn),
LSN_FORMAT_ARGS(sk->startStreamingAt));
}
@@ -1189,7 +1175,7 @@ SendProposerElected(Safekeeper *sk)
msg.timelineStartLsn = wp->timelineStartLsn;
lastCommonTerm = i >= 0 ? wp->propTermHistory.entries[i].term : 0;
walprop_log(LOG,
elog(LOG,
"sending elected msg to node " UINT64_FORMAT " term=" UINT64_FORMAT ", startStreamingAt=%X/%X (lastCommonTerm=" UINT64_FORMAT "), termHistory.n_entries=%u to %s:%s, timelineStartLsn=%X/%X",
sk->greetResponse.nodeId, msg.term, LSN_FORMAT_ARGS(msg.startStreamingAt), lastCommonTerm, msg.termHistory->n_entries, sk->host, sk->port, LSN_FORMAT_ARGS(msg.timelineStartLsn));
@@ -1354,12 +1340,13 @@ SendAppendRequests(Safekeeper *sk)
req = &sk->appendRequest;
PrepareAppendRequest(sk->wp, &sk->appendRequest, sk->streamingAt, endLsn);
walprop_log(DEBUG2, "sending message len %ld beginLsn=%X/%X endLsn=%X/%X commitLsn=%X/%X truncateLsn=%X/%X to %s:%s",
ereport(DEBUG2,
(errmsg("sending message len %ld beginLsn=%X/%X endLsn=%X/%X commitLsn=%X/%X truncateLsn=%X/%X to %s:%s",
req->endLsn - req->beginLsn,
LSN_FORMAT_ARGS(req->beginLsn),
LSN_FORMAT_ARGS(req->endLsn),
LSN_FORMAT_ARGS(req->commitLsn),
LSN_FORMAT_ARGS(wp->truncateLsn), sk->host, sk->port);
LSN_FORMAT_ARGS(wp->truncateLsn), sk->host, sk->port)));
resetStringInfo(&sk->outbuf);
@@ -1369,13 +1356,13 @@ SendAppendRequests(Safekeeper *sk)
/* write the WAL itself */
enlargeStringInfo(&sk->outbuf, req->endLsn - req->beginLsn);
/* wal_read will raise error on failure */
wp->api.wal_read(sk,
wp->api.wal_read(sk->xlogreader,
&sk->outbuf.data[sk->outbuf.len],
req->beginLsn,
req->endLsn - req->beginLsn);
sk->outbuf.len += req->endLsn - req->beginLsn;
writeResult = wp->api.conn_async_write(sk, sk->outbuf.data, sk->outbuf.len);
writeResult = wp->api.conn_async_write(sk->conn, sk->outbuf.data, sk->outbuf.len);
/* Mark current message as sent, whatever the result is */
sk->streamingAt = endLsn;
@@ -1397,9 +1384,9 @@ SendAppendRequests(Safekeeper *sk)
return true;
case PG_ASYNC_WRITE_FAIL:
walprop_log(WARNING, "Failed to send to node %s:%s in %s state: %s",
elog(WARNING, "Failed to send to node %s:%s in %s state: %s",
sk->host, sk->port, FormatSafekeeperState(sk->state),
wp->api.conn_error_message(sk));
wp->api.conn_error_message(sk->conn));
ShutdownConnection(sk);
return false;
default:
@@ -1437,16 +1424,17 @@ RecvAppendResponses(Safekeeper *sk)
if (!AsyncReadMessage(sk, (AcceptorProposerMessage *) &sk->appendResponse))
break;
walprop_log(DEBUG2, "received message term=" INT64_FORMAT " flushLsn=%X/%X commitLsn=%X/%X from %s:%s",
ereport(DEBUG2,
(errmsg("received message term=" INT64_FORMAT " flushLsn=%X/%X commitLsn=%X/%X from %s:%s",
sk->appendResponse.term,
LSN_FORMAT_ARGS(sk->appendResponse.flushLsn),
LSN_FORMAT_ARGS(sk->appendResponse.commitLsn),
sk->host, sk->port);
sk->host, sk->port)));
if (sk->appendResponse.term > wp->propTerm)
{
/* Another compute with higher term is running. */
walprop_log(PANIC, "WAL acceptor %s:%s with term " INT64_FORMAT " rejected our request, our term " INT64_FORMAT "",
elog(PANIC, "WAL acceptor %s:%s with term " INT64_FORMAT " rejected our request, our term " INT64_FORMAT "",
sk->host, sk->port,
sk->appendResponse.term, wp->propTerm);
}
@@ -1474,7 +1462,7 @@ RecvAppendResponses(Safekeeper *sk)
/* Parse a PageserverFeedback message, or the PageserverFeedback part of an AppendResponse */
void
ParsePageserverFeedbackMessage(WalProposer *wp, StringInfo reply_message, PageserverFeedback *rf)
ParsePageserverFeedbackMessage(StringInfo reply_message, PageserverFeedback *rf)
{
uint8 nkeys;
int i;
@@ -1492,7 +1480,7 @@ ParsePageserverFeedbackMessage(WalProposer *wp, StringInfo reply_message, Pagese
pq_getmsgint(reply_message, sizeof(int32));
/* read value length */
rf->currentClusterSize = pq_getmsgint64(reply_message);
walprop_log(DEBUG2, "ParsePageserverFeedbackMessage: current_timeline_size %lu",
elog(DEBUG2, "ParsePageserverFeedbackMessage: current_timeline_size %lu",
rf->currentClusterSize);
}
else if ((strcmp(key, "ps_writelsn") == 0) || (strcmp(key, "last_received_lsn") == 0))
@@ -1500,7 +1488,7 @@ ParsePageserverFeedbackMessage(WalProposer *wp, StringInfo reply_message, Pagese
pq_getmsgint(reply_message, sizeof(int32));
/* read value length */
rf->last_received_lsn = pq_getmsgint64(reply_message);
walprop_log(DEBUG2, "ParsePageserverFeedbackMessage: last_received_lsn %X/%X",
elog(DEBUG2, "ParsePageserverFeedbackMessage: last_received_lsn %X/%X",
LSN_FORMAT_ARGS(rf->last_received_lsn));
}
else if ((strcmp(key, "ps_flushlsn") == 0) || (strcmp(key, "disk_consistent_lsn") == 0))
@@ -1508,7 +1496,7 @@ ParsePageserverFeedbackMessage(WalProposer *wp, StringInfo reply_message, Pagese
pq_getmsgint(reply_message, sizeof(int32));
/* read value length */
rf->disk_consistent_lsn = pq_getmsgint64(reply_message);
walprop_log(DEBUG2, "ParsePageserverFeedbackMessage: disk_consistent_lsn %X/%X",
elog(DEBUG2, "ParsePageserverFeedbackMessage: disk_consistent_lsn %X/%X",
LSN_FORMAT_ARGS(rf->disk_consistent_lsn));
}
else if ((strcmp(key, "ps_applylsn") == 0) || (strcmp(key, "remote_consistent_lsn") == 0))
@@ -1516,7 +1504,7 @@ ParsePageserverFeedbackMessage(WalProposer *wp, StringInfo reply_message, Pagese
pq_getmsgint(reply_message, sizeof(int32));
/* read value length */
rf->remote_consistent_lsn = pq_getmsgint64(reply_message);
walprop_log(DEBUG2, "ParsePageserverFeedbackMessage: remote_consistent_lsn %X/%X",
elog(DEBUG2, "ParsePageserverFeedbackMessage: remote_consistent_lsn %X/%X",
LSN_FORMAT_ARGS(rf->remote_consistent_lsn));
}
else if ((strcmp(key, "ps_replytime") == 0) || (strcmp(key, "replytime") == 0))
@@ -1529,7 +1517,7 @@ ParsePageserverFeedbackMessage(WalProposer *wp, StringInfo reply_message, Pagese
/* Copy because timestamptz_to_str returns a static buffer */
replyTimeStr = pstrdup(timestamptz_to_str(rf->replytime));
walprop_log(DEBUG2, "ParsePageserverFeedbackMessage: replytime %lu reply_time: %s",
elog(DEBUG2, "ParsePageserverFeedbackMessage: replytime %lu reply_time: %s",
rf->replytime, replyTimeStr);
pfree(replyTimeStr);
@@ -1544,7 +1532,7 @@ ParsePageserverFeedbackMessage(WalProposer *wp, StringInfo reply_message, Pagese
* Skip unknown keys to support backward compatibile protocol
* changes
*/
walprop_log(LOG, "ParsePageserverFeedbackMessage: unknown key: %s len %d", key, len);
elog(LOG, "ParsePageserverFeedbackMessage: unknown key: %s len %d", key, len);
pq_getmsgbytes(reply_message, len);
};
}
@@ -1627,7 +1615,7 @@ HandleSafekeeperResponse(WalProposer *wp)
* Advance the replication slot to free up old WAL files. Note that
* slot doesn't exist if we are in syncSafekeepers mode.
*/
wp->api.confirm_wal_streamed(wp, wp->truncateLsn);
wp->api.confirm_wal_streamed(wp->truncateLsn);
}
/*
@@ -1674,7 +1662,7 @@ HandleSafekeeperResponse(WalProposer *wp)
*/
BroadcastAppendRequest(wp);
wp->api.finish_sync_safekeepers(wp, wp->propEpochStartLsn);
wp->api.finish_sync_safekeepers(wp->propEpochStartLsn);
/* unreachable */
}
}
@@ -1689,7 +1677,7 @@ AsyncRead(Safekeeper *sk, char **buf, int *buf_size)
{
WalProposer *wp = sk->wp;
switch (wp->api.conn_async_read(sk, buf, buf_size))
switch (wp->api.conn_async_read(sk->conn, buf, buf_size))
{
case PG_ASYNC_READ_SUCCESS:
return true;
@@ -1699,9 +1687,9 @@ AsyncRead(Safekeeper *sk, char **buf, int *buf_size)
return false;
case PG_ASYNC_READ_FAIL:
walprop_log(WARNING, "Failed to read from node %s:%s in %s state: %s", sk->host,
elog(WARNING, "Failed to read from node %s:%s in %s state: %s", sk->host,
sk->port, FormatSafekeeperState(sk->state),
wp->api.conn_error_message(sk));
wp->api.conn_error_message(sk->conn));
ShutdownConnection(sk);
return false;
}
@@ -1739,12 +1727,12 @@ AsyncReadMessage(Safekeeper *sk, AcceptorProposerMessage *anymsg)
tag = pq_getmsgint64_le(&s);
if (tag != anymsg->tag)
{
walprop_log(WARNING, "unexpected message tag %c from node %s:%s in state %s", (char) tag, sk->host,
elog(WARNING, "unexpected message tag %c from node %s:%s in state %s", (char) tag, sk->host,
sk->port, FormatSafekeeperState(sk->state));
ResetConnection(sk);
return false;
}
sk->latestMsgReceivedAt = wp->api.get_current_timestamp(wp);
sk->latestMsgReceivedAt = wp->api.get_current_timestamp();
switch (tag)
{
case 'g':
@@ -1788,7 +1776,7 @@ AsyncReadMessage(Safekeeper *sk, AcceptorProposerMessage *anymsg)
msg->hs.xmin.value = pq_getmsgint64_le(&s);
msg->hs.catalog_xmin.value = pq_getmsgint64_le(&s);
if (buf_size > APPENDRESPONSE_FIXEDPART_SIZE)
ParsePageserverFeedbackMessage(wp, &s, &msg->rf);
ParsePageserverFeedbackMessage(&s, &msg->rf);
pq_getmsgend(&s);
return true;
}
@@ -1813,11 +1801,11 @@ BlockingWrite(Safekeeper *sk, void *msg, size_t msg_size, SafekeeperState succes
WalProposer *wp = sk->wp;
uint32 events;
if (!wp->api.conn_blocking_write(sk, msg, msg_size))
if (!wp->api.conn_blocking_write(sk->conn, msg, msg_size))
{
walprop_log(WARNING, "Failed to send to node %s:%s in %s state: %s",
elog(WARNING, "Failed to send to node %s:%s in %s state: %s",
sk->host, sk->port, FormatSafekeeperState(sk->state),
wp->api.conn_error_message(sk));
wp->api.conn_error_message(sk->conn));
ShutdownConnection(sk);
return false;
}
@@ -1847,7 +1835,7 @@ AsyncWrite(Safekeeper *sk, void *msg, size_t msg_size, SafekeeperState flush_sta
{
WalProposer *wp = sk->wp;
switch (wp->api.conn_async_write(sk, msg, msg_size))
switch (wp->api.conn_async_write(sk->conn, msg, msg_size))
{
case PG_ASYNC_WRITE_SUCCESS:
return true;
@@ -1862,9 +1850,9 @@ AsyncWrite(Safekeeper *sk, void *msg, size_t msg_size, SafekeeperState flush_sta
wp->api.update_event_set(sk, WL_SOCKET_READABLE | WL_SOCKET_WRITEABLE);
return false;
case PG_ASYNC_WRITE_FAIL:
walprop_log(WARNING, "Failed to send to node %s:%s in %s state: %s",
elog(WARNING, "Failed to send to node %s:%s in %s state: %s",
sk->host, sk->port, FormatSafekeeperState(sk->state),
wp->api.conn_error_message(sk));
wp->api.conn_error_message(sk->conn));
ShutdownConnection(sk);
return false;
default:
@@ -1892,7 +1880,7 @@ AsyncFlush(Safekeeper *sk)
* 1 if unable to send everything yet [call PQflush again]
* -1 if it failed [emit an error]
*/
switch (wp->api.conn_flush(sk))
switch (wp->api.conn_flush(sk->conn))
{
case 0:
/* flush is done */
@@ -1901,9 +1889,9 @@ AsyncFlush(Safekeeper *sk)
/* Nothing to do; try again when the socket's ready */
return false;
case -1:
walprop_log(WARNING, "Failed to flush write to node %s:%s in %s state: %s",
elog(WARNING, "Failed to flush write to node %s:%s in %s state: %s",
sk->host, sk->port, FormatSafekeeperState(sk->state),
wp->api.conn_error_message(sk));
wp->api.conn_error_message(sk->conn));
ResetConnection(sk);
return false;
default:
@@ -1932,11 +1920,11 @@ CompareLsn(const void *a, const void *b)
*
* The strings are intended to be used as a prefix to "state", e.g.:
*
* walprop_log(LOG, "currently in %s state", FormatSafekeeperState(sk->state));
* elog(LOG, "currently in %s state", FormatSafekeeperState(sk->state));
*
* If this sort of phrasing doesn't fit the message, instead use something like:
*
* walprop_log(LOG, "currently in state [%s]", FormatSafekeeperState(sk->state));
* elog(LOG, "currently in state [%s]", FormatSafekeeperState(sk->state));
*/
static char *
FormatSafekeeperState(SafekeeperState state)
@@ -1984,7 +1972,6 @@ FormatSafekeeperState(SafekeeperState state)
static void
AssertEventsOkForState(uint32 events, Safekeeper *sk)
{
WalProposer *wp = sk->wp;
uint32 expected = SafekeeperStateDesiredEvents(sk->state);
/*
@@ -2007,8 +1994,8 @@ AssertEventsOkForState(uint32 events, Safekeeper *sk)
* To give a descriptive message in the case of failure, we use elog
* and then an assertion that's guaranteed to fail.
*/
walprop_log(WARNING, "events %s mismatched for safekeeper %s:%s in state [%s]",
FormatEvents(wp, events), sk->host, sk->port, FormatSafekeeperState(sk->state));
elog(WARNING, "events %s mismatched for safekeeper %s:%s in state [%s]",
FormatEvents(events), sk->host, sk->port, FormatSafekeeperState(sk->state));
Assert(events_ok_for_state);
}
}
@@ -2081,7 +2068,7 @@ SafekeeperStateDesiredEvents(SafekeeperState state)
* The string should not be freed. It should also not be expected to remain the same between
* function calls. */
static char *
FormatEvents(WalProposer *wp, uint32 events)
FormatEvents(uint32 events)
{
static char return_str[8];
@@ -2110,7 +2097,7 @@ FormatEvents(WalProposer *wp, uint32 events)
if (events & (~all_flags))
{
walprop_log(WARNING, "Event formatting found unexpected component %d",
elog(WARNING, "Event formatting found unexpected component %d",
events & (~all_flags));
return_str[6] = '*';
return_str[7] = '\0';

View File

@@ -333,11 +333,24 @@ typedef struct Safekeeper
*/
char conninfo[MAXCONNINFO];
/*
* postgres protocol connection to the WAL acceptor
*
* Equals NULL only when state = SS_OFFLINE. Nonblocking is set once we
* reach SS_ACTIVE; not before.
*/
WalProposerConn *conn;
/*
* Temporary buffer for the message being sent to the safekeeper.
*/
StringInfoData outbuf;
/*
* WAL reader, allocated for each safekeeper.
*/
XLogReaderState *xlogreader;
/*
* Streaming will start here; must be record boundary.
*/
@@ -348,43 +361,13 @@ typedef struct Safekeeper
XLogRecPtr streamingAt; /* current streaming position */
AppendRequestHeader appendRequest; /* request for sending to safekeeper */
int eventPos; /* position in wait event set. Equal to -1 if*
* no event */
SafekeeperState state; /* safekeeper state machine state */
TimestampTz latestMsgReceivedAt; /* when latest msg is received */
AcceptorGreeting greetResponse; /* acceptor greeting */
VoteResponse voteResponse; /* the vote */
AppendResponse appendResponse; /* feedback for master */
/* postgres-specific fields */
#ifndef WALPROPOSER_LIB
/*
* postgres protocol connection to the WAL acceptor
*
* Equals NULL only when state = SS_OFFLINE. Nonblocking is set once we
* reach SS_ACTIVE; not before.
*/
WalProposerConn *conn;
/*
* WAL reader, allocated for each safekeeper.
*/
XLogReaderState *xlogreader;
/*
* Position in wait event set. Equal to -1 if no event
*/
int eventPos;
#endif
/* WalProposer library specifics */
#ifdef WALPROPOSER_LIB
/*
* Buffer for incoming messages. Usually Rust vector is stored here.
* Caller is responsible for freeing the buffer.
*/
StringInfoData inbuf;
#endif
} Safekeeper;
/* Re-exported PostgresPollingStatusType */
@@ -450,7 +433,7 @@ typedef struct walproposer_api
* Get WalproposerShmemState. This is used to store information about last
* elected term.
*/
WalproposerShmemState *(*get_shmem_state) (WalProposer *wp);
WalproposerShmemState *(*get_shmem_state) (void);
/*
* Start receiving notifications about new WAL. This is an infinite loop
@@ -460,63 +443,61 @@ typedef struct walproposer_api
void (*start_streaming) (WalProposer *wp, XLogRecPtr startpos);
/* Get pointer to the latest available WAL. */
XLogRecPtr (*get_flush_rec_ptr) (WalProposer *wp);
XLogRecPtr (*get_flush_rec_ptr) (void);
/* Get current time. */
TimestampTz (*get_current_timestamp) (WalProposer *wp);
TimestampTz (*get_current_timestamp) (void);
/* Get postgres timeline. */
TimeLineID (*get_timeline_id) (void);
/* Current error message, aka PQerrorMessage. */
char *(*conn_error_message) (Safekeeper *sk);
char *(*conn_error_message) (WalProposerConn *conn);
/* Connection status, aka PQstatus. */
WalProposerConnStatusType (*conn_status) (Safekeeper *sk);
WalProposerConnStatusType (*conn_status) (WalProposerConn *conn);
/* Start the connection, aka PQconnectStart. */
void (*conn_connect_start) (Safekeeper *sk);
WalProposerConn *(*conn_connect_start) (char *conninfo);
/* Poll an asynchronous connection, aka PQconnectPoll. */
WalProposerConnectPollStatusType (*conn_connect_poll) (Safekeeper *sk);
WalProposerConnectPollStatusType (*conn_connect_poll) (WalProposerConn *conn);
/* Send a blocking SQL query, aka PQsendQuery. */
bool (*conn_send_query) (Safekeeper *sk, char *query);
bool (*conn_send_query) (WalProposerConn *conn, char *query);
/* Read the query result, aka PQgetResult. */
WalProposerExecStatusType (*conn_get_query_result) (Safekeeper *sk);
WalProposerExecStatusType (*conn_get_query_result) (WalProposerConn *conn);
/* Flush buffer to the network, aka PQflush. */
int (*conn_flush) (Safekeeper *sk);
int (*conn_flush) (WalProposerConn *conn);
/* Close the connection, aka PQfinish. */
void (*conn_finish) (Safekeeper *sk);
void (*conn_finish) (WalProposerConn *conn);
/*
* Try to read CopyData message from the safekeeper, aka PQgetCopyData.
*
* On success, the data is placed in *buf. It is valid until the next call
* to this function.
*/
PGAsyncReadResult (*conn_async_read) (Safekeeper *sk, char **buf, int *amount);
/* Try to read CopyData message, aka PQgetCopyData. */
PGAsyncReadResult (*conn_async_read) (WalProposerConn *conn, char **buf, int *amount);
/* Try to write CopyData message, aka PQputCopyData. */
PGAsyncWriteResult (*conn_async_write) (Safekeeper *sk, void const *buf, size_t size);
PGAsyncWriteResult (*conn_async_write) (WalProposerConn *conn, void const *buf, size_t size);
/* Blocking CopyData write, aka PQputCopyData + PQflush. */
bool (*conn_blocking_write) (Safekeeper *sk, void const *buf, size_t size);
bool (*conn_blocking_write) (WalProposerConn *conn, void const *buf, size_t size);
/* Download WAL from startpos to endpos and make it available locally. */
bool (*recovery_download) (Safekeeper *sk, TimeLineID timeline, XLogRecPtr startpos, XLogRecPtr endpos);
/* Read WAL from disk to buf. */
void (*wal_read) (Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count);
void (*wal_read) (XLogReaderState *state, char *buf, XLogRecPtr startptr, Size count);
/* Allocate WAL reader. */
void (*wal_reader_allocate) (Safekeeper *sk);
XLogReaderState *(*wal_reader_allocate) (void);
/* Deallocate event set. */
void (*free_event_set) (WalProposer *wp);
void (*free_event_set) (void);
/* Initialize event set. */
void (*init_event_set) (WalProposer *wp);
void (*init_event_set) (int n_safekeepers);
/* Update events for an existing safekeeper connection. */
void (*update_event_set) (Safekeeper *sk, uint32 events);
@@ -532,22 +513,22 @@ typedef struct walproposer_api
* events mask to indicate events and sets sk to the safekeeper which has
* an event.
*/
int (*wait_event_set) (WalProposer *wp, long timeout, Safekeeper **sk, uint32 *events);
int (*wait_event_set) (long timeout, Safekeeper **sk, uint32 *events);
/* Read random bytes. */
bool (*strong_random) (WalProposer *wp, void *buf, size_t len);
bool (*strong_random) (void *buf, size_t len);
/*
* Get a basebackup LSN. Used to cross-validate with the latest available
* LSN on the safekeepers.
*/
XLogRecPtr (*get_redo_start_lsn) (WalProposer *wp);
XLogRecPtr (*get_redo_start_lsn) (void);
/*
* Finish sync safekeepers with the given LSN. This function should not
* return and should exit the program.
*/
void (*finish_sync_safekeepers) (WalProposer *wp, XLogRecPtr lsn);
void (*finish_sync_safekeepers) (XLogRecPtr lsn);
/*
* Called after every new message from the safekeeper. Used to propagate
@@ -560,22 +541,7 @@ typedef struct walproposer_api
* Called on peer_horizon_lsn updates. Used to advance replication slot
* and to free up disk space by deleting unnecessary WAL.
*/
void (*confirm_wal_streamed) (WalProposer *wp, XLogRecPtr lsn);
/*
* Write a log message to the internal log processor. This is used only
* when walproposer is compiled as a library. Otherwise, all logging is
* handled by elog().
*/
void (*log_internal) (WalProposer *wp, int level, const char *line);
/*
* Called right after the proposer was elected, but before it started
* recovery and sent ProposerElected message to the safekeepers.
*
* Used by logical replication to update truncateLsn.
*/
void (*after_election) (WalProposer *wp);
void (*confirm_wal_streamed) (XLogRecPtr lsn);
} walproposer_api;
/*
@@ -624,13 +590,6 @@ typedef struct WalProposerConfig
/* Will be passed to safekeepers in greet request. */
uint64 systemId;
/* Will be passed to safekeepers in greet request. */
TimeLineID pgTimeline;
#ifdef WALPROPOSER_LIB
void *callback_data;
#endif
} WalProposerConfig;
@@ -707,16 +666,7 @@ extern WalProposer *WalProposerCreate(WalProposerConfig *config, walproposer_api
extern void WalProposerStart(WalProposer *wp);
extern void WalProposerBroadcast(WalProposer *wp, XLogRecPtr startpos, XLogRecPtr endpos);
extern void WalProposerPoll(WalProposer *wp);
extern void WalProposerFree(WalProposer *wp);
#define WPEVENT 1337 /* special log level for walproposer internal events */
#ifdef WALPROPOSER_LIB
void WalProposerLibLog(WalProposer *wp, int elevel, char *fmt, ...);
#define walprop_log(elevel, ...) WalProposerLibLog(wp, elevel, __VA_ARGS__)
#else
#define walprop_log(elevel, ...) elog(elevel, __VA_ARGS__)
#endif
extern void ParsePageserverFeedbackMessage(StringInfo reply_message,
PageserverFeedback *rf);
#endif /* __NEON_WALPROPOSER_H__ */

View File

@@ -1,192 +0,0 @@
/*
* Contains copied/adapted functions from libpq and some internal postgres functions.
* This is needed to avoid linking to full postgres server installation. This file
* is compiled as a part of libwalproposer static library.
*/
#include <stdio.h>
#include "walproposer.h"
#include "utils/datetime.h"
#include "miscadmin.h"
void ExceptionalCondition(const char *conditionName,
const char *fileName, int lineNumber)
{
fprintf(stderr, "ExceptionalCondition: %s:%d: %s\n",
fileName, lineNumber, conditionName);
fprintf(stderr, "aborting...\n");
exit(1);
}
void
pq_copymsgbytes(StringInfo msg, char *buf, int datalen)
{
if (datalen < 0 || datalen > (msg->len - msg->cursor))
ExceptionalCondition("insufficient data left in message", __FILE__, __LINE__);
memcpy(buf, &msg->data[msg->cursor], datalen);
msg->cursor += datalen;
}
/* --------------------------------
* pq_getmsgint - get a binary integer from a message buffer
*
* Values are treated as unsigned.
* --------------------------------
*/
unsigned int
pq_getmsgint(StringInfo msg, int b)
{
unsigned int result;
unsigned char n8;
uint16 n16;
uint32 n32;
switch (b)
{
case 1:
pq_copymsgbytes(msg, (char *) &n8, 1);
result = n8;
break;
case 2:
pq_copymsgbytes(msg, (char *) &n16, 2);
result = pg_ntoh16(n16);
break;
case 4:
pq_copymsgbytes(msg, (char *) &n32, 4);
result = pg_ntoh32(n32);
break;
default:
fprintf(stderr, "unsupported integer size %d\n", b);
ExceptionalCondition("unsupported integer size", __FILE__, __LINE__);
result = 0; /* keep compiler quiet */
break;
}
return result;
}
/* --------------------------------
* pq_getmsgint64 - get a binary 8-byte int from a message buffer
*
* It is tempting to merge this with pq_getmsgint, but we'd have to make the
* result int64 for all data widths --- that could be a big performance
* hit on machines where int64 isn't efficient.
* --------------------------------
*/
int64
pq_getmsgint64(StringInfo msg)
{
uint64 n64;
pq_copymsgbytes(msg, (char *) &n64, sizeof(n64));
return pg_ntoh64(n64);
}
/* --------------------------------
* pq_getmsgbyte - get a raw byte from a message buffer
* --------------------------------
*/
int
pq_getmsgbyte(StringInfo msg)
{
if (msg->cursor >= msg->len)
ExceptionalCondition("no data left in message", __FILE__, __LINE__);
return (unsigned char) msg->data[msg->cursor++];
}
/* --------------------------------
* pq_getmsgbytes - get raw data from a message buffer
*
* Returns a pointer directly into the message buffer; note this
* may not have any particular alignment.
* --------------------------------
*/
const char *
pq_getmsgbytes(StringInfo msg, int datalen)
{
const char *result;
if (datalen < 0 || datalen > (msg->len - msg->cursor))
ExceptionalCondition("insufficient data left in message", __FILE__, __LINE__);
result = &msg->data[msg->cursor];
msg->cursor += datalen;
return result;
}
/* --------------------------------
* pq_getmsgstring - get a null-terminated text string (with conversion)
*
* May return a pointer directly into the message buffer, or a pointer
* to a palloc'd conversion result.
* --------------------------------
*/
const char *
pq_getmsgstring(StringInfo msg)
{
char *str;
int slen;
str = &msg->data[msg->cursor];
/*
* It's safe to use strlen() here because a StringInfo is guaranteed to
* have a trailing null byte. But check we found a null inside the
* message.
*/
slen = strlen(str);
if (msg->cursor + slen >= msg->len)
ExceptionalCondition("invalid string in message", __FILE__, __LINE__);
msg->cursor += slen + 1;
return str;
}
/* --------------------------------
* pq_getmsgend - verify message fully consumed
* --------------------------------
*/
void
pq_getmsgend(StringInfo msg)
{
if (msg->cursor != msg->len)
ExceptionalCondition("invalid msg format", __FILE__, __LINE__);
}
/*
* Produce a C-string representation of a TimestampTz.
*
* This is mostly for use in emitting messages.
*/
const char *
timestamptz_to_str(TimestampTz t)
{
static char buf[MAXDATELEN + 1];
snprintf(buf, sizeof(buf), "TimestampTz(%ld)", t);
return buf;
}
bool
TimestampDifferenceExceeds(TimestampTz start_time,
TimestampTz stop_time,
int msec)
{
TimestampTz diff = stop_time - start_time;
return (diff >= msec * INT64CONST(1000));
}
void
WalProposerLibLog(WalProposer *wp, int elevel, char *fmt, ...)
{
char buf[1024];
va_list args;
fmt = _(fmt);
va_start(args, fmt);
vsnprintf(buf, sizeof(buf), fmt, args);
va_end(args);
wp->api.log_internal(wp, elevel, buf);
}

View File

@@ -73,8 +73,7 @@ static void walprop_register_bgworker(void);
static void walprop_pg_init_standalone_sync_safekeepers(void);
static void walprop_pg_init_walsender(void);
static void walprop_pg_init_bgworker(void);
static TimestampTz walprop_pg_get_current_timestamp(WalProposer *wp);
static TimeLineID walprop_pg_get_timeline_id(void);
static TimestampTz walprop_pg_get_current_timestamp(void);
static void walprop_pg_load_libpqwalreceiver(void);
static process_interrupts_callback_t PrevProcessInterruptsCallback;
@@ -105,7 +104,6 @@ init_walprop_config(bool syncSafekeepers)
walprop_config.systemId = GetSystemIdentifier();
else
walprop_config.systemId = 0;
walprop_config.pgTimeline = walprop_pg_get_timeline_id();
}
/*
@@ -138,7 +136,7 @@ WalProposerMain(Datum main_arg)
walprop_pg_load_libpqwalreceiver();
wp = WalProposerCreate(&walprop_config, walprop_pg);
wp->last_reconnect_attempt = walprop_pg_get_current_timestamp(wp);
wp->last_reconnect_attempt = walprop_pg_get_current_timestamp();
walprop_pg_init_walsender();
WalProposerStart(wp);
@@ -381,7 +379,7 @@ nwp_shmem_startup_hook(void)
}
static WalproposerShmemState *
walprop_pg_get_shmem_state(WalProposer *wp)
walprop_pg_get_shmem_state(void)
{
Assert(walprop_shared != NULL);
return walprop_shared;
@@ -507,7 +505,7 @@ walprop_pg_init_bgworker(void)
}
static XLogRecPtr
walprop_pg_get_flush_rec_ptr(WalProposer *wp)
walprop_pg_get_flush_rec_ptr(void)
{
#if PG_MAJORVERSION_NUM < 15
return GetFlushRecPtr();
@@ -517,7 +515,7 @@ walprop_pg_get_flush_rec_ptr(WalProposer *wp)
}
static TimestampTz
walprop_pg_get_current_timestamp(WalProposer *wp)
walprop_pg_get_current_timestamp(void)
{
return GetCurrentTimestamp();
}
@@ -567,15 +565,15 @@ ensure_nonblocking_status(WalProposerConn *conn, bool is_nonblocking)
/* Exported function definitions */
static char *
walprop_error_message(Safekeeper *sk)
walprop_error_message(WalProposerConn *conn)
{
return PQerrorMessage(sk->conn->pg_conn);
return PQerrorMessage(conn->pg_conn);
}
static WalProposerConnStatusType
walprop_status(Safekeeper *sk)
walprop_status(WalProposerConn *conn)
{
switch (PQstatus(sk->conn->pg_conn))
switch (PQstatus(conn->pg_conn))
{
case CONNECTION_OK:
return WP_CONNECTION_OK;
@@ -586,17 +584,16 @@ walprop_status(Safekeeper *sk)
}
}
static void
walprop_connect_start(Safekeeper *sk)
static WalProposerConn *
walprop_connect_start(char *conninfo)
{
WalProposerConn *conn;
PGconn *pg_conn;
const char *keywords[3];
const char *values[3];
int n;
char *password = neon_auth_token;
Assert(sk->conn == NULL);
/*
* Connect using the given connection string. If the NEON_AUTH_TOKEN
* environment variable was set, use that as the password.
@@ -614,7 +611,7 @@ walprop_connect_start(Safekeeper *sk)
n++;
}
keywords[n] = "dbname";
values[n] = sk->conninfo;
values[n] = conninfo;
n++;
keywords[n] = NULL;
values[n] = NULL;
@@ -622,11 +619,11 @@ walprop_connect_start(Safekeeper *sk)
pg_conn = PQconnectStartParams(keywords, values, 1);
/*
* "If the result is null, then libpq has been unable to allocate a new
* PGconn structure"
* Allocation of a PQconn can fail, and will return NULL. We want to fully
* replicate the behavior of PQconnectStart here.
*/
if (!pg_conn)
elog(FATAL, "failed to allocate new PGconn object");
return NULL;
/*
* And in theory this allocation can fail as well, but it's incredibly
@@ -635,19 +632,20 @@ walprop_connect_start(Safekeeper *sk)
* palloc will exit on failure though, so there's not much we could do if
* it *did* fail.
*/
sk->conn = palloc(sizeof(WalProposerConn));
sk->conn->pg_conn = pg_conn;
sk->conn->is_nonblocking = false; /* connections always start in blocking
conn = palloc(sizeof(WalProposerConn));
conn->pg_conn = pg_conn;
conn->is_nonblocking = false; /* connections always start in blocking
* mode */
sk->conn->recvbuf = NULL;
conn->recvbuf = NULL;
return conn;
}
static WalProposerConnectPollStatusType
walprop_connect_poll(Safekeeper *sk)
walprop_connect_poll(WalProposerConn *conn)
{
WalProposerConnectPollStatusType return_val;
switch (PQconnectPoll(sk->conn->pg_conn))
switch (PQconnectPoll(conn->pg_conn))
{
case PGRES_POLLING_FAILED:
return_val = WP_CONN_POLLING_FAILED;
@@ -684,24 +682,24 @@ walprop_connect_poll(Safekeeper *sk)
}
static bool
walprop_send_query(Safekeeper *sk, char *query)
walprop_send_query(WalProposerConn *conn, char *query)
{
/*
* We need to be in blocking mode for sending the query to run without
* requiring a call to PQflush
*/
if (!ensure_nonblocking_status(sk->conn, false))
if (!ensure_nonblocking_status(conn, false))
return false;
/* PQsendQuery returns 1 on success, 0 on failure */
if (!PQsendQuery(sk->conn->pg_conn, query))
if (!PQsendQuery(conn->pg_conn, query))
return false;
return true;
}
static WalProposerExecStatusType
walprop_get_query_result(Safekeeper *sk)
walprop_get_query_result(WalProposerConn *conn)
{
PGresult *result;
WalProposerExecStatusType return_val;
@@ -710,14 +708,14 @@ walprop_get_query_result(Safekeeper *sk)
char *unexpected_success = NULL;
/* Consume any input that we might be missing */
if (!PQconsumeInput(sk->conn->pg_conn))
if (!PQconsumeInput(conn->pg_conn))
return WP_EXEC_FAILED;
if (PQisBusy(sk->conn->pg_conn))
if (PQisBusy(conn->pg_conn))
return WP_EXEC_NEEDS_INPUT;
result = PQgetResult(sk->conn->pg_conn);
result = PQgetResult(conn->pg_conn);
/*
* PQgetResult returns NULL only if getting the result was successful &
@@ -779,28 +777,24 @@ walprop_get_query_result(Safekeeper *sk)
}
static pgsocket
walprop_socket(Safekeeper *sk)
walprop_socket(WalProposerConn *conn)
{
return PQsocket(sk->conn->pg_conn);
return PQsocket(conn->pg_conn);
}
static int
walprop_flush(Safekeeper *sk)
walprop_flush(WalProposerConn *conn)
{
return (PQflush(sk->conn->pg_conn));
return (PQflush(conn->pg_conn));
}
static void
walprop_finish(Safekeeper *sk)
walprop_finish(WalProposerConn *conn)
{
if (!sk->conn)
return;
if (sk->conn->recvbuf != NULL)
PQfreemem(sk->conn->recvbuf);
PQfinish(sk->conn->pg_conn);
pfree(sk->conn);
sk->conn = NULL;
if (conn->recvbuf != NULL)
PQfreemem(conn->recvbuf);
PQfinish(conn->pg_conn);
pfree(conn);
}
/*
@@ -810,18 +804,18 @@ walprop_finish(Safekeeper *sk)
* to this function.
*/
static PGAsyncReadResult
walprop_async_read(Safekeeper *sk, char **buf, int *amount)
walprop_async_read(WalProposerConn *conn, char **buf, int *amount)
{
int result;
if (sk->conn->recvbuf != NULL)
if (conn->recvbuf != NULL)
{
PQfreemem(sk->conn->recvbuf);
sk->conn->recvbuf = NULL;
PQfreemem(conn->recvbuf);
conn->recvbuf = NULL;
}
/* Call PQconsumeInput so that we have the data we need */
if (!PQconsumeInput(sk->conn->pg_conn))
if (!PQconsumeInput(conn->pg_conn))
{
*amount = 0;
*buf = NULL;
@@ -839,7 +833,7 @@ walprop_async_read(Safekeeper *sk, char **buf, int *amount)
* sometimes be triggered by the server returning an ErrorResponse (which
* also happens to have the effect that the copy is done).
*/
switch (result = PQgetCopyData(sk->conn->pg_conn, &sk->conn->recvbuf, true))
switch (result = PQgetCopyData(conn->pg_conn, &conn->recvbuf, true))
{
case 0:
*amount = 0;
@@ -854,7 +848,7 @@ walprop_async_read(Safekeeper *sk, char **buf, int *amount)
* We can check PQgetResult to make sure that the server
* failed; it'll always result in PGRES_FATAL_ERROR
*/
ExecStatusType status = PQresultStatus(PQgetResult(sk->conn->pg_conn));
ExecStatusType status = PQresultStatus(PQgetResult(conn->pg_conn));
if (status != PGRES_FATAL_ERROR)
elog(FATAL, "unexpected result status %d after failed PQgetCopyData", status);
@@ -875,18 +869,18 @@ walprop_async_read(Safekeeper *sk, char **buf, int *amount)
default:
/* Positive values indicate the size of the returned result */
*amount = result;
*buf = sk->conn->recvbuf;
*buf = conn->recvbuf;
return PG_ASYNC_READ_SUCCESS;
}
}
static PGAsyncWriteResult
walprop_async_write(Safekeeper *sk, void const *buf, size_t size)
walprop_async_write(WalProposerConn *conn, void const *buf, size_t size)
{
int result;
/* If we aren't in non-blocking mode, switch to it. */
if (!ensure_nonblocking_status(sk->conn, true))
if (!ensure_nonblocking_status(conn, true))
return PG_ASYNC_WRITE_FAIL;
/*
@@ -894,7 +888,7 @@ walprop_async_write(Safekeeper *sk, void const *buf, size_t size)
* queued, 0 if it was not queued because of full buffers, or -1 if an
* error occurred
*/
result = PQputCopyData(sk->conn->pg_conn, buf, size);
result = PQputCopyData(conn->pg_conn, buf, size);
/*
* We won't get a result of zero because walproposer always empties the
@@ -922,7 +916,7 @@ walprop_async_write(Safekeeper *sk, void const *buf, size_t size)
* sucessful, 1 if it was unable to send all the data in the send queue
* yet -1 if it failed for some reason
*/
switch (result = PQflush(sk->conn->pg_conn))
switch (result = PQflush(conn->pg_conn))
{
case 0:
return PG_ASYNC_WRITE_SUCCESS;
@@ -940,22 +934,22 @@ walprop_async_write(Safekeeper *sk, void const *buf, size_t size)
* information, refer to the comments there.
*/
static bool
walprop_blocking_write(Safekeeper *sk, void const *buf, size_t size)
walprop_blocking_write(WalProposerConn *conn, void const *buf, size_t size)
{
int result;
/* If we are in non-blocking mode, switch out of it. */
if (!ensure_nonblocking_status(sk->conn, false))
if (!ensure_nonblocking_status(conn, false))
return false;
if ((result = PQputCopyData(sk->conn->pg_conn, buf, size)) == -1)
if ((result = PQputCopyData(conn->pg_conn, buf, size)) == -1)
return false;
Assert(result == 1);
/* Because the connection is non-blocking, flushing returns 0 or -1 */
if ((result = PQflush(sk->conn->pg_conn)) == -1)
if ((result = PQflush(conn->pg_conn)) == -1)
return false;
Assert(result == 0);
@@ -1387,11 +1381,11 @@ XLogWalPropClose(XLogRecPtr recptr)
}
static void
walprop_pg_wal_read(Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count)
walprop_pg_wal_read(XLogReaderState *state, char *buf, XLogRecPtr startptr, Size count)
{
WALReadError errinfo;
if (!WALRead(sk->xlogreader,
if (!WALRead(state,
buf,
startptr,
count,
@@ -1402,38 +1396,31 @@ walprop_pg_wal_read(Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count)
}
}
static void
walprop_pg_wal_reader_allocate(Safekeeper *sk)
static XLogReaderState *
walprop_pg_wal_reader_allocate(void)
{
sk->xlogreader = XLogReaderAllocate(wal_segment_size, NULL, XL_ROUTINE(.segment_open = wal_segment_open,.segment_close = wal_segment_close), NULL);
if (sk->xlogreader == NULL)
elog(FATAL, "Failed to allocate xlog reader");
return XLogReaderAllocate(wal_segment_size, NULL, XL_ROUTINE(.segment_open = wal_segment_open,.segment_close = wal_segment_close), NULL);
}
static WaitEventSet *waitEvents;
static void
walprop_pg_free_event_set(WalProposer *wp)
walprop_pg_free_event_set(void)
{
if (waitEvents)
{
FreeWaitEventSet(waitEvents);
waitEvents = NULL;
}
for (int i = 0; i < wp->n_safekeepers; i++)
{
wp->safekeeper[i].eventPos = -1;
}
}
static void
walprop_pg_init_event_set(WalProposer *wp)
walprop_pg_init_event_set(int n_safekeepers)
{
if (waitEvents)
elog(FATAL, "double-initialization of event set");
waitEvents = CreateWaitEventSet(TopMemoryContext, 2 + wp->n_safekeepers);
waitEvents = CreateWaitEventSet(TopMemoryContext, 2 + n_safekeepers);
AddWaitEventToSet(waitEvents, WL_LATCH_SET, PGINVALID_SOCKET,
MyLatch, NULL);
AddWaitEventToSet(waitEvents, WL_EXIT_ON_PM_DEATH, PGINVALID_SOCKET,
@@ -1452,11 +1439,11 @@ walprop_pg_update_event_set(Safekeeper *sk, uint32 events)
static void
walprop_pg_add_safekeeper_event_set(Safekeeper *sk, uint32 events)
{
sk->eventPos = AddWaitEventToSet(waitEvents, events, walprop_socket(sk), NULL, sk);
sk->eventPos = AddWaitEventToSet(waitEvents, events, walprop_socket(sk->conn), NULL, sk);
}
static int
walprop_pg_wait_event_set(WalProposer *wp, long timeout, Safekeeper **sk, uint32 *events)
walprop_pg_wait_event_set(long timeout, Safekeeper **sk, uint32 *events)
{
WaitEvent event = {0};
int rc = 0;
@@ -1512,7 +1499,7 @@ walprop_pg_wait_event_set(WalProposer *wp, long timeout, Safekeeper **sk, uint32
}
static void
walprop_pg_finish_sync_safekeepers(WalProposer *wp, XLogRecPtr lsn)
walprop_pg_finish_sync_safekeepers(XLogRecPtr lsn)
{
fprintf(stdout, "%X/%X\n", LSN_FORMAT_ARGS(lsn));
exit(0);
@@ -1624,7 +1611,7 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, XLogRecPtr commitLsn)
* pageserver.
*/
quorumFeedback.rf.disk_consistent_lsn,
walprop_pg_get_current_timestamp(wp), false);
walprop_pg_get_current_timestamp(), false);
}
CombineHotStanbyFeedbacks(&hsFeedback, wp);
@@ -1641,65 +1628,18 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, XLogRecPtr commitLsn)
}
static void
walprop_pg_confirm_wal_streamed(WalProposer *wp, XLogRecPtr lsn)
walprop_pg_confirm_wal_streamed(XLogRecPtr lsn)
{
if (MyReplicationSlot)
PhysicalConfirmReceivedLocation(lsn);
}
static XLogRecPtr
walprop_pg_get_redo_start_lsn(WalProposer *wp)
{
return GetRedoStartLsn();
}
static bool
walprop_pg_strong_random(WalProposer *wp, void *buf, size_t len)
{
return pg_strong_random(buf, len);
}
static void
walprop_pg_log_internal(WalProposer *wp, int level, const char *line)
{
elog(FATAL, "unexpected log_internal message at level %d: %s", level, line);
}
static void
walprop_pg_after_election(WalProposer *wp)
{
FILE* f;
XLogRecPtr lrRestartLsn;
/* We don't need to do anything in syncSafekeepers mode.*/
if (wp->config->syncSafekeepers)
return;
/*
* If there are active logical replication subscription we need
* to provide enough WAL for their WAL senders based on th position
* of their replication slots.
*/
f = fopen("restart.lsn", "rb");
if (f != NULL && !wp->config->syncSafekeepers)
{
fread(&lrRestartLsn, sizeof(lrRestartLsn), 1, f);
fclose(f);
if (lrRestartLsn != InvalidXLogRecPtr)
{
elog(LOG, "Logical replication restart LSN %X/%X", LSN_FORMAT_ARGS(lrRestartLsn));
/* start from the beginning of the segment to fetch page headers verifed by XLogReader */
lrRestartLsn = lrRestartLsn - XLogSegmentOffset(lrRestartLsn, wal_segment_size);
wp->truncateLsn = Min(wp->truncateLsn, lrRestartLsn);
}
}
}
static const walproposer_api walprop_pg = {
.get_shmem_state = walprop_pg_get_shmem_state,
.start_streaming = walprop_pg_start_streaming,
.get_flush_rec_ptr = walprop_pg_get_flush_rec_ptr,
.get_current_timestamp = walprop_pg_get_current_timestamp,
.get_timeline_id = walprop_pg_get_timeline_id,
.conn_error_message = walprop_error_message,
.conn_status = walprop_status,
.conn_connect_start = walprop_connect_start,
@@ -1719,11 +1659,9 @@ static const walproposer_api walprop_pg = {
.update_event_set = walprop_pg_update_event_set,
.add_safekeeper_event_set = walprop_pg_add_safekeeper_event_set,
.wait_event_set = walprop_pg_wait_event_set,
.strong_random = walprop_pg_strong_random,
.get_redo_start_lsn = walprop_pg_get_redo_start_lsn,
.strong_random = pg_strong_random,
.get_redo_start_lsn = GetRedoStartLsn,
.finish_sync_safekeepers = walprop_pg_finish_sync_safekeepers,
.process_safekeeper_feedback = walprop_pg_process_safekeeper_feedback,
.confirm_wal_streamed = walprop_pg_confirm_wal_streamed,
.log_internal = walprop_pg_log_internal,
.after_election = walprop_pg_after_election,
};

16
poetry.lock generated
View File

@@ -2415,13 +2415,13 @@ files = [
[[package]]
name = "urllib3"
version = "1.26.18"
version = "1.26.17"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
files = [
{file = "urllib3-1.26.18-py2.py3-none-any.whl", hash = "sha256:34b97092d7e0a3a8cf7cd10e386f401b3737364026c45e622aa02903dffe0f07"},
{file = "urllib3-1.26.18.tar.gz", hash = "sha256:f8ecc1bba5667413457c529ab955bf8c67b45db799d159066261719e328580a0"},
{file = "urllib3-1.26.17-py2.py3-none-any.whl", hash = "sha256:94a757d178c9be92ef5539b8840d48dc9cf1b2709c9d6b588232a055c524458b"},
{file = "urllib3-1.26.17.tar.gz", hash = "sha256:24d6a242c28d29af46c3fae832c36db3bbebcc533dd1bb549172cd739c82df21"},
]
[package.extras]
@@ -2488,16 +2488,6 @@ files = [
{file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"},
{file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"},
{file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"},
{file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"},
{file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"},
{file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"},
{file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"},
{file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"},
{file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"},

View File

@@ -3,12 +3,9 @@ mod hacks;
mod link;
pub use link::LinkAuthError;
use tokio_postgres::config::AuthKeys;
use crate::proxy::{handle_try_wake, retry_after};
use crate::{
auth::{self, ClientCredentials},
config::AuthenticationConfig,
console::{
self,
provider::{CachedNodeInfo, ConsoleReqExtra},
@@ -18,9 +15,8 @@ use crate::{
};
use futures::TryFutureExt;
use std::borrow::Cow;
use std::ops::ControlFlow;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use tracing::info;
/// A product of successful authentication.
pub struct AuthSuccess<T> {
@@ -120,27 +116,21 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
}
}
pub enum ComputeCredentials {
Password(Vec<u8>),
AuthKeys(AuthKeys),
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
async fn auth_quirks_creds(
async fn auth_quirks(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
if creds.project.is_none() {
// Password will be checked by the compute node later.
return hacks::password_hack(creds, client).await;
return hacks::password_hack(api, extra, creds, client).await;
}
// Password hack should set the project name.
@@ -151,53 +141,11 @@ async fn auth_quirks_creds(
// Currently, we use it for websocket connections (latency).
if allow_cleartext {
// Password will be checked by the compute node later.
return hacks::cleartext_hack(client).await;
return hacks::cleartext_hack(api, extra, creds, client).await;
}
// Finally, proceed with the main auth flow (SCRAM-based).
classic::authenticate(api, extra, creds, client, config).await
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
async fn auth_quirks(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
let auth_stuff = auth_quirks_creds(api, extra, creds, client, allow_cleartext, config).await?;
let mut num_retries = 0;
let mut node = loop {
let wake_res = api.wake_compute(extra, creds).await;
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
return Err(e.into());
}
Ok(ControlFlow::Continue(e)) => {
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
}
Ok(ControlFlow::Break(n)) => break n,
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
tokio::time::sleep(wait_duration).await;
};
match auth_stuff.value {
ComputeCredentials::Password(password) => node.config.password(password),
ComputeCredentials::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
};
Ok(AuthSuccess {
reported_auth_ok: auth_stuff.reported_auth_ok,
value: node,
})
classic::authenticate(api, extra, creds, client).await
}
impl BackendType<'_, ClientCredentials<'_>> {
@@ -232,7 +180,6 @@ impl BackendType<'_, ClientCredentials<'_>> {
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
use BackendType::*;
@@ -245,7 +192,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
);
let api = api.as_ref();
auth_quirks(api, extra, creds, client, allow_cleartext, config).await?
auth_quirks(api, extra, creds, client, allow_cleartext).await?
}
Postgres(api, creds) => {
info!(
@@ -255,7 +202,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
);
let api = api.as_ref();
auth_quirks(api, extra, creds, client, allow_cleartext, config).await?
auth_quirks(api, extra, creds, client, allow_cleartext).await?
}
// NOTE: this auth backend doesn't use client credentials.
Link(url) => {

View File

@@ -1,22 +1,23 @@
use super::{AuthSuccess, ComputeCredentials};
use std::ops::ControlFlow;
use super::AuthSuccess;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
config::AuthenticationConfig,
console::{self, AuthInfo, ConsoleReqExtra},
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
proxy::{handle_try_wake, retry_after},
sasl, scram,
stream::PqStream,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
use tracing::{error, info, warn};
pub(super) async fn authenticate(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
config: &'static AuthenticationConfig,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
@@ -41,16 +42,7 @@ pub(super) async fn authenticate(
error
})?;
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
auth_flow.authenticate(),
)
.await
.map_err(|error| {
warn!("error processing scram messages error = authentication timed out, execution time exeeded {} seconds", config.scram_protocol_timeout.as_secs());
auth::io::Error::new(auth::io::ErrorKind::TimedOut, error)
})?
.map_err(|error| {
let auth_outcome = auth_flow.authenticate().await.map_err(|error| {
warn!(?error, "error processing scram messages");
error
})?;
@@ -63,17 +55,38 @@ pub(super) async fn authenticate(
}
};
compute::ScramKeys {
Some(compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
}
})
}
};
let mut num_retries = 0;
let mut node = loop {
let wake_res = api.wake_compute(extra, creds).await;
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
return Err(e.into());
}
Ok(ControlFlow::Continue(e)) => {
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
}
Ok(ControlFlow::Break(n)) => break n,
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
tokio::time::sleep(wait_duration).await;
};
if let Some(keys) = scram_keys {
use tokio_postgres::config::AuthKeys;
node.config.auth_keys(AuthKeys::ScramSha256(keys));
}
Ok(AuthSuccess {
reported_auth_ok: false,
value: ComputeCredentials::AuthKeys(tokio_postgres::config::AuthKeys::ScramSha256(
scram_keys,
)),
value: node,
})
}

View File

@@ -1,6 +1,10 @@
use super::{AuthSuccess, ComputeCredentials};
use super::AuthSuccess;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
console::{
self,
provider::{CachedNodeInfo, ConsoleReqExtra},
},
stream,
};
use tokio::io::{AsyncRead, AsyncWrite};
@@ -11,8 +15,11 @@ use tracing::{info, warn};
/// These properties are benefical for serverless JS workers, so we
/// use this mechanism for websocket connections.
pub async fn cleartext_hack(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
warn!("cleartext auth flow override is enabled, proceeding");
let password = AuthFlow::new(client)
.begin(auth::CleartextPassword)
@@ -20,19 +27,24 @@ pub async fn cleartext_hack(
.authenticate()
.await?;
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(password);
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {
reported_auth_ok: false,
value: ComputeCredentials::Password(password),
value: node,
})
}
/// Workaround for clients which don't provide an endpoint (project) name.
/// Very similar to [`cleartext_hack`], but there's a specific password format.
pub async fn password_hack(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<ComputeCredentials>> {
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
warn!("project not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
@@ -43,9 +55,12 @@ pub async fn password_hack(
info!(project = &payload.endpoint, "received missing parameter");
creds.project = Some(payload.endpoint);
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(payload.password);
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {
reported_auth_ok: false,
value: ComputeCredentials::Password(payload.password),
value: node,
})
}

View File

@@ -1,7 +1,5 @@
use futures::future::Either;
use proxy::auth;
use proxy::config::AuthenticationConfig;
use proxy::config::HttpConfig;
use proxy::console;
use proxy::http;
use proxy::metrics;
@@ -81,15 +79,6 @@ struct ProxyCliArgs {
/// Allow self-signed certificates for compute nodes (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
allow_self_signed_compute: bool,
/// timeout for http connections
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
sql_over_http_timeout: tokio::time::Duration,
/// timeout for scram authentication protocol
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
scram_protocol_timeout: tokio::time::Duration,
/// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
require_client_ip: bool,
}
#[tokio::main]
@@ -231,20 +220,12 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
auth::BackendType::Link(Cow::Owned(url))
}
};
let http_config = HttpConfig {
sql_over_http_timeout: args.sql_over_http_timeout,
};
let authentication_config = AuthenticationConfig {
scram_protocol_timeout: args.scram_protocol_timeout,
};
let config = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend,
metric_collection,
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
authentication_config,
require_client_ip: args.require_client_ip,
}));
Ok(config)

View File

@@ -1,5 +1,5 @@
use anyhow::{bail, Context};
use dashmap::DashMap;
use anyhow::{anyhow, Context};
use hashbrown::HashMap;
use pq_proto::CancelKeyData;
use std::net::SocketAddr;
use tokio::net::TcpStream;
@@ -8,7 +8,7 @@ use tracing::info;
/// Enables serving `CancelRequest`s.
#[derive(Default)]
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
pub struct CancelMap(parking_lot::RwLock<HashMap<CancelKeyData, Option<CancelClosure>>>);
impl CancelMap {
/// Cancel a running query for the corresponding connection.
@@ -16,6 +16,7 @@ impl CancelMap {
// NB: we should immediately release the lock after cloning the token.
let cancel_closure = self
.0
.read()
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("query cancellation key not found: {key}"))?;
@@ -39,19 +40,15 @@ impl CancelMap {
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.0.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => {
bail!("query cancellation key already exists: {key}")
}
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
}
}
self.0
.write()
.try_insert(key, None)
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.remove(&key);
self.0.write().remove(&key);
info!("dropped query cancellation key {key}");
}
@@ -62,12 +59,12 @@ impl CancelMap {
#[cfg(test)]
fn contains(&self, session: &Session) -> bool {
self.0.contains_key(&session.key)
self.0.read().contains_key(&session.key)
}
#[cfg(test)]
fn is_empty(&self) -> bool {
self.0.is_empty()
self.0.read().is_empty()
}
}
@@ -116,7 +113,10 @@ impl Session<'_> {
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map.0.insert(self.key, Some(cancel_closure));
self.cancel_map
.0
.write()
.insert(self.key, Some(cancel_closure));
self.key
}

View File

@@ -13,9 +13,6 @@ pub struct ProxyConfig {
pub auth_backend: auth::BackendType<'static, ()>,
pub metric_collection: Option<MetricCollectionConfig>,
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
pub require_client_ip: bool,
}
#[derive(Debug)]
@@ -29,14 +26,6 @@ pub struct TlsConfig {
pub common_names: Option<HashSet<String>>,
}
pub struct HttpConfig {
pub sql_over_http_timeout: tokio::time::Duration,
}
pub struct AuthenticationConfig {
pub scram_protocol_timeout: tokio::time::Duration,
}
impl TlsConfig {
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
self.config.clone()

View File

@@ -89,14 +89,7 @@ pub mod errors {
Self::Console {
status: http::StatusCode::LOCKED,
ref text,
} => {
// written data quota exceeded
// data transfer quota exceeded
// compute time quota exceeded
// logical size quota exceeded
!text.contains("quota exceeded")
&& !text.contains("the limit for current plan reached")
}
} => !text.contains("quota"),
// retry server errors
Self::Console { status, .. } if status.is_server_error() => true,
_ => false,

View File

@@ -8,34 +8,30 @@ use pbkdf2::{
Params, Pbkdf2,
};
use pq_proto::StartupMessageParams;
use std::sync::atomic::{self, AtomicUsize};
use std::{collections::HashMap, sync::Arc};
use std::{
fmt,
task::{ready, Poll},
};
use std::{
ops::Deref,
sync::atomic::{self, AtomicUsize},
};
use tokio::time;
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus};
use tokio_postgres::AsyncMessage;
use crate::{
auth, console,
metrics::{Ids, MetricCounter, USAGE_METRICS},
proxy::{LatencyTimer, NUM_DB_CONNECTIONS_CLOSED_COUNTER, NUM_DB_CONNECTIONS_OPENED_COUNTER},
};
use crate::{compute, config};
use crate::proxy::ConnectMechanism;
use tracing::{error, warn, Span};
use tracing::{error, warn};
use tracing::{info, info_span, Instrument};
pub const APP_NAME: &str = "sql_over_http";
const MAX_CONNS_PER_ENDPOINT: usize = 20;
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ConnInfo {
pub username: String,
pub dbname: String,
@@ -58,7 +54,7 @@ impl fmt::Display for ConnInfo {
}
struct ConnPoolEntry {
conn: ClientInner,
conn: Client,
_last_access: std::time::Instant,
}
@@ -136,19 +132,12 @@ impl GlobalConnPool {
}
pub async fn get(
self: &Arc<Self>,
&self,
conn_info: &ConnInfo,
force_new: bool,
session_id: uuid::Uuid,
) -> anyhow::Result<Client> {
let mut client: Option<ClientInner> = None;
let mut latency_timer = LatencyTimer::new("http");
let pool = if force_new {
None
} else {
Some((conn_info.clone(), self.clone()))
};
let mut client: Option<Client> = None;
let mut hash_valid = false;
if !force_new {
@@ -192,21 +181,15 @@ impl GlobalConnPool {
let new_client = if let Some(client) = client {
if client.inner.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
connect_to_compute(self.proxy_config, conn_info, session_id, latency_timer).await
connect_to_compute(self.proxy_config, conn_info, session_id).await
} else {
info!("pool: reusing connection '{conn_info}'");
client.session.send(session_id)?;
latency_timer.pool_hit();
latency_timer.success();
return Ok(Client {
inner: Some(client),
span: Span::current(),
pool,
});
return Ok(client);
}
} else {
info!("pool: opening a new connection '{conn_info}'");
connect_to_compute(self.proxy_config, conn_info, session_id, latency_timer).await
connect_to_compute(self.proxy_config, conn_info, session_id).await
};
match &new_client {
@@ -242,14 +225,10 @@ impl GlobalConnPool {
_ => {}
}
new_client.map(|inner| Client {
inner: Some(inner),
span: Span::current(),
pool,
})
new_client
}
fn put(&self, conn_info: &ConnInfo, client: ClientInner) -> anyhow::Result<()> {
pub fn put(&self, conn_info: &ConnInfo, client: Client) -> anyhow::Result<()> {
// We want to hold this open while we return. This ensures that the pool can't close
// while we are in the middle of returning the connection.
let closed = self.closed.read();
@@ -344,7 +323,7 @@ struct TokioMechanism<'a> {
#[async_trait]
impl ConnectMechanism for TokioMechanism<'_> {
type Connection = ClientInner;
type Connection = Client;
type ConnectError = tokio_postgres::Error;
type Error = anyhow::Error;
@@ -367,8 +346,7 @@ async fn connect_to_compute(
config: &config::ProxyConfig,
conn_info: &ConnInfo,
session_id: uuid::Uuid,
latency_timer: LatencyTimer,
) -> anyhow::Result<ClientInner> {
) -> anyhow::Result<Client> {
let tls = config.tls_config.as_ref();
let common_names = tls.and_then(|tls| tls.common_names.clone());
@@ -407,7 +385,6 @@ async fn connect_to_compute(
node_info,
&extra,
&creds,
latency_timer,
)
.await
}
@@ -417,7 +394,7 @@ async fn connect_to_compute_once(
conn_info: &ConnInfo,
timeout: time::Duration,
mut session: uuid::Uuid,
) -> Result<ClientInner, tokio_postgres::Error> {
) -> Result<Client, tokio_postgres::Error> {
let mut config = (*node_info.config).clone();
let (client, mut connection) = config
@@ -441,138 +418,54 @@ async fn connect_to_compute_once(
};
tokio::spawn(
async move {
NUM_DB_CONNECTIONS_OPENED_COUNTER.with_label_values(&["http"]).inc();
scopeguard::defer! {
NUM_DB_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
poll_fn(move |cx| {
if matches!(rx.has_changed(), Ok(true)) {
session = *rx.borrow_and_update();
info!(%session, "changed session");
}
poll_fn(move |cx| {
if matches!(rx.has_changed(), Ok(true)) {
session = *rx.borrow_and_update();
info!(%session, "changed session");
}
loop {
let message = ready!(connection.poll_message(cx));
loop {
let message = ready!(connection.poll_message(cx));
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session, "unknown message");
}
Some(Err(e)) => {
error!(%session, "connection error: {}", e);
return Poll::Ready(())
}
None => {
info!("connection closed");
return Poll::Ready(())
}
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session, "unknown message");
}
Some(Err(e)) => {
error!(%session, "connection error: {}", e);
return Poll::Ready(())
}
None => {
info!("connection closed");
return Poll::Ready(())
}
}
}).await
}
}
})
.instrument(span)
);
Ok(ClientInner {
Ok(Client {
inner: client,
session: tx,
ids,
})
}
struct ClientInner {
inner: tokio_postgres::Client,
pub struct Client {
pub inner: tokio_postgres::Client,
session: tokio::sync::watch::Sender<uuid::Uuid>,
ids: Ids,
}
impl Client {
pub fn metrics(&self) -> Arc<MetricCounter> {
USAGE_METRICS.register(self.inner.as_ref().unwrap().ids.clone())
}
}
pub struct Client {
span: Span,
inner: Option<ClientInner>,
pool: Option<(ConnInfo, Arc<GlobalConnPool>)>,
}
pub struct Discard<'a> {
pool: &'a mut Option<(ConnInfo, Arc<GlobalConnPool>)>,
}
impl Client {
pub fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
let Self {
inner,
pool,
span: _,
} = self;
(
&mut inner
.as_mut()
.expect("client inner should not be removed")
.inner,
Discard { pool },
)
}
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
self.inner().1.check_idle(status)
}
pub fn discard(&mut self) {
self.inner().1.discard()
}
}
impl Discard<'_> {
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
if status != ReadyForQueryStatus::Idle {
if let Some((conn_info, _)) = self.pool.take() {
info!("pool: throwing away connection '{conn_info}' because connection is not idle")
}
}
}
pub fn discard(&mut self) {
if let Some((conn_info, _)) = self.pool.take() {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
}
}
}
impl Deref for Client {
type Target = tokio_postgres::Client;
fn deref(&self) -> &Self::Target {
&self
.inner
.as_ref()
.expect("client inner should not be removed")
.inner
}
}
impl Drop for Client {
fn drop(&mut self) {
let client = self
.inner
.take()
.expect("client inner should not be removed");
if let Some((conn_info, conn_pool)) = self.pool.take() {
let current_span = self.span.clone();
// return connection to the pool
tokio::task::spawn_blocking(move || {
let _span = current_span.enter();
let _ = conn_pool.put(&conn_info, client);
});
}
USAGE_METRICS.register(self.ids.clone())
}
}

View File

@@ -17,18 +17,13 @@ use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::GenericClient;
use tokio_postgres::IsolationLevel;
use tokio_postgres::ReadyForQueryStatus;
use tokio_postgres::Row;
use tokio_postgres::Transaction;
use tracing::error;
use tracing::instrument;
use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::config::HttpConfig;
use crate::proxy::{NUM_CONNECTIONS_ACCEPTED_COUNTER, NUM_CONNECTIONS_CLOSED_COUNTER};
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
@@ -66,18 +61,20 @@ static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
fn json_to_pg_text(json: Vec<Value>) -> Result<Vec<Option<String>>, serde_json::Error> {
json.iter()
.map(|value| {
match value {
// special care for nulls
Value::Null => None,
Value::Null => Ok(None),
// convert to text with escaping
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
Value::Bool(_) => serde_json::to_string(value).map(Some),
Value::Number(_) => serde_json::to_string(value).map(Some),
Value::Object(_) => serde_json::to_string(value).map(Some),
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.to_string()),
Value::String(s) => Ok(Some(s.to_string())),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
@@ -94,26 +91,29 @@ fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(value: &Value) -> Option<String> {
fn json_array_to_pg_array(value: &Value) -> Result<Option<String>, serde_json::Error> {
match value {
// special care for nulls
Value::Null => None,
Value::Null => Ok(None),
// convert to text with escaping
Value::Bool(_) => serde_json::to_string(value).map(Some),
Value::Number(_) => serde_json::to_string(value).map(Some),
Value::Object(_) => serde_json::to_string(value).map(Some),
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
Value::String(_) => serde_json::to_string(value).map(Some),
// recurse into array
Value::Array(arr) => {
let vals = arr
.iter()
.map(json_array_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.collect::<Vec<_>>()
.map(|r| r.map(|v| v.unwrap_or_else(|| "NULL".to_string())))
.collect::<Result<Vec<_>, _>>()?
.join(",");
Some(format!("{{{}}}", vals))
Ok(Some(format!("{{{}}}", vals)))
}
}
}
@@ -188,46 +188,28 @@ pub async fn handle(
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
session_id: uuid::Uuid,
config: &'static HttpConfig,
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.sql_over_http_timeout,
handle_inner(request, sni_hostname, conn_pool, session_id),
)
.await;
let result = handle_inner(request, sni_hostname, conn_pool, session_id).await;
let mut response = match result {
Ok(r) => match r {
Ok(r) => r,
Err(e) => {
let message = format!("{:?}", e);
let code = e.downcast_ref::<tokio_postgres::Error>().and_then(|e| {
e.code()
.map(|s| serde_json::to_value(s.code()).unwrap_or_default())
});
let code = match code {
Some(c) => c,
Ok(r) => r,
Err(e) => {
let message = format!("{:?}", e);
let code = match e.downcast_ref::<tokio_postgres::Error>() {
Some(e) => match e.code() {
Some(e) => serde_json::to_value(e.code()).unwrap(),
None => Value::Null,
};
error!(
?code,
"sql-over-http per-client task finished with an error: {e:#}"
);
// TODO: this shouldn't always be bad request.
json_response(
StatusCode::BAD_REQUEST,
json!({ "message": message, "code": code }),
)?
}
},
Err(_) => {
let message = format!(
"HTTP-Connection timed out, execution time exeeded {} seconds",
config.sql_over_http_timeout.as_secs()
},
None => Value::Null,
};
error!(
?code,
"sql-over-http per-client task finished with an error: {e:#}"
);
error!(message);
// TODO: this shouldn't always be bad request.
json_response(
StatusCode::GATEWAY_TIMEOUT,
json!({ "message": message, "code": StatusCode::GATEWAY_TIMEOUT.as_u16() }),
StatusCode::BAD_REQUEST,
json!({ "message": message, "code": code }),
)?
}
};
@@ -245,13 +227,6 @@ async fn handle_inner(
conn_pool: Arc<GlobalConnPool>,
session_id: uuid::Uuid,
) -> anyhow::Result<Response<Body>> {
NUM_CONNECTIONS_ACCEPTED_COUNTER
.with_label_values(&["http"])
.inc();
scopeguard::defer! {
NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&["http"]).inc();
}
//
// Determine the destination and connection params
//
@@ -312,119 +287,83 @@ async fn handle_inner(
// Now execute the query and return the result
//
let mut size = 0;
let result =
match payload {
Payload::Single(stmt) => {
let (status, results) =
query_to_json(&*client, stmt, &mut 0, raw_output, array_mode)
.await
.map_err(|e| {
client.discard();
e
})?;
client.check_idle(status);
results
let result = match payload {
Payload::Single(query) => {
query_to_json(&client.inner, query, &mut size, raw_output, array_mode).await
}
Payload::Batch(batch_query) => {
let mut results = Vec::new();
let mut builder = client.inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
Payload::Batch(statements) => {
let (inner, mut discard) = client.inner();
let mut builder = inner.build_transaction();
if let Some(isolation_level) = txn_isolation_level {
builder = builder.isolation_level(isolation_level);
}
if txn_read_only {
builder = builder.read_only(true);
}
if txn_deferrable {
builder = builder.deferrable(true);
}
let transaction = builder.start().await.map_err(|e| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
e
})?;
let results =
match query_batch(&transaction, statements, &mut size, raw_output, array_mode)
.await
{
Ok(results) => {
let status = transaction.commit().await.map_err(|e| {
// if we cannot commit - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
results
}
Err(err) => {
let status = transaction.rollback().await.map_err(|e| {
// if we cannot rollback - for now don't return connection to pool
// TODO: get a query status from the error
discard.discard();
e
})?;
discard.check_idle(status);
return Err(err);
}
};
if txn_read_only {
response = response.header(
TXN_READ_ONLY.clone(),
HeaderValue::try_from(txn_read_only.to_string())?,
);
}
if txn_deferrable {
response = response.header(
TXN_DEFERRABLE.clone(),
HeaderValue::try_from(txn_deferrable.to_string())?,
);
}
if let Some(txn_isolation_level) = txn_isolation_level_raw {
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
json!({ "results": results })
if txn_read_only {
builder = builder.read_only(true);
}
};
if txn_deferrable {
builder = builder.deferrable(true);
}
let transaction = builder.start().await?;
for query in batch_query.queries {
let result =
query_to_json(&transaction, query, &mut size, raw_output, array_mode).await;
match result {
Ok(r) => results.push(r),
Err(e) => {
transaction.rollback().await?;
return Err(e);
}
}
}
transaction.commit().await?;
if txn_read_only {
response = response.header(
TXN_READ_ONLY.clone(),
HeaderValue::try_from(txn_read_only.to_string())?,
);
}
if txn_deferrable {
response = response.header(
TXN_DEFERRABLE.clone(),
HeaderValue::try_from(txn_deferrable.to_string())?,
);
}
if let Some(txn_isolation_level) = txn_isolation_level_raw {
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
}
Ok(json!({ "results": results }))
}
};
let metrics = client.metrics();
// how could this possibly fail
let body = serde_json::to_string(&result).expect("json serialization should not fail");
let len = body.len();
let response = response
.body(Body::from(body))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
// count the egress bytes - we miss the TLS and header overhead but oh well...
// moving this later in the stack is going to be a lot of effort and ehhhh
metrics.record_egress(len as u64);
Ok(response)
}
async fn query_batch(
transaction: &Transaction<'_>,
queries: BatchQueryData,
total_size: &mut usize,
raw_output: bool,
array_mode: bool,
) -> anyhow::Result<Vec<Value>> {
let mut results = Vec::with_capacity(queries.queries.len());
let mut current_size = 0;
for stmt in queries.queries {
// TODO: maybe we should check that the transaction bit is set here
let (_, values) =
query_to_json(transaction, stmt, &mut current_size, raw_output, array_mode).await?;
results.push(values);
if allow_pool {
let current_span = tracing::Span::current();
// return connection to the pool
tokio::task::spawn_blocking(move || {
let _span = current_span.enter();
let _ = conn_pool.put(&conn_info, client);
});
}
match result {
Ok(value) => {
// how could this possibly fail
let body = serde_json::to_string(&value).expect("json serialization should not fail");
let len = body.len();
let response = response
.body(Body::from(body))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
// count the egress bytes - we miss the TLS and header overhead but oh well...
// moving this later in the stack is going to be a lot of effort and ehhhh
metrics.record_egress(len as u64);
Ok(response)
}
Err(e) => Err(e),
}
*total_size += current_size;
Ok(results)
}
async fn query_to_json<T: GenericClient>(
@@ -433,9 +372,11 @@ async fn query_to_json<T: GenericClient>(
current_size: &mut usize,
raw_output: bool,
array_mode: bool,
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
let query_params = json_to_pg_text(data.params);
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
) -> anyhow::Result<Value> {
let query_params = json_to_pg_text(data.params)?;
let row_stream = client
.query_raw_txt::<String, _>(data.query, query_params)
.await?;
// Manually drain the stream into a vector to leave row_stream hanging
// around to get a command tag. Also check that the response is not too
@@ -455,8 +396,6 @@ async fn query_to_json<T: GenericClient>(
}
}
let ready = row_stream.ready_status();
// grab the command tag and number of rows affected
let command_tag = row_stream.command_tag().unwrap_or_default();
let mut command_tag_split = command_tag.split(' ');
@@ -497,16 +436,13 @@ async fn query_to_json<T: GenericClient>(
.collect::<Result<Vec<_>, _>>()?;
// resulting JSON format is based on the format of node-postgres result
Ok((
ready,
json!({
"command": command_tag_name,
"rowCount": command_tag_count,
"rows": rows,
"fields": fields,
"rowAsArray": array_mode,
}),
))
Ok(json!({
"command": command_tag_name,
"rowCount": command_tag_count,
"rows": rows,
"fields": fields,
"rowAsArray": array_mode,
}))
}
//
@@ -649,7 +585,7 @@ fn _pg_array_parse(
}
}
}
'}' if !quote => {
'}' => {
level -= 1;
if level == 0 {
push_checked(&mut entry, &mut entries, elem_type)?;
@@ -691,22 +627,22 @@ mod tests {
#[test]
fn test_atomic_types_to_pg_params() {
let json = vec![Value::Bool(true), Value::Bool(false)];
let pg_params = json_to_pg_text(json);
let pg_params = json_to_pg_text(json).unwrap();
assert_eq!(
pg_params,
vec![Some("true".to_owned()), Some("false".to_owned())]
);
let json = vec![Value::Number(serde_json::Number::from(42))];
let pg_params = json_to_pg_text(json);
let pg_params = json_to_pg_text(json).unwrap();
assert_eq!(pg_params, vec![Some("42".to_owned())]);
let json = vec![Value::String("foo\"".to_string())];
let pg_params = json_to_pg_text(json);
let pg_params = json_to_pg_text(json).unwrap();
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
let json = vec![Value::Null];
let pg_params = json_to_pg_text(json);
let pg_params = json_to_pg_text(json).unwrap();
assert_eq!(pg_params, vec![None]);
}
@@ -715,7 +651,7 @@ mod tests {
// atoms and escaping
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
let pg_params = json_to_pg_text(vec![json]).unwrap();
assert_eq!(
pg_params,
vec![Some(
@@ -726,21 +662,13 @@ mod tests {
// nested arrays
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
let pg_params = json_to_pg_text(vec![json]).unwrap();
assert_eq!(
pg_params,
vec![Some(
"{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
)]
);
// array of objects
let json = r#"[{"foo": 1},{"bar": 2}]"#;
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
);
}
#[test]
@@ -868,23 +796,4 @@ mod tests {
json!([[[1, 2, 3], [4, 5, 6]]])
);
}
#[test]
fn test_pg_array_parse_json() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::JSONB).unwrap()
}
assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
assert_eq!(
pt(r#"{"{\"foo\": 1, \"bar\": 2}"}"#),
json!([{"foo": 1, "bar": 2}])
);
assert_eq!(
pt(r#"{"{\"foo\": 1}", "{\"bar\": 2}"}"#),
json!([{"foo": 1}, {"bar": 2}])
);
assert_eq!(
pt(r#"{{"{\"foo\": 1}", "{\"bar\": 2}"}}"#),
json!([[{"foo": 1}, {"bar": 2}]])
);
}
}

View File

@@ -3,12 +3,8 @@ use crate::{
config::ProxyConfig,
error::io_error,
protocol2::{ProxyProtocolAccept, WithClientIp},
proxy::{
handle_client, ClientMode, NUM_CLIENT_CONNECTION_CLOSED_COUNTER,
NUM_CLIENT_CONNECTION_OPENED_COUNTER,
},
proxy::{handle_client, ClientMode},
};
use anyhow::bail;
use bytes::{Buf, Bytes};
use futures::{Sink, Stream, StreamExt};
use hyper::{
@@ -23,6 +19,7 @@ use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
use pin_project_lite::pin_project;
use std::{
convert::Infallible,
future::ready,
pin::Pin,
sync::Arc,
@@ -205,14 +202,7 @@ async fn ws_handler(
// TODO: that deserves a refactor as now this function also handles http json client besides websockets.
// Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead.
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
sql_over_http::handle(
request,
sni_hostname,
conn_pool,
session_id,
&config.http_config,
)
.await
sql_over_http::handle(request, sni_hostname, conn_pool, session_id).await
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")
@@ -280,36 +270,28 @@ pub async fn task_main(
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<WithClientIp<AddrStream>>| {
let (io, tls) = stream.get_ref();
let client_addr = io.client_addr();
let remote_addr = io.inner.remote_addr();
let peer_addr = io.client_addr().unwrap_or(io.inner.remote_addr());
let sni_name = tls.server_name().map(|s| s.to_string());
let conn_pool = conn_pool.clone();
async move {
let peer_addr = match client_addr {
Some(addr) => addr,
None if config.require_client_ip => bail!("missing required client ip"),
None => remote_addr,
};
Ok(MetricService::new(hyper::service::service_fn(
move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
Ok::<_, Infallible>(hyper::service::service_fn(move |req: Request<Body>| {
let sni_name = sni_name.clone();
let conn_pool = conn_pool.clone();
async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
async move {
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name)
.instrument(info_span!(
"ws-client",
session = %session_id,
%peer_addr,
))
.await
}
},
)))
ws_handler(req, config, conn_pool, cancel_map, session_id, sni_name)
.instrument(info_span!(
"ws-client",
session = %session_id,
%peer_addr,
))
.await
}
}))
}
},
);
@@ -321,41 +303,3 @@ pub async fn task_main(
Ok(())
}
struct MetricService<S> {
inner: S,
}
impl<S> MetricService<S> {
fn new(inner: S) -> MetricService<S> {
NUM_CLIENT_CONNECTION_OPENED_COUNTER
.with_label_values(&["http"])
.inc();
MetricService { inner }
}
}
impl<S> Drop for MetricService<S> {
fn drop(&mut self) {
NUM_CLIENT_CONNECTION_CLOSED_COUNTER
.with_label_values(&["http"])
.inc();
}
}
impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
where
S: hyper::service::Service<Request<ReqBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
self.inner.call(req)
}
}

View File

@@ -5,9 +5,8 @@ use crate::{
auth::{self, backend::AuthSuccess},
cancellation::{self, CancelMap},
compute::{self, PostgresConnection},
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
config::{ProxyConfig, TlsConfig},
console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
http::StatusCode,
metrics::{Ids, USAGE_METRICS},
protocol2::WithClientIp,
stream::{PqStream, Stream},
@@ -15,11 +14,12 @@ use crate::{
use anyhow::{bail, Context};
use async_trait::async_trait;
use futures::TryFutureExt;
use metrics::{exponential_buckets, register_int_counter_vec, IntCounterVec};
use metrics::{
exponential_buckets, register_histogram, register_int_counter_vec, Histogram, IntCounterVec,
};
use once_cell::sync::Lazy;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use prometheus::{register_histogram_vec, HistogramVec};
use std::{error::Error, io, ops::ControlFlow, sync::Arc, time::Instant};
use std::{error::Error, io, ops::ControlFlow, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
time,
@@ -38,121 +38,34 @@ const RETRY_WAIT_EXPONENT_BASE: f64 = std::f64::consts::SQRT_2;
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
pub static NUM_DB_CONNECTIONS_OPENED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_opened_db_connections_total",
"Number of opened connections to a database.",
&["protocol"],
)
.unwrap()
});
pub static NUM_DB_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_closed_db_connections_total",
"Number of closed connections to a database.",
&["protocol"],
)
.unwrap()
});
pub static NUM_CLIENT_CONNECTION_OPENED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_opened_client_connections_total",
"Number of opened connections from a client.",
&["protocol"],
)
.unwrap()
});
pub static NUM_CLIENT_CONNECTION_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_closed_client_connections_total",
"Number of closed connections from a client.",
&["protocol"],
)
.unwrap()
});
pub static NUM_CONNECTIONS_ACCEPTED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
static NUM_CONNECTIONS_ACCEPTED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_accepted_connections_total",
"Number of client connections accepted.",
"Number of TCP client connections accepted.",
&["protocol"],
)
.unwrap()
});
pub static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_closed_connections_total",
"Number of client connections closed.",
"Number of TCP client connections closed.",
&["protocol"],
)
.unwrap()
});
static COMPUTE_CONNECTION_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
static COMPUTE_CONNECTION_LATENCY: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"proxy_compute_connection_latency_seconds",
"Time it took for proxy to establish a connection to the compute endpoint",
// http/ws/tcp, true/false, true/false, success/failure
// 3 * 2 * 2 * 2 = 24 counters
&["protocol", "cache_miss", "pool_miss", "outcome"],
// largest bucket = 2^16 * 0.5ms = 32s
exponential_buckets(0.0005, 2.0, 16).unwrap(),
)
.unwrap()
});
pub struct LatencyTimer {
start: Instant,
protocol: &'static str,
cache_miss: bool,
pool_miss: bool,
outcome: &'static str,
}
impl LatencyTimer {
pub fn new(protocol: &'static str) -> Self {
Self {
start: Instant::now(),
protocol,
cache_miss: false,
// by default we don't do pooling
pool_miss: true,
// assume failed unless otherwise specified
outcome: "failed",
}
}
pub fn cache_miss(&mut self) {
self.cache_miss = true;
}
pub fn pool_hit(&mut self) {
self.pool_miss = false;
}
pub fn success(mut self) {
self.outcome = "success";
}
}
impl Drop for LatencyTimer {
fn drop(&mut self) {
let duration = self.start.elapsed().as_secs_f64();
COMPUTE_CONNECTION_LATENCY
.with_label_values(&[
self.protocol,
bool_to_str(self.cache_miss),
bool_to_str(self.pool_miss),
self.outcome,
])
.observe(duration)
}
}
static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_connection_failures_total",
@@ -162,15 +75,6 @@ static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
.unwrap()
});
static NUM_WAKEUP_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_connection_failures_breakdown",
"Number of wake-up failures (per kind).",
&["retry", "kind"],
)
.unwrap()
});
static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_io_bytes_per_client",
@@ -210,8 +114,6 @@ pub async fn task_main(
let mut socket = WithClientIp::new(socket);
if let Some(ip) = socket.wait_for_addr().await? {
tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
} else if config.require_client_ip {
bail!("missing required client IP");
}
socket
@@ -306,16 +208,12 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
"handling interactive connection from client"
);
let proto = mode.protocol_label();
NUM_CLIENT_CONNECTION_OPENED_COUNTER
.with_label_values(&[proto])
.inc();
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER
.with_label_values(&[proto])
.with_label_values(&[mode.protocol_label()])
.inc();
scopeguard::defer! {
NUM_CLIENT_CONNECTION_CLOSED_COUNTER.with_label_values(&[proto]).inc();
NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc();
NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[mode.protocol_label()]).inc();
}
let tls = config.tls_config.as_ref();
@@ -350,7 +248,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
mode.allow_self_signed_compute(config),
);
cancel_map
.with_session(|session| client.connect_to_db(session, mode, &config.authentication_config))
.with_session(|session| client.connect_to_db(session, mode.allow_cleartext()))
.await
}
@@ -499,46 +397,6 @@ impl ConnectMechanism for TcpMechanism<'_> {
}
}
const fn bool_to_str(x: bool) -> &'static str {
if x {
"true"
} else {
"false"
}
}
fn report_error(e: &WakeComputeError, retry: bool) {
use crate::console::errors::ApiError;
let retry = bool_to_str(retry);
let kind = match e {
WakeComputeError::BadComputeAddress(_) => "bad_compute_address",
WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error",
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
ref text,
}) if text.contains("written data quota exceeded")
|| text.contains("the limit for current plan reached") =>
{
"quota_exceeded"
}
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::LOCKED,
..
}) => "api_console_locked",
WakeComputeError::ApiError(ApiError::Console {
status: StatusCode::BAD_REQUEST,
..
}) => "api_console_bad_request",
WakeComputeError::ApiError(ApiError::Console { status, .. })
if status.is_server_error() =>
{
"api_console_other_server_error"
}
WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error",
};
NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc();
}
/// Try to connect to the compute node, retrying if necessary.
/// This function might update `node_info`, so we take it by `&mut`.
#[tracing::instrument(skip_all)]
@@ -547,28 +405,24 @@ pub async fn connect_to_compute<M: ConnectMechanism>(
mut node_info: console::CachedNodeInfo,
extra: &console::ConsoleReqExtra<'_>,
creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
mut latency_timer: LatencyTimer,
) -> Result<M::Connection, M::Error>
where
M::ConnectError: ShouldRetry + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
let _timer = COMPUTE_CONNECTION_LATENCY.start_timer();
mechanism.update_connect_config(&mut node_info.config);
// try once
let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
Ok(res) => {
latency_timer.success();
return Ok(res);
}
Ok(res) => return Ok(res),
Err(e) => {
error!(error = ?e, "could not connect to compute node");
(invalidate_cache(node_info), e)
}
};
latency_timer.cache_miss();
let mut num_retries = 1;
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
@@ -586,12 +440,10 @@ where
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
report_error(&e, false);
return Err(e.into());
}
// failed to wake up but we can continue to retry
Ok(ControlFlow::Continue(e)) => {
report_error(&e, true);
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
}
// successfully woke up a compute node and can break the wakeup loop
@@ -614,10 +466,7 @@ where
info!("wake_compute success. attempting to connect");
loop {
match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
Ok(res) => {
latency_timer.success();
return Ok(res);
}
Ok(res) => return Ok(res),
Err(e) => {
let retriable = e.should_retry(num_retries);
if !retriable {
@@ -833,8 +682,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
async fn connect_to_db(
self,
session: cancellation::Session<'_>,
mode: ClientMode,
config: &'static AuthenticationConfig,
allow_cleartext: bool,
) -> anyhow::Result<()> {
let Self {
mut stream,
@@ -849,10 +697,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
application_name: params.get("application_name"),
};
let latency_timer = LatencyTimer::new(mode.protocol_label());
let auth_result = match creds
.authenticate(&extra, &mut stream, mode.allow_cleartext(), config)
.authenticate(&extra, &mut stream, allow_cleartext)
.await
{
Ok(auth_result) => auth_result,
@@ -874,23 +720,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
node_info.allow_self_signed_compute = allow_self_signed_compute;
let aux = node_info.aux.clone();
let mut node = connect_to_compute(
&TcpMechanism { params },
node_info,
&extra,
&creds,
latency_timer,
)
.or_else(|e| stream.throw_error(e))
.await?;
let proto = mode.protocol_label();
NUM_DB_CONNECTIONS_OPENED_COUNTER
.with_label_values(&[proto])
.inc();
scopeguard::defer! {
NUM_DB_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc();
}
let mut node = connect_to_compute(&TcpMechanism { params }, node_info, &extra, &creds)
.or_else(|e| stream.throw_error(e))
.await?;
prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the

View File

@@ -450,7 +450,7 @@ async fn connect_to_compute_success() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap();
mechanism.verify();
@@ -461,7 +461,7 @@ async fn connect_to_compute_retry() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap();
mechanism.verify();
@@ -473,7 +473,7 @@ async fn connect_to_compute_non_retry_1() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap_err();
mechanism.verify();
@@ -485,7 +485,7 @@ async fn connect_to_compute_non_retry_2() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap();
mechanism.verify();
@@ -501,7 +501,7 @@ async fn connect_to_compute_non_retry_3() {
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap_err();
mechanism.verify();
@@ -513,7 +513,7 @@ async fn wake_retry() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap();
mechanism.verify();
@@ -525,7 +525,7 @@ async fn wake_non_retry() {
use ConnectAction::*;
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
let (cache, extra, creds) = helper_create_connect_info(&mechanism);
connect_to_compute(&mechanism, cache, &extra, &creds, LatencyTimer::new("test"))
connect_to_compute(&mechanism, cache, &extra, &creds)
.await
.unwrap_err();
mechanism.verify();

Some files were not shown because too many files have changed in this diff Show More