mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-03 18:50:38 +00:00
Compare commits
27 Commits
installed_
...
proxy-http
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe8b93ab9d | ||
|
|
7e3e7f1cca | ||
|
|
0b0ed662d9 | ||
|
|
50bd65769f | ||
|
|
90534b1745 | ||
|
|
99d52df475 | ||
|
|
ab5bbb445b | ||
|
|
5ef805e12c | ||
|
|
091a175a3e | ||
|
|
326cd80f0d | ||
|
|
6baf1aae33 | ||
|
|
184935619e | ||
|
|
b2ecbf3e80 | ||
|
|
53147b51f9 | ||
|
|
006d9dfb6b | ||
|
|
1f7904c917 | ||
|
|
07c714343f | ||
|
|
264c34dfb7 | ||
|
|
9dd80b9b4c | ||
|
|
c2623ffef4 | ||
|
|
426b1c5f08 | ||
|
|
306094a87d | ||
|
|
d3464584a6 | ||
|
|
878135fe9c | ||
|
|
75434060a5 | ||
|
|
721803a0e7 | ||
|
|
108a211917 |
@@ -218,6 +218,9 @@ runs:
|
||||
name: compatibility-snapshot-${{ runner.arch }}-${{ inputs.build_type }}-pg${{ inputs.pg_version }}
|
||||
# Directory is created by test_compatibility.py::test_create_snapshot, keep the path in sync with the test
|
||||
path: /tmp/test_output/compatibility_snapshot_pg${{ inputs.pg_version }}/
|
||||
# The lack of compatibility snapshot shouldn't fail the job
|
||||
# (for example if we didn't run the test for non build-and-test workflow)
|
||||
skip-if-does-not-exist: true
|
||||
|
||||
- name: Upload test results
|
||||
if: ${{ !cancelled() }}
|
||||
|
||||
18
.github/actions/upload/action.yml
vendored
18
.github/actions/upload/action.yml
vendored
@@ -7,6 +7,10 @@ inputs:
|
||||
path:
|
||||
description: "A directory or file to upload"
|
||||
required: true
|
||||
skip-if-does-not-exist:
|
||||
description: "Allow to skip if path doesn't exist, fail otherwise"
|
||||
default: false
|
||||
required: false
|
||||
prefix:
|
||||
description: "S3 prefix. Default is '${GITHUB_SHA}/${GITHUB_RUN_ID}/${GITHUB_RUN_ATTEMPT}'"
|
||||
required: false
|
||||
@@ -15,10 +19,12 @@ runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Prepare artifact
|
||||
id: prepare-artifact
|
||||
shell: bash -euxo pipefail {0}
|
||||
env:
|
||||
SOURCE: ${{ inputs.path }}
|
||||
ARCHIVE: /tmp/uploads/${{ inputs.name }}.tar.zst
|
||||
SKIP_IF_DOES_NOT_EXIST: ${{ inputs.skip-if-does-not-exist }}
|
||||
run: |
|
||||
mkdir -p $(dirname $ARCHIVE)
|
||||
|
||||
@@ -33,14 +39,22 @@ runs:
|
||||
elif [ -f ${SOURCE} ]; then
|
||||
time tar -cf ${ARCHIVE} --zstd ${SOURCE}
|
||||
elif ! ls ${SOURCE} > /dev/null 2>&1; then
|
||||
echo >&2 "${SOURCE} does not exist"
|
||||
exit 2
|
||||
if [ "${SKIP_IF_DOES_NOT_EXIST}" = "true" ]; then
|
||||
echo 'SKIPPED=true' >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
else
|
||||
echo >&2 "${SOURCE} does not exist"
|
||||
exit 2
|
||||
fi
|
||||
else
|
||||
echo >&2 "${SOURCE} is neither a directory nor a file, do not know how to handle it"
|
||||
exit 3
|
||||
fi
|
||||
|
||||
echo 'SKIPPED=false' >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Upload artifact
|
||||
if: ${{ steps.prepare-artifact.outputs.SKIPPED == 'false' }}
|
||||
shell: bash -euxo pipefail {0}
|
||||
env:
|
||||
SOURCE: ${{ inputs.path }}
|
||||
|
||||
19
.github/workflows/build_and_test.yml
vendored
19
.github/workflows/build_and_test.yml
vendored
@@ -193,16 +193,15 @@ jobs:
|
||||
with:
|
||||
submodules: true
|
||||
|
||||
# Disabled for now
|
||||
# - name: Restore cargo deps cache
|
||||
# id: cache_cargo
|
||||
# uses: actions/cache@v4
|
||||
# with:
|
||||
# path: |
|
||||
# !~/.cargo/registry/src
|
||||
# ~/.cargo/git/
|
||||
# target/
|
||||
# key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-clippy-${{ hashFiles('rust-toolchain.toml') }}-${{ hashFiles('Cargo.lock') }}
|
||||
- name: Cache cargo deps
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
!~/.cargo/registry/src
|
||||
~/.cargo/git
|
||||
target
|
||||
key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-${{ hashFiles('./Cargo.lock') }}-${{ hashFiles('./rust-toolchain.toml') }}-rust
|
||||
|
||||
# Some of our rust modules use FFI and need those to be checked
|
||||
- name: Get postgres headers
|
||||
|
||||
41
.github/workflows/report-workflow-stats.yml
vendored
Normal file
41
.github/workflows/report-workflow-stats.yml
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
name: Report Workflow Stats
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows:
|
||||
- Add `external` label to issues and PRs created by external users
|
||||
- Benchmarking
|
||||
- Build and Test
|
||||
- Build and Test Locally
|
||||
- Build build-tools image
|
||||
- Check Permissions
|
||||
- Check build-tools image
|
||||
- Check neon with extra platform builds
|
||||
- Cloud Regression Test
|
||||
- Create Release Branch
|
||||
- Handle `approved-for-ci-run` label
|
||||
- Lint GitHub Workflows
|
||||
- Notify Slack channel about upcoming release
|
||||
- Periodic pagebench performance test on dedicated EC2 machine in eu-central-1 region
|
||||
- Pin build-tools image
|
||||
- Prepare benchmarking databases by restoring dumps
|
||||
- Push images to ACR
|
||||
- Test Postgres client libraries
|
||||
- Trigger E2E Tests
|
||||
- cleanup caches by a branch
|
||||
types: [completed]
|
||||
|
||||
jobs:
|
||||
gh-workflow-stats:
|
||||
name: Github Workflow Stats
|
||||
runs-on: ubuntu-22.04
|
||||
permissions:
|
||||
actions: read
|
||||
steps:
|
||||
- name: Export GH Workflow Stats
|
||||
uses: neondatabase/gh-workflow-stats-action@v0.1.4
|
||||
with:
|
||||
DB_URI: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }}
|
||||
DB_TABLE: "gh_workflow_stats_neon"
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GH_RUN_ID: ${{ github.event.workflow_run.id }}
|
||||
@@ -1,5 +1,6 @@
|
||||
/compute_tools/ @neondatabase/control-plane @neondatabase/compute
|
||||
/storage_controller @neondatabase/storage
|
||||
/storage_scrubber @neondatabase/storage
|
||||
/libs/pageserver_api/ @neondatabase/storage
|
||||
/libs/postgres_ffi/ @neondatabase/compute @neondatabase/storage
|
||||
/libs/remote_storage/ @neondatabase/storage
|
||||
|
||||
21
Cargo.lock
generated
21
Cargo.lock
generated
@@ -1228,15 +1228,12 @@ dependencies = [
|
||||
"flate2",
|
||||
"futures",
|
||||
"hyper 0.14.30",
|
||||
"metrics",
|
||||
"nix 0.27.1",
|
||||
"notify",
|
||||
"num_cpus",
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"opentelemetry_sdk",
|
||||
"postgres",
|
||||
"prometheus",
|
||||
"regex",
|
||||
"remote_storage",
|
||||
"reqwest 0.12.4",
|
||||
@@ -1823,6 +1820,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47"
|
||||
dependencies = [
|
||||
"base16ct 0.2.0",
|
||||
"base64ct",
|
||||
"crypto-bigint 0.5.5",
|
||||
"digest",
|
||||
"ff 0.13.0",
|
||||
@@ -1832,6 +1830,8 @@ dependencies = [
|
||||
"pkcs8 0.10.2",
|
||||
"rand_core 0.6.4",
|
||||
"sec1 0.7.3",
|
||||
"serde_json",
|
||||
"serdect",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
@@ -4040,6 +4040,8 @@ dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5259,6 +5261,7 @@ dependencies = [
|
||||
"der 0.7.8",
|
||||
"generic-array",
|
||||
"pkcs8 0.10.2",
|
||||
"serdect",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
@@ -5513,6 +5516,16 @@ dependencies = [
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serdect"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a84f14a19e9a014bb9f4512488d9829a68e04ecabffb0f9904cd1ace94598177"
|
||||
dependencies = [
|
||||
"base16ct 0.2.0",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha1"
|
||||
version = "0.10.5"
|
||||
@@ -7305,6 +7318,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"parquet",
|
||||
"postgres-types",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"prost",
|
||||
@@ -7329,6 +7343,7 @@ dependencies = [
|
||||
"time",
|
||||
"time-macros",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"toml_edit",
|
||||
|
||||
@@ -109,13 +109,30 @@ RUN apt update && \
|
||||
libcgal-dev libgdal-dev libgmp-dev libmpfr-dev libopenscenegraph-dev libprotobuf-c-dev \
|
||||
protobuf-c-compiler xsltproc
|
||||
|
||||
|
||||
# Postgis 3.5.0 requires SFCGAL 1.4+
|
||||
#
|
||||
# It would be nice to update all versions together, but we must solve the SFCGAL dependency first.
|
||||
# SFCGAL > 1.3 requires CGAL > 5.2, Bullseye's libcgal-dev is 5.2
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
mkdir -p /sfcgal && \
|
||||
echo "Postgis doensn't yet support PG17 (needs 3.4.3, if not higher)" && exit 0;; \
|
||||
# and also we must check backward compatibility with older versions of PostGIS.
|
||||
#
|
||||
# Use new version only for v17
|
||||
RUN case "${PG_VERSION}" in \
|
||||
"v17") \
|
||||
export SFCGAL_VERSION=1.4.1 \
|
||||
export SFCGAL_CHECKSUM=1800c8a26241588f11cddcf433049e9b9aea902e923414d2ecef33a3295626c3 \
|
||||
;; \
|
||||
"v14" | "v15" | "v16") \
|
||||
export SFCGAL_VERSION=1.3.10 \
|
||||
export SFCGAL_CHECKSUM=4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 \
|
||||
;; \
|
||||
*) \
|
||||
echo "unexpected PostgreSQL version" && exit 1 \
|
||||
;; \
|
||||
esac && \
|
||||
wget https://gitlab.com/Oslandia/SFCGAL/-/archive/v1.3.10/SFCGAL-v1.3.10.tar.gz -O SFCGAL.tar.gz && \
|
||||
echo "4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 SFCGAL.tar.gz" | sha256sum --check && \
|
||||
mkdir -p /sfcgal && \
|
||||
wget https://gitlab.com/sfcgal/SFCGAL/-/archive/v${SFCGAL_VERSION}/SFCGAL-v${SFCGAL_VERSION}.tar.gz -O SFCGAL.tar.gz && \
|
||||
echo "${SFCGAL_CHECKSUM} SFCGAL.tar.gz" | sha256sum --check && \
|
||||
mkdir sfcgal-src && cd sfcgal-src && tar xzf ../SFCGAL.tar.gz --strip-components=1 -C . && \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release . && make -j $(getconf _NPROCESSORS_ONLN) && \
|
||||
DESTDIR=/sfcgal make install -j $(getconf _NPROCESSORS_ONLN) && \
|
||||
@@ -123,15 +140,27 @@ RUN case "${PG_VERSION}" in "v17") \
|
||||
|
||||
ENV PATH="/usr/local/pgsql/bin:$PATH"
|
||||
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "Postgis doensn't yet support PG17 (needs 3.4.3, if not higher)" && exit 0;; \
|
||||
# Postgis 3.5.0 supports v17
|
||||
RUN case "${PG_VERSION}" in \
|
||||
"v17") \
|
||||
export POSTGIS_VERSION=3.5.0 \
|
||||
export POSTGIS_CHECKSUM=ca698a22cc2b2b3467ac4e063b43a28413f3004ddd505bdccdd74c56a647f510 \
|
||||
;; \
|
||||
"v14" | "v15" | "v16") \
|
||||
export POSTGIS_VERSION=3.3.3 \
|
||||
export POSTGIS_CHECKSUM=74eb356e3f85f14233791013360881b6748f78081cc688ff9d6f0f673a762d13 \
|
||||
;; \
|
||||
*) \
|
||||
echo "unexpected PostgreSQL version" && exit 1 \
|
||||
;; \
|
||||
esac && \
|
||||
wget https://download.osgeo.org/postgis/source/postgis-3.3.3.tar.gz -O postgis.tar.gz && \
|
||||
echo "74eb356e3f85f14233791013360881b6748f78081cc688ff9d6f0f673a762d13 postgis.tar.gz" | sha256sum --check && \
|
||||
wget https://download.osgeo.org/postgis/source/postgis-${POSTGIS_VERSION}.tar.gz -O postgis.tar.gz && \
|
||||
echo "${POSTGIS_CHECKSUM} postgis.tar.gz" | sha256sum --check && \
|
||||
mkdir postgis-src && cd postgis-src && tar xzf ../postgis.tar.gz --strip-components=1 -C . && \
|
||||
find /usr/local/pgsql -type f | sed 's|^/usr/local/pgsql/||' > /before.txt &&\
|
||||
./autogen.sh && \
|
||||
./configure --with-sfcgal=/usr/local/bin/sfcgal-config && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install && \
|
||||
cd extensions/postgis && \
|
||||
make clean && \
|
||||
@@ -152,11 +181,27 @@ RUN case "${PG_VERSION}" in "v17") \
|
||||
cp /usr/local/pgsql/share/extension/address_standardizer.control /extensions/postgis && \
|
||||
cp /usr/local/pgsql/share/extension/address_standardizer_data_us.control /extensions/postgis
|
||||
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
|
||||
# Uses versioned libraries, i.e. libpgrouting-3.4
|
||||
# and may introduce function signature changes between releases
|
||||
# i.e. release 3.5.0 has new signature for pg_dijkstra function
|
||||
#
|
||||
# Use new version only for v17
|
||||
# last release v3.6.2 - Mar 30, 2024
|
||||
RUN case "${PG_VERSION}" in \
|
||||
"v17") \
|
||||
export PGROUTING_VERSION=3.6.2 \
|
||||
export PGROUTING_CHECKSUM=f4a1ed79d6f714e52548eca3bb8e5593c6745f1bde92eb5fb858efd8984dffa2 \
|
||||
;; \
|
||||
"v14" | "v15" | "v16") \
|
||||
export PGROUTING_VERSION=3.4.2 \
|
||||
export PGROUTING_CHECKSUM=cac297c07d34460887c4f3b522b35c470138760fe358e351ad1db4edb6ee306e \
|
||||
;; \
|
||||
*) \
|
||||
echo "unexpected PostgreSQL version" && exit 1 \
|
||||
;; \
|
||||
esac && \
|
||||
wget https://github.com/pgRouting/pgrouting/archive/v3.4.2.tar.gz -O pgrouting.tar.gz && \
|
||||
echo "cac297c07d34460887c4f3b522b35c470138760fe358e351ad1db4edb6ee306e pgrouting.tar.gz" | sha256sum --check && \
|
||||
wget https://github.com/pgRouting/pgrouting/archive/v${PGROUTING_VERSION}.tar.gz -O pgrouting.tar.gz && \
|
||||
echo "${PGROUTING_CHECKSUM} pgrouting.tar.gz" | sha256sum --check && \
|
||||
mkdir pgrouting-src && cd pgrouting-src && tar xzf ../pgrouting.tar.gz --strip-components=1 -C . && \
|
||||
mkdir build && cd build && \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release .. && \
|
||||
@@ -215,10 +260,9 @@ FROM build-deps AS h3-pg-build
|
||||
ARG PG_VERSION
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
mkdir -p /h3/usr/ && \
|
||||
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
|
||||
esac && \
|
||||
# not version-specific
|
||||
# last release v4.1.0 - Jan 18, 2023
|
||||
RUN mkdir -p /h3/usr/ && \
|
||||
wget https://github.com/uber/h3/archive/refs/tags/v4.1.0.tar.gz -O h3.tar.gz && \
|
||||
echo "ec99f1f5974846bde64f4513cf8d2ea1b8d172d2218ab41803bf6a63532272bc h3.tar.gz" | sha256sum --check && \
|
||||
mkdir h3-src && cd h3-src && tar xzf ../h3.tar.gz --strip-components=1 -C . && \
|
||||
@@ -229,10 +273,9 @@ RUN case "${PG_VERSION}" in "v17") \
|
||||
cp -R /h3/usr / && \
|
||||
rm -rf build
|
||||
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
|
||||
esac && \
|
||||
wget https://github.com/zachasme/h3-pg/archive/refs/tags/v4.1.3.tar.gz -O h3-pg.tar.gz && \
|
||||
# not version-specific
|
||||
# last release v4.1.3 - Jul 26, 2023
|
||||
RUN wget https://github.com/zachasme/h3-pg/archive/refs/tags/v4.1.3.tar.gz -O h3-pg.tar.gz && \
|
||||
echo "5c17f09a820859ffe949f847bebf1be98511fb8f1bd86f94932512c00479e324 h3-pg.tar.gz" | sha256sum --check && \
|
||||
mkdir h3-pg-src && cd h3-pg-src && tar xzf ../h3-pg.tar.gz --strip-components=1 -C . && \
|
||||
export PATH="/usr/local/pgsql/bin:$PATH" && \
|
||||
@@ -251,11 +294,10 @@ FROM build-deps AS unit-pg-build
|
||||
ARG PG_VERSION
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
|
||||
esac && \
|
||||
wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.7.tar.gz -O postgresql-unit.tar.gz && \
|
||||
echo "411d05beeb97e5a4abf17572bfcfbb5a68d98d1018918feff995f6ee3bb03e79 postgresql-unit.tar.gz" | sha256sum --check && \
|
||||
# not version-specific
|
||||
# last release 7.9 - Sep 15, 2024
|
||||
RUN wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.9.tar.gz -O postgresql-unit.tar.gz && \
|
||||
echo "e46de6245dcc8b2c2ecf29873dbd43b2b346773f31dd5ce4b8315895a052b456 postgresql-unit.tar.gz" | sha256sum --check && \
|
||||
mkdir postgresql-unit-src && cd postgresql-unit-src && tar xzf ../postgresql-unit.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
@@ -302,12 +344,10 @@ FROM build-deps AS pgjwt-pg-build
|
||||
ARG PG_VERSION
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
# 9742dab1b2f297ad3811120db7b21451bca2d3c9 made on 13/11/2021
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
|
||||
esac && \
|
||||
wget https://github.com/michelp/pgjwt/archive/9742dab1b2f297ad3811120db7b21451bca2d3c9.tar.gz -O pgjwt.tar.gz && \
|
||||
echo "cfdefb15007286f67d3d45510f04a6a7a495004be5b3aecb12cda667e774203f pgjwt.tar.gz" | sha256sum --check && \
|
||||
# not version-specific
|
||||
# doesn't use releases, last commit f3d82fd - Mar 2, 2023
|
||||
RUN wget https://github.com/michelp/pgjwt/archive/f3d82fd30151e754e19ce5d6a06c71c20689ce3d.tar.gz -O pgjwt.tar.gz && \
|
||||
echo "dae8ed99eebb7593b43013f6532d772b12dfecd55548d2673f2dfd0163f6d2b9 pgjwt.tar.gz" | sha256sum --check && \
|
||||
mkdir pgjwt-src && cd pgjwt-src && tar xzf ../pgjwt.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgjwt.control
|
||||
@@ -342,10 +382,9 @@ FROM build-deps AS pg-hashids-pg-build
|
||||
ARG PG_VERSION
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
|
||||
esac && \
|
||||
wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz -O pg_hashids.tar.gz && \
|
||||
# not version-specific
|
||||
# last release v1.2.1 -Jan 12, 2018
|
||||
RUN wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz -O pg_hashids.tar.gz && \
|
||||
echo "74576b992d9277c92196dd8d816baa2cc2d8046fe102f3dcd7f3c3febed6822a pg_hashids.tar.gz" | sha256sum --check && \
|
||||
mkdir pg_hashids-src && cd pg_hashids-src && tar xzf ../pg_hashids.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
|
||||
@@ -405,10 +444,9 @@ FROM build-deps AS ip4r-pg-build
|
||||
ARG PG_VERSION
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
|
||||
esac && \
|
||||
wget https://github.com/RhodiumToad/ip4r/archive/refs/tags/2.4.2.tar.gz -O ip4r.tar.gz && \
|
||||
# not version-specific
|
||||
# last release v2.4.2 - Jul 29, 2023
|
||||
RUN wget https://github.com/RhodiumToad/ip4r/archive/refs/tags/2.4.2.tar.gz -O ip4r.tar.gz && \
|
||||
echo "0f7b1f159974f49a47842a8ab6751aecca1ed1142b6d5e38d81b064b2ead1b4b ip4r.tar.gz" | sha256sum --check && \
|
||||
mkdir ip4r-src && cd ip4r-src && tar xzf ../ip4r.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
@@ -425,10 +463,9 @@ FROM build-deps AS prefix-pg-build
|
||||
ARG PG_VERSION
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
|
||||
esac && \
|
||||
wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.10.tar.gz -O prefix.tar.gz && \
|
||||
# not version-specific
|
||||
# last release v1.2.10 - Jul 5, 2023
|
||||
RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.10.tar.gz -O prefix.tar.gz && \
|
||||
echo "4342f251432a5f6fb05b8597139d3ccde8dcf87e8ca1498e7ee931ca057a8575 prefix.tar.gz" | sha256sum --check && \
|
||||
mkdir prefix-src && cd prefix-src && tar xzf ../prefix.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
@@ -445,10 +482,9 @@ FROM build-deps AS hll-pg-build
|
||||
ARG PG_VERSION
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
|
||||
esac && \
|
||||
wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar.gz -O hll.tar.gz && \
|
||||
# not version-specific
|
||||
# last release v2.18 - Aug 29, 2023
|
||||
RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar.gz -O hll.tar.gz && \
|
||||
echo "e2f55a6f4c4ab95ee4f1b4a2b73280258c5136b161fe9d059559556079694f0e hll.tar.gz" | sha256sum --check && \
|
||||
mkdir hll-src && cd hll-src && tar xzf ../hll.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
@@ -659,11 +695,10 @@ FROM build-deps AS pg-roaringbitmap-pg-build
|
||||
ARG PG_VERSION
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
# not version-specific
|
||||
# last release v0.5.4 - Jun 28, 2022
|
||||
ENV PATH="/usr/local/pgsql/bin/:$PATH"
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 extensions is not supported yet by pg_roaringbitmap. Quit" && exit 0;; \
|
||||
esac && \
|
||||
wget https://github.com/ChenHuajun/pg_roaringbitmap/archive/refs/tags/v0.5.4.tar.gz -O pg_roaringbitmap.tar.gz && \
|
||||
RUN wget https://github.com/ChenHuajun/pg_roaringbitmap/archive/refs/tags/v0.5.4.tar.gz -O pg_roaringbitmap.tar.gz && \
|
||||
echo "b75201efcb1c2d1b014ec4ae6a22769cc7a224e6e406a587f5784a37b6b5a2aa pg_roaringbitmap.tar.gz" | sha256sum --check && \
|
||||
mkdir pg_roaringbitmap-src && cd pg_roaringbitmap-src && tar xzf ../pg_roaringbitmap.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) && \
|
||||
@@ -680,12 +715,27 @@ FROM build-deps AS pg-semver-pg-build
|
||||
ARG PG_VERSION
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
# Release 0.40.0 breaks backward compatibility with previous versions
|
||||
# see release note https://github.com/theory/pg-semver/releases/tag/v0.40.0
|
||||
# Use new version only for v17
|
||||
#
|
||||
# last release v0.40.0 - Jul 22, 2024
|
||||
ENV PATH="/usr/local/pgsql/bin/:$PATH"
|
||||
RUN case "${PG_VERSION}" in "v17") \
|
||||
echo "v17 is not supported yet by pg_semver. Quit" && exit 0;; \
|
||||
RUN case "${PG_VERSION}" in \
|
||||
"v17") \
|
||||
export SEMVER_VERSION=0.40.0 \
|
||||
export SEMVER_CHECKSUM=3e50bcc29a0e2e481e7b6d2bc937cadc5f5869f55d983b5a1aafeb49f5425cfc \
|
||||
;; \
|
||||
"v14" | "v15" | "v16") \
|
||||
export SEMVER_VERSION=0.32.1 \
|
||||
export SEMVER_CHECKSUM=fbdaf7512026d62eec03fad8687c15ed509b6ba395bff140acd63d2e4fbe25d7 \
|
||||
;; \
|
||||
*) \
|
||||
echo "unexpected PostgreSQL version" && exit 1 \
|
||||
;; \
|
||||
esac && \
|
||||
wget https://github.com/theory/pg-semver/archive/refs/tags/v0.32.1.tar.gz -O pg_semver.tar.gz && \
|
||||
echo "fbdaf7512026d62eec03fad8687c15ed509b6ba395bff140acd63d2e4fbe25d7 pg_semver.tar.gz" | sha256sum --check && \
|
||||
wget https://github.com/theory/pg-semver/archive/refs/tags/v${SEMVER_VERSION}.tar.gz -O pg_semver.tar.gz && \
|
||||
echo "${SEMVER_CHECKSUM} pg_semver.tar.gz" | sha256sum --check && \
|
||||
mkdir pg_semver-src && cd pg_semver-src && tar xzf ../pg_semver.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install && \
|
||||
|
||||
@@ -18,11 +18,9 @@ clap.workspace = true
|
||||
flate2.workspace = true
|
||||
futures.workspace = true
|
||||
hyper0 = { workspace = true, features = ["full"] }
|
||||
metrics.workspace = true
|
||||
nix.workspace = true
|
||||
notify.workspace = true
|
||||
num_cpus.workspace = true
|
||||
once_cell.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
opentelemetry_sdk.workspace = true
|
||||
postgres.workspace = true
|
||||
@@ -41,7 +39,6 @@ tracing-subscriber.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
thiserror.workspace = true
|
||||
url.workspace = true
|
||||
prometheus.workspace = true
|
||||
|
||||
compute_api.workspace = true
|
||||
utils.workspace = true
|
||||
|
||||
@@ -1121,7 +1121,6 @@ impl ComputeNode {
|
||||
self.pg_reload_conf()?;
|
||||
}
|
||||
self.post_apply_config()?;
|
||||
self.get_installed_extensions()?;
|
||||
}
|
||||
|
||||
let startup_end_time = Utc::now();
|
||||
@@ -1490,22 +1489,20 @@ LIMIT 100",
|
||||
pub fn get_installed_extensions(&self) -> Result<()> {
|
||||
let connstr = self.connstr.clone();
|
||||
|
||||
thread::spawn(move || {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("failed to create runtime");
|
||||
let result = rt
|
||||
.block_on(crate::installed_extensions::get_installed_extensions(
|
||||
connstr,
|
||||
))
|
||||
.expect("failed to get installed extensions");
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("failed to create runtime");
|
||||
let result = rt
|
||||
.block_on(crate::installed_extensions::get_installed_extensions(
|
||||
connstr,
|
||||
))
|
||||
.expect("failed to get installed extensions");
|
||||
|
||||
info!(
|
||||
"{}",
|
||||
serde_json::to_string(&result).expect("failed to serialize extensions list")
|
||||
);
|
||||
});
|
||||
info!(
|
||||
"{}",
|
||||
serde_json::to_string(&result).expect("failed to serialize extensions list")
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ use crate::catalog::SchemaDumpError;
|
||||
use crate::catalog::{get_database_schema, get_dbs_and_roles};
|
||||
use crate::compute::forward_termination_signal;
|
||||
use crate::compute::{ComputeNode, ComputeState, ParsedSpec};
|
||||
use crate::installed_extensions;
|
||||
use compute_api::requests::ConfigurationRequest;
|
||||
use compute_api::responses::{ComputeStatus, ComputeStatusResponse, GenericAPIError};
|
||||
|
||||
@@ -17,8 +16,6 @@ use anyhow::Result;
|
||||
use hyper::header::CONTENT_TYPE;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Method, Request, Response, Server, StatusCode};
|
||||
use metrics::Encoder;
|
||||
use metrics::TextEncoder;
|
||||
use tokio::task;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tracing_utils::http::OtelName;
|
||||
@@ -65,28 +62,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
Response::new(Body::from(serde_json::to_string(&metrics).unwrap()))
|
||||
}
|
||||
|
||||
// Prometheus metrics
|
||||
(&Method::GET, "/metrics") => {
|
||||
debug!("serving /metrics GET request");
|
||||
|
||||
let mut buffer = vec![];
|
||||
let metrics = installed_extensions::collect();
|
||||
let encoder = TextEncoder::new();
|
||||
encoder.encode(&metrics, &mut buffer).unwrap();
|
||||
|
||||
match Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, encoder.format_type())
|
||||
.body(Body::from(buffer))
|
||||
{
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
let msg = format!("error handling /metrics request: {err}");
|
||||
error!(msg);
|
||||
render_json_error(&msg, StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Collect Postgres current usage insights
|
||||
(&Method::GET, "/insights") => {
|
||||
info!("serving /insights GET request");
|
||||
|
||||
@@ -37,21 +37,6 @@ paths:
|
||||
schema:
|
||||
$ref: "#/components/schemas/ComputeMetrics"
|
||||
|
||||
/metrics
|
||||
get:
|
||||
tags:
|
||||
- Info
|
||||
summary: Get compute node metrics in
|
||||
description: ""
|
||||
operationId: getComputeMetrics
|
||||
responses:
|
||||
200:
|
||||
description: ComputeMetrics
|
||||
content:
|
||||
text/plain:
|
||||
schema:
|
||||
type: string
|
||||
description: Metrics in text format.
|
||||
/insights:
|
||||
get:
|
||||
tags:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use compute_api::responses::{InstalledExtension, InstalledExtensions};
|
||||
use metrics::proto::MetricFamily;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use url::Url;
|
||||
@@ -8,10 +7,6 @@ use anyhow::Result;
|
||||
use postgres::{Client, NoTls};
|
||||
use tokio::task;
|
||||
|
||||
use metrics::core::Collector;
|
||||
use metrics::{register_uint_gauge_vec, UIntGaugeVec};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
/// We don't reuse get_existing_dbs() just for code clarity
|
||||
/// and to make database listing query here more explicit.
|
||||
///
|
||||
@@ -77,40 +72,9 @@ pub async fn get_installed_extensions(connstr: Url) -> Result<InstalledExtension
|
||||
}
|
||||
}
|
||||
|
||||
let res = InstalledExtensions {
|
||||
Ok(InstalledExtensions {
|
||||
extensions: extensions_map.values().cloned().collect(),
|
||||
};
|
||||
|
||||
// set the prometheus metrics
|
||||
for ext in res.extensions.iter() {
|
||||
let versions = {
|
||||
let mut vec: Vec<_> = ext.versions.iter().cloned().collect();
|
||||
vec.sort();
|
||||
vec.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
};
|
||||
|
||||
INSTALLED_EXTENSIONS
|
||||
.with_label_values(&[&ext.extname, &versions])
|
||||
.set(ext.n_databases as u64);
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
})
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
static INSTALLED_EXTENSIONS: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
register_uint_gauge_vec!(
|
||||
"installed_extensions",
|
||||
"Number of databases where extension is installed, versions passed as label",
|
||||
&["extension_name", "versions"]
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
pub fn collect() -> Vec<MetricFamily> {
|
||||
INSTALLED_EXTENSIONS.collect()
|
||||
}
|
||||
|
||||
@@ -31,9 +31,12 @@ pub enum Scope {
|
||||
/// The scope used by pageservers in upcalls to storage controller and cloud control plane
|
||||
#[serde(rename = "generations_api")]
|
||||
GenerationsApi,
|
||||
/// Allows access to control plane managment API and some storage controller endpoints.
|
||||
/// Allows access to control plane managment API and all storage controller endpoints.
|
||||
Admin,
|
||||
|
||||
/// Allows access to control plane & storage controller endpoints used in infrastructure automation (e.g. node registration)
|
||||
Infra,
|
||||
|
||||
/// Allows access to storage controller APIs used by the scrubber, to interrogate the state
|
||||
/// of a tenant & post scrub results.
|
||||
Scrubber,
|
||||
|
||||
@@ -14,14 +14,19 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
|
||||
}
|
||||
(Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope
|
||||
(Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope
|
||||
(Scope::Admin | Scope::SafekeeperData | Scope::GenerationsApi | Scope::Scrubber, _) => {
|
||||
Err(AuthError(
|
||||
format!(
|
||||
"JWT scope '{:?}' is ineligible for Pageserver auth",
|
||||
claims.scope
|
||||
)
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
(
|
||||
Scope::Admin
|
||||
| Scope::SafekeeperData
|
||||
| Scope::GenerationsApi
|
||||
| Scope::Infra
|
||||
| Scope::Scrubber,
|
||||
_,
|
||||
) => Err(AuthError(
|
||||
format!(
|
||||
"JWT scope '{:?}' is ineligible for Pageserver auth",
|
||||
claims.scope
|
||||
)
|
||||
.into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ subtle.workspace = true
|
||||
thiserror.workspace = true
|
||||
tikv-jemallocator.workspace = true
|
||||
tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
|
||||
tokio-postgres.workspace = true
|
||||
tokio-postgres = { workspace = true, features = ["with-serde_json-1"] }
|
||||
tokio-postgres-rustls.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
@@ -101,7 +101,7 @@ jose-jwa = "0.1.2"
|
||||
jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] }
|
||||
signature = "2"
|
||||
ecdsa = "0.16"
|
||||
p256 = "0.13"
|
||||
p256 = { version = "0.13", features = ["jwk"] }
|
||||
rsa = "0.9"
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
@@ -25,6 +25,10 @@ pub(crate) enum WebAuthError {
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
pub struct ConsoleRedirectBackend {
|
||||
console_uri: reqwest::Url,
|
||||
}
|
||||
|
||||
impl UserFacingError for WebAuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
"Internal error".to_string()
|
||||
@@ -57,7 +61,26 @@ pub(crate) fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
impl ConsoleRedirectBackend {
|
||||
pub fn new(console_uri: reqwest::Url) -> Self {
|
||||
Self { console_uri }
|
||||
}
|
||||
|
||||
pub(super) fn url(&self) -> &reqwest::Url {
|
||||
&self.console_uri
|
||||
}
|
||||
|
||||
pub(crate) async fn authenticate(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<NodeInfo> {
|
||||
authenticate(ctx, auth_config, &self.console_uri, client).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
ctx: &RequestMonitoring,
|
||||
auth_config: &'static AuthenticationConfig,
|
||||
link_uri: &reqwest::Url,
|
||||
|
||||
@@ -17,6 +17,8 @@ use crate::{
|
||||
RoleName,
|
||||
};
|
||||
|
||||
use super::ComputeCredentialKeys;
|
||||
|
||||
// TODO(conrad): make these configurable.
|
||||
const CLOCK_SKEW_LEEWAY: Duration = Duration::from_secs(30);
|
||||
const MIN_RENEW: Duration = Duration::from_secs(30);
|
||||
@@ -241,7 +243,7 @@ impl JwkCacheEntryLock {
|
||||
endpoint: EndpointId,
|
||||
role_name: &RoleName,
|
||||
fetch: &F,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<ComputeCredentialKeys, anyhow::Error> {
|
||||
// JWT compact form is defined to be
|
||||
// <B64(Header)> || . || <B64(Payload)> || . || <B64(Signature)>
|
||||
// where Signature = alg(<B64(Header)> || . || <B64(Payload)>);
|
||||
@@ -300,9 +302,9 @@ impl JwkCacheEntryLock {
|
||||
key => bail!("unsupported key type {key:?}"),
|
||||
};
|
||||
|
||||
let payload = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
|
||||
let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payload)
|
||||
let payload = serde_json::from_slice::<JwtPayload<'_>>(&payloadb)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
|
||||
tracing::debug!(?payload, "JWT signature valid with claims");
|
||||
@@ -327,7 +329,7 @@ impl JwkCacheEntryLock {
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(ComputeCredentialKeys::JwtPayload(payloadb))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,7 +341,7 @@ impl JwkCache {
|
||||
role_name: &RoleName,
|
||||
fetch: &F,
|
||||
jwt: &str,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
) -> Result<ComputeCredentialKeys, anyhow::Error> {
|
||||
// try with just a read lock first
|
||||
let key = (endpoint.clone(), role_name.clone());
|
||||
let entry = self.map.get(&key).as_deref().map(Arc::clone);
|
||||
|
||||
@@ -8,6 +8,7 @@ use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub use console_redirect::ConsoleRedirectBackend;
|
||||
pub(crate) use console_redirect::WebAuthError;
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
use local::LocalBackend;
|
||||
@@ -36,7 +37,7 @@ use crate::{
|
||||
provider::{CachedAllowedIps, CachedNodeInfo},
|
||||
Api,
|
||||
},
|
||||
stream, url,
|
||||
stream,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
|
||||
@@ -69,7 +70,7 @@ pub enum Backend<'a, T, D> {
|
||||
/// Cloud API (V2).
|
||||
ControlPlane(MaybeOwned<'a, ControlPlaneBackend>, T),
|
||||
/// Authentication via a web browser.
|
||||
ConsoleRedirect(MaybeOwned<'a, url::ApiUrl>, D),
|
||||
ConsoleRedirect(MaybeOwned<'a, ConsoleRedirectBackend>, D),
|
||||
/// Local proxy uses configured auth credentials and does not wake compute
|
||||
Local(MaybeOwned<'a, LocalBackend>),
|
||||
}
|
||||
@@ -106,9 +107,9 @@ impl std::fmt::Display for Backend<'_, (), ()> {
|
||||
#[cfg(test)]
|
||||
ControlPlaneBackend::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
|
||||
},
|
||||
Self::ConsoleRedirect(url, ()) => fmt
|
||||
Self::ConsoleRedirect(backend, ()) => fmt
|
||||
.debug_tuple("ConsoleRedirect")
|
||||
.field(&url.as_str())
|
||||
.field(&backend.url().as_str())
|
||||
.finish(),
|
||||
Self::Local(_) => fmt.debug_tuple("Local").finish(),
|
||||
}
|
||||
@@ -175,10 +176,12 @@ impl ComputeUserInfo {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(test, derive(Debug))]
|
||||
pub(crate) enum ComputeCredentialKeys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
JwtPayload(Vec<u8>),
|
||||
None,
|
||||
}
|
||||
|
||||
@@ -239,7 +242,6 @@ impl AuthenticationConfig {
|
||||
pub(crate) fn check_rate_limit(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
endpoint: &EndpointId,
|
||||
is_cleartext: bool,
|
||||
@@ -263,7 +265,7 @@ impl AuthenticationConfig {
|
||||
let limit_not_exceeded = self.rate_limiter.check(
|
||||
(
|
||||
endpoint_int,
|
||||
MaskedIp::new(ctx.peer_addr(), config.rate_limit_ip_subnet),
|
||||
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
|
||||
),
|
||||
password_weight,
|
||||
);
|
||||
@@ -337,7 +339,6 @@ async fn auth_quirks(
|
||||
let secret = if let Some(secret) = secret {
|
||||
config.check_rate_limit(
|
||||
ctx,
|
||||
config,
|
||||
secret,
|
||||
&info.endpoint,
|
||||
unauthenticated_password.is_some() || allow_cleartext,
|
||||
@@ -454,12 +455,12 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
Backend::ControlPlane(api, credentials)
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Self::ConsoleRedirect(url, ()) => {
|
||||
Self::ConsoleRedirect(backend, ()) => {
|
||||
info!("performing web authentication");
|
||||
|
||||
let info = console_redirect::authenticate(ctx, config, &url, client).await?;
|
||||
let info = backend.authenticate(ctx, config, client).await?;
|
||||
|
||||
Backend::ConsoleRedirect(url, info)
|
||||
Backend::ConsoleRedirect(backend, info)
|
||||
}
|
||||
Self::Local(_) => {
|
||||
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
|
||||
|
||||
@@ -6,9 +6,12 @@ use compute_api::spec::LocalProxySpec;
|
||||
use dashmap::DashMap;
|
||||
use futures::future::Either;
|
||||
use proxy::{
|
||||
auth::backend::{
|
||||
jwt::JwkCache,
|
||||
local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
auth::{
|
||||
self,
|
||||
backend::{
|
||||
jwt::JwkCache,
|
||||
local::{LocalBackend, JWKS_ROLE_MAP},
|
||||
},
|
||||
},
|
||||
cancellation::CancellationHandlerMain,
|
||||
config::{self, AuthenticationConfig, HttpConfig, ProxyConfig, RetryConfig},
|
||||
@@ -132,6 +135,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let args = LocalProxyCliArgs::parse();
|
||||
let config = build_config(&args)?;
|
||||
let auth_backend = build_auth_backend(&args)?;
|
||||
|
||||
// before we bind to any ports, write the process ID to a file
|
||||
// so that compute-ctl can find our process later
|
||||
@@ -193,6 +197,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let task = serverless::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
http_listener,
|
||||
shutdown.clone(),
|
||||
Arc::new(CancellationHandlerMain::new(
|
||||
@@ -257,9 +262,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
|
||||
Ok(Box::leak(Box::new(ProxyConfig {
|
||||
tls_config: None,
|
||||
auth_backend: proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
|
||||
LocalBackend::new(args.compute),
|
||||
)),
|
||||
metric_collection: None,
|
||||
allow_self_signed_compute: false,
|
||||
http_config,
|
||||
@@ -286,6 +288,17 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
})))
|
||||
}
|
||||
|
||||
/// auth::Backend is created at proxy startup, and lives forever.
|
||||
fn build_auth_backend(
|
||||
args: &LocalProxyCliArgs,
|
||||
) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> {
|
||||
let auth_backend = proxy::auth::Backend::Local(proxy::auth::backend::MaybeOwned::Owned(
|
||||
LocalBackend::new(args.compute),
|
||||
));
|
||||
|
||||
Ok(Box::leak(Box::new(auth_backend)))
|
||||
}
|
||||
|
||||
async fn refresh_config_loop(path: Utf8PathBuf, rx: Arc<Notify>) {
|
||||
loop {
|
||||
rx.notified().await;
|
||||
|
||||
@@ -10,6 +10,7 @@ use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::jwt::JwkCache;
|
||||
use proxy::auth::backend::AuthRateLimiter;
|
||||
use proxy::auth::backend::ConsoleRedirectBackend;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::cancellation::CancelMap;
|
||||
use proxy::cancellation::CancellationHandler;
|
||||
@@ -311,8 +312,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let args = ProxyCliArgs::parse();
|
||||
let config = build_config(&args)?;
|
||||
let auth_backend = build_auth_backend(&args)?;
|
||||
|
||||
info!("Authentication backend: {}", config.auth_backend);
|
||||
info!("Authentication backend: {}", auth_backend);
|
||||
info!("Using region: {}", args.aws_region);
|
||||
|
||||
let region_provider =
|
||||
@@ -462,6 +464,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
@@ -472,6 +475,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
if let Some(serverless_listener) = serverless_listener {
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
@@ -506,7 +510,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
));
|
||||
}
|
||||
|
||||
if let auth::Backend::ControlPlane(api, _) = &config.auth_backend {
|
||||
if let auth::Backend::ControlPlane(api, _) = auth_backend {
|
||||
if let proxy::control_plane::provider::ControlPlaneBackend::Management(api) = &**api {
|
||||
match (redis_notifications_client, regional_redis_client.clone()) {
|
||||
(None, None) => {}
|
||||
@@ -610,6 +614,80 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
bail!("dynamic rate limiter should be disabled");
|
||||
}
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.connect_compute_lock.parse()?;
|
||||
info!(
|
||||
?limiter,
|
||||
shards,
|
||||
?epoch,
|
||||
"Using NodeLocks (connect_compute)"
|
||||
);
|
||||
let connect_compute_locks = control_plane::locks::ApiLocks::new(
|
||||
"connect_compute_lock",
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().proxy.connect_compute_lock,
|
||||
)?;
|
||||
|
||||
let http_config = HttpConfig {
|
||||
accept_websockets: !args.is_auth_broker,
|
||||
pool_options: GlobalConnPoolOptions {
|
||||
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
|
||||
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
|
||||
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
|
||||
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
|
||||
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
|
||||
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
|
||||
},
|
||||
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
|
||||
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
|
||||
max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
|
||||
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_auth_broker: args.is_auth_broker,
|
||||
accept_jwts: args.is_auth_broker,
|
||||
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
metric_collection,
|
||||
allow_self_signed_compute: args.allow_self_signed_compute,
|
||||
http_config,
|
||||
authentication_config,
|
||||
proxy_protocol_v2: args.proxy_protocol_v2,
|
||||
handshake_timeout: args.handshake_timeout,
|
||||
region: args.region.clone(),
|
||||
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute_retry_config: config::RetryConfig::parse(
|
||||
&args.connect_to_compute_retry,
|
||||
)?,
|
||||
}));
|
||||
|
||||
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// auth::Backend is created at proxy startup, and lives forever.
|
||||
fn build_auth_backend(
|
||||
args: &ProxyCliArgs,
|
||||
) -> anyhow::Result<&'static auth::Backend<'static, (), ()>> {
|
||||
let auth_backend = match &args.auth_backend {
|
||||
AuthBackendType::Console => {
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
@@ -665,7 +743,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
|
||||
AuthBackendType::Web => {
|
||||
let url = args.uri.parse()?;
|
||||
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(url), ())
|
||||
auth::Backend::ConsoleRedirect(MaybeOwned::Owned(ConsoleRedirectBackend::new(url)), ())
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
@@ -677,75 +755,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
}
|
||||
};
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
shards,
|
||||
limiter,
|
||||
epoch,
|
||||
timeout,
|
||||
} = args.connect_compute_lock.parse()?;
|
||||
info!(
|
||||
?limiter,
|
||||
shards,
|
||||
?epoch,
|
||||
"Using NodeLocks (connect_compute)"
|
||||
);
|
||||
let connect_compute_locks = control_plane::locks::ApiLocks::new(
|
||||
"connect_compute_lock",
|
||||
limiter,
|
||||
shards,
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().proxy.connect_compute_lock,
|
||||
)?;
|
||||
|
||||
let http_config = HttpConfig {
|
||||
accept_websockets: !args.is_auth_broker,
|
||||
pool_options: GlobalConnPoolOptions {
|
||||
max_conns_per_endpoint: args.sql_over_http.sql_over_http_pool_max_conns_per_endpoint,
|
||||
gc_epoch: args.sql_over_http.sql_over_http_pool_gc_epoch,
|
||||
pool_shards: args.sql_over_http.sql_over_http_pool_shards,
|
||||
idle_timeout: args.sql_over_http.sql_over_http_idle_timeout,
|
||||
opt_in: args.sql_over_http.sql_over_http_pool_opt_in,
|
||||
max_total_conns: args.sql_over_http.sql_over_http_pool_max_total_conns,
|
||||
},
|
||||
cancel_set: CancelSet::new(args.sql_over_http.sql_over_http_cancel_set_shards),
|
||||
client_conn_threshold: args.sql_over_http.sql_over_http_client_conn_threshold,
|
||||
max_request_size_bytes: args.sql_over_http.sql_over_http_max_request_size_bytes,
|
||||
max_response_size_bytes: args.sql_over_http.sql_over_http_max_response_size_bytes,
|
||||
};
|
||||
let authentication_config = AuthenticationConfig {
|
||||
jwks_cache: JwkCache::default(),
|
||||
thread_pool,
|
||||
scram_protocol_timeout: args.scram_protocol_timeout,
|
||||
rate_limiter_enabled: args.auth_rate_limit_enabled,
|
||||
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
|
||||
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
|
||||
ip_allowlist_check_enabled: !args.is_private_access_proxy,
|
||||
is_auth_broker: args.is_auth_broker,
|
||||
accept_jwts: args.is_auth_broker,
|
||||
webauth_confirmation_timeout: args.webauth_confirmation_timeout,
|
||||
};
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
metric_collection,
|
||||
allow_self_signed_compute: args.allow_self_signed_compute,
|
||||
http_config,
|
||||
authentication_config,
|
||||
proxy_protocol_v2: args.proxy_protocol_v2,
|
||||
handshake_timeout: args.handshake_timeout,
|
||||
region: args.region.clone(),
|
||||
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute_retry_config: config::RetryConfig::parse(
|
||||
&args.connect_to_compute_retry,
|
||||
)?,
|
||||
}));
|
||||
|
||||
tokio::spawn(config.connect_compute_locks.garbage_collect_worker());
|
||||
|
||||
Ok(config)
|
||||
Ok(Box::leak(Box::new(auth_backend)))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
use crate::{
|
||||
auth::{
|
||||
self,
|
||||
backend::{jwt::JwkCache, AuthRateLimiter},
|
||||
},
|
||||
auth::backend::{jwt::JwkCache, AuthRateLimiter},
|
||||
control_plane::locks::ApiLocks,
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
||||
scram::threadpool::ThreadPool,
|
||||
@@ -29,7 +26,6 @@ use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::Backend<'static, (), ()>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub http_config: HttpConfig,
|
||||
|
||||
@@ -81,12 +81,12 @@ pub(crate) mod errors {
|
||||
Reason::EndpointNotFound => ErrorKind::User,
|
||||
Reason::BranchNotFound => ErrorKind::User,
|
||||
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
|
||||
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User,
|
||||
Reason::ActiveTimeQuotaExceeded => ErrorKind::User,
|
||||
Reason::ComputeTimeQuotaExceeded => ErrorKind::User,
|
||||
Reason::WrittenDataQuotaExceeded => ErrorKind::User,
|
||||
Reason::DataTransferQuotaExceeded => ErrorKind::User,
|
||||
Reason::LogicalSizeQuotaExceeded => ErrorKind::User,
|
||||
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::Quota,
|
||||
Reason::ActiveTimeQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::ComputeTimeQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::WrittenDataQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::DataTransferQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::LogicalSizeQuotaExceeded => ErrorKind::Quota,
|
||||
Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
|
||||
Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
|
||||
Reason::RunningOperations => ErrorKind::ControlPlane,
|
||||
@@ -103,7 +103,7 @@ pub(crate) mod errors {
|
||||
} if error
|
||||
.contains("compute time quota of non-primary branches is exceeded") =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
crate::error::ErrorKind::Quota
|
||||
}
|
||||
ControlPlaneError {
|
||||
http_status_code: http::StatusCode::LOCKED,
|
||||
@@ -112,7 +112,7 @@ pub(crate) mod errors {
|
||||
} if error.contains("quota exceeded")
|
||||
|| error.contains("the limit for current plan reached") =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
crate::error::ErrorKind::Quota
|
||||
}
|
||||
ControlPlaneError {
|
||||
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
|
||||
@@ -309,7 +309,7 @@ impl NodeInfo {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
ComputeCredentialKeys::None => &mut self.config,
|
||||
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ use futures::TryFutureExt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
@@ -456,7 +456,7 @@ async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
});
|
||||
body.http_status_code = status;
|
||||
|
||||
error!("console responded with an error ({status}): {body:?}");
|
||||
warn!("console responded with an error ({status}): {body:?}");
|
||||
Err(ApiError::ControlPlane(body))
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,10 @@ pub enum ErrorKind {
|
||||
#[label(rename = "serviceratelimit")]
|
||||
ServiceRateLimit,
|
||||
|
||||
/// Proxy quota limit violation
|
||||
#[label(rename = "quota")]
|
||||
Quota,
|
||||
|
||||
/// internal errors
|
||||
Service,
|
||||
|
||||
@@ -70,6 +74,7 @@ impl ErrorKind {
|
||||
ErrorKind::ClientDisconnect => "clientdisconnect",
|
||||
ErrorKind::RateLimit => "ratelimit",
|
||||
ErrorKind::ServiceRateLimit => "serviceratelimit",
|
||||
ErrorKind::Quota => "quota",
|
||||
ErrorKind::Service => "service",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "postgres",
|
||||
|
||||
@@ -35,7 +35,7 @@ use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, Instrument};
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
|
||||
use self::{
|
||||
connect_compute::{connect_to_compute, TcpMechanism},
|
||||
@@ -61,6 +61,7 @@ pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, (), ()>,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -95,15 +96,15 @@ pub async fn task_main(
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, None)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
|
||||
error!("missing required proxy protocol header");
|
||||
warn!("missing required proxy protocol header");
|
||||
return;
|
||||
}
|
||||
Ok((_socket, Some(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
|
||||
error!("proxy protocol header not supported");
|
||||
warn!("proxy protocol header not supported");
|
||||
return;
|
||||
}
|
||||
Ok((socket, Some(addr))) => (socket, addr.ip()),
|
||||
@@ -129,6 +130,7 @@ pub async fn task_main(
|
||||
let startup = Box::pin(
|
||||
handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
@@ -144,7 +146,7 @@ pub async fn task_main(
|
||||
Err(e) => {
|
||||
// todo: log and push to ctx the error kind
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
error!(parent: &span, "per-client task finished with an error: {e:#}");
|
||||
warn!(parent: &span, "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
@@ -155,7 +157,7 @@ pub async fn task_main(
|
||||
match p.proxy_pass().instrument(span.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
|
||||
warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}");
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}");
|
||||
@@ -243,8 +245,10 @@ impl ReportableError for ClientRequestError {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, (), ()>,
|
||||
ctx: &RequestMonitoring,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
stream: S,
|
||||
@@ -285,8 +289,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = config
|
||||
.auth_backend
|
||||
let result = auth_backend
|
||||
.as_ref()
|
||||
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
@@ -71,7 +71,7 @@ impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
|
||||
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
|
||||
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
|
||||
tracing::error!(?err, "could not cancel the query in the database");
|
||||
tracing::warn!(?err, "could not cancel the query in the database");
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use redis::{
|
||||
ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult,
|
||||
};
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{debug, error, info};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::elasticache::CredentialsProvider;
|
||||
|
||||
@@ -89,7 +89,7 @@ impl ConnectionWithCredentialsProvider {
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Error during PING: {e:?}");
|
||||
warn!("Error during PING: {e:?}");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -121,7 +121,7 @@ impl ConnectionWithCredentialsProvider {
|
||||
info!("Connection succesfully established");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Connection is broken. Error during PING: {e:?}");
|
||||
warn!("Connection is broken. Error during PING: {e:?}");
|
||||
}
|
||||
}
|
||||
self.con = Some(con);
|
||||
|
||||
@@ -146,7 +146,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to cancel session: {e}");
|
||||
tracing::warn!("failed to cancel session: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,15 +3,17 @@ use std::{io, sync::Arc, time::Duration};
|
||||
use async_trait::async_trait;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use tokio::net::{lookup_host, TcpStream};
|
||||
use tracing::{field::display, info};
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tracing::{debug, field::display, info};
|
||||
|
||||
use crate::{
|
||||
auth::{
|
||||
self,
|
||||
backend::{local::StaticAuthRules, ComputeCredentials, ComputeUserInfo},
|
||||
check_peer_addr_is_in_list, AuthError,
|
||||
},
|
||||
compute,
|
||||
config::{AuthenticationConfig, ProxyConfig},
|
||||
config::ProxyConfig,
|
||||
context::RequestMonitoring,
|
||||
control_plane::{
|
||||
errors::{GetAuthInfoError, WakeComputeError},
|
||||
@@ -26,18 +28,21 @@ use crate::{
|
||||
retry::{CouldRetry, ShouldRetryWakeCompute},
|
||||
},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
Host,
|
||||
EndpointId, Host,
|
||||
};
|
||||
|
||||
use super::{
|
||||
conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool},
|
||||
http_conn_pool::{self, poll_http2_client},
|
||||
local_conn_pool::{self, LocalClient, LocalConnPool},
|
||||
};
|
||||
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) http_conn_pool: Arc<super::http_conn_pool::GlobalConnPool>,
|
||||
pub(crate) local_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) auth_backend: &'static crate::auth::Backend<'static, (), ()>,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
@@ -45,18 +50,13 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_password(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
user_info: &ComputeUserInfo,
|
||||
password: &[u8],
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = user_info.clone();
|
||||
let backend = self
|
||||
.config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|()| user_info.clone());
|
||||
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
if config.ip_allowlist_check_enabled
|
||||
if self.config.authentication_config.ip_allowlist_check_enabled
|
||||
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
|
||||
{
|
||||
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
@@ -75,7 +75,6 @@ impl PoolingBackend {
|
||||
let secret = match cached_secret.value.clone() {
|
||||
Some(secret) => self.config.authentication_config.check_rate_limit(
|
||||
ctx,
|
||||
config,
|
||||
secret,
|
||||
&user_info.endpoint,
|
||||
true,
|
||||
@@ -87,9 +86,13 @@ impl PoolingBackend {
|
||||
}
|
||||
};
|
||||
let ep = EndpointIdInt::from(&user_info.endpoint);
|
||||
let auth_outcome =
|
||||
crate::auth::validate_password_and_exchange(&config.thread_pool, ep, password, secret)
|
||||
.await?;
|
||||
let auth_outcome = crate::auth::validate_password_and_exchange(
|
||||
&self.config.authentication_config.thread_pool,
|
||||
ep,
|
||||
password,
|
||||
secret,
|
||||
)
|
||||
.await?;
|
||||
let res = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => {
|
||||
info!("user successfully authenticated");
|
||||
@@ -109,13 +112,13 @@ impl PoolingBackend {
|
||||
pub(crate) async fn authenticate_with_jwt(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
user_info: &ComputeUserInfo,
|
||||
jwt: String,
|
||||
) -> Result<(), AuthError> {
|
||||
match &self.config.auth_backend {
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
match &self.auth_backend {
|
||||
crate::auth::Backend::ControlPlane(console, ()) => {
|
||||
config
|
||||
self.config
|
||||
.authentication_config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
@@ -127,13 +130,18 @@ impl PoolingBackend {
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
|
||||
Ok(())
|
||||
Ok(ComputeCredentials {
|
||||
info: user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
})
|
||||
}
|
||||
crate::auth::Backend::ConsoleRedirect(_, ()) => Err(AuthError::auth_failed(
|
||||
"JWT login over web auth proxy is not supported",
|
||||
)),
|
||||
crate::auth::Backend::Local(_) => {
|
||||
config
|
||||
let keys = self
|
||||
.config
|
||||
.authentication_config
|
||||
.jwks_cache
|
||||
.check_jwt(
|
||||
ctx,
|
||||
@@ -145,8 +153,10 @@ impl PoolingBackend {
|
||||
.await
|
||||
.map_err(|e| AuthError::auth_failed(e.to_string()))?;
|
||||
|
||||
// todo: rewrite JWT signature with key shared somehow between local proxy and postgres
|
||||
Ok(())
|
||||
Ok(ComputeCredentials {
|
||||
info: user_info.clone(),
|
||||
keys,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -176,7 +186,7 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.config.auth_backend.as_ref().map(|()| keys);
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
@@ -208,14 +218,14 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self
|
||||
.config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|()| ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
});
|
||||
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
user: conn_info.user_info.user.clone(),
|
||||
endpoint: EndpointId::from(format!("{}-local-proxy", conn_info.user_info.endpoint)),
|
||||
options: conn_info.user_info.options.clone(),
|
||||
},
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
@@ -231,6 +241,77 @@ impl PoolingBackend {
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Connect to postgres over localhost.
|
||||
///
|
||||
/// We expect postgres to be started here, so we won't do any retries.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called with a non-local_proxy backend.
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
pub(crate) async fn connect_to_local_postgres(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
) -> Result<LocalClient<tokio_postgres::Client>, HttpConnError> {
|
||||
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let mut node_info = match &self.auth_backend {
|
||||
auth::Backend::ControlPlane(_, ()) | auth::Backend::ConsoleRedirect(_, ()) => {
|
||||
unreachable!("only local_proxy can connect to local postgres")
|
||||
}
|
||||
auth::Backend::Local(local) => local.node_info.clone(),
|
||||
};
|
||||
|
||||
let config = node_info
|
||||
.config
|
||||
.user(&conn_info.user_info.user)
|
||||
.dbname(&conn_info.dbname);
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = config.connect(tokio_postgres::NoTls).await?;
|
||||
drop(pause);
|
||||
|
||||
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
|
||||
|
||||
let handle = local_conn_pool::poll_client(
|
||||
self.local_pool.clone(),
|
||||
ctx,
|
||||
conn_info,
|
||||
client,
|
||||
connection,
|
||||
conn_id,
|
||||
node_info.aux.clone(),
|
||||
);
|
||||
|
||||
let kid = handle.get_client().get_process_id() as i64;
|
||||
let jwk = p256::PublicKey::from(handle.key().verifying_key()).to_jwk();
|
||||
|
||||
debug!(kid, ?jwk, "setting up backend session state");
|
||||
|
||||
// initiates the auth session
|
||||
handle
|
||||
.get_client()
|
||||
.query(
|
||||
"select auth.init($1, $2);",
|
||||
&[
|
||||
&kid as &(dyn ToSql + Sync),
|
||||
&tokio_postgres::types::Json(jwk),
|
||||
],
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(?kid, "backend session state init");
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -241,6 +322,8 @@ pub(crate) enum HttpConnError {
|
||||
PostgresConnectionError(#[from] tokio_postgres::Error),
|
||||
#[error("could not connection to local-proxy in compute")]
|
||||
LocalProxyConnectionError(#[from] LocalProxyConnError),
|
||||
#[error("could not parse JWT payload")]
|
||||
JwtPayloadError(serde_json::Error),
|
||||
|
||||
#[error("could not get auth info")]
|
||||
GetAuthInfo(#[from] GetAuthInfoError),
|
||||
@@ -266,6 +349,7 @@ impl ReportableError for HttpConnError {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
|
||||
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
|
||||
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
|
||||
HttpConnError::JwtPayloadError(_) => ErrorKind::User,
|
||||
HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
|
||||
HttpConnError::AuthError(a) => a.get_error_kind(),
|
||||
HttpConnError::WakeCompute(w) => w.get_error_kind(),
|
||||
@@ -280,6 +364,7 @@ impl UserFacingError for HttpConnError {
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
|
||||
HttpConnError::PostgresConnectionError(p) => p.to_string(),
|
||||
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
|
||||
HttpConnError::JwtPayloadError(p) => p.to_string(),
|
||||
HttpConnError::GetAuthInfo(c) => c.to_string_client(),
|
||||
HttpConnError::AuthError(c) => c.to_string_client(),
|
||||
HttpConnError::WakeCompute(c) => c.to_string_client(),
|
||||
@@ -296,6 +381,7 @@ impl CouldRetry for HttpConnError {
|
||||
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
|
||||
HttpConnError::ConnectionClosedAbruptly(_) => false,
|
||||
HttpConnError::JwtPayloadError(_) => false,
|
||||
HttpConnError::GetAuthInfo(_) => false,
|
||||
HttpConnError::AuthError(_) => false,
|
||||
HttpConnError::WakeCompute(_) => false,
|
||||
@@ -422,8 +508,12 @@ impl ConnectMechanism for HyperMechanism {
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
|
||||
// let port = node_info.config.get_ports().first().unwrap_or_else(10432);
|
||||
let res = connect_http2(&host, 10432, timeout).await;
|
||||
let port = *node_info.config.get_ports().first().ok_or_else(|| {
|
||||
HttpConnError::WakeCompute(WakeComputeError::BadComputeAddress(
|
||||
"local-proxy port missing on compute address".into(),
|
||||
))
|
||||
})?;
|
||||
let res = connect_http2(&host, port, timeout).await;
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
|
||||
@@ -1,30 +1,40 @@
|
||||
use itertools::Itertools;
|
||||
use serde_json::value::RawValue;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use tokio_postgres::types::Kind;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::Row;
|
||||
use typed_json::json;
|
||||
|
||||
use super::json_raw_value::LazyValue;
|
||||
|
||||
//
|
||||
// Convert json non-string types to strings, so that they can be passed to Postgres
|
||||
// as parameters.
|
||||
//
|
||||
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
|
||||
json.iter().map(json_value_to_pg_text).collect()
|
||||
pub(crate) fn json_to_pg_text(
|
||||
json: &[&RawValue],
|
||||
) -> Result<Vec<Option<String>>, serde_json::Error> {
|
||||
json.iter().copied().map(json_value_to_pg_text).try_collect()
|
||||
}
|
||||
|
||||
fn json_value_to_pg_text(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
fn json_value_to_pg_text(value: &RawValue) -> Result<Option<String>, serde_json::Error> {
|
||||
let lazy_value = serde_json::from_str(value.get())?;
|
||||
match lazy_value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
LazyValue::Null => Ok(None),
|
||||
|
||||
// convert to text with escaping
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
|
||||
LazyValue::Bool | LazyValue::Number | LazyValue::Object => {
|
||||
Ok(Some(value.get().to_string()))
|
||||
}
|
||||
|
||||
// avoid escaping here, as we pass this as a parameter
|
||||
Value::String(s) => Some(s.to_string()),
|
||||
LazyValue::String(s) => Ok(Some(s.into_owned())),
|
||||
|
||||
// special care for arrays
|
||||
Value::Array(_) => json_array_to_pg_array(value),
|
||||
LazyValue::Array(arr) => Ok(Some(json_array_to_pg_array(arr)?)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,27 +46,42 @@ fn json_value_to_pg_text(value: &Value) -> Option<String> {
|
||||
//
|
||||
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
|
||||
//
|
||||
fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
fn json_array_to_pg_array(arr: Vec<&RawValue>) -> Result<String, serde_json::Error> {
|
||||
let mut output = String::new();
|
||||
let mut first = true;
|
||||
|
||||
output.push('{');
|
||||
|
||||
for value in arr {
|
||||
if !first {
|
||||
output.push(',');
|
||||
}
|
||||
first = false;
|
||||
|
||||
let value = json_array_to_pg_array_inner(value)?;
|
||||
output.push_str(value.as_deref().unwrap_or("NULL"));
|
||||
}
|
||||
|
||||
output.push('}');
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn json_array_to_pg_array_inner(value: &RawValue) -> Result<Option<String>, serde_json::Error> {
|
||||
let lazy_value = serde_json::from_str(value.get())?;
|
||||
match lazy_value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
LazyValue::Null => Ok(None),
|
||||
|
||||
// convert to text with escaping
|
||||
// here string needs to be escaped, as it is part of the array
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
|
||||
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
|
||||
LazyValue::Bool | LazyValue::Number | LazyValue::String(_) => {
|
||||
Ok(Some(value.get().to_string()))
|
||||
}
|
||||
LazyValue::Object => Ok(Some(json!(value.get().to_string()).to_string())),
|
||||
|
||||
// recurse into array
|
||||
Value::Array(arr) => {
|
||||
let vals = arr
|
||||
.iter()
|
||||
.map(json_array_to_pg_array)
|
||||
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
Some(format!("{{{vals}}}"))
|
||||
}
|
||||
LazyValue::Array(arr) => Ok(Some(json_array_to_pg_array(arr)?)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,25 +284,31 @@ mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn json_to_pg_text_test(json: Vec<serde_json::Value>) -> Vec<Option<String>> {
|
||||
let json = serde_json::Value::Array(json).to_string();
|
||||
let json: Vec<&RawValue> = serde_json::from_str(&json).unwrap();
|
||||
json_to_pg_text(&json).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_types_to_pg_params() {
|
||||
let json = vec![Value::Bool(true), Value::Bool(false)];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
let pg_params = json_to_pg_text_test(json);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some("true".to_owned()), Some("false".to_owned())]
|
||||
);
|
||||
|
||||
let json = vec![Value::Number(serde_json::Number::from(42))];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
let pg_params = json_to_pg_text_test(json);
|
||||
assert_eq!(pg_params, vec![Some("42".to_owned())]);
|
||||
|
||||
let json = vec![Value::String("foo\"".to_string())];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
let pg_params = json_to_pg_text_test(json);
|
||||
assert_eq!(pg_params, vec![Some("foo\"".to_owned())]);
|
||||
|
||||
let json = vec![Value::Null];
|
||||
let pg_params = json_to_pg_text(json);
|
||||
let pg_params = json_to_pg_text_test(json);
|
||||
assert_eq!(pg_params, vec![None]);
|
||||
}
|
||||
|
||||
@@ -286,7 +317,7 @@ mod tests {
|
||||
// atoms and escaping
|
||||
let json = "[true, false, null, \"NULL\", 42, \"foo\", \"bar\\\"-\\\\\"]";
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
let pg_params = json_to_pg_text_test(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(
|
||||
@@ -297,7 +328,7 @@ mod tests {
|
||||
// nested arrays
|
||||
let json = "[[true, false], [null, 42], [\"foo\", \"bar\\\"-\\\\\"]]";
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
let pg_params = json_to_pg_text_test(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(
|
||||
@@ -307,7 +338,7 @@ mod tests {
|
||||
// array of objects
|
||||
let json = r#"[{"foo": 1},{"bar": 2}]"#;
|
||||
let json: Value = serde_json::from_str(json).unwrap();
|
||||
let pg_params = json_to_pg_text(vec![json]);
|
||||
let pg_params = json_to_pg_text_test(vec![json]);
|
||||
assert_eq!(
|
||||
pg_params,
|
||||
vec![Some(r#"{"{\"foo\":1}","{\"bar\":2}"}"#.to_owned())]
|
||||
|
||||
193
proxy/src/serverless/json_raw_value.rs
Normal file
193
proxy/src/serverless/json_raw_value.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
//! [`serde_json::Value`] but uses RawValue internally
|
||||
//!
|
||||
//! This code forks from the serde_json code, but replaces internal Value with RawValue where possible.
|
||||
//!
|
||||
//! Taken from <https://github.com/serde-rs/json/blob/faab2e8d2fcf781a3f77f329df836ffb3aaacfba/src/value/de.rs>
|
||||
//! Licensed from serde-rs under MIT or APACHE-2.0, with modifications by Conrad Ludgate
|
||||
|
||||
use core::fmt;
|
||||
use std::borrow::Cow;
|
||||
|
||||
use serde::{
|
||||
de::{IgnoredAny, MapAccess, SeqAccess, Visitor},
|
||||
Deserialize,
|
||||
};
|
||||
use serde_json::value::RawValue;
|
||||
|
||||
pub enum LazyValue<'de> {
|
||||
Null,
|
||||
Bool,
|
||||
Number,
|
||||
String(Cow<'de, str>),
|
||||
Array(Vec<&'de RawValue>),
|
||||
Object,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for LazyValue<'de> {
|
||||
#[inline]
|
||||
fn deserialize<D>(deserializer: D) -> Result<LazyValue<'de>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct ValueVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for ValueVisitor {
|
||||
type Value = LazyValue<'de>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str("any valid JSON value")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_bool<E>(self, _value: bool) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Bool)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_i64<E>(self, _value: i64) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Number)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_u64<E>(self, _value: u64) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Number)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_f64<E>(self, _value: f64) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Number)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_str<E>(self, value: &str) -> Result<LazyValue<'de>, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
self.visit_string(String::from(value))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_borrowed_str<E>(self, value: &'de str) -> Result<LazyValue<'de>, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(LazyValue::String(Cow::Borrowed(value)))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_string<E>(self, value: String) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::String(Cow::Owned(value)))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_none<E>(self) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Null)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_some<D>(self, deserializer: D) -> Result<LazyValue<'de>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
Deserialize::deserialize(deserializer)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_unit<E>(self) -> Result<LazyValue<'de>, E> {
|
||||
Ok(LazyValue::Null)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_seq<V>(self, mut visitor: V) -> Result<LazyValue<'de>, V::Error>
|
||||
where
|
||||
V: SeqAccess<'de>,
|
||||
{
|
||||
let mut vec = Vec::new();
|
||||
|
||||
while let Some(elem) = visitor.next_element()? {
|
||||
vec.push(elem);
|
||||
}
|
||||
|
||||
Ok(LazyValue::Array(vec))
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut visitor: V) -> Result<LazyValue<'de>, V::Error>
|
||||
where
|
||||
V: MapAccess<'de>,
|
||||
{
|
||||
while visitor.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
|
||||
Ok(LazyValue::Object)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_any(ValueVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::borrow::Cow;
|
||||
|
||||
use typed_json::json;
|
||||
|
||||
use super::LazyValue;
|
||||
|
||||
#[test]
|
||||
fn object() {
|
||||
let json = json! {{
|
||||
"foo": {
|
||||
"bar": 1
|
||||
},
|
||||
"baz": [2, 3],
|
||||
}}
|
||||
.to_string();
|
||||
|
||||
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
|
||||
|
||||
let LazyValue::Object = lazy else {
|
||||
panic!("expected object")
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array() {
|
||||
let json = json! {[
|
||||
{
|
||||
"bar": 1
|
||||
},
|
||||
[2, 3],
|
||||
]}
|
||||
.to_string();
|
||||
|
||||
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
|
||||
|
||||
let LazyValue::Array(array) = lazy else {
|
||||
panic!("expected array")
|
||||
};
|
||||
assert_eq!(array.len(), 2);
|
||||
|
||||
assert_eq!(array[0].get(), r#"{"bar":1}"#);
|
||||
assert_eq!(array[1].get(), r#"[2,3]"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn string() {
|
||||
let json = json! { "hello world" }.to_string();
|
||||
|
||||
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
|
||||
|
||||
let LazyValue::String(Cow::Borrowed(string)) = lazy else {
|
||||
panic!("expected borrowed string")
|
||||
};
|
||||
assert_eq!(string, "hello world");
|
||||
|
||||
let json = json! { "hello \n world" }.to_string();
|
||||
|
||||
let lazy: LazyValue = serde_json::from_str(&json).unwrap();
|
||||
|
||||
let LazyValue::String(Cow::Owned(string)) = lazy else {
|
||||
panic!("expected owned string")
|
||||
};
|
||||
assert_eq!(string, "hello \n world");
|
||||
}
|
||||
}
|
||||
544
proxy/src/serverless/local_conn_pool.rs
Normal file
544
proxy/src/serverless/local_conn_pool.rs
Normal file
@@ -0,0 +1,544 @@
|
||||
use futures::{future::poll_fn, Future};
|
||||
use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
|
||||
use p256::ecdsa::{Signature, SigningKey};
|
||||
use parking_lot::RwLock;
|
||||
use rand::rngs::OsRng;
|
||||
use serde_json::Value;
|
||||
use signature::Signer;
|
||||
use std::task::{ready, Poll};
|
||||
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::types::ToSql;
|
||||
use tokio_postgres::{AsyncMessage, ReadyForQueryStatus, Socket};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use typed_json::json;
|
||||
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
|
||||
use crate::{context::RequestMonitoring, DbName, RoleName};
|
||||
|
||||
use tracing::{debug, error, warn, Span};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
use super::conn_pool::{ClientInnerExt, ConnInfo};
|
||||
|
||||
struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
conn: ClientInner<C>,
|
||||
_last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
// /// key id for the pg_session_jwt state
|
||||
// static PG_SESSION_JWT_KID: AtomicU64 = AtomicU64::new(1);
|
||||
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
|
||||
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
|
||||
total_conns: usize,
|
||||
max_conns: usize,
|
||||
global_pool_size_max_conns: usize,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
fn get_conn_entry(&mut self, db_user: (DbName, RoleName)) -> Option<ConnPoolEntry<C>> {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
pools
|
||||
.get_mut(&db_user)
|
||||
.and_then(|pool_entries| pool_entries.get_conn_entry(total_conns))
|
||||
}
|
||||
|
||||
fn remove_client(&mut self, db_user: (DbName, RoleName), conn_id: uuid::Uuid) -> bool {
|
||||
let Self {
|
||||
pools, total_conns, ..
|
||||
} = self;
|
||||
if let Some(pool) = pools.get_mut(&db_user) {
|
||||
let old_len = pool.conns.len();
|
||||
pool.conns.retain(|conn| conn.conn.conn_id != conn_id);
|
||||
let new_len = pool.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
}
|
||||
*total_conns -= removed;
|
||||
removed > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInner<C>) {
|
||||
let conn_id = client.conn_id;
|
||||
|
||||
if client.is_closed() {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because connection is closed");
|
||||
return;
|
||||
}
|
||||
let global_max_conn = pool.read().global_pool_size_max_conns;
|
||||
if pool.read().total_conns >= global_max_conn {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full");
|
||||
return;
|
||||
}
|
||||
|
||||
// return connection to the pool
|
||||
let mut returned = false;
|
||||
let mut per_db_size = 0;
|
||||
let total_conns = {
|
||||
let mut pool = pool.write();
|
||||
|
||||
if pool.total_conns < pool.max_conns {
|
||||
let pool_entries = pool.pools.entry(conn_info.db_and_user()).or_default();
|
||||
pool_entries.conns.push(ConnPoolEntry {
|
||||
conn: client,
|
||||
_last_access: std::time::Instant::now(),
|
||||
});
|
||||
|
||||
returned = true;
|
||||
per_db_size = pool_entries.conns.len();
|
||||
|
||||
pool.total_conns += 1;
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.inc();
|
||||
}
|
||||
|
||||
pool.total_conns
|
||||
};
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
info!(%conn_id, "local_pool: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
} else {
|
||||
info!(%conn_id, "local_pool: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
|
||||
fn drop(&mut self) {
|
||||
if self.total_conns > 0 {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(self.total_conns as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
|
||||
conns: Vec<ConnPoolEntry<C>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Default for DbUserConnPool<C> {
|
||||
fn default() -> Self {
|
||||
Self { conns: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
fn clear_closed_clients(&mut self, conns: &mut usize) -> usize {
|
||||
let old_len = self.conns.len();
|
||||
|
||||
self.conns.retain(|conn| !conn.conn.is_closed());
|
||||
|
||||
let new_len = self.conns.len();
|
||||
let removed = old_len - new_len;
|
||||
*conns -= removed;
|
||||
removed
|
||||
}
|
||||
|
||||
fn get_conn_entry(&mut self, conns: &mut usize) -> Option<ConnPoolEntry<C>> {
|
||||
let mut removed = self.clear_closed_clients(conns);
|
||||
let conn = self.conns.pop();
|
||||
if conn.is_some() {
|
||||
*conns -= 1;
|
||||
removed += 1;
|
||||
}
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.http_pool_opened_connections
|
||||
.get_metric()
|
||||
.dec_by(removed as i64);
|
||||
conn
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LocalConnPool<C: ClientInnerExt> {
|
||||
global_pool: RwLock<EndpointConnPool<C>>,
|
||||
|
||||
config: &'static crate::config::HttpConfig,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalConnPool<C> {
|
||||
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
global_pool: RwLock::new(EndpointConnPool {
|
||||
pools: HashMap::new(),
|
||||
total_conns: 0,
|
||||
max_conns: config.pool_options.max_conns_per_endpoint,
|
||||
global_pool_size_max_conns: config.pool_options.max_total_conns,
|
||||
}),
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_idle_timeout(&self) -> Duration {
|
||||
self.config.pool_options.idle_timeout
|
||||
}
|
||||
|
||||
// pub(crate) fn shutdown(&self) {
|
||||
// let mut pool = self.global_pool.write();
|
||||
// pool.pools.clear();
|
||||
// pool.total_conns = 0;
|
||||
// }
|
||||
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<Option<LocalClient<C>>, HttpConnError> {
|
||||
let mut client: Option<ClientInner<C>> = None;
|
||||
if let Some(entry) = self
|
||||
.global_pool
|
||||
.write()
|
||||
.get_conn_entry(conn_info.db_and_user())
|
||||
{
|
||||
client = Some(entry.conn);
|
||||
}
|
||||
|
||||
// ok return cached connection if found and establish a new one otherwise
|
||||
if let Some(client) = client {
|
||||
if client.is_closed() {
|
||||
info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
}
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
tracing::field::display(client.inner.get_process_id()),
|
||||
);
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"local_pool: reusing connection '{conn_info}'"
|
||||
);
|
||||
client.session.send(ctx.session_id())?;
|
||||
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
|
||||
ctx.success();
|
||||
return Ok(Some(LocalClient::new(
|
||||
client,
|
||||
conn_info.clone(),
|
||||
Arc::downgrade(self),
|
||||
)));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn poll_client(
|
||||
global_pool: Arc<LocalConnPool<tokio_postgres::Client>>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
client: tokio_postgres::Client,
|
||||
mut connection: tokio_postgres::Connection<Socket, NoTlsStream>,
|
||||
conn_id: uuid::Uuid,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> LocalClient<tokio_postgres::Client> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let mut session_id = ctx.session_id();
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
});
|
||||
let pool = Arc::downgrade(&global_pool);
|
||||
let pool_clone = pool.clone();
|
||||
|
||||
let db_user = conn_info.db_and_user();
|
||||
let idle = global_pool.get_idle_timeout();
|
||||
let cancel = CancellationToken::new();
|
||||
let cancelled = cancel.clone().cancelled_owned();
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let _conn_gauge = conn_gauge;
|
||||
let mut idle_timeout = pin!(tokio::time::sleep(idle));
|
||||
let mut cancelled = pin!(cancelled);
|
||||
|
||||
poll_fn(move |cx| {
|
||||
if cancelled.as_mut().poll(cx).is_ready() {
|
||||
info!("connection dropped");
|
||||
return Poll::Ready(())
|
||||
}
|
||||
|
||||
match rx.has_changed() {
|
||||
Ok(true) => {
|
||||
session_id = *rx.borrow_and_update();
|
||||
info!(%session_id, "changed session");
|
||||
idle_timeout.as_mut().reset(Instant::now() + idle);
|
||||
}
|
||||
Err(_) => {
|
||||
info!("connection dropped");
|
||||
return Poll::Ready(())
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// 5 minute idle connection timeout
|
||||
if idle_timeout.as_mut().poll(cx).is_ready() {
|
||||
idle_timeout.as_mut().reset(Instant::now() + idle);
|
||||
info!("connection idle");
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
// remove client from pool - should close the connection if it's idle.
|
||||
// does nothing if the client is currently checked-out and in-use
|
||||
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
info!("idle connection removed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let message = ready!(connection.poll_message(cx));
|
||||
|
||||
match message {
|
||||
Some(Ok(AsyncMessage::Notice(notice))) => {
|
||||
info!(%session_id, "notice: {}", notice);
|
||||
}
|
||||
Some(Ok(AsyncMessage::Notification(notif))) => {
|
||||
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
|
||||
}
|
||||
Some(Ok(_)) => {
|
||||
warn!(%session_id, "unknown message");
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!(%session_id, "connection error: {}", e);
|
||||
break
|
||||
}
|
||||
None => {
|
||||
info!("connection closed");
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// remove from connection pool
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
info!("closed connection removed");
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Ready(())
|
||||
}).await;
|
||||
|
||||
}
|
||||
.instrument(span));
|
||||
|
||||
let key = SigningKey::random(&mut OsRng);
|
||||
|
||||
let inner = ClientInner {
|
||||
inner: client,
|
||||
session: tx,
|
||||
cancel,
|
||||
aux,
|
||||
conn_id,
|
||||
key,
|
||||
jti: 0,
|
||||
};
|
||||
LocalClient::new(inner, conn_info, pool_clone)
|
||||
}
|
||||
|
||||
struct ClientInner<C: ClientInnerExt> {
|
||||
inner: C,
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
cancel: CancellationToken,
|
||||
aux: MetricsAuxInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
|
||||
// needed for pg_session_jwt state
|
||||
key: SigningKey,
|
||||
jti: u64,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for ClientInner<C> {
|
||||
fn drop(&mut self) {
|
||||
// on client drop, tell the conn to shut down
|
||||
self.cancel.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInner<C> {
|
||||
pub(crate) fn is_closed(&self) -> bool {
|
||||
self.inner.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
let aux = &self.inner.as_ref().unwrap().aux;
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
branch_id: aux.branch_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct LocalClient<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInner<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<LocalConnPool<C>>,
|
||||
}
|
||||
|
||||
pub(crate) struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<LocalConnPool<C>>,
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub(self) fn new(
|
||||
inner: ClientInner<C>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<LocalConnPool<C>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Some(inner),
|
||||
span: Span::current(),
|
||||
conn_info,
|
||||
pool,
|
||||
}
|
||||
}
|
||||
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
conn_info,
|
||||
span: _,
|
||||
} = self;
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
(&mut inner.inner, Discard { conn_info, pool })
|
||||
}
|
||||
pub(crate) fn key(&self) -> &SigningKey {
|
||||
let inner = &self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed");
|
||||
&inner.key
|
||||
}
|
||||
}
|
||||
|
||||
impl LocalClient<tokio_postgres::Client> {
|
||||
pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), HttpConnError> {
|
||||
let inner = self
|
||||
.inner
|
||||
.as_mut()
|
||||
.expect("client inner should not be removed");
|
||||
inner.jti += 1;
|
||||
|
||||
let kid = inner.inner.get_process_id();
|
||||
let header = json!({"kid":kid}).to_string();
|
||||
|
||||
let mut payload = serde_json::from_slice::<serde_json::Map<String, Value>>(payload)
|
||||
.map_err(HttpConnError::JwtPayloadError)?;
|
||||
payload.insert("jti".to_string(), Value::Number(inner.jti.into()));
|
||||
let payload = Value::Object(payload).to_string();
|
||||
|
||||
debug!(
|
||||
kid,
|
||||
jti = inner.jti,
|
||||
?header,
|
||||
?payload,
|
||||
"signing new ephemeral JWT"
|
||||
);
|
||||
|
||||
let token = sign_jwt(&inner.key, header, payload);
|
||||
|
||||
// initiates the auth session
|
||||
inner.inner.simple_query("discard all").await?;
|
||||
inner
|
||||
.inner
|
||||
.query(
|
||||
"select auth.jwt_session_init($1)",
|
||||
&[&token as &(dyn ToSql + Sync)],
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(kid, jti = inner.jti, "user session state init");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn sign_jwt(sk: &SigningKey, header: String, payload: String) -> String {
|
||||
let header = Base64UrlUnpadded::encode_string(header.as_bytes());
|
||||
let payload = Base64UrlUnpadded::encode_string(payload.as_bytes());
|
||||
|
||||
let message = format!("{header}.{payload}");
|
||||
let sig: Signature = sk.sign(message.as_bytes());
|
||||
let base64_sig = Base64UrlUnpadded::encode_string(&sig.to_bytes());
|
||||
format!("{message}.{base64_sig}")
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!(
|
||||
"local_pool: throwing away connection '{conn_info}' because connection is not idle"
|
||||
);
|
||||
}
|
||||
}
|
||||
pub(crate) fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("local_pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> LocalClient<C> {
|
||||
pub fn get_client(&self) -> &C {
|
||||
&self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed")
|
||||
.inner
|
||||
}
|
||||
|
||||
fn do_drop(&mut self) -> Option<impl FnOnce()> {
|
||||
let conn_info = self.conn_info.clone();
|
||||
let client = self
|
||||
.inner
|
||||
.take()
|
||||
.expect("client inner should not be removed");
|
||||
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
|
||||
let current_span = self.span.clone();
|
||||
// return connection to the pool
|
||||
return Some(move || {
|
||||
let _span = current_span.enter();
|
||||
EndpointConnPool::put(&conn_pool.global_pool, &conn_info, client);
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for LocalClient<C> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(drop) = self.do_drop() {
|
||||
tokio::task::spawn_blocking(drop);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@ mod conn_pool;
|
||||
mod http_conn_pool;
|
||||
mod http_util;
|
||||
mod json;
|
||||
mod json_raw_value;
|
||||
mod local_conn_pool;
|
||||
mod sql_over_http;
|
||||
mod websocket;
|
||||
|
||||
@@ -47,13 +49,14 @@ use std::pin::{pin, Pin};
|
||||
use std::sync::Arc;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
use tracing::{info, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static crate::auth::Backend<'static, (), ()>,
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -63,6 +66,7 @@ pub async fn task_main(
|
||||
info!("websocket server has shut down");
|
||||
}
|
||||
|
||||
let local_pool = local_conn_pool::LocalConnPool::new(&config.http_config);
|
||||
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
|
||||
{
|
||||
let conn_pool = Arc::clone(&conn_pool);
|
||||
@@ -105,8 +109,10 @@ pub async fn task_main(
|
||||
|
||||
let backend = Arc::new(PoolingBackend {
|
||||
http_conn_pool: Arc::clone(&http_conn_pool),
|
||||
local_pool,
|
||||
pool: Arc::clone(&conn_pool),
|
||||
config,
|
||||
auth_backend,
|
||||
endpoint_rate_limiter: Arc::clone(&endpoint_rate_limiter),
|
||||
});
|
||||
let tls_acceptor: Arc<dyn MaybeTlsAcceptor> = match config.tls_config.as_ref() {
|
||||
@@ -238,7 +244,7 @@ async fn connection_startup(
|
||||
let (conn, peer) = match read_proxy_protocol(conn).await {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::error!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
|
||||
tracing::warn!(?session_id, %peer_addr, "failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
@@ -394,6 +400,7 @@ async fn request_handler(
|
||||
async move {
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
backend.auth_backend,
|
||||
ctx,
|
||||
websocket,
|
||||
cancellation_handler,
|
||||
@@ -402,7 +409,7 @@ async fn request_handler(
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("error in websocket connection: {e:#}");
|
||||
warn!("error in websocket connection: {e:#}");
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::borrow::Cow;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -22,7 +23,7 @@ use hyper::StatusCode;
|
||||
use hyper::{HeaderMap, Request};
|
||||
use pq_proto::StartupMessageParamsBuilder;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
use tokio::time;
|
||||
use tokio_postgres::error::DbError;
|
||||
use tokio_postgres::error::ErrorPosition;
|
||||
@@ -40,11 +41,12 @@ use url::Url;
|
||||
use urlencoding;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
use crate::auth::backend::ComputeCredentials;
|
||||
use crate::auth::backend::ComputeCredentialKeys;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
use crate::auth::ComputeUserInfoParseError;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::config::HttpConfig;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::TlsConfig;
|
||||
use crate::context::RequestMonitoring;
|
||||
@@ -56,41 +58,47 @@ use crate::metrics::Metrics;
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::usage_metrics::MetricCounter;
|
||||
use crate::usage_metrics::MetricCounterRecorder;
|
||||
use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::backend::LocalProxyConnError;
|
||||
use super::backend::PoolingBackend;
|
||||
use super::conn_pool;
|
||||
use super::conn_pool::AuthData;
|
||||
use super::conn_pool::Client;
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::ConnInfoWithAuth;
|
||||
use super::http_util::json_response;
|
||||
use super::json::json_to_pg_text;
|
||||
use super::json::pg_text_row_to_json;
|
||||
use super::json::JsonConversionError;
|
||||
use super::local_conn_pool;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
#[serde(deserialize_with = "bytes_to_pg_text")]
|
||||
params: Vec<Option<String>>,
|
||||
#[serde(bound = "'de: 'a")]
|
||||
struct QueryData<'a> {
|
||||
#[serde(borrow)]
|
||||
query: Cow<'a, str>,
|
||||
|
||||
#[serde(borrow)]
|
||||
params: Vec<&'a RawValue>,
|
||||
|
||||
#[serde(default)]
|
||||
array_mode: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct BatchQueryData {
|
||||
queries: Vec<QueryData>,
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[serde(bound = "'de: 'a")]
|
||||
struct BatchQueryData<'a> {
|
||||
queries: Vec<QueryData<'a>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Payload {
|
||||
Single(QueryData),
|
||||
Batch(BatchQueryData),
|
||||
enum Payload<'a> {
|
||||
Batch(BatchQueryData<'a>),
|
||||
Single(QueryData<'a>),
|
||||
}
|
||||
|
||||
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
|
||||
@@ -103,13 +111,18 @@ static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrab
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
// TODO: consider avoiding the allocation here.
|
||||
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
|
||||
Ok(json_to_pg_text(json))
|
||||
fn parse_pg_params(params: &[&RawValue]) -> Result<Vec<Option<String>>, ReadPayloadError> {
|
||||
json_to_pg_text(params).map_err(ReadPayloadError::Parse)
|
||||
}
|
||||
|
||||
fn parse_payload(body: &[u8]) -> Result<Payload<'_>, ReadPayloadError> {
|
||||
// RawValue doesn't work via untagged enums
|
||||
// so instead we try parse each individually
|
||||
if let Ok(batch) = serde_json::from_slice(body) {
|
||||
Ok(Payload::Batch(batch))
|
||||
} else {
|
||||
Ok(Payload::Single(serde_json::from_slice(body)?))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -552,7 +565,7 @@ async fn handle_inner(
|
||||
|
||||
match conn_info.auth {
|
||||
AuthData::Jwt(jwt) if config.authentication_config.is_auth_broker => {
|
||||
handle_auth_broker_inner(config, ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
handle_auth_broker_inner(ctx, request, conn_info.conn_info, jwt, backend).await
|
||||
}
|
||||
auth => {
|
||||
handle_db_inner(
|
||||
@@ -612,45 +625,42 @@ async fn handle_db_inner(
|
||||
async {
|
||||
let body = request.into_body().collect().await?.to_bytes();
|
||||
info!(length = body.len(), "request payload read");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
|
||||
Ok::<Bytes, ReadPayloadError>(body)
|
||||
}
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
|
||||
let authenticate_and_connect = Box::pin(
|
||||
async {
|
||||
let is_local_proxy = matches!(backend.auth_backend, crate::auth::Backend::Local(_));
|
||||
|
||||
let keys = match auth {
|
||||
AuthData::Password(pw) => {
|
||||
backend
|
||||
.authenticate_with_password(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
&pw,
|
||||
)
|
||||
.authenticate_with_password(ctx, &conn_info.user_info, &pw)
|
||||
.await?
|
||||
}
|
||||
AuthData::Jwt(jwt) => {
|
||||
backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.await?;
|
||||
|
||||
ComputeCredentials {
|
||||
info: conn_info.user_info.clone(),
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
}
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let client = match keys.keys {
|
||||
ComputeCredentialKeys::JwtPayload(payload) if is_local_proxy => {
|
||||
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
|
||||
client.set_jwt_session(&payload).await?;
|
||||
Client::Local(client)
|
||||
}
|
||||
_ => {
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
Client::Remote(client)
|
||||
}
|
||||
};
|
||||
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await?;
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.success();
|
||||
@@ -659,7 +669,7 @@ async fn handle_db_inner(
|
||||
.map_err(SqlOverHttpError::from),
|
||||
);
|
||||
|
||||
let (payload, mut client) = match run_until_cancelled(
|
||||
let (body, mut client) = match run_until_cancelled(
|
||||
// Run both operations in parallel
|
||||
try_join(
|
||||
pin!(fetch_and_process_request),
|
||||
@@ -673,6 +683,8 @@ async fn handle_db_inner(
|
||||
None => return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Connect)),
|
||||
};
|
||||
|
||||
let payload = parse_payload(&body)?;
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/json");
|
||||
@@ -680,7 +692,7 @@ async fn handle_db_inner(
|
||||
// Now execute the query and return the result.
|
||||
let json_output = match payload {
|
||||
Payload::Single(stmt) => {
|
||||
stmt.process(config, cancel, &mut client, parsed_headers)
|
||||
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await?
|
||||
}
|
||||
Payload::Batch(statements) => {
|
||||
@@ -698,7 +710,7 @@ async fn handle_db_inner(
|
||||
}
|
||||
|
||||
statements
|
||||
.process(config, cancel, &mut client, parsed_headers)
|
||||
.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
@@ -738,7 +750,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
|
||||
];
|
||||
|
||||
async fn handle_auth_broker_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
conn_info: ConnInfo,
|
||||
@@ -746,12 +757,7 @@ async fn handle_auth_broker_inner(
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, SqlOverHttpError> {
|
||||
backend
|
||||
.authenticate_with_jwt(
|
||||
ctx,
|
||||
&config.authentication_config,
|
||||
&conn_info.user_info,
|
||||
jwt,
|
||||
)
|
||||
.authenticate_with_jwt(ctx, &conn_info.user_info, jwt)
|
||||
.await
|
||||
.map_err(HttpConnError::from)?;
|
||||
|
||||
@@ -786,12 +792,12 @@ async fn handle_auth_broker_inner(
|
||||
.map(|b| b.boxed()))
|
||||
}
|
||||
|
||||
impl QueryData {
|
||||
impl QueryData<'_> {
|
||||
async fn process(
|
||||
self,
|
||||
config: &'static ProxyConfig,
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
client: &mut Client,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let (inner, mut discard) = client.inner();
|
||||
@@ -820,7 +826,7 @@ impl QueryData {
|
||||
Either::Right((_cancelled, query)) => {
|
||||
tracing::info!("cancelling query");
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::error!(?err, "could not cancel query");
|
||||
tracing::warn!(?err, "could not cancel query");
|
||||
}
|
||||
// wait for the query cancellation
|
||||
match time::timeout(time::Duration::from_millis(100), query).await {
|
||||
@@ -860,12 +866,12 @@ impl QueryData {
|
||||
}
|
||||
}
|
||||
|
||||
impl BatchQueryData {
|
||||
impl BatchQueryData<'_> {
|
||||
async fn process(
|
||||
self,
|
||||
config: &'static ProxyConfig,
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
client: &mut Client,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
info!("starting transaction");
|
||||
@@ -909,7 +915,7 @@ impl BatchQueryData {
|
||||
}
|
||||
Err(SqlOverHttpError::Cancelled(_)) => {
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::error!(?err, "could not cancel query");
|
||||
tracing::warn!(?err, "could not cancel query");
|
||||
}
|
||||
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
|
||||
discard.discard();
|
||||
@@ -933,10 +939,10 @@ impl BatchQueryData {
|
||||
}
|
||||
|
||||
async fn query_batch(
|
||||
config: &'static ProxyConfig,
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
transaction: &Transaction<'_>,
|
||||
queries: BatchQueryData,
|
||||
queries: BatchQueryData<'_>,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let mut results = Vec::with_capacity(queries.queries.len());
|
||||
@@ -972,14 +978,14 @@ async fn query_batch(
|
||||
}
|
||||
|
||||
async fn query_to_json<T: GenericClient>(
|
||||
config: &'static ProxyConfig,
|
||||
config: &'static HttpConfig,
|
||||
client: &T,
|
||||
data: QueryData,
|
||||
data: QueryData<'_>,
|
||||
current_size: &mut usize,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
|
||||
info!("executing query");
|
||||
let query_params = data.params;
|
||||
let query_params = parse_pg_params(&data.params)?;
|
||||
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
|
||||
info!("finished executing query");
|
||||
|
||||
@@ -993,9 +999,9 @@ async fn query_to_json<T: GenericClient>(
|
||||
rows.push(row);
|
||||
// we don't have a streaming response support yet so this is to prevent OOM
|
||||
// from a malicious query (eg a cross join)
|
||||
if *current_size > config.http_config.max_response_size_bytes {
|
||||
if *current_size > config.max_response_size_bytes {
|
||||
return Err(SqlOverHttpError::ResponseTooLarge(
|
||||
config.http_config.max_response_size_bytes,
|
||||
config.max_response_size_bytes,
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -1058,3 +1064,88 @@ async fn query_to_json<T: GenericClient>(
|
||||
|
||||
Ok((ready, results))
|
||||
}
|
||||
|
||||
enum Client {
|
||||
Remote(conn_pool::Client<tokio_postgres::Client>),
|
||||
Local(local_conn_pool::LocalClient<tokio_postgres::Client>),
|
||||
}
|
||||
|
||||
enum Discard<'a> {
|
||||
Remote(conn_pool::Discard<'a, tokio_postgres::Client>),
|
||||
Local(local_conn_pool::Discard<'a, tokio_postgres::Client>),
|
||||
}
|
||||
|
||||
impl Client {
|
||||
fn metrics(&self) -> Arc<MetricCounter> {
|
||||
match self {
|
||||
Client::Remote(client) => client.metrics(),
|
||||
Client::Local(local_client) => local_client.metrics(),
|
||||
}
|
||||
}
|
||||
|
||||
fn inner(&mut self) -> (&mut tokio_postgres::Client, Discard<'_>) {
|
||||
match self {
|
||||
Client::Remote(client) => {
|
||||
let (c, d) = client.inner();
|
||||
(c, Discard::Remote(d))
|
||||
}
|
||||
Client::Local(local_client) => {
|
||||
let (c, d) = local_client.inner();
|
||||
(c, Discard::Local(d))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Discard<'_> {
|
||||
fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
match self {
|
||||
Discard::Remote(discard) => discard.check_idle(status),
|
||||
Discard::Local(discard) => discard.check_idle(status),
|
||||
}
|
||||
}
|
||||
fn discard(&mut self) {
|
||||
match self {
|
||||
Discard::Remote(discard) => discard.discard(),
|
||||
Discard::Local(discard) => discard.discard(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use typed_json::json;
|
||||
|
||||
use super::parse_payload;
|
||||
use super::Payload;
|
||||
|
||||
#[test]
|
||||
fn raw_single_payload() {
|
||||
let body = json! {
|
||||
{"query":"select $1","params":["1"]}
|
||||
}
|
||||
.to_string();
|
||||
|
||||
let Payload::Single(query) = parse_payload(body.as_bytes()).unwrap() else {
|
||||
panic!("expected single")
|
||||
};
|
||||
assert_eq!(&*query.query, "select $1");
|
||||
assert_eq!(query.params[0].get(), "\"1\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn raw_batch_payload() {
|
||||
let body = json! {{
|
||||
"queries": [
|
||||
{"query":"select $1","params":["1"]},
|
||||
{"query":"select $1","params":["2"]},
|
||||
]
|
||||
}}
|
||||
.to_string();
|
||||
|
||||
let Payload::Batch(query) = parse_payload(body.as_bytes()).unwrap() else {
|
||||
panic!("expected batch")
|
||||
};
|
||||
assert_eq!(query.queries.len(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,6 +129,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
|
||||
pub(crate) async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static crate::auth::Backend<'static, (), ()>,
|
||||
ctx: RequestMonitoring,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -145,6 +146,7 @@ pub(crate) async fn serve_websocket(
|
||||
|
||||
let res = Box::pin(handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
|
||||
@@ -27,7 +27,7 @@ use std::{
|
||||
};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, instrument, trace};
|
||||
use tracing::{error, info, instrument, trace, warn};
|
||||
use utils::backoff;
|
||||
use uuid::{NoContext, Timestamp};
|
||||
|
||||
@@ -346,7 +346,7 @@ async fn collect_metrics_iteration(
|
||||
error!("metrics endpoint refused the sent metrics: {:?}", res);
|
||||
for metric in chunk.events.iter().filter(|e| e.value > (1u64 << 40)) {
|
||||
// Report if the metric value is suspiciously large
|
||||
error!("potentially abnormal metric value: {:?}", metric);
|
||||
warn!("potentially abnormal metric value: {:?}", metric);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,15 +15,20 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
(Scope::Admin | Scope::PageServerApi | Scope::GenerationsApi | Scope::Scrubber, _) => {
|
||||
Err(AuthError(
|
||||
format!(
|
||||
"JWT scope '{:?}' is ineligible for Safekeeper auth",
|
||||
claims.scope
|
||||
)
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
(
|
||||
Scope::Admin
|
||||
| Scope::PageServerApi
|
||||
| Scope::GenerationsApi
|
||||
| Scope::Infra
|
||||
| Scope::Scrubber,
|
||||
_,
|
||||
) => Err(AuthError(
|
||||
format!(
|
||||
"JWT scope '{:?}' is ineligible for Safekeeper auth",
|
||||
claims.scope
|
||||
)
|
||||
.into(),
|
||||
)),
|
||||
(Scope::SafekeeperData, _) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,7 +636,7 @@ async fn handle_tenant_list(
|
||||
}
|
||||
|
||||
async fn handle_node_register(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
check_permissions(&req, Scope::Admin)?;
|
||||
check_permissions(&req, Scope::Infra)?;
|
||||
|
||||
let mut req = match maybe_forward(req).await {
|
||||
ForwardOutcome::Forwarded(res) => {
|
||||
@@ -1182,7 +1182,7 @@ async fn handle_get_safekeeper(req: Request<Body>) -> Result<Response<Body>, Api
|
||||
/// Assumes information is only relayed to storage controller after first selecting an unique id on
|
||||
/// control plane database, which means we have an id field in the request and payload.
|
||||
async fn handle_upsert_safekeeper(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
check_permissions(&req, Scope::Admin)?;
|
||||
check_permissions(&req, Scope::Infra)?;
|
||||
|
||||
let body = json_request::<SafekeeperPersistence>(&mut req).await?;
|
||||
let id = parse_request_param::<i64>(&req, "id")?;
|
||||
|
||||
@@ -317,9 +317,8 @@ pub async fn scan_pageserver_metadata(
|
||||
tenant_timeline_results.push((ttid, data));
|
||||
}
|
||||
|
||||
let tenant_id = tenant_id.expect("Must be set if results are present");
|
||||
|
||||
if !tenant_timeline_results.is_empty() {
|
||||
let tenant_id = tenant_id.expect("Must be set if results are present");
|
||||
analyze_tenant(
|
||||
&remote_client,
|
||||
tenant_id,
|
||||
|
||||
@@ -64,10 +64,12 @@ By default performance tests are excluded. To run them explicitly pass performan
|
||||
Useful environment variables:
|
||||
|
||||
`NEON_BIN`: The directory where neon binaries can be found.
|
||||
`COMPATIBILITY_NEON_BIN`: The directory where the previous version of Neon binaries can be found
|
||||
`POSTGRES_DISTRIB_DIR`: The directory where postgres distribution can be found.
|
||||
Since pageserver supports several postgres versions, `POSTGRES_DISTRIB_DIR` must contain
|
||||
a subdirectory for each version with naming convention `v{PG_VERSION}/`.
|
||||
Inside that dir, a `bin/postgres` binary should be present.
|
||||
`COMPATIBILITY_POSTGRES_DISTRIB_DIR`: The directory where the prevoius version of postgres distribution can be found.
|
||||
`DEFAULT_PG_VERSION`: The version of Postgres to use,
|
||||
This is used to construct full path to the postgres binaries.
|
||||
Format is 2-digit major version nubmer, i.e. `DEFAULT_PG_VERSION=16`
|
||||
@@ -294,6 +296,16 @@ def test_foobar2(neon_env_builder: NeonEnvBuilder):
|
||||
client.timeline_detail(tenant_id=tenant_id, timeline_id=timeline_id)
|
||||
```
|
||||
|
||||
All the test which rely on NeonEnvBuilder, can check the various version combinations of the components.
|
||||
To do this yuo may want to add the parametrize decorator with the function fixtures.utils.allpairs_versions()
|
||||
E.g.
|
||||
|
||||
```python
|
||||
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
|
||||
def test_something(
|
||||
...
|
||||
```
|
||||
|
||||
For more information about pytest fixtures, see https://docs.pytest.org/en/stable/fixture.html
|
||||
|
||||
At the end of a test, all the nodes in the environment are automatically stopped, so you
|
||||
|
||||
@@ -6,6 +6,7 @@ pytest_plugins = (
|
||||
"fixtures.httpserver",
|
||||
"fixtures.compute_reconfigure",
|
||||
"fixtures.storage_controller_proxy",
|
||||
"fixtures.paths",
|
||||
"fixtures.neon_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.pg_stats",
|
||||
|
||||
@@ -7,7 +7,6 @@ import json
|
||||
import os
|
||||
import re
|
||||
import timeit
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -25,7 +24,8 @@ from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import NeonPageserver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable, ClassVar, Optional
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
||||
"""
|
||||
@@ -141,6 +141,28 @@ class PgBenchRunResult:
|
||||
)
|
||||
|
||||
|
||||
# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171
|
||||
#
|
||||
# This used to be a class variable on PgBenchInitResult. However later versions
|
||||
# of Python complain:
|
||||
#
|
||||
# ValueError: mutable default <class 'dict'> for field EXTRACTORS is not allowed: use default_factory
|
||||
#
|
||||
# When you do what the error tells you to do, it seems to fail our Python 3.9
|
||||
# test environment. So let's just move it to a private module constant, and move
|
||||
# on.
|
||||
_PGBENCH_INIT_EXTRACTORS: Mapping[str, re.Pattern[str]] = {
|
||||
"drop_tables": re.compile(r"drop tables (\d+\.\d+) s"),
|
||||
"create_tables": re.compile(r"create tables (\d+\.\d+) s"),
|
||||
"client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"),
|
||||
"server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"),
|
||||
"vacuum": re.compile(r"vacuum (\d+\.\d+) s"),
|
||||
"primary_keys": re.compile(r"primary keys (\d+\.\d+) s"),
|
||||
"foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"),
|
||||
"total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PgBenchInitResult:
|
||||
total: Optional[float]
|
||||
@@ -155,20 +177,6 @@ class PgBenchInitResult:
|
||||
start_timestamp: int
|
||||
end_timestamp: int
|
||||
|
||||
# Taken from https://github.com/postgres/postgres/blob/REL_15_1/src/bin/pgbench/pgbench.c#L5144-L5171
|
||||
EXTRACTORS: ClassVar[dict[str, re.Pattern[str]]] = dataclasses.field(
|
||||
default_factory=lambda: {
|
||||
"drop_tables": re.compile(r"drop tables (\d+\.\d+) s"),
|
||||
"create_tables": re.compile(r"create tables (\d+\.\d+) s"),
|
||||
"client_side_generate": re.compile(r"client-side generate (\d+\.\d+) s"),
|
||||
"server_side_generate": re.compile(r"server-side generate (\d+\.\d+) s"),
|
||||
"vacuum": re.compile(r"vacuum (\d+\.\d+) s"),
|
||||
"primary_keys": re.compile(r"primary keys (\d+\.\d+) s"),
|
||||
"foreign_keys": re.compile(r"foreign keys (\d+\.\d+) s"),
|
||||
"total": re.compile(r"done in (\d+\.\d+) s"), # Total time printed by pgbench
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_from_stderr(
|
||||
cls,
|
||||
@@ -185,7 +193,7 @@ class PgBenchInitResult:
|
||||
timings: dict[str, Optional[float]] = {}
|
||||
last_line_items = re.split(r"\(|\)|,", last_line)
|
||||
for item in last_line_items:
|
||||
for key, regex in cls.EXTRACTORS.items():
|
||||
for key, regex in _PGBENCH_INIT_EXTRACTORS.items():
|
||||
if (m := regex.match(item.strip())) is not None:
|
||||
if key in timings:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -6,6 +6,8 @@ from enum import Enum
|
||||
from functools import total_ordering
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Union
|
||||
|
||||
@@ -31,33 +33,36 @@ class Lsn:
|
||||
self.lsn_int = (int(left, 16) << 32) + int(right, 16)
|
||||
assert 0 <= self.lsn_int <= 0xFFFFFFFF_FFFFFFFF
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
"""Convert lsn from int to standard hex notation."""
|
||||
return f"{(self.lsn_int >> 32):X}/{(self.lsn_int & 0xFFFFFFFF):X}"
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f'Lsn("{str(self)}")'
|
||||
|
||||
def __int__(self) -> int:
|
||||
return self.lsn_int
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, Lsn):
|
||||
return NotImplemented
|
||||
return self.lsn_int < other.lsn_int
|
||||
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
def __gt__(self, other: object) -> bool:
|
||||
if not isinstance(other, Lsn):
|
||||
raise NotImplementedError
|
||||
return self.lsn_int > other.lsn_int
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
@override
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Lsn):
|
||||
return NotImplemented
|
||||
return self.lsn_int == other.lsn_int
|
||||
|
||||
# Returns the difference between two Lsns, in bytes
|
||||
def __sub__(self, other: Any) -> int:
|
||||
def __sub__(self, other: object) -> int:
|
||||
if not isinstance(other, Lsn):
|
||||
return NotImplemented
|
||||
return self.lsn_int - other.lsn_int
|
||||
@@ -70,6 +75,7 @@ class Lsn:
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.lsn_int)
|
||||
|
||||
@@ -116,19 +122,22 @@ class Id:
|
||||
self.id = bytearray.fromhex(x)
|
||||
assert len(self.id) == 16
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.id.hex()
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return self.id < other.id
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
@override
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return self.id == other.id
|
||||
|
||||
@override
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self.id))
|
||||
|
||||
@@ -139,25 +148,31 @@ class Id:
|
||||
|
||||
|
||||
class TenantId(Id):
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f'`TenantId("{self.id.hex()}")'
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.id.hex()
|
||||
|
||||
|
||||
class NodeId(Id):
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f'`NodeId("{self.id.hex()}")'
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.id.hex()
|
||||
|
||||
|
||||
class TimelineId(Id):
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f'TimelineId("{self.id.hex()}")'
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.id.hex()
|
||||
|
||||
@@ -187,7 +202,7 @@ class TenantShardId:
|
||||
assert self.shard_number < self.shard_count or self.shard_count == 0
|
||||
|
||||
@classmethod
|
||||
def parse(cls: type[TTenantShardId], input) -> TTenantShardId:
|
||||
def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId:
|
||||
if len(input) == 32:
|
||||
return cls(
|
||||
tenant_id=TenantId(input),
|
||||
@@ -203,6 +218,7 @@ class TenantShardId:
|
||||
else:
|
||||
raise ValueError(f"Invalid TenantShardId '{input}'")
|
||||
|
||||
@override
|
||||
def __str__(self):
|
||||
if self.shard_count > 0:
|
||||
return f"{self.tenant_id}-{self.shard_number:02x}{self.shard_count:02x}"
|
||||
@@ -210,22 +226,25 @@ class TenantShardId:
|
||||
# Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id)
|
||||
return str(self.tenant_id)
|
||||
|
||||
@override
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def _tuple(self) -> tuple[TenantId, int, int]:
|
||||
return (self.tenant_id, self.shard_number, self.shard_count)
|
||||
|
||||
def __lt__(self, other) -> bool:
|
||||
def __lt__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return self._tuple() < other._tuple()
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
@override
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, type(self)):
|
||||
return NotImplemented
|
||||
return self._tuple() == other._tuple()
|
||||
|
||||
@override
|
||||
def __hash__(self) -> int:
|
||||
return hash(self._tuple())
|
||||
|
||||
|
||||
@@ -8,9 +8,11 @@ from contextlib import _GeneratorContextManager, contextmanager
|
||||
|
||||
# Type-related stuff
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from _pytest.fixtures import FixtureRequest
|
||||
from typing_extensions import override
|
||||
|
||||
from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker
|
||||
from fixtures.log_helper import log
|
||||
@@ -24,6 +26,9 @@ from fixtures.neon_fixtures import (
|
||||
)
|
||||
from fixtures.pg_stats import PgStatTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
class PgCompare(ABC):
|
||||
"""Common interface of all postgres implementations, useful for benchmarks.
|
||||
@@ -65,12 +70,12 @@ class PgCompare(ABC):
|
||||
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def record_pageserver_writes(self, out_name):
|
||||
def record_pageserver_writes(self, out_name: str):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def record_duration(self, out_name):
|
||||
def record_duration(self, out_name: str):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
@@ -122,28 +127,34 @@ class NeonCompare(PgCompare):
|
||||
self._pg = self.env.endpoints.create_start("main", "main", self.tenant)
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg(self) -> PgProtocol:
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
@override
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg_bin(self) -> PgBin:
|
||||
return self._pg_bin
|
||||
|
||||
@override
|
||||
def flush(self, compact: bool = True, gc: bool = True):
|
||||
wait_for_last_flush_lsn(self.env, self._pg, self.tenant, self.timeline)
|
||||
self.pageserver_http_client.timeline_checkpoint(self.tenant, self.timeline, compact=compact)
|
||||
if gc:
|
||||
self.pageserver_http_client.timeline_gc(self.tenant, self.timeline, 0)
|
||||
|
||||
@override
|
||||
def compact(self):
|
||||
self.pageserver_http_client.timeline_compact(
|
||||
self.tenant, self.timeline, wait_until_uploaded=True
|
||||
)
|
||||
|
||||
@override
|
||||
def report_peak_memory_use(self):
|
||||
self.zenbenchmark.record(
|
||||
"peak_mem",
|
||||
@@ -152,6 +163,7 @@ class NeonCompare(PgCompare):
|
||||
report=MetricReport.LOWER_IS_BETTER,
|
||||
)
|
||||
|
||||
@override
|
||||
def report_size(self):
|
||||
timeline_size = self.zenbenchmark.get_timeline_size(
|
||||
self.env.repo_dir, self.tenant, self.timeline
|
||||
@@ -185,9 +197,11 @@ class NeonCompare(PgCompare):
|
||||
"num_files_uploaded", total_files, "", report=MetricReport.LOWER_IS_BETTER
|
||||
)
|
||||
|
||||
@override
|
||||
def record_pageserver_writes(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_pageserver_writes(self.env.pageserver, out_name)
|
||||
|
||||
@override
|
||||
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
@@ -211,26 +225,33 @@ class VanillaCompare(PgCompare):
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg(self) -> VanillaPostgres:
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
@override
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg_bin(self) -> PgBin:
|
||||
return self._pg.pg_bin
|
||||
|
||||
@override
|
||||
def flush(self, compact: bool = False, gc: bool = False):
|
||||
self.cur.execute("checkpoint")
|
||||
|
||||
@override
|
||||
def compact(self):
|
||||
pass
|
||||
|
||||
@override
|
||||
def report_peak_memory_use(self):
|
||||
pass # TODO find something
|
||||
|
||||
@override
|
||||
def report_size(self):
|
||||
data_size = self.pg.get_subdir_size(Path("base"))
|
||||
self.zenbenchmark.record(
|
||||
@@ -245,6 +266,7 @@ class VanillaCompare(PgCompare):
|
||||
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
|
||||
yield # Do nothing
|
||||
|
||||
@override
|
||||
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
@@ -261,28 +283,35 @@ class RemoteCompare(PgCompare):
|
||||
self.cur = self.conn.cursor()
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg(self) -> PgProtocol:
|
||||
return self._pg
|
||||
|
||||
@property
|
||||
@override
|
||||
def zenbenchmark(self) -> NeonBenchmarker:
|
||||
return self._zenbenchmark
|
||||
|
||||
@property
|
||||
@override
|
||||
def pg_bin(self) -> PgBin:
|
||||
return self._pg.pg_bin
|
||||
|
||||
def flush(self):
|
||||
@override
|
||||
def flush(self, compact: bool = False, gc: bool = False):
|
||||
# TODO: flush the remote pageserver
|
||||
pass
|
||||
|
||||
@override
|
||||
def compact(self):
|
||||
pass
|
||||
|
||||
@override
|
||||
def report_peak_memory_use(self):
|
||||
# TODO: get memory usage from remote pageserver
|
||||
pass
|
||||
|
||||
@override
|
||||
def report_size(self):
|
||||
# TODO: get storage size from remote pageserver
|
||||
pass
|
||||
@@ -291,6 +320,7 @@ class RemoteCompare(PgCompare):
|
||||
def record_pageserver_writes(self, out_name: str) -> Iterator[None]:
|
||||
yield # Do nothing
|
||||
|
||||
@override
|
||||
def record_duration(self, out_name: str) -> _GeneratorContextManager[None]:
|
||||
return self.zenbenchmark.record_duration(out_name)
|
||||
|
||||
|
||||
@@ -1,27 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from pytest_httpserver import HTTPServer
|
||||
from werkzeug.wrappers.request import Request
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
from fixtures.common_types import TenantId
|
||||
from fixtures.log_helper import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
class ComputeReconfigure:
|
||||
def __init__(self, server):
|
||||
def __init__(self, server: HTTPServer):
|
||||
self.server = server
|
||||
self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach"
|
||||
self.workloads = {}
|
||||
self.on_notify = None
|
||||
self.workloads: dict[TenantId, Any] = {}
|
||||
self.on_notify: Optional[Callable[[Any], None]] = None
|
||||
|
||||
def register_workload(self, workload):
|
||||
def register_workload(self, workload: Any):
|
||||
self.workloads[workload.tenant_id] = workload
|
||||
|
||||
def register_on_notify(self, fn):
|
||||
def register_on_notify(self, fn: Optional[Callable[[Any], None]]):
|
||||
"""
|
||||
Add some extra work during a notification, like sleeping to slow things down, or
|
||||
logging what was notified.
|
||||
@@ -30,7 +34,7 @@ class ComputeReconfigure:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def compute_reconfigure_listener(make_httpserver):
|
||||
def compute_reconfigure_listener(make_httpserver: HTTPServer):
|
||||
"""
|
||||
This fixture exposes an HTTP listener for the storage controller to submit
|
||||
compute notifications to us, instead of updating neon_local endpoints itself.
|
||||
@@ -48,7 +52,7 @@ def compute_reconfigure_listener(make_httpserver):
|
||||
# accept a healthy rate of calls into notify-attach.
|
||||
reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def handler(request: Request):
|
||||
def handler(request: Request) -> Response:
|
||||
assert request.json is not None
|
||||
body: dict[str, Any] = request.json
|
||||
log.info(f"notify-attach request: {body}")
|
||||
|
||||
@@ -28,8 +28,3 @@ class EndpointHttpClient(requests.Session):
|
||||
res = self.get(f"http://localhost:{self.port}/installed_extensions")
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
def metrics(self):
|
||||
res = self.get(f"http://localhost:{self.port}/metrics")
|
||||
res.raise_for_status()
|
||||
return res.text
|
||||
|
||||
@@ -14,8 +14,10 @@ from allure_pytest.utils import allure_name, allure_suite_labels
|
||||
from fixtures.log_helper import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
|
||||
|
||||
"""
|
||||
The plugin reruns flaky tests.
|
||||
It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py`
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from pytest_httpserver import HTTPServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from fixtures.port_distributor import PortDistributor
|
||||
|
||||
# TODO: mypy fails with:
|
||||
# Module "fixtures.neon_fixtures" does not explicitly export attribute "PortDistributor" [attr-defined]
|
||||
# from fixtures.neon_fixtures import PortDistributor
|
||||
@@ -17,7 +24,7 @@ def httpserver_ssl_context():
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
|
||||
def make_httpserver(httpserver_listen_address, httpserver_ssl_context) -> Iterator[HTTPServer]:
|
||||
host, port = httpserver_listen_address
|
||||
if not host:
|
||||
host = HTTPServer.DEFAULT_LISTEN_HOST
|
||||
@@ -33,13 +40,13 @@ def make_httpserver(httpserver_listen_address, httpserver_ssl_context):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def httpserver(make_httpserver):
|
||||
def httpserver(make_httpserver: HTTPServer) -> Iterator[HTTPServer]:
|
||||
server = make_httpserver
|
||||
yield server
|
||||
server.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def httpserver_listen_address(port_distributor) -> tuple[str, int]:
|
||||
def httpserver_listen_address(port_distributor: PortDistributor) -> tuple[str, int]:
|
||||
port = port_distributor.get_port()
|
||||
return ("localhost", port)
|
||||
|
||||
@@ -31,7 +31,7 @@ LOGGING = {
|
||||
}
|
||||
|
||||
|
||||
def getLogger(name="root") -> logging.Logger:
|
||||
def getLogger(name: str = "root") -> logging.Logger:
|
||||
"""Method to get logger for tests.
|
||||
|
||||
Should be used to get correctly initialized logger."""
|
||||
|
||||
@@ -22,7 +22,7 @@ class Metrics:
|
||||
|
||||
def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]:
|
||||
filter = filter or {}
|
||||
res = []
|
||||
res: list[Sample] = []
|
||||
|
||||
for sample in self.metrics[name]:
|
||||
try:
|
||||
@@ -59,7 +59,7 @@ class MetricsGetter:
|
||||
return results[0].value
|
||||
|
||||
def get_metrics_values(
|
||||
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok=False
|
||||
self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
When fetching multiple named metrics, it is more efficient to use this
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast
|
||||
import requests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from fixtures.pg_version import PgVersion
|
||||
|
||||
@@ -25,9 +25,7 @@ class NeonAPI:
|
||||
self.__neon_api_key = neon_api_key
|
||||
self.__neon_api_base_url = neon_api_base_url.strip("/")
|
||||
|
||||
def __request(
|
||||
self, method: Union[str, bytes], endpoint: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
def __request(self, method: str | bytes, endpoint: str, **kwargs: Any) -> requests.Response:
|
||||
if "headers" not in kwargs:
|
||||
kwargs["headers"] = {}
|
||||
kwargs["headers"]["Authorization"] = f"Bearer {self.__neon_api_key}"
|
||||
|
||||
@@ -18,7 +18,6 @@ from contextlib import closing, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from fcntl import LOCK_EX, LOCK_UN, flock
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
@@ -59,6 +58,7 @@ from fixtures.pageserver.http import PageserverHttpClient
|
||||
from fixtures.pageserver.utils import (
|
||||
wait_for_last_record_lsn,
|
||||
)
|
||||
from fixtures.paths import get_test_repo_dir, shared_snapshot_dir
|
||||
from fixtures.pg_version import PgVersion
|
||||
from fixtures.port_distributor import PortDistributor
|
||||
from fixtures.remote_storage import (
|
||||
@@ -75,8 +75,8 @@ from fixtures.safekeeper.http import SafekeeperHttpClient
|
||||
from fixtures.safekeeper.utils import wait_walreceivers_absent
|
||||
from fixtures.utils import (
|
||||
ATTACHMENT_NAME_REGEX,
|
||||
COMPONENT_BINARIES,
|
||||
allure_add_grafana_links,
|
||||
allure_attach_from_dir,
|
||||
assert_no_errors,
|
||||
get_dir_size,
|
||||
print_gc_result,
|
||||
@@ -96,6 +96,8 @@ if TYPE_CHECKING:
|
||||
Union,
|
||||
)
|
||||
|
||||
from fixtures.paths import SnapshotDirLocked
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -118,65 +120,11 @@ put directly-importable functions into utils.py or another separate file.
|
||||
|
||||
Env = dict[str, str]
|
||||
|
||||
DEFAULT_OUTPUT_DIR: str = "test_output"
|
||||
DEFAULT_BRANCH_NAME: str = "main"
|
||||
|
||||
BASE_PORT: int = 15000
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base_dir() -> Iterator[Path]:
|
||||
# find the base directory (currently this is the git root)
|
||||
base_dir = Path(__file__).parents[2]
|
||||
log.info(f"base_dir is {base_dir}")
|
||||
|
||||
yield base_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]:
|
||||
if os.getenv("REMOTE_ENV"):
|
||||
# we are in remote env and do not have neon binaries locally
|
||||
# this is the case for benchmarks run on self-hosted runner
|
||||
return
|
||||
|
||||
# Find the neon binaries.
|
||||
if env_neon_bin := os.environ.get("NEON_BIN"):
|
||||
binpath = Path(env_neon_bin)
|
||||
else:
|
||||
binpath = base_dir / "target" / build_type
|
||||
log.info(f"neon_binpath is {binpath}")
|
||||
|
||||
if not (binpath / "pageserver").exists():
|
||||
raise Exception(f"neon binaries not found at '{binpath}'")
|
||||
|
||||
yield binpath
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pg_distrib_dir(base_dir: Path) -> Iterator[Path]:
|
||||
if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"):
|
||||
distrib_dir = Path(env_postgres_bin).resolve()
|
||||
else:
|
||||
distrib_dir = base_dir / "pg_install"
|
||||
|
||||
log.info(f"pg_distrib_dir is {distrib_dir}")
|
||||
yield distrib_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def top_output_dir(base_dir: Path) -> Iterator[Path]:
|
||||
# Compute the top-level directory for all tests.
|
||||
if env_test_output := os.environ.get("TEST_OUTPUT"):
|
||||
output_dir = Path(env_test_output).resolve()
|
||||
else:
|
||||
output_dir = base_dir / DEFAULT_OUTPUT_DIR
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
log.info(f"top_output_dir is {output_dir}")
|
||||
yield output_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def neon_api_key() -> str:
|
||||
api_key = os.getenv("NEON_API_KEY")
|
||||
@@ -369,11 +317,14 @@ class NeonEnvBuilder:
|
||||
run_id: uuid.UUID,
|
||||
mock_s3_server: MockS3Server,
|
||||
neon_binpath: Path,
|
||||
compatibility_neon_binpath: Path,
|
||||
pg_distrib_dir: Path,
|
||||
compatibility_pg_distrib_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
test_name: str,
|
||||
top_output_dir: Path,
|
||||
test_output_dir: Path,
|
||||
combination,
|
||||
test_overlay_dir: Optional[Path] = None,
|
||||
pageserver_remote_storage: Optional[RemoteStorage] = None,
|
||||
# toml that will be decomposed into `--config-override` flags during `pageserver --init`
|
||||
@@ -455,6 +406,19 @@ class NeonEnvBuilder:
|
||||
"test_"
|
||||
), "Unexpectedly instantiated from outside a test function"
|
||||
self.test_name = test_name
|
||||
self.compatibility_neon_binpath = compatibility_neon_binpath
|
||||
self.compatibility_pg_distrib_dir = compatibility_pg_distrib_dir
|
||||
self.version_combination = combination
|
||||
self.mixdir = self.test_output_dir / "mixdir_neon"
|
||||
if self.version_combination is not None:
|
||||
assert (
|
||||
self.compatibility_neon_binpath is not None
|
||||
), "the environment variable COMPATIBILITY_NEON_BIN is required when using mixed versions"
|
||||
assert (
|
||||
self.compatibility_pg_distrib_dir is not None
|
||||
), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required when using mixed versions"
|
||||
self.mixdir.mkdir(mode=0o755, exist_ok=True)
|
||||
self._mix_versions()
|
||||
|
||||
def init_configs(self, default_remote_storage_if_missing: bool = True) -> NeonEnv:
|
||||
# Cannot create more than one environment from one builder
|
||||
@@ -655,6 +619,21 @@ class NeonEnvBuilder:
|
||||
|
||||
return self.env
|
||||
|
||||
def _mix_versions(self):
|
||||
assert self.version_combination is not None, "version combination must be set"
|
||||
for component, paths in COMPONENT_BINARIES.items():
|
||||
directory = (
|
||||
self.neon_binpath
|
||||
if self.version_combination[component] == "new"
|
||||
else self.compatibility_neon_binpath
|
||||
)
|
||||
for filename in paths:
|
||||
destination = self.mixdir / filename
|
||||
destination.symlink_to(directory / filename)
|
||||
if self.version_combination["compute"] == "old":
|
||||
self.pg_distrib_dir = self.compatibility_pg_distrib_dir
|
||||
self.neon_binpath = self.mixdir
|
||||
|
||||
def overlay_mount(self, ident: str, srcdir: Path, dstdir: Path):
|
||||
"""
|
||||
Mount `srcdir` as an overlayfs mount at `dstdir`.
|
||||
@@ -1403,7 +1382,9 @@ def neon_simple_env(
|
||||
top_output_dir: Path,
|
||||
test_output_dir: Path,
|
||||
neon_binpath: Path,
|
||||
compatibility_neon_binpath: Path,
|
||||
pg_distrib_dir: Path,
|
||||
compatibility_pg_distrib_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
pageserver_virtual_file_io_engine: str,
|
||||
pageserver_aux_file_policy: Optional[AuxFileStore],
|
||||
@@ -1418,6 +1399,11 @@ def neon_simple_env(
|
||||
|
||||
# Create the environment in the per-test output directory
|
||||
repo_dir = get_test_repo_dir(request, top_output_dir)
|
||||
combination = (
|
||||
request._pyfuncitem.callspec.params["combination"]
|
||||
if "combination" in request._pyfuncitem.callspec.params
|
||||
else None
|
||||
)
|
||||
|
||||
with NeonEnvBuilder(
|
||||
top_output_dir=top_output_dir,
|
||||
@@ -1425,7 +1411,9 @@ def neon_simple_env(
|
||||
port_distributor=port_distributor,
|
||||
mock_s3_server=mock_s3_server,
|
||||
neon_binpath=neon_binpath,
|
||||
compatibility_neon_binpath=compatibility_neon_binpath,
|
||||
pg_distrib_dir=pg_distrib_dir,
|
||||
compatibility_pg_distrib_dir=compatibility_pg_distrib_dir,
|
||||
pg_version=pg_version,
|
||||
run_id=run_id,
|
||||
preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")),
|
||||
@@ -1435,6 +1423,7 @@ def neon_simple_env(
|
||||
pageserver_aux_file_policy=pageserver_aux_file_policy,
|
||||
pageserver_default_tenant_config_compaction_algorithm=pageserver_default_tenant_config_compaction_algorithm,
|
||||
pageserver_virtual_file_io_mode=pageserver_virtual_file_io_mode,
|
||||
combination=combination,
|
||||
) as builder:
|
||||
env = builder.init_start()
|
||||
|
||||
@@ -1448,7 +1437,9 @@ def neon_env_builder(
|
||||
port_distributor: PortDistributor,
|
||||
mock_s3_server: MockS3Server,
|
||||
neon_binpath: Path,
|
||||
compatibility_neon_binpath: Path,
|
||||
pg_distrib_dir: Path,
|
||||
compatibility_pg_distrib_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
run_id: uuid.UUID,
|
||||
request: FixtureRequest,
|
||||
@@ -1475,6 +1466,11 @@ def neon_env_builder(
|
||||
|
||||
# Create the environment in the test-specific output dir
|
||||
repo_dir = os.path.join(test_output_dir, "repo")
|
||||
combination = (
|
||||
request._pyfuncitem.callspec.params["combination"]
|
||||
if "combination" in request._pyfuncitem.callspec.params
|
||||
else None
|
||||
)
|
||||
|
||||
# Return the builder to the caller
|
||||
with NeonEnvBuilder(
|
||||
@@ -1483,7 +1479,10 @@ def neon_env_builder(
|
||||
port_distributor=port_distributor,
|
||||
mock_s3_server=mock_s3_server,
|
||||
neon_binpath=neon_binpath,
|
||||
compatibility_neon_binpath=compatibility_neon_binpath,
|
||||
pg_distrib_dir=pg_distrib_dir,
|
||||
compatibility_pg_distrib_dir=compatibility_pg_distrib_dir,
|
||||
combination=combination,
|
||||
pg_version=pg_version,
|
||||
run_id=run_id,
|
||||
preserve_database_files=cast(bool, pytestconfig.getoption("--preserve-database-files")),
|
||||
@@ -3657,7 +3656,7 @@ class Endpoint(PgProtocol, LogUtils):
|
||||
config_lines: Optional[list[str]] = None,
|
||||
remote_ext_config: Optional[str] = None,
|
||||
pageserver_id: Optional[int] = None,
|
||||
allow_multiple=False,
|
||||
allow_multiple: bool = False,
|
||||
basebackup_request_tries: Optional[int] = None,
|
||||
) -> Endpoint:
|
||||
"""
|
||||
@@ -3998,7 +3997,7 @@ class Safekeeper(LogUtils):
|
||||
def timeline_dir(self, tenant_id, timeline_id) -> Path:
|
||||
return self.data_dir / str(tenant_id) / str(timeline_id)
|
||||
|
||||
# List partial uploaded segments of this safekeeper. Works only for
|
||||
# list partial uploaded segments of this safekeeper. Works only for
|
||||
# RemoteStorageKind.LOCAL_FS.
|
||||
def list_uploaded_segments(self, tenant_id: TenantId, timeline_id: TimelineId):
|
||||
tline_path = (
|
||||
@@ -4246,44 +4245,6 @@ class StorageScrubber:
|
||||
raise
|
||||
|
||||
|
||||
def _get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str) -> Path:
|
||||
"""Compute the path to a working directory for an individual test."""
|
||||
test_name = request.node.name
|
||||
test_dir = top_output_dir / f"{prefix}{test_name.replace('/', '-')}"
|
||||
|
||||
# We rerun flaky tests multiple times, use a separate directory for each run.
|
||||
if (suffix := getattr(request.node, "execution_count", None)) is not None:
|
||||
test_dir = test_dir.parent / f"{test_dir.name}-{suffix}"
|
||||
|
||||
log.info(f"get_test_output_dir is {test_dir}")
|
||||
# make mypy happy
|
||||
assert isinstance(test_dir, Path)
|
||||
return test_dir
|
||||
|
||||
|
||||
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
"""
|
||||
The working directory for a test.
|
||||
"""
|
||||
return _get_test_dir(request, top_output_dir, "")
|
||||
|
||||
|
||||
def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
"""
|
||||
Directory that contains `upperdir` and `workdir` for overlayfs mounts
|
||||
that a test creates. See `NeonEnvBuilder.overlay_mount`.
|
||||
"""
|
||||
return _get_test_dir(request, top_output_dir, "overlay-")
|
||||
|
||||
|
||||
def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path:
|
||||
return top_output_dir / "shared-snapshots" / snapshot_name
|
||||
|
||||
|
||||
def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
return get_test_output_dir(request, top_output_dir) / "repo"
|
||||
|
||||
|
||||
def pytest_addoption(parser: Parser):
|
||||
parser.addoption(
|
||||
"--preserve-database-files",
|
||||
@@ -4293,154 +4254,11 @@ def pytest_addoption(parser: Parser):
|
||||
)
|
||||
|
||||
|
||||
SMALL_DB_FILE_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
SMALL_DB_FILE_NAME_REGEX: re.Pattern[str] = re.compile(
|
||||
r"config-v1|heatmap-v1|metadata|.+\.(?:toml|pid|json|sql|conf)"
|
||||
)
|
||||
|
||||
|
||||
# This is autouse, so the test output directory always gets created, even
|
||||
# if a test doesn't put anything there.
|
||||
#
|
||||
# NB: we request the overlay dir fixture so the fixture does its cleanups
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def test_output_dir(
|
||||
request: FixtureRequest, top_output_dir: Path, test_overlay_dir: Path
|
||||
) -> Iterator[Path]:
|
||||
"""Create the working directory for an individual test."""
|
||||
|
||||
# one directory per test
|
||||
test_dir = get_test_output_dir(request, top_output_dir)
|
||||
log.info(f"test_output_dir is {test_dir}")
|
||||
shutil.rmtree(test_dir, ignore_errors=True)
|
||||
test_dir.mkdir()
|
||||
|
||||
yield test_dir
|
||||
|
||||
# Allure artifacts creation might involve the creation of `.tar.zst` archives,
|
||||
# which aren't going to be used if Allure results collection is not enabled
|
||||
# (i.e. --alluredir is not set).
|
||||
# Skip `allure_attach_from_dir` in this case
|
||||
if not request.config.getoption("--alluredir"):
|
||||
return
|
||||
|
||||
preserve_database_files = False
|
||||
for k, v in request.node.user_properties:
|
||||
# NB: the neon_env_builder fixture uses this fixture (test_output_dir).
|
||||
# So, neon_env_builder's cleanup runs before here.
|
||||
# The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property.
|
||||
if k == "preserve_database_files":
|
||||
assert isinstance(v, bool)
|
||||
preserve_database_files = v
|
||||
|
||||
allure_attach_from_dir(test_dir, preserve_database_files)
|
||||
|
||||
|
||||
class FileAndThreadLock:
|
||||
def __init__(self, path: Path):
|
||||
self.path = path
|
||||
self.thread_lock = threading.Lock()
|
||||
self.fd: Optional[int] = None
|
||||
|
||||
def __enter__(self):
|
||||
self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY)
|
||||
# lock thread lock before file lock so that there's no race
|
||||
# around flocking / funlocking the file lock
|
||||
self.thread_lock.acquire()
|
||||
flock(self.fd, LOCK_EX)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
assert self.fd is not None
|
||||
assert self.thread_lock.locked() # ... by us
|
||||
flock(self.fd, LOCK_UN)
|
||||
self.thread_lock.release()
|
||||
os.close(self.fd)
|
||||
self.fd = None
|
||||
|
||||
|
||||
class SnapshotDirLocked:
|
||||
def __init__(self, parent: SnapshotDir):
|
||||
self._parent = parent
|
||||
|
||||
def is_initialized(self):
|
||||
# TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized.
|
||||
# Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed.
|
||||
return self._parent._marker_file_path.exists()
|
||||
|
||||
def set_initialized(self):
|
||||
self._parent._marker_file_path.write_text("")
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return self._parent._path / "snapshot"
|
||||
|
||||
|
||||
class SnapshotDir:
|
||||
_path: Path
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self._path = path
|
||||
assert self._path.is_dir()
|
||||
self._lock = FileAndThreadLock(self._lock_file_path)
|
||||
|
||||
@property
|
||||
def _lock_file_path(self) -> Path:
|
||||
return self._path / "initializing.flock"
|
||||
|
||||
@property
|
||||
def _marker_file_path(self) -> Path:
|
||||
return self._path / "initialized.marker"
|
||||
|
||||
def __enter__(self) -> SnapshotDirLocked:
|
||||
self._lock.__enter__()
|
||||
return SnapshotDirLocked(self)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
self._lock.__exit__(exc_type, exc_value, exc_traceback)
|
||||
|
||||
|
||||
def shared_snapshot_dir(top_output_dir, ident: str) -> SnapshotDir:
|
||||
snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident)
|
||||
snapshot_dir_path.mkdir(exist_ok=True, parents=True)
|
||||
return SnapshotDir(snapshot_dir_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]:
|
||||
"""
|
||||
Idempotently create a test's overlayfs mount state directory.
|
||||
If the functionality isn't enabled via env var, returns None.
|
||||
|
||||
The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc).
|
||||
"""
|
||||
|
||||
if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None:
|
||||
return None
|
||||
|
||||
overlay_dir = get_test_overlay_dir(request, top_output_dir)
|
||||
log.info(f"test_overlay_dir is {overlay_dir}")
|
||||
|
||||
overlay_dir.mkdir(exist_ok=True)
|
||||
# unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir`
|
||||
for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)):
|
||||
cmd = ["sudo", "umount", str(mountpoint)]
|
||||
log.info(
|
||||
f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}"
|
||||
)
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
# the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work.
|
||||
cmd = ["sudo", "rm", "-rf", str(overlay_dir)]
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
|
||||
overlay_dir.mkdir()
|
||||
|
||||
return overlay_dir
|
||||
|
||||
# no need to clean up anything: on clean shutdown,
|
||||
# NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup
|
||||
# and on unclean shutdown, this function will take care of it
|
||||
# on the next test run
|
||||
|
||||
|
||||
SKIP_DIRS = frozenset(
|
||||
(
|
||||
"pg_wal",
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import psutil
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
def iter_mounts_beneath(topdir: Path) -> Iterator[Path]:
|
||||
"""
|
||||
|
||||
@@ -886,7 +886,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
|
||||
self,
|
||||
tenant_id: Union[TenantId, TenantShardId],
|
||||
timeline_id: TimelineId,
|
||||
batch_size: int | None = None,
|
||||
batch_size: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> set[TimelineId]:
|
||||
params = {}
|
||||
|
||||
@@ -9,7 +9,12 @@ import toml
|
||||
from _pytest.python import Metafunc
|
||||
|
||||
from fixtures.pg_version import PgVersion
|
||||
from fixtures.utils import AuxFileStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
from fixtures.utils import AuxFileStore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
312
test_runner/fixtures/paths.py
Normal file
312
test_runner/fixtures/paths.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
from fcntl import LOCK_EX, LOCK_UN, flock
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from pytest import FixtureRequest
|
||||
|
||||
from fixtures import overlayfs
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.utils import allure_attach_from_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional
|
||||
|
||||
|
||||
DEFAULT_OUTPUT_DIR: str = "test_output"
|
||||
|
||||
|
||||
def get_test_dir(
|
||||
request: FixtureRequest, top_output_dir: Path, prefix: Optional[str] = None
|
||||
) -> Path:
|
||||
"""Compute the path to a working directory for an individual test."""
|
||||
test_name = request.node.name
|
||||
test_dir = top_output_dir / f"{prefix or ''}{test_name.replace('/', '-')}"
|
||||
|
||||
# We rerun flaky tests multiple times, use a separate directory for each run.
|
||||
if (suffix := getattr(request.node, "execution_count", None)) is not None:
|
||||
test_dir = test_dir.parent / f"{test_dir.name}-{suffix}"
|
||||
|
||||
return test_dir
|
||||
|
||||
|
||||
def get_test_output_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
"""
|
||||
The working directory for a test.
|
||||
"""
|
||||
return get_test_dir(request, top_output_dir)
|
||||
|
||||
|
||||
def get_test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
"""
|
||||
Directory that contains `upperdir` and `workdir` for overlayfs mounts
|
||||
that a test creates. See `NeonEnvBuilder.overlay_mount`.
|
||||
"""
|
||||
return get_test_dir(request, top_output_dir, "overlay-")
|
||||
|
||||
|
||||
def get_shared_snapshot_dir_path(top_output_dir: Path, snapshot_name: str) -> Path:
|
||||
return top_output_dir / "shared-snapshots" / snapshot_name
|
||||
|
||||
|
||||
def get_test_repo_dir(request: FixtureRequest, top_output_dir: Path) -> Path:
|
||||
return get_test_output_dir(request, top_output_dir) / "repo"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def base_dir() -> Iterator[Path]:
|
||||
# find the base directory (currently this is the git root)
|
||||
base_dir = Path(__file__).parents[2]
|
||||
log.info(f"base_dir is {base_dir}")
|
||||
|
||||
yield base_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def compute_config_dir(base_dir: Path) -> Iterator[Path]:
|
||||
"""
|
||||
Retrieve the path to the compute configuration directory.
|
||||
"""
|
||||
yield base_dir / "compute" / "etc"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def neon_binpath(base_dir: Path, build_type: str) -> Iterator[Path]:
|
||||
if os.getenv("REMOTE_ENV"):
|
||||
# we are in remote env and do not have neon binaries locally
|
||||
# this is the case for benchmarks run on self-hosted runner
|
||||
return
|
||||
|
||||
# Find the neon binaries.
|
||||
if env_neon_bin := os.environ.get("NEON_BIN"):
|
||||
binpath = Path(env_neon_bin)
|
||||
else:
|
||||
binpath = base_dir / "target" / build_type
|
||||
log.info(f"neon_binpath is {binpath}")
|
||||
|
||||
if not (binpath / "pageserver").exists():
|
||||
raise Exception(f"neon binaries not found at '{binpath}'")
|
||||
|
||||
yield binpath.absolute()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def compatibility_snapshot_dir() -> Iterator[Path]:
|
||||
if os.getenv("REMOTE_ENV"):
|
||||
return
|
||||
compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR")
|
||||
assert (
|
||||
compatibility_snapshot_dir_env is not None
|
||||
), "COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg(PG_VERSION)` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)"
|
||||
compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve()
|
||||
yield compatibility_snapshot_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def compatibility_neon_binpath() -> Optional[Iterator[Path]]:
|
||||
if os.getenv("REMOTE_ENV"):
|
||||
return
|
||||
comp_binpath = None
|
||||
if env_compatibility_neon_binpath := os.environ.get("COMPATIBILITY_NEON_BIN"):
|
||||
comp_binpath = Path(env_compatibility_neon_binpath).resolve().absolute()
|
||||
yield comp_binpath
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pg_distrib_dir(base_dir: Path) -> Iterator[Path]:
|
||||
if env_postgres_bin := os.environ.get("POSTGRES_DISTRIB_DIR"):
|
||||
distrib_dir = Path(env_postgres_bin).resolve()
|
||||
else:
|
||||
distrib_dir = base_dir / "pg_install"
|
||||
|
||||
log.info(f"pg_distrib_dir is {distrib_dir}")
|
||||
yield distrib_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def compatibility_pg_distrib_dir() -> Optional[Iterator[Path]]:
|
||||
compat_distrib_dir = None
|
||||
if env_compat_postgres_bin := os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR"):
|
||||
compat_distrib_dir = Path(env_compat_postgres_bin).resolve()
|
||||
if not compat_distrib_dir.exists():
|
||||
raise Exception(f"compatibility postgres directory not found at {compat_distrib_dir}")
|
||||
|
||||
if compat_distrib_dir:
|
||||
log.info(f"compatibility_pg_distrib_dir is {compat_distrib_dir}")
|
||||
yield compat_distrib_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def top_output_dir(base_dir: Path) -> Iterator[Path]:
|
||||
# Compute the top-level directory for all tests.
|
||||
if env_test_output := os.environ.get("TEST_OUTPUT"):
|
||||
output_dir = Path(env_test_output).resolve()
|
||||
else:
|
||||
output_dir = base_dir / DEFAULT_OUTPUT_DIR
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
log.info(f"top_output_dir is {output_dir}")
|
||||
yield output_dir
|
||||
|
||||
|
||||
# This is autouse, so the test output directory always gets created, even
|
||||
# if a test doesn't put anything there.
|
||||
#
|
||||
# NB: we request the overlay dir fixture so the fixture does its cleanups
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def test_output_dir(request: pytest.FixtureRequest, top_output_dir: Path) -> Iterator[Path]:
|
||||
"""Create the working directory for an individual test."""
|
||||
|
||||
# one directory per test
|
||||
test_dir = get_test_output_dir(request, top_output_dir)
|
||||
log.info(f"test_output_dir is {test_dir}")
|
||||
shutil.rmtree(test_dir, ignore_errors=True)
|
||||
test_dir.mkdir()
|
||||
|
||||
yield test_dir
|
||||
|
||||
# Allure artifacts creation might involve the creation of `.tar.zst` archives,
|
||||
# which aren't going to be used if Allure results collection is not enabled
|
||||
# (i.e. --alluredir is not set).
|
||||
# Skip `allure_attach_from_dir` in this case
|
||||
if not request.config.getoption("--alluredir"):
|
||||
return
|
||||
|
||||
preserve_database_files = False
|
||||
for k, v in request.node.user_properties:
|
||||
# NB: the neon_env_builder fixture uses this fixture (test_output_dir).
|
||||
# So, neon_env_builder's cleanup runs before here.
|
||||
# The cleanup propagates NeonEnvBuilder.preserve_database_files into this user property.
|
||||
if k == "preserve_database_files":
|
||||
assert isinstance(v, bool)
|
||||
preserve_database_files = v
|
||||
|
||||
allure_attach_from_dir(test_dir, preserve_database_files)
|
||||
|
||||
|
||||
class FileAndThreadLock:
|
||||
def __init__(self, path: Path):
|
||||
self.path = path
|
||||
self.thread_lock = threading.Lock()
|
||||
self.fd: Optional[int] = None
|
||||
|
||||
def __enter__(self):
|
||||
self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY)
|
||||
# lock thread lock before file lock so that there's no race
|
||||
# around flocking / funlocking the file lock
|
||||
self.thread_lock.acquire()
|
||||
flock(self.fd, LOCK_EX)
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
exc_traceback: Optional[TracebackType],
|
||||
):
|
||||
assert self.fd is not None
|
||||
assert self.thread_lock.locked() # ... by us
|
||||
flock(self.fd, LOCK_UN)
|
||||
self.thread_lock.release()
|
||||
os.close(self.fd)
|
||||
self.fd = None
|
||||
|
||||
|
||||
class SnapshotDirLocked:
|
||||
def __init__(self, parent: SnapshotDir):
|
||||
self._parent = parent
|
||||
|
||||
def is_initialized(self):
|
||||
# TODO: in the future, take a `tag` as argument and store it in the marker in set_initialized.
|
||||
# Then, in this function, compare marker file contents with the tag to invalidate the snapshot if the tag changed.
|
||||
return self._parent.marker_file_path.exists()
|
||||
|
||||
def set_initialized(self):
|
||||
self._parent.marker_file_path.write_text("")
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return self._parent.path / "snapshot"
|
||||
|
||||
|
||||
class SnapshotDir:
|
||||
_path: Path
|
||||
|
||||
def __init__(self, path: Path):
|
||||
self._path = path
|
||||
assert self._path.is_dir()
|
||||
self._lock = FileAndThreadLock(self.lock_file_path)
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return self._path
|
||||
|
||||
@property
|
||||
def lock_file_path(self) -> Path:
|
||||
return self._path / "initializing.flock"
|
||||
|
||||
@property
|
||||
def marker_file_path(self) -> Path:
|
||||
return self._path / "initialized.marker"
|
||||
|
||||
def __enter__(self) -> SnapshotDirLocked:
|
||||
self._lock.__enter__()
|
||||
return SnapshotDirLocked(self)
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
exc_traceback: Optional[TracebackType],
|
||||
):
|
||||
self._lock.__exit__(exc_type, exc_value, exc_traceback)
|
||||
|
||||
|
||||
def shared_snapshot_dir(top_output_dir: Path, ident: str) -> SnapshotDir:
|
||||
snapshot_dir_path = get_shared_snapshot_dir_path(top_output_dir, ident)
|
||||
snapshot_dir_path.mkdir(exist_ok=True, parents=True)
|
||||
return SnapshotDir(snapshot_dir_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]:
|
||||
"""
|
||||
Idempotently create a test's overlayfs mount state directory.
|
||||
If the functionality isn't enabled via env var, returns None.
|
||||
|
||||
The procedure cleans up after previous runs that were aborted (e.g. due to Ctrl-C, OOM kills, etc).
|
||||
"""
|
||||
|
||||
if os.getenv("NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS") is None:
|
||||
return None
|
||||
|
||||
overlay_dir = get_test_overlay_dir(request, top_output_dir)
|
||||
log.info(f"test_overlay_dir is {overlay_dir}")
|
||||
|
||||
overlay_dir.mkdir(exist_ok=True)
|
||||
# unmount stale overlayfs mounts which subdirectories of `overlay_dir/*` as the overlayfs `upperdir` and `workdir`
|
||||
for mountpoint in overlayfs.iter_mounts_beneath(get_test_output_dir(request, top_output_dir)):
|
||||
cmd = ["sudo", "umount", str(mountpoint)]
|
||||
log.info(
|
||||
f"Unmounting stale overlayfs mount probably created during earlier test run: {cmd}"
|
||||
)
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
# the overlayfs `workdir`` is owned by `root`, shutil.rmtree won't work.
|
||||
cmd = ["sudo", "rm", "-rf", str(overlay_dir)]
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
|
||||
overlay_dir.mkdir()
|
||||
|
||||
return overlay_dir
|
||||
|
||||
# no need to clean up anything: on clean shutdown,
|
||||
# NeonEnvBuilder.overlay_cleanup_teardown takes care of cleanup
|
||||
# and on unclean shutdown, this function will take care of it
|
||||
# on the next test run
|
||||
@@ -2,9 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional
|
||||
|
||||
|
||||
"""
|
||||
This fixture is used to determine which version of Postgres to use for tests.
|
||||
@@ -24,10 +29,12 @@ class PgVersion(str, enum.Enum):
|
||||
NOT_SET = "<-POSTRGRES VERSION IS NOT SET->"
|
||||
|
||||
# Make it less confusing in logs
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"'{self.value}'"
|
||||
|
||||
# Make this explicit for Python 3.11 compatibility, which changes the behavior of enums
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
@@ -38,7 +45,8 @@ class PgVersion(str, enum.Enum):
|
||||
return f"v{self.value}"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value) -> Optional[PgVersion]:
|
||||
@override
|
||||
def _missing_(cls, value: object) -> Optional[PgVersion]:
|
||||
known_values = {v.value for _, v in cls.__members__.items()}
|
||||
|
||||
# Allow passing version as a string with "v" prefix (e.g. "v14")
|
||||
|
||||
@@ -59,10 +59,7 @@ class PortDistributor:
|
||||
if isinstance(value, int):
|
||||
return self._replace_port_int(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
return self._replace_port_str(value)
|
||||
|
||||
raise TypeError(f"unsupported type {type(value)} of {value=}")
|
||||
return self._replace_port_str(value)
|
||||
|
||||
def _replace_port_int(self, value: int) -> int:
|
||||
known_port = self.port_map.get(value)
|
||||
@@ -75,7 +72,7 @@ class PortDistributor:
|
||||
# Use regex to find port in a string
|
||||
# urllib.parse.urlparse produces inconvenient results for cases without scheme like "localhost:5432"
|
||||
# See https://bugs.python.org/issue27657
|
||||
ports = re.findall(r":(\d+)(?:/|$)", value)
|
||||
ports: list[str] = re.findall(r":(\d+)(?:/|$)", value)
|
||||
assert len(ports) == 1, f"can't find port in {value}"
|
||||
port_int = int(ports[0])
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import boto3
|
||||
import toml
|
||||
from moto.server import ThreadedMotoServer
|
||||
from mypy_boto3_s3 import S3Client
|
||||
from typing_extensions import override
|
||||
|
||||
from fixtures.common_types import TenantId, TenantShardId, TimelineId
|
||||
from fixtures.log_helper import log
|
||||
@@ -36,6 +37,7 @@ class RemoteStorageUser(str, enum.Enum):
|
||||
EXTENSIONS = "ext"
|
||||
SAFEKEEPER = "safekeeper"
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
@@ -81,11 +83,13 @@ class LocalFsStorage:
|
||||
def timeline_path(self, tenant_id: TenantId, timeline_id: TimelineId) -> Path:
|
||||
return self.tenant_path(tenant_id) / "timelines" / str(timeline_id)
|
||||
|
||||
def timeline_latest_generation(self, tenant_id, timeline_id):
|
||||
def timeline_latest_generation(
|
||||
self, tenant_id: TenantId, timeline_id: TimelineId
|
||||
) -> Optional[int]:
|
||||
timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id))
|
||||
index_parts = [f for f in timeline_files if f.startswith("index_part")]
|
||||
|
||||
def parse_gen(filename):
|
||||
def parse_gen(filename: str) -> Optional[int]:
|
||||
log.info(f"parsing index_part '{filename}'")
|
||||
parts = filename.split("-")
|
||||
if len(parts) == 2:
|
||||
@@ -93,7 +97,7 @@ class LocalFsStorage:
|
||||
else:
|
||||
return None
|
||||
|
||||
generations = sorted([parse_gen(f) for f in index_parts])
|
||||
generations = sorted([parse_gen(f) for f in index_parts]) # type: ignore
|
||||
if len(generations) == 0:
|
||||
raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}")
|
||||
return generations[-1]
|
||||
@@ -122,14 +126,14 @@ class LocalFsStorage:
|
||||
filename = f"{local_name}-{generation:08x}"
|
||||
return self.timeline_path(tenant_id, timeline_id) / filename
|
||||
|
||||
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId):
|
||||
def index_content(self, tenant_id: TenantId, timeline_id: TimelineId) -> Any:
|
||||
with self.index_path(tenant_id, timeline_id).open("r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def heatmap_path(self, tenant_id: TenantId) -> Path:
|
||||
return self.tenant_path(tenant_id) / TENANT_HEATMAP_FILE_NAME
|
||||
|
||||
def heatmap_content(self, tenant_id):
|
||||
def heatmap_content(self, tenant_id: TenantId) -> Any:
|
||||
with self.heatmap_path(tenant_id).open("r") as f:
|
||||
return json.load(f)
|
||||
|
||||
@@ -297,7 +301,7 @@ class S3Storage:
|
||||
def heatmap_key(self, tenant_id: TenantId) -> str:
|
||||
return f"{self.tenant_path(tenant_id)}/{TENANT_HEATMAP_FILE_NAME}"
|
||||
|
||||
def heatmap_content(self, tenant_id: TenantId):
|
||||
def heatmap_content(self, tenant_id: TenantId) -> Any:
|
||||
r = self.client.get_object(Bucket=self.bucket_name, Key=self.heatmap_key(tenant_id))
|
||||
return json.loads(r["Body"].read().decode("utf-8"))
|
||||
|
||||
@@ -317,7 +321,7 @@ class RemoteStorageKind(str, enum.Enum):
|
||||
def configure(
|
||||
self,
|
||||
repo_dir: Path,
|
||||
mock_s3_server,
|
||||
mock_s3_server: MockS3Server,
|
||||
run_id: str,
|
||||
test_name: str,
|
||||
user: RemoteStorageUser,
|
||||
@@ -451,15 +455,9 @@ def default_remote_storage() -> RemoteStorageKind:
|
||||
|
||||
|
||||
def remote_storage_to_toml_dict(remote_storage: RemoteStorage) -> dict[str, Any]:
|
||||
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
|
||||
raise Exception("invalid remote storage type")
|
||||
|
||||
return remote_storage.to_toml_dict()
|
||||
|
||||
|
||||
# serialize as toml inline table
|
||||
def remote_storage_to_toml_inline_table(remote_storage: RemoteStorage) -> str:
|
||||
if not isinstance(remote_storage, (LocalFsStorage, S3Storage)):
|
||||
raise Exception("invalid remote storage type")
|
||||
|
||||
return remote_storage.to_toml_inline_table()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
@@ -12,6 +12,9 @@ from werkzeug.wrappers.response import Response
|
||||
|
||||
from fixtures.log_helper import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class StorageControllerProxy:
|
||||
def __init__(self, server: HTTPServer):
|
||||
@@ -34,7 +37,7 @@ def proxy_request(method: str, url: str, **kwargs) -> requests.Response:
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def storage_controller_proxy(make_httpserver):
|
||||
def storage_controller_proxy(make_httpserver: HTTPServer):
|
||||
"""
|
||||
Proxies requests into the storage controller to the currently
|
||||
selected storage controller instance via `StorageControllerProxy.route_to`.
|
||||
@@ -48,7 +51,7 @@ def storage_controller_proxy(make_httpserver):
|
||||
|
||||
log.info(f"Storage controller proxy listening on {self.listen}")
|
||||
|
||||
def handler(request: Request):
|
||||
def handler(request: Request) -> Response:
|
||||
if self.route_to is None:
|
||||
log.info(f"Storage controller proxy has no routing configured for {request.url}")
|
||||
return Response("Routing not configured", status=503)
|
||||
|
||||
@@ -18,6 +18,7 @@ from urllib.parse import urlencode
|
||||
import allure
|
||||
import zstandard
|
||||
from psycopg2.extensions import cursor
|
||||
from typing_extensions import override
|
||||
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.pageserver.common_types import (
|
||||
@@ -26,28 +27,45 @@ from fixtures.pageserver.common_types import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import (
|
||||
IO,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterable
|
||||
from typing import IO, Optional
|
||||
|
||||
from fixtures.common_types import TimelineId
|
||||
from fixtures.neon_fixtures import PgBin
|
||||
from fixtures.common_types import TimelineId
|
||||
|
||||
WaitUntilRet = TypeVar("WaitUntilRet")
|
||||
|
||||
|
||||
Fn = TypeVar("Fn", bound=Callable[..., Any])
|
||||
COMPONENT_BINARIES = {
|
||||
"storage_controller": ("storage_controller",),
|
||||
"storage_broker": ("storage_broker",),
|
||||
"compute": ("compute_ctl",),
|
||||
"safekeeper": ("safekeeper",),
|
||||
"pageserver": ("pageserver", "pagectl"),
|
||||
}
|
||||
# Disable auto-formatting for better readability
|
||||
# fmt: off
|
||||
VERSIONS_COMBINATIONS = (
|
||||
{"storage_controller": "new", "storage_broker": "new", "compute": "new", "safekeeper": "new", "pageserver": "new"},
|
||||
{"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "old"},
|
||||
{"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "old", "pageserver": "new"},
|
||||
{"storage_controller": "new", "storage_broker": "new", "compute": "old", "safekeeper": "new", "pageserver": "new"},
|
||||
{"storage_controller": "old", "storage_broker": "old", "compute": "new", "safekeeper": "new", "pageserver": "new"},
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
|
||||
def subprocess_capture(
|
||||
capture_dir: Path,
|
||||
cmd: list[str],
|
||||
*,
|
||||
check=False,
|
||||
echo_stderr=False,
|
||||
echo_stdout=False,
|
||||
capture_stdout=False,
|
||||
timeout=None,
|
||||
with_command_header=True,
|
||||
check: bool = False,
|
||||
echo_stderr: bool = False,
|
||||
echo_stdout: bool = False,
|
||||
capture_stdout: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
with_command_header: bool = True,
|
||||
**popen_kwargs: Any,
|
||||
) -> tuple[str, Optional[str], int]:
|
||||
"""Run a process and bifurcate its output to files and the `log` logger
|
||||
@@ -84,6 +102,7 @@ def subprocess_capture(
|
||||
self.capture = capture
|
||||
self.captured = ""
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
first = with_command_header
|
||||
for line in self.in_file:
|
||||
@@ -165,10 +184,10 @@ def global_counter() -> int:
|
||||
def print_gc_result(row: dict[str, Any]):
|
||||
log.info("GC duration {elapsed} ms".format_map(row))
|
||||
log.info(
|
||||
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
|
||||
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}".format_map(
|
||||
row
|
||||
)
|
||||
(
|
||||
" total: {layers_total}, needed_by_cutoff {layers_needed_by_cutoff}, needed_by_pitr {layers_needed_by_pitr}"
|
||||
" needed_by_branches: {layers_needed_by_branches}, not_updated: {layers_not_updated}, removed: {layers_removed}"
|
||||
).format_map(row)
|
||||
)
|
||||
|
||||
|
||||
@@ -226,7 +245,7 @@ def get_scale_for_db(size_mb: int) -> int:
|
||||
return round(0.06689 * size_mb - 0.5)
|
||||
|
||||
|
||||
ATTACHMENT_NAME_REGEX: re.Pattern = re.compile( # type: ignore[type-arg]
|
||||
ATTACHMENT_NAME_REGEX: re.Pattern[str] = re.compile(
|
||||
r"regression\.(diffs|out)|.+\.(?:log|stderr|stdout|filediff|metrics|html|walredo)"
|
||||
)
|
||||
|
||||
@@ -289,7 +308,7 @@ LOGS_STAGING_DATASOURCE_ID = "xHHYY0dVz"
|
||||
|
||||
def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int, end_ms: int):
|
||||
"""Add links to server logs in Grafana to Allure report"""
|
||||
links = {}
|
||||
links: dict[str, str] = {}
|
||||
# We expect host to be in format like ep-divine-night-159320.us-east-2.aws.neon.build
|
||||
endpoint_id, region_id, _ = host.split(".", 2)
|
||||
|
||||
@@ -341,7 +360,7 @@ def allure_add_grafana_links(host: str, timeline_id: TimelineId, start_ms: int,
|
||||
|
||||
|
||||
def start_in_background(
|
||||
command: list[str], cwd: Path, log_file_name: str, is_started: Fn
|
||||
command: list[str], cwd: Path, log_file_name: str, is_started: Callable[[], WaitUntilRet]
|
||||
) -> subprocess.Popen[bytes]:
|
||||
"""Starts a process, creates the logfile and redirects stderr and stdout there. Runs the start checks before the process is started, or errors."""
|
||||
|
||||
@@ -376,14 +395,11 @@ def start_in_background(
|
||||
return spawned_process
|
||||
|
||||
|
||||
WaitUntilRet = TypeVar("WaitUntilRet")
|
||||
|
||||
|
||||
def wait_until(
|
||||
number_of_iterations: int,
|
||||
interval: float,
|
||||
func: Callable[[], WaitUntilRet],
|
||||
show_intermediate_error=False,
|
||||
show_intermediate_error: bool = False,
|
||||
) -> WaitUntilRet:
|
||||
"""
|
||||
Wait until 'func' returns successfully, without exception. Returns the
|
||||
@@ -464,7 +480,7 @@ def humantime_to_ms(humantime: str) -> float:
|
||||
def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list[tuple[int, str]]:
|
||||
# FIXME: this duplicates test_runner/fixtures/pageserver/allowed_errors.py
|
||||
error_or_warn = re.compile(r"\s(ERROR|WARN)")
|
||||
errors = []
|
||||
errors: list[tuple[int, str]] = []
|
||||
for lineno, line in enumerate(input, start=1):
|
||||
if len(line) == 0:
|
||||
continue
|
||||
@@ -484,7 +500,7 @@ def scan_log_for_errors(input: Iterable[str], allowed_errors: list[str]) -> list
|
||||
return errors
|
||||
|
||||
|
||||
def assert_no_errors(log_file, service, allowed_errors):
|
||||
def assert_no_errors(log_file: Path, service: str, allowed_errors: list[str]):
|
||||
if not log_file.exists():
|
||||
log.warning(f"Skipping {service} log check: {log_file} does not exist")
|
||||
return
|
||||
@@ -504,9 +520,11 @@ class AuxFileStore(str, enum.Enum):
|
||||
V2 = "v2"
|
||||
CrossValidation = "cross-validation"
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"'aux-{self.value}'"
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
return f"'aux-{self.value}'"
|
||||
|
||||
@@ -525,7 +543,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
|
||||
"""
|
||||
started_at = time.time()
|
||||
|
||||
def hash_extracted(reader: Union[IO[bytes], None]) -> bytes:
|
||||
def hash_extracted(reader: Optional[IO[bytes]]) -> bytes:
|
||||
assert reader is not None
|
||||
digest = sha256(usedforsecurity=False)
|
||||
while True:
|
||||
@@ -550,7 +568,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str
|
||||
right_list
|
||||
), f"unexpected number of files on tar files, {len(left_list)} != {len(right_list)}"
|
||||
|
||||
mismatching = set()
|
||||
mismatching: set[str] = set()
|
||||
|
||||
for left_tuple, right_tuple in zip(left_list, right_list):
|
||||
left_path, left_hash = left_tuple
|
||||
@@ -575,6 +593,7 @@ class PropagatingThread(threading.Thread):
|
||||
Simple Thread wrapper with join() propagating the possible exception in the thread.
|
||||
"""
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
self.exc = None
|
||||
try:
|
||||
@@ -582,7 +601,8 @@ class PropagatingThread(threading.Thread):
|
||||
except BaseException as e:
|
||||
self.exc = e
|
||||
|
||||
def join(self, timeout=None):
|
||||
@override
|
||||
def join(self, timeout: Optional[float] = None) -> Any:
|
||||
super().join(timeout)
|
||||
if self.exc:
|
||||
raise self.exc
|
||||
@@ -604,3 +624,19 @@ def human_bytes(amt: float) -> str:
|
||||
amt = amt / 1024
|
||||
|
||||
raise RuntimeError("unreachable")
|
||||
|
||||
|
||||
def allpairs_versions():
|
||||
"""
|
||||
Returns a dictionary with arguments for pytest parametrize
|
||||
to test the compatibility with the previous version of Neon components
|
||||
combinations were pre-computed to test all the pairs of the components with
|
||||
the different versions.
|
||||
"""
|
||||
ids = []
|
||||
for pair in VERSIONS_COMBINATIONS:
|
||||
cur_id = []
|
||||
for component in sorted(pair.keys()):
|
||||
cur_id.append(pair[component][0])
|
||||
ids.append(f"combination_{''.join(cur_id)}")
|
||||
return {"argnames": "combination", "argvalues": VERSIONS_COMBINATIONS, "ids": ids}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fixtures.common_types import TenantId, TimelineId
|
||||
from fixtures.log_helper import log
|
||||
@@ -14,6 +14,9 @@ from fixtures.neon_fixtures import (
|
||||
)
|
||||
from fixtures.pageserver.utils import wait_for_last_record_lsn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Optional
|
||||
|
||||
# neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex
|
||||
# to ensure we don't do that: this enables running lots of Workloads in parallel safely.
|
||||
ENDPOINT_LOCK = threading.Lock()
|
||||
@@ -100,7 +103,7 @@ class Workload:
|
||||
self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id
|
||||
)
|
||||
|
||||
def write_rows(self, n, pageserver_id: Optional[int] = None, upload: bool = True):
|
||||
def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True):
|
||||
endpoint = self.endpoint(pageserver_id)
|
||||
start = self.expect_rows
|
||||
end = start + n - 1
|
||||
@@ -121,7 +124,9 @@ class Workload:
|
||||
else:
|
||||
return False
|
||||
|
||||
def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True):
|
||||
def churn_rows(
|
||||
self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True
|
||||
):
|
||||
assert self.expect_rows >= n
|
||||
|
||||
max_iters = 10
|
||||
|
||||
@@ -4,7 +4,7 @@ import enum
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
@@ -16,6 +16,10 @@ from fixtures.pageserver.http import PageserverApiException
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional
|
||||
|
||||
|
||||
AGGRESIVE_COMPACTION_TENANT_CONF = {
|
||||
# Disable gc and compaction. The test runs compaction manually.
|
||||
"gc_period": "0s",
|
||||
|
||||
@@ -9,6 +9,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import fixtures.utils
|
||||
import pytest
|
||||
import toml
|
||||
from fixtures.common_types import TenantId, TimelineId
|
||||
@@ -93,6 +94,34 @@ if TYPE_CHECKING:
|
||||
# # Run forward compatibility test
|
||||
# ./scripts/pytest -k test_forward_compatibility
|
||||
#
|
||||
#
|
||||
# How to run `test_version_mismatch` locally:
|
||||
#
|
||||
# export DEFAULT_PG_VERSION=16
|
||||
# export BUILD_TYPE=release
|
||||
# export CHECK_ONDISK_DATA_COMPATIBILITY=true
|
||||
# export COMPATIBILITY_NEON_BIN=neon_previous/target/${BUILD_TYPE}
|
||||
# export COMPATIBILITY_POSTGRES_DISTRIB_DIR=neon_previous/pg_install
|
||||
# export NEON_BIN=target/release
|
||||
# export POSTGRES_DISTRIB_DIR=pg_install
|
||||
#
|
||||
# # Build previous version of binaries and store them somewhere:
|
||||
# rm -rf pg_install target
|
||||
# git checkout <previous version>
|
||||
# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc`
|
||||
# mkdir -p neon_previous/target
|
||||
# cp -a target/${BUILD_TYPE} ./neon_previous/target/${BUILD_TYPE}
|
||||
# cp -a pg_install ./neon_previous/pg_install
|
||||
#
|
||||
# # Build current version of binaries and create a data snapshot:
|
||||
# rm -rf pg_install target
|
||||
# git checkout <current version>
|
||||
# CARGO_BUILD_FLAGS="--features=testing" make -s -j`nproc`
|
||||
# ./scripts/pytest -k test_create_snapshot
|
||||
#
|
||||
# # Run the version mismatch test
|
||||
# ./scripts/pytest -k test_version_mismatch
|
||||
|
||||
|
||||
check_ondisk_data_compatibility_if_enabled = pytest.mark.skipif(
|
||||
os.environ.get("CHECK_ONDISK_DATA_COMPATIBILITY") is None,
|
||||
@@ -166,16 +195,11 @@ def test_backward_compatibility(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
test_output_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
compatibility_snapshot_dir: Path,
|
||||
):
|
||||
"""
|
||||
Test that the new binaries can read old data
|
||||
"""
|
||||
compatibility_snapshot_dir_env = os.environ.get("COMPATIBILITY_SNAPSHOT_DIR")
|
||||
assert (
|
||||
compatibility_snapshot_dir_env is not None
|
||||
), f"COMPATIBILITY_SNAPSHOT_DIR is not set. It should be set to `compatibility_snapshot_pg{pg_version.v_prefixed}` path generateted by test_create_snapshot (ideally generated by the previous version of Neon)"
|
||||
compatibility_snapshot_dir = Path(compatibility_snapshot_dir_env).resolve()
|
||||
|
||||
breaking_changes_allowed = (
|
||||
os.environ.get("ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true"
|
||||
)
|
||||
@@ -214,27 +238,11 @@ def test_forward_compatibility(
|
||||
test_output_dir: Path,
|
||||
top_output_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
compatibility_snapshot_dir: Path,
|
||||
):
|
||||
"""
|
||||
Test that the old binaries can read new data
|
||||
"""
|
||||
compatibility_neon_bin_env = os.environ.get("COMPATIBILITY_NEON_BIN")
|
||||
assert compatibility_neon_bin_env is not None, (
|
||||
"COMPATIBILITY_NEON_BIN is not set. It should be set to a path with Neon binaries "
|
||||
"(ideally generated by the previous version of Neon)"
|
||||
)
|
||||
compatibility_neon_bin = Path(compatibility_neon_bin_env).resolve()
|
||||
|
||||
compatibility_postgres_distrib_dir_env = os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR")
|
||||
assert (
|
||||
compatibility_postgres_distrib_dir_env is not None
|
||||
), "COMPATIBILITY_POSTGRES_DISTRIB_DIR is not set. It should be set to a pg_install directrory (ideally generated by the previous version of Neon)"
|
||||
compatibility_postgres_distrib_dir = Path(compatibility_postgres_distrib_dir_env).resolve()
|
||||
|
||||
compatibility_snapshot_dir = (
|
||||
top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}"
|
||||
)
|
||||
|
||||
breaking_changes_allowed = (
|
||||
os.environ.get("ALLOW_FORWARD_COMPATIBILITY_BREAKAGE", "false").lower() == "true"
|
||||
)
|
||||
@@ -245,9 +253,14 @@ def test_forward_compatibility(
|
||||
# Use previous version's production binaries (pageserver, safekeeper, pg_distrib_dir, etc.).
|
||||
# But always use the current version's neon_local binary.
|
||||
# This is because we want to test the compatibility of the data format, not the compatibility of the neon_local CLI.
|
||||
neon_env_builder.neon_binpath = compatibility_neon_bin
|
||||
neon_env_builder.pg_distrib_dir = compatibility_postgres_distrib_dir
|
||||
neon_env_builder.neon_local_binpath = neon_env_builder.neon_local_binpath
|
||||
assert (
|
||||
neon_env_builder.compatibility_neon_binpath is not None
|
||||
), "the environment variable COMPATIBILITY_NEON_BIN is required"
|
||||
assert (
|
||||
neon_env_builder.compatibility_pg_distrib_dir is not None
|
||||
), "the environment variable COMPATIBILITY_POSTGRES_DISTRIB_DIR is required"
|
||||
neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath
|
||||
neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir
|
||||
|
||||
env = neon_env_builder.from_repo_dir(
|
||||
compatibility_snapshot_dir / "repo",
|
||||
@@ -558,3 +571,29 @@ def test_historic_storage_formats(
|
||||
env.pageserver.http_client().timeline_compact(
|
||||
dataset.tenant_id, existing_timeline_id, force_image_layer_creation=True
|
||||
)
|
||||
|
||||
|
||||
@check_ondisk_data_compatibility_if_enabled
|
||||
@pytest.mark.xdist_group("compatibility")
|
||||
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
|
||||
def test_versions_mismatch(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
test_output_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
compatibility_snapshot_dir,
|
||||
combination,
|
||||
):
|
||||
"""
|
||||
Checks compatibility of different combinations of versions of the components
|
||||
"""
|
||||
neon_env_builder.num_safekeepers = 3
|
||||
env = neon_env_builder.from_repo_dir(
|
||||
compatibility_snapshot_dir / "repo",
|
||||
)
|
||||
env.pageserver.allowed_errors.extend(
|
||||
[".*ingesting record with timestamp lagging more than wait_lsn_timeout.+"]
|
||||
)
|
||||
env.start()
|
||||
check_neon_works(
|
||||
env, test_output_dir, compatibility_snapshot_dir / "dump.sql", test_output_dir / "repo"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import time
|
||||
from logging import info
|
||||
|
||||
from fixtures.metrics import parse_metrics
|
||||
from fixtures.neon_fixtures import NeonEnv
|
||||
|
||||
|
||||
@@ -87,36 +85,3 @@ def test_installed_extensions(neon_simple_env: NeonEnv):
|
||||
assert ext["n_databases"] == 2
|
||||
ext["versions"].sort()
|
||||
assert ext["versions"] == ["1.2", "1.3"]
|
||||
|
||||
# check that /metrics endpoint is available
|
||||
# ensure that we see the metric before and after restart
|
||||
res = client.metrics()
|
||||
info("Metrics: %s", res)
|
||||
m = parse_metrics(res)
|
||||
neon_m = m.query_all("installed_extensions", {"extension_name": "neon", "versions": "1.2,1.3"})
|
||||
assert len(neon_m) == 1
|
||||
for sample in neon_m:
|
||||
assert sample.value == 2
|
||||
|
||||
endpoint.stop()
|
||||
endpoint.start()
|
||||
|
||||
timeout = 5
|
||||
while timeout > 0:
|
||||
try:
|
||||
res = client.metrics()
|
||||
timeout = -1
|
||||
except Exception as e:
|
||||
info("failed to get metrics, assume they are not collected yet: %s", e)
|
||||
time.sleep(1)
|
||||
timeout -= 1
|
||||
continue
|
||||
|
||||
info("After restart metrics: %s", res)
|
||||
m = parse_metrics(res)
|
||||
neon_m = m.query_all(
|
||||
"installed_extensions", {"extension_name": "neon", "versions": "1.2,1.3"}
|
||||
)
|
||||
assert len(neon_m) == 1
|
||||
for sample in neon_m:
|
||||
assert sample.value == 2
|
||||
|
||||
@@ -162,6 +162,11 @@ def test_cli_start_stop_multi(neon_env_builder: NeonEnvBuilder):
|
||||
env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID)
|
||||
env.neon_cli.pageserver_stop(env.BASE_PAGESERVER_ID + 1)
|
||||
|
||||
# We will stop the storage controller while it may have requests in
|
||||
# flight, and the pageserver complains when requests are abandoned.
|
||||
for ps in env.pageservers:
|
||||
ps.allowed_errors.append(".*request was dropped before completing.*")
|
||||
|
||||
# Keep NeonEnv state up to date, it usually owns starting/stopping services
|
||||
env.pageservers[0].running = False
|
||||
env.pageservers[1].running = False
|
||||
|
||||
@@ -15,7 +15,7 @@ import enum
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fixtures.common_types import TenantId, TimelineId
|
||||
@@ -40,6 +40,10 @@ from fixtures.remote_storage import (
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# A tenant configuration that is convenient for generating uploads and deletions
|
||||
# without a large amount of postgres traffic.
|
||||
TENANT_CONF = {
|
||||
|
||||
@@ -23,6 +23,7 @@ from fixtures.remote_storage import s3_storage
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
from pytest_httpserver import HTTPServer
|
||||
from typing_extensions import override
|
||||
from werkzeug.wrappers.request import Request
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
@@ -954,6 +955,7 @@ class PageserverFailpoint(Failure):
|
||||
self.pageserver_id = pageserver_id
|
||||
self._mitigate = mitigate
|
||||
|
||||
@override
|
||||
def apply(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.allowed_errors.extend(
|
||||
@@ -961,19 +963,23 @@ class PageserverFailpoint(Failure):
|
||||
)
|
||||
pageserver.http_client().configure_failpoints((self.failpoint, "return(1)"))
|
||||
|
||||
@override
|
||||
def clear(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.http_client().configure_failpoints((self.failpoint, "off"))
|
||||
if self._mitigate:
|
||||
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Active"})
|
||||
|
||||
@override
|
||||
def expect_available(self):
|
||||
return True
|
||||
|
||||
@override
|
||||
def can_mitigate(self):
|
||||
return self._mitigate
|
||||
|
||||
def mitigate(self, env):
|
||||
@override
|
||||
def mitigate(self, env: NeonEnv):
|
||||
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
|
||||
|
||||
|
||||
@@ -983,9 +989,11 @@ class StorageControllerFailpoint(Failure):
|
||||
self.pageserver_id = None
|
||||
self.action = action
|
||||
|
||||
@override
|
||||
def apply(self, env: NeonEnv):
|
||||
env.storage_controller.configure_failpoints((self.failpoint, self.action))
|
||||
|
||||
@override
|
||||
def clear(self, env: NeonEnv):
|
||||
if "panic" in self.action:
|
||||
log.info("Restarting storage controller after panic")
|
||||
@@ -994,16 +1002,19 @@ class StorageControllerFailpoint(Failure):
|
||||
else:
|
||||
env.storage_controller.configure_failpoints((self.failpoint, "off"))
|
||||
|
||||
@override
|
||||
def expect_available(self):
|
||||
# Controller panics _do_ leave pageservers available, but our test code relies
|
||||
# on using the locate API to update configurations in Workload, so we must skip
|
||||
# these actions when the controller has been panicked.
|
||||
return "panic" not in self.action
|
||||
|
||||
@override
|
||||
def can_mitigate(self):
|
||||
return False
|
||||
|
||||
def fails_forward(self, env):
|
||||
@override
|
||||
def fails_forward(self, env: NeonEnv):
|
||||
# Edge case: the very last failpoint that simulates a DB connection error, where
|
||||
# the abort path will fail-forward and result in a complete split.
|
||||
fail_forward = self.failpoint == "shard-split-post-complete"
|
||||
@@ -1017,6 +1028,7 @@ class StorageControllerFailpoint(Failure):
|
||||
|
||||
return fail_forward
|
||||
|
||||
@override
|
||||
def expect_exception(self):
|
||||
if "panic" in self.action:
|
||||
return requests.exceptions.ConnectionError
|
||||
@@ -1029,18 +1041,22 @@ class NodeKill(Failure):
|
||||
self.pageserver_id = pageserver_id
|
||||
self._mitigate = mitigate
|
||||
|
||||
@override
|
||||
def apply(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.stop(immediate=True)
|
||||
|
||||
@override
|
||||
def clear(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.start()
|
||||
|
||||
@override
|
||||
def expect_available(self):
|
||||
return False
|
||||
|
||||
def mitigate(self, env):
|
||||
@override
|
||||
def mitigate(self, env: NeonEnv):
|
||||
env.storage_controller.node_configure(self.pageserver_id, {"availability": "Offline"})
|
||||
|
||||
|
||||
@@ -1059,21 +1075,26 @@ class CompositeFailure(Failure):
|
||||
self.pageserver_id = f.pageserver_id
|
||||
break
|
||||
|
||||
@override
|
||||
def apply(self, env: NeonEnv):
|
||||
for f in self.failures:
|
||||
f.apply(env)
|
||||
|
||||
def clear(self, env):
|
||||
@override
|
||||
def clear(self, env: NeonEnv):
|
||||
for f in self.failures:
|
||||
f.clear(env)
|
||||
|
||||
@override
|
||||
def expect_available(self):
|
||||
return all(f.expect_available() for f in self.failures)
|
||||
|
||||
def mitigate(self, env):
|
||||
@override
|
||||
def mitigate(self, env: NeonEnv):
|
||||
for f in self.failures:
|
||||
f.mitigate(env)
|
||||
|
||||
@override
|
||||
def expect_exception(self):
|
||||
expect = set(f.expect_exception() for f in self.failures)
|
||||
|
||||
@@ -1211,7 +1232,7 @@ def test_sharding_split_failures(
|
||||
|
||||
assert attached_count == initial_shard_count
|
||||
|
||||
def assert_split_done(exclude_ps_id=None) -> None:
|
||||
def assert_split_done(exclude_ps_id: Optional[int] = None) -> None:
|
||||
secondary_count = 0
|
||||
attached_count = 0
|
||||
for ps in env.pageservers:
|
||||
|
||||
@@ -9,6 +9,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import fixtures.utils
|
||||
import pytest
|
||||
from fixtures.auth_tokens import TokenScope
|
||||
from fixtures.common_types import TenantId, TenantShardId, TimelineId
|
||||
@@ -38,7 +39,11 @@ from fixtures.pg_version import PgVersion, run_only_on_default_postgres
|
||||
from fixtures.port_distributor import PortDistributor
|
||||
from fixtures.remote_storage import RemoteStorageKind, s3_storage
|
||||
from fixtures.storage_controller_proxy import StorageControllerProxy
|
||||
from fixtures.utils import run_pg_bench_small, subprocess_capture, wait_until
|
||||
from fixtures.utils import (
|
||||
run_pg_bench_small,
|
||||
subprocess_capture,
|
||||
wait_until,
|
||||
)
|
||||
from fixtures.workload import Workload
|
||||
from mypy_boto3_s3.type_defs import (
|
||||
ObjectTypeDef,
|
||||
@@ -60,9 +65,8 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids):
|
||||
return counts
|
||||
|
||||
|
||||
def test_storage_controller_smoke(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
):
|
||||
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
|
||||
def test_storage_controller_smoke(neon_env_builder: NeonEnvBuilder, combination):
|
||||
"""
|
||||
Test the basic lifecycle of a storage controller:
|
||||
- Restarting
|
||||
@@ -1038,7 +1042,7 @@ def test_storage_controller_tenant_deletion(
|
||||
)
|
||||
|
||||
# Break the compute hook: we are checking that deletion does not depend on the compute hook being available
|
||||
def break_hook():
|
||||
def break_hook(_body: Any):
|
||||
raise RuntimeError("Unexpected call to compute hook")
|
||||
|
||||
compute_reconfigure_listener.register_on_notify(break_hook)
|
||||
@@ -1300,11 +1304,11 @@ def test_storage_controller_heartbeats(
|
||||
node_to_tenants = build_node_to_tenants_map(env)
|
||||
log.info(f"Back online: {node_to_tenants=}")
|
||||
|
||||
# ... expecting the storage controller to reach a consistent state
|
||||
def storage_controller_consistent():
|
||||
env.storage_controller.consistency_check()
|
||||
# ... background reconciliation may need to run to clean up the location on the node that was offline
|
||||
env.storage_controller.reconcile_until_idle()
|
||||
|
||||
wait_until(30, 1, storage_controller_consistent)
|
||||
# ... expecting the storage controller to reach a consistent state
|
||||
env.storage_controller.consistency_check()
|
||||
|
||||
|
||||
def test_storage_controller_re_attach(neon_env_builder: NeonEnvBuilder):
|
||||
|
||||
@@ -6,7 +6,7 @@ import shutil
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fixtures.common_types import TenantId, TenantShardId, TimelineId
|
||||
@@ -20,6 +20,9 @@ from fixtures.remote_storage import S3Storage, s3_storage
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shard_count", [None, 4])
|
||||
def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]):
|
||||
|
||||
@@ -58,6 +58,7 @@ num-integer = { version = "0.1", features = ["i128"] }
|
||||
num-traits = { version = "0.2", features = ["i128", "libm"] }
|
||||
once_cell = { version = "1" }
|
||||
parquet = { version = "53", default-features = false, features = ["zstd"] }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", default-features = false, features = ["with-serde_json-1"] }
|
||||
prost = { version = "0.13", features = ["prost-derive"] }
|
||||
rand = { version = "0.8", features = ["small_rng"] }
|
||||
regex = { version = "1" }
|
||||
@@ -66,7 +67,7 @@ regex-syntax = { version = "0.8" }
|
||||
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "stream"] }
|
||||
scopeguard = { version = "1" }
|
||||
serde = { version = "1", features = ["alloc", "derive"] }
|
||||
serde_json = { version = "1", features = ["raw_value"] }
|
||||
serde_json = { version = "1", features = ["alloc", "raw_value"] }
|
||||
sha2 = { version = "0.10", features = ["asm", "oid"] }
|
||||
signature = { version = "2", default-features = false, features = ["digest", "rand_core", "std"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
|
||||
@@ -76,6 +77,7 @@ sync_wrapper = { version = "0.1", default-features = false, features = ["futures
|
||||
tikv-jemalloc-sys = { version = "0.5" }
|
||||
time = { version = "0.3", features = ["macros", "serde-well-known"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev = "20031d7a9ee1addeae6e0968e3899ae6bf01cee2", features = ["with-serde_json-1"] }
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }
|
||||
toml_edit = { version = "0.22", features = ["serde"] }
|
||||
|
||||
Reference in New Issue
Block a user