Compare commits

..

4 Commits

Author SHA1 Message Date
Dmitry Rodionov
39fbac354f WIP 2023-02-27 14:38:46 +02:00
Shany Pozin
84a8089ae7 cr feedback 2023-02-27 13:00:13 +02:00
Shany Pozin
6ce1638df2 Add tracing for incoming request 2023-02-27 12:38:56 +02:00
Shany Pozin
cf4965f95b Add UUID header to mgmt API 2023-02-27 12:24:15 +02:00
203 changed files with 4925 additions and 10320 deletions

View File

@@ -91,15 +91,6 @@
tags:
- pageserver
# used in `pageserver.service` template
- name: learn current availability_zone
shell:
cmd: "curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone"
register: ec2_availability_zone
- set_fact:
ec2_availability_zone={{ ec2_availability_zone.stdout }}
- name: upload systemd service definition
ansible.builtin.template:
src: systemd/pageserver.service
@@ -127,7 +118,7 @@
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/pageservers
tags:
- pageserver
@@ -162,15 +153,6 @@
tags:
- safekeeper
# used in `safekeeper.service` template
- name: learn current availability_zone
shell:
cmd: "curl -s http://169.254.169.254/latest/meta-data/placement/availability-zone"
register: ec2_availability_zone
- set_fact:
ec2_availability_zone={{ ec2_availability_zone.stdout }}
# in the future safekeepers should discover pageservers byself
# but currently use first pageserver that was discovered
- name: set first pageserver var for safekeepers
@@ -206,6 +188,6 @@
cmd: |
INSTANCE_ID=$(curl -s http://169.254.169.254/latest/meta-data/instance-id)
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/$INSTANCE_ID | jq '.version = {{ current_version }}' > /tmp/new_version
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
curl -sfS -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -X POST -d@/tmp/new_version {{ console_mgmt_base_url }}/management/api/v2/safekeepers
tags:
- safekeeper

View File

@@ -27,8 +27,6 @@ storage:
ansible_host: i-0cd8d316ecbb715be
pageserver-1.eu-central-1.aws.neon.tech:
ansible_host: i-090044ed3d383fef0
pageserver-2.eu-central-1.aws.neon.tech:
ansible_host: i-033584edf3f4b6742
safekeepers:
hosts:

View File

@@ -26,7 +26,7 @@ EOF
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers/${INSTANCE_ID} -o /dev/null; then
# not registered, so register it now
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/pageservers -d@/tmp/payload | jq -r '.id')
# init pageserver
sudo -u pageserver /usr/local/bin/pageserver -c "id=${ID}" -c "pg_distrib_dir='/usr/local'" --init -D /storage/pageserver/data

View File

@@ -25,7 +25,7 @@ EOF
if ! curl -sf -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers/${INSTANCE_ID} -o /dev/null; then
# not registered, so register it now
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" -H "Content-Type: application/json" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
ID=$(curl -sf -X POST -H "Authorization: Bearer {{ CONSOLE_API_TOKEN }}" {{ console_mgmt_base_url }}/management/api/v2/safekeepers -d@/tmp/payload | jq -r '.id')
# init safekeeper
sudo -u safekeeper /usr/local/bin/safekeeper --id ${ID} --init -D /storage/safekeeper/data
fi

View File

@@ -8,14 +8,6 @@ storage:
pg_distrib_dir: /usr/local
metric_collection_endpoint: http://neon-internal-api.aws.neon.build/billing/api/v1/usage_events
metric_collection_interval: 10min
disk_usage_based_eviction:
max_usage_pct: 80
# TODO: learn typical resident-size growth rate [GiB/minute] and configure
# min_avail_bytes such that we have X minutes of headroom.
min_avail_bytes: 0
# We assume that the worst-case growth rate is small enough that we can
# catch above-threshold conditions by checking every 10s.
period: "10s"
tenant_config:
eviction_policy:
kind: "LayerAccessThreshold"

View File

@@ -8,14 +8,6 @@ storage:
pg_distrib_dir: /usr/local
metric_collection_endpoint: http://neon-internal-api.aws.neon.build/billing/api/v1/usage_events
metric_collection_interval: 10min
disk_usage_based_eviction:
max_usage_pct: 80
# TODO: learn typical resident-size growth rate [GiB/minute] and configure
# min_avail_bytes such that we have X minutes of headroom.
min_avail_bytes: 0
# We assume that the worst-case growth rate is small enough that we can
# catch above-threshold conditions by checking every 10s.
period: "10s"
tenant_config:
eviction_policy:
kind: "LayerAccessThreshold"

View File

@@ -6,7 +6,7 @@ After=network.target auditd.service
Type=simple
User=pageserver
Environment=RUST_BACKTRACE=1 NEON_REPO_DIR=/storage/pageserver LD_LIBRARY_PATH=/usr/local/v14/lib SENTRY_DSN={{ SENTRY_URL_PAGESERVER }} SENTRY_ENVIRONMENT={{ sentry_environment }}
ExecStart=/usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -c "broker_endpoint='{{ broker_endpoint }}'" -c "availability_zone='{{ ec2_availability_zone }}'" -D /storage/pageserver/data
ExecStart=/usr/local/bin/pageserver -c "pg_distrib_dir='/usr/local'" -c "listen_pg_addr='0.0.0.0:6400'" -c "listen_http_addr='0.0.0.0:9898'" -c "broker_endpoint='{{ broker_endpoint }}'" -D /storage/pageserver/data
ExecReload=/bin/kill -HUP $MAINPID
KillMode=mixed
KillSignal=SIGINT

View File

@@ -6,7 +6,7 @@ After=network.target auditd.service
Type=simple
User=safekeeper
Environment=RUST_BACKTRACE=1 NEON_REPO_DIR=/storage/safekeeper/data LD_LIBRARY_PATH=/usr/local/v14/lib SENTRY_DSN={{ SENTRY_URL_SAFEKEEPER }} SENTRY_ENVIRONMENT={{ sentry_environment }}
ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}{{ hostname_suffix }}:6500 --listen-http {{ inventory_hostname }}{{ hostname_suffix }}:7676 -D /storage/safekeeper/data --broker-endpoint={{ broker_endpoint }} --remote-storage='{bucket_name="{{bucket_name}}", bucket_region="{{bucket_region}}", prefix_in_bucket="{{ safekeeper_s3_prefix }}"}' --availability-zone={{ ec2_availability_zone }}
ExecStart=/usr/local/bin/safekeeper -l {{ inventory_hostname }}{{ hostname_suffix }}:6500 --listen-http {{ inventory_hostname }}{{ hostname_suffix }}:7676 -D /storage/safekeeper/data --broker-endpoint={{ broker_endpoint }} --remote-storage='{bucket_name="{{bucket_name}}", bucket_region="{{bucket_region}}", prefix_in_bucket="{{ safekeeper_s3_prefix }}"}'
ExecReload=/bin/kill -HUP $MAINPID
KillMode=mixed
KillSignal=SIGINT

View File

@@ -1,22 +1,6 @@
# Helm chart values for neon-proxy-scram.
# This is a YAML-formatted file.
deploymentStrategy:
type: RollingUpdate
rollingUpdate:
maxSurge: 100%
maxUnavailable: 50%
# Delay the kill signal by 7 days (7 * 24 * 60 * 60)
# The pod(s) will stay in Terminating, keeps the existing connections
# but doesn't receive new ones
containerLifecycle:
preStop:
exec:
command: ["/bin/sh", "-c", "sleep 604800"]
terminationGracePeriodSeconds: 604800
image:
repository: neondatabase/neon

View File

@@ -74,12 +74,15 @@ jobs:
- name: Install Python deps
run: ./scripts/pysync
- name: Run ruff to ensure code format
run: poetry run ruff .
- name: Run isort to ensure code format
run: poetry run isort --diff --check .
- name: Run black to ensure code format
run: poetry run black --diff --check .
- name: Run flake8 to ensure code format
run: poetry run flake8 .
- name: Run mypy to check types
run: poetry run mypy .
@@ -548,48 +551,6 @@ jobs:
- name: Cleanup ECR folder
run: rm -rf ~/.ecr
neon-image-depot:
# For testing this will run side-by-side for a few merges.
# This action is not really optimized yet, but gets the job done
runs-on: [ self-hosted, gen3, large ]
needs: [ tag ]
container: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
permissions:
contents: read
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 0
- name: Setup go
uses: actions/setup-go@v3
with:
go-version: '1.19'
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Install Crane & ECR helper
run: go install github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cli/docker-credential-ecr-login@69c85dc22db6511932bbf119e1a0cc5c90c69a7f # v0.6.0
- name: Configure ECR login
run: |
mkdir /github/home/.docker/
echo "{\"credsStore\":\"ecr-login\"}" > /github/home/.docker/config.json
- name: Build and push
uses: depot/build-push-action@v1
with:
# if no depot.json file is at the root of your repo, you must specify the project id
project: nrdv0s4kcs
push: true
tags: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:depot-${{needs.tag.outputs.build-tag}}
compute-tools-image:
runs-on: [ self-hosted, gen3, large ]
needs: [ tag ]

73
Cargo.lock generated
View File

@@ -851,7 +851,6 @@ dependencies = [
"futures",
"hyper",
"notify",
"num_cpus",
"opentelemetry",
"postgres",
"regex",
@@ -914,7 +913,6 @@ dependencies = [
"once_cell",
"pageserver_api",
"postgres",
"postgres_backend",
"postgres_connection",
"regex",
"reqwest",
@@ -2456,7 +2454,6 @@ dependencies = [
"postgres",
"postgres-protocol",
"postgres-types",
"postgres_backend",
"postgres_connection",
"postgres_ffi",
"pq_proto",
@@ -2474,7 +2471,6 @@ dependencies = [
"strum",
"strum_macros",
"svg_fmt",
"sync_wrapper",
"tempfile",
"tenant_size_model",
"thiserror",
@@ -2680,28 +2676,6 @@ dependencies = [
"postgres-protocol",
]
[[package]]
name = "postgres_backend"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bytes",
"futures",
"once_cell",
"pq_proto",
"rustls",
"rustls-pemfile",
"serde",
"thiserror",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls",
"tracing",
"workspace_hack",
]
[[package]]
name = "postgres_connection"
version = "0.1.0"
@@ -2749,7 +2723,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
name = "pq_proto"
version = "0.1.0"
dependencies = [
"byteorder",
"anyhow",
"bytes",
"pin-project-lite",
"postgres-protocol",
@@ -2924,7 +2898,6 @@ dependencies = [
"opentelemetry",
"parking_lot",
"pin-project-lite",
"postgres_backend",
"pq_proto",
"prometheus",
"rand",
@@ -3094,6 +3067,15 @@ dependencies = [
"workspace_hack",
]
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi",
]
[[package]]
name = "reqwest"
version = "0.11.14"
@@ -3304,6 +3286,15 @@ dependencies = [
"base64 0.21.0",
]
[[package]]
name = "rustls-split"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3"
dependencies = [
"rustls",
]
[[package]]
name = "rustversion"
version = "1.0.11"
@@ -3325,7 +3316,6 @@ dependencies = [
"async-trait",
"byteorder",
"bytes",
"chrono",
"clap 4.1.4",
"const_format",
"crc32c",
@@ -3335,11 +3325,11 @@ dependencies = [
"humantime",
"hyper",
"metrics",
"nix",
"once_cell",
"parking_lot",
"postgres",
"postgres-protocol",
"postgres_backend",
"postgres_ffi",
"pq_proto",
"regex",
@@ -3859,15 +3849,16 @@ dependencies = [
[[package]]
name = "tempfile"
version = "3.4.0"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95"
checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4"
dependencies = [
"cfg-if",
"fastrand",
"libc",
"redox_syscall",
"rustix",
"windows-sys 0.42.0",
"remove_dir_all",
"winapi",
]
[[package]]
@@ -4524,7 +4515,7 @@ dependencies = [
"byteorder",
"bytes",
"criterion",
"futures",
"git-version",
"heapless",
"hex",
"hex-literal",
@@ -4533,10 +4524,12 @@ dependencies = [
"metrics",
"nix",
"once_cell",
"pin-project-lite",
"pq_proto",
"rand",
"regex",
"routerify",
"rustls",
"rustls-pemfile",
"rustls-split",
"sentry",
"serde",
"serde_json",
@@ -4547,6 +4540,7 @@ dependencies = [
"tempfile",
"thiserror",
"tokio",
"tokio-rustls",
"tracing",
"tracing-subscriber",
"url",
@@ -4850,19 +4844,15 @@ name = "workspace_hack"
version = "0.1.0"
dependencies = [
"anyhow",
"byteorder",
"bytes",
"chrono",
"clap 4.1.4",
"crossbeam-utils",
"digest",
"either",
"fail",
"futures",
"futures-channel",
"futures-core",
"futures-executor",
"futures-sink",
"futures-util",
"hashbrown 0.12.3",
"indexmap",
@@ -4887,7 +4877,6 @@ dependencies = [
"socket2",
"syn",
"tokio",
"tokio-rustls",
"tokio-util",
"tonic",
"tower",

View File

@@ -64,7 +64,6 @@ md5 = "0.7.0"
memoffset = "0.8"
nix = "0.26"
notify = "5.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
once_cell = "1.13"
opentelemetry = "0.18.0"
@@ -134,7 +133,6 @@ heapless = { default-features=false, features=[], git = "https://github.com/japa
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" }
postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }
@@ -152,7 +150,7 @@ workspace_hack = { version = "0.1", path = "./workspace_hack/" }
criterion = "0.4"
rcgen = "0.10"
rstest = "0.16"
tempfile = "3.4"
tempfile = "3.2"
tonic-build = "0.8"
# This is only needed for proxy's tests.

View File

@@ -39,7 +39,7 @@ ARG CACHEPOT_BUCKET=neon-github-dev
COPY --from=pg-build /home/nonroot/pg_install/v14/include/postgresql/server pg_install/v14/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v15/include/postgresql/server pg_install/v15/include/postgresql/server
COPY --chown=nonroot . .
COPY . .
# Show build caching stats to check if it was used in the end.
# Has to be the part of the same RUN since cachepot daemon is killed in the end of this RUN, losing the compilation stats.

View File

@@ -225,81 +225,6 @@ RUN wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pg_hashids.control
#########################################################################################
#
# Layer "rum-pg-build"
# compile rum extension
#
#########################################################################################
FROM build-deps AS rum-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/postgrespro/rum/archive/refs/tags/1.3.13.tar.gz -O rum.tar.gz && \
mkdir rum-src && cd rum-src && tar xvzf ../rum.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/rum.control
#########################################################################################
#
# Layer "pgtap-pg-build"
# compile pgTAP extension
#
#########################################################################################
FROM build-deps AS pgtap-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/theory/pgtap/archive/refs/tags/v1.2.0.tar.gz -O pgtap.tar.gz && \
mkdir pgtap-src && cd pgtap-src && tar xvzf ../pgtap.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgtap.control
#########################################################################################
#
# Layer "prefix-pg-build"
# compile Prefix extension
#
#########################################################################################
FROM build-deps AS prefix-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.9.tar.gz -O prefix.tar.gz && \
mkdir prefix-src && cd prefix-src && tar xvzf ../prefix.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/prefix.control
#########################################################################################
#
# Layer "hll-pg-build"
# compile hll extension
#
#########################################################################################
FROM build-deps AS hll-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.17.tar.gz -O hll.tar.gz && \
mkdir hll-src && cd hll-src && tar xvzf ../hll.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/hll.control
#########################################################################################
#
# Layer "plpgsql-check-pg-build"
# compile plpgsql_check extension
#
#########################################################################################
FROM build-deps AS plpgsql-check-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.3.2.tar.gz -O plpgsql_check.tar.gz && \
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xvzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/plpgsql_check.control
#########################################################################################
#
# Layer "rust extensions"
@@ -323,7 +248,7 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
chmod +x rustup-init && \
./rustup-init -y --no-modify-path --profile minimal --default-toolchain stable && \
rm rustup-init && \
cargo install --locked --version 0.7.3 cargo-pgx && \
cargo install --git https://github.com/vadim2404/pgx --branch neon_abi_v0.6.1 --locked cargo-pgx && \
/bin/bash -c 'cargo pgx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
USER root
@@ -337,11 +262,11 @@ USER root
FROM rust-extensions-build AS pg-jsonschema-pg-build
# there is no release tag yet, but we need it due to the superuser fix in the control file
RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421ec66466a3abbb37b7ee6.tar.gz -O pg_jsonschema.tar.gz && \
mkdir pg_jsonschema-src && cd pg_jsonschema-src && tar xvzf ../pg_jsonschema.tar.gz --strip-components=1 -C . && \
sed -i 's/pgx = "0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
RUN git clone --depth=1 --single-branch --branch neon_abi_v0.1.4 https://github.com/vadim2404/pg_jsonschema/ && \
cd pg_jsonschema && \
cargo pgx install --release && \
# it's needed to enable extension because it uses untrusted C language
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_jsonschema.control && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_jsonschema.control
#########################################################################################
@@ -353,32 +278,13 @@ RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421e
FROM rust-extensions-build AS pg-graphql-pg-build
# Currently pgx version bump to >= 0.7.2 causes "call to unsafe function" compliation errors in
# pgx-contrib-spiext. There is a branch that removes that dependency, so use it. It is on the
# same 1.1 version we've used before.
RUN git clone -b remove-pgx-contrib-spiext --single-branch https://github.com/yrashk/pg_graphql && \
cd pg_graphql && \
sed -i 's/pgx = "~0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
sed -i 's/pgx-tests = "~0.7.1"/pgx-tests = "0.7.3"/g' Cargo.toml && \
RUN git clone --depth=1 --single-branch --branch neon_abi_v1.1.0 https://github.com/vadim2404/pg_graphql && \
cd pg_graphql && \
cargo pgx install --release && \
# it's needed to enable extension because it uses untrusted C language
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_graphql.control && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_graphql.control
#########################################################################################
#
# Layer "pg-tiktoken-build"
# Compile "pg_tiktoken" extension
#
#########################################################################################
FROM rust-extensions-build AS pg-tiktoken-pg-build
RUN git clone --depth=1 --single-branch https://github.com/kelvich/pg_tiktoken && \
cd pg_tiktoken && \
cargo pgx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_tiktoken.control
#########################################################################################
#
# Layer "neon-pg-ext-build"
@@ -396,23 +302,13 @@ COPY --from=vector-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgjwt-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-jsonschema-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-graphql-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-tiktoken-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hypopg-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-hashids-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=rum-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgtap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=prefix-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hll-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=plpgsql-check-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon \
-s install && \
make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon_utils \
-s install
#########################################################################################
@@ -467,7 +363,7 @@ COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-deb
# Install:
# libreadline8 for psql
# libicu67, locales for collations (including ICU and plpgsql_check)
# libicu67, locales for collations (including ICU)
# libossp-uuid16 for extension ossp-uuid
# libgeos, libgdal, libsfcgal1, libproj and libprotobuf-c1 for PostGIS
# libxml2, libxslt1.1 for xml2

View File

@@ -1,70 +1,32 @@
# Note: this file *mostly* just builds on Dockerfile.compute-node
ARG SRC_IMAGE
ARG VM_INFORMANT_VERSION=v0.1.14
# on libcgroup update, make sure to check bootstrap.sh for changes
ARG LIBCGROUP_VERSION=v2.0.3
ARG VM_INFORMANT_VERSION=v0.1.6
# Pull VM informant, to copy from later
# Pull VM informant and set up inittab
FROM neondatabase/vm-informant:$VM_INFORMANT_VERSION as informant
# Build cgroup-tools
#
# At time of writing (2023-03-14), debian bullseye has a version of cgroup-tools (technically
# libcgroup) that doesn't support cgroup v2 (version 0.41-11). Unfortunately, the vm-informant
# requires cgroup v2, so we'll build cgroup-tools ourselves.
FROM debian:bullseye-slim as libcgroup-builder
ARG LIBCGROUP_VERSION
RUN set -exu \
&& apt update \
&& apt install --no-install-recommends -y \
git \
ca-certificates \
automake \
cmake \
make \
gcc \
byacc \
flex \
libtool \
libpam0g-dev \
&& git clone --depth 1 -b $LIBCGROUP_VERSION https://github.com/libcgroup/libcgroup \
&& INSTALL_DIR="/libcgroup-install" \
&& mkdir -p "$INSTALL_DIR/bin" "$INSTALL_DIR/include" \
&& cd libcgroup \
# extracted from bootstrap.sh, with modified flags:
&& (test -d m4 || mkdir m4) \
&& autoreconf -fi \
&& rm -rf autom4te.cache \
&& CFLAGS="-O3" ./configure --prefix="$INSTALL_DIR" --sysconfdir=/etc --localstatedir=/var --enable-opaque-hierarchy="name=systemd" \
# actually build the thing...
&& make install
# Combine, starting from non-VM compute node image.
FROM $SRC_IMAGE as base
# Temporarily set user back to root so we can run adduser, set inittab
USER root
RUN adduser vm-informant --disabled-password --no-create-home
RUN set -e \
&& rm -f /etc/inittab \
&& touch /etc/inittab
ADD vm-cgconfig.conf /etc/cgconfig.conf
RUN set -e \
&& echo "::sysinit:cgconfigparser -l /etc/cgconfig.conf -s 1664" >> /etc/inittab \
&& CONNSTR="dbname=neondb user=cloud_admin sslmode=disable" \
&& ARGS="--auto-restart --cgroup=neon-postgres --pgconnstr=\"$CONNSTR\"" \
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant $ARGS'" >> /etc/inittab
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant --auto-restart --cgroup=neon-postgres'" >> /etc/inittab
# Combine, starting from non-VM compute node image.
FROM $SRC_IMAGE as base
# Temporarily set user back to root so we can run apt update and adduser
USER root
RUN apt update && \
apt install --no-install-recommends -y \
cgroup-tools
RUN adduser vm-informant --disabled-password --no-create-home
USER postgres
ADD vm-cgconfig.conf /etc/cgconfig.conf
COPY --from=informant /etc/inittab /etc/inittab
COPY --from=informant /usr/bin/vm-informant /usr/local/bin/vm-informant
COPY --from=libcgroup-builder /libcgroup-install/bin/* /usr/bin/
COPY --from=libcgroup-builder /libcgroup-install/lib/* /usr/lib/
COPY --from=libcgroup-builder /libcgroup-install/sbin/* /usr/sbin/
ENTRYPOINT ["/usr/sbin/cgexec", "-g", "*:neon-postgres", "/usr/local/bin/compute_ctl"]

View File

@@ -133,11 +133,6 @@ neon-pg-ext-%: postgres-%
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install
+@echo "Compiling neon_utils $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install
.PHONY: neon-pg-ext-clean-%
neon-pg-ext-clean-%:
@@ -150,9 +145,6 @@ neon-pg-ext-clean-%:
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile clean
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean
.PHONY: neon-pg-ext
neon-pg-ext: \

View File

@@ -46,14 +46,11 @@ postgresql-libs cmake postgresql protobuf
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
```
#### Installing dependencies on macOS (12.3.1)
#### Installing dependencies on OSX (12.3.1)
1. Install XCode and dependencies
```
xcode-select --install
brew install protobuf openssl flex bison
# add openssl to PATH, required for ed25519 keys generation in neon_local
echo 'export PATH="$(brew --prefix openssl)/bin:$PATH"' >> ~/.zshrc
```
2. [Install Rust](https://www.rust-lang.org/tools/install)

View File

@@ -11,7 +11,6 @@ clap.workspace = true
futures.workspace = true
hyper = { workspace = true, features = ["full"] }
notify.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
postgres.workspace = true
regex.workspace = true

View File

@@ -133,7 +133,6 @@ fn main() -> Result<()> {
.settings
.find("neon.pageserver_connstring")
.expect("pageserver connstr should be provided");
let storage_auth_token = spec.storage_auth_token.clone();
let tenant = spec
.cluster
.settings
@@ -154,7 +153,6 @@ fn main() -> Result<()> {
tenant,
timeline,
pageserver_connstr,
storage_auth_token,
metrics: ComputeMetrics::default(),
state: RwLock::new(ComputeState::new()),
};

View File

@@ -18,7 +18,6 @@ use std::fs;
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::{Command, Stdio};
use std::str::FromStr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::RwLock;
@@ -26,7 +25,6 @@ use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use postgres::{Client, NoTls};
use serde::{Serialize, Serializer};
use tokio_postgres;
use tracing::{info, instrument, warn};
use crate::checker::create_writability_check_data;
@@ -45,7 +43,6 @@ pub struct ComputeNode {
pub tenant: String,
pub timeline: String,
pub pageserver_connstr: String,
pub storage_auth_token: Option<String>,
pub metrics: ComputeMetrics,
/// Volatile part of the `ComputeNode` so should be used under `RwLock`
/// to allow HTTP API server to serve status requests, while configuration
@@ -128,18 +125,7 @@ impl ComputeNode {
fn get_basebackup(&self, lsn: &str) -> Result<()> {
let start_time = Utc::now();
let mut config = postgres::Config::from_str(&self.pageserver_connstr)?;
// Use the storage auth token from the config file, if given.
// Note: this overrides any password set in the connection string.
if let Some(storage_auth_token) = &self.storage_auth_token {
info!("Got storage auth token from spec file");
config.password(storage_auth_token);
} else {
info!("Storage auth token not set");
}
let mut client = config.connect(NoTls)?;
let mut client = Client::connect(&self.pageserver_connstr, NoTls)?;
let basebackup_cmd = match lsn {
"0/0" => format!("basebackup {} {}", &self.tenant, &self.timeline), // First start of the compute
_ => format!("basebackup {} {} {}", &self.tenant, &self.timeline, lsn),
@@ -176,11 +162,6 @@ impl ComputeNode {
let sync_handle = Command::new(&self.pgbin)
.args(["--sync-safekeepers"])
.env("PGDATA", &self.pgdata) // we cannot use -D in this mode
.envs(if let Some(storage_auth_token) = &self.storage_auth_token {
vec![("NEON_AUTH_TOKEN", storage_auth_token)]
} else {
vec![]
})
.stdout(Stdio::piped())
.spawn()
.expect("postgres --sync-safekeepers failed to start");
@@ -258,11 +239,6 @@ impl ComputeNode {
// Run postgres as a child process.
let mut pg = Command::new(&self.pgbin)
.args(["-D", &self.pgdata])
.envs(if let Some(storage_auth_token) = &self.storage_auth_token {
vec![("NEON_AUTH_TOKEN", storage_auth_token)]
} else {
vec![]
})
.spawn()
.expect("cannot start postgres process");
@@ -308,7 +284,6 @@ impl ComputeNode {
handle_role_deletions(self, &mut client)?;
handle_grants(self, &mut client)?;
create_writability_check_data(&mut client)?;
handle_extensions(&self.spec, &mut client)?;
// 'Close' connection
drop(client);
@@ -425,43 +400,4 @@ impl ComputeNode {
Ok(())
}
/// Select `pg_stat_statements` data and return it as a stringified JSON
pub async fn collect_insights(&self) -> String {
let mut result_rows: Vec<String> = Vec::new();
let connect_result = tokio_postgres::connect(self.connstr.as_str(), NoTls).await;
let (client, connection) = connect_result.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let result = client
.simple_query(
"SELECT
row_to_json(pg_stat_statements)
FROM
pg_stat_statements
WHERE
userid != 'cloud_admin'::regrole::oid
ORDER BY
(mean_exec_time + mean_plan_time) DESC
LIMIT 100",
)
.await;
if let Ok(raw_rows) = result {
for message in raw_rows.iter() {
if let postgres::SimpleQueryMessage::Row(row) = message {
if let Some(json) = row.get(0) {
result_rows.push(json.to_string());
}
}
}
format!("{{\"pg_stat_statements\": [{}]}}", result_rows.join(","))
} else {
"{{\"pg_stat_statements\": []}}".to_string()
}
}
}

View File

@@ -7,7 +7,6 @@ use crate::compute::ComputeNode;
use anyhow::Result;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use num_cpus;
use serde_json;
use tracing::{error, info};
use tracing_utils::http::OtelName;
@@ -34,13 +33,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
Response::new(Body::from(serde_json::to_string(&compute.metrics).unwrap()))
}
// Collect Postgres current usage insights
(&Method::GET, "/insights") => {
info!("serving /insights GET request");
let insights = compute.collect_insights().await;
Response::new(Body::from(insights))
}
(&Method::POST, "/check_writability") => {
info!("serving /check_writability POST request");
let res = crate::checker::check_writability(compute).await;
@@ -50,17 +42,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::GET, "/info") => {
let num_cpus = num_cpus::get_physical();
info!("serving /info GET request. num_cpus: {}", num_cpus);
Response::new(Body::from(
serde_json::json!({
"num_cpus": num_cpus,
})
.to_string(),
))
}
// Return the `404 Not Found` for any other routes.
_ => {
let mut not_found = Response::new(Body::from("404 Not Found"));

View File

@@ -10,12 +10,12 @@ paths:
/status:
get:
tags:
- Info
- "info"
summary: Get compute node internal status
description: ""
operationId: getComputeStatus
responses:
200:
"200":
description: ComputeState
content:
application/json:
@@ -25,58 +25,27 @@ paths:
/metrics.json:
get:
tags:
- Info
- "info"
summary: Get compute node startup metrics in JSON format
description: ""
operationId: getComputeMetricsJSON
responses:
200:
"200":
description: ComputeMetrics
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeMetrics"
/insights:
get:
tags:
- Info
summary: Get current compute insights in JSON format
description: |
Note, that this doesn't include any historical data
operationId: getComputeInsights
responses:
200:
description: Compute insights
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeInsights"
/info:
get:
tags:
- "info"
summary: Get info about the compute Pod/VM
description: ""
operationId: getInfo
responses:
"200":
description: Info
content:
application/json:
schema:
$ref: "#/components/schemas/Info"
/check_writability:
post:
tags:
- Check
- "check"
summary: Check that we can write new data on this compute
description: ""
operationId: checkComputeWritability
responses:
200:
"200":
description: Check result
content:
text/plain:
@@ -111,15 +80,6 @@ components:
total_startup_ms:
type: integer
Info:
type: object
description: Information about VM/Pod
required:
- num_cpus
properties:
num_cpus:
type: integer
ComputeState:
type: object
required:
@@ -136,15 +96,6 @@ components:
type: string
description: Text of the error during compute startup, if any
ComputeInsights:
type: object
properties:
pg_stat_statements:
description: Contains raw output from pg_stat_statements in JSON format
type: array
items:
type: object
ComputeStatus:
type: string
enum:

View File

@@ -47,23 +47,12 @@ pub struct GenericOption {
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;
/// Escape a string for including it in a SQL literal
fn escape_literal(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
/// Escape a string so that it can be used in postgresql.conf.
/// Same as escape_literal, currently.
fn escape_conf_value(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
impl GenericOption {
/// Represent `GenericOption` as SQL statement parameter.
pub fn to_pg_option(&self) -> String {
if let Some(val) = &self.value {
match self.vartype.as_ref() {
"string" => format!("{} '{}'", self.name, escape_literal(val)),
"string" => format!("{} '{}'", self.name, val),
_ => format!("{} {}", self.name, val),
}
} else {
@@ -74,8 +63,6 @@ impl GenericOption {
/// Represent `GenericOption` as configuration option.
pub fn to_pg_setting(&self) -> String {
if let Some(val) = &self.value {
// TODO: check in the console DB that we don't have these settings
// set for any non-deleted project and drop this override.
let name = match self.name.as_str() {
"safekeepers" => "neon.safekeepers",
"wal_acceptor_reconnect" => "neon.safekeeper_reconnect_timeout",
@@ -84,7 +71,7 @@ impl GenericOption {
};
match self.vartype.as_ref() {
"string" => format!("{} = '{}'", name, escape_conf_value(val)),
"string" => format!("{} = '{}'", name, val),
_ => format!("{} = {}", name, val),
}
} else {
@@ -120,7 +107,6 @@ impl PgOptionsSerialize for GenericOptions {
.map(|op| op.to_pg_setting())
.collect::<Vec<String>>()
.join("\n")
+ "\n" // newline after last setting
} else {
"".to_string()
}

View File

@@ -24,8 +24,6 @@ pub struct ComputeSpec {
pub cluster: Cluster,
pub delta_operations: Option<Vec<DeltaOp>>,
pub storage_auth_token: Option<String>,
pub startup_tracing_context: Option<HashMap<String, String>>,
}
@@ -517,18 +515,3 @@ pub fn handle_grants(node: &ComputeNode, client: &mut Client) -> Result<()> {
Ok(())
}
/// Create required system extensions
#[instrument(skip_all)]
pub fn handle_extensions(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
if libs.contains("pg_stat_statements") {
// Create extension only if this compute really needs it
let query = "CREATE EXTENSION IF NOT EXISTS pg_stat_statements";
info!("creating system extensions with query: {}", query);
client.simple_query(query)?;
}
}
Ok(())
}

View File

@@ -178,11 +178,6 @@
"name": "neon.pageserver_connstring",
"value": "host=127.0.0.1 port=6400",
"vartype": "string"
},
{
"name": "test.escaping",
"value": "here's a backslash \\ and a quote ' and a double-quote \" hooray",
"vartype": "string"
}
]
},

View File

@@ -28,30 +28,7 @@ mod pg_helpers_tests {
assert_eq!(
spec.cluster.settings.as_pg_settings(),
r#"fsync = off
wal_level = replica
hot_standby = on
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on
shared_buffers = 32768
port = 55432
max_connections = 100
max_wal_senders = 10
listen_addresses = '0.0.0.0'
wal_sender_timeout = 0
password_encryption = md5
maintenance_work_mem = 65536
max_parallel_workers = 8
max_worker_processes = 8
neon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'
max_replication_slots = 10
neon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'
shared_preload_libraries = 'neon'
synchronous_standby_names = 'walproposer'
neon.pageserver_connstring = 'host=127.0.0.1 port=6400'
test.escaping = 'here''s a backslash \\ and a quote '' and a double-quote " hooray'
"#
"fsync = off\nwal_level = replica\nhot_standby = on\nneon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'\nwal_log_hints = on\nlog_connections = on\nshared_buffers = 32768\nport = 55432\nmax_connections = 100\nmax_wal_senders = 10\nlisten_addresses = '0.0.0.0'\nwal_sender_timeout = 0\npassword_encryption = md5\nmaintenance_work_mem = 65536\nmax_parallel_workers = 8\nmax_worker_processes = 8\nneon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'\nmax_replication_slots = 10\nneon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'\nshared_preload_libraries = 'neon'\nsynchronous_standby_names = 'walproposer'\nneon.pageserver_connstring = 'host=127.0.0.1 port=6400'"
);
}

View File

@@ -24,7 +24,6 @@ url.workspace = true
# Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api
# instead, so that recompile times are better.
pageserver_api.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
postgres_connection.workspace = true
storage_broker.workspace = true

View File

@@ -2,8 +2,7 @@
[pageserver]
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
pg_auth_type = 'Trust'
http_auth_type = 'Trust'
auth_type = 'Trust'
[[safekeepers]]
id = 1

View File

@@ -3,8 +3,7 @@
[pageserver]
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
pg_auth_type = 'Trust'
http_auth_type = 'Trust'
auth_type = 'Trust'
[[safekeepers]]
id = 1

View File

@@ -17,7 +17,6 @@ use pageserver_api::{
DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR,
DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR,
};
use postgres_backend::AuthType;
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -31,6 +30,7 @@ use utils::{
auth::{Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
project_git_version,
};
@@ -53,15 +53,14 @@ listen_addr = '{DEFAULT_BROKER_ADDR}'
id = {DEFAULT_PAGESERVER_ID}
listen_pg_addr = '{DEFAULT_PAGESERVER_PG_ADDR}'
listen_http_addr = '{DEFAULT_PAGESERVER_HTTP_ADDR}'
pg_auth_type = '{trust_auth}'
http_auth_type = '{trust_auth}'
auth_type = '{pageserver_auth_type}'
[[safekeepers]]
id = {DEFAULT_SAFEKEEPER_ID}
pg_port = {DEFAULT_SAFEKEEPER_PG_PORT}
http_port = {DEFAULT_SAFEKEEPER_HTTP_PORT}
"#,
trust_auth = AuthType::Trust,
pageserver_auth_type = AuthType::Trust,
)
}
@@ -628,7 +627,7 @@ fn handle_pg(pg_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
let node = cplane.nodes.get(&(tenant_id, node_name.to_string()));
let auth_token = if matches!(env.pageserver.pg_auth_type, AuthType::NeonJWT) {
let auth_token = if matches!(env.pageserver.auth_type, AuthType::NeonJWT) {
let claims = Claims::new(Some(tenant_id), Scope::Tenant);
Some(env.generate_auth_token(&claims)?)

View File

@@ -14,6 +14,7 @@ use anyhow::{Context, Result};
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};
@@ -87,6 +88,7 @@ impl ComputeControlPlane {
address: SocketAddr::new("127.0.0.1".parse().unwrap(), port),
env: self.env.clone(),
pageserver: Arc::clone(&self.pageserver),
is_test: false,
timeline_id,
lsn,
tenant_id,
@@ -95,7 +97,7 @@ impl ComputeControlPlane {
});
node.create_pgdata()?;
node.setup_pg_conf()?;
node.setup_pg_conf(self.env.pageserver.auth_type)?;
self.nodes
.insert((tenant_id, node.name.clone()), Arc::clone(&node));
@@ -112,6 +114,7 @@ pub struct PostgresNode {
name: String,
pub env: LocalEnv,
pageserver: Arc<PageServerNode>,
is_test: bool,
pub timeline_id: TimelineId,
pub lsn: Option<Lsn>, // if it's a read-only node. None for primary
pub tenant_id: TenantId,
@@ -169,6 +172,7 @@ impl PostgresNode {
name,
env: env.clone(),
pageserver: Arc::clone(pageserver),
is_test: false,
timeline_id,
lsn: recovery_target_lsn,
tenant_id,
@@ -274,7 +278,7 @@ impl PostgresNode {
// Write postgresql.conf with default configuration
// and PG_VERSION file to the data directory of a new node.
fn setup_pg_conf(&self) -> Result<()> {
fn setup_pg_conf(&self, auth_type: AuthType) -> Result<()> {
let mut conf = PostgresConf::new();
conf.append("max_wal_senders", "10");
conf.append("wal_log_hints", "off");
@@ -298,12 +302,29 @@ impl PostgresNode {
let config = &self.pageserver.pg_connection_config;
let (host, port) = (config.host(), config.port());
// NOTE: avoid spaces in connection string, because it is less error prone if we forward it somewhere.
format!("postgresql://no_user@{host}:{port}")
// Set up authentication
//
// $NEON_AUTH_TOKEN will be replaced with value from environment
// variable during compute pg startup. It is done this way because
// otherwise user will be able to retrieve the value using SHOW
// command or pg_settings
let password = if let AuthType::NeonJWT = auth_type {
"$NEON_AUTH_TOKEN"
} else {
""
};
// NOTE avoiding spaces in connection string, because it is less error prone if we forward it somewhere.
// Also note that not all parameters are supported here. Because in compute we substitute $NEON_AUTH_TOKEN
// We parse this string and build it back with token from env var, and for simplicity rebuild
// uses only needed variables namely host, port, user, password.
format!("postgresql://no_user:{password}@{host}:{port}")
};
conf.append("shared_preload_libraries", "neon");
conf.append_line("");
conf.append("neon.pageserver_connstring", &pageserver_connstr);
if let AuthType::NeonJWT = auth_type {
conf.append("neon.safekeeper_token_env", "$NEON_AUTH_TOKEN");
}
conf.append("neon.tenant_id", &self.tenant_id.to_string());
conf.append("neon.timeline_id", &self.timeline_id.to_string());
if let Some(lsn) = self.lsn {
@@ -426,8 +447,6 @@ impl PostgresNode {
"DYLD_LIBRARY_PATH",
self.env.pg_lib_dir(self.pg_version)?.to_str().unwrap(),
);
// Pass authentication token used for the connections to pageserver and safekeepers
if let Some(token) = auth_token {
cmd.env("NEON_AUTH_TOKEN", token);
}
@@ -477,6 +496,10 @@ impl PostgresNode {
self.pg_ctl(&["start"], auth_token)
}
pub fn restart(&self, auth_token: &Option<String>) -> Result<()> {
self.pg_ctl(&["restart"], auth_token)
}
pub fn stop(&self, destroy: bool) -> Result<()> {
// If we are going to destroy data directory,
// use immediate shutdown mode, otherwise,
@@ -507,4 +530,26 @@ impl PostgresNode {
"postgres"
)
}
// XXX: cache that in control plane
pub fn whoami(&self) -> String {
let output = Command::new("whoami")
.output()
.expect("failed to execute whoami");
assert!(output.status.success(), "whoami failed");
String::from_utf8(output.stdout).unwrap().trim().to_string()
}
}
impl Drop for PostgresNode {
// destructor to clean up state after test is done
// XXX: we may detect failed test by setting some flag in catch_unwind()
// and checking it here. But let just clean datadirs on start.
fn drop(&mut self) {
if self.is_test {
let _ = self.stop(true);
}
}
}

View File

@@ -5,7 +5,6 @@
use anyhow::{bail, ensure, Context};
use postgres_backend::AuthType;
use reqwest::Url;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
@@ -18,8 +17,9 @@ use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use utils::{
auth::{encode_from_key_file, Claims},
auth::{encode_from_key_file, Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
postgres_backend::AuthType,
};
use crate::safekeeper::SafekeeperNode;
@@ -110,14 +110,15 @@ impl NeonBroker {
pub struct PageServerConf {
// node id
pub id: NodeId,
// Pageserver connection settings
pub listen_pg_addr: String,
pub listen_http_addr: String,
// auth type used for the PG and HTTP ports
pub pg_auth_type: AuthType,
pub http_auth_type: AuthType,
// used to determine which auth type is used
pub auth_type: AuthType,
// jwt auth token used for communication with pageserver
pub auth_token: String,
}
impl Default for PageServerConf {
@@ -126,8 +127,8 @@ impl Default for PageServerConf {
id: NodeId(0),
listen_pg_addr: String::new(),
listen_http_addr: String::new(),
pg_auth_type: AuthType::Trust,
http_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_token: String::new(),
}
}
}
@@ -400,33 +401,48 @@ impl LocalEnv {
fs::create_dir(base_path)?;
// Generate keypair for JWT.
//
// The keypair is only needed if authentication is enabled in any of the
// components. For convenience, we generate the keypair even if authentication
// is not enabled, so that you can easily enable it after the initialization
// step. However, if the key generation fails, we treat it as non-fatal if
// authentication was not enabled.
// generate keys for jwt
// openssl genrsa -out private_key.pem 2048
let private_key_path;
if self.private_key_path == PathBuf::new() {
match generate_auth_keys(
base_path.join("auth_private_key.pem").as_path(),
base_path.join("auth_public_key.pem").as_path(),
) {
Ok(()) => {
self.private_key_path = PathBuf::from("auth_private_key.pem");
}
Err(e) => {
if !self.auth_keys_needed() {
eprintln!("Could not generate keypair for JWT authentication: {e}");
eprintln!("Continuing anyway because authentication was not enabled");
self.private_key_path = PathBuf::from("auth_private_key.pem");
} else {
return Err(e);
}
}
private_key_path = base_path.join("auth_private_key.pem");
let keygen_output = Command::new("openssl")
.arg("genrsa")
.args(["-out", private_key_path.to_str().unwrap()])
.arg("2048")
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
self.private_key_path = PathBuf::from("auth_private_key.pem");
let public_key_path = base_path.join("auth_public_key.pem");
// openssl rsa -in private_key.pem -pubout -outform PEM -out public_key.pem
let keygen_output = Command::new("openssl")
.arg("rsa")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-outform", "PEM"])
.args(["-out", public_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
}
self.pageserver.auth_token =
self.generate_auth_token(&Claims::new(None, Scope::PageServerApi))?;
fs::create_dir_all(self.pg_data_dirs_path())?;
for safekeeper in &self.safekeepers {
@@ -435,12 +451,6 @@ impl LocalEnv {
self.persist_config(base_path)
}
fn auth_keys_needed(&self) -> bool {
self.pageserver.pg_auth_type == AuthType::NeonJWT
|| self.pageserver.http_auth_type == AuthType::NeonJWT
|| self.safekeepers.iter().any(|sk| sk.auth_enabled)
}
}
fn base_path() -> PathBuf {
@@ -450,43 +460,6 @@ fn base_path() -> PathBuf {
}
}
/// Generate a public/private key pair for JWT authentication
fn generate_auth_keys(private_key_path: &Path, public_key_path: &Path) -> anyhow::Result<()> {
// Generate the key pair
//
// openssl genpkey -algorithm ed25519 -out auth_private_key.pem
let keygen_output = Command::new("openssl")
.arg("genpkey")
.args(["-algorithm", "ed25519"])
.args(["-out", private_key_path.to_str().unwrap()])
.stdout(Stdio::null())
.output()
.context("failed to generate auth private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
// Extract the public key from the private key file
//
// openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
let keygen_output = Command::new("openssl")
.arg("pkey")
.args(["-in", private_key_path.to_str().unwrap()])
.arg("-pubout")
.args(["-out", public_key_path.to_str().unwrap()])
.output()
.context("failed to extract public key from private key")?;
if !keygen_output.status.success() {
bail!(
"openssl failed: '{}'",
String::from_utf8_lossy(&keygen_output.stderr)
);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -11,7 +11,6 @@ use anyhow::{bail, Context};
use pageserver_api::models::{
TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo,
};
use postgres_backend::AuthType;
use postgres_connection::{parse_host_port, PgConnectionConfig};
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
@@ -21,6 +20,7 @@ use utils::{
http::error::HttpErrorBody,
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::{background_process, local_env::LocalEnv};
@@ -82,8 +82,15 @@ impl PageServerNode {
let (host, port) = parse_host_port(&env.pageserver.listen_pg_addr)
.expect("Unable to parse listen_pg_addr");
let port = port.unwrap_or(5432);
let password = if env.pageserver.auth_type == AuthType::NeonJWT {
Some(env.pageserver.auth_token.clone())
} else {
None
};
Self {
pg_connection_config: PgConnectionConfig::new_host_port(host, port),
pg_connection_config: PgConnectionConfig::new_host_port(host, port)
.set_password(password),
env: env.clone(),
http_client: Client::new(),
http_base_url: format!("http://{}/v1", env.pageserver.listen_http_addr),
@@ -99,32 +106,25 @@ impl PageServerNode {
self.env.pg_distrib_dir_raw().display()
);
let http_auth_type_param =
format!("http_auth_type='{}'", self.env.pageserver.http_auth_type);
let authg_type_param = format!("auth_type='{}'", self.env.pageserver.auth_type);
let listen_http_addr_param = format!(
"listen_http_addr='{}'",
self.env.pageserver.listen_http_addr
);
let pg_auth_type_param = format!("pg_auth_type='{}'", self.env.pageserver.pg_auth_type);
let listen_pg_addr_param =
format!("listen_pg_addr='{}'", self.env.pageserver.listen_pg_addr);
let broker_endpoint_param = format!("broker_endpoint='{}'", self.env.broker.client_url());
let mut overrides = vec![
id,
pg_distrib_dir_param,
http_auth_type_param,
pg_auth_type_param,
authg_type_param,
listen_http_addr_param,
listen_pg_addr_param,
broker_endpoint_param,
];
if self.env.pageserver.http_auth_type != AuthType::Trust
|| self.env.pageserver.pg_auth_type != AuthType::Trust
{
if self.env.pageserver.auth_type != AuthType::Trust {
overrides.push("auth_validation_public_key_path='auth_public_key.pem'".to_owned());
}
overrides
@@ -247,10 +247,7 @@ impl PageServerNode {
}
fn pageserver_env_variables(&self) -> anyhow::Result<Vec<(String, String)>> {
// FIXME: why is this tied to pageserver's auth type? Whether or not the safekeeper
// needs a token, and how to generate that token, seems independent to whether
// the pageserver requires a token in incoming requests.
Ok(if self.env.pageserver.http_auth_type != AuthType::Trust {
Ok(if self.env.pageserver.auth_type != AuthType::Trust {
// Generate a token to connect from the pageserver to a safekeeper
let token = self
.env
@@ -273,30 +270,27 @@ impl PageServerNode {
background_process::stop_process(immediate, "pageserver", &self.pid_file())
}
pub fn page_server_psql_client(&self) -> anyhow::Result<postgres::Client> {
let mut config = self.pg_connection_config.clone();
if self.env.pageserver.pg_auth_type == AuthType::NeonJWT {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::PageServerApi))?;
config = config.set_password(Some(token));
}
Ok(config.connect_no_tls()?)
pub fn page_server_psql(&self, sql: &str) -> Vec<postgres::SimpleQueryMessage> {
let mut client = self.pg_connection_config.connect_no_tls().unwrap();
println!("Pageserver query: '{sql}'");
client.simple_query(sql).unwrap()
}
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> anyhow::Result<RequestBuilder> {
pub fn page_server_psql_client(&self) -> result::Result<postgres::Client, postgres::Error> {
self.pg_connection_config.connect_no_tls()
}
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
let mut builder = self.http_client.request(method, url);
if self.env.pageserver.http_auth_type == AuthType::NeonJWT {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::PageServerApi))?;
builder = builder.bearer_auth(token)
if self.env.pageserver.auth_type == AuthType::NeonJWT {
builder = builder.bearer_auth(&self.env.pageserver.auth_token)
}
Ok(builder)
builder
}
pub fn check_status(&self) -> Result<()> {
self.http_request(Method::GET, format!("{}/status", self.http_base_url))?
self.http_request(Method::GET, format!("{}/status", self.http_base_url))
.send()?
.error_from_body()?;
Ok(())
@@ -304,7 +298,7 @@ impl PageServerNode {
pub fn tenant_list(&self) -> Result<Vec<TenantInfo>> {
Ok(self
.http_request(Method::GET, format!("{}/tenant", self.http_base_url))?
.http_request(Method::GET, format!("{}/tenant", self.http_base_url))
.send()?
.error_from_body()?
.json()?)
@@ -358,21 +352,11 @@ impl PageServerNode {
.map(|x| x.parse::<bool>())
.transpose()
.context("Failed to parse 'trace_read_requests' as bool")?,
eviction_policy: settings
.get("eviction_policy")
.map(|x| serde_json::from_str(x))
.transpose()
.context("Failed to parse 'eviction_policy' json")?,
min_resident_size_override: settings
.remove("min_resident_size_override")
.map(|x| x.parse::<u64>())
.transpose()
.context("Failed to parse 'min_resident_size_override' as integer")?,
};
if !settings.is_empty() {
bail!("Unrecognized tenant settings: {settings:?}")
}
self.http_request(Method::POST, format!("{}/tenant", self.http_base_url))?
self.http_request(Method::POST, format!("{}/tenant", self.http_base_url))
.json(&request)
.send()?
.error_from_body()?
@@ -389,7 +373,7 @@ impl PageServerNode {
}
pub fn tenant_config(&self, tenant_id: TenantId, settings: HashMap<&str, &str>) -> Result<()> {
self.http_request(Method::PUT, format!("{}/tenant/config", self.http_base_url))?
self.http_request(Method::PUT, format!("{}/tenant/config", self.http_base_url))
.json(&TenantConfigRequest {
tenant_id,
checkpoint_distance: settings
@@ -440,11 +424,6 @@ impl PageServerNode {
.map(|x| serde_json::from_str(x))
.transpose()
.context("Failed to parse 'eviction_policy' json")?,
min_resident_size_override: settings
.get("min_resident_size_override")
.map(|x| x.parse::<u64>())
.transpose()
.context("Failed to parse 'min_resident_size_override' as an integer")?,
})
.send()?
.error_from_body()?;
@@ -457,7 +436,7 @@ impl PageServerNode {
.http_request(
Method::GET,
format!("{}/tenant/{}/timeline", self.http_base_url, tenant_id),
)?
)
.send()?
.error_from_body()?
.json()?;
@@ -476,7 +455,7 @@ impl PageServerNode {
self.http_request(
Method::POST,
format!("{}/tenant/{}/timeline", self.http_base_url, tenant_id),
)?
)
.json(&TimelineCreateRequest {
new_timeline_id,
ancestor_start_lsn,
@@ -513,7 +492,7 @@ impl PageServerNode {
pg_wal: Option<(Lsn, PathBuf)>,
pg_version: u32,
) -> anyhow::Result<()> {
let mut client = self.page_server_psql_client()?;
let mut client = self.pg_connection_config.connect_no_tls().unwrap();
// Init base reader
let (start_lsn, base_tarfile_path) = base;

View File

@@ -1,6 +1,7 @@
use std::io::Write;
use std::path::PathBuf;
use std::process::Child;
use std::sync::Arc;
use std::{io, result};
use anyhow::Context;
@@ -10,6 +11,7 @@ use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::{http::error::HttpErrorBody, id::NodeId};
use crate::pageserver::PageServerNode;
use crate::{
background_process,
local_env::{LocalEnv, SafekeeperConf},
@@ -63,10 +65,14 @@ pub struct SafekeeperNode {
pub env: LocalEnv,
pub http_client: Client,
pub http_base_url: String,
pub pageserver: Arc<PageServerNode>,
}
impl SafekeeperNode {
pub fn from_env(env: &LocalEnv, conf: &SafekeeperConf) -> SafekeeperNode {
let pageserver = Arc::new(PageServerNode::from_env(env));
SafekeeperNode {
id: conf.id,
conf: conf.clone(),
@@ -74,6 +80,7 @@ impl SafekeeperNode {
env: env.clone(),
http_client: Client::new(),
http_base_url: format!("http://127.0.0.1:{}/v1", conf.http_port),
pageserver,
}
}
@@ -108,10 +115,6 @@ impl SafekeeperNode {
let datadir = self.datadir_path();
let id_string = id.to_string();
// TODO: add availability_zone to the config.
// Right now we just specify any value here and use it to check metrics in tests.
let availability_zone = format!("sk-{}", id_string);
let mut args = vec![
"-D",
datadir.to_str().with_context(|| {
@@ -123,8 +126,6 @@ impl SafekeeperNode {
&listen_pg,
"--listen-http",
&listen_http,
"--availability-zone",
&availability_zone,
];
if !self.conf.sync {
args.push("--no-sync");

View File

@@ -160,7 +160,6 @@ services:
build:
context: ./compute_wrapper/
args:
- REPOSITORY=${REPOSITORY:-neondatabase}
- COMPUTE_IMAGE=compute-node-v${PG_VERSION:-14}
- TAG=${TAG:-latest}
- http_proxy=$http_proxy

View File

@@ -29,54 +29,12 @@ These components should not have access to the private key and may only get toke
The key pair is generated once for an installation of compute/pageserver/safekeeper, e.g. by `neon_local init`.
There is currently no way to rotate the key without bringing down all components.
### Best practices
See [RFC 8725: JSON Web Token Best Current Practices](https://www.rfc-editor.org/rfc/rfc8725)
### Token format
The JWT tokens in Neon use "EdDSA" as the algorithm (defined in [RFC8037](https://www.rfc-editor.org/rfc/rfc8037)).
Example:
Header:
```
{
"alg": "EdDSA",
"typ": "JWT"
}
```
Payload:
```
{
"scope": "tenant", # "tenant", "pageserverapi", or "safekeeperdata"
"tenant_id": "5204921ff44f09de8094a1390a6a50f6",
}
```
Meanings of scope:
"tenant": Provides access to all data for a specific tenant
"pageserverapi": Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
Should only be used e.g. for status check/tenant creation/list.
"safekeeperdata": Provides blanket access to all data on the safekeeper plus safekeeper-wide APIs.
Should only be used e.g. for status check.
Currently also used for connection from any pageserver to any safekeeper.
### CLI
CLI generates a key pair during call to `neon_local init` with the following commands:
```bash
openssl genpkey -algorithm ed25519 -out auth_private_key.pem
openssl pkey -in auth_private_key.pem -pubout -out auth_public_key.pem
openssl genrsa -out auth_private_key.pem 2048
openssl rsa -in auth_private_key.pem -pubout -outform PEM -out auth_public_key.pem
```
Configuration files for all components point to `public_key.pem` for JWT validation.
@@ -106,22 +64,20 @@ Their authentication is just plain PostgreSQL authentication and out of scope fo
There is no administrative API except those provided by PostgreSQL.
#### Outgoing connections
Compute connects to Pageserver for getting pages. The connection string is
configured by the `neon.pageserver_connstring` PostgreSQL GUC,
e.g. `postgresql://no_user@localhost:15028`. If the `$NEON_AUTH_TOKEN`
environment variable is set, it is used as the password for the connection. (The
pageserver uses JWT tokens for authentication, so the password is really a
token.)
Compute connects to Pageserver for getting pages.
The connection string is configured by the `neon.pageserver_connstring` PostgreSQL GUC, e.g. `postgresql://no_user:$NEON_AUTH_TOKEN@localhost:15028`.
The environment variable inside the connection string is substituted with
the JWT token.
Compute connects to Safekeepers to write and commit data. The list of safekeeper
addresses is given in the `neon.safekeepers` GUC. The connections to the
safekeepers take the password from the `$NEON_AUTH_TOKEN` environment
variable, if set.
Compute connects to Safekeepers to write and commit data.
The token is the same for all safekeepers.
It's stored in an environment variable, whose name is configured
by the `neon.safekeeper_token_env` PostgreSQL GUC.
If the GUC is unset, no token is passed.
The `compute_ctl` binary that runs before the PostgreSQL server, and launches
PostgreSQL, also makes a connection to the pageserver. It uses it to fetch the
initial "base backup" dump, to initialize the PostgreSQL data directory. It also
uses `$NEON_AUTH_TOKEN` as the password for the connection.
Note that both tokens can be (and typically are) the same;
the scope is the tenant and the token is usually passed through the
`$NEON_AUTH_TOKEN` environment variable.
### Pageserver
#### Overview
@@ -146,12 +102,10 @@ Each compute should present a token valid for the timeline's tenant.
Pageserver also has HTTP API: some parts are per-tenant,
some parts are server-wide, these are different scopes.
Authentication can be enabled separately for the HTTP mgmt API, and
for the libpq connections from compute. The `http_auth_type` and
`pg_auth_type` configuration variables in Pageserver's config may
have one of these values:
The `auth_type` configuration variable in Pageserver's config may have
either of three values:
* `Trust` removes all authentication.
* `Trust` removes all authentication. The outdated `MD5` value does likewise
* `NeonJWT` enables JWT validation.
Tokens are validated using the public key which lies in a PEM file
specified in the `auth_validation_public_key_path` config.

View File

@@ -37,9 +37,9 @@ You can specify version of neon cluster using following environment values.
- PG_VERSION: postgres version for compute (default is 14)
- TAG: the tag version of [docker image](https://registry.hub.docker.com/r/neondatabase/neon/tags) (default is latest), which is tagged in [CI test](/.github/workflows/build_and_test.yml)
```
$ cd docker-compose/
$ cd docker-compose/docker-compose.yml
$ docker-compose down # remove the conainers if exists
$ PG_VERSION=15 TAG=2937 docker-compose up --build -d # You can specify the postgres and image version
$ PG_VERSION=15 TAG=2221 docker-compose up --build -d # You can specify the postgres and image version
Creating network "dockercompose_default" with the default driver
Creating docker-compose_storage_broker_1 ... done
(...omit...)

View File

@@ -1,269 +0,0 @@
# Deleting pageserver part of tenants data from s3
Created on 08.03.23
## Motivation
Currently we dont delete pageserver part of the data from s3 when project is deleted. (The same is true for safekeepers, but this outside of the scope of this RFC).
This RFC aims to spin a discussion to come to a robust deletion solution that wont put us in into a corner for features like postponed deletion (when we keep data for user to be able to restore a project if it was deleted by accident)
## Summary
TLDR; There are two options, one based on control plane issuing actual delete requests to s3 and the other one that keeps s3 stuff bound to pageserver. Each one has its pros and cons.
The decision is to stick with pageserver centric approach. For motivation see [Decision](#decision).
## Components
pageserver, control-plane
## Requirements
Deletion should successfully finish (eventually) without leaving dangling files in presense of:
- component restarts
- component outage
- pageserver loss
## Proposed implementation
Before the options are discussed, note that deletion can be quite long process. For deletion from s3 the obvious choice is [DeleteObjects](https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html) API call. It allows to batch deletion of up to 1k objects in one API call. So deletion operation linearly depends on number of layer files.
Another design limitation is that there is no cheap `mv` operation available for s3. `mv` from `aws s3 mv` uses `copy(src, dst) + delete(src)`. So `mv`-like operation is not feasible as a building block because it actually amplifies the problem with both duration and resulting cost of the operation.
The case when there are multiple pageservers handling the same tenants is largely out of scope of the RFC. We still consider case with migration from one PS to another, but do not consider case when tenant exists on multiple pageservers for extended period of time. The case with multiple pageservers can be reduced to case with one pageservers by calling detach on all pageservers except the last one, for it actual delete needs to be called.
For simplicity lets look into deleting tenants. Differences in deletion process between tenants and timelines are mentioned in paragraph ["Differences between tenants and timelines"](#differences-between-tenants-and-timelines)
### 1. Pageserver owns deletion machinery
#### The sequence
TLDR; With this approach control plane needs to call delete on a tenant and poll for progress. As much as possible is handled on pageserver. Lets see the sequence.
Happy path:
```mermaid
sequenceDiagram
autonumber
participant CP as Control Plane
participant PS as Pageserver
participant S3
CP->>PS: Delete tenant
PS->>S3: Create deleted mark file at <br> /tenant/meta/deleted
PS->>PS: Create deleted mark file locally
PS->>CP: Accepted
PS->>PS: delete local files other than deleted mark
loop Delete layers for each timeline
PS->>S3: delete(..)
CP->>PS: Finished?
PS->>CP: False
end
PS->>S3: Delete mark file
PS->>PS: Delete local mark file
loop Poll for status
CP->>PS: Finished?
PS->>CP: True or False
end
```
Why two mark files?
Remote one is needed for cases when pageserver is lost during deletion so other pageserver can learn the deletion from s3 during attach.
Why local mark file is needed?
If we dont have one, we have two choices, delete local data before deleting the remote part or do that after.
If we delete local data before remote then during restart pageserver wont pick up remote tenant at all because nothing is available locally (pageserver looks for remote conuterparts of locally available tenants).
If we delete local data after remote then at the end of the sequence when remote mark file is deleted if pageserver restart happens then the state is the same to situation when pageserver just missing data on remote without knowing the fact that this data is intended to be deleted. In this case the current behavior is upload everything local-only to remote.
Thus we need local record of tenant being deleted as well.
##### Handle pageserver crashes
Lets explore sequences with various crash points.
Pageserver crashes before `deleted` mark file is persisted in s3:
```mermaid
sequenceDiagram
autonumber
participant CP as Control Plane
participant PS as Pageserver
participant S3
CP->>PS: Delete tenant
note over PS: Crash point 1.
CP->>PS: Retry delete request
PS->>S3: Create deleted mark file at <br> /tenant/meta/deleted
PS->>PS: Create deleted mark file locally
PS->>CP: Accepted
PS->>PS: delete local files other than deleted mark
loop Delete layers for each timeline
PS->>S3: delete(..)
CP->>PS: Finished?
PS->>CP: False
end
PS->>S3: Delete mark file
PS->>PS: Delete local mark file
CP->>PS: Finished?
PS->>CP: True
```
Pageserver crashed when deleted mark was about to be persisted in s3, before Control Plane gets a response:
```mermaid
sequenceDiagram
autonumber
participant CP as Control Plane
participant PS as Pageserver
participant S3
CP->>PS: Delete tenant
PS->>S3: Create deleted mark file at <br> /tenant/meta/deleted
note over PS: Crash point 2.
note over PS: During startup we reconcile <br> with remote and see <br> whether the remote mark exists
alt Remote mark exists
PS->>PS: create local mark if its missing
PS->>PS: delete local files other than deleted mark
loop Delete layers for each timeline
PS->>S3: delete(..)
end
note over CP: Eventually console should <br> retry delete request
CP->>PS: Retry delete tenant
PS->>CP: Not modified
else Mark is missing
note over PS: Continue to operate the tenant as if deletion didnt happen
note over CP: Eventually console should <br> retry delete request
CP->>PS: Retry delete tenant
PS->>S3: Create deleted mark file at <br> /tenant/meta/deleted
PS->>CP: Delete tenant
end
PS->>PS: Continue with layer file deletions
loop Delete layers for each timeline
PS->>S3: delete(..)
CP->>PS: Finished?
PS->>CP: False
end
PS->>S3: Delete mark file
PS->>PS: Delete local mark file
CP->>PS: Finished?
PS->>CP: True
```
Similar sequence applies when both local and remote marks were persisted but Control Plane still didnt receive a response.
If pageserver crashes after both mark files were deleted then it will reply to control plane status poll request with 404 which should be treated by control plane as success.
The same applies if pageserver crashes in the end, when remote mark is deleted but before local one gets deleted. In this case on restart pageserver moves forward with deletion of local mark and Control Plane will receive 404.
##### Differences between tenants and timelines
For timeline the sequence is the same with the following differences:
- remote delete mark file can be replaced with a boolean "deleted" flag in index_part.json
- local deletion mark is not needed, because whole tenant is kept locally so situation described in motivation for local mark is impossible
##### Handle pageserver loss
If pageseserver is lost then the deleted tenant should be attached to different pageserver and delete request needs to be retried against new pageserver. Then attach logic is shared with one described for pageserver restarts (local deletion mark wont be available so needs to be created).
##### Restrictions for tenant that is in progress of being deleted
I propose to add another state to tenant/timeline - PendingDelete. This state shouldnt allow executing any operations aside from polling the deletion status.
#### Summary
Pros:
- Storage is not dependent on control plane. Storage can be restarted even if control plane is not working.
- Allows for easier dogfooding, console can use Neon backed database as primary operational data store. If storage depends on control plane and control plane depends on storage we're stuck.
- No need to share inner s3 workings with control plane. Pageserver presents api contract and S3 paths are not part of this contract.
- No need to pass list of alive timelines to attach call. This will be solved by pageserver observing deleted flag. See
Cons:
- Logic is a tricky, needs good testing
- Anything else?
### 2. Control plane owns deletion machinery
In this case the only action performed on pageserver is removal of local files.
Everything else is done by control plane. The steps are as follows:
1. Control plane marks tenant as "delete pending" in its database
2. It lists the s3 for all the files and repeatedly calls delete until nothing is left behind
3. When no files are left marks deletion as completed
In case of restart it selects all tenants marked as "delete pending" and continues the deletion.
For tenants it is simple. For timelines there are caveats.
Assume that the same workflow is used for timelines.
If a tenant gets relocated during timeline deletion the attach call with its current logic will pick up deleted timeline in its half deleted state.
Available options:
- require list of alive timelines to be passed to attach call
- use the same schema with flag in index_part.json (again part of the caveats around pageserver restart applies). In this case nothing stops pageserver from implementing deletion inside if we already have these deletion marks.
With first option the following problem becomes apparent:
Who is the source of truth regarding timeline liveness?
Imagine:
PS1 fails.
PS2 gets assigned the tenant.
New branch gets created
PS1 starts up (is it possible or we just recycle it?)
PS1 is unaware of the new branch. It can either fall back to s3 ls, or ask control plane.
So here comes the dependency of storage on control plane. During restart storage needs to know which timelines are valid for operation. If there is nothing on s3 that can answer that question storage neeeds to ask control plane.
### Summary
Cons:
- Potential thundering herd-like problem during storage restart (requests to control plane)
- Potential increase in storage startup time (additional request to control plane)
- Storage startup starts to depend on console
- Erroneous attach call can attach tenant in half deleted state
Pros:
- Easier to reason about if you dont have to account for pageserver restarts
### Extra notes
There was a concern that having deletion code in pageserver is a littlebit scary, but we need to have this code somewhere. So to me it is equally scary to have that in whatever place it ends up at.
Delayed deletion can be done with both approaches. As discussed with Anna (@stepashka) this is only relevant for tenants (projects) not for timelines. For first approach detach can be called immediately and deletion can be done later with attach + delete. With second approach control plane needs to start the deletion whenever necessary.
## Decision
After discussion in comments I see that we settled on two options (though a bit different from ones described in rfc). First one is the same - pageserver owns as much as possible. The second option is that pageserver owns markers thing, but actual deletion happens in control plane by repeatedly calling ls + delete.
To my mind the only benefit of the latter approach is possible code reuse between safekeepers and pageservers. Otherwise poking around integrating s3 library into control plane, configuring shared knowledge abouth paths in s3 - are the downsides. Another downside of relying on control plane is the testing process. Control plane resides in different repository so it is quite hard to test pageserver related changes there. e2e test suite there doesnt support shutting down pageservers, which are separate docker containers there instead of just processes.
With pageserver owning everything we still give the retry logic to control plane but its easier to duplicate if needed compared to sharing inner s3 workings. We will have needed tests for retry logic in neon repo.
So the decision is to proceed with pageserver centric approach.

View File

@@ -129,12 +129,13 @@ Run `poetry shell` to activate the virtual environment.
Alternatively, use `poetry run` to run a single command in the venv, e.g. `poetry run pytest`.
### Obligatory checks
We force code formatting via `black`, `ruff`, and type hints via `mypy`.
We force code formatting via `black`, `isort` and type hints via `mypy`.
Run the following commands in the repository's root (next to `pyproject.toml`):
```bash
poetry run isort . # Imports are reformatted
poetry run black . # All code is reformatted
poetry run ruff . # Python linter
poetry run flake8 . # Python linter
poetry run mypy . # Ensure there are no typing errors
```

View File

@@ -115,12 +115,6 @@ pub struct TenantCreateRequest {
pub lagging_wal_timeout: Option<String>,
pub max_lsn_wal_lag: Option<NonZeroU64>,
pub trace_read_requests: Option<bool>,
// We defer the parsing of the eviction_policy field to the request handler.
// Otherwise we'd have to move the types for eviction policy into this package.
// We might do that once the eviction feature has stabilizied.
// For now, this field is not even documented in the openapi_spec.yml.
pub eviction_policy: Option<serde_json::Value>,
pub min_resident_size_override: Option<u64>,
}
#[serde_as]
@@ -166,7 +160,6 @@ pub struct TenantConfigRequest {
// We might do that once the eviction feature has stabilizied.
// For now, this field is not even documented in the openapi_spec.yml.
pub eviction_policy: Option<serde_json::Value>,
pub min_resident_size_override: Option<u64>,
}
impl TenantConfigRequest {
@@ -187,7 +180,6 @@ impl TenantConfigRequest {
max_lsn_wal_lag: None,
trace_read_requests: None,
eviction_policy: None,
min_resident_size_override: None,
}
}
}
@@ -349,7 +341,7 @@ pub enum InMemoryLayerInfo {
pub enum HistoricLayerInfo {
Delta {
layer_file_name: String,
layer_file_size: u64,
layer_file_size: Option<u64>,
#[serde_as(as = "DisplayFromStr")]
lsn_start: Lsn,
@@ -360,7 +352,7 @@ pub enum HistoricLayerInfo {
},
Image {
layer_file_name: String,
layer_file_size: u64,
layer_file_size: Option<u64>,
#[serde_as(as = "DisplayFromStr")]
lsn_start: Lsn,

View File

@@ -1,26 +0,0 @@
[package]
name = "postgres_backend"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
async-trait.workspace = true
anyhow.workspace = true
bytes.workspace = true
futures.workspace = true
rustls.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
once_cell.workspace = true
rustls-pemfile.workspace = true
tokio-postgres.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -1,931 +0,0 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use anyhow::Context;
use bytes::Bytes;
use futures::pin_mut;
use serde::{Deserialize, Serialize};
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Poll};
use std::{fmt, io};
use std::{future::Future, str::FromStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace};
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
use pq_proto::{
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
SQLSTATE_SUCCESSFUL_COMPLETION,
};
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Io(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
#[async_trait::async_trait]
pub trait Handler<IO> {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care). It will also flush out the output buffer.
async fn process_query(
&mut self,
pgb: &mut PostgresBackend<IO>,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend<IO>,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend<IO>,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
/// Nothing happened yet.
Initialization,
/// Encryption handshake is done; waiting for encrypted Startup message.
Encrypted,
/// Waiting for password (auth token).
Authentication,
/// Performed handshake and auth, ReadyForQuery is issued.
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream<IO> {
Unencrypted(IO),
Tls(Box<tokio_rustls::server::TlsStream<IO>>),
}
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
/// Either full duplex Framed or write only half; the latter is left in
/// PostgresBackend after call to `split`. In principle we could always store a
/// pair of splitted handles, but that would force to to pay splitting price
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
enum MaybeWriteOnly<IO> {
Full(Framed<MaybeTlsStream<IO>>),
WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
Broken, // temporary value palmed off during the split
}
impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
match self {
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn flush(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.flush().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn shutdown(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
}
pub struct PostgresBackend<IO> {
framed: MaybeWriteOnly<IO>,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
/// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend<tokio::net::TcpStream> {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
pub fn new_from_io(
socket: IO,
peer_addr: SocketAddr,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
if let ProtoState::Closed = self.state {
Ok(None)
} else {
let m = self.framed.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
}
/// Write message into internal output buffer, doesn't flush it. Technically
/// error type can be only ProtocolError here (if, unlikely, serialization
/// fails), but callers typically wrap it anyway.
pub fn write_message_noflush(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.framed.write_message_noflush(message)?;
trace!("wrote msg {:?}", message);
Ok(self)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
self.framed.flush().await
}
/// Polling version of `flush()`, saves the caller need to pin.
pub fn poll_flush(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let flush_fut = self.flush();
pin_mut!(flush_fut);
flush_fut.poll(cx)
}
/// Write message into internal output buffer and flush it to the stream.
pub async fn write_message(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
CopyDataWriter { pgb: self }
}
/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
// socket might be already closed, e.g. if previously received error,
// so ignore result.
self.framed.shutdown().await.ok();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = self.handshake(handler) => {
// Handshake complete.
result?;
if self.state == ProtoState::Closed {
return Ok(()); // EOF during handshake
}
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
async fn tls_upgrade(
src: MaybeTlsStream<IO>,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<MaybeTlsStream<IO>> {
match src {
MaybeTlsStream::Unencrypted(s) => {
let acceptor = TlsAcceptor::from(tls_config);
let tls_stream = acceptor.accept(s).await?;
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
}
MaybeTlsStream::Tls(_) => {
anyhow::bail!("TLS already started");
}
}
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
// temporary replace stream with fake to cook TLS one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let tls_config = self
.tls_config
.as_ref()
.context("start_tls called without conf")?
.clone();
let tls_framed = framed
.map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
.await?;
// push back ready TLS stream
self.framed = MaybeWriteOnly::Full(tls_framed);
Ok(())
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("TLS upgrade attempt in split state")
}
MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
}
}
/// Split off owned read part from which messages can be read in different
/// task/thread.
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
// temporary replace stream with fake to cook split one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let (reader, writer) = framed.split();
self.framed = MaybeWriteOnly::WriteOnly(writer);
Ok(PostgresBackendReader(reader))
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("PostgresBackend is already split")
}
MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
}
}
/// Join read part back.
pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
// temporary replace stream with fake to cook joined one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(_) => {
anyhow::bail!("PostgresBackend is not split")
}
MaybeWriteOnly::WriteOnly(writer) => {
let joined = Framed::unsplit(reader.0, writer);
self.framed = MaybeWriteOnly::Full(joined);
Ok(())
}
MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
}
}
/// Perform handshake with the client, transitioning to Established.
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
while self.state < ProtoState::Authentication {
match self.framed.read_startup_message().await? {
Some(msg) => {
self.process_startup_message(handler, msg).await?;
}
None => {
trace!(
"postgres backend to {:?} received EOF during handshake",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
// Perform auth, if needed.
if self.state == ProtoState::Authentication {
match self.framed.read_message().await? {
Some(FeMessage::PasswordMessage(m)) => {
assert!(self.auth_type == AuthType::NeonJWT);
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
Some(m) => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected message {:?} while waiting for handshake",
m
)));
}
None => {
trace!(
"postgres backend to {:?} received EOF during auth",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
Ok(())
}
/// Process startup packet:
/// - transition to Established if auth type is trust
/// - transition to Authentication if auth type is NeonJWT.
/// - or perform TLS handshake -- then need to call this again to receive
/// actual startup packet.
async fn process_startup_message(
&mut self,
handler: &mut impl Handler<IO>,
msg: FeStartupPacket,
) -> Result<(), QueryError> {
assert!(self.state < ProtoState::Authentication);
let have_tls = self.tls_config.is_some();
match msg {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))
.await?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))
.await?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
.await?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &msg)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)
.await?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected CancelRequest message during handshake"
)));
}
}
Ok(())
}
async fn process_message(
&mut self,
handler: &mut impl Handler<IO>,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message_noflush(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message_noflush(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message_noflush(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message_noflush(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_)
| FeMessage::CopyDone
| FeMessage::CopyFail
| FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}",
)));
}
}
Ok(ProcessMsgResult::Continue)
}
/// Log as info/error result of handling COPY stream and send back
/// ErrorResponse if that makes sense. Shutdown the stream if we got
/// Terminate. TODO: transition into waiting for Sync msg if we initiate the
/// close.
pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
use CopyStreamHandlerEnd::*;
let expected_end = match &end {
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
if is_expected_io_error(io_error) =>
{
true
}
_ => false,
};
if expected_end {
info!("terminated: {:#}", end);
} else {
error!("terminated: {:?}", end);
}
// Note: no current usages ever send this
if let CopyDone = &end {
if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
error!("failed to send CopyDone: {}", e);
}
}
if let Terminate = &end {
self.state = ProtoState::Closed;
}
let err_to_send_and_errcode = match &end {
ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
// Note: CopyFail in duplex copy is somewhat unexpected (at least to
// PG walsender; evidently and per my docs reading client should
// finish it with CopyDone). It is not a problem to recover from it
// finishing the stream in both directions like we do, but note that
// sync rust-postgres client (which we don't use anymore) hangs if
// socket is not closed here.
// https://github.com/sfackler/rust-postgres/issues/755
// https://github.com/neondatabase/neon/issues/935
//
// Currently, the version of tokio_postgres replication patch we use
// sends this when it closes the stream (e.g. pageserver decided to
// switch conn to another safekeeper and client gets dropped).
// Moreover, seems like 'connection' task errors with 'unexpected
// message from server' when it receives ErrorResponse (anything but
// CopyData/CopyDone) back.
CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
_ => None,
};
if let Some((err, errcode)) = err_to_send_and_errcode {
if let Err(ee) = self
.write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
.await
{
error!("failed to send ErrorResponse: {}", ee);
}
}
}
}
pub struct PostgresBackendReader<IO>(FramedReader<MaybeTlsStream<IO>>);
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
let m = self.0.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
/// Get CopyData contents of the next message in COPY stream or error
/// closing it. The error type is wider than actual errors which can happen
/// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
/// current callers.
pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
match self.read_message().await? {
Some(msg) => match msg {
FeMessage::CopyData(m) => Ok(m),
FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
))),
},
None => Err(CopyStreamHandlerEnd::EOF),
}
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a, IO> {
pgb: &'a mut PostgresBackend<IO>,
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
return Poll::Ready(Err(err));
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb
.write_message_noflush(&BeMessage::CopyData(buf))
// write_message only writes to the buffer, so it can fail iff the
// message is invaid, but CopyData can't be invalid.
.map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}
/// Something finishing handling of COPY stream, see handle_copy_stream_end.
/// This is not always a real error, but it allows to use ? and thiserror impls.
#[derive(thiserror::Error, Debug)]
pub enum CopyStreamHandlerEnd {
/// Handler initiates the end of streaming.
#[error("{0}")]
ServerInitiated(String),
#[error("received CopyDone")]
CopyDone,
#[error("received CopyFail")]
CopyFail,
#[error("received Terminate")]
Terminate,
#[error("EOF on COPY stream")]
EOF,
/// The connection was lost
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}

View File

@@ -1,140 +0,0 @@
/// Test postgres_backend_async with tokio_postgres
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor;
use std::{future, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
use tokio_postgres_rustls::MakeRustlsConnect;
// generate client, server test streams
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (server_stream, _) = listener.accept().await.unwrap();
(client_stream, server_stream)
}
struct TestHandler {}
#[async_trait::async_trait]
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
// return single col 'hey' for any query
async fn process_query(
&mut self,
pgb: &mut PostgresBackend<IO>,
_query_string: &str,
) -> Result<(), QueryError> {
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
b"hey",
)]))?
.write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
Ok(())
}
}
// test that basic select works
#[tokio::test]
async fn simple_select() {
let (client_sock, server_sock) = make_tcp_pair().await;
// create and run pgbackend
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let conf = Config::new();
let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
// test that basic select with ssl works
#[tokio::test]
async fn simple_select_ssl() {
let (client_sock, server_sock) = make_tcp_pair().await;
let server_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(server_cfg));
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let client_cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
&mut make_tls_connect,
"localhost",
)
.expect("make_tls_connect");
let mut conf = Config::new();
conf.ssl_mode(SslMode::Require);
let (client, connection) = conf
.connect_raw(client_sock, tls_connect)
.await
.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}

View File

@@ -63,7 +63,10 @@ fn main() -> anyhow::Result<()> {
pg_install_dir_versioned = cwd.join("..").join("..").join(pg_install_dir_versioned);
}
let pg_config_bin = pg_install_dir_versioned.join("bin").join("pg_config");
let pg_config_bin = pg_install_dir_versioned
.join(pg_version)
.join("bin")
.join("pg_config");
let inc_server_path: String = if pg_config_bin.exists() {
let output = Command::new(pg_config_bin)
.arg("--includedir-server")

View File

@@ -5,8 +5,8 @@ edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
bytes.workspace = true
byteorder.workspace = true
pin-project-lite.workspace = true
postgres-protocol.workspace = true
rand.workspace = true

View File

@@ -1,244 +0,0 @@
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
//! the async stream based on (and buffered with) BytesMut. All functions are
//! cancellation safe.
//!
//! It is similar to what tokio_util::codec::Framed with appropriate codec
//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
//! separately without using split from futures::stream::StreamExt (which
//! allocates box[1] in polling internally). tokio::io::split is used for splitting
//! instead. Plus we customize error messages more than a single type for all io
//! calls.
//!
//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
use bytes::{Buf, BytesMut};
use std::{
future::Future,
io::{self, ErrorKind},
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
const INITIAL_CAPACITY: usize = 8 * 1024;
/// Error on postgres connection: either IO (physical transport error) or
/// protocol violation.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Protocol(#[from] ProtocolError),
}
impl ConnectionError {
/// Proxy stream.rs uses only io::Error; provide it.
pub fn into_io_error(self) -> io::Error {
match self {
ConnectionError::Io(io) => io,
ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
}
}
}
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
/// messages.
pub struct Framed<S> {
stream: S,
read_buf: BytesMut,
write_buf: BytesMut,
}
impl<S> Framed<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
}
}
/// Get a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Deconstruct into the underlying stream and read buffer.
pub fn into_inner(self) -> (S, BytesMut) {
(self.stream, self.read_buf)
}
/// Return new Framed with stream type transformed by async f, for TLS
/// upgrade.
pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
where
F: FnOnce(S) -> Fut,
Fut: Future<Output = Result<S2, E>>,
{
let stream = f(self.stream).await?;
Ok(Framed {
stream,
read_buf: self.read_buf,
write_buf: self.write_buf,
})
}
}
impl<S: AsyncRead + Unpin> Framed<S> {
pub async fn read_startup_message(
&mut self,
) -> Result<Option<FeStartupPacket>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
}
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
impl<S: AsyncWrite + Unpin> Framed<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
/// Split into owned read and write parts. Beware of potential issues with
/// using halves in different tasks on TLS stream:
/// https://github.com/tokio-rs/tls/issues/40
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
let (read_half, write_half) = tokio::io::split(self.stream);
let reader = FramedReader {
stream: read_half,
read_buf: self.read_buf,
};
let writer = FramedWriter {
stream: write_half,
write_buf: self.write_buf,
};
(reader, writer)
}
/// Join read and write parts back.
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
Self {
stream: reader.stream.unsplit(writer.stream),
read_buf: reader.read_buf,
write_buf: writer.write_buf,
}
}
}
/// Read-only version of `Framed`.
pub struct FramedReader<S> {
stream: ReadHalf<S>,
read_buf: BytesMut,
}
impl<S: AsyncRead + Unpin> FramedReader<S> {
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
/// Write-only version of `Framed`.
pub struct FramedWriter<S> {
stream: WriteHalf<S>,
write_buf: BytesMut,
}
impl<S: AsyncWrite + Unpin> FramedWriter<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
/// Read next message from the stream. Returns Ok(None), if EOF happened and we
/// don't have remaining data in the buffer. This function is cancellation safe:
/// you can drop future which is not yet complete and finalize reading message
/// with the next call.
///
/// Parametrized to allow reading startup or usual message, having different
/// format.
async fn read_message<S: AsyncRead + Unpin, M, P>(
stream: &mut S,
read_buf: &mut BytesMut,
parse: P,
) -> Result<Option<M>, ConnectionError>
where
P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
{
loop {
if let Some(msg) = parse(read_buf)? {
return Ok(Some(msg));
}
// If we can't build a frame yet, try to read more data and try again.
// Make sure we've got room for at least one byte to read to ensure
// that we don't get a spurious 0 that looks like EOF.
read_buf.reserve(1);
if stream.read_buf(read_buf).await? == 0 {
if read_buf.has_remaining() {
return Err(io::Error::new(
ErrorKind::UnexpectedEof,
"EOF with unprocessed data in the buffer",
)
.into());
} else {
return Ok(None); // clean EOF
}
}
}
}
async fn flush<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
while write_buf.has_remaining() {
let bytes_written = stream.write(write_buf.chunk()).await?;
if bytes_written == 0 {
return Err(io::Error::new(
ErrorKind::WriteZero,
"failed to write message",
));
}
// The advanced part will be garbage collected, likely during shifting
// data left on next attempt to write to buffer when free space is not
// enough.
write_buf.advance(bytes_written);
}
write_buf.clear();
stream.flush().await
}
async fn shutdown<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
flush(stream, write_buf).await?;
stream.shutdown().await
}

View File

@@ -2,18 +2,24 @@
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
//! on message formats.
pub mod framed;
// Tools for calling certain async methods in sync contexts.
pub mod sync;
use byteorder::{BigEndian, ReadBytesExt};
use anyhow::{ensure, Context, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
collections::HashMap,
fmt, io, str,
fmt,
future::Future,
io::{self, Cursor},
str,
time::{Duration, SystemTime},
};
use sync::{AsyncishRead, SyncFuture};
use tokio::io::AsyncReadExt;
use tracing::{trace, warn};
pub type Oid = u32;
@@ -25,6 +31,7 @@ pub const TEXT_OID: Oid = 25;
#[derive(Debug)]
pub enum FeMessage {
StartupPacket(FeStartupPacket),
// Simple query.
Query(Bytes),
// Extended query protocol.
@@ -184,205 +191,260 @@ pub struct FeExecuteMessage {
#[derive(Debug)]
pub struct FeCloseMessage;
/// An error occured while parsing or serializing raw stream into Postgres
/// messages.
#[derive(thiserror::Error, Debug)]
pub enum ProtocolError {
/// Invalid packet was received from the client (e.g. unexpected message
/// type or broken len).
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse or, (unlikely), serialize a protocol message.
#[error("Message parse error: {0}")]
BadMessage(String),
/// Retry a read on EINTR
///
/// This runs the enclosed expression, and if it returns
/// Err(io::ErrorKind::Interrupted), retries it.
macro_rules! retry_read {
( $x:expr ) => {
loop {
match $x {
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
res => break res,
}
}
};
}
impl ProtocolError {
/// Proxy stream.rs uses only io::Error; provide it.
/// An error occured during connection being open.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
/// IO error during writing to or reading from the connection socket.
#[error("Socket IO error: {0}")]
Socket(std::io::Error),
/// Invalid packet was received from client
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse a protocol mesage
#[error("Message parse error: {0}")]
MessageParse(anyhow::Error),
}
impl From<anyhow::Error> for ConnectionError {
fn from(e: anyhow::Error) -> Self {
Self::MessageParse(e)
}
}
impl ConnectionError {
pub fn into_io_error(self) -> io::Error {
io::Error::new(io::ErrorKind::Other, self.to_string())
match self {
ConnectionError::Socket(io) => io,
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
}
}
}
impl FeMessage {
/// Read and parse one message from the `buf` input buffer. If there is at
/// least one valid message, returns it, advancing `buf`; redundant copies
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
/// directly into the `buf` (processed data is garbage collected after
/// parsed message is dropped).
/// Read one message from the stream.
/// This function returns `Ok(None)` in case of EOF.
/// One way to handle this properly:
///
/// Returns None if `buf` doesn't contain enough data for a single message.
/// For efficiency, tries to reserve large enough space in `buf` for the
/// next message in this case to save the repeated calls.
/// ```
/// # use std::io;
/// # use pq_proto::FeMessage;
/// #
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
/// # Ok(())
/// # };
/// #
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
/// while let Some(msg) = FeMessage::read(stream)? {
/// process_message(msg)?;
/// }
///
/// Returns Error if message is malformed, the only possible ErrorKind is
/// InvalidInput.
//
// Inspired by rust-postgres Message::parse.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
// Every message contains message type byte and 4 bytes len; can't do
// much without them.
if buf.len() < 5 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
/// Ok(())
/// }
/// ```
#[inline(never)]
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 {
return Err(ProtocolError::Protocol(format!(
"invalid message length {}",
len
)));
}
/// Read one message from the stream.
/// See documentation for `Self::read`.
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
// AsyncReadExt methods of the stream.
SyncFuture::new(async move {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match retry_read!(stream.read_u8().await) {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// length field includes itself, but not message type.
let total_len = len as usize + 1;
if buf.len() < total_len {
// Don't have full message yet.
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// The message length includes itself, so it better be at least 4.
let len = retry_read!(stream.read_u32().await)
.map_err(ConnectionError::Socket)?
.checked_sub(4)
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
// got the message, advance buffer
let mut msg = buf.split_to(total_len).freeze();
msg.advance(5); // consume message type and len
let body = {
let mut buffer = vec![0u8; len as usize];
stream
.read_exact(&mut buffer)
.await
.map_err(ConnectionError::Socket)?;
Bytes::from(buffer)
};
match tag {
b'Q' => Ok(Some(FeMessage::Query(msg))),
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(msg))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
tag => Err(ProtocolError::Protocol(format!(
"unknown message tag: {tag},'{msg:?}'"
))),
}
match tag {
b'Q' => Ok(Some(FeMessage::Query(body))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => {
return Err(ConnectionError::Protocol(format!(
"unknown message tag: {tag},'{body:?}'"
)))
}
}
})
}
}
impl FeStartupPacket {
/// Read and parse startup message from the `buf` input buffer. It is
/// different from [`FeMessage::parse`] because startup messages don't have
/// message type byte; otherwise, its comments apply.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
// need at least 4 bytes with packet len
if buf.len() < 4 {
let to_read = 4 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
SyncFuture::new(async move {
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match retry_read!(stream.read_u32().await) {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ProtocolError::Protocol(format!(
"invalid startup packet message length {}",
len
)));
}
if buf.len() < len {
// Don't have full message yet.
let to_read = len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// got the message, advance buffer
let mut msg = buf.split_to(len).freeze();
msg.advance(4); // consume len
let request_code = msg.get_u32();
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if msg.remaining() != 8 {
return Err(ProtocolError::BadMessage(
"CancelRequest message is malformed, backend PID / secret key missing"
.to_owned(),
));
}
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: msg.get_i32(),
cancel_key: msg.get_i32(),
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ProtocolError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
#[allow(clippy::manual_range_contains)]
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ConnectionError::Protocol(format!(
"invalid message length {len}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// StartupMessage
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&msg)
.map_err(|_e| {
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
})?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let request_code =
retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream
.read_exact(params_bytes.as_mut())
.await
.map_err(ConnectionError::Socket)?;
params.insert(name.to_owned(), value.to_owned());
// Parse params depending on request code
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if params_len != 8 {
return Err(ConnectionError::Protocol(
"expected 8 bytes for CancelRequest params".to_string(),
));
}
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
})
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
}
};
Ok(Some(message))
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ConnectionError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&params_bytes)
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
params.insert(name.to_owned(), value.to_owned());
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
}
}
};
Ok(Some(FeMessage::StartupPacket(message)))
})
}
}
impl FeParseMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
// FIXME: the rust-postgres driver uses a named prepared statement
// for copy_out(). We're not prepared to handle that correctly. For
// now, just ignore the statement name, assuming that the client never
@@ -390,82 +452,55 @@ impl FeParseMessage {
let _pstmt_name = read_cstr(&mut buf)?;
let query_string = read_cstr(&mut buf)?;
if buf.remaining() < 2 {
return Err(ProtocolError::BadMessage(
"Parse message is malformed, nparams missing".to_string(),
));
}
let nparams = buf.get_i16();
if nparams != 0 {
return Err(ProtocolError::BadMessage(
"query params not implemented".to_string(),
));
}
ensure!(nparams == 0, "query params not implemented");
Ok(FeMessage::Parse(FeParseMessage { query_string }))
}
}
impl FeDescribeMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let kind = buf.get_u8();
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
if kind != b'S' {
return Err(ProtocolError::BadMessage(
"only prepared statemement Describe is implemented".to_string(),
));
}
ensure!(
kind == b'S',
"only prepared statemement Describe is implemented"
);
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
}
}
impl FeExecuteMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
if buf.remaining() < 4 {
return Err(ProtocolError::BadMessage(
"FeExecuteMessage message is malformed, maxrows missing".to_string(),
));
}
let maxrows = buf.get_i32();
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
if maxrows != 0 {
return Err(ProtocolError::BadMessage(
"row limit in Execute message not implemented".to_string(),
));
}
ensure!(portal_name.is_empty(), "named portals not implemented");
ensure!(maxrows == 0, "row limit in Execute message not implemented");
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
}
}
impl FeBindMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
ensure!(portal_name.is_empty(), "named portals not implemented");
Ok(FeMessage::Bind(FeBindMessage))
}
}
impl FeCloseMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
@@ -494,7 +529,6 @@ pub enum BeMessage<'a> {
CloseComplete,
// None means column is NULL
DataRow(&'a [Option<&'a [u8]>]),
// None errcode means internal_error will be sent.
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
/// Single byte - used in response to SSLRequest/GSSENCRequest.
EncryptionResponse(bool),
@@ -525,11 +559,6 @@ impl<'a> BeMessage<'a> {
value: b"UTF8",
};
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
name: b"integer_datetimes",
value: b"on",
};
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
pub fn server_version(version: &'a str) -> Self {
Self::ParameterStatus {
@@ -608,7 +637,7 @@ impl RowDescriptor<'_> {
#[derive(Debug)]
pub struct XLogDataBody<'a> {
pub wal_start: u64,
pub wal_end: u64, // current end of WAL on the server
pub wal_end: u64,
pub timestamp: i64,
pub data: &'a [u8],
}
@@ -648,11 +677,12 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
}
/// Safe write of s into buf as cstring (String in the protocol).
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
let bytes = s.as_ref();
if bytes.contains(&0) {
return Err(ProtocolError::BadMessage(
"string contains embedded null".to_owned(),
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(bytes);
@@ -660,27 +690,22 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolErr
Ok(())
}
/// Read cstring from buf, advancing it.
fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
let pos = buf
.iter()
.position(|x| *x == 0)
.ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
let result = buf.split_to(pos);
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let pos = buf.iter().position(|x| *x == 0);
let result = buf.split_to(pos.context("missing terminator")?);
buf.advance(1); // drop the null terminator
Ok(result)
}
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
impl<'a> BeMessage<'a> {
/// Serialize `message` to the given `buf`.
/// Apart from smart memory managemet, BytesMut is good here as msg len
/// precedes its body and it is handy to write it down first and then fill
/// the length. With Write we would have to either calc it manually or have
/// one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
/// Write message to the given buf.
// Unlike the reading side, we use BytesMut
// here as msg len precedes its body and it is handy to write it down first
// and then fill the length. With Write we would have to either calc it
// manually or have one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
match message {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
@@ -725,7 +750,7 @@ impl<'a> BeMessage<'a> {
buf.put_slice(extra);
}
}
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -829,7 +854,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg, buf)?;
buf.put_u8(0); // terminator
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -852,7 +877,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -907,7 +932,7 @@ impl<'a> BeMessage<'a> {
buf.put_i32(-1); /* typmod */
buf.put_i16(0); /* format code */
}
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -974,7 +999,7 @@ impl ReplicationFeedback {
// null-terminated string - key,
// uint32 - value length in bytes
// value itself
pub fn serialize(&self, buf: &mut BytesMut) {
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
buf.put_slice(b"current_timeline_size\0");
buf.put_i32(8);
@@ -999,6 +1024,7 @@ impl ReplicationFeedback {
buf.put_slice(b"ps_replytime\0");
buf.put_i32(8);
buf.put_i64(timestamp);
Ok(())
}
// Deserialize ReplicationFeedback message
@@ -1066,7 +1092,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data);
rf.serialize(&mut data).unwrap();
let rf_parsed = ReplicationFeedback::parse(data.freeze());
assert_eq!(rf, rf_parsed);
@@ -1081,7 +1107,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data);
rf.serialize(&mut data).unwrap();
// Add an extra field to the buffer and adjust number of keys
if let Some(first) = data.first_mut() {
@@ -1123,6 +1149,15 @@ mod tests {
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
}
// Make sure that `read` is sync/async callable
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
let _ = FeMessage::read(&mut [].as_ref());
let _ = FeMessage::read_fut(stream).await;
let _ = FeStartupPacket::read(&mut [].as_ref());
let _ = FeStartupPacket::read_fut(stream).await;
}
}
fn terminate_code(code: &[u8; 5]) -> [u8; 6] {

179
libs/pq_proto/src/sync.rs Normal file
View File

@@ -0,0 +1,179 @@
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::{io, task};
pin_project! {
/// We use this future to mark certain methods
/// as callable in both sync and async modes.
#[repr(transparent)]
pub struct SyncFuture<S, T: Future> {
#[pin]
inner: T,
_marker: PhantomData<S>,
}
}
/// This wrapper lets us synchronously wait for inner future's completion
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
/// For instance, `S` may be substituted with types implementing
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
impl<S, T: Future> SyncFuture<S, T> {
/// NOTE: caller should carefully pick a type for `S`,
/// because we don't want to enable [`SyncFuture::wait`] when
/// it's in fact impossible to run the future synchronously.
/// Violation of this contract will not cause UB, but
/// panics and async event loop freezes won't please you.
///
/// Example:
///
/// ```
/// # use pq_proto::sync::SyncFuture;
/// # use std::future::Future;
/// # use tokio::io::AsyncReadExt;
/// #
/// // Parse a pair of numbers from a stream
/// pub fn parse_pair<Reader>(
/// stream: &mut Reader,
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
/// where
/// Reader: tokio::io::AsyncRead + Unpin,
/// {
/// // If `Reader` is a `SyncProof`, this will give caller
/// // an opportunity to use `SyncFuture::wait`, because
/// // `.await` will always result in `Poll::Ready`.
/// SyncFuture::new(async move {
/// let x = stream.read_u32().await?;
/// let y = stream.read_u64().await?;
/// Ok((x, y))
/// })
/// }
/// ```
pub fn new(inner: T) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
}
impl<S, T: Future> Future for SyncFuture<S, T> {
type Output = T::Output;
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
/// Postulates that we can call [`SyncFuture::wait`].
/// If implementer is also a [`Future`], it should always
/// return [`task::Poll::Ready`] from [`Future::poll`].
///
/// Each implementation should document which futures
/// specifically are being declared sync-proof.
pub trait SyncPostulate {}
impl<T: SyncPostulate> SyncPostulate for &T {}
impl<T: SyncPostulate> SyncPostulate for &mut T {}
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
/// Synchronously wait for future completion.
pub fn wait(mut self) -> T::Output {
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
std::ptr::null(),
&task::RawWakerVTable::new(
|_| RAW_WAKER,
|_| panic!("SyncFuture: failed to wake"),
|_| panic!("SyncFuture: failed to wake by ref"),
|_| { /* drop is no-op */ },
),
);
// SAFETY: We never move `self` during this call;
// furthermore, it will be dropped in the end regardless of panics
let this = unsafe { Pin::new_unchecked(&mut self) };
// SAFETY: This waker doesn't do anything apart from panicking
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
let context = &mut task::Context::from_waker(&waker);
match this.poll(context) {
task::Poll::Ready(res) => res,
_ => panic!("SyncFuture: unexpected pending!"),
}
}
}
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
/// NOTE: you **should not** use this in async code.
#[repr(transparent)]
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
/// and allows the future to await on any of the [`AsyncRead`]
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
#[inline(always)]
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
task::Poll::Ready(
// `Read::read` will block, meaning we don't need a real event loop!
self.0
.read(buf.initialize_unfilled())
.map(|sz| buf.advance(sz)),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
fn bytes_add<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
SyncFuture::new(async move {
let a = stream.read_u32().await?;
let b = stream.read_u32().await?;
Ok(a + b)
})
}
#[test]
fn test_sync() {
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
.wait()
.unwrap();
assert_eq!(res, 300);
}
// We need a single-threaded executor for this test
#[tokio::test(flavor = "current_thread")]
async fn test_async() {
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
let write = async move {
tx.write_u32(100).await?;
tx.write_u32(200).await?;
Ok(())
};
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
assert_eq!(res, 300);
}
}

View File

@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
}
pub struct Download {
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
/// Extra key-value data, associated with the current remote file.
pub metadata: Option<StorageMetadata>,
}

View File

@@ -12,38 +12,41 @@ anyhow.workspace = true
bincode.workspace = true
bytes.workspace = true
heapless.workspace = true
hex = { workspace = true, features = ["serde"] }
hyper = { workspace = true, features = ["full"] }
futures = { workspace = true}
jsonwebtoken.workspace = true
nix.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
regex.workspace = true
routerify.workspace = true
serde.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["json"] }
nix.workspace = true
signal-hook.workspace = true
rand.workspace = true
jsonwebtoken.workspace = true
hex = { workspace = true, features = ["serde"] }
rustls.workspace = true
rustls-split.workspace = true
git-version.workspace = true
serde_with.workspace = true
once_cell.workspace = true
strum.workspace = true
strum_macros.workspace = true
url.workspace = true
uuid = { version = "1.2", features = ["v4", "serde"] }
metrics.workspace = true
workspace_hack.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
url.workspace = true
uuid = { version = "1.2", features = ["v4", "serde"] }
[dev-dependencies]
byteorder.workspace = true
bytes.workspace = true
criterion.workspace = true
hex-literal.workspace = true
tempfile.workspace = true
criterion.workspace = true
rustls-pemfile.workspace = true
[[bench]]
name = "benchmarks"

View File

@@ -1,4 +1,7 @@
// For details about authentication see docs/authentication.md
//
// TODO: use ed25519 keys
// Relevant issue: https://github.com/Keats/jsonwebtoken/issues/162
use serde;
use std::fs;
@@ -13,10 +16,9 @@ use serde_with::{serde_as, DisplayFromStr};
use crate::id::TenantId;
/// Algorithm to use. We require EdDSA.
const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Scope {
// Provides access to all data for a specific tenant (specified in `struct Claims` below)
@@ -31,9 +33,8 @@ pub enum Scope {
SafekeeperData,
}
/// JWT payload. See docs/authentication.md for the format
#[serde_as]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
#[serde(default)]
#[serde_as(as = "Option<DisplayFromStr>")]
@@ -54,8 +55,7 @@ pub struct JwtAuth {
impl JwtAuth {
pub fn new(decoding_key: DecodingKey) -> Self {
let mut validation = Validation::default();
validation.algorithms = vec![STORAGE_TOKEN_ALGORITHM];
let mut validation = Validation::new(JWT_ALGORITHM);
// The default 'required_spec_claims' is 'exp'. But we don't want to require
// expiration.
validation.required_spec_claims = [].into();
@@ -67,7 +67,7 @@ impl JwtAuth {
pub fn from_key_path(key_path: &Path) -> Result<Self> {
let public_key = fs::read(key_path)?;
Ok(Self::new(DecodingKey::from_ed_pem(&public_key)?))
Ok(Self::new(DecodingKey::from_rsa_pem(&public_key)?))
}
pub fn decode(&self, token: &str) -> Result<TokenData<Claims>> {
@@ -85,75 +85,6 @@ impl std::fmt::Debug for JwtAuth {
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn encode_from_key_file(claims: &Claims, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_ed_pem(key_data)?;
Ok(encode(&Header::new(STORAGE_TOKEN_ALGORITHM), claims, &key)?)
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
// Generated with:
//
// openssl genpkey -algorithm ed25519 -out ed25519-priv.pem
// openssl pkey -in ed25519-priv.pem -pubout -out ed25519-pub.pem
const TEST_PUB_KEY_ED25519: &[u8] = br#"
-----BEGIN PUBLIC KEY-----
MCowBQYDK2VwAyEARYwaNBayR+eGI0iXB4s3QxE3Nl2g1iWbr6KtLWeVD/w=
-----END PUBLIC KEY-----
"#;
const TEST_PRIV_KEY_ED25519: &[u8] = br#"
-----BEGIN PRIVATE KEY-----
MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
-----END PRIVATE KEY-----
"#;
#[test]
fn test_decode() -> Result<(), anyhow::Error> {
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
//
// ```
// {
// "scope": "tenant",
// "tenant_id": "3d1f7595b468230304e0b73cecbcb081",
// "iss": "neon.controlplane",
// "exp": 1709200879,
// "iat": 1678442479
// }
// ```
//
let encoded_eddsa = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.U3eA8j-uU-JnhzeO3EDHRuXLwkAUFCPxtGHEgw6p7Ccc3YRbFs2tmCdbD9PZEXP-XsxSeBQi1FY0YPcT3NXADw";
// Check it can be validated with the public key
let auth = JwtAuth::new(DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519)?);
let claims_from_token = auth.decode(encoded_eddsa)?.claims;
assert_eq!(claims_from_token, expected_claims);
Ok(())
}
#[test]
fn test_encode() -> Result<(), anyhow::Error> {
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_ED25519)?;
// decode it back
let auth = JwtAuth::new(DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519)?);
let decoded = auth.decode(&encoded)?;
assert_eq!(decoded.claims, claims);
Ok(())
}
let key = EncodingKey::from_rsa_pem(key_data)?;
Ok(encode(&Header::new(JWT_ALGORITHM), claims, &key)?)
}

View File

@@ -11,7 +11,7 @@ where
P: AsRef<Path>,
{
fn is_empty_dir(&self) -> io::Result<bool> {
Ok(fs::read_dir(self)?.next().is_none())
Ok(fs::read_dir(self)?.into_iter().next().is_none())
}
}

View File

@@ -3,14 +3,15 @@ use crate::http::error;
use anyhow::{anyhow, Context};
use hyper::header::{HeaderName, AUTHORIZATION};
use hyper::http::HeaderValue;
use hyper::Method;
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server};
use hyper::{Method, StatusCode};
use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
use once_cell::sync::Lazy;
use routerify::ext::RequestExt;
use routerify::{Middleware, RequestInfo, Router, RouterBuilder, RouterService};
use routerify::RequestInfo;
use routerify::{Middleware, Router, RouterBuilder, RouterService};
use tokio::task::JoinError;
use tracing::{self, debug, info, info_span, warn, Instrument};
use tracing;
use std::future::Future;
use std::net::TcpListener;
@@ -26,83 +27,16 @@ static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
.expect("failed to define a metric")
});
static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
#[derive(Debug, Default, Clone)]
struct RequestId(String);
/// Adds a tracing info_span! instrumentation around the handler events,
/// logs the request start and end events for non-GET requests and non-200 responses.
///
/// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
/// in this type will get request info logged in the wrapping span, including the unique request ID.
///
/// There could be other ways to implement similar functionality:
///
/// * procmacros placed on top of all handler methods
/// With all the drawbacks of procmacros, brings no difference implementation-wise,
/// and little code reduction compared to the existing approach.
///
/// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
/// implemented for [`RouterBuilder`].
/// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
///
/// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
/// later, in a post-response middleware.
/// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
/// tries to achive with its `.instrument` used in the current approach.
///
/// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
pub struct RequestSpan<E, R, H>(pub H)
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static;
impl<E, R, H> RequestSpan<E, R, H>
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static,
{
/// Creates a tracing span around inner request handler and executes the request handler in the contex of that span.
/// Use as `|r| RequestSpan(my_handler).handle(r)` instead of `my_handler` as the request handler to get the span enabled.
pub async fn handle(self, request: Request<Body>) -> Result<Response<Body>, E> {
let request_id = request.context::<RequestId>().unwrap_or_default().0;
let method = request.method();
let path = request.uri().path();
let request_span = info_span!("request", %method, %path, %request_id);
let log_quietly = method == Method::GET;
async move {
if log_quietly {
debug!("Handling request");
} else {
info!("Handling request");
}
// Note that we reuse `error::handler` here and not returning and error at all,
// yet cannot use `!` directly in the method signature due to `routerify::RouterBuilder` limitation.
// Usage of the error handler also means that we expect only the `ApiError` errors to be raised in this call.
//
// Panics are not handled separately, there's a `tracing_panic_hook` from another module to do that globally.
match (self.0)(request).await {
Ok(response) => {
let response_status = response.status();
if log_quietly && response_status.is_success() {
debug!("Request handled, status: {response_status}");
} else {
info!("Request handled, status: {response_status}");
}
Ok(response)
}
Err(e) => Ok(error::handler(e.into()).await),
}
}
.instrument(request_span)
.await
async fn logger(res: Response<Body>, info: RequestInfo) -> Result<Response<Body>, ApiError> {
// cannot factor out the Level to avoid the repetition
// because tracing can only work with const Level
// which is not the case here
if info.method() == Method::GET && res.status() == StatusCode::OK {
tracing::debug!("{} {} {}", info.method(), info.uri().path(), res.status());
} else {
tracing::info!("{} {} {}", info.method(), info.uri().path(), res.status());
}
Ok(res)
}
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
@@ -131,46 +65,26 @@ async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body
pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
) -> Middleware<B, ApiError> {
Middleware::pre(move |req| async move {
let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
Some(request_id) => request_id
.to_str()
.expect("extract request id value")
.to_owned(),
None => {
let request_id = uuid::Uuid::new_v4();
request_id.to_string()
}
};
req.set_context(RequestId(request_id));
Middleware::pre(move |mut req| async move {
let headers = req.headers_mut();
let name = HeaderName::from_str("UUID").expect("created header name");
let request_id = uuid::Uuid::new_v4().to_string();
let value = HeaderValue::from_str(&request_id).unwrap();
headers.insert(name, value);
if req.method() == Method::GET {
tracing::debug!("{} {} {}", req.method(), req.uri().path(), request_id);
} else {
tracing::info!("{} {} {}", req.method(), req.uri().path(), request_id);
}
Ok(req)
})
}
async fn add_request_id_header_to_response(
mut res: Response<Body>,
req_info: RequestInfo,
) -> Result<Response<Body>, ApiError> {
if let Some(request_id) = req_info.context::<RequestId>() {
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
};
Ok(res)
}
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
Router::builder()
.middleware(add_request_id_middleware())
.middleware(Middleware::post_with_info(
add_request_id_header_to_response,
))
.get("/metrics", |r| {
RequestSpan(prometheus_metrics_handler).handle(r)
})
.middleware(Middleware::post_with_info(logger))
.get("/metrics", prometheus_metrics_handler)
.err_handler(error::handler)
}
@@ -180,43 +94,40 @@ pub fn attach_openapi_ui(
spec_mount_path: &'static str,
ui_mount_path: &'static str,
) -> RouterBuilder<hyper::Body, ApiError> {
router_builder
.get(spec_mount_path, move |r| {
RequestSpan(move |_| async move { Ok(Response::builder().body(Body::from(spec)).unwrap()) })
.handle(r)
})
.get(ui_mount_path, move |r| RequestSpan( move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
}).handle(r))
router_builder.get(spec_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(spec)).unwrap())
}).get(ui_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
})
}
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
@@ -278,7 +189,7 @@ where
async move {
let headers = response.headers_mut();
if headers.contains_key(&name) {
warn!(
tracing::warn!(
"{} response already contains header {:?}",
request_info.uri(),
&name,
@@ -318,7 +229,7 @@ pub fn serve_thread_main<S>(
where
S: Future<Output = ()> + Send + Sync,
{
info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
tracing::info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
// Create a Service from the router above to handle incoming requests.
let service = RouterService::new(router_builder.build().map_err(|err| anyhow!(err))?).unwrap();
@@ -338,48 +249,3 @@ where
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use futures::future::poll_fn;
use hyper::service::Service;
use routerify::RequestServiceBuilder;
use std::net::{IpAddr, SocketAddr};
#[tokio::test]
async fn test_request_id_returned() {
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
let mut service = builder.build(remote_addr);
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
panic!("request service is not ready: {:?}", e);
}
let mut req: Request<Body> = Request::default();
req.headers_mut()
.append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
assert!(header_val == "42", "response header mismatch");
}
#[tokio::test]
async fn test_request_id_empty() {
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
let mut service = builder.build(remote_addr);
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
panic!("request service is not ready: {:?}", e);
}
let req: Request<Body> = Request::default();
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
assert_ne!(header_val, None, "response header should NOT be empty");
}
}

View File

@@ -1,9 +1,7 @@
use std::fmt::Display;
use anyhow::Context;
use bytes::Buf;
use hyper::{header, Body, Request, Response, StatusCode};
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use super::error::ApiError;
@@ -33,12 +31,3 @@ pub fn json_response<T: Serialize>(
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}
/// Serialize through Display trait.
pub fn display_serialize<S, F>(z: &F, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
F: Display,
{
s.serialize_str(&format!("{}", z))
}

View File

@@ -23,7 +23,7 @@ pub enum IdError {
struct Id([u8; 16]);
impl Id {
pub fn get_from_buf(buf: &mut impl bytes::Buf) -> Id {
pub fn get_from_buf(buf: &mut dyn bytes::Buf) -> Id {
let mut arr = [0u8; 16];
buf.copy_to_slice(&mut arr);
Id::from(arr)
@@ -112,7 +112,7 @@ impl fmt::Debug for Id {
macro_rules! id_newtype {
($t:ident) => {
impl $t {
pub fn get_from_buf(buf: &mut impl bytes::Buf) -> $t {
pub fn get_from_buf(buf: &mut dyn bytes::Buf) -> $t {
$t(Id::get_from_buf(buf))
}

View File

@@ -13,6 +13,8 @@ pub mod simple_rcu;
pub mod vec_map;
pub mod bin_ser;
pub mod postgres_backend;
pub mod postgres_backend_async;
// helper functions for creating and fsyncing
pub mod crashsafe;
@@ -25,6 +27,9 @@ pub mod id;
// http endpoint utils
pub mod http;
// socket splitting utils
pub mod sock_split;
// common log initialisation routine
pub mod logging;
@@ -49,11 +54,6 @@ pub mod fs_ext;
pub mod history_buffer;
pub mod measured_stream;
pub mod serde_percent;
pub mod serde_regex;
/// use with fail::cfg("$name", "return(2000)")
#[macro_export]
macro_rules! failpoint_sleep_millis_async {

View File

@@ -1,77 +0,0 @@
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::{io, task};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MeasuredStream<S, R, W> {
#[pin]
stream: S,
write_count: usize,
inc_read_count: R,
inc_write_count: W,
}
}
impl<S, R, W> MeasuredStream<S, R, W> {
pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_read_count,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
let filled = buf.filled().len();
this.stream.poll_read(context, buf).map_ok(|()| {
let cnt = buf.filled().len() - filled;
// Increment the read count.
(this.inc_read_count)(cnt);
})
}
}
impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}

View File

@@ -0,0 +1,485 @@
//! Server-side synchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend_async::{log_query_error, short_error, QueryError};
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
use anyhow::Context;
use bytes::{Bytes, BytesMut};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::io::{self, Write};
use std::net::{Shutdown, SocketAddr, TcpStream};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tracing::*;
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
fn is_shutdown_requested(&self) -> bool {
false
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Bidirectional(BidiStream),
WriteOnly(WriteStream),
}
impl Stream {
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how),
Self::WriteOnly(write_stream) => write_stream.shutdown(how),
}
}
}
impl io::Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.write(buf),
Self::WriteOnly(write_stream) => write_stream.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.flush(),
Self::WriteOnly(write_stream) => write_stream.flush(),
}
}
}
pub struct PostgresBackend {
stream: Option<Stream>,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Helper function for socket read loops
pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
for cause in error.chain() {
if let Some(io_error) = cause.downcast_ref::<io::Error>() {
if io_error.kind() == std::io::ErrorKind::WouldBlock {
return true;
}
}
}
false
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
set_read_timeout: bool,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
if set_read_timeout {
socket
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
}
Ok(Self {
stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn into_stream(self) -> Stream {
self.stream.unwrap()
}
/// Get direct reference (into the Option) to the read stream.
fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> {
match &mut self.stream {
Some(Stream::Bidirectional(stream)) => Ok(stream),
_ => anyhow::bail!("reader taken"),
}
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
pub fn take_stream_in(&mut self) -> Option<ReadStream> {
let stream = self.stream.take();
match stream {
Some(Stream::Bidirectional(bidi_stream)) => {
let (read, write) = bidi_stream.split();
self.stream = Some(Stream::WriteOnly(write));
Some(read)
}
stream => {
self.stream = stream;
None
}
}
}
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
let (state, stream) = (self.state, self.get_stream_in()?);
use ProtoState::*;
match state {
Initialization | Encrypted => FeStartupPacket::read(stream),
Authentication | Established => FeMessage::read(stream),
}
.map_err(QueryError::from)
}
/// Write message into internal output buffer.
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Flush output buffer into the socket.
pub fn flush(&mut self) -> io::Result<&mut Self> {
let stream = self.stream.as_mut().unwrap();
stream.write_all(&self.buf_out)?;
self.buf_out.clear();
Ok(self)
}
/// Write message into internal buffer and flush it.
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
let ret = self.run_message_loop(handler);
if let Some(stream) = self.stream.as_mut() {
let _ = stream.shutdown(Shutdown::Both);
}
ret
}
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
trace!("postgres backend to {:?} started", self.peer_addr);
let mut unnamed_query_string = Bytes::new();
while !handler.is_shutdown_requested() {
match self.read_message() {
Ok(message) => {
if let Some(msg) = message {
trace!("got message {msg:?}");
match self.process_message(handler, msg, &mut unnamed_query_string)? {
ProcessMsgResult::Continue => continue,
ProcessMsgResult::Break => break,
}
} else {
break;
}
}
Err(e) => {
if let QueryError::Other(e) = &e {
if is_socket_read_timed_out(e) {
continue;
}
}
return Err(e);
}
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
pub fn start_tls(&mut self) -> anyhow::Result<()> {
match self.stream.take() {
Some(Stream::Bidirectional(bidi_stream)) => {
let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?;
self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?));
Ok(())
}
stream => {
self.stream = stream;
anyhow::bail!("can't start TLs without bidi stream");
}
}
}
fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
if self.state < ProtoState::Established
&& !matches!(
msg,
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
)
{
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls()?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}"
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}

View File

@@ -0,0 +1,634 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend::AuthType;
use anyhow::Context;
use bytes::{Buf, Bytes, BytesMut};
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use std::{future::Future, task::ready};
use tracing::{debug, error, info, trace};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio_rustls::TlsAcceptor;
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Socket(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
#[async_trait::async_trait]
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Unencrypted(BufReader<tokio::net::TcpStream>),
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
Broken,
}
impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Broken => unreachable!(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
Self::Broken => unreachable!(),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Broken => unreachable!(),
}
}
}
impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Broken => unreachable!(),
}
}
}
pub struct PostgresBackend {
stream: Stream,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
// The data between 0 and "current position" as tracked by the bytes::Buf
// implementation of BytesMut, have already been written.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
Ok(Self {
stream: Stream::Unencrypted(BufReader::new(socket)),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is closed.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
use ProtoState::*;
match self.state {
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
Closed => Ok(None),
}
.map_err(QueryError::from)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
while self.buf_out.has_remaining() {
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
self.buf_out.advance(bytes_written);
}
self.buf_out.clear();
Ok(())
}
/// Write message into internal output buffer.
pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
CopyDataWriter { pgb: self }
}
/// A polling function that tries to write all the data from 'buf_out' to the
/// underlying stream.
fn poll_write_buf(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
while self.buf_out.has_remaining() {
match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) {
Ok(bytes_written) => self.buf_out.advance(bytes_written),
Err(err) => return Poll::Ready(Err(err)),
}
}
Poll::Ready(Ok(()))
}
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
let _ = self.stream.shutdown();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = async {
while self.state < ProtoState::Established {
if let Some(msg) = self.read_message().await? {
trace!("got message {msg:?} during handshake");
match self.process_handshake_message(handler, msg).await? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
} else {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
Ok::<(), QueryError>(())
} => {
// Handshake complete.
result?;
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
if let Stream::Unencrypted(plain_stream) =
std::mem::replace(&mut self.stream, Stream::Broken)
{
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
let tls_stream = acceptor.accept(plain_stream).await?;
self.stream = Stream::Tls(Box::new(tls_stream));
return Ok(());
};
anyhow::bail!("TLS already started");
}
async fn process_handshake_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
) -> Result<ProcessMsgResult, QueryError> {
assert!(self.state < ProtoState::Established);
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
_ => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
Ok(ProcessMsgResult::Continue)
}
async fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {:?}",
msg
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb.write_message(&BeMessage::CopyData(buf))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
pub(super) fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}

View File

@@ -1,83 +0,0 @@
//! A serde::Deserialize type for percentages.
//!
//! See [`Percent`] for details.
use serde::{Deserialize, Serialize};
/// If the value is not an integer between 0 and 100,
/// deserialization fails with a descriptive error.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Percent(#[serde(deserialize_with = "deserialize_pct_0_to_100")] u8);
impl Percent {
pub fn get(&self) -> u8 {
self.0
}
}
fn deserialize_pct_0_to_100<'de, D>(deserializer: D) -> Result<u8, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let v: u8 = serde::de::Deserialize::deserialize(deserializer)?;
if v > 100 {
return Err(serde::de::Error::custom(
"must be an integer between 0 and 100",
));
}
Ok(v)
}
#[cfg(test)]
mod tests {
use super::Percent;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
struct Foo {
bar: Percent,
}
#[test]
fn basics() {
let input = r#"{ "bar": 50 }"#;
let foo: Foo = serde_json::from_str(input).unwrap();
assert_eq!(foo.bar.get(), 50);
}
#[test]
fn null_handling() {
let input = r#"{ "bar": null }"#;
let res: Result<Foo, _> = serde_json::from_str(input);
assert!(res.is_err());
}
#[test]
fn zero() {
let input = r#"{ "bar": 0 }"#;
let foo: Foo = serde_json::from_str(input).unwrap();
assert_eq!(foo.bar.get(), 0);
}
#[test]
fn out_of_range_above() {
let input = r#"{ "bar": 101 }"#;
let res: Result<Foo, _> = serde_json::from_str(input);
assert!(res.is_err());
}
#[test]
fn out_of_range_below() {
let input = r#"{ "bar": -1 }"#;
let res: Result<Foo, _> = serde_json::from_str(input);
assert!(res.is_err());
}
#[test]
fn float() {
let input = r#"{ "bar": 50.5 }"#;
let res: Result<Foo, _> = serde_json::from_str(input);
assert!(res.is_err());
}
#[test]
fn string() {
let input = r#"{ "bar": "50 %" }"#;
let res: Result<Foo, _> = serde_json::from_str(input);
assert!(res.is_err());
}
}

View File

@@ -1,60 +0,0 @@
//! A `serde::{Deserialize,Serialize}` type for regexes.
use std::ops::Deref;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(transparent)]
pub struct Regex(
#[serde(
deserialize_with = "deserialize_regex",
serialize_with = "serialize_regex"
)]
regex::Regex,
);
fn deserialize_regex<'de, D>(deserializer: D) -> Result<regex::Regex, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let s: String = serde::de::Deserialize::deserialize(deserializer)?;
let re = regex::Regex::new(&s).map_err(serde::de::Error::custom)?;
Ok(re)
}
fn serialize_regex<S>(re: &regex::Regex, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
serializer.collect_str(re.as_str())
}
impl Deref for Regex {
type Target = regex::Regex;
fn deref(&self) -> &regex::Regex {
&self.0
}
}
impl PartialEq for Regex {
fn eq(&self, other: &Regex) -> bool {
// comparing the automatons would be quite complicated
self.as_str() == other.as_str()
}
}
impl Eq for Regex {}
#[cfg(test)]
mod tests {
#[test]
fn roundtrip() {
let input = r#""foo.*bar""#;
let re: super::Regex = serde_json::from_str(input).unwrap();
assert!(re.is_match("foo123bar"));
assert!(!re.is_match("foo"));
let output = serde_json::to_string(&re).unwrap();
assert_eq!(output, input);
}
}

View File

@@ -0,0 +1,206 @@
use std::{
io::{self, BufReader, Write},
net::{Shutdown, TcpStream},
sync::Arc,
};
use rustls::Connection;
/// Wrapper supporting reads of a shared TcpStream.
pub struct ArcTcpRead(Arc<TcpStream>);
impl io::Read for ArcTcpRead {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
(&*self.0).read(buf)
}
}
impl std::ops::Deref for ArcTcpRead {
type Target = TcpStream;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
/// Wrapper around a TCP Stream supporting buffered reads.
pub struct BufStream(BufReader<ArcTcpRead>);
impl io::Read for BufStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl io::Write for BufStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.get_ref().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.get_ref().flush()
}
}
impl BufStream {
/// Unwrap into the internal BufReader.
fn into_reader(self) -> BufReader<ArcTcpRead> {
self.0
}
/// Returns a reference to the underlying TcpStream.
fn get_ref(&self) -> &TcpStream {
&self.0.get_ref().0
}
}
pub enum ReadStream {
Tcp(BufReader<ArcTcpRead>),
Tls(rustls_split::ReadHalf),
}
impl io::Read for ReadStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(reader) => reader.read(buf),
Self::Tls(read_half) => read_half.read(buf),
}
}
}
impl ReadStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
pub enum WriteStream {
Tcp(Arc<TcpStream>),
Tls(rustls_split::WriteHalf),
}
impl WriteStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
impl io::Write for WriteStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.as_ref().write(buf),
Self::Tls(write_half) => write_half.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.as_ref().flush(),
Self::Tls(write_half) => write_half.flush(),
}
}
}
type TlsStream<T> = rustls::StreamOwned<rustls::ServerConnection, T>;
pub enum BidiStream {
Tcp(BufStream),
/// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`].
Tls(Box<TlsStream<BufStream>>),
}
impl BidiStream {
pub fn from_tcp(stream: TcpStream) -> Self {
Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream)))))
}
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(tls_boxed) => {
if how == Shutdown::Read {
tls_boxed.sock.get_ref().shutdown(how)
} else {
tls_boxed.conn.send_close_notify();
let res = tls_boxed.flush();
tls_boxed.sock.get_ref().shutdown(how)?;
res
}
}
}
}
/// Split the bi-directional stream into two owned read and write halves.
pub fn split(self) -> (ReadStream, WriteStream) {
match self {
Self::Tcp(stream) => {
let reader = stream.into_reader();
let stream: Arc<TcpStream> = reader.get_ref().0.clone();
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
}
Self::Tls(tls_boxed) => {
let reader = tls_boxed.sock.into_reader();
let buffer_data = reader.buffer().to_owned();
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
// TODO would be nice to avoid the Arc here
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
let (read_half, write_half) = rustls_split::split(
socket,
Connection::Server(tls_boxed.conn),
read_buf_cfg,
write_buf_cfg,
);
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
}
}
}
pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result<Self> {
match self {
Self::Tcp(mut stream) => {
conn.complete_io(&mut stream)?;
assert!(!conn.is_handshaking());
Ok(Self::Tls(Box::new(TlsStream::new(conn, stream))))
}
Self::Tls { .. } => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"TLS is already started on this stream",
)),
}
}
}
impl io::Read for BidiStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::Tls(tls_boxed) => tls_boxed.read(buf),
}
}
}
impl io::Write for BidiStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
Self::Tls(tls_boxed) => tls_boxed.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
Self::Tls(tls_boxed) => tls_boxed.flush(),
}
}
}

View File

@@ -0,0 +1,238 @@
use std::{
collections::HashMap,
io::{Cursor, Read, Write},
net::{TcpListener, TcpStream},
sync::Arc,
};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use once_cell::sync::Lazy;
use utils::{
postgres_backend::{AuthType, Handler, PostgresBackend},
postgres_backend_async::QueryError,
};
fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).unwrap();
let (server_stream, _) = listener.accept().unwrap();
(server_stream, client_stream)
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
#[test]
// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274),
// we resize the vector so doing some modifications after all
#[allow(clippy::read_zero_byte_vec)]
fn ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
const QUERY: &str = "hello world";
let client_jh = std::thread::spawn(move || {
// SSLRequest
client_sock.write_u32::<BigEndian>(8).unwrap();
client_sock.write_u32::<BigEndian>(80877103).unwrap();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'S', ssl_response);
let cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let client_config = Arc::new(cfg);
let dns_name = "localhost".try_into().unwrap();
let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap();
conn.complete_io(&mut client_sock).unwrap();
assert!(!conn.is_handshaking());
let mut stream = rustls::Stream::new(&mut conn, &mut client_sock);
// StartupMessage
stream.write_u32::<BigEndian>(9).unwrap();
stream.write_u32::<BigEndian>(196608).unwrap();
stream.write_u8(0).unwrap();
stream.flush().unwrap();
// wait for ReadyForQuery
let mut msg_buf = Vec::new();
loop {
let msg = stream.read_u8().unwrap();
let size = stream.read_u32::<BigEndian>().unwrap() - 4;
msg_buf.resize(size as usize, 0);
stream.read_exact(&mut msg_buf).unwrap();
if msg == b'Z' {
// ReadyForQuery
break;
}
}
// Query
stream.write_u8(b'Q').unwrap();
stream
.write_u32::<BigEndian>(4u32 + QUERY.len() as u32)
.unwrap();
stream.write_all(QUERY.as_ref()).unwrap();
stream.flush().unwrap();
// ReadyForQuery
let msg = stream.read_u8().unwrap();
assert_eq!(msg, b'Z');
});
struct TestHandler {
got_query: bool,
}
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError> {
self.got_query = query_string == QUERY;
Ok(())
}
}
let mut handler = TestHandler { got_query: false };
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
pgb.run(&mut handler).unwrap();
assert!(handler.got_query);
client_jh.join().unwrap();
// TODO consider shutdown behavior
}
#[test]
fn no_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
let mut buf = BytesMut::new();
// SSLRequest
buf.put_u32(8);
buf.put_u32(80877103);
client_sock.write_all(&buf).unwrap();
buf.clear();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'N', ssl_response);
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap();
pgb.run(&mut handler).unwrap();
client_jh.join().unwrap();
}
#[test]
fn server_forces_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
// StartupMessage
client_sock.write_u32::<BigEndian>(9).unwrap();
client_sock.write_u32::<BigEndian>(196608).unwrap();
client_sock.write_u8(0).unwrap();
client_sock.flush().unwrap();
// ErrorResponse
assert_eq!(client_sock.read_u8().unwrap(), b'E');
let len = client_sock.read_u32::<BigEndian>().unwrap() - 4;
let mut body = vec![0; len as usize];
client_sock.read_exact(&mut body).unwrap();
let mut body = Bytes::from(body);
let mut errors = HashMap::new();
loop {
let field_type = body.get_u8();
if field_type == 0u8 {
break;
}
let end_idx = body.iter().position(|&b| b == 0u8).unwrap();
let mut value = body.split_to(end_idx + 1);
assert_eq!(value[end_idx], 0u8);
value.truncate(end_idx);
let old = errors.insert(field_type, value);
assert!(old.is_none());
}
assert!(!body.has_remaining());
assert_eq!("must connect with TLS", errors.get(&b'M').unwrap());
// TODO read failure
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
let res = pgb.run(&mut handler).unwrap_err();
assert_eq!("client did not connect with TLS", format!("{}", res));
client_jh.join().unwrap();
// TODO consider shutdown behavior
}

View File

@@ -37,7 +37,6 @@ num-traits.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
postgres.workspace = true
postgres_backend.workspace = true
postgres-protocol.workspace = true
postgres-types.workspace = true
rand.workspace = true
@@ -48,7 +47,6 @@ serde_json = { workspace = true, features = ["raw_value"] }
serde_with.workspace = true
signal-hook.workspace = true
svg_fmt.workspace = true
sync_wrapper.workspace = true
tokio-tar.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] }

View File

@@ -8,7 +8,6 @@ use anyhow::{anyhow, Context};
use clap::{Arg, ArgAction, Command};
use fail::FailScenario;
use metrics::launch_timestamp::{set_launch_timestamp_metric, LaunchTimestamp};
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use remote_storage::GenericRemoteStorage;
use tracing::*;
@@ -24,10 +23,11 @@ use pageserver::{
tenant::mgr,
virtual_file,
};
use postgres_backend::AuthType;
use utils::{
auth::JwtAuth,
logging, project_git_version,
logging,
postgres_backend::AuthType,
project_git_version,
sentry_init::init_sentry,
signals::{self, Signal},
tcp_listener,
@@ -271,43 +271,43 @@ fn start_pageserver(
WALRECEIVER_RUNTIME.block_on(pageserver::broker_client::init_broker_client(conf))?;
// Initialize authentication for incoming connections
let http_auth;
let pg_auth;
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
// unwrap is ok because check is performed when creating config, so path is set and file exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
info!(
"Loading public key for verifying JWT tokens from {:#?}",
key_path
);
let auth: Arc<JwtAuth> = Arc::new(JwtAuth::from_key_path(key_path)?);
let auth = match &conf.auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => {
// unwrap is ok because check is performed when creating config, so path is set and file exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
Some(JwtAuth::from_key_path(key_path)?.into())
}
};
info!("Using auth: {:#?}", conf.auth_type);
http_auth = match &conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth.clone()),
};
pg_auth = match &conf.pg_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth),
};
} else {
http_auth = None;
pg_auth = None;
}
info!("Using auth for http API: {:#?}", conf.http_auth_type);
info!("Using auth for pg connections: {:#?}", conf.pg_auth_type);
match var("NEON_AUTH_TOKEN") {
Ok(v) => {
// TODO: remove ZENITH_AUTH_TOKEN once it's not used anywhere in development/staging/prod configuration.
match (var("ZENITH_AUTH_TOKEN"), var("NEON_AUTH_TOKEN")) {
(old, Ok(v)) => {
info!("Loaded JWT token for authentication with Safekeeper");
if let Ok(v_old) = old {
warn!(
"JWT token for Safekeeper is specified twice, ZENITH_AUTH_TOKEN is deprecated"
);
if v_old != v {
warn!("JWT token for Safekeeper has two different values, choosing NEON_AUTH_TOKEN");
}
}
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
Err(VarError::NotPresent) => {
(Ok(v), _) => {
info!("Loaded JWT token for authentication with Safekeeper");
warn!("Please update pageserver configuration: the JWT token should be NEON_AUTH_TOKEN, not ZENITH_AUTH_TOKEN");
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
(_, Err(VarError::NotPresent)) => {
info!("No JWT token for authentication with Safekeeper detected");
}
Err(e) => {
(_, Err(e)) => {
return Err(e).with_context(|| {
"Failed to either load to detect non-present NEON_AUTH_TOKEN environment variable"
})
@@ -320,34 +320,14 @@ fn start_pageserver(
// Scan the local 'tenants/' directory and start loading the tenants
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(conf, remote_storage.clone()))?;
// shared state between the disk-usage backed eviction background task and the http endpoint
// that allows triggering disk-usage based eviction manually. note that the http endpoint
// is still accessible even if background task is not configured as long as remote storage has
// been configured.
let disk_usage_eviction_state: Arc<disk_usage_eviction_task::State> = Arc::default();
if let Some(remote_storage) = &remote_storage {
launch_disk_usage_global_eviction_task(
conf,
remote_storage.clone(),
disk_usage_eviction_state.clone(),
)?;
}
// Start up the service to handle HTTP mgmt API request. We created the
// listener earlier already.
{
let _rt_guard = MGMT_REQUEST_RUNTIME.enter();
let router = http::make_router(
conf,
launch_ts,
http_auth,
remote_storage,
disk_usage_eviction_state,
)?
.build()
.map_err(|err| anyhow!(err))?;
let router = http::make_router(conf, launch_ts, auth.clone(), remote_storage)?
.build()
.map_err(|err| anyhow!(err))?;
let service = utils::http::RouterService::new(router).unwrap();
let server = hyper::Server::from_tcp(http_listener)?
.serve(service)
@@ -419,9 +399,9 @@ fn start_pageserver(
async move {
page_service::libpq_listener_main(
conf,
pg_auth,
auth,
pageserver_listener,
conf.pg_auth_type,
conf.auth_type,
libpq_ctx,
)
.await

View File

@@ -21,13 +21,12 @@ use std::time::Duration;
use toml_edit;
use toml_edit::{Document, Item};
use postgres_backend::AuthType;
use utils::{
id::{NodeId, TenantId, TimelineId},
logging::LogFormat,
postgres_backend::AuthType,
};
use crate::disk_usage_eviction_task::DiskUsageEvictionTaskConfig;
use crate::tenant::config::TenantConf;
use crate::tenant::config::TenantConfOpt;
use crate::tenant::{TENANT_ATTACHING_MARKER_FILENAME, TIMELINES_SEGMENT_NAME};
@@ -62,7 +61,6 @@ pub mod defaults {
pub const DEFAULT_CACHED_METRIC_COLLECTION_INTERVAL: &str = "1 hour";
pub const DEFAULT_METRIC_COLLECTION_ENDPOINT: Option<reqwest::Url> = None;
pub const DEFAULT_SYNTHETIC_SIZE_CALCULATION_INTERVAL: &str = "10 min";
pub const DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD: &str = "24 hour";
///
/// Default built-in configuration file.
@@ -91,10 +89,6 @@ pub mod defaults {
#cached_metric_collection_interval = '{DEFAULT_CACHED_METRIC_COLLECTION_INTERVAL}'
#synthetic_size_calculation_interval = '{DEFAULT_SYNTHETIC_SIZE_CALCULATION_INTERVAL}'
#evictions_low_residence_duration_metric_threshold = '{DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD}'
#disk_usage_based_eviction = {{ max_usage_pct = .., min_avail_bytes = .., period = "10s"}}
# [tenant_config]
#checkpoint_distance = {DEFAULT_CHECKPOINT_DISTANCE} # in bytes
#checkpoint_timeout = {DEFAULT_CHECKPOINT_TIMEOUT}
@@ -107,8 +101,6 @@ pub mod defaults {
#image_creation_threshold = {DEFAULT_IMAGE_CREATION_THRESHOLD}
#pitr_interval = '{DEFAULT_PITR_INTERVAL}'
#min_resident_size_override = .. # in bytes
# [remote_storage]
"###
@@ -126,9 +118,6 @@ pub struct PageServerConf {
/// Example (default): 127.0.0.1:9898
pub listen_http_addr: String,
/// Current availability zone. Used for traffic metrics.
pub availability_zone: Option<String>,
// Timeout when waiting for WAL receiver to catch up to an LSN given in a GetPage@LSN call.
pub wait_lsn_timeout: Duration,
// How long to wait for WAL redo to complete.
@@ -149,15 +138,9 @@ pub struct PageServerConf {
pub pg_distrib_dir: PathBuf,
// Authentication
/// authentication method for the HTTP mgmt API
pub http_auth_type: AuthType,
/// authentication method for libpq connections from compute
pub pg_auth_type: AuthType,
/// Path to a file containing public key for verifying JWT tokens.
/// Used for both mgmt and compute auth, if enabled.
pub auth_validation_public_key_path: Option<PathBuf>,
pub auth_type: AuthType,
pub auth_validation_public_key_path: Option<PathBuf>,
pub remote_storage_config: Option<RemoteStorageConfig>,
pub default_tenant_conf: TenantConf,
@@ -178,11 +161,6 @@ pub struct PageServerConf {
pub metric_collection_endpoint: Option<Url>,
pub synthetic_size_calculation_interval: Duration,
// See the corresponding metric's help string.
pub evictions_low_residence_duration_metric_threshold: Duration,
pub disk_usage_based_eviction: Option<DiskUsageEvictionTaskConfig>,
pub test_remote_failures: u64,
pub ondemand_download_behavior_treat_error_as_warn: bool,
@@ -218,8 +196,6 @@ struct PageServerConfigBuilder {
listen_http_addr: BuilderValue<String>,
availability_zone: BuilderValue<Option<String>>,
wait_lsn_timeout: BuilderValue<Duration>,
wal_redo_timeout: BuilderValue<Duration>,
@@ -232,8 +208,7 @@ struct PageServerConfigBuilder {
pg_distrib_dir: BuilderValue<PathBuf>,
http_auth_type: BuilderValue<AuthType>,
pg_auth_type: BuilderValue<AuthType>,
auth_type: BuilderValue<AuthType>,
//
auth_validation_public_key_path: BuilderValue<Option<PathBuf>>,
@@ -253,10 +228,6 @@ struct PageServerConfigBuilder {
metric_collection_endpoint: BuilderValue<Option<Url>>,
synthetic_size_calculation_interval: BuilderValue<Duration>,
evictions_low_residence_duration_metric_threshold: BuilderValue<Duration>,
disk_usage_based_eviction: BuilderValue<Option<DiskUsageEvictionTaskConfig>>,
test_remote_failures: BuilderValue<u64>,
ondemand_download_behavior_treat_error_as_warn: BuilderValue<bool>,
@@ -269,7 +240,6 @@ impl Default for PageServerConfigBuilder {
Self {
listen_pg_addr: Set(DEFAULT_PG_LISTEN_ADDR.to_string()),
listen_http_addr: Set(DEFAULT_HTTP_LISTEN_ADDR.to_string()),
availability_zone: Set(None),
wait_lsn_timeout: Set(humantime::parse_duration(DEFAULT_WAIT_LSN_TIMEOUT)
.expect("cannot parse default wait lsn timeout")),
wal_redo_timeout: Set(humantime::parse_duration(DEFAULT_WAL_REDO_TIMEOUT)
@@ -281,8 +251,7 @@ impl Default for PageServerConfigBuilder {
pg_distrib_dir: Set(env::current_dir()
.expect("cannot access current directory")
.join("pg_install")),
http_auth_type: Set(AuthType::Trust),
pg_auth_type: Set(AuthType::Trust),
auth_type: Set(AuthType::Trust),
auth_validation_public_key_path: Set(None),
remote_storage_config: Set(None),
id: NotSet,
@@ -310,13 +279,6 @@ impl Default for PageServerConfigBuilder {
.expect("cannot parse default synthetic size calculation interval")),
metric_collection_endpoint: Set(DEFAULT_METRIC_COLLECTION_ENDPOINT),
evictions_low_residence_duration_metric_threshold: Set(humantime::parse_duration(
DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD,
)
.expect("cannot parse DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD")),
disk_usage_based_eviction: Set(None),
test_remote_failures: Set(0),
ondemand_download_behavior_treat_error_as_warn: Set(false),
@@ -333,10 +295,6 @@ impl PageServerConfigBuilder {
self.listen_http_addr = BuilderValue::Set(listen_http_addr)
}
pub fn availability_zone(&mut self, availability_zone: Option<String>) {
self.availability_zone = BuilderValue::Set(availability_zone)
}
pub fn wait_lsn_timeout(&mut self, wait_lsn_timeout: Duration) {
self.wait_lsn_timeout = BuilderValue::Set(wait_lsn_timeout)
}
@@ -365,12 +323,8 @@ impl PageServerConfigBuilder {
self.pg_distrib_dir = BuilderValue::Set(pg_distrib_dir)
}
pub fn http_auth_type(&mut self, auth_type: AuthType) {
self.http_auth_type = BuilderValue::Set(auth_type)
}
pub fn pg_auth_type(&mut self, auth_type: AuthType) {
self.pg_auth_type = BuilderValue::Set(auth_type)
pub fn auth_type(&mut self, auth_type: AuthType) {
self.auth_type = BuilderValue::Set(auth_type)
}
pub fn auth_validation_public_key_path(
@@ -432,14 +386,6 @@ impl PageServerConfigBuilder {
self.test_remote_failures = BuilderValue::Set(fail_first);
}
pub fn evictions_low_residence_duration_metric_threshold(&mut self, value: Duration) {
self.evictions_low_residence_duration_metric_threshold = BuilderValue::Set(value);
}
pub fn disk_usage_based_eviction(&mut self, value: Option<DiskUsageEvictionTaskConfig>) {
self.disk_usage_based_eviction = BuilderValue::Set(value);
}
pub fn ondemand_download_behavior_treat_error_as_warn(
&mut self,
ondemand_download_behavior_treat_error_as_warn: bool,
@@ -456,9 +402,6 @@ impl PageServerConfigBuilder {
listen_http_addr: self
.listen_http_addr
.ok_or(anyhow!("missing listen_http_addr"))?,
availability_zone: self
.availability_zone
.ok_or(anyhow!("missing availability_zone"))?,
wait_lsn_timeout: self
.wait_lsn_timeout
.ok_or(anyhow!("missing wait_lsn_timeout"))?,
@@ -476,10 +419,7 @@ impl PageServerConfigBuilder {
pg_distrib_dir: self
.pg_distrib_dir
.ok_or(anyhow!("missing pg_distrib_dir"))?,
http_auth_type: self
.http_auth_type
.ok_or(anyhow!("missing http_auth_type"))?,
pg_auth_type: self.pg_auth_type.ok_or(anyhow!("missing pg_auth_type"))?,
auth_type: self.auth_type.ok_or(anyhow!("missing auth_type"))?,
auth_validation_public_key_path: self
.auth_validation_public_key_path
.ok_or(anyhow!("missing auth_validation_public_key_path"))?,
@@ -513,14 +453,6 @@ impl PageServerConfigBuilder {
synthetic_size_calculation_interval: self
.synthetic_size_calculation_interval
.ok_or(anyhow!("missing synthetic_size_calculation_interval"))?,
evictions_low_residence_duration_metric_threshold: self
.evictions_low_residence_duration_metric_threshold
.ok_or(anyhow!(
"missing evictions_low_residence_duration_metric_threshold"
))?,
disk_usage_based_eviction: self
.disk_usage_based_eviction
.ok_or(anyhow!("missing disk_usage_based_eviction"))?,
test_remote_failures: self
.test_remote_failures
.ok_or(anyhow!("missing test_remote_failuers"))?,
@@ -667,7 +599,6 @@ impl PageServerConf {
match key {
"listen_pg_addr" => builder.listen_pg_addr(parse_toml_string(key, item)?),
"listen_http_addr" => builder.listen_http_addr(parse_toml_string(key, item)?),
"availability_zone" => builder.availability_zone(Some(parse_toml_string(key, item)?)),
"wait_lsn_timeout" => builder.wait_lsn_timeout(parse_toml_duration(key, item)?),
"wal_redo_timeout" => builder.wal_redo_timeout(parse_toml_duration(key, item)?),
"initial_superuser_name" => builder.superuser(parse_toml_string(key, item)?),
@@ -681,8 +612,7 @@ impl PageServerConf {
"auth_validation_public_key_path" => builder.auth_validation_public_key_path(Some(
PathBuf::from(parse_toml_string(key, item)?),
)),
"http_auth_type" => builder.http_auth_type(parse_toml_from_str(key, item)?),
"pg_auth_type" => builder.pg_auth_type(parse_toml_from_str(key, item)?),
"auth_type" => builder.auth_type(parse_toml_from_str(key, item)?),
"remote_storage" => {
builder.remote_storage_config(RemoteStorageConfig::from_toml(item)?)
}
@@ -710,13 +640,6 @@ impl PageServerConf {
"synthetic_size_calculation_interval" =>
builder.synthetic_size_calculation_interval(parse_toml_duration(key, item)?),
"test_remote_failures" => builder.test_remote_failures(parse_toml_u64(key, item)?),
"evictions_low_residence_duration_metric_threshold" => builder.evictions_low_residence_duration_metric_threshold(parse_toml_duration(key, item)?),
"disk_usage_based_eviction" => {
tracing::info!("disk_usage_based_eviction: {:#?}", &item);
builder.disk_usage_based_eviction(
toml_edit::de::from_item(item.clone())
.context("parse disk_usage_based_eviction")?)
},
"ondemand_download_behavior_treat_error_as_warn" => builder.ondemand_download_behavior_treat_error_as_warn(parse_toml_bool(key, item)?),
_ => bail!("unrecognized pageserver option '{key}'"),
}
@@ -724,7 +647,7 @@ impl PageServerConf {
let mut conf = builder.build().context("invalid config")?;
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
if conf.auth_type == AuthType::NeonJWT {
let auth_validation_public_key_path = conf
.auth_validation_public_key_path
.get_or_insert_with(|| workdir.join("auth_public_key.pem"));
@@ -775,12 +698,6 @@ impl PageServerConf {
Some(parse_toml_u64("compaction_threshold", compaction_threshold)?.try_into()?);
}
if let Some(image_creation_threshold) = item.get("image_creation_threshold") {
t_conf.image_creation_threshold = Some(
parse_toml_u64("image_creation_threshold", image_creation_threshold)?.try_into()?,
);
}
if let Some(gc_horizon) = item.get("gc_horizon") {
t_conf.gc_horizon = Some(parse_toml_u64("gc_horizon", gc_horizon)?);
}
@@ -821,13 +738,6 @@ impl PageServerConf {
);
}
if let Some(item) = item.get("min_resident_size_override") {
t_conf.min_resident_size_override = Some(
toml_edit::de::from_item(item.clone())
.context("parse min_resident_size_override")?,
);
}
Ok(t_conf)
}
@@ -847,12 +757,10 @@ impl PageServerConf {
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
availability_zone: None,
superuser: "cloud_admin".to_string(),
workdir: repo_dir,
pg_distrib_dir,
http_auth_type: AuthType::Trust,
pg_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_validation_public_key_path: None,
remote_storage_config: None,
default_tenant_conf: TenantConf::default(),
@@ -864,11 +772,6 @@ impl PageServerConf {
cached_metric_collection_interval: Duration::from_secs(60 * 60),
metric_collection_endpoint: defaults::DEFAULT_METRIC_COLLECTION_ENDPOINT,
synthetic_size_calculation_interval: Duration::from_secs(60),
evictions_low_residence_duration_metric_threshold: humantime::parse_duration(
defaults::DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD,
)
.unwrap(),
disk_usage_based_eviction: None,
test_remote_failures: 0,
ondemand_download_behavior_treat_error_as_warn: false,
}
@@ -1010,9 +913,6 @@ metric_collection_interval = '222 s'
cached_metric_collection_interval = '22200 s'
metric_collection_endpoint = 'http://localhost:80/metrics'
synthetic_size_calculation_interval = '333 s'
evictions_low_residence_duration_metric_threshold = '444 s'
log_format = 'json'
"#;
@@ -1038,7 +938,6 @@ log_format = 'json'
id: NodeId(10),
listen_pg_addr: defaults::DEFAULT_PG_LISTEN_ADDR.to_string(),
listen_http_addr: defaults::DEFAULT_HTTP_LISTEN_ADDR.to_string(),
availability_zone: None,
wait_lsn_timeout: humantime::parse_duration(defaults::DEFAULT_WAIT_LSN_TIMEOUT)?,
wal_redo_timeout: humantime::parse_duration(defaults::DEFAULT_WAL_REDO_TIMEOUT)?,
superuser: defaults::DEFAULT_SUPERUSER.to_string(),
@@ -1046,8 +945,7 @@ log_format = 'json'
max_file_descriptors: defaults::DEFAULT_MAX_FILE_DESCRIPTORS,
workdir,
pg_distrib_dir,
http_auth_type: AuthType::Trust,
pg_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_validation_public_key_path: None,
remote_storage_config: None,
default_tenant_conf: TenantConf::default(),
@@ -1067,10 +965,6 @@ log_format = 'json'
synthetic_size_calculation_interval: humantime::parse_duration(
defaults::DEFAULT_SYNTHETIC_SIZE_CALCULATION_INTERVAL
)?,
evictions_low_residence_duration_metric_threshold: humantime::parse_duration(
defaults::DEFAULT_EVICTIONS_LOW_RESIDENCE_DURATION_METRIC_THRESHOLD
)?,
disk_usage_based_eviction: None,
test_remote_failures: 0,
ondemand_download_behavior_treat_error_as_warn: false,
},
@@ -1101,7 +995,6 @@ log_format = 'json'
id: NodeId(10),
listen_pg_addr: "127.0.0.1:64000".to_string(),
listen_http_addr: "127.0.0.1:9898".to_string(),
availability_zone: None,
wait_lsn_timeout: Duration::from_secs(111),
wal_redo_timeout: Duration::from_secs(111),
superuser: "zzzz".to_string(),
@@ -1109,8 +1002,7 @@ log_format = 'json'
max_file_descriptors: 333,
workdir,
pg_distrib_dir,
http_auth_type: AuthType::Trust,
pg_auth_type: AuthType::Trust,
auth_type: AuthType::Trust,
auth_validation_public_key_path: None,
remote_storage_config: None,
default_tenant_conf: TenantConf::default(),
@@ -1122,8 +1014,6 @@ log_format = 'json'
cached_metric_collection_interval: Duration::from_secs(22200),
metric_collection_endpoint: Some(Url::parse("http://localhost:80/metrics")?),
synthetic_size_calculation_interval: Duration::from_secs(333),
evictions_low_residence_duration_metric_threshold: Duration::from_secs(444),
disk_usage_based_eviction: None,
test_remote_failures: 0,
ondemand_download_behavior_treat_error_as_warn: false,
},

View File

@@ -1,679 +0,0 @@
//! This module implements the pageserver-global disk-usage-based layer eviction task.
//!
//! # Mechanics
//!
//! Function `launch_disk_usage_global_eviction_task` starts a pageserver-global background
//! loop that evicts layers in response to a shortage of available bytes
//! in the $repo/tenants directory's filesystem.
//!
//! The loop runs periodically at a configurable `period`.
//!
//! Each loop iteration uses `statvfs` to determine filesystem-level space usage.
//! It compares the returned usage data against two different types of thresholds.
//! The iteration tries to evict layers until app-internal accounting says we should be below the thresholds.
//! We cross-check this internal accounting with the real world by making another `statvfs` at the end of the iteration.
//! We're good if that second statvfs shows that we're _actually_ below the configured thresholds.
//! If we're still above one or more thresholds, we emit a warning log message, leaving it to the operator to investigate further.
//!
//! # Eviction Policy
//!
//! There are two thresholds:
//! `max_usage_pct` is the relative available space, expressed in percent of the total filesystem space.
//! If the actual usage is higher, the threshold is exceeded.
//! `min_avail_bytes` is the absolute available space in bytes.
//! If the actual usage is lower, the threshold is exceeded.
//!
//! The iteration evicts layers in LRU fashion, but, with a weak reservation per tenant.
//! The reservation is to keep the most recently accessed X bytes per tenant resident.
//! If we cannot relieve pressure by evicting layers outside of the reservation, we
//! start evicting layers that are part of the reservation, LRU first.
//!
//! The value for the per-tenant reservation is referred to as `tenant_min_resident_size`
//! throughout the code, but, no actual variable carries that name.
//! The per-tenant default value is the `max(tenant's layer file sizes, regardless of local or remote)`.
//! The idea is to allow at least one layer to be resident per tenant, to ensure it can make forward progress
//! during page reconstruction.
//! An alternative default for all tenants can be specified in the `tenant_config` section of the config.
//! Lastly, each tenant can have an override in their respectice tenant config (`min_resident_size_override`).
use std::{
collections::HashMap,
path::Path,
sync::Arc,
time::{Duration, SystemTime},
};
use anyhow::Context;
use remote_storage::GenericRemoteStorage;
use serde::{Deserialize, Serialize};
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, instrument, warn, Instrument};
use utils::serde_percent::Percent;
use crate::{
config::PageServerConf,
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
tenant::{self, storage_layer::PersistentLayer, Timeline},
};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct DiskUsageEvictionTaskConfig {
pub max_usage_pct: Percent,
pub min_avail_bytes: u64,
#[serde(with = "humantime_serde")]
pub period: Duration,
#[cfg(feature = "testing")]
pub mock_statvfs: Option<crate::statvfs::mock::Behavior>,
}
#[derive(Default)]
pub struct State {
/// Exclude http requests and background task from running at the same time.
mutex: tokio::sync::Mutex<()>,
}
pub fn launch_disk_usage_global_eviction_task(
conf: &'static PageServerConf,
storage: GenericRemoteStorage,
state: Arc<State>,
) -> anyhow::Result<()> {
let Some(task_config) = &conf.disk_usage_based_eviction else {
info!("disk usage based eviction task not configured");
return Ok(());
};
info!("launching disk usage based eviction task");
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::DiskUsageEviction,
None,
None,
"disk usage based eviction",
false,
async move {
disk_usage_eviction_task(
&state,
task_config,
storage,
&conf.tenants_path(),
task_mgr::shutdown_token(),
)
.await;
info!("disk usage based eviction task finishing");
Ok(())
},
);
Ok(())
}
#[instrument(skip_all)]
async fn disk_usage_eviction_task(
state: &State,
task_config: &DiskUsageEvictionTaskConfig,
storage: GenericRemoteStorage,
tenants_dir: &Path,
cancel: CancellationToken,
) {
use crate::tenant::tasks::random_init_delay;
{
if random_init_delay(task_config.period, &cancel)
.await
.is_err()
{
info!("shutting down");
return;
}
}
let mut iteration_no = 0;
loop {
iteration_no += 1;
let start = Instant::now();
async {
let res = disk_usage_eviction_task_iteration(
state,
task_config,
&storage,
tenants_dir,
&cancel,
)
.await;
match res {
Ok(()) => {}
Err(e) => {
// these stat failures are expected to be very rare
warn!("iteration failed, unexpected error: {e:#}");
}
}
}
.instrument(tracing::info_span!("iteration", iteration_no))
.await;
let sleep_until = start + task_config.period;
tokio::select! {
_ = tokio::time::sleep_until(sleep_until) => {},
_ = cancel.cancelled() => {
info!("shutting down");
break
}
}
}
}
pub trait Usage: Clone + Copy + std::fmt::Debug {
fn has_pressure(&self) -> bool;
fn add_available_bytes(&mut self, bytes: u64);
}
async fn disk_usage_eviction_task_iteration(
state: &State,
task_config: &DiskUsageEvictionTaskConfig,
storage: &GenericRemoteStorage,
tenants_dir: &Path,
cancel: &CancellationToken,
) -> anyhow::Result<()> {
let usage_pre = filesystem_level_usage::get(tenants_dir, task_config)
.context("get filesystem-level disk usage before evictions")?;
let res = disk_usage_eviction_task_iteration_impl(state, storage, usage_pre, cancel).await;
match res {
Ok(outcome) => {
debug!(?outcome, "disk_usage_eviction_iteration finished");
match outcome {
IterationOutcome::NoPressure | IterationOutcome::Cancelled => {
// nothing to do, select statement below will handle things
}
IterationOutcome::Finished(outcome) => {
// Verify with statvfs whether we made any real progress
let after = filesystem_level_usage::get(tenants_dir, task_config)
// It's quite unlikely to hit the error here. Keep the code simple and bail out.
.context("get filesystem-level disk usage after evictions")?;
debug!(?after, "disk usage");
if after.has_pressure() {
// Don't bother doing an out-of-order iteration here now.
// In practice, the task period is set to a value in the tens-of-seconds range,
// which will cause another iteration to happen soon enough.
// TODO: deltas between the three different usages would be helpful,
// consider MiB, GiB, TiB
warn!(?outcome, ?after, "disk usage still high");
} else {
info!(?outcome, ?after, "disk usage pressure relieved");
}
}
}
}
Err(e) => {
error!("disk_usage_eviction_iteration failed: {:#}", e);
}
}
Ok(())
}
#[derive(Debug, Serialize)]
#[allow(clippy::large_enum_variant)]
pub enum IterationOutcome<U> {
NoPressure,
Cancelled,
Finished(IterationOutcomeFinished<U>),
}
#[derive(Debug, Serialize)]
pub struct IterationOutcomeFinished<U> {
/// The actual usage observed before we started the iteration.
before: U,
/// The expected value for `after`, according to internal accounting, after phase 1.
planned: PlannedUsage<U>,
/// The outcome of phase 2, where we actually do the evictions.
///
/// If all layers that phase 1 planned to evict _can_ actually get evicted, this will
/// be the same as `planned`.
assumed: AssumedUsage<U>,
}
#[derive(Debug, Serialize)]
struct AssumedUsage<U> {
/// The expected value for `after`, after phase 2.
projected_after: U,
/// The layers we failed to evict during phase 2.
failed: LayerCount,
}
#[derive(Debug, Serialize)]
struct PlannedUsage<U> {
respecting_tenant_min_resident_size: U,
fallback_to_global_lru: Option<U>,
}
#[derive(Debug, Default, Serialize)]
struct LayerCount {
file_sizes: u64,
count: usize,
}
pub async fn disk_usage_eviction_task_iteration_impl<U: Usage>(
state: &State,
storage: &GenericRemoteStorage,
usage_pre: U,
cancel: &CancellationToken,
) -> anyhow::Result<IterationOutcome<U>> {
// use tokio's mutex to get a Sync guard (instead of std::sync::Mutex)
let _g = state
.mutex
.try_lock()
.map_err(|_| anyhow::anyhow!("iteration is already executing"))?;
debug!(?usage_pre, "disk usage");
if !usage_pre.has_pressure() {
return Ok(IterationOutcome::NoPressure);
}
warn!(
?usage_pre,
"running disk usage based eviction due to pressure"
);
let candidates = match collect_eviction_candidates(cancel).await? {
EvictionCandidates::Cancelled => {
return Ok(IterationOutcome::Cancelled);
}
EvictionCandidates::Finished(partitioned) => partitioned,
};
// Debug-log the list of candidates
let now = SystemTime::now();
for (i, (partition, candidate)) in candidates.iter().enumerate() {
debug!(
"cand {}/{}: size={}, no_access_for={}us, parition={:?}, tenant={} timeline={} layer={}",
i + 1,
candidates.len(),
candidate.layer.file_size(),
now.duration_since(candidate.last_activity_ts)
.unwrap()
.as_micros(),
partition,
candidate.layer.get_tenant_id(),
candidate.layer.get_timeline_id(),
candidate.layer.filename().file_name(),
);
}
// phase1: select victims to relieve pressure
//
// Walk through the list of candidates, until we have accumulated enough layers to get
// us back under the pressure threshold. 'usage_planned' is updated so that it tracks
// how much disk space would be used after evicting all the layers up to the current
// point in the list. The layers are collected in 'batched', grouped per timeline.
//
// If we get far enough in the list that we start to evict layers that are below
// the tenant's min-resident-size threshold, print a warning, and memorize the disk
// usage at that point, in 'usage_planned_min_resident_size_respecting'.
let mut batched: HashMap<_, Vec<Arc<dyn PersistentLayer>>> = HashMap::new();
let mut warned = None;
let mut usage_planned = usage_pre;
for (i, (partition, candidate)) in candidates.into_iter().enumerate() {
if !usage_planned.has_pressure() {
debug!(
no_candidates_evicted = i,
"took enough candidates for pressure to be relieved"
);
break;
}
if partition == MinResidentSizePartition::Below && warned.is_none() {
warn!(?usage_pre, ?usage_planned, candidate_no=i, "tenant_min_resident_size-respecting LRU would not relieve pressure, evicting more following global LRU policy");
warned = Some(usage_planned);
}
usage_planned.add_available_bytes(candidate.layer.file_size());
batched
.entry(TimelineKey(candidate.timeline))
.or_default()
.push(candidate.layer);
}
let usage_planned = match warned {
Some(respecting_tenant_min_resident_size) => PlannedUsage {
respecting_tenant_min_resident_size,
fallback_to_global_lru: Some(usage_planned),
},
None => PlannedUsage {
respecting_tenant_min_resident_size: usage_planned,
fallback_to_global_lru: None,
},
};
debug!(?usage_planned, "usage planned");
// phase2: evict victims batched by timeline
// After the loop, `usage_assumed` is the post-eviction usage,
// according to internal accounting.
let mut usage_assumed = usage_pre;
let mut evictions_failed = LayerCount::default();
for (timeline, batch) in batched {
let tenant_id = timeline.tenant_id;
let timeline_id = timeline.timeline_id;
let batch_size = batch.len();
debug!(%timeline_id, "evicting batch for timeline");
async {
let results = timeline.evict_layers(storage, &batch, cancel.clone()).await;
match results {
Err(e) => {
warn!("failed to evict batch: {:#}", e);
}
Ok(results) => {
assert_eq!(results.len(), batch.len());
for (result, layer) in results.into_iter().zip(batch.iter()) {
match result {
Some(Ok(true)) => {
usage_assumed.add_available_bytes(layer.file_size());
}
Some(Ok(false)) => {
// this is:
// - Replacement::{NotFound, Unexpected}
// - it cannot be is_remote_layer, filtered already
evictions_failed.file_sizes += layer.file_size();
evictions_failed.count += 1;
}
None => {
assert!(cancel.is_cancelled());
return;
}
Some(Err(e)) => {
// we really shouldn't be getting this, precondition failure
error!("failed to evict layer: {:#}", e);
}
}
}
}
}
}
.instrument(tracing::info_span!("evict_batch", %tenant_id, %timeline_id, batch_size))
.await;
if cancel.is_cancelled() {
return Ok(IterationOutcome::Cancelled);
}
}
Ok(IterationOutcome::Finished(IterationOutcomeFinished {
before: usage_pre,
planned: usage_planned,
assumed: AssumedUsage {
projected_after: usage_assumed,
failed: evictions_failed,
},
}))
}
#[derive(Clone)]
struct EvictionCandidate {
timeline: Arc<Timeline>,
layer: Arc<dyn PersistentLayer>,
last_activity_ts: SystemTime,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum MinResidentSizePartition {
Above,
Below,
}
enum EvictionCandidates {
Cancelled,
Finished(Vec<(MinResidentSizePartition, EvictionCandidate)>),
}
/// Gather the eviction candidates.
///
/// The returned `Ok(EvictionCandidates::Finished(candidates))` is sorted in eviction
/// order. A caller that evicts in that order, until pressure is relieved, implements
/// the eviction policy outlined in the module comment.
///
/// # Example
///
/// Imagine that there are two tenants, A and B, with five layers each, a-e.
/// Each layer has size 100, and both tenant's min_resident_size is 150.
/// The eviction order would be
///
/// ```text
/// partition last_activity_ts tenant/layer
/// Above 18:30 A/c
/// Above 19:00 A/b
/// Above 18:29 B/c
/// Above 19:05 B/b
/// Above 20:00 B/a
/// Above 20:03 A/a
/// Below 20:30 A/d
/// Below 20:40 B/d
/// Below 20:45 B/e
/// Below 20:58 A/e
/// ```
///
/// Now, if we need to evict 300 bytes to relieve pressure, we'd evict `A/c, A/b, B/c`.
/// They are all in the `Above` partition, so, we respected each tenant's min_resident_size.
///
/// But, if we need to evict 900 bytes to relieve pressure, we'd evict
/// `A/c, A/b, B/c, B/b, B/a, A/a, A/d, B/d, B/e`, reaching into the `Below` partition
/// after exhauting the `Above` partition.
/// So, we did not respect each tenant's min_resident_size.
async fn collect_eviction_candidates(
cancel: &CancellationToken,
) -> anyhow::Result<EvictionCandidates> {
// get a snapshot of the list of tenants
let tenants = tenant::mgr::list_tenants()
.await
.context("get list of tenants")?;
let mut candidates = Vec::new();
for (tenant_id, _state) in &tenants {
if cancel.is_cancelled() {
return Ok(EvictionCandidates::Cancelled);
}
let tenant = match tenant::mgr::get_tenant(*tenant_id, true).await {
Ok(tenant) => tenant,
Err(e) => {
// this can happen if tenant has lifecycle transition after we fetched it
debug!("failed to get tenant: {e:#}");
continue;
}
};
// collect layers from all timelines in this tenant
//
// If one of the timelines becomes `!is_active()` during the iteration,
// for example because we're shutting down, then `max_layer_size` can be too small.
// That's OK. This code only runs under a disk pressure situation, and being
// a little unfair to tenants during shutdown in such a situation is tolerable.
let mut tenant_candidates = Vec::new();
let mut max_layer_size = 0;
for tl in tenant.list_timelines() {
if !tl.is_active() {
continue;
}
let info = tl.get_local_layers_for_disk_usage_eviction();
debug!(tenant_id=%tl.tenant_id, timeline_id=%tl.timeline_id, "timeline resident layers count: {}", info.resident_layers.len());
tenant_candidates.extend(
info.resident_layers
.into_iter()
.map(|layer_infos| (tl.clone(), layer_infos)),
);
max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0));
if cancel.is_cancelled() {
return Ok(EvictionCandidates::Cancelled);
}
}
// `min_resident_size` defaults to maximum layer file size of the tenant.
// This ensures that each tenant can have at least one layer resident at a given time,
// ensuring forward progress for a single Timeline::get in that tenant.
// It's a questionable heuristic since, usually, there are many Timeline::get
// requests going on for a tenant, and, at least in Neon prod, the median
// layer file size is much smaller than the compaction target size.
// We could be better here, e.g., sum of all L0 layers + most recent L1 layer.
// That's what's typically used by the various background loops.
//
// The default can be overriden with a fixed value in the tenant conf.
// A default override can be put in the default tenant conf in the pageserver.toml.
let min_resident_size = if let Some(s) = tenant.get_min_resident_size_override() {
debug!(
tenant_id=%tenant.tenant_id(),
overriden_size=s,
"using overridden min resident size for tenant"
);
s
} else {
debug!(
tenant_id=%tenant.tenant_id(),
max_layer_size,
"using max layer size as min_resident_size for tenant",
);
max_layer_size
};
// Sort layers most-recently-used first, then partition by
// cumsum above/below min_resident_size.
tenant_candidates
.sort_unstable_by_key(|(_, layer_info)| std::cmp::Reverse(layer_info.last_activity_ts));
let mut cumsum: i128 = 0;
for (timeline, layer_info) in tenant_candidates.into_iter() {
let file_size = layer_info.file_size();
let candidate = EvictionCandidate {
timeline,
last_activity_ts: layer_info.last_activity_ts,
layer: layer_info.layer,
};
let partition = if cumsum > min_resident_size as i128 {
MinResidentSizePartition::Above
} else {
MinResidentSizePartition::Below
};
candidates.push((partition, candidate));
cumsum += i128::from(file_size);
}
}
debug_assert!(MinResidentSizePartition::Above < MinResidentSizePartition::Below,
"as explained in the function's doc comment, layers that aren't in the tenant's min_resident_size are evicted first");
candidates
.sort_unstable_by_key(|(partition, candidate)| (*partition, candidate.last_activity_ts));
Ok(EvictionCandidates::Finished(candidates))
}
struct TimelineKey(Arc<Timeline>);
impl PartialEq for TimelineKey {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl Eq for TimelineKey {}
impl std::hash::Hash for TimelineKey {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
Arc::as_ptr(&self.0).hash(state);
}
}
impl std::ops::Deref for TimelineKey {
type Target = Timeline;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
mod filesystem_level_usage {
use std::path::Path;
use anyhow::Context;
use crate::statvfs::Statvfs;
use super::DiskUsageEvictionTaskConfig;
#[derive(Debug, Clone, Copy)]
pub struct Usage<'a> {
config: &'a DiskUsageEvictionTaskConfig,
/// Filesystem capacity
total_bytes: u64,
/// Free filesystem space
avail_bytes: u64,
}
impl super::Usage for Usage<'_> {
fn has_pressure(&self) -> bool {
let usage_pct =
(100.0 * (1.0 - ((self.avail_bytes as f64) / (self.total_bytes as f64)))) as u64;
let pressures = [
(
"min_avail_bytes",
self.avail_bytes < self.config.min_avail_bytes,
),
(
"max_usage_pct",
usage_pct > self.config.max_usage_pct.get() as u64,
),
];
pressures.into_iter().any(|(_, has_pressure)| has_pressure)
}
fn add_available_bytes(&mut self, bytes: u64) {
self.avail_bytes += bytes;
}
}
pub fn get<'a>(
tenants_dir: &Path,
config: &'a DiskUsageEvictionTaskConfig,
) -> anyhow::Result<Usage<'a>> {
let mock_config = {
#[cfg(feature = "testing")]
{
config.mock_statvfs.as_ref()
}
#[cfg(not(feature = "testing"))]
{
None
}
};
let stat = Statvfs::get(tenants_dir, mock_config)
.context("statvfs failed, presumably directory got unlinked")?;
// https://unix.stackexchange.com/a/703650
let blocksize = if stat.fragment_size() > 0 {
stat.fragment_size()
} else {
stat.block_size()
};
// use blocks_available (b_avail) since, pageserver runs as unprivileged user
let avail_bytes = stat.blocks_available() * blocksize;
let total_bytes = stat.blocks() * blocksize;
Ok(Usage {
config,
total_bytes,
avail_bytes,
})
}
}

View File

@@ -27,31 +27,6 @@ paths:
id:
type: integer
/v1/disk_usage_eviction/run:
put:
description: Do an iteration of disk-usage-based eviction to evict a given amount of disk space.
security: []
requestBody:
content:
application/json:
schema:
type: object
required:
- evict_bytes
properties:
evict_bytes:
type: integer
responses:
"200":
description: |
The run completed.
This does not necessarily mean that we actually evicted `evict_bytes`.
Examine the returned object for detail, or, just watch the actual effect of the call using `du` or `df`.
content:
application/json:
schema:
type: object
/v1/tenant/{tenant_id}:
parameters:
- name: tenant_id
@@ -208,12 +183,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"404":
description: Timeline not found
content:
application/json:
schema:
$ref: "#/components/schemas/NotFoundError"
"500":
description: Generic operation error
content:
@@ -276,53 +245,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/do_gc:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
put:
description: Garbage collect given timeline
responses:
"200":
description: OK
content:
application/json:
schema:
type: string
"400":
description: Error when no tenant id found in path, no timeline id or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/attach:
parameters:
- name: tenant_id
@@ -382,13 +304,6 @@ paths:
schema:
type: string
format: hex
- name: detach_ignored
in: query
required: false
schema:
type: boolean
description: |
When true, allow to detach a tenant which state is ignored.
post:
description: |
Remove tenant data (including all corresponding timelines) from pageserver's memory and file system.
@@ -414,12 +329,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"404":
description: Tenant not found
content:
application/json:
schema:
$ref: "#/components/schemas/NotFoundError"
"500":
description: Generic operation error
content:

View File

@@ -1,16 +1,20 @@
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use anyhow::{anyhow, Context, Result};
use hyper::StatusCode;
use futures::Future;
use hyper::body::HttpBody;
use hyper::header::HeaderName;
use hyper::http::HeaderValue;
use hyper::{Body, Request, Response, Uri};
use hyper::{Method, StatusCode};
use metrics::launch_timestamp::LaunchTimestamp;
use pageserver_api::models::DownloadRemoteLayersTaskSpawnRequest;
use remote_storage::GenericRemoteStorage;
use tenant_size_model::{SizeResult, StorageModel};
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::http::endpoint::RequestSpan;
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
use super::models::{
@@ -18,11 +22,10 @@ use super::models::{
TimelineCreateRequest, TimelineGcRequest, TimelineInfo,
};
use crate::context::{DownloadBehavior, RequestContext};
use crate::disk_usage_eviction_task;
use crate::pgdatadir_mapping::LsnForTimestamp;
use crate::task_mgr::TaskKind;
use crate::tenant::config::TenantConfOpt;
use crate::tenant::mgr::{TenantMapInsertError, TenantStateError};
use crate::tenant::mgr::TenantMapInsertError;
use crate::tenant::size::ModelInputs;
use crate::tenant::storage_layer::LayerAccessStatsReset;
use crate::tenant::{PageReconstructError, Timeline};
@@ -49,7 +52,6 @@ struct State {
auth: Option<Arc<JwtAuth>>,
allowlist_routes: Vec<Uri>,
remote_storage: Option<GenericRemoteStorage>,
disk_usage_eviction_state: Arc<disk_usage_eviction_task::State>,
}
impl State {
@@ -57,7 +59,6 @@ impl State {
conf: &'static PageServerConf,
auth: Option<Arc<JwtAuth>>,
remote_storage: Option<GenericRemoteStorage>,
disk_usage_eviction_state: Arc<disk_usage_eviction_task::State>,
) -> anyhow::Result<Self> {
let allowlist_routes = ["/v1/status", "/v1/doc", "/swagger.yml"]
.iter()
@@ -68,7 +69,6 @@ impl State {
auth,
allowlist_routes,
remote_storage,
disk_usage_eviction_state,
})
}
}
@@ -86,75 +86,38 @@ fn get_config(request: &Request<Body>) -> &'static PageServerConf {
get_state(request).conf
}
/// Check that the requester is authorized to operate on given tenant
fn check_permission(request: &Request<Body>, tenant_id: Option<TenantId>) -> Result<(), ApiError> {
check_permission_with(request, |claims| {
crate::auth::check_permission(claims, tenant_id)
})
}
impl From<PageReconstructError> for ApiError {
fn from(pre: PageReconstructError) -> ApiError {
match pre {
PageReconstructError::Other(pre) => ApiError::InternalServerError(pre),
PageReconstructError::NeedsDownload(_, _) => {
// This shouldn't happen, because we use a RequestContext that requests to
// download any missing layer files on-demand.
ApiError::InternalServerError(anyhow::anyhow!("need to download remote layer file"))
}
PageReconstructError::Cancelled => {
ApiError::InternalServerError(anyhow::anyhow!("request was cancelled"))
}
PageReconstructError::WalRedo(pre) => {
ApiError::InternalServerError(anyhow::Error::new(pre))
}
fn apierror_from_prerror(err: PageReconstructError) -> ApiError {
match err {
PageReconstructError::Other(err) => ApiError::InternalServerError(err),
PageReconstructError::NeedsDownload(_, _) => {
// This shouldn't happen, because we use a RequestContext that requests to
// download any missing layer files on-demand.
ApiError::InternalServerError(anyhow::anyhow!("need to download remote layer file"))
}
PageReconstructError::Cancelled => {
ApiError::InternalServerError(anyhow::anyhow!("request was cancelled"))
}
PageReconstructError::WalRedo(err) => {
ApiError::InternalServerError(anyhow::Error::new(err))
}
}
}
impl From<TenantMapInsertError> for ApiError {
fn from(tmie: TenantMapInsertError) -> ApiError {
match tmie {
TenantMapInsertError::StillInitializing | TenantMapInsertError::ShuttingDown => {
ApiError::InternalServerError(anyhow::Error::new(tmie))
}
TenantMapInsertError::TenantAlreadyExists(id, state) => {
ApiError::Conflict(format!("tenant {id} already exists, state: {state:?}"))
}
TenantMapInsertError::Closure(e) => ApiError::InternalServerError(e),
fn apierror_from_tenant_map_insert_error(e: TenantMapInsertError) -> ApiError {
match e {
TenantMapInsertError::StillInitializing | TenantMapInsertError::ShuttingDown => {
ApiError::InternalServerError(anyhow::Error::new(e))
}
}
}
impl From<TenantStateError> for ApiError {
fn from(tse: TenantStateError) -> ApiError {
match tse {
TenantStateError::NotFound(tid) => ApiError::NotFound(anyhow!("tenant {}", tid)),
_ => ApiError::InternalServerError(anyhow::Error::new(tse)),
}
}
}
impl From<crate::tenant::DeleteTimelineError> for ApiError {
fn from(value: crate::tenant::DeleteTimelineError) -> Self {
use crate::tenant::DeleteTimelineError::*;
match value {
NotFound => ApiError::NotFound(anyhow::anyhow!("timeline not found")),
HasChildren => ApiError::BadRequest(anyhow::anyhow!(
"Cannot delete timeline which has child timelines"
)),
Other(e) => ApiError::InternalServerError(e),
}
}
}
impl From<crate::tenant::mgr::DeleteTimelineError> for ApiError {
fn from(value: crate::tenant::mgr::DeleteTimelineError) -> Self {
use crate::tenant::mgr::DeleteTimelineError::*;
match value {
Tenant(t) => ApiError::from(t),
Timeline(t) => ApiError::from(t),
TenantMapInsertError::TenantAlreadyExists(id, state) => {
ApiError::Conflict(format!("tenant {id} already exists, state: {state:?}"))
}
TenantMapInsertError::Closure(e) => ApiError::InternalServerError(e),
}
}
@@ -212,7 +175,7 @@ fn build_timeline_info_common(
None
}
};
let current_physical_size = Some(timeline.layer_size_sum());
let current_physical_size = Some(timeline.layer_size_sum().approximate_is_ok());
let state = timeline.current_state();
let remote_consistent_lsn = timeline.get_remote_consistent_lsn().unwrap_or(Lsn(0));
@@ -258,7 +221,9 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Error);
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
match tenant.create_timeline(
new_timeline_id,
request_data.ancestor_timeline_id.map(TimelineId::from),
@@ -288,7 +253,9 @@ async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>,
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let response_data = async {
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
let timelines = tenant.list_timelines();
let mut response_data = Vec::with_capacity(timelines.len());
@@ -304,7 +271,7 @@ async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>,
response_data.push(timeline_info);
}
Ok::<Vec<TimelineInfo>, ApiError>(response_data)
Ok(response_data)
}
.instrument(info_span!("timeline_list", tenant = %tenant_id))
.await?;
@@ -323,7 +290,9 @@ async fn timeline_detail_handler(request: Request<Body>) -> Result<Response<Body
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline_info = async {
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
let timeline = tenant
.get_timeline(timeline_id, false)
@@ -359,7 +328,10 @@ async fn get_lsn_by_timestamp_handler(request: Request<Body>) -> Result<Response
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
let result = timeline.find_lsn_for_timestamp(timestamp_pg, &ctx).await?;
let result = timeline
.find_lsn_for_timestamp(timestamp_pg, &ctx)
.await
.map_err(apierror_from_prerror)?;
let result = match result {
LsnForTimestamp::Present(lsn) => format!("{lsn}"),
@@ -384,7 +356,8 @@ async fn tenant_attach_handler(request: Request<Body>) -> Result<Response<Body>,
if let Some(remote_storage) = &state.remote_storage {
mgr::attach_tenant(state.conf, tenant_id, remote_storage.clone(), &ctx)
.instrument(info_span!("tenant_attach", tenant = %tenant_id))
.await?;
.await
.map_err(apierror_from_tenant_map_insert_error)?;
} else {
return Err(ApiError::BadRequest(anyhow!(
"attach_tenant is not possible because pageserver was configured without remote storage"
@@ -403,7 +376,11 @@ async fn timeline_delete_handler(request: Request<Body>) -> Result<Response<Body
mgr::delete_timeline(tenant_id, timeline_id, &ctx)
.instrument(info_span!("timeline_delete", tenant = %tenant_id, timeline = %timeline_id))
.await?;
.await
// FIXME: Errors from `delete_timeline` can occur for a number of reasons, incuding both
// user and internal errors. Replace this with better handling once the error type permits
// it.
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -411,13 +388,15 @@ async fn timeline_delete_handler(request: Request<Body>) -> Result<Response<Body
async fn tenant_detach_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
check_permission(&request, Some(tenant_id))?;
let detach_ignored: Option<bool> = parse_query_param(&request, "detach_ignored")?;
let state = get_state(&request);
let conf = state.conf;
mgr::detach_tenant(conf, tenant_id, detach_ignored.unwrap_or(false))
mgr::detach_tenant(conf, tenant_id)
.instrument(info_span!("tenant_detach", tenant = %tenant_id))
.await?;
.await
// FIXME: Errors from `detach_tenant` can be caused by both both user and internal errors.
// Replace this with better handling once the error type permits it.
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -431,7 +410,8 @@ async fn tenant_load_handler(request: Request<Body>) -> Result<Response<Body>, A
let state = get_state(&request);
mgr::load_tenant(state.conf, tenant_id, state.remote_storage.clone(), &ctx)
.instrument(info_span!("load", tenant = %tenant_id))
.await?;
.await
.map_err(apierror_from_tenant_map_insert_error)?;
json_response(StatusCode::ACCEPTED, ())
}
@@ -444,7 +424,10 @@ async fn tenant_ignore_handler(request: Request<Body>) -> Result<Response<Body>,
let conf = state.conf;
mgr::ignore_tenant(conf, tenant_id)
.instrument(info_span!("ignore_tenant", tenant = %tenant_id))
.await?;
.await
// FIXME: Errors from `ignore_tenant` can be caused by both both user and internal errors.
// Replace this with better handling once the error type permits it.
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -479,7 +462,7 @@ async fn tenant_status(request: Request<Body>) -> Result<Response<Body>, ApiErro
// Calculate total physical size of all timelines
let mut current_physical_size = 0;
for timeline in tenant.list_timelines().iter() {
current_physical_size += timeline.layer_size_sum();
current_physical_size += timeline.layer_size_sum().approximate_is_ok();
}
let state = tenant.current_state();
@@ -518,7 +501,9 @@ async fn tenant_size_handler(request: Request<Body>) -> Result<Response<Body>, A
let headers = request.headers();
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::InternalServerError)?;
// this can be long operation
let inputs = tenant
@@ -766,16 +751,6 @@ async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Bo
);
}
if let Some(eviction_policy) = request_data.eviction_policy {
tenant_conf.eviction_policy = Some(
serde_json::from_value(eviction_policy)
.context("parse field `eviction_policy`")
.map_err(ApiError::BadRequest)?,
);
}
tenant_conf.min_resident_size_override = request_data.min_resident_size_override;
let target_tenant_id = request_data
.new_tenant_id
.map(TenantId::from)
@@ -791,7 +766,8 @@ async fn tenant_create_handler(mut request: Request<Body>) -> Result<Response<Bo
&ctx,
)
.instrument(info_span!("tenant_create", tenant = ?target_tenant_id))
.await?;
.await
.map_err(apierror_from_tenant_map_insert_error)?;
// We created the tenant. Existing API semantics are that the tenant
// is Active when this function returns.
@@ -815,7 +791,9 @@ async fn get_tenant_config_handler(request: Request<Body>) -> Result<Response<Bo
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
check_permission(&request, Some(tenant_id))?;
let tenant = mgr::get_tenant(tenant_id, false).await?;
let tenant = mgr::get_tenant(tenant_id, false)
.await
.map_err(ApiError::NotFound)?;
let response = HashMap::from([
(
@@ -907,26 +885,13 @@ async fn update_tenant_config_handler(
);
}
tenant_conf.min_resident_size_override = request_data.min_resident_size_override;
let state = get_state(&request);
mgr::set_new_tenant_config(state.conf, tenant_conf, tenant_id)
.instrument(info_span!("tenant_config", tenant = ?tenant_id))
.await?;
json_response(StatusCode::OK, ())
}
/// Testing helper to transition a tenant to [`crate::tenant::TenantState::Broken`].
#[cfg(feature = "testing")]
async fn handle_tenant_break(r: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&r, "tenant_id")?;
let tenant = crate::tenant::mgr::get_tenant(tenant_id, true)
.await
.map_err(|_| ApiError::Conflict(String::from("no active tenant found")))?;
tenant.set_broken("broken from test");
// FIXME: `update_tenant_config` can fail because of both user and internal errors.
// Replace this `map_err` with better error handling once the type permits it
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -1011,22 +976,19 @@ async fn timeline_checkpoint_handler(request: Request<Body>) -> Result<Response<
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?;
check_permission(&request, Some(tenant_id))?;
async {
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
timeline
.freeze_and_flush()
.await
.map_err(ApiError::InternalServerError)?;
timeline
.compact(&ctx)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
.instrument(info_span!("manual_checkpoint", tenant_id = %tenant_id, timeline_id = %timeline_id))
.await
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
timeline
.freeze_and_flush()
.await
.map_err(ApiError::InternalServerError)?;
timeline
.compact(&ctx)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
async fn timeline_download_remote_layers_handler_post(
@@ -1063,7 +1025,9 @@ async fn active_timeline_of_active_tenant(
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<Arc<Timeline>, ApiError> {
let tenant = mgr::get_tenant(tenant_id, true).await?;
let tenant = mgr::get_tenant(tenant_id, true)
.await
.map_err(ApiError::NotFound)?;
tenant
.get_timeline(timeline_id, true)
.map_err(ApiError::NotFound)
@@ -1080,89 +1044,6 @@ async fn always_panic_handler(req: Request<Body>) -> Result<Response<Body>, ApiE
json_response(StatusCode::NO_CONTENT, ())
}
async fn disk_usage_eviction_run(mut r: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permission(&r, None)?;
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
struct Config {
/// How many bytes to evict before reporting that pressure is relieved.
evict_bytes: u64,
}
#[derive(Debug, Clone, Copy, serde::Serialize)]
struct Usage {
// remains unchanged after instantiation of the struct
config: Config,
// updated by `add_available_bytes`
freed_bytes: u64,
}
impl crate::disk_usage_eviction_task::Usage for Usage {
fn has_pressure(&self) -> bool {
self.config.evict_bytes > self.freed_bytes
}
fn add_available_bytes(&mut self, bytes: u64) {
self.freed_bytes += bytes;
}
}
let config = json_request::<Config>(&mut r)
.await
.map_err(|_| ApiError::BadRequest(anyhow::anyhow!("invalid JSON body")))?;
let usage = Usage {
config,
freed_bytes: 0,
};
use crate::task_mgr::MGMT_REQUEST_RUNTIME;
let (tx, rx) = tokio::sync::oneshot::channel();
let state = get_state(&r);
let Some(storage) = state.remote_storage.clone() else {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"remote storage not configured, cannot run eviction iteration"
)))
};
let state = state.disk_usage_eviction_state.clone();
let cancel = CancellationToken::new();
let child_cancel = cancel.clone();
let _g = cancel.drop_guard();
crate::task_mgr::spawn(
MGMT_REQUEST_RUNTIME.handle(),
TaskKind::DiskUsageEviction,
None,
None,
"ondemand disk usage eviction",
false,
async move {
let res = crate::disk_usage_eviction_task::disk_usage_eviction_task_iteration_impl(
&state,
&storage,
usage,
&child_cancel,
)
.await;
info!(?res, "disk_usage_eviction_task_iteration_impl finished");
let _ = tx.send(res);
Ok(())
}
.in_current_span(),
);
let response = rx.await.unwrap().map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, response)
}
async fn handler_404(_: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(
StatusCode::NOT_FOUND,
@@ -1175,7 +1056,6 @@ pub fn make_router(
launch_ts: &'static LaunchTimestamp,
auth: Option<Arc<JwtAuth>>,
remote_storage: Option<GenericRemoteStorage>,
disk_usage_eviction_state: Arc<disk_usage_eviction_task::State>,
) -> anyhow::Result<RouterBuilder<hyper::Body, ApiError>> {
let spec = include_bytes!("openapi_spec.yml");
let mut router = attach_openapi_ui(endpoint::make_router(), spec, "/swagger.yml", "/v1/doc");
@@ -1213,65 +1093,43 @@ pub fn make_router(
let handler = $handler;
#[cfg(not(feature = "testing"))]
let handler = cfg_disabled;
move |r| RequestSpan(handler).handle(r)
handler
}};
}
Ok(router
.data(Arc::new(
State::new(conf, auth, remote_storage, disk_usage_eviction_state)
.context("Failed to initialize router state")?,
State::new(conf, auth, remote_storage).context("Failed to initialize router state")?,
))
.get("/v1/status", |r| RequestSpan(status_handler).handle(r))
.get("/v1/status", status_handler)
.put(
"/v1/failpoints",
testing_api!("manage failpoints", failpoints_handler),
)
.get("/v1/tenant", |r| RequestSpan(tenant_list_handler).handle(r))
.post("/v1/tenant", |r| {
RequestSpan(tenant_create_handler).handle(r)
})
.get("/v1/tenant/:tenant_id", |r| {
RequestSpan(tenant_status).handle(r)
})
.get("/v1/tenant/:tenant_id/synthetic_size", |r| {
RequestSpan(tenant_size_handler).handle(r)
})
.put("/v1/tenant/config", |r| {
RequestSpan(update_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/config", |r| {
RequestSpan(get_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_list_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_create_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/attach", |r| {
RequestSpan(tenant_attach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/detach", |r| {
RequestSpan(tenant_detach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/load", |r| {
RequestSpan(tenant_load_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/ignore", |r| {
RequestSpan(tenant_ignore_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_detail_handler).handle(r)
})
.get("/v1/tenant", tenant_list_handler)
.post("/v1/tenant", tenant_create_handler)
.get("/v1/tenant/:tenant_id", tenant_status)
.get("/v1/tenant/:tenant_id/synthetic_size", tenant_size_handler)
.put("/v1/tenant/config", update_tenant_config_handler)
.get("/v1/tenant/:tenant_id/config", get_tenant_config_handler)
.get("/v1/tenant/:tenant_id/timeline", timeline_list_handler)
.post("/v1/tenant/:tenant_id/timeline", timeline_create_handler)
.post("/v1/tenant/:tenant_id/attach", tenant_attach_handler)
.post("/v1/tenant/:tenant_id/detach", tenant_detach_handler)
.post("/v1/tenant/:tenant_id/load", tenant_load_handler)
.post("/v1/tenant/:tenant_id/ignore", tenant_ignore_handler)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_detail_handler,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/get_lsn_by_timestamp",
|r| RequestSpan(get_lsn_by_timestamp_handler).handle(r),
get_lsn_by_timestamp_handler,
)
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc",
timeline_gc_handler,
)
.put("/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc", |r| {
RequestSpan(timeline_gc_handler).handle(r)
})
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/compact",
testing_api!("run timeline compaction", timeline_compact_handler),
@@ -1282,33 +1140,64 @@ pub fn make_router(
)
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|r| RequestSpan(timeline_download_remote_layers_handler_post).handle(r),
timeline_download_remote_layers_handler_post,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|r| RequestSpan(timeline_download_remote_layers_handler_get).handle(r),
timeline_download_remote_layers_handler_get,
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_delete_handler,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer",
layer_map_info_handler,
)
.delete("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_delete_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id/layer", |r| {
RequestSpan(layer_map_info_handler).handle(r)
})
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|r| RequestSpan(layer_download_handler).handle(r),
layer_download_handler,
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|r| RequestSpan(evict_timeline_layer_handler).handle(r),
evict_timeline_layer_handler,
)
.put("/v1/disk_usage_eviction/run", |r| {
RequestSpan(disk_usage_eviction_run).handle(r)
})
.put(
"/v1/tenant/:tenant_id/break",
testing_api!("set tenant state to broken", handle_tenant_break),
.get(
"/v1/panic",
wrap_span("always_panic_handler", always_panic_handler),
)
.get("/v1/panic", |r| RequestSpan(always_panic_handler).handle(r))
.any(handler_404))
}
fn wrap_span<H, R, B, E>(
handler_name: &'static str,
handler: H,
) -> impl Fn(Request<hyper::Body>) -> R
where
H: Fn(Request<hyper::Body>) -> R + Send + Sync + 'static,
B: HttpBody + Send + Sync + 'static,
E: Into<Box<dyn std::error::Error + Send + Sync>>,
R: Future<Output = Result<Response<B>, E>> + Send + 'static,
{
move |r| -> R {
async {
let headers = r.headers_mut();
let name = HeaderName::from_str("UUID").expect("created header name");
let request_id = "foo";
let value = HeaderValue::from_str(&request_id).unwrap();
headers.insert(name, value);
if r.method() == Method::GET {
tracing::debug!("{} {} {}", r.method(), r.uri().path(), request_id);
} else {
tracing::info!("{} {} {}", r.method(), r.uri().path(), request_id);
}
handler(r)
.instrument(info_span!(
"request",
handler = handler_name,
request_id = request_id
))
.await
}
}
}

View File

@@ -4,7 +4,6 @@ pub mod broker_client;
pub mod config;
pub mod consumption_metrics;
pub mod context;
pub mod disk_usage_eviction_task;
pub mod http;
pub mod import_datadir;
pub mod keyspace;
@@ -13,7 +12,6 @@ pub mod page_cache;
pub mod page_service;
pub mod pgdatadir_mapping;
pub mod repository;
pub(crate) mod statvfs;
pub mod task_mgr;
pub mod tenant;
pub mod trace;

View File

@@ -9,18 +9,22 @@ use once_cell::sync::Lazy;
use pageserver_api::models::state;
use utils::id::{TenantId, TimelineId};
/// Prometheus histogram buckets (in seconds) for operations in the critical
/// path. In other words, operations that directly affect that latency of user
/// queries.
///
/// The buckets capture the majority of latencies in the microsecond and
/// millisecond range but also extend far enough up to distinguish "bad" from
/// "really bad".
const CRITICAL_OP_BUCKETS: &[f64] = &[
0.000_001, 0.000_010, 0.000_100, // 1 us, 10 us, 100 us
0.001_000, 0.010_000, 0.100_000, // 1 ms, 10 ms, 100 ms
1.0, 10.0, 100.0, // 1 s, 10 s, 100 s
];
/// Prometheus histogram buckets (in seconds) that capture the majority of
/// latencies in the microsecond range but also extend far enough up to distinguish
/// "bad" from "really bad".
fn get_buckets_for_critical_operations() -> Vec<f64> {
let buckets_per_digit = 5;
let min_exponent = -6;
let max_exponent = 2;
let mut buckets = vec![];
// Compute 10^(exp / buckets_per_digit) instead of 10^(1/buckets_per_digit)^exp
// because it's more numerically stable and doesn't result in numbers like 9.999999
for exp in (min_exponent * buckets_per_digit)..=(max_exponent * buckets_per_digit) {
buckets.push(10_f64.powf(exp as f64 / buckets_per_digit as f64))
}
buckets
}
// Metrics collected on operations on the storage repository.
const STORAGE_TIME_OPERATIONS: &[&str] = &[
@@ -51,15 +55,12 @@ pub static STORAGE_TIME_COUNT_PER_TIMELINE: Lazy<IntCounterVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
// Buckets for background operations like compaction, GC, size calculation
const STORAGE_OP_BUCKETS: &[f64] = &[0.010, 0.100, 1.0, 10.0, 100.0, 1000.0];
pub static STORAGE_TIME_GLOBAL: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"pageserver_storage_operations_seconds_global",
"Time spent on storage operations",
&["operation"],
STORAGE_OP_BUCKETS.into(),
get_buckets_for_critical_operations(),
)
.expect("failed to define a metric")
});
@@ -70,7 +71,7 @@ static RECONSTRUCT_TIME: Lazy<HistogramVec> = Lazy::new(|| {
"pageserver_getpage_reconstruct_seconds",
"Time spent in reconstruct_value",
&["tenant_id", "timeline_id"],
CRITICAL_OP_BUCKETS.into(),
get_buckets_for_critical_operations(),
)
.expect("failed to define a metric")
});
@@ -89,7 +90,7 @@ static WAIT_LSN_TIME: Lazy<HistogramVec> = Lazy::new(|| {
"pageserver_wait_lsn_seconds",
"Time spent waiting for WAL to arrive",
&["tenant_id", "timeline_id"],
CRITICAL_OP_BUCKETS.into(),
get_buckets_for_critical_operations(),
)
.expect("failed to define a metric")
});
@@ -122,22 +123,6 @@ static REMOTE_PHYSICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
pub static REMOTE_ONDEMAND_DOWNLOADED_LAYERS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_layers_total",
"Total on-demand downloaded layers"
)
.unwrap()
});
pub static REMOTE_ONDEMAND_DOWNLOADED_BYTES: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_bytes_total",
"Total bytes of layers on-demand downloaded",
)
.unwrap()
});
static CURRENT_LOGICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_current_logical_size",
@@ -194,101 +179,15 @@ static PERSISTENT_BYTES_WRITTEN: Lazy<IntCounterVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
static EVICTIONS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_evictions",
"Number of layers evicted from the pageserver",
&["tenant_id", "timeline_id"]
)
.expect("failed to define a metric")
});
static EVICTIONS_WITH_LOW_RESIDENCE_DURATION: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_evictions_with_low_residence_duration",
"If a layer is evicted that was resident for less than `low_threshold`, it is counted to this counter. \
Residence duration is determined using the `residence_duration_data_source`.",
&["tenant_id", "timeline_id", "residence_duration_data_source", "low_threshold_secs"]
)
.expect("failed to define a metric")
});
/// Each [`Timeline`]'s [`EVICTIONS_WITH_LOW_RESIDENCE_DURATION`] metric.
#[derive(Debug)]
pub struct EvictionsWithLowResidenceDuration {
data_source: &'static str,
threshold: Duration,
counter: Option<IntCounter>,
}
pub struct EvictionsWithLowResidenceDurationBuilder {
data_source: &'static str,
threshold: Duration,
}
impl EvictionsWithLowResidenceDurationBuilder {
pub fn new(data_source: &'static str, threshold: Duration) -> Self {
Self {
data_source,
threshold,
}
}
fn build(&self, tenant_id: &str, timeline_id: &str) -> EvictionsWithLowResidenceDuration {
let counter = EVICTIONS_WITH_LOW_RESIDENCE_DURATION
.get_metric_with_label_values(&[
tenant_id,
timeline_id,
self.data_source,
&EvictionsWithLowResidenceDuration::threshold_label_value(self.threshold),
])
.unwrap();
EvictionsWithLowResidenceDuration {
data_source: self.data_source,
threshold: self.threshold,
counter: Some(counter),
}
}
}
impl EvictionsWithLowResidenceDuration {
fn threshold_label_value(threshold: Duration) -> String {
format!("{}", threshold.as_secs())
}
pub fn observe(&self, observed_value: Duration) {
if self.threshold < observed_value {
self.counter
.as_ref()
.expect("nobody calls this function after `remove_from_vec`")
.inc();
}
}
// This could be a `Drop` impl, but, we need the `tenant_id` and `timeline_id`.
fn remove(&mut self, tenant_id: &str, timeline_id: &str) {
let Some(_counter) = self.counter.take() else {
return;
};
EVICTIONS_WITH_LOW_RESIDENCE_DURATION
.remove_label_values(&[
tenant_id,
timeline_id,
self.data_source,
&Self::threshold_label_value(self.threshold),
])
.expect("we own the metric, no-one else should remove it");
}
}
// Metrics collected on disk IO operations
//
// Roughly logarithmic scale.
const STORAGE_IO_TIME_BUCKETS: &[f64] = &[
0.000030, // 30 usec
0.001000, // 1000 usec
0.030, // 30 ms
1.000, // 1000 ms
0.000001, // 1 usec
0.00001, // 10 usec
0.0001, // 100 usec
0.001, // 1 msec
0.01, // 10 msec
0.1, // 100 msec
1.0, // 1 sec
];
const STORAGE_IO_TIME_OPERATIONS: &[&str] = &[
@@ -323,12 +222,20 @@ const SMGR_QUERY_TIME_OPERATIONS: &[&str] = &[
"get_db_size",
];
const SMGR_QUERY_TIME_BUCKETS: &[f64] = &[
0.00001, // 1/100000 s
0.0001, 0.00015, 0.0002, 0.00025, 0.0003, 0.00035, 0.0005, 0.00075, // 1/10000 s
0.001, 0.0025, 0.005, 0.0075, // 1/1000 s
0.01, 0.0125, 0.015, 0.025, 0.05, // 1/100 s
0.1, // 1/10 s
];
pub static SMGR_QUERY_TIME: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"pageserver_smgr_query_seconds",
"Time spent on smgr query handling",
&["smgr_query_type", "tenant_id", "timeline_id"],
CRITICAL_OP_BUCKETS.into(),
SMGR_QUERY_TIME_BUCKETS.into()
)
.expect("failed to define a metric")
});
@@ -586,6 +493,7 @@ pub struct TimelineMetrics {
pub flush_time_histo: StorageTimeMetrics,
pub compact_time_histo: StorageTimeMetrics,
pub create_images_time_histo: StorageTimeMetrics,
pub init_logical_size_histo: StorageTimeMetrics,
pub logical_size_histo: StorageTimeMetrics,
pub load_layer_map_histo: StorageTimeMetrics,
pub garbage_collect_histo: StorageTimeMetrics,
@@ -596,16 +504,10 @@ pub struct TimelineMetrics {
pub current_logical_size_gauge: UIntGauge,
pub num_persistent_files_created: IntCounter,
pub persistent_bytes_written: IntCounter,
pub evictions: IntCounter,
pub evictions_with_low_residence_duration: EvictionsWithLowResidenceDuration,
}
impl TimelineMetrics {
pub fn new(
tenant_id: &TenantId,
timeline_id: &TimelineId,
evictions_with_low_residence_duration_builder: EvictionsWithLowResidenceDurationBuilder,
) -> Self {
pub fn new(tenant_id: &TenantId, timeline_id: &TimelineId) -> Self {
let tenant_id = tenant_id.to_string();
let timeline_id = timeline_id.to_string();
let reconstruct_time_histo = RECONSTRUCT_TIME
@@ -618,6 +520,8 @@ impl TimelineMetrics {
let compact_time_histo = StorageTimeMetrics::new("compact", &tenant_id, &timeline_id);
let create_images_time_histo =
StorageTimeMetrics::new("create images", &tenant_id, &timeline_id);
let init_logical_size_histo =
StorageTimeMetrics::new("init logical size", &tenant_id, &timeline_id);
let logical_size_histo = StorageTimeMetrics::new("logical size", &tenant_id, &timeline_id);
let load_layer_map_histo =
StorageTimeMetrics::new("load layer map", &tenant_id, &timeline_id);
@@ -640,11 +544,6 @@ impl TimelineMetrics {
let persistent_bytes_written = PERSISTENT_BYTES_WRITTEN
.get_metric_with_label_values(&[&tenant_id, &timeline_id])
.unwrap();
let evictions = EVICTIONS
.get_metric_with_label_values(&[&tenant_id, &timeline_id])
.unwrap();
let evictions_with_low_residence_duration =
evictions_with_low_residence_duration_builder.build(&tenant_id, &timeline_id);
TimelineMetrics {
tenant_id,
@@ -654,6 +553,7 @@ impl TimelineMetrics {
flush_time_histo,
compact_time_histo,
create_images_time_histo,
init_logical_size_histo,
logical_size_histo,
garbage_collect_histo,
load_layer_map_histo,
@@ -663,8 +563,6 @@ impl TimelineMetrics {
current_logical_size_gauge,
num_persistent_files_created,
persistent_bytes_written,
evictions,
evictions_with_low_residence_duration,
}
}
}
@@ -681,9 +579,7 @@ impl Drop for TimelineMetrics {
let _ = CURRENT_LOGICAL_SIZE.remove_label_values(&[tenant_id, timeline_id]);
let _ = NUM_PERSISTENT_FILES_CREATED.remove_label_values(&[tenant_id, timeline_id]);
let _ = PERSISTENT_BYTES_WRITTEN.remove_label_values(&[tenant_id, timeline_id]);
let _ = EVICTIONS.remove_label_values(&[tenant_id, timeline_id]);
self.evictions_with_low_residence_duration
.remove(tenant_id, timeline_id);
for op in STORAGE_TIME_OPERATIONS {
let _ =
STORAGE_TIME_SUM_PER_TIMELINE.remove_label_values(&[op, tenant_id, timeline_id]);
@@ -718,7 +614,7 @@ use std::collections::HashMap;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use std::time::Instant;
pub struct RemoteTimelineClientMetrics {
tenant_id: String,

View File

@@ -12,7 +12,7 @@
use anyhow::Context;
use bytes::Buf;
use bytes::Bytes;
use futures::Stream;
use futures::{Stream, StreamExt};
use pageserver_api::models::TenantState;
use pageserver_api::models::{
PagestreamBeMessage, PagestreamDbSizeRequest, PagestreamDbSizeResponse,
@@ -20,9 +20,7 @@ use pageserver_api::models::{
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
PagestreamNblocksRequest, PagestreamNblocksResponse,
};
use postgres_backend::PostgresBackendTCP;
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
use pq_proto::framed::ConnectionError;
use pq_proto::ConnectionError;
use pq_proto::FeStartupPacket;
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
use std::io;
@@ -31,13 +29,14 @@ use std::str;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio_util::io::StreamReader;
use tracing::*;
use utils::id::ConnectionId;
use utils::{
auth::{Claims, JwtAuth, Scope},
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
postgres_backend_async::{self, is_expected_io_error, PostgresBackend, QueryError},
simple_rcu::RcuReadGuard,
};
@@ -56,7 +55,7 @@ use crate::trace::Tracer;
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
use postgres_ffi::BLCKSZ;
fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<Bytes>> + '_ {
fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Bytes>> + '_ {
async_stream::try_stream! {
loop {
let msg = tokio::select! {
@@ -65,11 +64,11 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
_ = task_mgr::shutdown_watcher() => {
// We were requested to shut down.
let msg = format!("pageserver is shutting down");
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None));
let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg, None));
Err(QueryError::Other(anyhow::anyhow!(msg)))
}
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
msg = pgb.read_message() => { msg }
};
match msg {
@@ -80,16 +79,14 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
FeMessage::Sync => continue,
FeMessage::Terminate => {
let msg = "client terminated connection with Terminate message during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
break;
}
m => {
let msg = format!("unexpected message {m:?}");
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?;
Err(io::Error::new(io::ErrorKind::Other, msg))?;
break;
}
@@ -99,66 +96,22 @@ fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<
}
Ok(None) => {
let msg = "client closed connection during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
pgb.flush().await?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
Err(io_error)?;
}
Err(other) => {
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
Err(io::Error::new(io::ErrorKind::Other, other))?;
}
};
}
}
}
/// Read the end of a tar archive.
///
/// A tar archive normally ends with two consecutive blocks of zeros, 512 bytes each.
/// `tokio_tar` already read the first such block. Read the second all-zeros block,
/// and check that there is no more data after the EOF marker.
///
/// XXX: Currently, any trailing data after the EOF marker prints a warning.
/// Perhaps it should be a hard error?
async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow::Result<()> {
use tokio::io::AsyncReadExt;
let mut buf = [0u8; 512];
// Read the all-zeros block, and verify it
let mut total_bytes = 0;
while total_bytes < 512 {
let nbytes = reader.read(&mut buf[total_bytes..]).await?;
total_bytes += nbytes;
if nbytes == 0 {
break;
}
}
if total_bytes < 512 {
anyhow::bail!("incomplete or invalid tar EOF marker");
}
if !buf.iter().all(|&x| x == 0) {
anyhow::bail!("invalid tar EOF marker");
}
// Drain any data after the EOF marker
let mut trailing_bytes = 0;
loop {
let nbytes = reader.read(&mut buf).await?;
trailing_bytes += nbytes;
if nbytes == 0 {
break;
}
}
if trailing_bytes > 0 {
warn!("ignored {trailing_bytes} unexpected bytes after the tar archive");
}
Ok(())
}
///////////////////////////////////////////////////////////////////////////////
///
@@ -259,7 +212,7 @@ async fn page_service_conn_main(
// we've been requested to shut down
Ok(())
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
if is_expected_io_error(&io_error) {
info!("Postgres client disconnected ({io_error})");
Ok(())
@@ -333,7 +286,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_pagerequests(
&self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
ctx: RequestContext,
@@ -358,7 +311,7 @@ impl PageServerHandler {
let timeline = tenant.get_timeline(timeline_id, true)?;
// switch client to COPYBOTH
pgb.write_message_noflush(&BeMessage::CopyBothResponse)?;
pgb.write_message(&BeMessage::CopyBothResponse)?;
pgb.flush().await?;
let metrics = PageRequestMetrics::new(&tenant_id, &timeline_id);
@@ -427,7 +380,7 @@ impl PageServerHandler {
})
});
pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?;
pgb.write_message(&BeMessage::CopyData(&response.serialize()))?;
pgb.flush().await?;
}
Ok(())
@@ -437,7 +390,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_basebackup(
&self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
base_lsn: Lsn,
@@ -463,17 +416,22 @@ impl PageServerHandler {
// Import basebackup provided via CopyData
info!("importing basebackup");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let copyin_reader = StreamReader::new(copyin_stream(pgb));
tokio::pin!(copyin_reader);
let mut copyin_stream = Box::pin(copyin_stream(pgb));
timeline
.import_basebackup_from_tar(&mut copyin_reader, base_lsn, &ctx)
.import_basebackup_from_tar(&mut copyin_stream, base_lsn, &ctx)
.await?;
// Read the end of the tar archive.
read_tar_eof(copyin_reader).await?;
// Drain the rest of the Copy data
let mut bytes_after_tar = 0;
while let Some(bytes) = copyin_stream.next().await {
bytes_after_tar += bytes?.len();
}
if bytes_after_tar > 0 {
warn!("ignored {bytes_after_tar} unexpected bytes after the tar archive");
}
// TODO check checksum
// Meanwhile you can verify client-side by taking fullbackup
@@ -488,7 +446,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_wal(
&self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
start_lsn: Lsn,
@@ -510,15 +468,21 @@ impl PageServerHandler {
// Import wal provided via CopyData
info!("importing wal");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let copyin_reader = StreamReader::new(copyin_stream(pgb));
tokio::pin!(copyin_reader);
import_wal_from_tar(&timeline, &mut copyin_reader, start_lsn, end_lsn, &ctx).await?;
let mut copyin_stream = Box::pin(copyin_stream(pgb));
let mut reader = tokio_util::io::StreamReader::new(&mut copyin_stream);
import_wal_from_tar(&timeline, &mut reader, start_lsn, end_lsn, &ctx).await?;
info!("wal import complete");
// Read the end of the tar archive.
read_tar_eof(copyin_reader).await?;
// Drain the rest of the Copy data
let mut bytes_after_tar = 0;
while let Some(bytes) = copyin_stream.next().await {
bytes_after_tar += bytes?.len();
}
if bytes_after_tar > 0 {
warn!("ignored {bytes_after_tar} unexpected bytes after the tar archive");
}
// TODO Does it make sense to overshoot?
if timeline.get_last_record_lsn() < end_lsn {
@@ -693,7 +657,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_basebackup_request(
&mut self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Option<Lsn>,
@@ -714,7 +678,7 @@ impl PageServerHandler {
}
// switch client to COPYOUT
pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
pgb.write_message(&BeMessage::CopyOutResponse)?;
pgb.flush().await?;
// Send a tarball of the latest layer on the timeline
@@ -731,7 +695,7 @@ impl PageServerHandler {
.await?;
}
pgb.write_message_noflush(&BeMessage::CopyDone)?;
pgb.write_message(&BeMessage::CopyDone)?;
pgb.flush().await?;
info!("basebackup complete");
@@ -757,10 +721,10 @@ impl PageServerHandler {
}
#[async_trait::async_trait]
impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
impl postgres_backend_async::Handler for PageServerHandler {
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackendTCP,
_pgb: &mut PostgresBackend,
jwt_response: &[u8],
) -> Result<(), QueryError> {
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
@@ -788,7 +752,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
fn startup(
&mut self,
_pgb: &mut PostgresBackendTCP,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
@@ -796,7 +760,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackendTCP,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError> {
let ctx = self.connection_ctx.attached_child();
@@ -848,7 +812,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, None, false, ctx)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// return pair of prev_lsn and last_lsn
else if query_string.starts_with("get_last_record_rlsn ") {
@@ -871,15 +835,15 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
let end_of_timeline = timeline.get_last_record_rlsn();
pgb.write_message_noflush(&BeMessage::RowDescription(&[
pgb.write_message(&BeMessage::RowDescription(&[
RowDescriptor::text_col(b"prev_lsn"),
RowDescriptor::text_col(b"last_lsn"),
]))?
.write_message_noflush(&BeMessage::DataRow(&[
.write_message(&BeMessage::DataRow(&[
Some(end_of_timeline.prev.to_string().as_bytes()),
Some(end_of_timeline.last.to_string().as_bytes()),
]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// same as basebackup, but result includes relational data as well
else if query_string.starts_with("fullbackup ") {
@@ -920,7 +884,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, prev_lsn, true, ctx)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("import basebackup ") {
// Import the `base` section (everything but the wal) of a basebackup.
// Assumes the tenant already exists on this pageserver.
@@ -965,10 +929,10 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
)
.await
{
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}");
pgb.write_message_noflush(&BeMessage::ErrorResponse(
pgb.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -1001,10 +965,10 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
.handle_import_wal(pgb, tenant_id, timeline_id, start_lsn, end_lsn, ctx)
.await
{
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}");
pgb.write_message_noflush(&BeMessage::ErrorResponse(
pgb.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -1013,7 +977,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
} else if query_string.to_ascii_lowercase().starts_with("set ") {
// important because psycopg2 executes "SET datestyle TO 'ISO'"
// on connect
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("show ") {
// show <tenant_id>
let (_, params_raw) = query_string.split_at("show ".len());
@@ -1029,7 +993,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
self.check_permission(Some(tenant_id))?;
let tenant = get_active_tenant_with_timeout(tenant_id, &ctx).await?;
pgb.write_message_noflush(&BeMessage::RowDescription(&[
pgb.write_message(&BeMessage::RowDescription(&[
RowDescriptor::int8_col(b"checkpoint_distance"),
RowDescriptor::int8_col(b"checkpoint_timeout"),
RowDescriptor::int8_col(b"compaction_target_size"),
@@ -1040,7 +1004,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
RowDescriptor::int8_col(b"image_creation_threshold"),
RowDescriptor::int8_col(b"pitr_interval"),
]))?
.write_message_noflush(&BeMessage::DataRow(&[
.write_message(&BeMessage::DataRow(&[
Some(tenant.get_checkpoint_distance().to_string().as_bytes()),
Some(
tenant
@@ -1063,7 +1027,7 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
Some(tenant.get_image_creation_threshold().to_string().as_bytes()),
Some(tenant.get_pitr_interval().as_secs().to_string().as_bytes()),
]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else {
return Err(QueryError::Other(anyhow::anyhow!(
"unknown command {query_string}"
@@ -1091,7 +1055,7 @@ impl From<GetActiveTenantError> for QueryError {
fn from(e: GetActiveTenantError) -> Self {
match e {
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
),
GetActiveTenantError::Other(e) => QueryError::Other(e),
}
@@ -1107,10 +1071,7 @@ async fn get_active_tenant_with_timeout(
tenant_id: TenantId,
_ctx: &RequestContext, /* require get a context to support cancellation in the future */
) -> Result<Arc<Tenant>, GetActiveTenantError> {
let tenant = match mgr::get_tenant(tenant_id, false).await {
Ok(tenant) => tenant,
Err(e) => return Err(GetActiveTenantError::Other(e.into())),
};
let tenant = mgr::get_tenant(tenant_id, false).await?;
let wait_time = Duration::from_secs(30);
match tokio::time::timeout(wait_time, tenant.wait_to_become_active()).await {
Ok(Ok(())) => Ok(tenant),

View File

@@ -1,150 +0,0 @@
//! Wrapper around nix::sys::statvfs::Statvfs that allows for mocking.
use std::path::Path;
pub enum Statvfs {
Real(nix::sys::statvfs::Statvfs),
Mock(mock::Statvfs),
}
// NB: on macOS, the block count type of struct statvfs is u32.
// The workaround seems to be to use the non-standard statfs64 call.
// Sincce it should only be a problem on > 2TiB disks, let's ignore
// the problem for now and upcast to u64.
impl Statvfs {
pub fn get(tenants_dir: &Path, mocked: Option<&mock::Behavior>) -> nix::Result<Self> {
if let Some(mocked) = mocked {
Ok(Statvfs::Mock(mock::get(tenants_dir, mocked)?))
} else {
Ok(Statvfs::Real(nix::sys::statvfs::statvfs(tenants_dir)?))
}
}
// NB: allow() because the block count type is u32 on macOS.
#[allow(clippy::useless_conversion)]
pub fn blocks(&self) -> u64 {
match self {
Statvfs::Real(stat) => u64::try_from(stat.blocks()).unwrap(),
Statvfs::Mock(stat) => stat.blocks,
}
}
// NB: allow() because the block count type is u32 on macOS.
#[allow(clippy::useless_conversion)]
pub fn blocks_available(&self) -> u64 {
match self {
Statvfs::Real(stat) => u64::try_from(stat.blocks_available()).unwrap(),
Statvfs::Mock(stat) => stat.blocks_available,
}
}
pub fn fragment_size(&self) -> u64 {
match self {
Statvfs::Real(stat) => stat.fragment_size(),
Statvfs::Mock(stat) => stat.fragment_size,
}
}
pub fn block_size(&self) -> u64 {
match self {
Statvfs::Real(stat) => stat.block_size(),
Statvfs::Mock(stat) => stat.block_size,
}
}
}
pub mod mock {
use anyhow::Context;
use regex::Regex;
use std::path::Path;
use tracing::log::info;
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type")]
pub enum Behavior {
Success {
blocksize: u64,
total_blocks: u64,
name_filter: Option<utils::serde_regex::Regex>,
},
Failure {
mocked_error: MockedError,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[allow(clippy::upper_case_acronyms)]
pub enum MockedError {
EIO,
}
impl From<MockedError> for nix::Error {
fn from(e: MockedError) -> Self {
match e {
MockedError::EIO => nix::Error::EIO,
}
}
}
pub fn get(tenants_dir: &Path, behavior: &Behavior) -> nix::Result<Statvfs> {
info!("running mocked statvfs");
match behavior {
Behavior::Success {
blocksize,
total_blocks,
ref name_filter,
} => {
let used_bytes = walk_dir_disk_usage(tenants_dir, name_filter.as_deref()).unwrap();
// round it up to the nearest block multiple
let used_blocks = (used_bytes + (blocksize - 1)) / blocksize;
if used_blocks > *total_blocks {
panic!(
"mocking error: used_blocks > total_blocks: {used_blocks} > {total_blocks}"
);
}
let avail_blocks = total_blocks - used_blocks;
Ok(Statvfs {
blocks: *total_blocks,
blocks_available: avail_blocks,
fragment_size: *blocksize,
block_size: *blocksize,
})
}
Behavior::Failure { mocked_error } => Err((*mocked_error).into()),
}
}
fn walk_dir_disk_usage(path: &Path, name_filter: Option<&Regex>) -> anyhow::Result<u64> {
let mut total = 0;
for entry in walkdir::WalkDir::new(path) {
let entry = entry?;
if !entry.file_type().is_file() {
continue;
}
if !name_filter
.as_ref()
.map(|filter| filter.is_match(entry.file_name().to_str().unwrap()))
.unwrap_or(true)
{
continue;
}
total += entry
.metadata()
.with_context(|| format!("get metadata of {:?}", entry.path()))?
.len();
}
Ok(total)
}
pub struct Statvfs {
pub blocks: u64,
pub blocks_available: u64,
pub fragment_size: u64,
pub block_size: u64,
}
}

View File

@@ -234,9 +234,6 @@ pub enum TaskKind {
// Eviction. One per timeline.
Eviction,
/// See [`crate::disk_usage_eviction_task`].
DiskUsageEviction,
// Initial logical size calculation
InitialLogicalSizeCalculation,
@@ -484,25 +481,13 @@ pub async fn shutdown_tasks(
for task in victim_tasks {
let join_handle = {
let mut task_mut = task.mutable.lock().unwrap();
task_mut.join_handle.take()
info!("waiting for {} to shut down", task.name);
let join_handle = task_mut.join_handle.take();
drop(task_mut);
join_handle
};
if let Some(mut join_handle) = join_handle {
let completed = tokio::select! {
_ = &mut join_handle => { true },
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {
// allow some time to elapse before logging to cut down the number of log
// lines.
info!("waiting for {} to shut down", task.name);
false
}
};
if !completed {
// we never handled this return value, but:
// - we don't deschedule which would lead to is_cancelled
// - panics are already logged (is_panicked)
// - task errors are already logged in the wrapper
let _ = join_handle.await;
}
if let Some(join_handle) = join_handle {
let _ = join_handle.await;
} else {
// Possibly one of:
// * The task had not even fully started yet.

View File

@@ -12,7 +12,9 @@
//!
use anyhow::{bail, Context};
use bytes::Bytes;
use futures::FutureExt;
use futures::Stream;
use pageserver_api::models::TimelineState;
use remote_storage::DownloadError;
use remote_storage::GenericRemoteStorage;
@@ -94,7 +96,7 @@ mod timeline;
pub mod size;
pub use timeline::{LocalLayerInfoForDiskUsageEviction, PageReconstructError, Timeline};
pub use timeline::{PageReconstructError, Timeline};
// re-export this function so that page_cache.rs can use it.
pub use crate::tenant::ephemeral_file::writeback as writeback_ephemeral_file;
@@ -237,13 +239,14 @@ impl UninitializedTimeline<'_> {
/// Prepares timeline data by loading it from the basebackup archive.
pub async fn import_basebackup_from_tar(
self,
copyin_read: &mut (impl tokio::io::AsyncRead + Send + Sync + Unpin),
copyin_stream: &mut (impl Stream<Item = io::Result<Bytes>> + Sync + Send + Unpin),
base_lsn: Lsn,
ctx: &RequestContext,
) -> anyhow::Result<Arc<Timeline>> {
let raw_timeline = self.raw_timeline()?;
import_datadir::import_basebackup_from_tar(raw_timeline, copyin_read, base_lsn, ctx)
let mut reader = tokio_util::io::StreamReader::new(copyin_stream);
import_datadir::import_basebackup_from_tar(raw_timeline, &mut reader, base_lsn, ctx)
.await
.context("Failed to import basebackup")?;
@@ -431,16 +434,6 @@ remote:
}
}
#[derive(Debug, thiserror::Error)]
pub enum DeleteTimelineError {
#[error("NotFound")]
NotFound,
#[error("HasChildren")]
HasChildren,
#[error(transparent)]
Other(#[from] anyhow::Error),
}
struct RemoteStartupData {
index_part: IndexPart,
remote_metadata: TimelineMetadata,
@@ -488,7 +481,7 @@ impl Tenant {
let dummy_timeline = self.create_timeline_data(
timeline_id,
up_to_date_metadata,
up_to_date_metadata.clone(),
ancestor.clone(),
remote_client,
)?;
@@ -513,7 +506,7 @@ impl Tenant {
let broken_timeline = self
.create_timeline_data(
timeline_id,
up_to_date_metadata,
up_to_date_metadata.clone(),
ancestor.clone(),
None,
)
@@ -1152,7 +1145,7 @@ impl Tenant {
);
self.prepare_timeline(
new_timeline_id,
&new_metadata,
new_metadata,
timeline_uninit_mark,
true,
None,
@@ -1250,8 +1243,11 @@ impl Tenant {
"Cannot run GC iteration on inactive tenant"
);
self.gc_iteration_internal(target_timeline_id, horizon, pitr, ctx)
.await
let gc_result = self
.gc_iteration_internal(target_timeline_id, horizon, pitr, ctx)
.await;
gc_result
}
/// Perform one compaction iteration.
@@ -1317,7 +1313,7 @@ impl Tenant {
&self,
timeline_id: TimelineId,
_ctx: &RequestContext,
) -> Result<(), DeleteTimelineError> {
) -> anyhow::Result<()> {
// Transition the timeline into TimelineState::Stopping.
// This should prevent new operations from starting.
let timeline = {
@@ -1329,13 +1325,13 @@ impl Tenant {
.iter()
.any(|(_, entry)| entry.get_ancestor_timeline_id() == Some(timeline_id));
if children_exist {
return Err(DeleteTimelineError::HasChildren);
}
anyhow::ensure!(
!children_exist,
"Cannot delete timeline which has child timelines"
);
let timeline_entry = match timelines.entry(timeline_id) {
Entry::Occupied(e) => e,
Entry::Vacant(_) => return Err(DeleteTimelineError::NotFound),
Entry::Vacant(_) => bail!("timeline not found"),
};
let timeline = Arc::clone(timeline_entry.get());
@@ -1703,13 +1699,6 @@ impl Tenant {
.unwrap_or(self.conf.default_tenant_conf.trace_read_requests)
}
pub fn get_min_resident_size_override(&self) -> Option<u64> {
let tenant_conf = self.tenant_conf.read().unwrap();
tenant_conf
.min_resident_size_override
.or(self.conf.default_tenant_conf.min_resident_size_override)
}
pub fn set_new_tenant_config(&self, new_tenant_conf: TenantConfOpt) {
*self.tenant_conf.write().unwrap() = new_tenant_conf;
}
@@ -1717,7 +1706,7 @@ impl Tenant {
fn create_timeline_data(
&self,
new_timeline_id: TimelineId,
new_metadata: &TimelineMetadata,
new_metadata: TimelineMetadata,
ancestor: Option<Arc<Timeline>>,
remote_client: Option<RemoteTimelineClient>,
) -> anyhow::Result<Arc<Timeline>> {
@@ -2177,25 +2166,13 @@ impl Tenant {
let new_timeline = self
.prepare_timeline(
dst_id,
&metadata,
metadata,
timeline_uninit_mark,
false,
Some(Arc::clone(src_timeline)),
)?
.initialize_with_lock(&mut timelines, true, true)?;
drop(timelines);
// Root timeline gets its layers during creation and uploads them along with the metadata.
// A branch timeline though, when created, can get no writes for some time, hence won't get any layers created.
// We still need to upload its metadata eagerly: if other nodes `attach` the tenant and miss this timeline, their GC
// could get incorrect information and remove more layers, than needed.
// See also https://github.com/neondatabase/neon/issues/3865
if let Some(remote_client) = new_timeline.remote_client.as_ref() {
remote_client
.schedule_index_upload_for_metadata_update(&metadata)
.context("branch initial metadata upload")?;
}
info!("branched timeline {dst_id} from {src_id} at {start_lsn}");
Ok(new_timeline)
@@ -2258,7 +2235,7 @@ impl Tenant {
pg_version,
);
let raw_timeline =
self.prepare_timeline(timeline_id, &new_metadata, timeline_uninit_mark, true, None)?;
self.prepare_timeline(timeline_id, new_metadata, timeline_uninit_mark, true, None)?;
let tenant_id = raw_timeline.owning_tenant.tenant_id;
let unfinished_timeline = raw_timeline.raw_timeline()?;
@@ -2312,7 +2289,7 @@ impl Tenant {
fn prepare_timeline(
&self,
new_timeline_id: TimelineId,
new_metadata: &TimelineMetadata,
new_metadata: TimelineMetadata,
uninit_mark: TimelineUninitMark,
init_layers: bool,
ancestor: Option<Arc<Timeline>>,
@@ -2326,7 +2303,7 @@ impl Tenant {
tenant_id,
new_timeline_id,
);
remote_client.init_upload_queue_for_empty_remote(new_metadata)?;
remote_client.init_upload_queue_for_empty_remote(&new_metadata)?;
Some(remote_client)
} else {
None
@@ -2365,12 +2342,17 @@ impl Tenant {
&self,
timeline_path: &Path,
new_timeline_id: TimelineId,
new_metadata: &TimelineMetadata,
new_metadata: TimelineMetadata,
ancestor: Option<Arc<Timeline>>,
remote_client: Option<RemoteTimelineClient>,
) -> anyhow::Result<Arc<Timeline>> {
let timeline_data = self
.create_timeline_data(new_timeline_id, new_metadata, ancestor, remote_client)
.create_timeline_data(
new_timeline_id,
new_metadata.clone(),
ancestor,
remote_client,
)
.context("Failed to create timeline data structure")?;
crashsafe::create_dir_all(timeline_path).context("Failed to create timeline directory")?;
@@ -2382,7 +2364,7 @@ impl Tenant {
self.conf,
new_timeline_id,
self.tenant_id,
new_metadata,
&new_metadata,
true,
)
.context("Failed to create timeline metadata")?;
@@ -2786,7 +2768,6 @@ pub mod harness {
max_lsn_wal_lag: Some(tenant_conf.max_lsn_wal_lag),
trace_read_requests: Some(tenant_conf.trace_read_requests),
eviction_policy: Some(tenant_conf.eviction_policy),
min_resident_size_override: tenant_conf.min_resident_size_override,
}
}
}
@@ -3195,44 +3176,6 @@ mod tests {
}
*/
#[tokio::test]
async fn test_get_branchpoints_from_an_inactive_timeline() -> anyhow::Result<()> {
let (tenant, ctx) =
TenantHarness::create("test_get_branchpoints_from_an_inactive_timeline")?
.load()
.await;
let tline = tenant
.create_empty_timeline(TIMELINE_ID, Lsn(0), DEFAULT_PG_VERSION, &ctx)?
.initialize(&ctx)?;
make_some_layers(tline.as_ref(), Lsn(0x20)).await?;
tenant
.branch_timeline(&tline, NEW_TIMELINE_ID, Some(Lsn(0x40)), &ctx)
.await?;
let newtline = tenant
.get_timeline(NEW_TIMELINE_ID, true)
.expect("Should have a local timeline");
make_some_layers(newtline.as_ref(), Lsn(0x60)).await?;
tline.set_state(TimelineState::Broken);
tenant
.gc_iteration(Some(TIMELINE_ID), 0x10, Duration::ZERO, &ctx)
.await?;
assert_eq!(
newtline.get(*TEST_KEY, Lsn(0x50), &ctx).await?,
TEST_IMG(&format!("foo at {}", Lsn(0x40)))
);
let branchpoints = &tline.gc_info.read().unwrap().retain_lsns;
assert_eq!(branchpoints.len(), 1);
assert_eq!(branchpoints[0], Lsn(0x40));
Ok(())
}
#[tokio::test]
async fn test_retain_data_in_parent_which_is_needed_for_child() -> anyhow::Result<()> {
let (tenant, ctx) =

View File

@@ -51,6 +51,9 @@ where
///
/// A "cursor" for efficiently reading multiple pages from a BlockReader
///
/// A cursor caches the last accessed page, allowing for faster access if the
/// same block is accessed repeatedly.
///
/// You can access the last page with `*cursor`. 'read_blk' returns 'self', so
/// that in many cases you can use a BlockCursor as a drop-in replacement for
/// the underlying BlockReader. For example:
@@ -70,6 +73,8 @@ where
R: BlockReader,
{
reader: R,
/// last accessed page
cache: Option<(u32, R::BlockLease)>,
}
impl<R> BlockCursor<R>
@@ -77,13 +82,40 @@ where
R: BlockReader,
{
pub fn new(reader: R) -> Self {
BlockCursor { reader }
BlockCursor {
reader,
cache: None,
}
}
pub fn read_blk(&mut self, blknum: u32) -> Result<R::BlockLease, std::io::Error> {
self.reader.read_blk(blknum)
pub fn read_blk(&mut self, blknum: u32) -> Result<&Self, std::io::Error> {
// Fast return if this is the same block as before
if let Some((cached_blk, _buf)) = &self.cache {
if *cached_blk == blknum {
return Ok(self);
}
}
// Read the block from the underlying reader, and cache it
self.cache = None;
let buf = self.reader.read_blk(blknum)?;
self.cache = Some((blknum, buf));
Ok(self)
}
}
impl<R> Deref for BlockCursor<R>
where
R: BlockReader,
{
type Target = [u8; PAGE_SZ];
fn deref(&self) -> &<Self as Deref>::Target {
&self.cache.as_ref().unwrap().1
}
}
static NEXT_ID: AtomicU64 = AtomicU64::new(1);
/// An adapter for reading a (virtual) file using the page cache.

View File

@@ -92,7 +92,6 @@ pub struct TenantConf {
pub max_lsn_wal_lag: NonZeroU64,
pub trace_read_requests: bool,
pub eviction_policy: EvictionPolicy,
pub min_resident_size_override: Option<u64>,
}
/// Same as TenantConf, but this struct preserves the information about
@@ -104,7 +103,6 @@ pub struct TenantConfOpt {
pub checkpoint_distance: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "humantime_serde")]
#[serde(default)]
pub checkpoint_timeout: Option<Duration>,
@@ -160,10 +158,6 @@ pub struct TenantConfOpt {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub eviction_policy: Option<EvictionPolicy>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub min_resident_size_override: Option<u64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
@@ -225,9 +219,48 @@ impl TenantConfOpt {
.trace_read_requests
.unwrap_or(global_conf.trace_read_requests),
eviction_policy: self.eviction_policy.unwrap_or(global_conf.eviction_policy),
min_resident_size_override: self
.min_resident_size_override
.or(global_conf.min_resident_size_override),
}
}
pub fn update(&mut self, other: &TenantConfOpt) {
if let Some(checkpoint_distance) = other.checkpoint_distance {
self.checkpoint_distance = Some(checkpoint_distance);
}
if let Some(checkpoint_timeout) = other.checkpoint_timeout {
self.checkpoint_timeout = Some(checkpoint_timeout);
}
if let Some(compaction_target_size) = other.compaction_target_size {
self.compaction_target_size = Some(compaction_target_size);
}
if let Some(compaction_period) = other.compaction_period {
self.compaction_period = Some(compaction_period);
}
if let Some(compaction_threshold) = other.compaction_threshold {
self.compaction_threshold = Some(compaction_threshold);
}
if let Some(gc_horizon) = other.gc_horizon {
self.gc_horizon = Some(gc_horizon);
}
if let Some(gc_period) = other.gc_period {
self.gc_period = Some(gc_period);
}
if let Some(image_creation_threshold) = other.image_creation_threshold {
self.image_creation_threshold = Some(image_creation_threshold);
}
if let Some(pitr_interval) = other.pitr_interval {
self.pitr_interval = Some(pitr_interval);
}
if let Some(walreceiver_connect_timeout) = other.walreceiver_connect_timeout {
self.walreceiver_connect_timeout = Some(walreceiver_connect_timeout);
}
if let Some(lagging_wal_timeout) = other.lagging_wal_timeout {
self.lagging_wal_timeout = Some(lagging_wal_timeout);
}
if let Some(max_lsn_wal_lag) = other.max_lsn_wal_lag {
self.max_lsn_wal_lag = Some(max_lsn_wal_lag);
}
if let Some(trace_read_requests) = other.trace_read_requests {
self.trace_read_requests = Some(trace_read_requests);
}
}
}
@@ -259,7 +292,6 @@ impl Default for TenantConf {
.expect("cannot parse default max walreceiver Lsn wal lag"),
trace_read_requests: false,
eviction_policy: EvictionPolicy::NoEviction,
min_resident_size_override: None,
}
}
}

View File

@@ -2,7 +2,9 @@
//! used to keep in-memory layers spilled on disk.
use crate::config::PageServerConf;
use crate::page_cache::{self, ReadBufResult, WriteBufResult, PAGE_SZ};
use crate::page_cache;
use crate::page_cache::PAGE_SZ;
use crate::page_cache::{ReadBufResult, WriteBufResult};
use crate::tenant::blob_io::BlobWriter;
use crate::tenant::block_io::BlockReader;
use crate::virtual_file::VirtualFile;
@@ -425,6 +427,7 @@ mod tests {
let actual = cursor.read_blob(pos)?;
assert_eq!(actual, expected);
}
drop(cursor);
// Test a large blob that spans multiple pages
let mut large_data = Vec::new();

View File

@@ -154,7 +154,11 @@ where
expected: &Arc<L>,
new: Arc<L>,
) -> anyhow::Result<Replacement<Arc<L>>> {
fail::fail_point!("layermap-replace-notfound", |_| Ok(Replacement::NotFound));
fail::fail_point!("layermap-replace-notfound", |_| Ok(
// this is not what happens if an L0 layer was not found a anyhow error but perhaps
// that should be changed. this is good enough to show a replacement failure.
Replacement::NotFound
));
self.layer_map.replace_historic_noflush(expected, new)
}
@@ -336,15 +340,12 @@ where
let l0_index = if expected_l0 {
// find the index in case replace worked, we need to replace that as well
let pos = self
.l0_delta_layers
.iter()
.position(|slot| Self::compare_arced_layers(slot, expected));
if pos.is_none() {
return Ok(Replacement::NotFound);
}
pos
Some(
self.l0_delta_layers
.iter()
.position(|slot| Self::compare_arced_layers(slot, expected))
.ok_or_else(|| anyhow::anyhow!("existing l0 delta layer was not found"))?,
)
} else {
None
};
@@ -803,26 +804,6 @@ mod tests {
)
}
#[test]
fn replacing_missing_l0_is_notfound() {
// original impl had an oversight, and L0 was an anyhow::Error. anyhow::Error should
// however only happen for precondition failures.
let layer = "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000053423C21-0000000053424D69";
let layer = LayerFileName::from_str(layer).unwrap();
let layer = LayerDescriptor::from(layer);
// same skeletan construction; see scenario below
let not_found: Arc<dyn Layer> = Arc::new(layer.clone());
let new_version: Arc<dyn Layer> = Arc::new(layer);
let mut map = LayerMap::default();
let res = map.batch_update().replace_historic(&not_found, new_version);
assert!(matches!(res, Ok(Replacement::NotFound)), "{res:?}");
}
fn l0_delta_layers_updated_scenario(layer_name: &str, expected_l0: bool) {
let name = LayerFileName::from_str(layer_name).unwrap();
let skeleton = LayerDescriptor::from(name);
@@ -832,8 +813,7 @@ mod tests {
let mut map = LayerMap::default();
// two disjoint Arcs in different lifecycle phases. even if it seems they must be the
// same layer, we use LayerMap::compare_arced_layers as the identity of layers.
// two disjoint Arcs in different lifecycle phases.
assert!(!LayerMap::compare_arced_layers(&remote, &downloaded));
let expected_in_counts = (1, usize::from(expected_l0));

View File

@@ -289,7 +289,7 @@ pub async fn set_new_tenant_config(
conf: &'static PageServerConf,
new_tenant_conf: TenantConfOpt,
tenant_id: TenantId,
) -> Result<(), TenantStateError> {
) -> anyhow::Result<()> {
info!("configuring tenant {tenant_id}");
let tenant = get_tenant(tenant_id, true).await?;
@@ -306,84 +306,50 @@ pub async fn set_new_tenant_config(
/// Gets the tenant from the in-memory data, erroring if it's absent or is not fitting to the query.
/// `active_only = true` allows to query only tenants that are ready for operations, erroring on other kinds of tenants.
pub async fn get_tenant(
tenant_id: TenantId,
active_only: bool,
) -> Result<Arc<Tenant>, TenantStateError> {
pub async fn get_tenant(tenant_id: TenantId, active_only: bool) -> anyhow::Result<Arc<Tenant>> {
let m = TENANTS.read().await;
let tenant = m
.get(&tenant_id)
.ok_or(TenantStateError::NotFound(tenant_id))?;
.with_context(|| format!("Tenant {tenant_id} not found in the local state"))?;
if active_only && !tenant.is_active() {
Err(TenantStateError::NotActive(tenant_id))
anyhow::bail!(
"Tenant {tenant_id} is not active. Current state: {:?}",
tenant.current_state()
)
} else {
Ok(Arc::clone(tenant))
}
}
#[derive(Debug, thiserror::Error)]
pub enum DeleteTimelineError {
#[error("Tenant {0}")]
Tenant(#[from] TenantStateError),
#[error("Timeline {0}")]
Timeline(#[from] crate::tenant::DeleteTimelineError),
}
pub async fn delete_timeline(
tenant_id: TenantId,
timeline_id: TimelineId,
ctx: &RequestContext,
) -> Result<(), DeleteTimelineError> {
let tenant = get_tenant(tenant_id, true).await?;
tenant.delete_timeline(timeline_id, ctx).await?;
Ok(())
}
) -> anyhow::Result<()> {
match get_tenant(tenant_id, true).await {
Ok(tenant) => {
tenant.delete_timeline(timeline_id, ctx).await?;
}
Err(e) => anyhow::bail!("Cannot access tenant {tenant_id} in local tenant state: {e:?}"),
}
#[derive(Debug, thiserror::Error)]
pub enum TenantStateError {
#[error("Tenant {0} not found")]
NotFound(TenantId),
#[error("Tenant {0} is stopping")]
IsStopping(TenantId),
#[error("Tenant {0} is not active")]
NotActive(TenantId),
#[error(transparent)]
Other(#[from] anyhow::Error),
Ok(())
}
pub async fn detach_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
detach_ignored: bool,
) -> Result<(), TenantStateError> {
let local_files_cleanup_operation = |tenant_id_to_clean| async move {
let local_tenant_directory = conf.tenant_path(&tenant_id_to_clean);
) -> anyhow::Result<()> {
remove_tenant_from_memory(tenant_id, async {
let local_tenant_directory = conf.tenant_path(&tenant_id);
fs::remove_dir_all(&local_tenant_directory)
.await
.with_context(|| {
format!("local tenant directory {local_tenant_directory:?} removal")
format!("Failed to remove local tenant directory {local_tenant_directory:?}")
})?;
Ok(())
};
let removal_result =
remove_tenant_from_memory(tenant_id, local_files_cleanup_operation(tenant_id)).await;
// Ignored tenants are not present in memory and will bail the removal from memory operation.
// Before returning the error, check for ignored tenant removal case — we only need to clean its local files then.
if detach_ignored && matches!(removal_result, Err(TenantStateError::NotFound(_))) {
let tenant_ignore_mark = conf.tenant_ignore_mark_file_path(tenant_id);
if tenant_ignore_mark.exists() {
info!("Detaching an ignored tenant");
local_files_cleanup_operation(tenant_id)
.await
.with_context(|| format!("Ignored tenant {tenant_id} local files cleanup"))?;
return Ok(());
}
}
removal_result
})
.await
}
pub async fn load_tenant(
@@ -413,7 +379,7 @@ pub async fn load_tenant(
pub async fn ignore_tenant(
conf: &'static PageServerConf,
tenant_id: TenantId,
) -> Result<(), TenantStateError> {
) -> anyhow::Result<()> {
remove_tenant_from_memory(tenant_id, async {
let ignore_mark_file = conf.tenant_ignore_mark_file_path(tenant_id);
fs::File::create(&ignore_mark_file)
@@ -523,7 +489,7 @@ where
async fn remove_tenant_from_memory<V, F>(
tenant_id: TenantId,
tenant_cleanup: F,
) -> Result<V, TenantStateError>
) -> anyhow::Result<V>
where
F: std::future::Future<Output = anyhow::Result<V>>,
{
@@ -539,9 +505,11 @@ where
| TenantState::Loading
| TenantState::Broken
| TenantState::Active => tenant.set_stopping(),
TenantState::Stopping => return Err(TenantStateError::IsStopping(tenant_id)),
TenantState::Stopping => {
anyhow::bail!("Tenant {tenant_id} is stopping already")
}
},
None => return Err(TenantStateError::NotFound(tenant_id)),
None => anyhow::bail!("Tenant not found for id {tenant_id}"),
}
}
@@ -564,15 +532,10 @@ where
Err(e) => {
let tenants_accessor = TENANTS.read().await;
match tenants_accessor.get(&tenant_id) {
Some(tenant) => {
tenant.set_broken(&e.to_string());
}
None => {
warn!("Tenant {tenant_id} got removed from memory");
return Err(TenantStateError::NotFound(tenant_id));
}
Some(tenant) => tenant.set_broken(&e.to_string()),
None => warn!("Tenant {tenant_id} got removed from memory"),
}
Err(TenantStateError::Other(e))
Err(e)
}
}
}
@@ -592,7 +555,7 @@ pub async fn immediate_gc(
let tenant = guard
.get(&tenant_id)
.map(Arc::clone)
.with_context(|| format!("tenant {tenant_id}"))
.with_context(|| format!("Tenant {tenant_id} not found"))
.map_err(ApiError::NotFound)?;
let gc_horizon = gc_req.gc_horizon.unwrap_or_else(|| tenant.get_gc_horizon());
@@ -642,7 +605,7 @@ pub async fn immediate_compact(
let tenant = guard
.get(&tenant_id)
.map(Arc::clone)
.with_context(|| format!("tenant {tenant_id}"))
.with_context(|| format!("Tenant {tenant_id} not found"))
.map_err(ApiError::NotFound)?;
let timeline = tenant

View File

@@ -210,6 +210,7 @@ pub use download::{is_temp_download_file, list_remote_timelines};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use anyhow::ensure;
use remote_storage::{DownloadError, GenericRemoteStorage};
use std::ops::DerefMut;
use tokio::runtime::Runtime;
@@ -217,10 +218,9 @@ use tracing::{debug, info, warn};
use tracing::{info_span, Instrument};
use utils::lsn::Lsn;
use crate::metrics::{
MeasureRemoteOp, RemoteOpFileKind, RemoteOpKind, RemoteTimelineClientMetrics,
REMOTE_ONDEMAND_DOWNLOADED_BYTES, REMOTE_ONDEMAND_DOWNLOADED_LAYERS,
};
use crate::metrics::RemoteOpFileKind;
use crate::metrics::RemoteOpKind;
use crate::metrics::{MeasureRemoteOp, RemoteTimelineClientMetrics};
use crate::tenant::remote_timeline_client::index::LayerFileMetadata;
use crate::{
config::PageServerConf,
@@ -346,7 +346,7 @@ impl RemoteTimelineClient {
.layer_metadata
.values()
// If we don't have the file size for the layer, don't account for it in the metric.
.map(|ilmd| ilmd.file_size)
.map(|ilmd| ilmd.file_size.unwrap_or(0))
.sum()
} else {
0
@@ -419,9 +419,33 @@ impl RemoteTimelineClient {
.await?
};
REMOTE_ONDEMAND_DOWNLOADED_LAYERS.inc();
REMOTE_ONDEMAND_DOWNLOADED_BYTES.inc_by(downloaded_size);
// Update the metadata for given layer file. The remote index file
// might be missing some information for the file; this allows us
// to fill in the missing details.
if layer_metadata.file_size().is_none() {
let new_metadata = LayerFileMetadata::new(downloaded_size);
let mut guard = self.upload_queue.lock().unwrap();
let upload_queue = guard.initialized_mut()?;
if let Some(upgraded) = upload_queue.latest_files.get_mut(layer_file_name) {
if upgraded.merge(&new_metadata) {
upload_queue.latest_files_changes_since_metadata_upload_scheduled += 1;
}
// If we don't do an index file upload inbetween here and restart,
// the value will go back down after pageserver restart, since we will
// have lost this data point.
// But, we upload index part fairly frequently, and restart pageserver rarely.
// So, by accounting eagerly, we present a most-of-the-time-more-accurate value sooner.
self.metrics
.remote_physical_size_gauge()
.add(downloaded_size);
} else {
// The file should exist, since we just downloaded it.
warn!(
"downloaded file {:?} not found in local copy of the index file",
layer_file_name
);
}
}
Ok(downloaded_size)
}
@@ -521,6 +545,13 @@ impl RemoteTimelineClient {
let mut guard = self.upload_queue.lock().unwrap();
let upload_queue = guard.initialized_mut()?;
// The file size can be missing for files that were created before we tracked that
// in the metadata, but it should be present for any new files we create.
ensure!(
layer_metadata.file_size().is_some(),
"file size not initialized in metadata"
);
upload_queue
.latest_files
.insert(layer_file_name.clone(), layer_metadata.clone());

View File

@@ -6,13 +6,11 @@
use std::collections::HashSet;
use std::future::Future;
use std::path::Path;
use std::time::Duration;
use anyhow::{anyhow, Context};
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tracing::{info, warn};
use tracing::{error, info, warn};
use crate::config::PageServerConf;
use crate::tenant::storage_layer::LayerFileName;
@@ -21,15 +19,13 @@ use remote_storage::{DownloadError, GenericRemoteStorage};
use utils::crashsafe::path_with_suffix_extension;
use utils::id::{TenantId, TimelineId};
use super::index::{IndexPart, LayerFileMetadata};
use super::index::{IndexPart, IndexPartUnclean, LayerFileMetadata};
use super::{FAILED_DOWNLOAD_RETRIES, FAILED_DOWNLOAD_WARN_THRESHOLD};
async fn fsync_path(path: impl AsRef<std::path::Path>) -> Result<(), std::io::Error> {
fs::File::open(path).await?.sync_all().await
}
static MAX_DOWNLOAD_DURATION: Duration = Duration::from_secs(120);
///
/// If 'metadata' is given, we will validate that the downloaded file's size matches that
/// in the metadata. (In the future, we might do more cross-checks, like CRC validation)
@@ -68,28 +64,22 @@ pub async fn download_layer_file<'a>(
// TODO: this doesn't use the cached fd for some reason?
let mut destination_file = fs::File::create(&temp_file_path).await.with_context(|| {
format!(
"create a destination file for layer '{}'",
"Failed to create a destination file for layer '{}'",
temp_file_path.display()
)
})
.map_err(DownloadError::Other)?;
let mut download = storage.download(&remote_path).await.with_context(|| {
format!(
"open a download stream for layer with remote storage path '{remote_path:?}'"
"Failed to open a download stream for layer with remote storage path '{remote_path:?}'"
)
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::time::timeout(MAX_DOWNLOAD_DURATION, tokio::io::copy(&mut download.download_stream, &mut destination_file))
.await
.map_err(|e| DownloadError::Other(anyhow::anyhow!("Timed out {:?}", e)))?
.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::io::copy(&mut download.download_stream, &mut destination_file).await.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
Ok((destination_file, bytes_amount))
},
&format!("download {remote_path:?}"),
).await?;
@@ -113,11 +103,16 @@ pub async fn download_layer_file<'a>(
})
.map_err(DownloadError::Other)?;
let expected = layer_metadata.file_size();
if expected != bytes_amount {
return Err(DownloadError::Other(anyhow!(
"According to layer file metadata should have downloaded {expected} bytes but downloaded {bytes_amount} bytes into file {temp_file_path:?}",
)));
match layer_metadata.file_size() {
Some(expected) if expected != bytes_amount => {
return Err(DownloadError::Other(anyhow!(
"According to layer file metadata should have downloaded {expected} bytes but downloaded {bytes_amount} bytes into file '{}'",
temp_file_path.display()
)));
}
Some(_) | None => {
// matches, or upgrading from an earlier IndexPart version
}
}
// not using sync_data because it can lose file size update
@@ -256,12 +251,14 @@ pub(super) async fn download_index_part(
)
.await?;
let index_part: IndexPart = serde_json::from_slice(&index_part_bytes)
let index_part: IndexPartUnclean = serde_json::from_slice(&index_part_bytes)
.with_context(|| {
format!("Failed to deserialize index part file into file {index_part_path:?}")
})
.map_err(DownloadError::Other)?;
let index_part = index_part.remove_unclean_layer_file_names();
Ok(index_part)
}
@@ -303,7 +300,7 @@ where
}
Err(DownloadError::Other(ref err)) => {
// Operation failed FAILED_DOWNLOAD_RETRIES times. Time to give up.
warn!("{description} still failed after {attempts} retries, giving up: {err:?}");
error!("{description} still failed after {attempts} retries, giving up: {err:?}");
return result;
}
}

View File

@@ -6,6 +6,7 @@ use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use tracing::warn;
use crate::tenant::metadata::TimelineMetadata;
use crate::tenant::storage_layer::LayerFileName;
@@ -19,7 +20,7 @@ use utils::lsn::Lsn;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(test, derive(Default))]
pub struct LayerFileMetadata {
file_size: u64,
file_size: Option<u64>,
}
impl From<&'_ IndexLayerMetadata> for LayerFileMetadata {
@@ -32,16 +33,36 @@ impl From<&'_ IndexLayerMetadata> for LayerFileMetadata {
impl LayerFileMetadata {
pub fn new(file_size: u64) -> Self {
LayerFileMetadata { file_size }
LayerFileMetadata {
file_size: Some(file_size),
}
}
pub fn file_size(&self) -> u64 {
/// This is used to initialize the metadata for remote layers, for which
/// the metadata was missing from the index part file.
pub const MISSING: Self = LayerFileMetadata { file_size: None };
pub fn file_size(&self) -> Option<u64> {
self.file_size
}
/// Metadata has holes due to version upgrades. This method is called to upgrade self with the
/// other value.
///
/// This is called on the possibly outdated version. Returns true if any changes
/// were made.
pub fn merge(&mut self, other: &Self) -> bool {
let mut changed = false;
if self.file_size != other.file_size {
self.file_size = other.file_size.or(self.file_size);
changed = true;
}
changed
}
}
// TODO seems like another part of the remote storage file format
// compatibility issue, see https://github.com/neondatabase/neon/issues/3072
/// In-memory representation of an `index_part.json` file
///
/// Contains the data about all files in the timeline, present remotely and its metadata.
@@ -50,7 +71,10 @@ impl LayerFileMetadata {
/// remember to add a test case for the changed version.
#[serde_as]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct IndexPart {
pub struct IndexPartImpl<L>
where
L: std::hash::Hash + PartialEq + Eq,
{
/// Debugging aid describing the version of this type.
#[serde(default)]
version: usize,
@@ -58,13 +82,14 @@ pub struct IndexPart {
/// Layer names, which are stored on the remote storage.
///
/// Additional metadata can might exist in `layer_metadata`.
pub timeline_layers: HashSet<LayerFileName>,
pub timeline_layers: HashSet<L>,
/// Per layer file name metadata, which can be present for a present or missing layer file.
///
/// Older versions of `IndexPart` will not have this property or have only a part of metadata
/// that latest version stores.
pub layer_metadata: HashMap<LayerFileName, IndexLayerMetadata>,
#[serde(default = "HashMap::default")]
pub layer_metadata: HashMap<L, IndexLayerMetadata>,
// 'disk_consistent_lsn' is a copy of the 'disk_consistent_lsn' in the metadata.
// It's duplicated here for convenience.
@@ -73,6 +98,101 @@ pub struct IndexPart {
metadata_bytes: Vec<u8>,
}
// TODO seems like another part of the remote storage file format
// compatibility issue, see https://github.com/neondatabase/neon/issues/3072
pub type IndexPart = IndexPartImpl<LayerFileName>;
pub type IndexPartUnclean = IndexPartImpl<UncleanLayerFileName>;
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub enum UncleanLayerFileName {
Clean(LayerFileName),
BackupFile(String),
}
impl<'de> serde::Deserialize<'de> for UncleanLayerFileName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_string(UncleanLayerFileNameVisitor)
}
}
struct UncleanLayerFileNameVisitor;
impl<'de> serde::de::Visitor<'de> for UncleanLayerFileNameVisitor {
type Value = UncleanLayerFileName;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
formatter,
"a string that is a valid LayerFileName or '.old' backup file name"
)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
let maybe_clean: Result<LayerFileName, _> = v.parse();
match maybe_clean {
Ok(clean) => Ok(UncleanLayerFileName::Clean(clean)),
Err(e) => {
if v.ends_with(".old") || v == "metadata_backup" {
Ok(UncleanLayerFileName::BackupFile(v.to_owned()))
} else {
Err(E::custom(e))
}
}
}
}
}
impl UncleanLayerFileName {
fn into_clean(self) -> Option<LayerFileName> {
match self {
UncleanLayerFileName::Clean(clean) => Some(clean),
UncleanLayerFileName::BackupFile(_) => None,
}
}
}
impl IndexPartUnclean {
pub fn remove_unclean_layer_file_names(self) -> IndexPart {
let IndexPartUnclean {
version,
timeline_layers,
layer_metadata,
disk_consistent_lsn,
metadata_bytes,
} = self;
IndexPart {
version,
timeline_layers: timeline_layers
.into_iter()
.filter_map(|unclean_file_name| match unclean_file_name {
UncleanLayerFileName::Clean(clean_name) => Some(clean_name),
UncleanLayerFileName::BackupFile(backup_file_name) => {
// For details see https://github.com/neondatabase/neon/issues/3024
warn!(
"got backup file on the remote storage, ignoring it {backup_file_name}"
);
None
}
})
.collect(),
layer_metadata: layer_metadata
.into_iter()
.filter_map(|(l, m)| l.into_clean().map(|l| (l, m)))
.collect(),
disk_consistent_lsn,
metadata_bytes,
}
}
}
impl IndexPart {
/// When adding or modifying any parts of `IndexPart`, increment the version so that it can be
/// used to understand later versions.
@@ -112,7 +232,7 @@ impl IndexPart {
/// Serialized form of [`LayerFileMetadata`].
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, Default)]
pub struct IndexLayerMetadata {
pub(super) file_size: u64,
pub(super) file_size: Option<u64>,
}
impl From<&'_ LayerFileMetadata> for IndexLayerMetadata {
@@ -127,6 +247,27 @@ impl From<&'_ LayerFileMetadata> for IndexLayerMetadata {
mod tests {
use super::*;
#[test]
fn v0_indexpart_is_parsed() {
let example = r#"{
"timeline_layers":["000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9"],
"disk_consistent_lsn":"0/16960E8",
"metadata_bytes":[113,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
}"#;
let expected = IndexPart {
version: 0,
timeline_layers: HashSet::from(["000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap()]),
layer_metadata: HashMap::default(),
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
metadata_bytes: [113,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0].to_vec(),
};
let part: IndexPartUnclean = serde_json::from_str(example).unwrap();
let part = part.remove_unclean_layer_file_names();
assert_eq!(part, expected);
}
#[test]
fn v1_indexpart_is_parsed() {
let example = r#"{
@@ -146,19 +287,21 @@ mod tests {
timeline_layers: HashSet::from(["000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap()]),
layer_metadata: HashMap::from([
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
file_size: 25600000,
file_size: Some(25600000),
}),
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
// serde_json should always parse this but this might be a double with jq for
// example.
file_size: 9007199254741001,
file_size: Some(9007199254741001),
})
]),
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
metadata_bytes: [113,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0].to_vec(),
};
let part = serde_json::from_str::<IndexPart>(example).unwrap();
let part = serde_json::from_str::<IndexPartUnclean>(example)
.unwrap()
.remove_unclean_layer_file_names();
assert_eq!(part, expected);
}
@@ -182,64 +325,20 @@ mod tests {
timeline_layers: HashSet::from(["000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap()]),
layer_metadata: HashMap::from([
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
file_size: 25600000,
file_size: Some(25600000),
}),
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
// serde_json should always parse this but this might be a double with jq for
// example.
file_size: 9007199254741001,
file_size: Some(9007199254741001),
})
]),
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
metadata_bytes: [112,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0].to_vec(),
};
let part = serde_json::from_str::<IndexPart>(example).unwrap();
let part = serde_json::from_str::<IndexPartUnclean>(example).unwrap();
let part = part.remove_unclean_layer_file_names();
assert_eq!(part, expected);
}
#[test]
fn empty_layers_are_parsed() {
let empty_layers_json = r#"{
"version":1,
"timeline_layers":[],
"layer_metadata":{},
"disk_consistent_lsn":"0/2532648",
"metadata_bytes":[136,151,49,208,0,70,0,4,0,0,0,0,2,83,38,72,1,0,0,0,0,2,83,38,32,1,87,198,240,135,97,119,45,125,38,29,155,161,140,141,255,210,0,0,0,0,2,83,38,72,0,0,0,0,1,73,240,192,0,0,0,0,1,73,240,192,0,0,0,15,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
}"#;
let expected = IndexPart {
version: 1,
timeline_layers: HashSet::new(),
layer_metadata: HashMap::new(),
disk_consistent_lsn: "0/2532648".parse::<Lsn>().unwrap(),
metadata_bytes: [
136, 151, 49, 208, 0, 70, 0, 4, 0, 0, 0, 0, 2, 83, 38, 72, 1, 0, 0, 0, 0, 2, 83,
38, 32, 1, 87, 198, 240, 135, 97, 119, 45, 125, 38, 29, 155, 161, 140, 141, 255,
210, 0, 0, 0, 0, 2, 83, 38, 72, 0, 0, 0, 0, 1, 73, 240, 192, 0, 0, 0, 0, 1, 73,
240, 192, 0, 0, 0, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0,
]
.to_vec(),
};
let empty_layers_parsed = serde_json::from_str::<IndexPart>(empty_layers_json).unwrap();
assert_eq!(empty_layers_parsed, expected);
}
}

View File

@@ -64,9 +64,13 @@ pub(super) async fn upload_timeline_layer<'a>(
})?
.len();
let metadata_size = known_metadata.file_size();
if metadata_size != fs_size {
bail!("File {source_path:?} has its current FS size {fs_size} diferent from initially determined {metadata_size}");
// FIXME: this looks bad
if let Some(metadata_size) = known_metadata.file_size() {
if metadata_size != fs_size {
bail!("File {source_path:?} has its current FS size {fs_size} diferent from initially determined {metadata_size}");
}
} else {
// this is a silly state we would like to avoid
}
let fs_size = usize::try_from(fs_size).with_context(|| {

View File

@@ -121,10 +121,10 @@ struct LayerAccessStatsInner {
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct LayerAccessStatFullDetails {
pub(crate) when: SystemTime,
pub(crate) task_kind: TaskKind,
pub(crate) access_kind: LayerAccessKind,
pub(super) struct LayerAccessStatFullDetails {
pub(super) when: SystemTime,
pub(super) task_kind: TaskKind,
pub(super) access_kind: LayerAccessKind,
}
#[derive(Clone, Copy, strum_macros::EnumString)]
@@ -255,7 +255,7 @@ impl LayerAccessStats {
ret
}
fn most_recent_access_or_residence_event(
pub(super) fn most_recent_access_or_residence_event(
&self,
) -> Either<LayerAccessStatFullDetails, LayerResidenceEvent> {
let locked = self.0.lock().unwrap();
@@ -268,13 +268,6 @@ impl LayerAccessStats {
}
}
}
pub(crate) fn latest_activity(&self) -> SystemTime {
match self.most_recent_access_or_residence_event() {
Either::Left(mra) => mra.when,
Either::Right(re) => re.timestamp,
}
}
}
/// Supertrait of the [`Layer`] trait that captures the bare minimum interface
@@ -371,7 +364,7 @@ pub trait PersistentLayer: Layer {
}
/// Permanently remove this layer from disk.
fn delete_resident_layer_file(&self) -> Result<()>;
fn delete(&self) -> Result<()>;
fn downcast_remote_layer(self: Arc<Self>) -> Option<std::sync::Arc<RemoteLayer>> {
None
@@ -385,7 +378,7 @@ pub trait PersistentLayer: Layer {
///
/// Should not change over the lifetime of the layer object because
/// current_physical_size is computed as the som of this value.
fn file_size(&self) -> u64;
fn file_size(&self) -> Option<u64>;
fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo;

View File

@@ -438,14 +438,14 @@ impl PersistentLayer for DeltaLayer {
))
}
fn delete_resident_layer_file(&self) -> Result<()> {
fn delete(&self) -> Result<()> {
// delete underlying file
fs::remove_file(self.path())?;
Ok(())
}
fn file_size(&self) -> u64 {
self.file_size
fn file_size(&self) -> Option<u64> {
Some(self.file_size)
}
fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
@@ -456,7 +456,7 @@ impl PersistentLayer for DeltaLayer {
HistoricLayerInfo::Delta {
layer_file_name,
layer_file_size: self.file_size,
layer_file_size: Some(self.file_size),
lsn_start: lsn_range.start,
lsn_end: lsn_range.end,
remote: false,

View File

@@ -258,15 +258,6 @@ impl serde::Serialize for LayerFileName {
}
}
impl<'de> serde::Deserialize<'de> for LayerFileName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_string(LayerFileNameVisitor)
}
}
struct LayerFileNameVisitor;
impl<'de> serde::de::Visitor<'de> for LayerFileNameVisitor {

View File

@@ -252,14 +252,14 @@ impl PersistentLayer for ImageLayer {
unimplemented!();
}
fn delete_resident_layer_file(&self) -> Result<()> {
fn delete(&self) -> Result<()> {
// delete underlying file
fs::remove_file(self.path())?;
Ok(())
}
fn file_size(&self) -> u64 {
self.file_size
fn file_size(&self) -> Option<u64> {
Some(self.file_size)
}
fn info(&self, reset: LayerAccessStatsReset) -> HistoricLayerInfo {
@@ -268,7 +268,7 @@ impl PersistentLayer for ImageLayer {
HistoricLayerInfo::Image {
layer_file_name,
layer_file_size: self.file_size,
layer_file_size: Some(self.file_size),
lsn_start: lsn_range.start,
remote: false,
access_stats: self.access_stats.as_api_model(reset),

View File

@@ -155,8 +155,8 @@ impl PersistentLayer for RemoteLayer {
bail!("cannot iterate a remote layer");
}
fn delete_resident_layer_file(&self) -> Result<()> {
bail!("remote layer has no layer file");
fn delete(&self) -> Result<()> {
Ok(())
}
fn downcast_remote_layer<'a>(self: Arc<Self>) -> Option<std::sync::Arc<RemoteLayer>> {
@@ -167,7 +167,7 @@ impl PersistentLayer for RemoteLayer {
true
}
fn file_size(&self) -> u64 {
fn file_size(&self) -> Option<u64> {
self.layer_metadata.file_size()
}

View File

@@ -244,12 +244,14 @@ pub(crate) async fn random_init_delay(
) -> Result<(), Cancelled> {
use rand::Rng;
if period == Duration::ZERO {
return Ok(());
}
let d = {
let mut rng = rand::thread_rng();
// gen_range asserts that the range cannot be empty, which it could be because period can
// be set to zero to disable gc or compaction, so lets set it to be at least 10s.
let period = std::cmp::max(period, Duration::from_secs(10));
// semi-ok default as the source of jitter
rng.gen_range(Duration::ZERO..=period)
};

View File

@@ -13,7 +13,6 @@ use pageserver_api::models::{
DownloadRemoteLayersTaskInfo, DownloadRemoteLayersTaskSpawnRequest,
DownloadRemoteLayersTaskState, LayerMapInfo, LayerResidenceStatus, TimelineState,
};
use remote_storage::GenericRemoteStorage;
use tokio::sync::{oneshot, watch, Semaphore, TryAcquireError};
use tokio_util::sync::CancellationToken;
use tracing::*;
@@ -72,8 +71,6 @@ use crate::ZERO_PAGE;
use crate::{is_temporary, task_mgr};
use walreceiver::spawn_connection_manager_task;
use self::eviction_task::EvictionTaskTimelineState;
use super::layer_map::BatchedUpdates;
use super::remote_timeline_client::index::IndexPart;
use super::remote_timeline_client::RemoteTimelineClient;
@@ -219,8 +216,6 @@ pub struct Timeline {
download_all_remote_layers_task_info: RwLock<Option<DownloadRemoteLayersTaskInfo>>,
state: watch::Sender<TimelineState>,
eviction_task_timeline_state: tokio::sync::Mutex<EvictionTaskTimelineState>,
}
/// Internal structure to hold all data needed for logical size calculation.
@@ -333,12 +328,27 @@ impl LogicalSize {
.fetch_add(delta, AtomicOrdering::SeqCst);
}
/// Make the value computed by initial logical size computation
/// available for re-use. This doesn't contain the incremental part.
fn initialized_size(&self, lsn: Lsn) -> Option<u64> {
match self.initial_part_end {
Some(v) if v == lsn => self.initial_logical_size.get().copied(),
_ => None,
/// Returns the initialized (already calculated) value, if any.
fn initialized_size(&self) -> Option<u64> {
self.initial_logical_size.get().copied()
}
}
/// Returned by [`Timeline::layer_size_sum`]
pub enum LayerSizeSum {
/// The result is accurate.
Accurate(u64),
// We don't know the layer file size of one or more layers.
// They contribute to the sum with a value of 0.
// Hence, the sum is a lower bound for the actualy layer file size sum.
ApproximateLowerBound(u64),
}
impl LayerSizeSum {
pub fn approximate_is_ok(self) -> u64 {
match self {
LayerSizeSum::Accurate(v) => v,
LayerSizeSum::ApproximateLowerBound(v) => v,
}
}
}
@@ -540,13 +550,20 @@ impl Timeline {
/// The sum of the file size of all historic layers in the layer map.
/// This method makes no distinction between local and remote layers.
/// Hence, the result **does not represent local filesystem usage**.
pub fn layer_size_sum(&self) -> u64 {
pub fn layer_size_sum(&self) -> LayerSizeSum {
let layer_map = self.layers.read().unwrap();
let mut size = 0;
let mut no_size_cnt = 0;
for l in layer_map.iter_historic_layers() {
size += l.file_size();
let (l_size, l_no_size) = l.file_size().map(|s| (s, 0)).unwrap_or((0, 1));
size += l_size;
no_size_cnt += l_no_size;
}
if no_size_cnt == 0 {
LayerSizeSum::Accurate(size)
} else {
LayerSizeSum::ApproximateLowerBound(size)
}
size
}
pub fn get_resident_physical_size(&self) -> u64 {
@@ -645,8 +662,8 @@ impl Timeline {
// update the index file on next flush iteration too. But it
// could take a while until that happens.
//
// Additionally, only do this once before we return from this function.
if last_round || res.is_ok() {
// Additionally, only do this on the terminal round before sleeping.
if last_round {
if let Some(remote_client) = &self.remote_client {
remote_client.schedule_index_upload_for_file_changes()?;
}
@@ -815,11 +832,11 @@ impl Timeline {
let mut is_exact = true;
let size = current_size.size();
if let (CurrentLogicalSize::Approximate(_), Some(initial_part_end)) =
if let (CurrentLogicalSize::Approximate(_), Some(init_lsn)) =
(current_size, self.current_logical_size.initial_part_end)
{
is_exact = false;
self.try_spawn_size_init_task(initial_part_end, ctx);
self.try_spawn_size_init_task(init_lsn, ctx);
}
Ok((size, is_exact))
@@ -957,25 +974,6 @@ impl Timeline {
}
}
/// Evict a batch of layers.
///
/// GenericRemoteStorage reference is required as a witness[^witness_article] for "remote storage is configured."
///
/// [^witness_article]: https://willcrichton.net/rust-api-type-patterns/witnesses.html
pub async fn evict_layers(
&self,
_: &GenericRemoteStorage,
layers_to_evict: &[Arc<dyn PersistentLayer>],
cancel: CancellationToken,
) -> anyhow::Result<Vec<Option<anyhow::Result<bool>>>> {
let remote_client = self.remote_client.clone().expect(
"GenericRemoteStorage is configured, so timeline must have RemoteTimelineClient",
);
self.evict_layer_batch(&remote_client, layers_to_evict, cancel)
.await
}
/// Evict multiple layers at once, continuing through errors.
///
/// Try to evict the given `layers_to_evict` by
@@ -1013,15 +1011,6 @@ impl Timeline {
// now lock out layer removal (compaction, gc, timeline deletion)
let layer_removal_guard = self.layer_removal_cs.lock().await;
{
// to avoid racing with detach and delete_timeline
let state = self.current_state();
anyhow::ensure!(
state == TimelineState::Active,
"timeline is not active but {state:?}"
);
}
// start the batch update
let mut layer_map = self.layers.write().unwrap();
let mut batch_updates = layer_map.batch_update();
@@ -1055,31 +1044,14 @@ impl Timeline {
use super::layer_map::Replacement;
if local_layer.is_remote_layer() {
// TODO(issue #3851): consider returning an err here instead of false,
// which is the same out the match later
return Ok(false);
}
let layer_file_size = local_layer.file_size();
let local_layer_mtime = local_layer
.local_path()
.expect("local layer should have a local path")
.metadata()
.context("get local layer file stat")?
.modified()
.context("get mtime of layer file")?;
let local_layer_residence_duration =
match SystemTime::now().duration_since(local_layer_mtime) {
Err(e) => {
warn!("layer mtime is in the future: {}", e);
None
}
Ok(delta) => Some(delta),
};
let layer_metadata = LayerFileMetadata::new(layer_file_size);
let layer_metadata = LayerFileMetadata::new(
local_layer
.file_size()
.expect("Local layer should have a file size"),
);
let new_remote_layer = Arc::new(match local_layer.filename() {
LayerFileName::Image(image_name) => RemoteLayer::new_img(
self.tenant_id,
@@ -1103,29 +1075,14 @@ impl Timeline {
let replaced = match batch_updates.replace_historic(local_layer, new_remote_layer)? {
Replacement::Replaced { .. } => {
if let Err(e) = local_layer.delete_resident_layer_file() {
let layer_size = local_layer.file_size();
if let Err(e) = local_layer.delete() {
error!("failed to remove layer file on evict after replacement: {e:#?}");
}
// Always decrement the physical size gauge, even if we failed to delete the file.
// Rationale: we already replaced the layer with a remote layer in the layer map,
// and any subsequent download_remote_layer will
// 1. overwrite the file on disk and
// 2. add the downloaded size to the resident size gauge.
//
// If there is no re-download, and we restart the pageserver, then load_layer_map
// will treat the file as a local layer again, count it towards resident size,
// and it'll be like the layer removal never happened.
// The bump in resident size is perhaps unexpected but overall a robust behavior.
self.metrics
.resident_physical_size_gauge
.sub(layer_file_size);
self.metrics.evictions.inc();
if let Some(delta) = local_layer_residence_duration {
self.metrics
.evictions_with_low_residence_duration
.observe(delta);
if let Some(layer_size) = layer_size {
self.metrics.resident_physical_size_gauge.sub(layer_size);
}
true
@@ -1202,7 +1159,7 @@ impl Timeline {
pub(super) fn new(
conf: &'static PageServerConf,
tenant_conf: Arc<RwLock<TenantConfOpt>>,
metadata: &TimelineMetadata,
metadata: TimelineMetadata,
ancestor: Option<Arc<Timeline>>,
timeline_id: TimelineId,
tenant_id: TenantId,
@@ -1243,14 +1200,7 @@ impl Timeline {
ancestor_timeline: ancestor,
ancestor_lsn: metadata.ancestor_lsn(),
metrics: TimelineMetrics::new(
&tenant_id,
&timeline_id,
crate::metrics::EvictionsWithLowResidenceDurationBuilder::new(
"mtime",
conf.evictions_low_residence_duration_metric_threshold,
),
),
metrics: TimelineMetrics::new(&tenant_id, &timeline_id),
flush_loop_state: Mutex::new(FlushLoopState::NotStarted),
@@ -1287,10 +1237,6 @@ impl Timeline {
download_all_remote_layers_task_info: RwLock::new(None),
state,
eviction_task_timeline_state: tokio::sync::Mutex::new(
EvictionTaskTimelineState::default(),
),
};
result.repartition_threshold = result.get_checkpoint_distance() / 10;
result
@@ -1381,7 +1327,6 @@ impl Timeline {
lagging_wal_timeout,
max_lsn_wal_lag,
crate::config::SAFEKEEPER_AUTH_TOKEN.get().cloned(),
self.conf.availability_zone.clone(),
background_ctx,
);
}
@@ -1529,12 +1474,7 @@ impl Timeline {
.layer_metadata
.get(remote_layer_name)
.map(LayerFileMetadata::from)
.with_context(|| {
format!(
"No remote layer metadata found for layer {}",
remote_layer_name.file_name()
)
})?;
.unwrap_or(LayerFileMetadata::MISSING);
// Is the local layer's size different from the size stored in the
// remote index file?
@@ -1550,27 +1490,34 @@ impl Timeline {
local_layer_path.display()
);
let remote_size = remote_layer_metadata.file_size();
let metadata = local_layer_path.metadata().with_context(|| {
format!(
"get file size of local layer {}",
local_layer_path.display()
)
})?;
let local_size = metadata.len();
if local_size != remote_size {
warn!("removing local file {local_layer_path:?} because it has unexpected length {local_size}; length in remote index is {remote_size}");
if let Err(err) = rename_to_backup(&local_layer_path) {
assert!(local_layer_path.exists(), "we would leave the local_layer without a file if this does not hold: {}", local_layer_path.display());
anyhow::bail!("could not rename file {local_layer_path:?}: {err:?}");
if let Some(remote_size) = remote_layer_metadata.file_size() {
let metadata = local_layer_path.metadata().with_context(|| {
format!(
"get file size of local layer {}",
local_layer_path.display()
)
})?;
let local_size = metadata.len();
if local_size != remote_size {
warn!("removing local file {local_layer_path:?} because it has unexpected length {local_size}; length in remote index is {remote_size}");
if let Err(err) = rename_to_backup(&local_layer_path) {
assert!(local_layer_path.exists(), "we would leave the local_layer without a file if this does not hold: {}", local_layer_path.display());
anyhow::bail!("could not rename file {local_layer_path:?}: {err:?}");
} else {
self.metrics.resident_physical_size_gauge.sub(local_size);
updates.remove_historic(local_layer);
// fall-through to adding the remote layer
}
} else {
self.metrics.resident_physical_size_gauge.sub(local_size);
updates.remove_historic(local_layer);
// fall-through to adding the remote layer
debug!(
"layer is present locally and file size matches remote, using it: {}",
local_layer_path.display()
);
continue;
}
} else {
debug!(
"layer is present locally and file size matches remote, using it: {}",
"layer is present locally and remote does not have file size, using it: {}",
local_layer_path.display()
);
continue;
@@ -1672,8 +1619,6 @@ impl Timeline {
.map(|l| (l.filename(), l))
.collect::<HashMap<_, _>>();
// If no writes happen, new branches do not have any layers, only the metadata file.
let has_local_layers = !local_layers.is_empty();
let local_only_layers = match index_part {
Some(index_part) => {
info!(
@@ -1691,47 +1636,28 @@ impl Timeline {
}
};
if has_local_layers {
// Are there local files that don't exist remotely? Schedule uploads for them.
// Local timeline metadata will get uploaded to remove along witht he layers.
for (layer_name, layer) in &local_only_layers {
// XXX solve this in the type system
let layer_path = layer
.local_path()
.expect("local_only_layers only contains local layers");
let layer_size = layer_path
.metadata()
.with_context(|| format!("failed to get file {layer_path:?} metadata"))?
.len();
info!("scheduling {layer_path:?} for upload");
remote_client
.schedule_layer_file_upload(layer_name, &LayerFileMetadata::new(layer_size))?;
}
remote_client.schedule_index_upload_for_file_changes()?;
} else if index_part.is_none() {
// No data on the remote storage, no local layers, local metadata file.
//
// TODO https://github.com/neondatabase/neon/issues/3865
// Currently, console does not wait for the timeline data upload to the remote storage
// and considers the timeline created, expecting other pageserver nodes to work with it.
// Branch metadata upload could get interrupted (e.g pageserver got killed),
// hence any locally existing branch metadata with no remote counterpart should be uploaded,
// otherwise any other pageserver won't see the branch on `attach`.
//
// After the issue gets implemented, pageserver should rather remove the branch,
// since absence on S3 means we did not acknowledge the branch creation and console will have to retry,
// no need to keep the old files.
remote_client.schedule_index_upload_for_metadata_update(up_to_date_metadata)?;
} else {
// Local timeline has a metadata file, remote one too, both have no layers to sync.
// Are there local files that don't exist remotely? Schedule uploads for them
for (layer_name, layer) in &local_only_layers {
// XXX solve this in the type system
let layer_path = layer
.local_path()
.expect("local_only_layers only contains local layers");
let layer_size = layer_path
.metadata()
.with_context(|| format!("failed to get file {layer_path:?} metadata"))?
.len();
info!("scheduling {layer_path:?} for upload");
remote_client
.schedule_layer_file_upload(layer_name, &LayerFileMetadata::new(layer_size))?;
}
remote_client.schedule_index_upload_for_file_changes()?;
info!("Done");
Ok(())
}
fn try_spawn_size_init_task(self: &Arc<Self>, lsn: Lsn, ctx: &RequestContext) {
fn try_spawn_size_init_task(self: &Arc<Self>, init_lsn: Lsn, ctx: &RequestContext) {
let permit = match Arc::clone(&self.current_logical_size.initial_size_computation)
.try_acquire_owned()
{
@@ -1769,7 +1695,7 @@ impl Timeline {
// NB: don't log errors here, task_mgr will do that.
async move {
let calculated_size = match self_clone
.logical_size_calculation_task(lsn, &background_ctx)
.logical_size_calculation_task(init_lsn, &background_ctx)
.await
{
Ok(s) => s,
@@ -1854,7 +1780,7 @@ impl Timeline {
#[instrument(skip_all, fields(tenant = %self.tenant_id, timeline = %self.timeline_id))]
async fn logical_size_calculation_task(
self: &Arc<Self>,
lsn: Lsn,
init_lsn: Lsn,
ctx: &RequestContext,
) -> Result<u64, CalculateLogicalSizeError> {
let mut timeline_state_updates = self.subscribe_for_state_updates();
@@ -1865,7 +1791,7 @@ impl Timeline {
let cancel = cancel.child_token();
let ctx = ctx.attached_child();
self_calculation
.calculate_logical_size(lsn, cancel, &ctx)
.calculate_logical_size(init_lsn, cancel, &ctx)
.await
};
let timeline_state_cancellation = async {
@@ -1949,12 +1875,21 @@ impl Timeline {
// need to return something
Ok(0)
});
// See if we've already done the work for initial size calculation.
// This is a short-cut for timelines that are mostly unused.
if let Some(size) = self.current_logical_size.initialized_size(up_to_lsn) {
return Ok(size);
}
let timer = self.metrics.logical_size_histo.start_timer();
let timer = if up_to_lsn == self.initdb_lsn {
if let Some(size) = self.current_logical_size.initialized_size() {
if size != 0 {
// non-zero size means that the size has already been calculated by this method
// after startup. if the logical size is for a new timeline without layers the
// size will be zero, and we cannot use that, or this caching strategy until
// pageserver restart.
return Ok(size);
}
}
self.metrics.init_logical_size_histo.start_timer()
} else {
self.metrics.logical_size_histo.start_timer()
};
let logical_size = self
.get_current_logical_size_non_incremental(up_to_lsn, cancel, ctx)
.await?;
@@ -2007,12 +1942,11 @@ impl Timeline {
layer: Arc<dyn PersistentLayer>,
updates: &mut BatchedUpdates<'_, dyn PersistentLayer>,
) -> anyhow::Result<()> {
if !layer.is_remote_layer() {
layer.delete_resident_layer_file()?;
let layer_file_size = layer.file_size();
self.metrics
.resident_physical_size_gauge
.sub(layer_file_size);
let layer_size = layer.file_size();
layer.delete()?;
if let Some(layer_size) = layer_size {
self.metrics.resident_physical_size_gauge.sub(layer_size);
}
// TODO Removing from the bottom of the layer map is expensive.
@@ -2770,22 +2704,10 @@ impl Timeline {
) -> Result<HashMap<LayerFileName, LayerFileMetadata>, PageReconstructError> {
let timer = self.metrics.create_images_time_histo.start_timer();
let mut image_layers: Vec<ImageLayer> = Vec::new();
// We need to avoid holes between generated image layers.
// Otherwise LayerMap::image_layer_exists will return false if key range of some layer is covered by more than one
// image layer with hole between them. In this case such layer can not be utilized by GC.
//
// How such hole between partitions can appear?
// if we have relation with relid=1 and size 100 and relation with relid=2 with size 200 then result of
// KeySpace::partition may contain partitions <100000000..100000099> and <200000000..200000199>.
// If there is delta layer <100000000..300000000> then it never be garbage collected because
// image layers <100000000..100000099> and <200000000..200000199> are not completely covering it.
let mut start = Key::MIN;
for partition in partitioning.parts.iter() {
let img_range = start..partition.ranges.last().unwrap().end;
start = img_range.end;
if force || self.time_for_new_image_layer(partition, lsn)? {
let img_range =
partition.ranges.first().unwrap().start..partition.ranges.last().unwrap().end;
let mut image_layer_writer = ImageLayerWriter::new(
self.conf,
self.timeline_id,
@@ -2799,6 +2721,7 @@ impl Timeline {
"failpoint image-layer-writer-fail-before-finish"
)))
});
for range in &partition.ranges {
let mut key = range.start;
while key < range.end {
@@ -3213,7 +3136,9 @@ impl Timeline {
}
fail_point!("delta-layer-writer-fail-before-finish", |_| {
Err(anyhow::anyhow!("failpoint delta-layer-writer-fail-before-finish").into())
return Err(
anyhow::anyhow!("failpoint delta-layer-writer-fail-before-finish").into(),
);
});
writer.as_mut().unwrap().put_value(key, lsn, value)?;
@@ -3883,7 +3808,7 @@ impl Timeline {
remote_layer.ongoing_download.close();
} else {
// Keep semaphore open. We'll drop the permit at the end of the function.
error!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
info!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
}
// Don't treat it as an error if the task that triggered the download
@@ -4037,67 +3962,6 @@ impl Timeline {
}
}
pub struct DiskUsageEvictionInfo {
/// Timeline's largest layer (remote or resident)
pub max_layer_size: Option<u64>,
/// Timeline's resident layers
pub resident_layers: Vec<LocalLayerInfoForDiskUsageEviction>,
}
pub struct LocalLayerInfoForDiskUsageEviction {
pub layer: Arc<dyn PersistentLayer>,
pub last_activity_ts: SystemTime,
}
impl std::fmt::Debug for LocalLayerInfoForDiskUsageEviction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// format the tv_sec, tv_nsec into rfc3339 in case someone is looking at it
// having to allocate a string to this is bad, but it will rarely be formatted
let ts = chrono::DateTime::<chrono::Utc>::from(self.last_activity_ts);
let ts = ts.to_rfc3339_opts(chrono::SecondsFormat::Nanos, true);
f.debug_struct("LocalLayerInfoForDiskUsageEviction")
.field("layer", &self.layer)
.field("last_activity", &ts)
.finish()
}
}
impl LocalLayerInfoForDiskUsageEviction {
pub fn file_size(&self) -> u64 {
self.layer.file_size()
}
}
impl Timeline {
pub(crate) fn get_local_layers_for_disk_usage_eviction(&self) -> DiskUsageEvictionInfo {
let layers = self.layers.read().unwrap();
let mut max_layer_size: Option<u64> = None;
let mut resident_layers = Vec::new();
for l in layers.iter_historic_layers() {
let file_size = l.file_size();
max_layer_size = max_layer_size.map_or(Some(file_size), |m| Some(m.max(file_size)));
if l.is_remote_layer() {
continue;
}
let last_activity_ts = l.access_stats().latest_activity();
resident_layers.push(LocalLayerInfoForDiskUsageEviction {
layer: l,
last_activity_ts,
});
}
DiskUsageEvictionInfo {
max_layer_size,
resident_layers,
}
}
}
type TraversalPathItem = (
ValueReconstructResult,
Lsn,

View File

@@ -1,30 +1,17 @@
//! The per-timeline layer eviction task, which evicts data which has not been accessed for more
//! than a given threshold.
//!
//! Data includes all kinds of caches, namely:
//! - (in-memory layers)
//! - on-demand downloaded layer files on disk
//! - (cached layer file pages)
//! - derived data from layer file contents, namely:
//! - initial logical size
//! - partitioning
//! - (other currently missing unknowns)
//!
//! Items with parentheses are not (yet) touched by this task.
//!
//! See write-up on restart on-demand download spike: <https://gist.github.com/problame/2265bf7b8dc398be834abfead36c76b5>
//! The per-timeline layer eviction task.
use std::{
ops::ControlFlow,
sync::Arc,
time::{Duration, SystemTime},
};
use either::Either;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, instrument, warn};
use crate::{
context::{DownloadBehavior, RequestContext},
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
tenant::{
config::{EvictionPolicy, EvictionPolicyLayerAccessThreshold},
@@ -34,11 +21,6 @@ use crate::{
use super::Timeline;
#[derive(Default)]
pub struct EvictionTaskTimelineState {
last_refresh_required_in_restart: Option<tokio::time::Instant>,
}
impl Timeline {
pub(super) fn launch_eviction_task(self: &Arc<Self>) {
let self_clone = Arc::clone(self);
@@ -72,10 +54,9 @@ impl Timeline {
}
}
let ctx = RequestContext::new(TaskKind::Eviction, DownloadBehavior::Warn);
loop {
let policy = self.get_eviction_policy();
let cf = self.eviction_iteration(&policy, &cancel, &ctx).await;
let cf = self.eviction_iteration(&policy, cancel.clone()).await;
match cf {
ControlFlow::Break(()) => break,
@@ -96,8 +77,7 @@ impl Timeline {
async fn eviction_iteration(
self: &Arc<Self>,
policy: &EvictionPolicy,
cancel: &CancellationToken,
ctx: &RequestContext,
cancel: CancellationToken,
) -> ControlFlow<(), Instant> {
debug!("eviction iteration: {policy:?}");
match policy {
@@ -107,7 +87,7 @@ impl Timeline {
}
EvictionPolicy::LayerAccessThreshold(p) => {
let start = Instant::now();
match self.eviction_iteration_threshold(p, cancel, ctx).await {
match self.eviction_iteration_threshold(p, cancel).await {
ControlFlow::Break(()) => return ControlFlow::Break(()),
ControlFlow::Continue(()) => (),
}
@@ -121,8 +101,7 @@ impl Timeline {
async fn eviction_iteration_threshold(
self: &Arc<Self>,
p: &EvictionPolicyLayerAccessThreshold,
cancel: &CancellationToken,
ctx: &RequestContext,
cancel: CancellationToken,
) -> ControlFlow<()> {
let now = SystemTime::now();
@@ -135,28 +114,6 @@ impl Timeline {
not_evictable: usize,
skipped_for_shutdown: usize,
}
// what we want is to invalidate any caches which haven't been accessed for `p.threshold`,
// but we cannot actually do it for current limitations except by restarting pageserver. we
// just recompute the values which would be recomputed on startup.
//
// for active tenants this will likely materialized page cache or in-memory layers. for
// inactive tenants it will refresh the last_access timestamps so that we will not evict
// and re-download on restart these layers.
let mut state = self.eviction_task_timeline_state.lock().await;
match state.last_refresh_required_in_restart {
Some(ts) if ts.elapsed() < p.threshold => { /* no need to run */ }
_ => {
self.refresh_layers_required_in_restart(cancel, ctx).await;
state.last_refresh_required_in_restart = Some(tokio::time::Instant::now())
}
}
drop(state);
if cancel.is_cancelled() {
return ControlFlow::Break(());
}
let mut stats = EvictionStats::default();
// Gather layers for eviction.
// NB: all the checks can be invalidated as soon as we release the layer map lock.
@@ -169,7 +126,13 @@ impl Timeline {
if hist_layer.is_remote_layer() {
continue;
}
let last_activity_ts = hist_layer.access_stats().latest_activity();
let last_activity_ts = match hist_layer
.access_stats()
.most_recent_access_or_residence_event()
{
Either::Left(mra) => mra.when,
Either::Right(re) => re.timestamp,
};
let no_activity_for = match now.duration_since(last_activity_ts) {
Ok(d) => d,
Err(_e) => {
@@ -211,7 +174,7 @@ impl Timeline {
};
let results = match self
.evict_layer_batch(remote_client, &candidates[..], cancel.clone())
.evict_layer_batch(remote_client, &candidates[..], cancel)
.await
{
Err(pre_err) => {
@@ -253,40 +216,4 @@ impl Timeline {
}
ControlFlow::Continue(())
}
/// Recompute the values which would cause on-demand downloads during restart.
async fn refresh_layers_required_in_restart(
&self,
cancel: &CancellationToken,
ctx: &RequestContext,
) {
let lsn = self.get_last_record_lsn();
// imitiate on-restart initial logical size
let size = self.calculate_logical_size(lsn, cancel.clone(), ctx).await;
match &size {
Ok(_size) => {
// good, don't log it to avoid confusion
}
Err(_) => {
// we have known issues for which we already log this on consumption metrics,
// gc, and compaction. leave logging out for now.
//
// https://github.com/neondatabase/neon/issues/2539
}
}
// imitiate repartiting on first compactation
if let Err(e) = self.collect_keyspace(lsn, ctx).await {
// if this failed, we probably failed logical size because these use the same keys
if size.is_err() {
// ignore, see above comment
} else {
warn!(
"failed to collect keyspace but succeeded in calculating logical size: {e:#}"
);
}
}
}
}

View File

@@ -45,7 +45,6 @@ pub fn spawn_connection_manager_task(
lagging_wal_timeout: Duration,
max_lsn_wal_lag: NonZeroU64,
auth_token: Option<Arc<String>>,
availability_zone: Option<String>,
ctx: RequestContext,
) {
let mut broker_client = get_broker_client().clone();
@@ -68,7 +67,6 @@ pub fn spawn_connection_manager_task(
lagging_wal_timeout,
max_lsn_wal_lag,
auth_token,
availability_zone,
);
loop {
select! {
@@ -336,7 +334,6 @@ struct WalreceiverState {
/// Data about all timelines, available for connection, fetched from storage broker, grouped by their corresponding safekeeper node id.
wal_stream_candidates: HashMap<NodeId, BrokerSkTimeline>,
auth_token: Option<Arc<String>>,
availability_zone: Option<String>,
}
/// Current connection data.
@@ -384,7 +381,6 @@ impl WalreceiverState {
lagging_wal_timeout: Duration,
max_lsn_wal_lag: NonZeroU64,
auth_token: Option<Arc<String>>,
availability_zone: Option<String>,
) -> Self {
let id = TenantTimelineId {
tenant_id: timeline.tenant_id,
@@ -400,7 +396,6 @@ impl WalreceiverState {
wal_stream_candidates: HashMap::new(),
wal_connection_retries: HashMap::new(),
auth_token,
availability_zone,
}
}
@@ -745,7 +740,6 @@ impl WalreceiverState {
None => None,
Some(x) => Some(x),
},
self.availability_zone.as_deref(),
) {
Ok(connstr) => Some((*sk_id, info, connstr)),
Err(e) => {
@@ -830,24 +824,17 @@ fn wal_stream_connection_config(
}: TenantTimelineId,
listen_pg_addr_str: &str,
auth_token: Option<&str>,
availability_zone: Option<&str>,
) -> anyhow::Result<PgConnectionConfig> {
let (host, port) =
parse_host_port(listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?;
let port = port.unwrap_or(5432);
let mut connstr = PgConnectionConfig::new_host_port(host, port)
Ok(PgConnectionConfig::new_host_port(host, port)
.extend_options([
"-c".to_owned(),
format!("timeline_id={}", timeline_id),
format!("tenant_id={}", tenant_id),
])
.set_password(auth_token.map(|s| s.to_owned()));
if let Some(availability_zone) = availability_zone {
connstr = connstr.extend_options([format!("availability_zone={}", availability_zone)]);
}
Ok(connstr)
.set_password(auth_token.map(|s| s.to_owned())))
}
#[cfg(test)]
@@ -1286,7 +1273,6 @@ mod tests {
wal_stream_candidates: HashMap::new(),
wal_connection_retries: HashMap::new(),
auth_token: None,
availability_zone: None,
}
}
}

View File

@@ -33,11 +33,10 @@ use crate::{
walingest::WalIngest,
walrecord::DecodedWALRecord,
};
use postgres_backend::is_expected_io_error;
use postgres_connection::PgConnectionConfig;
use postgres_ffi::waldecoder::WalStreamDecoder;
use pq_proto::ReplicationFeedback;
use utils::lsn::Lsn;
use utils::{lsn::Lsn, postgres_backend_async::is_expected_io_error};
/// Status of the connection.
#[derive(Debug, Clone, Copy)]
@@ -354,7 +353,7 @@ pub async fn handle_walreceiver_connection(
debug!("neon_status_update {status_update:?}");
let mut data = BytesMut::new();
status_update.serialize(&mut data);
status_update.serialize(&mut data)?;
physical_stream
.as_mut()
.zenith_status_update(data.len() as u64, &data)
@@ -435,8 +434,8 @@ fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result<postgres:
{
return Ok(pg_error);
} else if let Some(db_error) = pg_error.as_db_error() {
if db_error.code() == &SqlState::SUCCESSFUL_COMPLETION
&& db_error.message().contains("ending streaming")
if db_error.code() == &SqlState::CONNECTION_FAILURE
&& db_error.message().contains("end streaming")
{
return Ok(pg_error);
}

View File

@@ -127,21 +127,12 @@ impl UploadQueue {
let mut files = HashMap::with_capacity(index_part.timeline_layers.len());
for layer_name in &index_part.timeline_layers {
match index_part
let layer_metadata = index_part
.layer_metadata
.get(layer_name)
.map(LayerFileMetadata::from)
{
Some(layer_metadata) => {
files.insert(layer_name.to_owned(), layer_metadata);
}
None => {
anyhow::bail!(
"No remote layer metadata found for layer {}",
layer_name.file_name()
);
}
}
.unwrap_or(LayerFileMetadata::MISSING);
files.insert(layer_name.to_owned(), layer_metadata);
}
let index_part_metadata = index_part.parse_metadata()?;

View File

@@ -23,11 +23,13 @@ use bytes::{BufMut, Bytes, BytesMut};
use nix::poll::*;
use serde::Serialize;
use std::collections::VecDeque;
use std::fs::OpenOptions;
use std::io::prelude::*;
use std::io::{Error, ErrorKind};
use std::ops::{Deref, DerefMut};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::prelude::CommandExt;
use std::path::PathBuf;
use std::process::Stdio;
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
use std::sync::{Mutex, MutexGuard};
@@ -254,53 +256,52 @@ impl PostgresRedoManager {
pg_version: u32,
) -> Result<Bytes, WalRedoError> {
let (rel, blknum) = key_to_rel_block(key).or(Err(WalRedoError::InvalidRecord))?;
const MAX_RETRY_ATTEMPTS: u32 = 1;
let start_time = Instant::now();
let mut n_attempts = 0u32;
loop {
let mut proc = self.stdin.lock().unwrap();
let lock_time = Instant::now();
// launch the WAL redo process on first use
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
let mut proc = self.stdin.lock().unwrap();
let lock_time = Instant::now();
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
let result = self
.apply_wal_records(proc, buf_tag, &base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
// launch the WAL redo process on first use
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
let result = self
.apply_wal_records(proc, buf_tag, base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
let len = records.len();
let nbytes = records.iter().fold(0, |acumulator, record| {
acumulator
+ match &record.1 {
NeonWalRecord::Postgres { rec, .. } => rec.len(),
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
}
});
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
WAL_REDO_TIME.observe(duration.as_secs_f64());
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
let len = records.len();
let nbytes = records.iter().fold(0, |acumulator, record| {
acumulator
+ match &record.1 {
NeonWalRecord::Postgres { rec, .. } => rec.len(),
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
}
});
debug!(
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
len,
nbytes,
duration.as_micros(),
lsn
);
WAL_REDO_TIME.observe(duration.as_secs_f64());
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
// If something went wrong, don't try to reuse the process. Kill it, and
// next request will launch a new one.
if result.is_err() {
error!(
debug!(
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
len,
nbytes,
duration.as_micros(),
lsn
);
// If something went wrong, don't try to reuse the process. Kill it, and
// next request will launch a new one.
if result.is_err() {
error!(
"error applying {} WAL records {}..{} ({} bytes) to base image with LSN {} to reconstruct page image at LSN {}",
records.len(),
records.first().map(|p| p.0).unwrap_or(Lsn(0)),
@@ -309,28 +310,24 @@ impl PostgresRedoManager {
base_img_lsn,
lsn
);
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
// The owning objects (ChildStdout and ChildStderr) are stored in
// self.stdout and self.stderr, respsectively.
// We intentionally keep them open here to avoid a race between
// currently running `apply_wal_records()` and a `launch()` call
// after we return here.
// The currently running `apply_wal_records()` must not read from
// the newly launched process.
// By keeping self.stdout and self.stderr open here, `launch()` will
// get other file descriptors for the new child's stdout and stderr,
// and hence the current `apply_wal_records()` calls will observe
// `output.stdout.as_raw_fd() != stdout_fd` .
if let Some(proc) = self.stdin.lock().unwrap().take() {
proc.child.kill_and_wait();
}
}
n_attempts += 1;
if n_attempts > MAX_RETRY_ATTEMPTS || result.is_ok() {
return result;
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
// The owning objects (ChildStdout and ChildStderr) are stored in
// self.stdout and self.stderr, respsectively.
// We intentionally keep them open here to avoid a race between
// currently running `apply_wal_records()` and a `launch()` call
// after we return here.
// The currently running `apply_wal_records()` must not read from
// the newly launched process.
// By keeping self.stdout and self.stderr open here, `launch()` will
// get other file descriptors for the new child's stdout and stderr,
// and hence the current `apply_wal_records()` calls will observe
// `output.stdout.as_raw_fd() != stdout_fd` .
if let Some(proc) = self.stdin.lock().unwrap().take() {
proc.child.kill_and_wait();
}
}
result
}
///
@@ -637,26 +634,26 @@ impl PostgresRedoManager {
input: &mut MutexGuard<Option<ProcessInput>>,
pg_version: u32,
) -> Result<(), Error> {
// Previous versions of wal-redo required data directory and that directories
// occupied some space on disk. Remove it if we face it.
//
// This code could be dropped after one release cycle.
let legacy_datadir = path_with_suffix_extension(
// FIXME: We need a dummy Postgres cluster to run the process in. Currently, we
// just create one with constant name. That fails if you try to launch more than
// one WAL redo manager concurrently.
let datadir = path_with_suffix_extension(
self.conf
.tenant_path(&self.tenant_id)
.join("wal-redo-datadir"),
TEMP_FILE_SUFFIX,
);
if legacy_datadir.exists() {
info!("legacy wal-redo datadir {legacy_datadir:?} exists, removing");
fs::remove_dir_all(&legacy_datadir).map_err(|e| {
// Create empty data directory for wal-redo postgres, deleting old one first.
if datadir.exists() {
info!("old temporary datadir {datadir:?} exists, removing");
fs::remove_dir_all(&datadir).map_err(|e| {
Error::new(
e.kind(),
format!("legacy wal-redo datadir {legacy_datadir:?} removal failure: {e}"),
format!("Old temporary dir {datadir:?} removal failure: {e}"),
)
})?;
}
let pg_bin_dir_path = self
.conf
.pg_bin_dir(pg_version)
@@ -666,6 +663,35 @@ impl PostgresRedoManager {
.pg_lib_dir(pg_version)
.map_err(|e| Error::new(ErrorKind::Other, format!("incorrect pg_lib_dir path: {e}")))?;
info!("running initdb in {}", datadir.display());
let initdb = Command::new(pg_bin_dir_path.join("initdb"))
.args(["-D", &datadir.to_string_lossy()])
.arg("-N")
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path) // macOS
.close_fds()
.output()
.map_err(|e| Error::new(e.kind(), format!("failed to execute initdb: {e}")))?;
if !initdb.status.success() {
return Err(Error::new(
ErrorKind::Other,
format!(
"initdb failed\nstdout: {}\nstderr:\n{}",
String::from_utf8_lossy(&initdb.stdout),
String::from_utf8_lossy(&initdb.stderr)
),
));
} else {
// Limit shared cache for wal-redo-postgres
let mut config = OpenOptions::new()
.append(true)
.open(PathBuf::from(&datadir).join("postgresql.conf"))?;
config.write_all(b"shared_buffers=128kB\n")?;
config.write_all(b"fsync=off\n")?;
}
// Start postgres itself
let child = Command::new(pg_bin_dir_path.join("postgres"))
.arg("--wal-redo")
@@ -675,6 +701,7 @@ impl PostgresRedoManager {
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path)
.env("PGDATA", &datadir)
// The redo process is not trusted, and runs in seccomp mode that
// doesn't allow it to open any files. We have to also make sure it
// doesn't inherit any file descriptors from the pageserver, that
@@ -744,7 +771,7 @@ impl PostgresRedoManager {
&self,
mut input: MutexGuard<Option<ProcessInput>>,
tag: BufferTag,
base_img: &Option<Bytes>,
base_img: Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
) -> Result<Bytes, std::io::Error> {
@@ -760,7 +787,7 @@ impl PostgresRedoManager {
let mut writebuf: Vec<u8> = Vec::with_capacity((BLCKSZ as usize) * 3);
build_begin_redo_for_block_msg(tag, &mut writebuf);
if let Some(img) = base_img {
build_push_page_msg(tag, img, &mut writebuf);
build_push_page_msg(tag, &img, &mut writebuf);
}
for (lsn, rec) in records.iter() {
if let NeonWalRecord::Postgres {

View File

@@ -32,9 +32,6 @@
#define PageStoreTrace DEBUG5
#define MAX_RECONNECT_ATTEMPTS 5
#define RECONNECT_INTERVAL_USEC 1000000
bool connected = false;
PGconn *pageserver_conn = NULL;
@@ -46,12 +43,8 @@ PGconn *pageserver_conn = NULL;
*/
WaitEventSet *pageserver_conn_wes = NULL;
/* GUCs */
char *neon_timeline;
char *neon_tenant;
int32 max_cluster_size;
char *page_server_connstring;
char *neon_auth_token;
char *page_server_connstring_raw;
char *safekeeper_token_env;
int n_unflushed_requests = 0;
int flush_every_n_requests = 8;
@@ -59,42 +52,15 @@ int readahead_buffer_size = 128;
static void pageserver_flush(void);
static bool
pageserver_connect(int elevel)
static void
pageserver_connect()
{
char *query;
int ret;
const char *keywords[3];
const char *values[3];
int n;
Assert(!connected);
/*
* Connect using the connection string we got from the
* neon.pageserver_connstring GUC. If the NEON_AUTH_TOKEN environment
* variable was set, use that as the password.
*
* The connection options are parsed in the order they're given, so
* when we set the password before the connection string, the
* connection string can override the password from the env variable.
* Seems useful, although we don't currently use that capability
* anywhere.
*/
n = 0;
if (neon_auth_token)
{
keywords[n] = "password";
values[n] = neon_auth_token;
n++;
}
keywords[n] = "dbname";
values[n] = page_server_connstring;
n++;
keywords[n] = NULL;
values[n] = NULL;
n++;
pageserver_conn = PQconnectdbParams(keywords, values, 1);
pageserver_conn = PQconnectdb(page_server_connstring);
if (PQstatus(pageserver_conn) == CONNECTION_BAD)
{
@@ -103,11 +69,10 @@ pageserver_connect(int elevel)
PQfinish(pageserver_conn);
pageserver_conn = NULL;
ereport(elevel,
ereport(ERROR,
(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
errmsg(NEON_TAG "could not establish connection to pageserver"),
errdetail_internal("%s", msg)));
return false;
}
query = psprintf("pagestream %s %s", neon_tenant, neon_timeline);
@@ -116,8 +81,7 @@ pageserver_connect(int elevel)
{
PQfinish(pageserver_conn);
pageserver_conn = NULL;
neon_log(elevel, "could not send pagestream command to pageserver");
return false;
neon_log(ERROR, "could not send pagestream command to pageserver");
}
pageserver_conn_wes = CreateWaitEventSet(TopMemoryContext, 3);
@@ -149,17 +113,15 @@ pageserver_connect(int elevel)
FreeWaitEventSet(pageserver_conn_wes);
pageserver_conn_wes = NULL;
neon_log(elevel, "could not complete handshake with pageserver: %s",
neon_log(ERROR, "could not complete handshake with pageserver: %s",
msg);
return false;
}
}
}
neon_log(LOG, "libpagestore: connected to '%s'", page_server_connstring);
neon_log(LOG, "libpagestore: connected to '%s'", page_server_connstring_raw);
connected = true;
return true;
}
/*
@@ -187,11 +149,8 @@ retry:
if (event.events & WL_SOCKET_READABLE)
{
if (!PQconsumeInput(pageserver_conn))
{
neon_log(LOG, "could not get response from pageserver: %s",
neon_log(ERROR, "could not get response from pageserver: %s",
PQerrorMessage(pageserver_conn));
return -1;
}
}
goto retry;
@@ -231,62 +190,31 @@ static void
pageserver_send(NeonRequest * request)
{
StringInfoData req_buff;
int n_reconnect_attempts = 0;
/* If the connection was lost for some reason, reconnect */
if (connected && PQstatus(pageserver_conn) == CONNECTION_BAD)
pageserver_disconnect();
if (!connected)
pageserver_connect();
req_buff = nm_pack_request(request);
/*
* If pageserver is stopped, the connections from compute node are broken.
* The compute node doesn't notice that immediately, but it will cause the next request to fail, usually on the next query.
* That causes user-visible errors if pageserver is restarted, or the tenant is moved from one pageserver to another.
* See https://github.com/neondatabase/neon/issues/1138
* So try to reestablish connection in case of failure.
* Send request.
*
* In principle, this could block if the output buffer is full, and we
* should use async mode and check for interrupts while waiting. In
* practice, our requests are small enough to always fit in the output and
* TCP buffer.
*/
while (true)
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
{
if (!connected)
{
if (!pageserver_connect(n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS ? LOG : ERROR))
{
n_reconnect_attempts += 1;
pg_usleep(RECONNECT_INTERVAL_USEC);
continue;
}
}
char *msg = pchomp(PQerrorMessage(pageserver_conn));
/*
* Send request.
*
* In principle, this could block if the output buffer is full, and we
* should use async mode and check for interrupts while waiting. In
* practice, our requests are small enough to always fit in the output and
* TCP buffer.
*/
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
{
char *msg = pchomp(PQerrorMessage(pageserver_conn));
if (n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS)
{
neon_log(LOG, "failed to send page request (try to reconnect): %s", msg);
if (n_reconnect_attempts != 0) /* do not sleep before first reconnect attempt, assuming that pageserver is already restarted */
pg_usleep(RECONNECT_INTERVAL_USEC);
n_reconnect_attempts += 1;
continue;
}
else
{
pageserver_disconnect();
neon_log(ERROR, "failed to send page request: %s", msg);
}
}
break;
pageserver_disconnect();
neon_log(ERROR, "failed to send page request: %s", msg);
}
pfree(req_buff.data);
n_unflushed_requests++;
@@ -385,6 +313,105 @@ check_neon_id(char **newval, void **extra, GucSource source)
return **newval == '\0' || HexDecodeString(id, *newval, 16);
}
static char *
substitute_pageserver_password(const char *page_server_connstring_raw)
{
char *host = NULL;
char *port = NULL;
char *user = NULL;
char *auth_token = NULL;
char *err = NULL;
char *page_server_connstring = NULL;
PQconninfoOption *conn_options;
PQconninfoOption *conn_option;
MemoryContext oldcontext;
/*
* Here we substitute password in connection string with an environment
* variable. To simplify things we construct a connection string back with
* only known options. In particular: host port user and password. We do
* not currently use other options and constructing full connstring in an
* URI shape is quite messy.
*/
if (page_server_connstring_raw == NULL || page_server_connstring_raw[0] == '\0')
return NULL;
/* extract the auth token from the connection string */
conn_options = PQconninfoParse(page_server_connstring_raw, &err);
if (conn_options == NULL)
{
/* The error string is malloc'd, so we must free it explicitly */
char *errcopy = err ? pstrdup(err) : "out of memory";
PQfreemem(err);
ereport(ERROR,
(errcode(ERRCODE_SYNTAX_ERROR),
errmsg("invalid connection string syntax: %s", errcopy)));
}
/*
* Trying to populate pageserver connection string with auth token from
* environment. We are looking for password in with placeholder value like
* $ENV_VAR_NAME, so if password field is present and starts with $ we try
* to fetch environment variable value and fail loudly if it is not set.
*/
for (conn_option = conn_options; conn_option->keyword != NULL; conn_option++)
{
if (strcmp(conn_option->keyword, "host") == 0)
{
if (conn_option->val != NULL && conn_option->val[0] != '\0')
host = conn_option->val;
}
else if (strcmp(conn_option->keyword, "port") == 0)
{
if (conn_option->val != NULL && conn_option->val[0] != '\0')
port = conn_option->val;
}
else if (strcmp(conn_option->keyword, "user") == 0)
{
if (conn_option->val != NULL && conn_option->val[0] != '\0')
user = conn_option->val;
}
else if (strcmp(conn_option->keyword, "password") == 0)
{
if (conn_option->val != NULL && conn_option->val[0] != '\0')
{
/* ensure that this is a template */
if (strncmp(conn_option->val, "$", 1) != 0)
ereport(ERROR,
(errcode(ERRCODE_CONNECTION_EXCEPTION),
errmsg("expected placeholder value in pageserver password starting from $ but found: %s", &conn_option->val[1])));
neon_log(LOG, "found auth token placeholder in pageserver conn string '%s'", &conn_option->val[1]);
auth_token = getenv(&conn_option->val[1]);
if (!auth_token)
{
ereport(ERROR,
(errcode(ERRCODE_CONNECTION_EXCEPTION),
errmsg("cannot get auth token, environment variable %s is not set", &conn_option->val[1])));
}
else
{
neon_log(LOG, "using auth token from environment passed via env");
}
}
}
}
/*
* allocate connection string in TopMemoryContext to make sure it is not
* freed
*/
oldcontext = CurrentMemoryContext;
MemoryContextSwitchTo(TopMemoryContext);
page_server_connstring = psprintf("postgresql://%s:%s@%s:%s", user, auth_token ? auth_token : "", host, port);
MemoryContextSwitchTo(oldcontext);
PQconninfoFree(conn_options);
return page_server_connstring;
}
/*
* Module initialization function
*/
@@ -394,12 +421,21 @@ pg_init_libpagestore(void)
DefineCustomStringVariable("neon.pageserver_connstring",
"connection string to the page server",
NULL,
&page_server_connstring,
&page_server_connstring_raw,
"",
PGC_POSTMASTER,
0, /* no flags required */
NULL, NULL, NULL);
DefineCustomStringVariable("neon.safekeeper_token_env",
"the environment variable containing JWT token for authentication with Safekeepers, the convention is to either unset or set to $NEON_AUTH_TOKEN",
NULL,
&safekeeper_token_env,
NULL,
PGC_POSTMASTER,
0, /* no flags required */
NULL, NULL, NULL);
DefineCustomStringVariable("neon.timeline_id",
"Neon timeline_id the server is running on",
NULL,
@@ -456,10 +492,30 @@ pg_init_libpagestore(void)
neon_log(PageStoreTrace, "libpagestore already loaded");
page_server = &api;
/* Retrieve the auth token to use when connecting to pageserver and safekeepers */
neon_auth_token = getenv("NEON_AUTH_TOKEN");
if (neon_auth_token)
neon_log(LOG, "using storage auth token from NEON_AUTH_TOKEN environment variable");
/* substitute password in pageserver_connstring */
page_server_connstring = substitute_pageserver_password(page_server_connstring_raw);
/* Is there more correct way to pass CustomGUC to postgres code? */
neon_timeline_walproposer = neon_timeline;
neon_tenant_walproposer = neon_tenant;
/* retrieve the token for Safekeeper, if present */
if (safekeeper_token_env != NULL) {
if (safekeeper_token_env[0] != '$') {
ereport(ERROR,
(errcode(ERRCODE_CONNECTION_EXCEPTION),
errmsg("expected safekeeper auth token environment variable's name starting with $ but found: %s",
safekeeper_token_env)));
}
neon_safekeeper_token_walproposer = getenv(&safekeeper_token_env[1]);
if (!neon_safekeeper_token_walproposer) {
ereport(ERROR,
(errcode(ERRCODE_CONNECTION_EXCEPTION),
errmsg("cannot get safekeeper auth token, environment variable %s is not set",
&safekeeper_token_env[1])));
}
neon_log(LOG, "using safekeeper auth token from environment variable");
}
if (page_server_connstring && page_server_connstring[0])
{

View File

@@ -51,39 +51,12 @@ walprop_status(WalProposerConn *conn)
}
WalProposerConn *
walprop_connect_start(char *conninfo, char *password)
walprop_connect_start(char *conninfo)
{
WalProposerConn *conn;
PGconn *pg_conn;
const char *keywords[3];
const char *values[3];
int n;
/*
* Connect using the given connection string. If the
* NEON_AUTH_TOKEN environment variable was set, use that as
* the password.
*
* The connection options are parsed in the order they're given, so
* when we set the password before the connection string, the
* connection string can override the password from the env variable.
* Seems useful, although we don't currently use that capability
* anywhere.
*/
n = 0;
if (password)
{
keywords[n] = "password";
values[n] = neon_auth_token;
n++;
}
keywords[n] = "dbname";
values[n] = conninfo;
n++;
keywords[n] = NULL;
values[n] = NULL;
n++;
pg_conn = PQconnectStartParams(keywords, values, 1);
pg_conn = PQconnectStart(conninfo);
/*
* Allocation of a PQconn can fail, and will return NULL. We want to fully

Some files were not shown because too many files have changed in this diff Show More