Compare commits

..

1 Commits

Author SHA1 Message Date
Vlad Lazar
0aa8bfa995 storcon: skip spurious detach on node activation 2024-09-23 17:33:18 +01:00
121 changed files with 1450 additions and 9066 deletions

View File

@@ -13,7 +13,6 @@
# Directories
!.cargo/
!.config/
!compute/
!compute_tools/
!control_plane/
!libs/

View File

@@ -257,15 +257,7 @@ jobs:
${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(remote_storage)' -E 'test(test_real_azure)'
- name: Install postgres binaries
run: |
# Use tar to copy files matching the pattern, preserving the paths in the destionation
tar c \
pg_install/v* \
pg_install/build/*/src/test/regress/*.so \
pg_install/build/*/src/test/regress/pg_regress \
pg_install/build/*/src/test/isolation/isolationtester \
pg_install/build/*/src/test/isolation/pg_isolation_regress \
| tar x -C /tmp/neon
run: cp -a pg_install /tmp/neon/pg_install
- name: Upload Neon artifact
uses: ./.github/actions/upload

View File

@@ -602,20 +602,7 @@ jobs:
strategy:
fail-fast: false
matrix:
version:
# Much data was already generated on old PG versions with bullseye's
# libraries, the locales of which can cause data incompatibilities.
# However, new PG versions should check if they can be built on newer
# images, as that reduces the support burden of old and ancient
# distros.
- pg: v14
debian: bullseye-slim
- pg: v15
debian: bullseye-slim
- pg: v16
debian: bullseye-slim
- pg: v17
debian: bookworm-slim
version: [ v14, v15, v16, v17 ]
arch: [ x64, arm64 ]
runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', matrix.arch == 'arm64' && 'large-arm64' || 'large')) }}
@@ -658,46 +645,41 @@ jobs:
context: .
build-args: |
GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }}
PG_VERSION=${{ matrix.version.pg }}
PG_VERSION=${{ matrix.version }}
BUILD_TAG=${{ needs.tag.outputs.build-tag }}
TAG=${{ needs.build-build-tools-image.outputs.image-tag }}
DEBIAN_FLAVOR=${{ matrix.version.debian }}
provenance: false
push: true
pull: true
file: compute/Dockerfile.compute-node
cache-from: type=registry,ref=cache.neon.build/compute-node-${{ matrix.version.pg }}:cache-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/compute-node-{0}:cache-{1},mode=max', matrix.version.pg, matrix.arch) || '' }}
file: Dockerfile.compute-node
cache-from: type=registry,ref=cache.neon.build/compute-node-${{ matrix.version }}:cache-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/compute-node-{0}:cache-{1},mode=max', matrix.version, matrix.arch) || '' }}
tags: |
neondatabase/compute-node-${{ matrix.version.pg }}:${{ needs.tag.outputs.build-tag }}-${{ matrix.arch }}
neondatabase/compute-node-${{ matrix.version }}:${{ needs.tag.outputs.build-tag }}-${{ matrix.arch }}
- name: Build neon extensions test image
if: matrix.version.pg == 'v16'
if: matrix.version == 'v16'
uses: docker/build-push-action@v6
with:
context: .
build-args: |
GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }}
PG_VERSION=${{ matrix.version.pg }}
PG_VERSION=${{ matrix.version }}
BUILD_TAG=${{ needs.tag.outputs.build-tag }}
TAG=${{ needs.build-build-tools-image.outputs.image-tag }}
DEBIAN_FLAVOR=${{ matrix.version.debian }}
provenance: false
push: true
pull: true
file: compute/Dockerfile.compute-node
file: Dockerfile.compute-node
target: neon-pg-ext-test
cache-from: type=registry,ref=cache.neon.build/neon-test-extensions-${{ matrix.version.pg }}:cache-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/neon-test-extensions-{0}:cache-{1},mode=max', matrix.version.pg, matrix.arch) || '' }}
cache-from: type=registry,ref=cache.neon.build/neon-test-extensions-${{ matrix.version }}:cache-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/neon-test-extensions-{0}:cache-{1},mode=max', matrix.version, matrix.arch) || '' }}
tags: |
neondatabase/neon-test-extensions-${{ matrix.version.pg }}:${{needs.tag.outputs.build-tag}}-${{ matrix.arch }}
neondatabase/neon-test-extensions-${{ matrix.version }}:${{needs.tag.outputs.build-tag}}-${{ matrix.arch }}
- name: Build compute-tools image
# compute-tools are Postgres independent, so build it only once
# We pick 16, because that builds on debian 11 with older glibc (and is
# thus compatible with newer glibc), rather than 17 on Debian 12, as
# that isn't guaranteed to be compatible with Debian 11
if: matrix.version.pg == 'v16'
if: matrix.version == 'v17'
uses: docker/build-push-action@v6
with:
target: compute-tools-image
@@ -706,11 +688,10 @@ jobs:
GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }}
BUILD_TAG=${{ needs.tag.outputs.build-tag }}
TAG=${{ needs.build-build-tools-image.outputs.image-tag }}
DEBIAN_FLAVOR=${{ matrix.version.debian }}
provenance: false
push: true
pull: true
file: compute/Dockerfile.compute-node
file: Dockerfile.compute-node
tags: |
neondatabase/compute-tools:${{ needs.tag.outputs.build-tag }}-${{ matrix.arch }}
@@ -798,7 +779,7 @@ jobs:
- name: Build vm image
run: |
./vm-builder \
-spec=compute/vm-image-spec.yaml \
-spec=vm-image-spec.yaml \
-src=neondatabase/compute-node-${{ matrix.version }}:${{ needs.tag.outputs.build-tag }} \
-dst=neondatabase/vm-compute-node-${{ matrix.version }}:${{ needs.tag.outputs.build-tag }}
@@ -862,9 +843,6 @@ jobs:
needs: [ check-permissions, tag, test-images, vm-compute-node-image ]
runs-on: ubuntu-22.04
permissions:
id-token: write # for `aws-actions/configure-aws-credentials`
env:
VERSIONS: v14 v15 v16 v17
@@ -909,19 +887,13 @@ jobs:
docker buildx imagetools create -t neondatabase/neon-test-extensions-v16:latest \
neondatabase/neon-test-extensions-v16:${{ needs.tag.outputs.build-tag }}
- name: Configure AWS-prod credentials
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
mask-aws-account-id: true
role-to-assume: ${{ secrets.PROD_GHA_OIDC_ROLE }}
- name: Login to prod ECR
uses: docker/login-action@v3
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
with:
registry: 093970136003.dkr.ecr.eu-central-1.amazonaws.com
username: ${{ secrets.PROD_GHA_RUNNER_LIMITED_AWS_ACCESS_KEY_ID }}
password: ${{ secrets.PROD_GHA_RUNNER_LIMITED_AWS_SECRET_ACCESS_KEY }}
- name: Copy all images to prod ECR
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
@@ -1190,9 +1162,10 @@ jobs:
files_to_promote+=("s3://${BUCKET}/${s3_key}")
for pg_version in v14 v15 v16 v17; do
# TODO Add v17
for pg_version in v14 v15 v16; do
# We run less tests for debug builds, so we don't need to promote them
if [ "${build_type}" == "debug" ] && { [ "${arch}" == "ARM64" ] || [ "${pg_version}" != "v17" ] ; }; then
if [ "${build_type}" == "debug" ] && { [ "${arch}" == "ARM64" ] || [ "${pg_version}" != "v16" ] ; }; then
continue
fi
@@ -1234,6 +1207,7 @@ jobs:
# Usually we do `needs: [...]`
needs:
- build-and-test-locally
- check-submodules
- check-codestyle-python
- check-codestyle-rust
- promote-images

View File

@@ -1,102 +0,0 @@
name: Cloud Regression Test
on:
schedule:
# * is a special character in YAML so you have to quote this string
# ┌───────────── minute (0 - 59)
# │ ┌───────────── hour (0 - 23)
# │ │ ┌───────────── day of the month (1 - 31)
# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
- cron: '45 1 * * *' # run once a day, timezone is utc
workflow_dispatch: # adds ability to run this manually
defaults:
run:
shell: bash -euxo pipefail {0}
concurrency:
# Allow only one workflow
group: ${{ github.workflow }}
cancel-in-progress: true
jobs:
regress:
env:
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 16
TEST_OUTPUT: /tmp/test_output
BUILD_TYPE: remote
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
runs-on: us-east-2
container:
image: neondatabase/build-tools:pinned
options: --init
steps:
- uses: actions/checkout@v4
with:
submodules: true
- name: Patch the test
run: |
cd "vendor/postgres-v${DEFAULT_PG_VERSION}"
patch -p1 < "../../compute/patches/cloud_regress_pg${DEFAULT_PG_VERSION}.patch"
- name: Generate a random password
id: pwgen
run: |
set +x
DBPASS=$(dd if=/dev/random bs=48 count=1 2>/dev/null | base64)
echo "::add-mask::${DBPASS//\//}"
echo DBPASS="${DBPASS//\//}" >> "${GITHUB_OUTPUT}"
- name: Change tests according to the generated password
env:
DBPASS: ${{ steps.pwgen.outputs.DBPASS }}
run: |
cd vendor/postgres-v"${DEFAULT_PG_VERSION}"/src/test/regress
for fname in sql/*.sql expected/*.out; do
sed -i.bak s/NEON_PASSWORD_PLACEHOLDER/"'${DBPASS}'"/ "${fname}"
done
for ph in $(grep NEON_MD5_PLACEHOLDER expected/password.out | awk '{print $3;}' | sort | uniq); do
USER=$(echo "${ph}" | cut -c 22-)
MD5=md5$(echo -n "${DBPASS}${USER}" | md5sum | awk '{print $1;}')
sed -i.bak "s/${ph}/${MD5}/" expected/password.out
done
- name: Download Neon artifact
uses: ./.github/actions/download
with:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
- name: Run the regression tests
uses: ./.github/actions/run-python-test-set
with:
build_type: ${{ env.BUILD_TYPE }}
test_selection: cloud_regress
pg_version: ${{ env.DEFAULT_PG_VERSION }}
extra_params: -m remote_cluster
env:
BENCHMARK_CONNSTR: ${{ secrets.PG_REGRESS_CONNSTR }}
- name: Create Allure report
id: create-allure-report
if: ${{ !cancelled() }}
uses: ./.github/actions/allure-report-generate
- name: Post to a Slack channel
if: ${{ github.event.schedule && failure() }}
uses: slackapi/slack-github-action@v1
with:
channel-id: "C033QLM5P7D" # on-call-staging-stream
slack-message: |
Periodic pg_regress on staging: ${{ job.status }}
<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|GitHub Run>
<${{ steps.create-allure-report.outputs.report-url }}|Allure report>
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}

View File

@@ -102,12 +102,12 @@ jobs:
# Default set of platforms to run e2e tests on
platforms='["docker", "k8s"]'
# If the PR changes vendor/, pgxn/ or libs/vm_monitor/ directories, or compute/Dockerfile.compute-node, add k8s-neonvm to the list of platforms.
# If the PR changes vendor/, pgxn/ or libs/vm_monitor/ directories, or Dockerfile.compute-node, add k8s-neonvm to the list of platforms.
# If the workflow run is not a pull request, add k8s-neonvm to the list.
if [ "$GITHUB_EVENT_NAME" == "pull_request" ]; then
for f in $(gh api "/repos/${GITHUB_REPOSITORY}/pulls/${PR_NUMBER}/files" --paginate --jq '.[].filename'); do
case "$f" in
vendor/*|pgxn/*|libs/vm_monitor/*|compute/Dockerfile.compute-node)
vendor/*|pgxn/*|libs/vm_monitor/*|Dockerfile.compute-node)
platforms=$(echo "${platforms}" | jq --compact-output '. += ["k8s-neonvm"] | unique')
;;
*)

11
Cargo.lock generated
View File

@@ -1321,6 +1321,7 @@ dependencies = [
"clap",
"comfy-table",
"compute_api",
"git-version",
"humantime",
"humantime-serde",
"hyper 0.14.30",
@@ -3577,6 +3578,7 @@ dependencies = [
"anyhow",
"camino",
"clap",
"git-version",
"humantime",
"pageserver",
"pageserver_api",
@@ -3615,6 +3617,7 @@ dependencies = [
"enumset",
"fail",
"futures",
"git-version",
"hex",
"hex-literal",
"humantime",
@@ -3734,6 +3737,7 @@ dependencies = [
"clap",
"criterion",
"futures",
"git-version",
"hex-literal",
"itertools 0.10.5",
"once_cell",
@@ -4296,7 +4300,6 @@ dependencies = [
"camino-tempfile",
"chrono",
"clap",
"compute_api",
"consumption_metrics",
"dashmap",
"ecdsa 0.16.9",
@@ -4304,6 +4307,7 @@ dependencies = [
"fallible-iterator",
"framed-websockets",
"futures",
"git-version",
"hashbrown 0.14.5",
"hashlink",
"hex",
@@ -5135,6 +5139,7 @@ dependencies = [
"desim",
"fail",
"futures",
"git-version",
"hex",
"humantime",
"hyper 0.14.30",
@@ -5697,6 +5702,7 @@ dependencies = [
"futures",
"futures-core",
"futures-util",
"git-version",
"humantime",
"hyper 0.14.30",
"metrics",
@@ -5724,6 +5730,7 @@ dependencies = [
"diesel_migrations",
"fail",
"futures",
"git-version",
"hex",
"humantime",
"hyper 0.14.30",
@@ -5776,6 +5783,7 @@ dependencies = [
"either",
"futures",
"futures-util",
"git-version",
"hex",
"humantime",
"itertools 0.10.5",
@@ -6707,7 +6715,6 @@ dependencies = [
"criterion",
"fail",
"futures",
"git-version",
"hex",
"hex-literal",
"humantime",

View File

@@ -3,15 +3,13 @@ ARG REPOSITORY=neondatabase
ARG IMAGE=build-tools
ARG TAG=pinned
ARG BUILD_TAG
ARG DEBIAN_FLAVOR=bullseye-slim
#########################################################################################
#
# Layer "build-deps"
#
#########################################################################################
FROM debian:$DEBIAN_FLAVOR AS build-deps
ARG DEBIAN_FLAVOR
FROM debian:bullseye-slim AS build-deps
RUN apt update && \
apt install -y git autoconf automake libtool build-essential bison flex libreadline-dev \
zlib1g-dev libxml2-dev libcurl4-openssl-dev libossp-uuid-dev wget pkg-config libssl-dev \
@@ -282,7 +280,7 @@ FROM build-deps AS vector-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY compute/patches/pgvector.patch /pgvector.patch
COPY patches/pgvector.patch /pgvector.patch
# By default, pgvector Makefile uses `-march=native`. We don't want that,
# because we build the images on different machines than where we run them.
@@ -368,7 +366,7 @@ FROM build-deps AS rum-pg-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY compute/patches/rum.patch /rum.patch
COPY patches/rum.patch /rum.patch
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
@@ -1029,47 +1027,10 @@ RUN cd compute_tools && mold -run cargo build --locked --profile release-line-de
#
#########################################################################################
FROM debian:$DEBIAN_FLAVOR AS compute-tools-image
ARG DEBIAN_FLAVOR
FROM debian:bullseye-slim AS compute-tools-image
COPY --from=compute-tools /home/nonroot/target/release-line-debug-size-lto/compute_ctl /usr/local/bin/compute_ctl
#########################################################################################
#
# Layer "pgbouncer"
#
#########################################################################################
FROM debian:$DEBIAN_FLAVOR AS pgbouncer
ARG DEBIAN_FLAVOR
RUN set -e \
&& apt-get update \
&& apt-get install -y \
build-essential \
git \
libevent-dev \
libtool \
pkg-config
# Use `dist_man_MANS=` to skip manpage generation (which requires python3/pandoc)
ENV PGBOUNCER_TAG=pgbouncer_1_22_1
RUN set -e \
&& git clone --recurse-submodules --depth 1 --branch ${PGBOUNCER_TAG} https://github.com/pgbouncer/pgbouncer.git pgbouncer \
&& cd pgbouncer \
&& ./autogen.sh \
&& LDFLAGS=-static ./configure --prefix=/usr/local/pgbouncer --without-openssl \
&& make -j $(nproc) dist_man_MANS= \
&& make install dist_man_MANS=
#########################################################################################
#
# Layers "postgres-exporter" and "sql-exporter"
#
#########################################################################################
FROM quay.io/prometheuscommunity/postgres-exporter:v0.12.1 AS postgres-exporter
FROM burningalchemist/sql_exporter:0.13 AS sql-exporter
#########################################################################################
#
# Clean up postgres folder before inclusion
@@ -1117,7 +1078,7 @@ COPY --from=pgjwt-pg-build /pgjwt.tar.gz /ext-src
COPY --from=hypopg-pg-build /hypopg.tar.gz /ext-src
COPY --from=pg-hashids-pg-build /pg_hashids.tar.gz /ext-src
COPY --from=rum-pg-build /rum.tar.gz /ext-src
COPY compute/patches/rum.patch /ext-src
COPY patches/rum.patch /ext-src
#COPY --from=pgtap-pg-build /pgtap.tar.gz /ext-src
COPY --from=ip4r-pg-build /ip4r.tar.gz /ext-src
COPY --from=prefix-pg-build /prefix.tar.gz /ext-src
@@ -1125,9 +1086,9 @@ COPY --from=hll-pg-build /hll.tar.gz /ext-src
COPY --from=plpgsql-check-pg-build /plpgsql_check.tar.gz /ext-src
#COPY --from=timescaledb-pg-build /timescaledb.tar.gz /ext-src
COPY --from=pg-hint-plan-pg-build /pg_hint_plan.tar.gz /ext-src
COPY compute/patches/pg_hint_plan.patch /ext-src
COPY patches/pg_hint_plan.patch /ext-src
COPY --from=pg-cron-pg-build /pg_cron.tar.gz /ext-src
COPY compute/patches/pg_cron.patch /ext-src
COPY patches/pg_cron.patch /ext-src
#COPY --from=pg-pgx-ulid-build /home/nonroot/pgx_ulid.tar.gz /ext-src
#COPY --from=rdkit-pg-build /rdkit.tar.gz /ext-src
COPY --from=pg-uuidv7-pg-build /pg_uuidv7.tar.gz /ext-src
@@ -1136,7 +1097,7 @@ COPY --from=pg-semver-pg-build /pg_semver.tar.gz /ext-src
#COPY --from=pg-embedding-pg-build /home/nonroot/pg_embedding-src/ /ext-src
#COPY --from=wal2json-pg-build /wal2json_2_5.tar.gz /ext-src
COPY --from=pg-anon-pg-build /pg_anon.tar.gz /ext-src
COPY compute/patches/pg_anon.patch /ext-src
COPY patches/pg_anon.patch /ext-src
COPY --from=pg-ivm-build /pg_ivm.tar.gz /ext-src
COPY --from=pg-partman-build /pg_partman.tar.gz /ext-src
RUN case "${PG_VERSION}" in "v17") \
@@ -1183,9 +1144,7 @@ ENV PGDATABASE=postgres
# Put it all together into the final image
#
#########################################################################################
FROM debian:$DEBIAN_FLAVOR
ARG DEBIAN_FLAVOR
ENV DEBIAN_FLAVOR=$DEBIAN_FLAVOR
FROM debian:bullseye-slim
# Add user postgres
RUN mkdir /var/db && useradd -m -d /var/db/postgres postgres && \
echo "postgres:test_console_pass" | chpasswd && \
@@ -1201,50 +1160,23 @@ RUN mkdir /var/db && useradd -m -d /var/db/postgres postgres && \
COPY --from=postgres-cleanup-layer --chown=postgres /usr/local/pgsql /usr/local
COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-debug-size-lto/compute_ctl /usr/local/bin/compute_ctl
# pgbouncer and its config
COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/pgbouncer
COPY --chmod=0666 --chown=postgres compute/etc/pgbouncer.ini /etc/pgbouncer.ini
# Metrics exporter binaries and configuration files
COPY --from=postgres-exporter /bin/postgres_exporter /bin/postgres_exporter
COPY --from=sql-exporter /bin/sql_exporter /bin/sql_exporter
COPY --chmod=0644 compute/etc/sql_exporter.yml /etc/sql_exporter.yml
COPY --chmod=0644 compute/etc/neon_collector.yml /etc/neon_collector.yml
COPY --chmod=0644 compute/etc/sql_exporter_autoscaling.yml /etc/sql_exporter_autoscaling.yml
COPY --chmod=0644 compute/etc/neon_collector_autoscaling.yml /etc/neon_collector_autoscaling.yml
# Create remote extension download directory
RUN mkdir /usr/local/download_extensions && chown -R postgres:postgres /usr/local/download_extensions
# Install:
# libreadline8 for psql
# libicu67, locales for collations (including ICU and plpgsql_check)
# liblz4-1 for lz4
# libossp-uuid16 for extension ossp-uuid
# libgeos, libsfcgal1, and libprotobuf-c1 for PostGIS
# libgeos, libgdal, libsfcgal1, libproj and libprotobuf-c1 for PostGIS
# libxml2, libxslt1.1 for xml2
# libzstd1 for zstd
# libboost* for rdkit
# ca-certificates for communicating with s3 by compute_ctl
RUN apt update && \
case $DEBIAN_FLAVOR in \
# Version-specific installs for Bullseye (PG14-PG16):
# libicu67, locales for collations (including ICU and plpgsql_check)
# libgdal28, libproj19 for PostGIS
bullseye*) \
VERSION_INSTALLS="libicu67 libgdal28 libproj19"; \
;; \
# Version-specific installs for Bookworm (PG17):
# libicu72, locales for collations (including ICU and plpgsql_check)
# libgdal32, libproj25 for PostGIS
bookworm*) \
VERSION_INSTALLS="libicu72 libgdal32 libproj25"; \
;; \
esac && \
RUN apt update && \
apt install --no-install-recommends -y \
gdb \
libicu67 \
liblz4-1 \
libreadline8 \
libboost-iostreams1.74.0 \
@@ -1253,6 +1185,8 @@ RUN apt update && \
libboost-system1.74.0 \
libossp-uuid16 \
libgeos-c1v5 \
libgdal28 \
libproj19 \
libprotobuf-c1 \
libsfcgal1 \
libxml2 \
@@ -1261,8 +1195,7 @@ RUN apt update && \
libcurl4-openssl-dev \
locales \
procps \
ca-certificates \
$VERSION_INSTALLS && \
ca-certificates && \
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \
localedef -i en_US -c -f UTF-8 -A /usr/share/locale/locale.alias en_US.UTF-8

View File

@@ -1,21 +0,0 @@
This directory contains files that are needed to build the compute
images, or included in the compute images.
Dockerfile.compute-node
To build the compute image
vm-image-spec.yaml
Instructions for vm-builder, to turn the compute-node image into
corresponding vm-compute-node image.
etc/
Configuration files included in /etc in the compute image
patches/
Some extensions need to be patched to work with Neon. This
directory contains such patches. They are applied to the extension
sources in Dockerfile.compute-node
In addition to these, postgres itself, the neon postgres extension,
and compute_ctl are built and copied into the compute image by
Dockerfile.compute-node.

View File

@@ -1,246 +0,0 @@
collector_name: neon_collector
metrics:
- metric_name: lfc_misses
type: gauge
help: 'lfc_misses'
key_labels:
values: [lfc_misses]
query: |
select lfc_value as lfc_misses from neon.neon_lfc_stats where lfc_key='file_cache_misses';
- metric_name: lfc_used
type: gauge
help: 'LFC chunks used (chunk = 1MB)'
key_labels:
values: [lfc_used]
query: |
select lfc_value as lfc_used from neon.neon_lfc_stats where lfc_key='file_cache_used';
- metric_name: lfc_hits
type: gauge
help: 'lfc_hits'
key_labels:
values: [lfc_hits]
query: |
select lfc_value as lfc_hits from neon.neon_lfc_stats where lfc_key='file_cache_hits';
- metric_name: lfc_writes
type: gauge
help: 'lfc_writes'
key_labels:
values: [lfc_writes]
query: |
select lfc_value as lfc_writes from neon.neon_lfc_stats where lfc_key='file_cache_writes';
- metric_name: lfc_cache_size_limit
type: gauge
help: 'LFC cache size limit in bytes'
key_labels:
values: [lfc_cache_size_limit]
query: |
select pg_size_bytes(current_setting('neon.file_cache_size_limit')) as lfc_cache_size_limit;
- metric_name: connection_counts
type: gauge
help: 'Connection counts'
key_labels:
- datname
- state
values: [count]
query: |
select datname, state, count(*) as count from pg_stat_activity where state <> '' group by datname, state;
- metric_name: pg_stats_userdb
type: gauge
help: 'Stats for several oldest non-system dbs'
key_labels:
- datname
value_label: kind
values:
- db_size
- deadlocks
# Rows
- inserted
- updated
- deleted
# We export stats for 10 non-system database. Without this limit
# it is too easy to abuse the system by creating lots of databases.
query: |
select pg_database_size(datname) as db_size, deadlocks,
tup_inserted as inserted, tup_updated as updated, tup_deleted as deleted,
datname
from pg_stat_database
where datname IN (
select datname
from pg_database
where datname <> 'postgres' and not datistemplate
order by oid
limit 10
);
- metric_name: max_cluster_size
type: gauge
help: 'neon.max_cluster_size setting'
key_labels:
values: [max_cluster_size]
query: |
select setting::int as max_cluster_size from pg_settings where name = 'neon.max_cluster_size';
- metric_name: db_total_size
type: gauge
help: 'Size of all databases'
key_labels:
values: [total]
query: |
select sum(pg_database_size(datname)) as total from pg_database;
# DEPRECATED
- metric_name: lfc_approximate_working_set_size
type: gauge
help: 'Approximate working set size in pages of 8192 bytes'
key_labels:
values: [approximate_working_set_size]
query: |
select neon.approximate_working_set_size(false) as approximate_working_set_size;
- metric_name: lfc_approximate_working_set_size_windows
type: gauge
help: 'Approximate working set size in pages of 8192 bytes'
key_labels: [duration]
values: [size]
# NOTE: This is the "public" / "human-readable" version. Here, we supply a small selection
# of durations in a pretty-printed form.
query: |
select
x as duration,
neon.approximate_working_set_size_seconds(extract('epoch' from x::interval)::int) as size
from
(values ('5m'),('15m'),('1h')) as t (x);
- metric_name: compute_current_lsn
type: gauge
help: 'Current LSN of the database'
key_labels:
values: [lsn]
query: |
select
case
when pg_catalog.pg_is_in_recovery()
then (pg_last_wal_replay_lsn() - '0/0')::FLOAT8
else (pg_current_wal_lsn() - '0/0')::FLOAT8
end as lsn;
- metric_name: compute_receive_lsn
type: gauge
help: 'Returns the last write-ahead log location that has been received and synced to disk by streaming replication'
key_labels:
values: [lsn]
query: |
SELECT
CASE
WHEN pg_catalog.pg_is_in_recovery()
THEN (pg_last_wal_receive_lsn() - '0/0')::FLOAT8
ELSE 0
END AS lsn;
- metric_name: replication_delay_bytes
type: gauge
help: 'Bytes between received and replayed LSN'
key_labels:
values: [replication_delay_bytes]
# We use a GREATEST call here because this calculation can be negative.
# The calculation is not atomic, meaning after we've gotten the receive
# LSN, the replay LSN may have advanced past the receive LSN we
# are using for the calculation.
query: |
SELECT GREATEST(0, pg_wal_lsn_diff(pg_last_wal_receive_lsn(), pg_last_wal_replay_lsn())) AS replication_delay_bytes;
- metric_name: replication_delay_seconds
type: gauge
help: 'Time since last LSN was replayed'
key_labels:
values: [replication_delay_seconds]
query: |
SELECT
CASE
WHEN pg_last_wal_receive_lsn() = pg_last_wal_replay_lsn() THEN 0
ELSE GREATEST (0, EXTRACT (EPOCH FROM now() - pg_last_xact_replay_timestamp()))
END AS replication_delay_seconds;
- metric_name: checkpoints_req
type: gauge
help: 'Number of requested checkpoints'
key_labels:
values: [checkpoints_req]
query: |
SELECT checkpoints_req FROM pg_stat_bgwriter;
- metric_name: checkpoints_timed
type: gauge
help: 'Number of scheduled checkpoints'
key_labels:
values: [checkpoints_timed]
query: |
SELECT checkpoints_timed FROM pg_stat_bgwriter;
- metric_name: compute_logical_snapshot_files
type: gauge
help: 'Number of snapshot files in pg_logical/snapshot'
key_labels:
- timeline_id
values: [num_logical_snapshot_files]
query: |
SELECT
(SELECT setting FROM pg_settings WHERE name = 'neon.timeline_id') AS timeline_id,
-- Postgres creates temporary snapshot files of the form %X-%X.snap.%d.tmp. These
-- temporary snapshot files are renamed to the actual snapshot files after they are
-- completely built. We only WAL-log the completely built snapshot files.
(SELECT COUNT(*) FROM pg_ls_dir('pg_logical/snapshots') AS name WHERE name LIKE '%.snap') AS num_logical_snapshot_files;
# In all the below metrics, we cast LSNs to floats because Prometheus only supports floats.
# It's probably fine because float64 can store integers from -2^53 to +2^53 exactly.
# Number of slots is limited by max_replication_slots, so collecting position for all of them shouldn't be bad.
- metric_name: logical_slot_restart_lsn
type: gauge
help: 'restart_lsn of logical slots'
key_labels:
- slot_name
values: [restart_lsn]
query: |
select slot_name, (restart_lsn - '0/0')::FLOAT8 as restart_lsn
from pg_replication_slots
where slot_type = 'logical';
- metric_name: compute_subscriptions_count
type: gauge
help: 'Number of logical replication subscriptions grouped by enabled/disabled'
key_labels:
- enabled
values: [subscriptions_count]
query: |
select subenabled::text as enabled, count(*) as subscriptions_count
from pg_subscription
group by subenabled;
- metric_name: retained_wal
type: gauge
help: 'Retained WAL in inactive replication slots'
key_labels:
- slot_name
values: [retained_wal]
query: |
SELECT slot_name, pg_wal_lsn_diff(pg_current_wal_lsn(), restart_lsn)::FLOAT8 AS retained_wal
FROM pg_replication_slots
WHERE active = false;
- metric_name: wal_is_lost
type: gauge
help: 'Whether or not the replication slot wal_status is lost'
key_labels:
- slot_name
values: [wal_is_lost]
query: |
SELECT slot_name,
CASE WHEN wal_status = 'lost' THEN 1 ELSE 0 END AS wal_is_lost
FROM pg_replication_slots;

View File

@@ -1,55 +0,0 @@
collector_name: neon_collector_autoscaling
metrics:
- metric_name: lfc_misses
type: gauge
help: 'lfc_misses'
key_labels:
values: [lfc_misses]
query: |
select lfc_value as lfc_misses from neon.neon_lfc_stats where lfc_key='file_cache_misses';
- metric_name: lfc_used
type: gauge
help: 'LFC chunks used (chunk = 1MB)'
key_labels:
values: [lfc_used]
query: |
select lfc_value as lfc_used from neon.neon_lfc_stats where lfc_key='file_cache_used';
- metric_name: lfc_hits
type: gauge
help: 'lfc_hits'
key_labels:
values: [lfc_hits]
query: |
select lfc_value as lfc_hits from neon.neon_lfc_stats where lfc_key='file_cache_hits';
- metric_name: lfc_writes
type: gauge
help: 'lfc_writes'
key_labels:
values: [lfc_writes]
query: |
select lfc_value as lfc_writes from neon.neon_lfc_stats where lfc_key='file_cache_writes';
- metric_name: lfc_cache_size_limit
type: gauge
help: 'LFC cache size limit in bytes'
key_labels:
values: [lfc_cache_size_limit]
query: |
select pg_size_bytes(current_setting('neon.file_cache_size_limit')) as lfc_cache_size_limit;
- metric_name: lfc_approximate_working_set_size_windows
type: gauge
help: 'Approximate working set size in pages of 8192 bytes'
key_labels: [duration_seconds]
values: [size]
# NOTE: This is the "internal" / "machine-readable" version. This outputs the working set
# size looking back 1..60 minutes, labeled with the number of minutes.
query: |
select
x::text as duration_seconds,
neon.approximate_working_set_size_seconds(x) as size
from
(select generate_series * 60 as x from generate_series(1, 60)) as t (x);

View File

@@ -1,17 +0,0 @@
[databases]
*=host=localhost port=5432 auth_user=cloud_admin
[pgbouncer]
listen_port=6432
listen_addr=0.0.0.0
auth_type=scram-sha-256
auth_user=cloud_admin
auth_dbname=postgres
client_tls_sslmode=disable
server_tls_sslmode=disable
pool_mode=transaction
max_client_conn=10000
default_pool_size=64
max_prepared_statements=0
admin_users=postgres
unix_socket_dir=/tmp/
unix_socket_mode=0777

View File

@@ -1,33 +0,0 @@
# Configuration for sql_exporter
# Global defaults.
global:
# If scrape_timeout <= 0, no timeout is set unless Prometheus provides one. The default is 10s.
scrape_timeout: 10s
# Subtracted from Prometheus' scrape_timeout to give us some headroom and prevent Prometheus from timing out first.
scrape_timeout_offset: 500ms
# Minimum interval between collector runs: by default (0s) collectors are executed on every scrape.
min_interval: 0s
# Maximum number of open connections to any one target. Metric queries will run concurrently on multiple connections,
# as will concurrent scrapes.
max_connections: 1
# Maximum number of idle connections to any one target. Unless you use very long collection intervals, this should
# always be the same as max_connections.
max_idle_connections: 1
# Maximum number of maximum amount of time a connection may be reused. Expired connections may be closed lazily before reuse.
# If 0, connections are not closed due to a connection's age.
max_connection_lifetime: 5m
# The target to monitor and the collectors to execute on it.
target:
# Data source name always has a URI schema that matches the driver name. In some cases (e.g. MySQL)
# the schema gets dropped or replaced to match the driver expected DSN format.
data_source_name: 'postgresql://cloud_admin@127.0.0.1:5432/postgres?sslmode=disable&application_name=sql_exporter'
# Collectors (referenced by name) to execute on the target.
# Glob patterns are supported (see <https://pkg.go.dev/path/filepath#Match> for syntax).
collectors: [neon_collector]
# Collector files specifies a list of globs. One collector definition is read from each matching file.
# Glob patterns are supported (see <https://pkg.go.dev/path/filepath#Match> for syntax).
collector_files:
- "neon_collector.yml"

View File

@@ -1,33 +0,0 @@
# Configuration for sql_exporter for autoscaling-agent
# Global defaults.
global:
# If scrape_timeout <= 0, no timeout is set unless Prometheus provides one. The default is 10s.
scrape_timeout: 10s
# Subtracted from Prometheus' scrape_timeout to give us some headroom and prevent Prometheus from timing out first.
scrape_timeout_offset: 500ms
# Minimum interval between collector runs: by default (0s) collectors are executed on every scrape.
min_interval: 0s
# Maximum number of open connections to any one target. Metric queries will run concurrently on multiple connections,
# as will concurrent scrapes.
max_connections: 1
# Maximum number of idle connections to any one target. Unless you use very long collection intervals, this should
# always be the same as max_connections.
max_idle_connections: 1
# Maximum number of maximum amount of time a connection may be reused. Expired connections may be closed lazily before reuse.
# If 0, connections are not closed due to a connection's age.
max_connection_lifetime: 5m
# The target to monitor and the collectors to execute on it.
target:
# Data source name always has a URI schema that matches the driver name. In some cases (e.g. MySQL)
# the schema gets dropped or replaced to match the driver expected DSN format.
data_source_name: 'postgresql://cloud_admin@127.0.0.1:5432/postgres?sslmode=disable&application_name=sql_exporter_autoscaling'
# Collectors (referenced by name) to execute on the target.
# Glob patterns are supported (see <https://pkg.go.dev/path/filepath#Match> for syntax).
collectors: [neon_collector_autoscaling]
# Collector files specifies a list of globs. One collector definition is read from each matching file.
# Glob patterns are supported (see <https://pkg.go.dev/path/filepath#Match> for syntax).
collector_files:
- "neon_collector_autoscaling.yml"

File diff suppressed because it is too large Load Diff

View File

@@ -1,112 +0,0 @@
# Supplemental file for neondatabase/autoscaling's vm-builder, for producing the VM compute image.
---
commands:
- name: cgconfigparser
user: root
sysvInitAction: sysinit
shell: 'cgconfigparser -l /etc/cgconfig.conf -s 1664'
# restrict permissions on /neonvm/bin/resize-swap, because we grant access to compute_ctl for
# running it as root.
- name: chmod-resize-swap
user: root
sysvInitAction: sysinit
shell: 'chmod 711 /neonvm/bin/resize-swap'
- name: pgbouncer
user: postgres
sysvInitAction: respawn
shell: '/usr/local/bin/pgbouncer /etc/pgbouncer.ini'
- name: postgres-exporter
user: nobody
sysvInitAction: respawn
shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter" /bin/postgres_exporter'
- name: sql-exporter
user: nobody
sysvInitAction: respawn
shell: '/bin/sql_exporter -config.file=/etc/sql_exporter.yml -web.listen-address=:9399'
- name: sql-exporter-autoscaling
user: nobody
sysvInitAction: respawn
shell: '/bin/sql_exporter -config.file=/etc/sql_exporter_autoscaling.yml -web.listen-address=:9499'
shutdownHook: |
su -p postgres --session-command '/usr/local/bin/pg_ctl stop -D /var/db/postgres/compute/pgdata -m fast --wait -t 10'
files:
- filename: compute_ctl-resize-swap
content: |
# Allow postgres user (which is what compute_ctl runs as) to run /neonvm/bin/resize-swap
# as root without requiring entering a password (NOPASSWD), regardless of hostname (ALL)
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap
- filename: cgconfig.conf
content: |
# Configuration for cgroups in VM compute nodes
group neon-postgres {
perm {
admin {
uid = postgres;
}
task {
gid = users;
}
}
memory {}
}
build: |
# Build cgroup-tools
#
# At time of writing (2023-03-14), debian bullseye has a version of cgroup-tools (technically
# libcgroup) that doesn't support cgroup v2 (version 0.41-11). Unfortunately, the vm-monitor
# requires cgroup v2, so we'll build cgroup-tools ourselves.
FROM debian:bullseye-slim as libcgroup-builder
ENV LIBCGROUP_VERSION=v2.0.3
RUN set -exu \
&& apt update \
&& apt install --no-install-recommends -y \
git \
ca-certificates \
automake \
cmake \
make \
gcc \
byacc \
flex \
libtool \
libpam0g-dev \
&& git clone --depth 1 -b $LIBCGROUP_VERSION https://github.com/libcgroup/libcgroup \
&& INSTALL_DIR="/libcgroup-install" \
&& mkdir -p "$INSTALL_DIR/bin" "$INSTALL_DIR/include" \
&& cd libcgroup \
# extracted from bootstrap.sh, with modified flags:
&& (test -d m4 || mkdir m4) \
&& autoreconf -fi \
&& rm -rf autom4te.cache \
&& CFLAGS="-O3" ./configure --prefix="$INSTALL_DIR" --sysconfdir=/etc --localstatedir=/var --enable-opaque-hierarchy="name=systemd" \
# actually build the thing...
&& make install
merge: |
# tweak nofile limits
RUN set -e \
&& echo 'fs.file-max = 1048576' >>/etc/sysctl.conf \
&& test ! -e /etc/security || ( \
echo '* - nofile 1048576' >>/etc/security/limits.conf \
&& echo 'root - nofile 1048576' >>/etc/security/limits.conf \
)
# Allow postgres user (compute_ctl) to run swap resizer.
# Need to install sudo in order to allow this.
#
# Also, remove the 'read' permission from group/other on /neonvm/bin/resize-swap, just to be safe.
RUN set -e \
&& apt update \
&& apt install --no-install-recommends -y \
sudo \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
COPY compute_ctl-resize-swap /etc/sudoers.d/compute_ctl-resize-swap
COPY cgconfig.conf /etc/cgconfig.conf
RUN set -e \
&& chmod 0644 /etc/cgconfig.conf
COPY --from=libcgroup-builder /libcgroup-install/bin/* /usr/bin/
COPY --from=libcgroup-builder /libcgroup-install/lib/* /usr/lib/
COPY --from=libcgroup-builder /libcgroup-install/sbin/* /usr/sbin/

View File

@@ -11,17 +11,9 @@ use crate::compute::ComputeNode;
fn configurator_main_loop(compute: &Arc<ComputeNode>) {
info!("waiting for reconfiguration requests");
loop {
let mut state = compute.state.lock().unwrap();
let state = compute.state.lock().unwrap();
let mut state = compute.state_changed.wait(state).unwrap();
// We have to re-check the status after re-acquiring the lock because it could be that
// the status has changed while we were waiting for the lock, and we might not need to
// wait on the condition variable. Otherwise, we might end up in some soft-/deadlock, i.e.
// we are waiting for a condition variable that will never be signaled.
if state.status != ComputeStatus::ConfigurationPending {
state = compute.state_changed.wait(state).unwrap();
}
// Re-check the status after waking up
if state.status == ComputeStatus::ConfigurationPending {
info!("got configuration request");
state.status = ComputeStatus::Configuration;

View File

@@ -9,6 +9,7 @@ anyhow.workspace = true
camino.workspace = true
clap.workspace = true
comfy-table.workspace = true
git-version.workspace = true
humantime.workspace = true
nix.workspace = true
once_cell.workspace = true

View File

@@ -346,14 +346,7 @@ impl StorageController {
let pg_log_path = pg_data_path.join("postgres.log");
if !tokio::fs::try_exists(&pg_data_path).await? {
let initdb_args = [
"-D",
pg_data_path.as_ref(),
"--username",
&username(),
"--no-sync",
"--no-instructions",
];
let initdb_args = ["-D", pg_data_path.as_ref(), "--username", &username()];
tracing::info!(
"Initializing storage controller database with args: {:?}",
initdb_args

View File

@@ -4,8 +4,8 @@ use std::{str::FromStr, time::Duration};
use clap::{Parser, Subcommand};
use pageserver_api::{
controller_api::{
AvailabilityZone, NodeAvailabilityWrapper, NodeDescribeResponse, NodeShardResponse,
ShardSchedulingPolicy, TenantCreateRequest, TenantDescribeResponse, TenantPolicyRequest,
NodeAvailabilityWrapper, NodeDescribeResponse, NodeShardResponse, ShardSchedulingPolicy,
TenantCreateRequest, TenantDescribeResponse, TenantPolicyRequest,
},
models::{
EvictionPolicy, EvictionPolicyLayerAccessThreshold, LocationConfigSecondary,
@@ -339,7 +339,7 @@ async fn main() -> anyhow::Result<()> {
listen_pg_port,
listen_http_addr,
listen_http_port,
availability_zone_id: AvailabilityZone(availability_zone_id),
availability_zone_id,
}),
)
.await?;

View File

@@ -2,8 +2,8 @@
# Example docker compose configuration
The configuration in this directory is used for testing Neon docker images: it is
not intended for deploying a usable system. To run a development environment where
you can experiment with a miniature Neon system, use `cargo neon` rather than container images.
not intended for deploying a usable system. To run a development environment where
you can experiment with a minature Neon system, use `cargo neon` rather than container images.
This configuration does not start the storage controller, because the controller
needs a way to reconfigure running computes, and no such thing exists in this setup.

View File

@@ -1,343 +0,0 @@
# Independent compute release
Created at: 2024-08-30. Author: Alexey Kondratov (@ololobus)
## Summary
This document proposes an approach to fully independent compute release flow. It attempts to
cover the following features:
- Process is automated as much as possible to minimize human errors.
- Compute<->storage protocol compatibility is ensured.
- A transparent release history is available with an easy rollback strategy.
- Although not in the scope of this document, there is a viable way to extend the proposed release
flow to achieve the canary and/or blue-green deployment strategies.
## Motivation
Previously, the compute release was tightly coupled to the storage release. This meant that once
some storage nodes got restarted with a newer version, all new compute starts using these nodes
automatically got a new version. Thus, two releases happen in parallel, which increases the blast
radius and makes ownership fuzzy.
Now, we practice a manual v0 independent compute release flow -- after getting a new compute release
image and tag, we pin it region by region using Admin UI. It's better, but it still has its own flaws:
1. It's a simple but fairly manual process, as you need to click through a few pages.
2. It's prone to human errors, e.g., you could mistype or copy the wrong compute tag.
3. We now require an additional approval in the Admin UI, which partially solves the 2.,
but also makes the whole process pretty annoying, as you constantly need to go back
and forth between two people.
## Non-goals
It's not the goal of this document to propose a design for some general-purpose release tool like Helm.
The document considers how the current compute fleet is orchestrated at Neon. Even if we later
decide to split the control plane further (e.g., introduce a separate compute controller), the proposed
release process shouldn't change much, i.e., the releases table and API will reside in
one of the parts.
Achieving the canary and/or blue-green deploy strategies is out of the scope of this document. They
were kept in mind, though, so it's expected that the proposed approach will lay down the foundation
for implementing them in future iterations.
## Impacted components
Compute, control plane, CI, observability (some Grafana dashboards may require changes).
## Prior art
One of the very close examples is how Helm tracks [releases history](https://helm.sh/docs/helm/helm_history/).
In the code:
- [Release](https://github.com/helm/helm/blob/2b30cf4b61d587d3f7594102bb202b787b9918db/pkg/release/release.go#L20-L43)
- [Release info](https://github.com/helm/helm/blob/2b30cf4b61d587d3f7594102bb202b787b9918db/pkg/release/info.go#L24-L40)
- [Release status](https://github.com/helm/helm/blob/2b30cf4b61d587d3f7594102bb202b787b9918db/pkg/release/status.go#L18-L42)
TL;DR it has several important attributes:
- Revision -- unique release ID/primary key. It is not the same as the application version,
because the same version can be deployed several times, e.g., after a newer version rollback.
- App version -- version of the application chart/code.
- Config -- set of overrides to the default config of the application.
- Status -- current status of the release in the history.
- Timestamps -- tracks when a release was created and deployed.
## Proposed implementation
### Separate release branch
We will use a separate release branch, `release-compute`, to have a clean history for releases and commits.
In order to avoid confusion with storage releases, we will use a different prefix for compute [git release
tags](https://github.com/neondatabase/neon/releases) -- `release-compute-XXXX`. We will use the same tag for
Docker images as well. The `neondatabase/compute-node-v16:release-compute-XXXX` looks longer and a bit redundant,
but it's better to have image and git tags in sync.
Currently, control plane relies on the numeric compute and storage release versions to decide on compute->storage
compatibility. Once we implement this proposal, we should drop this code as release numbers will be completely
independent. The only constraint we want is that it must monotonically increase within the same release branch.
### Compute config/settings manifest
We will create a new sub-directory `compute` and file `compute/manifest.yaml` with a structure:
```yaml
pg_settings:
# Common settings for primaries and secondaries of all versions.
common:
wal_log_hints: "off"
max_wal_size: "1024"
per_version:
14:
# Common settings for both replica and primary of version PG 14
common:
shared_preload_libraries: "neon,pg_stat_statements,extension_x"
15:
common:
shared_preload_libraries: "neon,pg_stat_statements,extension_x"
# Settings that should be applied only to
replica:
# Available only starting Postgres 15th
recovery_prefetch: "off"
# ...
17:
common:
# For example, if third-party `extension_x` is not yet available for PG 17
shared_preload_libraries: "neon,pg_stat_statements"
replica:
recovery_prefetch: "off"
```
**N.B.** Setting value should be a string with `on|off` for booleans and a number (as a string)
without units for all numeric settings. That's how the control plane currently operates.
The priority of settings will be (a higher number is a higher priority):
1. Any static and hard-coded settings in the control plane
2. `pg_settings->common`
3. Per-version `common`
4. Per-version `replica`
5. Any per-user/project/endpoint overrides in the control plane
6. Any dynamic setting calculated based on the compute size
**N.B.** For simplicity, we do not do any custom logic for `shared_preload_libraries`, so it's completely
overridden if specified on some level. Make sure that you include all necessary extensions in it when you
do any overrides.
**N.B.** There is a tricky question about what to do with custom compute image pinning we sometimes
do for particular projects and customers. That's usually some ad-hoc work and images are based on
the latest compute image, so it's relatively safe to assume that we could use settings from the latest compute
release. If for some reason that's not true, and further overrides are needed, it's also possible to do
on the project level together with pinning the image, so it's on-call/engineer/support responsibility to
ensure that compute starts with the specified custom image. The only real risk is that compute image will get
stale and settings from new releases will drift away, so eventually it will get something incompatible,
but i) this is some operational issue, as we do not want stale images anyway, and ii) base settings
receive something really new so rarely that the chance of this happening is very low. If we want to solve it completely,
then together with pinning the image we could also pin the matching release revision in the control plane.
The compute team will own the content of `compute/manifest.yaml`.
### Control plane: releases table
In order to store information about releases, the control plane will use a table `compute_releases` with the following
schema:
```sql
CREATE TABLE compute_releases (
-- Unique release ID
-- N.B. Revision won't by synchronized across all regions, because all control planes are technically independent
-- services. We have the same situation with Helm releases as well because they could be deployed and rolled back
-- independently in different clusters.
revision BIGSERIAL PRIMARY KEY,
-- Numeric version of the compute image, e.g. 9057
version BIGINT NOT NULL,
-- Compute image tag, e.g. `release-9057`
tag TEXT NOT NULL,
-- Current release status. Currently, it will be a simple enum
-- * `deployed` -- release is deployed and used for new compute starts.
-- Exactly one release can have this status at a time.
-- * `superseded` -- release has been replaced by a newer one.
-- But we can always extend it in the future when we need more statuses
-- for more complex deployment strategies.
status TEXT NOT NULL,
-- Any additional metadata for compute in the corresponding release
manifest JSONB NOT NULL,
-- Timestamp when release record was created in the control plane database
created_at TIMESTAMP NOT NULL DEFAULT now(),
-- Timestamp when release deployment was finished
deployed_at TIMESTAMP
);
```
We keep track of the old releases not only for the sake of audit, but also because we usually have ~30% of
old computes started using the image from one of the previous releases. Yet, when users want to reconfigure
them without restarting, the control plane needs to know what settings are applicable to them, so we also need
information about the previous releases that are readily available. There could be some other auxiliary info
needed as well: supported extensions, compute flags, etc.
**N.B.** Here, we can end up in an ambiguous situation when the same compute image is deployed twice, e.g.,
it was deployed once, then rolled back, and then deployed again, potentially with a different manifest. Yet,
we could've started some computes with the first deployment and some with the second. Thus, when we need to
look up the manifest for the compute by its image tag, we will see two records in the table with the same tag,
but different revision numbers. We can assume that this could happen only in case of rollbacks, so we
can just take the latest revision for the given tag.
### Control plane: management API
The control plane will implement new API methods to manage releases:
1. `POST /management/api/v2/compute_releases` to create a new release. With payload
```json
{
"version": 9057,
"tag": "release-9057",
"manifest": {}
}
```
and response
```json
{
"revision": 53,
"version": 9057,
"tag": "release-9057",
"status": "deployed",
"manifest": {},
"created_at": "2024-08-15T15:52:01.0000Z",
"deployed_at": "2024-08-15T15:52:01.0000Z",
}
```
Here, we can actually mix-in custom (remote) extensions metadata into the `manifest`, so that the control plane
will get information about all available extensions not bundled into compute image. The corresponding
workflow in `neondatabase/build-custom-extensions` should produce it as an artifact and make
it accessible to the workflow in the `neondatabase/infra`. See the complete release flow below. Doing that,
we put a constraint that new custom extension requires new compute release, which is good for the safety,
but is not exactly what we want operational-wise (we want to be able to deploy new extensions without new
images). Yet, it can be solved incrementally: v0 -- do not do anything with extensions at all;
v1 -- put them into the same manifest; v2 -- make them separate entities with their own lifecycle.
**N.B.** This method is intended to be used in CI workflows, and CI/network can be flaky. It's reasonable
to assume that we could retry the request several times, even though it's already succeeded. Although it's
not a big deal to create several identical releases one-by-one, it's better to avoid it, so the control plane
should check if the latest release is identical and just return `304 Not Modified` in this case.
2. `POST /management/api/v2/compute_releases/rollback` to rollback to any previously deployed release. With payload
including the revision of the release to rollback to:
```json
{
"revision": 52
}
```
Rollback marks the current release as `superseded` and creates a new release with all the same data as the
requested revision, but with a new revision number.
This rollback API is not strictly needed, as we can just use `infra` repo workflow to deploy any
available tag. It's still nice to have for on-call and any urgent matters, for example, if we need
to rollback and GitHub is down. It's much easier to specify only the revision number vs. crafting
all the necessary data for the new release payload.
### Compute->storage compatibility tests
In order to safely release new compute versions independently from storage, we need to ensure that the currently
deployed storage is compatible with the new compute version. Currently, we maintain backward compatibility
in storage, but newer computes may require a newer storage version.
Remote end-to-end (e2e) tests [already accept](https://github.com/neondatabase/cloud/blob/e3468d433e0d73d02b7d7e738d027f509b522408/.github/workflows/testing.yml#L43-L48)
`storage_image_tag` and `compute_image_tag` as separate inputs. That means that we could reuse e2e tests to ensure
compatibility between storage and compute:
1. Pick the latest storage release tag and use it as `storage_image_tag`.
2. Pick a new compute tag built in the current compute release PR and use it as `compute_image_tag`.
Here, we should use a temporary ECR image tag, because the final tag will be known only after the release PR is merged.
3. Trigger e2e tests as usual.
### Release flow
```mermaid
sequenceDiagram
actor oncall as Compute on-call person
participant neon as neondatabase/neon
box private
participant cloud as neondatabase/cloud
participant exts as neondatabase/build-custom-extensions
participant infra as neondatabase/infra
end
box cloud
participant preprod as Pre-prod control plane
participant prod as Production control plane
participant k8s as Compute k8s
end
oncall ->> neon: Open release PR into release-compute
activate neon
neon ->> cloud: CI: trigger e2e compatibility tests
activate cloud
cloud -->> neon: CI: e2e tests pass
deactivate cloud
neon ->> neon: CI: pass PR checks, get approvals
deactivate neon
oncall ->> neon: Merge release PR into release-compute
activate neon
neon ->> neon: CI: pass checks, build and push images
neon ->> exts: CI: trigger extensions build
activate exts
exts -->> neon: CI: extensions are ready
deactivate exts
neon ->> neon: CI: create release tag
neon ->> infra: Trigger release workflow using the produced tag
deactivate neon
activate infra
infra ->> infra: CI: pass checks
infra ->> preprod: Release new compute image to pre-prod automatically <br/> POST /management/api/v2/compute_releases
activate preprod
preprod -->> infra: 200 OK
deactivate preprod
infra ->> infra: CI: wait for per-region production deploy approvals
oncall ->> infra: CI: approve deploys region by region
infra ->> k8s: Prewarm new compute image
infra ->> prod: POST /management/api/v2/compute_releases
activate prod
prod -->> infra: 200 OK
deactivate prod
deactivate infra
```
## Further work
As briefly mentioned in other sections, eventually, we would like to use more complex deployment strategies.
For example, we can pass a fraction of the total compute starts that should use the new release. Then we can
mark the release as `partial` or `canary` and monitor its performance. If everything is fine, we can promote it
to `deployed` status. If not, we can roll back to the previous one.
## Alternatives
In theory, we can try using Helm as-is:
1. Write a compute Helm chart. That will actually have only some config map, which the control plane can access and read.
N.B. We could reuse the control plane chart as well, but then it's not a fully independent release again and even more fuzzy.
2. The control plane will read it and start using the new compute version for new starts.
Drawbacks:
1. Helm releases work best if the workload is controlled by the Helm chart itself. Then you can have different
deployment strategies like rolling update or canary or blue/green deployments. At Neon, the compute starts are controlled
by control plane, so it makes it much more tricky.
2. Releases visibility will suffer, i.e. instead of a nice table in the control plane and Admin UI, we would need to use
`helm` cli and/or K8s UIs like K8sLens.
3. We do not restart all computes shortly after the new version release. This means that for some features and compatibility
purpose (see above) control plane may need some auxiliary info from the previous releases.

View File

@@ -268,22 +268,6 @@ pub struct GenericOption {
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;
/// Configured the local-proxy application with the relevant JWKS and roles it should
/// use for authorizing connect requests using JWT.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct LocalProxySpec {
pub jwks: Vec<JwksSettings>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct JwksSettings {
pub id: String,
pub role_names: Vec<String>,
pub jwks_url: String,
pub provider_name: String,
pub jwt_audience: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,5 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::fmt::Display;
use std::str::FromStr;
use std::time::{Duration, Instant};
@@ -58,7 +57,7 @@ pub struct NodeRegisterRequest {
pub listen_http_addr: String,
pub listen_http_port: u16,
pub availability_zone_id: AvailabilityZone,
pub availability_zone_id: String,
}
#[derive(Serialize, Deserialize)]
@@ -75,19 +74,10 @@ pub struct TenantPolicyRequest {
pub scheduling: Option<ShardSchedulingPolicy>,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct AvailabilityZone(pub String);
impl Display for AvailabilityZone {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Serialize, Deserialize)]
pub struct ShardsPreferredAzsRequest {
#[serde(flatten)]
pub preferred_az_ids: HashMap<TenantShardId, AvailabilityZone>,
pub preferred_az_ids: HashMap<TenantShardId, String>,
}
#[derive(Serialize, Deserialize)]

View File

@@ -37,11 +37,14 @@ use bytes::{Buf, BufMut, Bytes, BytesMut};
/// ```mermaid
/// stateDiagram-v2
///
/// [*] --> Loading: spawn_load()
/// [*] --> Attaching: spawn_attach()
///
/// Loading --> Activating: activate()
/// Attaching --> Activating: activate()
/// Activating --> Active: infallible
///
/// Loading --> Broken: load() failure
/// Attaching --> Broken: attach() failure
///
/// Active --> Stopping: set_stopping(), part of shutdown & detach
@@ -65,6 +68,10 @@ use bytes::{Buf, BufMut, Bytes, BytesMut};
)]
#[serde(tag = "slug", content = "data")]
pub enum TenantState {
/// This tenant is being loaded from local disk.
///
/// `set_stopping()` and `set_broken()` do not work in this state and wait for it to pass.
Loading,
/// This tenant is being attached to the pageserver.
///
/// `set_stopping()` and `set_broken()` do not work in this state and wait for it to pass.
@@ -114,6 +121,8 @@ impl TenantState {
// But, our attach task might still be fetching the remote timelines, etc.
// So, return `Maybe` while Attaching, making Console wait for the attach task to finish.
Self::Attaching | Self::Activating(ActivatingFrom::Attaching) => Maybe,
// tenant mgr startup distinguishes attaching from loading via marker file.
Self::Loading | Self::Activating(ActivatingFrom::Loading) => Attached,
// We only reach Active after successful load / attach.
// So, call atttachment status Attached.
Self::Active => Attached,
@@ -182,11 +191,10 @@ impl LsnLease {
}
/// The only [`TenantState`] variants we could be `TenantState::Activating` from.
///
/// XXX: We used to have more variants here, but now it's just one, which makes this rather
/// useless. Remove, once we've checked that there's no client code left that looks at this.
#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ActivatingFrom {
/// Arrived to [`TenantState::Activating`] from [`TenantState::Loading`]
Loading,
/// Arrived to [`TenantState::Activating`] from [`TenantState::Attaching`]
Attaching,
}
@@ -1554,8 +1562,11 @@ mod tests {
#[test]
fn tenantstatus_activating_serde() {
let states = [TenantState::Activating(ActivatingFrom::Attaching)];
let expected = "[{\"slug\":\"Activating\",\"data\":\"Attaching\"}]";
let states = [
TenantState::Activating(ActivatingFrom::Loading),
TenantState::Activating(ActivatingFrom::Attaching),
];
let expected = "[{\"slug\":\"Activating\",\"data\":\"Loading\"},{\"slug\":\"Activating\",\"data\":\"Attaching\"}]";
let actual = serde_json::to_string(&states).unwrap();
@@ -1570,7 +1581,13 @@ mod tests {
fn tenantstatus_activating_strum() {
// tests added, because we use these for metrics
let examples = [
(line!(), TenantState::Loading, "Loading"),
(line!(), TenantState::Attaching, "Attaching"),
(
line!(),
TenantState::Activating(ActivatingFrom::Loading),
"Activating",
),
(
line!(),
TenantState::Activating(ActivatingFrom::Attaching),

View File

@@ -984,7 +984,6 @@ pub fn short_error(e: &QueryError) -> String {
}
fn log_query_error(query: &str, e: &QueryError) {
// If you want to change the log level of a specific error, also re-categorize it in `BasebackupQueryTimeOngoingRecording`.
match e {
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
if is_expected_io_error(io_error) {

View File

@@ -19,7 +19,6 @@ bincode.workspace = true
bytes.workspace = true
camino.workspace = true
chrono.workspace = true
git-version.workspace = true
hex = { workspace = true, features = ["serde"] }
humantime.workspace = true
hyper = { workspace = true, features = ["full"] }

View File

@@ -92,10 +92,6 @@ pub mod toml_edit_ext;
pub mod circuit_breaker;
// Re-export used in macro. Avoids adding git-version as dep in target crates.
#[doc(hidden)]
pub use git_version;
/// This is a shortcut to embed git sha into binaries and avoid copying the same build script to all packages
///
/// we have several cases:
@@ -135,7 +131,7 @@ macro_rules! project_git_version {
($const_identifier:ident) => {
// this should try GIT_VERSION first only then git_version::git_version!
const $const_identifier: &::core::primitive::str = {
const __COMMIT_FROM_GIT: &::core::primitive::str = $crate::git_version::git_version! {
const __COMMIT_FROM_GIT: &::core::primitive::str = git_version::git_version! {
prefix = "",
fallback = "unknown",
args = ["--abbrev=40", "--always", "--dirty=-modified"] // always use full sha

View File

@@ -27,6 +27,7 @@ crc32c.workspace = true
either.workspace = true
fail.workspace = true
futures.workspace = true
git-version.workspace = true
hex.workspace = true
humantime.workspace = true
humantime-serde.workspace = true

View File

@@ -12,6 +12,7 @@ anyhow.workspace = true
async-stream.workspace = true
clap = { workspace = true, features = ["string"] }
futures.workspace = true
git-version.workspace = true
itertools.workspace = true
once_cell.workspace = true
pageserver_api.workspace = true

View File

@@ -10,6 +10,7 @@ license.workspace = true
anyhow.workspace = true
camino.workspace = true
clap = { workspace = true, features = ["string"] }
git-version.workspace = true
humantime.workspace = true
pageserver = { path = ".." }
pageserver_api.workspace = true

View File

@@ -15,7 +15,7 @@ use clap::{Arg, ArgAction, Command};
use metrics::launch_timestamp::{set_launch_timestamp_metric, LaunchTimestamp};
use pageserver::config::PageserverIdentity;
use pageserver::controller_upcall_client::ControllerUpcallClient;
use pageserver::control_plane_client::ControlPlaneClient;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::task_mgr::{COMPUTE_REQUEST_RUNTIME, WALRECEIVER_RUNTIME};
@@ -396,7 +396,7 @@ fn start_pageserver(
// Set up deletion queue
let (deletion_queue, deletion_workers) = DeletionQueue::new(
remote_storage.clone(),
ControllerUpcallClient::new(conf, &shutdown_pageserver),
ControlPlaneClient::new(conf, &shutdown_pageserver),
conf,
);
if let Some(deletion_workers) = deletion_workers {

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use futures::Future;
use pageserver_api::{
controller_api::{AvailabilityZone, NodeRegisterRequest},
controller_api::NodeRegisterRequest,
shard::TenantShardId,
upcall_api::{
ReAttachRequest, ReAttachResponse, ReAttachResponseTenant, ValidateRequest,
@@ -17,12 +17,9 @@ use utils::{backoff, failpoint_support, generation::Generation, id::NodeId};
use crate::{config::PageServerConf, virtual_file::on_fatal_io_error};
use pageserver_api::config::NodeMetadata;
/// The Pageserver's client for using the storage controller upcall API: this is a small API
/// for dealing with generations (see docs/rfcs/025-generation-numbers.md).
///
/// The server presenting this API may either be the storage controller or some other
/// service (such as the Neon control plane) providing a store of generation numbers.
pub struct ControllerUpcallClient {
/// The Pageserver's client for using the control plane API: this is a small subset
/// of the overall control plane API, for dealing with generations (see docs/rfcs/025-generation-numbers.md)
pub struct ControlPlaneClient {
http_client: reqwest::Client,
base_url: Url,
node_id: NodeId,
@@ -48,7 +45,7 @@ pub trait ControlPlaneGenerationsApi {
) -> impl Future<Output = Result<HashMap<TenantShardId, bool>, RetryForeverError>> + Send;
}
impl ControllerUpcallClient {
impl ControlPlaneClient {
/// A None return value indicates that the input `conf` object does not have control
/// plane API enabled.
pub fn new(conf: &'static PageServerConf, cancel: &CancellationToken) -> Option<Self> {
@@ -117,7 +114,7 @@ impl ControllerUpcallClient {
}
}
impl ControlPlaneGenerationsApi for ControllerUpcallClient {
impl ControlPlaneGenerationsApi for ControlPlaneClient {
/// Block until we get a successful response, or error out if we are shut down
async fn re_attach(
&self,
@@ -151,10 +148,10 @@ impl ControlPlaneGenerationsApi for ControllerUpcallClient {
.and_then(|jv| jv.as_str().map(|str| str.to_owned()));
match az_id_from_metadata {
Some(az_id) => Some(AvailabilityZone(az_id)),
Some(az_id) => Some(az_id),
None => {
tracing::warn!("metadata.json does not contain an 'availability_zone_id' field");
conf.availability_zone.clone().map(AvailabilityZone)
conf.availability_zone.clone()
}
}
};
@@ -219,38 +216,29 @@ impl ControlPlaneGenerationsApi for ControllerUpcallClient {
.join("validate")
.expect("Failed to build validate path");
// When sending validate requests, break them up into chunks so that we
// avoid possible edge cases of generating any HTTP requests that
// require database I/O across many thousands of tenants.
let mut result: HashMap<TenantShardId, bool> = HashMap::with_capacity(tenants.len());
for tenant_chunk in (tenants).chunks(128) {
let request = ValidateRequest {
tenants: tenant_chunk
.iter()
.map(|(id, generation)| ValidateRequestTenant {
id: *id,
gen: (*generation).into().expect(
"Generation should always be valid for a Tenant doing deletions",
),
})
.collect(),
};
let request = ValidateRequest {
tenants: tenants
.into_iter()
.map(|(id, gen)| ValidateRequestTenant {
id,
gen: gen
.into()
.expect("Generation should always be valid for a Tenant doing deletions"),
})
.collect(),
};
failpoint_support::sleep_millis_async!(
"control-plane-client-validate-sleep",
&self.cancel
);
if self.cancel.is_cancelled() {
return Err(RetryForeverError::ShuttingDown);
}
let response: ValidateResponse =
self.retry_http_forever(&re_attach_path, request).await?;
for rt in response.tenants {
result.insert(rt.id, rt.valid);
}
failpoint_support::sleep_millis_async!("control-plane-client-validate-sleep", &self.cancel);
if self.cancel.is_cancelled() {
return Err(RetryForeverError::ShuttingDown);
}
Ok(result.into_iter().collect())
let response: ValidateResponse = self.retry_http_forever(&re_attach_path, request).await?;
Ok(response
.tenants
.into_iter()
.map(|rt| (rt.id, rt.valid))
.collect())
}
}

View File

@@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use crate::controller_upcall_client::ControlPlaneGenerationsApi;
use crate::control_plane_client::ControlPlaneGenerationsApi;
use crate::metrics;
use crate::tenant::remote_timeline_client::remote_layer_path;
use crate::tenant::remote_timeline_client::remote_timeline_path;
@@ -622,7 +622,7 @@ impl DeletionQueue {
/// If remote_storage is None, then the returned workers will also be None.
pub fn new<C>(
remote_storage: GenericRemoteStorage,
controller_upcall_client: Option<C>,
control_plane_client: Option<C>,
conf: &'static PageServerConf,
) -> (Self, Option<DeletionQueueWorkers<C>>)
where
@@ -662,7 +662,7 @@ impl DeletionQueue {
conf,
backend_rx,
executor_tx,
controller_upcall_client,
control_plane_client,
lsn_table.clone(),
cancel.clone(),
),
@@ -704,7 +704,7 @@ mod test {
use tokio::task::JoinHandle;
use crate::{
controller_upcall_client::RetryForeverError,
control_plane_client::RetryForeverError,
repository::Key,
tenant::{harness::TenantHarness, storage_layer::DeltaLayerName},
};

View File

@@ -25,8 +25,8 @@ use tracing::info;
use tracing::warn;
use crate::config::PageServerConf;
use crate::controller_upcall_client::ControlPlaneGenerationsApi;
use crate::controller_upcall_client::RetryForeverError;
use crate::control_plane_client::ControlPlaneGenerationsApi;
use crate::control_plane_client::RetryForeverError;
use crate::metrics;
use crate::virtual_file::MaybeFatalIo;
@@ -61,7 +61,7 @@ where
tx: tokio::sync::mpsc::Sender<DeleterMessage>,
// Client for calling into control plane API for validation of deletes
controller_upcall_client: Option<C>,
control_plane_client: Option<C>,
// DeletionLists which are waiting generation validation. Not safe to
// execute until [`validate`] has processed them.
@@ -94,7 +94,7 @@ where
conf: &'static PageServerConf,
rx: tokio::sync::mpsc::Receiver<ValidatorQueueMessage>,
tx: tokio::sync::mpsc::Sender<DeleterMessage>,
controller_upcall_client: Option<C>,
control_plane_client: Option<C>,
lsn_table: Arc<std::sync::RwLock<VisibleLsnUpdates>>,
cancel: CancellationToken,
) -> Self {
@@ -102,7 +102,7 @@ where
conf,
rx,
tx,
controller_upcall_client,
control_plane_client,
lsn_table,
pending_lists: Vec::new(),
validated_lists: Vec::new(),
@@ -145,8 +145,8 @@ where
return Ok(());
}
let tenants_valid = if let Some(controller_upcall_client) = &self.controller_upcall_client {
match controller_upcall_client
let tenants_valid = if let Some(control_plane_client) = &self.control_plane_client {
match control_plane_client
.validate(tenant_generations.iter().map(|(k, v)| (*k, *v)).collect())
.await
{

View File

@@ -589,10 +589,6 @@ async fn timeline_create_handler(
StatusCode::SERVICE_UNAVAILABLE,
HttpErrorBody::from_msg(e.to_string()),
),
Err(e @ tenant::CreateTimelineError::AncestorArchived) => json_response(
StatusCode::NOT_ACCEPTABLE,
HttpErrorBody::from_msg(e.to_string()),
),
Err(tenant::CreateTimelineError::ShuttingDown) => json_response(
StatusCode::SERVICE_UNAVAILABLE,
HttpErrorBody::from_msg("tenant shutting down".to_string()),

View File

@@ -6,7 +6,7 @@ pub mod basebackup;
pub mod config;
pub mod consumption_metrics;
pub mod context;
pub mod controller_upcall_client;
pub mod control_plane_client;
pub mod deletion_queue;
pub mod disk_usage_eviction_task;
pub mod http;

View File

@@ -8,8 +8,6 @@ use metrics::{
};
use once_cell::sync::Lazy;
use pageserver_api::shard::TenantShardId;
use postgres_backend::{is_expected_io_error, QueryError};
use pq_proto::framed::ConnectionError;
use strum::{EnumCount, VariantNames};
use strum_macros::{IntoStaticStr, VariantNames};
use tracing::warn;
@@ -1385,7 +1383,7 @@ impl SmgrQueryTimePerTimeline {
&'a self,
op: SmgrQueryType,
ctx: &'c RequestContext,
) -> Option<impl Drop + 'a> {
) -> Option<impl Drop + '_> {
let start = Instant::now();
self.global_started[op as usize].inc();
@@ -1510,7 +1508,6 @@ static COMPUTE_STARTUP_BUCKETS: Lazy<[f64; 28]> = Lazy::new(|| {
pub(crate) struct BasebackupQueryTime {
ok: Histogram,
error: Histogram,
client_error: Histogram,
}
pub(crate) static BASEBACKUP_QUERY_TIME: Lazy<BasebackupQueryTime> = Lazy::new(|| {
@@ -1524,7 +1521,6 @@ pub(crate) static BASEBACKUP_QUERY_TIME: Lazy<BasebackupQueryTime> = Lazy::new(|
BasebackupQueryTime {
ok: vec.get_metric_with_label_values(&["ok"]).unwrap(),
error: vec.get_metric_with_label_values(&["error"]).unwrap(),
client_error: vec.get_metric_with_label_values(&["client_error"]).unwrap(),
}
});
@@ -1538,7 +1534,7 @@ impl BasebackupQueryTime {
pub(crate) fn start_recording<'c: 'a, 'a>(
&'a self,
ctx: &'c RequestContext,
) -> BasebackupQueryTimeOngoingRecording<'a, 'a> {
) -> BasebackupQueryTimeOngoingRecording<'_, '_> {
let start = Instant::now();
match ctx.micros_spent_throttled.open() {
Ok(()) => (),
@@ -1561,7 +1557,7 @@ impl BasebackupQueryTime {
}
impl<'a, 'c> BasebackupQueryTimeOngoingRecording<'a, 'c> {
pub(crate) fn observe<T>(self, res: &Result<T, QueryError>) {
pub(crate) fn observe<T, E>(self, res: &Result<T, E>) {
let elapsed = self.start.elapsed();
let ex_throttled = self
.ctx
@@ -1580,15 +1576,10 @@ impl<'a, 'c> BasebackupQueryTimeOngoingRecording<'a, 'c> {
elapsed
}
};
// If you want to change categorize of a specific error, also change it in `log_query_error`.
let metric = match res {
Ok(_) => &self.parent.ok,
Err(QueryError::Disconnected(ConnectionError::Io(io_error)))
if is_expected_io_error(io_error) =>
{
&self.parent.client_error
}
Err(_) => &self.parent.error,
let metric = if res.is_ok() {
&self.parent.ok
} else {
&self.parent.error
};
metric.observe(ex_throttled.as_secs_f64());
}
@@ -3217,38 +3208,45 @@ pub(crate) mod tenant_throttling {
impl TimelineGet {
pub(crate) fn new(tenant_shard_id: &TenantShardId) -> Self {
let per_tenant_label_values = &[
KIND,
&tenant_shard_id.tenant_id.to_string(),
&tenant_shard_id.shard_slug().to_string(),
];
TimelineGet {
count_accounted_start: {
GlobalAndPerTenantIntCounter {
global: COUNT_ACCOUNTED_START.with_label_values(&[KIND]),
per_tenant: COUNT_ACCOUNTED_START_PER_TENANT
.with_label_values(per_tenant_label_values),
per_tenant: COUNT_ACCOUNTED_START_PER_TENANT.with_label_values(&[
KIND,
&tenant_shard_id.tenant_id.to_string(),
&tenant_shard_id.shard_slug().to_string(),
]),
}
},
count_accounted_finish: {
GlobalAndPerTenantIntCounter {
global: COUNT_ACCOUNTED_FINISH.with_label_values(&[KIND]),
per_tenant: COUNT_ACCOUNTED_FINISH_PER_TENANT
.with_label_values(per_tenant_label_values),
per_tenant: COUNT_ACCOUNTED_FINISH_PER_TENANT.with_label_values(&[
KIND,
&tenant_shard_id.tenant_id.to_string(),
&tenant_shard_id.shard_slug().to_string(),
]),
}
},
wait_time: {
GlobalAndPerTenantIntCounter {
global: WAIT_USECS.with_label_values(&[KIND]),
per_tenant: WAIT_USECS_PER_TENANT
.with_label_values(per_tenant_label_values),
per_tenant: WAIT_USECS_PER_TENANT.with_label_values(&[
KIND,
&tenant_shard_id.tenant_id.to_string(),
&tenant_shard_id.shard_slug().to_string(),
]),
}
},
count_throttled: {
GlobalAndPerTenantIntCounter {
global: WAIT_COUNT.with_label_values(&[KIND]),
per_tenant: WAIT_COUNT_PER_TENANT
.with_label_values(per_tenant_label_values),
per_tenant: WAIT_COUNT_PER_TENANT.with_label_values(&[
KIND,
&tenant_shard_id.tenant_id.to_string(),
&tenant_shard_id.shard_slug().to_string(),
]),
}
},
}

View File

@@ -563,8 +563,6 @@ pub enum CreateTimelineError {
AncestorLsn(anyhow::Error),
#[error("ancestor timeline is not active")]
AncestorNotActive,
#[error("ancestor timeline is archived")]
AncestorArchived,
#[error("tenant shutting down")]
ShuttingDown,
#[error(transparent)]
@@ -1700,11 +1698,6 @@ impl Tenant {
return Err(CreateTimelineError::AncestorNotActive);
}
if ancestor_timeline.is_archived() == Some(true) {
info!("tried to branch archived timeline");
return Err(CreateTimelineError::AncestorArchived);
}
if let Some(lsn) = ancestor_start_lsn.as_mut() {
*lsn = lsn.align();
@@ -1975,6 +1968,9 @@ impl Tenant {
TenantState::Activating(_) | TenantState::Active | TenantState::Broken { .. } | TenantState::Stopping { .. } => {
panic!("caller is responsible for calling activate() only on Loading / Attaching tenants, got {state:?}", state = current_state);
}
TenantState::Loading => {
*current_state = TenantState::Activating(ActivatingFrom::Loading);
}
TenantState::Attaching => {
*current_state = TenantState::Activating(ActivatingFrom::Attaching);
}
@@ -2155,7 +2151,7 @@ impl Tenant {
async fn set_stopping(
&self,
progress: completion::Barrier,
_allow_transition_from_loading: bool,
allow_transition_from_loading: bool,
allow_transition_from_attaching: bool,
) -> Result<(), SetStoppingError> {
let mut rx = self.state.subscribe();
@@ -2170,6 +2166,7 @@ impl Tenant {
);
false
}
TenantState::Loading => allow_transition_from_loading,
TenantState::Active | TenantState::Broken { .. } | TenantState::Stopping { .. } => true,
})
.await
@@ -2188,6 +2185,13 @@ impl Tenant {
*current_state = TenantState::Stopping { progress };
true
}
TenantState::Loading => {
if !allow_transition_from_loading {
unreachable!("3we ensured above that we're done with activation, and, there is no re-activation")
};
*current_state = TenantState::Stopping { progress };
true
}
TenantState::Active => {
// FIXME: due to time-of-check vs time-of-use issues, it can happen that new timelines
// are created after the transition to Stopping. That's harmless, as the Timelines
@@ -2243,7 +2247,7 @@ impl Tenant {
// The load & attach routines own the tenant state until it has reached `Active`.
// So, wait until it's done.
rx.wait_for(|state| match state {
TenantState::Activating(_) | TenantState::Attaching => {
TenantState::Activating(_) | TenantState::Loading | TenantState::Attaching => {
info!(
"waiting for {} to turn Active|Broken|Stopping",
<&'static str>::from(state)
@@ -2263,7 +2267,7 @@ impl Tenant {
let reason = reason.to_string();
self.state.send_modify(|current_state| {
match *current_state {
TenantState::Activating(_) | TenantState::Attaching => {
TenantState::Activating(_) | TenantState::Loading | TenantState::Attaching => {
unreachable!("we ensured above that we're done with activation, and, there is no re-activation")
}
TenantState::Active => {
@@ -2307,7 +2311,7 @@ impl Tenant {
loop {
let current_state = receiver.borrow_and_update().clone();
match current_state {
TenantState::Attaching | TenantState::Activating(_) => {
TenantState::Loading | TenantState::Attaching | TenantState::Activating(_) => {
// in these states, there's a chance that we can reach ::Active
self.activate_now();
match timeout_cancellable(timeout, &self.cancel, receiver.changed()).await {
@@ -3623,7 +3627,7 @@ impl Tenant {
start_lsn: Lsn,
ancestor: Option<Arc<Timeline>>,
last_aux_file_policy: Option<AuxFilePolicy>,
) -> anyhow::Result<UninitializedTimeline<'a>> {
) -> anyhow::Result<UninitializedTimeline> {
let tenant_shard_id = self.tenant_shard_id;
let resources = self.build_timeline_resources(new_timeline_id);
@@ -4140,7 +4144,7 @@ pub(crate) mod harness {
let walredo_mgr = Arc::new(WalRedoManager::from(TestRedoManager));
let tenant = Arc::new(Tenant::new(
TenantState::Attaching,
TenantState::Loading,
self.conf,
AttachedTenantConf::try_from(LocationConf::attached_single(
TenantConfOpt::from(self.tenant_conf.clone()),

View File

@@ -5,7 +5,6 @@ use itertools::Itertools;
use super::storage_layer::LayerName;
/// Checks whether a layer map is valid (i.e., is a valid result of the current compaction algorithm if nothing goes wrong).
///
/// The function checks if we can split the LSN range of a delta layer only at the LSNs of the delta layers. For example,
///
/// ```plain

View File

@@ -30,8 +30,8 @@ use utils::{backoff, completion, crashsafe};
use crate::config::PageServerConf;
use crate::context::{DownloadBehavior, RequestContext};
use crate::controller_upcall_client::{
ControlPlaneGenerationsApi, ControllerUpcallClient, RetryForeverError,
use crate::control_plane_client::{
ControlPlaneClient, ControlPlaneGenerationsApi, RetryForeverError,
};
use crate::deletion_queue::DeletionQueueClient;
use crate::http::routes::ACTIVE_TENANT_TIMEOUT;
@@ -122,7 +122,7 @@ pub(crate) enum ShardSelector {
Known(ShardIndex),
}
/// A convenience for use with the re_attach ControllerUpcallClient function: rather
/// A convenience for use with the re_attach ControlPlaneClient function: rather
/// than the serializable struct, we build this enum that encapsulates
/// the invariant that attached tenants always have generations.
///
@@ -341,7 +341,7 @@ async fn init_load_generations(
"Emergency mode! Tenants will be attached unsafely using their last known generation"
);
emergency_generations(tenant_confs)
} else if let Some(client) = ControllerUpcallClient::new(conf, cancel) {
} else if let Some(client) = ControlPlaneClient::new(conf, cancel) {
info!("Calling control plane API to re-attach tenants");
// If we are configured to use the control plane API, then it is the source of truth for what tenants to load.
match client.re_attach(conf).await {

View File

@@ -39,7 +39,7 @@ use crate::tenant::disk_btree::{
use crate::tenant::storage_layer::layer::S3_UPLOAD_LIMIT;
use crate::tenant::timeline::GetVectoredError;
use crate::tenant::vectored_blob_io::{
BlobFlag, BufView, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
BlobFlag, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
VectoredReadCoalesceMode, VectoredReadPlanner,
};
use crate::tenant::PageReconstructError;
@@ -1021,30 +1021,13 @@ impl DeltaLayerInner {
continue;
}
};
let view = BufView::new_slice(&blobs_buf.buf);
for meta in blobs_buf.blobs.iter().rev() {
if Some(meta.meta.key) == ignore_key_with_err {
continue;
}
let blob_read = meta.read(&view).await;
let blob_read = match blob_read {
Ok(buf) => buf,
Err(e) => {
reconstruct_state.on_key_error(
meta.meta.key,
PageReconstructError::Other(anyhow!(e).context(format!(
"Failed to decompress blob from virtual file {}",
self.file.path,
))),
);
ignore_key_with_err = Some(meta.meta.key);
continue;
}
};
let value = Value::des(&blob_read);
let value = Value::des(&blobs_buf.buf[meta.start..meta.end]);
let value = match value {
Ok(v) => v,
Err(e) => {
@@ -1260,21 +1243,21 @@ impl DeltaLayerInner {
buf.reserve(read.size());
let res = reader.read_blobs(&read, buf, ctx).await?;
let view = BufView::new_slice(&res.buf);
for blob in res.blobs {
let key = blob.meta.key;
let lsn = blob.meta.lsn;
let data = blob.read(&view).await?;
let data = &res.buf[blob.start..blob.end];
#[cfg(debug_assertions)]
Value::des(&data)
Value::des(data)
.with_context(|| {
format!(
"blob failed to deserialize for {}: {:?}",
blob,
utils::Hex(&data)
"blob failed to deserialize for {}@{}, {}..{}: {:?}",
blob.meta.key,
blob.meta.lsn,
blob.start,
blob.end,
utils::Hex(data)
)
})
.unwrap();
@@ -1282,15 +1265,15 @@ impl DeltaLayerInner {
// is it an image or will_init walrecord?
// FIXME: this could be handled by threading the BlobRef to the
// VectoredReadBuilder
let will_init = crate::repository::ValueBytes::will_init(&data)
let will_init = crate::repository::ValueBytes::will_init(data)
.inspect_err(|_e| {
#[cfg(feature = "testing")]
tracing::error!(data=?utils::Hex(&data), err=?_e, %key, %lsn, "failed to parse will_init out of serialized value");
tracing::error!(data=?utils::Hex(data), err=?_e, %key, %lsn, "failed to parse will_init out of serialized value");
})
.unwrap_or(false);
per_blob_copy.clear();
per_blob_copy.extend_from_slice(&data);
per_blob_copy.extend_from_slice(data);
let (tmp, res) = writer
.put_value_bytes(
@@ -1555,11 +1538,8 @@ impl<'a> DeltaLayerIterator<'a> {
.read_blobs(&plan, buf, self.ctx)
.await?;
let frozen_buf = blobs_buf.buf.freeze();
let view = BufView::new_bytes(frozen_buf);
for meta in blobs_buf.blobs.iter() {
let blob_read = meta.read(&view).await?;
let value = Value::des(&blob_read)?;
let value = Value::des(&frozen_buf[meta.start..meta.end])?;
next_batch.push_back((meta.meta.key, meta.meta.lsn, value));
}
self.key_values_batch = next_batch;
@@ -1936,13 +1916,9 @@ pub(crate) mod test {
let blobs_buf = vectored_blob_reader
.read_blobs(&read, buf.take().expect("Should have a buffer"), &ctx)
.await?;
let view = BufView::new_slice(&blobs_buf.buf);
for meta in blobs_buf.blobs.iter() {
let value = meta.read(&view).await?;
assert_eq!(
&value[..],
&entries_meta.index[&(meta.meta.key, meta.meta.lsn)]
);
let value = &blobs_buf.buf[meta.start..meta.end];
assert_eq!(value, entries_meta.index[&(meta.meta.key, meta.meta.lsn)]);
}
buf = Some(blobs_buf.buf);

View File

@@ -36,8 +36,7 @@ use crate::tenant::disk_btree::{
};
use crate::tenant::timeline::GetVectoredError;
use crate::tenant::vectored_blob_io::{
BlobFlag, BufView, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
VectoredReadPlanner,
BlobFlag, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead, VectoredReadPlanner,
};
use crate::tenant::PageReconstructError;
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
@@ -548,15 +547,15 @@ impl ImageLayerInner {
let buf = BytesMut::with_capacity(buf_size);
let blobs_buf = vectored_blob_reader.read_blobs(&read, buf, ctx).await?;
let frozen_buf = blobs_buf.buf.freeze();
let view = BufView::new_bytes(frozen_buf);
for meta in blobs_buf.blobs.iter() {
let img_buf = meta.read(&view).await?;
let img_buf = frozen_buf.slice(meta.start..meta.end);
key_count += 1;
writer
.put_image(meta.meta.key, img_buf.into_bytes(), ctx)
.put_image(meta.meta.key, img_buf, ctx)
.await
.context(format!("Storing key {}", meta.meta.key))?;
}
@@ -603,28 +602,13 @@ impl ImageLayerInner {
match res {
Ok(blobs_buf) => {
let frozen_buf = blobs_buf.buf.freeze();
let view = BufView::new_bytes(frozen_buf);
for meta in blobs_buf.blobs.iter() {
let img_buf = meta.read(&view).await;
let img_buf = match img_buf {
Ok(img_buf) => img_buf,
Err(e) => {
reconstruct_state.on_key_error(
meta.meta.key,
PageReconstructError::Other(anyhow!(e).context(format!(
"Failed to decompress blob from virtual file {}",
self.file.path,
))),
);
continue;
}
};
let img_buf = frozen_buf.slice(meta.start..meta.end);
reconstruct_state.update_key(
&meta.meta.key,
self.lsn,
Value::Image(img_buf.into_bytes()),
Value::Image(img_buf),
);
}
}
@@ -1041,15 +1025,10 @@ impl<'a> ImageLayerIterator<'a> {
let blobs_buf = vectored_blob_reader
.read_blobs(&plan, buf, self.ctx)
.await?;
let frozen_buf = blobs_buf.buf.freeze();
let view = BufView::new_bytes(frozen_buf);
let frozen_buf: Bytes = blobs_buf.buf.freeze();
for meta in blobs_buf.blobs.iter() {
let img_buf = meta.read(&view).await?;
next_batch.push_back((
meta.meta.key,
self.image_layer.lsn,
Value::Image(img_buf.into_bytes()),
));
let img_buf = frozen_buf.slice(meta.start..meta.end);
next_batch.push_back((meta.meta.key, self.image_layer.lsn, Value::Image(img_buf)));
}
self.key_values_batch = next_batch;
Ok(())

View File

@@ -481,7 +481,8 @@ async fn ingest_housekeeping_loop(tenant: Arc<Tenant>, cancel: CancellationToken
let allowed_rps = tenant.timeline_get_throttle.steady_rps();
let delta = now - prev;
info!(
n_seconds=%format_args!("{:.3}", delta.as_secs_f64()),
n_seconds=%format_args!("{:.3}",
delta.as_secs_f64()),
count_accounted = count_accounted_finish, // don't break existing log scraping
count_throttled,
sum_throttled_usecs,

View File

@@ -112,7 +112,7 @@ use pageserver_api::reltag::RelTag;
use pageserver_api::shard::ShardIndex;
use postgres_connection::PgConnectionConfig;
use postgres_ffi::{to_pg_timestamp, v14::xlog_utils, WAL_SEGMENT_SIZE};
use postgres_ffi::to_pg_timestamp;
use utils::{
completion,
generation::Generation,
@@ -1337,10 +1337,6 @@ impl Timeline {
_ctx: &RequestContext,
) -> anyhow::Result<LsnLease> {
let lease = {
// Normalize the requested LSN to be aligned, and move to the first record
// if it points to the beginning of the page (header).
let lsn = xlog_utils::normalize_lsn(lsn, WAL_SEGMENT_SIZE);
let mut gc_info = self.gc_info.write().unwrap();
let valid_until = SystemTime::now() + length;
@@ -3601,7 +3597,7 @@ impl Timeline {
ctx,
)
.await
.map_err(|e| FlushLayerError::from_anyhow(self, e.into()))?;
.map_err(|e| FlushLayerError::from_anyhow(self, e))?;
if self.cancel.is_cancelled() {
return Err(FlushLayerError::Cancelled);
@@ -3840,20 +3836,16 @@ impl Timeline {
partition_size: u64,
flags: EnumSet<CompactFlags>,
ctx: &RequestContext,
) -> Result<((KeyPartitioning, SparseKeyPartitioning), Lsn), CompactionError> {
) -> anyhow::Result<((KeyPartitioning, SparseKeyPartitioning), Lsn)> {
let Ok(mut partitioning_guard) = self.partitioning.try_lock() else {
// NB: there are two callers, one is the compaction task, of which there is only one per struct Tenant and hence Timeline.
// The other is the initdb optimization in flush_frozen_layer, used by `boostrap_timeline`, which runs before `.activate()`
// and hence before the compaction task starts.
return Err(CompactionError::Other(anyhow!(
"repartition() called concurrently, this should not happen"
)));
anyhow::bail!("repartition() called concurrently, this should not happen");
};
let ((dense_partition, sparse_partition), partition_lsn) = &*partitioning_guard;
if lsn < *partition_lsn {
return Err(CompactionError::Other(anyhow!(
"repartition() called with LSN going backwards, this should not happen"
)));
anyhow::bail!("repartition() called with LSN going backwards, this should not happen");
}
let distance = lsn.0 - partition_lsn.0;
@@ -4455,12 +4447,6 @@ pub(crate) enum CompactionError {
Other(anyhow::Error),
}
impl CompactionError {
pub fn is_cancelled(&self) -> bool {
matches!(self, CompactionError::ShuttingDown)
}
}
impl From<CollectKeySpaceError> for CompactionError {
fn from(err: CollectKeySpaceError) -> Self {
match err {

View File

@@ -390,7 +390,7 @@ impl Timeline {
// error but continue.
//
// Suppress error when it's due to cancellation
if !self.cancel.is_cancelled() && !err.is_cancelled() {
if !self.cancel.is_cancelled() {
tracing::error!("could not compact, repartitioning keyspace failed: {err:?}");
}
(1, false)

View File

@@ -30,8 +30,8 @@ use crate::{
pgdatadir_mapping::CollectKeySpaceError,
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
tenant::{
size::CalculateSyntheticSizeError, storage_layer::LayerVisibilityHint,
tasks::BackgroundLoopKind, timeline::EvictionError, LogicalSizeCalculationCause, Tenant,
storage_layer::LayerVisibilityHint, tasks::BackgroundLoopKind, timeline::EvictionError,
LogicalSizeCalculationCause, Tenant,
},
};
@@ -557,8 +557,6 @@ impl Timeline {
gather_result = gather => {
match gather_result {
Ok(_) => {},
// It can happen sometimes that we hit this instead of the cancellation token firing above
Err(CalculateSyntheticSizeError::Cancelled) => {}
Err(e) => {
// We don't care about the result, but, if it failed, we should log it,
// since consumption metric might be hitting the cached value and

View File

@@ -16,9 +16,8 @@
//! Note that the vectored blob api does *not* go through the page cache.
use std::collections::BTreeMap;
use std::ops::Deref;
use bytes::{Bytes, BytesMut};
use bytes::BytesMut;
use pageserver_api::key::Key;
use tokio::io::AsyncWriteExt;
use tokio_epoll_uring::BoundedBuf;
@@ -36,123 +35,11 @@ pub struct BlobMeta {
pub lsn: Lsn,
}
/// A view into the vectored blobs read buffer.
#[derive(Clone, Debug)]
pub(crate) enum BufView<'a> {
Slice(&'a [u8]),
Bytes(bytes::Bytes),
}
impl<'a> BufView<'a> {
/// Creates a new slice-based view on the blob.
pub fn new_slice(slice: &'a [u8]) -> Self {
Self::Slice(slice)
}
/// Creates a new [`bytes::Bytes`]-based view on the blob.
pub fn new_bytes(bytes: bytes::Bytes) -> Self {
Self::Bytes(bytes)
}
/// Convert the view into `Bytes`.
///
/// If using slice as the underlying storage, the copy will be an O(n) operation.
pub fn into_bytes(self) -> Bytes {
match self {
BufView::Slice(slice) => Bytes::copy_from_slice(slice),
BufView::Bytes(bytes) => bytes,
}
}
/// Creates a sub-view of the blob based on the range.
fn view(&self, range: std::ops::Range<usize>) -> Self {
match self {
BufView::Slice(slice) => BufView::Slice(&slice[range]),
BufView::Bytes(bytes) => BufView::Bytes(bytes.slice(range)),
}
}
}
impl<'a> Deref for BufView<'a> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
match self {
BufView::Slice(slice) => slice,
BufView::Bytes(bytes) => bytes,
}
}
}
impl<'a> AsRef<[u8]> for BufView<'a> {
fn as_ref(&self) -> &[u8] {
match self {
BufView::Slice(slice) => slice,
BufView::Bytes(bytes) => bytes.as_ref(),
}
}
}
impl<'a> From<&'a [u8]> for BufView<'a> {
fn from(value: &'a [u8]) -> Self {
Self::new_slice(value)
}
}
impl From<Bytes> for BufView<'_> {
fn from(value: Bytes) -> Self {
Self::new_bytes(value)
}
}
/// Blob offsets into [`VectoredBlobsBuf::buf`]. The byte ranges is potentially compressed,
/// subject to [`VectoredBlob::compression_bits`].
/// Blob offsets into [`VectoredBlobsBuf::buf`]
pub struct VectoredBlob {
/// Blob metadata.
pub start: usize,
pub end: usize,
pub meta: BlobMeta,
/// Start offset.
start: usize,
/// End offset.
end: usize,
/// Compression used on the the blob.
compression_bits: u8,
}
impl VectoredBlob {
/// Reads a decompressed view of the blob.
pub(crate) async fn read<'a>(&self, buf: &BufView<'a>) -> Result<BufView<'a>, std::io::Error> {
let view = buf.view(self.start..self.end);
match self.compression_bits {
BYTE_UNCOMPRESSED => Ok(view),
BYTE_ZSTD => {
let mut decompressed_vec = Vec::new();
let mut decoder =
async_compression::tokio::write::ZstdDecoder::new(&mut decompressed_vec);
decoder.write_all(&view).await?;
decoder.flush().await?;
// Zero-copy conversion from `Vec` to `Bytes`
Ok(BufView::new_bytes(Bytes::from(decompressed_vec)))
}
bits => {
let error = std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Failed to decompress blob for {}@{}, {}..{}: invalid compression byte {bits:x}", self.meta.key, self.meta.lsn, self.start, self.end),
);
Err(error)
}
}
}
}
impl std::fmt::Display for VectoredBlob {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}@{}, {}..{}",
self.meta.key, self.meta.lsn, self.start, self.end
)
}
}
/// Return type of [`VectoredBlobReader::read_blobs`]
@@ -627,7 +514,7 @@ impl<'a> VectoredBlobReader<'a> {
);
}
let buf = self
let mut buf = self
.file
.read_exact_at(buf.slice(0..read.size()), read.start, ctx)
.await?
@@ -642,6 +529,9 @@ impl<'a> VectoredBlobReader<'a> {
// of a blob is implicit: the start of the next blob if one exists
// or the end of the read.
// Some scratch space, put here for reusing the allocation
let mut decompressed_vec = Vec::new();
for (blob_start, meta) in blobs_at {
let blob_start_in_buf = blob_start - start_offset;
let first_len_byte = buf[blob_start_in_buf as usize];
@@ -667,14 +557,35 @@ impl<'a> VectoredBlobReader<'a> {
)
};
let start = (blob_start_in_buf + size_length) as usize;
let end = start + blob_size as usize;
let start_raw = blob_start_in_buf + size_length;
let end_raw = start_raw + blob_size;
let (start, end);
if compression_bits == BYTE_UNCOMPRESSED {
start = start_raw as usize;
end = end_raw as usize;
} else if compression_bits == BYTE_ZSTD {
let mut decoder =
async_compression::tokio::write::ZstdDecoder::new(&mut decompressed_vec);
decoder
.write_all(&buf[start_raw as usize..end_raw as usize])
.await?;
decoder.flush().await?;
start = buf.len();
buf.extend_from_slice(&decompressed_vec);
end = buf.len();
decompressed_vec.clear();
} else {
let error = std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("invalid compression byte {compression_bits:x}"),
);
return Err(error);
}
metas.push(VectoredBlob {
start,
end,
meta: *meta,
compression_bits,
});
}
@@ -1109,13 +1020,8 @@ mod tests {
let result = vectored_blob_reader.read_blobs(&read, buf, &ctx).await?;
assert_eq!(result.blobs.len(), 1);
let read_blob = &result.blobs[0];
let view = BufView::new_slice(&result.buf);
let read_buf = read_blob.read(&view).await?;
assert_eq!(
&blob[..],
&read_buf[..],
"mismatch for idx={idx} at offset={offset}"
);
let read_buf = &result.buf[read_blob.start..read_blob.end];
assert_eq!(blob, read_buf, "mismatch for idx={idx} at offset={offset}");
buf = result.buf;
}
Ok(())

View File

@@ -9,8 +9,6 @@ OBJS = \
hll.o \
libpagestore.o \
neon.o \
neon_pgversioncompat.o \
neon_perf_counters.o \
neon_utils.o \
neon_walreader.o \
pagestore_smgr.o \
@@ -25,18 +23,7 @@ SHLIB_LINK_INTERNAL = $(libpq)
SHLIB_LINK = -lcurl
EXTENSION = neon
DATA = \
neon--1.0.sql \
neon--1.0--1.1.sql \
neon--1.1--1.2.sql \
neon--1.2--1.3.sql \
neon--1.3--1.4.sql \
neon--1.4--1.5.sql \
neon--1.5--1.4.sql \
neon--1.4--1.3.sql \
neon--1.3--1.2.sql \
neon--1.2--1.1.sql \
neon--1.1--1.0.sql
DATA = neon--1.0.sql neon--1.0--1.1.sql neon--1.1--1.2.sql neon--1.2--1.3.sql neon--1.3--1.2.sql neon--1.2--1.1.sql neon--1.1--1.0.sql neon--1.3--1.4.sql neon--1.4--1.3.sql
PGFILEDESC = "neon - cloud storage for PostgreSQL"
EXTRA_CLEAN = \

View File

@@ -109,7 +109,6 @@ typedef struct FileCacheControl
* reenabling */
uint32 size; /* size of cache file in chunks */
uint32 used; /* number of used chunks */
uint32 used_pages; /* number of used pages */
uint32 limit; /* shared copy of lfc_size_limit */
uint64 hits;
uint64 misses;
@@ -906,10 +905,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
/* Cache overflow: evict least recently used chunk */
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->lru));
for (int i = 0; i < BLOCKS_PER_CHUNK; i++)
{
lfc_ctl->used_pages -= (victim->bitmap[i >> 5] >> (i & 31)) & 1;
}
CriticalAssert(victim->access_count == 0);
entry->offset = victim->offset; /* grab victim's chunk */
hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL);
@@ -964,7 +959,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
for (int i = 0; i < blocks_in_chunk; i++)
{
lfc_ctl->used_pages += 1 - ((entry->bitmap[(chunk_offs + i) >> 5] >> ((chunk_offs + i) & 31)) & 1);
entry->bitmap[(chunk_offs + i) >> 5] |=
(1 << ((chunk_offs + i) & 31));
}
@@ -1057,11 +1051,6 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS)
if (lfc_ctl)
value = lfc_ctl->size;
break;
case 5:
key = "file_cache_used_pages";
if (lfc_ctl)
value = lfc_ctl->used_pages;
break;
default:
SRF_RETURN_DONE(funcctx);
}

View File

@@ -30,7 +30,6 @@
#include "utils/guc.h"
#include "neon.h"
#include "neon_perf_counters.h"
#include "neon_utils.h"
#include "pagestore_client.h"
#include "walproposer.h"
@@ -332,7 +331,6 @@ CLEANUP_AND_DISCONNECT(PageServer *shard)
}
if (shard->conn)
{
MyNeonCounters->pageserver_disconnects_total++;
PQfinish(shard->conn);
shard->conn = NULL;
}
@@ -739,8 +737,6 @@ pageserver_send(shardno_t shard_no, NeonRequest *request)
PageServer *shard = &page_servers[shard_no];
PGconn *pageserver_conn;
MyNeonCounters->pageserver_requests_sent_total++;
/* If the connection was lost for some reason, reconnect */
if (shard->state == PS_Connected && PQstatus(shard->conn) == CONNECTION_BAD)
{
@@ -893,7 +889,6 @@ pageserver_flush(shardno_t shard_no)
}
else
{
MyNeonCounters->pageserver_send_flushes_total++;
if (PQflush(pageserver_conn))
{
char *msg = pchomp(PQerrorMessage(pageserver_conn));
@@ -927,7 +922,7 @@ check_neon_id(char **newval, void **extra, GucSource source)
static Size
PagestoreShmemSize(void)
{
return add_size(sizeof(PagestoreShmemState), NeonPerfCountersShmemSize());
return sizeof(PagestoreShmemState);
}
static bool
@@ -946,9 +941,6 @@ PagestoreShmemInit(void)
memset(&pagestore_shared->shard_map, 0, sizeof(ShardMap));
AssignPageserverConnstring(page_server_connstring, NULL);
}
NeonPerfCountersShmemInit();
LWLockRelease(AddinShmemInitLock);
return found;
}

View File

@@ -1,39 +0,0 @@
\echo Use "ALTER EXTENSION neon UPDATE TO '1.5'" to load this file. \quit
CREATE FUNCTION get_backend_perf_counters()
RETURNS SETOF RECORD
AS 'MODULE_PATHNAME', 'neon_get_backend_perf_counters'
LANGUAGE C PARALLEL SAFE;
CREATE FUNCTION get_perf_counters()
RETURNS SETOF RECORD
AS 'MODULE_PATHNAME', 'neon_get_perf_counters'
LANGUAGE C PARALLEL SAFE;
-- Show various metrics, for each backend. Note that the values are not reset
-- when a backend exits. When a new backend starts with the backend ID, it will
-- continue accumulating the values from where the old backend left. If you are
-- only interested in the changes from your own session, store the values at the
-- beginning of the session somewhere, and subtract them on subsequent calls.
--
-- For histograms, 'bucket_le' is the upper bound of the histogram bucket.
CREATE VIEW neon_backend_perf_counters AS
SELECT P.procno, P.pid, P.metric, P.bucket_le, P.value
FROM get_backend_perf_counters() AS P (
procno integer,
pid integer,
metric text,
bucket_le float8,
value float8
);
-- Summary across all backends. (This could also be implemented with
-- an aggregate query over neon_backend_perf_counters view.)
CREATE VIEW neon_perf_counters AS
SELECT P.metric, P.bucket_le, P.value
FROM get_perf_counters() AS P (
metric text,
bucket_le float8,
value float8
);

View File

@@ -1,4 +0,0 @@
DROP VIEW IF EXISTS neon_perf_counters;
DROP VIEW IF EXISTS neon_backend_perf_counters;
DROP FUNCTION IF EXISTS get_perf_counters();
DROP FUNCTION IF EXISTS get_backend_perf_counters();

View File

@@ -1,7 +1,5 @@
# neon extension
comment = 'cloud storage for PostgreSQL'
# TODO: bump default version to 1.5, after we are certain that we don't
# need to rollback the compute image
default_version = '1.4'
module_pathname = '$libdir/neon'
relocatable = true

View File

@@ -1,261 +0,0 @@
/*-------------------------------------------------------------------------
*
* neon_perf_counters.c
* Collect statistics about Neon I/O
*
* Each backend has its own set of counters in shared memory.
*
*-------------------------------------------------------------------------
*/
#include "postgres.h"
#include <math.h>
#include "funcapi.h"
#include "miscadmin.h"
#include "storage/proc.h"
#include "storage/shmem.h"
#include "utils/builtins.h"
#include "neon_perf_counters.h"
#include "neon_pgversioncompat.h"
neon_per_backend_counters *neon_per_backend_counters_shared;
Size
NeonPerfCountersShmemSize(void)
{
Size size = 0;
size = add_size(size, mul_size(MaxBackends, sizeof(neon_per_backend_counters)));
return size;
}
void
NeonPerfCountersShmemInit(void)
{
bool found;
neon_per_backend_counters_shared =
ShmemInitStruct("Neon perf counters",
mul_size(MaxBackends,
sizeof(neon_per_backend_counters)),
&found);
Assert(found == IsUnderPostmaster);
if (!found)
{
/* shared memory is initialized to zeros, so nothing to do here */
}
}
/*
* Count a GetPage wait operation.
*/
void
inc_getpage_wait(uint64 latency_us)
{
int lo = 0;
int hi = NUM_GETPAGE_WAIT_BUCKETS - 1;
/* Find the right bucket with binary search */
while (lo < hi)
{
int mid = (lo + hi) / 2;
if (latency_us < getpage_wait_bucket_thresholds[mid])
hi = mid;
else
lo = mid + 1;
}
MyNeonCounters->getpage_wait_us_bucket[lo]++;
MyNeonCounters->getpage_wait_us_sum += latency_us;
MyNeonCounters->getpage_wait_us_count++;
}
/*
* Support functions for the views, neon_backend_perf_counters and
* neon_perf_counters.
*/
typedef struct
{
char *name;
bool is_bucket;
double bucket_le;
double value;
} metric_t;
static metric_t *
neon_perf_counters_to_metrics(neon_per_backend_counters *counters)
{
#define NUM_METRICS (2 + NUM_GETPAGE_WAIT_BUCKETS + 8)
metric_t *metrics = palloc((NUM_METRICS + 1) * sizeof(metric_t));
uint64 bucket_accum;
int i = 0;
Datum getpage_wait_str;
metrics[i].name = "getpage_wait_seconds_count";
metrics[i].is_bucket = false;
metrics[i].value = (double) counters->getpage_wait_us_count;
i++;
metrics[i].name = "getpage_wait_seconds_sum";
metrics[i].is_bucket = false;
metrics[i].value = ((double) counters->getpage_wait_us_sum) / 1000000.0;
i++;
bucket_accum = 0;
for (int bucketno = 0; bucketno < NUM_GETPAGE_WAIT_BUCKETS; bucketno++)
{
uint64 threshold = getpage_wait_bucket_thresholds[bucketno];
bucket_accum += counters->getpage_wait_us_bucket[bucketno];
metrics[i].name = "getpage_wait_seconds_bucket";
metrics[i].is_bucket = true;
metrics[i].bucket_le = (threshold == UINT64_MAX) ? INFINITY : ((double) threshold) / 1000000.0;
metrics[i].value = (double) bucket_accum;
i++;
}
metrics[i].name = "getpage_prefetch_requests_total";
metrics[i].is_bucket = false;
metrics[i].value = (double) counters->getpage_prefetch_requests_total;
i++;
metrics[i].name = "getpage_sync_requests_total";
metrics[i].is_bucket = false;
metrics[i].value = (double) counters->getpage_sync_requests_total;
i++;
metrics[i].name = "getpage_prefetch_misses_total";
metrics[i].is_bucket = false;
metrics[i].value = (double) counters->getpage_prefetch_misses_total;
i++;
metrics[i].name = "getpage_prefetch_discards_total";
metrics[i].is_bucket = false;
metrics[i].value = (double) counters->getpage_prefetch_discards_total;
i++;
metrics[i].name = "pageserver_requests_sent_total";
metrics[i].is_bucket = false;
metrics[i].value = (double) counters->pageserver_requests_sent_total;
i++;
metrics[i].name = "pageserver_requests_disconnects_total";
metrics[i].is_bucket = false;
metrics[i].value = (double) counters->pageserver_disconnects_total;
i++;
metrics[i].name = "pageserver_send_flushes_total";
metrics[i].is_bucket = false;
metrics[i].value = (double) counters->pageserver_send_flushes_total;
i++;
metrics[i].name = "file_cache_hits_total";
metrics[i].is_bucket = false;
metrics[i].value = (double) counters->file_cache_hits_total;
i++;
Assert(i == NUM_METRICS);
/* NULL entry marks end of array */
metrics[i].name = NULL;
metrics[i].value = 0;
return metrics;
}
/*
* Write metric to three output Datums
*/
static void
metric_to_datums(metric_t *m, Datum *values, bool *nulls)
{
values[0] = CStringGetTextDatum(m->name);
nulls[0] = false;
if (m->is_bucket)
{
values[1] = Float8GetDatum(m->bucket_le);
nulls[1] = false;
}
else
{
values[1] = (Datum) 0;
nulls[1] = true;
}
values[2] = Float8GetDatum(m->value);
nulls[2] = false;
}
PG_FUNCTION_INFO_V1(neon_get_backend_perf_counters);
Datum
neon_get_backend_perf_counters(PG_FUNCTION_ARGS)
{
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
Datum values[5];
bool nulls[5];
/* We put all the tuples into a tuplestore in one go. */
InitMaterializedSRF(fcinfo, 0);
for (int procno = 0; procno < MaxBackends; procno++)
{
PGPROC *proc = GetPGProcByNumber(procno);
int pid = proc->pid;
neon_per_backend_counters *counters = &neon_per_backend_counters_shared[procno];
metric_t *metrics = neon_perf_counters_to_metrics(counters);
values[0] = Int32GetDatum(procno);
nulls[0] = false;
values[1] = Int32GetDatum(pid);
nulls[1] = false;
for (int i = 0; metrics[i].name != NULL; i++)
{
metric_to_datums(&metrics[i], &values[2], &nulls[2]);
tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls);
}
pfree(metrics);
}
return (Datum) 0;
}
PG_FUNCTION_INFO_V1(neon_get_perf_counters);
Datum
neon_get_perf_counters(PG_FUNCTION_ARGS)
{
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
Datum values[3];
bool nulls[3];
Datum getpage_wait_str;
neon_per_backend_counters totals = {0};
metric_t *metrics;
/* We put all the tuples into a tuplestore in one go. */
InitMaterializedSRF(fcinfo, 0);
/* Aggregate the counters across all backends */
for (int procno = 0; procno < MaxBackends; procno++)
{
neon_per_backend_counters *counters = &neon_per_backend_counters_shared[procno];
totals.getpage_wait_us_count += counters->getpage_wait_us_count;
totals.getpage_wait_us_sum += counters->getpage_wait_us_sum;
for (int bucketno = 0; bucketno < NUM_GETPAGE_WAIT_BUCKETS; bucketno++)
totals.getpage_wait_us_bucket[bucketno] += counters->getpage_wait_us_bucket[bucketno];
totals.getpage_prefetch_requests_total += counters->getpage_prefetch_requests_total;
totals.getpage_sync_requests_total += counters->getpage_sync_requests_total;
totals.getpage_prefetch_misses_total += counters->getpage_prefetch_misses_total;
totals.getpage_prefetch_discards_total += counters->getpage_prefetch_discards_total;
totals.pageserver_requests_sent_total += counters->pageserver_requests_sent_total;
totals.pageserver_disconnects_total += counters->pageserver_disconnects_total;
totals.pageserver_send_flushes_total += counters->pageserver_send_flushes_total;
totals.file_cache_hits_total += counters->file_cache_hits_total;
}
metrics = neon_perf_counters_to_metrics(&totals);
for (int i = 0; metrics[i].name != NULL; i++)
{
metric_to_datums(&metrics[i], &values[0], &nulls[0]);
tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls);
}
pfree(metrics);
return (Datum) 0;
}

View File

@@ -1,111 +0,0 @@
/*-------------------------------------------------------------------------
*
* neon_perf_counters.h
* Performance counters for neon storage requests
*-------------------------------------------------------------------------
*/
#ifndef NEON_PERF_COUNTERS_H
#define NEON_PERF_COUNTERS_H
#if PG_VERSION_NUM >= 170000
#include "storage/procnumber.h"
#else
#include "storage/backendid.h"
#include "storage/proc.h"
#endif
static const uint64 getpage_wait_bucket_thresholds[] = {
20, 30, 60, 100, /* 0 - 100 us */
200, 300, 600, 1000, /* 100 us - 1 ms */
2000, 3000, 6000, 10000, /* 1 ms - 10 ms */
20000, 30000, 60000, 100000, /* 10 ms - 100 ms */
200000, 300000, 600000, 1000000, /* 100 ms - 1 s */
2000000, 3000000, 6000000, 10000000, /* 1 s - 10 s */
20000000, 30000000, 60000000, 100000000, /* 10 s - 100 s */
UINT64_MAX,
};
#define NUM_GETPAGE_WAIT_BUCKETS (lengthof(getpage_wait_bucket_thresholds))
typedef struct
{
/*
* Histogram for how long an smgrread() request needs to wait for response
* from pageserver. When prefetching is effective, these wait times can be
* lower than the network latency to the pageserver, even zero, if the
* page is already readily prefetched whenever we need to read a page.
*
* Note: we accumulate these in microseconds, because that's convenient in
* the backend, but the 'neon_backend_perf_counters' view will convert
* them to seconds, to make them more idiomatic as prometheus metrics.
*/
uint64 getpage_wait_us_count;
uint64 getpage_wait_us_sum;
uint64 getpage_wait_us_bucket[NUM_GETPAGE_WAIT_BUCKETS];
/*
* Total number of speculative prefetch Getpage requests and synchronous
* GetPage requests sent.
*/
uint64 getpage_prefetch_requests_total;
uint64 getpage_sync_requests_total;
/* XXX: It's not clear to me when these misses happen. */
uint64 getpage_prefetch_misses_total;
/*
* Number of prefetched responses that were discarded becuase the
* prefetched page was not needed or because it was concurrently fetched /
* modified by another backend.
*/
uint64 getpage_prefetch_discards_total;
/*
* Total number of requests send to pageserver. (prefetch_requests_total
* and sync_request_total count only GetPage requests, this counts all
* request types.)
*/
uint64 pageserver_requests_sent_total;
/*
* Number of times the connection to the pageserver was lost and the
* backend had to reconnect. Note that this doesn't count the first
* connection in each backend, only reconnects.
*/
uint64 pageserver_disconnects_total;
/*
* Number of network flushes to the pageserver. Synchronous requests are
* flushed immediately, but when prefetching requests are sent in batches,
* this can be smaller than pageserver_requests_sent_total.
*/
uint64 pageserver_send_flushes_total;
/*
* Number of requests satisfied from the LFC.
*
* This is redundant with the server-wide file_cache_hits, but this gives
* per-backend granularity, and it's handy to have this in the same place
* as counters for requests that went to the pageserver. Maybe move all
* the LFC stats to this struct in the future?
*/
uint64 file_cache_hits_total;
} neon_per_backend_counters;
/* Pointer to the shared memory array of neon_per_backend_counters structs */
extern neon_per_backend_counters *neon_per_backend_counters_shared;
#if PG_VERSION_NUM >= 170000
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber])
#else
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProc->pgprocno])
#endif
extern void inc_getpage_wait(uint64 latency);
extern Size NeonPerfCountersShmemSize(void);
extern void NeonPerfCountersShmemInit(void);
#endif /* NEON_PERF_COUNTERS_H */

View File

@@ -1,44 +0,0 @@
/*
* Support functions for the compatibility macros in neon_pgversioncompat.h
*/
#include "postgres.h"
#include "funcapi.h"
#include "miscadmin.h"
#include "utils/tuplestore.h"
#include "neon_pgversioncompat.h"
#if PG_MAJORVERSION_NUM < 15
void
InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags)
{
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
Tuplestorestate *tupstore;
MemoryContext old_context,
per_query_ctx;
TupleDesc stored_tupdesc;
/* check to see if caller supports returning a tuplestore */
if (rsinfo == NULL || !IsA(rsinfo, ReturnSetInfo))
ereport(ERROR,
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("set-valued function called in context that cannot accept a set")));
/*
* Store the tuplestore and the tuple descriptor in ReturnSetInfo. This
* must be done in the per-query memory context.
*/
per_query_ctx = rsinfo->econtext->ecxt_per_query_memory;
old_context = MemoryContextSwitchTo(per_query_ctx);
if (get_call_result_type(fcinfo, NULL, &stored_tupdesc) != TYPEFUNC_COMPOSITE)
elog(ERROR, "return type must be a row type");
tupstore = tuplestore_begin_heap(false, false, work_mem);
rsinfo->returnMode = SFRM_Materialize;
rsinfo->setResult = tupstore;
rsinfo->setDesc = stored_tupdesc;
MemoryContextSwitchTo(old_context);
}
#endif

View File

@@ -6,8 +6,6 @@
#ifndef NEON_PGVERSIONCOMPAT_H
#define NEON_PGVERSIONCOMPAT_H
#include "fmgr.h"
#if PG_MAJORVERSION_NUM < 17
#define NRelFileInfoBackendIsTemp(rinfo) (rinfo.backend != InvalidBackendId)
#else
@@ -125,8 +123,4 @@
#define AmAutoVacuumWorkerProcess() (IsAutoVacuumWorkerProcess())
#endif
#if PG_MAJORVERSION_NUM < 15
extern void InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags);
#endif
#endif /* NEON_PGVERSIONCOMPAT_H */

View File

@@ -66,7 +66,6 @@
#include "storage/md.h"
#include "storage/smgr.h"
#include "neon_perf_counters.h"
#include "pagestore_client.h"
#include "bitmap.h"
@@ -290,6 +289,7 @@ static PrefetchState *MyPState;
static bool compact_prefetch_buffers(void);
static void consume_prefetch_responses(void);
static uint64 prefetch_register_buffer(BufferTag tag, neon_request_lsns *force_request_lsns);
static bool prefetch_read(PrefetchRequest *slot);
static void prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns);
static bool prefetch_wait_for(uint64 ring_index);
@@ -780,27 +780,21 @@ prefetch_do_request(PrefetchRequest *slot, neon_request_lsns *force_request_lsns
}
/*
* prefetch_register_bufferv() - register and prefetch buffers
* prefetch_register_buffer() - register and prefetch buffer
*
* Register that we may want the contents of BufferTag in the near future.
* This is used when issuing a speculative prefetch request, but also when
* performing a synchronous request and need the buffer right now.
*
* If force_request_lsns is not NULL, those values are sent to the
* pageserver. If NULL, we utilize the lastWrittenLsn -infrastructure
* to calculate the LSNs to send.
*
* When performing a prefetch rather than a synchronous request,
* is_prefetch==true. Currently, it only affects how the request is accounted
* in the perf counters.
*
* NOTE: this function may indirectly update MyPState->pfs_hash; which
* invalidates any active pointers into the hash table.
*/
static uint64
prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
BlockNumber nblocks, const bits8 *mask,
bool is_prefetch)
BlockNumber nblocks, const bits8 *mask)
{
uint64 min_ring_index;
PrefetchRequest req;
@@ -821,7 +815,6 @@ Retry:
PrfHashEntry *entry = NULL;
uint64 ring_index;
neon_request_lsns *lsns;
if (PointerIsValid(mask) && !BITMAP_ISSET(mask, i))
continue;
@@ -865,7 +858,6 @@ Retry:
prefetch_set_unused(ring_index);
entry = NULL;
slot = NULL;
MyNeonCounters->getpage_prefetch_discards_total++;
}
}
@@ -980,11 +972,6 @@ Retry:
min_ring_index = Min(min_ring_index, ring_index);
if (is_prefetch)
MyNeonCounters->getpage_prefetch_requests_total++;
else
MyNeonCounters->getpage_sync_requests_total++;
prefetch_do_request(slot, lsns);
}
@@ -1013,6 +1000,13 @@ Retry:
}
static uint64
prefetch_register_buffer(BufferTag tag, neon_request_lsns *force_request_lsns)
{
return prefetch_register_bufferv(tag, force_request_lsns, 1, NULL);
}
/*
* Note: this function can get canceled and use a long jump to the next catch
* context. Take care.
@@ -2618,7 +2612,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
lfc_present[i] = ~(lfc_present[i]);
ring_index = prefetch_register_bufferv(tag, NULL, iterblocks,
lfc_present, true);
lfc_present);
nblocks -= iterblocks;
blocknum += iterblocks;
@@ -2662,7 +2656,7 @@ neon_prefetch(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum)
CopyNRelFileInfoToBufTag(tag, InfoFromSMgrRel(reln));
ring_index = prefetch_register_bufferv(tag, NULL, 1, NULL, true);
ring_index = prefetch_register_buffer(tag, NULL);
Assert(ring_index < MyPState->ring_unused &&
MyPState->ring_last <= ring_index);
@@ -2753,20 +2747,17 @@ neon_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber base_block
* weren't for the behaviour of the LwLsn cache that uses the highest
* value of the LwLsn cache when the entry is not found.
*/
prefetch_register_bufferv(buftag, request_lsns, nblocks, mask, false);
prefetch_register_bufferv(buftag, request_lsns, nblocks, mask);
for (int i = 0; i < nblocks; i++)
{
void *buffer = buffers[i];
BlockNumber blockno = base_blockno + i;
neon_request_lsns *reqlsns = &request_lsns[i];
TimestampTz start_ts, end_ts;
if (PointerIsValid(mask) && !BITMAP_ISSET(mask, i))
continue;
start_ts = GetCurrentTimestamp();
if (RecoveryInProgress() && MyBackendType != B_STARTUP)
XLogWaitForReplayOf(reqlsns[0].request_lsn);
@@ -2803,7 +2794,6 @@ Retry:
/* drop caches */
prefetch_set_unused(slot->my_ring_index);
pgBufferUsage.prefetch.expired += 1;
MyNeonCounters->getpage_prefetch_discards_total++;
/* make it look like a prefetch cache miss */
entry = NULL;
}
@@ -2814,9 +2804,8 @@ Retry:
if (entry == NULL)
{
pgBufferUsage.prefetch.misses += 1;
MyNeonCounters->getpage_prefetch_misses_total++;
ring_index = prefetch_register_bufferv(buftag, reqlsns, 1, NULL, false);
ring_index = prefetch_register_bufferv(buftag, reqlsns, 1, NULL);
Assert(ring_index != UINT64_MAX);
slot = GetPrfSlot(ring_index);
}
@@ -2871,9 +2860,6 @@ Retry:
/* buffer was used, clean up for later reuse */
prefetch_set_unused(ring_index);
prefetch_cleanup_trailing_unused();
end_ts = GetCurrentTimestamp();
inc_getpage_wait(end_ts >= start_ts ? (end_ts - start_ts) : 0);
}
}
@@ -2927,7 +2913,6 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
/* Try to read from local file cache */
if (lfc_read(InfoFromSMgrRel(reln), forkNum, blkno, buffer))
{
MyNeonCounters->file_cache_hits_total++;
return;
}
@@ -3112,7 +3097,7 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
/* assume heap */
RmgrTable[RM_HEAP_ID].rm_mask(mdbuf_masked, blkno);
RmgrTable[RM_HEAP_ID].rm_mask(pageserver_masked, blkno);
if (memcmp(mdbuf_masked, pageserver_masked, BLCKSZ) != 0)
{
neon_log(PANIC, "heap buffers differ at blk %u in rel %u/%u/%u fork %u (request LSN %X/%08X):\n------ MD ------\n%s\n------ Page Server ------\n%s\n",

View File

@@ -24,12 +24,12 @@ bytes = { workspace = true, features = ["serde"] }
camino.workspace = true
chrono.workspace = true
clap.workspace = true
compute_api.workspace = true
consumption_metrics.workspace = true
dashmap.workspace = true
env_logger.workspace = true
framed-websockets.workspace = true
futures.workspace = true
git-version.workspace = true
hashbrown.workspace = true
hashlink.workspace = true
hex.workspace = true

View File

@@ -80,14 +80,6 @@ pub(crate) trait TestBackend: Send + Sync + 'static {
fn get_allowed_ips_and_secret(
&self,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
fn dyn_clone(&self) -> Box<dyn TestBackend>;
}
#[cfg(test)]
impl Clone for Box<dyn TestBackend> {
fn clone(&self) -> Self {
TestBackend::dyn_clone(&**self)
}
}
impl std::fmt::Display for Backend<'_, (), ()> {
@@ -452,7 +444,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
Self::Web(url, ()) => {
info!("performing web authentication");
let info = web::authenticate(ctx, config, &url, client).await?;
let info = web::authenticate(ctx, &url, client).await?;
Backend::Web(url, info)
}
@@ -565,7 +557,7 @@ mod tests {
stream::{PqStream, Stream},
};
use super::{auth_quirks, jwt::JwkCache, AuthRateLimiter};
use super::{auth_quirks, AuthRateLimiter};
struct Auth {
ips: Vec<IpPattern>,
@@ -593,14 +585,6 @@ mod tests {
))
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
_endpoint: crate::EndpointId,
) -> anyhow::Result<Vec<super::jwt::AuthRule>> {
unimplemented!()
}
async fn wake_compute(
&self,
_ctx: &RequestMonitoring,
@@ -611,15 +595,12 @@ mod tests {
}
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: false,
});
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {

View File

@@ -1,5 +1,4 @@
use std::{
borrow::Cow,
future::Future,
sync::Arc,
time::{Duration, SystemTime},
@@ -9,17 +8,11 @@ use anyhow::{bail, ensure, Context};
use arc_swap::ArcSwapOption;
use dashmap::DashMap;
use jose_jwk::crypto::KeyInfo;
use serde::{
de::{DeserializeSeed, IgnoredAny, Visitor},
Deserializer,
};
use serde::{Deserialize, Deserializer};
use signature::Verifier;
use tokio::time::Instant;
use crate::{
context::RequestMonitoring, http::parse_json_body_with_limit, intern::RoleNameInt, EndpointId,
RoleName,
};
use crate::{context::RequestMonitoring, http::parse_json_body_with_limit, EndpointId, RoleName};
// TODO(conrad): make these configurable.
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
@@ -34,19 +27,18 @@ pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
role_name: RoleName,
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
}
#[derive(Debug, Clone)]
pub(crate) struct AuthRule {
pub(crate) id: String,
pub(crate) jwks_url: url::Url,
pub(crate) audience: Option<String>,
pub(crate) role_names: Vec<RoleNameInt>,
}
#[derive(Default)]
pub struct JwkCache {
pub(crate) struct JwkCache {
client: reqwest::Client,
map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
@@ -62,28 +54,18 @@ pub(crate) struct JwkCacheEntry {
}
impl JwkCacheEntry {
fn find_jwk_and_audience(
&self,
key_id: &str,
role_name: &RoleName,
) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
self.key_sets
.values()
// make sure our requested role has access to the key set
.filter(|key_set| key_set.role_names.iter().any(|role| **role == **role_name))
// try and find the requested key-id in the key set
.find_map(|key_set| {
key_set
.find_key(key_id)
.map(|jwk| (jwk, key_set.audience.as_deref()))
})
fn find_jwk_and_audience(&self, key_id: &str) -> Option<(&jose_jwk::Jwk, Option<&str>)> {
self.key_sets.values().find_map(|key_set| {
key_set
.find_key(key_id)
.map(|jwk| (jwk, key_set.audience.as_deref()))
})
}
}
struct KeySet {
jwks: jose_jwk::JwkSet,
audience: Option<String>,
role_names: Vec<RoleNameInt>,
}
impl KeySet {
@@ -124,6 +106,7 @@ impl JwkCacheEntryLock {
ctx: &RequestMonitoring,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
auth_rules: &F,
) -> anyhow::Result<Arc<JwkCacheEntry>> {
// double check that no one beat us to updating the cache.
@@ -136,10 +119,11 @@ impl JwkCacheEntryLock {
}
}
let rules = auth_rules.fetch_auth_rules(ctx, endpoint).await?;
let rules = auth_rules
.fetch_auth_rules(ctx, endpoint, role_name)
.await?;
let mut key_sets =
ahash::HashMap::with_capacity_and_hasher(rules.len(), ahash::RandomState::new());
// TODO(conrad): run concurrently
// TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284)
for rule in rules {
@@ -167,7 +151,6 @@ impl JwkCacheEntryLock {
KeySet {
jwks,
audience: rule.audience,
role_names: rule.role_names,
},
);
}
@@ -190,6 +173,7 @@ impl JwkCacheEntryLock {
ctx: &RequestMonitoring,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: RoleName,
fetch: &F,
) -> Result<Arc<JwkCacheEntry>, anyhow::Error> {
let now = Instant::now();
@@ -199,7 +183,9 @@ impl JwkCacheEntryLock {
let Some(cached) = guard else {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
return self
.renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
.await;
};
let last_update = now.duration_since(cached.last_retrieved);
@@ -210,7 +196,9 @@ impl JwkCacheEntryLock {
let permit = self.acquire_permit().await;
// it's been too long since we checked the keys. wait for them to update.
return self.renew_jwks(permit, ctx, client, endpoint, fetch).await;
return self
.renew_jwks(permit, ctx, client, endpoint, role_name, fetch)
.await;
}
// every 5 minutes we should spawn a job to eagerly update the token.
@@ -224,7 +212,7 @@ impl JwkCacheEntryLock {
let ctx = ctx.clone();
tokio::spawn(async move {
if let Err(e) = entry
.renew_jwks(permit, &ctx, &client, endpoint, &fetch)
.renew_jwks(permit, &ctx, &client, endpoint, role_name, &fetch)
.await
{
tracing::warn!(error=?e, "could not fetch JWKs in background job");
@@ -244,7 +232,7 @@ impl JwkCacheEntryLock {
jwt: &str,
client: &reqwest::Client,
endpoint: EndpointId,
role_name: &RoleName,
role_name: RoleName,
fetch: &F,
) -> Result<(), anyhow::Error> {
// JWT compact form is defined to be
@@ -266,26 +254,30 @@ impl JwkCacheEntryLock {
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
ensure!(
header.typ == "JWT",
"Provided authentication token is not a valid JWT encoding"
);
ensure!(header.typ == "JWT");
let kid = header.key_id.context("missing key id")?;
let mut guard = self
.get_or_update_jwk_cache(ctx, client, endpoint.clone(), fetch)
.get_or_update_jwk_cache(ctx, client, endpoint.clone(), role_name.clone(), fetch)
.await?;
// get the key from the JWKs if possible. If not, wait for the keys to update.
let (jwk, expected_audience) = loop {
match guard.find_jwk_and_audience(kid, role_name) {
match guard.find_jwk_and_audience(kid) {
Some(jwk) => break jwk,
None if guard.last_retrieved.elapsed() > MIN_RENEW => {
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let permit = self.acquire_permit().await;
guard = self
.renew_jwks(permit, ctx, client, endpoint.clone(), fetch)
.renew_jwks(
permit,
ctx,
client,
endpoint.clone(),
role_name.clone(),
fetch,
)
.await?;
}
_ => {
@@ -308,21 +300,32 @@ impl JwkCacheEntryLock {
}
key => bail!("unsupported key type {key:?}"),
};
tracing::debug!("JWT signature valid");
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
.context("Provided authentication token is not a valid JWT encoding")?;
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
.context("Provided authentication token is not a valid JWT encoding")?;
let validator = JwtValidator {
expected_audience,
current_time: SystemTime::now(),
clock_skew_leeway: CLOCK_SKEW_LEEWAY,
};
tracing::debug!(?payload, "JWT signature valid with claims");
let payload = validator
.deserialize(&mut serde_json::Deserializer::from_slice(&payload))?;
match (expected_audience, payload.audience) {
// check the audience matches
(Some(aud1), Some(aud2)) => ensure!(aud1 == aud2, "invalid JWT token audience"),
// the audience is expected but is missing
(Some(_), None) => bail!("invalid JWT token audience"),
// we don't care for the audience field
(None, _) => {}
}
tracing::debug!(?payload, "JWT claims valid");
let now = SystemTime::now();
if let Some(exp) = payload.expiration {
ensure!(now < exp + CLOCK_SKEW_LEEWAY);
}
if let Some(nbf) = payload.not_before {
ensure!(nbf < now + CLOCK_SKEW_LEEWAY);
}
Ok(())
}
@@ -333,7 +336,7 @@ impl JwkCache {
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
role_name: &RoleName,
role_name: RoleName,
fetch: &F,
jwt: &str,
) -> Result<(), anyhow::Error> {
@@ -410,184 +413,37 @@ struct JwtHeader<'a> {
key_id: Option<&'a str>,
}
struct JwtValidator<'a> {
expected_audience: Option<&'a str>,
current_time: SystemTime,
clock_skew_leeway: Duration,
}
impl<'de> DeserializeSeed<'de> for JwtValidator<'_> {
type Value = JwtPayload<'de>;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
impl<'de> Visitor<'de> for JwtValidator<'_> {
type Value = JwtPayload<'de>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a JWT payload")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut payload = JwtPayload {
issuer: None,
subject: None,
jwt_id: None,
session_id: None,
};
let mut aud = false;
while let Some(key) = map.next_key()? {
match key {
"iss" if payload.issuer.is_none() => {
payload.issuer = Some(map.next_value()?);
}
"sub" if payload.subject.is_none() => {
payload.subject = Some(map.next_value()?);
}
"jit" if payload.jwt_id.is_none() => {
payload.jwt_id = Some(map.next_value()?);
}
"sid" if payload.session_id.is_none() => {
payload.session_id = Some(map.next_value()?);
}
"exp" => {
let exp = map.next_value::<u64>()?;
let exp = SystemTime::UNIX_EPOCH + Duration::from_secs(exp);
if self.current_time > exp + self.clock_skew_leeway {
return Err(serde::de::Error::custom("JWT token has expired"));
}
}
"nbf" => {
let nbf = map.next_value::<u64>()?;
let nbf = SystemTime::UNIX_EPOCH + Duration::from_secs(nbf);
if self.current_time + self.clock_skew_leeway < nbf {
return Err(serde::de::Error::custom(
"JWT token is not yet ready to use",
));
}
}
"aud" => {
if let Some(expected_audience) = self.expected_audience {
map.next_value_seed(AudienceValidator { expected_audience })?;
aud = true;
} else {
map.next_value::<IgnoredAny>()?;
}
}
_ => map.next_value::<IgnoredAny>().map(|IgnoredAny| ())?,
}
}
if self.expected_audience.is_some() && !aud {
return Err(serde::de::Error::custom("invalid JWT token audience"));
}
Ok(payload)
}
}
deserializer.deserialize_map(self)
}
}
struct AudienceValidator<'a> {
expected_audience: &'a str,
}
impl<'de> DeserializeSeed<'de> for AudienceValidator<'_> {
type Value = ();
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
impl<'de> Visitor<'de> for AudienceValidator<'_> {
type Value = ();
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a single string or an array of strings")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if self.expected_audience == v {
Ok(())
} else {
Err(E::custom("invalid JWT token audience"))
}
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
while let Some(v) = seq.next_element_seed(SingleAudienceValidator {
expected_audience: self.expected_audience,
})? {
if v {
return Ok(());
}
}
Err(serde::de::Error::custom("invalid JWT token audience"))
}
}
deserializer.deserialize_any(self)
}
}
struct SingleAudienceValidator<'a> {
expected_audience: &'a str,
}
impl<'de> DeserializeSeed<'de> for SingleAudienceValidator<'_> {
type Value = bool;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
impl<'de> Visitor<'de> for SingleAudienceValidator<'_> {
type Value = bool;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a single audience string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(self.expected_audience == v)
}
}
deserializer.deserialize_any(self)
}
}
/// <https://datatracker.ietf.org/doc/html/rfc7519#section-4.1>
// the following entries are only extracted for the sake of debug logging.
#[derive(Debug)]
#[allow(dead_code)]
#[derive(serde::Deserialize, serde::Serialize, Debug)]
struct JwtPayload<'a> {
/// Audience - Recipient for which the JWT is intended
#[serde(rename = "aud")]
audience: Option<&'a str>,
/// Expiration - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "exp", default)]
expiration: Option<SystemTime>,
/// Not before - Time after which the JWT expires
#[serde(deserialize_with = "numeric_date_opt", rename = "nbf", default)]
not_before: Option<SystemTime>,
// the following entries are only extracted for the sake of debug logging.
/// Issuer of the JWT
issuer: Option<Cow<'a, str>>,
#[serde(rename = "iss")]
issuer: Option<&'a str>,
/// Subject of the JWT (the user)
subject: Option<Cow<'a, str>>,
#[serde(rename = "sub")]
subject: Option<&'a str>,
/// Unique token identifier
jwt_id: Option<Cow<'a, str>>,
#[serde(rename = "jti")]
jwt_id: Option<&'a str>,
/// Unique session identifier
session_id: Option<Cow<'a, str>>,
#[serde(rename = "sid")]
session_id: Option<&'a str>,
}
fn numeric_date_opt<'de, D: Deserializer<'de>>(d: D) -> Result<Option<SystemTime>, D::Error> {
let d = <Option<u64>>::deserialize(d)?;
Ok(d.map(|n| SystemTime::UNIX_EPOCH + Duration::from_secs(n)))
}
struct JwkRenewalPermit<'a> {
@@ -668,8 +524,6 @@ mod tests {
use hyper_util::rt::TokioIo;
use rand::rngs::OsRng;
use rsa::pkcs8::DecodePrivateKey;
use serde::Serialize;
use serde_json::json;
use signature::Signer;
use tokio::net::TcpListener;
@@ -702,41 +556,23 @@ mod tests {
}
fn build_jwt_payload(kid: String, sig: jose_jwa::Signing) -> String {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let body = typed_json::json! {{
"exp": now + 3600,
"nbf": now,
"aud": ["audience1", "neon", "audience2"],
"sub": "user1",
"sid": "session1",
"jti": "token1",
"iss": "neon-testing",
}};
build_custom_jwt_payload(kid, body, sig)
}
fn build_custom_jwt_payload(
kid: String,
body: impl Serialize,
sig: jose_jwa::Signing,
) -> String {
let header = JwtHeader {
typ: "JWT",
algorithm: jose_jwa::Algorithm::Signing(sig),
key_id: Some(&kid),
};
let body = typed_json::json! {{
"exp": SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 3600,
}};
let header =
base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD);
let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD);
let body = base64::encode_config(body.to_string(), URL_SAFE_NO_PAD);
format!("{header}.{body}")
}
fn new_ec_jwt(kid: String, key: &p256::SecretKey) -> String {
fn new_ec_jwt(kid: String, key: p256::SecretKey) -> String {
use p256::ecdsa::{Signature, SigningKey};
let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256);
@@ -746,16 +582,6 @@ mod tests {
format!("{payload}.{sig}")
}
fn new_custom_ec_jwt(kid: String, key: &p256::SecretKey, body: impl Serialize) -> String {
use p256::ecdsa::{Signature, SigningKey};
let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256);
let sig: Signature = SigningKey::from(key).sign(payload.as_bytes());
let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD);
format!("{payload}.{sig}")
}
fn new_rsa_jwt(kid: String, key: rsa::RsaPrivateKey) -> String {
use rsa::pkcs1v15::SigningKey;
use rsa::signature::SignatureEncoding;
@@ -827,34 +653,42 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
-----END PRIVATE KEY-----
";
#[derive(Clone)]
struct Fetch(Vec<AuthRule>);
#[tokio::test]
async fn renew() {
let (rs1, jwk1) = new_rsa_jwk(RS1, "1".into());
let (rs2, jwk2) = new_rsa_jwk(RS2, "2".into());
let (ec1, jwk3) = new_ec_jwk("3".into());
let (ec2, jwk4) = new_ec_jwk("4".into());
impl FetchAuthRules for Fetch {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
Ok(self.0.clone())
}
}
let jwt1 = new_rsa_jwt("1".into(), rs1);
let jwt2 = new_rsa_jwt("2".into(), rs2);
let jwt3 = new_ec_jwt("3".into(), ec1);
let jwt4 = new_ec_jwt("4".into(), ec2);
let foo_jwks = jose_jwk::JwkSet {
keys: vec![jwk1, jwk3],
};
let bar_jwks = jose_jwk::JwkSet {
keys: vec![jwk2, jwk4],
};
async fn jwks_server(
router: impl for<'a> Fn(&'a str) -> Option<Vec<u8>> + Send + Sync + 'static,
) -> SocketAddr {
let router = Arc::new(router);
let service = service_fn(move |req| {
let router = Arc::clone(&router);
let foo_jwks = foo_jwks.clone();
let bar_jwks = bar_jwks.clone();
async move {
match router(req.uri().path()) {
Some(body) => Response::builder()
.status(200)
.body(Full::new(Bytes::from(body))),
None => Response::builder()
.status(404)
.body(Full::new(Bytes::new())),
}
let jwks = match req.uri().path() {
"/foo" => &foo_jwks,
"/bar" => &bar_jwks,
_ => {
return Response::builder()
.status(404)
.body(Full::new(Bytes::new()));
}
};
let body = serde_json::to_vec(jwks).unwrap();
Response::builder()
.status(200)
.body(Full::new(Bytes::from(body)))
}
});
@@ -869,257 +703,50 @@ X0n5X2/pBLJzxZc62ccvZYVnctBiFs6HbSnxpuMQCfkt/BcR/ttIepBQQIW86wHL
}
});
addr
}
let client = reqwest::Client::new();
#[tokio::test]
async fn check_jwt_happy_path() {
let (rs1, jwk1) = new_rsa_jwk(RS1, "rs1".into());
let (rs2, jwk2) = new_rsa_jwk(RS2, "rs2".into());
let (ec1, jwk3) = new_ec_jwk("ec1".into());
let (ec2, jwk4) = new_ec_jwk("ec2".into());
#[derive(Clone)]
struct Fetch(SocketAddr);
let foo_jwks = jose_jwk::JwkSet {
keys: vec![jwk1, jwk3],
};
let bar_jwks = jose_jwk::JwkSet {
keys: vec![jwk2, jwk4],
};
let jwks_addr = jwks_server(move |path| match path {
"/foo" => Some(serde_json::to_vec(&foo_jwks).unwrap()),
"/bar" => Some(serde_json::to_vec(&bar_jwks).unwrap()),
_ => None,
})
.await;
let role_name1 = RoleName::from("anonymous");
let role_name2 = RoleName::from("authenticated");
let roles = vec![
RoleNameInt::from(&role_name1),
RoleNameInt::from(&role_name2),
];
let rules = vec![
AuthRule {
id: "foo".to_owned(),
jwks_url: format!("http://{jwks_addr}/foo").parse().unwrap(),
audience: None,
role_names: roles.clone(),
},
AuthRule {
id: "bar".to_owned(),
jwks_url: format!("http://{jwks_addr}/bar").parse().unwrap(),
audience: None,
role_names: roles.clone(),
},
];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
impl FetchAuthRules for Fetch {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
_role_name: RoleName,
) -> anyhow::Result<Vec<AuthRule>> {
Ok(vec![
AuthRule {
id: "foo".to_owned(),
jwks_url: format!("http://{}/foo", self.0).parse().unwrap(),
audience: None,
},
AuthRule {
id: "bar".to_owned(),
jwks_url: format!("http://{}/bar", self.0).parse().unwrap(),
audience: None,
},
])
}
}
let role_name = RoleName::from("user");
let endpoint = EndpointId::from("ep");
let jwt1 = new_rsa_jwt("rs1".into(), rs1);
let jwt2 = new_rsa_jwt("rs2".into(), rs2);
let jwt3 = new_ec_jwt("ec1".into(), &ec1);
let jwt4 = new_ec_jwt("ec2".into(), &ec2);
let jwk_cache = Arc::new(JwkCacheEntryLock::default());
let tokens = [jwt1, jwt2, jwt3, jwt4];
let role_names = [role_name1, role_name2];
for role in &role_names {
for token in &tokens {
jwk_cache
.check_jwt(
&RequestMonitoring::test(),
endpoint.clone(),
role,
&fetch,
token,
)
.await
.unwrap();
}
}
}
#[tokio::test]
async fn check_jwt_invalid_signature() {
let (_, jwk) = new_ec_jwk("1".into());
let (key, _) = new_ec_jwk("1".into());
// has a matching kid, but signed by the wrong key
let bad_jwt = new_ec_jwt("1".into(), &key);
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;
let role = RoleName::from("authenticated");
let rules = vec![AuthRule {
id: String::new(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: None,
role_names: vec![RoleNameInt::from(&role)],
}];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let ep = EndpointId::from("ep");
let ctx = RequestMonitoring::test();
let err = jwk_cache
.check_jwt(&ctx, ep, &role, &fetch, &bad_jwt)
.await
.unwrap_err();
assert!(
err.to_string().contains("signature error"),
"expected \"signature error\", got {err:?}"
);
}
#[tokio::test]
async fn check_jwt_unknown_role() {
let (key, jwk) = new_rsa_jwk(RS1, "1".into());
let jwt = new_rsa_jwt("1".into(), key);
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;
let role = RoleName::from("authenticated");
let rules = vec![AuthRule {
id: String::new(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: None,
role_names: vec![RoleNameInt::from(&role)],
}];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let ep = EndpointId::from("ep");
// this role_name is not accepted
let bad_role_name = RoleName::from("cloud_admin");
let ctx = RequestMonitoring::test();
let err = jwk_cache
.check_jwt(&ctx, ep, &bad_role_name, &fetch, &jwt)
.await
.unwrap_err();
assert!(
err.to_string().contains("jwk not found"),
"expected \"jwk not found\", got {err:?}"
);
}
#[tokio::test]
async fn check_jwt_invalid_claims() {
let (key, jwk) = new_ec_jwk("1".into());
let jwks = jose_jwk::JwkSet { keys: vec![jwk] };
let jwks_addr = jwks_server(move |path| match path {
"/" => Some(serde_json::to_vec(&jwks).unwrap()),
_ => None,
})
.await;
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
struct Test {
body: serde_json::Value,
error: &'static str,
}
let table = vec![
Test {
body: json! {{
"nbf": now + 60,
"aud": "neon",
}},
error: "JWT token is not yet ready to use",
},
Test {
body: json! {{
"exp": now - 60,
"aud": ["neon"],
}},
error: "JWT token has expired",
},
Test {
body: json! {{
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": [],
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": "foo",
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": ["foo"],
}},
error: "invalid JWT token audience",
},
Test {
body: json! {{
"aud": ["foo", "bar"],
}},
error: "invalid JWT token audience",
},
];
let role = RoleName::from("authenticated");
let rules = vec![AuthRule {
id: String::new(),
jwks_url: format!("http://{jwks_addr}/").parse().unwrap(),
audience: Some("neon".to_string()),
role_names: vec![RoleNameInt::from(&role)],
}];
let fetch = Fetch(rules);
let jwk_cache = JwkCache::default();
let ep = EndpointId::from("ep");
let ctx = RequestMonitoring::test();
for test in table {
let jwt = new_custom_ec_jwt("1".into(), &key, test.body);
match jwk_cache
.check_jwt(&ctx, ep.clone(), &role, &fetch, &jwt)
for token in [jwt1, jwt2, jwt3, jwt4] {
jwk_cache
.check_jwt(
&RequestMonitoring::test(),
&token,
&client,
endpoint.clone(),
role_name.clone(),
&Fetch(addr),
)
.await
{
Err(err) if err.to_string().contains(test.error) => {}
Err(err) => {
panic!("expected {:?}, got {err:?}", test.error)
}
Ok(()) => {
panic!("expected {:?}, got ok", test.error)
}
}
.unwrap();
}
}
}

View File

@@ -1,4 +1,4 @@
use std::net::SocketAddr;
use std::{collections::HashMap, net::SocketAddr};
use anyhow::Context;
use arc_swap::ArcSwapOption;
@@ -10,19 +10,21 @@ use crate::{
NodeInfo,
},
context::RequestMonitoring,
intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag},
EndpointId,
intern::{BranchIdInt, BranchIdTag, EndpointIdTag, InternId, ProjectIdInt, ProjectIdTag},
EndpointId, RoleName,
};
use super::jwt::{AuthRule, FetchAuthRules};
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
pub struct LocalBackend {
pub(crate) jwks_cache: JwkCache,
pub(crate) node_info: NodeInfo,
}
impl LocalBackend {
pub fn new(postgres_addr: SocketAddr) -> Self {
LocalBackend {
jwks_cache: JwkCache::default(),
node_info: NodeInfo {
config: {
let mut cfg = ConnCfg::new();
@@ -46,17 +48,26 @@ impl LocalBackend {
#[derive(Clone, Copy)]
pub(crate) struct StaticAuthRules;
pub static JWKS_ROLE_MAP: ArcSwapOption<EndpointJwksResponse> = ArcSwapOption::const_empty();
pub static JWKS_ROLE_MAP: ArcSwapOption<JwksRoleSettings> = ArcSwapOption::const_empty();
#[derive(Debug, Clone)]
pub struct JwksRoleSettings {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
}
impl FetchAuthRules for StaticAuthRules {
async fn fetch_auth_rules(
&self,
_ctx: &RequestMonitoring,
_endpoint: EndpointId,
role_name: RoleName,
) -> anyhow::Result<Vec<AuthRule>> {
let mappings = JWKS_ROLE_MAP.load();
let role_mappings = mappings
.as_deref()
.and_then(|m| m.roles.get(&role_name))
.context("JWKs settings for this role were not configured")?;
let mut rules = vec![];
for setting in &role_mappings.jwks {
@@ -64,7 +75,6 @@ impl FetchAuthRules for StaticAuthRules {
id: setting.id.clone(),
jwks_url: setting.jwks_url.clone(),
audience: setting.jwt_audience.clone(),
role_names: setting.role_names.clone(),
});
}

View File

@@ -1,6 +1,5 @@
use crate::{
auth, compute,
config::AuthenticationConfig,
console::{self, provider::NodeInfo},
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
@@ -59,7 +58,6 @@ pub(crate) fn new_psql_session_id() -> String {
pub(super) async fn authenticate(
ctx: &RequestMonitoring,
auth_config: &'static AuthenticationConfig,
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
@@ -91,14 +89,6 @@ pub(super) async fn authenticate(
info!(parent: &span, "waiting for console's reply...");
let db_info = waiter.await.map_err(WebAuthError::from)?;
if auth_config.ip_allowlist_check_enabled {
if let Some(allowed_ips) = &db_info.allowed_ips {
if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
}
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
// This config should be self-contained, because we won't

View File

@@ -1,38 +1,34 @@
use std::{net::SocketAddr, pin::pin, str::FromStr, sync::Arc, time::Duration};
use std::{
net::SocketAddr,
path::{Path, PathBuf},
pin::pin,
sync::Arc,
time::Duration,
};
use anyhow::{bail, ensure, Context};
use camino::{Utf8Path, Utf8PathBuf};
use compute_api::spec::LocalProxySpec;
use anyhow::{bail, ensure};
use dashmap::DashMap;
use futures::future::Either;
use futures::{future::Either, FutureExt};
use proxy::{
auth::backend::{
jwt::JwkCache,
local::{LocalBackend, JWKS_ROLE_MAP},
},
auth::backend::local::{JwksRoleSettings, LocalBackend, JWKS_ROLE_MAP},
cancellation::CancellationHandlerMain,
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
console::{
locks::ApiLocks,
messages::{EndpointJwksResponse, JwksSettings},
},
console::{locks::ApiLocks, messages::JwksRoleMapping},
http::health_server::AppMetrics,
intern::RoleNameInt,
metrics::{Metrics, ThreadPoolMetrics},
rate_limiter::{BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo},
scram::threadpool::ThreadPool,
serverless::{self, cancel_set::CancelSet, GlobalConnPoolOptions},
RoleName,
};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
use clap::Parser;
use tokio::{net::TcpListener, sync::Notify, task::JoinSet};
use tokio::{net::TcpListener, task::JoinSet};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn};
use utils::{pid_file, project_build_tag, project_git_version, sentry_init::init_sentry};
use utils::{project_build_tag, project_git_version, sentry_init::init_sentry};
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
@@ -76,12 +72,9 @@ struct LocalProxyCliArgs {
/// Address of the postgres server
#[clap(long, default_value = "127.0.0.1:5432")]
compute: SocketAddr,
/// Path of the local proxy config file
/// File address of the local proxy config file
#[clap(long, default_value = "./localproxy.json")]
config_path: Utf8PathBuf,
/// Path of the local proxy PID file
#[clap(long, default_value = "./localproxy.pid")]
pid_path: Utf8PathBuf,
config_path: PathBuf,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -133,24 +126,6 @@ async fn main() -> anyhow::Result<()> {
let args = LocalProxyCliArgs::parse();
let config = build_config(&args)?;
// before we bind to any ports, write the process ID to a file
// so that compute-ctl can find our process later
// in order to trigger the appropriate SIGHUP on config change.
//
// This also claims a "lock" that makes sure only one instance
// of local-proxy runs at a time.
let _process_guard = loop {
match pid_file::claim_for_current_process(&args.pid_path) {
Ok(guard) => break guard,
Err(e) => {
// compute-ctl might have tried to read the pid-file to let us
// know about some config change. We should try again.
error!(path=?args.pid_path, "could not claim PID file guard: {e:?}");
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
};
let metrics_listener = TcpListener::bind(args.metrics).await?.into_std()?;
let http_listener = TcpListener::bind(args.http).await?;
let shutdown = CancellationToken::new();
@@ -164,30 +139,12 @@ async fn main() -> anyhow::Result<()> {
16,
));
// write the process ID to a file so that compute-ctl can find our process later
// in order to trigger the appropriate SIGHUP on config change.
let pid = std::process::id();
info!("process running in PID {pid}");
std::fs::write(args.pid_path, format!("{pid}\n")).context("writing PID to file")?;
refresh_config(args.config_path.clone()).await;
let mut maintenance_tasks = JoinSet::new();
let refresh_config_notify = Arc::new(Notify::new());
maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), {
let refresh_config_notify = Arc::clone(&refresh_config_notify);
move || {
refresh_config_notify.notify_one();
}
maintenance_tasks.spawn(proxy::handle_signals(shutdown.clone(), move || {
refresh_config(args.config_path.clone()).map(Ok)
}));
// trigger the first config load **after** setting up the signal hook
// to avoid the race condition where:
// 1. No config file registered when local-proxy starts up
// 2. The config file is written but the signal hook is not yet received
// 3. local-proxy completes startup but has no config loaded, despite there being a registerd config.
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(args.config_path, refresh_config_notify));
maintenance_tasks.spawn(proxy::http::health_server::task_main(
metrics_listener,
AppMetrics {
@@ -270,15 +227,12 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
allow_self_signed_compute: false,
http_config,
authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_auth_broker: false,
accept_jwts: true,
},
require_client_ip: false,
handshake_timeout: Duration::from_secs(10),
@@ -291,84 +245,81 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
})))
}
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
loop {
rx.notified().await;
match refresh_config_inner(&path).await {
Ok(()) => {}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
async fn refresh_config(path: PathBuf) {
match refresh_config_inner(&path).await {
Ok(()) => {}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
}
async fn refresh_config_inner(path: &Utf8Path) -> anyhow::Result<()> {
async fn refresh_config_inner(path: &Path) -> anyhow::Result<()> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut data: JwksRoleMapping = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
let mut settings = None;
for jwks in data.jwks {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
for mapping in data.roles.values_mut() {
for jwks in &mut mapping.jwks {
ensure!(
jwks.jwks_url.has_authority()
&& (jwks.jwks_url.scheme() == "http" || jwks.jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks.jwks_url
.host()
.is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
// clear username, password and ports
jwks.jwks_url.set_username("").expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
jwks.jwks_url.set_password(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
jwks.jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
}
jwks_set.push(JwksSettings {
id: jwks.id,
jwks_url,
provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
// clear query params
jwks.jwks_url.set_fragment(None);
jwks.jwks_url.query_pairs_mut().clear().finish();
if jwks.jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks.jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks.jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
let (pr, br) = settings.get_or_insert((jwks.project_id, jwks.branch_id));
ensure!(
*pr == jwks.project_id,
"inconsistent project IDs configured"
);
ensure!(*br == jwks.branch_id, "inconsistent branch IDs configured");
}
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some((project_id, branch_id)) = settings {
JWKS_ROLE_MAP.store(Some(Arc::new(JwksRoleSettings {
roles: data.roles,
project_id,
branch_id,
})));
}
Ok(())
}

View File

@@ -133,7 +133,9 @@ async fn main() -> anyhow::Result<()> {
proxy_listener,
cancellation_token.clone(),
));
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || {}));
let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || async {
Ok(())
}));
// the signal task cant ever succeed.
// the main task can error, or can succeed on cancellation.

View File

@@ -8,7 +8,6 @@ use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
use aws_config::Region;
use futures::future::Either;
use proxy::auth;
use proxy::auth::backend::jwt::JwkCache;
use proxy::auth::backend::AuthRateLimiter;
use proxy::auth::backend::MaybeOwned;
use proxy::cancellation::CancelMap;
@@ -103,9 +102,6 @@ struct ProxyCliArgs {
default_value = "http://localhost:3000/authenticate_proxy_request/"
)]
auth_endpoint: String,
/// if this is not local proxy, this toggles whether we accept jwt or passwords for http
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_auth_broker: bool,
/// path to TLS key for client postgres connections
///
/// tls-key and tls-cert are for backwards compatibility, we can put all certs in one dir
@@ -386,27 +382,9 @@ async fn main() -> anyhow::Result<()> {
info!("Starting mgmt on {mgmt_address}");
let mgmt_listener = TcpListener::bind(mgmt_address).await?;
let proxy_listener = if !args.is_auth_broker {
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
Some(TcpListener::bind(proxy_address).await?)
} else {
None
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
None
};
let proxy_address: SocketAddr = args.proxy.parse()?;
info!("Starting proxy on {proxy_address}");
let proxy_listener = TcpListener::bind(proxy_address).await?;
let cancellation_token = CancellationToken::new();
let cancel_map = CancelMap::default();
@@ -452,17 +430,21 @@ async fn main() -> anyhow::Result<()> {
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
client_tasks.spawn(proxy::proxy::task_main(
config,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
if let Some(serverless_address) = args.wss {
let serverless_address: SocketAddr = serverless_address.parse()?;
info!("Starting wss on {serverless_address}");
let serverless_listener = TcpListener::bind(serverless_address).await?;
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
serverless_listener,
@@ -479,7 +461,10 @@ async fn main() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(proxy::handle_signals(
cancellation_token.clone(),
|| async { Ok(()) },
));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
@@ -692,7 +677,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
)?;
let http_config = HttpConfig {
accept_websockets: !args.is_auth_broker,
accept_websockets: true,
pool_options: GlobalConnPoolOptions {
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
@@ -707,15 +692,12 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,
accept_jwts: args.is_auth_broker,
};
let config = Box::leak(Box::new(ProxyConfig {

View File

@@ -1,8 +1,5 @@
use crate::{
auth::{
self,
backend::{jwt::JwkCache, AuthRateLimiter},
},
auth::{self, backend::AuthRateLimiter},
console::locks::ApiLocks,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
@@ -70,9 +67,6 @@ pub struct AuthenticationConfig {
pub rate_limiter: AuthRateLimiter,
pub rate_limit_ip_subnet: u8,
pub ip_allowlist_check_enabled: bool,
pub jwks_cache: JwkCache,
pub is_auth_broker: bool,
pub accept_jwts: bool,
}
impl TlsConfig {
@@ -256,26 +250,18 @@ impl CertResolver {
let common_name = pem.subject().to_string();
// We need to get the canonical name for this certificate so we can match them against any domain names
// seen within the proxy codebase.
//
// In scram-proxy we use wildcard certificates only, with the database endpoint as the wildcard subdomain, taken from SNI.
// We need to remove the wildcard prefix for the purposes of certificate selection.
//
// auth-broker does not use SNI and instead uses the Neon-Connection-String header.
// Auth broker has the subdomain `apiauth` we need to remove for the purposes of validating the Neon-Connection-String.
//
// Console Web proxy does not use any wildcard domains and does not need any certificate selection or conn string
// validation, so let's we can continue with any common-name
let common_name = if let Some(s) = common_name.strip_prefix("CN=*.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=apiauth.") {
s.to_string()
} else if let Some(s) = common_name.strip_prefix("CN=") {
s.to_string()
// We only use non-wildcard certificates in web auth proxy so it seems okay to treat them the same as
// wildcard ones as we don't use SNI there. That treatment only affects certificate selection, so
// verify-full will still check wildcard match. Old coding here just ignored non-wildcard common names
// and passed None instead, which blows up number of cases downstream code should handle. Proper coding
// here should better avoid Option for common_names, and do wildcard-based certificate selection instead
// of cutting off '*.' parts.
let common_name = if common_name.starts_with("CN=*.") {
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
} else {
bail!("Failed to parse common name from certificate")
};
common_name.strip_prefix("CN=").map(|s| s.to_string())
}
.context("Failed to parse common name from certificate")?;
let cert = Arc::new(rustls::sign::CertifiedKey::new(cert_chain, key));

View File

@@ -1,11 +1,13 @@
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::{self, Display};
use crate::auth::IpPattern;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
use crate::proxy::retry::CouldRetry;
use crate::RoleName;
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
@@ -282,8 +284,6 @@ pub(crate) struct DatabaseInfo {
/// be inconvenient for debug with local PG instance.
pub(crate) password: Option<Box<str>>,
pub(crate) aux: MetricsAuxInfo,
#[serde(default)]
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
}
// Manually implement debug to omit sensitive info.
@@ -294,7 +294,6 @@ impl fmt::Debug for DatabaseInfo {
.field("port", &self.port)
.field("dbname", &self.dbname)
.field("user", &self.user)
.field("allowed_ips", &self.allowed_ips)
.finish_non_exhaustive()
}
}
@@ -346,6 +345,11 @@ impl ColdStartInfo {
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct JwksRoleMapping {
pub roles: HashMap<RoleName, EndpointJwksResponse>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct EndpointJwksResponse {
pub jwks: Vec<JwksSettings>,
@@ -354,10 +358,11 @@ pub struct EndpointJwksResponse {
#[derive(Debug, Deserialize, Clone)]
pub struct JwksSettings {
pub id: String,
pub project_id: ProjectIdInt,
pub branch_id: BranchIdInt,
pub jwks_url: url::Url,
pub provider_name: String,
pub jwt_audience: Option<String>,
pub role_names: Vec<RoleNameInt>,
}
#[cfg(test)]
@@ -427,22 +432,6 @@ mod tests {
"aux": dummy_aux(),
}))?;
// with allowed_ips
let dbinfo = serde_json::from_value::<DatabaseInfo>(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
"aux": dummy_aux(),
"allowed_ips": ["127.0.0.1"],
}))?;
assert_eq!(
dbinfo.allowed_ips,
Some(vec![IpPattern::Single("127.0.0.1".parse()?)])
);
Ok(())
}

View File

@@ -5,10 +5,7 @@ pub mod neon;
use super::messages::{ConsoleError, MetricsAuxInfo};
use crate::{
auth::{
backend::{
jwt::{AuthRule, FetchAuthRules},
ComputeCredentialKeys, ComputeUserInfo,
},
backend::{ComputeCredentialKeys, ComputeUserInfo},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
@@ -19,7 +16,7 @@ use crate::{
intern::ProjectIdInt,
metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey, EndpointId,
scram, EndpointCacheKey,
};
use dashmap::DashMap;
use std::{hash::Hash, sync::Arc, time::Duration};
@@ -337,12 +334,6 @@ pub(crate) trait Api {
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
@@ -352,7 +343,6 @@ pub(crate) trait Api {
}
#[non_exhaustive]
#[derive(Clone)]
pub enum ConsoleBackend {
/// Current Cloud API (V2).
Console(neon::Api),
@@ -396,20 +386,6 @@ impl Api for ConsoleBackend {
}
}
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
match self {
Self::Console(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(any(test, feature = "testing"))]
Self::Postgres(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(test)]
Self::Test(_api) => Ok(vec![]),
}
}
async fn wake_compute(
&self,
ctx: &RequestMonitoring,
@@ -576,13 +552,3 @@ impl WakeComputePermit {
res
}
}
impl FetchAuthRules for ConsoleBackend {
async fn fetch_auth_rules(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.get_endpoint_jwks(ctx, endpoint).await
}
}

View File

@@ -4,9 +4,7 @@ use super::{
errors::{ApiError, GetAuthInfoError, WakeComputeError},
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::{
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
};
use crate::context::RequestMonitoring;
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use crate::{
@@ -120,39 +118,6 @@ impl Api {
})
}
async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
let connection = tokio::spawn(connection);
let res = client.query(
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
&[&endpoint.as_str()],
)
.await?;
let mut rows = vec![];
for row in res {
rows.push(AuthRule {
id: row.get("id"),
jwks_url: url::Url::parse(row.get("jwks_url"))?,
audience: row.get("audience"),
role_names: row
.get::<_, Vec<String>>("role_names")
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
});
}
drop(client);
connection.await??;
Ok(rows)
}
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
@@ -220,14 +185,6 @@ impl super::Api for Api {
))
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(endpoint).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -7,33 +7,27 @@ use super::{
NodeInfo,
};
use crate::{
auth::backend::{jwt::AuthRule, ComputeUserInfo},
auth::backend::ComputeUserInfo,
compute,
console::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
console::messages::{ColdStartInfo, Reason},
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey, EndpointId,
scram, EndpointCacheKey,
};
use crate::{cache::Cached, context::RequestMonitoring};
use ::http::{header::AUTHORIZATION, HeaderName};
use anyhow::bail;
use futures::TryFutureExt;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;
use tracing::{debug, error, info, info_span, warn, Instrument};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]
pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
// put in a shared ref so we don't copy secrets all over in memory
jwt: Arc<str>,
jwt: String,
}
impl Api {
@@ -44,9 +38,7 @@ impl Api {
locks: &'static ApiLocks<EndpointCacheKey>,
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
) -> Self {
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
.unwrap_or_default()
.into();
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN").unwrap_or_default();
Self {
endpoint,
caches,
@@ -79,9 +71,9 @@ impl Api {
async {
let request = self
.endpoint
.get_path("proxy_get_role_secret")
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
@@ -133,61 +125,6 @@ impl Api {
.await
}
async fn do_get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &endpoint.normalize())
.await
{
bail!("endpoint not found");
}
let request_id = ctx.session_id().to_string();
async {
let request = self
.endpoint
.get_with_url(|url| {
url.path_segments_mut()
.push("endpoints")
.push(endpoint.as_str())
.push("jwks");
})
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<EndpointJwksResponse>(response).await?;
let rules = body
.jwks
.into_iter()
.map(|jwks| AuthRule {
id: jwks.id,
jwks_url: jwks.jwks_url,
audience: jwks.jwt_audience,
role_names: jwks.role_names,
})
.collect();
Ok(rules)
}
.map_err(crate::error::log_error)
.instrument(info_span!("http", id = request_id))
.await
}
async fn do_wake_compute(
&self,
ctx: &RequestMonitoring,
@@ -198,7 +135,7 @@ impl Api {
async {
let mut request_builder = self
.endpoint
.get_path("proxy_wake_compute")
.get("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
@@ -325,15 +262,6 @@ impl super::Api for Api {
))
}
#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
ctx: &RequestMonitoring,
endpoint: EndpointId,
) -> anyhow::Result<Vec<AuthRule>> {
self.do_get_endpoint_jwks(ctx, endpoint).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,

View File

@@ -86,17 +86,9 @@ impl Endpoint {
/// Return a [builder](RequestBuilder) for a `GET` request,
/// appending a single `path` segment to the base endpoint URL.
pub(crate) fn get_path(&self, path: &str) -> RequestBuilder {
self.get_with_url(|u| {
u.path_segments_mut().push(path);
})
}
/// Return a [builder](RequestBuilder) for a `GET` request,
/// accepting a closure to modify the url path segments for more complex paths queries.
pub(crate) fn get_with_url(&self, f: impl for<'a> FnOnce(&'a mut ApiUrl)) -> RequestBuilder {
pub(crate) fn get(&self, path: &str) -> RequestBuilder {
let mut url = self.endpoint.clone();
f(&mut url);
url.path_segments_mut().push(path);
self.client.get(url.into_inner())
}
@@ -152,7 +144,7 @@ mod tests {
// Validate that this pattern makes sense.
let req = endpoint
.get_path("frobnicate")
.get("frobnicate")
.query(&[
("foo", Some("10")), // should be just `foo=10`
("bar", None), // shouldn't be passed at all
@@ -170,7 +162,7 @@ mod tests {
let endpoint = Endpoint::new(url, Client::new());
let req = endpoint
.get_path("frobnicate")
.get("frobnicate")
.query(&[("session_id", uuid::Uuid::nil())])
.build()?;

View File

@@ -1,6 +1,5 @@
use std::{
any::type_name, hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index,
sync::OnceLock,
hash::BuildHasherDefault, marker::PhantomData, num::NonZeroUsize, ops::Index, sync::OnceLock,
};
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
@@ -17,21 +16,12 @@ pub struct StringInterner<Id> {
_id: PhantomData<Id>,
}
#[derive(PartialEq, Clone, Copy, Eq, Hash)]
#[derive(PartialEq, Debug, Clone, Copy, Eq, Hash)]
pub struct InternedString<Id> {
inner: Spur,
_id: PhantomData<Id>,
}
impl<Id: InternId> std::fmt::Debug for InternedString<Id> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("InternedString")
.field(&type_name::<Id>())
.field(&self.as_str())
.finish()
}
}
impl<Id: InternId> std::fmt::Display for InternedString<Id> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.as_str().fmt(f)
@@ -140,14 +130,14 @@ impl<Id: InternId> Default for StringInterner<Id> {
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct RoleNameTag;
pub(crate) struct RoleNameTag;
impl InternId for RoleNameTag {
fn get_interner() -> &'static StringInterner<Self> {
static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
ROLE_NAMES.get_or_init(Default::default)
}
}
pub type RoleNameInt = InternedString<RoleNameTag>;
pub(crate) type RoleNameInt = InternedString<RoleNameTag>;
impl From<&RoleName> for RoleNameInt {
fn from(value: &RoleName) -> Self {
RoleNameTag::get_interner().get_or_intern(value)

View File

@@ -82,7 +82,7 @@
impl_trait_overcaptures,
)]
use std::convert::Infallible;
use std::{convert::Infallible, future::Future};
use anyhow::{bail, Context};
use intern::{EndpointIdInt, EndpointIdTag, InternId};
@@ -117,12 +117,13 @@ pub mod usage_metrics;
pub mod waiters;
/// Handle unix signals appropriately.
pub async fn handle_signals<F>(
pub async fn handle_signals<F, Fut>(
token: CancellationToken,
mut refresh_config: F,
) -> anyhow::Result<Infallible>
where
F: FnMut(),
F: FnMut() -> Fut,
Fut: Future<Output = anyhow::Result<()>>,
{
use tokio::signal::unix::{signal, SignalKind};
@@ -135,7 +136,7 @@ where
// Hangup is commonly used for config reload.
_ = hangup.recv() => {
warn!("received SIGHUP");
refresh_config();
refresh_config().await?;
}
// Shut down the whole application.
_ = interrupt.recv() => {

View File

@@ -525,10 +525,6 @@ impl TestBackend for TestConnectMechanism {
{
unimplemented!("not used in tests")
}
fn dyn_clone(&self) -> Box<dyn TestBackend> {
Box::new(self.clone())
}
}
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {

View File

@@ -43,13 +43,6 @@ impl ThreadPool {
pub fn new(n_workers: u8) -> Arc<Self> {
// rayon would be nice here, but yielding in rayon does not work well afaict.
if n_workers == 0 {
return Arc::new(Self {
runtime: None,
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
});
}
Arc::new_cyclic(|pool| {
let pool = pool.clone();
let worker_id = AtomicUsize::new(0);

View File

@@ -5,7 +5,6 @@
mod backend;
pub mod cancel_set;
mod conn_pool;
mod http_conn_pool;
mod http_util;
mod json;
mod sql_over_http;
@@ -20,8 +19,7 @@ use anyhow::Context;
use futures::future::{select, Either};
use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use http_body_util::Full;
use hyper1::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
@@ -83,28 +81,7 @@ pub async fn task_main(
}
});
let http_conn_pool = http_conn_pool::GlobalConnPool::new(&config.http_config);
{
let http_conn_pool = Arc::clone(&http_conn_pool);
tokio::spawn(async move {
http_conn_pool.gc_worker(StdRng::from_entropy()).await;
});
}
// shutdown the connection pool
tokio::spawn({
let cancellation_token = cancellation_token.clone();
let http_conn_pool = http_conn_pool.clone();
async move {
cancellation_token.cancelled().await;
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
.await
.unwrap();
}
});
let backend = Arc::new(PoolingBackend {
http_conn_pool: Arc::clone(&http_conn_pool),
pool: Arc::clone(&conn_pool),
config,
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
@@ -365,7 +342,7 @@ async fn request_handler(
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let host = request
.headers()
.get("host")
@@ -409,7 +386,7 @@ async fn request_handler(
);
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
Ok(response.map(|_: http_body_util::Empty<Bytes>| Full::new(Bytes::new())))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestMonitoring::new(
session_id,
@@ -432,7 +409,7 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Empty::new().map_err(|x| match x {}).boxed())
.body(Full::new(Bytes::new()))
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")

View File

@@ -1,8 +1,6 @@
use std::{io, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use tokio::net::{lookup_host, TcpStream};
use tracing::{field::display, info};
use crate::{
@@ -29,13 +27,9 @@ use crate::{
Host,
};
use super::{
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
http_conn_pool::{self, poll_http2_client},
};
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
pub(crate) config: &'static ProxyConfig,
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
@@ -109,44 +103,32 @@ impl PoolingBackend {
pub(crate) async fn authenticate_with_jwt(
&self,
ctx: &RequestMonitoring,
config: &AuthenticationConfig,
user_info: &ComputeUserInfo,
jwt: String,
) -> Result<(), AuthError> {
jwt: &str,
) -> Result<ComputeCredentials, AuthError> {
match &self.config.auth_backend {
crate::auth::Backend::Console(console, ()) => {
config
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
&user_info.user,
&**console,
&jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
Ok(())
crate::auth::Backend::Console(_, ()) => {
Err(AuthError::auth_failed("JWT login is not yet supported"))
}
crate::auth::Backend::Web(_, ()) => Err(AuthError::auth_failed(
"JWT login over web auth proxy is not supported",
)),
crate::auth::Backend::Local(_) => {
config
crate::auth::Backend::Local(cache) => {
cache
.jwks_cache
.check_jwt(
ctx,
user_info.endpoint.clone(),
&user_info.user,
user_info.user.clone(),
&StaticAuthRules,
&jwt,
jwt,
)
.await
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
Ok(())
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
}
}
}
@@ -192,55 +174,14 @@ impl PoolingBackend {
)
.await
}
// Wake up the destination if needed
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestMonitoring,
conn_info: ConnInfo,
) -> Result<http_conn_pool::Client, HttpConnError> {
info!("pool: looking for an existing connection");
if let Some(client) = self.http_conn_pool.get(ctx, &conn_info) {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self
.config
.auth_backend
.as_ref()
.map(|()| ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
false, // do not allow self signed compute for http flow
self.config.wake_compute_retry_config,
self.config.connect_to_compute_retry_config,
)
.await
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connection to postgres in compute")]
PostgresConnectionError(#[from] tokio_postgres::Error),
#[error("could not connection to local-proxy in compute")]
LocalProxyConnectionError(#[from] LocalProxyConnError),
#[error("could not connection to compute")]
ConnectionError(#[from] tokio_postgres::Error),
#[error("could not get auth info")]
GetAuthInfo(#[from] GetAuthInfoError),
@@ -252,20 +193,11 @@ pub(crate) enum HttpConnError {
TooManyConnectionAttempts(#[from] ApiLockError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper1::Error),
}
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::ConnectionError(p) => p.get_error_kind(),
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
HttpConnError::AuthError(a) => a.get_error_kind(),
HttpConnError::WakeCompute(w) => w.get_error_kind(),
@@ -278,8 +210,7 @@ impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
HttpConnError::ConnectionError(p) => p.to_string(),
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
HttpConnError::AuthError(c) => c.to_string_client(),
HttpConnError::WakeCompute(c) => c.to_string_client(),
@@ -293,8 +224,7 @@ impl UserFacingError for HttpConnError {
impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ConnectionError(e) => e.could_retry(),
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
@@ -306,7 +236,7 @@ impl CouldRetry for HttpConnError {
impl ShouldRetryWakeCompute for HttpConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
HttpConnError::ConnectionError(e) => e.should_retry_wake_compute(),
// we never checked cache validity
HttpConnError::TooManyConnectionAttempts(_) => false,
_ => true,
@@ -314,38 +244,6 @@ impl ShouldRetryWakeCompute for HttpConnError {
}
}
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
}
impl UserFacingError for LocalProxyConnError {
fn to_string_client(&self) -> String {
"Could not establish HTTP connection to the database".to_string()
}
}
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
struct TokioMechanism {
pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
conn_info: ConnInfo,
@@ -395,99 +293,3 @@ impl ConnectMechanism for TokioMechanism {
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
struct HyperMechanism {
pool: Arc<http_conn_pool::GlobalConnPool>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestMonitoring,
node_info: &CachedNodeInfo,
timeout: Duration,
) -> Result<Self::Connection, Self::ConnectError> {
let host = node_info.config.get_host()?;
let permit = self.locks.get_permit(&host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
let res = connect_http2(&host, 10432, timeout).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
}
async fn connect_http2(
host: &str,
port: u16,
timeout: Duration,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
// assumption: host is an ip address so this should not actually perform any requests.
// todo: add that assumption as a guarantee in the control-plane API.
let mut addrs = lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?;
let mut last_err = None;
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
};
};
let (client, connection) = hyper1::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await?;
Ok((client, connection))
}

View File

@@ -1,335 +0,0 @@
use dashmap::DashMap;
use hyper1::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use rand::Rng;
use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::{sync::Arc, sync::Weak};
use tokio::net::TcpStream;
use crate::console::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::{context::RequestMonitoring, EndpointCacheKey};
use tracing::{debug, error};
use tracing::{info, info_span, Instrument};
use super::conn_pool::ConnInfo;
pub(crate) type Send = http2::SendRequest<hyper1::body::Incoming>;
pub(crate) type Connect =
http2::Connection<TokioIo<TcpStream>, hyper1::body::Incoming, TokioExecutor>;
#[derive(Clone)]
struct ConnPoolEntry {
conn: Send,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
}
// Per-endpoint connection pool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub(crate) struct EndpointConnPool {
conns: VecDeque<ConnPoolEntry>,
_guard: HttpEndpointPoolsGuard<'static>,
global_connections_count: Arc<AtomicUsize>,
}
impl EndpointConnPool {
fn get_conn_entry(&mut self) -> Option<ConnPoolEntry> {
let Self { conns, .. } = self;
let conn = conns.pop_front()?;
conns.push_back(conn.clone());
Some(conn)
}
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
let Self {
conns,
global_connections_count,
..
} = self;
let old_len = conns.len();
conns.retain(|conn| conn.conn_id != conn_id);
let new_len = conns.len();
let removed = old_len - new_len;
if removed > 0 {
global_connections_count.fetch_sub(removed, atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(removed as i64);
}
removed > 0
}
}
impl Drop for EndpointConnPool {
fn drop(&mut self) {
if !self.conns.is_empty() {
self.global_connections_count
.fetch_sub(self.conns.len(), atomic::Ordering::Relaxed);
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(self.conns.len() as i64);
}
}
}
pub(crate) struct GlobalConnPool {
// endpoint -> per-endpoint connection pool
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<EndpointCacheKey, Arc<RwLock<EndpointConnPool>>>,
/// Number of endpoint-connection pools
///
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
/// It's only used for diagnostics.
global_pool_size: AtomicUsize,
/// Total number of connections in the pool
global_connections_count: Arc<AtomicUsize>,
config: &'static crate::config::HttpConfig,
}
impl GlobalConnPool {
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
let shards = config.pool_options.pool_shards;
Arc::new(Self {
global_pool: DashMap::with_shard_amount(shards),
global_pool_size: AtomicUsize::new(0),
config,
global_connections_count: Arc::new(AtomicUsize::new(0)),
})
}
pub(crate) fn shutdown(&self) {
// drops all strong references to endpoint-pools
self.global_pool.clear();
}
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
let epoch = self.config.pool_options.gc_epoch;
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
loop {
interval.tick().await;
let shard = rng.gen_range(0..self.global_pool.shards().len());
self.gc(shard);
}
}
fn gc(&self, shard: usize) {
debug!(shard, "pool: performing epoch reclamation");
// acquire a random shard lock
let mut shard = self.global_pool.shards()[shard].write();
let timer = Metrics::get()
.proxy
.http_pool_reclaimation_lag_seconds
.start_timer();
let current_len = shard.len();
let mut clients_removed = 0;
shard.retain(|endpoint, x| {
// if the current endpoint pool is unique (no other strong or weak references)
// then it is currently not in use by any connections.
if let Some(pool) = Arc::get_mut(x.get_mut()) {
let EndpointConnPool { conns, .. } = pool.get_mut();
let old_len = conns.len();
conns.retain(|conn| !conn.conn.is_closed());
let new_len = conns.len();
let removed = old_len - new_len;
clients_removed += removed;
// we only remove this pool if it has no active connections
if conns.is_empty() {
info!("pool: discarding pool for endpoint {endpoint}");
return false;
}
}
true
});
let new_len = shard.len();
drop(shard);
timer.observe();
// Do logging outside of the lock.
if clients_removed > 0 {
let size = self
.global_connections_count
.fetch_sub(clients_removed, atomic::Ordering::Relaxed)
- clients_removed;
Metrics::get()
.proxy
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
}
let removed = current_len - new_len;
if removed > 0 {
let global_pool_size = self
.global_pool_size
.fetch_sub(removed, atomic::Ordering::Relaxed)
- removed;
info!("pool: performed global pool gc. size now {global_pool_size}");
}
}
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
) -> Option<Client> {
let endpoint = conn_info.endpoint_cache_key()?;
let endpoint_pool = self.get_or_create_endpoint_pool(&endpoint);
let client = endpoint_pool.write().get_conn_entry()?;
if client.conn.is_closed() {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return None;
}
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
info!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
Some(Client::new(client.conn, client.aux))
}
fn get_or_create_endpoint_pool(
self: &Arc<Self>,
endpoint: &EndpointCacheKey,
) -> Arc<RwLock<EndpointConnPool>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
}
// slow path
let new_pool = Arc::new(RwLock::new(EndpointConnPool {
conns: VecDeque::new(),
_guard: Metrics::get().proxy.http_endpoint_pools.guard(),
global_connections_count: self.global_connections_count.clone(),
}));
// find or create a pool for this endpoint
let mut created = false;
let pool = self
.global_pool
.entry(endpoint.clone())
.or_insert_with(|| {
created = true;
new_pool
})
.clone();
// log new global pool size
if created {
let global_pool_size = self
.global_pool_size
.fetch_add(1, atomic::Ordering::Relaxed)
+ 1;
info!(
"pool: created new pool for '{endpoint}', global pool size now {global_pool_size}"
);
}
pool
}
}
pub(crate) fn poll_http2_client(
global_pool: Arc<GlobalConnPool>,
ctx: &RequestMonitoring,
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
aux: MetricsAuxInfo,
) -> Client {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
});
let pool = match conn_info.endpoint_cache_key() {
Some(endpoint) => {
let pool = global_pool.get_or_create_endpoint_pool(&endpoint);
pool.write().conns.push_back(ConnPoolEntry {
conn: client.clone(),
conn_id,
aux: aux.clone(),
});
Arc::downgrade(&pool)
}
None => Weak::new(),
};
// let idle = global_pool.get_idle_timeout();
tokio::spawn(
async move {
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {
Ok(()) => info!("connection closed"),
Err(e) => error!(%session_id, "connection error: {}", e),
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
}
}
.instrument(span),
);
Client::new(client, aux)
}
pub(crate) struct Client {
pub(crate) inner: Send,
aux: MetricsAuxInfo,
}
impl Client {
pub(self) fn new(inner: Send, aux: MetricsAuxInfo) -> Self {
Self { inner, aux }
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
USAGE_METRICS.register(Ids {
endpoint_id: self.aux.endpoint_id,
branch_id: self.aux.branch_id,
})
}
}

View File

@@ -5,13 +5,13 @@ use bytes::Bytes;
use anyhow::Context;
use http::{Response, StatusCode};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use http_body_util::Full;
use serde::Serialize;
use utils::http::error::ApiError;
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper1::Error>> {
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
@@ -64,24 +64,17 @@ struct HttpErrorBody {
impl HttpErrorBody {
/// Same as [`utils::http::error::HttpErrorBody::response_from_msg_and_status`]
fn response_from_msg_and_status(
msg: String,
status: StatusCode,
) -> Response<BoxBody<Bytes, hyper1::Error>> {
fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
HttpErrorBody { msg }.to_response(status)
}
/// Same as [`utils::http::error::HttpErrorBody::to_response`]
fn to_response(&self, status: StatusCode) -> Response<BoxBody<Bytes, hyper1::Error>> {
fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
.map_err(|x| match x {})
.boxed(),
)
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
.unwrap()
}
}
@@ -90,14 +83,14 @@ impl HttpErrorBody {
pub(crate) fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;
let response = Response::builder()
.status(status)
.header(http::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json)).map_err(|x| match x {}).boxed())
.body(Full::new(Bytes::from(json)))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}

View File

@@ -8,8 +8,6 @@ use futures::future::Either;
use futures::StreamExt;
use futures::TryFutureExt;
use http::header::AUTHORIZATION;
use http::Method;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
@@ -40,11 +38,9 @@ use url::Url;
use urlencoding;
use utils::http::error::ApiError;
use crate::auth::backend::ComputeCredentials;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
use crate::auth::ComputeUserInfoParseError;
use crate::config::AuthenticationConfig;
use crate::config::ProxyConfig;
use crate::config::TlsConfig;
use crate::context::RequestMonitoring;
@@ -60,7 +56,6 @@ use crate::usage_metrics::MetricCounterRecorder;
use crate::DbName;
use crate::RoleName;
use super::backend::LocalProxyConnError;
use super::backend::PoolingBackend;
use super::conn_pool::AuthData;
use super::conn_pool::Client;
@@ -128,8 +123,8 @@ pub(crate) enum ConnInfoError {
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing password")]
MissingPassword,
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
@@ -138,14 +133,6 @@ pub(crate) enum ConnInfoError {
MalformedEndpoint,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
@@ -159,7 +146,6 @@ impl UserFacingError for ConnInfoError {
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestMonitoring,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
@@ -195,32 +181,21 @@ fn get_conn_info(
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.ok_or(ConnInfoError::MissingPassword)?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
return Err(ConnInfoError::MissingPassword);
};
let endpoint = match connection_url.host() {
@@ -272,7 +247,7 @@ pub(crate) async fn handle(
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
cancel: CancellationToken,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, ApiError> {
) -> Result<Response<Full<Bytes>>, ApiError> {
let result = handle_inner(cancel, config, &ctx, request, backend).await;
let mut response = match result {
@@ -304,7 +279,7 @@ pub(crate) async fn handle(
let mut message = e.to_string_client();
let db_error = match &e {
SqlOverHttpError::ConnectCompute(HttpConnError::PostgresConnectionError(e))
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};
@@ -529,7 +504,7 @@ async fn handle_inner(
ctx: &RequestMonitoring,
request: Request<Incoming>,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
) -> Result<Response<Full<Bytes>>, SqlOverHttpError> {
let _requeset_gauge = Metrics::get()
.proxy
.connection_requests
@@ -539,50 +514,18 @@ async fn handle_inner(
"handling interactive connection from client"
);
let conn_info = get_conn_info(
&config.authentication_config,
ctx,
request.headers(),
config.tls_config.as_ref(),
)?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
);
match conn_info.auth {
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
}
auth => {
handle_db_inner(
cancel,
config,
ctx,
request,
conn_info.conn_info,
auth,
backend,
)
.await
}
}
}
async fn handle_db_inner(
cancel: CancellationToken,
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
auth: AuthData,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
//
// Determine the destination and connection params
//
let headers = request.headers();
// TLS config should be there.
let conn_info = get_conn_info(ctx, headers, config.tls_config.as_ref())?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
);
// Allow connection pooling only if explicitly requested
// or if we have decided that http pool is no longer opt-in
let allow_pool = !config.http_config.pool_options.opt_in
@@ -620,36 +563,26 @@ async fn handle_db_inner(
let authenticate_and_connect = Box::pin(
async {
let keys = match auth {
let keys = match &conn_info.auth {
AuthData::Password(pw) => {
backend
.authenticate_with_password(
ctx,
&config.authentication_config,
&conn_info.user_info,
&pw,
&conn_info.conn_info.user_info,
pw,
)
.await?
}
AuthData::Jwt(jwt) => {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await?;
ComputeCredentials {
info: conn_info.user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
}
.authenticate_with_jwt(ctx, &conn_info.conn_info.user_info, jwt)
.await?
}
};
let client = backend
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
.connect_to_compute(ctx, conn_info.conn_info, keys, !allow_pool)
.await?;
// not strictly necessary to mark success here,
// but it's just insurance for if we forget it somewhere else
@@ -707,11 +640,7 @@ async fn handle_db_inner(
let len = json_output.len();
let response = response
.body(
Full::new(Bytes::from(json_output))
.map_err(|x| match x {})
.boxed(),
)
.body(Full::new(Bytes::from(json_output)))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");
@@ -727,65 +656,6 @@ async fn handle_db_inner(
Ok(response)
}
static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&AUTHORIZATION,
&CONN_STRING,
&RAW_TEXT_OUTPUT,
&ARRAY_MODE,
&TXN_ISOLATION_LEVEL,
&TXN_READ_ONLY,
&TXN_DEFERRABLE,
];
async fn handle_auth_broker_inner(
config: &'static ProxyConfig,
ctx: &RequestMonitoring,
request: Request<Incoming>,
conn_info: ConnInfo,
jwt: String,
backend: Arc<PoolingBackend>,
) -> Result<Response<BoxBody<Bytes, hyper1::Error>>, SqlOverHttpError> {
backend
.authenticate_with_jwt(
ctx,
&config.authentication_config,
&conn_info.user_info,
jwt,
)
.await
.map_err(HttpConnError::from)?;
let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?;
let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql");
let (mut parts, body) = request.into_parts();
let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri);
// todo(conradludgate): maybe auth-broker should parse these and re-serialize
// these instead just to ensure they remain normalised.
for &h in HEADERS_TO_FORWARD {
if let Some(hv) = parts.headers.remove(h) {
req = req.header(h, hv);
}
}
let req = req
.body(body)
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress
let _metrics = client.metrics();
Ok(client
.inner
.send_request(req)
.await
.map_err(LocalProxyConnError::from)
.map_err(HttpConnError::from)?
.map(|b| b.boxed()))
}
impl QueryData {
async fn process(
self,
@@ -835,9 +705,7 @@ impl QueryData {
// query failed or was cancelled.
Ok(Err(error)) => {
let db_error = match &error {
SqlOverHttpError::ConnectCompute(
HttpConnError::PostgresConnectionError(e),
)
SqlOverHttpError::ConnectCompute(HttpConnError::ConnectionError(e))
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
_ => None,
};

View File

@@ -21,6 +21,7 @@ chrono.workspace = true
clap = { workspace = true, features = ["derive"] }
crc32c.workspace = true
fail.workspace = true
git-version.workspace = true
hex.workspace = true
humantime.workspace = true
hyper.workspace = true

View File

@@ -374,16 +374,14 @@ type JoinTaskRes = Result<anyhow::Result<()>, JoinError>;
async fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
// fsync the datadir to make sure we have a consistent state on disk.
if !conf.no_sync {
let dfd = File::open(&conf.workdir).context("open datadir for syncfs")?;
let started = Instant::now();
utils::crashsafe::syncfs(dfd)?;
let elapsed = started.elapsed();
info!(
elapsed_ms = elapsed.as_millis(),
"syncfs data directory done"
);
}
let dfd = File::open(&conf.workdir).context("open datadir for syncfs")?;
let started = Instant::now();
utils::crashsafe::syncfs(dfd)?;
let elapsed = started.elapsed();
info!(
elapsed_ms = elapsed.as_millis(),
"syncfs data directory done"
);
info!("starting safekeeper WAL service on {}", conf.listen_pg_addr);
let pg_listener = tcp_listener::bind(conf.listen_pg_addr.clone()).map_err(|e| {

View File

@@ -15,6 +15,7 @@ const_format.workspace = true
futures.workspace = true
futures-core.workspace = true
futures-util.workspace = true
git-version.workspace = true
humantime.workspace = true
hyper = { workspace = true, features = ["full"] }
once_cell.workspace = true

View File

@@ -20,6 +20,7 @@ chrono.workspace = true
clap.workspace = true
fail.workspace = true
futures.workspace = true
git-version.workspace = true
hex.workspace = true
hyper.workspace = true
humantime.workspace = true

View File

@@ -515,7 +515,7 @@ async fn handle_tenant_timeline_passthrough(
tracing::info!("Proxying request for tenant {} ({})", tenant_id, path);
// Find the node that holds shard zero
let (node, tenant_shard_id) = service.tenant_shard0_node(tenant_id).await?;
let (node, tenant_shard_id) = service.tenant_shard0_node(tenant_id)?;
// Callers will always pass an unsharded tenant ID. Before proxying, we must
// rewrite this to a shard-aware shard zero ID.
@@ -545,10 +545,10 @@ async fn handle_tenant_timeline_passthrough(
let _timer = latency.start_timer(labels.clone());
let client = mgmt_api::Client::new(node.base_url(), service.get_config().jwt_token.as_deref());
let resp = client.get_raw(path).await.map_err(|e|
// We return 503 here because if we can't successfully send a request to the pageserver,
// either we aren't available or the pageserver is unavailable.
ApiError::ResourceUnavailable(format!("Error sending pageserver API request to {node}: {e}").into()))?;
let resp = client.get_raw(path).await.map_err(|_e|
// FIXME: give APiError a proper Unavailable variant. We return 503 here because
// if we can't successfully send a request to the pageserver, we aren't available.
ApiError::ShuttingDown)?;
if !resp.status().is_success() {
let error_counter = &METRICS_REGISTRY
@@ -557,19 +557,6 @@ async fn handle_tenant_timeline_passthrough(
error_counter.inc(labels);
}
// Transform 404 into 503 if we raced with a migration
if resp.status() == reqwest::StatusCode::NOT_FOUND {
// Look up node again: if we migrated it will be different
let (new_node, _tenant_shard_id) = service.tenant_shard0_node(tenant_id).await?;
if new_node.get_id() != node.get_id() {
// Rather than retry here, send the client a 503 to prompt a retry: this matches
// the pageserver's use of 503, and all clients calling this API should retry on 503.
return Err(ApiError::ResourceUnavailable(
format!("Pageserver {node} returned 404, was migrated to {new_node}").into(),
));
}
}
// We have a reqest::Response, would like a http::Response
let mut builder = hyper::Response::builder().status(map_reqwest_hyper_status(resp.status())?);
for (k, v) in resp.headers() {

View File

@@ -2,8 +2,8 @@ use std::{str::FromStr, time::Duration};
use pageserver_api::{
controller_api::{
AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeRegisterRequest,
NodeSchedulingPolicy, TenantLocateResponseShard,
NodeAvailability, NodeDescribeResponse, NodeRegisterRequest, NodeSchedulingPolicy,
TenantLocateResponseShard,
},
shard::TenantShardId,
};
@@ -36,7 +36,7 @@ pub(crate) struct Node {
listen_pg_addr: String,
listen_pg_port: u16,
availability_zone_id: AvailabilityZone,
availability_zone_id: String,
// This cancellation token means "stop any RPCs in flight to this node, and don't start
// any more". It is not related to process shutdown.
@@ -64,8 +64,8 @@ impl Node {
}
#[allow(unused)]
pub(crate) fn get_availability_zone_id(&self) -> &AvailabilityZone {
&self.availability_zone_id
pub(crate) fn get_availability_zone_id(&self) -> &str {
self.availability_zone_id.as_str()
}
pub(crate) fn get_scheduling(&self) -> NodeSchedulingPolicy {
@@ -181,7 +181,7 @@ impl Node {
listen_http_port: u16,
listen_pg_addr: String,
listen_pg_port: u16,
availability_zone_id: AvailabilityZone,
availability_zone_id: String,
) -> Self {
Self {
id,
@@ -204,7 +204,7 @@ impl Node {
listen_http_port: self.listen_http_port as i32,
listen_pg_addr: self.listen_pg_addr.clone(),
listen_pg_port: self.listen_pg_port as i32,
availability_zone_id: self.availability_zone_id.0.clone(),
availability_zone_id: self.availability_zone_id.clone(),
}
}
@@ -219,7 +219,7 @@ impl Node {
listen_http_port: np.listen_http_port as u16,
listen_pg_addr: np.listen_pg_addr,
listen_pg_port: np.listen_pg_port as u16,
availability_zone_id: AvailabilityZone(np.availability_zone_id),
availability_zone_id: np.availability_zone_id,
cancel: CancellationToken::new(),
}
}

View File

@@ -9,7 +9,6 @@ use diesel::pg::PgConnection;
use diesel::prelude::*;
use diesel::Connection;
use itertools::Itertools;
use pageserver_api::controller_api::AvailabilityZone;
use pageserver_api::controller_api::MetadataHealthRecord;
use pageserver_api::controller_api::ShardSchedulingPolicy;
use pageserver_api::controller_api::{NodeSchedulingPolicy, PlacementPolicy};
@@ -668,8 +667,8 @@ impl Persistence {
pub(crate) async fn set_tenant_shard_preferred_azs(
&self,
preferred_azs: Vec<(TenantShardId, AvailabilityZone)>,
) -> DatabaseResult<Vec<(TenantShardId, AvailabilityZone)>> {
preferred_azs: Vec<(TenantShardId, String)>,
) -> DatabaseResult<Vec<(TenantShardId, String)>> {
use crate::schema::tenant_shards::dsl::*;
self.with_measured_conn(DatabaseOperation::SetPreferredAzs, move |conn| {
@@ -680,7 +679,7 @@ impl Persistence {
.filter(tenant_id.eq(tenant_shard_id.tenant_id.to_string()))
.filter(shard_number.eq(tenant_shard_id.shard_number.0 as i32))
.filter(shard_count.eq(tenant_shard_id.shard_count.literal() as i32))
.set(preferred_az_id.eq(preferred_az.0.clone()))
.set(preferred_az_id.eq(preferred_az))
.execute(conn)?;
if updated == 1 {

View File

@@ -463,7 +463,7 @@ impl Reconciler {
for (timeline_id, baseline_lsn) in &baseline {
match latest.get(timeline_id) {
Some(latest_lsn) => {
tracing::info!(timeline_id = %timeline_id, "🕑 LSN origin {baseline_lsn} vs destination {latest_lsn}");
tracing::info!("🕑 LSN origin {baseline_lsn} vs destination {latest_lsn}");
if latest_lsn < baseline_lsn {
any_behind = true;
}
@@ -541,8 +541,6 @@ impl Reconciler {
}
}
pausable_failpoint!("reconciler-live-migrate-pre-generation-inc");
// Increment generation before attaching to new pageserver
self.generation = Some(
self.persistence
@@ -619,8 +617,6 @@ impl Reconciler {
},
);
pausable_failpoint!("reconciler-live-migrate-post-detach");
tracing::info!("🔁 Switching to AttachedSingle mode on node {dest_ps}",);
let dest_final_conf = build_location_config(
&self.shard,

View File

@@ -1,8 +1,8 @@
use crate::{node::Node, tenant_shard::TenantShard};
use itertools::Itertools;
use pageserver_api::{controller_api::AvailabilityZone, models::PageserverUtilization};
use pageserver_api::models::PageserverUtilization;
use serde::Serialize;
use std::{collections::HashMap, fmt::Debug};
use std::collections::HashMap;
use utils::{http::error::ApiError, id::NodeId};
/// Scenarios in which we cannot find a suitable location for a tenant shard
@@ -27,230 +27,17 @@ pub enum MaySchedule {
}
#[derive(Serialize)]
pub(crate) struct SchedulerNode {
struct SchedulerNode {
/// How many shards are currently scheduled on this node, via their [`crate::tenant_shard::IntentState`].
shard_count: usize,
/// How many shards are currently attached on this node, via their [`crate::tenant_shard::IntentState`].
attached_shard_count: usize,
/// Availability zone id in which the node resides
az: AvailabilityZone,
/// Whether this node is currently elegible to have new shards scheduled (this is derived
/// from a node's availability state and scheduling policy).
may_schedule: MaySchedule,
}
pub(crate) trait NodeSchedulingScore: Debug + Ord + Copy + Sized {
fn generate(
node_id: &NodeId,
node: &mut SchedulerNode,
preferred_az: &Option<AvailabilityZone>,
context: &ScheduleContext,
) -> Option<Self>;
fn is_overloaded(&self) -> bool;
fn node_id(&self) -> NodeId;
}
pub(crate) trait ShardTag {
type Score: NodeSchedulingScore;
}
pub(crate) struct AttachedShardTag {}
impl ShardTag for AttachedShardTag {
type Score = NodeAttachmentSchedulingScore;
}
pub(crate) struct SecondaryShardTag {}
impl ShardTag for SecondaryShardTag {
type Score = NodeSecondarySchedulingScore;
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
enum AzMatch {
Yes,
No,
Unknown,
}
impl AzMatch {
fn new(node_az: &AvailabilityZone, shard_preferred_az: Option<&AvailabilityZone>) -> Self {
match shard_preferred_az {
Some(preferred_az) if preferred_az == node_az => Self::Yes,
Some(_preferred_az) => Self::No,
None => Self::Unknown,
}
}
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
struct AttachmentAzMatch(AzMatch);
impl Ord for AttachmentAzMatch {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// Lower scores indicate a more suitable node.
// Note that we prefer a node for which we don't have
// info to a node which we are certain doesn't match the
// preferred AZ of the shard.
let az_match_score = |az_match: &AzMatch| match az_match {
AzMatch::Yes => 0,
AzMatch::Unknown => 1,
AzMatch::No => 2,
};
az_match_score(&self.0).cmp(&az_match_score(&other.0))
}
}
impl PartialOrd for AttachmentAzMatch {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
struct SecondaryAzMatch(AzMatch);
impl Ord for SecondaryAzMatch {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// Lower scores indicate a more suitable node.
// For secondary locations we wish to avoid the preferred AZ
// of the shard.
let az_match_score = |az_match: &AzMatch| match az_match {
AzMatch::No => 0,
AzMatch::Unknown => 1,
AzMatch::Yes => 2,
};
az_match_score(&self.0).cmp(&az_match_score(&other.0))
}
}
impl PartialOrd for SecondaryAzMatch {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
/// Scheduling score of a given node for shard attachments.
/// Lower scores indicate more suitable nodes.
/// Ordering is given by member declaration order (top to bottom).
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub(crate) struct NodeAttachmentSchedulingScore {
/// The number of shards belonging to the tenant currently being
/// scheduled that are attached to this node.
affinity_score: AffinityScore,
/// Flag indicating whether this node matches the preferred AZ
/// of the shard. For equal affinity scores, nodes in the matching AZ
/// are considered first.
az_match: AttachmentAzMatch,
/// Size of [`ScheduleContext::attached_nodes`] for the current node.
/// This normally tracks the number of attached shards belonging to the
/// tenant being scheduled that are already on this node.
attached_shards_in_context: usize,
/// Utilisation score that combines shard count and disk utilisation
utilization_score: u64,
/// Total number of shards attached to this node. When nodes have identical utilisation, this
/// acts as an anti-affinity between attached shards.
total_attached_shard_count: usize,
/// Convenience to make selection deterministic in tests and empty systems
node_id: NodeId,
}
impl NodeSchedulingScore for NodeAttachmentSchedulingScore {
fn generate(
node_id: &NodeId,
node: &mut SchedulerNode,
preferred_az: &Option<AvailabilityZone>,
context: &ScheduleContext,
) -> Option<Self> {
let utilization = match &mut node.may_schedule {
MaySchedule::Yes(u) => u,
MaySchedule::No => {
return None;
}
};
Some(Self {
affinity_score: context
.nodes
.get(node_id)
.copied()
.unwrap_or(AffinityScore::FREE),
az_match: AttachmentAzMatch(AzMatch::new(&node.az, preferred_az.as_ref())),
attached_shards_in_context: context.attached_nodes.get(node_id).copied().unwrap_or(0),
utilization_score: utilization.cached_score(),
total_attached_shard_count: node.attached_shard_count,
node_id: *node_id,
})
}
fn is_overloaded(&self) -> bool {
PageserverUtilization::is_overloaded(self.utilization_score)
}
fn node_id(&self) -> NodeId {
self.node_id
}
}
/// Scheduling score of a given node for shard secondaries.
/// Lower scores indicate more suitable nodes.
/// Ordering is given by member declaration order (top to bottom).
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub(crate) struct NodeSecondarySchedulingScore {
/// Flag indicating whether this node matches the preferred AZ
/// of the shard. For secondary locations we wish to avoid nodes in.
/// the preferred AZ of the shard, since that's where the attached location
/// should be scheduled and having the secondary in the same AZ is bad for HA.
az_match: SecondaryAzMatch,
/// The number of shards belonging to the tenant currently being
/// scheduled that are attached to this node.
affinity_score: AffinityScore,
/// Utilisation score that combines shard count and disk utilisation
utilization_score: u64,
/// Total number of shards attached to this node. When nodes have identical utilisation, this
/// acts as an anti-affinity between attached shards.
total_attached_shard_count: usize,
/// Convenience to make selection deterministic in tests and empty systems
node_id: NodeId,
}
impl NodeSchedulingScore for NodeSecondarySchedulingScore {
fn generate(
node_id: &NodeId,
node: &mut SchedulerNode,
preferred_az: &Option<AvailabilityZone>,
context: &ScheduleContext,
) -> Option<Self> {
let utilization = match &mut node.may_schedule {
MaySchedule::Yes(u) => u,
MaySchedule::No => {
return None;
}
};
Some(Self {
az_match: SecondaryAzMatch(AzMatch::new(&node.az, preferred_az.as_ref())),
affinity_score: context
.nodes
.get(node_id)
.copied()
.unwrap_or(AffinityScore::FREE),
utilization_score: utilization.cached_score(),
total_attached_shard_count: node.attached_shard_count,
node_id: *node_id,
})
}
fn is_overloaded(&self) -> bool {
PageserverUtilization::is_overloaded(self.utilization_score)
}
fn node_id(&self) -> NodeId {
self.node_id
}
}
impl PartialEq for SchedulerNode {
fn eq(&self, other: &Self) -> bool {
let may_schedule_matches = matches!(
@@ -261,7 +48,6 @@ impl PartialEq for SchedulerNode {
may_schedule_matches
&& self.shard_count == other.shard_count
&& self.attached_shard_count == other.attached_shard_count
&& self.az == other.az
}
}
@@ -376,7 +162,6 @@ impl Scheduler {
shard_count: 0,
attached_shard_count: 0,
may_schedule: node.may_schedule(),
az: node.get_availability_zone_id().clone(),
},
);
}
@@ -403,7 +188,6 @@ impl Scheduler {
shard_count: 0,
attached_shard_count: 0,
may_schedule: node.may_schedule(),
az: node.get_availability_zone_id().clone(),
},
);
}
@@ -582,7 +366,6 @@ impl Scheduler {
shard_count: 0,
attached_shard_count: 0,
may_schedule: node.may_schedule(),
az: node.get_availability_zone_id().clone(),
});
}
}
@@ -623,29 +406,6 @@ impl Scheduler {
node.and_then(|(node_id, may_schedule)| if may_schedule { Some(node_id) } else { None })
}
/// Compute a schedulling score for each node that the scheduler knows of
/// minus a set of hard excluded nodes.
fn compute_node_scores<Score>(
&mut self,
hard_exclude: &[NodeId],
preferred_az: &Option<AvailabilityZone>,
context: &ScheduleContext,
) -> Vec<Score>
where
Score: NodeSchedulingScore,
{
self.nodes
.iter_mut()
.filter_map(|(k, v)| {
if hard_exclude.contains(k) {
None
} else {
Score::generate(k, v, preferred_az, context)
}
})
.collect()
}
/// hard_exclude: it is forbidden to use nodes in this list, typically becacuse they
/// are already in use by this shard -- we use this to avoid picking the same node
/// as both attached and secondary location. This is a hard constraint: if we cannot
@@ -655,18 +415,29 @@ impl Scheduler {
/// to their anti-affinity score. We use this to prefeer to avoid placing shards in
/// the same tenant on the same node. This is a soft constraint: the context will never
/// cause us to fail to schedule a shard.
pub(crate) fn schedule_shard<Tag: ShardTag>(
pub(crate) fn schedule_shard(
&mut self,
hard_exclude: &[NodeId],
preferred_az: &Option<AvailabilityZone>,
context: &ScheduleContext,
) -> Result<NodeId, ScheduleError> {
if self.nodes.is_empty() {
return Err(ScheduleError::NoPageservers);
}
let mut scores =
self.compute_node_scores::<Tag::Score>(hard_exclude, preferred_az, context);
let mut scores: Vec<(NodeId, AffinityScore, u64, usize)> = self
.nodes
.iter_mut()
.filter_map(|(k, v)| match &mut v.may_schedule {
MaySchedule::No => None,
MaySchedule::Yes(_) if hard_exclude.contains(k) => None,
MaySchedule::Yes(utilization) => Some((
*k,
context.nodes.get(k).copied().unwrap_or(AffinityScore::FREE),
utilization.cached_score(),
v.attached_shard_count,
)),
})
.collect();
// Exclude nodes whose utilization is critically high, if there are alternatives available. This will
// cause us to violate affinity rules if it is necessary to avoid critically overloading nodes: for example
@@ -674,18 +445,20 @@ impl Scheduler {
// overloaded.
let non_overloaded_scores = scores
.iter()
.filter(|i| !i.is_overloaded())
.filter(|i| !PageserverUtilization::is_overloaded(i.2))
.copied()
.collect::<Vec<_>>();
if !non_overloaded_scores.is_empty() {
scores = non_overloaded_scores;
}
// Sort the nodes by score. The one with the lowest scores will be the preferred node.
// Refer to [`NodeAttachmentSchedulingScore`] for attached locations and
// [`NodeSecondarySchedulingScore`] for secondary locations to understand how the nodes
// are ranked.
scores.sort();
// Sort by, in order of precedence:
// 1st: Affinity score. We should never pick a higher-score node if a lower-score node is available
// 2nd: Utilization score (this combines shard count and disk utilization)
// 3rd: Attached shard count. When nodes have identical utilization (e.g. when populating some
// empty nodes), this acts as an anti-affinity between attached shards.
// 4th: Node ID. This is a convenience to make selection deterministic in tests and empty systems.
scores.sort_by_key(|i| (i.1, i.2, i.3, i.0));
if scores.is_empty() {
// After applying constraints, no pageservers were left.
@@ -708,12 +481,12 @@ impl Scheduler {
}
// Lowest score wins
let node_id = scores.first().unwrap().node_id();
let node_id = scores.first().unwrap().0;
if !matches!(context.mode, ScheduleMode::Speculative) {
tracing::info!(
"scheduler selected node {node_id} (elegible nodes {:?}, hard exclude: {hard_exclude:?}, soft exclude: {context:?})",
scores.iter().map(|i| i.node_id().0).collect::<Vec<_>>()
scores.iter().map(|i| i.0 .0).collect::<Vec<_>>()
);
}
@@ -723,12 +496,6 @@ impl Scheduler {
Ok(node_id)
}
/// Selects any available node. This is suitable for performing background work (e.g. S3
/// deletions).
pub(crate) fn any_available_node(&mut self) -> Result<NodeId, ScheduleError> {
self.schedule_shard::<AttachedShardTag>(&[], &None, &ScheduleContext::default())
}
/// Unit test access to internal state
#[cfg(test)]
pub(crate) fn get_node_shard_count(&self, node_id: NodeId) -> usize {
@@ -745,22 +512,13 @@ impl Scheduler {
pub(crate) mod test_utils {
use crate::node::Node;
use pageserver_api::{
controller_api::{AvailabilityZone, NodeAvailability},
models::utilization::test_utilization,
};
use pageserver_api::{controller_api::NodeAvailability, models::utilization::test_utilization};
use std::collections::HashMap;
use utils::id::NodeId;
/// Test helper: synthesize the requested number of nodes, all in active state.
///
/// Node IDs start at one.
///
/// The `azs` argument specifies the list of availability zones which will be assigned
/// to nodes in round-robin fashion. If empy, a default AZ is assigned.
pub(crate) fn make_test_nodes(n: u64, azs: &[AvailabilityZone]) -> HashMap<NodeId, Node> {
let mut az_iter = azs.iter().cycle();
pub(crate) fn make_test_nodes(n: u64) -> HashMap<NodeId, Node> {
(1..n + 1)
.map(|i| {
(NodeId(i), {
@@ -770,10 +528,7 @@ pub(crate) mod test_utils {
80 + i as u16,
format!("pghost-{i}"),
5432 + i as u16,
az_iter
.next()
.cloned()
.unwrap_or(AvailabilityZone("test-az".to_string())),
"test-az".to_string(),
);
node.set_availability(NodeAvailability::Active(test_utilization::simple(0, 0)));
assert!(node.is_available());
@@ -793,7 +548,7 @@ mod tests {
use crate::tenant_shard::IntentState;
#[test]
fn scheduler_basic() -> anyhow::Result<()> {
let nodes = test_utils::make_test_nodes(2, &[]);
let nodes = test_utils::make_test_nodes(2);
let mut scheduler = Scheduler::new(nodes.values());
let mut t1_intent = IntentState::new();
@@ -801,9 +556,9 @@ mod tests {
let context = ScheduleContext::default();
let scheduled = scheduler.schedule_shard::<AttachedShardTag>(&[], &None, &context)?;
let scheduled = scheduler.schedule_shard(&[], &context)?;
t1_intent.set_attached(&mut scheduler, Some(scheduled));
let scheduled = scheduler.schedule_shard::<AttachedShardTag>(&[], &None, &context)?;
let scheduled = scheduler.schedule_shard(&[], &context)?;
t2_intent.set_attached(&mut scheduler, Some(scheduled));
assert_eq!(scheduler.get_node_shard_count(NodeId(1)), 1);
@@ -812,11 +567,7 @@ mod tests {
assert_eq!(scheduler.get_node_shard_count(NodeId(2)), 1);
assert_eq!(scheduler.get_node_attached_shard_count(NodeId(2)), 1);
let scheduled = scheduler.schedule_shard::<AttachedShardTag>(
&t1_intent.all_pageservers(),
&None,
&context,
)?;
let scheduled = scheduler.schedule_shard(&t1_intent.all_pageservers(), &context)?;
t1_intent.push_secondary(&mut scheduler, scheduled);
assert_eq!(scheduler.get_node_shard_count(NodeId(1)), 1);
@@ -856,7 +607,7 @@ mod tests {
#[test]
/// Test the PageserverUtilization's contribution to scheduling algorithm
fn scheduler_utilization() {
let mut nodes = test_utils::make_test_nodes(3, &[]);
let mut nodes = test_utils::make_test_nodes(3);
let mut scheduler = Scheduler::new(nodes.values());
// Need to keep these alive because they contribute to shard counts via RAII
@@ -870,9 +621,7 @@ mod tests {
scheduler: &mut Scheduler,
context: &ScheduleContext,
) {
let scheduled = scheduler
.schedule_shard::<AttachedShardTag>(&[], &None, context)
.unwrap();
let scheduled = scheduler.schedule_shard(&[], context).unwrap();
let mut intent = IntentState::new();
intent.set_attached(scheduler, Some(scheduled));
scheduled_intents.push(intent);
@@ -980,98 +729,4 @@ mod tests {
intent.clear(&mut scheduler);
}
}
#[test]
/// A simple test that showcases AZ-aware scheduling and its interaction with
/// affinity scores.
fn az_scheduling() {
let az_a_tag = AvailabilityZone("az-a".to_string());
let az_b_tag = AvailabilityZone("az-b".to_string());
let nodes = test_utils::make_test_nodes(3, &[az_a_tag.clone(), az_b_tag.clone()]);
let mut scheduler = Scheduler::new(nodes.values());
// Need to keep these alive because they contribute to shard counts via RAII
let mut scheduled_intents = Vec::new();
let mut context = ScheduleContext::default();
fn assert_scheduler_chooses<Tag: ShardTag>(
expect_node: NodeId,
preferred_az: Option<AvailabilityZone>,
scheduled_intents: &mut Vec<IntentState>,
scheduler: &mut Scheduler,
context: &mut ScheduleContext,
) {
let scheduled = scheduler
.schedule_shard::<Tag>(&[], &preferred_az, context)
.unwrap();
let mut intent = IntentState::new();
intent.set_attached(scheduler, Some(scheduled));
scheduled_intents.push(intent);
assert_eq!(scheduled, expect_node);
context.avoid(&[scheduled]);
}
assert_scheduler_chooses::<AttachedShardTag>(
NodeId(1),
Some(az_a_tag.clone()),
&mut scheduled_intents,
&mut scheduler,
&mut context,
);
// Node 2 and 3 have affinity score equal to 0, but node 3
// is in "az-a" so we prefer that.
assert_scheduler_chooses::<AttachedShardTag>(
NodeId(3),
Some(az_a_tag.clone()),
&mut scheduled_intents,
&mut scheduler,
&mut context,
);
// Node 2 is not in "az-a", but it has the lowest affinity so we prefer that.
assert_scheduler_chooses::<AttachedShardTag>(
NodeId(2),
Some(az_a_tag.clone()),
&mut scheduled_intents,
&mut scheduler,
&mut context,
);
// Avoid nodes in "az-a" for the secondary location.
assert_scheduler_chooses::<SecondaryShardTag>(
NodeId(2),
Some(az_a_tag.clone()),
&mut scheduled_intents,
&mut scheduler,
&mut context,
);
// Avoid nodes in "az-b" for the secondary location.
// Nodes 1 and 3 are identically loaded, so prefer the lowest node id.
assert_scheduler_chooses::<SecondaryShardTag>(
NodeId(1),
Some(az_b_tag.clone()),
&mut scheduled_intents,
&mut scheduler,
&mut context,
);
// Avoid nodes in "az-b" for the secondary location.
// Node 3 has lower affinity score than 1, so prefer that.
assert_scheduler_chooses::<SecondaryShardTag>(
NodeId(3),
Some(az_b_tag.clone()),
&mut scheduled_intents,
&mut scheduler,
&mut context,
);
for mut intent in scheduled_intents {
intent.clear(&mut scheduler);
}
}
}

View File

@@ -1265,8 +1265,6 @@ impl Service {
#[cfg(feature = "testing")]
{
use pageserver_api::controller_api::AvailabilityZone;
// Hack: insert scheduler state for all nodes referenced by shards, as compatibility
// tests only store the shards, not the nodes. The nodes will be loaded shortly
// after when pageservers start up and register.
@@ -1284,7 +1282,7 @@ impl Service {
123,
"".to_string(),
123,
AvailabilityZone("test_az".to_string()),
"test_az".to_string(),
);
scheduler.node_upsert(&node);
@@ -2101,7 +2099,7 @@ impl Service {
let az_id = locked
.nodes
.get(&resp.node_id)
.map(|n| n.get_availability_zone_id().clone())?;
.map(|n| n.get_availability_zone_id().to_string())?;
Some((resp.shard_id, az_id))
})
@@ -2631,7 +2629,7 @@ impl Service {
let scheduler = &mut locked.scheduler;
// Right now we only perform the operation on a single node without parallelization
// TODO fan out the operation to multiple nodes for better performance
let node_id = scheduler.any_available_node()?;
let node_id = scheduler.schedule_shard(&[], &ScheduleContext::default())?;
let node = locked
.nodes
.get(&node_id)
@@ -2817,7 +2815,7 @@ impl Service {
// Pick an arbitrary node to use for remote deletions (does not have to be where the tenant
// was attached, just has to be able to see the S3 content)
let node_id = scheduler.any_available_node()?;
let node_id = scheduler.schedule_shard(&[], &ScheduleContext::default())?;
let node = nodes
.get(&node_id)
.expect("Pageservers may not be deleted while lock is active");
@@ -3508,66 +3506,34 @@ impl Service {
/// When you need to send an HTTP request to the pageserver that holds shard0 of a tenant, this
/// function looks up and returns node. If the tenant isn't found, returns Err(ApiError::NotFound)
pub(crate) async fn tenant_shard0_node(
pub(crate) fn tenant_shard0_node(
&self,
tenant_id: TenantId,
) -> Result<(Node, TenantShardId), ApiError> {
// Look up in-memory state and maybe use the node from there.
{
let locked = self.inner.read().unwrap();
let Some((tenant_shard_id, shard)) = locked
.tenants
.range(TenantShardId::tenant_range(tenant_id))
.next()
else {
return Err(ApiError::NotFound(
anyhow::anyhow!("Tenant {tenant_id} not found").into(),
));
};
let Some(intent_node_id) = shard.intent.get_attached() else {
tracing::warn!(
tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),
"Shard not scheduled (policy {:?}), cannot generate pass-through URL",
shard.policy
);
return Err(ApiError::Conflict(
"Cannot call timeline API on non-attached tenant".to_string(),
));
};
if shard.reconciler.is_none() {
// Optimization: while no reconcile is in flight, we may trust our in-memory state
// to tell us which pageserver to use. Otherwise we will fall through and hit the database
let Some(node) = locked.nodes.get(intent_node_id) else {
// This should never happen
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"Shard refers to nonexistent node"
)));
};
return Ok((node.clone(), *tenant_shard_id));
}
};
// Look up the latest attached pageserver location from the database
// generation state: this will reflect the progress of any ongoing migration.
// Note that it is not guaranteed to _stay_ here, our caller must still handle
// the case where they call through to the pageserver and get a 404.
let db_result = self.persistence.tenant_generations(tenant_id).await?;
let Some(ShardGenerationState {
tenant_shard_id,
generation: _,
generation_pageserver: Some(node_id),
}) = db_result.first()
let locked = self.inner.read().unwrap();
let Some((tenant_shard_id, shard)) = locked
.tenants
.range(TenantShardId::tenant_range(tenant_id))
.next()
else {
// This can happen if we raced with a tenant deletion or a shard split. On a retry
// the caller will either succeed (shard split case), get a proper 404 (deletion case),
// or a conflict response (case where tenant was detached in background)
return Err(ApiError::ResourceUnavailable(
"Shard {} not found in database, or is not attached".into(),
return Err(ApiError::NotFound(
anyhow::anyhow!("Tenant {tenant_id} not found").into(),
));
};
let locked = self.inner.read().unwrap();
// TODO: should use the ID last published to compute_hook, rather than the intent: the intent might
// point to somewhere we haven't attached yet.
let Some(node_id) = shard.intent.get_attached() else {
tracing::warn!(
tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),
"Shard not scheduled (policy {:?}), cannot generate pass-through URL",
shard.policy
);
return Err(ApiError::Conflict(
"Cannot call timeline API on non-attached tenant".to_string(),
));
};
let Some(node) = locked.nodes.get(node_id) else {
// This should never happen
return Err(ApiError::InternalServerError(anyhow::anyhow!(
@@ -4513,7 +4479,7 @@ impl Service {
let az_id = locked
.nodes
.get(node_id)
.map(|n| n.get_availability_zone_id().clone())?;
.map(|n| n.get_availability_zone_id().to_string())?;
Some((*tid, az_id))
})
@@ -5306,7 +5272,7 @@ impl Service {
}
AvailabilityTransition::ToActive => {
tracing::info!("Node {} transition to active", node_id);
// When a node comes back online, we must reconcile any tenant that has a None observed
// When a node comes back online, we must reconcile any non-detached tenant that has a None observed
// location on the node.
for tenant_shard in locked.tenants.values_mut() {
// If a reconciliation is already in progress, rely on the previous scheduling
@@ -5316,7 +5282,9 @@ impl Service {
}
if let Some(observed_loc) = tenant_shard.observed.locations.get_mut(&node_id) {
if observed_loc.conf.is_none() {
if observed_loc.conf.is_none()
&& !matches!(tenant_shard.policy, PlacementPolicy::Detached)
{
self.maybe_reconcile_shard(tenant_shard, &new_nodes);
}
}

View File

@@ -8,14 +8,11 @@ use crate::{
metrics::{self, ReconcileCompleteLabelGroup, ReconcileOutcome},
persistence::TenantShardPersistence,
reconciler::{ReconcileUnits, ReconcilerConfig},
scheduler::{
AffinityScore, AttachedShardTag, MaySchedule, RefCountUpdate, ScheduleContext,
SecondaryShardTag,
},
scheduler::{AffinityScore, MaySchedule, RefCountUpdate, ScheduleContext},
service::ReconcileResultRequest,
};
use pageserver_api::controller_api::{
AvailabilityZone, NodeSchedulingPolicy, PlacementPolicy, ShardSchedulingPolicy,
NodeSchedulingPolicy, PlacementPolicy, ShardSchedulingPolicy,
};
use pageserver_api::{
models::{LocationConfig, LocationConfigMode, TenantConfig},
@@ -146,7 +143,7 @@ pub(crate) struct TenantShard {
// We should attempt to schedule this shard in the provided AZ to
// decrease chances of cross-AZ compute.
preferred_az_id: Option<AvailabilityZone>,
preferred_az_id: Option<String>,
}
#[derive(Default, Clone, Debug, Serialize)]
@@ -338,19 +335,19 @@ pub(crate) enum ReconcileWaitError {
Failed(TenantShardId, Arc<ReconcileError>),
}
#[derive(Eq, PartialEq, Debug, Clone)]
#[derive(Eq, PartialEq, Debug)]
pub(crate) struct ReplaceSecondary {
old_node_id: NodeId,
new_node_id: NodeId,
}
#[derive(Eq, PartialEq, Debug, Clone)]
#[derive(Eq, PartialEq, Debug)]
pub(crate) struct MigrateAttachment {
pub(crate) old_attached_node_id: NodeId,
pub(crate) new_attached_node_id: NodeId,
}
#[derive(Eq, PartialEq, Debug, Clone)]
#[derive(Eq, PartialEq, Debug)]
pub(crate) enum ScheduleOptimizationAction {
// Replace one of our secondary locations with a different node
ReplaceSecondary(ReplaceSecondary),
@@ -358,7 +355,7 @@ pub(crate) enum ScheduleOptimizationAction {
MigrateAttachment(MigrateAttachment),
}
#[derive(Eq, PartialEq, Debug, Clone)]
#[derive(Eq, PartialEq, Debug)]
pub(crate) struct ScheduleOptimization {
// What was the reconcile sequence when we generated this optimization? The optimization
// should only be applied if the shard's sequence is still at this value, in case other changes
@@ -540,22 +537,13 @@ impl TenantShard {
Ok((true, promote_secondary))
} else {
// Pick a fresh node: either we had no secondaries or none were schedulable
let node_id = scheduler.schedule_shard::<AttachedShardTag>(
&self.intent.secondary,
&self.preferred_az_id,
context,
)?;
let node_id = scheduler.schedule_shard(&self.intent.secondary, context)?;
tracing::debug!("Selected {} as attached", node_id);
self.intent.set_attached(scheduler, Some(node_id));
Ok((true, node_id))
}
}
#[instrument(skip_all, fields(
tenant_id=%self.tenant_shard_id.tenant_id,
shard_id=%self.tenant_shard_id.shard_slug(),
sequence=%self.sequence
))]
pub(crate) fn schedule(
&mut self,
scheduler: &mut Scheduler,
@@ -625,11 +613,7 @@ impl TenantShard {
let mut used_pageservers = vec![attached_node_id];
while self.intent.secondary.len() < secondary_count {
let node_id = scheduler.schedule_shard::<SecondaryShardTag>(
&used_pageservers,
&self.preferred_az_id,
context,
)?;
let node_id = scheduler.schedule_shard(&used_pageservers, context)?;
self.intent.push_secondary(scheduler, node_id);
used_pageservers.push(node_id);
modified = true;
@@ -642,11 +626,7 @@ impl TenantShard {
modified = true;
} else if self.intent.secondary.is_empty() {
// Populate secondary by scheduling a fresh node
let node_id = scheduler.schedule_shard::<SecondaryShardTag>(
&[],
&self.preferred_az_id,
context,
)?;
let node_id = scheduler.schedule_shard(&[], context)?;
self.intent.push_secondary(scheduler, node_id);
modified = true;
}
@@ -823,11 +803,9 @@ impl TenantShard {
// Let the scheduler suggest a node, where it would put us if we were scheduling afresh
// This implicitly limits the choice to nodes that are available, and prefers nodes
// with lower utilization.
let Ok(candidate_node) = scheduler.schedule_shard::<SecondaryShardTag>(
&self.intent.all_pageservers(),
&self.preferred_az_id,
schedule_context,
) else {
let Ok(candidate_node) =
scheduler.schedule_shard(&self.intent.all_pageservers(), schedule_context)
else {
// A scheduling error means we have no possible candidate replacements
continue;
};
@@ -1324,7 +1302,7 @@ impl TenantShard {
pending_compute_notification: false,
delayed_reconcile: false,
scheduling_policy: serde_json::from_str(&tsp.scheduling_policy).unwrap(),
preferred_az_id: tsp.preferred_az_id.map(AvailabilityZone),
preferred_az_id: tsp.preferred_az_id,
})
}
@@ -1340,28 +1318,25 @@ impl TenantShard {
config: serde_json::to_string(&self.config).unwrap(),
splitting: SplitState::default(),
scheduling_policy: serde_json::to_string(&self.scheduling_policy).unwrap(),
preferred_az_id: self.preferred_az_id.as_ref().map(|az| az.0.clone()),
preferred_az_id: self.preferred_az_id.clone(),
}
}
pub(crate) fn preferred_az(&self) -> Option<&AvailabilityZone> {
self.preferred_az_id.as_ref()
pub(crate) fn preferred_az(&self) -> Option<&str> {
self.preferred_az_id.as_deref()
}
pub(crate) fn set_preferred_az(&mut self, preferred_az_id: AvailabilityZone) {
pub(crate) fn set_preferred_az(&mut self, preferred_az_id: String) {
self.preferred_az_id = Some(preferred_az_id);
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::{cell::RefCell, rc::Rc};
use pageserver_api::{
controller_api::NodeAvailability,
shard::{ShardCount, ShardNumber},
};
use rand::{rngs::StdRng, SeedableRng};
use utils::id::TenantId;
use crate::scheduler::test_utils::make_test_nodes;
@@ -1390,11 +1365,7 @@ pub(crate) mod tests {
)
}
fn make_test_tenant(
policy: PlacementPolicy,
shard_count: ShardCount,
preferred_az: Option<AvailabilityZone>,
) -> Vec<TenantShard> {
fn make_test_tenant(policy: PlacementPolicy, shard_count: ShardCount) -> Vec<TenantShard> {
let tenant_id = TenantId::generate();
(0..shard_count.count())
@@ -1406,7 +1377,7 @@ pub(crate) mod tests {
shard_number,
shard_count,
};
let mut ts = TenantShard::new(
TenantShard::new(
tenant_shard_id,
ShardIdentity::new(
shard_number,
@@ -1415,13 +1386,7 @@ pub(crate) mod tests {
)
.unwrap(),
policy.clone(),
);
if let Some(az) = &preferred_az {
ts.set_preferred_az(az.clone());
}
ts
)
})
.collect()
}
@@ -1432,7 +1397,7 @@ pub(crate) mod tests {
fn tenant_ha_scheduling() -> anyhow::Result<()> {
// Start with three nodes. Our tenant will only use two. The third one is
// expected to remain unused.
let mut nodes = make_test_nodes(3, &[]);
let mut nodes = make_test_nodes(3);
let mut scheduler = Scheduler::new(nodes.values());
let mut context = ScheduleContext::default();
@@ -1484,7 +1449,7 @@ pub(crate) mod tests {
#[test]
fn intent_from_observed() -> anyhow::Result<()> {
let nodes = make_test_nodes(3, &[]);
let nodes = make_test_nodes(3);
let mut scheduler = Scheduler::new(nodes.values());
let mut tenant_shard = make_test_tenant_shard(PlacementPolicy::Attached(1));
@@ -1534,7 +1499,7 @@ pub(crate) mod tests {
#[test]
fn scheduling_mode() -> anyhow::Result<()> {
let nodes = make_test_nodes(3, &[]);
let nodes = make_test_nodes(3);
let mut scheduler = Scheduler::new(nodes.values());
let mut tenant_shard = make_test_tenant_shard(PlacementPolicy::Attached(1));
@@ -1559,7 +1524,7 @@ pub(crate) mod tests {
#[test]
fn optimize_attachment() -> anyhow::Result<()> {
let nodes = make_test_nodes(3, &[]);
let nodes = make_test_nodes(3);
let mut scheduler = Scheduler::new(nodes.values());
let mut shard_a = make_test_tenant_shard(PlacementPolicy::Attached(1));
@@ -1626,7 +1591,7 @@ pub(crate) mod tests {
#[test]
fn optimize_secondary() -> anyhow::Result<()> {
let nodes = make_test_nodes(4, &[]);
let nodes = make_test_nodes(4);
let mut scheduler = Scheduler::new(nodes.values());
let mut shard_a = make_test_tenant_shard(PlacementPolicy::Attached(1));
@@ -1672,14 +1637,12 @@ pub(crate) mod tests {
// Optimize til quiescent: this emulates what Service::optimize_all does, when
// called repeatedly in the background.
// Returns the applied optimizations
fn optimize_til_idle(
nodes: &HashMap<NodeId, Node>,
scheduler: &mut Scheduler,
shards: &mut [TenantShard],
) -> Vec<ScheduleOptimization> {
) {
let mut loop_n = 0;
let mut optimizations = Vec::default();
loop {
let mut schedule_context = ScheduleContext::default();
let mut any_changed = false;
@@ -1694,7 +1657,6 @@ pub(crate) mod tests {
for shard in shards.iter_mut() {
let optimization = shard.optimize_attachment(nodes, &schedule_context);
if let Some(optimization) = optimization {
optimizations.push(optimization.clone());
shard.apply_optimization(scheduler, optimization);
any_changed = true;
break;
@@ -1702,7 +1664,6 @@ pub(crate) mod tests {
let optimization = shard.optimize_secondary(scheduler, &schedule_context);
if let Some(optimization) = optimization {
optimizations.push(optimization.clone());
shard.apply_optimization(scheduler, optimization);
any_changed = true;
break;
@@ -1717,22 +1678,20 @@ pub(crate) mod tests {
loop_n += 1;
assert!(loop_n < 1000);
}
optimizations
}
/// Test the balancing behavior of shard scheduling: that it achieves a balance, and
/// that it converges.
#[test]
fn optimize_add_nodes() -> anyhow::Result<()> {
let nodes = make_test_nodes(4, &[]);
let nodes = make_test_nodes(4);
// Only show the scheduler a couple of nodes
let mut scheduler = Scheduler::new([].iter());
scheduler.node_upsert(nodes.get(&NodeId(1)).unwrap());
scheduler.node_upsert(nodes.get(&NodeId(2)).unwrap());
let mut shards = make_test_tenant(PlacementPolicy::Attached(1), ShardCount::new(4), None);
let mut shards = make_test_tenant(PlacementPolicy::Attached(1), ShardCount::new(4));
let mut schedule_context = ScheduleContext::default();
for shard in &mut shards {
assert!(shard
@@ -1771,191 +1730,4 @@ pub(crate) mod tests {
Ok(())
}
/// Test that initial shard scheduling is optimal. By optimal we mean
/// that the optimizer cannot find a way to improve it.
///
/// This test is an example of the scheduling issue described in
/// https://github.com/neondatabase/neon/issues/8969
#[test]
fn initial_scheduling_is_optimal() -> anyhow::Result<()> {
use itertools::Itertools;
let nodes = make_test_nodes(2, &[]);
let mut scheduler = Scheduler::new([].iter());
scheduler.node_upsert(nodes.get(&NodeId(1)).unwrap());
scheduler.node_upsert(nodes.get(&NodeId(2)).unwrap());
let mut a = make_test_tenant(PlacementPolicy::Attached(1), ShardCount::new(4), None);
let a_context = Rc::new(RefCell::new(ScheduleContext::default()));
let mut b = make_test_tenant(PlacementPolicy::Attached(1), ShardCount::new(4), None);
let b_context = Rc::new(RefCell::new(ScheduleContext::default()));
let a_shards_with_context = a.iter_mut().map(|shard| (shard, a_context.clone()));
let b_shards_with_context = b.iter_mut().map(|shard| (shard, b_context.clone()));
let schedule_order = a_shards_with_context.interleave(b_shards_with_context);
for (shard, context) in schedule_order {
let context = &mut *context.borrow_mut();
shard.schedule(&mut scheduler, context).unwrap();
}
let applied_to_a = optimize_til_idle(&nodes, &mut scheduler, &mut a);
assert_eq!(applied_to_a, vec![]);
let applied_to_b = optimize_til_idle(&nodes, &mut scheduler, &mut b);
assert_eq!(applied_to_b, vec![]);
for shard in a.iter_mut().chain(b.iter_mut()) {
shard.intent.clear(&mut scheduler);
}
Ok(())
}
#[test]
fn random_az_shard_scheduling() -> anyhow::Result<()> {
use rand::seq::SliceRandom;
for seed in 0..50 {
eprintln!("Running test with seed {seed}");
let mut rng = StdRng::seed_from_u64(seed);
let az_a_tag = AvailabilityZone("az-a".to_string());
let az_b_tag = AvailabilityZone("az-b".to_string());
let azs = [az_a_tag, az_b_tag];
let nodes = make_test_nodes(4, &azs);
let mut shards_per_az: HashMap<AvailabilityZone, u32> = HashMap::new();
let mut scheduler = Scheduler::new([].iter());
for node in nodes.values() {
scheduler.node_upsert(node);
}
let mut shards = Vec::default();
let mut contexts = Vec::default();
let mut az_picker = azs.iter().cycle().cloned();
for i in 0..100 {
let az = az_picker.next().unwrap();
let shard_count = i % 4 + 1;
*shards_per_az.entry(az.clone()).or_default() += shard_count;
let tenant_shards = make_test_tenant(
PlacementPolicy::Attached(1),
ShardCount::new(shard_count.try_into().unwrap()),
Some(az),
);
let context = Rc::new(RefCell::new(ScheduleContext::default()));
contexts.push(context.clone());
let with_ctx = tenant_shards
.into_iter()
.map(|shard| (shard, context.clone()));
for shard_with_ctx in with_ctx {
shards.push(shard_with_ctx);
}
}
shards.shuffle(&mut rng);
#[derive(Default, Debug)]
struct NodeStats {
attachments: u32,
secondaries: u32,
}
let mut node_stats: HashMap<NodeId, NodeStats> = HashMap::default();
let mut attachments_in_wrong_az = 0;
let mut secondaries_in_wrong_az = 0;
for (shard, context) in &mut shards {
let context = &mut *context.borrow_mut();
shard.schedule(&mut scheduler, context).unwrap();
let attached_node = shard.intent.get_attached().unwrap();
let stats = node_stats.entry(attached_node).or_default();
stats.attachments += 1;
let secondary_node = *shard.intent.get_secondary().first().unwrap();
let stats = node_stats.entry(secondary_node).or_default();
stats.secondaries += 1;
let attached_node_az = nodes
.get(&attached_node)
.unwrap()
.get_availability_zone_id();
let secondary_node_az = nodes
.get(&secondary_node)
.unwrap()
.get_availability_zone_id();
let preferred_az = shard.preferred_az().unwrap();
if attached_node_az != preferred_az {
eprintln!(
"{} attachment was scheduled in AZ {} but preferred AZ {}",
shard.tenant_shard_id, attached_node_az, preferred_az
);
attachments_in_wrong_az += 1;
}
if secondary_node_az == preferred_az {
eprintln!(
"{} secondary was scheduled in AZ {} which matches preference",
shard.tenant_shard_id, attached_node_az
);
secondaries_in_wrong_az += 1;
}
}
let mut violations = Vec::default();
if attachments_in_wrong_az > 0 {
violations.push(format!(
"{} attachments scheduled to the incorrect AZ",
attachments_in_wrong_az
));
}
if secondaries_in_wrong_az > 0 {
violations.push(format!(
"{} secondaries scheduled to the incorrect AZ",
secondaries_in_wrong_az
));
}
eprintln!(
"attachments_in_wrong_az={} secondaries_in_wrong_az={}",
attachments_in_wrong_az, secondaries_in_wrong_az
);
for (node_id, stats) in &node_stats {
let node_az = nodes.get(node_id).unwrap().get_availability_zone_id();
let ideal_attachment_load = shards_per_az.get(node_az).unwrap() / 2;
let allowed_attachment_load =
(ideal_attachment_load - 1)..(ideal_attachment_load + 2);
if !allowed_attachment_load.contains(&stats.attachments) {
violations.push(format!(
"Found {} attachments on node {}, but expected {}",
stats.attachments, node_id, ideal_attachment_load
));
}
eprintln!(
"{}: attachments={} secondaries={} ideal_attachment_load={}",
node_id, stats.attachments, stats.secondaries, ideal_attachment_load
);
}
assert!(violations.is_empty(), "{violations:?}");
for (mut shard, _ctx) in shards {
shard.intent.clear(&mut scheduler);
}
}
Ok(())
}
}

View File

@@ -8,6 +8,7 @@ license.workspace = true
aws-sdk-s3.workspace = true
either.workspace = true
anyhow.workspace = true
git-version.workspace = true
hex.workspace = true
humantime.workspace = true
serde.workspace = true

View File

@@ -1,12 +1,13 @@
use std::collections::{HashMap, HashSet};
use anyhow::Context;
use itertools::Itertools;
use pageserver::tenant::checks::check_valid_layermap;
use pageserver::tenant::layer_map::LayerMap;
use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata;
use pageserver_api::shard::ShardIndex;
use tokio_util::sync::CancellationToken;
use tracing::{info, warn};
use tracing::{error, info, warn};
use utils::generation::Generation;
use utils::id::TimelineId;
@@ -28,8 +29,9 @@ pub(crate) struct TimelineAnalysis {
/// yet.
pub(crate) warnings: Vec<String>,
/// Objects whose keys were not recognized at all, i.e. not layer files, not indices, and not initdb archive.
pub(crate) unknown_keys: Vec<String>,
/// Keys not referenced in metadata: candidates for removal, but NOT NECESSARILY: beware
/// of races between reading the metadata and reading the objects.
pub(crate) garbage_keys: Vec<String>,
}
impl TimelineAnalysis {
@@ -37,7 +39,7 @@ impl TimelineAnalysis {
Self {
errors: Vec::new(),
warnings: Vec::new(),
unknown_keys: Vec::new(),
garbage_keys: Vec::new(),
}
}
@@ -57,7 +59,7 @@ pub(crate) async fn branch_cleanup_and_check_errors(
) -> TimelineAnalysis {
let mut result = TimelineAnalysis::new();
info!("Checking timeline");
info!("Checking timeline {id}");
if let Some(s3_active_branch) = s3_active_branch {
info!(
@@ -78,7 +80,7 @@ pub(crate) async fn branch_cleanup_and_check_errors(
match s3_data {
Some(s3_data) => {
result
.unknown_keys
.garbage_keys
.extend(s3_data.unknown_keys.into_iter().map(|k| k.key.to_string()));
match s3_data.blob_data {
@@ -202,10 +204,10 @@ pub(crate) async fn branch_cleanup_and_check_errors(
warn!("Timeline metadata warnings: {0:?}", result.warnings);
}
if !result.unknown_keys.is_empty() {
warn!(
"The following keys are not recognized: {0:?}",
result.unknown_keys
if !result.garbage_keys.is_empty() {
error!(
"The following keys should be removed from S3: {0:?}",
result.garbage_keys
)
}
@@ -292,10 +294,10 @@ impl TenantObjectListing {
pub(crate) struct RemoteTimelineBlobData {
pub(crate) blob_data: BlobDataParseResult,
/// Index objects that were not used when loading `blob_data`, e.g. those from old generations
// Index objects that were not used when loading `blob_data`, e.g. those from old generations
pub(crate) unused_index_keys: Vec<ListingObject>,
/// Objects whose keys were not recognized at all, i.e. not layer files, not indices
// Objects whose keys were not recognized at all, i.e. not layer files, not indices
pub(crate) unknown_keys: Vec<ListingObject>,
}
@@ -327,54 +329,11 @@ pub(crate) fn parse_layer_object_name(name: &str) -> Result<(LayerName, Generati
}
}
/// Note (<https://github.com/neondatabase/neon/issues/8872>):
/// Since we do not gurantee the order of the listing, we could list layer keys right before
/// pageserver `RemoteTimelineClient` deletes the layer files and then the index.
/// In the rare case, this would give back a transient error where the index key is missing.
///
/// To avoid generating false positive, we try streaming the listing for a second time.
pub(crate) async fn list_timeline_blobs(
remote_client: &GenericRemoteStorage,
id: TenantShardTimelineId,
root_target: &RootTarget,
) -> anyhow::Result<RemoteTimelineBlobData> {
let res = list_timeline_blobs_impl(remote_client, id, root_target).await?;
match res {
ListTimelineBlobsResult::Ready(data) => Ok(data),
ListTimelineBlobsResult::MissingIndexPart(_) => {
// Retry if index is missing.
let data = list_timeline_blobs_impl(remote_client, id, root_target)
.await?
.into_data();
Ok(data)
}
}
}
enum ListTimelineBlobsResult {
/// Blob data is ready to be intepreted.
Ready(RemoteTimelineBlobData),
/// List timeline blobs has layer files but is missing [`IndexPart`].
MissingIndexPart(RemoteTimelineBlobData),
}
impl ListTimelineBlobsResult {
/// Get the inner blob data regardless the status.
pub fn into_data(self) -> RemoteTimelineBlobData {
match self {
ListTimelineBlobsResult::Ready(data) => data,
ListTimelineBlobsResult::MissingIndexPart(data) => data,
}
}
}
/// Returns [`ListTimelineBlobsResult::MissingIndexPart`] if blob data has layer files
/// but is missing [`IndexPart`], otherwise returns [`ListTimelineBlobsResult::Ready`].
async fn list_timeline_blobs_impl(
remote_client: &GenericRemoteStorage,
id: TenantShardTimelineId,
root_target: &RootTarget,
) -> anyhow::Result<ListTimelineBlobsResult> {
let mut s3_layers = HashSet::new();
let mut errors = Vec::new();
@@ -416,28 +375,30 @@ async fn list_timeline_blobs_impl(
s3_layers.insert((new_layer, gen));
}
Err(e) => {
tracing::info!("Error parsing {maybe_layer_name} as layer name: {e}");
tracing::info!("Error parsing key {maybe_layer_name}");
errors.push(
format!("S3 list response got an object with key {key} that is not a layer name: {e}"),
);
unknown_keys.push(obj);
}
},
None => {
tracing::info!("S3 listed an unknown key: {key}");
tracing::warn!("Unknown key {key}");
errors.push(format!("S3 list response got an object with odd key {key}"));
unknown_keys.push(obj);
}
}
}
if index_part_keys.is_empty() && s3_layers.is_empty() {
tracing::debug!("Timeline is empty: expected post-deletion state.");
if initdb_archive {
tracing::info!("Timeline is post deletion but initdb archive is still present.");
}
return Ok(ListTimelineBlobsResult::Ready(RemoteTimelineBlobData {
if index_part_keys.is_empty() && s3_layers.is_empty() && initdb_archive {
tracing::debug!(
"Timeline is empty apart from initdb archive: expected post-deletion state."
);
return Ok(RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Relic,
unused_index_keys: index_part_keys,
unknown_keys,
}));
unknown_keys: Vec::new(),
});
}
// Choose the index_part with the highest generation
@@ -463,43 +424,19 @@ async fn list_timeline_blobs_impl(
match index_part_object.as_ref() {
Some(selected) => index_part_keys.retain(|k| k != selected),
None => {
// It is possible that the branch gets deleted after we got some layer files listed
// and we no longer have the index file in the listing.
errors.push(
"S3 list response got no index_part.json file but still has layer files"
.to_string(),
);
return Ok(ListTimelineBlobsResult::MissingIndexPart(
RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Incorrect { errors, s3_layers },
unused_index_keys: index_part_keys,
unknown_keys,
},
));
errors.push("S3 list response got no index_part.json file".to_string());
}
}
if let Some(index_part_object_key) = index_part_object.as_ref() {
let index_part_bytes =
match download_object_with_retries(remote_client, &index_part_object_key.key).await {
Ok(index_part_bytes) => index_part_bytes,
Err(e) => {
// It is possible that the branch gets deleted in-between we list the objects
// and we download the index part file.
errors.push(format!("failed to download index_part.json: {e}"));
return Ok(ListTimelineBlobsResult::MissingIndexPart(
RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Incorrect { errors, s3_layers },
unused_index_keys: index_part_keys,
unknown_keys,
},
));
}
};
download_object_with_retries(remote_client, &index_part_object_key.key)
.await
.context("index_part.json download")?;
match serde_json::from_slice(&index_part_bytes) {
Ok(index_part) => {
return Ok(ListTimelineBlobsResult::Ready(RemoteTimelineBlobData {
return Ok(RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Parsed {
index_part: Box::new(index_part),
index_part_generation,
@@ -507,7 +444,7 @@ async fn list_timeline_blobs_impl(
},
unused_index_keys: index_part_keys,
unknown_keys,
}))
})
}
Err(index_parse_error) => errors.push(format!(
"index_part.json body parsing error: {index_parse_error}"
@@ -521,9 +458,9 @@ async fn list_timeline_blobs_impl(
);
}
Ok(ListTimelineBlobsResult::Ready(RemoteTimelineBlobData {
Ok(RemoteTimelineBlobData {
blob_data: BlobDataParseResult::Incorrect { errors, s3_layers },
unused_index_keys: index_part_keys,
unknown_keys,
}))
})
}

View File

@@ -41,10 +41,6 @@ struct Cli {
#[arg(long)]
/// JWT token for authenticating with storage controller. Requires scope 'scrubber' or 'admin'.
controller_jwt: Option<String>,
/// If set to true, the scrubber will exit with error code on fatal error.
#[arg(long, default_value_t = false)]
exit_code: bool,
}
#[derive(Subcommand, Debug)]
@@ -207,7 +203,6 @@ async fn main() -> anyhow::Result<()> {
tenant_ids,
json,
post_to_storcon,
cli.exit_code,
)
.await
}
@@ -274,7 +269,6 @@ async fn main() -> anyhow::Result<()> {
gc_min_age,
gc_mode,
post_to_storcon,
cli.exit_code,
)
.await
}
@@ -290,7 +284,6 @@ pub async fn run_cron_job(
gc_min_age: humantime::Duration,
gc_mode: GcMode,
post_to_storcon: bool,
exit_code: bool,
) -> anyhow::Result<()> {
tracing::info!(%gc_min_age, %gc_mode, "Running pageserver-physical-gc");
pageserver_physical_gc_cmd(
@@ -308,7 +301,6 @@ pub async fn run_cron_job(
Vec::new(),
true,
post_to_storcon,
exit_code,
)
.await?;
@@ -357,7 +349,6 @@ pub async fn scan_pageserver_metadata_cmd(
tenant_shard_ids: Vec<TenantShardId>,
json: bool,
post_to_storcon: bool,
exit_code: bool,
) -> anyhow::Result<()> {
if controller_client.is_none() && post_to_storcon {
return Err(anyhow!("Posting pageserver scan health status to storage controller requires `--controller-api` and `--controller-jwt` to run"));
@@ -389,9 +380,6 @@ pub async fn scan_pageserver_metadata_cmd(
if summary.is_fatal() {
tracing::error!("Fatal scrub errors detected");
if exit_code {
std::process::exit(1);
}
} else if summary.is_empty() {
// Strictly speaking an empty bucket is a valid bucket, but if someone ran the
// scrubber they were likely expecting to scan something, and if we see no timelines
@@ -403,9 +391,6 @@ pub async fn scan_pageserver_metadata_cmd(
.prefix_in_bucket
.unwrap_or("<none>".to_string())
);
if exit_code {
std::process::exit(1);
}
}
Ok(())

Some files were not shown because too many files have changed in this diff Show More