mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-21 23:20:40 +00:00
Compare commits
20 Commits
yuchen/dir
...
auth-broke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3f7d0d3f1 | ||
|
|
0724df1d3f | ||
|
|
4d47049b00 | ||
|
|
5687384a8e | ||
|
|
249f5ea17d | ||
|
|
6abcc1f298 | ||
|
|
3e97cf0d6e | ||
|
|
054ef4988b | ||
|
|
5202cd75b5 | ||
|
|
f475dac0e6 | ||
|
|
a4100373e5 | ||
|
|
040d8cf4f6 | ||
|
|
75bfd57e01 | ||
|
|
4bc2686dee | ||
|
|
8e7d2aab76 | ||
|
|
2703abccc7 | ||
|
|
76515cdae3 | ||
|
|
08c7f933a3 | ||
|
|
4ad3aa7c96 | ||
|
|
9c59e3b4b9 |
46
.github/workflows/_benchmarking_preparation.yml
vendored
46
.github/workflows/_benchmarking_preparation.yml
vendored
@@ -3,23 +3,19 @@ name: Prepare benchmarking databases by restoring dumps
|
||||
on:
|
||||
workflow_call:
|
||||
# no inputs needed
|
||||
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -euxo pipefail {0}
|
||||
|
||||
jobs:
|
||||
setup-databases:
|
||||
permissions:
|
||||
contents: write
|
||||
statuses: write
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
platform: [ aws-rds-postgres, aws-aurora-serverless-v2-postgres, neon ]
|
||||
platform: [ aws-rds-postgres, aws-aurora-serverless-v2-postgres, neon ]
|
||||
database: [ clickbench, tpch, userexample ]
|
||||
|
||||
|
||||
env:
|
||||
LD_LIBRARY_PATH: /tmp/neon/pg_install/v16/lib
|
||||
PLATFORM: ${{ matrix.platform }}
|
||||
@@ -27,10 +23,7 @@ jobs:
|
||||
|
||||
runs-on: [ self-hosted, us-east-2, x64 ]
|
||||
container:
|
||||
image: neondatabase/build-tools:pinned
|
||||
credentials:
|
||||
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
|
||||
options: --init
|
||||
|
||||
steps:
|
||||
@@ -39,13 +32,13 @@ jobs:
|
||||
run: |
|
||||
case "${PLATFORM}" in
|
||||
neon)
|
||||
CONNSTR=${{ secrets.BENCHMARK_CAPTEST_CONNSTR }}
|
||||
CONNSTR=${{ secrets.BENCHMARK_CAPTEST_CONNSTR }}
|
||||
;;
|
||||
aws-rds-postgres)
|
||||
CONNSTR=${{ secrets.BENCHMARK_RDS_POSTGRES_CONNSTR }}
|
||||
CONNSTR=${{ secrets.BENCHMARK_RDS_POSTGRES_CONNSTR }}
|
||||
;;
|
||||
aws-aurora-serverless-v2-postgres)
|
||||
CONNSTR=${{ secrets.BENCHMARK_RDS_AURORA_CONNSTR }}
|
||||
CONNSTR=${{ secrets.BENCHMARK_RDS_AURORA_CONNSTR }}
|
||||
;;
|
||||
*)
|
||||
echo >&2 "Unknown PLATFORM=${PLATFORM}"
|
||||
@@ -53,17 +46,10 @@ jobs:
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
|
||||
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-region: eu-central-1
|
||||
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
|
||||
role-duration-seconds: 18000 # 5 hours
|
||||
|
||||
- name: Download Neon artifact
|
||||
uses: ./.github/actions/download
|
||||
with:
|
||||
@@ -71,23 +57,23 @@ jobs:
|
||||
path: /tmp/neon/
|
||||
prefix: latest
|
||||
|
||||
# we create a table that has one row for each database that we want to restore with the status whether the restore is done
|
||||
# we create a table that has one row for each database that we want to restore with the status whether the restore is done
|
||||
- name: Create benchmark_restore_status table if it does not exist
|
||||
env:
|
||||
BENCHMARK_CONNSTR: ${{ steps.set-up-prep-connstr.outputs.connstr }}
|
||||
DATABASE_NAME: ${{ matrix.database }}
|
||||
# to avoid a race condition of multiple jobs trying to create the table at the same time,
|
||||
# to avoid a race condition of multiple jobs trying to create the table at the same time,
|
||||
# we use an advisory lock
|
||||
run: |
|
||||
${PG_BINARIES}/psql "${{ env.BENCHMARK_CONNSTR }}" -c "
|
||||
SELECT pg_advisory_lock(4711);
|
||||
SELECT pg_advisory_lock(4711);
|
||||
CREATE TABLE IF NOT EXISTS benchmark_restore_status (
|
||||
databasename text primary key,
|
||||
restore_done boolean
|
||||
);
|
||||
SELECT pg_advisory_unlock(4711);
|
||||
"
|
||||
|
||||
|
||||
- name: Check if restore is already done
|
||||
id: check-restore-done
|
||||
env:
|
||||
@@ -121,7 +107,7 @@ jobs:
|
||||
DATABASE_NAME: ${{ matrix.database }}
|
||||
run: |
|
||||
mkdir -p /tmp/dumps
|
||||
aws s3 cp s3://neon-github-dev/performance/pgdumps/$DATABASE_NAME/$DATABASE_NAME.pg_dump /tmp/dumps/
|
||||
aws s3 cp s3://neon-github-dev/performance/pgdumps/$DATABASE_NAME/$DATABASE_NAME.pg_dump /tmp/dumps/
|
||||
|
||||
- name: Replace database name in connection string
|
||||
if: steps.check-restore-done.outputs.skip != 'true'
|
||||
@@ -140,17 +126,17 @@ jobs:
|
||||
else
|
||||
new_connstr="${base_connstr}/${DATABASE_NAME}"
|
||||
fi
|
||||
echo "database_connstr=${new_connstr}" >> $GITHUB_OUTPUT
|
||||
echo "database_connstr=${new_connstr}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Restore dump
|
||||
if: steps.check-restore-done.outputs.skip != 'true'
|
||||
env:
|
||||
DATABASE_NAME: ${{ matrix.database }}
|
||||
DATABASE_CONNSTR: ${{ steps.replace-dbname.outputs.database_connstr }}
|
||||
# the following works only with larger computes:
|
||||
# the following works only with larger computes:
|
||||
# PGOPTIONS: "-c maintenance_work_mem=8388608 -c max_parallel_maintenance_workers=7"
|
||||
# we add the || true because:
|
||||
# the dumps were created with Neon and contain neon extensions that are not
|
||||
# the dumps were created with Neon and contain neon extensions that are not
|
||||
# available in RDS, so we will always report an error, but we can ignore it
|
||||
run: |
|
||||
${PG_BINARIES}/pg_restore --clean --if-exists --no-owner --jobs=4 \
|
||||
|
||||
@@ -236,7 +236,9 @@ jobs:
|
||||
|
||||
# run pageserver tests with different settings
|
||||
for io_engine in std-fs tokio-epoll-uring ; do
|
||||
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)'
|
||||
for io_buffer_alignment in 0 1 512 ; do
|
||||
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT=$io_buffer_alignment ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)'
|
||||
done
|
||||
done
|
||||
|
||||
# Run separate tests for real S3
|
||||
|
||||
119
.github/workflows/benchmarking.yml
vendored
119
.github/workflows/benchmarking.yml
vendored
@@ -12,6 +12,7 @@ on:
|
||||
# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
|
||||
# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
|
||||
- cron: '0 3 * * *' # run once a day, timezone is utc
|
||||
|
||||
workflow_dispatch: # adds ability to run this manually
|
||||
inputs:
|
||||
region_id:
|
||||
@@ -58,7 +59,7 @@ jobs:
|
||||
permissions:
|
||||
contents: write
|
||||
statuses: write
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
id-token: write # Required for OIDC authentication in azure runners
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -67,10 +68,12 @@ jobs:
|
||||
PLATFORM: "neon-staging"
|
||||
region_id: ${{ github.event.inputs.region_id || 'aws-us-east-2' }}
|
||||
RUNNER: [ self-hosted, us-east-2, x64 ]
|
||||
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
|
||||
- DEFAULT_PG_VERSION: 16
|
||||
PLATFORM: "azure-staging"
|
||||
region_id: 'azure-eastus2'
|
||||
RUNNER: [ self-hosted, eastus2, x64 ]
|
||||
IMAGE: neondatabase/build-tools:pinned
|
||||
env:
|
||||
TEST_PG_BENCH_DURATIONS_MATRIX: "300"
|
||||
TEST_PG_BENCH_SCALES_MATRIX: "10,100"
|
||||
@@ -83,10 +86,7 @@ jobs:
|
||||
|
||||
runs-on: ${{ matrix.RUNNER }}
|
||||
container:
|
||||
image: neondatabase/build-tools:pinned
|
||||
credentials:
|
||||
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
|
||||
image: ${{ matrix.IMAGE }}
|
||||
options: --init
|
||||
|
||||
steps:
|
||||
@@ -164,10 +164,6 @@ jobs:
|
||||
|
||||
replication-tests:
|
||||
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
|
||||
permissions:
|
||||
contents: write
|
||||
statuses: write
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
env:
|
||||
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
|
||||
DEFAULT_PG_VERSION: 16
|
||||
@@ -178,21 +174,12 @@ jobs:
|
||||
|
||||
runs-on: [ self-hosted, us-east-2, x64 ]
|
||||
container:
|
||||
image: neondatabase/build-tools:pinned
|
||||
credentials:
|
||||
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
|
||||
options: --init
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-region: eu-central-1
|
||||
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
|
||||
role-duration-seconds: 18000 # 5 hours
|
||||
|
||||
- name: Download Neon artifact
|
||||
uses: ./.github/actions/download
|
||||
@@ -280,7 +267,7 @@ jobs:
|
||||
region_id_default=${{ env.DEFAULT_REGION_ID }}
|
||||
runner_default='["self-hosted", "us-east-2", "x64"]'
|
||||
runner_azure='["self-hosted", "eastus2", "x64"]'
|
||||
image_default="neondatabase/build-tools:pinned"
|
||||
image_default="369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned"
|
||||
matrix='{
|
||||
"pg_version" : [
|
||||
16
|
||||
@@ -357,7 +344,7 @@ jobs:
|
||||
permissions:
|
||||
contents: write
|
||||
statuses: write
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
id-token: write # Required for OIDC authentication in azure runners
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -384,7 +371,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
- name: Configure AWS credentials # necessary on Azure runners
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-region: eu-central-1
|
||||
@@ -505,15 +492,17 @@ jobs:
|
||||
permissions:
|
||||
contents: write
|
||||
statuses: write
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
id-token: write # Required for OIDC authentication in azure runners
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- PLATFORM: "neonvm-captest-pgvector"
|
||||
RUNNER: [ self-hosted, us-east-2, x64 ]
|
||||
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
|
||||
- PLATFORM: "azure-captest-pgvector"
|
||||
RUNNER: [ self-hosted, eastus2, x64 ]
|
||||
IMAGE: neondatabase/build-tools:pinned
|
||||
|
||||
env:
|
||||
TEST_PG_BENCH_DURATIONS_MATRIX: "15m"
|
||||
@@ -522,16 +511,13 @@ jobs:
|
||||
DEFAULT_PG_VERSION: 16
|
||||
TEST_OUTPUT: /tmp/test_output
|
||||
BUILD_TYPE: remote
|
||||
|
||||
LD_LIBRARY_PATH: /home/nonroot/pg/usr/lib/x86_64-linux-gnu
|
||||
SAVE_PERF_REPORT: ${{ github.event.inputs.save_perf_report || ( github.ref_name == 'main' ) }}
|
||||
PLATFORM: ${{ matrix.PLATFORM }}
|
||||
|
||||
runs-on: ${{ matrix.RUNNER }}
|
||||
container:
|
||||
image: neondatabase/build-tools:pinned
|
||||
credentials:
|
||||
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
|
||||
image: ${{ matrix.IMAGE }}
|
||||
options: --init
|
||||
|
||||
steps:
|
||||
@@ -541,26 +527,17 @@ jobs:
|
||||
# instead of using Neon artifacts containing pgbench
|
||||
- name: Install postgresql-16 where pytest expects it
|
||||
run: |
|
||||
# Just to make it easier to test things locally on macOS (with arm64)
|
||||
arch=$(uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g')
|
||||
|
||||
cd /home/nonroot
|
||||
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.0-1.pgdg110+1_${arch}.deb"
|
||||
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.4-1.pgdg110+2_${arch}.deb"
|
||||
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.4-1.pgdg110+2_${arch}.deb"
|
||||
dpkg -x libpq5_17.0-1.pgdg110+1_${arch}.deb pg
|
||||
dpkg -x postgresql-16_16.4-1.pgdg110+2_${arch}.deb pg
|
||||
dpkg -x postgresql-client-16_16.4-1.pgdg110+2_${arch}.deb pg
|
||||
|
||||
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/libpq5_16.4-1.pgdg110%2B1_amd64.deb
|
||||
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.4-1.pgdg110%2B1_amd64.deb
|
||||
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.4-1.pgdg110%2B1_amd64.deb
|
||||
dpkg -x libpq5_16.4-1.pgdg110+1_amd64.deb pg
|
||||
dpkg -x postgresql-client-16_16.4-1.pgdg110+1_amd64.deb pg
|
||||
dpkg -x postgresql-16_16.4-1.pgdg110+1_amd64.deb pg
|
||||
mkdir -p /tmp/neon/pg_install/v16/bin
|
||||
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench
|
||||
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/psql /tmp/neon/pg_install/v16/bin/psql
|
||||
ln -s /home/nonroot/pg/usr/lib/$(uname -m)-linux-gnu /tmp/neon/pg_install/v16/lib
|
||||
|
||||
LD_LIBRARY_PATH="/home/nonroot/pg/usr/lib/$(uname -m)-linux-gnu:${LD_LIBRARY_PATH}"
|
||||
export LD_LIBRARY_PATH
|
||||
echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" >> ${GITHUB_ENV}
|
||||
|
||||
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench
|
||||
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/psql /tmp/neon/pg_install/v16/bin/psql
|
||||
ln -s /home/nonroot/pg/usr/lib/x86_64-linux-gnu /tmp/neon/pg_install/v16/lib
|
||||
/tmp/neon/pg_install/v16/bin/pgbench --version
|
||||
/tmp/neon/pg_install/v16/bin/psql --version
|
||||
|
||||
@@ -582,7 +559,7 @@ jobs:
|
||||
|
||||
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Configure AWS credentials
|
||||
- name: Configure AWS credentials # necessary on Azure runners to read/write from/to S3
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-region: eu-central-1
|
||||
@@ -643,10 +620,6 @@ jobs:
|
||||
# *_CLICKBENCH_CONNSTR: Genuine ClickBench DB with ~100M rows
|
||||
# *_CLICKBENCH_10M_CONNSTR: DB with the first 10M rows of ClickBench DB
|
||||
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
|
||||
permissions:
|
||||
contents: write
|
||||
statuses: write
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
needs: [ generate-matrices, pgbench-compare, prepare_AWS_RDS_databases ]
|
||||
|
||||
strategy:
|
||||
@@ -665,22 +638,12 @@ jobs:
|
||||
|
||||
runs-on: [ self-hosted, us-east-2, x64 ]
|
||||
container:
|
||||
image: neondatabase/build-tools:pinned
|
||||
credentials:
|
||||
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
|
||||
options: --init
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-region: eu-central-1
|
||||
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
|
||||
role-duration-seconds: 18000 # 5 hours
|
||||
|
||||
- name: Download Neon artifact
|
||||
uses: ./.github/actions/download
|
||||
with:
|
||||
@@ -751,10 +714,6 @@ jobs:
|
||||
#
|
||||
# *_TPCH_S10_CONNSTR: DB generated with scale factor 10 (~10 GB)
|
||||
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
|
||||
permissions:
|
||||
contents: write
|
||||
statuses: write
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
needs: [ generate-matrices, clickbench-compare, prepare_AWS_RDS_databases ]
|
||||
|
||||
strategy:
|
||||
@@ -772,22 +731,12 @@ jobs:
|
||||
|
||||
runs-on: [ self-hosted, us-east-2, x64 ]
|
||||
container:
|
||||
image: neondatabase/build-tools:pinned
|
||||
credentials:
|
||||
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
|
||||
options: --init
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-region: eu-central-1
|
||||
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
|
||||
role-duration-seconds: 18000 # 5 hours
|
||||
|
||||
- name: Download Neon artifact
|
||||
uses: ./.github/actions/download
|
||||
with:
|
||||
@@ -857,10 +806,6 @@ jobs:
|
||||
|
||||
user-examples-compare:
|
||||
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
|
||||
permissions:
|
||||
contents: write
|
||||
statuses: write
|
||||
id-token: write # aws-actions/configure-aws-credentials
|
||||
needs: [ generate-matrices, tpch-compare, prepare_AWS_RDS_databases ]
|
||||
|
||||
strategy:
|
||||
@@ -877,22 +822,12 @@ jobs:
|
||||
|
||||
runs-on: [ self-hosted, us-east-2, x64 ]
|
||||
container:
|
||||
image: neondatabase/build-tools:pinned
|
||||
credentials:
|
||||
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
|
||||
options: --init
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
aws-region: eu-central-1
|
||||
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
|
||||
role-duration-seconds: 18000 # 5 hours
|
||||
|
||||
- name: Download Neon artifact
|
||||
uses: ./.github/actions/download
|
||||
with:
|
||||
|
||||
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1321,7 +1321,6 @@ dependencies = [
|
||||
"clap",
|
||||
"comfy-table",
|
||||
"compute_api",
|
||||
"futures",
|
||||
"humantime",
|
||||
"humantime-serde",
|
||||
"hyper 0.14.30",
|
||||
|
||||
@@ -13,9 +13,6 @@ RUN useradd -ms /bin/bash nonroot -b /home
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# System deps
|
||||
#
|
||||
# 'gdb' is included so that we get backtraces of core dumps produced in
|
||||
# regression tests
|
||||
RUN set -e \
|
||||
&& apt update \
|
||||
&& apt install -y \
|
||||
@@ -27,7 +24,6 @@ RUN set -e \
|
||||
cmake \
|
||||
curl \
|
||||
flex \
|
||||
gdb \
|
||||
git \
|
||||
gnupg \
|
||||
gzip \
|
||||
|
||||
@@ -11,10 +11,6 @@ commands:
|
||||
user: root
|
||||
sysvInitAction: sysinit
|
||||
shell: 'chmod 711 /neonvm/bin/resize-swap'
|
||||
- name: chmod-set-disk-quota
|
||||
user: root
|
||||
sysvInitAction: sysinit
|
||||
shell: 'chmod 711 /neonvm/bin/set-disk-quota'
|
||||
- name: pgbouncer
|
||||
user: postgres
|
||||
sysvInitAction: respawn
|
||||
@@ -34,12 +30,11 @@ commands:
|
||||
shutdownHook: |
|
||||
su -p postgres --session-command '/usr/local/bin/pg_ctl stop -D /var/db/postgres/compute/pgdata -m fast --wait -t 10'
|
||||
files:
|
||||
- filename: compute_ctl-sudoers
|
||||
- filename: compute_ctl-resize-swap
|
||||
content: |
|
||||
# Allow postgres user (which is what compute_ctl runs as) to run /neonvm/bin/resize-swap
|
||||
# and /neonvm/bin/set-disk-quota as root without requiring entering a password (NOPASSWD),
|
||||
# regardless of hostname (ALL)
|
||||
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap, /neonvm/bin/set-disk-quota
|
||||
# as root without requiring entering a password (NOPASSWD), regardless of hostname (ALL)
|
||||
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap
|
||||
- filename: cgconfig.conf
|
||||
content: |
|
||||
# Configuration for cgroups in VM compute nodes
|
||||
@@ -105,7 +100,7 @@ merge: |
|
||||
&& apt install --no-install-recommends -y \
|
||||
sudo \
|
||||
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||
COPY compute_ctl-sudoers /etc/sudoers.d/compute_ctl-sudoers
|
||||
COPY compute_ctl-resize-swap /etc/sudoers.d/compute_ctl-resize-swap
|
||||
|
||||
COPY cgconfig.conf /etc/cgconfig.conf
|
||||
|
||||
|
||||
@@ -44,7 +44,6 @@ use std::{thread, time::Duration};
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::Utc;
|
||||
use clap::Arg;
|
||||
use compute_tools::disk_quota::set_disk_quota;
|
||||
use compute_tools::lsn_lease::launch_lsn_lease_bg_task_for_static;
|
||||
use signal_hook::consts::{SIGQUIT, SIGTERM};
|
||||
use signal_hook::{consts::SIGINT, iterator::Signals};
|
||||
@@ -152,7 +151,6 @@ fn process_cli(matches: &clap::ArgMatches) -> Result<ProcessCliResult> {
|
||||
let spec_json = matches.get_one::<String>("spec");
|
||||
let spec_path = matches.get_one::<String>("spec-path");
|
||||
let resize_swap_on_bind = matches.get_flag("resize-swap-on-bind");
|
||||
let set_disk_quota_for_fs = matches.get_one::<String>("set-disk-quota-for-fs");
|
||||
|
||||
Ok(ProcessCliResult {
|
||||
connstr,
|
||||
@@ -163,7 +161,6 @@ fn process_cli(matches: &clap::ArgMatches) -> Result<ProcessCliResult> {
|
||||
spec_json,
|
||||
spec_path,
|
||||
resize_swap_on_bind,
|
||||
set_disk_quota_for_fs,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -176,7 +173,6 @@ struct ProcessCliResult<'clap> {
|
||||
spec_json: Option<&'clap String>,
|
||||
spec_path: Option<&'clap String>,
|
||||
resize_swap_on_bind: bool,
|
||||
set_disk_quota_for_fs: Option<&'clap String>,
|
||||
}
|
||||
|
||||
fn startup_context_from_env() -> Option<opentelemetry::ContextGuard> {
|
||||
@@ -297,7 +293,6 @@ fn wait_spec(
|
||||
pgbin,
|
||||
ext_remote_storage,
|
||||
resize_swap_on_bind,
|
||||
set_disk_quota_for_fs,
|
||||
http_port,
|
||||
..
|
||||
}: ProcessCliResult,
|
||||
@@ -378,7 +373,6 @@ fn wait_spec(
|
||||
compute,
|
||||
http_port,
|
||||
resize_swap_on_bind,
|
||||
set_disk_quota_for_fs: set_disk_quota_for_fs.cloned(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -387,7 +381,6 @@ struct WaitSpecResult {
|
||||
// passed through from ProcessCliResult
|
||||
http_port: u16,
|
||||
resize_swap_on_bind: bool,
|
||||
set_disk_quota_for_fs: Option<String>,
|
||||
}
|
||||
|
||||
fn start_postgres(
|
||||
@@ -397,7 +390,6 @@ fn start_postgres(
|
||||
compute,
|
||||
http_port,
|
||||
resize_swap_on_bind,
|
||||
set_disk_quota_for_fs,
|
||||
}: WaitSpecResult,
|
||||
) -> Result<(Option<PostgresHandle>, StartPostgresResult)> {
|
||||
// We got all we need, update the state.
|
||||
@@ -411,7 +403,6 @@ fn start_postgres(
|
||||
);
|
||||
// before we release the mutex, fetch the swap size (if any) for later.
|
||||
let swap_size_bytes = state.pspec.as_ref().unwrap().spec.swap_size_bytes;
|
||||
let disk_quota_bytes = state.pspec.as_ref().unwrap().spec.disk_quota_bytes;
|
||||
drop(state);
|
||||
|
||||
// Launch remaining service threads
|
||||
@@ -431,8 +422,8 @@ fn start_postgres(
|
||||
// OOM-killed during startup because swap wasn't available yet.
|
||||
match resize_swap(size_bytes) {
|
||||
Ok(()) => {
|
||||
let size_mib = size_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
|
||||
info!(%size_bytes, %size_mib, "resized swap");
|
||||
let size_gib = size_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
|
||||
info!(%size_bytes, %size_gib, "resized swap");
|
||||
}
|
||||
Err(err) => {
|
||||
let err = err.context("failed to resize swap");
|
||||
@@ -441,29 +432,10 @@ fn start_postgres(
|
||||
// Mark compute startup as failed; don't try to start postgres, and report this
|
||||
// error to the control plane when it next asks.
|
||||
prestartup_failed = true;
|
||||
compute.set_failed_status(err);
|
||||
delay_exit = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set disk quota if the compute spec says so
|
||||
if let (Some(disk_quota_bytes), Some(disk_quota_fs_mountpoint)) =
|
||||
(disk_quota_bytes, set_disk_quota_for_fs)
|
||||
{
|
||||
match set_disk_quota(disk_quota_bytes, &disk_quota_fs_mountpoint) {
|
||||
Ok(()) => {
|
||||
let size_mib = disk_quota_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
|
||||
info!(%disk_quota_bytes, %size_mib, "set disk quota");
|
||||
}
|
||||
Err(err) => {
|
||||
let err = err.context("failed to set disk quota");
|
||||
error!("{err:#}");
|
||||
|
||||
// Mark compute startup as failed; don't try to start postgres, and report this
|
||||
// error to the control plane when it next asks.
|
||||
prestartup_failed = true;
|
||||
compute.set_failed_status(err);
|
||||
let mut state = compute.state.lock().unwrap();
|
||||
state.error = Some(format!("{err:?}"));
|
||||
state.status = ComputeStatus::Failed;
|
||||
compute.state_changed.notify_all();
|
||||
delay_exit = true;
|
||||
}
|
||||
}
|
||||
@@ -478,7 +450,16 @@ fn start_postgres(
|
||||
Ok(pg) => Some(pg),
|
||||
Err(err) => {
|
||||
error!("could not start the compute node: {:#}", err);
|
||||
compute.set_failed_status(err);
|
||||
let mut state = compute.state.lock().unwrap();
|
||||
state.error = Some(format!("{:?}", err));
|
||||
state.status = ComputeStatus::Failed;
|
||||
// Notify others that Postgres failed to start. In case of configuring the
|
||||
// empty compute, it's likely that API handler is still waiting for compute
|
||||
// state change. With this we will notify it that compute is in Failed state,
|
||||
// so control plane will know about it earlier and record proper error instead
|
||||
// of timeout.
|
||||
compute.state_changed.notify_all();
|
||||
drop(state); // unlock
|
||||
delay_exit = true;
|
||||
None
|
||||
}
|
||||
@@ -769,11 +750,6 @@ fn cli() -> clap::Command {
|
||||
.long("resize-swap-on-bind")
|
||||
.action(clap::ArgAction::SetTrue),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("set-disk-quota-for-fs")
|
||||
.long("set-disk-quota-for-fs")
|
||||
.value_name("SET_DISK_QUOTA_FOR_FS")
|
||||
)
|
||||
}
|
||||
|
||||
/// When compute_ctl is killed, send also termination signal to sync-safekeepers
|
||||
|
||||
@@ -10,7 +10,6 @@ use std::sync::atomic::AtomicU32;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::{Condvar, Mutex, RwLock};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
@@ -306,13 +305,6 @@ impl ComputeNode {
|
||||
self.state_changed.notify_all();
|
||||
}
|
||||
|
||||
pub fn set_failed_status(&self, err: anyhow::Error) {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
state.error = Some(format!("{err:?}"));
|
||||
state.status = ComputeStatus::Failed;
|
||||
self.state_changed.notify_all();
|
||||
}
|
||||
|
||||
pub fn get_status(&self) -> ComputeStatus {
|
||||
self.state.lock().unwrap().status
|
||||
}
|
||||
@@ -718,7 +710,7 @@ impl ComputeNode {
|
||||
info!("running initdb");
|
||||
let initdb_bin = Path::new(&self.pgbin).parent().unwrap().join("initdb");
|
||||
Command::new(initdb_bin)
|
||||
.args(["--pgdata", pgdata])
|
||||
.args(["-D", pgdata])
|
||||
.output()
|
||||
.expect("cannot start initdb process");
|
||||
|
||||
@@ -1131,9 +1123,6 @@ impl ComputeNode {
|
||||
//
|
||||
// Use that as a default location and pattern, except macos where core dumps are written
|
||||
// to /cores/ directory by default.
|
||||
//
|
||||
// With default Linux settings, the core dump file is called just "core", so check for
|
||||
// that too.
|
||||
pub fn check_for_core_dumps(&self) -> Result<()> {
|
||||
let core_dump_dir = match std::env::consts::OS {
|
||||
"macos" => Path::new("/cores/"),
|
||||
@@ -1145,17 +1134,8 @@ impl ComputeNode {
|
||||
let files = fs::read_dir(core_dump_dir)?;
|
||||
let cores = files.filter_map(|entry| {
|
||||
let entry = entry.ok()?;
|
||||
|
||||
let is_core_dump = match entry.file_name().to_str()? {
|
||||
n if n.starts_with("core.") => true,
|
||||
"core" => true,
|
||||
_ => false,
|
||||
};
|
||||
if is_core_dump {
|
||||
Some(entry.path())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
let _ = entry.file_name().to_str()?.strip_prefix("core.")?;
|
||||
Some(entry.path())
|
||||
});
|
||||
|
||||
// Print backtrace for each core dump
|
||||
@@ -1406,36 +1386,6 @@ LIMIT 100",
|
||||
}
|
||||
Ok(remote_ext_metrics)
|
||||
}
|
||||
|
||||
/// Waits until current thread receives a state changed notification and
|
||||
/// the pageserver connection strings has changed.
|
||||
///
|
||||
/// The operation will time out after a specified duration.
|
||||
pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) {
|
||||
let state = self.state.lock().unwrap();
|
||||
let old_pageserver_connstr = state
|
||||
.pspec
|
||||
.as_ref()
|
||||
.expect("spec must be set")
|
||||
.pageserver_connstr
|
||||
.clone();
|
||||
let mut unchanged = true;
|
||||
let _ = self
|
||||
.state_changed
|
||||
.wait_timeout_while(state, duration, |s| {
|
||||
let pageserver_connstr = &s
|
||||
.pspec
|
||||
.as_ref()
|
||||
.expect("spec must be set")
|
||||
.pageserver_connstr;
|
||||
unchanged = pageserver_connstr == &old_pageserver_connstr;
|
||||
unchanged
|
||||
})
|
||||
.unwrap();
|
||||
if !unchanged {
|
||||
info!("Pageserver config changed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_termination_signal() {
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
use anyhow::Context;
|
||||
|
||||
pub const DISK_QUOTA_BIN: &str = "/neonvm/bin/set-disk-quota";
|
||||
|
||||
/// If size_bytes is 0, it disables the quota. Otherwise, it sets filesystem quota to size_bytes.
|
||||
/// `fs_mountpoint` should point to the mountpoint of the filesystem where the quota should be set.
|
||||
pub fn set_disk_quota(size_bytes: u64, fs_mountpoint: &str) -> anyhow::Result<()> {
|
||||
let size_kb = size_bytes / 1024;
|
||||
// run `/neonvm/bin/set-disk-quota {size_kb} {mountpoint}`
|
||||
let child_result = std::process::Command::new("/usr/bin/sudo")
|
||||
.arg(DISK_QUOTA_BIN)
|
||||
.arg(size_kb.to_string())
|
||||
.arg(fs_mountpoint)
|
||||
.spawn();
|
||||
|
||||
child_result
|
||||
.context("spawn() failed")
|
||||
.and_then(|mut child| child.wait().context("wait() failed"))
|
||||
.and_then(|status| match status.success() {
|
||||
true => Ok(()),
|
||||
false => Err(anyhow::anyhow!("process exited with {status}")),
|
||||
})
|
||||
// wrap any prior error with the overall context that we couldn't run the command
|
||||
.with_context(|| format!("could not run `/usr/bin/sudo {DISK_QUOTA_BIN}`"))
|
||||
}
|
||||
@@ -10,7 +10,6 @@ pub mod http;
|
||||
pub mod logger;
|
||||
pub mod catalog;
|
||||
pub mod compute;
|
||||
pub mod disk_quota;
|
||||
pub mod extension_server;
|
||||
pub mod lsn_lease;
|
||||
mod migration;
|
||||
|
||||
@@ -57,10 +57,10 @@ fn lsn_lease_bg_task(
|
||||
.max(valid_duration / 2);
|
||||
|
||||
info!(
|
||||
"Request succeeded, sleeping for {} seconds",
|
||||
"Succeeded, sleeping for {} seconds",
|
||||
sleep_duration.as_secs()
|
||||
);
|
||||
compute.wait_timeout_while_pageserver_connstr_unchanged(sleep_duration);
|
||||
thread::sleep(sleep_duration);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +89,10 @@ fn acquire_lsn_lease_with_retry(
|
||||
.map(|connstr| {
|
||||
let mut config = postgres::Config::from_str(connstr).expect("Invalid connstr");
|
||||
if let Some(storage_auth_token) = &spec.storage_auth_token {
|
||||
info!("Got storage auth token from spec file");
|
||||
config.password(storage_auth_token.clone());
|
||||
} else {
|
||||
info!("Storage auth token not set");
|
||||
}
|
||||
config
|
||||
})
|
||||
@@ -105,11 +108,9 @@ fn acquire_lsn_lease_with_retry(
|
||||
bail!("Permanent error: lease could not be obtained, LSN is behind the GC cutoff");
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to acquire lsn lease: {e} (attempt {attempts})");
|
||||
warn!("Failed to acquire lsn lease: {e} (attempt {attempts}");
|
||||
|
||||
compute.wait_timeout_while_pageserver_connstr_unchanged(Duration::from_millis(
|
||||
retry_period_ms as u64,
|
||||
));
|
||||
thread::sleep(Duration::from_millis(retry_period_ms as u64));
|
||||
retry_period_ms *= 1.5;
|
||||
retry_period_ms = retry_period_ms.min(MAX_RETRY_PERIOD_MS);
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ anyhow.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
comfy-table.workspace = true
|
||||
futures.workspace = true
|
||||
humantime.workspace = true
|
||||
nix.workspace = true
|
||||
once_cell.workspace = true
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,94 +0,0 @@
|
||||
//! Branch mappings for convenience
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use utils::id::{TenantId, TenantTimelineId, TimelineId};
|
||||
|
||||
/// Keep human-readable aliases in memory (and persist them to config XXX), to hide tenant/timeline hex strings from the user.
|
||||
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(default, deny_unknown_fields)]
|
||||
pub struct BranchMappings {
|
||||
/// Default tenant ID to use with the 'neon_local' command line utility, when
|
||||
/// --tenant_id is not explicitly specified. This comes from the branches.
|
||||
pub default_tenant_id: Option<TenantId>,
|
||||
|
||||
// A `HashMap<String, HashMap<TenantId, TimelineId>>` would be more appropriate here,
|
||||
// but deserialization into a generic toml object as `toml::Value::try_from` fails with an error.
|
||||
// https://toml.io/en/v1.0.0 does not contain a concept of "a table inside another table".
|
||||
pub mappings: HashMap<String, Vec<(TenantId, TimelineId)>>,
|
||||
}
|
||||
|
||||
impl BranchMappings {
|
||||
pub fn register_branch_mapping(
|
||||
&mut self,
|
||||
branch_name: String,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> anyhow::Result<()> {
|
||||
let existing_values = self.mappings.entry(branch_name.clone()).or_default();
|
||||
|
||||
let existing_ids = existing_values
|
||||
.iter()
|
||||
.find(|(existing_tenant_id, _)| existing_tenant_id == &tenant_id);
|
||||
|
||||
if let Some((_, old_timeline_id)) = existing_ids {
|
||||
if old_timeline_id == &timeline_id {
|
||||
Ok(())
|
||||
} else {
|
||||
bail!("branch '{branch_name}' is already mapped to timeline {old_timeline_id}, cannot map to another timeline {timeline_id}");
|
||||
}
|
||||
} else {
|
||||
existing_values.push((tenant_id, timeline_id));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_branch_timeline_id(
|
||||
&self,
|
||||
branch_name: &str,
|
||||
tenant_id: TenantId,
|
||||
) -> Option<TimelineId> {
|
||||
// If it looks like a timeline ID, return it as it is
|
||||
if let Ok(timeline_id) = branch_name.parse::<TimelineId>() {
|
||||
return Some(timeline_id);
|
||||
}
|
||||
|
||||
self.mappings
|
||||
.get(branch_name)?
|
||||
.iter()
|
||||
.find(|(mapped_tenant_id, _)| mapped_tenant_id == &tenant_id)
|
||||
.map(|&(_, timeline_id)| timeline_id)
|
||||
.map(TimelineId::from)
|
||||
}
|
||||
|
||||
pub fn timeline_name_mappings(&self) -> HashMap<TenantTimelineId, String> {
|
||||
self.mappings
|
||||
.iter()
|
||||
.flat_map(|(name, tenant_timelines)| {
|
||||
tenant_timelines.iter().map(|&(tenant_id, timeline_id)| {
|
||||
(TenantTimelineId::new(tenant_id, timeline_id), name.clone())
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn persist(&self, path: &Path) -> anyhow::Result<()> {
|
||||
let content = &toml::to_string_pretty(self)?;
|
||||
fs::write(path, content).with_context(|| {
|
||||
format!(
|
||||
"Failed to write branch information into path '{}'",
|
||||
path.display()
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load(path: &Path) -> anyhow::Result<BranchMappings> {
|
||||
let branches_file_contents = fs::read_to_string(path)?;
|
||||
Ok(toml::from_str(branches_file_contents.as_str())?)
|
||||
}
|
||||
}
|
||||
@@ -561,7 +561,6 @@ impl Endpoint {
|
||||
operation_uuid: None,
|
||||
features: self.features.clone(),
|
||||
swap_size_bytes: None,
|
||||
disk_quota_bytes: None,
|
||||
cluster: Cluster {
|
||||
cluster_id: None, // project ID: not used
|
||||
name: None, // project name: not used
|
||||
|
||||
@@ -113,7 +113,7 @@ impl SafekeeperNode {
|
||||
|
||||
pub async fn start(
|
||||
&self,
|
||||
extra_opts: &[String],
|
||||
extra_opts: Vec<String>,
|
||||
retry_timeout: &Duration,
|
||||
) -> anyhow::Result<()> {
|
||||
print!(
|
||||
@@ -196,7 +196,7 @@ impl SafekeeperNode {
|
||||
]);
|
||||
}
|
||||
|
||||
args.extend_from_slice(extra_opts);
|
||||
args.extend(extra_opts);
|
||||
|
||||
background_process::start_process(
|
||||
&format!("safekeeper-{id}"),
|
||||
|
||||
@@ -347,7 +347,7 @@ impl StorageController {
|
||||
|
||||
if !tokio::fs::try_exists(&pg_data_path).await? {
|
||||
let initdb_args = [
|
||||
"--pgdata",
|
||||
"-D",
|
||||
pg_data_path.as_ref(),
|
||||
"--username",
|
||||
&username(),
|
||||
|
||||
@@ -50,16 +50,6 @@ pub struct ComputeSpec {
|
||||
#[serde(default)]
|
||||
pub swap_size_bytes: Option<u64>,
|
||||
|
||||
/// If compute_ctl was passed `--set-disk-quota-for-fs`, a value of `Some(_)` instructs
|
||||
/// compute_ctl to run `/neonvm/bin/set-disk-quota` with the given size and fs, when the
|
||||
/// spec is first received.
|
||||
///
|
||||
/// Both this field and `--set-disk-quota-for-fs` are required, so that the control plane's
|
||||
/// spec generation doesn't need to be aware of the actual compute it's running on, while
|
||||
/// guaranteeing gradual rollout of disk quota.
|
||||
#[serde(default)]
|
||||
pub disk_quota_bytes: Option<u64>,
|
||||
|
||||
/// Expected cluster state at the end of transition process.
|
||||
pub cluster: Cluster,
|
||||
pub delta_operations: Option<Vec<DeltaOp>>,
|
||||
|
||||
@@ -104,7 +104,7 @@ pub struct ConfigToml {
|
||||
pub image_compression: ImageCompressionAlgorithm,
|
||||
pub ephemeral_bytes_per_memory_kb: usize,
|
||||
pub l0_flush: Option<crate::models::L0FlushConfig>,
|
||||
pub virtual_file_io_mode: Option<crate::models::virtual_file::IoMode>,
|
||||
pub virtual_file_direct_io: crate::models::virtual_file::DirectIoMode,
|
||||
pub io_buffer_alignment: usize,
|
||||
}
|
||||
|
||||
@@ -381,7 +381,7 @@ impl Default for ConfigToml {
|
||||
image_compression: (DEFAULT_IMAGE_COMPRESSION),
|
||||
ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB),
|
||||
l0_flush: None,
|
||||
virtual_file_io_mode: None,
|
||||
virtual_file_direct_io: crate::models::virtual_file::DirectIoMode::default(),
|
||||
|
||||
io_buffer_alignment: DEFAULT_IO_BUFFER_ALIGNMENT,
|
||||
|
||||
|
||||
@@ -972,6 +972,8 @@ pub struct TopTenantShardsResponse {
|
||||
}
|
||||
|
||||
pub mod virtual_file {
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
@@ -992,51 +994,50 @@ pub mod virtual_file {
|
||||
}
|
||||
|
||||
/// Direct IO modes for a pageserver.
|
||||
#[derive(
|
||||
Copy,
|
||||
Clone,
|
||||
PartialEq,
|
||||
Eq,
|
||||
Hash,
|
||||
strum_macros::EnumString,
|
||||
strum_macros::Display,
|
||||
serde_with::DeserializeFromStr,
|
||||
serde_with::SerializeDisplay,
|
||||
Debug,
|
||||
)]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
#[repr(u8)]
|
||||
pub enum IoMode {
|
||||
/// Uses buffered IO.
|
||||
Buffered,
|
||||
/// Uses direct IO, error out if the operation fails.
|
||||
#[cfg(target_os = "linux")]
|
||||
Direct,
|
||||
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
|
||||
#[serde(tag = "mode", rename_all = "kebab-case", deny_unknown_fields)]
|
||||
pub enum DirectIoMode {
|
||||
/// Direct IO disabled (uses usual buffered IO).
|
||||
#[default]
|
||||
Disabled,
|
||||
/// Direct IO disabled (performs checks and perf simulations).
|
||||
Evaluate {
|
||||
/// Alignment check level
|
||||
alignment_check: DirectIoAlignmentCheckLevel,
|
||||
/// Latency padded for performance simulation.
|
||||
latency_padding: DirectIoLatencyPadding,
|
||||
},
|
||||
/// Direct IO enabled.
|
||||
Enabled {
|
||||
/// Actions to perform on alignment error.
|
||||
on_alignment_error: DirectIoOnAlignmentErrorAction,
|
||||
},
|
||||
}
|
||||
|
||||
impl IoMode {
|
||||
#[cfg(target_os = "linux")]
|
||||
pub const fn preferred() -> Self {
|
||||
Self::Direct
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub const fn preferred() -> Self {
|
||||
Self::Buffered
|
||||
}
|
||||
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum DirectIoAlignmentCheckLevel {
|
||||
#[default]
|
||||
Error,
|
||||
Log,
|
||||
None,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for IoMode {
|
||||
type Error = u8;
|
||||
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum DirectIoOnAlignmentErrorAction {
|
||||
Error,
|
||||
#[default]
|
||||
FallbackToBuffered,
|
||||
}
|
||||
|
||||
fn try_from(value: u8) -> Result<Self, Self::Error> {
|
||||
Ok(match value {
|
||||
v if v == (IoMode::Buffered as u8) => IoMode::Buffered,
|
||||
#[cfg(target_os = "linux")]
|
||||
v if v == (IoMode::Direct as u8) => IoMode::Direct,
|
||||
x => return Err(x),
|
||||
})
|
||||
}
|
||||
#[derive(Debug, PartialEq, Eq, Clone, serde::Deserialize, serde::Serialize, Default)]
|
||||
#[serde(tag = "type", rename_all = "kebab-case")]
|
||||
pub enum DirectIoLatencyPadding {
|
||||
/// Pad virtual file operations with IO to a fake file.
|
||||
FakeFileRW { path: PathBuf },
|
||||
#[default]
|
||||
None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -93,9 +93,9 @@ impl Conf {
|
||||
);
|
||||
let output = self
|
||||
.new_pg_command("initdb")?
|
||||
.arg("--pgdata")
|
||||
.arg("-D")
|
||||
.arg(&self.datadir)
|
||||
.args(["--username", "postgres", "--no-instructions", "--no-sync"])
|
||||
.args(["-U", "postgres", "--no-instructions", "--no-sync"])
|
||||
.output()?;
|
||||
debug!("initdb output: {:?}", output);
|
||||
ensure!(
|
||||
|
||||
@@ -164,10 +164,12 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
let conf: &'static PageServerConf = Box::leak(Box::new(
|
||||
pageserver::config::PageServerConf::dummy_conf(temp_dir.path().to_path_buf()),
|
||||
));
|
||||
|
||||
let align = pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
|
||||
virtual_file::init(16384, virtual_file::io_engine_for_bench(), align);
|
||||
page_cache::init(conf.page_cache_size, align);
|
||||
virtual_file::init(
|
||||
16384,
|
||||
virtual_file::io_engine_for_bench(),
|
||||
pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT,
|
||||
);
|
||||
page_cache::init(conf.page_cache_size);
|
||||
|
||||
{
|
||||
let mut group = c.benchmark_group("ingest-small-values");
|
||||
|
||||
@@ -550,19 +550,6 @@ impl Client {
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
|
||||
/// Configs io mode at runtime.
|
||||
pub async fn put_io_mode(
|
||||
&self,
|
||||
mode: &pageserver_api::models::virtual_file::IoMode,
|
||||
) -> Result<()> {
|
||||
let uri = format!("{}/v1/io_mode", self.mgmt_api_endpoint);
|
||||
self.request(Method::PUT, uri, mode)
|
||||
.await?
|
||||
.json()
|
||||
.await
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
|
||||
pub async fn get_utilization(&self) -> Result<PageserverUtilization> {
|
||||
let uri = format!("{}/v1/utilization", self.mgmt_api_endpoint);
|
||||
self.get(uri)
|
||||
@@ -749,22 +736,4 @@ impl Client {
|
||||
.await
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
|
||||
pub async fn timeline_init_lsn_lease(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
timeline_id: TimelineId,
|
||||
lsn: Lsn,
|
||||
) -> Result<LsnLease> {
|
||||
let uri = format!(
|
||||
"{}/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/lsn_lease",
|
||||
self.mgmt_api_endpoint,
|
||||
);
|
||||
|
||||
self.request(Method::POST, &uri, LsnLeaseRequest { lsn })
|
||||
.await?
|
||||
.json()
|
||||
.await
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,10 +151,13 @@ pub(crate) async fn main(cmd: &AnalyzeLayerMapCmd) -> Result<()> {
|
||||
let max_holes = cmd.max_holes.unwrap_or(DEFAULT_MAX_HOLES);
|
||||
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
|
||||
|
||||
let align = pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
|
||||
// Initialize virtual_file (file desriptor cache) and page cache which are needed to access layer persistent B-Tree.
|
||||
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, align);
|
||||
pageserver::page_cache::init(100, align);
|
||||
pageserver::virtual_file::init(
|
||||
10,
|
||||
virtual_file::api::IoEngineKind::StdFs,
|
||||
pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT,
|
||||
);
|
||||
pageserver::page_cache::init(100);
|
||||
|
||||
let mut total_delta_layers = 0usize;
|
||||
let mut total_image_layers = 0usize;
|
||||
|
||||
@@ -59,9 +59,8 @@ pub(crate) enum LayerCmd {
|
||||
|
||||
async fn read_delta_file(path: impl AsRef<Path>, ctx: &RequestContext) -> Result<()> {
|
||||
let path = Utf8Path::from_path(path.as_ref()).expect("non-Unicode path");
|
||||
let align = pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
|
||||
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, align);
|
||||
page_cache::init(100, align);
|
||||
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, 1);
|
||||
page_cache::init(100);
|
||||
let file = VirtualFile::open(path, ctx).await?;
|
||||
let file_id = page_cache::next_file_id();
|
||||
let block_reader = FileBlockReader::new(&file, file_id);
|
||||
@@ -191,10 +190,12 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> {
|
||||
new_tenant_id,
|
||||
new_timeline_id,
|
||||
} => {
|
||||
let align = pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
|
||||
|
||||
pageserver::virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, align);
|
||||
pageserver::page_cache::init(100, align);
|
||||
pageserver::virtual_file::init(
|
||||
10,
|
||||
virtual_file::api::IoEngineKind::StdFs,
|
||||
pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT,
|
||||
);
|
||||
pageserver::page_cache::init(100);
|
||||
|
||||
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
|
||||
|
||||
|
||||
@@ -205,9 +205,12 @@ fn read_pg_control_file(control_file_path: &Utf8Path) -> anyhow::Result<()> {
|
||||
|
||||
async fn print_layerfile(path: &Utf8Path) -> anyhow::Result<()> {
|
||||
// Basic initialization of things that don't change after startup
|
||||
let align = DEFAULT_IO_BUFFER_ALIGNMENT;
|
||||
virtual_file::init(10, virtual_file::api::IoEngineKind::StdFs, align);
|
||||
page_cache::init(100, align);
|
||||
virtual_file::init(
|
||||
10,
|
||||
virtual_file::api::IoEngineKind::StdFs,
|
||||
DEFAULT_IO_BUFFER_ALIGNMENT,
|
||||
);
|
||||
page_cache::init(100);
|
||||
let ctx = RequestContext::new(TaskKind::DebugTool, DownloadBehavior::Error);
|
||||
dump_layerfile_from_path(path, true, &ctx).await
|
||||
}
|
||||
|
||||
@@ -63,10 +63,6 @@ pub(crate) struct Args {
|
||||
#[clap(long)]
|
||||
set_io_alignment: Option<usize>,
|
||||
|
||||
/// Before starting the benchmark, live-reconfigure the pageserver to use specified io mode (buffered vs. direct).
|
||||
#[clap(long)]
|
||||
set_io_mode: Option<pageserver_api::models::virtual_file::IoMode>,
|
||||
|
||||
targets: Option<Vec<TenantTimelineId>>,
|
||||
}
|
||||
|
||||
@@ -137,10 +133,6 @@ async fn main_impl(
|
||||
mgmt_api_client.put_io_alignment(align).await?;
|
||||
}
|
||||
|
||||
if let Some(mode) = &args.set_io_mode {
|
||||
mgmt_api_client.put_io_mode(mode).await?;
|
||||
}
|
||||
|
||||
// discover targets
|
||||
let timelines: Vec<TenantTimelineId> = crate::util::cli::targets::discover(
|
||||
&mgmt_api_client,
|
||||
|
||||
@@ -125,7 +125,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
// after setting up logging, log the effective IO engine choice and read path implementations
|
||||
info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine");
|
||||
info!(?conf.virtual_file_io_mode, "starting with virtual_file Direct IO settings");
|
||||
info!(?conf.virtual_file_direct_io, "starting with virtual_file Direct IO settings");
|
||||
info!(?conf.io_buffer_alignment, "starting with setting for IO buffer alignment");
|
||||
|
||||
// The tenants directory contains all the pageserver local disk state.
|
||||
@@ -173,7 +173,7 @@ fn main() -> anyhow::Result<()> {
|
||||
conf.virtual_file_io_engine,
|
||||
conf.io_buffer_alignment,
|
||||
);
|
||||
page_cache::init(conf.page_cache_size, conf.io_buffer_alignment);
|
||||
page_cache::init(conf.page_cache_size);
|
||||
|
||||
start_pageserver(launch_ts, conf).context("Failed to start pageserver")?;
|
||||
|
||||
|
||||
@@ -174,7 +174,7 @@ pub struct PageServerConf {
|
||||
pub l0_flush: crate::l0_flush::L0FlushConfig,
|
||||
|
||||
/// Direct IO settings
|
||||
pub virtual_file_io_mode: virtual_file::IoMode,
|
||||
pub virtual_file_direct_io: virtual_file::DirectIoMode,
|
||||
|
||||
pub io_buffer_alignment: usize,
|
||||
}
|
||||
@@ -325,7 +325,7 @@ impl PageServerConf {
|
||||
image_compression,
|
||||
ephemeral_bytes_per_memory_kb,
|
||||
l0_flush,
|
||||
virtual_file_io_mode,
|
||||
virtual_file_direct_io,
|
||||
concurrent_tenant_warmup,
|
||||
concurrent_tenant_size_logical_size_queries,
|
||||
virtual_file_io_engine,
|
||||
@@ -368,6 +368,7 @@ impl PageServerConf {
|
||||
max_vectored_read_bytes,
|
||||
image_compression,
|
||||
ephemeral_bytes_per_memory_kb,
|
||||
virtual_file_direct_io,
|
||||
io_buffer_alignment,
|
||||
|
||||
// ------------------------------------------------------------
|
||||
@@ -407,7 +408,6 @@ impl PageServerConf {
|
||||
l0_flush: l0_flush
|
||||
.map(crate::l0_flush::L0FlushConfig::from)
|
||||
.unwrap_or_default(),
|
||||
virtual_file_io_mode: virtual_file_io_mode.unwrap_or(virtual_file::IoMode::preferred()),
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------
|
||||
|
||||
@@ -17,7 +17,6 @@ use hyper::header;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, Request, Response, Uri};
|
||||
use metrics::launch_timestamp::LaunchTimestamp;
|
||||
use pageserver_api::models::virtual_file::IoMode;
|
||||
use pageserver_api::models::AuxFilePolicy;
|
||||
use pageserver_api::models::DownloadRemoteLayersTaskSpawnRequest;
|
||||
use pageserver_api::models::IngestAuxFilesRequest;
|
||||
@@ -825,7 +824,7 @@ async fn get_lsn_by_timestamp_handler(
|
||||
|
||||
let lease = if with_lease {
|
||||
timeline
|
||||
.init_lsn_lease(lsn, timeline.get_lsn_lease_length_for_ts(), &ctx)
|
||||
.make_lsn_lease(lsn, timeline.get_lsn_lease_length_for_ts(), &ctx)
|
||||
.inspect_err(|_| {
|
||||
warn!("fail to grant a lease to {}", lsn);
|
||||
})
|
||||
@@ -1693,18 +1692,9 @@ async fn lsn_lease_handler(
|
||||
let timeline =
|
||||
active_timeline_of_active_tenant(&state.tenant_manager, tenant_shard_id, timeline_id)
|
||||
.await?;
|
||||
|
||||
let result = async {
|
||||
timeline
|
||||
.init_lsn_lease(lsn, timeline.get_lsn_lease_length(), &ctx)
|
||||
.map_err(|e| {
|
||||
ApiError::InternalServerError(
|
||||
e.context(format!("invalid lsn lease request at {lsn}")),
|
||||
)
|
||||
})
|
||||
}
|
||||
.instrument(info_span!("init_lsn_lease", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %timeline_id))
|
||||
.await?;
|
||||
let result = timeline
|
||||
.make_lsn_lease(lsn, timeline.get_lsn_lease_length(), &ctx)
|
||||
.map_err(|e| ApiError::InternalServerError(e.context("lsn lease http handler")))?;
|
||||
|
||||
json_response(StatusCode::OK, result)
|
||||
}
|
||||
@@ -2382,16 +2372,6 @@ async fn put_io_alignment_handler(
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
async fn put_io_mode_handler(
|
||||
mut r: Request<Body>,
|
||||
_cancel: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
check_permission(&r, None)?;
|
||||
let mode: IoMode = json_request(&mut r).await?;
|
||||
crate::virtual_file::set_io_mode(mode);
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
/// Polled by control plane.
|
||||
///
|
||||
/// See [`crate::utilization`].
|
||||
@@ -3082,7 +3062,6 @@ pub fn make_router(
|
||||
.put("/v1/io_alignment", |r| {
|
||||
api_handler(r, put_io_alignment_handler)
|
||||
})
|
||||
.put("/v1/io_mode", |r| api_handler(r, put_io_mode_handler))
|
||||
.put(
|
||||
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/force_aux_policy_switch",
|
||||
|r| api_handler(r, force_aux_policy_switch_handler),
|
||||
|
||||
@@ -82,7 +82,6 @@ use once_cell::sync::OnceCell;
|
||||
use crate::{
|
||||
context::RequestContext,
|
||||
metrics::{page_cache_eviction_metrics, PageCacheSizeMetrics},
|
||||
virtual_file::{self, dio::IoBufferMut},
|
||||
};
|
||||
|
||||
static PAGE_CACHE: OnceCell<PageCache> = OnceCell::new();
|
||||
@@ -91,8 +90,8 @@ const TEST_PAGE_CACHE_SIZE: usize = 50;
|
||||
///
|
||||
/// Initialize the page cache. This must be called once at page server startup.
|
||||
///
|
||||
pub fn init(size: usize, align: usize) {
|
||||
if PAGE_CACHE.set(PageCache::new(size, align)).is_err() {
|
||||
pub fn init(size: usize) {
|
||||
if PAGE_CACHE.set(PageCache::new(size)).is_err() {
|
||||
panic!("page cache already initialized");
|
||||
}
|
||||
}
|
||||
@@ -107,12 +106,7 @@ pub fn get() -> &'static PageCache {
|
||||
// page cache is usable in unit tests.
|
||||
//
|
||||
if cfg!(test) {
|
||||
PAGE_CACHE.get_or_init(|| {
|
||||
PageCache::new(
|
||||
TEST_PAGE_CACHE_SIZE,
|
||||
virtual_file::get_io_buffer_alignment(),
|
||||
)
|
||||
})
|
||||
PAGE_CACHE.get_or_init(|| PageCache::new(TEST_PAGE_CACHE_SIZE))
|
||||
} else {
|
||||
PAGE_CACHE.get().expect("page cache not initialized")
|
||||
}
|
||||
@@ -643,11 +637,13 @@ impl PageCache {
|
||||
/// Initialize a new page cache
|
||||
///
|
||||
/// This should be called only once at page server startup.
|
||||
fn new(num_pages: usize, align: usize) -> Self {
|
||||
fn new(num_pages: usize) -> Self {
|
||||
assert!(num_pages > 0, "page cache size must be > 0");
|
||||
|
||||
let page_buffer =
|
||||
IoBufferMut::with_capacity_aligned_zeroed(num_pages * PAGE_SZ, align).leak();
|
||||
// We could use Vec::leak here, but that potentially also leaks
|
||||
// uninitialized reserved capacity. With into_boxed_slice and Box::leak
|
||||
// this is avoided.
|
||||
let page_buffer = Box::leak(vec![0u8; num_pages * PAGE_SZ].into_boxed_slice());
|
||||
|
||||
let size_metrics = &crate::metrics::PAGE_CACHE_SIZE;
|
||||
size_metrics.max_bytes.set_page_sz(num_pages);
|
||||
|
||||
@@ -273,20 +273,10 @@ async fn page_service_conn_main(
|
||||
info!("Postgres client disconnected ({io_error})");
|
||||
Ok(())
|
||||
} else {
|
||||
let tenant_id = conn_handler.timeline_handles.tenant_id();
|
||||
Err(io_error).context(format!(
|
||||
"Postgres connection error for tenant_id={:?} client at peer_addr={}",
|
||||
tenant_id, peer_addr
|
||||
))
|
||||
Err(io_error).context("Postgres connection error")
|
||||
}
|
||||
}
|
||||
other => {
|
||||
let tenant_id = conn_handler.timeline_handles.tenant_id();
|
||||
other.context(format!(
|
||||
"Postgres query error for tenant_id={:?} client peer_addr={}",
|
||||
tenant_id, peer_addr
|
||||
))
|
||||
}
|
||||
other => other.context("Postgres query error"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -350,10 +340,6 @@ impl TimelineHandles {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn tenant_id(&self) -> Option<TenantId> {
|
||||
self.wrapper.tenant_id.get().copied()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct TenantManagerWrapper {
|
||||
@@ -833,7 +819,7 @@ impl PageServerHandler {
|
||||
set_tracing_field_shard_id(&timeline);
|
||||
|
||||
let lease = timeline
|
||||
.renew_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx)
|
||||
.make_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx)
|
||||
.inspect_err(|e| {
|
||||
warn!("{e}");
|
||||
})
|
||||
|
||||
@@ -21,7 +21,6 @@ use futures::stream::FuturesUnordered;
|
||||
use futures::StreamExt;
|
||||
use pageserver_api::models;
|
||||
use pageserver_api::models::AuxFilePolicy;
|
||||
use pageserver_api::models::LsnLease;
|
||||
use pageserver_api::models::TimelineArchivalState;
|
||||
use pageserver_api::models::TimelineState;
|
||||
use pageserver_api::models::TopTenantShardItem;
|
||||
@@ -183,54 +182,27 @@ pub struct TenantSharedResources {
|
||||
pub(super) struct AttachedTenantConf {
|
||||
tenant_conf: TenantConfOpt,
|
||||
location: AttachedLocationConfig,
|
||||
/// The deadline before which we are blocked from GC so that
|
||||
/// leases have a chance to be renewed.
|
||||
lsn_lease_deadline: Option<tokio::time::Instant>,
|
||||
}
|
||||
|
||||
impl AttachedTenantConf {
|
||||
fn new(tenant_conf: TenantConfOpt, location: AttachedLocationConfig) -> Self {
|
||||
// Sets a deadline before which we cannot proceed to GC due to lsn lease.
|
||||
//
|
||||
// We do this as the leases mapping are not persisted to disk. By delaying GC by lease
|
||||
// length, we guarantee that all the leases we granted before will have a chance to renew
|
||||
// when we run GC for the first time after restart / transition from AttachedMulti to AttachedSingle.
|
||||
let lsn_lease_deadline = if location.attach_mode == AttachmentMode::Single {
|
||||
Some(
|
||||
tokio::time::Instant::now()
|
||||
+ tenant_conf
|
||||
.lsn_lease_length
|
||||
.unwrap_or(LsnLease::DEFAULT_LENGTH),
|
||||
)
|
||||
} else {
|
||||
// We don't use `lsn_lease_deadline` to delay GC in AttachedMulti and AttachedStale
|
||||
// because we don't do GC in these modes.
|
||||
None
|
||||
};
|
||||
|
||||
Self {
|
||||
tenant_conf,
|
||||
location,
|
||||
lsn_lease_deadline,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_from(location_conf: LocationConf) -> anyhow::Result<Self> {
|
||||
match &location_conf.mode {
|
||||
LocationMode::Attached(attach_conf) => {
|
||||
Ok(Self::new(location_conf.tenant_conf, *attach_conf))
|
||||
}
|
||||
LocationMode::Attached(attach_conf) => Ok(Self {
|
||||
tenant_conf: location_conf.tenant_conf,
|
||||
location: *attach_conf,
|
||||
}),
|
||||
LocationMode::Secondary(_) => {
|
||||
anyhow::bail!("Attempted to construct AttachedTenantConf from a LocationConf in secondary mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_gc_blocked_by_lsn_lease_deadline(&self) -> bool {
|
||||
self.lsn_lease_deadline
|
||||
.map(|d| tokio::time::Instant::now() < d)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
struct TimelinePreload {
|
||||
timeline_id: TimelineId,
|
||||
@@ -1850,11 +1822,6 @@ impl Tenant {
|
||||
info!("Skipping GC in location state {:?}", conf.location);
|
||||
return Ok(GcResult::default());
|
||||
}
|
||||
|
||||
if conf.is_gc_blocked_by_lsn_lease_deadline() {
|
||||
info!("Skipping GC because lsn lease deadline is not reached");
|
||||
return Ok(GcResult::default());
|
||||
}
|
||||
}
|
||||
|
||||
let _guard = match self.gc_block.start().await {
|
||||
@@ -2663,8 +2630,6 @@ impl Tenant {
|
||||
Arc::new(AttachedTenantConf {
|
||||
tenant_conf: new_tenant_conf.clone(),
|
||||
location: inner.location,
|
||||
// Attached location is not changed, no need to update lsn lease deadline.
|
||||
lsn_lease_deadline: inner.lsn_lease_deadline,
|
||||
})
|
||||
});
|
||||
|
||||
@@ -3922,9 +3887,9 @@ async fn run_initdb(
|
||||
let _permit = INIT_DB_SEMAPHORE.acquire().await;
|
||||
|
||||
let initdb_command = tokio::process::Command::new(&initdb_bin_path)
|
||||
.args(["--pgdata", initdb_target_dir.as_ref()])
|
||||
.args(["--username", &conf.superuser])
|
||||
.args(["--encoding", "utf8"])
|
||||
.args(["-D", initdb_target_dir.as_ref()])
|
||||
.args(["-U", &conf.superuser])
|
||||
.args(["-E", "utf8"])
|
||||
.arg("--no-instructions")
|
||||
.arg("--no-sync")
|
||||
.env_clear()
|
||||
@@ -4496,17 +4461,13 @@ mod tests {
|
||||
tline.freeze_and_flush().await.map_err(|e| e.into())
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
#[tokio::test]
|
||||
async fn test_prohibit_branch_creation_on_garbage_collected_data() -> anyhow::Result<()> {
|
||||
let (tenant, ctx) =
|
||||
TenantHarness::create("test_prohibit_branch_creation_on_garbage_collected_data")
|
||||
.await?
|
||||
.load()
|
||||
.await;
|
||||
// Advance to the lsn lease deadline so that GC is not blocked by
|
||||
// initial transition into AttachedSingle.
|
||||
tokio::time::advance(tenant.get_lsn_lease_length()).await;
|
||||
tokio::time::resume();
|
||||
let tline = tenant
|
||||
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
|
||||
.await?;
|
||||
@@ -7283,17 +7244,9 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
#[tokio::test]
|
||||
async fn test_lsn_lease() -> anyhow::Result<()> {
|
||||
let (tenant, ctx) = TenantHarness::create("test_lsn_lease")
|
||||
.await
|
||||
.unwrap()
|
||||
.load()
|
||||
.await;
|
||||
// Advance to the lsn lease deadline so that GC is not blocked by
|
||||
// initial transition into AttachedSingle.
|
||||
tokio::time::advance(tenant.get_lsn_lease_length()).await;
|
||||
tokio::time::resume();
|
||||
let (tenant, ctx) = TenantHarness::create("test_lsn_lease").await?.load().await;
|
||||
let key = Key::from_hex("010000000033333333444444445500000000").unwrap();
|
||||
|
||||
let end_lsn = Lsn(0x100);
|
||||
@@ -7321,33 +7274,24 @@ mod tests {
|
||||
|
||||
let leased_lsns = [0x30, 0x50, 0x70];
|
||||
let mut leases = Vec::new();
|
||||
leased_lsns.iter().for_each(|n| {
|
||||
leases.push(
|
||||
timeline
|
||||
.init_lsn_lease(Lsn(*n), timeline.get_lsn_lease_length(), &ctx)
|
||||
.expect("lease request should succeed"),
|
||||
);
|
||||
let _: anyhow::Result<_> = leased_lsns.iter().try_for_each(|n| {
|
||||
leases.push(timeline.make_lsn_lease(Lsn(*n), timeline.get_lsn_lease_length(), &ctx)?);
|
||||
Ok(())
|
||||
});
|
||||
|
||||
let updated_lease_0 = timeline
|
||||
.renew_lsn_lease(Lsn(leased_lsns[0]), Duration::from_secs(0), &ctx)
|
||||
.expect("lease renewal should succeed");
|
||||
assert_eq!(
|
||||
updated_lease_0.valid_until, leases[0].valid_until,
|
||||
" Renewing with shorter lease should not change the lease."
|
||||
);
|
||||
// Renewing with shorter lease should not change the lease.
|
||||
let updated_lease_0 =
|
||||
timeline.make_lsn_lease(Lsn(leased_lsns[0]), Duration::from_secs(0), &ctx)?;
|
||||
assert_eq!(updated_lease_0.valid_until, leases[0].valid_until);
|
||||
|
||||
let updated_lease_1 = timeline
|
||||
.renew_lsn_lease(
|
||||
Lsn(leased_lsns[1]),
|
||||
timeline.get_lsn_lease_length() * 2,
|
||||
&ctx,
|
||||
)
|
||||
.expect("lease renewal should succeed");
|
||||
assert!(
|
||||
updated_lease_1.valid_until > leases[1].valid_until,
|
||||
"Renewing with a long lease should renew lease with later expiration time."
|
||||
);
|
||||
// Renewing with a long lease should renew lease with later expiration time.
|
||||
let updated_lease_1 = timeline.make_lsn_lease(
|
||||
Lsn(leased_lsns[1]),
|
||||
timeline.get_lsn_lease_length() * 2,
|
||||
&ctx,
|
||||
)?;
|
||||
|
||||
assert!(updated_lease_1.valid_until > leases[1].valid_until);
|
||||
|
||||
// Force set disk consistent lsn so we can get the cutoff at `end_lsn`.
|
||||
info!(
|
||||
@@ -7364,8 +7308,7 @@ mod tests {
|
||||
&CancellationToken::new(),
|
||||
&ctx,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
|
||||
// Keeping everything <= Lsn(0x80) b/c leases:
|
||||
// 0/10: initdb layer
|
||||
@@ -7379,16 +7322,13 @@ mod tests {
|
||||
// Make lease on a already GC-ed LSN.
|
||||
// 0/80 does not have a valid lease + is below latest_gc_cutoff
|
||||
assert!(Lsn(0x80) < *timeline.get_latest_gc_cutoff_lsn());
|
||||
timeline
|
||||
.init_lsn_lease(Lsn(0x80), timeline.get_lsn_lease_length(), &ctx)
|
||||
.expect_err("lease request on GC-ed LSN should fail");
|
||||
let res = timeline.make_lsn_lease(Lsn(0x80), timeline.get_lsn_lease_length(), &ctx);
|
||||
assert!(res.is_err());
|
||||
|
||||
// Should still be able to renew a currently valid lease
|
||||
// Assumption: original lease to is still valid for 0/50.
|
||||
// (use `Timeline::init_lsn_lease` for testing so it always does validation)
|
||||
timeline
|
||||
.init_lsn_lease(Lsn(leased_lsns[1]), timeline.get_lsn_lease_length(), &ctx)
|
||||
.expect("lease renewal with validation should succeed");
|
||||
let _ =
|
||||
timeline.make_lsn_lease(Lsn(leased_lsns[1]), timeline.get_lsn_lease_length(), &ctx)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@
|
||||
use super::storage_layer::delta_layer::{Adapter, DeltaLayerInner};
|
||||
use crate::context::RequestContext;
|
||||
use crate::page_cache::{self, FileId, PageReadGuard, PageWriteGuard, ReadBufResult, PAGE_SZ};
|
||||
#[cfg(test)]
|
||||
use crate::virtual_file::dio::IoBufferMut;
|
||||
use crate::virtual_file::VirtualFile;
|
||||
use bytes::Bytes;
|
||||
use std::ops::Deref;
|
||||
@@ -42,7 +40,7 @@ pub enum BlockLease<'a> {
|
||||
#[cfg(test)]
|
||||
Arc(std::sync::Arc<[u8; PAGE_SZ]>),
|
||||
#[cfg(test)]
|
||||
IoBufferMut(IoBufferMut),
|
||||
Vec(Vec<u8>),
|
||||
}
|
||||
|
||||
impl From<PageReadGuard<'static>> for BlockLease<'static> {
|
||||
@@ -69,7 +67,7 @@ impl<'a> Deref for BlockLease<'a> {
|
||||
#[cfg(test)]
|
||||
BlockLease::Arc(v) => v.deref(),
|
||||
#[cfg(test)]
|
||||
BlockLease::IoBufferMut(v) => {
|
||||
BlockLease::Vec(v) => {
|
||||
TryFrom::try_from(&v[..]).expect("caller must ensure that v has PAGE_SZ")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,8 +6,6 @@ use crate::config::PageServerConf;
|
||||
use crate::context::RequestContext;
|
||||
use crate::page_cache;
|
||||
use crate::tenant::storage_layer::inmemory_layer::vectored_dio_read::File;
|
||||
use crate::virtual_file::dio::IoBufferMut;
|
||||
use crate::virtual_file::owned_buffers_io::io_buf_aligned::IoBufAlignedMut;
|
||||
use crate::virtual_file::owned_buffers_io::slice::SliceMutExt;
|
||||
use crate::virtual_file::owned_buffers_io::util::size_tracking_writer;
|
||||
use crate::virtual_file::owned_buffers_io::write::Buffer;
|
||||
@@ -86,7 +84,7 @@ impl Drop for EphemeralFile {
|
||||
fn drop(&mut self) {
|
||||
// unlink the file
|
||||
// we are clear to do this, because we have entered a gate
|
||||
let path = self.buffered_writer.as_inner().as_inner().path();
|
||||
let path = &self.buffered_writer.as_inner().as_inner().path;
|
||||
let res = std::fs::remove_file(path);
|
||||
if let Err(e) = res {
|
||||
if e.kind() != std::io::ErrorKind::NotFound {
|
||||
@@ -109,16 +107,15 @@ impl EphemeralFile {
|
||||
self.page_cache_file_id
|
||||
}
|
||||
|
||||
pub(crate) async fn load_to_buf(&self, ctx: &RequestContext) -> Result<IoBufferMut, io::Error> {
|
||||
pub(crate) async fn load_to_vec(&self, ctx: &RequestContext) -> Result<Vec<u8>, io::Error> {
|
||||
let size = self.len().into_usize();
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let buf = IoBufferMut::with_capacity_aligned(size, align);
|
||||
let (slice, nread) = self.read_exact_at_eof_ok(0, buf.slice_full(), ctx).await?;
|
||||
let vec = Vec::with_capacity(size);
|
||||
let (slice, nread) = self.read_exact_at_eof_ok(0, vec.slice_full(), ctx).await?;
|
||||
assert_eq!(nread, size);
|
||||
let buf = slice.into_inner();
|
||||
assert_eq!(buf.len(), nread);
|
||||
assert_eq!(buf.capacity(), size, "we shouldn't be reallocating");
|
||||
Ok(buf)
|
||||
let vec = slice.into_inner();
|
||||
assert_eq!(vec.len(), nread);
|
||||
assert_eq!(vec.capacity(), size, "we shouldn't be reallocating");
|
||||
Ok(vec)
|
||||
}
|
||||
|
||||
/// Returns the offset at which the first byte of the input was written, for use
|
||||
@@ -161,7 +158,7 @@ impl EphemeralFile {
|
||||
}
|
||||
|
||||
impl super::storage_layer::inmemory_layer::vectored_dio_read::File for EphemeralFile {
|
||||
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufAlignedMut + Send>(
|
||||
async fn read_exact_at_eof_ok<'a, 'b, B: tokio_epoll_uring::IoBufMut + Send>(
|
||||
&'b self,
|
||||
start: u64,
|
||||
dst: tokio_epoll_uring::Slice<B>,
|
||||
@@ -346,10 +343,9 @@ mod tests {
|
||||
}
|
||||
|
||||
assert!(file.len() as usize == write_nbytes);
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
for i in 0..write_nbytes {
|
||||
assert_eq!(value_offsets[i], i.into_u64());
|
||||
let buf = IoBufferMut::with_capacity_aligned(1, align);
|
||||
let buf = Vec::with_capacity(1);
|
||||
let (buf_slice, nread) = file
|
||||
.read_exact_at_eof_ok(i.into_u64(), buf.slice_full(), &ctx)
|
||||
.await
|
||||
@@ -360,7 +356,7 @@ mod tests {
|
||||
}
|
||||
|
||||
let file_contents =
|
||||
std::fs::read(file.buffered_writer.as_inner().as_inner().path()).unwrap();
|
||||
std::fs::read(&file.buffered_writer.as_inner().as_inner().path).unwrap();
|
||||
assert_eq!(file_contents, &content[0..cap]);
|
||||
|
||||
let buffer_contents = file.buffered_writer.inspect_buffer();
|
||||
@@ -389,14 +385,14 @@ mod tests {
|
||||
|
||||
// assert the state is as this test expects it to be
|
||||
assert_eq!(
|
||||
&file.load_to_buf(&ctx).await.unwrap()[..],
|
||||
&file.load_to_vec(&ctx).await.unwrap(),
|
||||
&content[0..cap + cap / 2]
|
||||
);
|
||||
let md = file
|
||||
.buffered_writer
|
||||
.as_inner()
|
||||
.as_inner()
|
||||
.path()
|
||||
.path
|
||||
.metadata()
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
@@ -444,17 +440,13 @@ mod tests {
|
||||
let (buf, nread) = file
|
||||
.read_exact_at_eof_ok(
|
||||
start.into_u64(),
|
||||
IoBufferMut::with_capacity_aligned(
|
||||
len,
|
||||
virtual_file::get_io_buffer_alignment(),
|
||||
)
|
||||
.slice_full(),
|
||||
Vec::with_capacity(len).slice_full(),
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(nread, len);
|
||||
assert_eq!(&buf.into_inner()[..], &content[start..(start + len)]);
|
||||
assert_eq!(&buf.into_inner(), &content[start..(start + len)]);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,12 +1,29 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use utils::id::TimelineId;
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use super::remote_timeline_client::index::GcBlockingReason;
|
||||
use tokio::time::Instant;
|
||||
use utils::id::TimelineId;
|
||||
|
||||
type Storage = HashMap<TimelineId, enumset::EnumSet<GcBlockingReason>>;
|
||||
type TimelinesBlocked = HashMap<TimelineId, enumset::EnumSet<GcBlockingReason>>;
|
||||
|
||||
/// GcBlock provides persistent (per-timeline) gc blocking.
|
||||
#[derive(Default)]
|
||||
struct Storage {
|
||||
timelines_blocked: TimelinesBlocked,
|
||||
/// The deadline before which we are blocked from GC so that
|
||||
/// leases have a chance to be renewed.
|
||||
lsn_lease_deadline: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
fn is_blocked_by_lsn_lease_deadline(&self) -> bool {
|
||||
self.lsn_lease_deadline
|
||||
.map(|d| Instant::now() < d)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// GcBlock provides persistent (per-timeline) gc blocking and facilitates transient time based gc
|
||||
/// blocking.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct GcBlock {
|
||||
/// The timelines which have current reasons to block gc.
|
||||
@@ -49,6 +66,17 @@ impl GcBlock {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets a deadline before which we cannot proceed to GC due to lsn lease.
|
||||
///
|
||||
/// We do this as the leases mapping are not persisted to disk. By delaying GC by lease
|
||||
/// length, we guarantee that all the leases we granted before will have a chance to renew
|
||||
/// when we run GC for the first time after restart / transition from AttachedMulti to AttachedSingle.
|
||||
pub(super) fn set_lsn_lease_deadline(&self, lsn_lease_length: Duration) {
|
||||
let deadline = Instant::now() + lsn_lease_length;
|
||||
let mut g = self.reasons.lock().unwrap();
|
||||
g.lsn_lease_deadline = Some(deadline);
|
||||
}
|
||||
|
||||
/// Describe the current gc blocking reasons.
|
||||
///
|
||||
/// TODO: make this json serializable.
|
||||
@@ -74,7 +102,7 @@ impl GcBlock {
|
||||
) -> anyhow::Result<bool> {
|
||||
let (added, uploaded) = {
|
||||
let mut g = self.reasons.lock().unwrap();
|
||||
let set = g.entry(timeline.timeline_id).or_default();
|
||||
let set = g.timelines_blocked.entry(timeline.timeline_id).or_default();
|
||||
let added = set.insert(reason);
|
||||
|
||||
// LOCK ORDER: intentionally hold the lock, see self.reasons.
|
||||
@@ -105,7 +133,7 @@ impl GcBlock {
|
||||
|
||||
let (remaining_blocks, uploaded) = {
|
||||
let mut g = self.reasons.lock().unwrap();
|
||||
match g.entry(timeline.timeline_id) {
|
||||
match g.timelines_blocked.entry(timeline.timeline_id) {
|
||||
Entry::Occupied(mut oe) => {
|
||||
let set = oe.get_mut();
|
||||
set.remove(reason);
|
||||
@@ -119,7 +147,7 @@ impl GcBlock {
|
||||
}
|
||||
}
|
||||
|
||||
let remaining_blocks = g.len();
|
||||
let remaining_blocks = g.timelines_blocked.len();
|
||||
|
||||
// LOCK ORDER: intentionally hold the lock while scheduling; see self.reasons
|
||||
let uploaded = timeline
|
||||
@@ -144,11 +172,11 @@ impl GcBlock {
|
||||
pub(crate) fn before_delete(&self, timeline: &super::Timeline) {
|
||||
let unblocked = {
|
||||
let mut g = self.reasons.lock().unwrap();
|
||||
if g.is_empty() {
|
||||
if g.timelines_blocked.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
g.remove(&timeline.timeline_id);
|
||||
g.timelines_blocked.remove(&timeline.timeline_id);
|
||||
|
||||
BlockingReasons::clean_and_summarize(g).is_none()
|
||||
};
|
||||
@@ -159,10 +187,11 @@ impl GcBlock {
|
||||
}
|
||||
|
||||
/// Initialize with the non-deleted timelines of this tenant.
|
||||
pub(crate) fn set_scanned(&self, scanned: Storage) {
|
||||
pub(crate) fn set_scanned(&self, scanned: TimelinesBlocked) {
|
||||
let mut g = self.reasons.lock().unwrap();
|
||||
assert!(g.is_empty());
|
||||
g.extend(scanned.into_iter().filter(|(_, v)| !v.is_empty()));
|
||||
assert!(g.timelines_blocked.is_empty());
|
||||
g.timelines_blocked
|
||||
.extend(scanned.into_iter().filter(|(_, v)| !v.is_empty()));
|
||||
|
||||
if let Some(reasons) = BlockingReasons::clean_and_summarize(g) {
|
||||
tracing::info!(summary=?reasons, "initialized with gc blocked");
|
||||
@@ -176,6 +205,7 @@ pub(super) struct Guard<'a> {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct BlockingReasons {
|
||||
tenant_blocked_by_lsn_lease_deadline: bool,
|
||||
timelines: usize,
|
||||
reasons: enumset::EnumSet<GcBlockingReason>,
|
||||
}
|
||||
@@ -184,8 +214,8 @@ impl std::fmt::Display for BlockingReasons {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{} timelines block for {:?}",
|
||||
self.timelines, self.reasons
|
||||
"tenant_blocked_by_lsn_lease_deadline: {}, {} timelines block for {:?}",
|
||||
self.tenant_blocked_by_lsn_lease_deadline, self.timelines, self.reasons
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -193,13 +223,15 @@ impl std::fmt::Display for BlockingReasons {
|
||||
impl BlockingReasons {
|
||||
fn clean_and_summarize(mut g: std::sync::MutexGuard<'_, Storage>) -> Option<Self> {
|
||||
let mut reasons = enumset::EnumSet::empty();
|
||||
g.retain(|_key, value| {
|
||||
g.timelines_blocked.retain(|_key, value| {
|
||||
reasons = reasons.union(*value);
|
||||
!value.is_empty()
|
||||
});
|
||||
if !g.is_empty() {
|
||||
let blocked_by_lsn_lease_deadline = g.is_blocked_by_lsn_lease_deadline();
|
||||
if !g.timelines_blocked.is_empty() || blocked_by_lsn_lease_deadline {
|
||||
Some(BlockingReasons {
|
||||
timelines: g.len(),
|
||||
tenant_blocked_by_lsn_lease_deadline: blocked_by_lsn_lease_deadline,
|
||||
timelines: g.timelines_blocked.len(),
|
||||
reasons,
|
||||
})
|
||||
} else {
|
||||
@@ -208,14 +240,17 @@ impl BlockingReasons {
|
||||
}
|
||||
|
||||
fn summarize(g: &std::sync::MutexGuard<'_, Storage>) -> Option<Self> {
|
||||
if g.is_empty() {
|
||||
let blocked_by_lsn_lease_deadline = g.is_blocked_by_lsn_lease_deadline();
|
||||
if g.timelines_blocked.is_empty() && !blocked_by_lsn_lease_deadline {
|
||||
None
|
||||
} else {
|
||||
let reasons = g
|
||||
.timelines_blocked
|
||||
.values()
|
||||
.fold(enumset::EnumSet::empty(), |acc, next| acc.union(*next));
|
||||
Some(BlockingReasons {
|
||||
timelines: g.len(),
|
||||
tenant_blocked_by_lsn_lease_deadline: blocked_by_lsn_lease_deadline,
|
||||
timelines: g.timelines_blocked.len(),
|
||||
reasons,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -219,11 +219,7 @@ async fn safe_rename_tenant_dir(path: impl AsRef<Utf8Path>) -> std::io::Result<U
|
||||
+ TEMP_FILE_SUFFIX;
|
||||
let tmp_path = path_with_suffix_extension(&path, &rand_suffix);
|
||||
fs::rename(path.as_ref(), &tmp_path).await?;
|
||||
fs::File::open(parent)
|
||||
.await?
|
||||
.sync_all()
|
||||
.await
|
||||
.maybe_fatal_err("safe_rename_tenant_dir")?;
|
||||
fs::File::open(parent).await?.sync_all().await?;
|
||||
Ok(tmp_path)
|
||||
}
|
||||
|
||||
@@ -953,6 +949,12 @@ impl TenantManager {
|
||||
(LocationMode::Attached(attach_conf), Some(TenantSlot::Attached(tenant))) => {
|
||||
match attach_conf.generation.cmp(&tenant.generation) {
|
||||
Ordering::Equal => {
|
||||
if attach_conf.attach_mode == AttachmentMode::Single {
|
||||
tenant
|
||||
.gc_block
|
||||
.set_lsn_lease_deadline(tenant.get_lsn_lease_length());
|
||||
}
|
||||
|
||||
// A transition from Attached to Attached in the same generation, we may
|
||||
// take our fast path and just provide the updated configuration
|
||||
// to the tenant.
|
||||
|
||||
@@ -178,7 +178,6 @@ async fn download_object<'a>(
|
||||
destination_file
|
||||
.flush()
|
||||
.await
|
||||
.maybe_fatal_err("download_object sync_all")
|
||||
.with_context(|| format!("flush source file at {dst_path}"))
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
@@ -186,7 +185,6 @@ async fn download_object<'a>(
|
||||
destination_file
|
||||
.sync_all()
|
||||
.await
|
||||
.maybe_fatal_err("download_object sync_all")
|
||||
.with_context(|| format!("failed to fsync source file at {dst_path}"))
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
@@ -234,7 +232,6 @@ async fn download_object<'a>(
|
||||
destination_file
|
||||
.sync_all()
|
||||
.await
|
||||
.maybe_fatal_err("download_object sync_all")
|
||||
.with_context(|| format!("failed to fsync source file at {dst_path}"))
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
|
||||
@@ -40,15 +40,15 @@ use crate::tenant::storage_layer::layer::S3_UPLOAD_LIMIT;
|
||||
use crate::tenant::timeline::GetVectoredError;
|
||||
use crate::tenant::vectored_blob_io::{
|
||||
BlobFlag, BufView, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
|
||||
VectoredReadPlanner,
|
||||
VectoredReadCoalesceMode, VectoredReadPlanner,
|
||||
};
|
||||
use crate::tenant::PageReconstructError;
|
||||
use crate::virtual_file::dio::IoBufferMut;
|
||||
use crate::virtual_file::owned_buffers_io::io_buf_ext::{FullSlice, IoBufExt};
|
||||
use crate::virtual_file::{self, MaybeFatalIo, VirtualFile};
|
||||
use crate::virtual_file::{self, VirtualFile};
|
||||
use crate::{walrecord, TEMP_FILE_SUFFIX};
|
||||
use crate::{DELTA_FILE_MAGIC, STORAGE_FORMAT_VERSION};
|
||||
use anyhow::{anyhow, bail, ensure, Context, Result};
|
||||
use bytes::BytesMut;
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use futures::StreamExt;
|
||||
use itertools::Itertools;
|
||||
@@ -572,7 +572,7 @@ impl DeltaLayerWriterInner {
|
||||
ensure!(
|
||||
metadata.len() <= S3_UPLOAD_LIMIT,
|
||||
"Created delta layer file at {} of size {} above limit {S3_UPLOAD_LIMIT}!",
|
||||
file.path(),
|
||||
file.path,
|
||||
metadata.len()
|
||||
);
|
||||
|
||||
@@ -589,9 +589,7 @@ impl DeltaLayerWriterInner {
|
||||
);
|
||||
|
||||
// fsync the file
|
||||
file.sync_all()
|
||||
.await
|
||||
.maybe_fatal_err("delta_layer sync_all")?;
|
||||
file.sync_all().await?;
|
||||
|
||||
trace!("created delta layer {}", self.path);
|
||||
|
||||
@@ -790,7 +788,7 @@ impl DeltaLayerInner {
|
||||
max_vectored_read_bytes: Option<MaxVectoredReadBytes>,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
let file = VirtualFile::open_v2(path, ctx)
|
||||
let file = VirtualFile::open(path, ctx)
|
||||
.await
|
||||
.context("open layer file")?;
|
||||
|
||||
@@ -991,8 +989,7 @@ impl DeltaLayerInner {
|
||||
.0
|
||||
.into();
|
||||
let buf_size = Self::get_min_read_buffer_size(&reads, max_vectored_read_bytes);
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let mut buf = Some(IoBufferMut::with_capacity_aligned(buf_size, align));
|
||||
let mut buf = Some(BytesMut::with_capacity(buf_size));
|
||||
|
||||
// Note that reads are processed in reverse order (from highest key+lsn).
|
||||
// This is the order that `ReconstructState` requires such that it can
|
||||
@@ -1011,7 +1008,7 @@ impl DeltaLayerInner {
|
||||
blob_meta.key,
|
||||
PageReconstructError::Other(anyhow!(
|
||||
"Failed to read blobs from virtual file {}: {}",
|
||||
self.file.path(),
|
||||
self.file.path,
|
||||
kind
|
||||
)),
|
||||
);
|
||||
@@ -1019,7 +1016,7 @@ impl DeltaLayerInner {
|
||||
|
||||
// We have "lost" the buffer since the lower level IO api
|
||||
// doesn't return the buffer on error. Allocate a new one.
|
||||
buf = Some(IoBufferMut::with_capacity_aligned(buf_size, align));
|
||||
buf = Some(BytesMut::with_capacity(buf_size));
|
||||
|
||||
continue;
|
||||
}
|
||||
@@ -1037,7 +1034,7 @@ impl DeltaLayerInner {
|
||||
meta.meta.key,
|
||||
PageReconstructError::Other(anyhow!(e).context(format!(
|
||||
"Failed to decompress blob from virtual file {}",
|
||||
self.file.path(),
|
||||
self.file.path,
|
||||
))),
|
||||
);
|
||||
|
||||
@@ -1055,7 +1052,7 @@ impl DeltaLayerInner {
|
||||
meta.meta.key,
|
||||
PageReconstructError::Other(anyhow!(e).context(format!(
|
||||
"Failed to deserialize blob from virtual file {}",
|
||||
self.file.path(),
|
||||
self.file.path,
|
||||
))),
|
||||
);
|
||||
|
||||
@@ -1136,7 +1133,7 @@ impl DeltaLayerInner {
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<usize> {
|
||||
use crate::tenant::vectored_blob_io::{
|
||||
BlobMeta, ChunkedVectoredReadBuilder, VectoredReadExtended,
|
||||
BlobMeta, VectoredReadBuilder, VectoredReadExtended,
|
||||
};
|
||||
use futures::stream::TryStreamExt;
|
||||
|
||||
@@ -1186,15 +1183,15 @@ impl DeltaLayerInner {
|
||||
|
||||
let mut prev: Option<(Key, Lsn, BlobRef)> = None;
|
||||
|
||||
let mut read_builder: Option<ChunkedVectoredReadBuilder> = None;
|
||||
let mut read_builder: Option<VectoredReadBuilder> = None;
|
||||
let read_mode = VectoredReadCoalesceMode::get();
|
||||
|
||||
let max_read_size = self
|
||||
.max_vectored_read_bytes
|
||||
.map(|x| x.0.get())
|
||||
.unwrap_or(8192);
|
||||
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let mut buffer = Some(IoBufferMut::with_capacity_aligned(max_read_size, align));
|
||||
let mut buffer = Some(BytesMut::with_capacity(max_read_size));
|
||||
|
||||
// FIXME: buffering of DeltaLayerWriter
|
||||
let mut per_blob_copy = Vec::new();
|
||||
@@ -1231,12 +1228,12 @@ impl DeltaLayerInner {
|
||||
{
|
||||
None
|
||||
} else {
|
||||
read_builder.replace(ChunkedVectoredReadBuilder::new(
|
||||
read_builder.replace(VectoredReadBuilder::new(
|
||||
offsets.start.pos(),
|
||||
offsets.end.pos(),
|
||||
meta,
|
||||
max_read_size,
|
||||
align,
|
||||
read_mode,
|
||||
))
|
||||
}
|
||||
} else {
|
||||
@@ -1553,12 +1550,12 @@ impl<'a> DeltaLayerIterator<'a> {
|
||||
let vectored_blob_reader = VectoredBlobReader::new(&self.delta_layer.file);
|
||||
let mut next_batch = std::collections::VecDeque::new();
|
||||
let buf_size = plan.size();
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let buf = IoBufferMut::with_capacity_aligned(buf_size, align);
|
||||
let buf = BytesMut::with_capacity(buf_size);
|
||||
let blobs_buf = vectored_blob_reader
|
||||
.read_blobs(&plan, buf, self.ctx)
|
||||
.await?;
|
||||
let view = BufView::new_slice(&blobs_buf.buf);
|
||||
let frozen_buf = blobs_buf.buf.freeze();
|
||||
let view = BufView::new_bytes(frozen_buf);
|
||||
for meta in blobs_buf.blobs.iter() {
|
||||
let blob_read = meta.read(&view).await?;
|
||||
let value = Value::des(&blob_read)?;
|
||||
@@ -1933,9 +1930,7 @@ pub(crate) mod test {
|
||||
&vectored_reads,
|
||||
constants::MAX_VECTORED_READ_BYTES,
|
||||
);
|
||||
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let mut buf = Some(IoBufferMut::with_capacity_aligned(buf_size, align));
|
||||
let mut buf = Some(BytesMut::with_capacity(buf_size));
|
||||
|
||||
for read in vectored_reads {
|
||||
let blobs_buf = vectored_blob_reader
|
||||
|
||||
@@ -40,12 +40,11 @@ use crate::tenant::vectored_blob_io::{
|
||||
VectoredReadPlanner,
|
||||
};
|
||||
use crate::tenant::PageReconstructError;
|
||||
use crate::virtual_file::dio::IoBufferMut;
|
||||
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
|
||||
use crate::virtual_file::{self, MaybeFatalIo, VirtualFile};
|
||||
use crate::virtual_file::{self, VirtualFile};
|
||||
use crate::{IMAGE_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX};
|
||||
use anyhow::{anyhow, bail, ensure, Context, Result};
|
||||
use bytes::Bytes;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use hex;
|
||||
use itertools::Itertools;
|
||||
@@ -389,7 +388,7 @@ impl ImageLayerInner {
|
||||
max_vectored_read_bytes: Option<MaxVectoredReadBytes>,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Self> {
|
||||
let file = VirtualFile::open_v2(path, ctx)
|
||||
let file = VirtualFile::open(path, ctx)
|
||||
.await
|
||||
.context("open layer file")?;
|
||||
let file_id = page_cache::next_file_id();
|
||||
@@ -543,15 +542,14 @@ impl ImageLayerInner {
|
||||
.await?;
|
||||
|
||||
let vectored_blob_reader = VectoredBlobReader::new(&self.file);
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let mut key_count = 0;
|
||||
for read in plan.into_iter() {
|
||||
let buf_size = read.size();
|
||||
|
||||
let buf = IoBufferMut::with_capacity_aligned(buf_size, align);
|
||||
let buf = BytesMut::with_capacity(buf_size);
|
||||
let blobs_buf = vectored_blob_reader.read_blobs(&read, buf, ctx).await?;
|
||||
|
||||
let view = BufView::new_slice(&blobs_buf.buf);
|
||||
let frozen_buf = blobs_buf.buf.freeze();
|
||||
let view = BufView::new_bytes(frozen_buf);
|
||||
|
||||
for meta in blobs_buf.blobs.iter() {
|
||||
let img_buf = meta.read(&view).await?;
|
||||
@@ -599,13 +597,13 @@ impl ImageLayerInner {
|
||||
);
|
||||
}
|
||||
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let buf = IoBufferMut::with_capacity_aligned(buf_size, align);
|
||||
let buf = BytesMut::with_capacity(buf_size);
|
||||
let res = vectored_blob_reader.read_blobs(&read, buf, ctx).await;
|
||||
|
||||
match res {
|
||||
Ok(blobs_buf) => {
|
||||
let view = BufView::new_slice(&blobs_buf.buf);
|
||||
let frozen_buf = blobs_buf.buf.freeze();
|
||||
let view = BufView::new_bytes(frozen_buf);
|
||||
for meta in blobs_buf.blobs.iter() {
|
||||
let img_buf = meta.read(&view).await;
|
||||
|
||||
@@ -616,7 +614,7 @@ impl ImageLayerInner {
|
||||
meta.meta.key,
|
||||
PageReconstructError::Other(anyhow!(e).context(format!(
|
||||
"Failed to decompress blob from virtual file {}",
|
||||
self.file.path(),
|
||||
self.file.path,
|
||||
))),
|
||||
);
|
||||
|
||||
@@ -637,7 +635,7 @@ impl ImageLayerInner {
|
||||
blob_meta.key,
|
||||
PageReconstructError::from(anyhow!(
|
||||
"Failed to read blobs from virtual file {}: {}",
|
||||
self.file.path(),
|
||||
self.file.path,
|
||||
kind
|
||||
)),
|
||||
);
|
||||
@@ -891,9 +889,7 @@ impl ImageLayerWriterInner {
|
||||
// set inner.file here. The first read will have to re-open it.
|
||||
|
||||
// fsync the file
|
||||
file.sync_all()
|
||||
.await
|
||||
.maybe_fatal_err("image_layer sync_all")?;
|
||||
file.sync_all().await?;
|
||||
|
||||
trace!("created image layer {}", self.path);
|
||||
|
||||
@@ -1041,12 +1037,12 @@ impl<'a> ImageLayerIterator<'a> {
|
||||
let vectored_blob_reader = VectoredBlobReader::new(&self.image_layer.file);
|
||||
let mut next_batch = std::collections::VecDeque::new();
|
||||
let buf_size = plan.size();
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let buf = IoBufferMut::with_capacity_aligned(buf_size, align);
|
||||
let buf = BytesMut::with_capacity(buf_size);
|
||||
let blobs_buf = vectored_blob_reader
|
||||
.read_blobs(&plan, buf, self.ctx)
|
||||
.await?;
|
||||
let view = BufView::new_slice(&blobs_buf.buf);
|
||||
let frozen_buf = blobs_buf.buf.freeze();
|
||||
let view = BufView::new_bytes(frozen_buf);
|
||||
for meta in blobs_buf.blobs.iter() {
|
||||
let img_buf = meta.read(&view).await?;
|
||||
next_batch.push_back((
|
||||
|
||||
@@ -809,9 +809,9 @@ impl InMemoryLayer {
|
||||
|
||||
match l0_flush_global_state {
|
||||
l0_flush::Inner::Direct { .. } => {
|
||||
let file_contents = inner.file.load_to_buf(ctx).await?;
|
||||
let file_contents: Vec<u8> = inner.file.load_to_vec(ctx).await?;
|
||||
|
||||
let file_contents = Bytes::copy_from_slice(&file_contents[..]);
|
||||
let file_contents = Bytes::from(file_contents);
|
||||
|
||||
for (key, vec_map) in inner.index.iter() {
|
||||
// Write all page versions
|
||||
|
||||
@@ -9,7 +9,6 @@ use tokio_epoll_uring::{BoundedBuf, IoBufMut, Slice};
|
||||
use crate::{
|
||||
assert_u64_eq_usize::{U64IsUsize, UsizeIsU64},
|
||||
context::RequestContext,
|
||||
virtual_file::{self, dio::IoBufferMut, owned_buffers_io::io_buf_aligned::IoBufAlignedMut},
|
||||
};
|
||||
|
||||
/// The file interface we require. At runtime, this is a [`crate::tenant::ephemeral_file::EphemeralFile`].
|
||||
@@ -25,7 +24,7 @@ pub trait File: Send {
|
||||
/// [`std::io::ErrorKind::UnexpectedEof`] error if the file is shorter than `start+dst.len()`.
|
||||
///
|
||||
/// No guarantees are made about the remaining bytes in `dst` in case of a short read.
|
||||
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufAlignedMut + Send>(
|
||||
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufMut + Send>(
|
||||
&'b self,
|
||||
start: u64,
|
||||
dst: Slice<B>,
|
||||
@@ -228,12 +227,7 @@ where
|
||||
|
||||
// Execute physical reads and fill the logical read buffers
|
||||
// TODO: pipelined reads; prefetch;
|
||||
let get_io_buffer = |nchunks| {
|
||||
IoBufferMut::with_capacity_aligned(
|
||||
nchunks * DIO_CHUNK_SIZE,
|
||||
virtual_file::get_io_buffer_alignment(),
|
||||
)
|
||||
};
|
||||
let get_io_buffer = |nchunks| Vec::with_capacity(nchunks * DIO_CHUNK_SIZE);
|
||||
for PhysicalRead {
|
||||
start_chunk_no,
|
||||
nchunks,
|
||||
@@ -465,10 +459,7 @@ mod tests {
|
||||
let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error);
|
||||
let file = InMemoryFile::new_random(10);
|
||||
let test_read = |pos, len| {
|
||||
let buf = IoBufferMut::with_capacity_aligned_zeroed(
|
||||
len,
|
||||
virtual_file::get_io_buffer_alignment(),
|
||||
);
|
||||
let buf = vec![0; len];
|
||||
let fut = file.read_exact_at_eof_ok(pos, buf.slice_full(), &ctx);
|
||||
use futures::FutureExt;
|
||||
let (slice, nread) = fut
|
||||
@@ -479,9 +470,9 @@ mod tests {
|
||||
buf.truncate(nread);
|
||||
buf
|
||||
};
|
||||
assert_eq!(&test_read(0, 1), &file.content[0..1]);
|
||||
assert_eq!(&test_read(1, 2), &file.content[1..3]);
|
||||
assert_eq!(&test_read(9, 2), &file.content[9..]);
|
||||
assert_eq!(test_read(0, 1), &file.content[0..1]);
|
||||
assert_eq!(test_read(1, 2), &file.content[1..3]);
|
||||
assert_eq!(test_read(9, 2), &file.content[9..]);
|
||||
assert!(test_read(10, 2).is_empty());
|
||||
assert!(test_read(11, 2).is_empty());
|
||||
}
|
||||
@@ -618,7 +609,7 @@ mod tests {
|
||||
}
|
||||
|
||||
impl<'x> File for RecorderFile<'x> {
|
||||
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufAlignedMut + Send>(
|
||||
async fn read_exact_at_eof_ok<'a, 'b, B: IoBufMut + Send>(
|
||||
&'b self,
|
||||
start: u64,
|
||||
dst: Slice<B>,
|
||||
@@ -791,7 +782,7 @@ mod tests {
|
||||
2048, 1024 => Err("foo".to_owned()),
|
||||
};
|
||||
|
||||
let buf = IoBufferMut::with_capacity_aligned(512, 512);
|
||||
let buf = Vec::with_capacity(512);
|
||||
let (buf, nread) = mock_file
|
||||
.read_exact_at_eof_ok(0, buf.slice_full(), &ctx)
|
||||
.await
|
||||
@@ -799,7 +790,7 @@ mod tests {
|
||||
assert_eq!(nread, 512);
|
||||
assert_eq!(&buf.into_inner()[..nread], &[0; 512]);
|
||||
|
||||
let buf = IoBufferMut::with_capacity_aligned(512, 512);
|
||||
let buf = Vec::with_capacity(512);
|
||||
let (buf, nread) = mock_file
|
||||
.read_exact_at_eof_ok(512, buf.slice_full(), &ctx)
|
||||
.await
|
||||
@@ -807,7 +798,7 @@ mod tests {
|
||||
assert_eq!(nread, 512);
|
||||
assert_eq!(&buf.into_inner()[..nread], &[1; 512]);
|
||||
|
||||
let buf = IoBufferMut::with_capacity_aligned(512, 512);
|
||||
let buf = Vec::with_capacity(512);
|
||||
let (buf, nread) = mock_file
|
||||
.read_exact_at_eof_ok(1024, buf.slice_full(), &ctx)
|
||||
.await
|
||||
@@ -815,7 +806,7 @@ mod tests {
|
||||
assert_eq!(nread, 10);
|
||||
assert_eq!(&buf.into_inner()[..nread], &[2; 10]);
|
||||
|
||||
let buf = IoBufferMut::with_capacity_aligned(1024, 512);
|
||||
let buf = Vec::with_capacity(1024);
|
||||
let err = mock_file
|
||||
.read_exact_at_eof_ok(2048, buf.slice_full(), &ctx)
|
||||
.await
|
||||
|
||||
@@ -330,6 +330,7 @@ async fn gc_loop(tenant: Arc<Tenant>, cancel: CancellationToken) {
|
||||
RequestContext::todo_child(TaskKind::GarbageCollector, DownloadBehavior::Download);
|
||||
|
||||
let mut first = true;
|
||||
tenant.gc_block.set_lsn_lease_deadline(tenant.get_lsn_lease_length());
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = cancel.cancelled() => {
|
||||
|
||||
@@ -66,7 +66,6 @@ use std::{
|
||||
use crate::{
|
||||
aux_file::AuxFileSizeEstimator,
|
||||
tenant::{
|
||||
config::AttachmentMode,
|
||||
layer_map::{LayerMap, SearchResult},
|
||||
metadata::TimelineMetadata,
|
||||
storage_layer::{inmemory_layer::IndexEntry, PersistentLayerDesc},
|
||||
@@ -1325,38 +1324,16 @@ impl Timeline {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initializes an LSN lease. The function will return an error if the requested LSN is less than the `latest_gc_cutoff_lsn`.
|
||||
pub(crate) fn init_lsn_lease(
|
||||
&self,
|
||||
lsn: Lsn,
|
||||
length: Duration,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<LsnLease> {
|
||||
self.make_lsn_lease(lsn, length, true, ctx)
|
||||
}
|
||||
|
||||
/// Renews a lease at a particular LSN. The requested LSN is not validated against the `latest_gc_cutoff_lsn` when we are in the grace period.
|
||||
pub(crate) fn renew_lsn_lease(
|
||||
&self,
|
||||
lsn: Lsn,
|
||||
length: Duration,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<LsnLease> {
|
||||
self.make_lsn_lease(lsn, length, false, ctx)
|
||||
}
|
||||
|
||||
/// Obtains a temporary lease blocking garbage collection for the given LSN.
|
||||
///
|
||||
/// If we are in `AttachedSingle` mode and is not blocked by the lsn lease deadline, this function will error
|
||||
/// if the requesting LSN is less than the `latest_gc_cutoff_lsn` and there is no existing request present.
|
||||
///
|
||||
/// If there is an existing lease in the map, the lease will be renewed only if the request extends the lease.
|
||||
/// The returned lease is therefore the maximum between the existing lease and the requesting lease.
|
||||
fn make_lsn_lease(
|
||||
/// This function will error if the requesting LSN is less than the `latest_gc_cutoff_lsn` and there is also
|
||||
/// no existing lease to renew. If there is an existing lease in the map, the lease will be renewed only if
|
||||
/// the request extends the lease. The returned lease is therefore the maximum between the existing lease and
|
||||
/// the requesting lease.
|
||||
pub(crate) fn make_lsn_lease(
|
||||
&self,
|
||||
lsn: Lsn,
|
||||
length: Duration,
|
||||
init: bool,
|
||||
_ctx: &RequestContext,
|
||||
) -> anyhow::Result<LsnLease> {
|
||||
let lease = {
|
||||
@@ -1370,8 +1347,8 @@ impl Timeline {
|
||||
|
||||
let entry = gc_info.leases.entry(lsn);
|
||||
|
||||
match entry {
|
||||
Entry::Occupied(mut occupied) => {
|
||||
let lease = {
|
||||
if let Entry::Occupied(mut occupied) = entry {
|
||||
let existing_lease = occupied.get_mut();
|
||||
if valid_until > existing_lease.valid_until {
|
||||
existing_lease.valid_until = valid_until;
|
||||
@@ -1383,28 +1360,20 @@ impl Timeline {
|
||||
}
|
||||
|
||||
existing_lease.clone()
|
||||
}
|
||||
Entry::Vacant(vacant) => {
|
||||
// Reject already GC-ed LSN (lsn < latest_gc_cutoff) if we are in AttachedSingle and
|
||||
// not blocked by the lsn lease deadline.
|
||||
let validate = {
|
||||
let conf = self.tenant_conf.load();
|
||||
conf.location.attach_mode == AttachmentMode::Single
|
||||
&& !conf.is_gc_blocked_by_lsn_lease_deadline()
|
||||
};
|
||||
|
||||
if init || validate {
|
||||
let latest_gc_cutoff_lsn = self.get_latest_gc_cutoff_lsn();
|
||||
if lsn < *latest_gc_cutoff_lsn {
|
||||
bail!("tried to request a page version that was garbage collected. requested at {} gc cutoff {}", lsn, *latest_gc_cutoff_lsn);
|
||||
}
|
||||
} else {
|
||||
// Reject already GC-ed LSN (lsn < latest_gc_cutoff)
|
||||
let latest_gc_cutoff_lsn = self.get_latest_gc_cutoff_lsn();
|
||||
if lsn < *latest_gc_cutoff_lsn {
|
||||
bail!("tried to request a page version that was garbage collected. requested at {} gc cutoff {}", lsn, *latest_gc_cutoff_lsn);
|
||||
}
|
||||
|
||||
let dt: DateTime<Utc> = valid_until.into();
|
||||
info!("lease created, valid until {}", dt);
|
||||
vacant.insert(LsnLease { valid_until }).clone()
|
||||
entry.or_insert(LsnLease { valid_until }).clone()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
lease
|
||||
};
|
||||
|
||||
Ok(lease)
|
||||
@@ -1981,6 +1950,8 @@ impl Timeline {
|
||||
.unwrap_or(self.conf.default_tenant_conf.lsn_lease_length)
|
||||
}
|
||||
|
||||
// TODO(yuchen): remove unused flag after implementing https://github.com/neondatabase/neon/issues/8072
|
||||
#[allow(unused)]
|
||||
pub(crate) fn get_lsn_lease_length_for_ts(&self) -> Duration {
|
||||
let tenant_conf = self.tenant_conf.load();
|
||||
tenant_conf
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::ops::Deref;
|
||||
|
||||
use bytes::Bytes;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pageserver_api::key::Key;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio_epoll_uring::BoundedBuf;
|
||||
@@ -27,7 +27,6 @@ use utils::vec_map::VecMap;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::tenant::blob_io::{BYTE_UNCOMPRESSED, BYTE_ZSTD, LEN_COMPRESSION_BIT_MASK};
|
||||
use crate::virtual_file::dio::IoBufferMut;
|
||||
use crate::virtual_file::{self, VirtualFile};
|
||||
|
||||
/// Metadata bundled with the start and end offset of a blob.
|
||||
@@ -159,7 +158,7 @@ impl std::fmt::Display for VectoredBlob {
|
||||
/// Return type of [`VectoredBlobReader::read_blobs`]
|
||||
pub struct VectoredBlobsBuf {
|
||||
/// Buffer for all blobs in this read
|
||||
pub buf: IoBufferMut,
|
||||
pub buf: BytesMut,
|
||||
/// Offsets into the buffer and metadata for all blobs in this read
|
||||
pub blobs: Vec<VectoredBlob>,
|
||||
}
|
||||
@@ -186,7 +185,171 @@ pub(crate) enum VectoredReadExtended {
|
||||
No,
|
||||
}
|
||||
|
||||
/// A vectored read builder that tries to coalesce all reads that fits in a chunk.
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum VectoredReadCoalesceMode {
|
||||
/// Only coalesce exactly adjacent reads.
|
||||
AdjacentOnly,
|
||||
/// In addition to adjacent reads, also consider reads whose corresponding
|
||||
/// `end` and `start` offsets reside at the same chunk.
|
||||
Chunked(usize),
|
||||
}
|
||||
|
||||
impl VectoredReadCoalesceMode {
|
||||
/// [`AdjacentVectoredReadBuilder`] is used if alignment requirement is 0,
|
||||
/// whereas [`ChunkedVectoredReadBuilder`] is used for alignment requirement 1 and higher.
|
||||
pub(crate) fn get() -> Self {
|
||||
let align = virtual_file::get_io_buffer_alignment_raw();
|
||||
if align == 0 {
|
||||
VectoredReadCoalesceMode::AdjacentOnly
|
||||
} else {
|
||||
VectoredReadCoalesceMode::Chunked(align)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum VectoredReadBuilder {
|
||||
Adjacent(AdjacentVectoredReadBuilder),
|
||||
Chunked(ChunkedVectoredReadBuilder),
|
||||
}
|
||||
|
||||
impl VectoredReadBuilder {
|
||||
fn new_impl(
|
||||
start_offset: u64,
|
||||
end_offset: u64,
|
||||
meta: BlobMeta,
|
||||
max_read_size: Option<usize>,
|
||||
mode: VectoredReadCoalesceMode,
|
||||
) -> Self {
|
||||
match mode {
|
||||
VectoredReadCoalesceMode::AdjacentOnly => Self::Adjacent(
|
||||
AdjacentVectoredReadBuilder::new(start_offset, end_offset, meta, max_read_size),
|
||||
),
|
||||
VectoredReadCoalesceMode::Chunked(chunk_size) => {
|
||||
Self::Chunked(ChunkedVectoredReadBuilder::new(
|
||||
start_offset,
|
||||
end_offset,
|
||||
meta,
|
||||
max_read_size,
|
||||
chunk_size,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new(
|
||||
start_offset: u64,
|
||||
end_offset: u64,
|
||||
meta: BlobMeta,
|
||||
max_read_size: usize,
|
||||
mode: VectoredReadCoalesceMode,
|
||||
) -> Self {
|
||||
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), mode)
|
||||
}
|
||||
|
||||
pub(crate) fn new_streaming(
|
||||
start_offset: u64,
|
||||
end_offset: u64,
|
||||
meta: BlobMeta,
|
||||
mode: VectoredReadCoalesceMode,
|
||||
) -> Self {
|
||||
Self::new_impl(start_offset, end_offset, meta, None, mode)
|
||||
}
|
||||
|
||||
pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended {
|
||||
match self {
|
||||
VectoredReadBuilder::Adjacent(builder) => builder.extend(start, end, meta),
|
||||
VectoredReadBuilder::Chunked(builder) => builder.extend(start, end, meta),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build(self) -> VectoredRead {
|
||||
match self {
|
||||
VectoredReadBuilder::Adjacent(builder) => builder.build(),
|
||||
VectoredReadBuilder::Chunked(builder) => builder.build(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn size(&self) -> usize {
|
||||
match self {
|
||||
VectoredReadBuilder::Adjacent(builder) => builder.size(),
|
||||
VectoredReadBuilder::Chunked(builder) => builder.size(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct AdjacentVectoredReadBuilder {
|
||||
/// Start offset of the read.
|
||||
start: u64,
|
||||
// End offset of the read.
|
||||
end: u64,
|
||||
/// Start offset and metadata for each blob in this read
|
||||
blobs_at: VecMap<u64, BlobMeta>,
|
||||
max_read_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl AdjacentVectoredReadBuilder {
|
||||
/// Start building a new vectored read.
|
||||
///
|
||||
/// Note that by design, this does not check against reading more than `max_read_size` to
|
||||
/// support reading larger blobs than the configuration value. The builder will be single use
|
||||
/// however after that.
|
||||
pub(crate) fn new(
|
||||
start_offset: u64,
|
||||
end_offset: u64,
|
||||
meta: BlobMeta,
|
||||
max_read_size: Option<usize>,
|
||||
) -> Self {
|
||||
let mut blobs_at = VecMap::default();
|
||||
blobs_at
|
||||
.append(start_offset, meta)
|
||||
.expect("First insertion always succeeds");
|
||||
|
||||
Self {
|
||||
start: start_offset,
|
||||
end: end_offset,
|
||||
blobs_at,
|
||||
max_read_size,
|
||||
}
|
||||
}
|
||||
/// Attempt to extend the current read with a new blob if the start
|
||||
/// offset matches with the current end of the vectored read
|
||||
/// and the resuting size is below the max read size
|
||||
pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended {
|
||||
tracing::trace!(start, end, "trying to extend");
|
||||
let size = (end - start) as usize;
|
||||
let not_limited_by_max_read_size = {
|
||||
if let Some(max_read_size) = self.max_read_size {
|
||||
self.size() + size <= max_read_size
|
||||
} else {
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
if self.end == start && not_limited_by_max_read_size {
|
||||
self.end = end;
|
||||
self.blobs_at
|
||||
.append(start, meta)
|
||||
.expect("LSNs are ordered within vectored reads");
|
||||
|
||||
return VectoredReadExtended::Yes;
|
||||
}
|
||||
|
||||
VectoredReadExtended::No
|
||||
}
|
||||
|
||||
pub(crate) fn size(&self) -> usize {
|
||||
(self.end - self.start) as usize
|
||||
}
|
||||
|
||||
pub(crate) fn build(self) -> VectoredRead {
|
||||
VectoredRead {
|
||||
start: self.start,
|
||||
end: self.end,
|
||||
blobs_at: self.blobs_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ChunkedVectoredReadBuilder {
|
||||
/// Start block number
|
||||
start_blk_no: usize,
|
||||
@@ -210,7 +373,7 @@ impl ChunkedVectoredReadBuilder {
|
||||
/// Note that by design, this does not check against reading more than `max_read_size` to
|
||||
/// support reading larger blobs than the configuration value. The builder will be single use
|
||||
/// however after that.
|
||||
fn new_impl(
|
||||
pub(crate) fn new(
|
||||
start_offset: u64,
|
||||
end_offset: u64,
|
||||
meta: BlobMeta,
|
||||
@@ -233,25 +396,6 @@ impl ChunkedVectoredReadBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new(
|
||||
start_offset: u64,
|
||||
end_offset: u64,
|
||||
meta: BlobMeta,
|
||||
max_read_size: usize,
|
||||
align: usize,
|
||||
) -> Self {
|
||||
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), align)
|
||||
}
|
||||
|
||||
pub(crate) fn new_streaming(
|
||||
start_offset: u64,
|
||||
end_offset: u64,
|
||||
meta: BlobMeta,
|
||||
align: usize,
|
||||
) -> Self {
|
||||
Self::new_impl(start_offset, end_offset, meta, None, align)
|
||||
}
|
||||
|
||||
/// Attempts to extend the current read with a new blob if the new blob resides in the same or the immediate next chunk.
|
||||
///
|
||||
/// The resulting size also must be below the max read size.
|
||||
@@ -330,17 +474,17 @@ pub struct VectoredReadPlanner {
|
||||
|
||||
max_read_size: usize,
|
||||
|
||||
align: usize,
|
||||
mode: VectoredReadCoalesceMode,
|
||||
}
|
||||
|
||||
impl VectoredReadPlanner {
|
||||
pub fn new(max_read_size: usize) -> Self {
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let mode = VectoredReadCoalesceMode::get();
|
||||
Self {
|
||||
blobs: BTreeMap::new(),
|
||||
prev: None,
|
||||
max_read_size,
|
||||
align,
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -401,7 +545,7 @@ impl VectoredReadPlanner {
|
||||
}
|
||||
|
||||
pub fn finish(self) -> Vec<VectoredRead> {
|
||||
let mut current_read_builder: Option<ChunkedVectoredReadBuilder> = None;
|
||||
let mut current_read_builder: Option<VectoredReadBuilder> = None;
|
||||
let mut reads = Vec::new();
|
||||
|
||||
for (key, blobs_for_key) in self.blobs {
|
||||
@@ -414,12 +558,12 @@ impl VectoredReadPlanner {
|
||||
};
|
||||
|
||||
if extended == VectoredReadExtended::No {
|
||||
let next_read_builder = ChunkedVectoredReadBuilder::new(
|
||||
let next_read_builder = VectoredReadBuilder::new(
|
||||
start_offset,
|
||||
end_offset,
|
||||
BlobMeta { key, lsn },
|
||||
self.max_read_size,
|
||||
self.align,
|
||||
self.mode,
|
||||
);
|
||||
|
||||
let prev_read_builder = current_read_builder.replace(next_read_builder);
|
||||
@@ -461,7 +605,7 @@ impl<'a> VectoredBlobReader<'a> {
|
||||
pub async fn read_blobs(
|
||||
&self,
|
||||
read: &VectoredRead,
|
||||
buf: IoBufferMut,
|
||||
buf: BytesMut,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<VectoredBlobsBuf, std::io::Error> {
|
||||
assert!(read.size() > 0);
|
||||
@@ -544,7 +688,7 @@ impl<'a> VectoredBlobReader<'a> {
|
||||
/// `handle` gets called and when the current key would just exceed the read_size and
|
||||
/// max_cnt constraints.
|
||||
pub struct StreamingVectoredReadPlanner {
|
||||
read_builder: Option<ChunkedVectoredReadBuilder>,
|
||||
read_builder: Option<VectoredReadBuilder>,
|
||||
// Arguments for previous blob passed into [`StreamingVectoredReadPlanner::handle`]
|
||||
prev: Option<(Key, Lsn, u64)>,
|
||||
/// Max read size per batch. This is not a strict limit. If there are [0, 100) and [100, 200), while the `max_read_size` is 150,
|
||||
@@ -555,21 +699,21 @@ pub struct StreamingVectoredReadPlanner {
|
||||
/// Size of the current batch
|
||||
cnt: usize,
|
||||
|
||||
align: usize,
|
||||
mode: VectoredReadCoalesceMode,
|
||||
}
|
||||
|
||||
impl StreamingVectoredReadPlanner {
|
||||
pub fn new(max_read_size: u64, max_cnt: usize) -> Self {
|
||||
assert!(max_cnt > 0);
|
||||
assert!(max_read_size > 0);
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let mode = VectoredReadCoalesceMode::get();
|
||||
Self {
|
||||
read_builder: None,
|
||||
prev: None,
|
||||
max_cnt,
|
||||
max_read_size,
|
||||
cnt: 0,
|
||||
align,
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,11 +762,11 @@ impl StreamingVectoredReadPlanner {
|
||||
}
|
||||
None => {
|
||||
self.read_builder = {
|
||||
Some(ChunkedVectoredReadBuilder::new_streaming(
|
||||
Some(VectoredReadBuilder::new_streaming(
|
||||
start_offset,
|
||||
end_offset,
|
||||
BlobMeta { key, lsn },
|
||||
self.align,
|
||||
self.mode,
|
||||
))
|
||||
};
|
||||
}
|
||||
@@ -946,10 +1090,9 @@ mod tests {
|
||||
|
||||
// Multiply by two (compressed data might need more space), and add a few bytes for the header
|
||||
let reserved_bytes = blobs.iter().map(|bl| bl.len()).max().unwrap() * 2 + 16;
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let mut buf = IoBufferMut::with_capacity_aligned(reserved_bytes, align);
|
||||
let mut buf = BytesMut::with_capacity(reserved_bytes);
|
||||
|
||||
let align = virtual_file::get_io_buffer_alignment();
|
||||
let mode = VectoredReadCoalesceMode::get();
|
||||
let vectored_blob_reader = VectoredBlobReader::new(&file);
|
||||
let meta = BlobMeta {
|
||||
key: Key::MIN,
|
||||
@@ -961,8 +1104,7 @@ mod tests {
|
||||
if idx + 1 == offsets.len() {
|
||||
continue;
|
||||
}
|
||||
let read_builder =
|
||||
ChunkedVectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, align);
|
||||
let read_builder = VectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, mode);
|
||||
let read = read_builder.build();
|
||||
let result = vectored_blob_reader.read_blobs(&read, buf, &ctx).await?;
|
||||
assert_eq!(result.blobs.len(), 1);
|
||||
|
||||
@@ -17,21 +17,16 @@ use crate::metrics::{StorageIoOperation, STORAGE_IO_SIZE, STORAGE_IO_TIME_METRIC
|
||||
use crate::page_cache::{PageWriteGuard, PAGE_SZ};
|
||||
use crate::tenant::TENANTS_SEGMENT_NAME;
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
#[cfg(test)]
|
||||
use dio::IoBufferMut;
|
||||
use once_cell::sync::OnceCell;
|
||||
use owned_buffers_io::io_buf_aligned::IoBufAlignedMut;
|
||||
use owned_buffers_io::io_buf_ext::FullSlice;
|
||||
use pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use std::fs::File;
|
||||
use std::io::{Error, ErrorKind, Seek, SeekFrom};
|
||||
#[cfg(target_os = "linux")]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
use tokio_epoll_uring::{BoundedBuf, IoBuf, IoBufMut, Slice};
|
||||
|
||||
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering};
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
|
||||
use tokio::time::Instant;
|
||||
|
||||
@@ -43,11 +38,10 @@ pub use io_engine::FeatureTestResult as IoEngineFeatureTestResult;
|
||||
mod metadata;
|
||||
mod open_options;
|
||||
use self::owned_buffers_io::write::OwnedAsyncWriter;
|
||||
pub(crate) use api::IoMode;
|
||||
pub(crate) use api::DirectIoMode;
|
||||
pub(crate) use io_engine::IoEngineKind;
|
||||
pub(crate) use metadata::Metadata;
|
||||
pub(crate) use open_options::*;
|
||||
pub(crate) mod dio;
|
||||
|
||||
pub(crate) mod owned_buffers_io {
|
||||
//! Abstractions for IO with owned buffers.
|
||||
@@ -59,7 +53,6 @@ pub(crate) mod owned_buffers_io {
|
||||
//! but for the time being we're proving out the primitives in the neon.git repo
|
||||
//! for faster iteration.
|
||||
|
||||
pub(crate) mod io_buf_aligned;
|
||||
pub(crate) mod io_buf_ext;
|
||||
pub(crate) mod slice;
|
||||
pub(crate) mod write;
|
||||
@@ -68,176 +61,6 @@ pub(crate) mod owned_buffers_io {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum VirtualFile {
|
||||
Buffered(VirtualFileInner),
|
||||
Direct(VirtualFileInner),
|
||||
}
|
||||
|
||||
impl VirtualFile {
|
||||
fn inner(&self) -> &VirtualFileInner {
|
||||
match self {
|
||||
Self::Buffered(file) => file,
|
||||
Self::Direct(file) => file,
|
||||
}
|
||||
}
|
||||
|
||||
fn inner_mut(&mut self) -> &mut VirtualFileInner {
|
||||
match self {
|
||||
Self::Buffered(file) => file,
|
||||
Self::Direct(file) => file,
|
||||
}
|
||||
}
|
||||
|
||||
fn into_inner(self) -> VirtualFileInner {
|
||||
match self {
|
||||
Self::Buffered(file) => file,
|
||||
Self::Direct(file) => file,
|
||||
}
|
||||
}
|
||||
/// Open a file in read-only mode. Like File::open.
|
||||
pub async fn open<P: AsRef<Utf8Path>>(
|
||||
path: P,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let file = VirtualFileInner::open(path, ctx).await?;
|
||||
Ok(Self::Buffered(file))
|
||||
}
|
||||
|
||||
/// Open a file in read-only mode. Like File::open.
|
||||
///
|
||||
/// `O_DIRECT` will be enabled base on `virtual_file_io_mode`.
|
||||
pub async fn open_v2<P: AsRef<Utf8Path>>(
|
||||
path: P,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
Self::open_with_options_v2(path.as_ref(), OpenOptions::new().read(true), ctx).await
|
||||
}
|
||||
|
||||
pub async fn create<P: AsRef<Utf8Path>>(
|
||||
path: P,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let file = VirtualFileInner::create(path, ctx).await?;
|
||||
Ok(Self::Buffered(file))
|
||||
}
|
||||
|
||||
pub async fn create_v2<P: AsRef<Utf8Path>>(
|
||||
path: P,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Self, std::io::Error> {
|
||||
VirtualFile::open_with_options_v2(
|
||||
path.as_ref(),
|
||||
OpenOptions::new().write(true).create(true).truncate(true),
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn open_with_options<P: AsRef<Utf8Path>>(
|
||||
path: P,
|
||||
open_options: &OpenOptions,
|
||||
ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let file = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
|
||||
Ok(Self::Buffered(file))
|
||||
}
|
||||
|
||||
pub async fn open_with_options_v2<P: AsRef<Utf8Path>>(
|
||||
path: P,
|
||||
open_options: &mut OpenOptions, // Uses `&mut` here to add `O_DIRECT`.
|
||||
ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
|
||||
) -> Result<Self, std::io::Error> {
|
||||
let file = match get_io_mode() {
|
||||
IoMode::Buffered => {
|
||||
let file = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
|
||||
Self::Buffered(file)
|
||||
}
|
||||
#[cfg(target_os = "linux")]
|
||||
IoMode::Direct => {
|
||||
let file = VirtualFileInner::open_with_options(
|
||||
path,
|
||||
open_options.custom_flags(nix::libc::O_DIRECT),
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
Self::Direct(file)
|
||||
}
|
||||
};
|
||||
Ok(file)
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Utf8Path {
|
||||
self.inner().path.as_path()
|
||||
}
|
||||
|
||||
pub async fn crashsafe_overwrite<B: BoundedBuf<Buf = Buf> + Send, Buf: IoBuf + Send>(
|
||||
final_path: Utf8PathBuf,
|
||||
tmp_path: Utf8PathBuf,
|
||||
content: B,
|
||||
) -> std::io::Result<()> {
|
||||
VirtualFileInner::crashsafe_overwrite(final_path, tmp_path, content).await
|
||||
}
|
||||
|
||||
pub async fn sync_all(&self) -> Result<(), Error> {
|
||||
self.inner().sync_all().await
|
||||
}
|
||||
|
||||
pub async fn sync_data(&self) -> Result<(), Error> {
|
||||
self.inner().sync_data().await
|
||||
}
|
||||
|
||||
pub async fn metadata(&self) -> Result<Metadata, Error> {
|
||||
self.inner().metadata().await
|
||||
}
|
||||
|
||||
pub fn remove(self) {
|
||||
self.into_inner().remove();
|
||||
}
|
||||
|
||||
pub async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error> {
|
||||
self.inner_mut().seek(pos).await
|
||||
}
|
||||
|
||||
pub async fn read_exact_at<Buf>(
|
||||
&self,
|
||||
slice: Slice<Buf>,
|
||||
offset: u64,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Slice<Buf>, Error>
|
||||
where
|
||||
Buf: IoBufAlignedMut + Send,
|
||||
{
|
||||
self.inner().read_exact_at(slice, offset, ctx).await
|
||||
}
|
||||
|
||||
pub async fn read_exact_at_page(
|
||||
&self,
|
||||
page: PageWriteGuard<'static>,
|
||||
offset: u64,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<PageWriteGuard<'static>, Error> {
|
||||
self.inner().read_exact_at_page(page, offset, ctx).await
|
||||
}
|
||||
|
||||
pub async fn write_all_at<Buf: IoBuf + Send>(
|
||||
&self,
|
||||
buf: FullSlice<Buf>,
|
||||
offset: u64,
|
||||
ctx: &RequestContext,
|
||||
) -> (FullSlice<Buf>, Result<(), Error>) {
|
||||
self.inner().write_all_at(buf, offset, ctx).await
|
||||
}
|
||||
|
||||
pub async fn write_all<Buf: IoBuf + Send>(
|
||||
&mut self,
|
||||
buf: FullSlice<Buf>,
|
||||
ctx: &RequestContext,
|
||||
) -> (FullSlice<Buf>, Result<usize, Error>) {
|
||||
self.inner_mut().write_all(buf, ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// A virtual file descriptor. You can use this just like std::fs::File, but internally
|
||||
/// the underlying file is closed if the system is low on file descriptors,
|
||||
@@ -254,7 +77,7 @@ impl VirtualFile {
|
||||
/// 'tag' field is used to detect whether the handle still is valid or not.
|
||||
///
|
||||
#[derive(Debug)]
|
||||
pub struct VirtualFileInner {
|
||||
pub struct VirtualFile {
|
||||
/// Lazy handle to the global file descriptor cache. The slot that this points to
|
||||
/// might contain our File, or it may be empty, or it may contain a File that
|
||||
/// belongs to a different VirtualFile.
|
||||
@@ -527,12 +350,12 @@ macro_rules! with_file {
|
||||
}};
|
||||
}
|
||||
|
||||
impl VirtualFileInner {
|
||||
impl VirtualFile {
|
||||
/// Open a file in read-only mode. Like File::open.
|
||||
pub async fn open<P: AsRef<Utf8Path>>(
|
||||
path: P,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<VirtualFileInner, std::io::Error> {
|
||||
) -> Result<VirtualFile, std::io::Error> {
|
||||
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true), ctx).await
|
||||
}
|
||||
|
||||
@@ -541,7 +364,7 @@ impl VirtualFileInner {
|
||||
pub async fn create<P: AsRef<Utf8Path>>(
|
||||
path: P,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<VirtualFileInner, std::io::Error> {
|
||||
) -> Result<VirtualFile, std::io::Error> {
|
||||
Self::open_with_options(
|
||||
path.as_ref(),
|
||||
OpenOptions::new().write(true).create(true).truncate(true),
|
||||
@@ -559,7 +382,7 @@ impl VirtualFileInner {
|
||||
path: P,
|
||||
open_options: &OpenOptions,
|
||||
_ctx: &RequestContext, /* TODO: carry a pointer to the metrics in the RequestContext instead of the parsing https://github.com/neondatabase/neon/issues/6107 */
|
||||
) -> Result<VirtualFileInner, std::io::Error> {
|
||||
) -> Result<VirtualFile, std::io::Error> {
|
||||
let path_ref = path.as_ref();
|
||||
let path_str = path_ref.to_string();
|
||||
let parts = path_str.split('/').collect::<Vec<&str>>();
|
||||
@@ -590,7 +413,7 @@ impl VirtualFileInner {
|
||||
open_options.open(path_ref.as_std_path()).await?
|
||||
});
|
||||
|
||||
// Strip all options other than read and write (O_DIRECT).
|
||||
// Strip all options other than read and write.
|
||||
//
|
||||
// It would perhaps be nicer to check just for the read and write flags
|
||||
// explicitly, but OpenOptions doesn't contain any functions to read flags,
|
||||
@@ -600,7 +423,7 @@ impl VirtualFileInner {
|
||||
reopen_options.create_new(false);
|
||||
reopen_options.truncate(false);
|
||||
|
||||
let vfile = VirtualFileInner {
|
||||
let vfile = VirtualFile {
|
||||
handle: RwLock::new(handle),
|
||||
pos: 0,
|
||||
path: path_ref.to_path_buf(),
|
||||
@@ -643,7 +466,6 @@ impl VirtualFileInner {
|
||||
&[]
|
||||
};
|
||||
utils::crashsafe::overwrite(&final_path, &tmp_path, content)
|
||||
.maybe_fatal_err("crashsafe_overwrite")
|
||||
})
|
||||
.await
|
||||
.expect("blocking task is never aborted")
|
||||
@@ -653,7 +475,7 @@ impl VirtualFileInner {
|
||||
pub async fn sync_all(&self) -> Result<(), Error> {
|
||||
with_file!(self, StorageIoOperation::Fsync, |file_guard| {
|
||||
let (_file_guard, res) = io_engine::get().sync_all(file_guard).await;
|
||||
res.maybe_fatal_err("sync_all")
|
||||
res
|
||||
})
|
||||
}
|
||||
|
||||
@@ -661,7 +483,7 @@ impl VirtualFileInner {
|
||||
pub async fn sync_data(&self) -> Result<(), Error> {
|
||||
with_file!(self, StorageIoOperation::Fsync, |file_guard| {
|
||||
let (_file_guard, res) = io_engine::get().sync_data(file_guard).await;
|
||||
res.maybe_fatal_err("sync_data")
|
||||
res
|
||||
})
|
||||
}
|
||||
|
||||
@@ -781,7 +603,7 @@ impl VirtualFileInner {
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Slice<Buf>, Error>
|
||||
where
|
||||
Buf: IoBufAlignedMut + Send,
|
||||
Buf: IoBufMut + Send,
|
||||
{
|
||||
let assert_we_return_original_bounds = if cfg!(debug_assertions) {
|
||||
Some((slice.stable_ptr() as usize, slice.bytes_total()))
|
||||
@@ -1211,36 +1033,18 @@ impl tokio_epoll_uring::IoFd for FileGuard {
|
||||
|
||||
#[cfg(test)]
|
||||
impl VirtualFile {
|
||||
pub(crate) async fn read_blk(
|
||||
&self,
|
||||
blknum: u32,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<crate::tenant::block_io::BlockLease<'_>, std::io::Error> {
|
||||
self.inner().read_blk(blknum, ctx).await
|
||||
}
|
||||
|
||||
async fn read_to_end(&mut self, buf: &mut Vec<u8>, ctx: &RequestContext) -> Result<(), Error> {
|
||||
self.inner_mut().read_to_end(buf, ctx).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl VirtualFileInner {
|
||||
pub(crate) async fn read_blk(
|
||||
&self,
|
||||
blknum: u32,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<crate::tenant::block_io::BlockLease<'_>, std::io::Error> {
|
||||
use crate::page_cache::PAGE_SZ;
|
||||
let align = get_io_buffer_alignment();
|
||||
let slice = IoBufferMut::with_capacity_aligned(PAGE_SZ, align).slice_full();
|
||||
let slice = Vec::with_capacity(PAGE_SZ).slice_full();
|
||||
assert_eq!(slice.bytes_total(), PAGE_SZ);
|
||||
let slice = self
|
||||
.read_exact_at(slice, blknum as u64 * (PAGE_SZ as u64), ctx)
|
||||
.await?;
|
||||
Ok(crate::tenant::block_io::BlockLease::IoBufferMut(
|
||||
slice.into_inner(),
|
||||
))
|
||||
Ok(crate::tenant::block_io::BlockLease::Vec(slice.into_inner()))
|
||||
}
|
||||
|
||||
async fn read_to_end(&mut self, buf: &mut Vec<u8>, ctx: &RequestContext) -> Result<(), Error> {
|
||||
@@ -1262,7 +1066,7 @@ impl VirtualFileInner {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for VirtualFileInner {
|
||||
impl Drop for VirtualFile {
|
||||
/// If a VirtualFile is dropped, close the underlying file if it was open.
|
||||
fn drop(&mut self) {
|
||||
let handle = self.handle.get_mut();
|
||||
@@ -1343,9 +1147,7 @@ pub fn init(num_slots: usize, engine: IoEngineKind, io_buffer_alignment: usize)
|
||||
panic!("virtual_file::init called twice");
|
||||
}
|
||||
if set_io_buffer_alignment(io_buffer_alignment).is_err() {
|
||||
panic!(
|
||||
"IO buffer alignment needs to be a power of two and greater than 512, got {io_buffer_alignment}"
|
||||
);
|
||||
panic!("IO buffer alignment ({io_buffer_alignment}) is not a power of two");
|
||||
}
|
||||
io_engine::init(engine);
|
||||
crate::metrics::virtual_file_descriptor_cache::SIZE_MAX.set(num_slots as u64);
|
||||
@@ -1372,16 +1174,14 @@ fn get_open_files() -> &'static OpenFiles {
|
||||
|
||||
static IO_BUFFER_ALIGNMENT: AtomicUsize = AtomicUsize::new(DEFAULT_IO_BUFFER_ALIGNMENT);
|
||||
|
||||
/// Returns true if the alignment is a power of two and is greater or equal to 512.
|
||||
fn is_valid_io_buffer_alignment(align: usize) -> bool {
|
||||
align.is_power_of_two() && align >= 512
|
||||
/// Returns true if `x` is zero or a power of two.
|
||||
fn is_zero_or_power_of_two(x: usize) -> bool {
|
||||
(x == 0) || ((x & (x - 1)) == 0)
|
||||
}
|
||||
|
||||
/// Sets IO buffer alignment requirement. Returns error if the alignment requirement is
|
||||
/// not a power of two or less than 512 bytes.
|
||||
#[allow(unused)]
|
||||
pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> {
|
||||
if is_valid_io_buffer_alignment(align) {
|
||||
if is_zero_or_power_of_two(align) {
|
||||
IO_BUFFER_ALIGNMENT.store(align, std::sync::atomic::Ordering::Relaxed);
|
||||
Ok(())
|
||||
} else {
|
||||
@@ -1389,19 +1189,19 @@ pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the io buffer alignment.
|
||||
/// Gets the io buffer alignment requirement. Returns 0 if there is no requirement specified.
|
||||
///
|
||||
/// This function should be used for getting the actual alignment value to use.
|
||||
pub(crate) fn get_io_buffer_alignment() -> usize {
|
||||
/// This function should be used to check the raw config value.
|
||||
pub(crate) fn get_io_buffer_alignment_raw() -> usize {
|
||||
let align = IO_BUFFER_ALIGNMENT.load(std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
if cfg!(test) {
|
||||
let env_var_name = "NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT";
|
||||
if let Some(test_align) = utils::env::var(env_var_name) {
|
||||
if is_valid_io_buffer_alignment(test_align) {
|
||||
if is_zero_or_power_of_two(test_align) {
|
||||
test_align
|
||||
} else {
|
||||
panic!("IO buffer alignment needs to be a power of two and greater than 512, got {test_align}");
|
||||
panic!("IO buffer alignment ({test_align}) is not a power of two");
|
||||
}
|
||||
} else {
|
||||
align
|
||||
@@ -1411,22 +1211,20 @@ pub(crate) fn get_io_buffer_alignment() -> usize {
|
||||
}
|
||||
}
|
||||
|
||||
static IO_MODE: AtomicU8 = AtomicU8::new(IoMode::preferred() as u8);
|
||||
|
||||
pub(crate) fn set_io_mode(mode: IoMode) {
|
||||
IO_MODE.store(mode as u8, std::sync::atomic::Ordering::Relaxed);
|
||||
/// Gets the io buffer alignment requirement. Returns 1 if the alignment config is set to zero.
|
||||
///
|
||||
/// This function should be used for getting the actual alignment value to use.
|
||||
pub(crate) fn get_io_buffer_alignment() -> usize {
|
||||
let align = get_io_buffer_alignment_raw();
|
||||
align.max(1)
|
||||
}
|
||||
|
||||
pub(crate) fn get_io_mode() -> IoMode {
|
||||
IoMode::try_from(IO_MODE.load(Ordering::Relaxed)).unwrap()
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::context::DownloadBehavior;
|
||||
use crate::task_mgr::TaskKind;
|
||||
|
||||
use super::*;
|
||||
use dio::IoBufferMut;
|
||||
use owned_buffers_io::io_buf_ext::IoBufExt;
|
||||
use owned_buffers_io::slice::SliceMutExt;
|
||||
use rand::seq::SliceRandom;
|
||||
@@ -1450,10 +1248,10 @@ mod tests {
|
||||
impl MaybeVirtualFile {
|
||||
async fn read_exact_at(
|
||||
&self,
|
||||
mut slice: tokio_epoll_uring::Slice<IoBufferMut>,
|
||||
mut slice: tokio_epoll_uring::Slice<Vec<u8>>,
|
||||
offset: u64,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<tokio_epoll_uring::Slice<IoBufferMut>, Error> {
|
||||
) -> Result<tokio_epoll_uring::Slice<Vec<u8>>, Error> {
|
||||
match self {
|
||||
MaybeVirtualFile::VirtualFile(file) => file.read_exact_at(slice, offset, ctx).await,
|
||||
MaybeVirtualFile::File(file) => {
|
||||
@@ -1521,13 +1319,11 @@ mod tests {
|
||||
len: usize,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<String, Error> {
|
||||
let slice = IoBufferMut::with_capacity_aligned(len, 512).slice_full();
|
||||
let slice = Vec::with_capacity(len).slice_full();
|
||||
assert_eq!(slice.bytes_total(), len);
|
||||
let slice = self.read_exact_at(slice, pos, ctx).await?;
|
||||
let buf = slice.into_inner();
|
||||
assert_eq!(buf.len(), len);
|
||||
let mut vec = Vec::with_capacity(buf.len());
|
||||
vec.extend_from_slice(&buf);
|
||||
let vec = slice.into_inner();
|
||||
assert_eq!(vec.len(), len);
|
||||
Ok(String::from_utf8(vec).unwrap())
|
||||
}
|
||||
}
|
||||
@@ -1716,7 +1512,6 @@ mod tests {
|
||||
const VIRTUAL_FILES: usize = 100;
|
||||
const THREADS: usize = 100;
|
||||
const SAMPLE: [u8; SIZE] = [0xADu8; SIZE];
|
||||
let align = super::get_io_buffer_alignment();
|
||||
|
||||
let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error);
|
||||
let testdir = crate::config::PageServerConf::test_repo_dir("vfile_concurrency");
|
||||
@@ -1732,7 +1527,7 @@ mod tests {
|
||||
// Open the file many times.
|
||||
let mut files = Vec::new();
|
||||
for _ in 0..VIRTUAL_FILES {
|
||||
let f = VirtualFileInner::open_with_options(
|
||||
let f = VirtualFile::open_with_options(
|
||||
&test_file_path,
|
||||
OpenOptions::new().read(true),
|
||||
&ctx,
|
||||
@@ -1753,7 +1548,7 @@ mod tests {
|
||||
let files = files.clone();
|
||||
let ctx = ctx.detached_child(TaskKind::UnitTest, DownloadBehavior::Error);
|
||||
let hdl = rt.spawn(async move {
|
||||
let mut buf = IoBufferMut::with_capacity_aligned_zeroed(SIZE, align);
|
||||
let mut buf = vec![0u8; SIZE];
|
||||
let mut rng = rand::rngs::OsRng;
|
||||
for _ in 1..1000 {
|
||||
let f = &files[rng.gen_range(0..files.len())];
|
||||
@@ -1762,7 +1557,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap()
|
||||
.into_inner();
|
||||
assert!(&buf == SAMPLE.as_slice());
|
||||
assert!(buf == SAMPLE);
|
||||
}
|
||||
});
|
||||
hdls.push(hdl);
|
||||
@@ -1784,7 +1579,7 @@ mod tests {
|
||||
let path = testdir.join("myfile");
|
||||
let tmp_path = testdir.join("myfile.tmp");
|
||||
|
||||
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
|
||||
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
|
||||
@@ -1793,7 +1588,7 @@ mod tests {
|
||||
assert!(!tmp_path.exists());
|
||||
drop(file);
|
||||
|
||||
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec())
|
||||
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"bar".to_vec())
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
|
||||
@@ -1816,7 +1611,7 @@ mod tests {
|
||||
std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap();
|
||||
assert!(tmp_path.exists());
|
||||
|
||||
VirtualFileInner::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
|
||||
VirtualFile::crashsafe_overwrite(path.clone(), tmp_path.clone(), b"foo".to_vec())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -1,434 +0,0 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use core::slice;
|
||||
use std::{
|
||||
alloc::{self, Layout},
|
||||
cmp,
|
||||
mem::{ManuallyDrop, MaybeUninit},
|
||||
ops::{Deref, DerefMut},
|
||||
ptr::{addr_of_mut, NonNull},
|
||||
};
|
||||
|
||||
use bytes::buf::UninitSlice;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct IoBufferPtr(*mut u8);
|
||||
|
||||
// SAFETY: We gurantees no one besides `IoBufferPtr` itself has the raw pointer.
|
||||
unsafe impl Send for IoBufferPtr {}
|
||||
|
||||
/// An aligned buffer type used for I/O.
|
||||
#[derive(Debug)]
|
||||
pub struct IoBufferMut {
|
||||
ptr: IoBufferPtr,
|
||||
capacity: usize,
|
||||
len: usize,
|
||||
align: usize,
|
||||
}
|
||||
|
||||
impl IoBufferMut {
|
||||
/// Constructs a new, empty `IoBufferMut` with at least the specified capacity and alignment.
|
||||
///
|
||||
/// The buffer will be able to hold at most `capacity` elements and will never resize.
|
||||
///
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the new capacity exceeds `isize::MAX` _bytes_, or if the following alignment requirement is not met:
|
||||
/// * `align` must not be zero,
|
||||
///
|
||||
/// * `align` must be a power of two,
|
||||
///
|
||||
/// * `capacity`, when rounded up to the nearest multiple of `align`,
|
||||
/// must not overflow isize (i.e., the rounded value must be
|
||||
/// less than or equal to `isize::MAX`).
|
||||
pub fn with_capacity_aligned(capacity: usize, align: usize) -> Self {
|
||||
let layout = Layout::from_size_align(capacity, align).expect("Invalid layout");
|
||||
|
||||
// SAFETY: Making an allocation with a sized and aligned layout. The memory is manually freed with the same layout.
|
||||
let ptr = unsafe {
|
||||
let ptr = alloc::alloc(layout);
|
||||
if ptr.is_null() {
|
||||
alloc::handle_alloc_error(layout);
|
||||
}
|
||||
IoBufferPtr(ptr)
|
||||
};
|
||||
|
||||
IoBufferMut {
|
||||
ptr,
|
||||
capacity,
|
||||
len: 0,
|
||||
align,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_capacity_aligned_zeroed(capacity: usize, align: usize) -> Self {
|
||||
use bytes::BufMut;
|
||||
let mut buf = Self::with_capacity_aligned(capacity, align);
|
||||
buf.put_bytes(0, capacity);
|
||||
buf.len = capacity;
|
||||
buf
|
||||
}
|
||||
|
||||
/// Returns the total number of bytes the buffer can hold.
|
||||
#[inline]
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.capacity
|
||||
}
|
||||
|
||||
/// Returns the alignment of the buffer.
|
||||
#[inline]
|
||||
pub fn align(&self) -> usize {
|
||||
self.align
|
||||
}
|
||||
|
||||
/// Returns the number of bytes in the buffer, also referred to as its 'length'.
|
||||
#[inline]
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Force the length of the buffer to `new_len`.
|
||||
#[inline]
|
||||
unsafe fn set_len(&mut self, new_len: usize) {
|
||||
debug_assert!(new_len <= self.capacity());
|
||||
self.len = new_len;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn as_ptr(&self) -> *const u8 {
|
||||
self.ptr.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn as_mut_ptr(&mut self) -> *mut u8 {
|
||||
self.ptr.0
|
||||
}
|
||||
|
||||
/// Extracts a slice containing the entire buffer.
|
||||
///
|
||||
/// Equivalent to `&s[..]`.
|
||||
#[inline]
|
||||
fn as_slice(&self) -> &[u8] {
|
||||
// SAFETY: The pointer is valid and `len` bytes are initialized.
|
||||
unsafe { slice::from_raw_parts(self.as_ptr(), self.len) }
|
||||
}
|
||||
|
||||
/// Extracts a mutable slice of the entire buffer.
|
||||
///
|
||||
/// Equivalent to `&mut s[..]`.
|
||||
fn as_mut_slice(&mut self) -> &mut [u8] {
|
||||
// SAFETY: The pointer is valid and `len` bytes are initialized.
|
||||
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.len) }
|
||||
}
|
||||
|
||||
/// Drops the all the contents of the buffer, setting its length to `0`.
|
||||
#[inline]
|
||||
pub fn clear(&mut self) {
|
||||
self.len = 0;
|
||||
}
|
||||
|
||||
/// Reserves capacity for at least `additional` more bytes to be inserted
|
||||
/// in the given `IoBufferMut`. The collection may reserve more space to
|
||||
/// speculatively avoid frequent reallocations. After calling `reserve`,
|
||||
/// capacity will be greater than or equal to `self.len() + additional`.
|
||||
/// Does nothing if capacity is already sufficient.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the new capacity exceeds `isize::MAX` _bytes_.
|
||||
pub fn reserve(&mut self, additional: usize) {
|
||||
if additional > self.capacity() - self.len() {
|
||||
self.reserve_inner(additional);
|
||||
}
|
||||
}
|
||||
|
||||
/// Shortens the buffer, keeping the first len bytes.
|
||||
pub fn truncate(&mut self, len: usize) {
|
||||
if len > self.len {
|
||||
return;
|
||||
}
|
||||
self.len = len;
|
||||
}
|
||||
|
||||
fn reserve_inner(&mut self, additional: usize) {
|
||||
let Some(required_cap) = self.len().checked_add(additional) else {
|
||||
capacity_overflow()
|
||||
};
|
||||
|
||||
let old_capacity = self.capacity();
|
||||
let align = self.align();
|
||||
// This guarantees exponential growth. The doubling cannot overflow
|
||||
// because `cap <= isize::MAX` and the type of `cap` is `usize`.
|
||||
let cap = cmp::max(old_capacity * 2, required_cap);
|
||||
|
||||
if !is_valid_alloc(cap) {
|
||||
capacity_overflow()
|
||||
}
|
||||
let new_layout = Layout::from_size_align(cap, self.align()).expect("Invalid layout");
|
||||
|
||||
let old_ptr = self.as_mut_ptr();
|
||||
|
||||
// SAFETY: old allocation was allocated with std::alloc::alloc with the same layout,
|
||||
// and we panics on null pointer.
|
||||
let (ptr, cap) = unsafe {
|
||||
let old_layout = Layout::from_size_align_unchecked(old_capacity, align);
|
||||
let ptr = alloc::realloc(old_ptr, old_layout, new_layout.size());
|
||||
if ptr.is_null() {
|
||||
alloc::handle_alloc_error(new_layout);
|
||||
}
|
||||
(IoBufferPtr(ptr), cap)
|
||||
};
|
||||
|
||||
self.ptr = ptr;
|
||||
self.capacity = cap;
|
||||
}
|
||||
|
||||
pub fn leak<'a>(self) -> &'a mut [u8] {
|
||||
let mut buf = ManuallyDrop::new(self);
|
||||
// SAFETY: leaking the buffer as intended.
|
||||
unsafe { slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.len) }
|
||||
}
|
||||
}
|
||||
|
||||
fn capacity_overflow() -> ! {
|
||||
panic!("capacity overflow")
|
||||
}
|
||||
|
||||
// We need to guarantee the following:
|
||||
// * We don't ever allocate `> isize::MAX` byte-size objects.
|
||||
// * We don't overflow `usize::MAX` and actually allocate too little.
|
||||
//
|
||||
// On 64-bit we just need to check for overflow since trying to allocate
|
||||
// `> isize::MAX` bytes will surely fail. On 32-bit and 16-bit we need to add
|
||||
// an extra guard for this in case we're running on a platform which can use
|
||||
// all 4GB in user-space, e.g., PAE or x32.
|
||||
#[inline]
|
||||
fn is_valid_alloc(alloc_size: usize) -> bool {
|
||||
!(usize::BITS < 64 && alloc_size > isize::MAX as usize)
|
||||
}
|
||||
|
||||
impl Drop for IoBufferMut {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY: memory was allocated with std::alloc::alloc with the same layout.
|
||||
unsafe {
|
||||
alloc::dealloc(
|
||||
self.as_mut_ptr(),
|
||||
Layout::from_size_align_unchecked(self.capacity, self.align),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for IoBufferMut {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
self.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMut<[u8]> for IoBufferMut {
|
||||
fn as_mut(&mut self) -> &mut [u8] {
|
||||
self.as_mut_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<[u8]> for IoBufferMut {
|
||||
fn eq(&self, other: &[u8]) -> bool {
|
||||
self.as_slice().eq(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for IoBufferMut {
|
||||
type Target = [u8];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
impl DerefMut for IoBufferMut {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.as_mut_slice()
|
||||
}
|
||||
}
|
||||
|
||||
/// SAFETY: When advancing the internal cursor, the caller needs to make sure the bytes advcanced past have been initialized.
|
||||
unsafe impl bytes::BufMut for IoBufferMut {
|
||||
#[inline]
|
||||
fn remaining_mut(&self) -> usize {
|
||||
// Although a `Vec` can have at most isize::MAX bytes, we never want to grow `IoBufferMut`.
|
||||
// Thus, it can have at most `self.capacity` bytes.
|
||||
self.capacity() - self.len()
|
||||
}
|
||||
|
||||
// SAFETY: Caller needs to make sure the bytes being advanced past have been initialized.
|
||||
#[inline]
|
||||
unsafe fn advance_mut(&mut self, cnt: usize) {
|
||||
let len: usize = self.len();
|
||||
let remaining = self.remaining_mut();
|
||||
|
||||
if remaining < cnt {
|
||||
panic_advance(cnt, remaining);
|
||||
}
|
||||
|
||||
// Addition will not overflow since the sum is at most the capacity.
|
||||
self.set_len(len + cnt);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
|
||||
let cap = self.capacity();
|
||||
let len = self.len();
|
||||
|
||||
// SAFETY: Since `self.ptr` is valid for `cap` bytes, `self.ptr.add(len)` must be
|
||||
// valid for `cap - len` bytes. The subtraction will not underflow since
|
||||
// `len <= cap`.
|
||||
unsafe { UninitSlice::from_raw_parts_mut(self.as_mut_ptr().add(len), cap - len) }
|
||||
}
|
||||
}
|
||||
|
||||
/// Panic with a nice error message.
|
||||
#[cold]
|
||||
fn panic_advance(idx: usize, len: usize) -> ! {
|
||||
panic!(
|
||||
"advance out of bounds: the len is {} but advancing by {}",
|
||||
len, idx
|
||||
);
|
||||
}
|
||||
|
||||
/// Safety: [`IoBufferMut`] has exclusive ownership of the io buffer,
|
||||
/// and the location remains stable even if [`Self`] is moved.
|
||||
unsafe impl tokio_epoll_uring::IoBuf for IoBufferMut {
|
||||
fn stable_ptr(&self) -> *const u8 {
|
||||
self.as_ptr()
|
||||
}
|
||||
|
||||
fn bytes_init(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
fn bytes_total(&self) -> usize {
|
||||
self.capacity()
|
||||
}
|
||||
}
|
||||
|
||||
// SAFETY: See above.
|
||||
unsafe impl tokio_epoll_uring::IoBufMut for IoBufferMut {
|
||||
fn stable_mut_ptr(&mut self) -> *mut u8 {
|
||||
self.as_mut_ptr()
|
||||
}
|
||||
|
||||
unsafe fn set_init(&mut self, init_len: usize) {
|
||||
if self.len() < init_len {
|
||||
self.set_len(init_len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_with_capacity_aligned() {
|
||||
const ALIGN: usize = 4 * 1024;
|
||||
let v = IoBufferMut::with_capacity_aligned(ALIGN * 4, ALIGN);
|
||||
assert_eq!(v.len(), 0);
|
||||
assert_eq!(v.capacity(), ALIGN * 4);
|
||||
assert_eq!(v.align(), ALIGN);
|
||||
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
|
||||
|
||||
let v = IoBufferMut::with_capacity_aligned(ALIGN / 2, ALIGN);
|
||||
assert_eq!(v.len(), 0);
|
||||
assert_eq!(v.capacity(), ALIGN / 2);
|
||||
assert_eq!(v.align(), ALIGN);
|
||||
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_capacity_aligned_zeroed() {
|
||||
const ALIGN: usize = 4 * 1024;
|
||||
let v = IoBufferMut::with_capacity_aligned_zeroed(ALIGN, ALIGN);
|
||||
assert_eq!(v.len(), ALIGN);
|
||||
assert_eq!(v.capacity(), ALIGN);
|
||||
assert_eq!(v.align(), ALIGN);
|
||||
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
|
||||
assert_eq!(&v[..], &[0; ALIGN])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reserve() {
|
||||
use bytes::BufMut;
|
||||
const ALIGN: usize = 4 * 1024;
|
||||
let mut v = IoBufferMut::with_capacity_aligned(ALIGN, ALIGN);
|
||||
let capacity = v.capacity();
|
||||
v.reserve(capacity);
|
||||
assert_eq!(v.capacity(), capacity);
|
||||
let data = [b'a'; ALIGN];
|
||||
v.put(&data[..]);
|
||||
v.reserve(capacity);
|
||||
assert!(v.capacity() >= capacity * 2);
|
||||
assert_eq!(&v[..], &data[..]);
|
||||
let capacity = v.capacity();
|
||||
v.clear();
|
||||
v.reserve(capacity);
|
||||
assert_eq!(capacity, v.capacity());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bytes_put() {
|
||||
use bytes::BufMut;
|
||||
const ALIGN: usize = 4 * 1024;
|
||||
let mut v = IoBufferMut::with_capacity_aligned(ALIGN * 4, ALIGN);
|
||||
let x = [b'a'; ALIGN];
|
||||
|
||||
for _ in 0..2 {
|
||||
for _ in 0..4 {
|
||||
v.put(&x[..]);
|
||||
}
|
||||
assert_eq!(v.len(), ALIGN * 4);
|
||||
assert_eq!(v.capacity(), ALIGN * 4);
|
||||
assert_eq!(v.align(), ALIGN);
|
||||
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
|
||||
v.clear()
|
||||
}
|
||||
assert_eq!(v.len(), 0);
|
||||
assert_eq!(v.capacity(), ALIGN * 4);
|
||||
assert_eq!(v.align(), ALIGN);
|
||||
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_bytes_put_panic() {
|
||||
use bytes::BufMut;
|
||||
const ALIGN: usize = 4 * 1024;
|
||||
let mut v = IoBufferMut::with_capacity_aligned(ALIGN * 4, ALIGN);
|
||||
let x = [b'a'; ALIGN];
|
||||
for _ in 0..5 {
|
||||
v.put_slice(&x[..]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_io_buf_put_slice() {
|
||||
use tokio_epoll_uring::BoundedBufMut;
|
||||
const ALIGN: usize = 4 * 1024;
|
||||
let mut v = IoBufferMut::with_capacity_aligned(ALIGN, ALIGN);
|
||||
let x = [b'a'; ALIGN];
|
||||
|
||||
for _ in 0..2 {
|
||||
v.put_slice(&x[..]);
|
||||
assert_eq!(v.len(), ALIGN);
|
||||
assert_eq!(v.capacity(), ALIGN);
|
||||
assert_eq!(v.align(), ALIGN);
|
||||
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
|
||||
v.clear()
|
||||
}
|
||||
assert_eq!(v.len(), 0);
|
||||
assert_eq!(v.capacity(), ALIGN);
|
||||
assert_eq!(v.align(), ALIGN);
|
||||
assert_eq!(v.as_ptr().align_offset(ALIGN), 0);
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use tokio_epoll_uring::IoBufMut;
|
||||
|
||||
use crate::virtual_file::{dio::IoBufferMut, PageWriteGuardBuf};
|
||||
|
||||
pub trait IoBufAlignedMut: IoBufMut {}
|
||||
|
||||
impl IoBufAlignedMut for IoBufferMut {}
|
||||
|
||||
impl IoBufAlignedMut for PageWriteGuardBuf {}
|
||||
@@ -1,6 +1,5 @@
|
||||
//! See [`FullSlice`].
|
||||
|
||||
use crate::virtual_file::dio::IoBufferMut;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::ops::{Deref, Range};
|
||||
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
|
||||
@@ -77,4 +76,3 @@ macro_rules! impl_io_buf_ext {
|
||||
impl_io_buf_ext!(Bytes);
|
||||
impl_io_buf_ext!(BytesMut);
|
||||
impl_io_buf_ext!(Vec<u8>);
|
||||
impl_io_buf_ext!(IoBufferMut);
|
||||
|
||||
@@ -1473,33 +1473,11 @@ walprop_pg_wal_read(Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count,
|
||||
{
|
||||
NeonWALReadResult res;
|
||||
|
||||
#if PG_MAJORVERSION_NUM >= 17
|
||||
if (!sk->wp->config->syncSafekeepers)
|
||||
{
|
||||
Size rbytes;
|
||||
rbytes = WALReadFromBuffers(buf, startptr, count,
|
||||
walprop_pg_get_timeline_id());
|
||||
|
||||
startptr += rbytes;
|
||||
count -= rbytes;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (count == 0)
|
||||
{
|
||||
res = NEON_WALREAD_SUCCESS;
|
||||
}
|
||||
else
|
||||
{
|
||||
Assert(count > 0);
|
||||
|
||||
/* Now read the remaining WAL from the WAL file */
|
||||
res = NeonWALRead(sk->xlogreader,
|
||||
buf,
|
||||
startptr,
|
||||
count,
|
||||
walprop_pg_get_timeline_id());
|
||||
}
|
||||
res = NeonWALRead(sk->xlogreader,
|
||||
buf,
|
||||
startptr,
|
||||
count,
|
||||
walprop_pg_get_timeline_id());
|
||||
|
||||
if (res == NEON_WALREAD_SUCCESS)
|
||||
{
|
||||
|
||||
@@ -565,7 +565,7 @@ mod tests {
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
|
||||
use super::{auth_quirks, AuthRateLimiter};
|
||||
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
|
||||
|
||||
struct Auth {
|
||||
ips: Vec<IpPattern>,
|
||||
@@ -611,12 +611,15 @@ mod tests {
|
||||
}
|
||||
|
||||
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(1),
|
||||
scram_protocol_timeout: std::time::Duration::from_secs(5),
|
||||
rate_limiter_enabled: true,
|
||||
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: false,
|
||||
});
|
||||
|
||||
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
future::Future,
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
@@ -8,7 +9,10 @@ use anyhow::{bail, ensure, Context};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use dashmap::DashMap;
|
||||
use jose_jwk::crypto::KeyInfo;
|
||||
use serde::{Deserialize, Deserializer};
|
||||
use serde::{
|
||||
de::{DeserializeSeed, IgnoredAny, Visitor},
|
||||
Deserializer,
|
||||
};
|
||||
use signature::Verifier;
|
||||
use tokio::time::Instant;
|
||||
|
||||
@@ -33,6 +37,7 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
|
||||
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct AuthRule {
|
||||
pub(crate) id: String,
|
||||
pub(crate) jwks_url: url::Url,
|
||||
@@ -303,35 +308,21 @@ impl JwkCacheEntryLock {
|
||||
}
|
||||
key => bail!("unsupported key type {key:?}"),
|
||||
};
|
||||
tracing::debug!("JWT signature valid");
|
||||
|
||||
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
|
||||
tracing::debug!(?payload, "JWT signature valid with claims");
|
||||
let validator = JwtValidator {
|
||||
expected_audience,
|
||||
current_time: SystemTime::now(),
|
||||
clock_skew_leeway: CLOCK_SKEW_LEEWAY,
|
||||
};
|
||||
|
||||
match (expected_audience, payload.audience) {
|
||||
// check the audience matches
|
||||
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
|
||||
// the audience is expected but is missing
|
||||
(Some(_), None) => bail!("invalid JWT token audience"),
|
||||
// we don't care for the audience field
|
||||
(None, _) => {}
|
||||
}
|
||||
let payload = validator
|
||||
.deserialize(&mut serde_json::Deserializer::from_slice(&payload))?;
|
||||
|
||||
let now = SystemTime::now();
|
||||
|
||||
if let Some(exp) = payload.expiration {
|
||||
ensure!(now < exp + CLOCK_SKEW_LEEWAY, "JWT token has expired");
|
||||
}
|
||||
|
||||
if let Some(nbf) = payload.not_before {
|
||||
ensure!(
|
||||
nbf < now + CLOCK_SKEW_LEEWAY,
|
||||
"JWT token is not yet ready to use"
|
||||
);
|
||||
}
|
||||
tracing::debug!(?payload, "JWT claims valid");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -419,37 +410,184 @@ struct JwtHeader<'a> {
|
||||
key_id: Option<&'a str>,
|
||||
}
|
||||
|
||||
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
|
||||
#[derive(serde::Deserialize, serde::Serialize, Debug)]
|
||||
struct JwtPayload<'a> {
|
||||
/// Audience - Recipient for which the JWT is intended
|
||||
#[serde(rename = "aud")]
|
||||
audience: Option<&'a str>,
|
||||
/// Expiration - Time after which the JWT expires
|
||||
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
|
||||
expiration: Option<SystemTime>,
|
||||
/// Not before - Time after which the JWT expires
|
||||
#[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
|
||||
not_before: Option<SystemTime>,
|
||||
|
||||
// the following entries are only extracted for the sake of debug logging.
|
||||
/// Issuer of the JWT
|
||||
#[serde(rename = "iss")]
|
||||
issuer: Option<&'a str>,
|
||||
/// Subject of the JWT (the user)
|
||||
#[serde(rename = "sub")]
|
||||
subject: Option<&'a str>,
|
||||
/// Unique token identifier
|
||||
#[serde(rename = "jti")]
|
||||
jwt_id: Option<&'a str>,
|
||||
/// Unique session identifier
|
||||
#[serde(rename = "sid")]
|
||||
session_id: Option<&'a str>,
|
||||
struct JwtValidator<'a> {
|
||||
expected_audience: Option<&'a str>,
|
||||
current_time: SystemTime,
|
||||
clock_skew_leeway: Duration,
|
||||
}
|
||||
|
||||
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
|
||||
let d = <Option<u64>>::deserialize(d)?;
|
||||
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
|
||||
impl<'de> DeserializeSeed<'de> for JwtValidator<'_> {
|
||||
type Value = JwtPayload<'de>;
|
||||
|
||||
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
impl<'de> Visitor<'de> for JwtValidator<'_> {
|
||||
type Value = JwtPayload<'de>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a JWT payload")
|
||||
}
|
||||
|
||||
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::MapAccess<'de>,
|
||||
{
|
||||
let mut payload = JwtPayload {
|
||||
issuer: None,
|
||||
subject: None,
|
||||
jwt_id: None,
|
||||
session_id: None,
|
||||
};
|
||||
|
||||
let mut aud = false;
|
||||
|
||||
while let Some(key) = map.next_key()? {
|
||||
match key {
|
||||
"iss" if payload.issuer.is_none() => {
|
||||
payload.issuer = Some(map.next_value()?);
|
||||
}
|
||||
"sub" if payload.subject.is_none() => {
|
||||
payload.subject = Some(map.next_value()?);
|
||||
}
|
||||
"jit" if payload.jwt_id.is_none() => {
|
||||
payload.jwt_id = Some(map.next_value()?);
|
||||
}
|
||||
"sid" if payload.session_id.is_none() => {
|
||||
payload.session_id = Some(map.next_value()?);
|
||||
}
|
||||
"exp" => {
|
||||
let exp = map.next_value::<u64>()?;
|
||||
let exp = SystemTime::UNIX_EPOCH + Duration::from_secs(exp);
|
||||
|
||||
if self.current_time > exp + self.clock_skew_leeway {
|
||||
return Err(serde::de::Error::custom("JWT token has expired"));
|
||||
}
|
||||
}
|
||||
"nbf" => {
|
||||
let nbf = map.next_value::<u64>()?;
|
||||
let nbf = SystemTime::UNIX_EPOCH + Duration::from_secs(nbf);
|
||||
|
||||
if self.current_time + self.clock_skew_leeway < nbf {
|
||||
return Err(serde::de::Error::custom(
|
||||
"JWT token is not yet ready to use",
|
||||
));
|
||||
}
|
||||
}
|
||||
"aud" => {
|
||||
if let Some(expected_audience) = self.expected_audience {
|
||||
map.next_value_seed(AudienceValidator { expected_audience })?;
|
||||
aud = true;
|
||||
} else {
|
||||
map.next_value::<IgnoredAny>()?;
|
||||
}
|
||||
}
|
||||
_ => map.next_value::<IgnoredAny>().map(|IgnoredAny| ())?,
|
||||
}
|
||||
}
|
||||
|
||||
if self.expected_audience.is_some() && !aud {
|
||||
return Err(serde::de::Error::custom("invalid JWT token audience"));
|
||||
}
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_map(self)
|
||||
}
|
||||
}
|
||||
|
||||
struct AudienceValidator<'a> {
|
||||
expected_audience: &'a str,
|
||||
}
|
||||
|
||||
impl<'de> DeserializeSeed<'de> for AudienceValidator<'_> {
|
||||
type Value = ();
|
||||
|
||||
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
impl<'de> Visitor<'de> for AudienceValidator<'_> {
|
||||
type Value = ();
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a single string or an array of strings")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
if self.expected_audience == v {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(E::custom("invalid JWT token audience"))
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
while let Some(v) = seq.next_element_seed(SingleAudienceValidator {
|
||||
expected_audience: self.expected_audience,
|
||||
})? {
|
||||
if v {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err(serde::de::Error::custom("invalid JWT token audience"))
|
||||
}
|
||||
}
|
||||
deserializer.deserialize_any(self)
|
||||
}
|
||||
}
|
||||
|
||||
struct SingleAudienceValidator<'a> {
|
||||
expected_audience: &'a str,
|
||||
}
|
||||
|
||||
impl<'de> DeserializeSeed<'de> for SingleAudienceValidator<'_> {
|
||||
type Value = bool;
|
||||
|
||||
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
impl<'de> Visitor<'de> for SingleAudienceValidator<'_> {
|
||||
type Value = bool;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a single audience string")
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(self.expected_audience == v)
|
||||
}
|
||||
}
|
||||
deserializer.deserialize_any(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
|
||||
// the following entries are only extracted for the sake of debug logging.
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
struct JwtPayload<'a> {
|
||||
/// Issuer of the JWT
|
||||
issuer: Option<Cow<'a, str>>,
|
||||
/// Subject of the JWT (the user)
|
||||
subject: Option<Cow<'a, str>>,
|
||||
/// Unique token identifier
|
||||
jwt_id: Option<Cow<'a, str>>,
|
||||
/// Unique session identifier
|
||||
session_id: Option<Cow<'a, str>>,
|
||||
}
|
||||
|
||||
struct JwkRenewalPermit<'a> {
|
||||
@@ -530,6 +668,8 @@ mod tests {
|
||||
use hyper_util::rt::TokioIo;
|
||||
use rand::rngs::OsRng;
|
||||
use rsa::pkcs8::DecodePrivateKey;
|
||||
use serde::Serialize;
|
||||
use serde_json::json;
|
||||
use signature::Signer;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
@@ -562,18 +702,36 @@ mod tests {
|
||||
}
|
||||
|
||||
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let body = typed_json::json! {{
|
||||
"exp": now + 3600,
|
||||
"nbf": now,
|
||||
"aud": ["audience1", "neon", "audience2"],
|
||||
"sub": "user1",
|
||||
"sid": "session1",
|
||||
"jti": "token1",
|
||||
"iss": "neon-testing",
|
||||
}};
|
||||
build_custom_jwt_payload(kid, body, sig)
|
||||
}
|
||||
|
||||
fn build_custom_jwt_payload(
|
||||
kid: String,
|
||||
body: impl Serialize,
|
||||
sig: jose_jwa::Signing,
|
||||
) -> String {
|
||||
let header = JwtHeader {
|
||||
typ: "JWT",
|
||||
algorithm: jose_jwa::Algorithm::Signing(sig),
|
||||
key_id: Some(&kid),
|
||||
};
|
||||
let body = typed_json::json! {{
|
||||
"exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
|
||||
}};
|
||||
|
||||
let header =
|
||||
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
|
||||
let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
|
||||
let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
|
||||
|
||||
format!("{header}.{body}")
|
||||
}
|
||||
@@ -588,6 +746,16 @@ mod tests {
|
||||
format!("{payload}.{sig}")
|
||||
}
|
||||
|
||||
fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
|
||||
use p256::ecdsa::{Signature, SigningKey};
|
||||
|
||||
let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
|
||||
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
|
||||
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
|
||||
|
||||
format!("{payload}.{sig}")
|
||||
}
|
||||
|
||||
fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
|
||||
use rsa::pkcs1v15::SigningKey;
|
||||
use rsa::signature::SignatureEncoding;
|
||||
@@ -659,37 +827,34 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
|
||||
-----END PRIVATE KEY-----
|
||||
";
|
||||
|
||||
#[tokio::test]
|
||||
async fn renew() {
|
||||
let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into());
|
||||
let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into());
|
||||
let (ec1, jwk3) = new_ec_jwk("3".into());
|
||||
let (ec2, jwk4) = new_ec_jwk("4".into());
|
||||
#[derive(Clone)]
|
||||
struct Fetch(Vec<AuthRule>);
|
||||
|
||||
let foo_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk1, jwk3],
|
||||
};
|
||||
let bar_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk2, jwk4],
|
||||
};
|
||||
impl FetchAuthRules for Fetch {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
Ok(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
async fn jwks_server(
|
||||
router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
|
||||
) -> SocketAddr {
|
||||
let router = Arc::new(router);
|
||||
let service = service_fn(move |req| {
|
||||
let foo_jwks = foo_jwks.clone();
|
||||
let bar_jwks = bar_jwks.clone();
|
||||
let router = Arc::clone(&router);
|
||||
async move {
|
||||
let jwks = match req.uri().path() {
|
||||
"/foo" => &foo_jwks,
|
||||
"/bar" => &bar_jwks,
|
||||
_ => {
|
||||
return Response::builder()
|
||||
.status(404)
|
||||
.body(Full::new(Bytes::new()));
|
||||
}
|
||||
};
|
||||
let body = serde_json::to_vec(jwks).unwrap();
|
||||
Response::builder()
|
||||
.status(200)
|
||||
.body(Full::new(Bytes::from(body)))
|
||||
match router(req.uri().path()) {
|
||||
Some(body) => Response::builder()
|
||||
.status(200)
|
||||
.body(Full::new(Bytes::from(body))),
|
||||
None => Response::builder()
|
||||
.status(404)
|
||||
.body(Full::new(Bytes::new())),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -704,84 +869,61 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
|
||||
}
|
||||
});
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
addr
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Fetch(SocketAddr, Vec<RoleNameInt>);
|
||||
#[tokio::test]
|
||||
async fn check_jwt_happy_path() {
|
||||
let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
|
||||
let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
|
||||
let (ec1, jwk3) = new_ec_jwk("ec1".into());
|
||||
let (ec2, jwk4) = new_ec_jwk("ec2".into());
|
||||
|
||||
impl FetchAuthRules for Fetch {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
Ok(vec![
|
||||
AuthRule {
|
||||
id: "foo".to_owned(),
|
||||
jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: self.1.clone(),
|
||||
},
|
||||
AuthRule {
|
||||
id: "bar".to_owned(),
|
||||
jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: self.1.clone(),
|
||||
},
|
||||
])
|
||||
}
|
||||
}
|
||||
let foo_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk1, jwk3],
|
||||
};
|
||||
let bar_jwks = jose_jwk::JwkSet {
|
||||
keys: vec![jwk2, jwk4],
|
||||
};
|
||||
|
||||
let jwks_addr = jwks_server(move |path| match path {
|
||||
"/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
|
||||
"/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let role_name1 = RoleName::from("anonymous");
|
||||
let role_name2 = RoleName::from("authenticated");
|
||||
|
||||
let fetch = Fetch(
|
||||
addr,
|
||||
vec![
|
||||
RoleNameInt::from(&role_name1),
|
||||
RoleNameInt::from(&role_name2),
|
||||
],
|
||||
);
|
||||
let roles = vec![
|
||||
RoleNameInt::from(&role_name1),
|
||||
RoleNameInt::from(&role_name2),
|
||||
];
|
||||
let rules = vec![
|
||||
AuthRule {
|
||||
id: "foo".to_owned(),
|
||||
jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: roles.clone(),
|
||||
},
|
||||
AuthRule {
|
||||
id: "bar".to_owned(),
|
||||
jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: roles.clone(),
|
||||
},
|
||||
];
|
||||
|
||||
let fetch = Fetch(rules);
|
||||
let jwk_cache = JwkCache::default();
|
||||
|
||||
let endpoint = EndpointId::from("ep");
|
||||
|
||||
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
|
||||
|
||||
let jwt1 = new_rsa_jwt("1".into(), rs1);
|
||||
let jwt2 = new_rsa_jwt("2".into(), rs2);
|
||||
let jwt3 = new_ec_jwt("3".into(), &ec1);
|
||||
let jwt4 = new_ec_jwt("4".into(), &ec2);
|
||||
|
||||
// had the wrong kid, therefore will have the wrong ecdsa signature
|
||||
let bad_jwt = new_ec_jwt("3".into(), &ec2);
|
||||
// this role_name is not accepted
|
||||
let bad_role_name = RoleName::from("cloud_admin");
|
||||
|
||||
let err = jwk_cache
|
||||
.check_jwt(
|
||||
&RequestMonitoring::test(),
|
||||
&bad_jwt,
|
||||
&client,
|
||||
endpoint.clone(),
|
||||
&role_name1,
|
||||
&fetch,
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("signature error"));
|
||||
|
||||
let err = jwk_cache
|
||||
.check_jwt(
|
||||
&RequestMonitoring::test(),
|
||||
&jwt1,
|
||||
&client,
|
||||
endpoint.clone(),
|
||||
&bad_role_name,
|
||||
&fetch,
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(err.to_string().contains("jwk not found"));
|
||||
let jwt1 = new_rsa_jwt("rs1".into(), rs1);
|
||||
let jwt2 = new_rsa_jwt("rs2".into(), rs2);
|
||||
let jwt3 = new_ec_jwt("ec1".into(), &ec1);
|
||||
let jwt4 = new_ec_jwt("ec2".into(), &ec2);
|
||||
|
||||
let tokens = [jwt1, jwt2, jwt3, jwt4];
|
||||
let role_names = [role_name1, role_name2];
|
||||
@@ -790,15 +932,194 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
|
||||
jwk_cache
|
||||
.check_jwt(
|
||||
&RequestMonitoring::test(),
|
||||
token,
|
||||
&client,
|
||||
endpoint.clone(),
|
||||
role,
|
||||
&fetch,
|
||||
token,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_jwt_invalid_signature() {
|
||||
let (_, jwk) = new_ec_jwk("1".into());
|
||||
let (key, _) = new_ec_jwk("1".into());
|
||||
|
||||
// has a matching kid, but signed by the wrong key
|
||||
let bad_jwt = new_ec_jwt("1".into(), &key);
|
||||
|
||||
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
|
||||
let jwks_addr = jwks_server(move |path| match path {
|
||||
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let role = RoleName::from("authenticated");
|
||||
|
||||
let rules = vec![AuthRule {
|
||||
id: String::new(),
|
||||
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: vec![RoleNameInt::from(&role)],
|
||||
}];
|
||||
|
||||
let fetch = Fetch(rules);
|
||||
let jwk_cache = JwkCache::default();
|
||||
|
||||
let ep = EndpointId::from("ep");
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let err = jwk_cache
|
||||
.check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string().contains("signature error"),
|
||||
"expected \"signature error\", got {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_jwt_unknown_role() {
|
||||
let (key, jwk) = new_rsa_jwk(RS1, "1".into());
|
||||
let jwt = new_rsa_jwt("1".into(), key);
|
||||
|
||||
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
|
||||
let jwks_addr = jwks_server(move |path| match path {
|
||||
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let role = RoleName::from("authenticated");
|
||||
let rules = vec![AuthRule {
|
||||
id: String::new(),
|
||||
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
|
||||
audience: None,
|
||||
role_names: vec![RoleNameInt::from(&role)],
|
||||
}];
|
||||
|
||||
let fetch = Fetch(rules);
|
||||
let jwk_cache = JwkCache::default();
|
||||
|
||||
let ep = EndpointId::from("ep");
|
||||
|
||||
// this role_name is not accepted
|
||||
let bad_role_name = RoleName::from("cloud_admin");
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
let err = jwk_cache
|
||||
.check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("jwk not found"),
|
||||
"expected \"jwk not found\", got {err:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_jwt_invalid_claims() {
|
||||
let (key, jwk) = new_ec_jwk("1".into());
|
||||
|
||||
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
|
||||
let jwks_addr = jwks_server(move |path| match path {
|
||||
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
|
||||
_ => None,
|
||||
})
|
||||
.await;
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
struct Test {
|
||||
body: serde_json::Value,
|
||||
error: &'static str,
|
||||
}
|
||||
|
||||
let table = vec![
|
||||
Test {
|
||||
body: json! {{
|
||||
"nbf": now + 60,
|
||||
"aud": "neon",
|
||||
}},
|
||||
error: "JWT token is not yet ready to use",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"exp": now - 60,
|
||||
"aud": ["neon"],
|
||||
}},
|
||||
error: "JWT token has expired",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"aud": [],
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"aud": "foo",
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"aud": ["foo"],
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
Test {
|
||||
body: json! {{
|
||||
"aud": ["foo", "bar"],
|
||||
}},
|
||||
error: "invalid JWT token audience",
|
||||
},
|
||||
];
|
||||
|
||||
let role = RoleName::from("authenticated");
|
||||
|
||||
let rules = vec![AuthRule {
|
||||
id: String::new(),
|
||||
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
|
||||
audience: Some("neon".to_string()),
|
||||
role_names: vec![RoleNameInt::from(&role)],
|
||||
}];
|
||||
|
||||
let fetch = Fetch(rules);
|
||||
let jwk_cache = JwkCache::default();
|
||||
|
||||
let ep = EndpointId::from("ep");
|
||||
|
||||
let ctx = RequestMonitoring::test();
|
||||
for test in table {
|
||||
let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
|
||||
|
||||
match jwk_cache
|
||||
.check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
|
||||
.await
|
||||
{
|
||||
Err(err) if err.to_string().contains(test.error) => {}
|
||||
Err(err) => {
|
||||
panic!("expected {:?}, got {err:?}", test.error)
|
||||
}
|
||||
Ok(()) => {
|
||||
panic!("expected {:?}, got ok", test.error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,17 +14,15 @@ use crate::{
|
||||
EndpointId,
|
||||
};
|
||||
|
||||
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
|
||||
use super::jwt::{AuthRule, FetchAuthRules};
|
||||
|
||||
pub struct LocalBackend {
|
||||
pub(crate) jwks_cache: JwkCache,
|
||||
pub(crate) node_info: NodeInfo,
|
||||
}
|
||||
|
||||
impl LocalBackend {
|
||||
pub fn new(postgres_addr: SocketAddr) -> Self {
|
||||
LocalBackend {
|
||||
jwks_cache: JwkCache::default(),
|
||||
node_info: NodeInfo {
|
||||
config: {
|
||||
let mut cfg = ConnCfg::new();
|
||||
|
||||
@@ -6,7 +6,10 @@ use compute_api::spec::LocalProxySpec;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::Either;
|
||||
use proxy::{
|
||||
auth::backend::local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
auth::backend::{
|
||||
jwt::JwkCache,
|
||||
local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
},
|
||||
cancellation::CancellationHandlerMain,
|
||||
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
|
||||
console::{
|
||||
@@ -267,12 +270,15 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
allow_self_signed_compute: false,
|
||||
http_config,
|
||||
authentication_config: AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool: ThreadPool::new(0),
|
||||
scram_protocol_timeout: Duration::from_secs(10),
|
||||
rate_limiter_enabled: false,
|
||||
rate_limiter: BucketRateLimiter::new(vec![]),
|
||||
rate_limit_ip_subnet: 64,
|
||||
ip_allowlist_check_enabled: true,
|
||||
is_auth_broker: false,
|
||||
accept_jwts: true,
|
||||
},
|
||||
require_client_ip: false,
|
||||
handshake_timeout: Duration::from_secs(10),
|
||||
|
||||
@@ -8,6 +8,7 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
|
||||
use aws_config::Region;
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::AuthRateLimiter;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::cancellation::CancelMap;
|
||||
@@ -102,6 +103,9 @@ struct ProxyCliArgs {
|
||||
default_value = "http://localhost:3000/authenticate_proxy_request/"
|
||||
)]
|
||||
auth_endpoint: String,
|
||||
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
is_auth_broker: bool,
|
||||
/// path to TLS key for client postgres connections
|
||||
///
|
||||
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
|
||||
@@ -382,9 +386,27 @@ async fn main() -> anyhow::Result<()> {
|
||||
info!("Starting mgmt on {mgmt_address}");
|
||||
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
|
||||
|
||||
let proxy_address: SocketAddr = args.proxy.parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
let proxy_listener = TcpListener::bind(proxy_address).await?;
|
||||
let proxy_listener = if !args.is_auth_broker {
|
||||
let proxy_address: SocketAddr = args.proxy.parse()?;
|
||||
info!("Starting proxy on {proxy_address}");
|
||||
|
||||
Some(TcpListener::bind(proxy_address).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
let serverless_listener = if let Some(serverless_address) = args.wss {
|
||||
let serverless_address: SocketAddr = serverless_address.parse()?;
|
||||
info!("Starting wss on {serverless_address}");
|
||||
Some(TcpListener::bind(serverless_address).await?)
|
||||
} else if args.is_auth_broker {
|
||||
bail!("wss arg must be present for auth-broker")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let cancel_map = CancelMap::default();
|
||||
@@ -430,21 +452,17 @@ async fn main() -> anyhow::Result<()> {
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
if let Some(serverless_address) = args.wss {
|
||||
let serverless_address: SocketAddr = serverless_address.parse()?;
|
||||
info!("Starting wss on {serverless_address}");
|
||||
let serverless_listener = TcpListener::bind(serverless_address).await?;
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(serverless_listener) = serverless_listener {
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
serverless_listener,
|
||||
@@ -674,7 +692,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
)?;
|
||||
|
||||
let http_config = HttpConfig {
|
||||
accept_websockets: true,
|
||||
accept_websockets: !args.is_auth_broker,
|
||||
pool_options: GlobalConnPoolOptions {
|
||||
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
|
||||
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
|
||||
@@ -689,12 +707,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_auth_broker: args.is_auth_broker,
|
||||
accept_jwts: args.is_auth_broker,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::{
|
||||
auth::{self, backend::AuthRateLimiter},
|
||||
auth::{
|
||||
self,
|
||||
backend::{jwt::JwkCache, AuthRateLimiter},
|
||||
},
|
||||
console::locks::ApiLocks,
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
||||
scram::threadpool::ThreadPool,
|
||||
@@ -67,6 +70,9 @@ pub struct AuthenticationConfig {
|
||||
pub rate_limiter: AuthRateLimiter,
|
||||
pub rate_limit_ip_subnet: u8,
|
||||
pub ip_allowlist_check_enabled: bool,
|
||||
pub jwks_cache: JwkCache,
|
||||
pub is_auth_broker: bool,
|
||||
pub accept_jwts: bool,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
@@ -250,18 +256,26 @@ impl CertResolver {
|
||||
|
||||
let common_name = pem.subject().to_string();
|
||||
|
||||
// We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
|
||||
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
|
||||
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
|
||||
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
|
||||
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
|
||||
// of cutting off '*.' parts.
|
||||
let common_name = if common_name.starts_with("CN=*.") {
|
||||
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
|
||||
// We need to get the canonical name for this certificate so we can match them against any domain names
|
||||
// seen within the proxy codebase.
|
||||
//
|
||||
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
|
||||
// We need to remove the wildcard prefix for the purposes of certificate selection.
|
||||
//
|
||||
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
|
||||
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
|
||||
//
|
||||
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
|
||||
// validation, so let's we can continue with any common-name
|
||||
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
|
||||
s.to_string()
|
||||
} else if let Some(s) = common_name.strip_prefix("CN=") {
|
||||
s.to_string()
|
||||
} else {
|
||||
common_name.strip_prefix("CN=").map(|s| s.to_string())
|
||||
}
|
||||
.context("Failed to parse common name from certificate")?;
|
||||
bail!("Failed to parse common name from certificate")
|
||||
};
|
||||
|
||||
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use std::{
|
||||
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
|
||||
any::type_name, hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index,
|
||||
sync::OnceLock,
|
||||
};
|
||||
|
||||
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
|
||||
@@ -16,12 +17,21 @@ pub struct StringInterner<Id> {
|
||||
_id: PhantomData<Id>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
|
||||
#[derive(PartialEq, Clone, Copy, Eq, Hash)]
|
||||
pub struct InternedString<Id> {
|
||||
inner: Spur,
|
||||
_id: PhantomData<Id>,
|
||||
}
|
||||
|
||||
impl<Id: InternId> std::fmt::Debug for InternedString<Id> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_tuple("InternedString")
|
||||
.field(&type_name::<Id>())
|
||||
.field(&self.as_str())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: InternId> std::fmt::Display for InternedString<Id> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.as_str().fmt(f)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
mod backend;
|
||||
pub mod cancel_set;
|
||||
mod conn_pool;
|
||||
mod http_conn_pool;
|
||||
mod http_util;
|
||||
mod json;
|
||||
mod sql_over_http;
|
||||
@@ -19,7 +20,8 @@ use anyhow::Context;
|
||||
use futures::future::{select, Either};
|
||||
use futures::TryFutureExt;
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
use hyper1::body::Incoming;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
@@ -81,7 +83,28 @@ pub async fn task_main(
|
||||
}
|
||||
});
|
||||
|
||||
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
|
||||
{
|
||||
let http_conn_pool = Arc::clone(&http_conn_pool);
|
||||
tokio::spawn(async move {
|
||||
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
|
||||
});
|
||||
}
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::spawn({
|
||||
let cancellation_token = cancellation_token.clone();
|
||||
let http_conn_pool = http_conn_pool.clone();
|
||||
async move {
|
||||
cancellation_token.cancelled().await;
|
||||
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let backend = Arc::new(PoolingBackend {
|
||||
http_conn_pool: Arc::clone(&http_conn_pool),
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
@@ -342,7 +365,7 @@ async fn request_handler(
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
.get("host")
|
||||
@@ -386,7 +409,7 @@ async fn request_handler(
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
|
||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(
|
||||
session_id,
|
||||
@@ -409,7 +432,7 @@ async fn request_handler(
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Full::new(Bytes::new()))
|
||||
.body(Empty::new().map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{io, sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use tokio::net::{lookup_host, TcpStream};
|
||||
use tracing::{field::display, info};
|
||||
|
||||
use crate::{
|
||||
@@ -27,9 +29,13 @@ use crate::{
|
||||
Host,
|
||||
};
|
||||
|
||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||
use super::{
|
||||
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
|
||||
http_conn_pool::{self, poll_http2_client},
|
||||
};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -103,32 +109,44 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_jwt(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
user_info: &ComputeUserInfo,
|
||||
jwt: &str,
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
jwt: String,
|
||||
) -> Result<(), AuthError> {
|
||||
match &self.config.auth_backend {
|
||||
crate::auth::Backend::Console(_, ()) => {
|
||||
Err(AuthError::auth_failed("JWT login is not yet supported"))
|
||||
crate::auth::Backend::Console(console, ()) => {
|
||||
config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
user_info.endpoint.clone(),
|
||||
&user_info.user,
|
||||
&**console,
|
||||
&jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
|
||||
"JWT login over web auth proxy is not supported",
|
||||
)),
|
||||
crate::auth::Backend::Local(cache) => {
|
||||
cache
|
||||
crate::auth::Backend::Local(_) => {
|
||||
config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
user_info.endpoint.clone(),
|
||||
&user_info.user,
|
||||
&StaticAuthRules,
|
||||
jwt,
|
||||
&jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
Ok(ComputeCredentials {
|
||||
info: user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
})
|
||||
|
||||
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -174,14 +192,55 @@ impl PoolingBackend {
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// Wake up the destination if needed
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
pub(crate) async fn connect_to_local_proxy(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<http_conn_pool::Client, HttpConnError> {
|
||||
info!("pool: looking for an existing connection");
|
||||
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self
|
||||
.config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|()| ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.http_conn_pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
&backend,
|
||||
false, // do not allow self signed compute for http flow
|
||||
self.config.wake_compute_retry_config,
|
||||
self.config.connect_to_compute_retry_config,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
#[error("could not connection to compute")]
|
||||
ConnectionError(#[from] tokio_postgres::Error),
|
||||
#[error("could not connection to postgres in compute")]
|
||||
PostgresConnectionError(#[from] tokio_postgres::Error),
|
||||
#[error("could not connection to local-proxy in compute")]
|
||||
LocalProxyConnectionError(#[from] LocalProxyConnError),
|
||||
|
||||
#[error("could not get auth info")]
|
||||
GetAuthInfo(#[from] GetAuthInfoError),
|
||||
@@ -193,11 +252,20 @@ pub(crate) enum HttpConnError {
|
||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum LocalProxyConnError {
|
||||
#[error("error with connection to local-proxy")]
|
||||
Io(#[source] std::io::Error),
|
||||
#[error("could not establish h2 connection")]
|
||||
H2(#[from] hyper1::Error),
|
||||
}
|
||||
|
||||
impl ReportableError for HttpConnError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
||||
HttpConnError::ConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
|
||||
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
||||
HttpConnError::AuthError(a) => a.get_error_kind(),
|
||||
HttpConnError::WakeCompute(w) => w.get_error_kind(),
|
||||
@@ -210,7 +278,8 @@ impl UserFacingError for HttpConnError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
||||
HttpConnError::ConnectionError(p) => p.to_string(),
|
||||
HttpConnError::PostgresConnectionError(p) => p.to_string(),
|
||||
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
|
||||
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
||||
HttpConnError::AuthError(c) => c.to_string_client(),
|
||||
HttpConnError::WakeCompute(c) => c.to_string_client(),
|
||||
@@ -224,7 +293,8 @@ impl UserFacingError for HttpConnError {
|
||||
impl CouldRetry for HttpConnError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
HttpConnError::ConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => false,
|
||||
HttpConnError::GetAuthInfo(_) => false,
|
||||
HttpConnError::AuthError(_) => false,
|
||||
@@ -236,7 +306,7 @@ impl CouldRetry for HttpConnError {
|
||||
impl ShouldRetryWakeCompute for HttpConnError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
|
||||
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
|
||||
// we never checked cache validity
|
||||
HttpConnError::TooManyConnectionAttempts(_) => false,
|
||||
_ => true,
|
||||
@@ -244,6 +314,38 @@ impl ShouldRetryWakeCompute for HttpConnError {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for LocalProxyConnError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => ErrorKind::Compute,
|
||||
LocalProxyConnError::H2(_) => ErrorKind::Compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for LocalProxyConnError {
|
||||
fn to_string_client(&self) -> String {
|
||||
"Could not establish HTTP connection to the database".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for LocalProxyConnError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => false,
|
||||
LocalProxyConnError::H2(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ShouldRetryWakeCompute for LocalProxyConnError {
|
||||
fn should_retry_wake_compute(&self) -> bool {
|
||||
match self {
|
||||
LocalProxyConnError::Io(_) => false,
|
||||
LocalProxyConnError::H2(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
conn_info: ConnInfo,
|
||||
@@ -293,3 +395,99 @@ impl ConnectMechanism for TokioMechanism {
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
struct HyperMechanism {
|
||||
pool: Arc<http_conn_pool::GlobalConnPool>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
locks: &'static ApiLocks<Host>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for HyperMechanism {
|
||||
type Connection = http_conn_pool::Client;
|
||||
type ConnectError = HttpConnError;
|
||||
type Error = HttpConnError;
|
||||
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
node_info: &CachedNodeInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
let host = node_info.config.get_host()?;
|
||||
let permit = self.locks.get_permit(&host).await?;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
|
||||
let res = connect_http2(&host, 10432, timeout).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
Ok(poll_http2_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
&self.conn_info,
|
||||
client,
|
||||
connection,
|
||||
self.conn_id,
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
|
||||
}
|
||||
|
||||
async fn connect_http2(
|
||||
host: &str,
|
||||
port: u16,
|
||||
timeout: Duration,
|
||||
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
|
||||
// assumption: host is an ip address so this should not actually perform any requests.
|
||||
// todo: add that assumption as a guarantee in the control-plane API.
|
||||
let mut addrs = lookup_host((host, port))
|
||||
.await
|
||||
.map_err(LocalProxyConnError::Io)?;
|
||||
|
||||
let mut last_err = None;
|
||||
|
||||
let stream = loop {
|
||||
let Some(addr) = addrs.next() else {
|
||||
return Err(last_err.unwrap_or_else(|| {
|
||||
LocalProxyConnError::Io(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"could not resolve any addresses",
|
||||
))
|
||||
}));
|
||||
};
|
||||
|
||||
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
|
||||
Ok(Ok(stream)) => {
|
||||
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
|
||||
break stream;
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
last_err = Some(LocalProxyConnError::Io(e));
|
||||
}
|
||||
Err(e) => {
|
||||
last_err = Some(LocalProxyConnError::Io(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
e,
|
||||
)));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive_interval(Duration::from_secs(20))
|
||||
.keep_alive_while_idle(true)
|
||||
.keep_alive_timeout(Duration::from_secs(5))
|
||||
.handshake(TokioIo::new(stream))
|
||||
.await?;
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
335
proxy/src/serverless/http_conn_pool.rs
Normal file
335
proxy/src/serverless/http_conn_pool.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
use dashmap::DashMap;
|
||||
use hyper1::client::conn::http2;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo};
|
||||
use parking_lot::RwLock;
|
||||
use rand::Rng;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::atomic::{self, AtomicUsize};
|
||||
use std::{sync::Arc, sync::Weak};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{context::RequestMonitoring, EndpointCacheKey};
|
||||
|
||||
use tracing::{debug, error};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
|
||||
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
|
||||
pub(crate) type Connect =
|
||||
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConnPoolEntry {
|
||||
conn: Send,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
// Per-endpoint connection pool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool {
|
||||
conns: VecDeque<ConnPoolEntry>,
|
||||
_guard: HttpEndpointPoolsGuard<'static>,
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl EndpointConnPool {
|
||||
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
|
||||
let Self { conns, .. } = self;
|
||||
|
||||
let conn = conns.pop_front()?;
|
||||
conns.push_back(conn.clone());
|
||||
Some(conn)
|
||||
}
|
||||
|
||||
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
|
||||
let Self {
|
||||
conns,
|
||||
global_connections_count,
|
||||
..
|
||||
} = self;
|
||||
|
||||
let old_len = conns.len();
|
||||
conns.retain(|conn| conn.conn_id != conn_id);
|
||||
let new_len = conns.len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
}
|
||||
removed > 0
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EndpointConnPool {
|
||||
fn drop(&mut self) {
|
||||
if !self.conns.is_empty() {
|
||||
self.global_connections_count
|
||||
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(self.conns.len() as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct GlobalConnPool {
|
||||
// endpoint -> per-endpoint connection pool
|
||||
//
|
||||
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||
// pool as early as possible and release the lock.
|
||||
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
|
||||
|
||||
/// Number of endpoint-connection pools
|
||||
///
|
||||
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
|
||||
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
|
||||
/// It's only used for diagnostics.
|
||||
global_pool_size: AtomicUsize,
|
||||
|
||||
/// Total number of connections in the pool
|
||||
global_connections_count: Arc<AtomicUsize>,
|
||||
|
||||
config: &'static crate::config::HttpConfig,
|
||||
}
|
||||
|
||||
impl GlobalConnPool {
|
||||
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
let shards = config.pool_options.pool_shards;
|
||||
Arc::new(Self {
|
||||
global_pool: DashMap::with_shard_amount(shards),
|
||||
global_pool_size: AtomicUsize::new(0),
|
||||
config,
|
||||
global_connections_count: Arc::new(AtomicUsize::new(0)),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&self) {
|
||||
// drops all strong references to endpoint-pools
|
||||
self.global_pool.clear();
|
||||
}
|
||||
|
||||
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
|
||||
let epoch = self.config.pool_options.gc_epoch;
|
||||
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
let shard = rng.gen_range(0..self.global_pool.shards().len());
|
||||
self.gc(shard);
|
||||
}
|
||||
}
|
||||
|
||||
fn gc(&self, shard: usize) {
|
||||
debug!(shard, "pool: performing epoch reclamation");
|
||||
|
||||
// acquire a random shard lock
|
||||
let mut shard = self.global_pool.shards()[shard].write();
|
||||
|
||||
let timer = Metrics::get()
|
||||
.proxy
|
||||
.http_pool_reclaimation_lag_seconds
|
||||
.start_timer();
|
||||
let current_len = shard.len();
|
||||
let mut clients_removed = 0;
|
||||
shard.retain(|endpoint, x| {
|
||||
// if the current endpoint pool is unique (no other strong or weak references)
|
||||
// then it is currently not in use by any connections.
|
||||
if let Some(pool) = Arc::get_mut(x.get_mut()) {
|
||||
let EndpointConnPool { conns, .. } = pool.get_mut();
|
||||
|
||||
let old_len = conns.len();
|
||||
|
||||
conns.retain(|conn| !conn.conn.is_closed());
|
||||
|
||||
let new_len = conns.len();
|
||||
let removed = old_len - new_len;
|
||||
clients_removed += removed;
|
||||
|
||||
// we only remove this pool if it has no active connections
|
||||
if conns.is_empty() {
|
||||
info!("pool: discarding pool for endpoint {endpoint}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
});
|
||||
|
||||
let new_len = shard.len();
|
||||
drop(shard);
|
||||
timer.observe();
|
||||
|
||||
// Do logging outside of the lock.
|
||||
if clients_removed > 0 {
|
||||
let size = self
|
||||
.global_connections_count
|
||||
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
|
||||
- clients_removed;
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(clients_removed as i64);
|
||||
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
|
||||
}
|
||||
let removed = current_len - new_len;
|
||||
|
||||
if removed > 0 {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_sub(removed, atomic::Ordering::Relaxed)
|
||||
- removed;
|
||||
info!("pool: performed global pool gc. size now {global_pool_size}");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Option<Client> {
|
||||
let endpoint = conn_info.endpoint_cache_key()?;
|
||||
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
|
||||
let client = endpoint_pool.write().get_conn_entry()?;
|
||||
|
||||
if client.conn.is_closed() {
|
||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return None;
|
||||
}
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
Some(Client::new(client.conn, client.aux))
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(
|
||||
self: &Arc<Self>,
|
||||
endpoint: &EndpointCacheKey,
|
||||
) -> Arc<RwLock<EndpointConnPool>> {
|
||||
// fast path
|
||||
if let Some(pool) = self.global_pool.get(endpoint) {
|
||||
return pool.clone();
|
||||
}
|
||||
|
||||
// slow path
|
||||
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
|
||||
conns: VecDeque::new(),
|
||||
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
|
||||
global_connections_count: self.global_connections_count.clone(),
|
||||
}));
|
||||
|
||||
// find or create a pool for this endpoint
|
||||
let mut created = false;
|
||||
let pool = self
|
||||
.global_pool
|
||||
.entry(endpoint.clone())
|
||||
.or_insert_with(|| {
|
||||
created = true;
|
||||
new_pool
|
||||
})
|
||||
.clone();
|
||||
|
||||
// log new global pool size
|
||||
if created {
|
||||
let global_pool_size = self
|
||||
.global_pool_size
|
||||
.fetch_add(1, atomic::Ordering::Relaxed)
|
||||
+ 1;
|
||||
info!(
|
||||
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
|
||||
);
|
||||
}
|
||||
|
||||
pool
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn poll_http2_client(
|
||||
global_pool: Arc<GlobalConnPool>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
client: Send,
|
||||
connection: Connect,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let session_id = ctx.session_id();
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
});
|
||||
|
||||
let pool = match conn_info.endpoint_cache_key() {
|
||||
Some(endpoint) => {
|
||||
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
|
||||
|
||||
pool.write().conns.push_back(ConnPoolEntry {
|
||||
conn: client.clone(),
|
||||
conn_id,
|
||||
aux: aux.clone(),
|
||||
});
|
||||
|
||||
Arc::downgrade(&pool)
|
||||
}
|
||||
None => Weak::new(),
|
||||
};
|
||||
|
||||
// let idle = global_pool.get_idle_timeout();
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let _conn_gauge = conn_gauge;
|
||||
let res = connection.await;
|
||||
match res {
|
||||
Ok(()) => info!("connection closed"),
|
||||
Err(e) => error!(%session_id, "connection error: {}", e),
|
||||
}
|
||||
|
||||
// remove from connection pool
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
if pool.write().remove_conn(conn_id) {
|
||||
info!("closed connection removed");
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
|
||||
Client::new(client, aux)
|
||||
}
|
||||
|
||||
pub(crate) struct Client {
|
||||
pub(crate) inner: Send,
|
||||
aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
|
||||
Self { inner, aux }
|
||||
}
|
||||
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: self.aux.endpoint_id,
|
||||
branch_id: self.aux.branch_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,13 +5,13 @@ use bytes::Bytes;
|
||||
|
||||
use anyhow::Context;
|
||||
use http::{Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Full};
|
||||
|
||||
use serde::Serialize;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
/// Like [`ApiError::into_response`]
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
match this {
|
||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||
@@ -64,17 +64,24 @@ struct HttpErrorBody {
|
||||
|
||||
impl HttpErrorBody {
|
||||
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
|
||||
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
fn response_from_msg_and_status(
|
||||
msg: String,
|
||||
status: StatusCode,
|
||||
) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
HttpErrorBody { msg }.to_response(status)
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
|
||||
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
|
||||
.body(
|
||||
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
@@ -83,14 +90,14 @@ impl HttpErrorBody {
|
||||
pub(crate) fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
.body(Full::new(Bytes::from(json)))
|
||||
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ use futures::future::Either;
|
||||
use futures::StreamExt;
|
||||
use futures::TryFutureExt;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http::Method;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Body;
|
||||
@@ -38,9 +40,11 @@ use url::Url;
|
||||
use urlencoding;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::auth::backend::ComputeCredentials;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
@@ -56,6 +60,7 @@ use crate::usage_metrics::MetricCounterRecorder;
|
||||
use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::backend::LocalProxyConnError;
|
||||
use super::backend::PoolingBackend;
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool::Client;
|
||||
@@ -123,8 +128,8 @@ pub(crate) enum ConnInfoError {
|
||||
MissingUsername,
|
||||
#[error("invalid username: {0}")]
|
||||
InvalidUsername(#[from] std::string::FromUtf8Error),
|
||||
#[error("missing password")]
|
||||
MissingPassword,
|
||||
#[error("missing authentication credentials: {0}")]
|
||||
MissingCredentials(Credentials),
|
||||
#[error("missing hostname")]
|
||||
MissingHostname,
|
||||
#[error("invalid hostname: {0}")]
|
||||
@@ -133,6 +138,14 @@ pub(crate) enum ConnInfoError {
|
||||
MalformedEndpoint,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum Credentials {
|
||||
#[error("required password")]
|
||||
Password,
|
||||
#[error("required authorization bearer token in JWT format")]
|
||||
BearerJwt,
|
||||
}
|
||||
|
||||
impl ReportableError for ConnInfoError {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
ErrorKind::User
|
||||
@@ -146,6 +159,7 @@ impl UserFacingError for ConnInfoError {
|
||||
}
|
||||
|
||||
fn get_conn_info(
|
||||
config: &'static AuthenticationConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
headers: &HeaderMap,
|
||||
tls: Option<&TlsConfig>,
|
||||
@@ -181,21 +195,32 @@ fn get_conn_info(
|
||||
ctx.set_user(username.clone());
|
||||
|
||||
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
|
||||
if !config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
}
|
||||
|
||||
let auth = auth
|
||||
.to_str()
|
||||
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
|
||||
AuthData::Jwt(
|
||||
auth.strip_prefix("Bearer ")
|
||||
.ok_or(ConnInfoError::MissingPassword)?
|
||||
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
|
||||
.into(),
|
||||
)
|
||||
} else if let Some(pass) = connection_url.password() {
|
||||
// wrong credentials provided
|
||||
if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
}
|
||||
|
||||
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
})
|
||||
} else if config.accept_jwts {
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
|
||||
} else {
|
||||
return Err(ConnInfoError::MissingPassword);
|
||||
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
|
||||
};
|
||||
|
||||
let endpoint = match connection_url.host() {
|
||||
@@ -247,7 +272,7 @@ pub(crate) async fn handle(
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
|
||||
let result = handle_inner(cancel, config, &ctx, request, backend).await;
|
||||
|
||||
let mut response = match result {
|
||||
@@ -279,7 +304,7 @@ pub(crate) async fn handle(
|
||||
|
||||
let mut message = e.to_string_client();
|
||||
let db_error = match &e {
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
@@ -504,7 +529,7 @@ async fn handle_inner(
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
let _requeset_gauge = Metrics::get()
|
||||
.proxy
|
||||
.connection_requests
|
||||
@@ -514,18 +539,50 @@ async fn handle_inner(
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
|
||||
// TLS config should be there.
|
||||
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
|
||||
let conn_info = get_conn_info(
|
||||
&config.authentication_config,
|
||||
ctx,
|
||||
request.headers(),
|
||||
config.tls_config.as_ref(),
|
||||
)?;
|
||||
info!(
|
||||
user = conn_info.conn_info.user_info.user.as_str(),
|
||||
"credentials"
|
||||
);
|
||||
|
||||
match conn_info.auth {
|
||||
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
}
|
||||
auth => {
|
||||
handle_db_inner(
|
||||
cancel,
|
||||
config,
|
||||
ctx,
|
||||
request,
|
||||
conn_info.conn_info,
|
||||
auth,
|
||||
backend,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_db_inner(
|
||||
cancel: CancellationToken,
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
auth: AuthData,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
let headers = request.headers();
|
||||
|
||||
// Allow connection pooling only if explicitly requested
|
||||
// or if we have decided that http pool is no longer opt-in
|
||||
let allow_pool = !config.http_config.pool_options.opt_in
|
||||
@@ -563,26 +620,36 @@ async fn handle_inner(
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let keys = match &conn_info.auth {
|
||||
let keys = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
.authenticate_with_password(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.conn_info.user_info,
|
||||
pw,
|
||||
&conn_info.user_info,
|
||||
&pw,
|
||||
)
|
||||
.await?
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
backend
|
||||
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
|
||||
.await?
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await?;
|
||||
|
||||
ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
@@ -640,7 +707,11 @@ async fn handle_inner(
|
||||
|
||||
let len = json_output.len();
|
||||
let response = response
|
||||
.body(Full::new(Bytes::from(json_output)))
|
||||
.body(
|
||||
Full::new(Bytes::from(json_output))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
)
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
@@ -656,6 +727,65 @@ async fn handle_inner(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
&AUTHORIZATION,
|
||||
&CONN_STRING,
|
||||
&RAW_TEXT_OUTPUT,
|
||||
&ARRAY_MODE,
|
||||
&TXN_ISOLATION_LEVEL,
|
||||
&TXN_READ_ONLY,
|
||||
&TXN_DEFERRABLE,
|
||||
];
|
||||
|
||||
async fn handle_auth_broker_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
jwt: String,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
|
||||
backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
|
||||
|
||||
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
|
||||
|
||||
let (mut parts, body) = request.into_parts();
|
||||
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
|
||||
|
||||
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
|
||||
// these instead just to ensure they remain normalised.
|
||||
for &h in HEADERS_TO_FORWARD {
|
||||
if let Some(hv) = parts.headers.remove(h) {
|
||||
req = req.header(h, hv);
|
||||
}
|
||||
}
|
||||
|
||||
let req = req
|
||||
.body(body)
|
||||
.expect("all headers and params received via hyper should be valid for request");
|
||||
|
||||
// todo: map body to count egress
|
||||
let _metrics = client.metrics();
|
||||
|
||||
Ok(client
|
||||
.inner
|
||||
.send_request(req)
|
||||
.await
|
||||
.map_err(LocalProxyConnError::from)
|
||||
.map_err(HttpConnError::from)?
|
||||
.map(|b| b.boxed()))
|
||||
}
|
||||
|
||||
impl QueryData {
|
||||
async fn process(
|
||||
self,
|
||||
@@ -705,7 +835,9 @@ impl QueryData {
|
||||
// query failed or was cancelled.
|
||||
Ok(Err(error)) => {
|
||||
let db_error = match &error {
|
||||
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
|
||||
SqlOverHttpError::ConnectCompute(
|
||||
HttpConnError::PostgresConnectionError(e),
|
||||
)
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
@@ -9,7 +9,7 @@ use crate::walproposer_sim::{
|
||||
|
||||
pub mod walproposer_sim;
|
||||
|
||||
// Generates 500 random seeds and runs a schedule for each of them.
|
||||
// Generates 2000 random seeds and runs a schedule for each of them.
|
||||
// If you see this test fail, please report the last seed to the
|
||||
// @safekeeper team.
|
||||
#[test]
|
||||
@@ -17,7 +17,7 @@ fn test_random_schedules() -> anyhow::Result<()> {
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
|
||||
for _ in 0..500 {
|
||||
for _ in 0..2000 {
|
||||
let seed: u64 = rand::thread_rng().gen();
|
||||
config.network = generate_network_opts(seed);
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ from _pytest.fixtures import FixtureRequest
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
from psycopg2.extensions import cursor as PgCursor
|
||||
from psycopg2.extensions import make_dsn, parse_dsn
|
||||
from pytest_httpserver import HTTPServer
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from fixtures import overlayfs
|
||||
@@ -401,7 +402,6 @@ class NeonEnvBuilder:
|
||||
safekeeper_extra_opts: Optional[list[str]] = None,
|
||||
storage_controller_port_override: Optional[int] = None,
|
||||
pageserver_io_buffer_alignment: Optional[int] = None,
|
||||
pageserver_virtual_file_io_mode: Optional[str] = None,
|
||||
):
|
||||
self.repo_dir = repo_dir
|
||||
self.rust_log_override = rust_log_override
|
||||
@@ -441,9 +441,9 @@ class NeonEnvBuilder:
|
||||
|
||||
self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine
|
||||
|
||||
self.pageserver_default_tenant_config_compaction_algorithm: Optional[
|
||||
Dict[str, Any]
|
||||
] = pageserver_default_tenant_config_compaction_algorithm
|
||||
self.pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]] = (
|
||||
pageserver_default_tenant_config_compaction_algorithm
|
||||
)
|
||||
if self.pageserver_default_tenant_config_compaction_algorithm is not None:
|
||||
log.debug(
|
||||
f"Overriding pageserver default compaction algorithm to {self.pageserver_default_tenant_config_compaction_algorithm}"
|
||||
@@ -456,7 +456,6 @@ class NeonEnvBuilder:
|
||||
self.storage_controller_port_override = storage_controller_port_override
|
||||
|
||||
self.pageserver_io_buffer_alignment = pageserver_io_buffer_alignment
|
||||
self.pageserver_virtual_file_io_mode = pageserver_virtual_file_io_mode
|
||||
|
||||
assert test_name.startswith(
|
||||
"test_"
|
||||
@@ -1030,7 +1029,6 @@ class NeonEnv:
|
||||
self.pageserver_virtual_file_io_engine = config.pageserver_virtual_file_io_engine
|
||||
self.pageserver_aux_file_policy = config.pageserver_aux_file_policy
|
||||
self.pageserver_io_buffer_alignment = config.pageserver_io_buffer_alignment
|
||||
self.pageserver_virtual_file_io_mode = config.pageserver_virtual_file_io_mode
|
||||
|
||||
# Create the neon_local's `NeonLocalInitConf`
|
||||
cfg: Dict[str, Any] = {
|
||||
@@ -1075,9 +1073,9 @@ class NeonEnv:
|
||||
ps_cfg["virtual_file_io_engine"] = self.pageserver_virtual_file_io_engine
|
||||
if config.pageserver_default_tenant_config_compaction_algorithm is not None:
|
||||
tenant_config = ps_cfg.setdefault("tenant_config", {})
|
||||
tenant_config[
|
||||
"compaction_algorithm"
|
||||
] = config.pageserver_default_tenant_config_compaction_algorithm
|
||||
tenant_config["compaction_algorithm"] = (
|
||||
config.pageserver_default_tenant_config_compaction_algorithm
|
||||
)
|
||||
|
||||
if self.pageserver_remote_storage is not None:
|
||||
ps_cfg["remote_storage"] = remote_storage_to_toml_dict(
|
||||
@@ -1094,10 +1092,7 @@ class NeonEnv:
|
||||
for key, value in override.items():
|
||||
ps_cfg[key] = value
|
||||
|
||||
if self.pageserver_io_buffer_alignment is not None:
|
||||
ps_cfg["io_buffer_alignment"] = self.pageserver_io_buffer_alignment
|
||||
if self.pageserver_virtual_file_io_mode is not None:
|
||||
ps_cfg["virtual_file_io_mode"] = self.pageserver_virtual_file_io_mode
|
||||
ps_cfg["io_buffer_alignment"] = self.pageserver_io_buffer_alignment
|
||||
|
||||
# Create a corresponding NeonPageserver object
|
||||
self.pageservers.append(
|
||||
@@ -1123,9 +1118,9 @@ class NeonEnv:
|
||||
if config.auth_enabled:
|
||||
sk_cfg["auth_enabled"] = True
|
||||
if self.safekeepers_remote_storage is not None:
|
||||
sk_cfg[
|
||||
"remote_storage"
|
||||
] = self.safekeepers_remote_storage.to_toml_inline_table().strip()
|
||||
sk_cfg["remote_storage"] = (
|
||||
self.safekeepers_remote_storage.to_toml_inline_table().strip()
|
||||
)
|
||||
self.safekeepers.append(
|
||||
Safekeeper(env=self, id=id, port=port, extra_opts=config.safekeeper_extra_opts)
|
||||
)
|
||||
@@ -1336,7 +1331,6 @@ def neon_simple_env(
|
||||
pageserver_aux_file_policy: Optional[AuxFileStore],
|
||||
pageserver_default_tenant_config_compaction_algorithm: Optional[Dict[str, Any]],
|
||||
pageserver_io_buffer_alignment: Optional[int],
|
||||
pageserver_virtual_file_io_mode: Optional[str],
|
||||
) -> Iterator[NeonEnv]:
|
||||
"""
|
||||
Simple Neon environment, with no authentication and no safekeepers.
|
||||
@@ -1363,7 +1357,6 @@ def neon_simple_env(
|
||||
pageserver_aux_file_policy=pageserver_aux_file_policy,
|
||||
pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm,
|
||||
pageserver_io_buffer_alignment=pageserver_io_buffer_alignment,
|
||||
pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode,
|
||||
) as builder:
|
||||
env = builder.init_start()
|
||||
|
||||
@@ -1388,7 +1381,6 @@ def neon_env_builder(
|
||||
pageserver_aux_file_policy: Optional[AuxFileStore],
|
||||
record_property: Callable[[str, object], None],
|
||||
pageserver_io_buffer_alignment: Optional[int],
|
||||
pageserver_virtual_file_io_mode: Optional[str],
|
||||
) -> Iterator[NeonEnvBuilder]:
|
||||
"""
|
||||
Fixture to create a Neon environment for test.
|
||||
@@ -1424,7 +1416,6 @@ def neon_env_builder(
|
||||
pageserver_aux_file_policy=pageserver_aux_file_policy,
|
||||
pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm,
|
||||
pageserver_io_buffer_alignment=pageserver_io_buffer_alignment,
|
||||
pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode,
|
||||
) as builder:
|
||||
yield builder
|
||||
# Propogate `preserve_database_files` to make it possible to use in other fixtures,
|
||||
@@ -3321,7 +3312,7 @@ class VanillaPostgres(PgProtocol):
|
||||
self.pg_bin = pg_bin
|
||||
self.running = False
|
||||
if init:
|
||||
self.pg_bin.run_capture(["initdb", "--pgdata", str(pgdatadir)])
|
||||
self.pg_bin.run_capture(["initdb", "-D", str(pgdatadir)])
|
||||
self.configure([f"port = {port}\n"])
|
||||
|
||||
def enable_tls(self):
|
||||
@@ -3582,6 +3573,20 @@ class NeonProxy(PgProtocol):
|
||||
]
|
||||
return args
|
||||
|
||||
class AuthBroker(AuthBackend):
|
||||
def __init__(self, endpoint: str):
|
||||
self.endpoint = endpoint
|
||||
|
||||
def extra_args(self) -> list[str]:
|
||||
args = [
|
||||
# Console auth backend params
|
||||
*["--auth-backend", "console"],
|
||||
*["--auth-endpoint", self.endpoint],
|
||||
*["--sql-over-http-pool-opt-in", "false"],
|
||||
*["--is-auth-broker"],
|
||||
]
|
||||
return args
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Postgres(AuthBackend):
|
||||
pg_conn_url: str
|
||||
@@ -3610,7 +3615,7 @@ class NeonProxy(PgProtocol):
|
||||
metric_collection_interval: Optional[str] = None,
|
||||
):
|
||||
host = "127.0.0.1"
|
||||
domain = "proxy.localtest.me" # resolves to 127.0.0.1
|
||||
domain = "ep-foo-bar-1234.localtest.me" # resolves to 127.0.0.1
|
||||
super().__init__(dsn=auth_backend.default_conn_url, host=domain, port=proxy_port)
|
||||
|
||||
self.domain = domain
|
||||
@@ -3896,6 +3901,50 @@ def static_proxy(
|
||||
yield proxy
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def static_auth_broker(
|
||||
vanilla_pg: VanillaPostgres,
|
||||
port_distributor: PortDistributor,
|
||||
neon_binpath: Path,
|
||||
test_output_dir: Path,
|
||||
httpserver: HTTPServer,
|
||||
) -> Iterator[NeonProxy]:
|
||||
"""Neon proxy that routes directly to vanilla postgres."""
|
||||
|
||||
auth_endpoint = httpserver.url_for("/cplane")
|
||||
|
||||
port = vanilla_pg.default_options["port"]
|
||||
host = vanilla_pg.default_options["host"]
|
||||
|
||||
httpserver.expect_request("/cplane/proxy_wake_compute").respond_with_json(
|
||||
{
|
||||
"address": f"{host}:{port}",
|
||||
"aux": {
|
||||
"endpoint_id": "ep-foo-bar-1234",
|
||||
"branch_id": "br-foo-bar",
|
||||
"project_id": "foo-bar",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
proxy_port = port_distributor.get_port()
|
||||
mgmt_port = port_distributor.get_port()
|
||||
http_port = port_distributor.get_port()
|
||||
external_http_port = port_distributor.get_port()
|
||||
|
||||
with NeonProxy(
|
||||
neon_binpath=neon_binpath,
|
||||
test_output_dir=test_output_dir,
|
||||
proxy_port=proxy_port,
|
||||
http_port=http_port,
|
||||
mgmt_port=mgmt_port,
|
||||
external_http_port=external_http_port,
|
||||
auth_backend=NeonProxy.AuthBroker(auth_endpoint),
|
||||
) as proxy:
|
||||
proxy.start()
|
||||
yield proxy
|
||||
|
||||
|
||||
class Endpoint(PgProtocol, LogUtils):
|
||||
"""An object representing a Postgres compute endpoint managed by the control plane."""
|
||||
|
||||
|
||||
@@ -39,11 +39,6 @@ def pageserver_io_buffer_alignment() -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def pageserver_virtual_file_io_mode() -> Optional[str]:
|
||||
return os.getenv("PAGESERVER_VIRTUAL_FILE_IO_MODE")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def pageserver_aux_file_policy() -> Optional[AuxFileStore]:
|
||||
return None
|
||||
|
||||
71
test_runner/regress/test_auth_broker.py
Normal file
71
test_runner/regress/test_auth_broker.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import psycopg2
|
||||
import pytest
|
||||
import requests
|
||||
from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres
|
||||
|
||||
GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'"
|
||||
|
||||
|
||||
def test_sql_over_http(static_auth_broker: NeonProxy):
|
||||
static_auth_broker.safe_psql("create role http with login password 'http' superuser")
|
||||
|
||||
def q(sql: str, params: Optional[List[Any]] = None) -> Any:
|
||||
params = params or []
|
||||
connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
|
||||
response = requests.post(
|
||||
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
||||
data=json.dumps({"query": sql, "params": params}),
|
||||
headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
|
||||
verify=str(static_proxy.test_output_dir / "proxy.crt"),
|
||||
)
|
||||
assert response.status_code == 200, response.text
|
||||
return response.json()
|
||||
|
||||
rows = q("select 42 as answer")["rows"]
|
||||
assert rows == [{"answer": 42}]
|
||||
|
||||
rows = q("select $1 as answer", [42])["rows"]
|
||||
assert rows == [{"answer": "42"}]
|
||||
|
||||
rows = q("select $1 * 1 as answer", [42])["rows"]
|
||||
assert rows == [{"answer": 42}]
|
||||
|
||||
rows = q("select $1::int[] as answer", [[1, 2, 3]])["rows"]
|
||||
assert rows == [{"answer": [1, 2, 3]}]
|
||||
|
||||
rows = q("select $1::json->'a' as answer", [{"a": {"b": 42}}])["rows"]
|
||||
assert rows == [{"answer": {"b": 42}}]
|
||||
|
||||
rows = q("select $1::jsonb[] as answer", [[{}]])["rows"]
|
||||
assert rows == [{"answer": [{}]}]
|
||||
|
||||
rows = q("select $1::jsonb[] as answer", [[{"foo": 1}, {"bar": 2}]])["rows"]
|
||||
assert rows == [{"answer": [{"foo": 1}, {"bar": 2}]}]
|
||||
|
||||
rows = q("select * from pg_class limit 1")["rows"]
|
||||
assert len(rows) == 1
|
||||
|
||||
res = q("create table t(id serial primary key, val int)")
|
||||
assert res["command"] == "CREATE"
|
||||
assert res["rowCount"] is None
|
||||
|
||||
res = q("insert into t(val) values (10), (20), (30) returning id")
|
||||
assert res["command"] == "INSERT"
|
||||
assert res["rowCount"] == 3
|
||||
assert res["rows"] == [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
|
||||
res = q("select * from t")
|
||||
assert res["command"] == "SELECT"
|
||||
assert res["rowCount"] == 3
|
||||
|
||||
res = q("drop table t")
|
||||
assert res["command"] == "DROP"
|
||||
assert res["rowCount"] is None
|
||||
|
||||
@@ -27,7 +27,7 @@ def test_readonly_node(neon_simple_env: NeonEnv):
|
||||
env.pageserver.allowed_errors.extend(
|
||||
[
|
||||
".*basebackup .* failed: invalid basebackup lsn.*",
|
||||
".*/lsn_lease.*invalid lsn lease request.*",
|
||||
".*page_service.*handle_make_lsn_lease.*.*tried to request a page version that was garbage collected",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -108,7 +108,7 @@ def test_readonly_node(neon_simple_env: NeonEnv):
|
||||
assert cur.fetchone() == (1,)
|
||||
|
||||
# Create node at pre-initdb lsn
|
||||
with pytest.raises(Exception, match="invalid lsn lease request"):
|
||||
with pytest.raises(Exception, match="invalid basebackup lsn"):
|
||||
# compute node startup with invalid LSN should fail
|
||||
env.endpoints.create_start(
|
||||
branch_name="main",
|
||||
@@ -167,23 +167,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
|
||||
)
|
||||
return last_flush_lsn
|
||||
|
||||
def trigger_gc_and_select(env: NeonEnv, ep_static: Endpoint):
|
||||
"""
|
||||
Trigger GC manually on all pageservers. Then run an `SELECT` query.
|
||||
"""
|
||||
for shard, ps in tenant_get_shards(env, env.initial_tenant):
|
||||
client = ps.http_client()
|
||||
gc_result = client.timeline_gc(shard, env.initial_timeline, 0)
|
||||
log.info(f"{gc_result=}")
|
||||
|
||||
assert (
|
||||
gc_result["layers_removed"] == 0
|
||||
), "No layers should be removed, old layers are guarded by leases."
|
||||
|
||||
with ep_static.cursor() as cur:
|
||||
cur.execute("SELECT count(*) FROM t0")
|
||||
assert cur.fetchone() == (ROW_COUNT,)
|
||||
|
||||
# Insert some records on main branch
|
||||
with env.endpoints.create_start("main") as ep_main:
|
||||
with ep_main.cursor() as cur:
|
||||
@@ -210,31 +193,25 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
|
||||
|
||||
generate_updates_on_main(env, ep_main, i, end=100)
|
||||
|
||||
trigger_gc_and_select(env, ep_static)
|
||||
# Trigger GC
|
||||
for shard, ps in tenant_get_shards(env, env.initial_tenant):
|
||||
client = ps.http_client()
|
||||
gc_result = client.timeline_gc(shard, env.initial_timeline, 0)
|
||||
log.info(f"{gc_result=}")
|
||||
|
||||
# Trigger Pageserver restarts
|
||||
for ps in env.pageservers:
|
||||
ps.stop()
|
||||
# Static compute should have at least one lease request failure due to connection.
|
||||
time.sleep(LSN_LEASE_LENGTH / 2)
|
||||
ps.start()
|
||||
assert (
|
||||
gc_result["layers_removed"] == 0
|
||||
), "No layers should be removed, old layers are guarded by leases."
|
||||
|
||||
trigger_gc_and_select(env, ep_static)
|
||||
|
||||
# Reconfigure pageservers
|
||||
env.pageservers[0].stop()
|
||||
env.storage_controller.node_configure(
|
||||
env.pageservers[0].id, {"availability": "Offline"}
|
||||
)
|
||||
env.storage_controller.reconcile_until_idle()
|
||||
|
||||
trigger_gc_and_select(env, ep_static)
|
||||
with ep_static.cursor() as cur:
|
||||
cur.execute("SELECT count(*) FROM t0")
|
||||
assert cur.fetchone() == (ROW_COUNT,)
|
||||
|
||||
# Do some update so we can increment latest_gc_cutoff
|
||||
generate_updates_on_main(env, ep_main, i, end=100)
|
||||
|
||||
# Wait for the existing lease to expire.
|
||||
time.sleep(LSN_LEASE_LENGTH + 1)
|
||||
time.sleep(LSN_LEASE_LENGTH)
|
||||
# Now trigger GC again, layers should be removed.
|
||||
for shard, ps in tenant_get_shards(env, env.initial_tenant):
|
||||
client = ps.http_client()
|
||||
|
||||
@@ -45,7 +45,10 @@ def test_gc_blocking_by_timeline(neon_env_builder: NeonEnvBuilder, sharded: bool
|
||||
tenant_after = http.tenant_status(env.initial_tenant)
|
||||
assert tenant_before != tenant_after
|
||||
gc_blocking = tenant_after["gc_blocking"]
|
||||
assert gc_blocking == "BlockingReasons { timelines: 1, reasons: EnumSet(Manual) }"
|
||||
assert (
|
||||
gc_blocking
|
||||
== "BlockingReasons { tenant_blocked_by_lsn_lease_deadline: false, timelines: 1, reasons: EnumSet(Manual) }"
|
||||
)
|
||||
|
||||
wait_for_another_gc_round()
|
||||
pss.assert_log_contains(gc_skipped_line)
|
||||
|
||||
@@ -772,7 +772,7 @@ class ProposerPostgres(PgProtocol):
|
||||
def initdb(self):
|
||||
"""Run initdb"""
|
||||
|
||||
args = ["initdb", "--username", "cloud_admin", "--pgdata", self.pg_data_dir_path()]
|
||||
args = ["initdb", "-U", "cloud_admin", "-D", self.pg_data_dir_path()]
|
||||
self.pg_bin.run(args)
|
||||
|
||||
def start(self):
|
||||
|
||||
Reference in New Issue
Block a user