Compare commits

..

20 Commits

Author SHA1 Message Date
Conrad Ludgate
f3f7d0d3f1 zero-copy jwt claim validation 2024-09-30 12:47:07 +01:00
Conrad Ludgate
0724df1d3f stash 2024-09-29 20:29:26 +01:00
Conrad Ludgate
4d47049b00 split up jwt tests 2024-09-27 16:31:49 +01:00
Conrad Ludgate
5687384a8e remove deref impl 2024-09-27 11:43:34 +01:00
Conrad Ludgate
249f5ea17d cleaner local-proxy conn error code 2024-09-27 11:43:34 +01:00
Conrad Ludgate
6abcc1f298 add explicit panic reason 2024-09-27 11:43:34 +01:00
Conrad Ludgate
3e97cf0d6e refine missing credentials error 2024-09-27 11:43:34 +01:00
Conrad Ludgate
054ef4988b update certification comment 2024-09-27 11:43:34 +01:00
Conrad Ludgate
5202cd75b5 only forward expected headers 2024-09-27 11:43:34 +01:00
Conrad Ludgate
f475dac0e6 keepalive while idle 2024-09-27 11:43:34 +01:00
Conrad Ludgate
a4100373e5 fix common name parsing 2024-09-27 11:43:34 +01:00
Conrad Ludgate
040d8cf4f6 fix common name parsing 2024-09-27 11:43:34 +01:00
Conrad Ludgate
75bfd57e01 add authbroker cli flag and fix http2 ka 2024-09-27 11:43:34 +01:00
Conrad Ludgate
4bc2686dee small tweaks 2024-09-27 11:43:34 +01:00
Conrad Ludgate
8e7d2aab76 put it all together 2024-09-27 11:43:34 +01:00
Conrad Ludgate
2703abccc7 start on http2 local proxy connection pool 2024-09-27 11:43:34 +01:00
Conrad Ludgate
76515cdae3 split out auth info from conn info, return the jwt as the auth keys 2024-09-27 11:43:34 +01:00
Conrad Ludgate
08c7f933a3 add support for console backend jwt 2024-09-27 11:43:34 +01:00
Conrad Ludgate
4ad3aa7c96 update doc comment for get_with_url 2024-09-27 10:24:50 +01:00
Conrad Ludgate
9c59e3b4b9 proxy: add jwks endpoint to control plane and mock providers 2024-09-27 10:24:43 +01:00
69 changed files with 2924 additions and 2735 deletions

View File

@@ -3,23 +3,19 @@ name: Prepare benchmarking databases by restoring dumps
on:
workflow_call:
# no inputs needed
defaults:
run:
shell: bash -euxo pipefail {0}
jobs:
setup-databases:
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
strategy:
fail-fast: false
matrix:
platform: [ aws-rds-postgres, aws-aurora-serverless-v2-postgres, neon ]
platform: [ aws-rds-postgres, aws-aurora-serverless-v2-postgres, neon ]
database: [ clickbench, tpch, userexample ]
env:
LD_LIBRARY_PATH: /tmp/neon/pg_install/v16/lib
PLATFORM: ${{ matrix.platform }}
@@ -27,10 +23,7 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
@@ -39,13 +32,13 @@ jobs:
run: |
case "${PLATFORM}" in
neon)
CONNSTR=${{ secrets.BENCHMARK_CAPTEST_CONNSTR }}
CONNSTR=${{ secrets.BENCHMARK_CAPTEST_CONNSTR }}
;;
aws-rds-postgres)
CONNSTR=${{ secrets.BENCHMARK_RDS_POSTGRES_CONNSTR }}
CONNSTR=${{ secrets.BENCHMARK_RDS_POSTGRES_CONNSTR }}
;;
aws-aurora-serverless-v2-postgres)
CONNSTR=${{ secrets.BENCHMARK_RDS_AURORA_CONNSTR }}
CONNSTR=${{ secrets.BENCHMARK_RDS_AURORA_CONNSTR }}
;;
*)
echo >&2 "Unknown PLATFORM=${PLATFORM}"
@@ -53,17 +46,10 @@ jobs:
;;
esac
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
@@ -71,23 +57,23 @@ jobs:
path: /tmp/neon/
prefix: latest
# we create a table that has one row for each database that we want to restore with the status whether the restore is done
# we create a table that has one row for each database that we want to restore with the status whether the restore is done
- name: Create benchmark_restore_status table if it does not exist
env:
BENCHMARK_CONNSTR: ${{ steps.set-up-prep-connstr.outputs.connstr }}
DATABASE_NAME: ${{ matrix.database }}
# to avoid a race condition of multiple jobs trying to create the table at the same time,
# to avoid a race condition of multiple jobs trying to create the table at the same time,
# we use an advisory lock
run: |
${PG_BINARIES}/psql "${{ env.BENCHMARK_CONNSTR }}" -c "
SELECT pg_advisory_lock(4711);
SELECT pg_advisory_lock(4711);
CREATE TABLE IF NOT EXISTS benchmark_restore_status (
databasename text primary key,
restore_done boolean
);
SELECT pg_advisory_unlock(4711);
"
- name: Check if restore is already done
id: check-restore-done
env:
@@ -121,7 +107,7 @@ jobs:
DATABASE_NAME: ${{ matrix.database }}
run: |
mkdir -p /tmp/dumps
aws s3 cp s3://neon-github-dev/performance/pgdumps/$DATABASE_NAME/$DATABASE_NAME.pg_dump /tmp/dumps/
aws s3 cp s3://neon-github-dev/performance/pgdumps/$DATABASE_NAME/$DATABASE_NAME.pg_dump /tmp/dumps/
- name: Replace database name in connection string
if: steps.check-restore-done.outputs.skip != 'true'
@@ -140,17 +126,17 @@ jobs:
else
new_connstr="${base_connstr}/${DATABASE_NAME}"
fi
echo "database_connstr=${new_connstr}" >> $GITHUB_OUTPUT
echo "database_connstr=${new_connstr}" >> $GITHUB_OUTPUT
- name: Restore dump
if: steps.check-restore-done.outputs.skip != 'true'
env:
DATABASE_NAME: ${{ matrix.database }}
DATABASE_CONNSTR: ${{ steps.replace-dbname.outputs.database_connstr }}
# the following works only with larger computes:
# the following works only with larger computes:
# PGOPTIONS: "-c maintenance_work_mem=8388608 -c max_parallel_maintenance_workers=7"
# we add the || true because:
# the dumps were created with Neon and contain neon extensions that are not
# the dumps were created with Neon and contain neon extensions that are not
# available in RDS, so we will always report an error, but we can ignore it
run: |
${PG_BINARIES}/pg_restore --clean --if-exists --no-owner --jobs=4 \

View File

@@ -236,7 +236,9 @@ jobs:
# run pageserver tests with different settings
for io_engine in std-fs tokio-epoll-uring ; do
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)'
for io_buffer_alignment in 0 1 512 ; do
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT=$io_buffer_alignment ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)'
done
done
# Run separate tests for real S3

View File

@@ -12,6 +12,7 @@ on:
# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
- cron: '0 3 * * *' # run once a day, timezone is utc
workflow_dispatch: # adds ability to run this manually
inputs:
region_id:
@@ -58,7 +59,7 @@ jobs:
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
id-token: write # Required for OIDC authentication in azure runners
strategy:
fail-fast: false
matrix:
@@ -67,10 +68,12 @@ jobs:
PLATFORM: "neon-staging"
region_id: ${{ github.event.inputs.region_id || 'aws-us-east-2' }}
RUNNER: [ self-hosted, us-east-2, x64 ]
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
- DEFAULT_PG_VERSION: 16
PLATFORM: "azure-staging"
region_id: 'azure-eastus2'
RUNNER: [ self-hosted, eastus2, x64 ]
IMAGE: neondatabase/build-tools:pinned
env:
TEST_PG_BENCH_DURATIONS_MATRIX: "300"
TEST_PG_BENCH_SCALES_MATRIX: "10,100"
@@ -83,10 +86,7 @@ jobs:
runs-on: ${{ matrix.RUNNER }}
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: ${{ matrix.IMAGE }}
options: --init
steps:
@@ -164,10 +164,6 @@ jobs:
replication-tests:
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
env:
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 16
@@ -178,21 +174,12 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
@@ -280,7 +267,7 @@ jobs:
region_id_default=${{ env.DEFAULT_REGION_ID }}
runner_default='["self-hosted", "us-east-2", "x64"]'
runner_azure='["self-hosted", "eastus2", "x64"]'
image_default="neondatabase/build-tools:pinned"
image_default="369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned"
matrix='{
"pg_version" : [
16
@@ -357,7 +344,7 @@ jobs:
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
id-token: write # Required for OIDC authentication in azure runners
strategy:
fail-fast: false
@@ -384,7 +371,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
- name: Configure AWS credentials # necessary on Azure runners
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
@@ -505,15 +492,17 @@ jobs:
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
id-token: write # Required for OIDC authentication in azure runners
strategy:
fail-fast: false
matrix:
include:
- PLATFORM: "neonvm-captest-pgvector"
RUNNER: [ self-hosted, us-east-2, x64 ]
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
- PLATFORM: "azure-captest-pgvector"
RUNNER: [ self-hosted, eastus2, x64 ]
IMAGE: neondatabase/build-tools:pinned
env:
TEST_PG_BENCH_DURATIONS_MATRIX: "15m"
@@ -522,16 +511,13 @@ jobs:
DEFAULT_PG_VERSION: 16
TEST_OUTPUT: /tmp/test_output
BUILD_TYPE: remote
LD_LIBRARY_PATH: /home/nonroot/pg/usr/lib/x86_64-linux-gnu
SAVE_PERF_REPORT: ${{ github.event.inputs.save_perf_report || ( github.ref_name == 'main' ) }}
PLATFORM: ${{ matrix.PLATFORM }}
runs-on: ${{ matrix.RUNNER }}
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: ${{ matrix.IMAGE }}
options: --init
steps:
@@ -541,26 +527,17 @@ jobs:
# instead of using Neon artifacts containing pgbench
- name: Install postgresql-16 where pytest expects it
run: |
# Just to make it easier to test things locally on macOS (with arm64)
arch=$(uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g')
cd /home/nonroot
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.0-1.pgdg110+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.4-1.pgdg110+2_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.4-1.pgdg110+2_${arch}.deb"
dpkg -x libpq5_17.0-1.pgdg110+1_${arch}.deb pg
dpkg -x postgresql-16_16.4-1.pgdg110+2_${arch}.deb pg
dpkg -x postgresql-client-16_16.4-1.pgdg110+2_${arch}.deb pg
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/libpq5_16.4-1.pgdg110%2B1_amd64.deb
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.4-1.pgdg110%2B1_amd64.deb
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.4-1.pgdg110%2B1_amd64.deb
dpkg -x libpq5_16.4-1.pgdg110+1_amd64.deb pg
dpkg -x postgresql-client-16_16.4-1.pgdg110+1_amd64.deb pg
dpkg -x postgresql-16_16.4-1.pgdg110+1_amd64.deb pg
mkdir -p /tmp/neon/pg_install/v16/bin
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/psql /tmp/neon/pg_install/v16/bin/psql
ln -s /home/nonroot/pg/usr/lib/$(uname -m)-linux-gnu /tmp/neon/pg_install/v16/lib
LD_LIBRARY_PATH="/home/nonroot/pg/usr/lib/$(uname -m)-linux-gnu:${LD_LIBRARY_PATH}"
export LD_LIBRARY_PATH
echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" >> ${GITHUB_ENV}
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/psql /tmp/neon/pg_install/v16/bin/psql
ln -s /home/nonroot/pg/usr/lib/x86_64-linux-gnu /tmp/neon/pg_install/v16/lib
/tmp/neon/pg_install/v16/bin/pgbench --version
/tmp/neon/pg_install/v16/bin/psql --version
@@ -582,7 +559,7 @@ jobs:
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
- name: Configure AWS credentials
- name: Configure AWS credentials # necessary on Azure runners to read/write from/to S3
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
@@ -643,10 +620,6 @@ jobs:
# *_CLICKBENCH_CONNSTR: Genuine ClickBench DB with ~100M rows
# *_CLICKBENCH_10M_CONNSTR: DB with the first 10M rows of ClickBench DB
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
needs: [ generate-matrices, pgbench-compare, prepare_AWS_RDS_databases ]
strategy:
@@ -665,22 +638,12 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
@@ -751,10 +714,6 @@ jobs:
#
# *_TPCH_S10_CONNSTR: DB generated with scale factor 10 (~10 GB)
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
needs: [ generate-matrices, clickbench-compare, prepare_AWS_RDS_databases ]
strategy:
@@ -772,22 +731,12 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
@@ -857,10 +806,6 @@ jobs:
user-examples-compare:
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
needs: [ generate-matrices, tpch-compare, prepare_AWS_RDS_databases ]
strategy:
@@ -877,22 +822,12 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:

1
Cargo.lock generated
View File

@@ -1321,7 +1321,6 @@ dependencies = [
"clap",
"comfy-table",
"compute_api",
"futures",
"humantime",
"humantime-serde",
"hyper 0.14.30",

View File

@@ -13,9 +13,6 @@ RUN useradd -ms /bin/bash nonroot -b /home
SHELL ["/bin/bash", "-c"]
# System deps
#
# 'gdb' is included so that we get backtraces of core dumps produced in
# regression tests
RUN set -e \
&& apt update \
&& apt install -y \
@@ -27,7 +24,6 @@ RUN set -e \
cmake \
curl \
flex \
gdb \
git \
gnupg \
gzip \

View File

@@ -11,10 +11,6 @@ commands:
user: root
sysvInitAction: sysinit
shell: 'chmod 711 /neonvm/bin/resize-swap'
- name: chmod-set-disk-quota
user: root
sysvInitAction: sysinit
shell: 'chmod 711 /neonvm/bin/set-disk-quota'
- name: pgbouncer
user: postgres
sysvInitAction: respawn
@@ -34,12 +30,11 @@ commands:
shutdownHook: |
su -p postgres --session-command '/usr/local/bin/pg_ctl stop -D /var/db/postgres/compute/pgdata -m fast --wait -t 10'
files:
- filename: compute_ctl-sudoers
- filename: compute_ctl-resize-swap
content: |
# Allow postgres user (which is what compute_ctl runs as) to run /neonvm/bin/resize-swap
# and /neonvm/bin/set-disk-quota as root without requiring entering a password (NOPASSWD),
# regardless of hostname (ALL)
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap, /neonvm/bin/set-disk-quota
# as root without requiring entering a password (NOPASSWD), regardless of hostname (ALL)
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap
- filename: cgconfig.conf
content: |
# Configuration for cgroups in VM compute nodes
@@ -105,7 +100,7 @@ merge: |
&& apt install --no-install-recommends -y \
sudo \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
COPY compute_ctl-sudoers /etc/sudoers.d/compute_ctl-sudoers
COPY compute_ctl-resize-swap /etc/sudoers.d/compute_ctl-resize-swap
COPY cgconfig.conf /etc/cgconfig.conf

View File

@@ -44,7 +44,6 @@ use std::{thread, time::Duration};
use anyhow::{Context, Result};
use chrono::Utc;
use clap::Arg;
use compute_tools::disk_quota::set_disk_quota;
use compute_tools::lsn_lease::launch_lsn_lease_bg_task_for_static;
use signal_hook::consts::{SIGQUIT, SIGTERM};
use signal_hook::{consts::SIGINT, iterator::Signals};
@@ -152,7 +151,6 @@ fn process_cli(matches: &clap::ArgMatches) -> Result<ProcessCliResult> {
let spec_json = matches.get_one::<String>("spec");
let spec_path = matches.get_one::<String>("spec-path");
let resize_swap_on_bind = matches.get_flag("resize-swap-on-bind");
let set_disk_quota_for_fs = matches.get_one::<String>("set-disk-quota-for-fs");
Ok(ProcessCliResult {
connstr,
@@ -163,7 +161,6 @@ fn process_cli(matches: &clap::ArgMatches) -> Result<ProcessCliResult> {
spec_json,
spec_path,
resize_swap_on_bind,
set_disk_quota_for_fs,
})
}
@@ -176,7 +173,6 @@ struct ProcessCliResult<'clap> {
spec_json: Option<&'clap String>,
spec_path: Option<&'clap String>,
resize_swap_on_bind: bool,
set_disk_quota_for_fs: Option<&'clap String>,
}
fn startup_context_from_env() -> Option<opentelemetry::ContextGuard> {
@@ -297,7 +293,6 @@ fn wait_spec(
pgbin,
ext_remote_storage,
resize_swap_on_bind,
set_disk_quota_for_fs,
http_port,
..
}: ProcessCliResult,
@@ -378,7 +373,6 @@ fn wait_spec(
compute,
http_port,
resize_swap_on_bind,
set_disk_quota_for_fs: set_disk_quota_for_fs.cloned(),
})
}
@@ -387,7 +381,6 @@ struct WaitSpecResult {
// passed through from ProcessCliResult
http_port: u16,
resize_swap_on_bind: bool,
set_disk_quota_for_fs: Option<String>,
}
fn start_postgres(
@@ -397,7 +390,6 @@ fn start_postgres(
compute,
http_port,
resize_swap_on_bind,
set_disk_quota_for_fs,
}: WaitSpecResult,
) -> Result<(Option<PostgresHandle>, StartPostgresResult)> {
// We got all we need, update the state.
@@ -411,7 +403,6 @@ fn start_postgres(
);
// before we release the mutex, fetch the swap size (if any) for later.
let swap_size_bytes = state.pspec.as_ref().unwrap().spec.swap_size_bytes;
let disk_quota_bytes = state.pspec.as_ref().unwrap().spec.disk_quota_bytes;
drop(state);
// Launch remaining service threads
@@ -431,8 +422,8 @@ fn start_postgres(
// OOM-killed during startup because swap wasn't available yet.
match resize_swap(size_bytes) {
Ok(()) => {
let size_mib = size_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
info!(%size_bytes, %size_mib, "resized swap");
let size_gib = size_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
info!(%size_bytes, %size_gib, "resized swap");
}
Err(err) => {
let err = err.context("failed to resize swap");
@@ -441,29 +432,10 @@ fn start_postgres(
// Mark compute startup as failed; don't try to start postgres, and report this
// error to the control plane when it next asks.
prestartup_failed = true;
compute.set_failed_status(err);
delay_exit = true;
}
}
}
// Set disk quota if the compute spec says so
if let (Some(disk_quota_bytes), Some(disk_quota_fs_mountpoint)) =
(disk_quota_bytes, set_disk_quota_for_fs)
{
match set_disk_quota(disk_quota_bytes, &disk_quota_fs_mountpoint) {
Ok(()) => {
let size_mib = disk_quota_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
info!(%disk_quota_bytes, %size_mib, "set disk quota");
}
Err(err) => {
let err = err.context("failed to set disk quota");
error!("{err:#}");
// Mark compute startup as failed; don't try to start postgres, and report this
// error to the control plane when it next asks.
prestartup_failed = true;
compute.set_failed_status(err);
let mut state = compute.state.lock().unwrap();
state.error = Some(format!("{err:?}"));
state.status = ComputeStatus::Failed;
compute.state_changed.notify_all();
delay_exit = true;
}
}
@@ -478,7 +450,16 @@ fn start_postgres(
Ok(pg) => Some(pg),
Err(err) => {
error!("could not start the compute node: {:#}", err);
compute.set_failed_status(err);
let mut state = compute.state.lock().unwrap();
state.error = Some(format!("{:?}", err));
state.status = ComputeStatus::Failed;
// Notify others that Postgres failed to start. In case of configuring the
// empty compute, it's likely that API handler is still waiting for compute
// state change. With this we will notify it that compute is in Failed state,
// so control plane will know about it earlier and record proper error instead
// of timeout.
compute.state_changed.notify_all();
drop(state); // unlock
delay_exit = true;
None
}
@@ -769,11 +750,6 @@ fn cli() -> clap::Command {
.long("resize-swap-on-bind")
.action(clap::ArgAction::SetTrue),
)
.arg(
Arg::new("set-disk-quota-for-fs")
.long("set-disk-quota-for-fs")
.value_name("SET_DISK_QUOTA_FOR_FS")
)
}
/// When compute_ctl is killed, send also termination signal to sync-safekeepers

View File

@@ -10,7 +10,6 @@ use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::sync::{Condvar, Mutex, RwLock};
use std::thread;
use std::time::Duration;
use std::time::Instant;
use anyhow::{Context, Result};
@@ -306,13 +305,6 @@ impl ComputeNode {
self.state_changed.notify_all();
}
pub fn set_failed_status(&self, err: anyhow::Error) {
let mut state = self.state.lock().unwrap();
state.error = Some(format!("{err:?}"));
state.status = ComputeStatus::Failed;
self.state_changed.notify_all();
}
pub fn get_status(&self) -> ComputeStatus {
self.state.lock().unwrap().status
}
@@ -718,7 +710,7 @@ impl ComputeNode {
info!("running initdb");
let initdb_bin = Path::new(&self.pgbin).parent().unwrap().join("initdb");
Command::new(initdb_bin)
.args(["--pgdata", pgdata])
.args(["-D", pgdata])
.output()
.expect("cannot start initdb process");
@@ -1131,9 +1123,6 @@ impl ComputeNode {
//
// Use that as a default location and pattern, except macos where core dumps are written
// to /cores/ directory by default.
//
// With default Linux settings, the core dump file is called just "core", so check for
// that too.
pub fn check_for_core_dumps(&self) -> Result<()> {
let core_dump_dir = match std::env::consts::OS {
"macos" => Path::new("/cores/"),
@@ -1145,17 +1134,8 @@ impl ComputeNode {
let files = fs::read_dir(core_dump_dir)?;
let cores = files.filter_map(|entry| {
let entry = entry.ok()?;
let is_core_dump = match entry.file_name().to_str()? {
n if n.starts_with("core.") => true,
"core" => true,
_ => false,
};
if is_core_dump {
Some(entry.path())
} else {
None
}
let _ = entry.file_name().to_str()?.strip_prefix("core.")?;
Some(entry.path())
});
// Print backtrace for each core dump
@@ -1406,36 +1386,6 @@ LIMIT 100",
}
Ok(remote_ext_metrics)
}
/// Waits until current thread receives a state changed notification and
/// the pageserver connection strings has changed.
///
/// The operation will time out after a specified duration.
pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) {
let state = self.state.lock().unwrap();
let old_pageserver_connstr = state
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_connstr
.clone();
let mut unchanged = true;
let _ = self
.state_changed
.wait_timeout_while(state, duration, |s| {
let pageserver_connstr = &s
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_connstr;
unchanged = pageserver_connstr == &old_pageserver_connstr;
unchanged
})
.unwrap();
if !unchanged {
info!("Pageserver config changed");
}
}
}
pub fn forward_termination_signal() {

View File

@@ -1,25 +0,0 @@
use anyhow::Context;
pub const DISK_QUOTA_BIN: &str = "/neonvm/bin/set-disk-quota";
/// If size_bytes is 0, it disables the quota. Otherwise, it sets filesystem quota to size_bytes.
/// `fs_mountpoint` should point to the mountpoint of the filesystem where the quota should be set.
pub fn set_disk_quota(size_bytes: u64, fs_mountpoint: &str) -> anyhow::Result<()> {
let size_kb = size_bytes / 1024;
// run `/neonvm/bin/set-disk-quota {size_kb} {mountpoint}`
let child_result = std::process::Command::new("/usr/bin/sudo")
.arg(DISK_QUOTA_BIN)
.arg(size_kb.to_string())
.arg(fs_mountpoint)
.spawn();
child_result
.context("spawn() failed")
.and_then(|mut child| child.wait().context("wait() failed"))
.and_then(|status| match status.success() {
true => Ok(()),
false => Err(anyhow::anyhow!("process exited with {status}")),
})
// wrap any prior error with the overall context that we couldn't run the command
.with_context(|| format!("could not run `/usr/bin/sudo {DISK_QUOTA_BIN}`"))
}

View File

@@ -10,7 +10,6 @@ pub mod http;
pub mod logger;
pub mod catalog;
pub mod compute;
pub mod disk_quota;
pub mod extension_server;
pub mod lsn_lease;
mod migration;

View File

@@ -57,10 +57,10 @@ fn lsn_lease_bg_task(
.max(valid_duration / 2);
info!(
"Request succeeded, sleeping for {} seconds",
"Succeeded, sleeping for {} seconds",
sleep_duration.as_secs()
);
compute.wait_timeout_while_pageserver_connstr_unchanged(sleep_duration);
thread::sleep(sleep_duration);
}
}
@@ -89,7 +89,10 @@ fn acquire_lsn_lease_with_retry(
.map(|connstr| {
let mut config = postgres::Config::from_str(connstr).expect("Invalid connstr");
if let Some(storage_auth_token) = &spec.storage_auth_token {
info!("Got storage auth token from spec file");
config.password(storage_auth_token.clone());
} else {
info!("Storage auth token not set");
}
config
})
@@ -105,11 +108,9 @@ fn acquire_lsn_lease_with_retry(
bail!("Permanent error: lease could not be obtained, LSN is behind the GC cutoff");
}
Err(e) => {
warn!("Failed to acquire lsn lease: {e} (attempt {attempts})");
warn!("Failed to acquire lsn lease: {e} (attempt {attempts}");
compute.wait_timeout_while_pageserver_connstr_unchanged(Duration::from_millis(
retry_period_ms as u64,
));
thread::sleep(Duration::from_millis(retry_period_ms as u64));
retry_period_ms *= 1.5;
retry_period_ms = retry_period_ms.min(MAX_RETRY_PERIOD_MS);
}

View File

@@ -9,7 +9,6 @@ anyhow.workspace = true
camino.workspace = true
clap.workspace = true
comfy-table.workspace = true
futures.workspace = true
humantime.workspace = true
nix.workspace = true
once_cell.workspace = true

File diff suppressed because it is too large Load Diff

View File

@@ -1,94 +0,0 @@
//! Branch mappings for convenience
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use anyhow::{bail, Context};
use serde::{Deserialize, Serialize};
use utils::id::{TenantId, TenantTimelineId, TimelineId};
/// Keep human-readable aliases in memory (and persist them to config XXX), to hide tenant/timeline hex strings from the user.
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct BranchMappings {
/// Default tenant ID to use with the 'neon_local' command line utility, when
/// --tenant_id is not explicitly specified. This comes from the branches.
pub default_tenant_id: Option<TenantId>,
// A `HashMap<String, HashMap<TenantId, TimelineId>>` would be more appropriate here,
// but deserialization into a generic toml object as `toml::Value::try_from` fails with an error.
// https://toml.io/en/v1.0.0 does not contain a concept of "a table inside another table".
pub mappings: HashMap<String, Vec<(TenantId, TimelineId)>>,
}
impl BranchMappings {
pub fn register_branch_mapping(
&mut self,
branch_name: String,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> anyhow::Result<()> {
let existing_values = self.mappings.entry(branch_name.clone()).or_default();
let existing_ids = existing_values
.iter()
.find(|(existing_tenant_id, _)| existing_tenant_id == &tenant_id);
if let Some((_, old_timeline_id)) = existing_ids {
if old_timeline_id == &timeline_id {
Ok(())
} else {
bail!("branch '{branch_name}' is already mapped to timeline {old_timeline_id}, cannot map to another timeline {timeline_id}");
}
} else {
existing_values.push((tenant_id, timeline_id));
Ok(())
}
}
pub fn get_branch_timeline_id(
&self,
branch_name: &str,
tenant_id: TenantId,
) -> Option<TimelineId> {
// If it looks like a timeline ID, return it as it is
if let Ok(timeline_id) = branch_name.parse::<TimelineId>() {
return Some(timeline_id);
}
self.mappings
.get(branch_name)?
.iter()
.find(|(mapped_tenant_id, _)| mapped_tenant_id == &tenant_id)
.map(|&(_, timeline_id)| timeline_id)
.map(TimelineId::from)
}
pub fn timeline_name_mappings(&self) -> HashMap<TenantTimelineId, String> {
self.mappings
.iter()
.flat_map(|(name, tenant_timelines)| {
tenant_timelines.iter().map(|&(tenant_id, timeline_id)| {
(TenantTimelineId::new(tenant_id, timeline_id), name.clone())
})
})
.collect()
}
pub fn persist(&self, path: &Path) -> anyhow::Result<()> {
let content = &toml::to_string_pretty(self)?;
fs::write(path, content).with_context(|| {
format!(
"Failed to write branch information into path '{}'",
path.display()
)
})
}
pub fn load(path: &Path) -> anyhow::Result<BranchMappings> {
let branches_file_contents = fs::read_to_string(path)?;
Ok(toml::from_str(branches_file_contents.as_str())?)
}
}

View File

@@ -561,7 +561,6 @@ impl Endpoint {
operation_uuid: None,
features: self.features.clone(),
swap_size_bytes: None,
disk_quota_bytes: None,
cluster: Cluster {
cluster_id: None, // project ID: not used
name: None, // project name: not used

View File

@@ -113,7 +113,7 @@ impl SafekeeperNode {
pub async fn start(
&self,
extra_opts: &[String],
extra_opts: Vec<String>,
retry_timeout: &Duration,
) -> anyhow::Result<()> {
print!(
@@ -196,7 +196,7 @@ impl SafekeeperNode {
]);
}
args.extend_from_slice(extra_opts);
args.extend(extra_opts);
background_process::start_process(
&format!("safekeeper-{id}"),

View File

@@ -347,7 +347,7 @@ impl StorageController {
if !tokio::fs::try_exists(&pg_data_path).await? {
let initdb_args = [
"--pgdata",
"-D",
pg_data_path.as_ref(),
"--username",
&username(),

View File

@@ -50,16 +50,6 @@ pub struct ComputeSpec {
#[serde(default)]
pub swap_size_bytes: Option<u64>,
/// If compute_ctl was passed `--set-disk-quota-for-fs`, a value of `Some(_)` instructs
/// compute_ctl to run `/neonvm/bin/set-disk-quota` with the given size and fs, when the
/// spec is first received.
///
/// Both this field and `--set-disk-quota-for-fs` are required, so that the control plane's
/// spec generation doesn't need to be aware of the actual compute it's running on, while
/// guaranteeing gradual rollout of disk quota.
#[serde(default)]
pub disk_quota_bytes: Option<u64>,
/// Expected cluster state at the end of transition process.
pub cluster: Cluster,
pub delta_operations: Option<Vec<DeltaOp>>,

View File

@@ -104,7 +104,7 @@ pub struct ConfigToml {
pub image_compression: ImageCompressionAlgorithm,
pub ephemeral_bytes_per_memory_kb: usize,
pub l0_flush: Option<crate::models::L0FlushConfig>,
pub virtual_file_io_mode: Option<crate::models::virtual_file::IoMode>,
pub virtual_file_direct_io: crate::models::virtual_file::DirectIoMode,
pub io_buffer_alignment: usize,
}
@@ -381,7 +381,7 @@ impl Default for ConfigToml {
image_compression: (DEFAULT_IMAGE_COMPRESSION),
ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB),
l0_flush: None,
virtual_file_io_mode: None,
virtual_file_direct_io: crate::models::virtual_file::DirectIoMode::default(),
io_buffer_alignment: DEFAULT_IO_BUFFER_ALIGNMENT,

View File

@@ -972,6 +972,8 @@ pub struct TopTenantShardsResponse {
}
pub mod virtual_file {
use std::path::PathBuf;
#[derive(
Copy,
Clone,
@@ -992,51 +994,50 @@ pub mod virtual_file {
}
/// Direct IO modes for a pageserver.
#[derive(
Copy,
Clone,
PartialEq,
Eq,
Hash,
strum_macros::EnumString,
strum_macros::Display,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
Debug,
)]
#[strum(serialize_all = "kebab-case")]
#[repr(u8)]
pub enum IoMode {
/// Uses buffered IO.
Buffered,
/// Uses direct IO, error out if the operation fails.
#[cfg(target_os = "linux")]
Direct,
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
#[serde(tag = "mode", rename_all = "kebab-case", deny_unknown_fields)]
pub enum DirectIoMode {
/// Direct IO disabled (uses usual buffered IO).
#[default]
Disabled,
/// Direct IO disabled (performs checks and perf simulations).
Evaluate {
/// Alignment check level
alignment_check: DirectIoAlignmentCheckLevel,
/// Latency padded for performance simulation.
latency_padding: DirectIoLatencyPadding,
},
/// Direct IO enabled.
Enabled {
/// Actions to perform on alignment error.
on_alignment_error: DirectIoOnAlignmentErrorAction,
},
}
impl IoMode {
#[cfg(target_os = "linux")]
pub const fn preferred() -> Self {
Self::Direct
}
#[cfg(target_os = "macos")]
pub const fn preferred() -> Self {
Self::Buffered
}
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum DirectIoAlignmentCheckLevel {
#[default]
Error,
Log,
None,
}
impl TryFrom<u8> for IoMode {
type Error = u8;
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum DirectIoOnAlignmentErrorAction {
Error,
#[default]
FallbackToBuffered,
}
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
v if v == (IoMode::Buffered as u8) => IoMode::Buffered,
#[cfg(target_os = "linux")]
v if v == (IoMode::Direct as u8) => IoMode::Direct,
x => return Err(x),
})
}
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum DirectIoLatencyPadding {
/// Pad virtual file operations with IO to a fake file.
FakeFileRW { path: PathBuf },
#[default]
None,
}
}

View File

@@ -93,9 +93,9 @@ impl Conf {
);
let output = self
.new_pg_command("initdb")?
.arg("--pgdata")
.arg("-D")
.arg(&self.datadir)
.args(["--username", "postgres", "--no-instructions", "--no-sync"])
.args(["-U", "postgres", "--no-instructions", "--no-sync"])
.output()?;
debug!("initdb output: {:?}", output);
ensure!(

View File

@@ -164,10 +164,12 @@ fn criterion_benchmark(c: &mut Criterion) {
let conf: &'static PageServerConf = Box::leak(Box::new(
pageserver::config::PageServerConf::dummy_conf(temp_dir.path().to_path_buf()),
));
let align = pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
virtual_file::init(16384, virtual_file::io_engine_for_bench(), align);
page_cache::init(conf.page_cache_size, align);
virtual_file::init(
16384,
virtual_file::io_engine_for_bench(),
pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT,
);
page_cache::init(conf.page_cache_size);
{
let mut group = c.benchmark_group("ingest-small-values");

View File

@@ -550,19 +550,6 @@ impl Client {
.map_err(Error::ReceiveBody)
}
/// Configs io mode at runtime.
pub async fn put_io_mode(
&self,
mode: &pageserver_api::models::virtual_file::IoMode,
) -> Result<()> {
let uri = format!("{}/v1/io_mode", self.mgmt_api_endpoint);
self.request(Method::PUT, uri, mode)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn get_utilization(&self) -> Result<PageserverUtilization> {
let uri = format!("{}/v1/utilization", self.mgmt_api_endpoint);
self.get(uri)
@@ -749,22 +736,4 @@ impl Client {
.await
.map_err(Error::ReceiveBody)
}
pub async fn timeline_init_lsn_lease(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Result<LsnLease> {
let uri = format!(
"{}/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/lsn_lease",
self.mgmt_api_endpoint,
);
self.request(Method::POST, &uri, LsnLeaseRequest { lsn })
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
}

View File

@@ -151,10 +151,13 @@ pub(crate) async fn main(cmd: &AnalyzeLayerMapCmd) -> Result<()> {
let max_holes = cmd.max_holes.unwrap_or(DEFAULT_MAX_HOLES);
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
let align = pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
// Initialize virtual_file (file desriptor cache) and page cache which are needed to access layer persistent B-Tree.
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, align);
pageserver::page_cache::init(100, align);
pageserver::virtual_file::init(
10,
virtual_file::api::IoEngineKind::StdFs,
pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT,
);
pageserver::page_cache::init(100);
let mut total_delta_layers = 0usize;
let mut total_image_layers = 0usize;

View File

@@ -59,9 +59,8 @@ pub(crate) enum LayerCmd {
async fn read_delta_file(path: impl AsRef<Path>, ctx: &RequestContext) -> Result<()> {
let path = Utf8Path::from_path(path.as_ref()).expect("non-Unicode path");
let align = pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, align);
page_cache::init(100, align);
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, 1);
page_cache::init(100);
let file = VirtualFile::open(path, ctx).await?;
let file_id = page_cache::next_file_id();
let block_reader = FileBlockReader::new(&file, file_id);
@@ -191,10 +190,12 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> {
new_tenant_id,
new_timeline_id,
} => {
let align = pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, align);
pageserver::page_cache::init(100, align);
pageserver::virtual_file::init(
10,
virtual_file::api::IoEngineKind::StdFs,
pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT,
);
pageserver::page_cache::init(100);
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);

View File

@@ -205,9 +205,12 @@ fn read_pg_control_file(control_file_path: &Utf8Path) -> anyhow::Result<()> {
async fn print_layerfile(path: &Utf8Path) -> anyhow::Result<()> {
// Basic initialization of things that don't change after startup
let align = DEFAULT_IO_BUFFER_ALIGNMENT;
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, align);
page_cache::init(100, align);
virtual_file::init(
10,
virtual_file::api::IoEngineKind::StdFs,
DEFAULT_IO_BUFFER_ALIGNMENT,
);
page_cache::init(100);
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
dump_layerfile_from_path(path, true, &ctx).await
}

View File

@@ -63,10 +63,6 @@ pub(crate) struct Args {
#[clap(long)]
set_io_alignment: Option<usize>,
/// Before starting the benchmark, live-reconfigure the pageserver to use specified io mode (buffered vs. direct).
#[clap(long)]
set_io_mode: Option<pageserver_api::models::virtual_file::IoMode>,
targets: Option<Vec<TenantTimelineId>>,
}
@@ -137,10 +133,6 @@ async fn main_impl(
mgmt_api_client.put_io_alignment(align).await?;
}
if let Some(mode) = &args.set_io_mode {
mgmt_api_client.put_io_mode(mode).await?;
}
// discover targets
let timelines: Vec<TenantTimelineId> = crate::util::cli::targets::discover(
&mgmt_api_client,

View File

@@ -125,7 +125,7 @@ fn main() -> anyhow::Result<()> {
// after setting up logging, log the effective IO engine choice and read path implementations
info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine");
info!(?conf.virtual_file_io_mode, "starting with virtual_file Direct IO settings");
info!(?conf.virtual_file_direct_io, "starting with virtual_file Direct IO settings");
info!(?conf.io_buffer_alignment, "starting with setting for IO buffer alignment");
// The tenants directory contains all the pageserver local disk state.
@@ -173,7 +173,7 @@ fn main() -> anyhow::Result<()> {
conf.virtual_file_io_engine,
conf.io_buffer_alignment,
);
page_cache::init(conf.page_cache_size, conf.io_buffer_alignment);
page_cache::init(conf.page_cache_size);
start_pageserver(launch_ts, conf).context("Failed to start pageserver")?;

View File

@@ -174,7 +174,7 @@ pub struct PageServerConf {
pub l0_flush: crate::l0_flush::L0FlushConfig,
/// Direct IO settings
pub virtual_file_io_mode: virtual_file::IoMode,
pub virtual_file_direct_io: virtual_file::DirectIoMode,
pub io_buffer_alignment: usize,
}
@@ -325,7 +325,7 @@ impl PageServerConf {
image_compression,
ephemeral_bytes_per_memory_kb,
l0_flush,
virtual_file_io_mode,
virtual_file_direct_io,
concurrent_tenant_warmup,
concurrent_tenant_size_logical_size_queries,
virtual_file_io_engine,
@@ -368,6 +368,7 @@ impl PageServerConf {
max_vectored_read_bytes,
image_compression,
ephemeral_bytes_per_memory_kb,
virtual_file_direct_io,
io_buffer_alignment,
// ------------------------------------------------------------
@@ -407,7 +408,6 @@ impl PageServerConf {
l0_flush: l0_flush
.map(crate::l0_flush::L0FlushConfig::from)
.unwrap_or_default(),
virtual_file_io_mode: virtual_file_io_mode.unwrap_or(virtual_file::IoMode::preferred()),
};
// ------------------------------------------------------------

View File

@@ -17,7 +17,6 @@ use hyper::header;
use hyper::StatusCode;
use hyper::{Body, Request, Response, Uri};
use metrics::launch_timestamp::LaunchTimestamp;
use pageserver_api::models::virtual_file::IoMode;
use pageserver_api::models::AuxFilePolicy;
use pageserver_api::models::DownloadRemoteLayersTaskSpawnRequest;
use pageserver_api::models::IngestAuxFilesRequest;
@@ -825,7 +824,7 @@ async fn get_lsn_by_timestamp_handler(
let lease = if with_lease {
timeline
.init_lsn_lease(lsn, timeline.get_lsn_lease_length_for_ts(), &ctx)
.make_lsn_lease(lsn, timeline.get_lsn_lease_length_for_ts(), &ctx)
.inspect_err(|_| {
warn!("fail to grant a lease to {}", lsn);
})
@@ -1693,18 +1692,9 @@ async fn lsn_lease_handler(
let timeline =
active_timeline_of_active_tenant(&state.tenant_manager, tenant_shard_id, timeline_id)
.await?;
let result = async {
timeline
.init_lsn_lease(lsn, timeline.get_lsn_lease_length(), &ctx)
.map_err(|e| {
ApiError::InternalServerError(
e.context(format!("invalid lsn lease request at {lsn}")),
)
})
}
.instrument(info_span!("init_lsn_lease", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %timeline_id))
.await?;
let result = timeline
.make_lsn_lease(lsn, timeline.get_lsn_lease_length(), &ctx)
.map_err(|e| ApiError::InternalServerError(e.context("lsn lease http handler")))?;
json_response(StatusCode::OK, result)
}
@@ -2382,16 +2372,6 @@ async fn put_io_alignment_handler(
json_response(StatusCode::OK, ())
}
async fn put_io_mode_handler(
mut r: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
check_permission(&r, None)?;
let mode: IoMode = json_request(&mut r).await?;
crate::virtual_file::set_io_mode(mode);
json_response(StatusCode::OK, ())
}
/// Polled by control plane.
///
/// See [`crate::utilization`].
@@ -3082,7 +3062,6 @@ pub fn make_router(
.put("/v1/io_alignment", |r| {
api_handler(r, put_io_alignment_handler)
})
.put("/v1/io_mode", |r| api_handler(r, put_io_mode_handler))
.put(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/force_aux_policy_switch",
|r| api_handler(r, force_aux_policy_switch_handler),

View File

@@ -82,7 +82,6 @@ use once_cell::sync::OnceCell;
use crate::{
context::RequestContext,
metrics::{page_cache_eviction_metrics, PageCacheSizeMetrics},
virtual_file::{self, dio::IoBufferMut},
};
static PAGE_CACHE: OnceCell<PageCache> = OnceCell::new();
@@ -91,8 +90,8 @@ const TEST_PAGE_CACHE_SIZE: usize = 50;
///
/// Initialize the page cache. This must be called once at page server startup.
///
pub fn init(size: usize, align: usize) {
if PAGE_CACHE.set(PageCache::new(size, align)).is_err() {
pub fn init(size: usize) {
if PAGE_CACHE.set(PageCache::new(size)).is_err() {
panic!("page cache already initialized");
}
}
@@ -107,12 +106,7 @@ pub fn get() -> &'static PageCache {
// page cache is usable in unit tests.
//
if cfg!(test) {
PAGE_CACHE.get_or_init(|| {
PageCache::new(
TEST_PAGE_CACHE_SIZE,
virtual_file::get_io_buffer_alignment(),
)
})
PAGE_CACHE.get_or_init(|| PageCache::new(TEST_PAGE_CACHE_SIZE))
} else {
PAGE_CACHE.get().expect("page cache not initialized")
}
@@ -643,11 +637,13 @@ impl PageCache {
/// Initialize a new page cache
///
/// This should be called only once at page server startup.
fn new(num_pages: usize, align: usize) -> Self {
fn new(num_pages: usize) -> Self {
assert!(num_pages > 0, "page cache size must be > 0");
let page_buffer =
IoBufferMut::with_capacity_aligned_zeroed(num_pages * PAGE_SZ, align).leak();
// We could use Vec::leak here, but that potentially also leaks
// uninitialized reserved capacity. With into_boxed_slice and Box::leak
// this is avoided.
let page_buffer = Box::leak(vec![0u8; num_pages * PAGE_SZ].into_boxed_slice());
let size_metrics = &crate::metrics::PAGE_CACHE_SIZE;
size_metrics.max_bytes.set_page_sz(num_pages);

View File

@@ -273,20 +273,10 @@ async fn page_service_conn_main(
info!("Postgres client disconnected ({io_error})");
Ok(())
} else {
let tenant_id = conn_handler.timeline_handles.tenant_id();
Err(io_error).context(format!(
"Postgres connection error for tenant_id={:?} client at peer_addr={}",
tenant_id, peer_addr
))
Err(io_error).context("Postgres connection error")
}
}
other => {
let tenant_id = conn_handler.timeline_handles.tenant_id();
other.context(format!(
"Postgres query error for tenant_id={:?} client peer_addr={}",
tenant_id, peer_addr
))
}
other => other.context("Postgres query error"),
}
}
@@ -350,10 +340,6 @@ impl TimelineHandles {
}
})
}
fn tenant_id(&self) -> Option<TenantId> {
self.wrapper.tenant_id.get().copied()
}
}
pub(crate) struct TenantManagerWrapper {
@@ -833,7 +819,7 @@ impl PageServerHandler {
set_tracing_field_shard_id(&timeline);
let lease = timeline
.renew_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx)
.make_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx)
.inspect_err(|e| {
warn!("{e}");
})

View File

@@ -21,7 +21,6 @@ use futures::stream::FuturesUnordered;
use futures::StreamExt;
use pageserver_api::models;
use pageserver_api::models::AuxFilePolicy;
use pageserver_api::models::LsnLease;
use pageserver_api::models::TimelineArchivalState;
use pageserver_api::models::TimelineState;
use pageserver_api::models::TopTenantShardItem;
@@ -183,54 +182,27 @@ pub struct TenantSharedResources {
pub(super) struct AttachedTenantConf {
tenant_conf: TenantConfOpt,
location: AttachedLocationConfig,
/// The deadline before which we are blocked from GC so that
/// leases have a chance to be renewed.
lsn_lease_deadline: Option<tokio::time::Instant>,
}
impl AttachedTenantConf {
fn new(tenant_conf: TenantConfOpt, location: AttachedLocationConfig) -> Self {
// Sets a deadline before which we cannot proceed to GC due to lsn lease.
//
// We do this as the leases mapping are not persisted to disk. By delaying GC by lease
// length, we guarantee that all the leases we granted before will have a chance to renew
// when we run GC for the first time after restart / transition from AttachedMulti to AttachedSingle.
let lsn_lease_deadline = if location.attach_mode == AttachmentMode::Single {
Some(
tokio::time::Instant::now()
+ tenant_conf
.lsn_lease_length
.unwrap_or(LsnLease::DEFAULT_LENGTH),
)
} else {
// We don't use `lsn_lease_deadline` to delay GC in AttachedMulti and AttachedStale
// because we don't do GC in these modes.
None
};
Self {
tenant_conf,
location,
lsn_lease_deadline,
}
}
fn try_from(location_conf: LocationConf) -> anyhow::Result<Self> {
match &location_conf.mode {
LocationMode::Attached(attach_conf) => {
Ok(Self::new(location_conf.tenant_conf, *attach_conf))
}
LocationMode::Attached(attach_conf) => Ok(Self {
tenant_conf: location_conf.tenant_conf,
location: *attach_conf,
}),
LocationMode::Secondary(_) => {
anyhow::bail!("Attempted to construct AttachedTenantConf from a LocationConf in secondary mode")
}
}
}
fn is_gc_blocked_by_lsn_lease_deadline(&self) -> bool {
self.lsn_lease_deadline
.map(|d| tokio::time::Instant::now() < d)
.unwrap_or(false)
}
}
struct TimelinePreload {
timeline_id: TimelineId,
@@ -1850,11 +1822,6 @@ impl Tenant {
info!("Skipping GC in location state {:?}", conf.location);
return Ok(GcResult::default());
}
if conf.is_gc_blocked_by_lsn_lease_deadline() {
info!("Skipping GC because lsn lease deadline is not reached");
return Ok(GcResult::default());
}
}
let _guard = match self.gc_block.start().await {
@@ -2663,8 +2630,6 @@ impl Tenant {
Arc::new(AttachedTenantConf {
tenant_conf: new_tenant_conf.clone(),
location: inner.location,
// Attached location is not changed, no need to update lsn lease deadline.
lsn_lease_deadline: inner.lsn_lease_deadline,
})
});
@@ -3922,9 +3887,9 @@ async fn run_initdb(
let _permit = INIT_DB_SEMAPHORE.acquire().await;
let initdb_command = tokio::process::Command::new(&initdb_bin_path)
.args(["--pgdata", initdb_target_dir.as_ref()])
.args(["--username", &conf.superuser])
.args(["--encoding", "utf8"])
.args(["-D", initdb_target_dir.as_ref()])
.args(["-U", &conf.superuser])
.args(["-E", "utf8"])
.arg("--no-instructions")
.arg("--no-sync")
.env_clear()
@@ -4496,17 +4461,13 @@ mod tests {
tline.freeze_and_flush().await.map_err(|e| e.into())
}
#[tokio::test(start_paused = true)]
#[tokio::test]
async fn test_prohibit_branch_creation_on_garbage_collected_data() -> anyhow::Result<()> {
let (tenant, ctx) =
TenantHarness::create("test_prohibit_branch_creation_on_garbage_collected_data")
.await?
.load()
.await;
// Advance to the lsn lease deadline so that GC is not blocked by
// initial transition into AttachedSingle.
tokio::time::advance(tenant.get_lsn_lease_length()).await;
tokio::time::resume();
let tline = tenant
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
.await?;
@@ -7283,17 +7244,9 @@ mod tests {
Ok(())
}
#[tokio::test(start_paused = true)]
#[tokio::test]
async fn test_lsn_lease() -> anyhow::Result<()> {
let (tenant, ctx) = TenantHarness::create("test_lsn_lease")
.await
.unwrap()
.load()
.await;
// Advance to the lsn lease deadline so that GC is not blocked by
// initial transition into AttachedSingle.
tokio::time::advance(tenant.get_lsn_lease_length()).await;
tokio::time::resume();
let (tenant, ctx) = TenantHarness::create("test_lsn_lease").await?.load().await;
let key = Key::from_hex("010000000033333333444444445500000000").unwrap();
let end_lsn = Lsn(0x100);
@@ -7321,33 +7274,24 @@ mod tests {
let leased_lsns = [0x30, 0x50, 0x70];
let mut leases = Vec::new();
leased_lsns.iter().for_each(|n| {
leases.push(
timeline
.init_lsn_lease(Lsn(*n), timeline.get_lsn_lease_length(), &ctx)
.expect("lease request should succeed"),
);
let _: anyhow::Result<_> = leased_lsns.iter().try_for_each(|n| {
leases.push(timeline.make_lsn_lease(Lsn(*n), timeline.get_lsn_lease_length(), &ctx)?);
Ok(())
});
let updated_lease_0 = timeline
.renew_lsn_lease(Lsn(leased_lsns[0]), Duration::from_secs(0), &ctx)
.expect("lease renewal should succeed");
assert_eq!(
updated_lease_0.valid_until, leases[0].valid_until,
" Renewing with shorter lease should not change the lease."
);
// Renewing with shorter lease should not change the lease.
let updated_lease_0 =
timeline.make_lsn_lease(Lsn(leased_lsns[0]), Duration::from_secs(0), &ctx)?;
assert_eq!(updated_lease_0.valid_until, leases[0].valid_until);
let updated_lease_1 = timeline
.renew_lsn_lease(
Lsn(leased_lsns[1]),
timeline.get_lsn_lease_length() * 2,
&ctx,
)
.expect("lease renewal should succeed");
assert!(
updated_lease_1.valid_until > leases[1].valid_until,
"Renewing with a long lease should renew lease with later expiration time."
);
// Renewing with a long lease should renew lease with later expiration time.
let updated_lease_1 = timeline.make_lsn_lease(
Lsn(leased_lsns[1]),
timeline.get_lsn_lease_length() * 2,
&ctx,
)?;
assert!(updated_lease_1.valid_until > leases[1].valid_until);
// Force set disk consistent lsn so we can get the cutoff at `end_lsn`.
info!(
@@ -7364,8 +7308,7 @@ mod tests {
&CancellationToken::new(),
&ctx,
)
.await
.unwrap();
.await?;
// Keeping everything <= Lsn(0x80) b/c leases:
// 0/10: initdb layer
@@ -7379,16 +7322,13 @@ mod tests {
// Make lease on a already GC-ed LSN.
// 0/80 does not have a valid lease + is below latest_gc_cutoff
assert!(Lsn(0x80) < *timeline.get_latest_gc_cutoff_lsn());
timeline
.init_lsn_lease(Lsn(0x80), timeline.get_lsn_lease_length(), &ctx)
.expect_err("lease request on GC-ed LSN should fail");
let res = timeline.make_lsn_lease(Lsn(0x80), timeline.get_lsn_lease_length(), &ctx);
assert!(res.is_err());
// Should still be able to renew a currently valid lease
// Assumption: original lease to is still valid for 0/50.
// (use `Timeline::init_lsn_lease` for testing so it always does validation)
timeline
.init_lsn_lease(Lsn(leased_lsns[1]), timeline.get_lsn_lease_length(), &ctx)
.expect("lease renewal with validation should succeed");
let _ =
timeline.make_lsn_lease(Lsn(leased_lsns[1]), timeline.get_lsn_lease_length(), &ctx)?;
Ok(())
}

View File

@@ -5,8 +5,6 @@
use super::storage_layer::delta_layer::{Adapter, DeltaLayerInner};
use crate::context::RequestContext;
use crate::page_cache::{self, FileId, PageReadGuard, PageWriteGuard, ReadBufResult, PAGE_SZ};
#[cfg(test)]
use crate::virtual_file::dio::IoBufferMut;
use crate::virtual_file::VirtualFile;
use bytes::Bytes;
use std::ops::Deref;
@@ -42,7 +40,7 @@ pub enum BlockLease<'a> {
#[cfg(test)]
Arc(std::sync::Arc<[u8; PAGE_SZ]>),
#[cfg(test)]
IoBufferMut(IoBufferMut),
Vec(Vec<u8>),
}
impl From<PageReadGuard<'static>> for BlockLease<'static> {
@@ -69,7 +67,7 @@ impl<'a> Deref for BlockLease<'a> {
#[cfg(test)]
BlockLease::Arc(v) => v.deref(),
#[cfg(test)]
BlockLease::IoBufferMut(v) => {
BlockLease::Vec(v) => {
TryFrom::try_from(&v[..]).expect("caller must ensure that v has PAGE_SZ")
}
}

View File

@@ -6,8 +6,6 @@ use crate::config::PageServerConf;
use crate::context::RequestContext;
use crate::page_cache;
use crate::tenant::storage_layer::inmemory_layer::vectored_dio_read::File;
use crate::virtual_file::dio::IoBufferMut;
use crate::virtual_file::owned_buffers_io::io_buf_aligned::IoBufAlignedMut;
use crate::virtual_file::owned_buffers_io::slice::SliceMutExt;
use crate::virtual_file::owned_buffers_io::util::size_tracking_writer;
use crate::virtual_file::owned_buffers_io::write::Buffer;
@@ -86,7 +84,7 @@ impl Drop for EphemeralFile {
fn drop(&mut self) {
// unlink the file
// we are clear to do this, because we have entered a gate
let path = self.buffered_writer.as_inner().as_inner().path();
let path = &self.buffered_writer.as_inner().as_inner().path;
let res = std::fs::remove_file(path);
if let Err(e) = res {
if e.kind() != std::io::ErrorKind::NotFound {
@@ -109,16 +107,15 @@ impl EphemeralFile {
self.page_cache_file_id
}
pub(crate) async fn load_to_buf(&self, ctx: &RequestContext) -> Result<IoBufferMut, io::Error> {
pub(crate) async fn load_to_vec(&self, ctx: &RequestContext) -> Result<Vec<u8>, io::Error> {
let size = self.len().into_usize();
let align = virtual_file::get_io_buffer_alignment();
let buf = IoBufferMut::with_capacity_aligned(size, align);
let (slice, nread) = self.read_exact_at_eof_ok(0, buf.slice_full(), ctx).await?;
let vec = Vec::with_capacity(size);
let (slice, nread) = self.read_exact_at_eof_ok(0, vec.slice_full(), ctx).await?;
assert_eq!(nread, size);
let buf = slice.into_inner();
assert_eq!(buf.len(), nread);
assert_eq!(buf.capacity(), size, "we shouldn't be reallocating");
Ok(buf)
let vec = slice.into_inner();
assert_eq!(vec.len(), nread);
assert_eq!(vec.capacity(), size, "we shouldn't be reallocating");
Ok(vec)
}
/// Returns the offset at which the first byte of the input was written, for use
@@ -161,7 +158,7 @@ impl EphemeralFile {
}
impl super::storage_layer::inmemory_layer::vectored_dio_read::File for EphemeralFile {
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufAlignedMut + Send>(
async fn read_exact_at_eof_ok<'a, 'b, B: tokio_epoll_uring::IoBufMut + Send>(
&'b self,
start: u64,
dst: tokio_epoll_uring::Slice<B>,
@@ -346,10 +343,9 @@ mod tests {
}
assert!(file.len() as usize == write_nbytes);
let align = virtual_file::get_io_buffer_alignment();
for i in 0..write_nbytes {
assert_eq!(value_offsets[i], i.into_u64());
let buf = IoBufferMut::with_capacity_aligned(1, align);
let buf = Vec::with_capacity(1);
let (buf_slice, nread) = file
.read_exact_at_eof_ok(i.into_u64(), buf.slice_full(), &ctx)
.await
@@ -360,7 +356,7 @@ mod tests {
}
let file_contents =
std::fs::read(file.buffered_writer.as_inner().as_inner().path()).unwrap();
std::fs::read(&file.buffered_writer.as_inner().as_inner().path).unwrap();
assert_eq!(file_contents, &content[0..cap]);
let buffer_contents = file.buffered_writer.inspect_buffer();
@@ -389,14 +385,14 @@ mod tests {
// assert the state is as this test expects it to be
assert_eq!(
&file.load_to_buf(&ctx).await.unwrap()[..],
&file.load_to_vec(&ctx).await.unwrap(),
&content[0..cap + cap / 2]
);
let md = file
.buffered_writer
.as_inner()
.as_inner()
.path()
.path
.metadata()
.unwrap();
assert_eq!(
@@ -444,17 +440,13 @@ mod tests {
let (buf, nread) = file
.read_exact_at_eof_ok(
start.into_u64(),
IoBufferMut::with_capacity_aligned(
len,
virtual_file::get_io_buffer_alignment(),
)
.slice_full(),
Vec::with_capacity(len).slice_full(),
ctx,
)
.await
.unwrap();
assert_eq!(nread, len);
assert_eq!(&buf.into_inner()[..], &content[start..(start + len)]);
assert_eq!(&buf.into_inner(), &content[start..(start + len)]);
}
};

View File

@@ -1,12 +1,29 @@
use std::collections::HashMap;
use utils::id::TimelineId;
use std::{collections::HashMap, time::Duration};
use super::remote_timeline_client::index::GcBlockingReason;
use tokio::time::Instant;
use utils::id::TimelineId;
type Storage = HashMap<TimelineId, enumset::EnumSet<GcBlockingReason>>;
type TimelinesBlocked = HashMap<TimelineId, enumset::EnumSet<GcBlockingReason>>;
/// GcBlock provides persistent (per-timeline) gc blocking.
#[derive(Default)]
struct Storage {
timelines_blocked: TimelinesBlocked,
/// The deadline before which we are blocked from GC so that
/// leases have a chance to be renewed.
lsn_lease_deadline: Option<Instant>,
}
impl Storage {
fn is_blocked_by_lsn_lease_deadline(&self) -> bool {
self.lsn_lease_deadline
.map(|d| Instant::now() < d)
.unwrap_or(false)
}
}
/// GcBlock provides persistent (per-timeline) gc blocking and facilitates transient time based gc
/// blocking.
#[derive(Default)]
pub(crate) struct GcBlock {
/// The timelines which have current reasons to block gc.
@@ -49,6 +66,17 @@ impl GcBlock {
}
}
/// Sets a deadline before which we cannot proceed to GC due to lsn lease.
///
/// We do this as the leases mapping are not persisted to disk. By delaying GC by lease
/// length, we guarantee that all the leases we granted before will have a chance to renew
/// when we run GC for the first time after restart / transition from AttachedMulti to AttachedSingle.
pub(super) fn set_lsn_lease_deadline(&self, lsn_lease_length: Duration) {
let deadline = Instant::now() + lsn_lease_length;
let mut g = self.reasons.lock().unwrap();
g.lsn_lease_deadline = Some(deadline);
}
/// Describe the current gc blocking reasons.
///
/// TODO: make this json serializable.
@@ -74,7 +102,7 @@ impl GcBlock {
) -> anyhow::Result<bool> {
let (added, uploaded) = {
let mut g = self.reasons.lock().unwrap();
let set = g.entry(timeline.timeline_id).or_default();
let set = g.timelines_blocked.entry(timeline.timeline_id).or_default();
let added = set.insert(reason);
// LOCK ORDER: intentionally hold the lock, see self.reasons.
@@ -105,7 +133,7 @@ impl GcBlock {
let (remaining_blocks, uploaded) = {
let mut g = self.reasons.lock().unwrap();
match g.entry(timeline.timeline_id) {
match g.timelines_blocked.entry(timeline.timeline_id) {
Entry::Occupied(mut oe) => {
let set = oe.get_mut();
set.remove(reason);
@@ -119,7 +147,7 @@ impl GcBlock {
}
}
let remaining_blocks = g.len();
let remaining_blocks = g.timelines_blocked.len();
// LOCK ORDER: intentionally hold the lock while scheduling; see self.reasons
let uploaded = timeline
@@ -144,11 +172,11 @@ impl GcBlock {
pub(crate) fn before_delete(&self, timeline: &super::Timeline) {
let unblocked = {
let mut g = self.reasons.lock().unwrap();
if g.is_empty() {
if g.timelines_blocked.is_empty() {
return;
}
g.remove(&timeline.timeline_id);
g.timelines_blocked.remove(&timeline.timeline_id);
BlockingReasons::clean_and_summarize(g).is_none()
};
@@ -159,10 +187,11 @@ impl GcBlock {
}
/// Initialize with the non-deleted timelines of this tenant.
pub(crate) fn set_scanned(&self, scanned: Storage) {
pub(crate) fn set_scanned(&self, scanned: TimelinesBlocked) {
let mut g = self.reasons.lock().unwrap();
assert!(g.is_empty());
g.extend(scanned.into_iter().filter(|(_, v)| !v.is_empty()));
assert!(g.timelines_blocked.is_empty());
g.timelines_blocked
.extend(scanned.into_iter().filter(|(_, v)| !v.is_empty()));
if let Some(reasons) = BlockingReasons::clean_and_summarize(g) {
tracing::info!(summary=?reasons, "initialized with gc blocked");
@@ -176,6 +205,7 @@ pub(super) struct Guard<'a> {
#[derive(Debug)]
pub(crate) struct BlockingReasons {
tenant_blocked_by_lsn_lease_deadline: bool,
timelines: usize,
reasons: enumset::EnumSet<GcBlockingReason>,
}
@@ -184,8 +214,8 @@ impl std::fmt::Display for BlockingReasons {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} timelines block for {:?}",
self.timelines, self.reasons
"tenant_blocked_by_lsn_lease_deadline: {}, {} timelines block for {:?}",
self.tenant_blocked_by_lsn_lease_deadline, self.timelines, self.reasons
)
}
}
@@ -193,13 +223,15 @@ impl std::fmt::Display for BlockingReasons {
impl BlockingReasons {
fn clean_and_summarize(mut g: std::sync::MutexGuard<'_, Storage>) -> Option<Self> {
let mut reasons = enumset::EnumSet::empty();
g.retain(|_key, value| {
g.timelines_blocked.retain(|_key, value| {
reasons = reasons.union(*value);
!value.is_empty()
});
if !g.is_empty() {
let blocked_by_lsn_lease_deadline = g.is_blocked_by_lsn_lease_deadline();
if !g.timelines_blocked.is_empty() || blocked_by_lsn_lease_deadline {
Some(BlockingReasons {
timelines: g.len(),
tenant_blocked_by_lsn_lease_deadline: blocked_by_lsn_lease_deadline,
timelines: g.timelines_blocked.len(),
reasons,
})
} else {
@@ -208,14 +240,17 @@ impl BlockingReasons {
}
fn summarize(g: &std::sync::MutexGuard<'_, Storage>) -> Option<Self> {
if g.is_empty() {
let blocked_by_lsn_lease_deadline = g.is_blocked_by_lsn_lease_deadline();
if g.timelines_blocked.is_empty() && !blocked_by_lsn_lease_deadline {
None
} else {
let reasons = g
.timelines_blocked
.values()
.fold(enumset::EnumSet::empty(), |acc, next| acc.union(*next));
Some(BlockingReasons {
timelines: g.len(),
tenant_blocked_by_lsn_lease_deadline: blocked_by_lsn_lease_deadline,
timelines: g.timelines_blocked.len(),
reasons,
})
}

View File

@@ -219,11 +219,7 @@ async fn safe_rename_tenant_dir(path: impl AsRef<Utf8Path>) -> std::io::Result<U
+ TEMP_FILE_SUFFIX;
let tmp_path = path_with_suffix_extension(&path, &rand_suffix);
fs::rename(path.as_ref(), &tmp_path).await?;
fs::File::open(parent)
.await?
.sync_all()
.await
.maybe_fatal_err("safe_rename_tenant_dir")?;
fs::File::open(parent).await?.sync_all().await?;
Ok(tmp_path)
}
@@ -953,6 +949,12 @@ impl TenantManager {
(LocationMode::Attached(attach_conf), Some(TenantSlot::Attached(tenant))) => {
match attach_conf.generation.cmp(&tenant.generation) {
Ordering::Equal => {
if attach_conf.attach_mode == AttachmentMode::Single {
tenant
.gc_block
.set_lsn_lease_deadline(tenant.get_lsn_lease_length());
}
// A transition from Attached to Attached in the same generation, we may
// take our fast path and just provide the updated configuration
// to the tenant.

View File

@@ -178,7 +178,6 @@ async fn download_object<'a>(
destination_file
.flush()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("flush source file at {dst_path}"))
.map_err(DownloadError::Other)?;
@@ -186,7 +185,6 @@ async fn download_object<'a>(
destination_file
.sync_all()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("failed to fsync source file at {dst_path}"))
.map_err(DownloadError::Other)?;
@@ -234,7 +232,6 @@ async fn download_object<'a>(
destination_file
.sync_all()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("failed to fsync source file at {dst_path}"))
.map_err(DownloadError::Other)?;

View File

@@ -40,15 +40,15 @@ use crate::tenant::storage_layer::layer::S3_UPLOAD_LIMIT;
use crate::tenant::timeline::GetVectoredError;
use crate::tenant::vectored_blob_io::{
BlobFlag, BufView, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
VectoredReadPlanner,
VectoredReadCoalesceMode, VectoredReadPlanner,
};
use crate::tenant::PageReconstructError;
use crate::virtual_file::dio::IoBufferMut;
use crate::virtual_file::owned_buffers_io::io_buf_ext::{FullSlice, IoBufExt};
use crate::virtual_file::{self, MaybeFatalIo, VirtualFile};
use crate::virtual_file::{self, VirtualFile};
use crate::{walrecord, TEMP_FILE_SUFFIX};
use crate::{DELTA_FILE_MAGIC, STORAGE_FORMAT_VERSION};
use anyhow::{anyhow, bail, ensure, Context, Result};
use bytes::BytesMut;
use camino::{Utf8Path, Utf8PathBuf};
use futures::StreamExt;
use itertools::Itertools;
@@ -572,7 +572,7 @@ impl DeltaLayerWriterInner {
ensure!(
metadata.len() <= S3_UPLOAD_LIMIT,
"Created delta layer file at {} of size {} above limit {S3_UPLOAD_LIMIT}!",
file.path(),
file.path,
metadata.len()
);
@@ -589,9 +589,7 @@ impl DeltaLayerWriterInner {
);
// fsync the file
file.sync_all()
.await
.maybe_fatal_err("delta_layer sync_all")?;
file.sync_all().await?;
trace!("created delta layer {}", self.path);
@@ -790,7 +788,7 @@ impl DeltaLayerInner {
max_vectored_read_bytes: Option<MaxVectoredReadBytes>,
ctx: &RequestContext,
) -> anyhow::Result<Self> {
let file = VirtualFile::open_v2(path, ctx)
let file = VirtualFile::open(path, ctx)
.await
.context("open layer file")?;
@@ -991,8 +989,7 @@ impl DeltaLayerInner {
.0
.into();
let buf_size = Self::get_min_read_buffer_size(&reads, max_vectored_read_bytes);
let align = virtual_file::get_io_buffer_alignment();
let mut buf = Some(IoBufferMut::with_capacity_aligned(buf_size, align));
let mut buf = Some(BytesMut::with_capacity(buf_size));
// Note that reads are processed in reverse order (from highest key+lsn).
// This is the order that `ReconstructState` requires such that it can
@@ -1011,7 +1008,7 @@ impl DeltaLayerInner {
blob_meta.key,
PageReconstructError::Other(anyhow!(
"Failed to read blobs from virtual file {}: {}",
self.file.path(),
self.file.path,
kind
)),
);
@@ -1019,7 +1016,7 @@ impl DeltaLayerInner {
// We have "lost" the buffer since the lower level IO api
// doesn't return the buffer on error. Allocate a new one.
buf = Some(IoBufferMut::with_capacity_aligned(buf_size, align));
buf = Some(BytesMut::with_capacity(buf_size));
continue;
}
@@ -1037,7 +1034,7 @@ impl DeltaLayerInner {
meta.meta.key,
PageReconstructError::Other(anyhow!(e).context(format!(
"Failed to decompress blob from virtual file {}",
self.file.path(),
self.file.path,
))),
);
@@ -1055,7 +1052,7 @@ impl DeltaLayerInner {
meta.meta.key,
PageReconstructError::Other(anyhow!(e).context(format!(
"Failed to deserialize blob from virtual file {}",
self.file.path(),
self.file.path,
))),
);
@@ -1136,7 +1133,7 @@ impl DeltaLayerInner {
ctx: &RequestContext,
) -> anyhow::Result<usize> {
use crate::tenant::vectored_blob_io::{
BlobMeta, ChunkedVectoredReadBuilder, VectoredReadExtended,
BlobMeta, VectoredReadBuilder, VectoredReadExtended,
};
use futures::stream::TryStreamExt;
@@ -1186,15 +1183,15 @@ impl DeltaLayerInner {
let mut prev: Option<(Key, Lsn, BlobRef)> = None;
let mut read_builder: Option<ChunkedVectoredReadBuilder> = None;
let mut read_builder: Option<VectoredReadBuilder> = None;
let read_mode = VectoredReadCoalesceMode::get();
let max_read_size = self
.max_vectored_read_bytes
.map(|x| x.0.get())
.unwrap_or(8192);
let align = virtual_file::get_io_buffer_alignment();
let mut buffer = Some(IoBufferMut::with_capacity_aligned(max_read_size, align));
let mut buffer = Some(BytesMut::with_capacity(max_read_size));
// FIXME: buffering of DeltaLayerWriter
let mut per_blob_copy = Vec::new();
@@ -1231,12 +1228,12 @@ impl DeltaLayerInner {
{
None
} else {
read_builder.replace(ChunkedVectoredReadBuilder::new(
read_builder.replace(VectoredReadBuilder::new(
offsets.start.pos(),
offsets.end.pos(),
meta,
max_read_size,
align,
read_mode,
))
}
} else {
@@ -1553,12 +1550,12 @@ impl<'a> DeltaLayerIterator<'a> {
let vectored_blob_reader = VectoredBlobReader::new(&self.delta_layer.file);
let mut next_batch = std::collections::VecDeque::new();
let buf_size = plan.size();
let align = virtual_file::get_io_buffer_alignment();
let buf = IoBufferMut::with_capacity_aligned(buf_size, align);
let buf = BytesMut::with_capacity(buf_size);
let blobs_buf = vectored_blob_reader
.read_blobs(&plan, buf, self.ctx)
.await?;
let view = BufView::new_slice(&blobs_buf.buf);
let frozen_buf = blobs_buf.buf.freeze();
let view = BufView::new_bytes(frozen_buf);
for meta in blobs_buf.blobs.iter() {
let blob_read = meta.read(&view).await?;
let value = Value::des(&blob_read)?;
@@ -1933,9 +1930,7 @@ pub(crate) mod test {
&vectored_reads,
constants::MAX_VECTORED_READ_BYTES,
);
let align = virtual_file::get_io_buffer_alignment();
let mut buf = Some(IoBufferMut::with_capacity_aligned(buf_size, align));
let mut buf = Some(BytesMut::with_capacity(buf_size));
for read in vectored_reads {
let blobs_buf = vectored_blob_reader

View File

@@ -40,12 +40,11 @@ use crate::tenant::vectored_blob_io::{
VectoredReadPlanner,
};
use crate::tenant::PageReconstructError;
use crate::virtual_file::dio::IoBufferMut;
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
use crate::virtual_file::{self, MaybeFatalIo, VirtualFile};
use crate::virtual_file::{self, VirtualFile};
use crate::{IMAGE_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX};
use anyhow::{anyhow, bail, ensure, Context, Result};
use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use camino::{Utf8Path, Utf8PathBuf};
use hex;
use itertools::Itertools;
@@ -389,7 +388,7 @@ impl ImageLayerInner {
max_vectored_read_bytes: Option<MaxVectoredReadBytes>,
ctx: &RequestContext,
) -> anyhow::Result<Self> {
let file = VirtualFile::open_v2(path, ctx)
let file = VirtualFile::open(path, ctx)
.await
.context("open layer file")?;
let file_id = page_cache::next_file_id();
@@ -543,15 +542,14 @@ impl ImageLayerInner {
.await?;
let vectored_blob_reader = VectoredBlobReader::new(&self.file);
let align = virtual_file::get_io_buffer_alignment();
let mut key_count = 0;
for read in plan.into_iter() {
let buf_size = read.size();
let buf = IoBufferMut::with_capacity_aligned(buf_size, align);
let buf = BytesMut::with_capacity(buf_size);
let blobs_buf = vectored_blob_reader.read_blobs(&read, buf, ctx).await?;
let view = BufView::new_slice(&blobs_buf.buf);
let frozen_buf = blobs_buf.buf.freeze();
let view = BufView::new_bytes(frozen_buf);
for meta in blobs_buf.blobs.iter() {
let img_buf = meta.read(&view).await?;
@@ -599,13 +597,13 @@ impl ImageLayerInner {
);
}
let align = virtual_file::get_io_buffer_alignment();
let buf = IoBufferMut::with_capacity_aligned(buf_size, align);
let buf = BytesMut::with_capacity(buf_size);
let res = vectored_blob_reader.read_blobs(&read, buf, ctx).await;
match res {
Ok(blobs_buf) => {
let view = BufView::new_slice(&blobs_buf.buf);
let frozen_buf = blobs_buf.buf.freeze();
let view = BufView::new_bytes(frozen_buf);
for meta in blobs_buf.blobs.iter() {
let img_buf = meta.read(&view).await;
@@ -616,7 +614,7 @@ impl ImageLayerInner {
meta.meta.key,
PageReconstructError::Other(anyhow!(e).context(format!(
"Failed to decompress blob from virtual file {}",
self.file.path(),
self.file.path,
))),
);
@@ -637,7 +635,7 @@ impl ImageLayerInner {
blob_meta.key,
PageReconstructError::from(anyhow!(
"Failed to read blobs from virtual file {}: {}",
self.file.path(),
self.file.path,
kind
)),
);
@@ -891,9 +889,7 @@ impl ImageLayerWriterInner {
// set inner.file here. The first read will have to re-open it.
// fsync the file
file.sync_all()
.await
.maybe_fatal_err("image_layer sync_all")?;
file.sync_all().await?;
trace!("created image layer {}", self.path);
@@ -1041,12 +1037,12 @@ impl<'a> ImageLayerIterator<'a> {
let vectored_blob_reader = VectoredBlobReader::new(&self.image_layer.file);
let mut next_batch = std::collections::VecDeque::new();
let buf_size = plan.size();
let align = virtual_file::get_io_buffer_alignment();
let buf = IoBufferMut::with_capacity_aligned(buf_size, align);
let buf = BytesMut::with_capacity(buf_size);
let blobs_buf = vectored_blob_reader
.read_blobs(&plan, buf, self.ctx)
.await?;
let view = BufView::new_slice(&blobs_buf.buf);
let frozen_buf = blobs_buf.buf.freeze();
let view = BufView::new_bytes(frozen_buf);
for meta in blobs_buf.blobs.iter() {
let img_buf = meta.read(&view).await?;
next_batch.push_back((

View File

@@ -809,9 +809,9 @@ impl InMemoryLayer {
match l0_flush_global_state {
l0_flush::Inner::Direct { .. } => {
let file_contents = inner.file.load_to_buf(ctx).await?;
let file_contents: Vec<u8> = inner.file.load_to_vec(ctx).await?;
let file_contents = Bytes::copy_from_slice(&file_contents[..]);
let file_contents = Bytes::from(file_contents);
for (key, vec_map) in inner.index.iter() {
// Write all page versions

View File

@@ -9,7 +9,6 @@ use tokio_epoll_uring::{BoundedBuf, IoBufMut, Slice};
use crate::{
assert_u64_eq_usize::{U64IsUsize, UsizeIsU64},
context::RequestContext,
virtual_file::{self, dio::IoBufferMut, owned_buffers_io::io_buf_aligned::IoBufAlignedMut},
};
/// The file interface we require. At runtime, this is a [`crate::tenant::ephemeral_file::EphemeralFile`].
@@ -25,7 +24,7 @@ pub trait File: Send {
/// [`std::io::ErrorKind::UnexpectedEof`] error if the file is shorter than `start+dst.len()`.
///
/// No guarantees are made about the remaining bytes in `dst` in case of a short read.
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufAlignedMut + Send>(
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufMut + Send>(
&'b self,
start: u64,
dst: Slice<B>,
@@ -228,12 +227,7 @@ where
// Execute physical reads and fill the logical read buffers
// TODO: pipelined reads; prefetch;
let get_io_buffer = |nchunks| {
IoBufferMut::with_capacity_aligned(
nchunks * DIO_CHUNK_SIZE,
virtual_file::get_io_buffer_alignment(),
)
};
let get_io_buffer = |nchunks| Vec::with_capacity(nchunks * DIO_CHUNK_SIZE);
for PhysicalRead {
start_chunk_no,
nchunks,
@@ -465,10 +459,7 @@ mod tests {
let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error);
let file = InMemoryFile::new_random(10);
let test_read = |pos, len| {
let buf = IoBufferMut::with_capacity_aligned_zeroed(
len,
virtual_file::get_io_buffer_alignment(),
);
let buf = vec![0; len];
let fut = file.read_exact_at_eof_ok(pos, buf.slice_full(), &ctx);
use futures::FutureExt;
let (slice, nread) = fut
@@ -479,9 +470,9 @@ mod tests {
buf.truncate(nread);
buf
};
assert_eq!(&test_read(0, 1), &file.content[0..1]);
assert_eq!(&test_read(1, 2), &file.content[1..3]);
assert_eq!(&test_read(9, 2), &file.content[9..]);
assert_eq!(test_read(0, 1), &file.content[0..1]);
assert_eq!(test_read(1, 2), &file.content[1..3]);
assert_eq!(test_read(9, 2), &file.content[9..]);
assert!(test_read(10, 2).is_empty());
assert!(test_read(11, 2).is_empty());
}
@@ -618,7 +609,7 @@ mod tests {
}
impl<'x> File for RecorderFile<'x> {
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufAlignedMut + Send>(
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufMut + Send>(
&'b self,
start: u64,
dst: Slice<B>,
@@ -791,7 +782,7 @@ mod tests {
2048, 1024 => Err("foo".to_owned()),
};
let buf = IoBufferMut::with_capacity_aligned(512, 512);
let buf = Vec::with_capacity(512);
let (buf, nread) = mock_file
.read_exact_at_eof_ok(0, buf.slice_full(), &ctx)
.await
@@ -799,7 +790,7 @@ mod tests {
assert_eq!(nread, 512);
assert_eq!(&buf.into_inner()[..nread], &[0; 512]);
let buf = IoBufferMut::with_capacity_aligned(512, 512);
let buf = Vec::with_capacity(512);
let (buf, nread) = mock_file
.read_exact_at_eof_ok(512, buf.slice_full(), &ctx)
.await
@@ -807,7 +798,7 @@ mod tests {
assert_eq!(nread, 512);
assert_eq!(&buf.into_inner()[..nread], &[1; 512]);
let buf = IoBufferMut::with_capacity_aligned(512, 512);
let buf = Vec::with_capacity(512);
let (buf, nread) = mock_file
.read_exact_at_eof_ok(1024, buf.slice_full(), &ctx)
.await
@@ -815,7 +806,7 @@ mod tests {
assert_eq!(nread, 10);
assert_eq!(&buf.into_inner()[..nread], &[2; 10]);
let buf = IoBufferMut::with_capacity_aligned(1024, 512);
let buf = Vec::with_capacity(1024);
let err = mock_file
.read_exact_at_eof_ok(2048, buf.slice_full(), &ctx)
.await

View File

@@ -330,6 +330,7 @@ async fn gc_loop(tenant: Arc<Tenant>, cancel: CancellationToken) {
RequestContext::todo_child(TaskKind::GarbageCollector, DownloadBehavior::Download);
let mut first = true;
tenant.gc_block.set_lsn_lease_deadline(tenant.get_lsn_lease_length());
loop {
tokio::select! {
_ = cancel.cancelled() => {

View File

@@ -66,7 +66,6 @@ use std::{
use crate::{
aux_file::AuxFileSizeEstimator,
tenant::{
config::AttachmentMode,
layer_map::{LayerMap, SearchResult},
metadata::TimelineMetadata,
storage_layer::{inmemory_layer::IndexEntry, PersistentLayerDesc},
@@ -1325,38 +1324,16 @@ impl Timeline {
Ok(())
}
/// Initializes an LSN lease. The function will return an error if the requested LSN is less than the `latest_gc_cutoff_lsn`.
pub(crate) fn init_lsn_lease(
&self,
lsn: Lsn,
length: Duration,
ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
self.make_lsn_lease(lsn, length, true, ctx)
}
/// Renews a lease at a particular LSN. The requested LSN is not validated against the `latest_gc_cutoff_lsn` when we are in the grace period.
pub(crate) fn renew_lsn_lease(
&self,
lsn: Lsn,
length: Duration,
ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
self.make_lsn_lease(lsn, length, false, ctx)
}
/// Obtains a temporary lease blocking garbage collection for the given LSN.
///
/// If we are in `AttachedSingle` mode and is not blocked by the lsn lease deadline, this function will error
/// if the requesting LSN is less than the `latest_gc_cutoff_lsn` and there is no existing request present.
///
/// If there is an existing lease in the map, the lease will be renewed only if the request extends the lease.
/// The returned lease is therefore the maximum between the existing lease and the requesting lease.
fn make_lsn_lease(
/// This function will error if the requesting LSN is less than the `latest_gc_cutoff_lsn` and there is also
/// no existing lease to renew. If there is an existing lease in the map, the lease will be renewed only if
/// the request extends the lease. The returned lease is therefore the maximum between the existing lease and
/// the requesting lease.
pub(crate) fn make_lsn_lease(
&self,
lsn: Lsn,
length: Duration,
init: bool,
_ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
let lease = {
@@ -1370,8 +1347,8 @@ impl Timeline {
let entry = gc_info.leases.entry(lsn);
match entry {
Entry::Occupied(mut occupied) => {
let lease = {
if let Entry::Occupied(mut occupied) = entry {
let existing_lease = occupied.get_mut();
if valid_until > existing_lease.valid_until {
existing_lease.valid_until = valid_until;
@@ -1383,28 +1360,20 @@ impl Timeline {
}
existing_lease.clone()
}
Entry::Vacant(vacant) => {
// Reject already GC-ed LSN (lsn < latest_gc_cutoff) if we are in AttachedSingle and
// not blocked by the lsn lease deadline.
let validate = {
let conf = self.tenant_conf.load();
conf.location.attach_mode == AttachmentMode::Single
&& !conf.is_gc_blocked_by_lsn_lease_deadline()
};
if init || validate {
let latest_gc_cutoff_lsn = self.get_latest_gc_cutoff_lsn();
if lsn < *latest_gc_cutoff_lsn {
bail!("tried to request a page version that was garbage collected. requested at {} gc cutoff {}", lsn, *latest_gc_cutoff_lsn);
}
} else {
// Reject already GC-ed LSN (lsn < latest_gc_cutoff)
let latest_gc_cutoff_lsn = self.get_latest_gc_cutoff_lsn();
if lsn < *latest_gc_cutoff_lsn {
bail!("tried to request a page version that was garbage collected. requested at {} gc cutoff {}", lsn, *latest_gc_cutoff_lsn);
}
let dt: DateTime<Utc> = valid_until.into();
info!("lease created, valid until {}", dt);
vacant.insert(LsnLease { valid_until }).clone()
entry.or_insert(LsnLease { valid_until }).clone()
}
}
};
lease
};
Ok(lease)
@@ -1981,6 +1950,8 @@ impl Timeline {
.unwrap_or(self.conf.default_tenant_conf.lsn_lease_length)
}
// TODO(yuchen): remove unused flag after implementing https://github.com/neondatabase/neon/issues/8072
#[allow(unused)]
pub(crate) fn get_lsn_lease_length_for_ts(&self) -> Duration {
let tenant_conf = self.tenant_conf.load();
tenant_conf

View File

@@ -18,7 +18,7 @@
use std::collections::BTreeMap;
use std::ops::Deref;
use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use pageserver_api::key::Key;
use tokio::io::AsyncWriteExt;
use tokio_epoll_uring::BoundedBuf;
@@ -27,7 +27,6 @@ use utils::vec_map::VecMap;
use crate::context::RequestContext;
use crate::tenant::blob_io::{BYTE_UNCOMPRESSED, BYTE_ZSTD, LEN_COMPRESSION_BIT_MASK};
use crate::virtual_file::dio::IoBufferMut;
use crate::virtual_file::{self, VirtualFile};
/// Metadata bundled with the start and end offset of a blob.
@@ -159,7 +158,7 @@ impl std::fmt::Display for VectoredBlob {
/// Return type of [`VectoredBlobReader::read_blobs`]
pub struct VectoredBlobsBuf {
/// Buffer for all blobs in this read
pub buf: IoBufferMut,
pub buf: BytesMut,
/// Offsets into the buffer and metadata for all blobs in this read
pub blobs: Vec<VectoredBlob>,
}
@@ -186,7 +185,171 @@ pub(crate) enum VectoredReadExtended {
No,
}
/// A vectored read builder that tries to coalesce all reads that fits in a chunk.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum VectoredReadCoalesceMode {
/// Only coalesce exactly adjacent reads.
AdjacentOnly,
/// In addition to adjacent reads, also consider reads whose corresponding
/// `end` and `start` offsets reside at the same chunk.
Chunked(usize),
}
impl VectoredReadCoalesceMode {
/// [`AdjacentVectoredReadBuilder`] is used if alignment requirement is 0,
/// whereas [`ChunkedVectoredReadBuilder`] is used for alignment requirement 1 and higher.
pub(crate) fn get() -> Self {
let align = virtual_file::get_io_buffer_alignment_raw();
if align == 0 {
VectoredReadCoalesceMode::AdjacentOnly
} else {
VectoredReadCoalesceMode::Chunked(align)
}
}
}
pub(crate) enum VectoredReadBuilder {
Adjacent(AdjacentVectoredReadBuilder),
Chunked(ChunkedVectoredReadBuilder),
}
impl VectoredReadBuilder {
fn new_impl(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: Option<usize>,
mode: VectoredReadCoalesceMode,
) -> Self {
match mode {
VectoredReadCoalesceMode::AdjacentOnly => Self::Adjacent(
AdjacentVectoredReadBuilder::new(start_offset, end_offset, meta, max_read_size),
),
VectoredReadCoalesceMode::Chunked(chunk_size) => {
Self::Chunked(ChunkedVectoredReadBuilder::new(
start_offset,
end_offset,
meta,
max_read_size,
chunk_size,
))
}
}
}
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: usize,
mode: VectoredReadCoalesceMode,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), mode)
}
pub(crate) fn new_streaming(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
mode: VectoredReadCoalesceMode,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, None, mode)
}
pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended {
match self {
VectoredReadBuilder::Adjacent(builder) => builder.extend(start, end, meta),
VectoredReadBuilder::Chunked(builder) => builder.extend(start, end, meta),
}
}
pub(crate) fn build(self) -> VectoredRead {
match self {
VectoredReadBuilder::Adjacent(builder) => builder.build(),
VectoredReadBuilder::Chunked(builder) => builder.build(),
}
}
pub(crate) fn size(&self) -> usize {
match self {
VectoredReadBuilder::Adjacent(builder) => builder.size(),
VectoredReadBuilder::Chunked(builder) => builder.size(),
}
}
}
pub(crate) struct AdjacentVectoredReadBuilder {
/// Start offset of the read.
start: u64,
// End offset of the read.
end: u64,
/// Start offset and metadata for each blob in this read
blobs_at: VecMap<u64, BlobMeta>,
max_read_size: Option<usize>,
}
impl AdjacentVectoredReadBuilder {
/// Start building a new vectored read.
///
/// Note that by design, this does not check against reading more than `max_read_size` to
/// support reading larger blobs than the configuration value. The builder will be single use
/// however after that.
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: Option<usize>,
) -> Self {
let mut blobs_at = VecMap::default();
blobs_at
.append(start_offset, meta)
.expect("First insertion always succeeds");
Self {
start: start_offset,
end: end_offset,
blobs_at,
max_read_size,
}
}
/// Attempt to extend the current read with a new blob if the start
/// offset matches with the current end of the vectored read
/// and the resuting size is below the max read size
pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended {
tracing::trace!(start, end, "trying to extend");
let size = (end - start) as usize;
let not_limited_by_max_read_size = {
if let Some(max_read_size) = self.max_read_size {
self.size() + size <= max_read_size
} else {
true
}
};
if self.end == start && not_limited_by_max_read_size {
self.end = end;
self.blobs_at
.append(start, meta)
.expect("LSNs are ordered within vectored reads");
return VectoredReadExtended::Yes;
}
VectoredReadExtended::No
}
pub(crate) fn size(&self) -> usize {
(self.end - self.start) as usize
}
pub(crate) fn build(self) -> VectoredRead {
VectoredRead {
start: self.start,
end: self.end,
blobs_at: self.blobs_at,
}
}
}
pub(crate) struct ChunkedVectoredReadBuilder {
/// Start block number
start_blk_no: usize,
@@ -210,7 +373,7 @@ impl ChunkedVectoredReadBuilder {
/// Note that by design, this does not check against reading more than `max_read_size` to
/// support reading larger blobs than the configuration value. The builder will be single use
/// however after that.
fn new_impl(
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
@@ -233,25 +396,6 @@ impl ChunkedVectoredReadBuilder {
}
}
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: usize,
align: usize,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), align)
}
pub(crate) fn new_streaming(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
align: usize,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, None, align)
}
/// Attempts to extend the current read with a new blob if the new blob resides in the same or the immediate next chunk.
///
/// The resulting size also must be below the max read size.
@@ -330,17 +474,17 @@ pub struct VectoredReadPlanner {
max_read_size: usize,
align: usize,
mode: VectoredReadCoalesceMode,
}
impl VectoredReadPlanner {
pub fn new(max_read_size: usize) -> Self {
let align = virtual_file::get_io_buffer_alignment();
let mode = VectoredReadCoalesceMode::get();
Self {
blobs: BTreeMap::new(),
prev: None,
max_read_size,
align,
mode,
}
}
@@ -401,7 +545,7 @@ impl VectoredReadPlanner {
}
pub fn finish(self) -> Vec<VectoredRead> {
let mut current_read_builder: Option<ChunkedVectoredReadBuilder> = None;
let mut current_read_builder: Option<VectoredReadBuilder> = None;
let mut reads = Vec::new();
for (key, blobs_for_key) in self.blobs {
@@ -414,12 +558,12 @@ impl VectoredReadPlanner {
};
if extended == VectoredReadExtended::No {
let next_read_builder = ChunkedVectoredReadBuilder::new(
let next_read_builder = VectoredReadBuilder::new(
start_offset,
end_offset,
BlobMeta { key, lsn },
self.max_read_size,
self.align,
self.mode,
);
let prev_read_builder = current_read_builder.replace(next_read_builder);
@@ -461,7 +605,7 @@ impl<'a> VectoredBlobReader<'a> {
pub async fn read_blobs(
&self,
read: &VectoredRead,
buf: IoBufferMut,
buf: BytesMut,
ctx: &RequestContext,
) -> Result<VectoredBlobsBuf, std::io::Error> {
assert!(read.size() > 0);
@@ -544,7 +688,7 @@ impl<'a> VectoredBlobReader<'a> {
/// `handle` gets called and when the current key would just exceed the read_size and
/// max_cnt constraints.
pub struct StreamingVectoredReadPlanner {
read_builder: Option<ChunkedVectoredReadBuilder>,
read_builder: Option<VectoredReadBuilder>,
// Arguments for previous blob passed into [`StreamingVectoredReadPlanner::handle`]
prev: Option<(Key, Lsn, u64)>,
/// Max read size per batch. This is not a strict limit. If there are [0, 100) and [100, 200), while the `max_read_size` is 150,
@@ -555,21 +699,21 @@ pub struct StreamingVectoredReadPlanner {
/// Size of the current batch
cnt: usize,
align: usize,
mode: VectoredReadCoalesceMode,
}
impl StreamingVectoredReadPlanner {
pub fn new(max_read_size: u64, max_cnt: usize) -> Self {
assert!(max_cnt > 0);
assert!(max_read_size > 0);
let align = virtual_file::get_io_buffer_alignment();
let mode = VectoredReadCoalesceMode::get();
Self {
read_builder: None,
prev: None,
max_cnt,
max_read_size,
cnt: 0,
align,
mode,
}
}
@@ -618,11 +762,11 @@ impl StreamingVectoredReadPlanner {
}
None => {
self.read_builder = {
Some(ChunkedVectoredReadBuilder::new_streaming(
Some(VectoredReadBuilder::new_streaming(
start_offset,
end_offset,
BlobMeta { key, lsn },
self.align,
self.mode,
))
};
}
@@ -946,10 +1090,9 @@ mod tests {
// Multiply by two (compressed data might need more space), and add a few bytes for the header
let reserved_bytes = blobs.iter().map(|bl| bl.len()).max().unwrap() * 2 + 16;
let align = virtual_file::get_io_buffer_alignment();
let mut buf = IoBufferMut::with_capacity_aligned(reserved_bytes, align);
let mut buf = BytesMut::with_capacity(reserved_bytes);
let align = virtual_file::get_io_buffer_alignment();
let mode = VectoredReadCoalesceMode::get();
let vectored_blob_reader = VectoredBlobReader::new(&file);
let meta = BlobMeta {
key: Key::MIN,
@@ -961,8 +1104,7 @@ mod tests {
if idx + 1 == offsets.len() {
continue;
}
let read_builder =
ChunkedVectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, align);
let read_builder = VectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, mode);
let read = read_builder.build();
let result = vectored_blob_reader.read_blobs(&read, buf, &ctx).await?;
assert_eq!(result.blobs.len(), 1);

View File

@@ -17,21 +17,16 @@ use crate::metrics::{StorageIoOperation, STORAGE_IO_SIZE, STORAGE_IO_TIME_METRIC
use crate::page_cache::{PageWriteGuard, PAGE_SZ};
use crate::tenant::TENANTS_SEGMENT_NAME;
use camino::{Utf8Path, Utf8PathBuf};
#[cfg(test)]
use dio::IoBufferMut;
use once_cell::sync::OnceCell;
use owned_buffers_io::io_buf_aligned::IoBufAlignedMut;
use owned_buffers_io::io_buf_ext::FullSlice;
use pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
use pageserver_api::shard::TenantShardId;
use std::fs::File;
use std::io::{Error, ErrorKind, Seek, SeekFrom};
#[cfg(target_os = "linux")]
use std::os::unix::fs::OpenOptionsExt;
use tokio_epoll_uring::{BoundedBuf, IoBuf, IoBufMut, Slice};
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tokio::time::Instant;
@@ -43,11 +38,10 @@ pub use io_engine::FeatureTestResult as IoEngineFeatureTestResult;
mod metadata;
mod open_options;
use self::owned_buffers_io::write::OwnedAsyncWriter;
pub(crate) use api::IoMode;
pub(crate) use api::DirectIoMode;
pub(crate) use io_engine::IoEngineKind;
pub(crate) use metadata::Metadata;
pub(crate) use open_options::*;
pub(crate) mod dio;
pub(crate) mod owned_buffers_io {
//! Abstractions for IO with owned buffers.
@@ -59,7 +53,6 @@ pub(crate) mod owned_buffers_io {
//! but for the time being we're proving out the primitives in the neon.git repo
//! for faster iteration.
pub(crate) mod io_buf_aligned;
pub(crate) mod io_buf_ext;
pub(crate) mod slice;
pub(crate) mod write;
@@ -68,176 +61,6 @@ pub(crate) mod owned_buffers_io {
}
}
#[derive(Debug)]
pub enum VirtualFile {
Buffered(VirtualFileInner),
Direct(VirtualFileInner),
}
impl VirtualFile {
fn inner(&self) -> &VirtualFileInner {
match self {
Self::Buffered(file) => file,
Self::Direct(file) => file,
}
}
fn inner_mut(&mut self) -> &mut VirtualFileInner {
match self {
Self::Buffered(file) => file,
Self::Direct(file) => file,
}
}
fn into_inner(self) -> VirtualFileInner {
match self {
Self::Buffered(file) => file,
Self::Direct(file) => file,
}
}
/// Open a file in read-only mode. Like File::open.
pub async fn open<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let file = VirtualFileInner::open(path, ctx).await?;
Ok(Self::Buffered(file))
}
/// Open a file in read-only mode. Like File::open.
///
/// `O_DIRECT` will be enabled base on `virtual_file_io_mode`.
pub async fn open_v2<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
Self::open_with_options_v2(path.as_ref(), OpenOptions::new().read(true), ctx).await
}
pub async fn create<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let file = VirtualFileInner::create(path, ctx).await?;
Ok(Self::Buffered(file))
}
pub async fn create_v2<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
VirtualFile::open_with_options_v2(
path.as_ref(),
OpenOptions::new().write(true).create(true).truncate(true),
ctx,
)
.await
}
pub async fn open_with_options<P: AsRef<Utf8Path>>(
path: P,
open_options: &OpenOptions,
ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
) -> Result<Self, std::io::Error> {
let file = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Ok(Self::Buffered(file))
}
pub async fn open_with_options_v2<P: AsRef<Utf8Path>>(
path: P,
open_options: &mut OpenOptions, // Uses `&mut` here to add `O_DIRECT`.
ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
) -> Result<Self, std::io::Error> {
let file = match get_io_mode() {
IoMode::Buffered => {
let file = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Self::Buffered(file)
}
#[cfg(target_os = "linux")]
IoMode::Direct => {
let file = VirtualFileInner::open_with_options(
path,
open_options.custom_flags(nix::libc::O_DIRECT),
ctx,
)
.await?;
Self::Direct(file)
}
};
Ok(file)
}
pub fn path(&self) -> &Utf8Path {
self.inner().path.as_path()
}
pub async fn crashsafe_overwrite<B: BoundedBuf<Buf = Buf> + Send, Buf: IoBuf + Send>(
final_path: Utf8PathBuf,
tmp_path: Utf8PathBuf,
content: B,
) -> std::io::Result<()> {
VirtualFileInner::crashsafe_overwrite(final_path, tmp_path, content).await
}
pub async fn sync_all(&self) -> Result<(), Error> {
self.inner().sync_all().await
}
pub async fn sync_data(&self) -> Result<(), Error> {
self.inner().sync_data().await
}
pub async fn metadata(&self) -> Result<Metadata, Error> {
self.inner().metadata().await
}
pub fn remove(self) {
self.into_inner().remove();
}
pub async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error> {
self.inner_mut().seek(pos).await
}
pub async fn read_exact_at<Buf>(
&self,
slice: Slice<Buf>,
offset: u64,
ctx: &RequestContext,
) -> Result<Slice<Buf>, Error>
where
Buf: IoBufAlignedMut + Send,
{
self.inner().read_exact_at(slice, offset, ctx).await
}
pub async fn read_exact_at_page(
&self,
page: PageWriteGuard<'static>,
offset: u64,
ctx: &RequestContext,
) -> Result<PageWriteGuard<'static>, Error> {
self.inner().read_exact_at_page(page, offset, ctx).await
}
pub async fn write_all_at<Buf: IoBuf + Send>(
&self,
buf: FullSlice<Buf>,
offset: u64,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<(), Error>) {
self.inner().write_all_at(buf, offset, ctx).await
}
pub async fn write_all<Buf: IoBuf + Send>(
&mut self,
buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<usize, Error>) {
self.inner_mut().write_all(buf, ctx).await
}
}
///
/// A virtual file descriptor. You can use this just like std::fs::File, but internally
/// the underlying file is closed if the system is low on file descriptors,
@@ -254,7 +77,7 @@ impl VirtualFile {
/// 'tag' field is used to detect whether the handle still is valid or not.
///
#[derive(Debug)]
pub struct VirtualFileInner {
pub struct VirtualFile {
/// Lazy handle to the global file descriptor cache. The slot that this points to
/// might contain our File, or it may be empty, or it may contain a File that
/// belongs to a different VirtualFile.
@@ -527,12 +350,12 @@ macro_rules! with_file {
}};
}
impl VirtualFileInner {
impl VirtualFile {
/// Open a file in read-only mode. Like File::open.
pub async fn open<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<VirtualFileInner, std::io::Error> {
) -> Result<VirtualFile, std::io::Error> {
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true), ctx).await
}
@@ -541,7 +364,7 @@ impl VirtualFileInner {
pub async fn create<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<VirtualFileInner, std::io::Error> {
) -> Result<VirtualFile, std::io::Error> {
Self::open_with_options(
path.as_ref(),
OpenOptions::new().write(true).create(true).truncate(true),
@@ -559,7 +382,7 @@ impl VirtualFileInner {
path: P,
open_options: &OpenOptions,
_ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
) -> Result<VirtualFileInner, std::io::Error> {
) -> Result<VirtualFile, std::io::Error> {
let path_ref = path.as_ref();
let path_str = path_ref.to_string();
let parts = path_str.split('/').collect::<Vec<&str>>();
@@ -590,7 +413,7 @@ impl VirtualFileInner {
open_options.open(path_ref.as_std_path()).await?
});
// Strip all options other than read and write (O_DIRECT).
// Strip all options other than read and write.
//
// It would perhaps be nicer to check just for the read and write flags
// explicitly, but OpenOptions doesn't contain any functions to read flags,
@@ -600,7 +423,7 @@ impl VirtualFileInner {
reopen_options.create_new(false);
reopen_options.truncate(false);
let vfile = VirtualFileInner {
let vfile = VirtualFile {
handle: RwLock::new(handle),
pos: 0,
path: path_ref.to_path_buf(),
@@ -643,7 +466,6 @@ impl VirtualFileInner {
&[]
};
utils::crashsafe::overwrite(&final_path, &tmp_path, content)
.maybe_fatal_err("crashsafe_overwrite")
})
.await
.expect("blocking task is never aborted")
@@ -653,7 +475,7 @@ impl VirtualFileInner {
pub async fn sync_all(&self) -> Result<(), Error> {
with_file!(self, StorageIoOperation::Fsync, |file_guard| {
let (_file_guard, res) = io_engine::get().sync_all(file_guard).await;
res.maybe_fatal_err("sync_all")
res
})
}
@@ -661,7 +483,7 @@ impl VirtualFileInner {
pub async fn sync_data(&self) -> Result<(), Error> {
with_file!(self, StorageIoOperation::Fsync, |file_guard| {
let (_file_guard, res) = io_engine::get().sync_data(file_guard).await;
res.maybe_fatal_err("sync_data")
res
})
}
@@ -781,7 +603,7 @@ impl VirtualFileInner {
ctx: &RequestContext,
) -> Result<Slice<Buf>, Error>
where
Buf: IoBufAlignedMut + Send,
Buf: IoBufMut + Send,
{
let assert_we_return_original_bounds = if cfg!(debug_assertions) {
Some((slice.stable_ptr() as usize, slice.bytes_total()))
@@ -1211,36 +1033,18 @@ impl tokio_epoll_uring::IoFd for FileGuard {
#[cfg(test)]
impl VirtualFile {
pub(crate) async fn read_blk(
&self,
blknum: u32,
ctx: &RequestContext,
) -> Result<crate::tenant::block_io::BlockLease<'_>, std::io::Error> {
self.inner().read_blk(blknum, ctx).await
}
async fn read_to_end(&mut self, buf: &mut Vec<u8>, ctx: &RequestContext) -> Result<(), Error> {
self.inner_mut().read_to_end(buf, ctx).await
}
}
#[cfg(test)]
impl VirtualFileInner {
pub(crate) async fn read_blk(
&self,
blknum: u32,
ctx: &RequestContext,
) -> Result<crate::tenant::block_io::BlockLease<'_>, std::io::Error> {
use crate::page_cache::PAGE_SZ;
let align = get_io_buffer_alignment();
let slice = IoBufferMut::with_capacity_aligned(PAGE_SZ, align).slice_full();
let slice = Vec::with_capacity(PAGE_SZ).slice_full();
assert_eq!(slice.bytes_total(), PAGE_SZ);
let slice = self
.read_exact_at(slice, blknum as u64 * (PAGE_SZ as u64), ctx)
.await?;
Ok(crate::tenant::block_io::BlockLease::IoBufferMut(
slice.into_inner(),
))
Ok(crate::tenant::block_io::BlockLease::Vec(slice.into_inner()))
}
async fn read_to_end(&mut self, buf: &mut Vec<u8>, ctx: &RequestContext) -> Result<(), Error> {
@@ -1262,7 +1066,7 @@ impl VirtualFileInner {
}
}
impl Drop for VirtualFileInner {
impl Drop for VirtualFile {
/// If a VirtualFile is dropped, close the underlying file if it was open.
fn drop(&mut self) {
let handle = self.handle.get_mut();
@@ -1343,9 +1147,7 @@ pub fn init(num_slots: usize, engine: IoEngineKind, io_buffer_alignment: usize)
panic!("virtual_file::init called twice");
}
if set_io_buffer_alignment(io_buffer_alignment).is_err() {
panic!(
"IO buffer alignment needs to be a power of two and greater than 512, got {io_buffer_alignment}"
);
panic!("IO buffer alignment ({io_buffer_alignment}) is not a power of two");
}
io_engine::init(engine);
crate::metrics::virtual_file_descriptor_cache::SIZE_MAX.set(num_slots as u64);
@@ -1372,16 +1174,14 @@ fn get_open_files() -> &'static OpenFiles {
static IO_BUFFER_ALIGNMENT: AtomicUsize = AtomicUsize::new(DEFAULT_IO_BUFFER_ALIGNMENT);
/// Returns true if the alignment is a power of two and is greater or equal to 512.
fn is_valid_io_buffer_alignment(align: usize) -> bool {
align.is_power_of_two() && align >= 512
/// Returns true if `x` is zero or a power of two.
fn is_zero_or_power_of_two(x: usize) -> bool {
(x == 0) || ((x & (x - 1)) == 0)
}
/// Sets IO buffer alignment requirement. Returns error if the alignment requirement is
/// not a power of two or less than 512 bytes.
#[allow(unused)]
pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> {
if is_valid_io_buffer_alignment(align) {
if is_zero_or_power_of_two(align) {
IO_BUFFER_ALIGNMENT.store(align, std::sync::atomic::Ordering::Relaxed);
Ok(())
} else {
@@ -1389,19 +1189,19 @@ pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> {
}
}
/// Gets the io buffer alignment.
/// Gets the io buffer alignment requirement. Returns 0 if there is no requirement specified.
///
/// This function should be used for getting the actual alignment value to use.
pub(crate) fn get_io_buffer_alignment() -> usize {
/// This function should be used to check the raw config value.
pub(crate) fn get_io_buffer_alignment_raw() -> usize {
let align = IO_BUFFER_ALIGNMENT.load(std::sync::atomic::Ordering::Relaxed);
if cfg!(test) {
let env_var_name = "NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT";
if let Some(test_align) = utils::env::var(env_var_name) {
if is_valid_io_buffer_alignment(test_align) {
if is_zero_or_power_of_two(test_align) {
test_align
} else {
panic!("IO buffer alignment needs to be a power of two and greater than 512, got {test_align}");
panic!("IO buffer alignment ({test_align}) is not a power of two");
}
} else {
align
@@ -1411,22 +1211,20 @@ pub(crate) fn get_io_buffer_alignment() -> usize {
}
}
static IO_MODE: AtomicU8 = AtomicU8::new(IoMode::preferred() as u8);
pub(crate) fn set_io_mode(mode: IoMode) {
IO_MODE.store(mode as u8, std::sync::atomic::Ordering::Relaxed);
/// Gets the io buffer alignment requirement. Returns 1 if the alignment config is set to zero.
///
/// This function should be used for getting the actual alignment value to use.
pub(crate) fn get_io_buffer_alignment() -> usize {
let align = get_io_buffer_alignment_raw();
align.max(1)
}
pub(crate) fn get_io_mode() -> IoMode {
IoMode::try_from(IO_MODE.load(Ordering::Relaxed)).unwrap()
}
#[cfg(test)]
mod tests {
use crate::context::DownloadBehavior;
use crate::task_mgr::TaskKind;
use super::*;
use dio::IoBufferMut;
use owned_buffers_io::io_buf_ext::IoBufExt;
use owned_buffers_io::slice::SliceMutExt;
use rand::seq::SliceRandom;
@@ -1450,10 +1248,10 @@ mod tests {
impl MaybeVirtualFile {
async fn read_exact_at(
&self,
mut slice: tokio_epoll_uring::Slice<IoBufferMut>,
mut slice: tokio_epoll_uring::Slice<Vec<u8>>,
offset: u64,
ctx: &RequestContext,
) -> Result<tokio_epoll_uring::Slice<IoBufferMut>, Error> {
) -> Result<tokio_epoll_uring::Slice<Vec<u8>>, Error> {
match self {
MaybeVirtualFile::VirtualFile(file) => file.read_exact_at(slice, offset, ctx).await,
MaybeVirtualFile::File(file) => {
@@ -1521,13 +1319,11 @@ mod tests {
len: usize,
ctx: &RequestContext,
) -> Result<String, Error> {
let slice = IoBufferMut::with_capacity_aligned(len, 512).slice_full();
let slice = Vec::with_capacity(len).slice_full();
assert_eq!(slice.bytes_total(), len);
let slice = self.read_exact_at(slice, pos, ctx).await?;
let buf = slice.into_inner();
assert_eq!(buf.len(), len);
let mut vec = Vec::with_capacity(buf.len());
vec.extend_from_slice(&buf);
let vec = slice.into_inner();
assert_eq!(vec.len(), len);
Ok(String::from_utf8(vec).unwrap())
}
}
@@ -1716,7 +1512,6 @@ mod tests {
const VIRTUAL_FILES: usize = 100;
const THREADS: usize = 100;
const SAMPLE: [u8; SIZE] = [0xADu8; SIZE];
let align = super::get_io_buffer_alignment();
let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error);
let testdir = crate::config::PageServerConf::test_repo_dir("vfile_concurrency");
@@ -1732,7 +1527,7 @@ mod tests {
// Open the file many times.
let mut files = Vec::new();
for _ in 0..VIRTUAL_FILES {
let f = VirtualFileInner::open_with_options(
let f = VirtualFile::open_with_options(
&test_file_path,
OpenOptions::new().read(true),
&ctx,
@@ -1753,7 +1548,7 @@ mod tests {
let files = files.clone();
let ctx = ctx.detached_child(TaskKind::UnitTest, DownloadBehavior::Error);
let hdl = rt.spawn(async move {
let mut buf = IoBufferMut::with_capacity_aligned_zeroed(SIZE, align);
let mut buf = vec![0u8; SIZE];
let mut rng = rand::rngs::OsRng;
for _ in 1..1000 {
let f = &files[rng.gen_range(0..files.len())];
@@ -1762,7 +1557,7 @@ mod tests {
.await
.unwrap()
.into_inner();
assert!(&buf == SAMPLE.as_slice());
assert!(buf == SAMPLE);
}
});
hdls.push(hdl);
@@ -1784,7 +1579,7 @@ mod tests {
let path = testdir.join("myfile");
let tmp_path = testdir.join("myfile.tmp");
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
.await
.unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
@@ -1793,7 +1588,7 @@ mod tests {
assert!(!tmp_path.exists());
drop(file);
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec())
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec())
.await
.unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
@@ -1816,7 +1611,7 @@ mod tests {
std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap();
assert!(tmp_path.exists());
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
.await
.unwrap();

View File

@@ -1,434 +0,0 @@
#![allow(unused)]
use core::slice;
use std::{
alloc::{self, Layout},
cmp,
mem::{ManuallyDrop, MaybeUninit},
ops::{Deref, DerefMut},
ptr::{addr_of_mut, NonNull},
};
use bytes::buf::UninitSlice;
#[derive(Debug)]
struct IoBufferPtr(*mut u8);
// SAFETY: We gurantees no one besides `IoBufferPtr` itself has the raw pointer.
unsafe impl Send for IoBufferPtr {}
/// An aligned buffer type used for I/O.
#[derive(Debug)]
pub struct IoBufferMut {
ptr: IoBufferPtr,
capacity: usize,
len: usize,
align: usize,
}
impl IoBufferMut {
/// Constructs a new, empty `IoBufferMut` with at least the specified capacity and alignment.
///
/// The buffer will be able to hold at most `capacity` elements and will never resize.
///
///
/// # Panics
///
/// Panics if the new capacity exceeds `isize::MAX` _bytes_, or if the following alignment requirement is not met:
/// * `align` must not be zero,
///
/// * `align` must be a power of two,
///
/// * `capacity`, when rounded up to the nearest multiple of `align`,
/// must not overflow isize (i.e., the rounded value must be
/// less than or equal to `isize::MAX`).
pub fn with_capacity_aligned(capacity: usize, align: usize) -> Self {
let layout = Layout::from_size_align(capacity, align).expect("Invalid layout");
// SAFETY: Making an allocation with a sized and aligned layout. The memory is manually freed with the same layout.
let ptr = unsafe {
let ptr = alloc::alloc(layout);
if ptr.is_null() {
alloc::handle_alloc_error(layout);
}
IoBufferPtr(ptr)
};
IoBufferMut {
ptr,
capacity,
len: 0,
align,
}
}
pub fn with_capacity_aligned_zeroed(capacity: usize, align: usize) -> Self {
use bytes::BufMut;
let mut buf = Self::with_capacity_aligned(capacity, align);
buf.put_bytes(0, capacity);
buf.len = capacity;
buf
}
/// Returns the total number of bytes the buffer can hold.
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
/// Returns the alignment of the buffer.
#[inline]
pub fn align(&self) -> usize {
self.align
}
/// Returns the number of bytes in the buffer, also referred to as its 'length'.
#[inline]
pub fn len(&self) -> usize {
self.len
}
/// Force the length of the buffer to `new_len`.
#[inline]
unsafe fn set_len(&mut self, new_len: usize) {
debug_assert!(new_len <= self.capacity());
self.len = new_len;
}
#[inline]
fn as_ptr(&self) -> *const u8 {
self.ptr.0
}
#[inline]
fn as_mut_ptr(&mut self) -> *mut u8 {
self.ptr.0
}
/// Extracts a slice containing the entire buffer.
///
/// Equivalent to `&s[..]`.
#[inline]
fn as_slice(&self) -> &[u8] {
// SAFETY: The pointer is valid and `len` bytes are initialized.
unsafe { slice::from_raw_parts(self.as_ptr(), self.len) }
}
/// Extracts a mutable slice of the entire buffer.
///
/// Equivalent to `&mut s[..]`.
fn as_mut_slice(&mut self) -> &mut [u8] {
// SAFETY: The pointer is valid and `len` bytes are initialized.
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.len) }
}
/// Drops the all the contents of the buffer, setting its length to `0`.
#[inline]
pub fn clear(&mut self) {
self.len = 0;
}
/// Reserves capacity for at least `additional` more bytes to be inserted
/// in the given `IoBufferMut`. The collection may reserve more space to
/// speculatively avoid frequent reallocations. After calling `reserve`,
/// capacity will be greater than or equal to `self.len() + additional`.
/// Does nothing if capacity is already sufficient.
///
/// # Panics
///
/// Panics if the new capacity exceeds `isize::MAX` _bytes_.
pub fn reserve(&mut self, additional: usize) {
if additional > self.capacity() - self.len() {
self.reserve_inner(additional);
}
}
/// Shortens the buffer, keeping the first len bytes.
pub fn truncate(&mut self, len: usize) {
if len > self.len {
return;
}
self.len = len;
}
fn reserve_inner(&mut self, additional: usize) {
let Some(required_cap) = self.len().checked_add(additional) else {
capacity_overflow()
};
let old_capacity = self.capacity();
let align = self.align();
// This guarantees exponential growth. The doubling cannot overflow
// because `cap <= isize::MAX` and the type of `cap` is `usize`.
let cap = cmp::max(old_capacity * 2, required_cap);
if !is_valid_alloc(cap) {
capacity_overflow()
}
let new_layout = Layout::from_size_align(cap, self.align()).expect("Invalid layout");
let old_ptr = self.as_mut_ptr();
// SAFETY: old allocation was allocated with std::alloc::alloc with the same layout,
// and we panics on null pointer.
let (ptr, cap) = unsafe {
let old_layout = Layout::from_size_align_unchecked(old_capacity, align);
let ptr = alloc::realloc(old_ptr, old_layout, new_layout.size());
if ptr.is_null() {
alloc::handle_alloc_error(new_layout);
}
(IoBufferPtr(ptr), cap)
};
self.ptr = ptr;
self.capacity = cap;
}
pub fn leak<'a>(self) -> &'a mut [u8] {
let mut buf = ManuallyDrop::new(self);
// SAFETY: leaking the buffer as intended.
unsafe { slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.len) }
}
}
fn capacity_overflow() -> ! {
panic!("capacity overflow")
}
// We need to guarantee the following:
// * We don't ever allocate `> isize::MAX` byte-size objects.
// * We don't overflow `usize::MAX` and actually allocate too little.
//
// On 64-bit we just need to check for overflow since trying to allocate
// `> isize::MAX` bytes will surely fail. On 32-bit and 16-bit we need to add
// an extra guard for this in case we're running on a platform which can use
// all 4GB in user-space, e.g., PAE or x32.
#[inline]
fn is_valid_alloc(alloc_size: usize) -> bool {
!(usize::BITS < 64 && alloc_size > isize::MAX as usize)
}
impl Drop for IoBufferMut {
fn drop(&mut self) {
// SAFETY: memory was allocated with std::alloc::alloc with the same layout.
unsafe {
alloc::dealloc(
self.as_mut_ptr(),
Layout::from_size_align_unchecked(self.capacity, self.align),
)
}
}
}
impl AsRef<[u8]> for IoBufferMut {
fn as_ref(&self) -> &[u8] {
self.as_slice()
}
}
impl AsMut<[u8]> for IoBufferMut {
fn as_mut(&mut self) -> &mut [u8] {
self.as_mut_slice()
}
}
impl PartialEq<[u8]> for IoBufferMut {
fn eq(&self, other: &[u8]) -> bool {
self.as_slice().eq(other)
}
}
impl Deref for IoBufferMut {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl DerefMut for IoBufferMut {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_slice()
}
}
/// SAFETY: When advancing the internal cursor, the caller needs to make sure the bytes advcanced past have been initialized.
unsafe impl bytes::BufMut for IoBufferMut {
#[inline]
fn remaining_mut(&self) -> usize {
// Although a `Vec` can have at most isize::MAX bytes, we never want to grow `IoBufferMut`.
// Thus, it can have at most `self.capacity` bytes.
self.capacity() - self.len()
}
// SAFETY: Caller needs to make sure the bytes being advanced past have been initialized.
#[inline]
unsafe fn advance_mut(&mut self, cnt: usize) {
let len: usize = self.len();
let remaining = self.remaining_mut();
if remaining < cnt {
panic_advance(cnt, remaining);
}
// Addition will not overflow since the sum is at most the capacity.
self.set_len(len + cnt);
}
#[inline]
fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
let cap = self.capacity();
let len = self.len();
// SAFETY: Since `self.ptr` is valid for `cap` bytes, `self.ptr.add(len)` must be
// valid for `cap - len` bytes. The subtraction will not underflow since
// `len <= cap`.
unsafe { UninitSlice::from_raw_parts_mut(self.as_mut_ptr().add(len), cap - len) }
}
}
/// Panic with a nice error message.
#[cold]
fn panic_advance(idx: usize, len: usize) -> ! {
panic!(
"advance out of bounds: the len is {} but advancing by {}",
len, idx
);
}
/// Safety: [`IoBufferMut`] has exclusive ownership of the io buffer,
/// and the location remains stable even if [`Self`] is moved.
unsafe impl tokio_epoll_uring::IoBuf for IoBufferMut {
fn stable_ptr(&self) -> *const u8 {
self.as_ptr()
}
fn bytes_init(&self) -> usize {
self.len()
}
fn bytes_total(&self) -> usize {
self.capacity()
}
}
// SAFETY: See above.
unsafe impl tokio_epoll_uring::IoBufMut for IoBufferMut {
fn stable_mut_ptr(&mut self) -> *mut u8 {
self.as_mut_ptr()
}
unsafe fn set_init(&mut self, init_len: usize) {
if self.len() < init_len {
self.set_len(init_len);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_with_capacity_aligned() {
const ALIGN: usize = 4 * 1024;
let v = IoBufferMut::with_capacity_aligned(ALIGN * 4, ALIGN);
assert_eq!(v.len(), 0);
assert_eq!(v.capacity(), ALIGN * 4);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
let v = IoBufferMut::with_capacity_aligned(ALIGN / 2, ALIGN);
assert_eq!(v.len(), 0);
assert_eq!(v.capacity(), ALIGN / 2);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
}
#[test]
fn test_with_capacity_aligned_zeroed() {
const ALIGN: usize = 4 * 1024;
let v = IoBufferMut::with_capacity_aligned_zeroed(ALIGN, ALIGN);
assert_eq!(v.len(), ALIGN);
assert_eq!(v.capacity(), ALIGN);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
assert_eq!(&v[..], &[0; ALIGN])
}
#[test]
fn test_reserve() {
use bytes::BufMut;
const ALIGN: usize = 4 * 1024;
let mut v = IoBufferMut::with_capacity_aligned(ALIGN, ALIGN);
let capacity = v.capacity();
v.reserve(capacity);
assert_eq!(v.capacity(), capacity);
let data = [b'a'; ALIGN];
v.put(&data[..]);
v.reserve(capacity);
assert!(v.capacity() >= capacity * 2);
assert_eq!(&v[..], &data[..]);
let capacity = v.capacity();
v.clear();
v.reserve(capacity);
assert_eq!(capacity, v.capacity());
}
#[test]
fn test_bytes_put() {
use bytes::BufMut;
const ALIGN: usize = 4 * 1024;
let mut v = IoBufferMut::with_capacity_aligned(ALIGN * 4, ALIGN);
let x = [b'a'; ALIGN];
for _ in 0..2 {
for _ in 0..4 {
v.put(&x[..]);
}
assert_eq!(v.len(), ALIGN * 4);
assert_eq!(v.capacity(), ALIGN * 4);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
v.clear()
}
assert_eq!(v.len(), 0);
assert_eq!(v.capacity(), ALIGN * 4);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
}
#[test]
#[should_panic]
fn test_bytes_put_panic() {
use bytes::BufMut;
const ALIGN: usize = 4 * 1024;
let mut v = IoBufferMut::with_capacity_aligned(ALIGN * 4, ALIGN);
let x = [b'a'; ALIGN];
for _ in 0..5 {
v.put_slice(&x[..]);
}
}
#[test]
fn test_io_buf_put_slice() {
use tokio_epoll_uring::BoundedBufMut;
const ALIGN: usize = 4 * 1024;
let mut v = IoBufferMut::with_capacity_aligned(ALIGN, ALIGN);
let x = [b'a'; ALIGN];
for _ in 0..2 {
v.put_slice(&x[..]);
assert_eq!(v.len(), ALIGN);
assert_eq!(v.capacity(), ALIGN);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
v.clear()
}
assert_eq!(v.len(), 0);
assert_eq!(v.capacity(), ALIGN);
assert_eq!(v.align(), ALIGN);
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
}
}

View File

@@ -1,11 +0,0 @@
#![allow(unused)]
use tokio_epoll_uring::IoBufMut;
use crate::virtual_file::{dio::IoBufferMut, PageWriteGuardBuf};
pub trait IoBufAlignedMut: IoBufMut {}
impl IoBufAlignedMut for IoBufferMut {}
impl IoBufAlignedMut for PageWriteGuardBuf {}

View File

@@ -1,6 +1,5 @@
//! See [`FullSlice`].
use crate::virtual_file::dio::IoBufferMut;
use bytes::{Bytes, BytesMut};
use std::ops::{Deref, Range};
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
@@ -77,4 +76,3 @@ macro_rules! impl_io_buf_ext {
impl_io_buf_ext!(Bytes);
impl_io_buf_ext!(BytesMut);
impl_io_buf_ext!(Vec<u8>);
impl_io_buf_ext!(IoBufferMut);

View File

@@ -1473,33 +1473,11 @@ walprop_pg_wal_read(Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count,
{
NeonWALReadResult res;
#if PG_MAJORVERSION_NUM >= 17
if (!sk->wp->config->syncSafekeepers)
{
Size rbytes;
rbytes = WALReadFromBuffers(buf, startptr, count,
walprop_pg_get_timeline_id());
startptr += rbytes;
count -= rbytes;
}
#endif
if (count == 0)
{
res = NEON_WALREAD_SUCCESS;
}
else
{
Assert(count > 0);
/* Now read the remaining WAL from the WAL file */
res = NeonWALRead(sk->xlogreader,
buf,
startptr,
count,
walprop_pg_get_timeline_id());
}
res = NeonWALRead(sk->xlogreader,
buf,
startptr,
count,
walprop_pg_get_timeline_id());
if (res == NEON_WALREAD_SUCCESS)
{

View File

@@ -565,7 +565,7 @@ mod tests {
stream::{PqStream, Stream},
};
use super::{auth_quirks, AuthRateLimiter};
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
struct Auth {
ips: Vec<IpPattern>,
@@ -611,12 +611,15 @@ mod tests {
}
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: false,
});
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {

View File

@@ -1,4 +1,5 @@
use std::{
borrow::Cow,
future::Future,
sync::Arc,
time::{Duration, SystemTime},
@@ -8,7 +9,10 @@ use anyhow::{bail, ensure, Context};
use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo;
use serde::{Deserialize, Deserializer};
use serde::{
de::{DeserializeSeed, IgnoredAny, Visitor},
Deserializer,
};
use signature::Verifier;
use tokio::time::Instant;
@@ -33,6 +37,7 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
}
#[derive(Debug, Clone)]
pub(crate) struct AuthRule {
pub(crate) id: String,
pub(crate) jwks_url: url::Url,
@@ -303,35 +308,21 @@ impl JwkCacheEntryLock {
}
key => bail!("unsupported key type {key:?}"),
};
tracing::debug!("JWT signature valid");
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
.context("Provided authentication token is not a valid JWT encoding")?;
tracing::debug!(?payload, "JWT signature valid with claims");
let validator = JwtValidator {
expected_audience,
current_time: SystemTime::now(),
clock_skew_leeway: CLOCK_SKEW_LEEWAY,
};
match (expected_audience, payload.audience) {
// check the audience matches
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
// the audience is expected but is missing
(Some(_), None) => bail!("invalid JWT token audience"),
// we don't care for the audience field
(None, _) => {}
}
let payload = validator
.deserialize(&mut serde_json::Deserializer::from_slice(&payload))?;
let now = SystemTime::now();
if let Some(exp) = payload.expiration {
ensure!(now < exp + CLOCK_SKEW_LEEWAY, "JWT token has expired");
}
if let Some(nbf) = payload.not_before {
ensure!(
nbf < now + CLOCK_SKEW_LEEWAY,
"JWT token is not yet ready to use"
);
}
tracing::debug!(?payload, "JWT claims valid");
Ok(())
}
@@ -419,37 +410,184 @@ struct JwtHeader<'a> {
key_id: Option<&'a str>,
}
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
#[derive(serde::Deserialize, serde::Serialize, Debug)]
struct JwtPayload<'a> {
/// Audience - Recipient for which the JWT is intended
#[serde(rename = "aud")]
audience: Option<&'a str>,
/// Expiration - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
expiration: Option<SystemTime>,
/// Not before - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
not_before: Option<SystemTime>,
// the following entries are only extracted for the sake of debug logging.
/// Issuer of the JWT
#[serde(rename = "iss")]
issuer: Option<&'a str>,
/// Subject of the JWT (the user)
#[serde(rename = "sub")]
subject: Option<&'a str>,
/// Unique token identifier
#[serde(rename = "jti")]
jwt_id: Option<&'a str>,
/// Unique session identifier
#[serde(rename = "sid")]
session_id: Option<&'a str>,
struct JwtValidator<'a> {
expected_audience: Option<&'a str>,
current_time: SystemTime,
clock_skew_leeway: Duration,
}
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
let d = <Option<u64>>::deserialize(d)?;
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
impl<'de> DeserializeSeed<'de> for JwtValidator<'_> {
type Value = JwtPayload<'de>;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
impl<'de> Visitor<'de> for JwtValidator<'_> {
type Value = JwtPayload<'de>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a JWT payload")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut payload = JwtPayload {
issuer: None,
subject: None,
jwt_id: None,
session_id: None,
};
let mut aud = false;
while let Some(key) = map.next_key()? {
match key {
"iss" if payload.issuer.is_none() => {
payload.issuer = Some(map.next_value()?);
}
"sub" if payload.subject.is_none() => {
payload.subject = Some(map.next_value()?);
}
"jit" if payload.jwt_id.is_none() => {
payload.jwt_id = Some(map.next_value()?);
}
"sid" if payload.session_id.is_none() => {
payload.session_id = Some(map.next_value()?);
}
"exp" => {
let exp = map.next_value::<u64>()?;
let exp = SystemTime::UNIX_EPOCH + Duration::from_secs(exp);
if self.current_time > exp + self.clock_skew_leeway {
return Err(serde::de::Error::custom("JWT token has expired"));
}
}
"nbf" => {
let nbf = map.next_value::<u64>()?;
let nbf = SystemTime::UNIX_EPOCH + Duration::from_secs(nbf);
if self.current_time + self.clock_skew_leeway < nbf {
return Err(serde::de::Error::custom(
"JWT token is not yet ready to use",
));
}
}
"aud" => {
if let Some(expected_audience) = self.expected_audience {
map.next_value_seed(AudienceValidator { expected_audience })?;
aud = true;
} else {
map.next_value::<IgnoredAny>()?;
}
}
_ => map.next_value::<IgnoredAny>().map(|IgnoredAny| ())?,
}
}
if self.expected_audience.is_some() && !aud {
return Err(serde::de::Error::custom("invalid JWT token audience"));
}
Ok(payload)
}
}
deserializer.deserialize_map(self)
}
}
struct AudienceValidator<'a> {
expected_audience: &'a str,
}
impl<'de> DeserializeSeed<'de> for AudienceValidator<'_> {
type Value = ();
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
impl<'de> Visitor<'de> for AudienceValidator<'_> {
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a single string or an array of strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if self.expected_audience == v {
Ok(())
} else {
Err(E::custom("invalid JWT token audience"))
}
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
while let Some(v) = seq.next_element_seed(SingleAudienceValidator {
expected_audience: self.expected_audience,
})? {
if v {
return Ok(());
}
}
Err(serde::de::Error::custom("invalid JWT token audience"))
}
}
deserializer.deserialize_any(self)
}
}
struct SingleAudienceValidator<'a> {
expected_audience: &'a str,
}
impl<'de> DeserializeSeed<'de> for SingleAudienceValidator<'_> {
type Value = bool;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
impl<'de> Visitor<'de> for SingleAudienceValidator<'_> {
type Value = bool;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a single audience string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(self.expected_audience == v)
}
}
deserializer.deserialize_any(self)
}
}
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
// the following entries are only extracted for the sake of debug logging.
#[derive(Debug)]
#[allow(dead_code)]
struct JwtPayload<'a> {
/// Issuer of the JWT
issuer: Option<Cow<'a, str>>,
/// Subject of the JWT (the user)
subject: Option<Cow<'a, str>>,
/// Unique token identifier
jwt_id: Option<Cow<'a, str>>,
/// Unique session identifier
session_id: Option<Cow<'a, str>>,
}
struct JwkRenewalPermit<'a> {
@@ -530,6 +668,8 @@ mod tests {
use hyper_util::rt::TokioIo;
use rand::rngs::OsRng;
use rsa::pkcs8::DecodePrivateKey;
use serde::Serialize;
use serde_json::json;
use signature::Signer;
use tokio::net::TcpListener;
@@ -562,18 +702,36 @@ mod tests {
}
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let body = typed_json::json! {{
"exp": now + 3600,
"nbf": now,
"aud": ["audience1", "neon", "audience2"],
"sub": "user1",
"sid": "session1",
"jti": "token1",
"iss": "neon-testing",
}};
build_custom_jwt_payload(kid, body, sig)
}
fn build_custom_jwt_payload(
kid: String,
body: impl Serialize,
sig: jose_jwa::Signing,
) -> String {
let header = JwtHeader {
typ: "JWT",
algorithm: jose_jwa::Algorithm::Signing(sig),
key_id: Some(&kid),
};
let body = typed_json::json! {{
"exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
}};
let header =
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
format!("{header}.{body}")
}
@@ -588,6 +746,16 @@ mod tests {
format!("{payload}.{sig}")
}
fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
use p256::ecdsa::{Signature, SigningKey};
let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
format!("{payload}.{sig}")
}
fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
use rsa::pkcs1v15::SigningKey;
use rsa::signature::SignatureEncoding;
@@ -659,37 +827,34 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
-----END PRIVATE KEY-----
";
#[tokio::test]
async fn renew() {
let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into());
let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into());
let (ec1, jwk3) = new_ec_jwk("3".into());
let (ec2, jwk4) = new_ec_jwk("4".into());
#[derive(Clone)]
struct Fetch(Vec<AuthRule>);
let foo_jwks = jose_jwk::JwkSet {
keys: vec![jwk1, jwk3],
};
let bar_jwks = jose_jwk::JwkSet {
keys: vec![jwk2, jwk4],
};
impl FetchAuthRules for Fetch {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
Ok(self.0.clone())
}
}
async fn jwks_server(
router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
) -> SocketAddr {
let router = Arc::new(router);
let service = service_fn(move |req| {
let foo_jwks = foo_jwks.clone();
let bar_jwks = bar_jwks.clone();
let router = Arc::clone(&router);
async move {
let jwks = match req.uri().path() {
"/foo" => &foo_jwks,
"/bar" => &bar_jwks,
_ => {
return Response::builder()
.status(404)
.body(Full::new(Bytes::new()));
}
};
let body = serde_json::to_vec(jwks).unwrap();
Response::builder()
.status(200)
.body(Full::new(Bytes::from(body)))
match router(req.uri().path()) {
Some(body) => Response::builder()
.status(200)
.body(Full::new(Bytes::from(body))),
None => Response::builder()
.status(404)
.body(Full::new(Bytes::new())),
}
}
});
@@ -704,84 +869,61 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
}
});
let client = reqwest::Client::new();
addr
}
#[derive(Clone)]
struct Fetch(SocketAddr, Vec<RoleNameInt>);
#[tokio::test]
async fn check_jwt_happy_path() {
let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
let (ec1, jwk3) = new_ec_jwk("ec1".into());
let (ec2, jwk4) = new_ec_jwk("ec2".into());
impl FetchAuthRules for Fetch {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
Ok(vec![
AuthRule {
id: "foo".to_owned(),
jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
audience: None,
role_names: self.1.clone(),
},
AuthRule {
id: "bar".to_owned(),
jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
audience: None,
role_names: self.1.clone(),
},
])
}
}
let foo_jwks = jose_jwk::JwkSet {
keys: vec![jwk1, jwk3],
};
let bar_jwks = jose_jwk::JwkSet {
keys: vec![jwk2, jwk4],
};
let jwks_addr = jwks_server(move |path| match path {
"/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
"/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
_ => None,
})
.await;
let role_name1 = RoleName::from("anonymous");
let role_name2 = RoleName::from("authenticated");
let fetch = Fetch(
addr,
vec![
RoleNameInt::from(&role_name1),
RoleNameInt::from(&role_name2),
],
);
let roles = vec![
RoleNameInt::from(&role_name1),
RoleNameInt::from(&role_name2),
];
let rules = vec![
AuthRule {
id: "foo".to_owned(),
jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
audience: None,
role_names: roles.clone(),
},
AuthRule {
id: "bar".to_owned(),
jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
audience: None,
role_names: roles.clone(),
},
];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let endpoint = EndpointId::from("ep");
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
let jwt1 = new_rsa_jwt("1".into(), rs1);
let jwt2 = new_rsa_jwt("2".into(), rs2);
let jwt3 = new_ec_jwt("3".into(), &ec1);
let jwt4 = new_ec_jwt("4".into(), &ec2);
// had the wrong kid, therefore will have the wrong ecdsa signature
let bad_jwt = new_ec_jwt("3".into(), &ec2);
// this role_name is not accepted
let bad_role_name = RoleName::from("cloud_admin");
let err = jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&bad_jwt,
&client,
endpoint.clone(),
&role_name1,
&fetch,
)
.await
.unwrap_err();
assert!(err.to_string().contains("signature error"));
let err = jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&jwt1,
&client,
endpoint.clone(),
&bad_role_name,
&fetch,
)
.await
.unwrap_err();
assert!(err.to_string().contains("jwk not found"));
let jwt1 = new_rsa_jwt("rs1".into(), rs1);
let jwt2 = new_rsa_jwt("rs2".into(), rs2);
let jwt3 = new_ec_jwt("ec1".into(), &ec1);
let jwt4 = new_ec_jwt("ec2".into(), &ec2);
let tokens = [jwt1, jwt2, jwt3, jwt4];
let role_names = [role_name1, role_name2];
@@ -790,15 +932,194 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
jwk_cache
.check_jwt(
&RequestMonitoring::test(),
token,
&client,
endpoint.clone(),
role,
&fetch,
token,
)
.await
.unwrap();
}
}
}
#[tokio::test]
async fn check_jwt_invalid_signature() {
let (_, jwk) = new_ec_jwk("1".into());
let (key, _) = new_ec_jwk("1".into());
// has a matching kid, but signed by the wrong key
let bad_jwt = new_ec_jwt("1".into(), &key);
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;
let role = RoleName::from("authenticated");
let rules = vec![AuthRule {
id: String::new(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: None,
role_names: vec![RoleNameInt::from(&role)],
}];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let ep = EndpointId::from("ep");
let ctx = RequestMonitoring::test();
let err = jwk_cache
.check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
.await
.unwrap_err();
assert!(
err.to_string().contains("signature error"),
"expected \"signature error\", got {err:?}"
);
}
#[tokio::test]
async fn check_jwt_unknown_role() {
let (key, jwk) = new_rsa_jwk(RS1, "1".into());
let jwt = new_rsa_jwt("1".into(), key);
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;
let role = RoleName::from("authenticated");
let rules = vec![AuthRule {
id: String::new(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: None,
role_names: vec![RoleNameInt::from(&role)],
}];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let ep = EndpointId::from("ep");
// this role_name is not accepted
let bad_role_name = RoleName::from("cloud_admin");
let ctx = RequestMonitoring::test();
let err = jwk_cache
.check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
.await
.unwrap_err();
assert!(
err.to_string().contains("jwk not found"),
"expected \"jwk not found\", got {err:?}"
);
}
#[tokio::test]
async fn check_jwt_invalid_claims() {
let (key, jwk) = new_ec_jwk("1".into());
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
struct Test {
body: serde_json::Value,
error: &'static str,
}
let table = vec![
Test {
body: json! {{
"nbf": now + 60,
"aud": "neon",
}},
error: "JWT token is not yet ready to use",
},
Test {
body: json! {{
"exp": now - 60,
"aud": ["neon"],
}},
error: "JWT token has expired",
},
Test {
body: json! {{
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": [],
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": "foo",
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": ["foo"],
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": ["foo", "bar"],
}},
error: "invalid JWT token audience",
},
];
let role = RoleName::from("authenticated");
let rules = vec![AuthRule {
id: String::new(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: Some("neon".to_string()),
role_names: vec![RoleNameInt::from(&role)],
}];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let ep = EndpointId::from("ep");
let ctx = RequestMonitoring::test();
for test in table {
let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
match jwk_cache
.check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
.await
{
Err(err) if err.to_string().contains(test.error) => {}
Err(err) => {
panic!("expected {:?}, got {err:?}", test.error)
}
Ok(()) => {
panic!("expected {:?}, got ok", test.error)
}
}
}
}
}

View File

@@ -14,17 +14,15 @@ use crate::{
EndpointId,
};
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
use super::jwt::{AuthRule, FetchAuthRules};
pub struct LocalBackend {
pub(crate) jwks_cache: JwkCache,
pub(crate) node_info: NodeInfo,
}
impl LocalBackend {
pub fn new(postgres_addr: SocketAddr) -> Self {
LocalBackend {
jwks_cache: JwkCache::default(),
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();

View File

@@ -6,7 +6,10 @@ use compute_api::spec::LocalProxySpec;
use dashmap::DashMap;
use futures::future::Either;
use proxy::{
auth::backend::local::{LocalBackend, JWKS_ROLE_MAP},
auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
console::{
@@ -267,12 +270,15 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
allow_self_signed_compute: false,
http_config,
authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: true,
},
require_client_ip: false,
handshake_timeout: Duration::from_secs(10),

View File

@@ -8,6 +8,7 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_config::Region;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
@@ -102,6 +103,9 @@ struct ProxyCliArgs {
default_value = "http://localhost:3000/authenticate_proxy_request/"
)]
auth_endpoint: String,
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_auth_broker: bool,
/// path to TLS key for client postgres connections
///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
@@ -382,9 +386,27 @@ async fn main() -> anyhow::Result<()> {
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let proxy_listener = if !args.is_auth_broker {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
Some(TcpListener::bind(proxy_address).await?)
} else {
None
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
None
};
let cancellation_token = CancellationToken::new();
let cancel_map = CancelMap::default();
@@ -430,21 +452,17 @@ async fn main() -> anyhow::Result<()> {
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
let serverless_listener = TcpListener::bind(serverless_address).await?;
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
serverless_listener,
@@ -674,7 +692,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
)?;
let http_config = HttpConfig {
accept_websockets: true,
accept_websockets: !args.is_auth_broker,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
@@ -689,12 +707,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
};
let config = Box::leak(Box::new(ProxyConfig {

View File

@@ -1,5 +1,8 @@
use crate::{
auth::{self, backend::AuthRateLimiter},
auth::{
self,
backend::{jwt::JwkCache, AuthRateLimiter},
},
console::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
@@ -67,6 +70,9 @@ pub struct AuthenticationConfig {
pub rate_limiter: AuthRateLimiter,
pub rate_limit_ip_subnet: u8,
pub ip_allowlist_check_enabled: bool,
pub jwks_cache: JwkCache,
pub is_auth_broker: bool,
pub accept_jwts: bool,
}
impl TlsConfig {
@@ -250,18 +256,26 @@ impl CertResolver {
let common_name = pem.subject().to_string();
// We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
let common_name = if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
} else {
common_name.strip_prefix("CN=").map(|s| s.to_string())
}
.context("Failed to parse common name from certificate")?;
bail!("Failed to parse common name from certificate")
};
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));

View File

@@ -1,5 +1,6 @@
use std::{
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
any::type_name, hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index,
sync::OnceLock,
};
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
@@ -16,12 +17,21 @@ pub struct StringInterner<Id> {
_id: PhantomData<Id>,
}
#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
#[derive(PartialEq, Clone, Copy, Eq, Hash)]
pub struct InternedString<Id> {
inner: Spur,
_id: PhantomData<Id>,
}
impl<Id: InternId> std::fmt::Debug for InternedString<Id> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("InternedString")
.field(&type_name::<Id>())
.field(&self.as_str())
.finish()
}
}
impl<Id: InternId> std::fmt::Display for InternedString<Id> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.as_str().fmt(f)

View File

@@ -5,6 +5,7 @@
mod backend;
pub mod cancel_set;
mod conn_pool;
mod http_conn_pool;
mod http_util;
mod json;
mod sql_over_http;
@@ -19,7 +20,8 @@ use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::Full;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use hyper1::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
@@ -81,7 +83,28 @@ pub async fn task_main(
}
});
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
{
let http_conn_pool = Arc::clone(&http_conn_pool);
tokio::spawn(async move {
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
});
}
// shutdown the connection pool
tokio::spawn({
let cancellation_token = cancellation_token.clone();
let http_conn_pool = http_conn_pool.clone();
async move {
cancellation_token.cancelled().await;
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
.await
.unwrap();
}
});
let backend = Arc::new(PoolingBackend {
http_conn_pool: Arc::clone(&http_conn_pool),
pool: Arc::clone(&conn_pool),
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
@@ -342,7 +365,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let host = request
.headers()
.get("host")
@@ -386,7 +409,7 @@ async fn request_handler(
);
// Return the response so the spawned future can continue.
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
@@ -409,7 +432,7 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Full::new(Bytes::new()))
.body(Empty::new().map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")

View File

@@ -1,6 +1,8 @@
use std::{sync::Arc, time::Duration};
use std::{io, sync::Arc, time::Duration};
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use tokio::net::{lookup_host, TcpStream};
use tracing::{field::display, info};
use crate::{
@@ -27,9 +29,13 @@ use crate::{
Host,
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -103,32 +109,44 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: &str,
) -> Result<ComputeCredentials, AuthError> {
jwt: String,
) -> Result<(), AuthError> {
match &self.config.auth_backend {
crate::auth::Backend::Console(_, ()) => {
Err(AuthError::auth_failed("JWT login is not yet supported"))
crate::auth::Backend::Console(console, ()) => {
config
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
&user_info.user,
&**console,
&jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(())
}
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(cache) => {
cache
crate::auth::Backend::Local(_) => {
config
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
&user_info.user,
&StaticAuthRules,
jwt,
&jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
Ok(())
}
}
}
@@ -174,14 +192,55 @@ impl PoolingBackend {
)
.await
}
// Wake up the destination if needed
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client, HttpConnError> {
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connection to compute")]
ConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to postgres in compute")]
PostgresConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to local-proxy in compute")]
LocalProxyConnectionError(#[from] LocalProxyConnError),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
@@ -193,11 +252,20 @@ pub(crate) enum HttpConnError {
TooManyConnectionAttempts(#[from] ApiLockError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper1::Error),
}
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::ConnectionError(p) => p.get_error_kind(),
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
HttpConnError::WakeCompute(w) => w.get_error_kind(),
@@ -210,7 +278,8 @@ impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::ConnectionError(p) => p.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
HttpConnError::WakeCompute(c) => c.to_string_client(),
@@ -224,7 +293,8 @@ impl UserFacingError for HttpConnError {
impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::ConnectionError(e) => e.could_retry(),
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
@@ -236,7 +306,7 @@ impl CouldRetry for HttpConnError {
impl ShouldRetryWakeCompute for HttpConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
// we never checked cache validity
HttpConnError::TooManyConnectionAttempts(_) => false,
_ => true,
@@ -244,6 +314,38 @@ impl ShouldRetryWakeCompute for HttpConnError {
}
}
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
}
impl UserFacingError for LocalProxyConnError {
fn to_string_client(&self) -> String {
"Could not establish HTTP connection to the database".to_string()
}
}
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
conn_info: ConnInfo,
@@ -293,3 +395,99 @@ impl ConnectMechanism for TokioMechanism {
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
struct HyperMechanism {
pool: Arc<http_conn_pool::GlobalConnPool>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestMonitoring,
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let permit = self.locks.get_permit(&host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
let res = connect_http2(&host, 10432, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(
host: &str,
port: u16,
timeout: Duration,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
// assumption: host is an ip address so this should not actually perform any requests.
// todo: add that assumption as a guarantee in the control-plane API.
let mut addrs = lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?;
let mut last_err = None;
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
};
};
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;
Ok((client, connection))
}

View File

@@ -0,0 +1,335 @@
use dashmap::DashMap;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::{sync::Arc, sync::Weak};
use tokio::net::TcpStream;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, EndpointCacheKey};
use tracing::{debug, error};
use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
#[derive(Clone)]
struct ConnPoolEntry {
conn: Send,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool {
conns: VecDeque<ConnPoolEntry>,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
}
impl EndpointConnPool {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
let Self { conns, .. } = self;
let conn = conns.pop_front()?;
conns.push_back(conn.clone());
Some(conn)
}
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
let Self {
conns,
global_connections_count,
..
} = self;
let old_len = conns.len();
conns.retain(|conn| conn.conn_id != conn_id);
let new_len = conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
removed > 0
}
}
impl Drop for EndpointConnPool {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.conns.len() as i64);
}
}
}
pub(crate) struct GlobalConnPool {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
impl GlobalConnPool {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool { conns, .. } = pool.get_mut();
let old_len = conns.len();
conns.retain(|conn| !conn.conn.is_closed());
let new_len = conns.len();
let removed = old_len - new_len;
clients_removed += removed;
// we only remove this pool if it has no active connections
if conns.is_empty() {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Option<Client> {
let endpoint = conn_info.endpoint_cache_key()?;
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let client = endpoint_pool.write().get_conn_entry()?;
if client.conn.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return None;
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Some(Client::new(client.conn, client.aux))
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
conns: VecDeque::new(),
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => {
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
pool.write().conns.push_back(ConnPoolEntry {
conn: client.clone(),
conn_id,
aux: aux.clone(),
});
Arc::downgrade(&pool)
}
None => Weak::new(),
};
// let idle = global_pool.get_idle_timeout();
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!(%session_id, "connection error: {}", e),
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
}
}
.instrument(span),
);
Client::new(client, aux)
}
pub(crate) struct Client {
pub(crate) inner: Send,
aux: MetricsAuxInfo,
}
impl Client {
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
Self { inner, aux }
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
USAGE_METRICS.register(Ids {
endpoint_id: self.aux.endpoint_id,
branch_id: self.aux.branch_id,
})
}
}

View File

@@ -5,13 +5,13 @@ use bytes::Bytes;
use anyhow::Context;
use http::{Response, StatusCode};
use http_body_util::Full;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use serde::Serialize;
use utils::http::error::ApiError;
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
@@ -64,17 +64,24 @@ struct HttpErrorBody {
impl HttpErrorBody {
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
fn response_from_msg_and_status(
msg: String,
status: StatusCode,
) -> Response<BoxBody<Bytes, hyper1::Error>> {
HttpErrorBody { msg }.to_response(status)
}
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
.body(
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
.map_err(|x| match x {})
.boxed(),
)
.unwrap()
}
}
@@ -83,14 +90,14 @@ impl HttpErrorBody {
pub(crate) fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;
let response = Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json)))
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}

View File

@@ -8,6 +8,8 @@ use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use http::header::AUTHORIZATION;
use http::Method;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
@@ -38,9 +40,11 @@ use url::Url;
use urlencoding;
use utils::http::error::ApiError;
use crate::auth::backend::ComputeCredentials;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
@@ -56,6 +60,7 @@ use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName;
use crate::RoleName;
use super::backend::LocalProxyConnError;
use super::backend::PoolingBackend;
use super::conn_pool::AuthData;
use super::conn_pool::Client;
@@ -123,8 +128,8 @@ pub(crate) enum ConnInfoError {
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing password")]
MissingPassword,
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
@@ -133,6 +138,14 @@ pub(crate) enum ConnInfoError {
MalformedEndpoint,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
@@ -146,6 +159,7 @@ impl UserFacingError for ConnInfoError {
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestMonitoring,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
@@ -181,21 +195,32 @@ fn get_conn_info(
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingPassword)?
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingPassword);
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint = match connection_url.host() {
@@ -247,7 +272,7 @@ pub(crate) async fn handle(
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
@@ -279,7 +304,7 @@ pub(crate) async fn handle(
let mut message = e.to_string_client();
let db_error = match &e {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
@@ -504,7 +529,7 @@ async fn handle_inner(
ctx: &RequestMonitoring,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get()
.proxy
.connection_requests
@@ -514,18 +539,50 @@ async fn handle_inner(
"handling interactive connection from client"
);
//
// Determine the destination and connection params
//
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
let conn_info = get_conn_info(
&config.authentication_config,
ctx,
request.headers(),
config.tls_config.as_ref(),
)?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
);
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
cancel,
config,
ctx,
request,
conn_info.conn_info,
auth,
backend,
)
.await
}
}
}
async fn handle_db_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
auth: AuthData,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
//
// Determine the destination and connection params
//
let headers = request.headers();
// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
let allow_pool = !config.http_config.pool_options.opt_in
@@ -563,26 +620,36 @@ async fn handle_inner(
let authenticate_and_connect = Box::pin(
async {
let keys = match &conn_info.auth {
let keys = match auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.conn_info.user_info,
pw,
&conn_info.user_info,
&pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
.await?
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await?;
ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
}
}
};
let client = backend
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
@@ -640,7 +707,11 @@ async fn handle_inner(
let len = json_output.len();
let response = response
.body(Full::new(Bytes::from(json_output)))
.body(
Full::new(Bytes::from(json_output))
.map_err(|x| match x {})
.boxed(),
)
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
@@ -656,6 +727,65 @@ async fn handle_inner(
Ok(response)
}
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&AUTHORIZATION,
&CONN_STRING,
&RAW_TEXT_OUTPUT,
&ARRAY_MODE,
&TXN_ISOLATION_LEVEL,
&TXN_READ_ONLY,
&TXN_DEFERRABLE,
];
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
let (mut parts, body) = request.into_parts();
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
// these instead just to ensure they remain normalised.
for &h in HEADERS_TO_FORWARD {
if let Some(hv) = parts.headers.remove(h) {
req = req.header(h, hv);
}
}
let req = req
.body(body)
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress
let _metrics = client.metrics();
Ok(client
.inner
.send_request(req)
.await
.map_err(LocalProxyConnError::from)
.map_err(HttpConnError::from)?
.map(|b| b.boxed()))
}
impl QueryData {
async fn process(
self,
@@ -705,7 +835,9 @@ impl QueryData {
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = match &error {
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
SqlOverHttpError::ConnectCompute(
HttpConnError::PostgresConnectionError(e),
)
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};

View File

@@ -9,7 +9,7 @@ use crate::walproposer_sim::{
pub mod walproposer_sim;
// Generates 500 random seeds and runs a schedule for each of them.
// Generates 2000 random seeds and runs a schedule for each of them.
// If you see this test fail, please report the last seed to the
// @safekeeper team.
#[test]
@@ -17,7 +17,7 @@ fn test_random_schedules() -> anyhow::Result<()> {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
for _ in 0..500 {
for _ in 0..2000 {
let seed: u64 = rand::thread_rng().gen();
config.network = generate_network_opts(seed);

View File

@@ -56,6 +56,7 @@ from _pytest.fixtures import FixtureRequest
from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import cursor as PgCursor
from psycopg2.extensions import make_dsn, parse_dsn
from pytest_httpserver import HTTPServer
from urllib3.util.retry import Retry
from fixtures import overlayfs
@@ -401,7 +402,6 @@ class NeonEnvBuilder:
safekeeper_extra_opts: Optional[list[str]] = None,
storage_controller_port_override: Optional[int] = None,
pageserver_io_buffer_alignment: Optional[int] = None,
pageserver_virtual_file_io_mode: Optional[str] = None,
):
self.repo_dir = repo_dir
self.rust_log_override = rust_log_override
@@ -441,9 +441,9 @@ class NeonEnvBuilder:
self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine
self.pageserver_default_tenant_config_compaction_algorithm: Optional[
Dict[str, Any]
] = pageserver_default_tenant_config_compaction_algorithm
self.pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]] = (
pageserver_default_tenant_config_compaction_algorithm
)
if self.pageserver_default_tenant_config_compaction_algorithm is not None:
log.debug(
f"Overriding pageserver default compaction algorithm to {self.pageserver_default_tenant_config_compaction_algorithm}"
@@ -456,7 +456,6 @@ class NeonEnvBuilder:
self.storage_controller_port_override = storage_controller_port_override
self.pageserver_io_buffer_alignment = pageserver_io_buffer_alignment
self.pageserver_virtual_file_io_mode = pageserver_virtual_file_io_mode
assert test_name.startswith(
"test_"
@@ -1030,7 +1029,6 @@ class NeonEnv:
self.pageserver_virtual_file_io_engine = config.pageserver_virtual_file_io_engine
self.pageserver_aux_file_policy = config.pageserver_aux_file_policy
self.pageserver_io_buffer_alignment = config.pageserver_io_buffer_alignment
self.pageserver_virtual_file_io_mode = config.pageserver_virtual_file_io_mode
# Create the neon_local's `NeonLocalInitConf`
cfg: Dict[str, Any] = {
@@ -1075,9 +1073,9 @@ class NeonEnv:
ps_cfg["virtual_file_io_engine"] = self.pageserver_virtual_file_io_engine
if config.pageserver_default_tenant_config_compaction_algorithm is not None:
tenant_config = ps_cfg.setdefault("tenant_config", {})
tenant_config[
"compaction_algorithm"
] = config.pageserver_default_tenant_config_compaction_algorithm
tenant_config["compaction_algorithm"] = (
config.pageserver_default_tenant_config_compaction_algorithm
)
if self.pageserver_remote_storage is not None:
ps_cfg["remote_storage"] = remote_storage_to_toml_dict(
@@ -1094,10 +1092,7 @@ class NeonEnv:
for key, value in override.items():
ps_cfg[key] = value
if self.pageserver_io_buffer_alignment is not None:
ps_cfg["io_buffer_alignment"] = self.pageserver_io_buffer_alignment
if self.pageserver_virtual_file_io_mode is not None:
ps_cfg["virtual_file_io_mode"] = self.pageserver_virtual_file_io_mode
ps_cfg["io_buffer_alignment"] = self.pageserver_io_buffer_alignment
# Create a corresponding NeonPageserver object
self.pageservers.append(
@@ -1123,9 +1118,9 @@ class NeonEnv:
if config.auth_enabled:
sk_cfg["auth_enabled"] = True
if self.safekeepers_remote_storage is not None:
sk_cfg[
"remote_storage"
] = self.safekeepers_remote_storage.to_toml_inline_table().strip()
sk_cfg["remote_storage"] = (
self.safekeepers_remote_storage.to_toml_inline_table().strip()
)
self.safekeepers.append(
Safekeeper(env=self, id=id, port=port, extra_opts=config.safekeeper_extra_opts)
)
@@ -1336,7 +1331,6 @@ def neon_simple_env(
pageserver_aux_file_policy: Optional[AuxFileStore],
pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]],
pageserver_io_buffer_alignment: Optional[int],
pageserver_virtual_file_io_mode: Optional[str],
) -> Iterator[NeonEnv]:
"""
Simple Neon environment, with no authentication and no safekeepers.
@@ -1363,7 +1357,6 @@ def neon_simple_env(
pageserver_aux_file_policy=pageserver_aux_file_policy,
pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm,
pageserver_io_buffer_alignment=pageserver_io_buffer_alignment,
pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode,
) as builder:
env = builder.init_start()
@@ -1388,7 +1381,6 @@ def neon_env_builder(
pageserver_aux_file_policy: Optional[AuxFileStore],
record_property: Callable[[str, object], None],
pageserver_io_buffer_alignment: Optional[int],
pageserver_virtual_file_io_mode: Optional[str],
) -> Iterator[NeonEnvBuilder]:
"""
Fixture to create a Neon environment for test.
@@ -1424,7 +1416,6 @@ def neon_env_builder(
pageserver_aux_file_policy=pageserver_aux_file_policy,
pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm,
pageserver_io_buffer_alignment=pageserver_io_buffer_alignment,
pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode,
) as builder:
yield builder
# Propogate `preserve_database_files` to make it possible to use in other fixtures,
@@ -3321,7 +3312,7 @@ class VanillaPostgres(PgProtocol):
self.pg_bin = pg_bin
self.running = False
if init:
self.pg_bin.run_capture(["initdb", "--pgdata", str(pgdatadir)])
self.pg_bin.run_capture(["initdb", "-D", str(pgdatadir)])
self.configure([f"port = {port}\n"])
def enable_tls(self):
@@ -3582,6 +3573,20 @@ class NeonProxy(PgProtocol):
]
return args
class AuthBroker(AuthBackend):
def __init__(self, endpoint: str):
self.endpoint = endpoint
def extra_args(self) -> list[str]:
args = [
# Console auth backend params
*["--auth-backend", "console"],
*["--auth-endpoint", self.endpoint],
*["--sql-over-http-pool-opt-in", "false"],
*["--is-auth-broker"],
]
return args
@dataclass(frozen=True)
class Postgres(AuthBackend):
pg_conn_url: str
@@ -3610,7 +3615,7 @@ class NeonProxy(PgProtocol):
metric_collection_interval: Optional[str] = None,
):
host = "127.0.0.1"
domain = "proxy.localtest.me" # resolves to 127.0.0.1
domain = "ep-foo-bar-1234.localtest.me" # resolves to 127.0.0.1
super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port)
self.domain = domain
@@ -3896,6 +3901,50 @@ def static_proxy(
yield proxy
@pytest.fixture(scope="function")
def static_auth_broker(
vanilla_pg: VanillaPostgres,
port_distributor: PortDistributor,
neon_binpath: Path,
test_output_dir: Path,
httpserver: HTTPServer,
) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
auth_endpoint = httpserver.url_for("/cplane")
port = vanilla_pg.default_options["port"]
host = vanilla_pg.default_options["host"]
httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json(
{
"address": f"{host}:{port}",
"aux": {
"endpoint_id": "ep-foo-bar-1234",
"branch_id": "br-foo-bar",
"project_id": "foo-bar",
},
}
)
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()
external_http_port = port_distributor.get_port()
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.AuthBroker(auth_endpoint),
) as proxy:
proxy.start()
yield proxy
class Endpoint(PgProtocol, LogUtils):
"""An object representing a Postgres compute endpoint managed by the control plane."""

View File

@@ -39,11 +39,6 @@ def pageserver_io_buffer_alignment() -> Optional[int]:
return None
@pytest.fixture(scope="function", autouse=True)
def pageserver_virtual_file_io_mode() -> Optional[str]:
return os.getenv("PAGESERVER_VIRTUAL_FILE_IO_MODE")
@pytest.fixture(scope="function", autouse=True)
def pageserver_aux_file_policy() -> Optional[AuxFileStore]:
return None

View File

@@ -0,0 +1,71 @@
import asyncio
import json
import subprocess
import time
import urllib.parse
from typing import Any, List, Optional, Tuple
import psycopg2
import pytest
import requests
from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres
GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'"
def test_sql_over_http(static_auth_broker: NeonProxy):
static_auth_broker.safe_psql("create role http with login password 'http' superuser")
def q(sql: str, params: Optional[List[Any]] = None) -> Any:
params = params or []
connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
response = requests.post(
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
data=json.dumps({"query": sql, "params": params}),
headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
verify=str(static_proxy.test_output_dir / "proxy.crt"),
)
assert response.status_code == 200, response.text
return response.json()
rows = q("select 42 as answer")["rows"]
assert rows == [{"answer": 42}]
rows = q("select $1 as answer", [42])["rows"]
assert rows == [{"answer": "42"}]
rows = q("select $1 * 1 as answer", [42])["rows"]
assert rows == [{"answer": 42}]
rows = q("select $1::int[] as answer", [[1, 2, 3]])["rows"]
assert rows == [{"answer": [1, 2, 3]}]
rows = q("select $1::json->'a' as answer", [{"a": {"b": 42}}])["rows"]
assert rows == [{"answer": {"b": 42}}]
rows = q("select $1::jsonb[] as answer", [[{}]])["rows"]
assert rows == [{"answer": [{}]}]
rows = q("select $1::jsonb[] as answer", [[{"foo": 1}, {"bar": 2}]])["rows"]
assert rows == [{"answer": [{"foo": 1}, {"bar": 2}]}]
rows = q("select * from pg_class limit 1")["rows"]
assert len(rows) == 1
res = q("create table t(id serial primary key, val int)")
assert res["command"] == "CREATE"
assert res["rowCount"] is None
res = q("insert into t(val) values (10), (20), (30) returning id")
assert res["command"] == "INSERT"
assert res["rowCount"] == 3
assert res["rows"] == [{"id": 1}, {"id": 2}, {"id": 3}]
res = q("select * from t")
assert res["command"] == "SELECT"
assert res["rowCount"] == 3
res = q("drop table t")
assert res["command"] == "DROP"
assert res["rowCount"] is None

View File

@@ -27,7 +27,7 @@ def test_readonly_node(neon_simple_env: NeonEnv):
env.pageserver.allowed_errors.extend(
[
".*basebackup .* failed: invalid basebackup lsn.*",
".*/lsn_lease.*invalid lsn lease request.*",
".*page_service.*handle_make_lsn_lease.*.*tried to request a page version that was garbage collected",
]
)
@@ -108,7 +108,7 @@ def test_readonly_node(neon_simple_env: NeonEnv):
assert cur.fetchone() == (1,)
# Create node at pre-initdb lsn
with pytest.raises(Exception, match="invalid lsn lease request"):
with pytest.raises(Exception, match="invalid basebackup lsn"):
# compute node startup with invalid LSN should fail
env.endpoints.create_start(
branch_name="main",
@@ -167,23 +167,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
)
return last_flush_lsn
def trigger_gc_and_select(env: NeonEnv, ep_static: Endpoint):
"""
Trigger GC manually on all pageservers. Then run an `SELECT` query.
"""
for shard, ps in tenant_get_shards(env, env.initial_tenant):
client = ps.http_client()
gc_result = client.timeline_gc(shard, env.initial_timeline, 0)
log.info(f"{gc_result=}")
assert (
gc_result["layers_removed"] == 0
), "No layers should be removed, old layers are guarded by leases."
with ep_static.cursor() as cur:
cur.execute("SELECT count(*) FROM t0")
assert cur.fetchone() == (ROW_COUNT,)
# Insert some records on main branch
with env.endpoints.create_start("main") as ep_main:
with ep_main.cursor() as cur:
@@ -210,31 +193,25 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
generate_updates_on_main(env, ep_main, i, end=100)
trigger_gc_and_select(env, ep_static)
# Trigger GC
for shard, ps in tenant_get_shards(env, env.initial_tenant):
client = ps.http_client()
gc_result = client.timeline_gc(shard, env.initial_timeline, 0)
log.info(f"{gc_result=}")
# Trigger Pageserver restarts
for ps in env.pageservers:
ps.stop()
# Static compute should have at least one lease request failure due to connection.
time.sleep(LSN_LEASE_LENGTH / 2)
ps.start()
assert (
gc_result["layers_removed"] == 0
), "No layers should be removed, old layers are guarded by leases."
trigger_gc_and_select(env, ep_static)
# Reconfigure pageservers
env.pageservers[0].stop()
env.storage_controller.node_configure(
env.pageservers[0].id, {"availability": "Offline"}
)
env.storage_controller.reconcile_until_idle()
trigger_gc_and_select(env, ep_static)
with ep_static.cursor() as cur:
cur.execute("SELECT count(*) FROM t0")
assert cur.fetchone() == (ROW_COUNT,)
# Do some update so we can increment latest_gc_cutoff
generate_updates_on_main(env, ep_main, i, end=100)
# Wait for the existing lease to expire.
time.sleep(LSN_LEASE_LENGTH + 1)
time.sleep(LSN_LEASE_LENGTH)
# Now trigger GC again, layers should be removed.
for shard, ps in tenant_get_shards(env, env.initial_tenant):
client = ps.http_client()

View File

@@ -45,7 +45,10 @@ def test_gc_blocking_by_timeline(neon_env_builder: NeonEnvBuilder, sharded: bool
tenant_after = http.tenant_status(env.initial_tenant)
assert tenant_before != tenant_after
gc_blocking = tenant_after["gc_blocking"]
assert gc_blocking == "BlockingReasons { timelines: 1, reasons: EnumSet(Manual) }"
assert (
gc_blocking
== "BlockingReasons { tenant_blocked_by_lsn_lease_deadline: false, timelines: 1, reasons: EnumSet(Manual) }"
)
wait_for_another_gc_round()
pss.assert_log_contains(gc_skipped_line)

View File

@@ -772,7 +772,7 @@ class ProposerPostgres(PgProtocol):
def initdb(self):
"""Run initdb"""
args = ["initdb", "--username", "cloud_admin", "--pgdata", self.pg_data_dir_path()]
args = ["initdb", "-U", "cloud_admin", "-D", self.pg_data_dir_path()]
self.pg_bin.run(args)
def start(self):