Compare commits

..

4 Commits

Author SHA1 Message Date
John Spray
b58e9ef05b tests: add test_image_layer_reads 2024-09-27 18:09:57 +01:00
John Spray
c2c9530ab7 hack: log layer accesses 2024-09-27 16:49:59 +01:00
John Spray
c62f45fff4 hack: always repartition 2024-09-27 16:49:42 +01:00
John Spray
6a9d51b41f pageserver: unit test for case of LayerMap::search at same LSN as image layer 2024-09-27 13:39:03 +01:00
103 changed files with 2248 additions and 4393 deletions

View File

@@ -3,23 +3,19 @@ name: Prepare benchmarking databases by restoring dumps
on:
workflow_call:
# no inputs needed
defaults:
run:
shell: bash -euxo pipefail {0}
jobs:
setup-databases:
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
strategy:
fail-fast: false
matrix:
platform: [ aws-rds-postgres, aws-aurora-serverless-v2-postgres, neon ]
platform: [ aws-rds-postgres, aws-aurora-serverless-v2-postgres, neon ]
database: [ clickbench, tpch, userexample ]
env:
LD_LIBRARY_PATH: /tmp/neon/pg_install/v16/lib
PLATFORM: ${{ matrix.platform }}
@@ -27,10 +23,7 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
@@ -39,13 +32,13 @@ jobs:
run: |
case "${PLATFORM}" in
neon)
CONNSTR=${{ secrets.BENCHMARK_CAPTEST_CONNSTR }}
CONNSTR=${{ secrets.BENCHMARK_CAPTEST_CONNSTR }}
;;
aws-rds-postgres)
CONNSTR=${{ secrets.BENCHMARK_RDS_POSTGRES_CONNSTR }}
CONNSTR=${{ secrets.BENCHMARK_RDS_POSTGRES_CONNSTR }}
;;
aws-aurora-serverless-v2-postgres)
CONNSTR=${{ secrets.BENCHMARK_RDS_AURORA_CONNSTR }}
CONNSTR=${{ secrets.BENCHMARK_RDS_AURORA_CONNSTR }}
;;
*)
echo >&2 "Unknown PLATFORM=${PLATFORM}"
@@ -53,17 +46,10 @@ jobs:
;;
esac
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
@@ -71,23 +57,23 @@ jobs:
path: /tmp/neon/
prefix: latest
# we create a table that has one row for each database that we want to restore with the status whether the restore is done
# we create a table that has one row for each database that we want to restore with the status whether the restore is done
- name: Create benchmark_restore_status table if it does not exist
env:
BENCHMARK_CONNSTR: ${{ steps.set-up-prep-connstr.outputs.connstr }}
DATABASE_NAME: ${{ matrix.database }}
# to avoid a race condition of multiple jobs trying to create the table at the same time,
# to avoid a race condition of multiple jobs trying to create the table at the same time,
# we use an advisory lock
run: |
${PG_BINARIES}/psql "${{ env.BENCHMARK_CONNSTR }}" -c "
SELECT pg_advisory_lock(4711);
SELECT pg_advisory_lock(4711);
CREATE TABLE IF NOT EXISTS benchmark_restore_status (
databasename text primary key,
restore_done boolean
);
SELECT pg_advisory_unlock(4711);
"
- name: Check if restore is already done
id: check-restore-done
env:
@@ -121,7 +107,7 @@ jobs:
DATABASE_NAME: ${{ matrix.database }}
run: |
mkdir -p /tmp/dumps
aws s3 cp s3://neon-github-dev/performance/pgdumps/$DATABASE_NAME/$DATABASE_NAME.pg_dump /tmp/dumps/
aws s3 cp s3://neon-github-dev/performance/pgdumps/$DATABASE_NAME/$DATABASE_NAME.pg_dump /tmp/dumps/
- name: Replace database name in connection string
if: steps.check-restore-done.outputs.skip != 'true'
@@ -140,17 +126,17 @@ jobs:
else
new_connstr="${base_connstr}/${DATABASE_NAME}"
fi
echo "database_connstr=${new_connstr}" >> $GITHUB_OUTPUT
echo "database_connstr=${new_connstr}" >> $GITHUB_OUTPUT
- name: Restore dump
if: steps.check-restore-done.outputs.skip != 'true'
env:
DATABASE_NAME: ${{ matrix.database }}
DATABASE_CONNSTR: ${{ steps.replace-dbname.outputs.database_connstr }}
# the following works only with larger computes:
# the following works only with larger computes:
# PGOPTIONS: "-c maintenance_work_mem=8388608 -c max_parallel_maintenance_workers=7"
# we add the || true because:
# the dumps were created with Neon and contain neon extensions that are not
# the dumps were created with Neon and contain neon extensions that are not
# available in RDS, so we will always report an error, but we can ignore it
run: |
${PG_BINARIES}/pg_restore --clean --if-exists --no-owner --jobs=4 \

View File

@@ -236,7 +236,9 @@ jobs:
# run pageserver tests with different settings
for io_engine in std-fs tokio-epoll-uring ; do
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)'
for io_buffer_alignment in 0 1 512 ; do
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT=$io_buffer_alignment ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)'
done
done
# Run separate tests for real S3

View File

@@ -12,6 +12,7 @@ on:
# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
- cron: '0 3 * * *' # run once a day, timezone is utc
workflow_dispatch: # adds ability to run this manually
inputs:
region_id:
@@ -58,7 +59,7 @@ jobs:
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
id-token: write # Required for OIDC authentication in azure runners
strategy:
fail-fast: false
matrix:
@@ -67,10 +68,12 @@ jobs:
PLATFORM: "neon-staging"
region_id: ${{ github.event.inputs.region_id || 'aws-us-east-2' }}
RUNNER: [ self-hosted, us-east-2, x64 ]
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
- DEFAULT_PG_VERSION: 16
PLATFORM: "azure-staging"
region_id: 'azure-eastus2'
RUNNER: [ self-hosted, eastus2, x64 ]
IMAGE: neondatabase/build-tools:pinned
env:
TEST_PG_BENCH_DURATIONS_MATRIX: "300"
TEST_PG_BENCH_SCALES_MATRIX: "10,100"
@@ -83,10 +86,7 @@ jobs:
runs-on: ${{ matrix.RUNNER }}
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: ${{ matrix.IMAGE }}
options: --init
steps:
@@ -164,10 +164,6 @@ jobs:
replication-tests:
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
env:
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 16
@@ -178,21 +174,12 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
@@ -280,7 +267,7 @@ jobs:
region_id_default=${{ env.DEFAULT_REGION_ID }}
runner_default='["self-hosted", "us-east-2", "x64"]'
runner_azure='["self-hosted", "eastus2", "x64"]'
image_default="neondatabase/build-tools:pinned"
image_default="369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned"
matrix='{
"pg_version" : [
16
@@ -357,7 +344,7 @@ jobs:
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
id-token: write # Required for OIDC authentication in azure runners
strategy:
fail-fast: false
@@ -384,7 +371,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
- name: Configure AWS credentials # necessary on Azure runners
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
@@ -505,15 +492,17 @@ jobs:
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
id-token: write # Required for OIDC authentication in azure runners
strategy:
fail-fast: false
matrix:
include:
- PLATFORM: "neonvm-captest-pgvector"
RUNNER: [ self-hosted, us-east-2, x64 ]
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
- PLATFORM: "azure-captest-pgvector"
RUNNER: [ self-hosted, eastus2, x64 ]
IMAGE: neondatabase/build-tools:pinned
env:
TEST_PG_BENCH_DURATIONS_MATRIX: "15m"
@@ -522,16 +511,13 @@ jobs:
DEFAULT_PG_VERSION: 16
TEST_OUTPUT: /tmp/test_output
BUILD_TYPE: remote
LD_LIBRARY_PATH: /home/nonroot/pg/usr/lib/x86_64-linux-gnu
SAVE_PERF_REPORT: ${{ github.event.inputs.save_perf_report || ( github.ref_name == 'main' ) }}
PLATFORM: ${{ matrix.PLATFORM }}
runs-on: ${{ matrix.RUNNER }}
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: ${{ matrix.IMAGE }}
options: --init
steps:
@@ -541,26 +527,17 @@ jobs:
# instead of using Neon artifacts containing pgbench
- name: Install postgresql-16 where pytest expects it
run: |
# Just to make it easier to test things locally on macOS (with arm64)
arch=$(uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g')
cd /home/nonroot
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.0-1.pgdg110+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.4-1.pgdg110+2_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.4-1.pgdg110+2_${arch}.deb"
dpkg -x libpq5_17.0-1.pgdg110+1_${arch}.deb pg
dpkg -x postgresql-16_16.4-1.pgdg110+2_${arch}.deb pg
dpkg -x postgresql-client-16_16.4-1.pgdg110+2_${arch}.deb pg
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/libpq5_16.4-1.pgdg110%2B1_amd64.deb
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.4-1.pgdg110%2B1_amd64.deb
wget -q https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.4-1.pgdg110%2B1_amd64.deb
dpkg -x libpq5_16.4-1.pgdg110+1_amd64.deb pg
dpkg -x postgresql-client-16_16.4-1.pgdg110+1_amd64.deb pg
dpkg -x postgresql-16_16.4-1.pgdg110+1_amd64.deb pg
mkdir -p /tmp/neon/pg_install/v16/bin
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/psql /tmp/neon/pg_install/v16/bin/psql
ln -s /home/nonroot/pg/usr/lib/$(uname -m)-linux-gnu /tmp/neon/pg_install/v16/lib
LD_LIBRARY_PATH="/home/nonroot/pg/usr/lib/$(uname -m)-linux-gnu:${LD_LIBRARY_PATH:-}"
export LD_LIBRARY_PATH
echo "LD_LIBRARY_PATH=${LD_LIBRARY_PATH}" >> ${GITHUB_ENV}
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/psql /tmp/neon/pg_install/v16/bin/psql
ln -s /home/nonroot/pg/usr/lib/x86_64-linux-gnu /tmp/neon/pg_install/v16/lib
/tmp/neon/pg_install/v16/bin/pgbench --version
/tmp/neon/pg_install/v16/bin/psql --version
@@ -582,7 +559,7 @@ jobs:
echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT
- name: Configure AWS credentials
- name: Configure AWS credentials # necessary on Azure runners to read/write from/to S3
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
@@ -643,10 +620,6 @@ jobs:
# *_CLICKBENCH_CONNSTR: Genuine ClickBench DB with ~100M rows
# *_CLICKBENCH_10M_CONNSTR: DB with the first 10M rows of ClickBench DB
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
needs: [ generate-matrices, pgbench-compare, prepare_AWS_RDS_databases ]
strategy:
@@ -665,22 +638,12 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
@@ -751,10 +714,6 @@ jobs:
#
# *_TPCH_S10_CONNSTR: DB generated with scale factor 10 (~10 GB)
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
needs: [ generate-matrices, clickbench-compare, prepare_AWS_RDS_databases ]
strategy:
@@ -772,22 +731,12 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
@@ -857,10 +806,6 @@ jobs:
user-examples-compare:
if: ${{ !cancelled() && (github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null) }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
needs: [ generate-matrices, tpch-compare, prepare_AWS_RDS_databases ]
strategy:
@@ -877,22 +822,12 @@ jobs:
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:pinned
options: --init
steps:
- uses: actions/checkout@v4
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:

View File

@@ -773,7 +773,7 @@ jobs:
matrix:
version: [ v14, v15, v16, v17 ]
env:
VM_BUILDER_VERSION: v0.35.0
VM_BUILDER_VERSION: v0.29.3
steps:
- uses: actions/checkout@v4
@@ -1190,9 +1190,10 @@ jobs:
files_to_promote+=("s3://${BUCKET}/${s3_key}")
for pg_version in v14 v15 v16 v17; do
# TODO Add v17
for pg_version in v14 v15 v16; do
# We run less tests for debug builds, so we don't need to promote them
if [ "${build_type}" == "debug" ] && { [ "${arch}" == "ARM64" ] || [ "${pg_version}" != "v17" ] ; }; then
if [ "${build_type}" == "debug" ] && { [ "${arch}" == "ARM64" ] || [ "${pg_version}" != "v16" ] ; }; then
continue
fi

218
Cargo.lock generated
View File

@@ -90,9 +90,9 @@ dependencies = [
[[package]]
name = "anstyle"
version = "1.0.8"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1"
checksum = "41ed9a86bf92ae6580e0a31281f65a1b1d867c0cc68d5346e2ae128dddfa6a7d"
[[package]]
name = "anstyle-parse"
@@ -1223,7 +1223,6 @@ dependencies = [
"notify",
"num_cpus",
"opentelemetry",
"opentelemetry_sdk",
"postgres",
"regex",
"remote_storage",
@@ -1322,7 +1321,6 @@ dependencies = [
"clap",
"comfy-table",
"compute_api",
"futures",
"humantime",
"humantime-serde",
"hyper 0.14.30",
@@ -1877,9 +1875,9 @@ dependencies = [
[[package]]
name = "env_logger"
version = "0.10.2"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580"
checksum = "85cdab6a89accf66733ad5a1693a4dcced6aeff64602b634530dd73c1f3ee9f0"
dependencies = [
"humantime",
"is-terminal",
@@ -3369,82 +3367,102 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "opentelemetry"
version = "0.24.0"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c365a63eec4f55b7efeceb724f1336f26a9cf3427b70e59e2cd2a5b947fba96"
checksum = "9591d937bc0e6d2feb6f71a559540ab300ea49955229c347a517a28d27784c54"
dependencies = [
"futures-core",
"futures-sink",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
"opentelemetry_api",
"opentelemetry_sdk",
]
[[package]]
name = "opentelemetry-http"
version = "0.13.0"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad31e9de44ee3538fb9d64fe3376c1362f406162434609e79aea2a41a0af78ab"
checksum = "c7594ec0e11d8e33faf03530a4c49af7064ebba81c1480e01be67d90b356508b"
dependencies = [
"async-trait",
"bytes",
"http 1.1.0",
"opentelemetry",
"reqwest 0.12.4",
"http 0.2.9",
"opentelemetry_api",
"reqwest 0.11.19",
]
[[package]]
name = "opentelemetry-otlp"
version = "0.17.0"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b925a602ffb916fb7421276b86756027b37ee708f9dce2dbdcc51739f07e727"
checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275"
dependencies = [
"async-trait",
"futures-core",
"http 1.1.0",
"opentelemetry",
"http 0.2.9",
"opentelemetry-http",
"opentelemetry-proto",
"opentelemetry-semantic-conventions",
"opentelemetry_api",
"opentelemetry_sdk",
"prost 0.13.3",
"reqwest 0.12.4",
"prost",
"reqwest 0.11.19",
"thiserror",
"tokio",
"tonic",
]
[[package]]
name = "opentelemetry-proto"
version = "0.7.0"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30ee9f20bff9c984511a02f082dc8ede839e4a9bf15cc2487c8d6fea5ad850d9"
checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb"
dependencies = [
"opentelemetry",
"opentelemetry_api",
"opentelemetry_sdk",
"prost 0.13.3",
"tonic 0.12.2",
"prost",
"tonic",
]
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.16.0"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cefe0543875379e47eb5f1e68ff83f45cc41366a92dfd0d073d513bf68e9a05"
checksum = "73c9f9340ad135068800e7f1b24e9e09ed9e7143f5bf8518ded3d3ec69789269"
dependencies = [
"opentelemetry",
]
[[package]]
name = "opentelemetry_api"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a81f725323db1b1206ca3da8bb19874bbd3f57c3bcd59471bfb04525b265b9b"
dependencies = [
"futures-channel",
"futures-util",
"indexmap 1.9.3",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
"urlencoding",
]
[[package]]
name = "opentelemetry_sdk"
version = "0.24.1"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "692eac490ec80f24a17828d49b40b60f5aeaccdfe6a503f939713afd22bc28df"
checksum = "fa8e705a0612d48139799fcbaba0d4a90f06277153e43dd2bdc16c6f0edd8026"
dependencies = [
"async-trait",
"crossbeam-channel",
"futures-channel",
"futures-executor",
"futures-util",
"glob",
"once_cell",
"opentelemetry",
"opentelemetry_api",
"ordered-float 3.9.2",
"percent-encoding",
"rand 0.8.5",
"regex",
"serde_json",
"thiserror",
"tokio",
@@ -3460,6 +3478,15 @@ dependencies = [
"num-traits",
]
[[package]]
name = "ordered-float"
version = "3.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1e1c390732d15f1d48471625cd92d154e66db2c56645e29a9cd26f4699f72dc"
dependencies = [
"num-traits",
]
[[package]]
name = "ordered-multimap"
version = "0.7.3"
@@ -4202,17 +4229,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
dependencies = [
"bytes",
"prost-derive 0.11.9",
]
[[package]]
name = "prost"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b0487d90e047de87f984913713b85c601c05609aad5b0df4b4573fbf69aa13f"
dependencies = [
"bytes",
"prost-derive 0.13.3",
"prost-derive",
]
[[package]]
@@ -4229,7 +4246,7 @@ dependencies = [
"multimap",
"petgraph",
"prettyplease 0.1.25",
"prost 0.11.9",
"prost",
"prost-types",
"regex",
"syn 1.0.109",
@@ -4250,26 +4267,13 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "prost-derive"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5"
dependencies = [
"anyhow",
"itertools 0.12.1",
"proc-macro2",
"quote",
"syn 2.0.52",
]
[[package]]
name = "prost-types"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
dependencies = [
"prost 0.11.9",
"prost",
]
[[package]]
@@ -4292,7 +4296,6 @@ dependencies = [
"camino-tempfile",
"chrono",
"clap",
"compute_api",
"consumption_metrics",
"dashmap",
"ecdsa 0.16.9",
@@ -4366,6 +4369,7 @@ dependencies = [
"tokio-tungstenite",
"tokio-util",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
"tracing-utils",
"try-lock",
@@ -4810,9 +4814,9 @@ dependencies = [
[[package]]
name = "reqwest-tracing"
version = "0.5.3"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfdd9bfa64c72233d8dd99ab7883efcdefe9e16d46488ecb9228b71a2e2ceb45"
checksum = "b253954a1979e02eabccd7e9c3d61d8f86576108baa160775e7f160bb4e800a3"
dependencies = [
"anyhow",
"async-trait",
@@ -5697,9 +5701,9 @@ dependencies = [
"metrics",
"once_cell",
"parking_lot 0.12.1",
"prost 0.11.9",
"prost",
"tokio",
"tonic 0.9.2",
"tonic",
"tonic-build",
"tracing",
"utils",
@@ -6023,7 +6027,7 @@ checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09"
dependencies = [
"byteorder",
"integer-encoding",
"ordered-float",
"ordered-float 2.10.1",
]
[[package]]
@@ -6125,9 +6129,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.38.1"
version = "1.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df"
checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787"
dependencies = [
"backtrace",
"bytes",
@@ -6169,9 +6173,9 @@ dependencies = [
[[package]]
name = "tokio-macros"
version = "2.3.0"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a"
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
@@ -6346,7 +6350,7 @@ dependencies = [
"hyper-timeout",
"percent-encoding",
"pin-project",
"prost 0.11.9",
"prost",
"rustls-native-certs 0.6.2",
"rustls-pemfile 1.0.2",
"tokio",
@@ -6358,27 +6362,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "tonic"
version = "0.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6f6ba989e4b2c58ae83d862d3a3e27690b6e3ae630d0deb59f3697f32aa88ad"
dependencies = [
"async-trait",
"base64 0.22.1",
"bytes",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"percent-encoding",
"pin-project",
"prost 0.13.3",
"tokio-stream",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tonic-build"
version = "0.9.2"
@@ -6426,10 +6409,11 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52"
[[package]]
name = "tracing"
version = "0.1.40"
version = "0.1.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [
"cfg-if",
"log",
"pin-project-lite",
"tracing-attributes",
@@ -6449,9 +6433,9 @@ dependencies = [
[[package]]
name = "tracing-attributes"
version = "0.1.27"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
checksum = "0f57e3ca2a01450b1a921183a9c9cbfda207fd822cef4ccb00a65402cbba7a74"
dependencies = [
"proc-macro2",
"quote",
@@ -6460,9 +6444,9 @@ dependencies = [
[[package]]
name = "tracing-core"
version = "0.1.32"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54"
checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a"
dependencies = [
"once_cell",
"valuable",
@@ -6480,22 +6464,21 @@ dependencies = [
[[package]]
name = "tracing-log"
version = "0.2.0"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
checksum = "78ddad33d2d10b1ed7eb9d1f518a5674713876e97e5bb9b7345a7984fbb4f922"
dependencies = [
"lazy_static",
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-opentelemetry"
version = "0.25.0"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9784ed4da7d921bc8df6963f8c80a0e4ce34ba6ba76668acadd3edbd985ff3b"
checksum = "75327c6b667828ddc28f5e3f169036cb793c3f588d83bf0f262a7f062ffed3c8"
dependencies = [
"js-sys",
"once_cell",
"opentelemetry",
"opentelemetry_sdk",
@@ -6504,7 +6487,6 @@ dependencies = [
"tracing-core",
"tracing-log",
"tracing-subscriber",
"web-time",
]
[[package]]
@@ -6519,9 +6501,9 @@ dependencies = [
[[package]]
name = "tracing-subscriber"
version = "0.3.18"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77"
dependencies = [
"matchers",
"once_cell",
@@ -6545,7 +6527,6 @@ dependencies = [
"opentelemetry",
"opentelemetry-otlp",
"opentelemetry-semantic-conventions",
"opentelemetry_sdk",
"tokio",
"tracing",
"tracing-opentelemetry",
@@ -7001,16 +6982,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.25.2"
@@ -7280,6 +7251,7 @@ dependencies = [
"chrono",
"clap",
"clap_builder",
"crossbeam-utils",
"crypto-bigint 0.5.5",
"der 0.7.8",
"deranged",
@@ -7312,12 +7284,13 @@ dependencies = [
"once_cell",
"parquet",
"proc-macro2",
"prost 0.11.9",
"prost",
"quote",
"rand 0.8.5",
"regex",
"regex-automata 0.4.3",
"regex-syntax 0.8.2",
"reqwest 0.11.19",
"reqwest 0.12.4",
"rustls 0.21.11",
"scopeguard",
@@ -7338,9 +7311,12 @@ dependencies = [
"tokio-rustls 0.24.0",
"tokio-util",
"toml_edit",
"tonic",
"tower",
"tracing",
"tracing-core",
"tracing-log",
"tracing-subscriber",
"url",
"uuid",
"zeroize",

View File

@@ -116,10 +116,9 @@ notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
once_cell = "1.13"
opentelemetry = "0.24"
opentelemetry_sdk = "0.24"
opentelemetry-otlp = { version = "0.17", default-features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.16"
opentelemetry = "0.20.0"
opentelemetry-otlp = { version = "0.13.0", default-features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.12.0"
parking_lot = "0.12"
parquet = { version = "53", default-features = false, features = ["zstd"] }
parquet_derive = "53"
@@ -132,7 +131,7 @@ rand = "0.8"
redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_24"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_20"] }
reqwest-middleware = "0.3.0"
reqwest-retry = "0.5"
routerify = "3"
@@ -178,8 +177,8 @@ toml_edit = "0.22"
tonic = {version = "0.9", features = ["tls", "tls-roots"]}
tower-service = "0.3.2"
tracing = "0.1"
tracing-error = "0.2"
tracing-opentelemetry = "0.25"
tracing-error = "0.2.0"
tracing-opentelemetry = "0.21.0"
tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
try-lock = "0.2.5"
twox-hash = { version = "1.6.3", default-features = false }

View File

@@ -42,7 +42,6 @@ COPY --from=pg-build /home/nonroot/pg_install/v17/lib pg_i
COPY --chown=nonroot . .
ARG ADDITIONAL_RUSTFLAGS
ENV _RJEM_MALLOC_CONF="thp:never"
RUN set -e \
&& PQ_LIB_DIR=$(pwd)/pg_install/v${STABLE_PG_VERSION}/lib RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment ${ADDITIONAL_RUSTFLAGS}" cargo build \
--bin pg_sni_router \

View File

@@ -13,9 +13,6 @@ RUN useradd -ms /bin/bash nonroot -b /home
SHELL ["/bin/bash", "-c"]
# System deps
#
# 'gdb' is included so that we get backtraces of core dumps produced in
# regression tests
RUN set -e \
&& apt update \
&& apt install -y \
@@ -27,7 +24,6 @@ RUN set -e \
cmake \
curl \
flex \
gdb \
git \
gnupg \
gzip \

View File

@@ -871,28 +871,6 @@ RUN case "${PG_VERSION}" in "v17") \
cargo pgrx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/ulid.control
#########################################################################################
#
# Layer "pg-session-jwt-build"
# Compile "pg_session_jwt" extension
#
#########################################################################################
FROM rust-extensions-build AS pg-session-jwt-build
ARG PG_VERSION
RUN case "${PG_VERSION}" in "v17") \
echo "pg_session_jwt does not yet have a release that supports pg17" && exit 0;; \
esac && \
wget https://github.com/neondatabase/pg_session_jwt/archive/ff0a72440e8ff584dab24b3f9b7c00c56c660b8e.tar.gz -O pg_session_jwt.tar.gz && \
echo "1fbb2b5a339263bcf6daa847fad8bccbc0b451cea6a62e6d3bf232b0087f05cb pg_session_jwt.tar.gz" | sha256sum --check && \
mkdir pg_session_jwt-src && cd pg_session_jwt-src && tar xzf ../pg_session_jwt.tar.gz --strip-components=1 -C . && \
sed -i 's/pgrx = "=0.11.3"/pgrx = { version = "=0.11.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
cargo pgrx install --release
# it's needed to enable extension because it uses untrusted C language
# sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_session_jwt.control && \
# echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_session_jwt.control
#########################################################################################
#
# Layer "wal2json-build"
@@ -989,7 +967,6 @@ COPY --from=timescaledb-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-hint-plan-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-cron-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-pgx-ulid-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-session-jwt-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=rdkit-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-uuidv7-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
@@ -1281,7 +1258,7 @@ RUN apt update && \
libxml2 \
libxslt1.1 \
libzstd1 \
libcurl4 \
libcurl4-openssl-dev \
locales \
procps \
ca-certificates \

View File

@@ -11,10 +11,6 @@ commands:
user: root
sysvInitAction: sysinit
shell: 'chmod 711 /neonvm/bin/resize-swap'
- name: chmod-set-disk-quota
user: root
sysvInitAction: sysinit
shell: 'chmod 711 /neonvm/bin/set-disk-quota'
- name: pgbouncer
user: postgres
sysvInitAction: respawn
@@ -34,12 +30,11 @@ commands:
shutdownHook: |
su -p postgres --session-command '/usr/local/bin/pg_ctl stop -D /var/db/postgres/compute/pgdata -m fast --wait -t 10'
files:
- filename: compute_ctl-sudoers
- filename: compute_ctl-resize-swap
content: |
# Allow postgres user (which is what compute_ctl runs as) to run /neonvm/bin/resize-swap
# and /neonvm/bin/set-disk-quota as root without requiring entering a password (NOPASSWD),
# regardless of hostname (ALL)
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap, /neonvm/bin/set-disk-quota
# as root without requiring entering a password (NOPASSWD), regardless of hostname (ALL)
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap
- filename: cgconfig.conf
content: |
# Configuration for cgroups in VM compute nodes
@@ -105,7 +100,7 @@ merge: |
&& apt install --no-install-recommends -y \
sudo \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
COPY compute_ctl-sudoers /etc/sudoers.d/compute_ctl-sudoers
COPY compute_ctl-resize-swap /etc/sudoers.d/compute_ctl-resize-swap
COPY cgconfig.conf /etc/cgconfig.conf

View File

@@ -21,7 +21,6 @@ nix.workspace = true
notify.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
opentelemetry_sdk.workspace = true
postgres.workspace = true
regex.workspace = true
serde_json.workspace = true

View File

@@ -44,7 +44,6 @@ use std::{thread, time::Duration};
use anyhow::{Context, Result};
use chrono::Utc;
use clap::Arg;
use compute_tools::disk_quota::set_disk_quota;
use compute_tools::lsn_lease::launch_lsn_lease_bg_task_for_static;
use signal_hook::consts::{SIGQUIT, SIGTERM};
use signal_hook::{consts::SIGINT, iterator::Signals};
@@ -152,7 +151,6 @@ fn process_cli(matches: &clap::ArgMatches) -> Result<ProcessCliResult> {
let spec_json = matches.get_one::<String>("spec");
let spec_path = matches.get_one::<String>("spec-path");
let resize_swap_on_bind = matches.get_flag("resize-swap-on-bind");
let set_disk_quota_for_fs = matches.get_one::<String>("set-disk-quota-for-fs");
Ok(ProcessCliResult {
connstr,
@@ -163,7 +161,6 @@ fn process_cli(matches: &clap::ArgMatches) -> Result<ProcessCliResult> {
spec_json,
spec_path,
resize_swap_on_bind,
set_disk_quota_for_fs,
})
}
@@ -176,7 +173,6 @@ struct ProcessCliResult<'clap> {
spec_json: Option<&'clap String>,
spec_path: Option<&'clap String>,
resize_swap_on_bind: bool,
set_disk_quota_for_fs: Option<&'clap String>,
}
fn startup_context_from_env() -> Option<opentelemetry::ContextGuard> {
@@ -218,7 +214,7 @@ fn startup_context_from_env() -> Option<opentelemetry::ContextGuard> {
}
if !startup_tracing_carrier.is_empty() {
use opentelemetry::propagation::TextMapPropagator;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use opentelemetry::sdk::propagation::TraceContextPropagator;
let guard = TraceContextPropagator::new()
.extract(&startup_tracing_carrier)
.attach();
@@ -297,7 +293,6 @@ fn wait_spec(
pgbin,
ext_remote_storage,
resize_swap_on_bind,
set_disk_quota_for_fs,
http_port,
..
}: ProcessCliResult,
@@ -378,7 +373,6 @@ fn wait_spec(
compute,
http_port,
resize_swap_on_bind,
set_disk_quota_for_fs: set_disk_quota_for_fs.cloned(),
})
}
@@ -387,7 +381,6 @@ struct WaitSpecResult {
// passed through from ProcessCliResult
http_port: u16,
resize_swap_on_bind: bool,
set_disk_quota_for_fs: Option<String>,
}
fn start_postgres(
@@ -397,7 +390,6 @@ fn start_postgres(
compute,
http_port,
resize_swap_on_bind,
set_disk_quota_for_fs,
}: WaitSpecResult,
) -> Result<(Option<PostgresHandle>, StartPostgresResult)> {
// We got all we need, update the state.
@@ -411,7 +403,6 @@ fn start_postgres(
);
// before we release the mutex, fetch the swap size (if any) for later.
let swap_size_bytes = state.pspec.as_ref().unwrap().spec.swap_size_bytes;
let disk_quota_bytes = state.pspec.as_ref().unwrap().spec.disk_quota_bytes;
drop(state);
// Launch remaining service threads
@@ -431,8 +422,8 @@ fn start_postgres(
// OOM-killed during startup because swap wasn't available yet.
match resize_swap(size_bytes) {
Ok(()) => {
let size_mib = size_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
info!(%size_bytes, %size_mib, "resized swap");
let size_gib = size_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
info!(%size_bytes, %size_gib, "resized swap");
}
Err(err) => {
let err = err.context("failed to resize swap");
@@ -441,29 +432,10 @@ fn start_postgres(
// Mark compute startup as failed; don't try to start postgres, and report this
// error to the control plane when it next asks.
prestartup_failed = true;
compute.set_failed_status(err);
delay_exit = true;
}
}
}
// Set disk quota if the compute spec says so
if let (Some(disk_quota_bytes), Some(disk_quota_fs_mountpoint)) =
(disk_quota_bytes, set_disk_quota_for_fs)
{
match set_disk_quota(disk_quota_bytes, &disk_quota_fs_mountpoint) {
Ok(()) => {
let size_mib = disk_quota_bytes as f32 / (1 << 20) as f32; // just for more coherent display.
info!(%disk_quota_bytes, %size_mib, "set disk quota");
}
Err(err) => {
let err = err.context("failed to set disk quota");
error!("{err:#}");
// Mark compute startup as failed; don't try to start postgres, and report this
// error to the control plane when it next asks.
prestartup_failed = true;
compute.set_failed_status(err);
let mut state = compute.state.lock().unwrap();
state.error = Some(format!("{err:?}"));
state.status = ComputeStatus::Failed;
compute.state_changed.notify_all();
delay_exit = true;
}
}
@@ -478,7 +450,16 @@ fn start_postgres(
Ok(pg) => Some(pg),
Err(err) => {
error!("could not start the compute node: {:#}", err);
compute.set_failed_status(err);
let mut state = compute.state.lock().unwrap();
state.error = Some(format!("{:?}", err));
state.status = ComputeStatus::Failed;
// Notify others that Postgres failed to start. In case of configuring the
// empty compute, it's likely that API handler is still waiting for compute
// state change. With this we will notify it that compute is in Failed state,
// so control plane will know about it earlier and record proper error instead
// of timeout.
compute.state_changed.notify_all();
drop(state); // unlock
delay_exit = true;
None
}
@@ -769,11 +750,6 @@ fn cli() -> clap::Command {
.long("resize-swap-on-bind")
.action(clap::ArgAction::SetTrue),
)
.arg(
Arg::new("set-disk-quota-for-fs")
.long("set-disk-quota-for-fs")
.value_name("SET_DISK_QUOTA_FOR_FS")
)
}
/// When compute_ctl is killed, send also termination signal to sync-safekeepers

View File

@@ -10,7 +10,6 @@ use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::sync::{Condvar, Mutex, RwLock};
use std::thread;
use std::time::Duration;
use std::time::Instant;
use anyhow::{Context, Result};
@@ -306,13 +305,6 @@ impl ComputeNode {
self.state_changed.notify_all();
}
pub fn set_failed_status(&self, err: anyhow::Error) {
let mut state = self.state.lock().unwrap();
state.error = Some(format!("{err:?}"));
state.status = ComputeStatus::Failed;
self.state_changed.notify_all();
}
pub fn get_status(&self) -> ComputeStatus {
self.state.lock().unwrap().status
}
@@ -718,7 +710,7 @@ impl ComputeNode {
info!("running initdb");
let initdb_bin = Path::new(&self.pgbin).parent().unwrap().join("initdb");
Command::new(initdb_bin)
.args(["--pgdata", pgdata])
.args(["-D", pgdata])
.output()
.expect("cannot start initdb process");
@@ -1131,9 +1123,6 @@ impl ComputeNode {
//
// Use that as a default location and pattern, except macos where core dumps are written
// to /cores/ directory by default.
//
// With default Linux settings, the core dump file is called just "core", so check for
// that too.
pub fn check_for_core_dumps(&self) -> Result<()> {
let core_dump_dir = match std::env::consts::OS {
"macos" => Path::new("/cores/"),
@@ -1145,17 +1134,8 @@ impl ComputeNode {
let files = fs::read_dir(core_dump_dir)?;
let cores = files.filter_map(|entry| {
let entry = entry.ok()?;
let is_core_dump = match entry.file_name().to_str()? {
n if n.starts_with("core.") => true,
"core" => true,
_ => false,
};
if is_core_dump {
Some(entry.path())
} else {
None
}
let _ = entry.file_name().to_str()?.strip_prefix("core.")?;
Some(entry.path())
});
// Print backtrace for each core dump
@@ -1406,36 +1386,6 @@ LIMIT 100",
}
Ok(remote_ext_metrics)
}
/// Waits until current thread receives a state changed notification and
/// the pageserver connection strings has changed.
///
/// The operation will time out after a specified duration.
pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) {
let state = self.state.lock().unwrap();
let old_pageserver_connstr = state
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_connstr
.clone();
let mut unchanged = true;
let _ = self
.state_changed
.wait_timeout_while(state, duration, |s| {
let pageserver_connstr = &s
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_connstr;
unchanged = pageserver_connstr == &old_pageserver_connstr;
unchanged
})
.unwrap();
if !unchanged {
info!("Pageserver config changed");
}
}
}
pub fn forward_termination_signal() {

View File

@@ -11,17 +11,9 @@ use crate::compute::ComputeNode;
fn configurator_main_loop(compute: &Arc<ComputeNode>) {
info!("waiting for reconfiguration requests");
loop {
let mut state = compute.state.lock().unwrap();
let state = compute.state.lock().unwrap();
let mut state = compute.state_changed.wait(state).unwrap();
// We have to re-check the status after re-acquiring the lock because it could be that
// the status has changed while we were waiting for the lock, and we might not need to
// wait on the condition variable. Otherwise, we might end up in some soft-/deadlock, i.e.
// we are waiting for a condition variable that will never be signaled.
if state.status != ComputeStatus::ConfigurationPending {
state = compute.state_changed.wait(state).unwrap();
}
// Re-check the status after waking up
if state.status == ComputeStatus::ConfigurationPending {
info!("got configuration request");
state.status = ComputeStatus::Configuration;

View File

@@ -1,25 +0,0 @@
use anyhow::Context;
pub const DISK_QUOTA_BIN: &str = "/neonvm/bin/set-disk-quota";
/// If size_bytes is 0, it disables the quota. Otherwise, it sets filesystem quota to size_bytes.
/// `fs_mountpoint` should point to the mountpoint of the filesystem where the quota should be set.
pub fn set_disk_quota(size_bytes: u64, fs_mountpoint: &str) -> anyhow::Result<()> {
let size_kb = size_bytes / 1024;
// run `/neonvm/bin/set-disk-quota {size_kb} {mountpoint}`
let child_result = std::process::Command::new("/usr/bin/sudo")
.arg(DISK_QUOTA_BIN)
.arg(size_kb.to_string())
.arg(fs_mountpoint)
.spawn();
child_result
.context("spawn() failed")
.and_then(|mut child| child.wait().context("wait() failed"))
.and_then(|status| match status.success() {
true => Ok(()),
false => Err(anyhow::anyhow!("process exited with {status}")),
})
// wrap any prior error with the overall context that we couldn't run the command
.with_context(|| format!("could not run `/usr/bin/sudo {DISK_QUOTA_BIN}`"))
}

View File

@@ -10,7 +10,6 @@ pub mod http;
pub mod logger;
pub mod catalog;
pub mod compute;
pub mod disk_quota;
pub mod extension_server;
pub mod lsn_lease;
mod migration;

View File

@@ -1,3 +1,4 @@
use tracing_opentelemetry::OpenTelemetryLayer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::prelude::*;
@@ -22,7 +23,8 @@ pub fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result<()> {
.with_writer(std::io::stderr);
// Initialize OpenTelemetry
let otlp_layer = tracing_utils::init_tracing_without_runtime("compute_ctl");
let otlp_layer =
tracing_utils::init_tracing_without_runtime("compute_ctl").map(OpenTelemetryLayer::new);
// Put it all together
tracing_subscriber::registry()

View File

@@ -57,10 +57,10 @@ fn lsn_lease_bg_task(
.max(valid_duration / 2);
info!(
"Request succeeded, sleeping for {} seconds",
"Succeeded, sleeping for {} seconds",
sleep_duration.as_secs()
);
compute.wait_timeout_while_pageserver_connstr_unchanged(sleep_duration);
thread::sleep(sleep_duration);
}
}
@@ -89,7 +89,10 @@ fn acquire_lsn_lease_with_retry(
.map(|connstr| {
let mut config = postgres::Config::from_str(connstr).expect("Invalid connstr");
if let Some(storage_auth_token) = &spec.storage_auth_token {
info!("Got storage auth token from spec file");
config.password(storage_auth_token.clone());
} else {
info!("Storage auth token not set");
}
config
})
@@ -105,11 +108,9 @@ fn acquire_lsn_lease_with_retry(
bail!("Permanent error: lease could not be obtained, LSN is behind the GC cutoff");
}
Err(e) => {
warn!("Failed to acquire lsn lease: {e} (attempt {attempts})");
warn!("Failed to acquire lsn lease: {e} (attempt {attempts}");
compute.wait_timeout_while_pageserver_connstr_unchanged(Duration::from_millis(
retry_period_ms as u64,
));
thread::sleep(Duration::from_millis(retry_period_ms as u64));
retry_period_ms *= 1.5;
retry_period_ms = retry_period_ms.min(MAX_RETRY_PERIOD_MS);
}

View File

@@ -9,7 +9,6 @@ anyhow.workspace = true
camino.workspace = true
clap.workspace = true
comfy-table.workspace = true
futures.workspace = true
humantime.workspace = true
nix.workspace = true
once_cell.workspace = true

File diff suppressed because it is too large Load Diff

View File

@@ -1,94 +0,0 @@
//! Branch mappings for convenience
use std::collections::HashMap;
use std::fs;
use std::path::Path;
use anyhow::{bail, Context};
use serde::{Deserialize, Serialize};
use utils::id::{TenantId, TenantTimelineId, TimelineId};
/// Keep human-readable aliases in memory (and persist them to config XXX), to hide tenant/timeline hex strings from the user.
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct BranchMappings {
/// Default tenant ID to use with the 'neon_local' command line utility, when
/// --tenant_id is not explicitly specified. This comes from the branches.
pub default_tenant_id: Option<TenantId>,
// A `HashMap<String, HashMap<TenantId, TimelineId>>` would be more appropriate here,
// but deserialization into a generic toml object as `toml::Value::try_from` fails with an error.
// https://toml.io/en/v1.0.0 does not contain a concept of "a table inside another table".
pub mappings: HashMap<String, Vec<(TenantId, TimelineId)>>,
}
impl BranchMappings {
pub fn register_branch_mapping(
&mut self,
branch_name: String,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> anyhow::Result<()> {
let existing_values = self.mappings.entry(branch_name.clone()).or_default();
let existing_ids = existing_values
.iter()
.find(|(existing_tenant_id, _)| existing_tenant_id == &tenant_id);
if let Some((_, old_timeline_id)) = existing_ids {
if old_timeline_id == &timeline_id {
Ok(())
} else {
bail!("branch '{branch_name}' is already mapped to timeline {old_timeline_id}, cannot map to another timeline {timeline_id}");
}
} else {
existing_values.push((tenant_id, timeline_id));
Ok(())
}
}
pub fn get_branch_timeline_id(
&self,
branch_name: &str,
tenant_id: TenantId,
) -> Option<TimelineId> {
// If it looks like a timeline ID, return it as it is
if let Ok(timeline_id) = branch_name.parse::<TimelineId>() {
return Some(timeline_id);
}
self.mappings
.get(branch_name)?
.iter()
.find(|(mapped_tenant_id, _)| mapped_tenant_id == &tenant_id)
.map(|&(_, timeline_id)| timeline_id)
.map(TimelineId::from)
}
pub fn timeline_name_mappings(&self) -> HashMap<TenantTimelineId, String> {
self.mappings
.iter()
.flat_map(|(name, tenant_timelines)| {
tenant_timelines.iter().map(|&(tenant_id, timeline_id)| {
(TenantTimelineId::new(tenant_id, timeline_id), name.clone())
})
})
.collect()
}
pub fn persist(&self, path: &Path) -> anyhow::Result<()> {
let content = &toml::to_string_pretty(self)?;
fs::write(path, content).with_context(|| {
format!(
"Failed to write branch information into path '{}'",
path.display()
)
})
}
pub fn load(path: &Path) -> anyhow::Result<BranchMappings> {
let branches_file_contents = fs::read_to_string(path)?;
Ok(toml::from_str(branches_file_contents.as_str())?)
}
}

View File

@@ -561,7 +561,6 @@ impl Endpoint {
operation_uuid: None,
features: self.features.clone(),
swap_size_bytes: None,
disk_quota_bytes: None,
cluster: Cluster {
cluster_id: None, // project ID: not used
name: None, // project name: not used

View File

@@ -113,7 +113,7 @@ impl SafekeeperNode {
pub async fn start(
&self,
extra_opts: &[String],
extra_opts: Vec<String>,
retry_timeout: &Duration,
) -> anyhow::Result<()> {
print!(
@@ -196,7 +196,7 @@ impl SafekeeperNode {
]);
}
args.extend_from_slice(extra_opts);
args.extend(extra_opts);
background_process::start_process(
&format!("safekeeper-{id}"),

View File

@@ -347,7 +347,7 @@ impl StorageController {
if !tokio::fs::try_exists(&pg_data_path).await? {
let initdb_args = [
"--pgdata",
"-D",
pg_data_path.as_ref(),
"--username",
&username(),

View File

@@ -50,16 +50,6 @@ pub struct ComputeSpec {
#[serde(default)]
pub swap_size_bytes: Option<u64>,
/// If compute_ctl was passed `--set-disk-quota-for-fs`, a value of `Some(_)` instructs
/// compute_ctl to run `/neonvm/bin/set-disk-quota` with the given size and fs, when the
/// spec is first received.
///
/// Both this field and `--set-disk-quota-for-fs` are required, so that the control plane's
/// spec generation doesn't need to be aware of the actual compute it's running on, while
/// guaranteeing gradual rollout of disk quota.
#[serde(default)]
pub disk_quota_bytes: Option<u64>,
/// Expected cluster state at the end of transition process.
pub cluster: Cluster,
pub delta_operations: Option<Vec<DeltaOp>>,
@@ -278,22 +268,6 @@ pub struct GenericOption {
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;
/// Configured the local-proxy application with the relevant JWKS and roles it should
/// use for authorizing connect requests using JWT.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct LocalProxySpec {
pub jwks: Vec<JwksSettings>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct JwksSettings {
pub id: String,
pub role_names: Vec<String>,
pub jwks_url: String,
pub provider_name: String,
pub jwt_audience: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -984,7 +984,6 @@ pub fn short_error(e: &QueryError) -> String {
}
fn log_query_error(query: &str, e: &QueryError) {
// If you want to change the log level of a specific error, also re-categorize it in `BasebackupQueryTimeOngoingRecording`.
match e {
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
if is_expected_io_error(io_error) {

View File

@@ -93,9 +93,9 @@ impl Conf {
);
let output = self
.new_pg_command("initdb")?
.arg("--pgdata")
.arg("-D")
.arg(&self.datadir)
.args(["--username", "postgres", "--no-instructions", "--no-sync"])
.args(["-U", "postgres", "--no-instructions", "--no-sync"])
.output()?;
debug!("initdb output: {:?}", output);
ensure!(

View File

@@ -6,14 +6,12 @@ license.workspace = true
[dependencies]
hyper.workspace = true
opentelemetry = { workspace = true, features = ["trace"] }
opentelemetry_sdk = { workspace = true, features = ["rt-tokio"] }
opentelemetry-otlp = { workspace = true, default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry = { workspace = true, features=["rt-tokio"] }
opentelemetry-otlp = { workspace = true, default-features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tracing.workspace = true
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
[dev-dependencies]
tracing-subscriber.workspace = true # For examples in docs

View File

@@ -10,6 +10,7 @@
//!
//! ```rust,no_run
//! use tracing_subscriber::prelude::*;
//! use tracing_opentelemetry::OpenTelemetryLayer;
//!
//! #[tokio::main]
//! async fn main() {
@@ -21,7 +22,7 @@
//! .with_writer(std::io::stderr);
//!
//! // Initialize OpenTelemetry. Exports tracing spans as OpenTelemetry traces
//! let otlp_layer = tracing_utils::init_tracing("my_application").await;
//! let otlp_layer = tracing_utils::init_tracing("my_application").await.map(OpenTelemetryLayer::new);
//!
//! // Put it all together
//! tracing_subscriber::registry()
@@ -34,14 +35,14 @@
#![deny(unsafe_code)]
#![deny(clippy::undocumented_unsafe_blocks)]
pub mod http;
use opentelemetry::trace::TracerProvider;
use opentelemetry::sdk::Resource;
use opentelemetry::KeyValue;
use opentelemetry_sdk::Resource;
use tracing::Subscriber;
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::Layer;
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_otlp::{OTEL_EXPORTER_OTLP_ENDPOINT, OTEL_EXPORTER_OTLP_TRACES_ENDPOINT};
pub use tracing_opentelemetry::OpenTelemetryLayer;
pub mod http;
/// Set up OpenTelemetry exporter, using configuration from environment variables.
///
@@ -70,10 +71,7 @@ use tracing_subscriber::Layer;
///
/// This doesn't block, but is marked as 'async' to hint that this must be called in
/// asynchronous execution context.
pub async fn init_tracing<S>(service_name: &str) -> Option<impl Layer<S>>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
pub async fn init_tracing(service_name: &str) -> Option<opentelemetry::sdk::trace::Tracer> {
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
@@ -82,10 +80,9 @@ where
/// Like `init_tracing`, but creates a separate tokio Runtime for the tracing
/// tasks.
pub fn init_tracing_without_runtime<S>(service_name: &str) -> Option<impl Layer<S>>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
pub fn init_tracing_without_runtime(
service_name: &str,
) -> Option<opentelemetry::sdk::trace::Tracer> {
if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) {
return None;
};
@@ -116,36 +113,54 @@ where
Some(init_tracing_internal(service_name.to_string()))
}
fn init_tracing_internal<S>(service_name: String) -> impl Layer<S>
where
S: Subscriber + for<'span> LookupSpan<'span>,
{
// Sets up exporter from the OTEL_EXPORTER_* environment variables.
let exporter = opentelemetry_otlp::new_exporter().http();
fn init_tracing_internal(service_name: String) -> opentelemetry::sdk::trace::Tracer {
// Set up exporter from the OTEL_EXPORTER_* environment variables
let mut exporter = opentelemetry_otlp::new_exporter().http().with_env();
// TODO: opentelemetry::global::set_error_handler() with custom handler that
// bypasses default tracing layers, but logs regular looking log
// messages.
// XXX opentelemetry-otlp v0.18.0 has a bug in how it uses the
// OTEL_EXPORTER_OTLP_ENDPOINT env variable. According to the
// OpenTelemetry spec at
// <https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md#endpoint-urls-for-otlphttp>,
// the full exporter URL is formed by appending "/v1/traces" to the value
// of OTEL_EXPORTER_OTLP_ENDPOINT. However, opentelemetry-otlp only does
// that with the grpc-tonic exporter. Other exporters, like the HTTP
// exporter, use the URL from OTEL_EXPORTER_OTLP_ENDPOINT as is, without
// appending "/v1/traces".
//
// See https://github.com/open-telemetry/opentelemetry-rust/pull/950
//
// Work around that by checking OTEL_EXPORTER_OTLP_ENDPOINT, and setting
// the endpoint url with the "/v1/traces" path ourselves. If the bug is
// fixed in a later version, we can remove this code. But if we don't
// remember to remove this, it won't do any harm either, as the crate will
// just ignore the OTEL_EXPORTER_OTLP_ENDPOINT setting when the endpoint
// is set directly with `with_endpoint`.
if std::env::var(OTEL_EXPORTER_OTLP_TRACES_ENDPOINT).is_err() {
if let Ok(mut endpoint) = std::env::var(OTEL_EXPORTER_OTLP_ENDPOINT) {
if !endpoint.ends_with('/') {
endpoint.push('/');
}
endpoint.push_str("v1/traces");
exporter = exporter.with_endpoint(endpoint);
}
}
// Propagate trace information in the standard W3C TraceContext format.
opentelemetry::global::set_text_map_propagator(
opentelemetry_sdk::propagation::TraceContextPropagator::new(),
opentelemetry::sdk::propagation::TraceContextPropagator::new(),
);
let tracer = opentelemetry_otlp::new_pipeline()
opentelemetry_otlp::new_pipeline()
.tracing()
.with_exporter(exporter)
.with_trace_config(opentelemetry_sdk::trace::Config::default().with_resource(
Resource::new(vec![KeyValue::new(
.with_trace_config(
opentelemetry::sdk::trace::config().with_resource(Resource::new(vec![KeyValue::new(
opentelemetry_semantic_conventions::resource::SERVICE_NAME,
service_name,
)]),
))
.install_batch(opentelemetry_sdk::runtime::Tokio)
)])),
)
.install_batch(opentelemetry::runtime::Tokio)
.expect("could not initialize opentelemetry exporter")
.tracer("global");
tracing_opentelemetry::layer().with_tracer(tracer)
}
// Shutdown trace pipeline gracefully, so that it has a chance to send any

View File

@@ -736,22 +736,4 @@ impl Client {
.await
.map_err(Error::ReceiveBody)
}
pub async fn timeline_init_lsn_lease(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Result<LsnLease> {
let uri = format!(
"{}/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/lsn_lease",
self.mgmt_api_endpoint,
);
self.request(Method::POST, &uri, LsnLeaseRequest { lsn })
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
}

View File

@@ -15,7 +15,7 @@ use clap::{Arg, ArgAction, Command};
use metrics::launch_timestamp::{set_launch_timestamp_metric, LaunchTimestamp};
use pageserver::config::PageserverIdentity;
use pageserver::controller_upcall_client::ControllerUpcallClient;
use pageserver::control_plane_client::ControlPlaneClient;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::task_mgr::{COMPUTE_REQUEST_RUNTIME, WALRECEIVER_RUNTIME};
@@ -396,7 +396,7 @@ fn start_pageserver(
// Set up deletion queue
let (deletion_queue, deletion_workers) = DeletionQueue::new(
remote_storage.clone(),
ControllerUpcallClient::new(conf, &shutdown_pageserver),
ControlPlaneClient::new(conf, &shutdown_pageserver),
conf,
);
if let Some(deletion_workers) = deletion_workers {

View File

@@ -17,12 +17,9 @@ use utils::{backoff, failpoint_support, generation::Generation, id::NodeId};
use crate::{config::PageServerConf, virtual_file::on_fatal_io_error};
use pageserver_api::config::NodeMetadata;
/// The Pageserver's client for using the storage controller upcall API: this is a small API
/// for dealing with generations (see docs/rfcs/025-generation-numbers.md).
///
/// The server presenting this API may either be the storage controller or some other
/// service (such as the Neon control plane) providing a store of generation numbers.
pub struct ControllerUpcallClient {
/// The Pageserver's client for using the control plane API: this is a small subset
/// of the overall control plane API, for dealing with generations (see docs/rfcs/025-generation-numbers.md)
pub struct ControlPlaneClient {
http_client: reqwest::Client,
base_url: Url,
node_id: NodeId,
@@ -48,7 +45,7 @@ pub trait ControlPlaneGenerationsApi {
) -> impl Future<Output = Result<HashMap<TenantShardId, bool>, RetryForeverError>> + Send;
}
impl ControllerUpcallClient {
impl ControlPlaneClient {
/// A None return value indicates that the input `conf` object does not have control
/// plane API enabled.
pub fn new(conf: &'static PageServerConf, cancel: &CancellationToken) -> Option<Self> {
@@ -117,7 +114,7 @@ impl ControllerUpcallClient {
}
}
impl ControlPlaneGenerationsApi for ControllerUpcallClient {
impl ControlPlaneGenerationsApi for ControlPlaneClient {
/// Block until we get a successful response, or error out if we are shut down
async fn re_attach(
&self,
@@ -219,38 +216,29 @@ impl ControlPlaneGenerationsApi for ControllerUpcallClient {
.join("validate")
.expect("Failed to build validate path");
// When sending validate requests, break them up into chunks so that we
// avoid possible edge cases of generating any HTTP requests that
// require database I/O across many thousands of tenants.
let mut result: HashMap<TenantShardId, bool> = HashMap::with_capacity(tenants.len());
for tenant_chunk in (tenants).chunks(128) {
let request = ValidateRequest {
tenants: tenant_chunk
.iter()
.map(|(id, generation)| ValidateRequestTenant {
id: *id,
gen: (*generation).into().expect(
"Generation should always be valid for a Tenant doing deletions",
),
})
.collect(),
};
let request = ValidateRequest {
tenants: tenants
.into_iter()
.map(|(id, gen)| ValidateRequestTenant {
id,
gen: gen
.into()
.expect("Generation should always be valid for a Tenant doing deletions"),
})
.collect(),
};
failpoint_support::sleep_millis_async!(
"control-plane-client-validate-sleep",
&self.cancel
);
if self.cancel.is_cancelled() {
return Err(RetryForeverError::ShuttingDown);
}
let response: ValidateResponse =
self.retry_http_forever(&re_attach_path, request).await?;
for rt in response.tenants {
result.insert(rt.id, rt.valid);
}
failpoint_support::sleep_millis_async!("control-plane-client-validate-sleep", &self.cancel);
if self.cancel.is_cancelled() {
return Err(RetryForeverError::ShuttingDown);
}
Ok(result.into_iter().collect())
let response: ValidateResponse = self.retry_http_forever(&re_attach_path, request).await?;
Ok(response
.tenants
.into_iter()
.map(|rt| (rt.id, rt.valid))
.collect())
}
}

View File

@@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use crate::controller_upcall_client::ControlPlaneGenerationsApi;
use crate::control_plane_client::ControlPlaneGenerationsApi;
use crate::metrics;
use crate::tenant::remote_timeline_client::remote_layer_path;
use crate::tenant::remote_timeline_client::remote_timeline_path;
@@ -622,7 +622,7 @@ impl DeletionQueue {
/// If remote_storage is None, then the returned workers will also be None.
pub fn new<C>(
remote_storage: GenericRemoteStorage,
controller_upcall_client: Option<C>,
control_plane_client: Option<C>,
conf: &'static PageServerConf,
) -> (Self, Option<DeletionQueueWorkers<C>>)
where
@@ -662,7 +662,7 @@ impl DeletionQueue {
conf,
backend_rx,
executor_tx,
controller_upcall_client,
control_plane_client,
lsn_table.clone(),
cancel.clone(),
),
@@ -704,7 +704,7 @@ mod test {
use tokio::task::JoinHandle;
use crate::{
controller_upcall_client::RetryForeverError,
control_plane_client::RetryForeverError,
repository::Key,
tenant::{harness::TenantHarness, storage_layer::DeltaLayerName},
};

View File

@@ -25,8 +25,8 @@ use tracing::info;
use tracing::warn;
use crate::config::PageServerConf;
use crate::controller_upcall_client::ControlPlaneGenerationsApi;
use crate::controller_upcall_client::RetryForeverError;
use crate::control_plane_client::ControlPlaneGenerationsApi;
use crate::control_plane_client::RetryForeverError;
use crate::metrics;
use crate::virtual_file::MaybeFatalIo;
@@ -61,7 +61,7 @@ where
tx: tokio::sync::mpsc::Sender<DeleterMessage>,
// Client for calling into control plane API for validation of deletes
controller_upcall_client: Option<C>,
control_plane_client: Option<C>,
// DeletionLists which are waiting generation validation. Not safe to
// execute until [`validate`] has processed them.
@@ -94,7 +94,7 @@ where
conf: &'static PageServerConf,
rx: tokio::sync::mpsc::Receiver<ValidatorQueueMessage>,
tx: tokio::sync::mpsc::Sender<DeleterMessage>,
controller_upcall_client: Option<C>,
control_plane_client: Option<C>,
lsn_table: Arc<std::sync::RwLock<VisibleLsnUpdates>>,
cancel: CancellationToken,
) -> Self {
@@ -102,7 +102,7 @@ where
conf,
rx,
tx,
controller_upcall_client,
control_plane_client,
lsn_table,
pending_lists: Vec::new(),
validated_lists: Vec::new(),
@@ -145,8 +145,8 @@ where
return Ok(());
}
let tenants_valid = if let Some(controller_upcall_client) = &self.controller_upcall_client {
match controller_upcall_client
let tenants_valid = if let Some(control_plane_client) = &self.control_plane_client {
match control_plane_client
.validate(tenant_generations.iter().map(|(k, v)| (*k, *v)).collect())
.await
{

View File

@@ -56,7 +56,6 @@ use utils::http::endpoint::request_span;
use utils::http::request::must_parse_query_param;
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
use crate::config::PageServerConf;
use crate::context::{DownloadBehavior, RequestContext};
use crate::deletion_queue::DeletionQueueClient;
use crate::pgdatadir_mapping::LsnForTimestamp;
@@ -81,6 +80,7 @@ use crate::tenant::timeline::CompactionError;
use crate::tenant::timeline::Timeline;
use crate::tenant::GetTimelineError;
use crate::tenant::{LogicalSizeCalculationCause, PageReconstructError};
use crate::{config::PageServerConf, tenant::mgr};
use crate::{disk_usage_eviction_task, tenant};
use pageserver_api::models::{
StatusResponse, TenantConfigRequest, TenantInfo, TimelineCreateRequest, TimelineGcRequest,
@@ -824,7 +824,7 @@ async fn get_lsn_by_timestamp_handler(
let lease = if with_lease {
timeline
.init_lsn_lease(lsn, timeline.get_lsn_lease_length_for_ts(), &ctx)
.make_lsn_lease(lsn, timeline.get_lsn_lease_length_for_ts(), &ctx)
.inspect_err(|_| {
warn!("fail to grant a lease to {}", lsn);
})
@@ -1692,18 +1692,9 @@ async fn lsn_lease_handler(
let timeline =
active_timeline_of_active_tenant(&state.tenant_manager, tenant_shard_id, timeline_id)
.await?;
let result = async {
timeline
.init_lsn_lease(lsn, timeline.get_lsn_lease_length(), &ctx)
.map_err(|e| {
ApiError::InternalServerError(
e.context(format!("invalid lsn lease request at {lsn}")),
)
})
}
.instrument(info_span!("init_lsn_lease", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), %timeline_id))
.await?;
let result = timeline
.make_lsn_lease(lsn, timeline.get_lsn_lease_length(), &ctx)
.map_err(|e| ApiError::InternalServerError(e.context("lsn lease http handler")))?;
json_response(StatusCode::OK, result)
}
@@ -1719,13 +1710,8 @@ async fn timeline_gc_handler(
let gc_req: TimelineGcRequest = json_request(&mut request).await?;
let state = get_state(&request);
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let gc_result = state
.tenant_manager
.immediate_gc(tenant_shard_id, timeline_id, gc_req, cancel, &ctx)
.await?;
let gc_result = mgr::immediate_gc(tenant_shard_id, timeline_id, gc_req, cancel, &ctx).await?;
json_response(StatusCode::OK, gc_result)
}

View File

@@ -6,7 +6,7 @@ pub mod basebackup;
pub mod config;
pub mod consumption_metrics;
pub mod context;
pub mod controller_upcall_client;
pub mod control_plane_client;
pub mod deletion_queue;
pub mod disk_usage_eviction_task;
pub mod http;

View File

@@ -8,8 +8,6 @@ use metrics::{
};
use once_cell::sync::Lazy;
use pageserver_api::shard::TenantShardId;
use postgres_backend::{is_expected_io_error, QueryError};
use pq_proto::framed::ConnectionError;
use strum::{EnumCount, VariantNames};
use strum_macros::{IntoStaticStr, VariantNames};
use tracing::warn;
@@ -1510,7 +1508,6 @@ static COMPUTE_STARTUP_BUCKETS: Lazy<[f64; 28]> = Lazy::new(|| {
pub(crate) struct BasebackupQueryTime {
ok: Histogram,
error: Histogram,
client_error: Histogram,
}
pub(crate) static BASEBACKUP_QUERY_TIME: Lazy<BasebackupQueryTime> = Lazy::new(|| {
@@ -1524,7 +1521,6 @@ pub(crate) static BASEBACKUP_QUERY_TIME: Lazy<BasebackupQueryTime> = Lazy::new(|
BasebackupQueryTime {
ok: vec.get_metric_with_label_values(&["ok"]).unwrap(),
error: vec.get_metric_with_label_values(&["error"]).unwrap(),
client_error: vec.get_metric_with_label_values(&["client_error"]).unwrap(),
}
});
@@ -1561,7 +1557,7 @@ impl BasebackupQueryTime {
}
impl<'a, 'c> BasebackupQueryTimeOngoingRecording<'a, 'c> {
pub(crate) fn observe<T>(self, res: &Result<T, QueryError>) {
pub(crate) fn observe<T, E>(self, res: &Result<T, E>) {
let elapsed = self.start.elapsed();
let ex_throttled = self
.ctx
@@ -1580,15 +1576,10 @@ impl<'a, 'c> BasebackupQueryTimeOngoingRecording<'a, 'c> {
elapsed
}
};
// If you want to change categorize of a specific error, also change it in `log_query_error`.
let metric = match res {
Ok(_) => &self.parent.ok,
Err(QueryError::Disconnected(ConnectionError::Io(io_error)))
if is_expected_io_error(io_error) =>
{
&self.parent.client_error
}
Err(_) => &self.parent.error,
let metric = if res.is_ok() {
&self.parent.ok
} else {
&self.parent.error
};
metric.observe(ex_throttled.as_secs_f64());
}

View File

@@ -273,20 +273,10 @@ async fn page_service_conn_main(
info!("Postgres client disconnected ({io_error})");
Ok(())
} else {
let tenant_id = conn_handler.timeline_handles.tenant_id();
Err(io_error).context(format!(
"Postgres connection error for tenant_id={:?} client at peer_addr={}",
tenant_id, peer_addr
))
Err(io_error).context("Postgres connection error")
}
}
other => {
let tenant_id = conn_handler.timeline_handles.tenant_id();
other.context(format!(
"Postgres query error for tenant_id={:?} client peer_addr={}",
tenant_id, peer_addr
))
}
other => other.context("Postgres query error"),
}
}
@@ -350,10 +340,6 @@ impl TimelineHandles {
}
})
}
fn tenant_id(&self) -> Option<TenantId> {
self.wrapper.tenant_id.get().copied()
}
}
pub(crate) struct TenantManagerWrapper {
@@ -833,7 +819,7 @@ impl PageServerHandler {
set_tracing_field_shard_id(&timeline);
let lease = timeline
.renew_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx)
.make_lsn_lease(lsn, timeline.get_lsn_lease_length(), ctx)
.inspect_err(|e| {
warn!("{e}");
})
@@ -1011,6 +997,7 @@ impl PageServerHandler {
)
.await?;
tracing::info!("get_rel_page_at_lsn: {lsn}");
let page = timeline
.get_rel_page_at_lsn(req.rel, req.blkno, Version::Lsn(lsn), ctx)
.await?;

View File

@@ -21,7 +21,6 @@ use futures::stream::FuturesUnordered;
use futures::StreamExt;
use pageserver_api::models;
use pageserver_api::models::AuxFilePolicy;
use pageserver_api::models::LsnLease;
use pageserver_api::models::TimelineArchivalState;
use pageserver_api::models::TimelineState;
use pageserver_api::models::TopTenantShardItem;
@@ -183,54 +182,27 @@ pub struct TenantSharedResources {
pub(super) struct AttachedTenantConf {
tenant_conf: TenantConfOpt,
location: AttachedLocationConfig,
/// The deadline before which we are blocked from GC so that
/// leases have a chance to be renewed.
lsn_lease_deadline: Option<tokio::time::Instant>,
}
impl AttachedTenantConf {
fn new(tenant_conf: TenantConfOpt, location: AttachedLocationConfig) -> Self {
// Sets a deadline before which we cannot proceed to GC due to lsn lease.
//
// We do this as the leases mapping are not persisted to disk. By delaying GC by lease
// length, we guarantee that all the leases we granted before will have a chance to renew
// when we run GC for the first time after restart / transition from AttachedMulti to AttachedSingle.
let lsn_lease_deadline = if location.attach_mode == AttachmentMode::Single {
Some(
tokio::time::Instant::now()
+ tenant_conf
.lsn_lease_length
.unwrap_or(LsnLease::DEFAULT_LENGTH),
)
} else {
// We don't use `lsn_lease_deadline` to delay GC in AttachedMulti and AttachedStale
// because we don't do GC in these modes.
None
};
Self {
tenant_conf,
location,
lsn_lease_deadline,
}
}
fn try_from(location_conf: LocationConf) -> anyhow::Result<Self> {
match &location_conf.mode {
LocationMode::Attached(attach_conf) => {
Ok(Self::new(location_conf.tenant_conf, *attach_conf))
}
LocationMode::Attached(attach_conf) => Ok(Self {
tenant_conf: location_conf.tenant_conf,
location: *attach_conf,
}),
LocationMode::Secondary(_) => {
anyhow::bail!("Attempted to construct AttachedTenantConf from a LocationConf in secondary mode")
}
}
}
fn is_gc_blocked_by_lsn_lease_deadline(&self) -> bool {
self.lsn_lease_deadline
.map(|d| tokio::time::Instant::now() < d)
.unwrap_or(false)
}
}
struct TimelinePreload {
timeline_id: TimelineId,
@@ -1850,11 +1822,6 @@ impl Tenant {
info!("Skipping GC in location state {:?}", conf.location);
return Ok(GcResult::default());
}
if conf.is_gc_blocked_by_lsn_lease_deadline() {
info!("Skipping GC because lsn lease deadline is not reached");
return Ok(GcResult::default());
}
}
let _guard = match self.gc_block.start().await {
@@ -2663,8 +2630,6 @@ impl Tenant {
Arc::new(AttachedTenantConf {
tenant_conf: new_tenant_conf.clone(),
location: inner.location,
// Attached location is not changed, no need to update lsn lease deadline.
lsn_lease_deadline: inner.lsn_lease_deadline,
})
});
@@ -3922,9 +3887,9 @@ async fn run_initdb(
let _permit = INIT_DB_SEMAPHORE.acquire().await;
let initdb_command = tokio::process::Command::new(&initdb_bin_path)
.args(["--pgdata", initdb_target_dir.as_ref()])
.args(["--username", &conf.superuser])
.args(["--encoding", "utf8"])
.args(["-D", initdb_target_dir.as_ref()])
.args(["-U", &conf.superuser])
.args(["-E", "utf8"])
.arg("--no-instructions")
.arg("--no-sync")
.env_clear()
@@ -4496,17 +4461,13 @@ mod tests {
tline.freeze_and_flush().await.map_err(|e| e.into())
}
#[tokio::test(start_paused = true)]
#[tokio::test]
async fn test_prohibit_branch_creation_on_garbage_collected_data() -> anyhow::Result<()> {
let (tenant, ctx) =
TenantHarness::create("test_prohibit_branch_creation_on_garbage_collected_data")
.await?
.load()
.await;
// Advance to the lsn lease deadline so that GC is not blocked by
// initial transition into AttachedSingle.
tokio::time::advance(tenant.get_lsn_lease_length()).await;
tokio::time::resume();
let tline = tenant
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
.await?;
@@ -7283,17 +7244,9 @@ mod tests {
Ok(())
}
#[tokio::test(start_paused = true)]
#[tokio::test]
async fn test_lsn_lease() -> anyhow::Result<()> {
let (tenant, ctx) = TenantHarness::create("test_lsn_lease")
.await
.unwrap()
.load()
.await;
// Advance to the lsn lease deadline so that GC is not blocked by
// initial transition into AttachedSingle.
tokio::time::advance(tenant.get_lsn_lease_length()).await;
tokio::time::resume();
let (tenant, ctx) = TenantHarness::create("test_lsn_lease").await?.load().await;
let key = Key::from_hex("010000000033333333444444445500000000").unwrap();
let end_lsn = Lsn(0x100);
@@ -7321,33 +7274,24 @@ mod tests {
let leased_lsns = [0x30, 0x50, 0x70];
let mut leases = Vec::new();
leased_lsns.iter().for_each(|n| {
leases.push(
timeline
.init_lsn_lease(Lsn(*n), timeline.get_lsn_lease_length(), &ctx)
.expect("lease request should succeed"),
);
let _: anyhow::Result<_> = leased_lsns.iter().try_for_each(|n| {
leases.push(timeline.make_lsn_lease(Lsn(*n), timeline.get_lsn_lease_length(), &ctx)?);
Ok(())
});
let updated_lease_0 = timeline
.renew_lsn_lease(Lsn(leased_lsns[0]), Duration::from_secs(0), &ctx)
.expect("lease renewal should succeed");
assert_eq!(
updated_lease_0.valid_until, leases[0].valid_until,
" Renewing with shorter lease should not change the lease."
);
// Renewing with shorter lease should not change the lease.
let updated_lease_0 =
timeline.make_lsn_lease(Lsn(leased_lsns[0]), Duration::from_secs(0), &ctx)?;
assert_eq!(updated_lease_0.valid_until, leases[0].valid_until);
let updated_lease_1 = timeline
.renew_lsn_lease(
Lsn(leased_lsns[1]),
timeline.get_lsn_lease_length() * 2,
&ctx,
)
.expect("lease renewal should succeed");
assert!(
updated_lease_1.valid_until > leases[1].valid_until,
"Renewing with a long lease should renew lease with later expiration time."
);
// Renewing with a long lease should renew lease with later expiration time.
let updated_lease_1 = timeline.make_lsn_lease(
Lsn(leased_lsns[1]),
timeline.get_lsn_lease_length() * 2,
&ctx,
)?;
assert!(updated_lease_1.valid_until > leases[1].valid_until);
// Force set disk consistent lsn so we can get the cutoff at `end_lsn`.
info!(
@@ -7364,8 +7308,7 @@ mod tests {
&CancellationToken::new(),
&ctx,
)
.await
.unwrap();
.await?;
// Keeping everything <= Lsn(0x80) b/c leases:
// 0/10: initdb layer
@@ -7379,16 +7322,13 @@ mod tests {
// Make lease on a already GC-ed LSN.
// 0/80 does not have a valid lease + is below latest_gc_cutoff
assert!(Lsn(0x80) < *timeline.get_latest_gc_cutoff_lsn());
timeline
.init_lsn_lease(Lsn(0x80), timeline.get_lsn_lease_length(), &ctx)
.expect_err("lease request on GC-ed LSN should fail");
let res = timeline.make_lsn_lease(Lsn(0x80), timeline.get_lsn_lease_length(), &ctx);
assert!(res.is_err());
// Should still be able to renew a currently valid lease
// Assumption: original lease to is still valid for 0/50.
// (use `Timeline::init_lsn_lease` for testing so it always does validation)
timeline
.init_lsn_lease(Lsn(leased_lsns[1]), timeline.get_lsn_lease_length(), &ctx)
.expect("lease renewal with validation should succeed");
let _ =
timeline.make_lsn_lease(Lsn(leased_lsns[1]), timeline.get_lsn_lease_length(), &ctx)?;
Ok(())
}

View File

@@ -1,12 +1,29 @@
use std::collections::HashMap;
use utils::id::TimelineId;
use std::{collections::HashMap, time::Duration};
use super::remote_timeline_client::index::GcBlockingReason;
use tokio::time::Instant;
use utils::id::TimelineId;
type Storage = HashMap<TimelineId, enumset::EnumSet<GcBlockingReason>>;
type TimelinesBlocked = HashMap<TimelineId, enumset::EnumSet<GcBlockingReason>>;
/// GcBlock provides persistent (per-timeline) gc blocking.
#[derive(Default)]
struct Storage {
timelines_blocked: TimelinesBlocked,
/// The deadline before which we are blocked from GC so that
/// leases have a chance to be renewed.
lsn_lease_deadline: Option<Instant>,
}
impl Storage {
fn is_blocked_by_lsn_lease_deadline(&self) -> bool {
self.lsn_lease_deadline
.map(|d| Instant::now() < d)
.unwrap_or(false)
}
}
/// GcBlock provides persistent (per-timeline) gc blocking and facilitates transient time based gc
/// blocking.
#[derive(Default)]
pub(crate) struct GcBlock {
/// The timelines which have current reasons to block gc.
@@ -49,6 +66,17 @@ impl GcBlock {
}
}
/// Sets a deadline before which we cannot proceed to GC due to lsn lease.
///
/// We do this as the leases mapping are not persisted to disk. By delaying GC by lease
/// length, we guarantee that all the leases we granted before will have a chance to renew
/// when we run GC for the first time after restart / transition from AttachedMulti to AttachedSingle.
pub(super) fn set_lsn_lease_deadline(&self, lsn_lease_length: Duration) {
let deadline = Instant::now() + lsn_lease_length;
let mut g = self.reasons.lock().unwrap();
g.lsn_lease_deadline = Some(deadline);
}
/// Describe the current gc blocking reasons.
///
/// TODO: make this json serializable.
@@ -74,7 +102,7 @@ impl GcBlock {
) -> anyhow::Result<bool> {
let (added, uploaded) = {
let mut g = self.reasons.lock().unwrap();
let set = g.entry(timeline.timeline_id).or_default();
let set = g.timelines_blocked.entry(timeline.timeline_id).or_default();
let added = set.insert(reason);
// LOCK ORDER: intentionally hold the lock, see self.reasons.
@@ -105,7 +133,7 @@ impl GcBlock {
let (remaining_blocks, uploaded) = {
let mut g = self.reasons.lock().unwrap();
match g.entry(timeline.timeline_id) {
match g.timelines_blocked.entry(timeline.timeline_id) {
Entry::Occupied(mut oe) => {
let set = oe.get_mut();
set.remove(reason);
@@ -119,7 +147,7 @@ impl GcBlock {
}
}
let remaining_blocks = g.len();
let remaining_blocks = g.timelines_blocked.len();
// LOCK ORDER: intentionally hold the lock while scheduling; see self.reasons
let uploaded = timeline
@@ -144,11 +172,11 @@ impl GcBlock {
pub(crate) fn before_delete(&self, timeline: &super::Timeline) {
let unblocked = {
let mut g = self.reasons.lock().unwrap();
if g.is_empty() {
if g.timelines_blocked.is_empty() {
return;
}
g.remove(&timeline.timeline_id);
g.timelines_blocked.remove(&timeline.timeline_id);
BlockingReasons::clean_and_summarize(g).is_none()
};
@@ -159,10 +187,11 @@ impl GcBlock {
}
/// Initialize with the non-deleted timelines of this tenant.
pub(crate) fn set_scanned(&self, scanned: Storage) {
pub(crate) fn set_scanned(&self, scanned: TimelinesBlocked) {
let mut g = self.reasons.lock().unwrap();
assert!(g.is_empty());
g.extend(scanned.into_iter().filter(|(_, v)| !v.is_empty()));
assert!(g.timelines_blocked.is_empty());
g.timelines_blocked
.extend(scanned.into_iter().filter(|(_, v)| !v.is_empty()));
if let Some(reasons) = BlockingReasons::clean_and_summarize(g) {
tracing::info!(summary=?reasons, "initialized with gc blocked");
@@ -176,6 +205,7 @@ pub(super) struct Guard<'a> {
#[derive(Debug)]
pub(crate) struct BlockingReasons {
tenant_blocked_by_lsn_lease_deadline: bool,
timelines: usize,
reasons: enumset::EnumSet<GcBlockingReason>,
}
@@ -184,8 +214,8 @@ impl std::fmt::Display for BlockingReasons {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} timelines block for {:?}",
self.timelines, self.reasons
"tenant_blocked_by_lsn_lease_deadline: {}, {} timelines block for {:?}",
self.tenant_blocked_by_lsn_lease_deadline, self.timelines, self.reasons
)
}
}
@@ -193,13 +223,15 @@ impl std::fmt::Display for BlockingReasons {
impl BlockingReasons {
fn clean_and_summarize(mut g: std::sync::MutexGuard<'_, Storage>) -> Option<Self> {
let mut reasons = enumset::EnumSet::empty();
g.retain(|_key, value| {
g.timelines_blocked.retain(|_key, value| {
reasons = reasons.union(*value);
!value.is_empty()
});
if !g.is_empty() {
let blocked_by_lsn_lease_deadline = g.is_blocked_by_lsn_lease_deadline();
if !g.timelines_blocked.is_empty() || blocked_by_lsn_lease_deadline {
Some(BlockingReasons {
timelines: g.len(),
tenant_blocked_by_lsn_lease_deadline: blocked_by_lsn_lease_deadline,
timelines: g.timelines_blocked.len(),
reasons,
})
} else {
@@ -208,14 +240,17 @@ impl BlockingReasons {
}
fn summarize(g: &std::sync::MutexGuard<'_, Storage>) -> Option<Self> {
if g.is_empty() {
let blocked_by_lsn_lease_deadline = g.is_blocked_by_lsn_lease_deadline();
if g.timelines_blocked.is_empty() && !blocked_by_lsn_lease_deadline {
None
} else {
let reasons = g
.timelines_blocked
.values()
.fold(enumset::EnumSet::empty(), |acc, next| acc.union(*next));
Some(BlockingReasons {
timelines: g.len(),
tenant_blocked_by_lsn_lease_deadline: blocked_by_lsn_lease_deadline,
timelines: g.timelines_blocked.len(),
reasons,
})
}

View File

@@ -1470,4 +1470,52 @@ mod tests {
LayerVisibilityHint::Visible
));
}
/// Exercise edge case of querying at exactly the LSN of an image layer
#[test]
fn layer_search_at_image_lsn() {
let tenant_id = TenantId::generate();
let tenant_shard_id = TenantShardId::unsharded(tenant_id);
let timeline_id = TimelineId::generate();
let last_record_lsn = Lsn::from_hex("00000000DEADBEEF").unwrap();
let mut layer_map = LayerMap::default();
let mut updates = layer_map.batch_update();
let image_layer = PersistentLayerDesc {
key_range: Key::from_i128(0)..Key::from_i128(i128::MAX),
lsn_range: PersistentLayerDesc::image_layer_lsn_range(last_record_lsn),
tenant_shard_id,
timeline_id,
is_delta: false,
file_size: 123,
};
let delta_layer = PersistentLayerDesc {
key_range: Key::from_i128(0)..Key::from_i128(i128::MAX),
lsn_range: Lsn(0)..Lsn(0xdead0000),
tenant_shard_id,
timeline_id,
is_delta: true,
file_size: 123,
};
updates.insert_historic(image_layer.clone());
updates.insert_historic(delta_layer);
updates.flush();
// FIXME: according to the search() docstring, it searches for layers with start LSNs _less then_
// `end_lsn` -- i.e. it's correct that if you ask for exactly the LSN of an image layer, it shouldn't hit
// it. However, the way that page_service calls it is to take the last_record_lsn of a Timeline
// and pass that directly into LayerMap::search().
let searched = layer_map
.search(Key::from_i128(12345), last_record_lsn)
.unwrap();
// We searched at the LSN of the image layer: we should hit it
assert_eq!(searched.layer.as_ref(), &image_layer);
}
}

View File

@@ -30,8 +30,8 @@ use utils::{backoff, completion, crashsafe};
use crate::config::PageServerConf;
use crate::context::{DownloadBehavior, RequestContext};
use crate::controller_upcall_client::{
ControlPlaneGenerationsApi, ControllerUpcallClient, RetryForeverError,
use crate::control_plane_client::{
ControlPlaneClient, ControlPlaneGenerationsApi, RetryForeverError,
};
use crate::deletion_queue::DeletionQueueClient;
use crate::http::routes::ACTIVE_TENANT_TIMEOUT;
@@ -122,7 +122,7 @@ pub(crate) enum ShardSelector {
Known(ShardIndex),
}
/// A convenience for use with the re_attach ControllerUpcallClient function: rather
/// A convenience for use with the re_attach ControlPlaneClient function: rather
/// than the serializable struct, we build this enum that encapsulates
/// the invariant that attached tenants always have generations.
///
@@ -219,11 +219,7 @@ async fn safe_rename_tenant_dir(path: impl AsRef<Utf8Path>) -> std::io::Result<U
+ TEMP_FILE_SUFFIX;
let tmp_path = path_with_suffix_extension(&path, &rand_suffix);
fs::rename(path.as_ref(), &tmp_path).await?;
fs::File::open(parent)
.await?
.sync_all()
.await
.maybe_fatal_err("safe_rename_tenant_dir")?;
fs::File::open(parent).await?.sync_all().await?;
Ok(tmp_path)
}
@@ -345,7 +341,7 @@ async fn init_load_generations(
"Emergency mode! Tenants will be attached unsafely using their last known generation"
);
emergency_generations(tenant_confs)
} else if let Some(client) = ControllerUpcallClient::new(conf, cancel) {
} else if let Some(client) = ControlPlaneClient::new(conf, cancel) {
info!("Calling control plane API to re-attach tenants");
// If we are configured to use the control plane API, then it is the source of truth for what tenants to load.
match client.re_attach(conf).await {
@@ -953,6 +949,12 @@ impl TenantManager {
(LocationMode::Attached(attach_conf), Some(TenantSlot::Attached(tenant))) => {
match attach_conf.generation.cmp(&tenant.generation) {
Ordering::Equal => {
if attach_conf.attach_mode == AttachmentMode::Single {
tenant
.gc_block
.set_lsn_lease_deadline(tenant.get_lsn_lease_length());
}
// A transition from Attached to Attached in the same generation, we may
// take our fast path and just provide the updated configuration
// to the tenant.
@@ -2197,82 +2199,6 @@ impl TenantManager {
Ok((wanted_bytes, shard_count as u32))
}
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), %timeline_id))]
pub(crate) async fn immediate_gc(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
gc_req: TimelineGcRequest,
cancel: CancellationToken,
ctx: &RequestContext,
) -> Result<GcResult, ApiError> {
let tenant = {
let guard = self.tenants.read().unwrap();
guard
.get(&tenant_shard_id)
.cloned()
.with_context(|| format!("tenant {tenant_shard_id}"))
.map_err(|e| ApiError::NotFound(e.into()))?
};
let gc_horizon = gc_req.gc_horizon.unwrap_or_else(|| tenant.get_gc_horizon());
// Use tenant's pitr setting
let pitr = tenant.get_pitr_interval();
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
// Run in task_mgr to avoid race with tenant_detach operation
let ctx: RequestContext =
ctx.detached_child(TaskKind::GarbageCollector, DownloadBehavior::Download);
let _gate_guard = tenant.gate.enter().map_err(|_| ApiError::ShuttingDown)?;
fail::fail_point!("immediate_gc_task_pre");
#[allow(unused_mut)]
let mut result = tenant
.gc_iteration(Some(timeline_id), gc_horizon, pitr, &cancel, &ctx)
.await;
// FIXME: `gc_iteration` can return an error for multiple reasons; we should handle it
// better once the types support it.
#[cfg(feature = "testing")]
{
// we need to synchronize with drop completion for python tests without polling for
// log messages
if let Ok(result) = result.as_mut() {
let mut js = tokio::task::JoinSet::new();
for layer in std::mem::take(&mut result.doomed_layers) {
js.spawn(layer.wait_drop());
}
tracing::info!(
total = js.len(),
"starting to wait for the gc'd layers to be dropped"
);
while let Some(res) = js.join_next().await {
res.expect("wait_drop should not panic");
}
}
let timeline = tenant.get_timeline(timeline_id, false).ok();
let rtc = timeline.as_ref().map(|x| &x.remote_client);
if let Some(rtc) = rtc {
// layer drops schedule actions on remote timeline client to actually do the
// deletions; don't care about the shutdown error, just exit fast
drop(rtc.wait_completion().await);
}
}
result.map_err(|e| match e {
GcError::TenantCancelled | GcError::TimelineCancelled => ApiError::ShuttingDown,
GcError::TimelineNotFound => {
ApiError::NotFound(anyhow::anyhow!("Timeline not found").into())
}
other => ApiError::InternalServerError(anyhow::anyhow!(other)),
})
}
}
#[derive(Debug, thiserror::Error)]
@@ -2417,7 +2343,7 @@ enum TenantSlotDropError {
/// Errors that can happen any time we are walking the tenant map to try and acquire
/// the TenantSlot for a particular tenant.
#[derive(Debug, thiserror::Error)]
pub(crate) enum TenantMapError {
pub enum TenantMapError {
// Tried to read while initializing
#[error("tenant map is still initializing")]
StillInitializing,
@@ -2447,7 +2373,7 @@ pub(crate) enum TenantMapError {
/// The `old_value` may be dropped before the SlotGuard is dropped, by calling
/// `drop_old_value`. It is an error to call this without shutting down
/// the conents of `old_value`.
pub(crate) struct SlotGuard {
pub struct SlotGuard {
tenant_shard_id: TenantShardId,
old_value: Option<TenantSlot>,
upserted: bool,
@@ -2840,6 +2766,81 @@ use {
utils::http::error::ApiError,
};
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), %timeline_id))]
pub(crate) async fn immediate_gc(
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
gc_req: TimelineGcRequest,
cancel: CancellationToken,
ctx: &RequestContext,
) -> Result<GcResult, ApiError> {
let tenant = {
let guard = TENANTS.read().unwrap();
guard
.get(&tenant_shard_id)
.cloned()
.with_context(|| format!("tenant {tenant_shard_id}"))
.map_err(|e| ApiError::NotFound(e.into()))?
};
let gc_horizon = gc_req.gc_horizon.unwrap_or_else(|| tenant.get_gc_horizon());
// Use tenant's pitr setting
let pitr = tenant.get_pitr_interval();
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
// Run in task_mgr to avoid race with tenant_detach operation
let ctx: RequestContext =
ctx.detached_child(TaskKind::GarbageCollector, DownloadBehavior::Download);
let _gate_guard = tenant.gate.enter().map_err(|_| ApiError::ShuttingDown)?;
fail::fail_point!("immediate_gc_task_pre");
#[allow(unused_mut)]
let mut result = tenant
.gc_iteration(Some(timeline_id), gc_horizon, pitr, &cancel, &ctx)
.await;
// FIXME: `gc_iteration` can return an error for multiple reasons; we should handle it
// better once the types support it.
#[cfg(feature = "testing")]
{
// we need to synchronize with drop completion for python tests without polling for
// log messages
if let Ok(result) = result.as_mut() {
let mut js = tokio::task::JoinSet::new();
for layer in std::mem::take(&mut result.doomed_layers) {
js.spawn(layer.wait_drop());
}
tracing::info!(
total = js.len(),
"starting to wait for the gc'd layers to be dropped"
);
while let Some(res) = js.join_next().await {
res.expect("wait_drop should not panic");
}
}
let timeline = tenant.get_timeline(timeline_id, false).ok();
let rtc = timeline.as_ref().map(|x| &x.remote_client);
if let Some(rtc) = rtc {
// layer drops schedule actions on remote timeline client to actually do the
// deletions; don't care about the shutdown error, just exit fast
drop(rtc.wait_completion().await);
}
}
result.map_err(|e| match e {
GcError::TenantCancelled | GcError::TimelineCancelled => ApiError::ShuttingDown,
GcError::TimelineNotFound => {
ApiError::NotFound(anyhow::anyhow!("Timeline not found").into())
}
other => ApiError::InternalServerError(anyhow::anyhow!(other)),
})
}
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;

View File

@@ -178,7 +178,6 @@ async fn download_object<'a>(
destination_file
.flush()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("flush source file at {dst_path}"))
.map_err(DownloadError::Other)?;
@@ -186,7 +185,6 @@ async fn download_object<'a>(
destination_file
.sync_all()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("failed to fsync source file at {dst_path}"))
.map_err(DownloadError::Other)?;
@@ -234,7 +232,6 @@ async fn download_object<'a>(
destination_file
.sync_all()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("failed to fsync source file at {dst_path}"))
.map_err(DownloadError::Other)?;

View File

@@ -433,6 +433,7 @@ impl ReadableLayer {
reconstruct_state: &mut ValuesReconstructState,
ctx: &RequestContext,
) -> Result<(), GetVectoredError> {
tracing::info!("get_values_reconstruct_data: {:?}", self.id());
match self {
ReadableLayer::PersistentLayer(layer) => {
layer

View File

@@ -40,11 +40,11 @@ use crate::tenant::storage_layer::layer::S3_UPLOAD_LIMIT;
use crate::tenant::timeline::GetVectoredError;
use crate::tenant::vectored_blob_io::{
BlobFlag, BufView, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
VectoredReadPlanner,
VectoredReadCoalesceMode, VectoredReadPlanner,
};
use crate::tenant::PageReconstructError;
use crate::virtual_file::owned_buffers_io::io_buf_ext::{FullSlice, IoBufExt};
use crate::virtual_file::{self, MaybeFatalIo, VirtualFile};
use crate::virtual_file::{self, VirtualFile};
use crate::{walrecord, TEMP_FILE_SUFFIX};
use crate::{DELTA_FILE_MAGIC, STORAGE_FORMAT_VERSION};
use anyhow::{anyhow, bail, ensure, Context, Result};
@@ -589,9 +589,7 @@ impl DeltaLayerWriterInner {
);
// fsync the file
file.sync_all()
.await
.maybe_fatal_err("delta_layer sync_all")?;
file.sync_all().await?;
trace!("created delta layer {}", self.path);
@@ -1135,7 +1133,7 @@ impl DeltaLayerInner {
ctx: &RequestContext,
) -> anyhow::Result<usize> {
use crate::tenant::vectored_blob_io::{
BlobMeta, ChunkedVectoredReadBuilder, VectoredReadExtended,
BlobMeta, VectoredReadBuilder, VectoredReadExtended,
};
use futures::stream::TryStreamExt;
@@ -1185,8 +1183,8 @@ impl DeltaLayerInner {
let mut prev: Option<(Key, Lsn, BlobRef)> = None;
let mut read_builder: Option<ChunkedVectoredReadBuilder> = None;
let align = virtual_file::get_io_buffer_alignment();
let mut read_builder: Option<VectoredReadBuilder> = None;
let read_mode = VectoredReadCoalesceMode::get();
let max_read_size = self
.max_vectored_read_bytes
@@ -1230,12 +1228,12 @@ impl DeltaLayerInner {
{
None
} else {
read_builder.replace(ChunkedVectoredReadBuilder::new(
read_builder.replace(VectoredReadBuilder::new(
offsets.start.pos(),
offsets.end.pos(),
meta,
max_read_size,
align,
read_mode,
))
}
} else {

View File

@@ -41,7 +41,7 @@ use crate::tenant::vectored_blob_io::{
};
use crate::tenant::PageReconstructError;
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
use crate::virtual_file::{self, MaybeFatalIo, VirtualFile};
use crate::virtual_file::{self, VirtualFile};
use crate::{IMAGE_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX};
use anyhow::{anyhow, bail, ensure, Context, Result};
use bytes::{Bytes, BytesMut};
@@ -889,9 +889,7 @@ impl ImageLayerWriterInner {
// set inner.file here. The first read will have to re-open it.
// fsync the file
file.sync_all()
.await
.maybe_fatal_err("image_layer sync_all")?;
file.sync_all().await?;
trace!("created image layer {}", self.path);

View File

@@ -330,6 +330,7 @@ async fn gc_loop(tenant: Arc<Tenant>, cancel: CancellationToken) {
RequestContext::todo_child(TaskKind::GarbageCollector, DownloadBehavior::Download);
let mut first = true;
tenant.gc_block.set_lsn_lease_deadline(tenant.get_lsn_lease_length());
loop {
tokio::select! {
_ = cancel.cancelled() => {

View File

@@ -66,7 +66,6 @@ use std::{
use crate::{
aux_file::AuxFileSizeEstimator,
tenant::{
config::AttachmentMode,
layer_map::{LayerMap, SearchResult},
metadata::TimelineMetadata,
storage_layer::{inmemory_layer::IndexEntry, PersistentLayerDesc},
@@ -1325,38 +1324,16 @@ impl Timeline {
Ok(())
}
/// Initializes an LSN lease. The function will return an error if the requested LSN is less than the `latest_gc_cutoff_lsn`.
pub(crate) fn init_lsn_lease(
&self,
lsn: Lsn,
length: Duration,
ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
self.make_lsn_lease(lsn, length, true, ctx)
}
/// Renews a lease at a particular LSN. The requested LSN is not validated against the `latest_gc_cutoff_lsn` when we are in the grace period.
pub(crate) fn renew_lsn_lease(
&self,
lsn: Lsn,
length: Duration,
ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
self.make_lsn_lease(lsn, length, false, ctx)
}
/// Obtains a temporary lease blocking garbage collection for the given LSN.
///
/// If we are in `AttachedSingle` mode and is not blocked by the lsn lease deadline, this function will error
/// if the requesting LSN is less than the `latest_gc_cutoff_lsn` and there is no existing request present.
///
/// If there is an existing lease in the map, the lease will be renewed only if the request extends the lease.
/// The returned lease is therefore the maximum between the existing lease and the requesting lease.
fn make_lsn_lease(
/// This function will error if the requesting LSN is less than the `latest_gc_cutoff_lsn` and there is also
/// no existing lease to renew. If there is an existing lease in the map, the lease will be renewed only if
/// the request extends the lease. The returned lease is therefore the maximum between the existing lease and
/// the requesting lease.
pub(crate) fn make_lsn_lease(
&self,
lsn: Lsn,
length: Duration,
init: bool,
_ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
let lease = {
@@ -1370,8 +1347,8 @@ impl Timeline {
let entry = gc_info.leases.entry(lsn);
match entry {
Entry::Occupied(mut occupied) => {
let lease = {
if let Entry::Occupied(mut occupied) = entry {
let existing_lease = occupied.get_mut();
if valid_until > existing_lease.valid_until {
existing_lease.valid_until = valid_until;
@@ -1383,28 +1360,20 @@ impl Timeline {
}
existing_lease.clone()
}
Entry::Vacant(vacant) => {
// Reject already GC-ed LSN (lsn < latest_gc_cutoff) if we are in AttachedSingle and
// not blocked by the lsn lease deadline.
let validate = {
let conf = self.tenant_conf.load();
conf.location.attach_mode == AttachmentMode::Single
&& !conf.is_gc_blocked_by_lsn_lease_deadline()
};
if init || validate {
let latest_gc_cutoff_lsn = self.get_latest_gc_cutoff_lsn();
if lsn < *latest_gc_cutoff_lsn {
bail!("tried to request a page version that was garbage collected. requested at {} gc cutoff {}", lsn, *latest_gc_cutoff_lsn);
}
} else {
// Reject already GC-ed LSN (lsn < latest_gc_cutoff)
let latest_gc_cutoff_lsn = self.get_latest_gc_cutoff_lsn();
if lsn < *latest_gc_cutoff_lsn {
bail!("tried to request a page version that was garbage collected. requested at {} gc cutoff {}", lsn, *latest_gc_cutoff_lsn);
}
let dt: DateTime<Utc> = valid_until.into();
info!("lease created, valid until {}", dt);
vacant.insert(LsnLease { valid_until }).clone()
entry.or_insert(LsnLease { valid_until }).clone()
}
}
};
lease
};
Ok(lease)
@@ -1981,6 +1950,8 @@ impl Timeline {
.unwrap_or(self.conf.default_tenant_conf.lsn_lease_length)
}
// TODO(yuchen): remove unused flag after implementing https://github.com/neondatabase/neon/issues/8072
#[allow(unused)]
pub(crate) fn get_lsn_lease_length_for_ts(&self) -> Duration {
let tenant_conf = self.tenant_conf.load();
tenant_conf
@@ -3885,21 +3856,21 @@ impl Timeline {
)));
}
let distance = lsn.0 - partition_lsn.0;
if *partition_lsn != Lsn(0)
&& distance <= self.repartition_threshold
&& !flags.contains(CompactFlags::ForceRepartition)
{
debug!(
distance,
threshold = self.repartition_threshold,
"no repartitioning needed"
);
return Ok((
(dense_partition.clone(), sparse_partition.clone()),
*partition_lsn,
));
}
// let distance = lsn.0 - partition_lsn.0;
// if *partition_lsn != Lsn(0)
// && distance <= self.repartition_threshold
// && !flags.contains(CompactFlags::ForceRepartition)
// {
// debug!(
// distance,
// threshold = self.repartition_threshold,
// "no repartitioning needed"
// );
// return Ok((
// (dense_partition.clone(), sparse_partition.clone()),
// *partition_lsn,
// ));
// }
let (dense_ks, sparse_ks) = self.collect_keyspace(lsn, ctx).await?;
let dense_partitioning = dense_ks.partition(&self.shard_identity, partition_size);
@@ -5808,6 +5779,7 @@ impl<'a> TimelineWriter<'a> {
/// the 'lsn' or anything older. The previous last record LSN is stored alongside
/// the latest and can be read.
pub(crate) fn finish_write(&self, new_lsn: Lsn) {
tracing::info!("finish_write @ {new_lsn}");
self.tl.finish_write(new_lsn);
}

View File

@@ -364,6 +364,10 @@ impl Timeline {
// 3. Create new image layers for partitions that have been modified
// "enough". Skip image layer creation if L0 compaction cannot keep up.
if fully_compacted {
tracing::info!(
"create_image_layers @ {lsn} (latest {})",
self.get_last_record_lsn()
);
let image_layers = self
.create_image_layers(
&partitioning,

View File

@@ -185,7 +185,171 @@ pub(crate) enum VectoredReadExtended {
No,
}
/// A vectored read builder that tries to coalesce all reads that fits in a chunk.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum VectoredReadCoalesceMode {
/// Only coalesce exactly adjacent reads.
AdjacentOnly,
/// In addition to adjacent reads, also consider reads whose corresponding
/// `end` and `start` offsets reside at the same chunk.
Chunked(usize),
}
impl VectoredReadCoalesceMode {
/// [`AdjacentVectoredReadBuilder`] is used if alignment requirement is 0,
/// whereas [`ChunkedVectoredReadBuilder`] is used for alignment requirement 1 and higher.
pub(crate) fn get() -> Self {
let align = virtual_file::get_io_buffer_alignment_raw();
if align == 0 {
VectoredReadCoalesceMode::AdjacentOnly
} else {
VectoredReadCoalesceMode::Chunked(align)
}
}
}
pub(crate) enum VectoredReadBuilder {
Adjacent(AdjacentVectoredReadBuilder),
Chunked(ChunkedVectoredReadBuilder),
}
impl VectoredReadBuilder {
fn new_impl(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: Option<usize>,
mode: VectoredReadCoalesceMode,
) -> Self {
match mode {
VectoredReadCoalesceMode::AdjacentOnly => Self::Adjacent(
AdjacentVectoredReadBuilder::new(start_offset, end_offset, meta, max_read_size),
),
VectoredReadCoalesceMode::Chunked(chunk_size) => {
Self::Chunked(ChunkedVectoredReadBuilder::new(
start_offset,
end_offset,
meta,
max_read_size,
chunk_size,
))
}
}
}
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: usize,
mode: VectoredReadCoalesceMode,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), mode)
}
pub(crate) fn new_streaming(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
mode: VectoredReadCoalesceMode,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, None, mode)
}
pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended {
match self {
VectoredReadBuilder::Adjacent(builder) => builder.extend(start, end, meta),
VectoredReadBuilder::Chunked(builder) => builder.extend(start, end, meta),
}
}
pub(crate) fn build(self) -> VectoredRead {
match self {
VectoredReadBuilder::Adjacent(builder) => builder.build(),
VectoredReadBuilder::Chunked(builder) => builder.build(),
}
}
pub(crate) fn size(&self) -> usize {
match self {
VectoredReadBuilder::Adjacent(builder) => builder.size(),
VectoredReadBuilder::Chunked(builder) => builder.size(),
}
}
}
pub(crate) struct AdjacentVectoredReadBuilder {
/// Start offset of the read.
start: u64,
// End offset of the read.
end: u64,
/// Start offset and metadata for each blob in this read
blobs_at: VecMap<u64, BlobMeta>,
max_read_size: Option<usize>,
}
impl AdjacentVectoredReadBuilder {
/// Start building a new vectored read.
///
/// Note that by design, this does not check against reading more than `max_read_size` to
/// support reading larger blobs than the configuration value. The builder will be single use
/// however after that.
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: Option<usize>,
) -> Self {
let mut blobs_at = VecMap::default();
blobs_at
.append(start_offset, meta)
.expect("First insertion always succeeds");
Self {
start: start_offset,
end: end_offset,
blobs_at,
max_read_size,
}
}
/// Attempt to extend the current read with a new blob if the start
/// offset matches with the current end of the vectored read
/// and the resuting size is below the max read size
pub(crate) fn extend(&mut self, start: u64, end: u64, meta: BlobMeta) -> VectoredReadExtended {
tracing::trace!(start, end, "trying to extend");
let size = (end - start) as usize;
let not_limited_by_max_read_size = {
if let Some(max_read_size) = self.max_read_size {
self.size() + size <= max_read_size
} else {
true
}
};
if self.end == start && not_limited_by_max_read_size {
self.end = end;
self.blobs_at
.append(start, meta)
.expect("LSNs are ordered within vectored reads");
return VectoredReadExtended::Yes;
}
VectoredReadExtended::No
}
pub(crate) fn size(&self) -> usize {
(self.end - self.start) as usize
}
pub(crate) fn build(self) -> VectoredRead {
VectoredRead {
start: self.start,
end: self.end,
blobs_at: self.blobs_at,
}
}
}
pub(crate) struct ChunkedVectoredReadBuilder {
/// Start block number
start_blk_no: usize,
@@ -209,7 +373,7 @@ impl ChunkedVectoredReadBuilder {
/// Note that by design, this does not check against reading more than `max_read_size` to
/// support reading larger blobs than the configuration value. The builder will be single use
/// however after that.
fn new_impl(
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
@@ -232,25 +396,6 @@ impl ChunkedVectoredReadBuilder {
}
}
pub(crate) fn new(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
max_read_size: usize,
align: usize,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, Some(max_read_size), align)
}
pub(crate) fn new_streaming(
start_offset: u64,
end_offset: u64,
meta: BlobMeta,
align: usize,
) -> Self {
Self::new_impl(start_offset, end_offset, meta, None, align)
}
/// Attempts to extend the current read with a new blob if the new blob resides in the same or the immediate next chunk.
///
/// The resulting size also must be below the max read size.
@@ -329,17 +474,17 @@ pub struct VectoredReadPlanner {
max_read_size: usize,
align: usize,
mode: VectoredReadCoalesceMode,
}
impl VectoredReadPlanner {
pub fn new(max_read_size: usize) -> Self {
let align = virtual_file::get_io_buffer_alignment();
let mode = VectoredReadCoalesceMode::get();
Self {
blobs: BTreeMap::new(),
prev: None,
max_read_size,
align,
mode,
}
}
@@ -400,7 +545,7 @@ impl VectoredReadPlanner {
}
pub fn finish(self) -> Vec<VectoredRead> {
let mut current_read_builder: Option<ChunkedVectoredReadBuilder> = None;
let mut current_read_builder: Option<VectoredReadBuilder> = None;
let mut reads = Vec::new();
for (key, blobs_for_key) in self.blobs {
@@ -413,12 +558,12 @@ impl VectoredReadPlanner {
};
if extended == VectoredReadExtended::No {
let next_read_builder = ChunkedVectoredReadBuilder::new(
let next_read_builder = VectoredReadBuilder::new(
start_offset,
end_offset,
BlobMeta { key, lsn },
self.max_read_size,
self.align,
self.mode,
);
let prev_read_builder = current_read_builder.replace(next_read_builder);
@@ -543,7 +688,7 @@ impl<'a> VectoredBlobReader<'a> {
/// `handle` gets called and when the current key would just exceed the read_size and
/// max_cnt constraints.
pub struct StreamingVectoredReadPlanner {
read_builder: Option<ChunkedVectoredReadBuilder>,
read_builder: Option<VectoredReadBuilder>,
// Arguments for previous blob passed into [`StreamingVectoredReadPlanner::handle`]
prev: Option<(Key, Lsn, u64)>,
/// Max read size per batch. This is not a strict limit. If there are [0, 100) and [100, 200), while the `max_read_size` is 150,
@@ -554,21 +699,21 @@ pub struct StreamingVectoredReadPlanner {
/// Size of the current batch
cnt: usize,
align: usize,
mode: VectoredReadCoalesceMode,
}
impl StreamingVectoredReadPlanner {
pub fn new(max_read_size: u64, max_cnt: usize) -> Self {
assert!(max_cnt > 0);
assert!(max_read_size > 0);
let align = virtual_file::get_io_buffer_alignment();
let mode = VectoredReadCoalesceMode::get();
Self {
read_builder: None,
prev: None,
max_cnt,
max_read_size,
cnt: 0,
align,
mode,
}
}
@@ -617,11 +762,11 @@ impl StreamingVectoredReadPlanner {
}
None => {
self.read_builder = {
Some(ChunkedVectoredReadBuilder::new_streaming(
Some(VectoredReadBuilder::new_streaming(
start_offset,
end_offset,
BlobMeta { key, lsn },
self.align,
self.mode,
))
};
}
@@ -947,7 +1092,7 @@ mod tests {
let reserved_bytes = blobs.iter().map(|bl| bl.len()).max().unwrap() * 2 + 16;
let mut buf = BytesMut::with_capacity(reserved_bytes);
let align = virtual_file::get_io_buffer_alignment();
let mode = VectoredReadCoalesceMode::get();
let vectored_blob_reader = VectoredBlobReader::new(&file);
let meta = BlobMeta {
key: Key::MIN,
@@ -959,8 +1104,7 @@ mod tests {
if idx + 1 == offsets.len() {
continue;
}
let read_builder =
ChunkedVectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, align);
let read_builder = VectoredReadBuilder::new(*offset, *end, meta, 16 * 4096, mode);
let read = read_builder.build();
let result = vectored_blob_reader.read_blobs(&read, buf, &ctx).await?;
assert_eq!(result.blobs.len(), 1);

View File

@@ -466,7 +466,6 @@ impl VirtualFile {
&[]
};
utils::crashsafe::overwrite(&final_path, &tmp_path, content)
.maybe_fatal_err("crashsafe_overwrite")
})
.await
.expect("blocking task is never aborted")
@@ -476,7 +475,7 @@ impl VirtualFile {
pub async fn sync_all(&self) -> Result<(), Error> {
with_file!(self, StorageIoOperation::Fsync, |file_guard| {
let (_file_guard, res) = io_engine::get().sync_all(file_guard).await;
res.maybe_fatal_err("sync_all")
res
})
}
@@ -484,7 +483,7 @@ impl VirtualFile {
pub async fn sync_data(&self) -> Result<(), Error> {
with_file!(self, StorageIoOperation::Fsync, |file_guard| {
let (_file_guard, res) = io_engine::get().sync_data(file_guard).await;
res.maybe_fatal_err("sync_data")
res
})
}
@@ -1148,9 +1147,7 @@ pub fn init(num_slots: usize, engine: IoEngineKind, io_buffer_alignment: usize)
panic!("virtual_file::init called twice");
}
if set_io_buffer_alignment(io_buffer_alignment).is_err() {
panic!(
"IO buffer alignment needs to be a power of two and greater than 512, got {io_buffer_alignment}"
);
panic!("IO buffer alignment ({io_buffer_alignment}) is not a power of two");
}
io_engine::init(engine);
crate::metrics::virtual_file_descriptor_cache::SIZE_MAX.set(num_slots as u64);
@@ -1177,16 +1174,14 @@ fn get_open_files() -> &'static OpenFiles {
static IO_BUFFER_ALIGNMENT: AtomicUsize = AtomicUsize::new(DEFAULT_IO_BUFFER_ALIGNMENT);
/// Returns true if the alignment is a power of two and is greater or equal to 512.
fn is_valid_io_buffer_alignment(align: usize) -> bool {
align.is_power_of_two() && align >= 512
/// Returns true if `x` is zero or a power of two.
fn is_zero_or_power_of_two(x: usize) -> bool {
(x == 0) || ((x & (x - 1)) == 0)
}
/// Sets IO buffer alignment requirement. Returns error if the alignment requirement is
/// not a power of two or less than 512 bytes.
#[allow(unused)]
pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> {
if is_valid_io_buffer_alignment(align) {
if is_zero_or_power_of_two(align) {
IO_BUFFER_ALIGNMENT.store(align, std::sync::atomic::Ordering::Relaxed);
Ok(())
} else {
@@ -1194,19 +1189,19 @@ pub(crate) fn set_io_buffer_alignment(align: usize) -> Result<(), usize> {
}
}
/// Gets the io buffer alignment.
/// Gets the io buffer alignment requirement. Returns 0 if there is no requirement specified.
///
/// This function should be used for getting the actual alignment value to use.
pub(crate) fn get_io_buffer_alignment() -> usize {
/// This function should be used to check the raw config value.
pub(crate) fn get_io_buffer_alignment_raw() -> usize {
let align = IO_BUFFER_ALIGNMENT.load(std::sync::atomic::Ordering::Relaxed);
if cfg!(test) {
let env_var_name = "NEON_PAGESERVER_UNIT_TEST_IO_BUFFER_ALIGNMENT";
if let Some(test_align) = utils::env::var(env_var_name) {
if is_valid_io_buffer_alignment(test_align) {
if is_zero_or_power_of_two(test_align) {
test_align
} else {
panic!("IO buffer alignment needs to be a power of two and greater than 512, got {test_align}");
panic!("IO buffer alignment ({test_align}) is not a power of two");
}
} else {
align
@@ -1216,6 +1211,14 @@ pub(crate) fn get_io_buffer_alignment() -> usize {
}
}
/// Gets the io buffer alignment requirement. Returns 1 if the alignment config is set to zero.
///
/// This function should be used for getting the actual alignment value to use.
pub(crate) fn get_io_buffer_alignment() -> usize {
let align = get_io_buffer_alignment_raw();
align.max(1)
}
#[cfg(test)]
mod tests {
use crate::context::DownloadBehavior;

View File

@@ -1,6 +1,8 @@
# neon extension
comment = 'cloud storage for PostgreSQL'
default_version = '1.5'
# TODO: bump default version to 1.5, after we are certain that we don't
# need to rollback the compute image
default_version = '1.4'
module_pathname = '$libdir/neon'
relocatable = true
trusted = true

View File

@@ -1473,33 +1473,11 @@ walprop_pg_wal_read(Safekeeper *sk, char *buf, XLogRecPtr startptr, Size count,
{
NeonWALReadResult res;
#if PG_MAJORVERSION_NUM >= 17
if (!sk->wp->config->syncSafekeepers)
{
Size rbytes;
rbytes = WALReadFromBuffers(buf, startptr, count,
walprop_pg_get_timeline_id());
startptr += rbytes;
count -= rbytes;
}
#endif
if (count == 0)
{
res = NEON_WALREAD_SUCCESS;
}
else
{
Assert(count > 0);
/* Now read the remaining WAL from the WAL file */
res = NeonWALRead(sk->xlogreader,
buf,
startptr,
count,
walprop_pg_get_timeline_id());
}
res = NeonWALRead(sk->xlogreader,
buf,
startptr,
count,
walprop_pg_get_timeline_id());
if (res == NEON_WALREAD_SUCCESS)
{

View File

@@ -24,7 +24,6 @@ bytes = { workspace = true, features = ["serde"] }
camino.workspace = true
chrono.workspace = true
clap.workspace = true
compute_api.workspace = true
consumption_metrics.workspace = true
dashmap.workspace = true
env_logger.workspace = true
@@ -82,6 +81,7 @@ tokio-postgres-rustls.workspace = true
tokio-rustls.workspace = true
tokio-util.workspace = true
tokio = { workspace = true, features = ["signal"] }
tracing-opentelemetry.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true
tracing.workspace = true

View File

@@ -80,14 +80,6 @@ pub(crate) trait TestBackend: Send + Sync + 'static {
fn get_allowed_ips_and_secret(
&self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
fn dyn_clone(&self) -> Box<dyn TestBackend>;
}
#[cfg(test)]
impl Clone for Box<dyn TestBackend> {
fn clone(&self) -> Self {
TestBackend::dyn_clone(&**self)
}
}
impl std::fmt::Display for Backend<'_, (), ()> {
@@ -565,7 +557,7 @@ mod tests {
stream::{PqStream, Stream},
};
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
use super::{auth_quirks, AuthRateLimiter};
struct Auth {
ips: Vec<IpPattern>,
@@ -593,14 +585,6 @@ mod tests {
))
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
_endpoint: crate::EndpointId,
) -> anyhow::Result<Vec<super::jwt::AuthRule>> {
unimplemented!()
}
async fn wake_compute(
&self,
_ctx: &RequestMonitoring,
@@ -611,15 +595,12 @@ mod tests {
}
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: false,
});
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {

View File

@@ -1,6 +1,5 @@
use std::{
future::Future,
marker::PhantomData,
sync::Arc,
time::{Duration, SystemTime},
};
@@ -9,14 +8,11 @@ use anyhow::{bail, ensure, Context};
use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo;
use serde::{de::Visitor, Deserialize, Deserializer};
use serde::{Deserialize, Deserializer};
use signature::Verifier;
use tokio::time::Instant;
use crate::{
context::RequestMonitoring, http::parse_json_body_with_limit, intern::RoleNameInt, EndpointId,
RoleName,
};
use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName};
// TODO(conrad): make these configurable.
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
@@ -31,6 +27,7 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
role_name: RoleName,
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
}
@@ -38,11 +35,10 @@ pub(crate) struct AuthRule {
pub(crate) id: String,
pub(crate) jwks_url: url::Url,
pub(crate) audience: Option<String>,
pub(crate) role_names: Vec<RoleNameInt>,
}
#[derive(Default)]
pub struct JwkCache {
pub(crate) struct JwkCache {
client: reqwest::Client,
map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
@@ -58,28 +54,18 @@ pub(crate) struct JwkCacheEntry {
}
impl JwkCacheEntry {
fn find_jwk_and_audience(
&self,
key_id: &str,
role_name: &RoleName,
) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
self.key_sets
.values()
// make sure our requested role has access to the key set
.filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
// try and find the requested key-id in the key set
.find_map(|key_set| {
key_set
.find_key(key_id)
.map(|jwk| (jwk, key_set.audience.as_deref()))
})
fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
self.key_sets.values().find_map(|key_set| {
key_set
.find_key(key_id)
.map(|jwk| (jwk, key_set.audience.as_deref()))
})
}
}
struct KeySet {
jwks: jose_jwk::JwkSet,
audience: Option<String>,
role_names: Vec<RoleNameInt>,
}
impl KeySet {
@@ -120,6 +106,7 @@ impl JwkCacheEntryLock {
ctx: &RequestMonitoring,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
auth_rules: &F,
) -> anyhow::Result<Arc<JwkCacheEntry>> {
// double check that no one beat us to updating the cache.
@@ -132,10 +119,11 @@ impl JwkCacheEntryLock {
}
}
let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
let rules = auth_rules
.fetch_auth_rules(ctx, endpoint, role_name)
.await?;
let mut key_sets =
ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
// TODO(conrad): run concurrently
// TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
for rule in rules {
@@ -148,15 +136,14 @@ impl JwkCacheEntryLock {
Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"),
Ok(r) => {
let resp: http::Response<reqwest::Body> = r.into();
match parse_json_body_with_limit::<jose_jwk::JwkSet, _>(
PhantomData,
match parse_json_body_with_limit::<jose_jwk::JwkSet>(
resp.into_body(),
MAX_JWK_BODY_SIZE,
)
.await
{
Err(e) => {
tracing::warn!(url=?rule.jwks_url, error=%e, "could not decode JWKs");
tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs");
}
Ok(jwks) => {
key_sets.insert(
@@ -164,7 +151,6 @@ impl JwkCacheEntryLock {
KeySet {
jwks,
audience: rule.audience,
role_names: rule.role_names,
},
);
}
@@ -187,6 +173,7 @@ impl JwkCacheEntryLock {
ctx: &RequestMonitoring,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
fetch: &F,
) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
let now = Instant::now();
@@ -196,7 +183,9 @@ impl JwkCacheEntryLock {
let Some(cached) = guard else {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
return self
.renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
.await;
};
let last_update = now.duration_since(cached.last_retrieved);
@@ -207,7 +196,9 @@ impl JwkCacheEntryLock {
let permit = self.acquire_permit().await;
// it's been too long since we checked the keys. wait for them to update.
return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
return self
.renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
.await;
}
// every 5 minutes we should spawn a job to eagerly update the token.
@@ -221,7 +212,7 @@ impl JwkCacheEntryLock {
let ctx = ctx.clone();
tokio::spawn(async move {
if let Err(e) = entry
.renew_jwks(permit, &ctx, &client, endpoint, &fetch)
.renew_jwks(permit, &ctx, &client, endpoint, role_name, &fetch)
.await
{
tracing::warn!(error=?e, "could not fetch JWKs in background job");
@@ -241,7 +232,7 @@ impl JwkCacheEntryLock {
jwt: &str,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: &RoleName,
role_name: RoleName,
fetch: &F,
) -> Result<(), anyhow::Error> {
// JWT compact form is defined to be
@@ -263,22 +254,30 @@ impl JwkCacheEntryLock {
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
ensure!(header.typ == "JWT");
let kid = header.key_id.context("missing key id")?;
let mut guard = self
.get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
.get_or_update_jwk_cache(ctx, client, endpoint.clone(), role_name.clone(), fetch)
.await?;
// get the key from the JWKs if possible. If not, wait for the keys to update.
let (jwk, expected_audience) = loop {
match guard.find_jwk_and_audience(kid, role_name) {
match guard.find_jwk_and_audience(kid) {
Some(jwk) => break jwk,
None if guard.last_retrieved.elapsed() > MIN_RENEW => {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
guard = self
.renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
.renew_jwks(
permit,
ctx,
client,
endpoint.clone(),
role_name.clone(),
fetch,
)
.await?;
}
_ => {
@@ -297,7 +296,7 @@ impl JwkCacheEntryLock {
verify_ec_signature(header_payload.as_bytes(), &sig, key)?;
}
jose_jwk::Key::Rsa(key) => {
verify_rsa_signature(header_payload.as_bytes(), &sig, key, &header.algorithm)?;
verify_rsa_signature(header_payload.as_bytes(), &sig, key, &jwk.prm.alg)?;
}
key => bail!("unsupported key type {key:?}"),
};
@@ -309,24 +308,23 @@ impl JwkCacheEntryLock {
tracing::debug!(?payload, "JWT signature valid with claims");
if let Some(aud) = expected_audience {
ensure!(
payload.audience.0.iter().any(|s| s == aud),
"invalid JWT token audience"
);
match (expected_audience, payload.audience) {
// check the audience matches
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
// the audience is expected but is missing
(Some(_), None) => bail!("invalid JWT token audience"),
// we don't care for the audience field
(None, _) => {}
}
let now = SystemTime::now();
if let Some(exp) = payload.expiration {
ensure!(now < exp + CLOCK_SKEW_LEEWAY, "JWT token has expired");
ensure!(now < exp + CLOCK_SKEW_LEEWAY);
}
if let Some(nbf) = payload.not_before {
ensure!(
nbf < now + CLOCK_SKEW_LEEWAY,
"JWT token is not yet ready to use"
);
ensure!(nbf < now + CLOCK_SKEW_LEEWAY);
}
Ok(())
@@ -338,7 +336,7 @@ impl JwkCache {
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
role_name: &RoleName,
role_name: RoleName,
fetch: &F,
jwt: &str,
) -> Result<(), anyhow::Error> {
@@ -379,7 +377,7 @@ fn verify_rsa_signature(
data: &[u8],
sig: &[u8],
key: &jose_jwk::Rsa,
alg: &jose_jwa::Algorithm,
alg: &Option<jose_jwa::Algorithm>,
) -> anyhow::Result<()> {
use jose_jwa::{Algorithm, Signing};
use rsa::{
@@ -390,7 +388,7 @@ fn verify_rsa_signature(
let key = RsaPublicKey::try_from(key).map_err(|_| anyhow::anyhow!("invalid RSA key"))?;
match alg {
Algorithm::Signing(Signing::Rs256) => {
Some(Algorithm::Signing(Signing::Rs256)) => {
let key = VerifyingKey::<sha2::Sha256>::new(key);
let sig = Signature::try_from(sig)?;
key.verify(data, &sig)?;
@@ -404,6 +402,9 @@ fn verify_rsa_signature(
/// <https://datatracker.ietf.org/doc/html/rfc7515#section-4.1>
#[derive(serde::Deserialize, serde::Serialize)]
struct JwtHeader<'a> {
/// must be "JWT"
#[serde(rename = "typ")]
typ: &'a str,
/// must be a supported alg
#[serde(rename = "alg")]
algorithm: jose_jwa::Algorithm,
@@ -413,12 +414,11 @@ struct JwtHeader<'a> {
}
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
#[derive(serde::Deserialize, Debug)]
#[allow(dead_code)]
#[derive(serde::Deserialize, serde::Serialize, Debug)]
struct JwtPayload<'a> {
/// Audience - Recipient for which the JWT is intended
#[serde(rename = "aud", default)]
audience: OneOrMany,
#[serde(rename = "aud")]
audience: Option<&'a str>,
/// Expiration - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
expiration: Option<SystemTime>,
@@ -441,59 +441,6 @@ struct JwtPayload<'a> {
session_id: Option<&'a str>,
}
/// `OneOrMany` supports parsing either a single item or an array of items.
///
/// Needed for <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1.3>
///
/// > The "aud" (audience) claim identifies the recipients that the JWT is
/// > intended for. Each principal intended to process the JWT MUST
/// > identify itself with a value in the audience claim. If the principal
/// > processing the claim does not identify itself with a value in the
/// > "aud" claim when this claim is present, then the JWT MUST be
/// > rejected. In the general case, the "aud" value is **an array of case-
/// > sensitive strings**, each containing a StringOrURI value. In the
/// > special case when the JWT has one audience, the "aud" value MAY be a
/// > **single case-sensitive string** containing a StringOrURI value. The
/// > interpretation of audience values is generally application specific.
/// > Use of this claim is OPTIONAL.
#[derive(Default, Debug)]
struct OneOrMany(Vec<String>);
impl<'de> Deserialize<'de> for OneOrMany {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct OneOrManyVisitor;
impl<'de> Visitor<'de> for OneOrManyVisitor {
type Value = OneOrMany;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a single string or an array of strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(OneOrMany(vec![v.to_owned()]))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut v = vec![];
while let Some(s) = seq.next_element()? {
v.push(s);
}
Ok(OneOrMany(v))
}
}
deserializer.deserialize_any(OneOrManyVisitor)
}
}
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
let d = <Option<u64>>::deserialize(d)?;
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
@@ -587,6 +534,7 @@ mod tests {
key: jose_jwk::Key::Ec(pk),
prm: jose_jwk::Parameters {
kid: Some(kid),
alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Es256)),
..Default::default()
},
};
@@ -600,6 +548,7 @@ mod tests {
key: jose_jwk::Key::Rsa(pk),
prm: jose_jwk::Parameters {
kid: Some(kid),
alg: Some(jose_jwa::Algorithm::Signing(jose_jwa::Signing::Rs256)),
..Default::default()
},
};
@@ -608,6 +557,7 @@ mod tests {
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
let header = JwtHeader {
typ: "JWT",
algorithm: jose_jwa::Algorithm::Signing(sig),
key_id: Some(&kid),
};
@@ -622,7 +572,7 @@ mod tests {
format!("{header}.{body}")
}
fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String {
use p256::ecdsa::{Signature, SigningKey};
let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
@@ -710,6 +660,11 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
let (ec1, jwk3) = new_ec_jwk("3".into());
let (ec2, jwk4) = new_ec_jwk("4".into());
let jwt1 = new_rsa_jwt("1".into(), rs1);
let jwt2 = new_rsa_jwt("2".into(), rs2);
let jwt3 = new_ec_jwt("3".into(), ec1);
let jwt4 = new_ec_jwt("4".into(), ec2);
let foo_jwks = jose_jwk::JwkSet {
keys: vec![jwk1, jwk3],
};
@@ -751,98 +706,47 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
let client = reqwest::Client::new();
#[derive(Clone)]
struct Fetch(SocketAddr, Vec<RoleNameInt>);
struct Fetch(SocketAddr);
impl FetchAuthRules for Fetch {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
_role_name: RoleName,
) -> anyhow::Result<Vec<AuthRule>> {
Ok(vec![
AuthRule {
id: "foo".to_owned(),
jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
audience: None,
role_names: self.1.clone(),
},
AuthRule {
id: "bar".to_owned(),
jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
audience: None,
role_names: self.1.clone(),
},
])
}
}
let role_name1 = RoleName::from("anonymous");
let role_name2 = RoleName::from("authenticated");
let fetch = Fetch(
addr,
vec![
RoleNameInt::from(&role_name1),
RoleNameInt::from(&role_name2),
],
);
let role_name = RoleName::from("user");
let endpoint = EndpointId::from("ep");
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
let jwt1 = new_rsa_jwt("1".into(), rs1);
let jwt2 = new_rsa_jwt("2".into(), rs2);
let jwt3 = new_ec_jwt("3".into(), &ec1);
let jwt4 = new_ec_jwt("4".into(), &ec2);
// had the wrong kid, therefore will have the wrong ecdsa signature
let bad_jwt = new_ec_jwt("3".into(), &ec2);
// this role_name is not accepted
let bad_role_name = RoleName::from("cloud_admin");
let err = jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&bad_jwt,
&client,
endpoint.clone(),
&role_name1,
&fetch,
)
.await
.unwrap_err();
assert!(err.to_string().contains("signature error"));
let err = jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&jwt1,
&client,
endpoint.clone(),
&bad_role_name,
&fetch,
)
.await
.unwrap_err();
assert!(err.to_string().contains("jwk not found"));
let tokens = [jwt1, jwt2, jwt3, jwt4];
let role_names = [role_name1, role_name2];
for role in &role_names {
for token in &tokens {
jwk_cache
.check_jwt(
&RequestMonitoring::test(),
token,
&client,
endpoint.clone(),
role,
&fetch,
)
.await
.unwrap();
}
for token in [jwt1, jwt2, jwt3, jwt4] {
jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&token,
&client,
endpoint.clone(),
role_name.clone(),
&Fetch(addr),
)
.await
.unwrap();
}
}
}

View File

@@ -1,4 +1,4 @@
use std::net::SocketAddr;
use std::{collections::HashMap, net::SocketAddr};
use anyhow::Context;
use arc_swap::ArcSwapOption;
@@ -10,19 +10,21 @@ use crate::{
NodeInfo,
},
context::RequestMonitoring,
intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag},
EndpointId,
intern::{BranchIdInt, BranchIdTag, EndpointIdTag, InternId, ProjectIdInt, ProjectIdTag},
EndpointId, RoleName,
};
use super::jwt::{AuthRule, FetchAuthRules};
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
pub struct LocalBackend {
pub(crate) jwks_cache: JwkCache,
pub(crate) node_info: NodeInfo,
}
impl LocalBackend {
pub fn new(postgres_addr: SocketAddr) -> Self {
LocalBackend {
jwks_cache: JwkCache::default(),
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();
@@ -46,17 +48,26 @@ impl LocalBackend {
#[derive(Clone, Copy)]
pub(crate) struct StaticAuthRules;
pub static JWKS_ROLE_MAP: ArcSwapOption<EndpointJwksResponse> = ArcSwapOption::const_empty();
pub static JWKS_ROLE_MAP: ArcSwapOption<JwksRoleSettings> = ArcSwapOption::const_empty();
#[derive(Debug, Clone)]
pub struct JwksRoleSettings {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
}
impl FetchAuthRules for StaticAuthRules {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
role_name: RoleName,
) -> anyhow::Result<Vec<AuthRule>> {
let mappings = JWKS_ROLE_MAP.load();
let role_mappings = mappings
.as_deref()
.and_then(|m| m.roles.get(&role_name))
.context("JWKs settings for this role were not configured")?;
let mut rules = vec![];
for setting in &role_mappings.jwks {
@@ -64,7 +75,6 @@ impl FetchAuthRules for StaticAuthRules {
id: setting.id.clone(),
jwks_url: setting.jwks_url.clone(),
audience: setting.jwt_audience.clone(),
role_names: setting.role_names.clone(),
});
}

View File

@@ -1,38 +1,34 @@
use std::{net::SocketAddr, pin::pin, str::FromStr, sync::Arc, time::Duration};
use std::{
net::SocketAddr,
path::{Path, PathBuf},
pin::pin,
sync::Arc,
time::Duration,
};
use anyhow::{bail, ensure, Context};
use camino::{Utf8Path, Utf8PathBuf};
use compute_api::spec::LocalProxySpec;
use anyhow::{bail, ensure};
use dashmap::DashMap;
use futures::future::Either;
use futures::{future::Either, FutureExt};
use proxy::{
auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
auth::backend::local::{JwksRoleSettings, LocalBackend, JWKS_ROLE_MAP},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
console::{
locks::ApiLocks,
messages::{EndpointJwksResponse, JwksSettings},
},
console::{locks::ApiLocks, messages::JwksRoleMapping},
http::health_server::AppMetrics,
intern::RoleNameInt,
metrics::{Metrics, ThreadPoolMetrics},
rate_limiter::{BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo},
scram::threadpool::ThreadPool,
serverless::{self, cancel_set::CancelSet, GlobalConnPoolOptions},
RoleName,
};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
use clap::Parser;
use tokio::{net::TcpListener, sync::Notify, task::JoinSet};
use tokio::{net::TcpListener, task::JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use utils::{pid_file, project_build_tag, project_git_version, sentry_init::init_sentry};
use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
@@ -76,12 +72,9 @@ struct LocalProxyCliArgs {
/// Address of the postgres server
#[clap(long, default_value = "127.0.0.1:5432")]
compute: SocketAddr,
/// Path of the local proxy config file
/// File address of the local proxy config file
#[clap(long, default_value = "./localproxy.json")]
config_path: Utf8PathBuf,
/// Path of the local proxy PID file
#[clap(long, default_value = "./localproxy.pid")]
pid_path: Utf8PathBuf,
config_path: PathBuf,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -133,24 +126,6 @@ async fn main() -> anyhow::Result<()> {
let args = LocalProxyCliArgs::parse();
let config = build_config(&args)?;
// before we bind to any ports, write the process ID to a file
// so that compute-ctl can find our process later
// in order to trigger the appropriate SIGHUP on config change.
//
// This also claims a "lock" that makes sure only one instance
// of local-proxy runs at a time.
let _process_guard = loop {
match pid_file::claim_for_current_process(&args.pid_path) {
Ok(guard) => break guard,
Err(e) => {
// compute-ctl might have tried to read the pid-file to let us
// know about some config change. We should try again.
error!(path=?args.pid_path, "could not claim PID file guard: {e:?}");
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
};
let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
let http_listener = TcpListener::bind(args.http).await?;
let shutdown = CancellationToken::new();
@@ -164,30 +139,12 @@ async fn main() -> anyhow::Result<()> {
16,
));
// write the process ID to a file so that compute-ctl can find our process later
// in order to trigger the appropriate SIGHUP on config change.
let pid = std::process::id();
info!("process running in PID {pid}");
std::fs::write(args.pid_path, format!("{pid}\n")).context("writing PID to file")?;
refresh_config(args.config_path.clone()).await;
let mut maintenance_tasks = JoinSet::new();
let refresh_config_notify = Arc::new(Notify::new());
maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), {
let refresh_config_notify = Arc::clone(&refresh_config_notify);
move || {
refresh_config_notify.notify_one();
}
maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), move || {
refresh_config(args.config_path.clone()).map(Ok)
}));
// trigger the first config load **after** setting up the signal hook
// to avoid the race condition where:
// 1. No config file registered when local-proxy starts up
// 2. The config file is written but the signal hook is not yet received
// 3. local-proxy completes startup but has no config loaded, despite there being a registerd config.
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(args.config_path, refresh_config_notify));
maintenance_tasks.spawn(proxy::http::health_server::task_main(
metrics_listener,
AppMetrics {
@@ -270,17 +227,14 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
allow_self_signed_compute: false,
http_config,
authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: true,
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
require_client_ip: false,
handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
@@ -291,84 +245,81 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
})))
}
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
loop {
rx.notified().await;
match refresh_config_inner(&path).await {
Ok(()) => {}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
async fn refresh_config(path: PathBuf) {
match refresh_config_inner(&path).await {
Ok(()) => {}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
}
async fn refresh_config_inner(path: &Utf8Path) -> anyhow::Result<()> {
async fn refresh_config_inner(path: &Path) -> anyhow::Result<()> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut data: JwksRoleMapping = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
let mut settings = None;
for jwks in data.jwks {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
for mapping in data.roles.values_mut() {
for jwks in &mut mapping.jwks {
ensure!(
jwks.jwks_url.has_authority()
&& (jwks.jwks_url.scheme() == "http" || jwks.jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks.jwks_url
.host()
.is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
// clear username, password and ports
jwks.jwks_url.set_username("").expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
jwks.jwks_url.set_password(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
jwks.jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
}
jwks_set.push(JwksSettings {
id: jwks.id,
jwks_url,
provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
// clear query params
jwks.jwks_url.set_fragment(None);
jwks.jwks_url.query_pairs_mut().clear().finish();
if jwks.jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks.jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks.jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
let (pr, br) = settings.get_or_insert((jwks.project_id, jwks.branch_id));
ensure!(
*pr == jwks.project_id,
"inconsistent project IDs configured"
);
ensure!(*br == jwks.branch_id, "inconsistent branch IDs configured");
}
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some((project_id, branch_id)) = settings {
JWKS_ROLE_MAP.store(Some(Arc::new(JwksRoleSettings {
roles: data.roles,
project_id,
branch_id,
})));
}
Ok(())
}

View File

@@ -133,7 +133,9 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
));
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || {}));
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || async {
Ok(())
}));
// the signal task cant ever succeed.
// the main task can error, or can succeed on cancellation.

View File

@@ -8,7 +8,6 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_config::Region;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
@@ -18,7 +17,6 @@ use proxy::config::AuthenticationConfig;
use proxy::config::CacheOptions;
use proxy::config::HttpConfig;
use proxy::config::ProjectInfoCacheOptions;
use proxy::config::ProxyProtocolV2;
use proxy::console;
use proxy::context::parquet::ParquetUploadArgs;
use proxy::http;
@@ -104,9 +102,6 @@ struct ProxyCliArgs {
default_value = "http://localhost:3000/authenticate_proxy_request/"
)]
auth_endpoint: String,
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_auth_broker: bool,
/// path to TLS key for client postgres connections
///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
@@ -149,6 +144,9 @@ struct ProxyCliArgs {
/// size of the threadpool for password hashing
#[clap(long, default_value_t = 4)]
scram_thread_pool_size: u8,
/// Require that all incoming requests have a Proxy Protocol V2 packet **and** have an IP address associated.
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
require_client_ip: bool,
/// Disable dynamic rate limiter and store the metrics to ensure its production behaviour.
#[clap(long, default_value_t = true, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
disable_dynamic_rate_limiter: bool,
@@ -231,11 +229,6 @@ struct ProxyCliArgs {
/// Configure if this is a private access proxy for the POC: In that case the proxy will ignore the IP allowlist
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_private_access_proxy: bool,
/// Configure whether all incoming requests have a Proxy Protocol V2 packet.
// TODO(conradludgate): switch default to rejected or required once we've updated all deployments
#[clap(value_enum, long, default_value_t = ProxyProtocolV2::Supported)]
proxy_protocol_v2: ProxyProtocolV2,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -297,7 +290,6 @@ async fn main() -> anyhow::Result<()> {
build_tag: BUILD_TAG,
});
proxy::jemalloc::inspect_thp()?;
let jemalloc = match proxy::jemalloc::MetricRecorder::new() {
Ok(t) => Some(t),
Err(e) => {
@@ -390,27 +382,9 @@ async fn main() -> anyhow::Result<()> {
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_listener = if !args.is_auth_broker {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
Some(TcpListener::bind(proxy_address).await?)
} else {
None
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
None
};
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let cancellation_token = CancellationToken::new();
let cancel_map = CancelMap::default();
@@ -456,17 +430,21 @@ async fn main() -> anyhow::Result<()> {
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
let serverless_listener = TcpListener::bind(serverless_address).await?;
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
serverless_listener,
@@ -483,7 +461,10 @@ async fn main() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(proxy::handle_signals(
cancellation_token.clone(),
|| async { Ok(()) },
));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
@@ -696,7 +677,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
)?;
let http_config = HttpConfig {
accept_websockets: !args.is_auth_broker,
accept_websockets: true,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
@@ -711,15 +692,12 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
};
let config = Box::leak(Box::new(ProxyConfig {
@@ -729,7 +707,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
allow_self_signed_compute: args.allow_self_signed_compute,
http_config,
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
require_client_ip: args.require_client_ip,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,

View File

@@ -1,8 +1,5 @@
use crate::{
auth::{
self,
backend::{jwt::JwkCache, AuthRateLimiter},
},
auth::{self, backend::AuthRateLimiter},
console::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
@@ -10,7 +7,6 @@ use crate::{
Host,
};
use anyhow::{bail, ensure, Context, Ok};
use clap::ValueEnum;
use itertools::Itertools;
use remote_storage::RemoteStorageConfig;
use rustls::{
@@ -34,7 +30,7 @@ pub struct ProxyConfig {
pub allow_self_signed_compute: bool,
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub require_client_ip: bool,
pub region: String,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
@@ -42,16 +38,6 @@ pub struct ProxyConfig {
pub connect_to_compute_retry_config: RetryConfig,
}
#[derive(Copy, Clone, Debug, ValueEnum, PartialEq)]
pub enum ProxyProtocolV2 {
/// Connection will error if PROXY protocol v2 header is missing
Required,
/// Connection will parse PROXY protocol v2 header, but accept the connection if it's missing.
Supported,
/// Connection will error if PROXY protocol v2 header is provided
Rejected,
}
#[derive(Debug)]
pub struct MetricCollectionConfig {
pub endpoint: reqwest::Url,
@@ -81,9 +67,6 @@ pub struct AuthenticationConfig {
pub rate_limiter: AuthRateLimiter,
pub rate_limit_ip_subnet: u8,
pub ip_allowlist_check_enabled: bool,
pub jwks_cache: JwkCache,
pub is_auth_broker: bool,
pub accept_jwts: bool,
}
impl TlsConfig {
@@ -267,26 +250,18 @@ impl CertResolver {
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
// We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
let common_name = if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
} else {
bail!("Failed to parse common name from certificate")
};
common_name.strip_prefix("CN=").map(|s| s.to_string())
}
.context("Failed to parse common name from certificate")?;
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));

View File

@@ -1,11 +1,13 @@
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::{self, Display};
use crate::auth::IpPattern;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
use crate::proxy::retry::CouldRetry;
use crate::RoleName;
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
@@ -346,6 +348,11 @@ impl ColdStartInfo {
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct JwksRoleMapping {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct EndpointJwksResponse {
pub jwks: Vec<JwksSettings>,
@@ -354,10 +361,11 @@ pub struct EndpointJwksResponse {
#[derive(Debug, Deserialize, Clone)]
pub struct JwksSettings {
pub id: String,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
pub jwks_url: url::Url,
pub provider_name: String,
pub jwt_audience: Option<String>,
pub role_names: Vec<RoleNameInt>,
}
#[cfg(test)]

View File

@@ -5,10 +5,7 @@ pub mod neon;
use super::messages::{ConsoleError, MetricsAuxInfo};
use crate::{
auth::{
backend::{
jwt::{AuthRule, FetchAuthRules},
ComputeCredentialKeys, ComputeUserInfo,
},
backend::{ComputeCredentialKeys, ComputeUserInfo},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
@@ -19,7 +16,7 @@ use crate::{
intern::ProjectIdInt,
metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey, EndpointId,
scram, EndpointCacheKey,
};
use dashmap::DashMap;
use std::{hash::Hash, sync::Arc, time::Duration};
@@ -337,12 +334,6 @@ pub(crate) trait Api {
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
@@ -352,7 +343,6 @@ pub(crate) trait Api {
}
#[non_exhaustive]
#[derive(Clone)]
pub enum ConsoleBackend {
/// Current Cloud API (V2).
Console(neon::Api),
@@ -396,20 +386,6 @@ impl Api for ConsoleBackend {
}
}
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
match self {
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(any(test, feature = "testing"))]
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(test)]
Self::Test(_api) => Ok(vec![]),
}
}
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
@@ -576,13 +552,3 @@ impl WakeComputePermit {
res
}
}
impl FetchAuthRules for ConsoleBackend {
async fn fetch_auth_rules(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.get_endpoint_jwks(ctx, endpoint).await
}
}

View File

@@ -4,9 +4,7 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::{
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
};
use crate::context::RequestMonitoring;
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use crate::{
@@ -120,39 +118,6 @@ impl Api {
})
}
async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
let connection = tokio::spawn(connection);
let res = client.query(
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
&[&endpoint.as_str()],
)
.await?;
let mut rows = vec![];
for row in res {
rows.push(AuthRule {
id: row.get("id"),
jwks_url: url::Url::parse(row.get("jwks_url"))?,
audience: row.get("audience"),
role_names: row
.get::<_, Vec<String>>("role_names")
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
});
}
drop(client);
connection.await??;
Ok(rows)
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
@@ -220,14 +185,6 @@ impl super::Api for Api {
))
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(endpoint).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -7,33 +7,27 @@ use super::{
NodeInfo,
};
use crate::{
auth::backend::{jwt::AuthRule, ComputeUserInfo},
auth::backend::ComputeUserInfo,
compute,
console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
console::messages::{ColdStartInfo, Reason},
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey, EndpointId,
scram, EndpointCacheKey,
};
use crate::{cache::Cached, context::RequestMonitoring};
use ::http::{header::AUTHORIZATION, HeaderName};
use anyhow::bail;
use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, error, info, info_span, warn, Instrument};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]
pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
// put in a shared ref so we don't copy secrets all over in memory
jwt: Arc<str>,
jwt: String,
}
impl Api {
@@ -44,9 +38,7 @@ impl Api {
locks: &'static ApiLocks<EndpointCacheKey>,
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
) -> Self {
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
.unwrap_or_default()
.into();
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default();
Self {
endpoint,
caches,
@@ -79,9 +71,9 @@ impl Api {
async {
let request = self
.endpoint
.get_path("proxy_get_role_secret")
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
@@ -133,61 +125,6 @@ impl Api {
.await
}
async fn do_get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &endpoint.normalize())
.await
{
bail!("endpoint not found");
}
let request_id = ctx.session_id().to_string();
async {
let request = self
.endpoint
.get_with_url(|url| {
url.path_segments_mut()
.push("endpoints")
.push(endpoint.as_str())
.push("jwks");
})
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<EndpointJwksResponse>(response).await?;
let rules = body
.jwks
.into_iter()
.map(|jwks| AuthRule {
id: jwks.id,
jwks_url: jwks.jwks_url,
audience: jwks.jwt_audience,
role_names: jwks.role_names,
})
.collect();
Ok(rules)
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}
async fn do_wake_compute(
&self,
ctx: &RequestMonitoring,
@@ -198,7 +135,7 @@ impl Api {
async {
let mut request_builder = self
.endpoint
.get_path("proxy_wake_compute")
.get("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
@@ -325,15 +262,6 @@ impl super::Api for Api {
))
}
#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(ctx, endpoint).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -6,10 +6,11 @@ pub mod health_server;
use std::time::Duration;
use anyhow::bail;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper1::body::Body;
use serde::de::DeserializeSeed;
use serde::de::DeserializeOwned;
pub(crate) use reqwest::{Request, Response};
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
@@ -85,17 +86,9 @@ impl Endpoint {
/// Return a [builder](RequestBuilder) for a `GET` request,
/// appending a single `path` segment to the base endpoint URL.
pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
self.get_with_url(|u| {
u.path_segments_mut().push(path);
})
}
/// Return a [builder](RequestBuilder) for a `GET` request,
/// accepting a closure to modify the url path segments for more complex paths queries.
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
pub(crate) fn get(&self, path: &str) -> RequestBuilder {
let mut url = self.endpoint.clone();
f(&mut url);
url.path_segments_mut().push(path);
self.client.get(url.into_inner())
}
@@ -112,21 +105,10 @@ impl Endpoint {
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError<E> {
#[error("could not read the HTTP body: {0}")]
Read(E),
#[error("could not parse the HTTP body: {0}")]
Parse(#[from] serde_json::Error),
#[error("could not parse the HTTP body: content length exceeds limit of {0} bytes")]
LengthExceeded(usize),
}
pub(crate) async fn parse_json_body_with_limit<D, E>(
seed: impl for<'de> DeserializeSeed<'de, Value = D>,
mut b: impl Body<Data = Bytes, Error = E> + Unpin,
pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
limit: usize,
) -> Result<D, ReadPayloadError<E>> {
) -> anyhow::Result<D> {
// We could use `b.limited().collect().await.to_bytes()` here
// but this ends up being slightly more efficient as far as I can tell.
@@ -134,25 +116,20 @@ pub(crate) async fn parse_json_body_with_limit<D, E>(
// in reqwest, this value is influenced by the Content-Length header.
let lower_bound = match usize::try_from(b.size_hint().lower()) {
Ok(bound) if bound <= limit => bound,
_ => return Err(ReadPayloadError::LengthExceeded(limit)),
_ => bail!("Content length exceeds limit of {limit} bytes"),
};
let mut bytes = Vec::with_capacity(lower_bound);
while let Some(frame) = b
.frame()
.await
.transpose()
.map_err(ReadPayloadError::Read)?
{
while let Some(frame) = b.frame().await.transpose()? {
if let Ok(data) = frame.into_data() {
if bytes.len() + data.len() > limit {
return Err(ReadPayloadError::LengthExceeded(limit));
bail!("Content length exceeds limit of {limit} bytes")
}
bytes.extend_from_slice(&data);
}
}
Ok(seed.deserialize(&mut serde_json::Deserializer::from_slice(&bytes))?)
Ok(serde_json::from_slice::<D>(&bytes)?)
}
#[cfg(test)]
@@ -167,7 +144,7 @@ mod tests {
// Validate that this pattern makes sense.
let req = endpoint
.get_path("frobnicate")
.get("frobnicate")
.query(&[
("foo", Some("10")), // should be just `foo=10`
("bar", None), // shouldn't be passed at all
@@ -185,7 +162,7 @@ mod tests {
let endpoint = Endpoint::new(url, Client::new());
let req = endpoint
.get_path("frobnicate")
.get("frobnicate")
.query(&[("session_id", uuid::Uuid::nil())])
.build()?;

View File

@@ -130,14 +130,14 @@ impl<Id: InternId> Default for StringInterner<Id> {
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct RoleNameTag;
pub(crate) struct RoleNameTag;
impl InternId for RoleNameTag {
fn get_interner() -> &'static StringInterner<Self> {
static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
pub type RoleNameInt = InternedString<RoleNameTag>;
pub(crate) type RoleNameInt = InternedString<RoleNameTag>;
impl From<&RoleName> for RoleNameInt {
fn from(value: &RoleName) -> Self {
RoleNameTag::get_interner().get_or_intern(value)

View File

@@ -9,8 +9,7 @@ use measured::{
text::TextEncoder,
LabelGroup, MetricGroup,
};
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version, Access, AsName, Name};
use tracing::info;
use tikv_jemalloc_ctl::{config, epoch, epoch_mib, stats, version};
pub struct MetricRecorder {
epoch: epoch_mib,
@@ -115,10 +114,3 @@ jemalloc_gauge!(mapped, mapped_mib);
jemalloc_gauge!(metadata, metadata_mib);
jemalloc_gauge!(resident, resident_mib);
jemalloc_gauge!(retained, retained_mib);
pub fn inspect_thp() -> Result<(), tikv_jemalloc_ctl::Error> {
let opt_thp: &Name = c"opt.thp".to_bytes_with_nul().name();
let s: &str = opt_thp.read()?;
info!("jemalloc opt.thp {s}");
Ok(())
}

View File

@@ -82,7 +82,7 @@
impl_trait_overcaptures,
)]
use std::convert::Infallible;
use std::{convert::Infallible, future::Future};
use anyhow::{bail, Context};
use intern::{EndpointIdInt, EndpointIdTag, InternId};
@@ -117,12 +117,13 @@ pub mod usage_metrics;
pub mod waiters;
/// Handle unix signals appropriately.
pub async fn handle_signals<F>(
pub async fn handle_signals<F, Fut>(
token: CancellationToken,
mut refresh_config: F,
) -> anyhow::Result<Infallible>
where
F: FnMut(),
F: FnMut() -> Fut,
Fut: Future<Output = anyhow::Result<()>>,
{
use tokio::signal::unix::{signal, SignalKind};
@@ -135,7 +136,7 @@ where
// Hangup is commonly used for config reload.
_ = hangup.recv() => {
warn!("received SIGHUP");
refresh_config();
refresh_config().await?;
}
// Shut down the whole application.
_ = interrupt.recv() => {

View File

@@ -1,3 +1,4 @@
use tracing_opentelemetry::OpenTelemetryLayer;
use tracing_subscriber::{
filter::{EnvFilter, LevelFilter},
prelude::*,
@@ -22,7 +23,9 @@ pub async fn init() -> anyhow::Result<LoggingGuard> {
.with_writer(std::io::stderr)
.with_target(false);
let otlp_layer = tracing_utils::init_tracing("proxy").await;
let otlp_layer = tracing_utils::init_tracing("proxy")
.await
.map(OpenTelemetryLayer::new);
tracing_subscriber::registry()
.with(env_filter)

View File

@@ -10,7 +10,6 @@ pub(crate) mod wake_compute;
pub use copy_bidirectional::copy_bidirectional_client_compute;
pub use copy_bidirectional::ErrorSource;
use crate::config::ProxyProtocolV2;
use crate::{
auth,
cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
@@ -94,19 +93,15 @@ pub async fn task_main(
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Ok((socket, Some(addr))) => (socket, addr.ip()),
Err(e) => {
error!("per-client task finished with an error: {e:#}");
return;
}
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
error!("missing required proxy protocol header");
Ok((_socket, None)) if config.require_client_ip => {
error!("missing required client IP");
return;
}
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
error!("proxy protocol header not supported");
return;
}
Ok((socket, Some(addr))) => (socket, addr.ip()),
Ok((socket, None)) => (socket, peer_addr.ip()),
};

View File

@@ -525,10 +525,6 @@ impl TestBackend for TestConnectMechanism {
{
unimplemented!("not used in tests")
}
fn dyn_clone(&self) -> Box<dyn TestBackend> {
Box::new(self.clone())
}
}
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {

View File

@@ -43,13 +43,6 @@ impl ThreadPool {
pub fn new(n_workers: u8) -> Arc<Self> {
// rayon would be nice here, but yielding in rayon does not work well afaict.
if n_workers == 0 {
return Arc::new(Self {
runtime: None,
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
});
}
Arc::new_cyclic(|pool| {
let pool = pool.clone();
let worker_id = AtomicUsize::new(0);

View File

@@ -5,7 +5,6 @@
mod backend;
pub mod cancel_set;
mod conn_pool;
mod http_conn_pool;
mod http_util;
mod json;
mod sql_over_http;
@@ -20,8 +19,7 @@ use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use http_body_util::Full;
use hyper1::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
@@ -83,28 +81,7 @@ pub async fn task_main(
}
});
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
{
let http_conn_pool = Arc::clone(&http_conn_pool);
tokio::spawn(async move {
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
});
}
// shutdown the connection pool
tokio::spawn({
let cancellation_token = cancellation_token.clone();
let http_conn_pool = http_conn_pool.clone();
async move {
cancellation_token.cancelled().await;
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
.await
.unwrap();
}
});
let backend = Arc::new(PoolingBackend {
http_conn_pool: Arc::clone(&http_conn_pool),
pool: Arc::clone(&conn_pool),
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
@@ -365,7 +342,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let host = request
.headers()
.get("host")
@@ -409,7 +386,7 @@ async fn request_handler(
);
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
@@ -432,7 +409,7 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Empty::new().map_err(|x| match x {}).boxed())
.body(Full::new(Bytes::new()))
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")

View File

@@ -1,8 +1,6 @@
use std::{io, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use tokio::net::{lookup_host, TcpStream};
use tracing::{field::display, info};
use crate::{
@@ -29,13 +27,9 @@ use crate::{
Host,
};
use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -109,44 +103,32 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<(), AuthError> {
jwt: &str,
) -> Result<ComputeCredentials, AuthError> {
match &self.config.auth_backend {
crate::auth::Backend::Console(console, ()) => {
config
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
&user_info.user,
&**console,
&jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(())
crate::auth::Backend::Console(_, ()) => {
Err(AuthError::auth_failed("JWT login is not yet supported"))
}
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(_) => {
config
crate::auth::Backend::Local(cache) => {
cache
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
&user_info.user,
user_info.user.clone(),
&StaticAuthRules,
&jwt,
jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
Ok(())
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
}
}
}
@@ -192,55 +174,14 @@ impl PoolingBackend {
)
.await
}
// Wake up the destination if needed
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client, HttpConnError> {
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connection to postgres in compute")]
PostgresConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to local-proxy in compute")]
LocalProxyConnectionError(#[from] LocalProxyConnError),
#[error("could not connection to compute")]
ConnectionError(#[from] tokio_postgres::Error),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
@@ -252,20 +193,11 @@ pub(crate) enum HttpConnError {
TooManyConnectionAttempts(#[from] ApiLockError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper1::Error),
}
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::ConnectionError(p) => p.get_error_kind(),
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
HttpConnError::WakeCompute(w) => w.get_error_kind(),
@@ -278,8 +210,7 @@ impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
HttpConnError::ConnectionError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
HttpConnError::WakeCompute(c) => c.to_string_client(),
@@ -293,8 +224,7 @@ impl UserFacingError for HttpConnError {
impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ConnectionError(e) => e.could_retry(),
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
@@ -306,7 +236,7 @@ impl CouldRetry for HttpConnError {
impl ShouldRetryWakeCompute for HttpConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
// we never checked cache validity
HttpConnError::TooManyConnectionAttempts(_) => false,
_ => true,
@@ -314,38 +244,6 @@ impl ShouldRetryWakeCompute for HttpConnError {
}
}
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
}
impl UserFacingError for LocalProxyConnError {
fn to_string_client(&self) -> String {
"Could not establish HTTP connection to the database".to_string()
}
}
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
conn_info: ConnInfo,
@@ -395,99 +293,3 @@ impl ConnectMechanism for TokioMechanism {
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
struct HyperMechanism {
pool: Arc<http_conn_pool::GlobalConnPool>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestMonitoring,
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let permit = self.locks.get_permit(&host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
let res = connect_http2(&host, 10432, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(
host: &str,
port: u16,
timeout: Duration,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
// assumption: host is an ip address so this should not actually perform any requests.
// todo: add that assumption as a guarantee in the control-plane API.
let mut addrs = lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?;
let mut last_err = None;
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
};
};
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;
Ok((client, connection))
}

View File

@@ -1,342 +0,0 @@
use dashmap::DashMap;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::{sync::Arc, sync::Weak};
use tokio::net::TcpStream;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, EndpointCacheKey};
use tracing::{debug, error};
use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
#[derive(Clone)]
struct ConnPoolEntry {
conn: Send,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool {
// TODO(conrad):
// either we should open more connections depending on stream count
// (not exposed by hyper, need our own counter)
// or we can change this to an Option rather than a VecDeque.
//
// Opening more connections to the same db because we run out of streams
// seems somewhat redundant though.
//
// Probably we should run a semaphore and just the single conn. TBD.
conns: VecDeque<ConnPoolEntry>,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
}
impl EndpointConnPool {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
let Self { conns, .. } = self;
loop {
let conn = conns.pop_front()?;
if !conn.conn.is_closed() {
conns.push_back(conn.clone());
return Some(conn);
}
}
}
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
let Self {
conns,
global_connections_count,
..
} = self;
let old_len = conns.len();
conns.retain(|conn| conn.conn_id != conn_id);
let new_len = conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
removed > 0
}
}
impl Drop for EndpointConnPool {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.conns.len() as i64);
}
}
}
pub(crate) struct GlobalConnPool {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
impl GlobalConnPool {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool { conns, .. } = pool.get_mut();
let old_len = conns.len();
conns.retain(|conn| !conn.conn.is_closed());
let new_len = conns.len();
let removed = old_len - new_len;
clients_removed += removed;
// we only remove this pool if it has no active connections
if conns.is_empty() {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Option<Client> {
let endpoint = conn_info.endpoint_cache_key()?;
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let client = endpoint_pool.write().get_conn_entry()?;
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Some(Client::new(client.conn, client.aux))
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
conns: VecDeque::new(),
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => {
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
pool.write().conns.push_back(ConnPoolEntry {
conn: client.clone(),
conn_id,
aux: aux.clone(),
});
Arc::downgrade(&pool)
}
None => Weak::new(),
};
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!(%session_id, "connection error: {}", e),
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
}
}
.instrument(span),
);
Client::new(client, aux)
}
pub(crate) struct Client {
pub(crate) inner: Send,
aux: MetricsAuxInfo,
}
impl Client {
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
Self { inner, aux }
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
USAGE_METRICS.register(Ids {
endpoint_id: self.aux.endpoint_id,
branch_id: self.aux.branch_id,
})
}
}

View File

@@ -5,13 +5,13 @@ use bytes::Bytes;
use anyhow::Context;
use http::{Response, StatusCode};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use http_body_util::Full;
use serde::Serialize;
use utils::http::error::ApiError;
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
@@ -64,24 +64,17 @@ struct HttpErrorBody {
impl HttpErrorBody {
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
fn response_from_msg_and_status(
msg: String,
status: StatusCode,
) -> Response<BoxBody<Bytes, hyper1::Error>> {
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
HttpErrorBody { msg }.to_response(status)
}
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
.map_err(|x| match x {})
.boxed(),
)
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
.unwrap()
}
}
@@ -90,14 +83,14 @@ impl HttpErrorBody {
pub(crate) fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;
let response = Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
.body(Full::new(Bytes::from(json)))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}

View File

@@ -1,534 +1,18 @@
use std::fmt;
use std::marker::PhantomData;
use std::ops::Range;
use itertools::Itertools;
use serde::de;
use serde::de::DeserializeSeed;
use serde::Deserialize;
use serde::Deserializer;
use serde_json::Map;
use serde_json::Value;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::Row;
use super::sql_over_http::BatchQueryData;
use super::sql_over_http::Payload;
use super::sql_over_http::QueryData;
#[derive(Clone, Copy)]
pub struct Slice {
pub start: u32,
pub len: u32,
//
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
json.iter().map(json_value_to_pg_text).collect()
}
impl Slice {
pub fn into_range(self) -> Range<usize> {
let start = self.start as usize;
let end = start + self.len as usize;
start..end
}
}
#[derive(Default)]
pub struct Arena {
pub str_arena: String,
pub params_arena: Vec<Option<Slice>>,
}
impl Arena {
fn alloc_str(&mut self, s: &str) -> Slice {
let start = self.str_arena.len() as u32;
let len = s.len() as u32;
self.str_arena.push_str(s);
Slice { start, len }
}
}
pub struct SerdeArena<'a, T> {
pub arena: &'a mut Arena,
pub _t: PhantomData<T>,
}
impl<'a, T> SerdeArena<'a, T> {
fn alloc_str(&mut self, s: &str) -> Slice {
self.arena.alloc_str(s)
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Vec<QueryData>> {
type Value = Vec<QueryData>;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct VecVisitor<'a>(SerdeArena<'a, Vec<QueryData>>);
impl<'a, 'de> de::Visitor<'de> for VecVisitor<'a> {
type Value = Vec<QueryData>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: de::SeqAccess<'de>,
{
let mut values = Vec::new();
while let Some(value) = seq.next_element_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<QueryData>,
})? {
values.push(value);
}
Ok(values)
}
}
d.deserialize_seq(VecVisitor(self))
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Slice> {
type Value = Slice;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'a>(SerdeArena<'a, Slice>);
impl<'a, 'de> de::Visitor<'de> for Visitor<'a> {
type Value = Slice;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string")
}
fn visit_str<E>(mut self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(self.0.alloc_str(v))
}
}
d.deserialize_str(Visitor(self))
}
}
enum States {
Empty,
HasQueries(Vec<QueryData>),
HasPartialQueryData {
query: Option<Slice>,
params: Option<Slice>,
#[allow(clippy::option_option)]
array_mode: Option<Option<bool>>,
},
}
enum Field {
Queries,
Query,
Params,
ArrayMode,
Ignore,
}
struct FieldVisitor;
impl<'de> de::Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
r#"a JSON object string of either "query", "params", "arrayMode", or "queries"."#,
)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
self.visit_bytes(v.as_bytes())
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
match v {
b"queries" => Ok(Field::Queries),
b"query" => Ok(Field::Query),
b"params" => Ok(Field::Params),
b"arrayMode" => Ok(Field::ArrayMode),
_ => Ok(Field::Ignore),
}
}
}
impl<'de> Deserialize<'de> for Field {
#[inline]
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
d.deserialize_identifier(FieldVisitor)
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, QueryData> {
type Value = QueryData;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'a>(SerdeArena<'a, QueryData>);
impl<'a, 'de> de::Visitor<'de> for Visitor<'a> {
type Value = QueryData;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
"a json object containing either a query object, or a list of query objects",
)
}
#[inline]
fn visit_map<A>(self, mut m: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut state = States::Empty;
while let Some(key) = m.next_key()? {
match key {
Field::Query => {
let (params, array_mode) = match state {
States::HasQueries(_) => unreachable!(),
States::HasPartialQueryData { query: Some(_), .. } => {
return Err(<A::Error as de::Error>::duplicate_field("query"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query: None,
params,
array_mode,
} => (params, array_mode),
};
state = States::HasPartialQueryData {
query: Some(m.next_value_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<Slice>,
})?),
params,
array_mode,
};
}
Field::Params => {
let (query, array_mode) = match state {
States::HasQueries(_) => unreachable!(),
States::HasPartialQueryData {
params: Some(_), ..
} => {
return Err(<A::Error as de::Error>::duplicate_field("params"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params: None,
array_mode,
} => (query, array_mode),
};
let params = m.next_value::<PgText>()?.value;
let start = self.0.arena.params_arena.len() as u32;
let len = params.len() as u32;
for param in params {
match param {
Some(s) => {
let s = self.0.arena.alloc_str(&s);
self.0.arena.params_arena.push(Some(s));
}
None => self.0.arena.params_arena.push(None),
}
}
state = States::HasPartialQueryData {
query,
params: Some(Slice { start, len }),
array_mode,
};
}
Field::ArrayMode => {
let (query, params) = match state {
States::HasQueries(_) => unreachable!(),
States::HasPartialQueryData {
array_mode: Some(_),
..
} => {
return Err(<A::Error as de::Error>::duplicate_field(
"arrayMode",
))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params,
array_mode: None,
} => (query, params),
};
state = States::HasPartialQueryData {
query,
params,
array_mode: Some(m.next_value()?),
};
}
Field::Queries | Field::Ignore => {
let _ = m.next_value::<de::IgnoredAny>()?;
}
}
}
match state {
States::HasQueries(_) => unreachable!(),
States::HasPartialQueryData {
query: Some(query),
params: Some(params),
array_mode,
} => Ok(QueryData {
query,
params,
array_mode: array_mode.unwrap_or_default(),
}),
States::Empty | States::HasPartialQueryData { query: None, .. } => {
Err(<A::Error as de::Error>::missing_field("query"))
}
States::HasPartialQueryData { params: None, .. } => {
Err(<A::Error as de::Error>::missing_field("params"))
}
}
}
}
Deserializer::deserialize_struct(
d,
"QueryData",
&["query", "params", "arrayMode"],
Visitor(self),
)
}
}
impl<'a, 'de> DeserializeSeed<'de> for SerdeArena<'a, Payload> {
type Value = Payload;
fn deserialize<D>(self, d: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
struct Visitor<'a>(SerdeArena<'a, Payload>);
impl<'a, 'de> de::Visitor<'de> for Visitor<'a> {
type Value = Payload;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(
"a json object containing either a query object, or a list of query objects",
)
}
#[inline]
fn visit_map<A>(self, mut m: A) -> Result<Self::Value, A::Error>
where
A: de::MapAccess<'de>,
{
let mut state = States::Empty;
while let Some(key) = m.next_key()? {
match key {
Field::Queries => match state {
States::Empty => {
state = States::HasQueries(m.next_value_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<Vec<QueryData>>,
})?);
}
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::duplicate_field("queries"))
}
States::HasPartialQueryData { .. } => {
return Err(<A::Error as de::Error>::unknown_field(
"queries",
&["query", "params", "arrayMode"],
))
}
},
Field::Query => {
let (params, array_mode) = match state {
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::unknown_field(
"query",
&["queries"],
))
}
States::HasPartialQueryData { query: Some(_), .. } => {
return Err(<A::Error as de::Error>::duplicate_field("query"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query: None,
params,
array_mode,
} => (params, array_mode),
};
state = States::HasPartialQueryData {
query: Some(m.next_value_seed(SerdeArena {
arena: &mut *self.0.arena,
_t: PhantomData::<Slice>,
})?),
params,
array_mode,
};
}
Field::Params => {
let (query, array_mode) = match state {
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::unknown_field(
"params",
&["queries"],
))
}
States::HasPartialQueryData {
params: Some(_), ..
} => {
return Err(<A::Error as de::Error>::duplicate_field("params"))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params: None,
array_mode,
} => (query, array_mode),
};
let params = m.next_value::<PgText>()?.value;
let start = self.0.arena.params_arena.len() as u32;
let len = params.len() as u32;
for param in params {
match param {
Some(s) => {
let s = self.0.arena.alloc_str(&s);
self.0.arena.params_arena.push(Some(s));
}
None => self.0.arena.params_arena.push(None),
}
}
state = States::HasPartialQueryData {
query,
params: Some(Slice { start, len }),
array_mode,
};
}
Field::ArrayMode => {
let (query, params) = match state {
States::HasQueries(_) => {
return Err(<A::Error as de::Error>::unknown_field(
"arrayMode",
&["queries"],
))
}
States::HasPartialQueryData {
array_mode: Some(_),
..
} => {
return Err(<A::Error as de::Error>::duplicate_field(
"arrayMode",
))
}
States::Empty => (None, None),
States::HasPartialQueryData {
query,
params,
array_mode: None,
} => (query, params),
};
state = States::HasPartialQueryData {
query,
params,
array_mode: Some(m.next_value()?),
};
}
Field::Ignore => {
let _ = m.next_value::<de::IgnoredAny>()?;
}
}
}
match state {
States::HasQueries(queries) => Ok(Payload::Batch(BatchQueryData { queries })),
States::HasPartialQueryData {
query: Some(query),
params: Some(params),
array_mode,
} => Ok(Payload::Single(QueryData {
query,
params,
array_mode: array_mode.unwrap_or_default(),
})),
States::Empty | States::HasPartialQueryData { query: None, .. } => {
Err(<A::Error as de::Error>::missing_field("query"))
}
States::HasPartialQueryData { params: None, .. } => {
Err(<A::Error as de::Error>::missing_field("params"))
}
}
}
}
Deserializer::deserialize_struct(
d,
"Payload",
&["queries", "query", "params", "arrayMode"],
Visitor(self),
)
}
}
struct PgText {
value: Vec<Option<String>>,
}
impl<'de> Deserialize<'de> for PgText {
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct VecVisitor;
impl<'de> de::Visitor<'de> for VecVisitor {
type Value = Vec<Option<String>>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence of postgres parameters")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Vec<Option<String>>, A::Error>
where
A: de::SeqAccess<'de>,
{
let mut values = Vec::new();
// TODO: consider avoiding the allocations for json::Value here.
while let Some(value) = seq.next_element()? {
values.push(json_value_to_pg_text(value));
}
Ok(values)
}
}
let value = d.deserialize_seq(VecVisitor)?;
Ok(PgText { value })
}
}
fn json_value_to_pg_text(value: Value) -> Option<String> {
fn json_value_to_pg_text(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
@@ -537,10 +21,10 @@ fn json_value_to_pg_text(value: Value) -> Option<String> {
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s),
Value::String(s) => Some(s.to_string()),
// special care for arrays
Value::Array(arr) => Some(json_array_to_pg_array(arr)),
Value::Array(_) => json_array_to_pg_array(value),
}
}
@@ -552,17 +36,7 @@ fn json_value_to_pg_text(value: Value) -> Option<String> {
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(arr: Vec<Value>) -> String {
let vals = arr
.into_iter()
.map(json_array_value_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.join(",");
format!("{{{vals}}}")
}
fn json_array_value_to_pg_array(value: Value) -> Option<String> {
fn json_array_to_pg_array(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
@@ -570,10 +44,19 @@ fn json_array_value_to_pg_array(value: Value) -> Option<String> {
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
v @ Value::Object(_) => json_array_value_to_pg_array(Value::String(v.to_string())),
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
// recurse into array
Value::Array(arr) => Some(json_array_to_pg_array(arr)),
Value::Array(arr) => {
let vals = arr
.iter()
.map(json_array_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.collect::<Vec<_>>()
.join(",");
Some(format!("{{{vals}}}"))
}
}
}
@@ -778,22 +261,24 @@ mod tests {
#[test]
fn test_atomic_types_to_pg_params() {
let pg_params = json_value_to_pg_text(Value::Bool(true));
assert_eq!(pg_params, Some("true".to_owned()));
let pg_params = json_value_to_pg_text(Value::Bool(false));
assert_eq!(pg_params, Some("false".to_owned()));
let json = vec![Value::Bool(true), Value::Bool(false)];
let pg_params = json_to_pg_text(json);
assert_eq!(
pg_params,
vec![Some("true".to_owned()), Some("false".to_owned())]
);
let json = Value::Number(serde_json::Number::from(42));
let pg_params = json_value_to_pg_text(json);
assert_eq!(pg_params, Some("42".to_owned()));
let json = vec![Value::Number(serde_json::Number::from(42))];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![Some("42".to_owned())]);
let json = Value::String("foo\"".to_string());
let pg_params = json_value_to_pg_text(json);
assert_eq!(pg_params, Some("foo\"".to_owned()));
let json = vec![Value::String("foo\"".to_string())];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
let json = Value::Null;
let pg_params = json_value_to_pg_text(json);
assert_eq!(pg_params, None);
let json = vec![Value::Null];
let pg_params = json_to_pg_text(json);
assert_eq!(pg_params, vec![None]);
}
#[test]
@@ -801,27 +286,31 @@ mod tests {
// atoms and escaping
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_value_to_pg_text(json);
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
Some("{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned())
vec![Some(
"{true,false,NULL,\"NULL\",42,\"foo\",\"bar\\\"-\\\\\"}".to_owned()
)]
);
// nested arrays
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_value_to_pg_text(json);
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
Some("{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned())
vec![Some(
"{{true,false},{NULL,42},{\"foo\",\"bar\\\"-\\\\\"}}".to_owned()
)]
);
// array of objects
let json = r#"[{"foo": 1},{"bar": 2}]"#;
let json: Value = serde_json::from_str(json).unwrap();
let pg_params = json_value_to_pg_text(json);
let pg_params = json_to_pg_text(vec![json]);
assert_eq!(
pg_params,
Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
);
}

View File

@@ -1,4 +1,3 @@
use std::marker::PhantomData;
use std::pin::pin;
use std::sync::Arc;
@@ -9,8 +8,6 @@ use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use http::header::AUTHORIZATION;
use http::Method;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
@@ -22,6 +19,8 @@ use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
use tokio::time;
use tokio_postgres::error::DbError;
use tokio_postgres::error::ErrorPosition;
@@ -39,51 +38,52 @@ use url::Url;
use urlencoding;
use utils::http::error::ApiError;
use crate::auth::backend::ComputeCredentials;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
use crate::error::ErrorKind;
use crate::error::ReportableError;
use crate::error::UserFacingError;
use crate::http::parse_json_body_with_limit;
use crate::metrics::HttpDirection;
use crate::metrics::Metrics;
use crate::proxy::run_until_cancelled;
use crate::proxy::NeonOptions;
use crate::serverless::json::Arena;
use crate::serverless::json::SerdeArena;
use crate::serverless::backend::HttpConnError;
use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName;
use crate::RoleName;
use super::backend::HttpConnError;
use super::backend::LocalProxyConnError;
use super::backend::PoolingBackend;
use super::conn_pool::AuthData;
use super::conn_pool::Client;
use super::conn_pool::ConnInfo;
use super::conn_pool::ConnInfoWithAuth;
use super::http_util::json_response;
use super::json::json_to_pg_text;
use super::json::pg_text_row_to_json;
use super::json::JsonConversionError;
use super::json::Slice;
pub(crate) struct QueryData {
pub(crate) query: Slice,
pub(crate) params: Slice,
pub(crate) array_mode: Option<bool>,
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
params: Vec<Option<String>>,
#[serde(default)]
array_mode: Option<bool>,
}
pub(crate) struct BatchQueryData {
pub(crate) queries: Vec<QueryData>,
#[derive(serde::Deserialize)]
struct BatchQueryData {
queries: Vec<QueryData>,
}
pub(crate) enum Payload {
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Payload {
Single(QueryData),
Batch(BatchQueryData),
}
@@ -98,6 +98,15 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
Ok(json_to_pg_text(json))
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
@@ -114,8 +123,8 @@ pub(crate) enum ConnInfoError {
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing password")]
MissingPassword,
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
@@ -124,14 +133,6 @@ pub(crate) enum ConnInfoError {
MalformedEndpoint,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
@@ -145,7 +146,6 @@ impl UserFacingError for ConnInfoError {
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestMonitoring,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
@@ -181,32 +181,21 @@ fn get_conn_info(
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.ok_or(ConnInfoError::MissingPassword)?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
return Err(ConnInfoError::MissingPassword);
};
let endpoint = match connection_url.host() {
@@ -258,7 +247,7 @@ pub(crate) async fn handle(
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
@@ -290,7 +279,7 @@ pub(crate) async fn handle(
let mut message = e.to_string_client();
let db_error = match &e {
SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
@@ -367,7 +356,7 @@ pub(crate) async fn handle(
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpError {
#[error("{0}")]
ReadPayload(ReadPayloadError),
ReadPayload(#[from] ReadPayloadError),
#[error("{0}")]
ConnectCompute(#[from] HttpConnError),
#[error("{0}")]
@@ -421,9 +410,9 @@ impl UserFacingError for SqlOverHttpError {
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(hyper1::Error),
Read(#[from] hyper1::Error),
#[error("could not parse the HTTP request body: {0}")]
Parse(serde_json::Error),
Parse(#[from] serde_json::Error),
}
impl ReportableError for ReadPayloadError {
@@ -435,18 +424,6 @@ impl ReportableError for ReadPayloadError {
}
}
impl From<crate::http::ReadPayloadError<hyper1::Error>> for SqlOverHttpError {
fn from(value: crate::http::ReadPayloadError<hyper1::Error>) -> Self {
match value {
crate::http::ReadPayloadError::Read(e) => Self::ReadPayload(ReadPayloadError::Read(e)),
crate::http::ReadPayloadError::Parse(e) => {
Self::ReadPayload(ReadPayloadError::Parse(e))
}
crate::http::ReadPayloadError::LengthExceeded(x) => Self::RequestTooLarge(x as u64),
}
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpCancel {
#[error("query was cancelled")]
@@ -527,7 +504,7 @@ async fn handle_inner(
ctx: &RequestMonitoring,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get()
.proxy
.connection_requests
@@ -537,58 +514,26 @@ async fn handle_inner(
"handling interactive connection from client"
);
let conn_info = get_conn_info(
&config.authentication_config,
ctx,
request.headers(),
config.tls_config.as_ref(),
)?;
//
// Determine the destination and connection params
//
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
);
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
cancel,
config,
ctx,
request,
conn_info.conn_info,
auth,
backend,
)
.await
}
}
}
async fn handle_db_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
auth: AuthData,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
//
// Determine the destination and connection params
//
let (parts, body) = request.into_parts();
// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
let allow_pool = !config.http_config.pool_options.opt_in
|| parts.headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
|| headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
let parsed_headers = HttpHeaders::try_parse(&parts.headers)?;
let parsed_headers = HttpHeaders::try_parse(headers)?;
let request_content_length = match body.size_hint().upper() {
let request_content_length = match request.body().size_hint().upper() {
Some(v) => v,
None => config.http_config.max_request_size_bytes + 1,
};
@@ -606,53 +551,38 @@ async fn handle_db_inner(
));
}
let fetch_and_process_request = Box::pin(async move {
let mut arena = Arena::default();
let seed = SerdeArena {
arena: &mut arena,
_t: PhantomData::<Payload>,
};
let payload = parse_json_body_with_limit(
seed,
body,
config.http_config.max_request_size_bytes as usize,
)
.await?;
Ok::<(Arena, Payload), SqlOverHttpError>((arena, payload)) // Adjust error type accordingly
});
let fetch_and_process_request = Box::pin(
async {
let body = request.into_body().collect().await?.to_bytes();
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from),
);
let authenticate_and_connect = Box::pin(
async {
let keys = match auth {
let keys = match &conn_info.auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
&pw,
&conn_info.conn_info.user_info,
pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await?;
ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
}
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
.await?
}
};
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
@@ -662,7 +592,7 @@ async fn handle_db_inner(
.map_err(SqlOverHttpError::from),
);
let ((mut arena, payload), mut client) = match run_until_cancelled(
let (payload, mut client) = match run_until_cancelled(
// Run both operations in parallel
try_join(
pin!(fetch_and_process_request),
@@ -676,9 +606,6 @@ async fn handle_db_inner(
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
};
arena.params_arena.shrink_to_fit();
arena.str_arena.shrink_to_fit();
let mut response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json");
@@ -686,7 +613,7 @@ async fn handle_db_inner(
// Now execute the query and return the result.
let json_output = match payload {
Payload::Single(stmt) => {
stmt.process(config, &arena, cancel, &mut client, parsed_headers)
stmt.process(config, cancel, &mut client, parsed_headers)
.await?
}
Payload::Batch(statements) => {
@@ -704,27 +631,16 @@ async fn handle_db_inner(
}
statements
.process(config, &arena, cancel, &mut client, parsed_headers)
.process(config, cancel, &mut client, parsed_headers)
.await?
}
};
info!(
str_len = arena.str_arena.len(),
params = arena.params_arena.len(),
response = json_output.len(),
"data size"
);
let metrics = client.metrics();
let len = json_output.len();
let response = response
.body(
Full::new(Bytes::from(json_output))
.map_err(|x| match x {})
.boxed(),
)
.body(Full::new(Bytes::from(json_output)))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
@@ -740,70 +656,10 @@ async fn handle_db_inner(
Ok(response)
}
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&AUTHORIZATION,
&CONN_STRING,
&RAW_TEXT_OUTPUT,
&ARRAY_MODE,
&TXN_ISOLATION_LEVEL,
&TXN_READ_ONLY,
&TXN_DEFERRABLE,
];
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
let (mut parts, body) = request.into_parts();
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
// these instead just to ensure they remain normalised.
for &h in HEADERS_TO_FORWARD {
if let Some(hv) = parts.headers.remove(h) {
req = req.header(h, hv);
}
}
let req = req
.body(body)
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress
let _metrics = client.metrics();
Ok(client
.inner
.send_request(req)
.await
.map_err(LocalProxyConnError::from)
.map_err(HttpConnError::from)?
.map(|b| b.boxed()))
}
impl QueryData {
async fn process(
self,
config: &'static ProxyConfig,
arena: &Arena,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -812,14 +668,7 @@ impl QueryData {
let cancel_token = inner.cancel_token();
let res = match select(
pin!(query_to_json(
config,
arena,
&*inner,
self,
&mut 0,
parsed_headers
)),
pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)),
pin!(cancel.cancelled()),
)
.await
@@ -827,7 +676,10 @@ impl QueryData {
// The query successfully completed.
Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
discard.check_idle(status);
Ok(results)
let json_output =
serde_json::to_string(&results).expect("json serialization should not fail");
Ok(json_output)
}
// The query failed with an error
Either::Left((Err(e), __not_yet_cancelled)) => {
@@ -853,9 +705,7 @@ impl QueryData {
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = match &error {
SqlOverHttpError::ConnectCompute(
HttpConnError::PostgresConnectionError(e),
)
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
@@ -882,7 +732,6 @@ impl BatchQueryData {
async fn process(
self,
config: &'static ProxyConfig,
arena: &Arena,
cancel: CancellationToken,
client: &mut Client<tokio_postgres::Client>,
parsed_headers: HttpHeaders,
@@ -909,7 +758,6 @@ impl BatchQueryData {
let json_output = match query_batch(
config,
arena,
cancel.child_token(),
&transaction,
self,
@@ -954,20 +802,16 @@ impl BatchQueryData {
async fn query_batch(
config: &'static ProxyConfig,
arena: &Arena,
cancel: CancellationToken,
transaction: &Transaction<'_>,
queries: BatchQueryData,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
let mut comma = false;
let mut results = r#"{"results":["#.to_string();
let mut results = Vec::with_capacity(queries.queries.len());
let mut current_size = 0;
for stmt in queries.queries {
let query = pin!(query_to_json(
config,
arena,
transaction,
stmt,
&mut current_size,
@@ -978,11 +822,7 @@ async fn query_batch(
match res {
// TODO: maybe we should check that the transaction bit is set here
Either::Left((Ok((_, values)), _cancelled)) => {
if comma {
results.push(',');
}
results.push_str(&values);
comma = true;
results.push(values);
}
Either::Left((Err(e), _cancelled)) => {
return Err(e);
@@ -993,28 +833,22 @@ async fn query_batch(
}
}
results.push_str("]}");
let results = json!({ "results": results });
let json_output = serde_json::to_string(&results).expect("json serialization should not fail");
Ok(results)
Ok(json_output)
}
async fn query_to_json<T: GenericClient>(
config: &'static ProxyConfig,
arena: &Arena,
client: &T,
data: QueryData,
current_size: &mut usize,
parsed_headers: HttpHeaders,
) -> Result<(ReadyForQueryStatus, String), SqlOverHttpError> {
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
info!("executing query");
let query_params = arena.params_arena[data.params.into_range()]
.iter()
.map(|p| p.map(|p| &arena.str_arena[p.into_range()]));
let query = &arena.str_arena[data.query.into_range()];
let mut row_stream = std::pin::pin!(client.query_raw_txt(query, query_params).await?);
let query_params = data.params;
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
info!("finished executing query");
// Manually drain the stream into a vector to leave row_stream hanging
@@ -1083,13 +917,12 @@ async fn query_to_json<T: GenericClient>(
// Resulting JSON format is based on the format of node-postgres result.
let results = json!({
"command": command_tag_name,
"command": command_tag_name.to_string(),
"rowCount": command_tag_count,
"rows": rows,
"fields": fields,
"rowAsArray": array_mode,
})
.to_string();
});
Ok((ready, results))
}

View File

@@ -374,16 +374,14 @@ type JoinTaskRes = Result<anyhow::Result<()>, JoinError>;
async fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
// fsync the datadir to make sure we have a consistent state on disk.
if !conf.no_sync {
let dfd = File::open(&conf.workdir).context("open datadir for syncfs")?;
let started = Instant::now();
utils::crashsafe::syncfs(dfd)?;
let elapsed = started.elapsed();
info!(
elapsed_ms = elapsed.as_millis(),
"syncfs data directory done"
);
}
let dfd = File::open(&conf.workdir).context("open datadir for syncfs")?;
let started = Instant::now();
utils::crashsafe::syncfs(dfd)?;
let elapsed = started.elapsed();
info!(
elapsed_ms = elapsed.as_millis(),
"syncfs data directory done"
);
info!("starting safekeeper WAL service on {}", conf.listen_pg_addr);
let pg_listener = tcp_listener::bind(conf.listen_pg_addr.clone()).map_err(|e| {

View File

@@ -9,7 +9,7 @@ use crate::walproposer_sim::{
pub mod walproposer_sim;
// Generates 500 random seeds and runs a schedule for each of them.
// Generates 2000 random seeds and runs a schedule for each of them.
// If you see this test fail, please report the last seed to the
// @safekeeper team.
#[test]
@@ -17,7 +17,7 @@ fn test_random_schedules() -> anyhow::Result<()> {
let clock = init_logger();
let mut config = TestConfig::new(Some(clock));
for _ in 0..500 {
for _ in 0..2000 {
let seed: u64 = rand::thread_rng().gen();
config.network = generate_network_opts(seed);

View File

@@ -572,7 +572,30 @@ impl Reconciler {
// During a live migration it is unhelpful to proceed if we couldn't notify compute: if we detach
// the origin without notifying compute, we will render the tenant unavailable.
self.compute_notify_blocking(&origin_ps).await?;
let mut notify_attempts = 0;
while let Err(e) = self.compute_notify().await {
match e {
NotifyError::Fatal(_) => return Err(ReconcileError::Notify(e)),
NotifyError::ShuttingDown => return Err(ReconcileError::Cancel),
_ => {
tracing::warn!(
"Live migration blocked by compute notification error, retrying: {e}"
);
}
}
exponential_backoff(
notify_attempts,
// Generous waits: control plane operations which might be blocking us usually complete on the order
// of hundreds to thousands of milliseconds, so no point busy polling.
1.0,
10.0,
&self.cancel,
)
.await;
notify_attempts += 1;
}
pausable_failpoint!("reconciler-live-migrate-post-notify");
// Downgrade the origin to secondary. If the tenant's policy is PlacementPolicy::Attached(0), then
@@ -846,117 +869,6 @@ impl Reconciler {
Ok(())
}
}
/// Keep trying to notify the compute indefinitely, only dropping out if:
/// - the node `origin` becomes unavailable -> Ok(())
/// - the node `origin` no longer has our tenant shard attached -> Ok(())
/// - our cancellation token fires -> Err(ReconcileError::Cancelled)
///
/// This is used during live migration, where we do not wish to detach
/// an origin location until the compute definitely knows about the new
/// location.
///
/// In cases where the origin node becomes unavailable, we return success, indicating
/// to the caller that they should continue irrespective of whether the compute was notified,
/// because the origin node is unusable anyway. Notification will be retried later via the
/// [`Self::compute_notify_failure`] flag.
async fn compute_notify_blocking(&mut self, origin: &Node) -> Result<(), ReconcileError> {
let mut notify_attempts = 0;
while let Err(e) = self.compute_notify().await {
match e {
NotifyError::Fatal(_) => return Err(ReconcileError::Notify(e)),
NotifyError::ShuttingDown => return Err(ReconcileError::Cancel),
_ => {
tracing::warn!(
"Live migration blocked by compute notification error, retrying: {e}"
);
}
}
// Did the origin pageserver become unavailable?
if !origin.is_available() {
tracing::info!("Giving up on compute notification because {origin} is unavailable");
break;
}
// Does the origin pageserver still host the shard we are interested in? We should only
// continue waiting for compute notification to be acked if the old location is still usable.
let tenant_shard_id = self.tenant_shard_id;
match origin
.with_client_retries(
|client| async move { client.get_location_config(tenant_shard_id).await },
&self.service_config.jwt_token,
1,
3,
Duration::from_secs(5),
&self.cancel,
)
.await
{
Some(Ok(Some(location_conf))) => {
if matches!(
location_conf.mode,
LocationConfigMode::AttachedMulti
| LocationConfigMode::AttachedSingle
| LocationConfigMode::AttachedStale
) {
tracing::debug!(
"Still attached to {origin}, will wait & retry compute notification"
);
} else {
tracing::info!(
"Giving up on compute notification because {origin} is in state {:?}",
location_conf.mode
);
return Ok(());
}
// Fall through
}
Some(Ok(None)) => {
tracing::info!(
"No longer attached to {origin}, giving up on compute notification"
);
return Ok(());
}
Some(Err(e)) => {
match e {
mgmt_api::Error::Cancelled => {
tracing::info!(
"Giving up on compute notification because {origin} is unavailable"
);
return Ok(());
}
mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, _) => {
tracing::info!(
"No longer attached to {origin}, giving up on compute notification"
);
return Ok(());
}
e => {
// Other API errors are unexpected here.
tracing::warn!("Unexpected error checking location on {origin}: {e}");
// Fall through, we will retry compute notification.
}
}
}
None => return Err(ReconcileError::Cancel),
};
exponential_backoff(
notify_attempts,
// Generous waits: control plane operations which might be blocking us usually complete on the order
// of hundreds to thousands of milliseconds, so no point busy polling.
1.0,
10.0,
&self.cancel,
)
.await;
notify_attempts += 1;
}
Ok(())
}
}
/// We tweak the externally-set TenantConfig while configuring

View File

@@ -4974,12 +4974,7 @@ impl Service {
{
let mut nodes_mut = (**nodes).clone();
if let Some(mut removed_node) = nodes_mut.remove(&node_id) {
// Ensure that any reconciler holding an Arc<> to this node will
// drop out when trying to RPC to it (setting Offline state sets the
// cancellation token on the Node object).
removed_node.set_availability(NodeAvailability::Offline);
}
nodes_mut.remove(&node_id);
*nodes = Arc::new(nodes_mut);
}
}

View File

@@ -4,7 +4,7 @@ use std::time::Duration;
use crate::checks::{list_timeline_blobs, BlobDataParseResult};
use crate::metadata_stream::{stream_tenant_timelines, stream_tenants};
use crate::{init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId, MAX_RETRIES};
use crate::{init_remote, BucketConfig, NodeKind, RootTarget, TenantShardTimelineId};
use futures_util::{StreamExt, TryStreamExt};
use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata;
use pageserver::tenant::remote_timeline_client::{parse_remote_index_path, remote_layer_path};
@@ -18,7 +18,6 @@ use serde::Serialize;
use storage_controller_client::control_api;
use tokio_util::sync::CancellationToken;
use tracing::{info_span, Instrument};
use utils::backoff;
use utils::generation::Generation;
use utils::id::{TenantId, TenantTimelineId};
@@ -327,25 +326,15 @@ async fn maybe_delete_index(
}
// All validations passed: erase the object
let cancel = CancellationToken::new();
match backoff::retry(
|| remote_client.delete(&obj.key, &cancel),
|_| false,
3,
MAX_RETRIES as u32,
"maybe_delete_index",
&cancel,
)
.await
match remote_client
.delete(&obj.key, &CancellationToken::new())
.await
{
None => {
unreachable!("Using a dummy cancellation token");
}
Some(Ok(_)) => {
Ok(_) => {
tracing::info!("Successfully deleted index");
summary.indices_deleted += 1;
}
Some(Err(e)) => {
Err(e) => {
tracing::warn!("Failed to delete index: {e}");
summary.remote_storage_errors += 1;
}

View File

@@ -950,6 +950,9 @@ class NeonEnv:
safekeepers - An array containing objects representing the safekeepers
pg_bin - pg_bin.run() can be used to execute Postgres client binaries,
like psql or pg_dump
initial_tenant - tenant ID of the initial tenant created in the repository
neon_cli - can be used to run the 'neon' CLI tool
@@ -3297,8 +3300,6 @@ class PgBin:
@pytest.fixture(scope="function")
def pg_bin(test_output_dir: Path, pg_distrib_dir: Path, pg_version: PgVersion) -> PgBin:
"""pg_bin.run() can be used to execute Postgres client binaries, like psql or pg_dump"""
return PgBin(test_output_dir, pg_distrib_dir, pg_version)
@@ -3310,7 +3311,7 @@ class VanillaPostgres(PgProtocol):
self.pg_bin = pg_bin
self.running = False
if init:
self.pg_bin.run_capture(["initdb", "--pgdata", str(pgdatadir)])
self.pg_bin.run_capture(["initdb", "-D", str(pgdatadir)])
self.configure([f"port = {port}\n"])
def enable_tls(self):

View File

@@ -8,7 +8,6 @@ import requests
from fixtures.common_types import Lsn, TenantId, TenantTimelineId, TimelineId
from fixtures.log_helper import log
from fixtures.metrics import Metrics, MetricsGetter, parse_metrics
from fixtures.utils import wait_until
# Walreceiver as returned by sk's timeline status endpoint.
@@ -162,16 +161,6 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter):
walreceivers=walreceivers,
)
# Get timeline_start_lsn, waiting until it's nonzero. It is a way to ensure
# that the timeline is fully initialized at the safekeeper.
def get_non_zero_timeline_start_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn:
def timeline_start_lsn_non_zero() -> Lsn:
s = self.timeline_status(tenant_id, timeline_id).timeline_start_lsn
assert s > Lsn(0)
return s
return wait_until(30, 1, timeline_start_lsn_non_zero)
def get_commit_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn:
return self.timeline_status(tenant_id, timeline_id).commit_lsn

View File

@@ -56,20 +56,32 @@ class Workload:
with ENDPOINT_LOCK:
self._endpoint.reconfigure()
def endpoint(self, pageserver_id: Optional[int] = None) -> Endpoint:
def go_readonly(self):
self.stop()
self._endpoint = self.make_endpoint(readonly=True, pageserver_id=None)
self._endpoint.start(pageserver_id=None)
def make_endpoint(self, readonly: bool, pageserver_id: Optional[int] = None) -> Endpoint:
# We may be running alongside other Workloads for different tenants. Full TTID is
# obnoxiously long for use here, but a cut-down version is still unique enough for tests.
endpoint_id = f"ep-workload-{str(self.tenant_id)[0:4]}-{str(self.timeline_id)[0:4]}"
if readonly:
self._endpoint_opts["hot_standby"] = True
return self.env.endpoints.create(
self.branch_name,
tenant_id=self.tenant_id,
pageserver_id=pageserver_id,
endpoint_id=endpoint_id,
**self._endpoint_opts,
)
def endpoint(self, pageserver_id: Optional[int] = None) -> Endpoint:
with ENDPOINT_LOCK:
if self._endpoint is None:
self._endpoint = self.env.endpoints.create(
self.branch_name,
tenant_id=self.tenant_id,
pageserver_id=pageserver_id,
endpoint_id=endpoint_id,
**self._endpoint_opts,
)
self._endpoint = self.make_endpoint(pageserver_id=pageserver_id, readonly=False)
self._endpoint.start(pageserver_id=pageserver_id)
else:
self._endpoint.reconfigure(pageserver_id=pageserver_id)

View File

@@ -53,7 +53,7 @@ def test_branch_and_gc(neon_simple_env: NeonEnv, build_type: str):
env = neon_simple_env
pageserver_http_client = env.pageserver.http_client()
tenant, timeline_main = env.neon_cli.create_tenant(
tenant, _ = env.neon_cli.create_tenant(
conf={
# disable background GC
"gc_period": "0s",
@@ -70,7 +70,8 @@ def test_branch_and_gc(neon_simple_env: NeonEnv, build_type: str):
}
)
endpoint_main = env.endpoints.create_start("main", tenant_id=tenant)
timeline_main = env.neon_cli.create_timeline("test_main", tenant_id=tenant)
endpoint_main = env.endpoints.create_start("test_main", tenant_id=tenant)
main_cur = endpoint_main.connect().cursor()
@@ -91,7 +92,7 @@ def test_branch_and_gc(neon_simple_env: NeonEnv, build_type: str):
pageserver_http_client.timeline_gc(tenant, timeline_main, lsn2 - lsn1 + 1024)
env.neon_cli.create_branch(
"test_branch", ancestor_branch_name="main", ancestor_start_lsn=lsn1, tenant_id=tenant
"test_branch", "test_main", tenant_id=tenant, ancestor_start_lsn=lsn1
)
endpoint_branch = env.endpoints.create_start("test_branch", tenant_id=tenant)

View File

@@ -11,6 +11,7 @@ from fixtures.neon_fixtures import (
generate_uploads_and_deletions,
)
from fixtures.pageserver.http import PageserverApiException
from fixtures.pageserver.utils import wait_for_last_record_lsn
from fixtures.utils import wait_until
from fixtures.workload import Workload
@@ -412,3 +413,42 @@ def test_image_layer_compression(neon_env_builder: NeonEnvBuilder, enabled: bool
f"SELECT count(*) FROM foo WHERE id={v} and val=repeat('abcde{v:0>3}', 500)"
)
assert res[0][0] == 1
def test_image_layer_reads(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
env.pageserver.http_client().set_tenant_config(
tenant_id,
{
"compaction_period": "0s",
},
)
workload = Workload(env, tenant_id, timeline_id)
workload.init()
workload.write_rows(256)
workload.validate()
workload.go_readonly()
commit_lsn = env.safekeepers[0].http_client().get_commit_lsn(tenant_id, timeline_id)
wait_for_last_record_lsn(env.pageserver.http_client(), tenant_id, timeline_id, commit_lsn)
log.info(f"Ingested up to commit_lsn {commit_lsn}")
env.pageserver.http_client().timeline_compact(
tenant_id, timeline_id, force_image_layer_creation=True
)
# Uncomment this checkpoint, and the logs will show getpage requests hitting the image layers we
# just created. However, without the checkpoint, getpage requests will hit one InMemoryLayer and
# one persistent delta layer.
# env.pageserver.http_client().timeline_checkpoint(tenant_id, timeline_id, wait_until_uploaded=True)
# This should send getpage requests at the same LSN where we just created image layers
workload.validate()
# Nothing should have written in the meantime
assert commit_lsn == env.safekeepers[0].http_client().get_commit_lsn(tenant_id, timeline_id)

View File

@@ -21,7 +21,7 @@ from fixtures.pageserver.http import PageserverApiException
from fixtures.pageserver.utils import (
timeline_delete_wait_completed,
)
from fixtures.pg_version import PgVersion
from fixtures.pg_version import PgVersion, skip_on_postgres
from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage
from fixtures.workload import Workload
@@ -156,6 +156,9 @@ ingest_lag_log_line = ".*ingesting record with timestamp lagging more than wait_
@check_ondisk_data_compatibility_if_enabled
@pytest.mark.xdist_group("compatibility")
@pytest.mark.order(after="test_create_snapshot")
@skip_on_postgres(
PgVersion.V17, "There are no snapshots yet"
) # TODO: revert this once we have snapshots
def test_backward_compatibility(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
@@ -203,6 +206,9 @@ def test_backward_compatibility(
@check_ondisk_data_compatibility_if_enabled
@pytest.mark.xdist_group("compatibility")
@pytest.mark.order(after="test_create_snapshot")
@skip_on_postgres(
PgVersion.V17, "There are no snapshots yet"
) # TODO: revert this once we have snapshots
def test_forward_compatibility(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
@@ -252,7 +258,7 @@ def test_forward_compatibility(
# not using env.pageserver.version because it was initialized before
prev_pageserver_version_str = env.get_binary_version("pageserver")
prev_pageserver_version_match = re.search(
"Neon page server git(?:-env)?:(.*) failpoints: (.*), features: (.*)",
"Neon page server git-env:(.*) failpoints: (.*), features: (.*)",
prev_pageserver_version_str,
)
if prev_pageserver_version_match is not None:
@@ -263,12 +269,12 @@ def test_forward_compatibility(
)
# does not include logs from previous runs
assert not env.pageserver.log_contains(f"git(-env)?:{prev_pageserver_version}")
assert not env.pageserver.log_contains("git-env:" + prev_pageserver_version)
env.start()
# ensure the specified pageserver is running
assert env.pageserver.log_contains(f"git(-env)?:{prev_pageserver_version}")
assert env.pageserver.log_contains("git-env:" + prev_pageserver_version)
check_neon_works(
env,

View File

@@ -31,7 +31,9 @@ def helper_compare_timeline_list(
)
)
timelines_cli = env.neon_cli.list_timelines(initial_tenant)
timelines_cli = env.neon_cli.list_timelines()
assert timelines_cli == env.neon_cli.list_timelines(initial_tenant)
cli_timeline_ids = sorted([timeline_id for (_, timeline_id) in timelines_cli])
assert timelines_api == cli_timeline_ids

View File

@@ -24,7 +24,7 @@ def test_neon_extension(neon_env_builder: NeonEnvBuilder):
# IMPORTANT:
# If the version has changed, the test should be updated.
# Ensure that the default version is also updated in the neon.control file
assert cur.fetchone() == ("1.5",)
assert cur.fetchone() == ("1.4",)
cur.execute("SELECT * from neon.NEON_STAT_FILE_CACHE")
res = cur.fetchall()
log.info(res)
@@ -48,7 +48,7 @@ def test_neon_extension_compatibility(neon_env_builder: NeonEnvBuilder):
# IMPORTANT:
# If the version has changed, the test should be updated.
# Ensure that the default version is also updated in the neon.control file
assert cur.fetchone() == ("1.5",)
assert cur.fetchone() == ("1.4",)
cur.execute("SELECT * from neon.NEON_STAT_FILE_CACHE")
all_versions = ["1.5", "1.4", "1.3", "1.2", "1.1", "1.0"]
current_version = "1.5"

View File

@@ -549,14 +549,6 @@ def test_multi_attach(
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
# Instruct the storage controller to not interfere with our low level configuration
# of the pageserver's attachment states. Otherwise when it sees nodes go offline+return,
# it would send its own requests that would conflict with the test's.
env.storage_controller.tenant_policy_update(tenant_id, {"scheduling": "Stop"})
env.storage_controller.allowed_errors.extend(
[".*Scheduling is disabled by policy Stop.*", ".*Skipping reconcile for policy Stop.*"]
)
# Initially, the tenant will be attached to the first pageserver (first is default in our test harness)
wait_until(10, 0.2, lambda: assert_tenant_state(http_clients[0], tenant_id, "Active"))
_detail = http_clients[0].timeline_detail(tenant_id, timeline_id)

View File

@@ -174,7 +174,8 @@ def test_pageserver_chaos(
"checkpoint_distance": "5000000",
}
)
endpoint = env.endpoints.create_start("main", tenant_id=tenant)
env.neon_cli.create_timeline("test_pageserver_chaos", tenant_id=tenant)
endpoint = env.endpoints.create_start("test_pageserver_chaos", tenant_id=tenant)
# Create table, and insert some rows. Make it big enough that it doesn't fit in
# shared_buffers, otherwise the SELECT after restart will just return answer

View File

@@ -27,7 +27,7 @@ def test_readonly_node(neon_simple_env: NeonEnv):
env.pageserver.allowed_errors.extend(
[
".*basebackup .* failed: invalid basebackup lsn.*",
".*/lsn_lease.*invalid lsn lease request.*",
".*page_service.*handle_make_lsn_lease.*.*tried to request a page version that was garbage collected",
]
)
@@ -108,7 +108,7 @@ def test_readonly_node(neon_simple_env: NeonEnv):
assert cur.fetchone() == (1,)
# Create node at pre-initdb lsn
with pytest.raises(Exception, match="invalid lsn lease request"):
with pytest.raises(Exception, match="invalid basebackup lsn"):
# compute node startup with invalid LSN should fail
env.endpoints.create_start(
branch_name="main",
@@ -167,23 +167,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
)
return last_flush_lsn
def trigger_gc_and_select(env: NeonEnv, ep_static: Endpoint):
"""
Trigger GC manually on all pageservers. Then run an `SELECT` query.
"""
for shard, ps in tenant_get_shards(env, env.initial_tenant):
client = ps.http_client()
gc_result = client.timeline_gc(shard, env.initial_timeline, 0)
log.info(f"{gc_result=}")
assert (
gc_result["layers_removed"] == 0
), "No layers should be removed, old layers are guarded by leases."
with ep_static.cursor() as cur:
cur.execute("SELECT count(*) FROM t0")
assert cur.fetchone() == (ROW_COUNT,)
# Insert some records on main branch
with env.endpoints.create_start("main") as ep_main:
with ep_main.cursor() as cur:
@@ -210,31 +193,25 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder):
generate_updates_on_main(env, ep_main, i, end=100)
trigger_gc_and_select(env, ep_static)
# Trigger GC
for shard, ps in tenant_get_shards(env, env.initial_tenant):
client = ps.http_client()
gc_result = client.timeline_gc(shard, env.initial_timeline, 0)
log.info(f"{gc_result=}")
# Trigger Pageserver restarts
for ps in env.pageservers:
ps.stop()
# Static compute should have at least one lease request failure due to connection.
time.sleep(LSN_LEASE_LENGTH / 2)
ps.start()
assert (
gc_result["layers_removed"] == 0
), "No layers should be removed, old layers are guarded by leases."
trigger_gc_and_select(env, ep_static)
# Reconfigure pageservers
env.pageservers[0].stop()
env.storage_controller.node_configure(
env.pageservers[0].id, {"availability": "Offline"}
)
env.storage_controller.reconcile_until_idle()
trigger_gc_and_select(env, ep_static)
with ep_static.cursor() as cur:
cur.execute("SELECT count(*) FROM t0")
assert cur.fetchone() == (ROW_COUNT,)
# Do some update so we can increment latest_gc_cutoff
generate_updates_on_main(env, ep_main, i, end=100)
# Wait for the existing lease to expire.
time.sleep(LSN_LEASE_LENGTH + 1)
time.sleep(LSN_LEASE_LENGTH)
# Now trigger GC again, layers should be removed.
for shard, ps in tenant_get_shards(env, env.initial_tenant):
client = ps.http_client()

View File

@@ -567,149 +567,6 @@ def test_storage_controller_compute_hook(
env.storage_controller.consistency_check()
def test_storage_controller_stuck_compute_hook(
httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder,
httpserver_listen_address,
):
"""
Test the migration process's behavior when the compute hook does not enable it to proceed
"""
neon_env_builder.num_pageservers = 2
(host, port) = httpserver_listen_address
neon_env_builder.control_plane_compute_hook_api = f"http://{host}:{port}/notify"
handle_params = {"status": 200}
notifications = []
def handler(request: Request):
status = handle_params["status"]
log.info(f"Notify request[{status}]: {request}")
notifications.append(request.json)
return Response(status=status)
httpserver.expect_request("/notify", method="PUT").respond_with_handler(handler)
# Start running
env = neon_env_builder.init_start(initial_tenant_conf={"lsn_lease_length": "0s"})
# Initial notification from tenant creation
assert len(notifications) == 1
expect: Dict[str, Union[List[Dict[str, int]], str, None, int]] = {
"tenant_id": str(env.initial_tenant),
"stripe_size": None,
"shards": [{"node_id": int(env.pageservers[0].id), "shard_number": 0}],
}
assert notifications[0] == expect
# Do a migration while the compute hook is returning 423 status
tenant_id = env.initial_tenant
origin_pageserver = env.get_tenant_pageserver(tenant_id)
dest_ps_id = [p.id for p in env.pageservers if p.id != origin_pageserver.id][0]
dest_pageserver = env.get_pageserver(dest_ps_id)
shard_0_id = TenantShardId(tenant_id, 0, 0)
NOTIFY_BLOCKED_LOG = ".*Live migration blocked.*"
env.storage_controller.allowed_errors.extend(
[
NOTIFY_BLOCKED_LOG,
".*Failed to notify compute.*",
".*Reconcile error.*Cancelled",
".*Reconcile error.*Control plane tenant busy",
]
)
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
# We expect the controller to hit the 423 (locked) and retry. Migration shouldn't complete until that
# status is cleared.
handle_params["status"] = 423
migrate_fut = executor.submit(
env.storage_controller.tenant_shard_migrate, shard_0_id, dest_ps_id
)
def logged_stuck():
env.storage_controller.assert_log_contains(NOTIFY_BLOCKED_LOG)
wait_until(10, 0.25, logged_stuck)
contains_r = env.storage_controller.log_contains(NOTIFY_BLOCKED_LOG)
assert contains_r is not None # Appease mypy
(_, log_cursor) = contains_r
assert migrate_fut.running()
# Permit the compute hook to proceed
handle_params["status"] = 200
migrate_fut.result(timeout=10)
# Advance log cursor past the last 'stuck' message (we already waited for one, but
# there could be more than one)
while True:
contains_r = env.storage_controller.log_contains(NOTIFY_BLOCKED_LOG, offset=log_cursor)
if contains_r is None:
break
else:
(_, log_cursor) = contains_r
# Now, do a migration in the opposite direction
handle_params["status"] = 423
migrate_fut = executor.submit(
env.storage_controller.tenant_shard_migrate, shard_0_id, origin_pageserver.id
)
def logged_stuck_again():
env.storage_controller.assert_log_contains(NOTIFY_BLOCKED_LOG, offset=log_cursor)
wait_until(10, 0.25, logged_stuck_again)
assert migrate_fut.running()
# This time, the compute hook remains stuck, but we mark the origin node offline: this should
# also allow the migration to complete -- we only wait for the compute hook as long as we think
# the old location is still usable for computes.
# This is a regression test for issue https://github.com/neondatabase/neon/issues/8901
dest_pageserver.stop()
env.storage_controller.node_configure(dest_ps_id, {"availability": "Offline"})
try:
migrate_fut.result(timeout=10)
except StorageControllerApiException as e:
# The reconciler will fail because it can't detach from the origin: the important
# thing is that it finishes, rather than getting stuck in the compute notify loop.
assert "Reconcile error" in str(e)
# A later background reconciliation will clean up and leave things in a neat state, even
# while the compute hook is still blocked
try:
env.storage_controller.reconcile_all()
except StorageControllerApiException as e:
# We expect that the reconciler will do its work, but be unable to fully succeed
# because it can't send a compute notification. It will complete, but leave
# the internal flag set for "retry compute notification later"
assert "Control plane tenant busy" in str(e)
# Confirm that we are AttachedSingle on the node we last called the migrate API for
loc = origin_pageserver.http_client().tenant_get_location(shard_0_id)
assert loc["mode"] == "AttachedSingle"
# When the origin node comes back, it should get cleaned up
dest_pageserver.start()
try:
env.storage_controller.reconcile_all()
except StorageControllerApiException as e:
# Compute hook is still blocked: reconciler will configure PS but not fully succeed
assert "Control plane tenant busy" in str(e)
with pytest.raises(PageserverApiException, match="Tenant shard not found"):
dest_pageserver.http_client().tenant_get_location(shard_0_id)
# Once the compute hook is unblocked, we should be able to get into a totally
# quiescent state again
handle_params["status"] = 200
env.storage_controller.reconcile_until_idle()
env.storage_controller.consistency_check()
def test_storage_controller_debug_apis(neon_env_builder: NeonEnvBuilder):
"""
Verify that occasional-use debug APIs work as expected. This is a lightweight test

View File

@@ -27,15 +27,20 @@ def test_empty_tenant_size(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_configs()
env.start()
(tenant_id, timeline_id) = env.neon_cli.create_tenant()
(tenant_id, _) = env.neon_cli.create_tenant()
http_client = env.pageserver.http_client()
initial_size = http_client.tenant_size(tenant_id)
# we should never have zero, because there should be the initdb "changes"
assert initial_size > 0, "initial implementation returns ~initdb tenant_size"
main_branch_name = "main"
branch_name, main_timeline_id = env.neon_cli.list_timelines(tenant_id)[0]
assert branch_name == main_branch_name
endpoint = env.endpoints.create_start(
"main",
main_branch_name,
tenant_id=tenant_id,
config_lines=["autovacuum=off", "checkpoint_timeout=10min"],
)
@@ -49,7 +54,7 @@ def test_empty_tenant_size(neon_env_builder: NeonEnvBuilder):
# The transaction above will make the compute generate a checkpoint.
# In turn, the pageserver persists the checkpoint. This should only be
# one key with a size of a couple hundred bytes.
wait_for_last_flush_lsn(env, endpoint, tenant_id, timeline_id)
wait_for_last_flush_lsn(env, endpoint, tenant_id, main_timeline_id)
size = http_client.tenant_size(tenant_id)
assert size >= initial_size and size - initial_size < 1024
@@ -301,8 +306,7 @@ def test_single_branch_get_tenant_size_grows(
env = neon_env_builder.init_start(initial_tenant_conf=tenant_config)
tenant_id = env.initial_tenant
timeline_id = env.initial_timeline
branch_name = "main"
branch_name, timeline_id = env.neon_cli.list_timelines(tenant_id)[0]
http_client = env.pageserver.http_client()
@@ -512,8 +516,7 @@ def test_get_tenant_size_with_multiple_branches(
env.pageserver.allowed_errors.append(".*InternalServerError\\(No such file or directory.*")
tenant_id = env.initial_tenant
main_timeline_id = env.initial_timeline
main_branch_name = "main"
main_branch_name, main_timeline_id = env.neon_cli.list_timelines(tenant_id)[0]
http_client = env.pageserver.http_client()

View File

@@ -71,9 +71,10 @@ def test_tenants_many(neon_env_builder: NeonEnvBuilder):
"checkpoint_distance": "5000000",
}
)
env.neon_cli.create_timeline("test_tenants_many", tenant_id=tenant)
endpoint = env.endpoints.create_start(
"main",
"test_tenants_many",
tenant_id=tenant,
)
tenants_endpoints.append((tenant, endpoint))

View File

@@ -638,7 +638,7 @@ def test_delete_timeline_client_hangup(neon_env_builder: NeonEnvBuilder):
wait_until(50, 0.1, first_request_finished)
# check that the timeline is gone
wait_timeline_detail_404(ps_http, env.initial_tenant, child_timeline_id, iterations=10)
wait_timeline_detail_404(ps_http, env.initial_tenant, child_timeline_id, iterations=2)
def test_timeline_delete_works_for_remote_smoke(

View File

@@ -45,7 +45,10 @@ def test_gc_blocking_by_timeline(neon_env_builder: NeonEnvBuilder, sharded: bool
tenant_after = http.tenant_status(env.initial_tenant)
assert tenant_before != tenant_after
gc_blocking = tenant_after["gc_blocking"]
assert gc_blocking == "BlockingReasons { timelines: 1, reasons: EnumSet(Manual) }"
assert (
gc_blocking
== "BlockingReasons { tenant_blocked_by_lsn_lease_deadline: false, timelines: 1, reasons: EnumSet(Manual) }"
)
wait_for_another_gc_round()
pss.assert_log_contains(gc_skipped_line)

Some files were not shown because too many files have changed in this diff Show More